24  Machine learning

24.1 Background

Machine learning (ML) is a branch of artificial intelligence. While there are several definitions of machine learning, they generally involve using computational methods (i.e., machines) to identify patterns in data (i.e., learning).

Machine learning can be divided into supervised and unsupervised learning. Supervised ML is used to predict outcomes based on labeled data, whereas unsupervised ML is used to discover unknown patterns and structures within the data.

24.1.1 Unsupervised machine learning

“Unsupervised” means that the outcomes (e.g., patient status) are not known by the model during its training, and patterns are learned based solely on the data, such as an abundance table.

Common tasks in unsupervised machine learning include dimension reduction and clustering, tasks discussed in Chapter 14 and Chapter 15, respectively.

24.1.2 Supervised machine learning

“Supervised” means that a model is trained on observations paired with a known output (e.g., patient status: healthy or diseased). During training, the model learns patterns from a portion of the data and it is then evaluated on its ability to generalize to new data. This process involves splitting the collected data into a training and testing sets, commonly in 80/20 ratio, although other proportions can be used depending on the size and applications of the dataset.

Training is usually enhanced with cross-validation to improve the model’s robustness. However, when the dataset is small, splitting it into training and test sets might not be feasible. In such cases, cross-validation alone can be used to provide a rough estimate of the model’s performance. This strategy involves dividing the data into K folds (or subsets) of similar size. The model is then trained on K-1 folds, and tested on the remaining fold. This process is repeated K times, allowing each fold to serve as the test set once. While this approach is not as reliable as having a separate test set, it can still give valuable insights into how well the model might perform on new data.

Common tasks for supervised machine learning includes classification (e.g., predict categorical variables) and regression (e.g., predicting continuous variables). This chapter discusses two supervised ML algorithms that can be applied to classification and regression tasks.

Note: ML in multi-omics data analysis

ML applications for the integration of multi-omic datasets are covered in Chapter 22 and Chapter 23

24.2 Setup

Published fecal microbiome data (Qin et al. 2012) will be used to illustrate how to deploy supervised machine learning algorithms to address classification and regression problems. For classification, two ML models will be used to classify subjects in two groups: Type II diabetes patients or Healthy individuals (encoded as T2D and healthy in the metadata). For regression, models will be use to predict the body mass index (BMI) of each subject. In both tasks the models will be trained with participants gut microbiome data.

Ridaura, Vanessa K., Jeremiah J. Faith, Federico E. Rey, Jiye Cheng, Alexis E. Duncan, Andrew L. Kau, Nicholas W. Griffin, et al. 2013. “Gut Microbiota from Twins Discordant for Obesity Modulate Metabolism in Mice.” Science 341 (6150): 1241214. https://doi.org/10.1126/science.1241214.

Mounting evidence show associations between gut microbiome and the onset of type II diabetes. Thus, it has been suggested that microbiome data could be used to discriminate between patients and healthy individuals. Indeed, that hypothesis was explored in the research article where the dataset we will work with was first described (Qin et al. 2012). On the other hand, experiments conducted with twins have shown that while the transplantation of gut microbiome from the obese twin induces obesity in mice, the gut microbiome of the lean twin doesn’t (Ridaura et al. 2013). Thus, predicting the health status of a person or their BMI from their gut microbiome seems a plausible task, and an interesting opportunity to learn about supervised ML algorithms using real-world data.

To do so, the R package mikropml(Topçuoğlu et al. 2020, 2021) will be used throught this chapter. This package was developed to offer a user-friendly interface to supervised ML algorithms implemented in the caret package. Here is a list of models well supported by mikropml.

Topçuoğlu, Begüm D., Nicholas A. Lesniak, Mack T. Ruffin, Jenna Wiens, and Patrick D. Schloss. 2020. “A Framework for Effective Application of Machine Learning to Microbiome-Based Classification Problems.” mBio 11 (3): 10.1128/mbio.00434–20. https://doi.org/10.1128/mbio.00434-20.
Topçuoğlu, Begüm D., Zena Lapp, Kelly L. Sovacool, Evan Snitkin, Jenna Wiens, and Patrick D. Schloss. 2021. mikropml: User-Friendly r Package for Supervised Machine Learning Pipelines.” Journal of Open Source Software 6 (61): 3073. https://doi.org/10.21105/joss.03073.

