clustra: clustering trajectories

knitr::opts_knit$set(
  collapse = TRUE,
  comment = "#>"
)
start_knit = proc.time()

The clustra package was built to cluster longitudinal trajectories (time series) on a common time axis. For example, a number of individuals are started on a specific drug regimen and their blood pressure data is collected for a varying amount of time before and after the start of the medication. Observations can be unequally spaced, unequal length, and only partially overlapping.

Clustering proceeds by an EM algorithm that iterates switching between fitting a bspline to combined responses within each cluster (M-step) and reassigning cluster membership based on nearest fitted bspline (E-step). Initial cluster assignments are random. The fitting is done with the mgcv package function bam, which scales well to very large data sets.

For this vignette, we begin by generating a data set with the gen_traj_data() function. Given its parameters, the function generates groups of ids (their size given by the vector n_id) and for each id, a random number of observations based on the Poisson($\lambda =$ m_obs + 3) distribution. The 3 additional observations are to guarantee one before intervention at time start, one at the intervention time 0, and one after the intervention at time end. The start time is Uniform(s_range) and the end time is Uniform(e_range). The remaining times are at times Uniform(start, end). The time units are arbitrary and depend on your application. Up to 3 groups are implemented so far, with Sin, Sigmoid, and constant forms.

We also set RNGkind and seed for reproducibility. Code below generates the data and looks at a few observations of the generated data. The mc variable sets core use and will be assigned to mccores parameter through the rest of the vignette. By default, 1 core is assigned. Parallel sections are implemented with parallel::mclappy(), so on unix and Mac platforms it is recommended to use the full number of cores available for faster performance.

library(clustra)
mc = 1 # If running on a unix or a Mac platform, increase up to 2x # cores
set.seed(12345)
data = gen_traj_data(n_id = c(400, 800, 1600), m_obs = 25, 
                     s_range = c(-365, -14), e_range = c(0.5*365, 2*365),
                     noise = c(0, 5))
head(data)

Select a few random ids and print their scatterplots.

library(ggplot2)
ggplot(data[id %in% sample(unique(data[, id]), 9)],
       aes(x = time, y = response)) + facet_wrap(~ id) + geom_point()

Next, cluster the trajectories. Set k=3, spline max degrees of freedom to 30, and set conv maximum iterations to 10 and convergence when 0 changes occur. mccores sets the number of cores to use in various components of the code. Note that this does not work on Windows operating systems, where it should be left at 1. In the code that follows, we use verbose output to get information from each iteration.

set.seed(1234737)
cl = clustra(data, k = 3, maxdf = 30, conv = c(10, 0), mccores = mc,
             verbose = TRUE)

Next, plot the raw data (sample if more than 10,000 points). Then repeat the plot with resulting spline fit, colored by the cluster value.

sdt = data
if(nrow(data) > 10000)
  sdt = data[, group:=factor(..cl$data_group)][sample(nrow(data), 10000)]
ggplot(sdt, aes(x = time, y = response)) + geom_point(pch = ".")

np = 100
k = length(cl$tps)
ntime = seq(data[, min(time)], data[, max(time)], length.out = np)
pdata = expand.grid(time = ntime, group = factor(1:k))
pdata = subset(pdata, group %in% which(lengths(cl$tps) > 0))
pred = vector("list", k)
for(i in 1:k) 
  if(is.null(cl$tps[[i]])) {
    pred[[i]] = NULL
  } else {
    pred[[i]] = mgcv::predict.bam(cl$tps[[i]], newdata = list(time = ntime),
                        type = "response")
  }
pdata$pred = do.call(c, pred)
ggplot(pdata, aes(x = time, y = pred, color = group)) + 
  geom_point(data = sdt, aes(y = response), pch = ".") + geom_line()

The Rand index for comparing with true_groups is

MixSim::RandIndex(cl$data_group, data[, true_group])

A perfect score! Let's double the error variance (4*sd) in data generation ...

set.seed(1234567)
data2 = gen_traj_data(n_id = c(500, 1000, 2000), m_obs = 25, s_range = c(-365, -14),
                     e_range = c(60, 2*365), noise = c(0, 20))
iplot = sample(unique(data2$id), 9)
sampobs = match(data2$id, iplot, nomatch = 0) > 0
ggplot(data2[sampobs], aes(x = time, y = response)) + 
  facet_wrap(~ id) + geom_point()
cl = clustra(data2, k = 3, maxdf = 30, conv = c(10, 0), mccores = mc, verbose = TRUE)
MixSim::RandIndex(cl$data_group, data2[, true_group])

The result is less perfect but still pretty good score. Now the plots:

