R/stanfit-class.R

Defines functions is.array.stanfit dimnames.stanfit As.mcmc.list as.mcmc.list.stanfit dim.stanfit as.data.frame.stanfit as.matrix.stanfit as.array.stanfit `names<-.stanfit` names.stanfit sflist2stanfit is_sfinstance_valid is_sf_valid par_traceplot get_samples2 get_samples get_kept_samples2 get_kept_samples print.stanfit

Documented in as.array.stanfit as.data.frame.stanfit as.matrix.stanfit As.mcmc.list dimnames.stanfit dim.stanfit is.array.stanfit names.stanfit print.stanfit sflist2stanfit

# This file is part of RStan
# Copyright (C) 2012, 2013, 2014, 2015, 2016, 2017 Trustees of Columbia University
#
# RStan is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 3
# of the License, or (at your option) any later version.
#
# RStan is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

setMethod("show", "stanfit", 
          function(object) {
            print.stanfit(x = object, pars = object@sim$pars_oi)
          })  

print.stanfit <- function(x, pars = x@sim$pars_oi, 
                          probs = c(0.025, 0.25, 0.5, 0.75, 0.975), 
                          digits_summary = 2, include = TRUE, ...) { 
  if (x@mode == 1L) { 
    cat("Stan model '", x@model_name, "' is of mode 'test_grad';\n",
        "sampling is not conducted.\n", sep = '')
    return(invisible(NULL)) 
  } else if (x@mode == 2L) {
    cat("Stan model '", x@model_name, "' does not contain samples.\n", sep = '') 
    return(invisible(NULL)) 
  } 

  if(!include) pars <- setdiff(x@sim$pars_oi, pars)
  s <- summary(x, pars, probs, ...)  
  if (is.null(s)) return(invisible(NULL))
  n_kept <- x@sim$n_save - x@sim$warmup2
  cat("Inference for Stan model: ", x@model_name, '.\n', sep = '')
  cat(x@sim$chains, " chains, each with iter=", x@sim$iter, 
      "; warmup=", x@sim$warmup, "; thin=", x@sim$thin, "; \n", 
      "post-warmup draws per chain=", n_kept[1], ", ", 
      "total post-warmup draws=", sum(n_kept), ".\n\n", sep = '')

  # round n_eff to integers
  s$summary[, 'n_eff'] <- round(s$summary[, 'n_eff'], 0)

  print(round(s$summary, digits_summary), ...) 

  sampler <- attr(x@sim$samples[[1]], "args")$sampler_t

  if (!is.null(x@stan_args[[1]]$method) && 
        isTRUE(x@stan_args[[1]]$method == "variational")) {
      if ("diagnostics" %in% names(x@sim)
          & "ir_idx" %in% names(x@sim$diagnostics)
          & !is.null(x@sim$diagnostics$ir_idx)) {
        cat("\nApproximate samples were drawn using VB(", x@stan_args[[1]]$algorithm, 
            ") + PSIS at ", x@date, ".\n", sep = '')
    } else {
      cat("\nApproximate samples were drawn using VB(", x@stan_args[[1]]$algorithm, 
          ") at ", x@date, ".\n", sep = '')
      message("We recommend genuine 'sampling' from the posterior distribution for final inferences!")
    }
    return(invisible(NULL))
  } else {
      cat("\nSamples were drawn using ", sampler, " at ", x@date, ".\n",
          "For each parameter, n_eff is a crude measure of effective sample size,\n", 
          "and Rhat is the potential scale reduction factor on split chains (at \n",
          "convergence, Rhat=1).\n", sep = '')
      return(invisible(NULL))
  }
}

setMethod("plot", signature(x = "stanfit", y = "missing"), 
          function(x, ..., plotfun) {
            if (x@mode == 1L) {
              cat("Stan model '", x@model_name, "' is of mode 'test_grad';\n",
                  "sampling is not conducted.\n", sep = '')
              return(invisible(NULL)) 
            } else if (x@mode == 2L) {
              cat("Stan model '", x@model_name, 
                  "' does not contain samples.\n", sep = '') 
              return(invisible(NULL)) 
            } 
            if (isTRUE(all.equal(x@sim$n_save, x@sim$warmup2,
                                 check.attributes = FALSE, check.names = FALSE))) {
              cat("Stan model '", x@model_name, 
                  "' does not contain samples after warmup.\n", sep = '')
              return(invisible(NULL))
            }
            
            if (missing(plotfun)) fun <- "stan_plot" 
            else {
              plotters <- paste0("stan_", c("plot", "trace", "scat", "hist", "dens", "ac",
                                            "diag", "rhat", "ess", "mcse", "par"))
              fun <- grep(paste0(plotfun, "$"), plotters, value = TRUE)
              if (!length(fun)) stop("Plotting function not found.")
              else fun <- match.fun(fun)
            }
            do.call(fun, list(object = x, ...))
          }) 

setGeneric(name = "get_stancode",
           def = function(object, ...) { standardGeneric("get_stancode")}) 

setGeneric(name = "get_cppo_mode", 
           def = function(object, ...) { standardGeneric("get_cppo_mode") }) 

