R/cluster_component_update.R

Defines functions ClusterComponentUpdate.hierarchical ClusterComponentUpdate.nonconjugate ClusterComponentUpdate.conjugate ClusterComponentUpdate

Documented in ClusterComponentUpdate ClusterComponentUpdate.conjugate ClusterComponentUpdate.hierarchical

#' Update the component of the Dirichlet process
#'
#' Update the cluster assignment for each data point.
#'
#' @param dpObj Dirichlet Process object
#' @return Dirichlet process object with update components.
#'
#' @examples
#' dp <- DirichletProcessGaussian(rnorm(10))
#' dp <- ClusterComponentUpdate(dp)
#'
#' @export
ClusterComponentUpdate <- function(dpObj){
  UseMethod("ClusterComponentUpdate", dpObj)
}

#' @export
#' @rdname ClusterComponentUpdate
ClusterComponentUpdate.conjugate <- function(dpObj) {

  y <- dpObj$data
  n <- dpObj$n
  alpha <- dpObj$alpha

  clusterLabels <- dpObj$clusterLabels
  clusterParams <- dpObj$clusterParameters
  numLabels <- dpObj$numberClusters
  mdObj <- dpObj$mixingDistribution

  pointsPerCluster <- dpObj$pointsPerCluster

  predictiveArray <- dpObj$predictiveArray

  for (i in seq_len(n)) {


    currentLabel <- clusterLabels[i]

    pointsPerCluster[currentLabel] <- pointsPerCluster[currentLabel] - 1

    probs <- c(
      pointsPerCluster * Likelihood(mdObj, y[i, , drop = FALSE], clusterParams),
      alpha * predictiveArray[i]
      )

    probs[is.na(probs)] <- 0

    if (all(probs == 0)) {
      probs <- rep_len(1, length(probs))
    }

    newLabel <- sample.int(numLabels + 1, 1, prob = probs)

    dpObj$pointsPerCluster <- pointsPerCluster

    dpObj <- ClusterLabelChange(dpObj, i, newLabel, currentLabel)

    pointsPerCluster <- dpObj$pointsPerCluster
    clusterLabels <- dpObj$clusterLabels
    clusterParams <- dpObj$clusterParameters
    numLabels <- dpObj$numberClusters

  }

  dpObj$pointsPerCluster <- pointsPerCluster
  dpObj$clusterLabels <- clusterLabels
  dpObj$clusterParameters <- clusterParams
  dpObj$numberClusters <- numLabels
  return(dpObj)
}
#'@export
ClusterComponentUpdate.nonconjugate <- function(dpObj) {

  y <- dpObj$data
  n <- dpObj$n
  alpha <- dpObj$alpha

  clusterLabels <- dpObj$clusterLabels
  clusterParams <- dpObj$clusterParameters
  numLabels <- dpObj$numberClusters

  mdObj <- dpObj$mixingDistribution
  m <- dpObj$m

  pointsPerCluster <- dpObj$pointsPerCluster

  aux <- vector("list", length(clusterParams))

  for (i in seq_len(n)) {

    currentLabel <- clusterLabels[i]

    pointsPerCluster[currentLabel] <- pointsPerCluster[currentLabel] - 1

    if (pointsPerCluster[currentLabel] == 0) {

      priorDraws <- PriorDraw(mdObj, m - 1)

      for (j in seq_along(priorDraws)) {
        aux[[j]] <- array(c(clusterParams[[j]][, , currentLabel], priorDraws[[j]]),
          dim = c(dim(priorDraws[[j]])[1:2], m))
      }
    } else {
        aux <- PriorDraw(mdObj, m)
    }

    probs <- c(
      pointsPerCluster * Likelihood(mdObj, y[i, , drop = FALSE],clusterParams),
      (alpha/m) * Likelihood(mdObj, y[i, , drop = FALSE], aux))

    if (any(is.nan(probs))) {
      probs[is.nan(probs)] <- 0
    }


    probs[is.na(probs)] <- 0


    if (any(is.infinite(probs))) {
      probs[is.infinite(probs)] <- 1
      probs[-is.infinite(probs)] <- 0
    }

    if (all(probs == 0)) {
      probs <- rep_len(1, length(probs))
    }
    newLabel <- sample.int(numLabels + m, 1, prob = probs)

    dpObj$pointsPerCluster <- pointsPerCluster

    dpObj <- ClusterLabelChange(dpObj, i, newLabel, currentLabel, aux)

    pointsPerCluster <- dpObj$pointsPerCluster
    clusterLabels <- dpObj$clusterLabels
    clusterParams <- dpObj$clusterParameters
    numLabels <- dpObj$numberClusters

  }

  dpObj$pointsPerCluster <- pointsPerCluster
  dpObj$clusterLabels <- clusterLabels
  dpObj$clusterParameters <- clusterParams
  dpObj$numberClusters <- numLabels
  return(dpObj)
}

#' @export
#' @rdname ClusterComponentUpdate
ClusterComponentUpdate.hierarchical <- function(dpObj){

  for(i in seq_along(dpObj$indDP)){
    dpObj$indDP[[i]] <- ClusterComponentUpdate(dpObj$indDP[[i]])
    dpObj$indDP[[i]] <- DuplicateClusterRemove(dpObj$indDP[[i]])
  }
  return(dpObj)
}

Try the dirichletprocess package in your browser

Any scripts or data that you put into this service are public.

dirichletprocess documentation built on Aug. 25, 2023, 5:19 p.m.