sdt = data2
if(nrow(data) > 10000)
  sdt = data2[, group:=factor(..cl$data_group)][sample(nrow(data), 10000)]
ggplot(sdt, aes(x = time, y = response)) + geom_point(pch = ".")

np = 100
k = length(cl$tps)
ntime = seq(data[, min(time)], data[, max(time)], length.out = np)
pdata = expand.grid(time = ntime, group = factor(1:k))
pdata = subset(pdata, group %in% which(lengths(cl$tps) > 0))
pred = vector("list", k)
for(i in 1:k) 
  if(is.null(cl$tps[[i]])) {
    pred[[i]] = NULL
  } else {
    pred[[i]] = mgcv::predict.bam(cl$tps[[i]], newdata = list(time = ntime),
                        type = "response")
  }
pdata$pred = do.call(c, pred)
ggplot(pdata, aes(x = time, y = pred, color = group)) + 
  geom_point(data = sdt, aes(y = response), pch = ".") + geom_line()

Average silhouette value is a way to select the number of clusters and a silhouette plot provides a way for a deeper evaluation (Rouseeuw 1986). As silhouette requires distances between individual trajectories, this is not possible due to unequal trajectory sampling without fitting a separate model for each id. As a proxy for distance between points, we use trajectory distances to cluster mean spline trajectories in the clustra_sil() function. The structure returned from the clustra() function contains the matrix loss, which has all the information needed to construct these proxy silhouette plots. The function clustra_sil() performs clustering for a number of k values and outputs information for the silhouette plot that is displayed next. We relax the convergence criterion in conv to 1 % of changes (instead of 0 used earlier). We use the first data set with noise = c(0, 5).

set.seed(1234737)
sil = clustra_sil(data, k = c(2, 3, 4), mccores = mc, conv = c(7, 1),
                  verbose = TRUE)
plot_sil = function(x) {
  msil = round(mean(x$silhouette), 2)
  ggplot(x, aes(id, silhouette, color = cluster, fill = cluster)) + geom_col() +
    ggtitle(paste("Average Width:", msil)) +
    scale_x_discrete(breaks = NULL) + scale_y_continuous("Silhouette Width") +
    geom_hline(yintercept = msil, linetype = "dashed", color = "red")
}
lapply(sil, plot_sil)

The plots show that 3 clusters give the best Average Width.

If we don't want to recluster the data again, we can directly reuse a previous clustra run and produce a silhouette plot for it, as we now do for the double variance error data clustra run above results in cl.

sil = clustra_sil(cl)
lapply(sil, plot_sil)

Another way to select the number of clusters is the Rand Index comparing different random starts and different numbers of clusters. When we replicate clustering with different random seeds, the "replicability" is an indicator of how stable the results are for a given k, the number of clusters. For this demonstration, we look at k = c(2, 3, 4), and 10 replicates for each k.

set.seed(1234737)
ran = clustra_rand(data, k = c(2, 3, 4), mccores = mc, replicates = 10,
                   conv = c(7, 1), verbose = TRUE)
rand_plot(ran)

The plot shows Adjusted Rand Index similarity level between all pairs of 30 clusterings (10 random starts for each of 2, 3, and 4 clusters). The ten random starts agree the most for k=3. From the deviance results shown during iterations, we also see that all of the k=3 clusters are near the best deviance attainable even with k = 4. Among the k = 4 results, several converged to only three clusters that agree with k=3 results.

Another possible evaluation of the number of clusters is to first ask clustra for a large number of clusters, evaluate the cluster centers on a common set of time points, and feed the resulting matrix to a hierarchical clustering function. Below, we ask for 40 clusters on the data2 data set but actually get back only 26 because several become empty or too small for maxdf. Below, the hclust() function clusters the 26 resulting cluster means, each evaluated on 100 time points.

set.seed(12347)
cl = clustra(data2, k = 40, maxdf = 30, conv = c(10, 0), mccores = mc,
             verbose = TRUE)
gpred = function(tps, newdata) 
  as.numeric(mgcv::predict.bam(tps, newdata, type = "response",
                               newdata.guaranteed = TRUE))
resp = do.call(rbind, lapply(cl$tps, gpred, newdata = data.frame(
  time = seq(min(data2$time), max(data2$time), length.out = 100))))
plot(hclust(dist(resp)))

The cluster dendrogram clearly indicates there are only three clusters. Making the cut at a height of roughly 300 groups the 26 clusters into only three.

cat("clustra vignette run time:\n")
print(proc.time() - start_knit)


Try the clustra package in your browser

Any scripts or data that you put into this service are public.

clustra documentation built on Jan. 16, 2022, 9:06 a.m.