The code below will retrieve the data and show the number of participants on each category.

library(curatedMetagenomicData)

samples <- sampleMetadata[ sampleMetadata[["study_name"]] == "QinJ_2012", ]

tse <- returnSamples(
    samples, 
    dataType = "relative_abundance",
    counts = TRUE, # use counts instead of rel abundances
    rownames = "short"
)

# Change assay's name to reflect its content
assayNames(tse) <- "counts"

table(tse[["disease"]]) |> t() |> knitr::kable()
T2D healthy
170 193

24.3 Data preprocessing

Before applying any ML algorithm, the data must be preprocessed. This speeds up the training of the models by reducing the amount of features analysed, a desirable outcome when working with high-dimensional microbiome data. In addition to faster performance, common pre-processing steps have biological justifications. For instance:

  • Collapse highly correlated features: In a microbial community, it’s common for the abundance of two or more taxa to be highly correlated. Thus, removing or collapsing correlated taxa allows the model to analyse them as one group.
  • Remove features with near-zero variance: Features that don’t vary enough across groups can hardly help in discerning between them, as they don’t hold any biologically relevant information. Additionally, under certain data splits, these variables can show zero variance.
  • Remove features with low prevalence: Microbiome data is sparse, and taxa present in just a few samples of each group hardly provide useful biological information for their classification.
Important: Data leakage is a common pitfall.

Data leakage occurs when information from the test set influences the model training process, leading to overly optimistic performance estimates. This can happen, for example, if preprocessing steps like scaling are applied to both the training and test data together, allowing the model to indirectly “see” the test data during training. Fortunately, there are questions we can ask ourselves to void data leakage (Bernett et al. 2024)

Bernett, Judith, David B. Blumenthal, Dominik G. Grimm, Florian Haselbeck, Roman Joeres, Olga V. Kalinina, and Markus List. 2024. “Guiding Questions to Avoid Data Leakage in Biological Machine Learning Applications.” Nature Methods 21 (8): 1444–53. https://doi.org/10.1038/s41592-024-02362-y.

The code below shows how to join the abundance of each microbial taxa with different alpha diversity indices that provide ecosystem-level information. See Chapter 13 for a discussion on alpha diversity.

library(mia)
library(mikropml)

# Keep taxa present in more than 10% of samples
tse_prev <- subsetByPrevalent(
    x = tse,
    assay.type = "counts",
    prevalence = 10/100
)

# Get alpha diversity metrics to add them as inputs
alpha_divs <- getAlpha(
    tse_prev, 
    assay.type = "counts"
)

# Apply CLR transformation 
tse_prev <- transformAssay(
    x = tse_prev,
    assay.type = "counts",
    method = "clr",
    MARGIN = "cols",
    pseudocount = TRUE
)

# Get CLR assay
assay <- assay(tse_prev, "clr")
# Transpose assay
assay <- t(assay)
# Join CLR abundances and alpha diversity metrics
assay_alpha <- cbind(assay, alpha_divs)
raw_df <- as.data.frame(assay_alpha)

# Make new names of groups of correlated features 
# by concatenating the names of the group members
names_grp_feats <- group_correlated_features(
    features = raw_df, 
    group_neg_corr = FALSE
)

The next step is to join the microbial abundances and alpha diversities with the outcome of interest (either diagnosis status or BMI) for each observation.

24.3.1 Preprocess for classification task

The code below joins the diagnosis status of each participant (either ‘T2D’ or ‘healthy’) and then preprocess the microbiome data using the three strategies discussed above. The resulting preprocessed dataset will be used to train and test our model in the following section.

# Add labels, i.e., known outcome
labels <- tse_prev[["disease"]]
raw_df[["diagnosis"]] <- labels

