clustra: clustering trajectories

knitr::opts_knit$set(
  collapse = TRUE,
  comment = "#>"
)
start_knit = proc.time()
data.table::setDTthreads(1) # manage data.table threads

The clustra package was built to cluster longitudinal trajectories (time series) on a common time axis. For example, a number of individuals are started on a specific drug regimen and their blood pressure data is collected for a varying amount of time before and after the start of the medication. Observations can be unequally spaced, unequal length, and only partially overlapping.

Clustering proceeds by an EM algorithm that iterates switching between fitting a thin plate spline (TPS) to combined responses within each cluster (M-step) and reassigning cluster membership based on nearest fitted bspline (E-step). The fitting is done with the mgcv package function bam, which scales well to very large data sets.

For this vignette, we begin by generating a data set with the gen_traj_data() function. Given its parameters, the function generates groups of ids (their size given by the vector n_id) and for each id, a random number of observations based on the Poisson($\lambda =$ m_obs) distribution plus 3. The 3 additional observations are to guarantee one before intervention at time start, one at the intervention time 0, and one after the intervention at time end. The start time is Uniform(s_range) and the end time is Uniform(e_range). The remaining times are at times Uniform(start, end). The time units are arbitrary and depend on your application. Up to 3 groups are implemented so far, with Sin, Sigmoid, and constant forms.

Code below generates the data and looks at a few observations of the generated data. The mc variable sets core use and will be assigned to mccores parameter through the rest of the vignette. By default, 1 core is assigned. Parallel sections are implemented with parallel::mclappy(), so on unix and Mac platforms it is recommended to use the full number of cores available for faster performance. Default initialization of the clusters is set to "random" (see clustra help file for the other option distant). We also set seed for reproducibility.

library(clustra)
mc = 2 # If running on a unix or a Mac platform, increase up to # cores
if (.Platform$OS.type == "windows") mc = 1
init = "random"
set.seed(12345)
data = gen_traj_data(n_id = c(500, 1000, 1500, 2000), types = c(2, 1, 3, 2), 
                     intercepts = c(70, 130, 120, 130), m_obs = 25, 
                     s_range = c(-365, -14), e_range = c(0.5*365, 2*365),
                     noise = c(0, 15))
head(data)

The histogram shows the distribution of generated lengths. The short ones will be the most difficult to cluster correctly.

Select a few random ids and show their scatterplots.

plot_sample(data[id %in% sample(unique(data[, id]), 9)], group = "true_group")

Next, cluster the trajectories. Set k=4 (we will consider selection of k later), spline max degrees of freedom to 30, and set conv maximum iterations to 10 and convergence when 0 changes occur. mccores sets the number of cores to use in various components of the code. Note that this does not work on Windows operating systems, where it should be set to 1 (the default). In the code that follows, we use verbose output to get information from each iteration.

set.seed(12345)
cl4 = clustra(data, k = 4, maxdf = 10, conv = c(10, 0), mccores = mc,
             verbose = TRUE)

Each iteration displays components of the M-step and the E-step followed by its duration in seconds, the number of classification changes in the E-step, the current counts in each cluster, and the deviance.

Next, plot the raw data (sample if more than 10,000 points) with resulting spline fit, colored by the cluster value.

plot_smooths(data, group = NULL)
plot_smooths(data, cl4$tps)

The Rand index for comparing with true_groups is

MixSim::RandIndex(cl4$data_group, data[, true_group])

The AR stands for Adjusted Rand index, which adjusts for random agreement. A .827 value comparing with true groups used to generate the data is quite good, considering that the short series are easily misclassified and that k-means often find a local minimum. Let's double the error standard deviation in data generation and repeat...

set.seed(12345)
data2 = gen_traj_data(n_id = c(500, 1000, 1500, 2000), types = c(2, 1, 3, 2), 
                     intercepts = c(70, 130, 120, 130), m_obs = 25,
                     s_range = c(-365, -14), e_range = c(60, 2*365), 
                     noise = c(0, 30))
plot_sample(data2[id %in% sample(unique(data2[, id]), 9)], group = "true_group")

cl4a = clustra(data2, k = 4, maxdf = 10, conv = c(10, 0), mccores = mc,
             verbose = TRUE)
MixSim::RandIndex(cl4a$data_group, data2[, true_group])

This time the AR is 0.815 result is less but still respectable. It recovers the trajectory means quite well as we see the following plots. The first without cluster colors (obtained by setting group = NULL), showing the mass of points and the second with cluster means and cluster colors.

plot_smooths(data2, group = NULL)
plot_smooths(data2, cl4a$tps)