setMethod('get_cppo_mode', signature = "stanfit", 
           function(object) { 
             cxxf <- get_cxxflags(object@stanmodel)
             if (identical(cxxf, character(0))) return(NA)
             l <- get_cxxo_level(cxxf)
             if ("" == l) l <- "0" 
             p <- match(l, c("3", "2", "1", "0")) 
             c("fast", "presentation2", "presentation1", "debug")[p]
           }) 

setMethod('get_stancode', signature = "stanfit", 
          function(object, print = FALSE) {
            return(get_stancode(object@stanmodel, print))
          }) 

setGeneric(name = 'get_stanmodel', 
           def = function(object, ...) { standardGeneric("get_stanmodel")})

setMethod("get_stanmodel", signature = "stanfit", 
          function(object) { 
            return(object@stanmodel)
          }) 

setGeneric(name = 'get_inits', 
           def = function(object, ...) { standardGeneric("get_inits")})

setMethod("get_inits", signature = "stanfit", 
          function(object, iter = NULL) { 
            if (is.null(iter)) return(object@inits)
            stopifnot(is.numeric(iter), iter > 0, 
                      iter <= length(object@sim$samples[[1]][[1]]))
            inits <- object@inits
            if (length(inits) == 0) {
              inits <- lapply(1:ncol(object), FUN = function(c) {
                sapply(object@par_dims, simplify = FALSE, FUN = function(p) {
                  if (length(p) == 0) return(NA_real_)
                  array(NA_real_, dim = p)
                })
              })
            }
            for (c in 1:ncol(object)) {
              vec <- as.matrix(as.data.frame(object@sim$samples[[c]]))[iter,]
              inits[[c]] <- relist(vec, skeleton = inits[[c]])
            }
            return(inits)                       
})

setGeneric(name = 'get_seed', 
           def = function(object, ...) { standardGeneric("get_seed")})

setMethod("get_seed", signature = "stanfit", 
          function(object) { 
            if (length(object@stan_args) < 1L) return(NULL) 
            object@stan_args[[1]]$seed })

setGeneric(name = 'get_seeds', 
           def = function(object, ...) { standardGeneric("get_seeds")})

setMethod("get_seeds", signature = "stanfit", 
          function(object) { 
            if (length(object@stan_args) < 1L) return(NULL) 
            sapply(object@stan_args, function(x) x$seed) })

### HELPER FUNCTIONS
### 


get_kept_samples <- function(n, sim) {
  #
  # Args:
  #  sim: the sim slot in object stanfit 
  #  n: the nth parameter (starting from 1) 
  # Note: 
  #  samples from different chains are merged. 

  # get chain kept samples (gcks) 
  gcks <- function(s, nw, permutation) {
    s <- s[[n]][-(1:nw)] 
    s[permutation] 
  } 
  ss <- mapply(gcks, sim$samples, sim$warmup2, sim$permutation,
               SIMPLIFY = FALSE, USE.NAMES = FALSE) 
  do.call(c, ss) 
} 

get_kept_samples2 <- function(n, sim) {
  # a different implementation of get_kept_samples 
  # It seems this one is faster than get_kept_samples 
  # TODO: to understand why it is faster? 
  lst <- vector("list", sim$chains)
  for (ic in 1:sim$chains) { 
    if (sim$warmup2[ic] > 0) 
      lst[[ic]] <- sim$samples[[ic]][[n]][-(1:sim$warmup2[ic])][sim$permutation[[ic]]]
    else 
      lst[[ic]] <- sim$samples[[ic]][[n]][sim$permutation[[ic]]]
  } 
  do.call(c, lst)
}

get_samples <- function(n, sim, inc_warmup = TRUE) {
  # get chain samples
  # Args:
  #   n: parameter index (integer)
  #   sim: the sim list in stanfit 
  # 
  # Returns:
  #   a list of chains for the nth parameter; each chain is an 
  #   element of the list.  
  if (all(sim$warmup2 == 0)) inc_warmup <- TRUE # for the case warmup sample is discarded
  gcs <- function(s, inc_warmup, nw) {
    # Args:
    #   s: samples of all chains 
    #   nw: number of warmup 
    if (inc_warmup)  return(s[[n]])
    else return(s[[n]][-(1:nw)]) 
  } 
  ss <- mapply(gcs, sim$samples, inc_warmup, sim$warmup2, 
               SIMPLIFY = FALSE, USE.NAMES = FALSE) 
  ss 
} 

get_samples2 <- function(n, sim, inc_warmup = TRUE) {
  # serves the same purpose with get_samples, but with 
  # different implementation 
  # It seems that this one is fast. 
  if (all(sim$warmup2 == 0)) inc_warmup <- TRUE # for the case warmup sample is discarded
  npar <- length(n)
  lst <- vector("list", sim$chains)
  for (ic in 1:sim$chains) {
    lst[[ic]] <- 
      if (inc_warmup) sim$samples[[ic]][[n]] else sim$samples[[ic]][[n]][-(1:sim$warmup2[ic])]
  }
  lst
} 

