R/navmix.R

Defines functions navmix_K row_standardise C_vMF col_norm row_norm rad_plot par_plot hm_plot navmix

Documented in navmix

#'Noise-augmented directional clustering
#'
#'Performs directional clustering by fitting a noise-augmented von Mises-Fisher mixture model
#'
#'@param x Matrix of values where rows represent observations and columns represent features.
#'@param K The number of clusters to fit.
#'@param select_K If TRUE (the default setting), the number of clusters will be chosen by BIC, with K the maximum number of clusters
#'considered. If FALSE, then a model with K clusters will be fit.
#'@param common_kapp If TRUE, then model will force the kappa parameter to be equal for all clusters, except the noise
#'cluster.
#'@param pj_ini The initial proportion of observations which belong in the noise cluster. Must be a number greater or
#'equal to 0 and strictly less than 1. The default value is 0.05. If set to 0, no observations will be placed in the
#'noise cluster.
#'@param no_ini The number of time the algorithm is run with different initialisations. Must be a number greater than
#'zero. The default value is 5.
#'@param tol The tolerance threshold for convergence of the EM algorithm. Must be a number greater than 0. The default
#'value is 1.0e-4.
#'@param max_iter The maximum number of iterations of the EM algorithm. Must be a number greater than 0. The default
#'value is 100.
#'@param plot Plots of the results will be produced if set to TRUE. Default is FALSE.
#'@param plot_heat Produces a heatmap of the results if plot is set to TRUE. The heatmap will also be returned as a
#'ggplot object.
#'@param plot_radial Produces (a) radial plot(s) of the results if plot is set to TRUE.
#'@param plot_radial_separate If set to FALSE (the default value), the fitted means of each cluster are plotted on the
#'same radial plot. If set to TRUE, they are plotted on separate radial plots.
#'@param radial_legend_pos Adjusts the position of the legend for a radial plots with all fitted means plotted together.
#'@param radial_separate_col Adjusts the format of the output of radial plots on separate plots.
#'@return Returned are the BIC values for each model fitted ($BIC), the final fitted model ($fit) and, if produced, the
#'heatmap as a ggplot object ($heatmap_plot). The fitted model has the following.
#'\item{mu}{A matrix where each column represents the mean of the fitted von Mises-Fisher distribution for each cluster.}
#'\item{kappa}{A row vector where each element represents the kappa parameter of the fitted von Mises-Fisher distribution
#'for each cluster.}
#'\item{g}{A matrix of probabilities for each observation belonging to each cluster. The value in the jth row and kth
#'column represents the probability that the jth observation belongs to the kth cluster.}
#'\item{z}{A vector of the cluster membership of each observation when allocated according to the cluster for which it
#'has the highest probability of membership (hard clustering).}
#'\item{bic}{The BIC for the fitted model.}
#'\item{l}{The value of the likelihood function at the estimated parameters.}

