R/misc.R

Defines functions test_221 sample_indices log_sum_exp avoid_crash is.sparc get_CXX create_progress_html_file get_stan_param_names set_cppo parse_data get_time_from_csv is_arg_deprecated is_arg_recognizable read_csv_header all_int_eq get_dims_from_fnames unique_par dotfnames_to_sqrfnames sqrfnames_to_dotfnames read_comments system_info obj_size_str is_null_cxxfun is_null_ptr makeconf_path boost_url legitimate_model_name stan_plot_inferences rstan_relist create_skeleton summary_sim_rhat summary_sim_ess summary_sim_quan summary_sim combine_msd_quan get_par_summary_quantile get_par_summary_msd get_par_summary default_summary_probs pars_total_indexes remove_empty_pars check_pars_second check_pars_first check_pars calc_starts num_pars flatnames flat_one_par seq_array_ind multi_idx_row2colm idx_row2colm idx_col2rowm read_rdump plot_rhat_legend get_rhat_cols stan_rdump writable_sample_file is_dir_writable config_argss is_named_list check_seed append_id check_args get_model_strcode read_model_from_con data_preprocess data_list2array is_legal_stan_vname mklist list_as_integer_if_doable real_is_integer filename_rm_ext filename_ext

Documented in makeconf_path read_rdump set_cppo stan_rdump

# This file is part of RStan
# Copyright (C) 2012, 2013, 2014, 2015, 2016, 2017 Trustees of Columbia University
#
# RStan is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 3
# of the License, or (at your option) any later version.
#
# RStan is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

filename_ext <- function(x) {
  # obtain the file extension
  # copied from tools package
  pos <- regexpr("\\.([[:alnum:]]+)$", x)
  ifelse(pos > -1L, substring(x, pos + 1L), "")
}

filename_rm_ext <- function(x) {
  # remove the filename's extension
  sub("\\.[^.]*$", "", x)
}

real_is_integer <- function(x) {
  if (length(x) < 1L) return(TRUE)
  if (any(is.infinite(x)) || any(is.nan(x))) return(FALSE)
  all(floor(x) == x)
}


list_as_integer_if_doable <- function(x) {
  # change the storage mode from 'real' to 'integer'
  # if applicable since by default R use real.
  #
  # Args:
  #  x: A list
  #
  # Note:
  # Ignore non-numeric vectors since we ignore
  # them in rlist_var_context
  #
  lapply(x,
         FUN = function(y) {
           if (!is.numeric(y)) return(y)
           if (is.integer(y)) return(y)
           ## this commented out is the idea in the function is.wholenumber in
           ## the help of is.integer
           # if (isTRUE(all.equal(y, round(y), check.attributes = FALSE)))
           if (real_is_integer(y)) storage.mode(y) <- "integer"
           return(y)
         })
}

mklist <- function(names) {
  # Make a list using names
  # Args:
  #   names: character strings of names of objects
  # Note:
  #   Only extracted are modes of numeric and list, which
  #   are enough for stan

  names <- unique(names)
  cenv <- environment()
  for (fn in rev(sys.parents())) {
    env1 <- sys.frame(fn)
    if (identical(env1, cenv)) next
    d1 <- mget(names, envir = env1, ifnotfound = NA, inherits = FALSE, mode = "numeric")
    d2 <- mget(names, envir = env1, ifnotfound = NA, inherits = FALSE, mode = "list")
    na_idx1 <- is.na(d1)
    na_idx2 <- is.na(d2)
    na_idx <- na_idx1 & na_idx2
    numf <- sum(na_idx)
    if (numf > 0 && numf < length(names))
      stop(paste("objects ", paste("'", names[na_idx], "'", collapse = ', ', sep = ''),
                 " of mode numeric and list not found", sep = ''))
    if (numf == length(names))  next
    r <- c(d1[!na_idx1], d2[na_idx1])
    names(r) <- c(names[!na_idx1], names[na_idx1])
    return(r)
  }
  stop(paste("objects ", paste("'", names, "'", collapse = ', ', sep = ''),
             " of mode numeric and list not found", sep = ''))
}

stan_kw1 <- c('for', 'in', 'while', 'repeat', 'until', 'if', 'then', 'else',
              'true', 'false')
stan_kw2 <- c('int', 'real', 'vector', 'simplex', 'ordered', 'positive_ordered',
              'row_vector', 'matrix', 'corr_matrix', 'cov_matrix', 'lower', 'upper')
stan_kw3 <- c('model', 'data', 'parameters', 'quantities', 'transformed', 'generated')

cpp_kw <- c("alignas", "alignof", "and", "and_eq", "asm", "auto", "bitand", "bitor", "bool",
            "break", "case", "catch", "char", "char16_t", "char32_t", "class", "compl",
            "const", "constexpr", "const_cast", "continue", "decltype", "default", "delete",
            "do", "double", "dynamic_cast", "else", "enum", "explicit", "export", "extern",
            "false", "float", "for", "friend", "goto", "if", "inline", "int", "long", "mutable",
            "namespace", "new", "noexcept", "not", "not_eq", "nullptr", "operator", "or", "or_eq",
            "private", "protected", "public", "register", "reinterpret_cast", "return",
            "short", "signed", "sizeof", "static", "static_assert", "static_cast", "struct",
            "switch", "template", "this", "thread_local", "throw", "true", "try", "typedef",
            "typeid", "typename", "union", "unsigned", "using", "virtual", "void", "volatile",
            "wchar_t", "while", "xor", "xor_eq")


is_legal_stan_vname <- function(name) {
  # Return:
  #   FALSE: not a lega variable name in Stan
  #   TRUE: maybe it is valid, but 100% sure
  if (grepl('\\.',  name)) return(FALSE)
  if (grepl('^\\d', name)) return(FALSE)
  if (grepl('__$',  name)) return(FALSE)
  if (name %in% stan_kw1) return(FALSE)
  if (name %in% stan_kw2) return(FALSE)
  if (name %in% stan_kw3) return(FALSE)
  !name %in% cpp_kw
}

data_list2array <- function(x) {
  # Turn a list of array to an array whose first dimension is the list
  # and other dimensions being the dimensions of the array element.
  # So this would allow data in Stan coded as `vector[J] y[I]`
  # to read data in form a list that has I elements of vector of length J, say
  #
  # # I <- 4; J <- 5
  # # y <- lapply(1:I, function(i) rnorm(J))
  #
  # Args:
  #   x: A list of numeric array with the same dimensions
  # Returns:
  #   An array with the first dimension indexes the list;
  #   other dimensions being the dimensions of the list element (an array)
  #
  len <- length(x)
  if (len == 0L)  return(NULL)

  dimx1 <- dim(x[[1]])

  if (any(sapply(x, function(xi) !is.numeric(xi))))
    stop("all elements of the list should be numeric")
  if (is.null(dimx1)) dimx1 <- length(x[[1]])
  lendimx1 <- length(dimx1)

  if (len > 1) {
    d <- sapply(x[-1],
                function(xi) {
                  dimxi <- dim(xi)
                  if (is.null(dimxi)) dimxi <- length(xi)
                  identical(dimxi, dimx1)
                })
    if (!all(d)) stop("the dimensions for all elements (array) of the list are not same")
  }

  # TODO(?): check if x is numeric or array.
  x <- do.call(c, x)
  dim(x) <- c(dimx1, len)
  aperm(x, c(lendimx1 + 1L, seq_len(lendimx1)))
}