par_traceplot <- function(sim, n, par_name, inc_warmup = TRUE, window = NULL, ...) {
  # same thin, n_save, warmup2 for all the chains
  thin <- sim$thin
  warmup2 <- sim$warmup2[1] 
  warmup <- sim$warmup 
  main <- paste("Trace of ", par_name) 
  chain_cols <- rstan_options("rstan_chain_cols")
  warmup_col <- rstan_options("rstan_warmup_bg_col") 

  start_i <- window[1] 
  window_size <- (window[2] - start_i) %/% thin
  id <- seq.int(start_i, by = thin, length.out = window_size)
  start_idx <- (if (warmup2 == 0) (start_i - warmup) else start_i) %/% thin
  if (start_idx < 1)  start_idx <- 1
  idx <- seq.int(start_idx, by = 1, length.out = window_size)

  yrange <- NULL 
  for (i in 1:sim$chains)
    yrange <- range(yrange, sim$samples[[i]][[n]][idx]) 

  if (inc_warmup) {
    plot(range(id), yrange, type = 'n', bty = 'l',
         xlab = 'Iterations', ylab = "", main = main, ...)
    rect(par("usr")[1], par("usr")[3], sim$warmup, par("usr")[4], 
         col = warmup_col, border = NA)
  } else {  
    plot(range(id), yrange, type = 'n', bty = 'l',
         xlab = 'Iterations (without warmup)', ylab = "", main = main, ...)
  } 

  for (i in 1:sim$chains)  
    lines(id, sim$samples[[i]][[n]][idx], 
          xlab = '', ylab = '', col = chain_cols[(i-1) %% 6 + 1], ...) 
} 

######

setGeneric(name = 'get_adaptation_info', 
           def = function(object, ...) { standardGeneric("get_adaptation_info")})

setMethod("get_adaptation_info", 
          definition = function(object) {
            if (object@mode == 1L) {
              cat("Stan model '", object@model_name, "' is of mode 'test_grad';\n",
                  "sampling is not conducted.\n", sep = '')
              return(invisible(NULL)) 
            } else if (object@mode == 2L) {
              cat("Stan model '", object@model_name, "' does not contain samples.\n", sep = '') 
              return(invisible(NULL)) 
            } 

            lai <- lapply(object@sim$samples, function(x) attr(x, "adaptation_info"))
            is_empty <- function(x) { 
              if (is.null(x)) return(TRUE) 
              if (is.character(x) && all(nchar(x) == 0)) return(TRUE)
              FALSE
            }
            if (all(sapply(lai, FUN = is_empty))) return(invisible(NULL))  
            return(lai) 
          }) 

setGeneric(name = "get_logposterior", 
           def = function(object, ...) { standardGeneric("get_logposterior")})


setMethod("get_logposterior", "stanfit",
          definition = function(object, inc_warmup = TRUE) {
            if (object@mode == 1L) {
              cat("Stan model '", object@model_name, "' is of mode 'test_grad';\n",
                  "sampling is not conducted.\n", sep = '')
              return(invisible(NULL)) 
            } else if (object@mode == 2L) {
              cat("Stan model '", object@model_name, "' does not contain samples.\n", sep = '') 
              return(invisible(NULL)) 
            } 

            llp <- lapply(object@sim$samples, function(x) x[['lp__']]) 
            if (inc_warmup) return(llp)
            if (all(object@sim$warmup2 == 0)) return(llp)
            return(mapply(function(x, w) x[-(1:w)], 
                             llp, object@sim$warmup2,
                             SIMPLIFY = FALSE, USE.NAMES = FALSE)) 
          }) 

setGeneric(name = 'get_sampler_params', 
           def = function(object, ...) { standardGeneric("get_sampler_params")}) 

setMethod("get_sampler_params", 
          definition = function(object, inc_warmup = TRUE) {
            if (object@mode == 1L) {
              cat("Stan model '", object@model_name, "' is of mode 'test_grad';\n",
                  "sampling is not conducted.\n", sep = '')
              return(invisible(NULL)) 
            } else if (object@mode == 2L) {
              cat("Stan model '", object@model_name, "' does not contain samples.\n", sep = '') 
              return(invisible(NULL)) 
            }
            
            if (isTRUE(object@stan_args[[1]]$method == "variational")) {
              stop("'get_sampler_params' not available for ",
                   "meanfield or fullrank algorithms.") 
            }

            ldf <- lapply(object@sim$samples, 
                          function(x) do.call(cbind, attr(x, "sampler_params")))   
            if (all(sapply(ldf, is.null))) return(invisible(NULL))  
            if (inc_warmup) {
              if (all(object@sim$warmup2 == 0))
                warning("warmup samples not saved")
              return(invisible(ldf))
            }
            else if (all(object@sim$warmup2 == 0)) return(invisible(ldf))
            return(mapply(function(x, w) 
              if (w > 0) x[-(1:w), , drop = FALSE] else x, 
                             ldf, object@sim$warmup2, 
                             SIMPLIFY = FALSE, USE.NAMES = FALSE)) 
          }) 

