R/MultiBatchP.R

setClass("MultiBatchP", contains="MultiBatch")

setValidity("MultiBatchP", function(object){
  msg <- TRUE
  if(nrow(dataMean(object) > 0)){
    vals <- dataMean(object)[, 1]
    if(!is(vals, "numeric")){
      msg <- "data.mean is not numeric"
      return(msg)
    }
  }
  if(nrow(dataPrec(object)) > 0){
    vals <- dataPrec(object)[, 1]
    if(!is(vals, "numeric")){
      msg <- "data.prec is not numeric"
      return(msg)
    }
  }
  if(!identical(ncol(sigma2(object)), 1L)){
    msg <- "sigma2 matrix should have a single column"
    return(msg)
  }
  if(!identical(ncol(sigma2(chains(object))), nBatch(object))){
    msg <- "chains for sigma2 does not have the correct dimension"
    return(msg)
  }
  msg
})

mcmc_chainsP <- function(specs, parameters){
  B <- specs$number_batches
  K <- specs$k
  S <- iter(parameters$mp)
  N <- nStarts(parameters$mp)
  initialize_mcmcP(K, S, B)
}

modelValuesP <- function(specs, data, hp){
  vals <- modelValues2(specs, data, hp)
  s2 <- vals[["sigma2"]]
  if(is.null(dim(s2))){
    return(vals)
  }
  s2 <- apply(s2, 1, "mean")
  s2 <- matrix(s2, nrow=specs$number_batches, ncol=1)
  vals[["sigma2"]] <- s2
  vals
}

modelSummariesP <- function(specs){
  B <- specs$number_batches
  K <- specs$k
  data.mean <- matrix(as.numeric(NA), nrow=B, ncol=K)
  data.prec <- matrix(as.numeric(NA), nrow=B, ncol=1)
  zfreq <- integer(K)
  marginal_lik <- as.numeric(NA)
  modes <- list()
  list(data.mean=data.mean,
       data.prec=data.prec,
       zfreq=zfreq,
       marginal_lik=marginal_lik,
       modes=modes)
}

##
## Constructor
##
MultiBatchP <- function(model="MBP3",
                        data=modelData(),
                        ## by default, assume no downsampling
                        down_sample=seq_len(nrow(data)),
                        specs=model_spec(model, data, down_sample),
                        iter=1000L,
                        burnin=500L,
                        thin=1L,
                        nStarts=4L,
                        hp=Hyperparameters(k=specs$k),
                        mp=McmcParams(iter=iter, thin=thin,
                                      burnin=burnin,
                                      nStarts=nStarts),
                        parameters=modelParameters(mp=mp, hp=hp),
                        chains=mcmc_chainsP(specs, parameters),
                        current_values=modelValuesP(specs, data[down_sample, ], hp),
                        summaries=modelSummariesP(specs),
                        flags=modelFlags()){
  ##
  ## When there are multiple batches in data, but the model specification is one of SB[X]
  ##
  is_SB <- substr(model, 1, 2) == "SB"
  if(nrow(data) > 0 && is_SB){
    data$batch <- 1L
  }
  model <- new("MultiBatchP",
               data=data,
               down_sample=down_sample,
               specs=specs,
               parameters=parameters,
               chains=chains,
               current_values=current_values,
               summaries=summaries,
               flags=flags)
  model
}

setMethod("computePrec", "MultiBatchP", function(object){
  z(object) <- map_z(object)
  tib <- downSampledData(object) %>%
    mutate(z=z(object)) %>%
    group_by(batch) %>%
    summarize(prec=1/var(oned))
  matrix(tib$prec, nrow=nrow(tib))
})

setMethod("downSampleModel", "MultiBatchP", function(object, N=1000, i){
  if(!missing(N)){
    if(N >= nrow(assays(object))){
      return(object)
    }
  }
  ## by sorting, batches are guaranteed to be ordered
  if(missing(i)){
    i <- sort(sample(seq_len(nrow(object)), N, replace=TRUE))
  }
  b <- assays(object)$batch[i]
  current.vals <- current_values(object)
  current.vals[["u"]] <- current.vals[["u"]][i]
  current.vals[["z"]] <- current.vals[["z"]][i]
  current.vals[["probz"]] <- current.vals[["probz"]][i, , drop=FALSE]
  mb <- MultiBatchP(model=modelName(object),
                    data=assays(object),
                    down_sample=i,
                    parameters=parameters(object),
                    current_values=current.vals,
                    chains=mcmc_chainsP( specs(object), parameters(object) ))
  dataMean(mb) <- computeMeans(mb)
  dataPrec(mb) <- computePrec(mb)
  zFreq(mb) <- as.integer(table(z(mb)))
  mb
})

setAs("MultiBatch", "MultiBatchP", function(from){
  vals <- current_values(from)
  vals[["sigma2"]] <- matrix(rowMeans(vals[["sigma2"]]),
                             numBatch(from),
                             1)
  s <- summaries(from)
  m <- s$modes
  if(length(m) > 0){
    m[["sigma2"]] <- matrix(rowMeans(m[["sigma2"]]),
                            numBatch(from),
                            1)
    s$modes <- m
  }
  if(!all(is.na(s$data.prec))){
    s$data.prec <- matrix(rowMeans(s[["data.prec"]]),
                          numBatch(from),
                          1)
  }
  ch <- chains(from)
  sigma2(ch) <- matrix(as.numeric(NA),
                       iter(from),
                       numBatch(from))
  if(numBatch(from) > 1){
    specs(from)$model <- "MBP"
  }else specs(from)$model <- "SBP"
  model <- new("MultiBatchP",
               data=assays(from),
               down_sample=down_sample(from),
               specs=specs(from),
               parameters=parameters(from),
               chains=ch,
               current_values=vals,
               summaries=s,
               flags=flags(from))
})