data_preprocess <- function(data) { # , varnames) {
  # Preprocess the data (list or env) to list for stan
  #
  # Args:
  #  data: A list, an environment, or a vector of character strings for names
  #  of objects
  #   * stop if there is NA; no-name lists; duplicate names
  #   * stop if the objects given name is not found
  #   * remove NULL, non-numeric elements
  #   * change to integers when applicable

  #
  # if (is.environment(data)) {

  #   data <- mget(varnames, envir = data, mode = "numeric",
  #                ifnotfound = list(NULL))
  #   data <- data[!sapply(data, is.null)]
  # }
  if (is.environment(data)) {
    data <- as.list(data)
  } else if (is.list(data)) {
    v <- names(data)
    if (is.null(v))
      stop("data must be a named list")

    ## Stan would report error if variable is not found
    ## from the list
    # if (any(nchar(v) == 0))
    #   stop("unnamed variables in data list")
    #

    if (any(duplicated(v))) {
      stop("duplicated names in data list: ",
           paste(v[duplicated(v)], collapse = " "))
    }
  } else {
    stop("data must be a list or an environment")
  }

  names <- names(data)
  for (x in names) {
    if (!is_legal_stan_vname(x))
    stop(paste('data with name ', x, " is not allowed in Stan", sep = ''))
  }

  data <- lapply(names,
                 FUN = function(name) {
                   x <- data[[name]]
                   if (is.data.frame(x)) {
                     x <- data.matrix(x) # change data.frame to array
                   } else if (is.list(x)) {
                     x <- data_list2array(x) # list to array
                   } else if (is.logical(x)) {
                     storage.mode(x) <- "integer"
                   }

                   ## Now we stop whenever we have NA in the data
                   ## since we do not know what variables are needed
                   ## at this point.
                   if (any(is.na(x))) {
                     stop("Stan does not support NA (in ", name, ") in data")
                   }

                   # remove those not numeric data
                   if (!is.numeric(x)) {
                     warning("data with name ", name, " is not numeric and not used")
                     return(NULL)
                   }

                   if (is.integer(x)) return(x)

                   # change those integers stored as reals to integers
                   if (all(abs(x) < .Machine$integer.max) && real_is_integer(x))
                     storage.mode(x) <- "integer"
                   return(x)
                 })

  names(data) <- names
  data[!sapply(data, is.null)]
}


read_model_from_con <- function(con) {
  lines <- readLines(con, n = -1L, warn = FALSE)
  paste(lines, collapse = '\n')
}

get_model_strcode <- function(file, model_code = '') {
  # return the model code as a character string
  # Args:
  #   file: a file or connection
  #   model_code: character string for one of the following
  #     * the name of an object of character string
  #     * the model code itself
  #
  # Returns:
  #   the model code with attribute model_name2,
  #   a name implied from file or object name,
  #   which can be used later when model_name is not
  #   specified for function stan.

  if (!missing(file)) {
    if (is.character(file)) {
      fname <- file
      model_name2 <- sub("\\.[^.]*$", "", filename_rm_ext(basename(fname)))
      file <- try(file(fname, "rt"))
      if (inherits(file, "try-error")) {
        stop(paste("cannot open model file \"", fname, "\"", sep = ""))
      }
      on.exit(close(file))
    } else if (!inherits(file, "connection")) {
      stop("file must be a character string or connection")
    }
    model_code <- paste(readLines(file, warn = TRUE), collapse = '\n')
    # the model name implied from file name, which
    # will be used if model_name is not specified later
    attr(model_code, "model_name2") <- model_name2
    return(model_code)
  }

  model_name2 <- attr(model_code, "model_name2")
  if (is.null(model_name2))
    model_name2 <- deparse(substitute(model_code))
  if (model_code != '' && is.character(model_code)) {
    if (!grepl("\\{", model_code)) {
      # model_code points an object that includes the model
      model_name2 <- model_code
      if (exists(model_code, mode = 'character', envir = parent.frame()))
        model_code <- get(model_code, mode = 'character', envir = parent.frame())
    } else {
      # model_code includes the code itself, two cases of passing:
      #  1. using another object such as stan(mode_code = scode)`
      #  2. providing the string directly such stan(model_code = "")
      if (grepl("\\{", model_name2))
        model_name2 <- 'anon_model'
    }
    attr(model_code, "model_name2") <- model_name2
    return(model_code)
  }

  stop("model file missing and empty model_code")
}

# FIXEME: implement more check on the arguments
check_args <- function(argss) {
  if (FALSE) stop()
}

#
# model_code <- read_model_from_con('https://stan.googlecode.com/git/src/models/bugs_examples/vol1/dyes/dyes.stan')
# cat(model_code)


append_id <- function(file, id, suffix = '.csv') {
  fname <- basename(file)
  fpath <- dirname(file)
  fname2 <- gsub("\\.csv[[:space:]]*$",
                 paste("_", id, ".csv", sep = ''),
                 fname)
  if (fname2 == fname)
    fname2 <- paste(fname, "_", id, ".csv", sep = '')
  file.path(fpath, fname2)
}

check_seed <- function(seed, warn = 0) {
  if (is.character(seed) && grepl("[^0-9]", seed)) {
    if (warn == 0) stop("seed needs to be string of digits")
    else message("seed needs to be string of digits")
    return(NULL)
  }
  if (is.numeric(seed)) seed <- as.integer(seed)
  if (is.na(seed)) seed <- sample.int(.Machine$integer.max, 1)
  seed
}

is_named_list <- function(x) {
  # tell if list x is a named list
  if (!is.list(x)) return(FALSE)
  n <- names(x)
  if (is.null(n) || "" %in% n) return(FALSE)
  return(TRUE)
}


## from ../inst/include/rstan/stan_args.hpp
#
# enum sampling_algo_t { NUTS = 1, HMC = 2, Metroplos = 3};
# enum optim_algo_t { Newton = 1, BFGS = 3, LBFGS = 4};
# enum sampling_metric_t { UNIT_E = 1, DIAG_E = 2, DENSE_E = 3};
# enum stan_args_method_t { SAMPLING = 1, OPTIM = 2, TEST_GRADIENT = 3};


config_argss <- function(chains, iter, warmup, thin,
                         init, seed, sample_file, diagnostic_file, algorithm,
                         control, ...) {

  iter <- as.integer(iter)
  if (iter < 1)
    stop("parameter 'iter' should be a positive integer")
  thin <- as.integer(thin)
  if (thin < 1 || thin > iter)
    stop("parameter 'thin' should be a positive integer less than 'iter'")
  warmup <- max(0, as.integer(warmup))
  if (warmup >= iter)
    stop("parameter 'warmup' should be an integer less than 'iter'")
  chains <- as.integer(chains)
  if (chains < 1)
    stop("parameter 'chains' should be a positive integer")

  iters <- rep(iter, chains)
  thins <- rep(thin, chains)
  warmups <- rep(warmup, chains)

  inits_specified <- FALSE
  if (is.numeric(init)) init <- as.character(init)
  if (is.character(init)) {
    if (init[1] %in% c("0", "random")) inits <- rep(init[1], chains)
    else inits <- rep("random", chains)
    inits_specified <- TRUE
  }

  dotlist <- list(...)

  # use chain_id argument if specified
  chain_ids <- seq_len(chains)
  if (!is.null(dotlist$chain_id)) {
    chain_id <- as.integer(dotlist$chain_id)
    if (any(duplicated(chain_id))) stop("chain_id has duplicated elements")
    chain_id_len <- length(chain_id)
    chain_ids <- if (chain_id_len >= chains) chain_id else {
                   c(chain_id, max(chain_id) + seq_len(chains - chain_id_len))
                 }
    dotlist$chain_id <- NULL
  }

  if (!inits_specified && is.function(init)) {
    ## the function can take an argument named by chain_id
    if (any(names(formals(init)) == "chain_id")) {
      inits <- lapply(chain_ids, function(id) init(chain_id = id))
    } else {
      inits <- lapply(chain_ids, function(id) init())
    }
    if (!is_named_list(inits[[1]]))
      stop('the function for specifying initial values need return a named list')
    inits_specified <- TRUE
  }
  if (!inits_specified && is.list(init)) {
    if (length(init) != chains)
      stop("initial value list mismatchs number of chains")
    if (!any(sapply(init, is.list))) {
      stop("initial value list is not a list of lists")
    }
    inits <- init;
    for (i in 1:chains) {
      if (!is_named_list(inits[[i]]))
        stop('the list for specifying initial values need be a named list')
    }
    inits_specified <- TRUE
  }
  if (!inits_specified) stop("wrong specification of initial values")

  ## only one seed is needed by virtue of the RNG
  seed <- if (missing(seed)) sample.int(.Machine$integer.max, 1) else check_seed(seed)

  dotlist$method <- if (!is.null(dotlist$test_grad) && dotlist$test_grad) "test_grad" else "sampling"

  all_metrics <- c("unit_e", "diag_e", "dense_e")
  if (!is.null(control)) {
    if (!is.list(control))
      stop("'control' should be a named list")
      is_arg_recognizable(names(control),
                          c("adapt_engaged", "adapt_gamma",
                            "adapt_delta", "adapt_kappa", "adapt_t0",
                            "adapt_init_buffer", "adapt_term_buffer",
                            "adapt_window", "stepsize",
                            "stepsize_jitter", "metric", "int_time",
                            "max_treedepth",
                            "epsilon", "error"),
                          pre_msg = "'control' list contains unknown members of names: ",
                          call. = FALSE)
    metric <- control$metric
    if (!is.null(metric) && is.na(match(metric, all_metrics))) {
      stop("metric should be one of ", paste0(paste0('"', all_metrics, '"'), collapse = ", "))
    }
    dotlist$control <- control
  }

  argss <- vector("list", chains)
  ## the name of arguments in the list need to
  ## match those in include/rstan/stan_args.hpp
  for (i in 1:chains)
    argss[[i]] <- list(chain_id = chain_ids[i],
                       iter = iters[i], thin = thins[i], seed = seed,
                       warmup = warmups[i], init = inits[[i]],
                       algorithm = algorithm)

  if (!missing(sample_file) && !is.null(sample_file) && !is.na(sample_file)) {
    sample_file <- writable_sample_file(sample_file)
    if (chains == 1)
        argss[[1]]$sample_file <- sample_file
    if (chains > 1) {
      for (i in 1:chains)
        argss[[i]]$sample_file <- append_id(sample_file, i)
    }
  }

  if (!missing(diagnostic_file) && !is.null(diagnostic_file) && !is.na(diagnostic_file)) {
    diagnostic_file <- writable_sample_file(diagnostic_file)
    if (chains == 1)
        argss[[1]]$diagnostic_file <- diagnostic_file
    if (chains > 1) {
      for (i in 1:chains)
        argss[[i]]$diagnostic_file <- append_id(diagnostic_file, i)
    }
  }

  for (i in 1:chains)
    argss[[i]] <- c(argss[[i]], dotlist)
  check_args(argss)
  argss
}

