knitr::opts_chunk$set(dpi=300, collapse = T, comment = "#>")
options(tibble.print_min = 6L, tibble.print_max = 6L)
knitr::opts_chunk$set(fig.align="center")
library(MagmaClustR)
library(dplyr)
library(ggplot2)
if(rlang::is_installed("gridExtra")){library(gridExtra)}

Data and purpose

Context

To explore the features of MagmaClust, we use the weight dataset available here, coming from the GUSTO cohort study (https://www.gusto.sg/) and more thoroughly studied here). This dataset contains r nrow(weight) weight measurements of Singaporean children (r length(unique(weight %>% filter(sex=="Male") %>% pull(ID))) boys and r length(unique(weight %>% filter(sex=="Female") %>% pull(ID))) girls, respectively).

Throughout this example, we use MagmaClust to model and forecast evolution of children’s weight during early childhood. This extension of the Magma model has been proposed to take into account the presence of group structures in within data (a similar vignette for Magma is available here) to enhance predictions.

More generally, our aim is to train a model on a dataset containing multiple individuals and predict a new individual’s weight thanks to shared information within clusters.

To get a better idea of the weight dataset content, we display an illustration of the observed data for 5 children. Do the children’s weights evolve the same, i.e. with the same pattern? Can we identify group structures in the data, and define appropriate clusters? Those are the kind of questions we aim to tackle in the following.

set.seed(10)
list_ID <- weight %>% pull(ID) %>%  sample()

ggplot2::ggplot(data=weight %>% filter(ID %in% list_ID[1:5]),
                ggplot2::aes(x=Input,y=Output,colour=factor(ID)))+
  ggplot2::geom_point(size=3) +
  ggplot2::theme_classic() +
  ggplot2::guides(colour="none")

This plot underlines the regularity of our functional data. The values of weight are observed at the same ages for all children, as measurements are collected in the context of a standardised cohort study. Let us note that MagmaClust can also be used in the case of irregular input grids without any change.

Overall, we can notice some differences in children profiles; some are plump at birth before their weight gain slows down quickly, while others may born slightly thinner but grow faster. Therefore, it might be important to take these profiles into account if one wants to improve the quality of weight predictions.

Data formatting

The weight dataset contains 4 columns: ID, sex, Input and Output. Each row corresponds to the weight measured for on child at a given age. More specifically, each column contains: - ID, the identifying number associated with each child; - sex, a factor indicating the biological gender of the child (Male or Female); - Input, the age of the child (in months); - Output, the weight of the child (in kilograms).

knitr::kable(head(weight)) 

If new variables (such as height, morphology, percentage of muscle mass ...) would have been observed, any additional column would be treated as a covariate, and thus result in a model with multi-dimensional inputs.

Before starting to use MagmaClust, we must ensure that our dataset contains at least the three following mandatory columns with adequate type: - ID: character or factor; - Input: numeric; - Output: numeric.

Since the children weight is not particularly affected by gender between 0 and 5 years old, we chose to maintain all individuals in the same dataset. Therefore, the sex column is below removed for simplicity.

weight <- weight %>% select(-sex)

Classical pipeline

The overall process of the MagmaClust algorithm can be decomposed in 3 main steps: training, prediction and display of results. The corresponding functions are:

Apply MagmaClust on the weight database

Data organisation

Once the data organisation step is done, we randomly select two types of children:

To limit computation time in this illustrative example, we randomly select 20 children for training. Even if the performances of MagmaClust increase with the number of training individuals, 20 are more than enough to get a clear idea of how the algorithm works.

weight_train <- weight %>% filter(ID %in% list_ID[1:20])
weight_pred <- weight %>% filter(ID == list_ID[261]) %>% filter(Input<20)
weight_test <- weight %>% filter(ID == list_ID[261]) %>% filter(Input>20)
ggplot2::ggplot(data = weight_train,
       mapping = ggplot2::aes(x=Input,y=Output,colour=factor(ID)))+
  ggplot2::geom_point(size=1.5,alpha=0.3)+
  ggplot2::geom_point(data = weight_pred,
             size=3,
             shape=17,
             col="black") +
  ggplot2::theme_classic() +
  ggplot2::guides(colour="none")

The triangles correspond to James’ weight at different ages, whereas the coloured dots are the children' data we use for training.

Training

It's now time to train the model thanks to train_magmaclust(), for which several arguments can be specified:

The most commonly used kernels and their properties are discussed in the kernel cookbook. Details of available kernels and how to combine them are available in help(train_magma).

As for any GP method, initialisation of the HP may have an influence on the final optimisation and leads to inadequate prediction for pathological cases. Therefore, users may explicitly define specific initial values through the dedicated arguments: ini_hp_k and ini_hp_i.

Other parameters can also be specified; see help (train_magmaclust) for details.

set.seed(10)
model_clust <- train_magmaclust(data = weight_train,
                        nb_cluster = 3,
                        kern_k = "SE",
                        kern_i = "SE",
                        common_hp_k = TRUE,
                        common_hp_i = TRUE)

Prediction for James

As the MagmaClust model is trained, we can now predict the evolution of James' weight. To perform prediction, we need to specify two main parameters in the pred_magmaclust() function:

Let us mention that we aim to study weight curves during early childhood, and compare our predictions to the true values observed in the dataset. Therefore, the evolution of James’ weight is predicted between 0 and 6 years (72 months), using only the his weight data from 0 to 20 months. Therefore, the argument grid_inputs = seq(0,72,0.1).

pred_clust <- pred_magmaclust(data = weight_pred,
                              trained_model = model_clust,
                              grid_inputs = seq(0,72,0.1),
                              plot = FALSE,
                              get_hyperpost = TRUE)

Here we specified get_hyperpost = TRUE (to return the hyper-posterior distribution of all mean processes) as it will allow us to plot the average GP of each cluster in the plot_magmaclust() function. However, we do not need it if we merely need to display the prediction of our individual.

Plots

With the plot_magmaclust() function, we can display the prediction for the evolution of James' weight over time.

plot_magmaclust(pred = pred_clust,
                data = weight_pred,
                prior_mean = pred_clust$hyperpost$mean,
                data_train = weight_train,
                size_data = 5)

In the above figure, we can observe:

Customise graphs

In order to customise the graphical representations of results, we first need to distinguish 2 cases:

We can also display training individuals in the same colour as their most probable cluster by using the data_allocate_cluster() function beforehand, and specifying the following arguments: data_train = data_train_with_clust and col_clust = TRUE. In order to evaluate the quality of the prediction, we can also represent (in grey below) the true observations for this individual, which we didn't use in the algorithm and kept in the weight_test variable for testing purposes.

data_train_with_clust = data_allocate_cluster(model_clust)

plot_magmaclust(pred = pred_clust,
                cluster = "all",
                data = weight_pred,
                data_train = data_train_with_clust,
                prior_mean = pred_clust$hyperpost$mean,
                heatmap = TRUE,
                y_grid = seq(0,37,0.1),
                col_clust = TRUE,
                size_data = 5) + 
  ggplot2::geom_point(data = weight_test,
                      ggplot2::aes(x = Input, y = Output),
                      color = 'grey', size = 3)

MagmaClust vs Magma

The MagmaClustR package also provides an implementation for the Magma algorithm, allowing us to compare MagmaClust and Magma predictions. For more details on Magma, see the dedicated vignette.

The graphs below correspond to James' weight prediction with MagmaClust (left) and Magma (right).

set.seed(11)
model_magma <- train_magma(data = weight_train,
                        kern_0 = "SE",
                        kern_i = "SE",
                        common_hp = FALSE)

pred_magma <- pred_magma(data=weight_pred,
                   grid_inputs = seq(0,72,0.1),
                   trained_model = model_magma,
                   plot = FALSE,
                   get_hyperpost = TRUE) 
gg2 <- plot_gp(pred_gp = pred_magma,
        data = weight_pred,
        prior_mean = pred_magma$hyperpost$mean,
        data_train = weight_train,
        heatmap = TRUE,
        y_grid = seq(3,36,0.1),
        size_data = 5) + 
  ggplot2::geom_point(data = weight_test,
                      ggplot2::aes(x = Input, y = Output),
                      color = 'grey', size = 3) +
  ggplot2::scale_y_continuous(limits = c(0,37))

data_train_with_clust = data_allocate_cluster(model_clust)

gg1 <- plot_magmaclust(pred = pred_clust,
                data = weight_pred,
                data_train = data_train_with_clust,
                prior_mean = pred_clust$hyperpost$mean,
                heatmap = TRUE,
                y_grid = seq(3,36,0.1),
                col_clust = TRUE,
                size_data = 5) + 
  ggplot2::geom_point(data = weight_test,
                      ggplot2::aes(x = Input, y = Output),
                      color = 'grey', size = 3) +
  ggplot2::scale_y_continuous(limits = c(0,37))

if(rlang::is_installed("gridExtra")){
  gridExtra::grid.arrange(gg1, gg2, ncol = 2)
} else {
  gg1
  gg2
}

On intervals of unobserved timestamps containing data points from the training dataset ($t \in ]20, 72]$), Magma takes advantage of its multi-task component to share knowledge across individuals by estimating a unique mean process. However, this unique mean process appears unable to recover accurately the evolution trend. In this example, MagmaClust offers a significant improvement in this long term (more than 4 years ahead) forecasting task. By leveraging group structures among children, the algorithm shares more knowledge across individual that are similar, resulting in more precise and specific predictions based on a mixture of GPs.

Reference

Overall, the multi-task framework combined with a clustering component provided by MagmaClust offers reliable probabilistic predictions for a new individual (child in our example) on a wide range of timestamps. Moreover, the uncertainty quantification inherent to GP-based methods allows practitioners to maintain an adequate degree of caution in their decision making process.

For further details, the complete derivation of the algorithm and experiments are published and available in Cluster-Specific Predictions with Multi-Task Gaussian Processes.



ArthurLeroy/MagmaClustR documentation built on Dec. 24, 2024, 11:18 a.m.