# Preprocess data for classification
prep_classification <- preprocess_data(
    dataset = raw_df,
    outcome_colname = "diagnosis",
    method = NULL, # Skip normalization as CLR was performed 
    remove_var = "zv",
    collapse_corr_feats = TRUE,
    group_neg_corr = FALSE
)

# rename grouped features with names generated in previous chunk
new_names <- c("diagnosis", names_grp_feats)
colnames(prep_classification[["dat_transformed"]]) <- new_names

# get preprocessed data
df_classification <- prep_classification[["dat_transformed"]]

24.3.2 Preprocess for regression task

In the code below we join the BMI of each participant before preprocessing it. Again, the resulting preprocessed dataset will be then used to train and test our model.

# Remove 'diagnosis' used before for classification
raw_df <- within(raw_df, rm("diagnosis"))

# Add BMI for regression task
BMI <- tse_prev[["BMI"]]
raw_df[["BMI"]] <- BMI

# Preprocess data for regression task
prep_regression <- preprocess_data(
    dataset = raw_df,
    outcome_colname = "BMI",
    method = NULL, # Skip normalization as CLR was performed 
    remove_var = "zv",
    collapse_corr_feats = TRUE,
    group_neg_corr = FALSE
)

# rename grouped features with names generated above
new_names <- c("BMI", names_grp_feats)
colnames(prep_regression[["dat_transformed"]]) <- new_names

# get preprocessed data
df_regression <- prep_regression[["dat_transformed"]]
Note: Preprocessing strategies must follow research goals

The code above should be deemed as a reference. Preprocessing strategies impact the performance of ML models, and thus the steps used should always be tailored to satisfy specific research goals. Other preprocessing steps can be implemented using it the preprocess_data function, as described here.

24.4 Model training

Now we can deploy supervised ML models on the preprocessed data. In this section, two algorithms will be covered:

  • Random Forest
  • XGBoost

These are within the most used supervised ML in the microbiome field. Possible reasons are that they can be used for classification and regression tasks, and that they show a good balance between performance and interpretability. Although the focus of this book is the implementation and interpretation of these models, other resources are suggested for an introduction to the mathematical underpinnings of these —and other— models (Gareth James 2013).

Gareth James, Robert Tibshirani, Trevor Hastie. 2013. An Introduction to Statistical Learning: With Applications in R. New York : Springer, [2013] ©2013. https://search.library.wisc.edu/catalog/9910207152902121.

24.4.1 Random forest

Random Forests (RF) is an ensemble algorithm. That means that RF deploys and combines the outputs of multiple decision trees through majority voting or averaging the predicted numerical outcome. Since each tree is trained on a random subset of the data and features, RF reduces overfitting and enhance generalization to new data.

When applied to classification problems, each tree predicts the class of an observation (e.g., healthy or T2D) based on the values of the features (e.g., microbial taxa). Each tree finds the best split of the data by reducing how mixed the classes are, often using metrics like entropy or Gini impurity.

Below, the RF algorithm is used for a classification task:

# Train random forest for classification
rf_classification <- run_ml(
    dataset = df_classification, 
    method = "rf", 
    outcome_colname = "diagnosis", 
    seed = 1,
    kfold = 2, 
    cv_times = 2, 
    training_frac = .8, 
    find_feature_importance = FALSE
)

For regression tasks, each tree predicts a numerical outcome (e.g., BMI) based on the values of the features (e.g., microbial taxa). Each tree split the data to minimize the difference between predicted and real (observed) values, often using metrics like mean squared error (MSE). One strength of RF is that it can learn non-linear patterns in the data that simpler regression algorithms might not capture.

Regarding its implementation, notice that just by changing the outcome from a categorical to a continuous variable, RF models can now be used in regression tasks:

# Train random forest for regression
rf_regression <- run_ml(
    dataset = df_regression, 
    method = "rf", 
    outcome_colname = "BMI", 
    seed = 1,
    kfold = 2, 
    cv_times = 2, 
    training_frac = .8, 
    find_feature_importance = FALSE
)

24.4.2 XGBoost

Extreme Gradient Boosting (XGBoost) is another ensemble algorithm where decision trees are sequentially built. In this strategy, each new tree improves the performance of the previous.

