Vignette 1 - Clustering via the Mean Shift Algorithm

The mean-shift algorithm and modal clustering

The mean shift algorithm [@fukunagahostetlermeanshift; @chengmeanshift] is a recursive algorithm that allows us to perform nonparametric mode-based clustering, i.e. clustering data on the basis of a kernel density estimate of the probability density function associated with the data-generating process. https://normaldeviate.wordpress.com/2012/07/20/the-amazing-mean-shift-algorithm/ has a great introduction to the mean shift algorithm.

In its standard form, the mean shift algorithm works as follows. We observe $X_1, \dots, X_n$, a sample of i.i.d. random variables valued in $\mathbb{R}^d$ generated from an unknown probability density $p$. We fix a kernel function $K$ and a bandwidth parameter $h$ and we apply the update rule $$ x \leftarrow \frac{\sum_{i=1}^n K \left( \frac{\|X_i-x\|}{h} \right)X_i}{\sum_{i=1}^n K \left( \frac{\|X_i-x\|}{h} \right)} $$ to an arbitrary initial point $x=x_0 \in \mathbb{R}^d$ until convergence. The discrete sequence of points ${x_0, x_1, \dots, x_k, \dots}$ generated by the application of the above rule approximates the continuous gradient flow trajectory $\pi_x$ satisfying $$ \begin{cases} \pi_x(0)=x_0\ \pi_x'(t)=\nabla \hat p(\pi_x(t)) \end{cases} $$ where $\hat p$ is a kernel density estimator of $p$ based on another kernel function (the "shadow" kernel of $K$). In turn, $\pi_x$ is an estimate of the population gradient flow line $\tau_x$ satisfying $$ \begin{cases} \tau_x(0)=x_0\ \tau_x'(t)=\nabla p(\tau_x(t)) \end{cases} $$ associated to the population gradient flow based on $p$. Under some assumptions on $p$ and $K$, it can be shown that $\pi_x(t)$ and $\tau_x(t)$ converge respectively to a mode (a local maximum) of $\hat p$ and $p$ as $t \to \infty$. Furthermore, for any initial point $x \in \mathbb{R}^d$, there is a unique $\pi_x$ and a unique $\tau_x$, and the collections ${\tau_x}{{x \in \mathbb{R}^d}}$ and ${\pi_x}{{x \in \mathbb{R}^d}}$ both partition $\mathbb{R}^d$, thus inducing respectively a population and an empirical clustering. More specifically, a set $M$ in the population partition (or "population clustering") induced by ${\tau_x}{{x \in \mathbb{R}^d}}$ can be described as the subset of points in $\mathbb{R}^d$ such that $\tau_x(t) \to m$ as $t \to \infty$, where $m$ is a mode of $p$, i.e. $M={x \in \mathbb{R}^d: \lim{t \to \infty} \tau_x(t) = m }$. In a similar way, $\hat M={x \in \mathbb{R}^d: \lim_{t \to \infty} \pi_x(t) = \hat m }$ defines an "empirical cluster". For more details, see for instance @ariascastromeanshiftgradient and @chacon2015population.

From a practical point of view, it is clear that one is particularly interested in the case $x=x_0 \in {X_1,\dots,X_n}$ as we want to group the sample data into "sample clusters". The MeanShift package is designed to accomplish this goal.

The "MeanShift" package

The MeanShift package contains two implementations of the mean shift algorithm: the standard mean shift algorithm and its "blurring" version, which is an approximation to the standard algorithm that is often substantially faster.

The standard implementation of the mean shift algorithm comes with the function msClustering. The user needs to input

In our implementation, convergence is achieved at iteration $k$ if $\|x_k - x_{k-1}\|<$tol.stop.

The blurring mean shift algorithm is a variant of the mean shift algorithm in which the sample ${X_1,\dots,X_n}$ is updated at each mean shift iteration. In particular, $\forall i \in {1,\dots,n}$, the update $$ X_i \leftarrow \frac{\sum_{j=1}^n K \left( \frac{\|X_j-X_i\|}{h} \right)X_j}{\sum_{j=1}^n K \left( \frac{\|X_j-X_i\|}{h} \right)} $$ is recursively applied until convergence. In the MeanShift package, the blurring mean shift algorithm is available with the function bmsClustering which takes the following input arguments:

In the context of bmsClustering, convergence occurs at the $k$-th iteration if $\max_i \|X_{i,k}-X_{i,k-1}\|<$ tol.stop.

Example: clustering wheat grain varieties