setGeneric(name = 'get_elapsed_time',
           def = function(object, ...) { standardGeneric("get_elapsed_time")})

setMethod("get_elapsed_time",
          definition = function(object, inc_warmup = TRUE) {
            if (object@mode == 1L) {
              cat("Stan model '", object@model_name, "' is of mode 'test_grad';\n",
                  "sampling is not conducted.\n", sep = '')
              return(invisible(NULL))
            } else if (object@mode == 2L) {
              cat("Stan model '", object@model_name, "' does not contain samples.\n", sep = '')
              return(invisible(NULL))
            }

            ltime <- lapply(object@sim$samples,
                            function(x) attr(x, "elapsed_time"))
            t <- do.call(rbind, ltime)
            if (is.null(t)) return(t)
            cids <- sapply(object@stan_args, function(x) x$chain_id)
            rownames(t) <- paste0("chain:", cids)
            t
          })

setGeneric(name = 'get_posterior_mean', 
           def = function(object, ...) { standardGeneric("get_posterior_mean")}) 

setMethod("get_posterior_mean", signature = "stanfit",
          definition = function(object, pars) {
            if (object@mode == 1L) {
              cat("Stan model '", object@model_name, "' is of mode 'test_grad';\n",
                  "sampling is not conducted.\n", sep = '')
              return(invisible(NULL)) 
            } else if (object@mode == 2L) {
              cat("Stan model '", object@model_name, "' does not contain samples.\n", sep = '') 
              return(invisible(NULL)) 
            } 
            fnames <- flatnames(object@model_pars, object@par_dims, col_major = TRUE)
            if (!exists("posterior_mean_4all", envir = object@.MISC, inherits = FALSE)) {
              mean_pars <- lapply(object@sim$samples, function(x) attr(x, "mean_pars"))
              if (all(sapply(mean_pars, is.null))) 
                mean_pars <- list(rapply(object@sim$samples, mean))
              mean_lp__ <- lapply(object@sim$samples, function(x) attr(x, "mean_lp__"))
              if (all(sapply(mean_lp__, is.null))) 
                mean_lp__ <- as.list(rep(0, length(mean_lp__)))
              m <- rbind(do.call(cbind, mean_pars), do.call(cbind, mean_lp__))
              name_allchains <- NULL
              if (ncol(m) > 1) {
                m <- cbind(m, apply(m, 1, mean))
                name_allchains <- "mean-all chains"
              } 
              cids <- sapply(object@stan_args, function(x) x$chain_id)
              colnames(m) <- c(paste0('mean-chain:', cids), name_allchains)
              rownames(m) <- fnames
              assign("posterior_mean_4all", m, envir = object@.MISC)
            }
            pars <- if (missing(pars)) object@model_pars else check_pars(c(object@model_pars, fnames), pars)
            pars <- remove_empty_pars(pars, object@sim$dims_oi)
            tidx <- pars_total_indexes(object@model_pars, object@par_dims, fnames, pars) 
            tidx <- lapply(tidx, function(x) attr(x, "row_major_idx"))
            return(object@.MISC$posterior_mean_4all[unlist(tidx), , drop = FALSE])
          })

setGeneric(name = "extract",
           def = function(object, ...) { standardGeneric("extract") }) 

setMethod("extract", signature = "stanfit",
          definition = function(object, pars, permuted = TRUE, 
                                inc_warmup = FALSE, include = TRUE) {
            # Extract the samples in different forms for different parameters. 
            #
            # Args:
            #   object: the object of "stanfit" class 
            #   pars: the names of parameters (including other quantiles) 
            #   permuted: if TRUE, the returned samples are permuted without
            #     warming up. And all the chains are merged. 
            #   inc_warmup: if TRUE, warmup samples are kept; otherwise, 
            #     discarded. If permuted is TRUE, inc_warmup is ignored. 
            #   include: if FALSE interpret pars as those to exclude
            #
            # Returns:
            #   If permuted is TRUE, return an array (matrix) of samples with each
            #   column being the samples for a parameter. 
            #   If permuted is FALSE, return array with dimensions
            #   (# of iter (with or w.o. warmup), # of chains, # of flat parameters). 

            if (object@mode == 1L) {
              cat("Stan model '", object@model_name, "' is of mode 'test_grad';\n",
                  "sampling is not conducted.\n", sep = '')
              return(invisible(NULL)) 
            } else if (object@mode == 2L) {
              cat("Stan model '", object@model_name, "' does not contain samples.\n", sep = '') 
              return(invisible(NULL)) 
            } 

            if(!include) pars <- setdiff(object@sim$pars_oi, pars)
            pars <- if (missing(pars)) object@sim$pars_oi else check_pars_second(object@sim, pars) 
            pars <- remove_empty_pars(pars, object@sim$dims_oi)
            tidx <- pars_total_indexes(object@sim$pars_oi, 
                                       object@sim$dims_oi, 
                                       object@sim$fnames_oi, 
                                       pars) 

            n_kept <- object@sim$n_save - object@sim$warmup2
            fun1 <- function(par_i) {
              # sss <- sapply(tidx[[par_i]], get_kept_samples2, object@sim)
              # if (is.list(sss))  sss <- do.call(c, sss)
              # the above two lines are slower than the following line of code
              sss <- do.call(cbind, lapply(tidx[[par_i]], get_kept_samples2, object@sim)) 
              dim(sss) <- c(sum(n_kept), object@sim$dims_oi[[par_i]]) 
              dimnames(sss) <- list(iterations = NULL)
              sss 
            } 
           
            if (permuted) {
              slist <- lapply(pars, fun1) 
              names(slist) <- pars 
              return(slist) 
            } 

            tidx <- unlist(tidx, use.names = FALSE) 
            tidxnames <- object@sim$fnames_oi[tidx] 
            sss <- lapply(tidx, get_samples2, object@sim, inc_warmup) 
            sss2 <- lapply(sss, function(x) do.call(c, x))  # concatenate samples from different chains
            sssf <- unlist(sss2, use.names = FALSE) 
  
            n2 <- object@sim$n_save[1]  ## assuming all the chains have equal iter 
            if (!inc_warmup) n2 <- n2 - object@sim$warmup2[1] 
            dim(sssf) <- c(n2, object@sim$chains, length(tidx)) 
            cids <- sapply(object@stan_args, function(x) x$chain_id)
            dimnames(sssf) <- list(iterations = NULL, chains = paste0("chain:", cids), parameters = tidxnames)
            sssf 
          })  