When applied to classification, each tree predicts the class of an observation (e.g., healthy or T2D) based on feature values (e.g., microbial taxa). However, instead of ‘voting’ like in Random Forest, XGBoost improves the performance by assigning higher weights (or penalties) to missclassified observations.

It must be noted that XGBoost is a more complex model usually described as one of the best performing models for tabular data (Shwartz-Ziv and Armon 2022), such as the count tables used in microbiome data analysis.

Shwartz-Ziv, Ravid, and Amitai Armon. 2022. “Tabular Data: Deep Learning Is Not All You Need.” Information Fusion 81: 84–90. https://doi.org/https://doi.org/10.1016/j.inffus.2021.11.011.
# Train XGBoost for classification
xgb <- run_ml(
    dataset = df_classification, 
    method = "xgbTree", 
    outcome_colname = "diagnosis", 
    seed = 1,
    kfold = 2, 
    cv_times = 2, 
    training_frac = .8,
    find_feature_importance = FALSE
)

Although not demonstrated in the chapter, XGBoost can be used in regressions tasks using similar code to the one shown above. The only difference will be the outcome variable selected. See the RF example for reference.

24.5 Model performance metrics

Under the hood, the function run_ml generates a default 80/20 split of the data to train and test a model of interest. That means that after the model is trained (i.e., learns patterns) in 80% of the data, the remaining 20% of data (not seen by our model in the training) is used to assess how well our model can generalize the patterns it learned to new data.

Therefore, the goals of this section are: to discuss different metrics used to assess the performance of the model, to compare metrics from two models used in the classification of the same observations, and to contrasts the results of these models with previous analyses in the published literature.

Tip: Assess model performance with multiple data splits

Notice that the function run_ml performs only one 80/20 split. Thus, the output metrics represent the performance of the model in one of multiple splitting scenarios. However, researchers are often interested in generating multiple splits, calculate the performance of the models trained on each split, and then look at the variability across iterations. This gives a more accurate assessment of the model’s performance. That approach is discussed in the documentation of mikropml.

24.5.1 Classification metrics

Two models (RF and XGBoost) were used to perform a classification task in the same microbiome dataset. Since both were used in classification, the type of performance metrics are the same:

Metrics of the RF model:

# RF performance in classification tasks
rf_classification[["performance"]] |> knitr::kable()
cv_metric_AUC logLoss AUC prAUC Accuracy Kappa F1 Sensitivity Specificity Pos_Pred_Value Neg_Pred_Value Precision Recall Detection_Rate Balanced_Accuracy method seed
0.6579 0.5856 0.8061 0.7497 0.75 0.5015 0.75 0.7105 0.7941 0.7941 0.7105 0.7941 0.7105 0.375 0.7523 rf 1

Metrics of the XGBoost model:

# XGBoost performance in classification tasks
xgb[["performance"]] |> knitr::kable()
cv_metric_AUC logLoss AUC prAUC Accuracy Kappa F1 Sensitivity Specificity Pos_Pred_Value Neg_Pred_Value Precision Recall Detection_Rate Balanced_Accuracy method seed
0.6894 0.6015 0.7825 0.7517 0.6944 0.387 0.7105 0.7105 0.6765 0.7105 0.6765 0.7105 0.7105 0.375 0.6935 xgbTree 1

A common metric to quickly assess model performance in binary classification tasks is the area under the receiver operator characteristic curve ‘AUC’. Notice that the result tables include two types of AUC metrics: cv_metric_AUC, which represents AUC for the 80% of data used in training, and AUC, which is the AUC for the 20% used in testing. Thus, results where the performance of the model in the train set highly exceeds its performance in the test set suggest overfiting, meaning that the model fails to generalize to new data. Notice, however, that small drops of model’s performance in the test compared to the train data are expected and thus shouldn’t be a concern.

Now let’s take a look at the results. First, notice that AUC values suggest a relatively good performance considering that predicting outcomes based on complex microbiome data is typically a challenging task. In addition, notice that the performance of RF and XGBoost are similar (RF AUC = 0.8061 and XGBoost AUC = 0.7825). This illustrates an important point in ML: sometimes simpler models can perform as good as more complex ones. Thus, it is often a good idea to deploy different models and compare their performance on the same task.