Average silhouette value is a way to select the number of clusters and a silhouette plot provides a way for a deeper evaluation (Rouseeuw 1986). As silhouette requires distances between individual subjects, this is not possible due to unequal subject sampling without fitting a separate trajectory model for each subject id. As a proxy, we use subject distances to cluster mean trajectories in the clustra_sil() function. The structure returned from the clustra() function contains the matrix loss, which has all the information needed to construct these proxy silhouette plots. The function clustra_sil() performs clustering for a number of k values and outputs information for the silhouette plot that is displayed next. We relax the convergence criterion in conv to 1 % of changes (instead of 0 used earlier) for faster processing. We use the first data set with noise = c(0, 15).

set.seed(12345)
sil = clustra_sil(data, kv = c(2, 3, 4, 5), mccores = mc, maxdf = 10,
                  conv = c(7, 1), verbose = TRUE)
lapply(sil, plot_silhouette)

The plots for 3 or 4 clusters give the best Average Width. Usually we take the larger one, 4, which is supported here also by the minimum AIC and BIC scores. We also note that the final deviance drops substantially a 4 clusters and barely moves when 5 clusters are fit, further corroborating that k = 4.

If we don't want to recluster the data again, we can directly reuse a previous clustra run and produce a silhouette plot for it, as we now do for the double variance error data clustra run above results in cl4.

sil = clustra_sil(cl4)
lapply(sil, plot_silhouette)

Another way to select the number of clusters is the Rand Index comparing different random starts and different numbers of clusters. When we replicate clustering with different random seeds, the "replicability" is an indicator of how stable the results are for a given k, the number of clusters. For this demonstration, we look at k = c(2, 3, 4), and 10 replicates for each k. To run this long-running chunk, set eval = TRUE.

set.seed(12345)
ran = clustra_rand(data, k = c(2, 3, 4, 5), mccores = mc,
                   replicates = 10, maxdf = 10, conv = c(7, 1), verbose = TRUE)
rand_plot(ran)

The plot shows AR similarity level between all pairs of 40 clusterings (10 random starts for each of 2, 3, 4, and 5 clusters). It is difficult to distinguish between the 3, 4, and 5 results but the 4 result has the largest block of complete agreement.

Here, we cat try running clustra with the "distant" starts option. The sequential selection of above-medial length series that are most distant from previous selections introduces less initial variability. To run this long-running chunk, set eval = TRUE.

set.seed(12345)
ran = clustra_rand(data, k = c(2, 3, 4, 5), starts = "distant", mccores = mc,
                   replicates = 10, maxdf = 10, conv = c(7, 1), verbose = TRUE)
rand_plot(ran)

In this case, k = 4 comes with complete agreement between the 10 starts.

Another possible evaluation of the number of clusters is to first ask clustra for a large number of clusters, evaluate the cluster centers on a common set of time points, and feed the resulting matrix to a hierarchical clustering function. Below, we ask for 40 clusters on the data2 data set but actually get back only 17 because several become empty or too small for maxdf. Below, the hclust() function clusters the 17 resulting cluster means, each evaluated on 100 time points.

set.seed(12345)
cl30 = clustra(data, k = 40, maxdf = 10, conv = c(10, 0), mccores = mc,
             verbose = TRUE)
gpred = function(tps, newdata) 
  as.numeric(mgcv::predict.bam(tps, newdata, type = "response",
                               newdata.guaranteed = TRUE))
resp = do.call(rbind, lapply(cl30$tps, gpred, newdata = data.frame(
  time = seq(min(data2$time), max(data2$time), length.out = 100))))
plot(hclust(dist(resp)))

The dendrogram clearly indicates 4 clusters.

When we use starts = "distant", the selected distant starts are more likely to persist into a nearby local minimum, retaining the full 40 specified clusters.

set.seed(12345)
cl30 = clustra(data, k = 40, starts = "distant", maxdf = 10, conv = c(10, 0), mccores = mc,
             verbose = TRUE)
gpred = function(tps, newdata) 
  as.numeric(mgcv::predict.bam(tps, newdata, type = "response",
                               newdata.guaranteed = TRUE))
resp = do.call(rbind, lapply(cl30$tps, gpred, newdata = data.frame(
  time = seq(min(data2$time), max(data2$time), length.out = 100))))
plot(hclust(dist(resp)))

Here again (if we consider 24 as an outlier) we get 4 clusters.

cat("clustra vignette run time:\n")
print(proc.time() - start_knit)


Try the clustra package in your browser

Any scripts or data that you put into this service are public.

clustra documentation built on Oct. 14, 2023, 9:15 a.m.