is_dir_writable <- function(path) {
  (file.access(path, mode = 2) == 0) && (file.access(path, mode = 1) == 0)
}

writable_sample_file <-
function(file, warn = TRUE,
         wfun = function(x, x2) {
           paste('"', x, '" is not writable; use "', x2, '" instead', sep = '')
         }) {
  # Check if the path for file is writable, if not using tempdir()
  #
  # Args:
  #  file: The file interested.
  #  warning: TRUE give a warning.
  #  warningfun: A function that take two dirs for creating
  #    the warning message.
  #
  # Returns:
  #  If the specified file is writable, return itself.
  #  Otherwise, change the path to tempdir().

  dir <- dirname(file)
  if (is_dir_writable(dir)) return(file)

  dir2 <- tempdir()
  if (warn) warning(wfun(dir, dir2))
  file.path(dir2, basename(file))
}


stan_rdump <- function(list, file = "", append = FALSE,
                       envir = parent.frame(),
                       width = options("width")$width, quiet = FALSE) {
  # Dump an R list or environment for a model data
  # to the R dump file that Stan supports.
  #
  # Args:
  #   list: a vector of character for all variables interested
  #         (the same as in R's dump function)
  #   file: the output file for dumping the variables.
  #   append: then TRUE, the file is opened with
  #           mode of appending; otherwise, a new file
  #           is created.
  #   quiet: no warning if TRUE
  #
  # Return:

  if (is.character(file)) {
    ex <- sapply(list, exists, envir = envir)
    if (!all(ex)) {
      notfound_list <- list[!ex]
      if (!quiet)
        warning(paste("objects not found: ", paste(notfound_list, collapse = ', '), sep = ''))
    }
    list <- list[ex]
    if (!any(ex))
      return(invisible(character()))

    if (nzchar(file)) {
      file <- file(file, ifelse(append, "a", "w"))
      on.exit(close(file), add = TRUE)
    } else {
      file <- stdout()
    }
  }

  for (x in list) {
    if (!is_legal_stan_vname(x) & !quiet)
      warning(paste("variable name ", x, " is not allowed in Stan", sep = ''))
  }

  l2 <- NULL
  addnlpat <- paste0("(.{1,", width, "})(\\s|$)")
  for (v in list) {
    vv <- get(v, envir)

    if (is.data.frame(vv)) {
      vv <- data.matrix(vv)
    } else if (is.list(vv)) {
      vv <- data_list2array(vv)
    } else if (is.logical(vv)) {
      mode(vv) <- "integer"
    } else if (is.factor(vv)) {
      vv <- as.integer(vv)
    }

    if (!is.numeric(vv))  {
      if (!quiet)
        warning(paste0("variable ", v, " is not supported for dumping."))
      next
    }

    if (!is.integer(vv) && max(abs(vv)) < .Machine$integer.max && real_is_integer(vv))
      storage.mode(vv) <- "integer"

    if (is.vector(vv)) {
      if (length(vv) == 0) {
        cat(v, " <- integer(0)\n", file = file, sep = '')
        next
      }
      if (length(vv) == 1) {
        cat(v, " <- ", as.character(vv), "\n", file = file, sep = '')
        next
      }
      str <- paste0(v, " <- \nc(", paste(vv, collapse = ', '), ")")
      str <-  gsub(addnlpat, '\\1\n', str)
      cat(str, file = file)
      l2 <- c(l2, v)
      next
    }

    if (is.matrix(vv) || is.array(vv)) {
      l2 <- c(l2, v)
      vvdim <- dim(vv)
      cat(v, " <- \n", file = file, sep = '')
      if (length(vv) == 0) { 
        str <- paste0("structure(integer(0), ")
      } else {
        str <- paste0("structure(c(", paste(as.vector(vv), collapse = ', '), "),") 
      }
      str <- gsub(addnlpat, '\\1\n', str)
      cat(str,
          ".Dim = c(", paste(vvdim, collapse = ', '), "))\n", file = file, sep = '')
      next
    }
  }
  invisible(l2)
}

get_rhat_cols <- function(rhats) {
  #
  # Args:
  #   rhats: a vector of rhats
  #
  rhat_nan_col <- rstan_options("plot_rhat_nan_col")
  rhat_large_col <- rstan_options("plot_rhat_large_col")
  rhat_breaks <- rstan_options("plot_rhat_breaks")
  # print(rhat_breaks)
  rhat_colors <- rstan_options("plot_rhat_cols")

  sapply(rhats,
         FUN = function(x) {
           if (is.na(x) || is.nan(x) || is.infinite(x))
             return(rhat_nan_col)
           for (i in 1:length(rhat_breaks)) {
             if (x >= rhat_breaks[i]) next
             return(rhat_colors[i])
           }
           rhat_large_col
         })
}

plot_rhat_legend <- function(x, y, cex = 1) {
  # Args
  #   x, y: left, bottom corner coordinates
  #   cex: cex for the labels
  rhat_breaks <- rstan_options("plot_rhat_breaks")
  n_breaks <- length(rhat_breaks)
  rhat_colors <- rstan_options("plot_rhat_cols")[1:n_breaks]
  rhat_legend_labels <- c(paste("< ", rhat_breaks, "  ", sep = ''),
                        paste(">= ", max(rhat_breaks), "  ", sep = ''),
                        "NaN/Inf")
  rhat_legend_cols <- c(rhat_colors, rstan_options('plot_rhat_large_col'),
                        rstan_options("plot_rhat_nan_col"))
  rhat_legend_width <- strwidth(rhat_legend_labels, cex = cex)
  rhat_rect_width <- strwidth("r-hat ", cex = cex)
  text(x, y, label = 'Rhat:  ', adj = c(0, 0), cex = cex)
  s1 <- strwidth('Rhat:  ', cex = cex)
  starts <- x + c(s1, s1 + cumsum(rhat_rect_width + rhat_legend_width))

  height <- strheight("0123456789<>=", cex = cex)

  for (i in 1:length(rhat_legend_cols)) {
    rect(starts[i], y, starts[i] + rhat_rect_width, y + height, col = rhat_legend_cols[i], border = NA)
    text(starts[i] + rhat_rect_width, y, adj = c(0, 0), label = rhat_legend_labels[i], cex = cex)
  }
}