Now let’s interpret our results at the light of previous research. In the largest multi-cohort analysis of gut microbiome associations with T2D published to date (Mei et al. 2024), authors report similar AUC values to ours when training RF models to discriminate healthy participants and T2D patients using microbiome and anthropometric variables like age, sex and BMI. Interestingly, they included the study we are working with in their analysis. Notice their reported an AUC of 0.74, which is very close to ours. This further validates the potential of the gut microbiome to discriminate between healthy individuals and T2D patients.

Mei, Zhendong, Fenglei Wang, Amrisha Bhosle, Danyue Dong, Raaj Mehta, Andrew Ghazi, Yancong Zhang, et al. 2024. “Strain-Specific Gut Microbial Signatures in Type 2 Diabetes Identified in a Cross-Cohort Analysis of 8,117 Metagenomes.” Nature Medicine 30 (8): 2265–76. https://doi.org/10.1038/s41591-024-03067-7.

Although the tasks we addressed required binomial classification (i.e., classify sampples either as patients or controls), some research questions will require the classification of observations in more than just two groups, a task known as multiclass classification. For instance, we may have been interested in classifying patients as (1) healthy individuals, (2) T2D patients with obesity, and (3) T2D patients without obesity by integrating BMI information. In such cases, metrics like AUC (developed for binary classification) can be generalized, like in the case of multiclass AUC, also referred to as One-vs-Rest AUC. This approach evaluates how well each class is distinguished from the others, enebling the assessment of model performance in multiclass problems. Multiclass classification can be implemented with run_ml too, as described here.

Tip: Always use multiple performance metrics

Regardless of the type of classification task performed, it is often desirable to look at different metrics to accurately assess the model’s performance. This is particularly relevant in cases where classes are imbalanced. Suppose a dataset consists of 95% healthy individuals, and only 5% T2D patients. In such a case, a model can easily achieve 95% accuracy by just predicting (labelling) all samples as healthy, despite being useless for the classification of T2D patients.

This is an extreme case and hopefully we won’t encounter datasets like that. However, it highlights that class imbalance can lead to misleading interpretations of performance metrics like accuracy and AUC. Thus, when dealing with imbalanced classes, other metrics like F1-score and the area under the precision recall curve (prAUC), might be more appropriate.

In addition to relying in other performance metrics, different strategies for handling class imbalance datasets have been discussed (Papoutsoglou et al. 2023) and applied (Diez Lopez et al. 2022) in microbiome data analysis before.

Diez Lopez, Celia, Diego Montiel Gonzalez, Athina Vidaki, and Manfred Kayser. 2022. “Prediction of Smoking Habits from Class-Imbalanced Saliva Microbiome Data Using Data Augmentation and Machine Learning.” Frontiers in Microbiology 13. https://doi.org/10.3389/fmicb.2022.886201.
Papoutsoglou, Georgios, Sonia Tarazona, Marta B. Lopes, Thomas Klammsteiner, Eliana Ibrahimi, Julia Eckenberger, Pierfrancesco Novielli, et al. 2023. “Machine Learning Approaches in Microbiome Research: Challenges and Best Practices.” Frontiers in Microbiology 14. https://doi.org/10.3389/fmicb.2023.1261889.

24.5.2 Regression metrics

When RF are used in regression tasks, the performance metrics are different to what was discussed for classifications.

# RF performance in regression tasks
rf_regression[["performance"]] |> knitr::kable()
cv_metric_RMSE RMSE Rsquared MAE method seed
3.722 3.126 0.0553 2.664 rf 1

A commonly used metric to quickly assess model performance in regression tasks is the root mean square error (RMSE). As it name suggests, this metric is just the root of the mean squared difference between the observed and predicted values of the outcome variable (patient’s BMI in this example).