navmix = function(x, K = 10, select_K = TRUE, common_kappa = FALSE, pj_ini = 0.05, no_ini = 5, tol = 1.0e-4,
                  max_iter = 100, plot = FALSE, plot_heat = TRUE, reorder_traits = TRUE, plot_heat_mu = FALSE,
                  plot_parallel = TRUE, plot_radial = FALSE, plot_radial_options = list("plot_radial_separate" = FALSE,
                                                                  "radial_legend_pos" = c(-2.5, 2.7),
                                                                  "radial_separate_col" = 2)){
  if (is.numeric(pj_ini) == FALSE){stop('pj_ini must be a number greater or equal to 0 and strictly less than 1')}
    else if (pj_ini < 0 | pj_ini > 1) {stop('pj_ini must be a number greater or equal to 0 and strictly less than 1')}
  if (is.numeric(no_ini) == FALSE){stop('no_ini must be a number greater or equal to 1')}
    else if (no_ini < 1){stop('no_ini must be a number greater or equal to 1')}
  if (is.numeric(tol) == FALSE){stop('tol must be a number greater than 0')}
    else if (tol <= 0){stop('tol must be a number greater than 0')}
  if (is.numeric(max_iter) == FALSE){stop('max_iter must be a number greater than 0')}
    else if (max_iter <= 0){stop('max_iter must be a number greater than 0')}
  if (is.null(dim(x))){
    n = length(x)
    x = as.matrix(x, nrow = n)
    m = dim(x)[1]
  } else {
    x = as.matrix(x)
    n = dim(x)[1]
    m = dim(x)[2]
  }
  if (is.numeric(x) == FALSE){stop('x must be able to be coerced into a numeric matrix')}
  if (sum(is.na(x)) > 0){
    x_complete = complete.cases(x)
    x = x[x_complete, ]
    warning('Rows in x with missing values have been removed.')
    n = dim(x)[1]
  } else {
    x_complete = seq(1:n)
  }
  if (m < 2){
    stop('x must have at least two columns.')
  }
  if (K > n){
    if (select_K == TRUE){
      K = n
      warning('Cannot fit more clusters than observations. Maximum K has been set to the number of observations.')
    } else {
      stop('Cannot fit more clusters than observations.')
    }
  }
  if (is.null(rownames(x))){
    snps_names = sprintf("snp %0d", seq(1:n))[x_complete]
  } else{
    snps_names = rownames(x)
  }
  if (is.null(colnames(x))){
    trait_names = sprintf("trait %0d", seq(1:m))
  } else{
    trait_names = colnames(x)
  }
  all_fit = vector(mode = "list", length = K)
  bic = vector(length = K)
  if (select_K == TRUE){
    for (j in 1:K){
      all_fit[[j]] = navmix_K(x, j, pj_ini = pj_ini, common_kappa = FALSE, no_ini = no_ini, tol = tol, max_iter = max_iter)
      bic[j] = all_fit[[j]]$bic
      if (j > 2){
        if (bic[j] < bic[(j-1)] & bic[(j-1)] < bic[(j-2)]){
          bic = bic[1:j]
          break
        }
      }
    }
  select_fit = which.max(bic)
  fit = all_fit[[select_fit]]
  } else {
    fit = navmix_K(x, K, pj_ini = pj_ini, common_kappa = FALSE, no_ini = no_ini, tol = tol, max_iter = max_iter)
    bic = fit$bic
  }
  rownames(fit$mu) = trait_names
  rownames(fit$g) = snps_names
  navmix_out = list(BIC = bic, fit = fit)
  if (plot == TRUE){
    x_prop = row_norm(x)
    rownames(x_prop) = snps_names
    colnames(x_prop) = trait_names
    noise_cl = ncol(fit$g)
    if(plot_heat == TRUE){
      heatmap_plot = hm_plot(x_prop[fit$z != noise_cl, ], fit$z[fit$z != noise_cl], reorder_traits = reorder_traits,
                             print = TRUE)
      navmix_out$heatmap_plot = heatmap_plot
    }
    if(plot_heat_mu == TRUE){
      mu = as.matrix(fit$mu[, -noise_cl])
      rownames(mu) = trait_names
      colnames(mu) = seq(1:(noise_cl-1))
      heatmap_mu_plot = hm_plot(t(mu), seq(1, (noise_cl-1)), reorder_traits = FALSE, print = FALSE)
      navmix_out$heatmap_mu_plot = heatmap_mu_plot + xlab("Cluster") +
        theme(axis.title.y = element_blank(), axis.text.x = element_text(size = 7), strip.text = element_blank())
    }
    if(plot_parallel == TRUE){
      mu = as.matrix(fit$mu[, -noise_cl])
      rownames(mu) = trait_names
      colnames(mu) = seq(1:(noise_cl-1))
      parallel_plot = par_plot(mu)
      navmix_out$parallel_plot = parallel_plot
    }
    if(plot_radial == TRUE){
      mu = as.matrix(fit$mu[, -noise_cl])
      rownames(mu) = trait_names
      colnames(mu) = seq(1:(noise_cl-1))
      if (is.null(plot_radial_options$plot_radial_separate)){plot_radial_options$plot_radial_separate = FALSE}
      if (is.null(plot_radial_options$radial_legend_pos)){plot_radial_options$radial_legend_pos = c(-2.5, 2.7)}
      if (is.null(plot_radial_options$radial_separate_col)){plot_radial_options$radial_separate_col = 2}
      radial_plot = rad_plot(mu, plot_radial_separate = plot_radial_options$plot_radial_separate,
                             radial_legend_pos = plot_radial_options$radial_legend_pos,
                             radial_separate_col = plot_radial_options$radial_separate_col)
      }
  }
  return(navmix_out)
}