setMethod("summary", signature = "stanfit", 
          function(object, pars, 
                   probs = c(0.025, 0.25, 0.50, 0.75, 0.975), use_cache = TRUE, ...) { 
            # Summarize the samples (that is, compute the mean, SD, quantiles, for 
            # the samples in all chains and chains individually after removing
            # warmup samples, and n_eff and split R hat for all the kept samples.)
            # 
            # Returns: 
            #   A list with elements:
            #   summary: the summary for all the kept samples 
            #   c_summary: the summary for individual chains. 
            # 
            # Note: 
            #   This function is not straight in terms of implementation as it
            #   saves some standard summaries including n_eff and Rhat in the
            #   environment of the object. The summaries and created upon 
            #   the first time standard summary is called and resued later if possible. 
            #   In addition, the indexes complicate the implementation as internally
            #   we use column major indexes for vector/array parameters. But it might
            #   be better to use row major indexes for the output such as print.

            if (object@mode == 1L) {
              cat("Stan model '", object@model_name, "' is of mode 'test_grad';\n",
                  "sampling is not conducted.\n", sep = '')
              return(invisible(NULL)) 
            } else if (object@mode == 2L) {
              cat("Stan model '", object@model_name, "' does not contain samples.\n", sep = '') 
              return(invisible(NULL)) 
            } 

            if (isTRUE(all.equal(object@sim$n_save, object@sim$warmup2,
                check.attributes = FALSE, check.names = FALSE))) {
              cat("Stan model '", object@model_name, "' does not contain samples after warmup.\n", sep = '')
              return(invisible(NULL))
            }

            if (!exists("summary", envir = object@.MISC, inherits = FALSE) && use_cache) 
              assign("summary", summary_sim(object@sim), envir = object@.MISC)
           
            pars <- if (missing(pars)) object@sim$pars_oi else check_pars_second(object@sim, pars) 
            pars <- remove_empty_pars(pars, object@sim$dims_oi)
            if (missing(probs)) 
              probs <- c(0.025, 0.25, 0.50, 0.75, 0.975)  

            if (!use_cache) {
              # not using the cached (and not create cache, which takes time for too many pars)
              ss <-  summary_sim(object@sim, pars, probs) 
              if (isTRUE(object@stan_args[[1]]$method == "variational")) {
                s1 <- cbind(ss$msd[, 1, drop = FALSE], ss$sem, ss$msd[, 2, drop = FALSE], 
                            ss$quan, ss$ess, ss$khat)
                colnames(s1) <- c("mean", "se_mean", "sd", colnames(ss$quan), 'n_eff', 'khat')
              } else {
                s1 <- cbind(ss$msd[, 1, drop = FALSE], ss$sem, ss$msd[, 2, drop = FALSE], 
                            ss$quan, ss$ess, ss$rhat)
                colnames(s1) <- c("mean", "se_mean", "sd", colnames(ss$quan), 'n_eff', 'Rhat')
              }
              s2 <- combine_msd_quan(ss$c_msd, ss$c_quan) 
              idx2 <- match(attr(ss, "row_major_idx"), attr(ss, "col_major_idx"))
              sf <- list(summary = s1[idx2, , drop = FALSE],
                         c_summary = s2[idx2, , , drop = FALSE])
              return(sf)
            } 
            m <- match(probs, default_summary_probs())
            if (any(is.na(m))) { # unordinary quantiles are requested 
              ss <-  summary_sim_quan(object@sim, pars, probs) 
              col_idx <- attr(ss, "col_major_idx") 
              ss$ess <- object@.MISC$summary$ess[col_idx, drop = FALSE]
              if (isTRUE(object@stan_args[[1]]$method == "variational")) {
                ss$hat <- object@.MISC$summary$khat[col_idx, drop = FALSE]
                hatstr <- "khat"
              } else {
                ss$hat <- object@.MISC$summary$rhat[col_idx, drop = FALSE]
                hatstr <- "Rhat"
              }
              ss$mean <- object@.MISC$summary$msd[col_idx, 1, drop = FALSE] 
              ss$sd <- object@.MISC$summary$msd[col_idx, 2, drop = FALSE] 
              ss$sem <- object@.MISC$summary$sem[col_idx]  
              s1 <- cbind(ss$mean, ss$sem, ss$sd, 
                          ss$quan, ss$ess, ss$hat)
              colnames(s1) <- c("mean", "se_mean", "sd", colnames(ss$quan), 'n_eff', hatstr)
              s2 <- combine_msd_quan(object@.MISC$summary$c_msd[col_idx, , , drop = FALSE], ss$c_quan) 
              idx2 <- match(attr(ss, "row_major_idx"), col_idx)
              ss <- list(summary = s1[idx2, , drop = FALSE],
                         c_summary = s2[idx2, , , drop = FALSE])
              return(ss)
            }

            tidx <- pars_total_indexes(object@sim$pars_oi, 
                                       object@sim$dims_oi, 
                                       object@sim$fnames_oi, 
                                       pars) 
            tidx <- lapply(tidx, function(x) attr(x, "row_major_idx"))
            tidx <- unlist(tidx, use.names = FALSE)
            tidx_len <- length(tidx)

            ss <- object@.MISC$summary 
            qnames <- colnames(ss$quan)[m]

            if (!is.null(object@stan_args[[1]]$method) && 
                  isTRUE(object@stan_args[[1]]$method == "variational")) {
              s1 <- cbind(ss$msd[tidx, 1, drop = FALSE], 
                          ss$sem[tidx, drop = FALSE], 
                          ss$msd[tidx, 2, drop = FALSE], 
                          ss$quan[tidx, m, drop = FALSE], 
                          ss$ess[tidx, drop = FALSE],
                          ss$khat[tidx, drop = FALSE])
              dim(s1) <- c(length(tidx), length(m) + 5L)
              colnames(s1) <- c("mean", "se_mean", "sd", qnames, 'n_eff', 'khat')
            }
            else {
              s1 <- cbind(ss$msd[tidx, 1, drop = FALSE], 
                          ss$sem[tidx, drop = FALSE], 
                          ss$msd[tidx, 2, drop = FALSE], 
                          ss$quan[tidx, m, drop = FALSE], 
                          ss$ess[tidx, drop = FALSE],
                          ss$rhat[tidx, drop = FALSE])
              dim(s1) <- c(length(tidx), length(m) + 5L) 
              colnames(s1) <- c("mean", "se_mean", "sd", qnames, 'n_eff', 'Rhat')
            }
            pars_names <- rownames(ss$msd)[tidx] 
            rownames(s1) <- pars_names 
            s2 <- combine_msd_quan(ss$c_msd[tidx, , , drop = FALSE], ss$c_quan[tidx, m, , drop = FALSE]) 
            # dim(s2) <- c(tidx_len, length(m) + 2, object@sim$chains)
            # dimnames(s2) <- list(parameter = pars_names, 
            #                      stats = c("mean", "sd", qnames), NULL) 
            ss <- list(summary = s1, c_summary = s2) 
            return(ss)
          })  