Similarly to AUC in classification, the results table include two types of RMSE metrics: cv_metric_RMSE, which represents RMSE for the 80% of data used in training, and RMSE, which is the metric for the 20% used in testing. Notice, the small drop of model’s performance in the test (3.126) compared to the train (3.722) tests. As discussed in the classification results, big drops of model’s performance in the test compared to the train data are indicative of overfitting. However, in this case the drop in performance is small and thus it shouldn’t be a concern.

24.6 Visualizing model’s performance

We explored the different performance metrics used in classification and regression tasks. However, researchers are often interested in generating visual representations of the performance of their models. Thus, the goal of this section is to show code to create those visualizations, as well as providing insights in their interpretations.

24.6.1 Classification

The area under the receiver-operator characteristic (ROC) curve, also called ‘AUC’ was introduced before as common metric to assess model’s performance in binary classification tasks. As it name suggests, AUC is just the area under a curve. Since different ROC curves can have similar AUC, visualizing the curve can give complementary information.

Complementary to ROC curves, the the precision-recall curves (PRC) are preferred in datasets with class imbalance, where AUC can be misleading, as discussed above.

The code below is used to calculate the metrics required to generate both curves for the two models generated (RF and XGBoost):

# Calculate RF model metrics required for plotting
rf_senspec <- calc_model_sensspec(
        rf_classification[["trained_model"]],
        rf_classification[["test_data"]]
)
# Add model label to data
rf_senspec$model <- "rf"

# Calculate XGBoost model metrics required for plotting
xgb_senspec <- calc_model_sensspec(
        xgb[["trained_model"]],
        xgb[["test_data"]]
)
# Add model label to data
xgb_senspec$model <- "XGBoost"

# Combine model metrics
senspec <- rbind(rf_senspec, xgb_senspec)

# Inspect part of the output
senspec |> head() |> knitr::kable()
T2D healthy actual tp fp sensitivity fpr specificity precision model
109 0.870 0.130 healthy 0 1 0.0000 0.0263 0.9737 0.0000 rf
328 0.760 0.240 T2D 1 1 0.0294 0.0263 0.9737 0.5000 rf
291 0.714 0.286 healthy 1 2 0.0294 0.0526 0.9474 0.3333 rf
158 0.700 0.300 T2D 2 2 0.0588 0.0526 0.9474 0.5000 rf
119 0.696 0.304 T2D 3 2 0.0882 0.0526 0.9474 0.6000 rf
326 0.696 0.304 T2D 4 2 0.1176 0.0526 0.9474 0.6667 rf

The metrics ‘Sensitivity’ and ‘Specificity’ are used in ROC curves, and ‘Sensitivity’ and ‘Precision’ are used in PRC curves. Note that ‘Precision’ and ‘Recall’ are the same metric. While the term ‘Precision’ is preferred in the biomedical literature, ‘Recall’ is more prevalent in other fields.

The code below generates the ROC and PRC curves of both models and shows them side by side:

library(patchwork)

# 1. Plot the ROC curve of each model
roc_p <- ggplot(data = senspec) +
    geom_path(aes(x = fpr, y = sensitivity, color = model))
# 1.1 Add line representing 'random guess' performance
roc_p <- roc_p + 
    geom_abline(color = "grey30", linetype = "dashed")
# 1.2 Add axis titles and custom theme
roc_p <- roc_p + 
    labs(x = "FPR or (1 - Specificity)", y = "Sensitivity", title = "ROC") +
    theme_bw()


# 2. Plot the PRC curve of each model
prc_p <- ggplot(data = senspec) +
    geom_path(aes(x = sensitivity, y = precision, color = model)) 
# 2.1 Add line representing 'random guess' performance
prc_p <- prc_p +
    geom_hline(color = "grey30", linetype = "dashed", yintercept = 0.5)
# 2.2 Add axis titles and custom theme
prc_p <- prc_p +
    labs(x = "Recall or Sensitivity", y = "Precision", title = "PRC") +
    theme_bw()


# 3. Plot ROC and PRC side by side
roc_p + prc_p + plot_layout(guides = "collect")