setAs("MultiBatchModel", "MultiBatchP", function(from){
  mb <- as(from, "MultiBatch")
  mbp <- as(mb, "MultiBatchP")
  mbp
})

setAs("MultiBatchP", "MultiBatchPooled", function(from){
  flag1 <- as.integer(flags(from)[[".internal.constraint"]])
  flag2 <- as.integer(flags(from)[[".internal.counter"]])
  be <- as.integer(table(batch(from)))
  names(be) <- unique(batch(from))
  dat <- downSampledData(from)
  th <- theta(from)
  KB <- nrow(th) * ncol(th)
  pred <- numeric(KB)
  zs <- integer(KB)
  obj <- new("MultiBatchPooled",
             k=k(from),
             hyperparams=hyperParams(from),
             theta=theta(from),
             sigma2=sigma2(from)[, 1],
             mu=mu(from),
             tau2=tau2(from),
             nu.0=nu.0(from),
             sigma2.0=sigma2.0(from),
             pi=p(from),
             data=dat$oned,
             data.mean=dataMean(from),
             data.prec=dataPrec(from)[, 1],
             predictive=pred,
             zstar=zs,
             z=z(from),
             u=u(from),
             zfreq=zFreq(from),
             probz=probz(from),
             logprior=logPrior(from),
             loglik=log_lik(from),
             mcmc.chains=chains(from),
             mcmc.params=mcmcParams(from),
             batch=batch(from),
             batchElements=be,
             label_switch=label_switch(from),
             marginal_lik=marginal_lik(from),
             .internal.constraint=flag1,
             .internal.counter=flag2)
  m <- modes(from)
  if(length(m) == 0){
    m <- computeModes(from)
  }
  ix <- match(c("nu.0", "p"), names(m))
  if(length(ix) == 2){
    names(m)[ix] <- c("nu0", "mixprob")
  }
  m$zfreq <- table(map_z(obj))
  modes(obj) <- m
  obj
})

setAs("MultiBatchPooled", "MultiBatchP", function(from){
  values <- extractValues(from)
  values[["sigma2"]] <- matrix(values[["sigma2"]], nBatch(from), 1)
  flags <- extractFlags(from)
  data <- extractData(from)
  params <- extractParameters(from)
  summaries <- extractSummaries(from)
  summaries[["data.prec"]] <- matrix(summaries[["data.prec"]],
                                     numBatch(from), 1)
  down_sample <- seq_len(nrow(data))
  specs <- model_spec(modelName(from), data, down_sample)
  modal.ordinates <- modes(from)
  mb <- MultiBatchP(data=data,
                    ## By default, assume no downsampling
                    down_sample=down_sample,
                    specs=specs,
                    parameters=params,
                    chains=chains(from),
                    current_values=values,
                    summaries=summaries,
                    flags=flags)
  if(length(modal.ordinates) > 0 ){
    ix <- match(c("nu0", "mixprob"), names(modal.ordinates))
    names(modal.ordinates)[ix] <- c("nu.0", "p")
    modal.ordinates$z <- map_z(from)
    modal.ordinates$probz <- probz(from)
    modal.ordinates$u <- u(from)
    modal.ordinates$sigma2 <- matrix(modes(from)[["sigma2"]],
                                     numBatch(from), 1)
    m <- modal.ordinates[names(current_values(mb))]
    modes(mb) <- m
  }
  mb
})


setAs("MultiBatchP", "list", function(from){
  ns <- nStarts(from)
  mb.list <- replicate(ns, as(from, "MultiBatchPooled"))
  mb.list <- lapply(mb.list, function(x) {nStarts(x) <- 1; return(x)})
  mb.list
})


setMethod("compute_marginal_lik", "MultiBatchP", function(object, params){
  if(missing(params)){
    params <- mlParams(root=1/2,
                       reject.threshold=exp(-100),
                       prop.threshold=0.5,
                       prop.effective.size=0)
  }
  mbm <- as(object, "MultiBatchPooled")
  ml <- tryCatch(marginalLikelihood(mbm, params), warning=function(w) NULL, error=function(e) NULL)
  if(!is.null(ml)){
    summaries(object)[["marginal_lik"]] <- ml
    message("     marginal likelihood: ", round(ml, 2))
  } else {
    message("Unable to compute marginal likelihood")
  }
  object
})

setMethod("computeModes", "MultiBatchP", function(object){
  modes <- callNextMethod(object)
  i <- argMax(object)[1]
  mc <- chains(object)
  B <- specs(object)$number_batches
  sigma2max <- matrix(sigma2(mc)[i, ], B, 1)
  modes[["sigma2"]] <- sigma2max
  modes
})

Try the CNPBayes package in your browser

Any scripts or data that you put into this service are public.

CNPBayes documentation built on May 6, 2019, 4:06 a.m.