if (!isGeneric("traceplot")) {
  setGeneric(name = "traceplot",
             def = function(object, ...) { standardGeneric("traceplot") }) 
} 

if (!isGeneric("log_prob")) {
  setGeneric(name = "log_prob", 
             def = function(object, ...) { standardGeneric("log_prob") }) 
} 

if (!isGeneric("unconstrain_pars")) {
  setGeneric(name = "unconstrain_pars", 
             def = function(object, ...) { standardGeneric("unconstrain_pars") }) 
} 

setMethod("unconstrain_pars", signature = "stanfit", 
          function(object, pars) {
            # pars is a list as in specifying inits for a chain
            if (!is_sfinstance_valid(object)) 
              stop("the model object is not created or not valid")
            object@.MISC$stan_fit_instance$unconstrain_pars(pars)
          })

if (!isGeneric("constrain_pars")) {
  setGeneric(name = "constrain_pars", 
             def = function(object, ...) { standardGeneric("constrain_pars") }) 
} 

setMethod("constrain_pars", signature = "stanfit", 
          function(object, upars) {
            # upars is a vector on the unconstrained space (R^N*), 
            # where N* is the number of unconstrained parameters. 
            if (!is_sfinstance_valid(object)) 
              stop("the model object is not created or not valid")
            p <- object@.MISC$stan_fit_instance$constrain_pars(upars)
            idx_wo_lp <- which(object@model_pars != 'lp__')
            rstan_relist(p, create_skeleton(object@model_pars[idx_wo_lp], object@par_dims[idx_wo_lp]))
          })


