The mean shift algorithm [@fukunagahostetlermeanshift; @chengmeanshift] is a recursive algorithm that allows us to perform nonparametric mode-based clustering, i.e. clustering data on the basis of a kernel density estimate of the probability density function associated with the data-generating process. https://normaldeviate.wordpress.com/2012/07/20/the-amazing-mean-shift-algorithm/ has a great introduction to the mean shift algorithm.
In its standard form, the mean shift algorithm works as follows. We observe $X_1, \dots, X_n$, a sample of i.i.d. random variables valued in $\mathbb{R}^d$ generated from an unknown probability density $p$. We fix a kernel function $K$ and a bandwidth parameter $h$ and we apply the update rule $$ x \leftarrow \frac{\sum_{i=1}^n K \left( \frac{\|X_i-x\|}{h} \right)X_i}{\sum_{i=1}^n K \left( \frac{\|X_i-x\|}{h} \right)} $$ to an arbitrary initial point $x=x_0 \in \mathbb{R}^d$ until convergence. The discrete sequence of points ${x_0, x_1, \dots, x_k, \dots}$ generated by the application of the above rule approximates the continuous gradient flow trajectory $\pi_x$ satisfying $$ \begin{cases} \pi_x(0)=x_0\ \pi_x'(t)=\nabla \hat p(\pi_x(t)) \end{cases} $$ where $\hat p$ is a kernel density estimator of $p$ based on another kernel function (the "shadow" kernel of $K$). In turn, $\pi_x$ is an estimate of the population gradient flow line $\tau_x$ satisfying $$ \begin{cases} \tau_x(0)=x_0\ \tau_x'(t)=\nabla p(\tau_x(t)) \end{cases} $$ associated to the population gradient flow based on $p$. Under some assumptions on $p$ and $K$, it can be shown that $\pi_x(t)$ and $\tau_x(t)$ converge respectively to a mode (a local maximum) of $\hat p$ and $p$ as $t \to \infty$. Furthermore, for any initial point $x \in \mathbb{R}^d$, there is a unique $\pi_x$ and a unique $\tau_x$, and the collections ${\tau_x}{{x \in \mathbb{R}^d}}$ and ${\pi_x}{{x \in \mathbb{R}^d}}$ both partition $\mathbb{R}^d$, thus inducing respectively a population and an empirical clustering. More specifically, a set $M$ in the population partition (or "population clustering") induced by ${\tau_x}{{x \in \mathbb{R}^d}}$ can be described as the subset of points in $\mathbb{R}^d$ such that $\tau_x(t) \to m$ as $t \to \infty$, where $m$ is a mode of $p$, i.e. $M={x \in \mathbb{R}^d: \lim{t \to \infty} \tau_x(t) = m }$. In a similar way, $\hat M={x \in \mathbb{R}^d: \lim_{t \to \infty} \pi_x(t) = \hat m }$ defines an "empirical cluster". For more details, see for instance @ariascastromeanshiftgradient and @chacon2015population.
From a practical point of view, it is clear that one is particularly interested in the case $x=x_0 \in {X_1,\dots,X_n}$ as we want to group the sample data into "sample clusters". The MeanShift
package is designed to accomplish this goal.
The MeanShift
package contains two implementations of the mean shift algorithm: the standard mean shift algorithm and its "blurring" version, which is an approximation to the standard algorithm that is often substantially faster.
The standard implementation of the mean shift algorithm comes with the function msClustering
. The user needs to input
X
: the data matrix containing the sample points ${X_1, \dots, X_n}$ by column.
h
: the bandwidth parameter.
kernel
: the type of kernel function $K$
tol.stop
: a tolerance parameter; the mean shift update equation is stopped at iteration $k$ if $\|x_k-x_{k-1}\|<$tol.stop
.
tol.epsilon
: another tolerance parameter; once the mean shift algorithm has been applied to all the columns of X
, the $X_i$ is assigned to the cluster corresponding to the mode $\hat m$ if the end point of its mean shift trajectory lies within tol.epsilon
from $\hat m$. These assignments are implemented using an efficient algorithm to identify connected components. See @carreira2015review for more details.
multi.core
: a logical parameter that allows to parallelize the algorithm using multiple cores.
In our implementation, convergence is achieved at iteration $k$ if $\|x_k - x_{k-1}\|<$tol.stop
.
The blurring mean shift algorithm is a variant of the mean shift algorithm in which the sample ${X_1,\dots,X_n}$ is updated at each mean shift iteration. In particular, $\forall i \in {1,\dots,n}$, the update
$$
X_i \leftarrow \frac{\sum_{j=1}^n K \left( \frac{\|X_j-X_i\|}{h} \right)X_j}{\sum_{j=1}^n K \left( \frac{\|X_j-X_i\|}{h} \right)}
$$
is recursively applied until convergence. In the MeanShift
package, the blurring mean shift algorithm is available with the function bmsClustering
which takes the following input arguments:
X
, h
, kernel
, tol.stop
, tol.epsilon
: same as in msClustering
.
max.iter
: a maximum number of iterations; if convergence does not occur in max.iter
iterations, the algorithm is interrupted.
In the context of bmsClustering
, convergence occurs at the $k$-th iteration if $\max_i \|X_{i,k}-X_{i,k-1}\|<$ tol.stop
.
We illustrate the use of the MeanShift
package by applying it to the seeds
dataset at https://archive.ics.uci.edu/ml/datasets/seeds. The seeds
dataset gives measurements of geometrical properties of wheat grains belonging to 3 different varieties.
Our goal is to demonstrate the use of the msClustering
and bmsClustering
functions by clustering the wheat varieties on the basis of the 7 quantitative variables contained in the dataset.
## load "MeanShift" package library( MeanShift ) ## load `seeds` dataset load( "seeds.RData" )
## wheat variety labels seeds.labels <- seeds[,"variety"] ## organize data by columns seeds.data <- t( seeds[,c( "area", "perimeter", "compactness", "length", "width", "asymmetry", "groove.length" )] ) print( dim( seeds.data ) ) ## standardize the variables seeds.data <- seeds.data / apply( seeds.data, 1, sd ) ## form a set of candidate bandwidths h.cand <- quantile( dist( t( seeds.data ) ), seq( 0.05, 0.40, by=0.05 ) )
## perform mean shift clustering with the blurring version of the algorithm system.time( bms.clustering <- lapply( h.cand, function( h ){ bmsClustering( seeds.data, h=h ) } ) )
tmp.labels3 <- bms.clustering[[3]]$labels tmp.labels3[tmp.labels3==3] <- "pink" tmp.labels3[tmp.labels3==4] <- 3 tmp.labels3[tmp.labels3=="pink"] <- 4 bms.clustering[[3]]$labels <- as.integer( tmp.labels3 ) bms.clustering[[3]]$components <- bms.clustering[[3]]$components[,c( 1, 2, 4, 3, 5 )] colnames( bms.clustering[[3]]$components ) <- colnames( bms.clustering[[3]]$components )[c( 1, 2, 4, 3, 5 )]
## the resulting object is a list with names "components" and "labels" class( bms.clustering[[1]] ) names( bms.clustering[[1]] ) ## extract the cluster labels ms.labels1 <- bms.clustering[[1]]$labels print( ms.labels1 ) ## extract the cluster modes/representatives ms.modes1 <- bms.clustering[[1]]$components print( ms.modes1 ) ## plot par( mfrow=c( 1, 2 ) ) plot( seeds.data[5,], seeds.data[6,], col=bms.clustering[[1]]$labels, xlab=names( seeds )[5], ylab=names( seeds )[6], main="Mean shift labels", cex=0.65, pch=16 ) plot( seeds.data[5,], seeds.data[6,], col=seeds.labels, xlab=names( seeds )[5], ylab=names( seeds )[6], main="True labels", cex=0.65, pch=16 ) ## bandwidth h is too small -> "overclustering" ## extract the cluster labels ms.labels6 <- bms.clustering[[6]]$labels print( ms.labels6 ) ## extract the cluster modes/representatives ms.modes6 <- bms.clustering[[6]]$components print( ms.modes6 ) ## plot par( mfrow=c( 1, 2 ) ) plot( seeds.data[5,], seeds.data[6,], col=bms.clustering[[8]]$labels, xlab=names( seeds )[5], ylab=names( seeds )[6], main="Mean shift labels", cex=0.65, pch=16 ) plot( seeds.data[5,], seeds.data[6,], col=seeds.labels, xlab=names( seeds )[5], ylab=names( seeds )[6], main="True labels", cex=0.65, pch=16 ) ## bandwidth h is too large -> "underclustering" ## extract the cluster labels ms.labels3 <- bms.clustering[[3]]$labels print( ms.labels3 ) ## extract the cluster modes/representatives ms.modes3 <- bms.clustering[[3]]$components print( ms.modes3 ) ## plot par( mfrow=c( 1, 2 ) ) plot( seeds.data[5,], seeds.data[6,], col=bms.clustering[[3]]$labels, xlab=names( seeds )[5], ylab=names( seeds )[6], main="Mean shift labels", cex=0.65, pch=16 ) ## add estimated cluster modes to the plot points( ms.modes3[5,], ms.modes3[6,], col=1:ncol( ms.modes3 ), pch="+", cex=3 ) plot( seeds.data[5,], seeds.data[6,], col=seeds.labels, xlab=names( seeds )[5], ylab=names( seeds )[6], main="True labels", cex=0.65, pch=16 ) ## just right!
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.