read_rdump <- function(f, keep.source = FALSE, ...) {
  # Read data defined in an R dump file to an R list
  #
  # Args:
  #   f: the file to be sourced
  #   keep.source: see doc of function source
  # 
  # Returns:
  #   A list

  if (missing(f))
    stop("no file specified.")
  e <- new.env()
  source(file = f, local = e, keep.source = keep.source, ...)
  as.list(e)
}


idx_col2rowm <- function(d) {
  # Suppose an iteration of samples for an array parameter is ordered by
  # col-major. This function generates the indexes that can be used to change
  # the sequences to row-major.
  # Args:
  #   d: the dimension of the parameter
  len <- length(d)
  if (0 == len) return(1)
  if (1 == len) return(1:d)
  idx <- aperm(array(1:prod(d), dim = d))
  return(as.vector(idx))
}


idx_row2colm <- function(d) {
  # What if it is row-major and we want col_major?
  len <- length(d)
  if (0 == len) return(1)
  if (1 == len) return(1:d)
  idx <- aperm(array(1:prod(d), dim = rev(d)))
  return(as.vector(idx))
}

multi_idx_row2colm <- function(dims) {
  # Suppose we want to change a vector of parameter names (each of which is in
  # row major) to col major.  This function serves to get the indexes.
  # Args:
  #   dims: a list of dimensions for all the parameters
  #
  ## print(dims)
  shifts <- calc_starts(dims) - 1
  idx <- lapply(seq_along(shifts), function(i) shifts[i] + idx_row2colm(dims[[i]]))
  do.call(c, idx)
}


seq_array_ind <- function(d, col_major = FALSE) {
  #
  # Generate an array of indexes for an array parameter
  # in order of major or column.
  #
  # Args:
  #   d: the dimensions of an array parameter, for example,
  #     c(2, 3).
  #
  #   col_major: Determine what is the order of indexes.
  #   If col_major = TRUE, for d = c(2, 3), return
  #   [1, 1]
  #   [2, 1]
  #   [1, 2]
  #   [2, 2]
  #   [1, 3]
  #   [2, 3]
  #   If col_major = FALSE, for d = c(2, 3), return
  #   [1, 1]
  #   [1, 2]
  #   [1, 3]
  #   [2, 1]
  #   [2, 2]
  #   [2, 3]
  #
  # Returns:
  #   If length of d is 0, return empty vector.
  #   Otherwise, return an array of indexes, each
  #   row of which is an index.
  #
  # Note:
  #   R function arrayInd might be helpful sometimes.
  #
  if (length(d) == 0L)
    return(numeric(0L))

  total <- prod(d)
  if (total == 0L)
    return(array(0L, dim = 0L))

  len <- length(d)
  if (len == 1L)
    return(array(1:total, dim = c(total, 1)))

  res <- array(1L, dim = c(total, len))

  # Handle cases like 1x1 matrices
  if (total == 1)
    return(res)

  jidx <- if (col_major) 1L:len else len:1L
  for (i in 2L:total) {
    res[i, ] <- res[i - 1, ]
    for (j in jidx) {
      if (res[i - 1, j] < d[j]) {
        res[i, j] <- res[i - 1, j] + 1
        break
      }
      res[i, j] <- 1
    }
  }
  res
}

flat_one_par <- function(n, d, col_major = FALSE) {
  # Return all the elemenetwise parameters for a vector/array
  # parameter.
  #
  # Args:
  #  n: Name of the parameter. For example, n = "alpha"
  #  d: A vector indicates the dimensions of parameter n.
  #     For example, d = c(2, 3).  d could be empty
  #     as well when n is a scalar.
  #
  if (0 == length(d)) return(n)
  nameidx <- seq_array_ind(d, col_major)
  names <- apply(nameidx, 1, function(x) paste(n, "[", paste(x, collapse = ','), "]", sep = ''))
  as.vector(names)
}


flatnames <- function(names, dims, col_major = FALSE) {
  if (length(names) == 1)
    return(flat_one_par(names, dims[[1]], col_major = col_major))
  nameslst <- mapply(flat_one_par, names, dims,
                     MoreArgs = list(col_major = col_major),
                     SIMPLIFY = FALSE,
                     USE.NAMES = FALSE)
  if (is.vector(nameslst, "character"))
    return(nameslst)
  do.call(c, nameslst)
}

num_pars <- function(d) prod(d)

calc_starts <- function(dims) {
  len <- length(dims)
  s <- sapply(unname(dims), function(d)  num_pars(d), USE.NAMES = FALSE)
  cumsum(c(1, s))[1:len]
}

check_pars <- function(allpars, pars) {
  pars_wo_ws <- gsub('\\s+', '', pars)
  m <- which(match(pars_wo_ws, allpars, nomatch = 0) == 0)
  if (length(m) > 0)
    stop("no parameter ", paste(pars[m], collapse = ', '))
  if (length(pars_wo_ws) == 0)
    stop("no parameter specified (pars is empty)")
  unique(pars_wo_ws)
}

check_pars_first <- function(object, pars) {
  # Check if all parameters in pars are valid parameters of the model
  # Args:
  #   object: a stanfit object
  #   pars: a character vector of parameter names
  # Returns:
  #   pars without white spaces, if any, if all are valid
  #   otherwise stop reporting error
  allpars <- cbind(object@model_pars, flatnames(object@model_pars))
  check_pars(allpars, pars)
}

check_pars_second <- function(sim, pars) {
  #
  # Check if all parameters in pars are parameters for which we saved
  # their samples
  #
  # Args:
  #   sim: The sim slot of class stanfit
  #   pars: a character vector of parameter names
  #
  # Returns:
  #   pars without white spaces, if any, if all are valid
  #   otherwise stop reporting error
  if (missing(pars)) return(sim$pars_oi)
  allpars <- c(sim$pars_oi, sim$fnames_oi)
  check_pars(allpars, pars)
}

remove_empty_pars <- function(pars, model_dims) {
  #
  # Remove parameters that are actually empty, which
  # could happen when for exmample a user specify the
  # following stan model code:
  #
  # transformed data { int n; n <- 0; }
  # parameters { array[n] real y; }
  #
  # Args:
  #   pars: a character vector of parameters names
  #   model_dims: a named list of the parameter dimension
  #
  # Returns:
  #   A character vector of parameter names with empty parameter
  #   being removed.
  #
  ind <- rep(TRUE, length(pars))
  model_pars <- names(model_dims)
  if (is.null(model_pars)) stop("model_dims need be a named list")
  for (i in seq_along(pars)) {
    p <- pars[i]
    m <- match(p, model_pars)
    if (!is.na(m) && prod(model_dims[[p]]) == 0)  ind[i] <- FALSE
  }
  pars[ind]
}

pars_total_indexes <- function(names, dims, fnames, pars) {
  # Obtain the total indexes for parameters (pars) in the
  # whole sequences of names that is order by 'column major.'
  # Args:
  #   names: all the parameters names specifying the sequence of parameters
  #   dims:  the dimensions for all parameters, the order for all parameters
  #          should be the same with that in 'names'
  #   fnames: all the parameter names specified by names and dims
  #   pars:  the parameters of interest. This function assumes that
  #     pars are in names.
  # Note: inside each parameter (vector or array), the sequence is in terms of
  #   col-major. That means if we have parameter alpha and beta, the dims
  #   of which are [2,2] and [2,3] respectively.  The whole parameter sequence
  #   are alpha[1,1], alpha[2,1], alpha[1,2], alpha[2,2], beta[1,1], beta[2,1],
  #   beta[1,2], beta[2,2], beta[1,3], beta[2,3]. In addition, for the col-majored
  #   sequence, an attribute named 'row_major_idx' is attached, which could
  #   be used when row major index is favored.

  starts <- calc_starts(dims)
  par_total_indexes <- function(par) {
    # for just one parameter
    #
    p <- match(par, fnames)
    # note that here when `par' is a scalar, it would
    # match one of `fnames'
    if (!is.na(p)) {
      names(p) <- par
      attr(p, "row_major_idx") <- p
      return(p)
    }
    p <- match(par, names)
    np <- num_pars(dims[[p]])
    if (np == 0) return(NULL)
    idx <- starts[p] + seq(0, by = 1, length.out = np)
    names(idx) <- fnames[idx]
    attr(idx, "row_major_idx") <- starts[p] + idx_col2rowm(dims[[p]]) - 1
    idx
  }
  idx <- lapply(pars, FUN = par_total_indexes)
  nulls <- sapply(idx, is.null)
  idx <- idx[!nulls]
  names(idx) <- pars[!nulls]
  idx
}