Before describing the plots and their meaning, it is worth noting that the ROC curves of both models resembles the curve presented in the article where this dataset was first analysed (Qin et al. 2012) (see Figure 4B). Interestingly, authors used other supervised ML algorithm, and it was trained in a set of 50 microbiome genes (instead of taxa and alpha diversity metrics, as we did). However, it is interesting that concordant AUCs and ROC curves shapes were obtained using different microbiome-derived information.

Qin, Junjie, Yingrui Li, Zhiming Cai, Shenghui Li, Jianfeng Zhu, Fan Zhang, Suisha Liang, et al. 2012. “A Metagenome-Wide Association Study of Gut Microbiota in Type 2 Diabetes.” Nature 490: 55–60. https://doi.org/10.1038/nature11450.

Regarding our figures, note the dashed grey lines in both plots representing the expected performance of a model that is classifying samples randomly. Therefore, the greater the distance between that reference and the line representing our model’s performance, the better. Thus, perfect performance will be achieve when the line is the farthest from that reference.

In our example, a model with perfect performance is such that, by learning patterns from the microbiome data, it can correctly identify T2D patients with perfect sensitivity (i.e., all T2D patients are classified as such) while maintaining a false positive rate (FPR) of 0 (i.e., no healthy individual is misclassified as T2D). In the ROC plot, this means the model consistently achieves a Sensitivity of 1 and a FPR of 0, pushing the curve towards the top left corner.

In terms of the PRC plot, the perfect model would consistently achieve a Recall and Precision of 1, pushing the curve towards the top right corner. That, in turn, would mean that the model has perfect Sensitivity (i.e., all T2D patients are classified as such), while maintaining a perfect Precision (i.e., every individual classified as T2D is actually a T2D patient).

Finally, note that the ROC curves of RF and XGBoost have a similar shape. This suggests a similar performance in classifying T2D patients. This isn’t surprising as the AUC of both models were very similar too (RF AUC = 0.8061 and XGBoost AUC = 0.7825).

Note: Visualization of multiclass ROC curves

It was discussed above that ROC curves (and AUCs) can be extended to multiclass classification tasks. One example of such approach is this article (Su et al. 2022) where authors used gut microbiome data to train different classification supervised ML algorithms to discriminate between patients of 9 different diseases. Notice that the visualization of multiclass ROC curves can be achieved in a similar way to what was described in this chapter. In addition, the R package pROC provides handy functions to build multiclass visualizations.

Su, Qi, Qin Liu, Raphaela Iris Lau, Jingwan Zhang, Zhilu Xu, Yun Kit Yeoh, Thomas W. H. Leung, et al. 2022. “Faecal Microbiome-Based Machine Learning for Multi-Class Disease Diagnosis.” Nature Communications 13 (1). https://doi.org/10.1038/s41467-022-34405-3.

24.6.2 Regression

In this chapter RF was used to learn patterns in the gut microbiome to predict the BMI of the participants. The most used diagnostic visualization of model’s performance is to compare observed against predicted BMI values.

The code below is used to generate that plot using the observations in the test data.

# Get the test set: data not used in training the model
test_data <- rf_regression[["test_data"]]

# Get the trained model
model <- rf_regression[["trained_model"]][["finalModel"]]

# Use the model to predict BMI values in the test set
pred <- predict(model, test_data)

# Add predicted BMI values to the test dataset
test_data$pred <- pred
# Add diagnosis status to test data.
# This variable was created when preprocessing for classification
ids <- rownames(test_data) |> as.numeric()
test_data$diagnosis <- labels[ids]

table(test_data$diagnosis)
##  
##      T2D healthy 
##       30      38

# Plot actual (y-axis) vs predicted (x-axis) BMI values
obs_vs_pred <- ggplot(data = test_data) +
    geom_point(aes(x = pred, y = BMI, color = diagnosis), size = 2)
# Add line showing perfect prediction
obs_vs_pred <- obs_vs_pred + geom_abline(color = "grey30", linetype = "dashed")
# Make axes of same scale
obs_vs_pred<- obs_vs_pred + coord_equal(xlim = c(17,30), ylim = c(17,30))
# Custom axes titles
obs_vs_pred <- obs_vs_pred + labs(x = "Predicted BMI", y = "Observed BMI")
# Add custom theme and visualize the plot
obs_vs_pred + theme_bw()