hm_plot = function(B, z, reorder_traits = TRUE, print = TRUE){
  heat_df = data.frame(B)
  names(heat_df) = colnames(B)
  if (reorder_traits == TRUE){
    d = dist(t(heat_df))
    hc = hclust(d)
    heat_colord = hc$order
  } else {
    heat_colord = seq(1, ncol(B))
    }
  heat_df = heat_df[, heat_colord, drop = FALSE]
  if (is.null(rownames(B))){heat_df$Variant = seq(1, length(z))} else {heat_df$variant = rownames(B)}
  heat_df$clust = z
  heat_df = pivot_longer(heat_df, 1:ncol(B))
  names(heat_df) = c("Variant", "clust", "Trait", "Value")
  heat_df$Trait = factor(heat_df$Trait, levels = colnames(B)[heat_colord])
  heat_df$Variant = factor(heat_df$Variant, levels = rownames(B))
  heatmap_plot = ggplot(heat_df, aes(x = Variant, y = Trait, fill = Value)) +
    geom_tile() + facet_grid(cols = vars(clust), scales = "free_x", space = 'free') +
    theme(axis.text.x = element_blank(), axis.ticks.x = element_blank(), axis.text.y = element_text(size = 7),
          axis.title.y = element_text(size = 7), axis.title.x = element_text(size = 7),
          legend.text = element_text(size = 7), legend.title = element_text(size = 7),
          strip.text = element_text(size = 7)) +
    scale_fill_distiller(palette="RdYlBu") + scale_y_discrete(expand = c(0, 0)) + scale_x_discrete(expand = c(0, 0))
  if (print == TRUE) {print(heatmap_plot)} else {heatmap_plot}
}

par_plot = function(mu){
  par_df = data.frame(t(mu))
  names(par_df) = rownames(mu)
  par_df$Cluster = factor(colnames(mu), levels = colnames(mu))
  par_df = pivot_longer(par_df, seq(1, nrow(mu)), names_to =  "Trait")
  trait_names = rownames(mu)[order(apply(mu, 1, var))]
  par_df$Trait = factor(par_df$Trait, levels = trait_names)
  K = ncol(mu)
  if (K <= 8){line_col = brewer.pal(max(3, K), "Dark2")} else {line_col = hcl(seq(15, 375, length = (K + 1)))}
  ymaxmin = 1.1 * max(abs(mu))
  ggplot(par_df, aes(x = Trait, y = value, group = Cluster, color = Cluster)) + geom_line() +
    scale_color_manual(values = line_col) + ylim(c(-ymaxmin, ymaxmin)) + theme_bw() +
    theme(axis.text = element_text(size = 6), legend.text = element_text(size = 6),
          legend.title = element_text(size = 6), axis.title.x = element_text(size = 6),
          axis.title.y = element_text(size = 6)) +
    scale_x_discrete(labels = trait_names) +
    ylab("Mean proportional association") + geom_hline(yintercept = 0, size = 0.25)
}

rad_plot = function(mu, plot_radial_separate = FALSE, plot_radial_par = NULL, radial_legend_pos = c(-2.5, 2.7),
                    radial_separate_col = 2){
  if (is.null(plot_radial_par)){
    current_par = par(cex.axis = 0.5, cex.lab = 0.5, cex.main = 0.75, oma = c(0, 0, 0, 0), mfrow = c(1, 1))
  } else {
      current_par = plot_radial_par
      }
  K = ncol(mu)
  if (K <= 8){line_col = brewer.pal(max(3, K), "Dark2")} else{line_col = hcl(seq(15, 375, length = (K + 1)))}
  if (plot_radial_separate == FALSE){
    rad_prop = radial.plot(t(mu), rp.type = "p", labels = rownames(mu), show.grid.labels = TRUE, line.col = line_col,
                           lwd = 1.5, radial.lim = c(-1, 1), label.prop = 1.3)
    legend(radial_legend_pos[1], radial_legend_pos[2], seq(1, K), col=line_col, lty=1, cex = 0.5, title = "Cluster", bty = "n")
  } else {
    rlist = vector(mode = "list", length = K)
    par(mfrow = c(ceiling(K/radial_separate_col), radial_separate_col))
    for (j in 1:K){
      rlist[[j]] = radial.plot(t(mu[, j]), rp.type = "p", labels = rownames(mu), show.grid.labels = TRUE,
                               line.col = line_col[j], lwd = 1.5, radial.lim = c(-1, 1), label.prop = 1.3)
      title(paste('Cluster', j, sep = ' '))
    }
  }
  par(current_par)
}