setMethod("log_prob", signature = "stanfit", 
          function(object, upars, adjust_transform = TRUE, gradient = FALSE) {
            if (!is_sfinstance_valid(object)) 
              stop("the model object is not created or not valid")
            return(object@.MISC$stan_fit_instance$log_prob(upars, adjust_transform, gradient))
          }) 

if (!isGeneric("get_num_upars")) {
  setGeneric(name = "get_num_upars", 
             def = function(object, ...) { standardGeneric("get_num_upars") }) 
} 

setMethod("get_num_upars", signature = "stanfit", 
          function(object) {
            if (!is_sfinstance_valid(object)) 
              stop("the model object is not created or not valid")
            object@.MISC$stan_fit_instance$num_pars_unconstrained()
          })

if (!isGeneric("grad_log_prob")) {
  setGeneric(name = "grad_log_prob", 
             def = function(object, ...) { standardGeneric("grad_log_prob") }) 
} 

setMethod("grad_log_prob", signature = "stanfit", 
          function(object, upars, adjust_transform = TRUE) {
            if (!is_sfinstance_valid(object)) 
              stop("the model object is not created or not valid")
            object@.MISC$stan_fit_instance$grad_log_prob(upars, adjust_transform) 
          }) 

setMethod("traceplot", signature = "stanfit", 
          function(object, pars, include = TRUE, unconstrain = FALSE,
                   inc_warmup = FALSE, window = NULL, nrow = NULL, ncol = NULL,
                   ...) { 

            if (object@mode == 1L) {
              cat("Stan model '", object@model_name, "' is of mode 'test_grad';\n",
                  "sampling is not conducted.\n", sep = '')
              return(invisible(NULL)) 
            } else if (object@mode == 2L) {
              cat("Stan model '", object@model_name, "' does not contain samples.\n", sep = '') 
              return(invisible(NULL)) 
            } 
            args <- list(object = object, include = include, 
                         unconstrain = unconstrain, inc_warmup = inc_warmup,
                         nrow = nrow, ncol = ncol, window = window, ...)
            if (!missing(pars)) { 
              if ("log-posterior" %in% pars)
                pars[which(pars == "log-posterior")] <- "lp__"
              args$pars <- pars
            }
            do.call("stan_trace", args)
          })  


is_sf_valid <- function(sf) {
  # Similar to is_sm_valid, this is only to test whether
  # the compiled DSO is loaded 
  return(is_sm_valid(sf@stanmodel)) 
} 

is_sfinstance_valid <- function(object) {
  # Args
  #  object: an instance of S4 class stanfit 
  exists("stan_fit_instance", envir = object@.MISC, inherits = FALSE) && 
  !(is_null_ptr(object@.MISC$stan_fit_instance@.xData$.pointer) ||
    is_null_ptr(object@.MISC$stan_fit_instance@.xData$.module) || 
    is_null_ptr(object@.MISC$stan_fit_instance@.xData$.cppclass)) 
} 


sflist2stanfit <- function(sflist) {
  # merge a list of stanfit objects into one
  # Args:
  #   sflist, a list of stanfit objects, each element of which 
  #   should have equal length of `iter`, `warmup`, and `thin`. 
  # Returns: 
  #   A new stanfit objects have all the chains in each element of sf_list. 
  #   The date would be where the new object is created. 
  # Note: 
  #   * method get_seed would not work well for this merged 
  #     stanfit object in that it only returns the seed used
  #     for the first object. But all the information is still there. 
  #   * When print function is called, the sampler info is obtained 
  #     only from the first chain. 
  #  
  sf_len <- length(sflist) 
  if (sf_len == 0) stop("'sflist' should have at least 1 element")
  if (!is.list(sflist) || 
      any(sapply(sflist, function(x) !is(x, "stanfit")))) {   
    stop("'sflist' must be a list of 'stanfit' objects")
  }

  non_zero_modes_idx <- which(sapply(sflist, function(x) x@mode) > 0)
  if (length(non_zero_modes_idx) > 0) { 
    stop("The following elements of 'sflist' do not contain samples: ",
         paste(non_zero_modes_idx, collapse = ', '), ".") 
  }   
  if (sf_len == 1) return(sflist[[1]])
  for (i in 2:sf_len) { 
    if (!identical(sflist[[i]]@sim$pars_oi, sflist[[1]]@sim$pars_oi) || 
        !identical(sflist[[i]]@sim$dims_oi, sflist[[1]]@sim$dims_oi))   
      stop("parameters in element ", i, " (stanfit object) are different from in element 1") 
    if (sflist[[i]]@sim$n_save[1] != sflist[[1]]@sim$n_save[1] ||  
        sflist[[i]]@sim$warmup2[1] != sflist[[1]]@sim$warmup2[1])
      stop("all 'stanfit' objects should have equal length of iterations and warmup") 
  } 
  n_chains = sum(sapply(sflist, function(x) x@sim$chains))  
  sim = list(samples = do.call(c, lapply(sflist, function(x) x@sim$samples)),  
             chains = n_chains, 
             iter = sflist[[1]]@sim$iter, 
             thin = sflist[[1]]@sim$thin, 
             warmup = sflist[[1]]@sim$warmup,
             n_save = rep(sflist[[1]]@sim$n_save[1], n_chains), 
             warmup2 = rep(sflist[[1]]@sim$warmup2[1], n_chains), 
             permutation = do.call(c, lapply(sflist, function(x) x@sim$permutation)), 
             pars_oi = sflist[[1]]@sim$pars_oi, 
             dims_oi = sflist[[1]]@sim$dims_oi, 
             fnames_oi = sflist[[1]]@sim$fnames_oi,
             n_flatnames = sflist[[1]]@sim$n_flatnames) 
  nfit <- new("stanfit", 
              model_name = sflist[[1]]@model_name, 
              model_pars = sflist[[1]]@model_pars,
              par_dims  = sflist[[1]]@par_dims, 
              mode = 0L, 
              sim = sim, 
              inits = do.call(c, lapply(sflist, function(x) x@inits)), 
              stan_args = do.call(c, lapply(sflist, function(x) x@stan_args)), 
              stanmodel = sflist[[1]]@stanmodel, 
              date = date(), 
              .MISC = new.env(parent = emptyenv())) 
  return(nfit)
} 
# sflist2stan(list(l1=ss1, l2=ss2))