rstancolgrey <- rgb(matrix(c(247, 247, 247, 204, 204, 204, 150, 150, 150, 82, 82, 82),
                           byrow = TRUE, ncol = 3),
                    alpha = 100,
                    names = paste(1:4), maxColorValue = 255)

# from https://colorbrewer2.org/, colorblind safe,
# 6 different colors, diverging
rstancolc <- rgb(matrix(c(230, 97, 1,
                          153, 142, 195,
                          84, 39, 136,
                          241, 163, 64,
                          216, 218, 235,
                          254, 224, 182),
                        byrow = TRUE, ncol = 3),
                 names = paste(1:6), maxColorValue = 255)

default_summary_probs <- function() c(0.025, 0.05, 0.10, 0.25, 0.50, 0.75, 0.90, 0.95, 0.975)

## summarize the chains merged and individually
get_par_summary <- function(sim, n, probs = default_summary_probs()) {
  ss <- lapply(1:sim$chains,
               function(i) {
                 if (sim$warmup2[i] == 0) sim$samples[[i]][[n]]
                 else sim$samples[[i]][[n]][-(1:sim$warmup2[i])]
               })
  msdfun <- function(chain) c(mean(chain, na.rm = TRUE), sd(chain, na.rm = TRUE))
  qfun <- function(chain) quantile(chain, probs = probs, na.rm = TRUE)
  c_msd <- unlist(lapply(ss, msdfun), use.names = FALSE)
  c_quan <- unlist(lapply(ss, qfun), use.names = FALSE)
  ass <- do.call(c, ss)
  msd <- msdfun(ass)
  quan <- qfun(ass)
  list(msd = msdfun(ass), quan = qfun(ass), c_msd = c_msd, c_quan = c_quan)
}

# mean and sd
get_par_summary_msd <- function(sim, n) {
  ss <- lapply(1:sim$chains,
               function(i) {
                 if (sim$warmup2[i] == 0) sim$samples[[i]][[n]]
                 else sim$samples[[i]][[n]][-(1:sim$warmup2[i])]
               })
  sumfun <- function(chain) c(mean(chain), sd(chain))
  cs <- lapply(ss, sumfun)
  as <- sumfun(do.call(c, ss))
  list(msd = as, c_msd = unlist(cs, use.names = FALSE))
}

# quantiles
get_par_summary_quantile <- function(sim, n, probs = default_summary_probs()) {
  ss <- lapply(1:sim$chains,
               function(i) {
                 if (sim$warmup2[i] == 0) sim$samples[[i]][[n]]
                 else sim$samples[[i]][[n]][-(1:sim$warmup2[i])]
               })
  sumfun <- function(chain) quantile(chain, probs = probs, na.rm = TRUE)
  cs <- lapply(ss, sumfun)
  as <- sumfun(do.call(c, ss))
  list(quan = as, c_quan = unlist(cs, use.names = FALSE))
}

combine_msd_quan <- function(msd, quan) {
  # Combine msd and quantiles for chain's summary
  # Args:
  #   msd: the array for mean and sd with dim num.par * 2 * chains
  #   cquan: the array for quantiles with dim num.par * n.quan * chains
  dim1 <- dim(msd)
  dim2 <- dim(quan)
  if (any(dim1[c(1, 3)] != dim2[c(1, 3)]))
    stop("numers of parameter/chains differ in msd and quan")
  chains <- dim1[3]
  n_par <- dim1[1]
  n_stat <- dim1[2] + dim2[2]
  par_names <- dimnames(msd)[[1]]
  stat_names <- c(dimnames(msd)[[2]], dimnames(quan)[[2]])
  chain_id_names <- dimnames(msd)[[3]]
  fun <- function(i) {
    # This is a bit ugly; one reason is that we need to
    # deal with the case that dim1[1] = 1, in which
    # a1 is a vector.
    a1 <- msd[, , i]
    a2 <- quan[, , i]
    dim(a1) <- dim1[1:2]
    dim(a2) <- dim2[1:2]
    cbind(a1, a2)
  }
  ll <- lapply(1:chains, fun)
  twodnames <- dimnames(ll[[1]])
  msdquan <- array(unlist(ll), dim = c(n_par, n_stat, chains))
  dimnames(msdquan) <- list(parameter = par_names, stats = stat_names,
                            chains = chain_id_names)
  msdquan
}


summary_sim <- function(sim, pars, probs = default_summary_probs()) {
  # cat("summary_sim is called.\n")
  probs_len <- length(probs)
  pars <- if (missing(pars)) sim$pars_oi else check_pars_second(sim, pars)
  tidx <- pars_total_indexes(sim$pars_oi, sim$dims_oi, sim$fnames_oi, pars)
  tidx_rowm <- lapply(tidx, function(x) attr(x, "row_major_idx"))
  tidx <- unlist(tidx, use.names = FALSE)
  tidx_len <- length(tidx)
  tidx_rowm <- unlist(tidx_rowm, use.names = FALSE)
  lmsdq <- lapply(tidx, function(n) get_par_summary(sim, n, probs))
  msd <- do.call(rbind, lapply(lmsdq, function(x) x$msd))
  quan <- do.call(rbind, lapply(lmsdq, function(x) x$quan))
  probs_str <- colnames(quan)
  dim(msd) <- c(tidx_len, 2)
  dim(quan) <- c(tidx_len, probs_len)
  rownames(msd) <- sim$fnames_oi[tidx]
  rownames(quan) <- sim$fnames_oi[tidx]
  colnames(msd) <- c("mean", "sd")
  colnames(quan) <- probs_str

  c_msd <- do.call(rbind, lapply(lmsdq, function(x) x$c_msd))
  c_quan <- do.call(rbind, lapply(lmsdq, function(x) x$c_quan))
  dim(c_msd) <- c(tidx_len, 2, sim$chains)
  dim(c_quan) <- c(tidx_len, probs_len, sim$chains)

  sim_attr_args <- attr(sim, "args")
  cids <- if (is.null(sim_attr_args)) {
            cids <- 1:sim$chains
          } else {
            sapply(attr(sim, "args"), function(x) x$chain_id)
          }

  dimnames(c_msd) <- list(parameters = sim$fnames_oi[tidx],
                          stats = c("mean", "sd"),
                          chains = paste0("chain:", cids))
  dimnames(c_quan) <- list(parameters = sim$fnames_oi[tidx],
                           stats = probs_str,
                           chains = paste0("chains:", cids))

  if ("diagnostics" %in% names(sim) & "n_eff" %in% names(sim$diagnostics)) {
      #vb
      ess <- array(sim$diagnostics$n_eff, dim = c(tidx_len, 1))
      khat <- array(sim$diagnostics$theta_pareto_k, dim = c(tidx_len, 1))
      mcse <- array(sim$diagnostics$mcse, dim = c(tidx_len, 1))
      ss <- list(msd = msd, sem = mcse,
                 c_msd = c_msd, quan = quan, c_quan = c_quan,
                 ess = ess, khat = khat)
  } else {
      ess <-  array(sapply(tidx, function(n) rstan_ess(sim, n)), dim = c(tidx_len, 1))
      rhat <- array(sapply(tidx, function(n) rstan_splitrhat(sim, n)), dim = c(tidx_len, 1))
      ss <- list(msd = msd, sem = msd[, 2] / sqrt(ess),
                 c_msd = c_msd, quan = quan, c_quan = c_quan,
                 ess = ess, rhat = rhat)
  }
  attr(ss, "row_major_idx") <- tidx_rowm
  attr(ss, "col_major_idx") <- tidx
  ss
}