row_norm = function(x){
  x = matrix(x, nrow = nrow(x))
  x_norm = sapply(1:nrow(x), function(j){
    x[j, ] / sqrt(sum(x[j, ]^2))
  })
  t(x_norm)
}

col_norm = function(x){
  x = matrix(x, nrow = nrow(x))
  x_norm = sapply(1:ncol(x), function(j){
    x[, j] / sqrt(sum(x[, j]^2))
  })
}

C_vMF = function(kappa, d){
  nu = d / 2 - 1
  kappa^nu / ((2 * pi)^(nu+1) * besselI(kappa, nu))
}

row_standardise = function(x, se, r){
  x_std = matrix(nrow = nrow(x), ncol = ncol(x))
  for (j in 1:nrow(x)){
    S1 = diag(se[j, ])
    S = S1 %*% r %*% S1
    x_std[j, ] = solve(sqrtm(S), x[j, ])
  }
  x_std
}

navmix_K = function(x0, K, pj_ini = 0.05, common_kappa = FALSE, no_ini = 5, tol = 1.0e-4, max_iter = 100){
  n = nrow(x0)
  m = ncol(x0)
  if (pj_ini > 0){
    no_par = (m + 1) * K + m + 1 + abs(as.numeric(common_kappa) - 1) * (K - 1)
  } else {
    no_par = (m + 1) * K + abs(as.numeric(common_kappa) - 1) * (K - 1)
  }
  x = row_norm(x0)
  M = col_norm(matrix(colSums(x), ncol = 1))
  l = vector(length = no_ini)
  for (i in 1:no_ini){
    if (K > 1){
      clust_ini = skmeans(x0, K, m = 1)
      g = sapply(1:K, function(k){
        gk = rep(0, n)
        gk[which(clust_ini$cluster == k)] = 1
        gk
      })
    } else {
      g = rep(1, n)
    }
    g = cbind((1 - pj_ini) * g, rep(pj_ini, n))
    l[i] = -Inf
    e = tol + 1
    iter = 0
    while (e > tol){
      r = t(x) %*% g[, 1:K, drop = FALSE]
      mu = cbind(col_norm(r), M)
      if (common_kappa == TRUE){
        r1 = sum(sqrt(colSums(r^2))) / sum(g[, 1:K])
        kappa = rep((r1*m - r1^3) / (1 - r1^2), K)
      } else {
        kappa = sapply(1:K, function(k){
          r1 = min(sqrt(sum(r[, k]^2)) / sum(g[, k]), 1)
          min((r1*m - r1^3) / (1 - r1^2), 500)
        })
      }
      kappa[(K+1)] = 0.0001
      p = colSums(g) / n
      C = sapply(1:(K+1), function(k){C_vMF(kappa[k], m)}) * p
      G = exp(x %*% mu %*% diag(kappa)) %*% diag(C)
      g = G / rowSums(G)
      l1 = sum(log(rowSums(G)))
      e = abs(l1 - l[i])
      l[i] = l1
      iter = iter + 1
      if (is.na(e)){
        l[i] = -Inf
        break
      }
      if (iter >= max_iter){
        break
      }
    }
    bic = 2 * l[i] - no_par * log(n)
    if (i == 1){
      mu_out = mu
      kappa_out = kappa
      g_out = g
      z_out = sapply(1:n, function(j){which.max(g[j, ])})
      bic_out = bic
      l_out = l[i]
    } else if (l[i] > l[(i - 1)]){
      mu_out = mu
      kappa_out = kappa
      g_out = g
      z_out = sapply(1:n, function(j){which.max(g[j, ])})
      bic_out = bic
      l_out = l[i]
    }
  }
  kappa_out = matrix(kappa_out, nrow = 1)
  cluster_labels = c(sprintf("Cluster %0d", seq(1:(ncol(g)-1))), "Noise")
  colnames(mu_out) = colnames(kappa_out) = colnames(g_out) = cluster_labels
  return(list(mu = mu_out, kappa = kappa_out, g = g_out, z = z_out, bic = bic_out, l = l_out))
}
aj-grant/navmix documentation built on Jan. 29, 2023, 10:21 a.m.