The dashed grey line in the plot above represents a perfect correlation between the observed and the model-predicted BMI values of each participant. Thus, the line indicates perfect performance of the model. We can see that while the predictions are around the mean BMI (close to 24), the observed values range approximately between 17 and 29, showing poor correlation between observed and predicted BMI values. Note that both groups show a similar spread in BMI.

This plot clearly indicates a poor performance of our model. This usually means that either the gut microbiome is not related to BMI, making it impossible for the model to use microbiome patterns to predict BMI, or that our model is overfitted. Since the RMSE for training (3.722) and test (3.1259) data are similar, we may conclude that gut microbiome composition is not informative of BMI.

This illustrate another important aspect of supervised ML. Regardless of the complexity of the models deployed, their predictive performance depends in their ability to learn patterns from the data. If such patterns don’t exist it would be impossible for the model to make accurate predictions. This is most likely what happened in this regression task. Notice that this can certainly happen in classification tasks too.

24.7 Model interpretation

Besides the performance of the model, often researchers are interested in understanding what are the features that are affecting the model performance, a property of ML models called interpretability.

ML models have different degrees of interpretability. One way to understand how different features (e.g., microbial taxa) affect model’s performance is by randomly permuting the values of that feature across samples and then evaluating the amount of change in a performance metric.

For simplicity, the code below shows how to determine feature importance only with the RF model trained for classification. However, minimal changes are required to determine feature importance for other models built in this chapter, like the XGBoost model used in classification or the RF model used in regression.

# estimate features importance
feat_imp <- get_feature_importance(
    trained_model = rf_classification[["trained_model"]],
    test_data = rf_classification[["test_data"]],
    outcome_colname = "diagnosis",
    perf_metric_function = multiClassSummary,
    perf_metric_name = "AUC",
    class_probs = TRUE,
    method = "rf",
    seed = 1,
    nperms = 5 # 1/20 of default to speed calculations up
)

# Identify the 10 most important features
ordered_features <- order(feat_imp$perf_metric_diff, decreasing = TRUE)
ordered_features <- ordered_features[1:10]

# Retain only those features
feat_imp <- feat_imp[ordered_features, ]

# Conver column 'feat' into a factor to fix its order in the plot
feat_imp$feat <- factor(feat_imp$feat, levels = rev(feat_imp$feat))

# Plot mean feature importance with 95% CI
ggplot(feat_imp) +
    geom_col(aes(x = perf_metric_diff, y = feat)) +
    labs(x = "Decrease in performance", y = "Feature") +
    theme_classic()

Important: A call to avoid oversimplifications

Feature importance plots are often interpreted in the literature with a reductionist mindset. Thus, researchers often conclude that a single feature (i.e., bacterial taxa) is responsible of a complex clinical output of interest. However, more often than not, microbes interact with one another, and thus the effect of a single taxa doesn’t always inform about ecosystem-level properties of the microbiome. It is important to bear that in mind when interpreting the results of our models.

Exercises

Goal: After completing these exercises, you should be able to fit supervised machine learning model utilizing multi-assay data.

Exercise 1: Supervised ML

  1. Load any of the example datasets mentioned in Section 4.2.

  2. Observe colData and check that the metadata includes outcome variables that you want to model.

  3. Visualize the selected outcome variable with a histogram or a bar plot. What is the distribution? If the distribution is biased, how can this affect the training of the model?

  4. Apply CLR transformation.

  5. Preprocess data by removing features with near-zero variance and by grouping correlated features.

  6. Fit a random forest model with find_feature_importance = TRUE.

  7. Visualize results. Do the model perform well by predicting the outcome with a high accuracy?

  8. What features are the most important for predicting the outcome?

Useful functions:

data(), colData(), plotHistogram(), plotBarplot(), subsetByPrevalent(), transformAssay(), mikropml::preprocess_data(), mikropml::run_ml(), mikropml::plot_model_performance()

Back to top