summary_sim_quan <- function(sim, pars, probs = default_summary_probs()) {
  probs_len <- length(probs)
  pars <- if (missing(pars)) sim$pars_oi else check_pars_second(sim, pars)
  tidx <- pars_total_indexes(sim$pars_oi, sim$dims_oi, sim$fnames_oi, pars)
  tidx_rowm <- lapply(tidx, function(x) attr(x, "row_major_idx"))
  tidx <- unlist(tidx, use.names = FALSE)
  tidx_len <- length(tidx)
  tidx_rowm <- unlist(tidx_rowm, use.names = FALSE)
  lquan <- lapply(tidx, function(n) get_par_summary_quantile(sim, n, probs))
  quan <- do.call(rbind, lapply(lquan, function(x) x$quan))
  probs_str <- colnames(quan)
  dim(quan) <- c(tidx_len, probs_len)
  rownames(quan) <- sim$fnames_oi[tidx]
  colnames(quan) <- probs_str

  sim_attr_args <- attr(sim, "args")
  cids <- if (is.null(sim_attr_args)) {
            cids <- 1:sim$chains
          } else {
            sapply(attr(sim, "args"), function(x) x$chain_id)
          }

  c_quan <- do.call(rbind, lapply(lquan, function(x) x$c_quan))
  dim(c_quan) <- c(tidx_len, probs_len, sim$chains)
  dimnames(c_quan) <- list(parameters = sim$fnames_oi[tidx],
                           stats = probs_str,
                           chains = paste0("chains:", cids))

  ss <- list(quan = quan, c_quan = c_quan)
  attr(ss, "row_major_idx") <- tidx_rowm
  attr(ss, "col_major_idx") <- tidx
  ss
}

summary_sim_ess <- function(sim, pars) {
  pars <- if (missing(pars)) sim$pars_oi else check_pars_second(sim, pars)
  tidx <- pars_total_indexes(sim$pars_oi, sim$dims_oi, sim$fnames_oi, pars)
  tidx_rowm <- lapply(tidx, function(x) attr(x, "row_major_idx"))
  tidx <- unlist(tidx, use.names = FALSE)
  tidx_rowm <- unlist(tidx_rowm, use.names = FALSE)
  ess <- sapply(tidx, function(n) rstan_ess(sim, n))
  names(ess) <- sim$fnames_oi[tidx]
  attr(ess, "row_major_idx") <- tidx_rowm
  attr(ess, "col_major_idx") <- tidx
  ess
}

summary_sim_rhat <- function(sim, pars) {
  pars <- if (missing(pars)) sim$pars_oi else check_pars_second(sim, pars)
  tidx <- pars_total_indexes(sim$pars_oi, sim$dims_oi, sim$fnames_oi, pars)
  tidx_rowm <- lapply(tidx, function(x) attr(x, "row_major_idx"))
  tidx <- unlist(tidx, use.names = FALSE)
  tidx_rowm <- unlist(tidx_rowm, use.names = FALSE)
  rhat <- sapply(tidx, function(n) rstan_splitrhat(sim, n))
  names(rhat) <- sim$fnames_oi[tidx]
  attr(rhat, "row_major_idx") <- tidx_rowm
  attr(rhat, "col_major_idx") <- tidx
  rhat
}


create_skeleton <- function(pars, dims) {
  # for the purpose of using relist to convert
  # vector to list
  lst <- lapply(seq_along(pars),
                function(i) {
                  len_dims <- length(dims[[i]])
                  if (len_dims < 1) return(0)
                  return(array(0, dim = dims[[i]]))
                })
  names(lst) <- pars
  lst
}

rstan_relist <- function(x, skeleton) {
  lst <- relist(x, skeleton)
  for (i in seq_along(skeleton))
    dim(lst[[i]]) <- dim(skeleton[[i]])
  lst
}

# ported from bugs.plot.inferences in R2WinBUGS
#
stan_plot_inferences <- function(sim, summary, pars, model_info, display_parallel = FALSE, ...) {
  #
  # Args:
  #   sim: the sim list in stanfit object
  #   pars: parameters of interest
  #   model_info: names list with elements model_name and model_date
  #   display_parallel

  alert_col <- rstan_options("rstan_alert_col")
  chain_cols <- rstan_options("rstan_chain_cols")
  chain_cols.len <- length(chain_cols)

  if (.Device %in% c("windows", "X11cairo", 'quartz')) {
    cex.points <- .7
    min.width <- .02
  } else {
    cex.points <- .3
    min.width <- .01
  }

  cex_names <- .7
  cex.axis <- .6
  cex_tiny <- .4
  # the standard number of parameters in an array parameters.
  # we have this so that even the # of parameters are less than
  # 30, we still have equal space between parameters for
  # the whole plot.
  standard_width <- rstan_options('plot_standard_npar')
  max_width <- rstan_options('plot_max_npar')

  pars <- if (missing(pars)) sim$pars_oi else check_pars_second(sim, pars)
  n_pars <- length(pars)
  chains <- sim$chains

  tidx <- pars_total_indexes(sim$pars_oi, sim$dims_oi, sim$fnames_oi, pars)

  height <- .6
  # mar: c(bottom, left, top, right)
  par.old <- par(no.readonly = TRUE)
  on.exit(par(par.old))
  par(mar = c(1, 0, 1, 0))

  plot(c(0, 1), c(-n_pars - .5, -.4),
       ann = FALSE, bty = "n", xaxt = "n", yaxt = "n", type = "n")

  # plot the model general information
  header <- paste("Stan model '", model_info$model_name, "' (", chains,
                  " chains: iter=", sim$iter, "; warmup=", sim$warmup,
                  "; thin=", sim$thin, ") fitted at ",
                  model_info$model_date, sep = '')
  # side: (1=bottom, 2=left, 3=top, 4=right)
  mtext(header, side = 3, outer = TRUE, line = -1, cex = .7)

  W <- max(strwidth(pars, cex = cex_names))
  # the max width of the variable names

  # cex_names is defined at the beginning of this fun
  B <- (1 - W) / 3.8
  A <- 1 - 3.5 * B
  title <- if (display_parallel) "80% interval for each chain" else  "medians and 80% intervals"
  text(A, -.4, title, adj = 0, cex = cex_names)
  num_height <- strheight (1:9, cex = cex_tiny) * 1.2

  truncated <- FALSE
  for (k in 1:n_pars) {
    text (0, -k, pars[k], adj = 0, cex = cex_names)

    k_dim <- sim$dims_oi[[pars[k]]]
    k_dim_len <- length(k_dim)
    k_aidx <- seq_array_ind(k_dim, col_major = FALSE)

    # the index for the parameters in the whole
    # sequences of parameters
    index <- attr(tidx[[k]], "row_major_idx")

    # number of parameters we could plot for this
    # particular vector/array parameter
    k_num_p <- length(index)

    # number of parameter we would plot
    J <- min(k_num_p, max_width)
    spacing <- 3.5 / max(J, standard_width)

    # the medians for all the kept samples merged
    sprobs = default_summary_probs()
    mp <- match(0.5, sprobs)
    i80p <- match(c(0.1, 0.9), sprobs)
    med <- summary$quan[index, mp]
    med <- array(med, dim = c(k_num_p, 1))
    i80 <- summary$quan[index, i80p]
    i80 <- array(i80, dim = c(k_num_p, 2))
    rhats <- summary$rhat
    rhats_cols <- get_rhat_cols(rhats)

    med.chain <- summary$c_quan[index, mp, ]
    med.chain <- array(med.chain, dim = c(k_num_p, sim$chains))
    i80.chain <- summary$c_quan[index, i80p, ]
    i80.chain <- array(i80.chain, dim = c(k_num_p, 2, sim$chains))

    rng <- if (display_parallel) range(i80, i80.chain) else range(i80)
    p.rng <- pretty(rng, n = 2)
    b <- height / (max(p.rng) - min(p.rng))
    a <- -(k + height / 2) - b * p.rng[1]
    lines(A + c(0, 0), -k + 0.5 * height * c(-1, 1))

    # plot a line at zero (if zero is in the range of the mini-plot)
    if (min(p.rng) < 0 & max(p.rng) > 0) {
      lines(A + B * spacing * c(0, J + 1),
            rep(a, 2), lwd = .5, col = "gray")
    }
    # plot the breaks of the axis
    for (x in p.rng){
      text(A - B * .2, a + b * x, x, cex = cex.axis)
      lines(A + B * c(-.05, 0), rep(a + b * x, 2))
    }
    for (j in 1:J){
      if (display_parallel){
        for (m in 1:chains){
          interval <- a + b * i80.chain[j, , m]

          # When the interval is too tiny, we use the min.width instead
          # of the real one.
          if (interval[2] - interval[1] < min.width)
            interval <- mean(interval) + c(-.5, .5) * min.width
          segments(x0 = A + B * spacing * (j + .6 *(m - (chains + 1) / 2) / chains),
                   y0 = interval[1], y1 = interval[2], lwd = .5,
                   col = chain_cols[(m-1) %% chain_cols.len + 1])
        }
      } else {
        lines(A + B * spacing * rep(j, 2), a + b * i80[j,], lwd = .5)
        for (m in 1:chains)
          points(A + B * spacing * j, a + b * med.chain[j, m],
                 pch = 20, cex = cex.points,
                 col = chain_cols[(m-1) %% chain_cols.len + 1])
      }

      # draw an indicator for Rhat
      # (xleft, ybottom, xright, ytop)

      if (k_dim_len == 0)
        rect(A + B * spacing * (j - .5), -k - height / 2 - 0.05 + num_height * .5,
             A + B * spacing * (j + .5), -k - height / 2 - 0.05 - num_height * .5, col = rhats_cols[j], border = NA)

      # plot the dimension indexes for this parameter
      if (k_dim_len  >= 1) {
        rect(A + B * spacing * (j - .5), -k - height / 2 - 0.05 + num_height * .5,
             A + B * spacing * (j + .5), -k - height / 2 - 0.05 - num_height * (k_dim_len - .5), col = rhats_cols[j], border = NA)

        # k_dim: the dimension of parameter k
        for (m in 1:k_dim_len) {
          index0 <- k_aidx[j, m]
          if (j == 1)
            text(A+B*spacing*j, -k-height/2-.05-num_height*(m-1), index0, cex=cex_tiny)
          else if (index0 != k_aidx[j - 1, m] & (index0 %% (floor(log10(index0) + 1)) == 0))
            text(A+B*spacing*j, -k-height/2-.05-num_height*(m-1), index0, cex=cex_tiny)

          # Note for `(index0 %% (floor(log10(index0) + 1)) == 0) in the above condition.
          # When 10 <= index0 <= 99, floor(log10(index0) + 1) == 2,
          # so that one index would be drawn out of two consecutive.
          # That is, we would have 10, 12, 14, 16, etc.
          # Similarly, when 100 <= index0 <= 999, we draw one out of three
          # though in the case, we do not draw them at all since the max is
          # 40.
        }
      }
    }
    if (J < k_num_p) {
      text (-.015, -k, "*", cex = cex_names, col = alert_col)
      truncated <- TRUE
    }
  }
  plot_rhat_legend(0, -n_pars - .5, cex = cex_names)
  if (truncated) {
    text(0, -n_pars - .5 - num_height * 2.5, "*  array truncated for lack of space",
         adj = c(0, 0), cex = cex_names, col = alert_col)
  }
  invisible(NULL)
}

