R/initialise.R

Defines functions InitialisePredictive.nonconjugate InitialisePredictive.conjugate InitialisePredictive Initialise.nonconjugate Initialise.conjugate Initialise

Documented in Initialise

#' Initialise a Dirichlet process object
#'
#' Initialise a Dirichlet process object by assigning all the data points to a single cluster with a posterior or prior draw for parameters.
#'
#' @param dpObj A Dirichlet process object.
#' @param posterior TRUE/FALSE value for whether the cluster parameters should be from the posterior. If false then the values are from the prior.
#' @param m Number of auxiliary variables to use for a non-conjugate mixing distribution. Defaults to m=3. See \code{\link{ClusterComponentUpdate}} for more details on m.
#' @param verbose Logical flag indicating whether to output the acceptance ratio for non-conjugate mixtures.
#' @param numInitialClusters Number of clusters to initialise with.
#' @return A Dirichlet process object that has initial cluster allocations.
#' @export
Initialise <- function(dpObj, posterior = TRUE, m=3, verbose=TRUE, numInitialClusters = 1){
  UseMethod("Initialise", dpObj)
}

#' @export
Initialise.conjugate <- function(dpObj, posterior = TRUE, m=NULL, verbose=NULL, numInitialClusters = 1) {

  dpObj$clusterLabels <- rep_len(seq_len(numInitialClusters), length.out = dpObj$n)
  dpObj$numberClusters <- numInitialClusters
  dpObj$pointsPerCluster <- vapply(seq_len(numInitialClusters), function(x) sum(dpObj$clusterLabels == x), numeric(1))

  if (posterior && numInitialClusters == 1) {
    dpObj$clusterParameters <- PosteriorDraw(dpObj$mixingDistribution, dpObj$data, 1)
  } else {
    dpObj$clusterParameters <- PriorDraw(dpObj$mixingDistribution, numInitialClusters)
  }

  dpObj <- InitialisePredictive(dpObj)

  return(dpObj)
}

#'@export
Initialise.nonconjugate <- function(dpObj, posterior = TRUE, m = 3, verbose = TRUE, numInitialClusters=1) {

  # dpObj$clusterLabels <- 1:dpObj$n dpObj$numberClusters <- dpObj$n
  # dpObj$pointsPerCluster <- rep(1, dpObj$n) dpObj$clusterParameters <-
  # PosteriorDraw(dpObj$MixingDistribution, dpObj$data, dpObj$n)
  dpObj$clusterLabels <- rep(1, dpObj$n)
  dpObj$numberClusters <- 1
  dpObj$pointsPerCluster <- dpObj$n

  if (posterior) {
    post_draws <- PosteriorDraw(dpObj$mixingDistribution, dpObj$data, 1000)

    if (verbose)
      cat(paste("Accept Ratio: ",
                length(unique(c(post_draws[[1]])))/1000,
                "\n"))

    dpObj$clusterParameters <- lapply(post_draws, function(x) x[, , 1000, drop = FALSE])


    # dpObj$clusterParameters <- list(post_draws[[1]][, , 1000, drop = FALSE],
                                    # post_draws[[2]][, , 1000, drop = FALSE])
  } else {
    dpObj$clusterParameters <- PriorDraw(dpObj$mixingDistribution, 1)
  }

  dpObj$m <- m

  return(dpObj)
}


InitialisePredictive <- function(dpObj) UseMethod("InitialisePredictive", dpObj)

InitialisePredictive.conjugate <- function(dpObj) {

  dpObj$predictiveArray <- Predictive(dpObj$mixingDistribution, dpObj$data)

  return(dpObj)
}

InitialisePredictive.nonconjugate <- function(dpObj) {
  return(dpObj)
}

Try the dirichletprocess package in your browser

Any scripts or data that you put into this service are public.

dirichletprocess documentation built on Aug. 25, 2023, 5:19 p.m.