R/compare-samplers.R

Defines functions simulation.result eval.sampler compare.samplers

Documented in compare.samplers simulation.result

# From SamplerCompare, (c) 2010 Madeleine Thompson

# compare-samplers.R contains compare.samplers and its support
# function eval.sampler, which runs an individual simulation.
# See "Graphical Comparison of MCMC Samplers" (https://arxiv.org/abs/1011.4457)
# for discussion of the figures of merit used here.

# compare.samplers is the main entry point for the SamplerCompare
# package.  See ?compare.samplers for more information.
compare.samplers <- function(sample.size, dists, samplers,
                             tuning = 1, trace = TRUE, seed = 17,
                             burn.in = 0.2) {
  # Ensure all distributions are of class dist and all samplers are
  # functions with a name attribute.
  stopifnot(all(sapply(dists, function(a) is(a, "scdist"))))
  stopifnot(all(sapply(samplers,
    function(a) is(a, "function") && !is.null(attr(a, "name")))))

  if (trace) {
    message("Simulation started at ", Sys.time(), ".")
  }

  # Set up a data frame with a row for each simulation to run.
  jobs <- expand.grid(sampler.id = seq_along(samplers),
                      dist.id = seq_along(dists), tuning = tuning)

  # Curry out all the parameters that do not vary simulation to simulation.
  eval.sampler.job.id <- function(job.id) {
    dist <- dists[[jobs$dist.id[job.id]]]
    sampler <- samplers[[jobs$sampler.id[job.id]]]
    tuning.param <- jobs$tuning[job.id]
    eval.sampler(
        dist, sampler, sample.size, tuning.param, burn.in, trace, seed)
  }

  # Call eval.sampler.job.id on each job id, possibly using multiple cores.
  results <- parallel::mclapply(seq_len(nrow(jobs)), eval.sampler.job.id,
                                mc.preschedule = FALSE)
  comparison <- do.call(rbind, results)
  if (trace) {
    message("Simulation finished at ", Sys.time(), ".")
  }
  return(comparison)
}

# Takes a distribution, a sampler, a sample size, and a tuning parameter and
# runs the associated simulation.  Returns a list with the distribution name,
# the sampler name, the tuning parameter, the autocorrelation time and
# information about its uncertainty, the number of log density evaluations, the
# processor time consumed, the two-norm of the error in the mean estimate, and
# a flag indicating whether the simulation was aborted.
eval.sampler <- function(dist, sampler, sample.size, tuning.param, burn.in,
                         trace = FALSE, seed = 17) {
  # Set seed if requested.
  if (seed) {
    if (!exists(".Random.seed")) {
      runif(1)
    }
    saved.seed <- .Random.seed
    set.seed(seed)
  }

  # Run the sampler started at a random point on the unit hypercube
  # or an initial point generated by the distribution if available.
  if (is.null(dist$initial)) {
    x0 <- runif(dist$ndim)
  } else {
    x0 <- dist$initial()
  }
  timings <- system.time(
    S <- sampler(target.dist = dist, x0 = x0, sample.size = sample.size,
                 tuning = tuning.param))
  stopifnot(!is.null(S$X) && !is.null(S$evals))

  # Compute value to return.
  rv <- simulation.result(
      dist, attr(sampler, "name"), S$X, S$evals, S$grads, tuning.param,
      timings[["elapsed"]], burn.in,
      sampler.expr = attr(sampler, "name.expression"),
      aborted = nrow(S$X) < sample.size)

  if (trace) {
    message(sprintf("%s %s: %.3g (%.3g,%.3g) evals tuning=%.3g%s; act.y=%.3g",
                    dist$name, attr(sampler, "name"), rv$act * rv$evals,
                    rv$act.025 * rv$evals, rv$act.975 * rv$evals, tuning.param,
                    ifelse(rv$aborted, " (aborted)", ""), rv$act.y))
  }

  # Restore seed if previously modified.
  if (seed) {
    .Random.seed <- saved.seed
  }
  return(rv)
}

simulation.result <- function(target.dist, sampler.name, X,
    evals = NULL, grads = NULL, tuning = NULL, cpu = NULL, burn.in = 0.2,
    y = NULL, sampler.expr = sprintf("plain('%s')", sampler.name),
    aborted = NA) {
  # Cast X to a matrix.  Necessary when it is an mcmc object from coda.
  X <- as.matrix(X)

  # Check and chop off burn-in segments of chains.
  stopifnot(target.dist$ndim == ncol(X))
  first.obs <- ceiling(burn.in * nrow(X))
  chain <- X[first.obs:nrow(X), , drop = FALSE]

  # Initialize return value.
  num_na <- as.numeric(NA)
  rv <- list(dist = target.dist$name, dist.expr = target.dist$name.expression,
             ndim = target.dist$ndim, sampler = sampler.name,
             sampler.expr = sampler.expr, tuning = tuning, act = num_na,
             act.025 = num_na, act.975 = num_na, act.y = num_na,
             act.y.025 = num_na, act.y.975 = num_na,
             evals = ifelse(is.null(evals), num_na, evals / nrow(X)),
             grads = ifelse(is.null(grads), num_na, grads / nrow(X)),
             cpu = ifelse(is.null(cpu), num_na, cpu / nrow(X)),
             err = num_na, aborted = aborted)

  if (is.null(rv$dist.expr)) {
    rv$dist.expr <- sprintf("plain('%s')", rv$dist)
  }

  if (is.null(sampler.expr)) {
    rv$sampler.expr <- sprintf("plain('%s')", sampler.name)
  }

  # Fill in maximal autocorrelation times for components of chain.
  acts <- ar.act(chain, true.mean = target.dist$mean)
  rv$act <- acts$act
  rv$act.025 <- acts$act.025
  rv$act.975 <- acts$act.975

  # Fill in autocorrelation times for log density if the caller
  # specifies the log density function or passes in log density states
  # explicitly.
  if (is.null(y) && !is.null(target.dist$log.density)) {
    y <- apply(chain, 1, target.dist$log.density)
  }
  if (!is.null(y)) {
    acts.y <- ar.act(y, true.mean = target.dist$mean.log.dens)
    rv$act.y <- acts.y$act
    rv$act.y.025 <- acts.y$act.025
    rv$act.y.975 <- acts.y$act.975
  }

  # Compute error in sample mean if the true mean is known.
  if (!is.null(target.dist$mean)) {
    rv$err <- sqrt(sum((target.dist$mean - colMeans(chain)) ^ 2))
  }

  # Return data as a data frame with strings for factors so rbind
  # works as expected.
  return(as.data.frame(rv, stringsAsFactors = FALSE))
}

Try the SamplerCompare package in your browser

Any scripts or data that you put into this service are public.

SamplerCompare documentation built on April 24, 2023, 9:09 a.m.