legitimate_model_name <- function(name, obfuscate_name = TRUE) {
  # To make model name be a valid name in C++.
  # obfuscate_name

  namep1 <- if (obfuscate_name)  basename(tempfile('model', '')) else 'model'
  name <- paste(namep1, '_', name, sep = '')
  gsub('[^[:alnum:]]', '_', name)
  # return("anon_model")

  # Note: why using different (ideally unique) name?
  #
  # The name returned from this function is used
  # as Rcpp module name and the name for the stan_fit class
  # for each model. Actually we need a unique name. The reason
  # is that it seems if the Rcpp modules have the same name, a newly
  # created model created from compiling the C++ code would replace
  # previous one though the DSO files are different. I guess
  # that Rcpp implement the module by call the C++ function using .Call, we
  # would always call the function with the same name loaded later. I am
  # not sure the real reason, but experiments do show that
  # later modules created would use previous one if the class name
  # in the module is the same. So if obfuscate_name = TRUE, we try
  # to generate a unique name, if FALSE, it is the user's responsibility
  # to keep the name unique and in the case, users might be able to
  # take advantage of tools such as ccache
}

boost_url <- function() {"https://www.boost.org/users/download/"}

makeconf_path <- function() {
  RMU <- Sys.getenv("R_MAKEVARS_USER")
  if (!is.na(RMU) && RMU != "") return(RMU)
  arch <- .Platform$r_arch
  if (arch == '')
    return(file.path(R.home(component = 'etc'), 'Makeconf'))
  return(file.path(R.home(component = 'etc'), arch, 'Makeconf'))
}

is_null_ptr <- function(ns) {
  .Call(is_Null_NS, ns)
}

is_null_cxxfun <- function(cx) {
  # Tell if the returned object from cxxfunction in package inline
  # contains null pointer
  add <- body(cx@.Data)[[2]]
  # add is of class NativeSymbol
  .Call(is_Null_NS, add)
}

obj_size_str <- function(x) {
  if (x >= 1024^3)       return(paste(round(x/1024^3, 1L), "Gb"))
  else if (x >= 1024^2)  return(paste(round(x/1024^2, 1L), "Mb"))
  else if (x >= 1024)    return(paste(round(x/1024, 1L), "Kb"))
  return(paste(x, "bytes"))
}

system_info <- function() {
  paste("OS: ", R.version$system,
        "; rstan: ",  packageVersion('rstan'),
        "; Rcpp: ", packageVersion('Rcpp'),
        "; inline: ", packageVersion('inline'), sep = '')
}

read_comments <- function(f, n = -1) {
  # Read comments beginning with `#`
  # Args:
  #   f: the filename
  #   n: max number of line; -1 means all
  # Returns:
  #   a vector of strings
  con <- file(f, 'r')
  comments <- list()
  iter <- 0
  while (length(input <- readLines(con, n = 1)) > 0) {
    if (n > 0 && n <= iter) break;
    if (grepl("#", input)) {
      comments <- c(comments, gsub("^.*#", "#", input))
      iter <- iter + 1
    }
  }
  close(con)
  do.call(c, comments)
}

sqrfnames_to_dotfnames <- function(fnames) {
  # change names such as alpha[1,1] to alpha.1.1
  gsub('\\]', '', gsub('\\[|,', '.', fnames))
}


dotfnames_to_sqrfnames <- function(fnames) {
  fnames <- sapply(fnames,
                   function(i) {
                     if (!grepl("\\.", i)) return(i)
                     i <- sub("\\.", "[", i)
                     i <- sub("\\s*$", "]", i)
                     i }, USE.NAMES = FALSE)
  gsub("\\.\\s*", ",", fnames)
}

unique_par <- function(fnames) {
  # obtain parameters from flat names in format of say alpha.1,
  # alpha.2, beta.1.1, ..., beta.3.4, --- in this case, return
  # c('alpha', 'beta')
  unique(gsub('\\..*', '', fnames))
}


get_dims_from_fnames <- function(fnames, pname) {
  # Get the dimension for a parameter from
  # the flatnames such as "alpha.1.1", ..., "alpha.3.4", the
  # format of names in the CSV files generated by Stan.
  # Currently, this function assume fnames are correctly given.
  # Args:
  #   fnames: a character of names for one (vector/array) parameter
  #   pname: the name for this vector/array parameter such as "alpha"
  #     for the above example

  if (missing(pname)) pname <- gsub('\\..*', '', fnames[1])

  if (length(fnames) == 1 && fnames == pname)
    return(integer(0)) # a scalar

  idxs <- sub(pname, '', fnames, fixed = TRUE)
  lp <- gregexpr('\\d+', idxs)

  tfun <- function(name, start, i) {
    last <- attr(start, 'match.length')[i] + start[i]
    # cat('name=', name, ', start=', start[i], ', last=', last, '.\n', sep = '')
    as.integer(substr(name, start[i], last))
  }

  dim_len <- length(lp[[1]])
  dims <- integer(dim_len)
  for (i in 1:dim_len) {
    dimi <- mapply(tfun, idxs, lp, MoreArgs = list(i = i), USE.NAMES = FALSE)
    dims[i] <- max(dimi)
  }
  dims
}