names.stanfit <- function(x) dimnames(x)$parameters

`names<-.stanfit` <- function(x, value) {
  value <- as.character(value)
  len <- length(x@sim$fnames_oi)
  if(length(value) != len)
    stop(paste("parameter names must be of length", len))
  x@sim$fnames_oi <- value
  if(length(x@.MISC$summary)) {
    x@.MISC$summary <- rapply(x@.MISC$summary, f = function(y) {
      rownames(y) <- value
      return(y)
    }, how = "replace")
  }
  return(x)
}

as.array.stanfit <- function(x, ...) {
  if (x@mode != 0) return(numeric(0)) 
  out <- extract(x, permuted = FALSE, inc_warmup = FALSE, ...)
  # dimnames(out) <- dimnames(x)
  return(out)
} 
as.matrix.stanfit <- function(x, ...) {
  if (x@mode != 0) return(numeric(0)) 
  e <- extract(x, permuted = FALSE, inc_warmup = FALSE, ...) 
  if (is.null(e)) return(e)
  enames <- dimnames(e)
  edim <- dim(e)
  dim(e) <- c(edim[1] * edim[2], edim[3])
  dimnames(e) <- enames[-2]
  e
}
 
as.data.frame.stanfit <- function(x, ...) {
  return( as.data.frame(as.matrix(x, ...)) )
}

dim.stanfit <- function(x) {
  if (x@mode != 0) return(numeric(0)) 
  c(x@sim$n_save[1] - x@sim$warmup2[1], x@sim$chains, x@sim$n_flatnames)
}

setGeneric("as.mcmc.list", function(object, ...) standardGeneric("as.mcmc.list"))

as.mcmc.list.stanfit <- function(object, pars, ...) {
  pars <- if (missing(pars)) object@sim$pars_oi else check_pars_second(object@sim, pars) 
  pars <- remove_empty_pars(pars, object@sim$dims_oi)
  tidx <- pars_total_indexes(object@sim$pars_oi, object@sim$dims_oi, object@sim$fnames_oi, pars)
  tidx <- lapply(tidx, function(x) attr(x, "row_major_idx"))
  tidx <- unlist(tidx, use.names = FALSE)

  lst <- vector("list", object@sim$chains)
  for (ic in 1:object@sim$chains) { 
    x <- do.call(cbind, object@sim$samples[[ic]])[,tidx,drop=FALSE]
    warmup2 <- object@sim$warmup2[ic] 
    if (warmup2 > 0) x <- x[-(1:warmup2), ]
    x <- as.matrix(x)
    if (is.null(colnames(x))) colnames(x) <- pars
    end <- object@sim$iter
    thin <- object@sim$thin
    start <- end - (nrow(x) - 1) * thin
    class(x) <- 'mcmc'
    attr(x, "mcpar") <- c(start, end, thin)
    lst[[ic]] <- x 
  }
  class(lst) <- "mcmc.list"
  return(lst)
}

setMethod("as.mcmc.list", "stanfit", as.mcmc.list.stanfit)

As.mcmc.list <- function(object, pars, include = TRUE, ...) {
  if (missing(pars)) pars <- object@sim$pars_oi
  else if (!include) pars <- setdiff(object@sim$pars_oi, pars)
  pars <-  check_pars_second(object@sim, pars)
  return(as.mcmc.list.stanfit(object, pars = pars))
}

dimnames.stanfit <- function(x) {
  if (x@mode != 0) return(character(0)) 
  cids <- sapply(x@stan_args, function(x) x$chain_id)
  list(iterations = NULL, chains = paste0("chain:", cids), parameters = x@sim$fnames_oi) 
}
is.array.stanfit <- function(x)  return(x@mode == 0)

Try the rstan package in your browser

Any scripts or data that you put into this service are public.

rstan documentation built on May 29, 2024, 11:04 a.m.