We illustrate the use of the MeanShift package by applying it to the seeds dataset at https://archive.ics.uci.edu/ml/datasets/seeds. The seeds dataset gives measurements of geometrical properties of wheat grains belonging to 3 different varieties.

Our goal is to demonstrate the use of the msClustering and bmsClustering functions by clustering the wheat varieties on the basis of the 7 quantitative variables contained in the dataset.

## load "MeanShift" package
library( MeanShift )

## load `seeds` dataset
load( "seeds.RData" )
## wheat variety labels
seeds.labels <- seeds[,"variety"]

## organize data by columns
seeds.data <- t( seeds[,c( "area", "perimeter", "compactness", 
                      "length", "width", "asymmetry", 
                      "groove.length" )] )

print( dim( seeds.data ) )

## standardize the variables
seeds.data <- seeds.data / apply( seeds.data, 1, sd )

## form a set of candidate bandwidths
h.cand <- quantile( dist( t( seeds.data ) ), seq( 0.05, 0.40, by=0.05 ) )
## perform mean shift clustering with the blurring version of the algorithm
system.time( bms.clustering <- lapply( h.cand,
function( h ){ bmsClustering( seeds.data, h=h ) } ) )
tmp.labels3 <- bms.clustering[[3]]$labels
tmp.labels3[tmp.labels3==3] <- "pink"
tmp.labels3[tmp.labels3==4] <- 3
tmp.labels3[tmp.labels3=="pink"] <- 4
bms.clustering[[3]]$labels <- as.integer( tmp.labels3 )

bms.clustering[[3]]$components <- bms.clustering[[3]]$components[,c( 1, 2, 4, 3, 5 )]
colnames( bms.clustering[[3]]$components ) <- colnames( bms.clustering[[3]]$components )[c( 1, 2, 4, 3, 5 )]
## the resulting object is a list with names "components" and "labels"
class( bms.clustering[[1]] )
names( bms.clustering[[1]] )

## extract the cluster labels
ms.labels1 <- bms.clustering[[1]]$labels
print( ms.labels1 )

## extract the cluster modes/representatives
ms.modes1 <- bms.clustering[[1]]$components
print( ms.modes1 )

## plot
par( mfrow=c( 1, 2 ) )
plot( seeds.data[5,], seeds.data[6,], col=bms.clustering[[1]]$labels,
xlab=names( seeds )[5], ylab=names( seeds )[6], main="Mean shift labels",
cex=0.65, pch=16 )
plot( seeds.data[5,], seeds.data[6,], col=seeds.labels,
xlab=names( seeds )[5], ylab=names( seeds )[6], main="True labels",
cex=0.65, pch=16 )

## bandwidth h is too small -> "overclustering"

## extract the cluster labels
ms.labels6 <- bms.clustering[[6]]$labels
print( ms.labels6 )

## extract the cluster modes/representatives
ms.modes6 <- bms.clustering[[6]]$components
print( ms.modes6 )

## plot
par( mfrow=c( 1, 2 ) )
plot( seeds.data[5,], seeds.data[6,], col=bms.clustering[[8]]$labels,
xlab=names( seeds )[5], ylab=names( seeds )[6], main="Mean shift labels",
cex=0.65, pch=16 )
plot( seeds.data[5,], seeds.data[6,], col=seeds.labels,
xlab=names( seeds )[5], ylab=names( seeds )[6], main="True labels",
cex=0.65, pch=16 )

## bandwidth h is too large -> "underclustering"

## extract the cluster labels
ms.labels3 <- bms.clustering[[3]]$labels
print( ms.labels3 )

## extract the cluster modes/representatives
ms.modes3 <- bms.clustering[[3]]$components
print( ms.modes3 )

## plot
par( mfrow=c( 1, 2 ) )
plot( seeds.data[5,], seeds.data[6,], col=bms.clustering[[3]]$labels,
xlab=names( seeds )[5], ylab=names( seeds )[6], main="Mean shift labels",
cex=0.65, pch=16 )
## add estimated cluster modes to the plot
points( ms.modes3[5,], ms.modes3[6,], col=1:ncol( ms.modes3 ),
pch="+", cex=3 )
plot( seeds.data[5,], seeds.data[6,], col=seeds.labels,
xlab=names( seeds )[5], ylab=names( seeds )[6], main="True labels",
cex=0.65, pch=16 )

## just right!


Try the MeanShift package in your browser

Any scripts or data that you put into this service are public.

MeanShift documentation built on May 29, 2017, 2:27 p.m.