all_int_eq <- function(is) {
  # tell if all integers in 'is' are the same
  if (!all(is.integer(is)))
    stop("not all are integers")
  min(is) == max(is)
}

read_csv_header <- function(f, comment.char = '#') {
  # Read the header of a csv file (the first line not beginning with
  # comment.char). And the line number is return as attribute of name 'lineno'.
  con <- file(f, 'r')
  niter <- 0
  iter.count <- NA
  save.warmup <- FALSE
  sample.count <- NA_integer_
  thin <- NULL
  while (length(input <- readLines(con, n = 1)) > 0) {
    niter <- niter + 1
    if (!grepl(comment.char, input)) break;
    if (grepl("# iter=",input))
      iter.count <- as.integer(gsub("# iter=","",input))
    if (grepl("#.*num_samples",input)){
      sample.count <- as.integer(gsub("[^0-9]*([0-9]*).*","\\1",input))
    }
    if (grepl("#.*num_warmup",input)){
      warmup.count <- as.integer(gsub("[^0-9]*([0-9]*).*","\\1",input))
    } else {
      warmup.count <- 0L
    }
    if (grepl("#.*thin", input)){
      thin <- as.integer(gsub("[^0-9]*([0-9]*).*","\\1",input))
    }
    if (grepl("#.*save_warmup",input)){
      save.warmup <- !grepl("0",input)
    }
    if (grepl("#.*output_sample",input)){
      iter.count <- as.numeric(gsub("[^0-9]*([0-9]*).*","\\1",input))
    }

  }
  header <- input
  if(is.na(iter.count)){
    if(save.warmup)
      iter.count <- warmup.count + sample.count
    else
      iter.count <- sample.count
  }
  if(!is.null(thin)){
    iter.count <- iter.count %/% thin
  }
  attr(header, "iter.count") <- iter.count
  attr(header, "lineno") <- niter
  close(con)
  header
}

is_arg_recognizable <- function(x, y, pre_msg = '', post_msg = '', ...) {
  # check if all elements of x are in y.
  # x: a vector of characters
  # y: a vector of characters
  idx <- match(x, y)
  na_idx <- which(is.na(idx))
  if (length(na_idx) > 0) {
    stop(pre_msg, paste(x[na_idx], collapse = ', '), ".", post_msg, ...)
  }
  return(TRUE)
}

is_arg_deprecated <- function(x, y, pre_msg = '', post_msg = '', ...) {
  # check if any elements of x are in y.
  # x: a vector of characters
  # y: a vector of characters
  idx <- match(x, y)
  found_idx <- which(!is.na(idx))
  if (length(found_idx) > 0) {
    message(pre_msg, paste(x[found_idx], collapse = ', '), ".", post_msg, ...)
  }
  return(TRUE)
}

get_time_from_csv <- function(tlines) {
  # get the warmup time and sample time from the commented lines
  # about time in the CSV files
  # Args:
  #  tlines: character vector of length 3 (or 2 since the last one is not used)
  #          from the CSV File. For example, it could be
  #          # Elapsed Time: 0.005308 seconds (Warm-up)
  #                          0.003964 seconds (Sampling)
  #                          0.009272 seconds (Total)
  t <- rep(NA, 2)
  names(t) <- c("warmup", "sample")
  if (length(tlines) < 2) return(t)
  warmupt <- gsub("[^0-9.]", "", tlines[1])
  samplet <- gsub("[^0-9.]", "", tlines[2])
  t[1] <- as.double(warmupt)
  t[2] <- as.double(samplet)
  t
}

parse_data <- function(cppcode) {
  cppcode <- scan(what = character(), sep = "\n", quiet = TRUE,
                  text = cppcode)
  private <- grep("^private:$", cppcode) + 1L
  public <- grep("^public:$", cppcode) - 1L
  # pull out object names from the data block
  objects <- gsub("^.* ([0-9A-Za-z_]+).*;.*$", "\\1",
                  cppcode[private:public])
  # Remove model internal name _data__ suffix for stanc3 v2.30+
  objects <- gsub("_data__$", "", objects)
  # Remove model internal name underscores in case of Eigen::Maps
  objects <- gsub("__$", "\\1", objects)
  # Remove any bad regex matches that found the end of an Eigen::Map.
  objects <- gsub("^[[:digit:]]+", "\\1", objects)
  # Remove empty characters and trim whitespaces
  objects <- objects[nzchar(trimws(objects))]

  # Get them from the calling environment
  stuff <- list()
  for (int in seq_along(objects)) {
   stuff[[objects[int]]] <- dynGet(objects[int], inherits = FALSE, ifnotfound = NULL)
  }
  for (i in seq_along(stuff)) if (is.null(stuff[[i]])) {
    if (exists(objects[i], envir = globalenv(), mode = "numeric"))
      stuff[[i]] <- get(objects[i], envir = globalenv(), mode = "numeric")
    else if (exists(objects[i], envir = globalenv(), mode = "logical"))
      stuff[[i]] <- get(objects[i], envir = globalenv(), mode = "logical")
  }
  return(stuff)
}

set_cppo <- function(...) {
  warning("'set_cppo' is defunct; manually edit your Makevars file if necessary")
  return(invisible(NULL))
}

get_stan_param_names <- function(object) {
  stopifnot(is(object, "stanfit"))
  params <- grep("vals_r__ = context__.vals_r(", fixed = TRUE, value = TRUE,
                 x = strsplit(get_cppcode(get_stanmodel(object)), "\n")[[1]])
  params <- sapply(strsplit(params, "\""), FUN = function(x) x[[2]])
  params <- intersect(params, object@sim$pars_oi)
  stopifnot(length(params) > 0)
  return(params)
}

create_progress_html_file <- function(htmlfname, textfname) {
  # Args:
  #   htmlfname: the HTML file name
  #   textfname: the text file name
  template_file <- file.path(system.file('misc', package = 'rstan'), 'stan_progress.html')
  src <- paste(readLines(template_file), collapse = '\n')
  src2 <- sub("%filename%", textfname, sub("%title%", textfname, src, fixed = TRUE), fixed = TRUE)
  cat(src2, file = htmlfname)
}

get_CXX <- function(...) {
  if (.Platform$OS.type != "windows")
    return (system2(file.path(R.home(component = "bin"), "R"),
            args = "CMD config CXX17", stdout = TRUE, stderr = FALSE))

  ls_path <- Sys.which("ls")
  if (ls_path == "") return(NULL)

  install_path <- dirname(dirname(ls_path))
  file.path(install_path,
            paste0('mingw_', Sys.getenv('WIN')), 'bin', 'g++')
}

is.sparc <- function() {
  grepl("^sparc",  R.version$platform)
}

avoid_crash <- function(mod) {
  file.exists(get("packageName", envir = mod)[["path"]]) &&
  as(get("packageName", envir = mod)["info"][1], "character") %in% 
    c("<pointer: (nil)>", "<pointer: 0x0>")
}

# @param x numeric vector
log_sum_exp <- function(x) {
  max_x <- max(x)
  max_x + log(sum(exp(x - max_x)))
}

sample_indices <- function(wts, n_draws) {
  ## Stratified resampling
  ##   Kitagawa, G., Monte Carlo Filter and Smoother for Non-Gaussian
  ##   Nonlinear State Space Models, Journal of Computational and
  ##   Graphical Statistics, 5(1):1-25, 1996.
  K <- length(wts)
  w <- n_draws * wts # expected number of draws from each model
  idx <- rep(NA, n_draws)

  c <- 0
  j <- 0

  for (k in 1:K) {
    c <- c + w[k]
    if (c >= 1) {
      a <- floor(c)
      c <- c - a
      idx[j + 1:a] <- k
      j <- j + a
    }
    if (j < n_draws && c >= runif(1)) {
      c <- c - 1
      j <- j + 1
      idx[j] <- k
    }
  }
  return(idx)
}

test_221 <- function(cppcode) {
  grepl("Code generated by Stan version 2.2", cppcode, fixed = TRUE) ||
  grepl("Code generated by Stan version 3", cppcode, fixed = TRUE)  
}

Try the rstan package in your browser

Any scripts or data that you put into this service are public.

rstan documentation built on May 29, 2024, 11:04 a.m.