# Part of the rstanarm package for estimating model parameters
# Copyright (C) 2015, 2016, 2017 Trustees of Columbia University
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 3
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
# Set arguments for sampling
#
# Prepare a list of arguments to use with \code{rstan::sampling} via
# \code{do.call}.
#
# @param object The stanfit object to use for sampling.
# @param user_dots The contents of \code{...} from the user's call to
# the \code{stan_*} modeling function.
# @param user_adapt_delta The value for \code{adapt_delta} specified by the
# user.
# @param prior Prior distribution list (can be NULL).
# @param ... Other arguments to \code{\link[rstan]{sampling}} not coming from
# \code{user_dots} (e.g. \code{data}, \code{pars}, \code{init}, etc.)
# @return A list of arguments to use for the \code{args} argument for
# \code{do.call(sampling, args)}.
set_sampling_args <- function(object, prior, user_dots = list(),
user_adapt_delta = NULL, ...) {
args <- list(object = object, ...)
unms <- names(user_dots)
for (j in seq_along(user_dots)) {
args[[unms[j]]] <- user_dots[[j]]
}
defaults <- default_stan_control(prior = prior,
adapt_delta = user_adapt_delta)
if (!"control" %in% unms) {
# no user-specified 'control' argument
args$control <- defaults
} else {
# user specifies a 'control' argument
if (!is.null(user_adapt_delta)) {
# if user specified adapt_delta argument to stan_* then
# set control$adapt_delta to user-specified value
args$control$adapt_delta <- user_adapt_delta
} else {
# use default adapt_delta for the user's chosen prior
args$control$adapt_delta <- defaults$adapt_delta
}
if (is.null(args$control$max_treedepth)) {
# if user's 'control' has no max_treedepth set it to rstanarm default
args$control$max_treedepth <- defaults$max_treedepth
}
}
args$save_warmup <- FALSE
return(args)
}
# Default control arguments for sampling
#
# Called by set_sampling_args to set the default 'control' argument for
# \code{rstan::sampling} if none specified by user. This allows the value of
# \code{adapt_delta} to depend on the prior.
#
# @param prior Prior distribution list (can be NULL).
# @param adapt_delta User's \code{adapt_delta} argument.
# @param max_treedepth Default for \code{max_treedepth}.
# @return A list with \code{adapt_delta} and \code{max_treedepth}.
default_stan_control <- function(prior, adapt_delta = NULL,
max_treedepth = 15L) {
if (!length(prior)) {
if (is.null(adapt_delta)) adapt_delta <- 0.95
} else if (is.null(adapt_delta)) {
adapt_delta <- switch(prior$dist,
"R2" = 0.99,
"hs" = 0.99,
"hs_plus" = 0.99,
"lasso" = 0.99,
"product_normal" = 0.99,
0.95) # default
}
nlist(adapt_delta, max_treedepth)
}
# Test if an object inherits a specific stanreg class
#
# @param x The object to test.
is.stanreg <- function(x) inherits(x, "stanreg")
is.stansurv <- function(x) inherits(x, "stansurv")
is.stanmvreg <- function(x) inherits(x, "stanmvreg")
is.stanjm <- function(x) inherits(x, "stanjm")
# Test if object contains a specific type of submodel
#
# @param x The object to test.
is.jm <- function(x) isTRUE(x$stan_function %in% c("stan_jm"))
is.mvmer <- function(x) isTRUE(x$stan_function %in% c("stan_jm", "stan_mvmer"))
is.surv <- function(x) isTRUE(x$stan_function %in% c("stan_jm", "stan_surv"))
# Throw error if object isn't a stanreg object
#
# @param x The object to test.
validate_stanreg_object <- function(x, call. = FALSE) {
if (!is.stanreg(x))
stop("Object is not a stanreg object.", call. = call.)
}
# Throw error if object isn't a stanmvreg object
#
# @param x The object to test.
validate_stanmvreg_object <- function(x, call. = FALSE) {
if (!is.stanmvreg(x))
stop("Object is not a stanmvreg object.", call. = call.)
}
# Throw error if object isn't a stanjm object
#
# @param x The object to test.
validate_stanjm_object <- function(x, call. = FALSE) {
if (!is.stanjm(x))
stop("Object is not a stanjm object.", call. = call.)
}
# Throw error if object isn't a stansurv object
#
# @param x The object to test.
validate_stansurv_object <- function(x, call. = FALSE) {
if (!is.stansurv(x))
stop("Object is not a stansurv object.", call. = call.)
}
# Test for a given family
#
# @param x A character vector (probably x = family(fit)$family)
is.binomial <- function(x) x == "binomial"
is.gaussian <- function(x) x == "gaussian"
is.gamma <- function(x) x == "Gamma"
is.ig <- function(x) x == "inverse.gaussian"
is.nb <- function(x) x == "neg_binomial_2"
is.poisson <- function(x) x == "poisson"
is.beta <- function(x) x == "beta" || x == "Beta regression"
# test if a stanreg object has class clogit
is_clogit <- function(object) {
is(object, "clogit")
}
# test if a stanreg object has class polr
is_polr <- function(object) {
is(object, "polr")
}
# test if a stanreg object is a scobit model
is_scobit <- function(object) {
validate_stanreg_object(object)
if (!is_polr(object))
return(FALSE)
return("alpha" %in% rownames(object$stan_summary))
}
# Test for a given estimation method
#
# @param x A stanreg object.
used.optimizing <- function(x) {
x$algorithm == "optimizing"
}
used.sampling <- function(x) {
x$algorithm == "sampling"
}
used.variational <- function(x) {
x$algorithm %in% c("meanfield", "fullrank")
}
# Test if stanreg object used stan_(g)lmer
#
# @param x A stanreg object.
is.mer <- function(x) {
stopifnot(is.stanreg(x))
check1 <- inherits(x, "lmerMod")
check2 <- !is.null(x$glmod)
if (check1 && !check2) {
stop("Bug found. 'x' has class 'lmerMod' but no 'glmod' component.")
} else if (!check1 && check2) {
stop("Bug found. 'x' has 'glmod' component but not class 'lmerMod'.")
}
isTRUE(check1 && check2)
}
# Test if stanreg object used stan_nlmer
#
# @param x A stanreg object.
is.nlmer <- function(x) {
is.mer(x) && inherits(x, "nlmerMod")
}
# Consistent error message to use when something is only available for
# models fit using MCMC
#
# @param what An optional message to prepend to the default message.
STOP_sampling_only <- function(what) {
msg <- "only available for models fit using MCMC (algorithm='sampling')."
if (!missing(what))
msg <- paste(what, msg)
stop(msg, call. = FALSE)
}
# Consistent error message to use when something is only available for models
# fit using MCMC or VB but not optimization
#
# @param what An optional message to prepend to the default message.
STOP_not_optimizing <- function(what) {
msg <- "not available for models fit using algorithm='optimizing'."
if (!missing(what))
msg <- paste(what, msg)
stop(msg, call. = FALSE)
}
# Message to issue when fitting model with ADVI but 'QR=FALSE'.
recommend_QR_for_vb <- function() {
message(
"Setting 'QR' to TRUE can often be helpful when using ",
"one of the variational inference algorithms. ",
"See the documentation for the 'QR' argument."
)
}
# Issue warning if high rhat values
#
# @param rhats Vector of rhat values.
# @param threshold Threshold value. If any rhat values are above threshold a
# warning is issued.
check_rhats <- function(rhats, threshold = 1.1, check_lp = FALSE) {
if (!check_lp)
rhats <- rhats[!names(rhats) %in% c("lp__", "log-posterior")]
if (any(rhats > threshold, na.rm = TRUE))
warning("Markov chains did not converge! Do not analyze results!",
call. = FALSE, noBreaks. = TRUE)
}
# If y is a 1D array keep any names but convert to vector (used in stan_glm)
#
# @param y Result of calling model.response
array1D_check <- function(y) {
if (length(dim(y)) == 1L) {
nms <- rownames(y)
dim(y) <- NULL
if (!is.null(nms))
names(y) <- nms
}
return(y)
}
# Check for a binomial model with Y given as proportion of successes and weights
# given as total number of trials
#
binom_y_prop <- function(y, family, weights) {
if (!is.binomial(family$family))
return(FALSE)
yprop <- NCOL(y) == 1L &&
is.numeric(y) &&
any(y > 0 & y < 1) &&
!any(y < 0 | y > 1)
if (!yprop)
return(FALSE)
wtrials <- !identical(weights, double(0)) &&
all(weights > 0) &&
all(abs(weights - round(weights)) < .Machine$double.eps^0.5)
isTRUE(wtrials)
}
# Convert 2-level factor to 0/1
fac2bin <- function(y) {
if (!is.factor(y))
stop("Bug found: non-factor as input to fac2bin.",
call. = FALSE)
if (!identical(nlevels(y), 2L))
stop("Bug found: factor with nlevels != 2 as input to fac2bin.",
call. = FALSE)
as.integer(y != levels(y)[1L])
}
# Check weights argument
#
# @param w The \code{weights} argument specified by user or the result of
# calling \code{model.weights} on a model frame.
# @return If no error is thrown then \code{w} is returned.
validate_weights <- function(w) {
if (missing(w) || is.null(w)) {
w <- double(0)
} else {
if (!is.numeric(w))
stop("'weights' must be a numeric vector.",
call. = FALSE)
if (any(w < 0))
stop("Negative weights are not allowed.",
call. = FALSE)
}
return(w)
}
# Check offset argument
#
# @param o The \code{offset} argument specified by user or the result of calling
# \code{model.offset} on a model frame.
# @param y The result of calling \code{model.response} on a model frame.
# @return If no error is thrown then \code{o} is returned.
validate_offset <- function(o, y) {
if (is.null(o)) {
o <- double(0)
} else {
if (length(o) != NROW(y))
stop(gettextf("Number of offsets is %d but should be %d (number of observations)",
length(o), NROW(y)), domain = NA, call. = FALSE)
}
return(o)
}
# Check family argument
#
# @param f The \code{family} argument specified by user (or the default).
# @return If no error is thrown, then either \code{f} itself is returned (if
# already a family) or the family object created from \code{f} is returned (if
# \code{f} is a string or function).
validate_family <- function(f) {
if (is.character(f))
f <- get(f, mode = "function", envir = parent.frame(2))
if (is.function(f))
f <- f()
if (!is(f, "family"))
stop("'family' must be a family.", call. = FALSE)
return(f)
}
# Check for glmer syntax in formulas for non-glmer models
#
# @param f The model \code{formula}.
# @return Nothing is returned but an error might be thrown
validate_glm_formula <- function(f) {
if (any(grepl("\\|", f)))
stop("Using '|' in model formula not allowed. ",
"Maybe you meant to use 'stan_(g)lmer'?", call. = FALSE)
}
# Check if any variables in a model frame are constants
# (the exception is that a constant variable of all 1's is allowed)
#
# @param mf A model frame or model matrix
# @return If no constant variables are found mf is returned, otherwise an error
# is thrown.
check_constant_vars <- function(mf) {
# don't check if columns are constant for binomial or Surv object
mf1 <- if (NCOL(mf[, 1]) == 2 || survival::is.Surv(mf[, 1]))
mf[, -1, drop=FALSE] else mf
lu1 <- function(x) !all(x == 1) && length(unique(x)) == 1
nocheck <- c("(weights)", "(offset)", "(Intercept)")
sel <- !colnames(mf1) %in% nocheck
is_constant <- apply(mf1[, sel, drop=FALSE], 2, lu1)
if (any(is_constant)) {
stop("Constant variable(s) found: ",
paste(names(is_constant)[is_constant], collapse = ", "),
call. = FALSE)
}
return(mf)
}
# Grep for "b" parameters (ranef)
#
# @param x Character vector (often rownames(fit$stan_summary))
# @param ... Passed to grep
b_names <- function(x, ...) {
grep("^b\\[", x, ...)
}
# Return names of the last dimension in a matrix/array (e.g. colnames if matrix)
#
# @param x A matrix or array
last_dimnames <- function(x) {
ndim <- length(dim(x))
dimnames(x)[[ndim]]
}
# Get the correct column name to use for selecting the median
#
# @param algorithm String naming the estimation algorithm (probably
# \code{fit$algorithm}).
# @return Either \code{"50%"} or \code{"Median"} depending on \code{algorithm}.
select_median <- function(algorithm) {
switch(algorithm,
sampling = "50%",
meanfield = "50%",
fullrank = "50%",
optimizing = "Median",
stop("Bug found (incorrect algorithm name passed to select_median)",
call. = FALSE))
}
# Regex parameter selection
#
# @param x stanreg object
# @param regex_pars Character vector of patterns
grep_for_pars <- function(x, regex_pars) {
validate_stanreg_object(x)
if (used.optimizing(x)) {
warning("'regex_pars' ignored for models fit using algorithm='optimizing'.",
call. = FALSE)
return(NULL)
}
stopifnot(is.character(regex_pars))
out <- unlist(lapply(seq_along(regex_pars), function(j) {
grep(regex_pars[j], rownames(x$stan_summary), value = TRUE)
}))
if (!length(out))
stop("No matches for 'regex_pars'.", call. = FALSE)
return(out)
}
# Combine pars and regex_pars
#
# @param x stanreg object
# @param pars Character vector of parameter names
# @param regex_pars Character vector of patterns
collect_pars <- function(x, pars = NULL, regex_pars = NULL) {
if (is.null(pars) && is.null(regex_pars))
return(NULL)
if (!is.null(pars))
pars[pars == "varying"] <- "b"
if (!is.null(regex_pars))
pars <- c(pars, grep_for_pars(x, regex_pars))
unique(pars)
}
# Get the posterior sample size
#
# @param x A stanreg object
# @return the posterior sample size (or size of sample from approximate posterior)
posterior_sample_size <- function(x) {
validate_stanreg_object(x)
if (used.optimizing(x)) {
return(NROW(x$asymptotic_sampling_dist))
}
pss <- x$stanfit@sim$n_save
if (used.variational(x))
return(pss)
sum(pss - x$stanfit@sim$warmup2)
}
# If a is NULL (and Inf, respectively) return b, otherwise just return a
# @param a,b Objects
`%ORifNULL%` <- function(a, b) {
if (is.null(a)) b else a
}
`%ORifINF%` <- function(a, b) {
if (a == Inf) b else a
}
# Maybe broadcast
#
# @param x A vector or scalar.
# @param n Number of replications to possibly make.
# @return If \code{x} has no length the \code{0} replicated \code{n} times is
# returned. If \code{x} has length 1, the \code{x} replicated \code{n} times
# is returned. Otherwise \code{x} itself is returned.
maybe_broadcast <- function(x, n) {
if (!length(x)) {
rep(0, times = n)
} else if (length(x) == 1L) {
rep(x, times = n)
} else {
x
}
}
# Create a named list using specified names or, if names are omitted, using the
# names of the objects in the list
#
# @param ... Objects to include in the list.
# @return A named list.
nlist <- function(...) {
m <- match.call()
out <- list(...)
no_names <- is.null(names(out))
has_name <- if (no_names) FALSE else nzchar(names(out))
if (all(has_name))
return(out)
nms <- as.character(m)[-1L]
if (no_names) {
names(out) <- nms
} else {
names(out)[!has_name] <- nms[!has_name]
}
return(out)
}
# Check and set scale parameters for priors
#
# @param scale Value of scale parameter (can be NULL).
# @param default Default value to use if \code{scale} is NULL.
# @param link String naming the link function or NULL.
# @return If a probit link is being used, \code{scale} (or \code{default} if
# \code{scale} is NULL) is scaled by \code{dnorm(0) / dlogis(0)}. Otherwise
# either \code{scale} or \code{default} is returned.
set_prior_scale <- function(scale, default, link) {
stopifnot(is.numeric(default), is.character(link) || is.null(link))
if (is.null(scale))
scale <- default
if (isTRUE(link == "probit"))
scale <- scale * dnorm(0) / dlogis(0)
return(scale)
}
# Methods for creating linear predictor
#
# Make linear predictor vector from x and point estimates for beta, or linear
# predictor matrix from x and full posterior sample of beta.
#
# @param beta A vector or matrix or parameter estimates.
# @param x Predictor matrix.
# @param offset Optional offset vector.
# @return A vector or matrix.
linear_predictor <- function(beta, x, offset = NULL) {
UseMethod("linear_predictor")
}
linear_predictor.default <- function(beta, x, offset = NULL) {
eta <- as.vector(if (NCOL(x) == 1L) x * beta else x %*% beta)
if (length(offset))
eta <- eta + offset
return(eta)
}
linear_predictor.matrix <- function(beta, x, offset = NULL) {
if (NCOL(beta) == 1L)
beta <- as.matrix(beta)
eta <- beta %*% t(x)
if (length(offset))
eta <- sweep(eta, 2L, offset, `+`)
return(eta)
}
#' Extract survival response from a stansurv or stanjm object
#'
#' @keywords internal
#' @export
#' @param object A \code{stansurv} or \code{stanjm} object.
#' @param ... Other arguments passed to methods.
#' @return A \code{Surv} object, see \code{?survival::Surv}.
get_surv <- function(object, ...) UseMethod("get_surv")
#' @export
get_surv.stansurv <- function(object, ...) {
model.response(model.frame(object)) %ORifNULL% stop("response not found")
}
#' @export
get_surv.stanjm <- function(object, ...) {
object$survmod$mod$y %ORifNULL% stop("response not found")
}
# Get inverse link function
#
# @param x A stanreg object, family object, or string.
# @param ... Other arguments passed to methods. For a \code{stanmvreg} object
# this can be an integer \code{m} specifying the submodel.
# @return The inverse link function associated with x.
linkinv <- function(x, ...) UseMethod("linkinv")
linkinv.stanreg <- function(x, ...) {
if (is(x, "polr")) polr_linkinv(x) else family(x)$linkinv
}
linkinv.stanmvreg <- function(x, m = NULL, ...) {
ret <- lapply(family(x), `[[`, "linkinv")
stub <- get_stub(x)
if (!is.null(m)) ret[[m]] else list_nms(ret, stub = stub)
}
linkinv.family <- function(x, ...) {
x$linkinv
}
linkinv.character <- function(x, ...) {
stopifnot(length(x) == 1)
polr_linkinv(x)
}
# Make inverse link function for stan_polr models, neglecting any
# exponent in the scobit case
#
# @param x A stanreg object or character scalar giving the "method".
# @return The inverse link function associated with x.
polr_linkinv <- function(x) {
if (is.stanreg(x) && is(x, "polr")) {
method <- x$method
} else if (is.character(x) && length(x) == 1L) {
method <- x
} else {
stop("'x' should be a stanreg object created by stan_polr ",
"or a single string.")
}
if (is.null(method) || method == "logistic")
method <- "logit"
if (method == "loglog")
return(pgumbel)
make.link(method)$linkinv
}
# Wrapper for rstan::summary
# @param stanfit A stanfit object created using rstan::sampling or rstan::vb
# @return A matrix of summary stats
make_stan_summary <- function(stanfit) {
levs <- c(0.5, 0.8, 0.95)
qq <- (1 - levs) / 2
probs <- sort(c(0.5, qq, 1 - qq))
rstan::summary(stanfit, probs = probs, digits = 10)$summary
}
check_reTrms <- function(reTrms) {
stopifnot(is.list(reTrms))
nms <- names(reTrms$cnms)
dupes <- duplicated(nms)
for (i in which(dupes)) {
original <- reTrms$cnms[[nms[i]]]
dupe <- reTrms$cnms[[i]]
overlap <- dupe %in% original
if (any(overlap))
stop("rstanarm does not permit formulas with duplicate group-specific terms.\n",
"In this case ", nms[i], " is used as a grouping factor multiple times and\n",
dupe[overlap], " is included multiple times.\n",
"Consider using || or -1 in your formulas to prevent this from happening.")
}
return(invisible(NULL))
}
#' @importFrom lme4 glmerControl
make_glmerControl <- function(...) {
glmerControl(check.nlev.gtreq.5 = "ignore",
check.nlev.gtr.1 = "stop",
check.nobs.vs.rankZ = "ignore",
check.nobs.vs.nlev = "ignore",
check.nobs.vs.nRE = "ignore", ...)
}
# Check if a fitted model (stanreg object) has weights
#
# @param x stanreg object
# @return Logical. Only TRUE if x$weights has positive length and the elements
# of x$weights are not all the same.
#
model_has_weights <- function(x) {
wts <- x[["weights"]]
if (!length(wts)) {
FALSE
} else if (all(wts == wts[1])) {
FALSE
} else {
TRUE
}
}
# Check that a stanfit object (or list returned by rstan::optimizing) is valid
#
check_stanfit <- function(x) {
if (is.list(x)) {
if (!all(c("par", "value") %in% names(x)))
stop("Invalid object produced please report bug")
}
else {
stopifnot(is(x, "stanfit"))
if (x@mode != 0)
stop("Invalid stanfit object produced please report bug")
}
return(TRUE)
}
# Validate data argument
#
# Make sure that, if specified, data is a data frame. If data is not missing
# then dimension reduction is also performed on variables (i.e., a one column
# matrix inside a data frame is converted to a vector).
#
# @param data User's data argument
# @param if_missing Object to return if data is missing/null
# @return If no error is thrown, data itself is returned if not missing/null,
# otherwise if_missing is returned.
#
drop_redundant_dims <- function(data) {
drop_dim <- sapply(data, function(v) is.matrix(v) && NCOL(v) == 1)
data[, drop_dim] <- lapply(data[, drop_dim, drop=FALSE], drop)
return(data)
}
validate_data <- function(data, if_missing = NULL) {
if (missing(data) || is.null(data)) {
warn_data_arg_missing()
return(if_missing)
}
if (!is.data.frame(data)) {
stop("'data' must be a data frame.", call. = FALSE)
}
drop_redundant_dims(data)
}
# Throw a warning if 'data' argument to modeling function is missing
warn_data_arg_missing <- function() {
warning(
"Omitting the 'data' argument is not recommended ",
"and may not be allowed in future versions of rstanarm. ",
"Some post-estimation functions (in particular 'update', 'loo', 'kfold') ",
"are not guaranteed to work properly unless 'data' is specified as a data frame.",
call. = FALSE
)
}
# Validate newdata argument for posterior_predict, log_lik, etc.
#
# Doesn't check if the correct variables are included (that's done in pp_data),
# just that newdata is either NULL or a data frame with no missing values. Also
# drops any unused dimensions in variables (e.g. a one column matrix inside a
# data frame is converted to a vector).
#
# @param x User's 'newdata' argument
# @return Either NULL or a data frame
#
validate_newdata <- function(x) {
if (is.null(x)) {
return(NULL)
}
if (!is.data.frame(x)) {
stop("If 'newdata' is specified it must be a data frame.", call. = FALSE)
}
if (any(is.na(x))) {
stop("NAs are not allowed in 'newdata'.", call. = FALSE)
}
x <- as.data.frame(x)
drop_redundant_dims(x)
}
#---------------------- for stan_{mvmer,jm} only -----------------------------
# Return a list (or vector if unlist = TRUE) which
# contains the embedded elements in list x named y
fetch <- function(x, y, z = NULL, zz = NULL, null_to_zero = FALSE,
pad_length = NULL, unlist = FALSE) {
ret <- lapply(x, `[[`, y)
if (!is.null(z))
ret <- lapply(ret, `[[`, z)
if (!is.null(zz))
ret <- lapply(ret, `[[`, zz)
if (null_to_zero)
ret <- lapply(ret, function(i) ifelse(is.null(i), 0L, i))
if (!is.null(pad_length)) {
padding <- rep(list(0L), pad_length - length(ret))
ret <- c(ret, padding)
}
if (unlist) unlist(ret) else ret
}
# Wrapper for using fetch with unlist = TRUE
fetch_ <- function(x, y, z = NULL, zz = NULL, null_to_zero = FALSE,
pad_length = NULL) {
fetch(x = x, y = y, z = z, zz = zz, null_to_zero = null_to_zero,
pad_length = pad_length, unlist = TRUE)
}
# Wrapper for using fetch with unlist = TRUE and
# returning array. Also converts logical to integer.
fetch_array <- function(x, y, z = NULL, zz = NULL, null_to_zero = FALSE,
pad_length = NULL) {
val <- fetch(x = x, y = y, z = z, zz = zz, null_to_zero = null_to_zero,
pad_length = pad_length, unlist = TRUE)
if (is.logical(val))
val <- as.integer(val)
as.array(val)
}
# Unlist the result from an lapply call
#
# @param X,FUN,... Same as lapply
uapply <- function(X, FUN, ...) {
unlist(lapply(X, FUN, ...))
}
# Unlist the result from an lapply call not recursive
#
# @param X,FUN,... Same as lapply
nruapply <- function(X, FUN, ...) {
unlist(lapply(X, FUN, ...), recursive = FALSE)
}
# A refactored version of mapply with SIMPLIFY = FALSE
#
# @param FUN,... Same as mapply
# @param arg Passed to MoreArgs
xapply <- function(..., FUN, args = NULL) {
mapply(FUN, ..., MoreArgs = args, SIMPLIFY = FALSE)
}
# Test if family object corresponds to a linear mixed model
#
# @param x A family object
is.lmer <- function(x) {
if (!is(x, "family"))
stop("x should be a family object.", call. = FALSE)
isTRUE((x$family == "gaussian") && (x$link == "identity"))
}
# Split a 2D array into nsplits subarrays, returned as a list
#
# @param x A 2D array or matrix
# @param nsplits An integer, the number of subarrays or submatrices
# @param bycol A logical, if TRUE then the subarrays are generated by
# splitting the columns of x
# @return A list of nsplits arrays or matrices
array2list <- function(x, nsplits, bycol = TRUE) {
len <- if (bycol) ncol(x) else nrow(x)
len_k <- len %/% nsplits
if (!len == (len_k * nsplits))
stop("Dividing x by nsplits does not result in an integer.")
lapply(1:nsplits, function(k) {
if (bycol) x[, (k-1) * len_k + 1:len_k, drop = FALSE] else
x[(k-1) * len_k + 1:len_k, , drop = FALSE]})
}
# Use sweep to multiply a vector or array. Note that usually sweep cannot
# handle a vector, whereas this function definition can.
#
# @param x A vector or array.
# @param y The vector or scalar to multiply 'x' by.
# @param margin The margin of 'x' across which to apply 'y' (only relevant
# if 'x' is an array, i.e. not a vector).
# @return An object of the same class as 'x'.
sweep_multiply <- function(x, y, margin = 2L) {
if (is.vector(x)) return(x * y)
sweep(x, margin, y, `*`)
}
# Convert a standardised quadrature node to an unstandardised value based on
# the specified integral limits
#
# @param x An unstandardised quadrature node
# @param a The lower limit(s) of the integral, possibly a vector
# @param b The upper limit(s) of the integral, possibly a vector
unstandardise_qpts <- function(x, a, b, na.ok = TRUE) {
if (!identical(length(x), 1L) || !is.numeric(x))
stop2("'x' should be a single numeric value.")
if (!length(a) %in% c(1L, length(b)))
stop2("'a' and 'b' should be vectors of length 1, or, be the same length.")
if (!na.ok) {
if (!all(is.numeric(a), is.numeric(b)))
stop2("'a' and 'b' should be numeric.")
if (any((b - a) < 0))
stop2("The upper limits for the integral ('b' values) should be greater than ",
"the corresponding lower limits for the integral ('a' values).")
}
((b - a) / 2) * x + ((b + a) / 2)
}
# Convert a standardised quadrature weight to an unstandardised value based on
# the specified integral limits
#
# @param x An unstandardised quadrature weight
# @param a The lower limit(s) of the integral, possibly a vector
# @param b The upper limit(s) of the integral, possibly a vector
unstandardise_qwts <- function(x, a, b, na.ok = TRUE) {
if (!identical(length(x), 1L) || !is.numeric(x))
stop2("'x' should be a single numeric value.")
if (!length(a) %in% c(1L, length(b)))
stop2("'a' and 'b' should be vectors of length 1, or, be the same length.")
if (!na.ok) {
if (!all(is.numeric(a), is.numeric(b)))
stop2("'a' and 'b' should be numeric.")
if (any((b - a) < 0))
stop2("The upper limits for the integral ('b' values) should be greater than ",
"the corresponding lower limits for the integral ('a' values).")
}
((b - a) / 2) * x
}
# Throw error if parameter isn't a positive scalar
#
# @param x The object to test.
validate_positive_scalar <- function(x, not_greater_than = NULL) {
nm <- deparse(substitute(x))
if (is.null(x))
stop(nm, " cannot be NULL", call. = FALSE)
if (!is.numeric(x))
stop(nm, " should be numeric", call. = FALSE)
if (any(x <= 0))
stop(nm, " should be postive", call. = FALSE)
if (!is.null(not_greater_than)) {
if (!is.numeric(not_greater_than) || (not_greater_than <= 0))
stop("'not_greater_than' should be numeric and postive")
if (!all(x <= not_greater_than))
stop(nm, " should less than or equal to ", not_greater_than, call. = FALSE)
}
}
# Return a matrix or list with the median and prob% CrI bounds for
# each column of a matrix or 2D array
#
# @param x A matrix or 2D array
# @param prob Value between 0 and 1 indicating the desired width of the CrI
# @param return_matrix Logical, if TRUE then a matrix with three columns is
# returned (med, lb, ub) else if FALSE a list with three elements is returned.
median_and_bounds <- function(x, prob, na.rm = FALSE, return_matrix = FALSE) {
if (!any(is.matrix(x), is.array(x)))
stop("x should be a matrix or 2D array.")
med <- apply(x, 2, median, na.rm = na.rm)
lb <- apply(x, 2, quantile, (1 - prob)/2, na.rm = na.rm)
ub <- apply(x, 2, quantile, (1 + prob)/2, na.rm = na.rm)
if (return_matrix) cbind(med, lb, ub) else nlist(med, lb, ub)
}
# Return the stub for variable names from one submodel of a stan_jm model
#
# @param m An integer specifying the number of the longitudinal submodel or
# a character string specifying the submodel (e.g. "Long1", "Event", etc)
# @param stub A character string to prefix to m, if m is supplied as an integer
get_m_stub <- function(m, stub = "Long") {
if (is.null(m)) {
return(NULL)
} else if (is.numeric(m)) {
return(paste0(stub, m, "|"))
} else if (is.character(m)) {
return(paste0(m, "|"))
}
}
# Return the appropriate stub for variable names
#
# @param object A stanmvreg object
get_stub <- function(object) {
if (is.jm(object)) "Long" else if (is.mvmer(object)) "y" else NULL
}
# Separates a names object into separate parts based on the longitudinal,
# event, or association parameters.
#
# @param x Character vector (often rownames(fit$stan_summary))
# @param M An integer specifying the number of longitudinal submodels.
# @param stub The character string used at the start of the names of variables
# in the longitudinal/GLM submodels
# @param ... Arguments passed to grep
# @return A list with x separated out into those names corresponding
# to parameters from the M longitudinal submodels, the event submodel
# or association parameters.
collect_nms <- function(x, M, stub = "Long", ...) {
ppd <- grep(paste0("^", stub, ".{1}\\|mean_PPD"), x, ...)
y <- lapply(1:M, function(m) grep(mod2rx(m, stub = stub), x, ...))
y_extra <- lapply(1:M, function(m)
c(grep(paste0("^", stub, m, "\\|sigma"), x, ...),
grep(paste0("^", stub, m, "\\|shape"), x, ...),
grep(paste0("^", stub, m, "\\|lambda"), x, ...),
grep(paste0("^", stub, m, "\\|reciprocal_dispersion"), x, ...)))
y <- lapply(1:M, function(m) setdiff(y[[m]], c(y_extra[[m]], ppd[m])))
e <- grep(mod2rx("^Event"), x, ...)
e_extra <- c(grep("^Event\\|weibull-shape|^Event\\|b-splines-coef|^Event\\|piecewise-coef", x, ...))
e <- setdiff(e, e_extra)
a <- grep(mod2rx("^Assoc"), x, ...)
b <- b_names(x, ...)
y_b <- lapply(1:M, function(m) b_names_M(x, m, stub = stub, ...))
alpha <- grep("^.{5}\\|\\(Intercept\\)", x, ...)
alpha <- c(alpha, grep(pattern=paste0("^", stub, ".{1}\\|\\(Intercept\\)"), x=x, ...))
beta <- setdiff(c(unlist(y), e, a), alpha)
nlist(y, y_extra, y_b, e, e_extra, a, b, alpha, beta, ppd)
}
# Grep for "b" parameters (ranef), can optionally be specified
# for a specific longitudinal submodel
#
# @param x Character vector (often rownames(fit$stan_summary))
# @param submodel Optional integer specifying which long submodel
# @param ... Passed to grep
b_names_M <- function(x, submodel = NULL, stub = "Long", ...) {
if (is.null(submodel)) {
grep("^b\\[", x, ...)
} else {
grep(paste0("^b\\[", stub, submodel, "\\|"), x, ...)
}
}
# Grep for regression coefs (fixef), can optionally be specified
# for a specific submodel
#
# @param x Character vector (often rownames(fit$stan_summary))
# @param submodel Character vector specifying which submodels
# to obtain the coef names for. Can be "Long", "Event", "Assoc", or
# an integer specifying a specific longitudinal submodel. Specifying
# NULL selects all submodels.
# @param ... Passed to grep
beta_names <- function(x, submodel = NULL, ...) {
if (is.null(submodel)) {
rxlist <- c(mod2rx("^Long"), mod2rx("^Event"), mod2rx("^Assoc"))
} else {
rxlist <- c()
if ("Long" %in% submodel) rxlist <- c(rxlist, mod2rx("^Long"))
if ("Event" %in% submodel) rxlist <- c(rxlist, mod2rx("^Event"))
if ("Assoc" %in% submodel) rxlist <- c(rxlist, mod2rx("^Assoc"))
miss <- setdiff(submodel, c("Long", "Event", "Assoc"))
if (length(miss)) rxlist <- c(rxlist, sapply(miss, mod2rx))
}
unlist(lapply(rxlist, function(y) grep(y, x, ...)))
}
# Converts "Long", "Event" or "Assoc" to the regular expression
# used at the start of variable names for the fitted joint model
#
# @param x The submodel for which the regular expression should be
# obtained. Can be "Long", "Event", "Assoc", or an integer specifying
# a specific longitudinal submodel.
mod2rx <- function(x, stub = "Long") {
if (x == "^Long") {
c("^Long[1-9]\\|")
} else if (x == "^Event") {
c("^Event\\|")
} else if (x == "^Assoc") {
c("^Assoc\\|")
} else if (x == "Long") {
c("Long[1-9]\\|")
} else if (x == "Event") {
c("Event\\|")
} else if (x == "Assoc") {
c("Assoc\\|")
} else if (x == "^y") {
c("^y[1-9]\\|")
} else if (x == "y") {
c("y[1-9]\\|")
} else {
paste0("^", stub, x, "\\|")
}
}
# Return the number of longitudinal submodels
#
# @param object A stanmvreg object
get_M <- function(object) {
validate_stanmvreg_object(object)
return(object$n_markers)
}
# Supplies names for the output list returned by most stanmvreg methods
#
# @param object The list object to which the names are to be applied
# @param M The number of longitudinal/GLM submodels. If NULL then the number of
# longitudinal/GLM submodels is assumed to be equal to the length of object.
# @param stub The character string to use at the start of the names for
# list items related to the longitudinal/GLM submodels
list_nms <- function(object, M = NULL, stub = "Long") {
ok_type <- is.null(object) || is.list(object) || is.vector(object)
if (!ok_type)
stop("'object' argument should be a list or vector.")
if (is.null(object))
return(object)
if (is.null(M))
M <- length(object)
nms <- paste0(stub, 1:M)
if (length(object) > M)
nms <- c(nms, "Event")
names(object) <- nms
object
}
# Removes the submodel identifying text (e.g. "Long1|", "Event|", etc
# from variable names
#
# @param x Character vector (often rownames(fit$stan_summary)) from which
# the stub should be removed
rm_stub <- function(x) {
x <- gsub(mod2rx("^y"), "", x)
x <- gsub(mod2rx("^Long"), "", x)
x <- gsub(mod2rx("^Event"), "", x)
}
# Removes a specified character string from the names of an
# object (for example, a matched call)
#
# @param x The matched call
# @param string The character string to be removed
strip_nms <- function(x, string) {
names(x) <- gsub(string, "", names(x))
x
}
# Check argument contains one of the allowed options
check_submodelopt2 <- function(x) {
if (!x %in% c("long", "event"))
stop("submodel option must be 'long' or 'event'")
}
check_submodelopt3 <- function(x) {
if (!x %in% c("long", "event", "both"))
stop("submodel option must be 'long', 'event' or 'both'")
}
# Error message when the argument contains an object of the incorrect type
STOP_arg <- function(arg_name, type) {
stop(paste0("'", arg_name, "' should be a ", paste0(type, collapse = " or "),
" object or a list of those objects."), call. = FALSE)
}
# Return error msg if both elements of the object are TRUE
STOP_combination_not_allowed <- function(object, x, y) {
if (object[[x]] && object[[y]])
stop("In ", deparse(substitute(object)), ", '", x, "' and '", y,
"' cannot be specified together", call. = FALSE)
}
# Error message when not specifying an argument required for stanmvreg objects
#
# @param arg The argument
STOP_arg_required_for_stanmvreg <- function(arg) {
nm <- deparse(substitute(arg))
msg <- paste0("Argument '", nm, "' required for stanmvreg objects.")
stop2(msg)
}
# Error message when not specifying 'id_var' for stansurv methods that require it
#
# @param arg The argument
STOP_id_var_required <- function() {
stop2("'id_var' must be specified for models with a start-stop response ",
"or with time dependent effects.")
}
# Error message when a function is not yet implemented for stanmvreg objects
#
# @param what A character string naming the function not yet implemented
STOP_if_stanmvreg <- function(what) {
msg <- "not yet implemented for stanmvreg objects."
if (!missing(what))
msg <- paste(what, msg)
stop2(msg)
}
# Error message when a function is not yet implemented for stansurv objects
#
# @param what A character string naming the function not yet implemented
STOP_if_stansurv <- function(what) {
msg <- "not yet implemented for stansurv objects."
if (!missing(what))
msg <- paste(what, msg)
stop2(msg)
}
# Error message when a function is not yet implemented for stan_mvmer models
#
# @param what An optional message to prepend to the default message.
STOP_stan_mvmer <- function(what) {
msg <- "is not yet implemented for models fit using stan_mvmer."
if (!missing(what))
msg <- paste(what, msg)
stop2(msg)
}
# Consistent error message to use when something that is only available for
# models fit using stan_jm
#
# @param what An optional message to prepend to the default message.
STOP_jm_only <- function(what) {
msg <- "can only be used with stan_jm models."
if (!missing(what))
msg <- paste(what, msg)
stop2(msg)
}
# Consistent error message when binomial models with greater than
# one trial are not allowed
#
STOP_binomial <- function() {
stop2("Binomial models with number of trials greater than one ",
"are not allowed (i.e. only bernoulli models are allowed).")
}
# Error message when a required variable is missing from the data frame
#
# @param var The name of the variable that could not be found
STOP_no_var <- function(var) {
stop2("Variable '", var, "' cannot be found in the data frame.")
}
# Error message when values for the time variable are negative
#
# @param var The name of the time variable
STOP_negative_times <- function(var) {
stop2("Values for the time variable (", var, ") should not be negative.")
}
# Error message for dynamic predictions
#
# @param what A reason why the dynamic predictions are not allowed
STOP_dynpred <- function(what) {
stop2(paste("Dynamic predictions are not yet implemented for", what))
}
# Check if individuals in ids argument were also used in model estimation
#
# @param object A stanmvreg object
# @param ids A vector of ids appearing in the pp data
# @param m Integer specifying which submodel to get the estimation IDs from
# @return A logical. TRUE indicates their are new ids in the prediction data,
# while FALSE indicates all ids in the prediction data were used in fitting
# the model. This return is used to determine whether to draw new b pars.
check_pp_ids <- function(object, ids, m = 1) {
ids2 <- unique(model.frame(object, m = m)[[object$id_var]])
if (any(ids %in% ids2))
warning("Some of the IDs in the 'newdata' correspond to individuals in the ",
"estimation dataset. Please be sure you want to obtain subject-",
"specific predictions using the estimated random effects for those ",
"individuals. If you instead meant to marginalise over the distribution ",
"of the random effects (for posterior_predict or posterior_traj), or ",
"to draw new random effects conditional on outcome data provided in ",
"the 'newdata' arguments (for posterior_survfit), then please make ",
"sure the ID values do not correspond to individuals in the ",
"estimation dataset.", immediate. = TRUE)
if (!all(ids %in% ids2)) TRUE else FALSE
}
# Validate newdataLong and newdataEvent arguments
#
# @param object A stanmvreg object
# @param newdataLong A data frame, or a list of data frames
# @param newdataEvent A data frame
# @param duplicate_ok A logical. If FALSE then only one row per individual is
# allowed in the newdataEvent data frame
# @param response A logical specifying whether the longitudinal response
# variable must be included in the new data frame
# @return A list of validated data frames
validate_newdatas <- function(object, newdataLong = NULL, newdataEvent = NULL,
duplicate_ok = FALSE, response = TRUE) {
validate_stanmvreg_object(object)
id_var <- object$id_var
newdatas <- list()
if (!is.null(newdataLong)) {
if (!is(newdataLong, "list"))
newdataLong <- rep(list(newdataLong), get_M(object))
dfcheck <- sapply(newdataLong, is.data.frame)
if (!all(dfcheck))
stop("'newdataLong' must be a data frame or list of data frames.", call. = FALSE)
nacheck <- sapply(seq_along(newdataLong), function(m) {
if (response) { # newdataLong needs the reponse variable
fmL <- formula(object, m = m)
} else { # newdataLong only needs the covariates
fmL <- formula(object, m = m)[c(1,3)]
}
all(!is.na(get_all_vars(fmL, newdataLong[[m]])))
})
if (!all(nacheck))
stop("'newdataLong' cannot contain NAs.", call. = FALSE)
newdatas <- c(newdatas, newdataLong)
}
if (!is.null(newdataEvent)) {
if (!is.data.frame(newdataEvent))
stop("'newdataEvent' must be a data frame.", call. = FALSE)
if (response) { # newdataEvent needs the reponse variable
fmE <- formula(object, m = "Event")
} else { # newdataEvent only needs the covariates
fmE <- formula(object, m = "Event")[c(1,3)]
}
dat <- get_all_vars(fmE, newdataEvent)
dat[[id_var]] <- newdataEvent[[id_var]] # include ID variable in event data
if (any(is.na(dat)))
stop("'newdataEvent' cannot contain NAs.", call. = FALSE)
if (!duplicate_ok && any(duplicated(newdataEvent[[id_var]])))
stop("'newdataEvent' should only contain one row per individual, since ",
"time varying covariates are not allowed in the prediction data.")
newdatas <- c(newdatas, list(Event = newdataEvent))
}
if (length(newdatas)) {
idvar_check <- sapply(newdatas, function(x) id_var %in% colnames(x))
if (!all(idvar_check))
STOP_no_var(id_var)
ids <- lapply(newdatas, function(x) unique(x[[id_var]]))
sorted_ids <- lapply(ids, sort)
if (!length(unique(sorted_ids)) == 1L)
stop("The same subject ids should appear in each new data frame.")
if (!length(unique(ids)) == 1L)
stop("The subject ids should be ordered the same in each new data frame.")
return(newdatas)
} else return(NULL)
}
# Return data frames only including the specified subset of individuals
#
# @param data A data frame, or a list of data frames
# @param ids A vector of ids indicating which individuals to keep
# @param id_var Character string, the name of the ID variable
# @return A data frame, or a list of data frames, depending on the input
subset_ids <- function(data, ids, id_var) {
if (is.null(data))
return(NULL)
is_list <- is(data, "list")
if (!is_list)
data <- list(data) # convert to list
is_df <- sapply(data, inherits, "data.frame")
if (!all(is_df))
stop("'data' should be a data frame, or list of data frames.")
data <- lapply(data, function(x) {
if (!id_var %in% colnames(x)) STOP_no_var(id_var)
sel <- which(!ids %in% x[[id_var]])
if (length(sel))
stop("The following 'ids' do not appear in the data: ",
paste(ids[[sel]], collapse = ", "))
x[x[[id_var]] %in% ids, , drop = FALSE]
})
if (is_list) return(data) else return(data[[1]])
}
# Return a data.table with a key set using the appropriate id/time/grp variables
#
# @param data A data frame.
# @param id_var The name of the ID variable.
# @param grp_var The name of the variable identifying groups clustered within
# individuals.
# @param time_var The name of the time variable.
# @return A data.table (which will be used in a rolling merge against the
# event times and/or quadrature times).
prepare_data_table <- function(data, id_var, time_var, grp_var = NULL) {
if (!requireNamespace("data.table"))
stop("the 'data.table' package must be installed to use this function")
if (!is.data.frame(data))
stop("'data' should be a data frame.")
# check required vars are in the data
if (!id_var %in% colnames(data))
STOP_no_var(id_var)
if (!time_var %in% colnames(data))
STOP_no_var(time_var)
if (!is.null(grp_var) && (!grp_var %in% colnames(data)))
STOP_no_var(grp_var)
# define and set the key for the data.table
key_vars <- if (!is.null(grp_var))
c(id_var, grp_var, time_var) else c(id_var, time_var)
dt <- data.table::data.table(data, key = key_vars)
dt[[time_var]] <- as.numeric(dt[[time_var]]) # ensures no rounding on merge
dt[[id_var]] <- factor(dt[[id_var]]) # ensures matching of ids
if (!is.null(grp_var))
dt[[grp_var]] <- factor(dt[[grp_var]]) # ensures matching of grps
dt
}
# Carry out a rolling merge
#
# @param data A data.table with a set key corresponding to ids, times (and
# possibly also grps).
# @param ids A vector of patient ids to merge against.
# @param times A vector of times to (rolling) merge against.
# @param grps An optional vector of groups clustered within patients to
# merge against. Only relevant when there is clustering within patient ids.
# @return A data.table formed by a merge of ids, (grps), times, and the closest
# preceding (in terms of times) rows in data.
rolling_merge <- function(data, ids, times, grps = NULL) {
if (!requireNamespace("data.table"))
stop("the 'data.table' package must be installed to use this function")
# check data.table is keyed
key_length <- length(data.table::key(data))
val_length <- if (is.null(grps)) 2L else 3L
if (key_length == 0L)
stop2("Bug found: data.table should have a key.")
if (!key_length == val_length)
stop2("Bug found: data.table key is not the same length as supplied keylist.")
# ensure data types are same as returned by the prepare_data_table function
ids <- factor(ids) # ensures matching of ids
times <- as.numeric(times) # ensures no rounding on merge
# carry out the rolling merge against the specified times
if (is.null(grps)) {
tmp <- data.table::data.table(ids, times)
val <- data[tmp, roll = TRUE, rollends = c(TRUE, TRUE)]
} else {
grps <- factor(grps)
tmp <- data.table::data.table(ids, grps, times)
val <- data[tmp, roll = TRUE, rollends = c(TRUE, TRUE)]
}
val
}
# Return an array or list with the time sequence used for posterior predictions
#
# @param increments An integer with the number of increments (time points) at
# which to predict the outcome for each individual
# @param t0,t1 Numeric vectors giving the start and end times across which to
# generate prediction times
# @param simplify Logical specifying whether to return each increment as a
# column of an array (TRUE) or as an element of a list (FALSE)
get_time_seq <- function(increments, t0, t1, simplify = TRUE) {
val <- sapply(0:(increments - 1), function(x, t0, t1) {
t0 + (t1 - t0) * (x / (increments - 1))
}, t0 = t0, t1 = t1, simplify = simplify)
if (simplify && is.vector(val)) {
# need to transform if there is only one individual
val <- t(val)
rownames(val) <- if (!is.null(names(t0))) names(t0) else
if (!is.null(names(t1))) names(t1) else NULL
}
return(val)
}
# Extract parameters from stanmat and return as a list
#
# @param object A stanmvreg or stansurv object
# @param stanmat A matrix of posterior draws, may be provided if the desired
# stanmat is only a subset of the draws from as.matrix(object$stanfit)
# @return A named list
extract_pars <- function(object, ...) {
UseMethod("extract_pars")
}
extract_pars.stansurv <- function(object, stanmat = NULL, means = FALSE) {
validate_stansurv_object(object)
if (is.null(stanmat))
stanmat <- as.matrix(object$stanfit)
if (means)
stanmat <- t(colMeans(stanmat)) # return posterior means
nms_beta <- colnames(object$x)
nms_tde <- get_smooth_name(object$s_cpts, type = "smooth_coefs")
nms_smth <- get_smooth_name(object$s_cpts, type = "smooth_sd")
nms_int <- get_int_name_basehaz(object$basehaz)
nms_aux <- get_aux_name_basehaz(object$basehaz)
alpha <- stanmat[, nms_int, drop = FALSE]
beta <- stanmat[, nms_beta, drop = FALSE]
beta_tde <- stanmat[, nms_tde, drop = FALSE]
aux <- stanmat[, nms_aux, drop = FALSE]
smooth <- stanmat[, nms_smth, drop = FALSE]
nlist(alpha, beta, beta_tde, aux, smooth, stanmat)
}
extract_pars.stanmvreg <- function(object, stanmat = NULL, means = FALSE) {
validate_stanmvreg_object(object)
M <- get_M(object)
if (is.null(stanmat))
stanmat <- as.matrix(object$stanfit)
if (means)
stanmat <- t(colMeans(stanmat)) # return posterior means
nms <- collect_nms(colnames(stanmat), M, stub = get_stub(object))
beta <- lapply(1:M, function(m) stanmat[, nms$y[[m]], drop = FALSE])
b <- lapply(1:M, function(m) stanmat[, nms$y_b[[m]], drop = FALSE])
ebeta <- stanmat[, nms$e, drop = FALSE]
abeta <- stanmat[, nms$a, drop = FALSE]
bhcoef <- stanmat[, nms$e_extra, drop = FALSE]
nlist(beta, ebeta, abeta, bhcoef, b, stanmat)
}
# Promote a character variable to a factor
#
# @param x The variable to potentially promote
promote_to_factor <- function(x) {
if (is.character(x)) as.factor(x) else x
}
# Draw from a multivariate normal distribution
# @param mu A mean vector
# @param Sigma A variance-covariance matrix
# @param df A degrees of freedom
rmt <- function(mu, Sigma, df) {
y <- c(t(chol(Sigma)) %*% rnorm(length(mu)))
u <- rchisq(1, df = df)
return(mu + y / sqrt(u / df))
}
# Evaluate the multivariate t log-density
# @param x A realization
# @param mu A mean vector
# @param Sigma A variance-covariance matrix
# @param df A degrees of freedom
dmt <- function(x, mu, Sigma, df) {
x_mu <- x - mu
p <- length(x)
lgamma(0.5 * (df + p)) - lgamma(0.5 * df) -
0.5 * p * log(df) - 0.5 * p * log(pi) -
0.5 * c(determinant(Sigma, logarithm = TRUE)$modulus) -
0.5 * (df + p) * log1p((x_mu %*% chol2inv(chol(Sigma)) %*% x_mu)[1] / df)
}
# Count the number of unique values
#
# @param x A vector or list
n_distinct <- function(x) {
length(unique(x))
}
# Transpose function that can handle NULL objects
#
# @param x A matrix, a vector, or otherwise (e.g. NULL)
transpose <- function(x) {
if (is.matrix(x) || is.vector(x)) {
t(x)
} else {
x
}
}
# Translate group/factor IDs into integer values
#
# @param x A vector of group/factor IDs
groups <- function(x) {
if (!is.null(x)) {
as.integer(as.factor(x))
} else {
x
}
}
# Drop named attributes listed in ... from the object x
#
# @param x Any object with attributes
# @param ... The named attributes to drop
drop_attributes <- function(x, ...) {
dots <- list(...)
if (length(dots)) {
for (i in dots) {
attr(x, i) <- NULL
}
}
x
}
# Check if x and any objects in ... were all NULL or not
#
# @param x The first object to use in the comparison
# @param ... Any additional objects to include in the comparison
# @param error If TRUE then return an error if all objects aren't
# equal with regard to the 'is.null' test.
# @return If error = TRUE, then an error if all objects aren't
# equal with regard to the 'is.null' test. Otherwise, a logical
# specifying whether all objects were equal with regard to the
# 'is.null' test.
supplied_together <- function(x, ..., error = FALSE) {
dots <- list(...)
for (i in dots) {
if (!identical(is.null(x), is.null(i))) {
if (error) {
nm_x <- deparse(substitute(x))
nm_i <- deparse(substitute(i))
stop2(nm_x, " and ", nm_i, " must be supplied together.")
} else {
return(FALSE) # not supplied together, ie. one NULL and one not NULL
}
}
}
return(TRUE) # supplied together, ie. all NULL or all not NULL
}
# Check variables specified in ... are in the data frame
#
# @param data A data frame
# @param ... The names of the variables
check_vars_are_included <- function(data, ...) {
nms <- names(data)
vars <- list(...)
for (i in vars) {
if (!i %in% nms) {
arg_nm <- deparse(substitute(data))
stop2("Variable '", i, "' is not present in ", arg_nm, ".")
}
}
data
}
# Check whether a vector/matrix/array contains an "(Intercept)"
check_for_intercept <- function(x, logical = FALSE) {
nms <- if (is.matrix(x)) colnames(x) else names(x)
sel <- which("(Intercept)" %in% nms)
if (logical) as.logical(length(sel)) else sel
}
# Drop intercept from a vector/matrix/array of named coefficients
drop_intercept <- function(x) {
sel <- check_for_intercept(x)
if (length(sel) && is.matrix(x)) {
x[, -sel, drop = FALSE]
} else if (length(sel)) {
x[-sel]
} else {
x
}
}
# Return intercept from a vector/matrix/array of named coefficients
return_intercept <- function(x) {
sel <- which("(Intercept)" %in% names(x))
if (length(sel)) x[sel] else NULL
}
# Standardise a coefficient
standardise_coef <- function(x, location = 0, scale = 1)
(x - location) / scale
# Return a one-dimensional array or an empty numeric
array_else_double <- function(x)
if (!length(x)) double(0) else as.array(unlist(x))
# Return a matrix of uniform random variables or an empty matrix
matrix_of_uniforms <- function(nrow = 0, ncol = 0) {
if (nrow == 0 || ncol == 0) {
matrix(0,0,0)
} else {
matrix(runif(nrow * ncol), nrow, ncol)
}
}
# If x is NULL then return an empty object of the specified 'type'
#
# @param x An object to test whether it is null.
# @param type The type of empty object to return if x is null.
convert_null <- function(x, type = c("double", "integer", "matrix",
"arraydouble", "arrayinteger")) {
if (!is.null(x)) {
return(x)
} else if (type == "double") {
return(double(0))
} else if (type == "integer") {
return(integer(0))
} else if (type == "matrix") {
return(matrix(0,0,0))
} else if (type == "arraydouble") {
return(as.array(double(0)))
} else if (type == "arrayinteger") {
return(as.array(integer(0)))
} else {
stop("Input type not valid.")
}
}
# Expand/pad a matrix to the specified number of cols/rows
#
# @param x A matrix or 2D array
# @param cols,rows Integer specifying the desired number
# of columns/rows
# @param value The value to use for the padded cells
# @return A matrix
pad_matrix <- function(x, cols = NULL, rows = NULL,
value = 0L) {
nc <- ncol(x)
nr <- nrow(x)
if (!is.null(cols) && nc < cols) {
pad_mat <- matrix(value, nr, cols - nc)
x <- cbind(x, pad_mat)
nc <- ncol(x) # update nc to reflect new num cols
}
if (!is.null(rows) && nr < rows) {
pad_mat <- matrix(value, rows - nr, nc)
x <- rbind(x, pad_mat)
}
x
}
# Return the cutpoints for a specified number of quantiles of 'x'
#
# @param x A numeric vector.
# @param nq Integer specifying the number of quantiles.
# @return A vector of percentiles corresponding to percentages 100*k/m for
# k=1,2,...,nq-1.
qtile <- function(x, nq = 2) {
if (nq > 1) {
probs <- seq(1, nq - 1) / nq
return(quantile(x, probs = probs))
} else if (nq == 1) {
return(NULL)
} else {
stop("'nq' must be >= 1.")
}
}
# Return the desired spline basis for the given knot locations
get_basis <- function(x, iknots, bknots = range(x),
degree = 3, intercept = TRUE,
type = c("bs", "is", "ms")) {
type <- match.arg(type)
if (type == "bs") {
out <- splines::bs(x, knots = iknots, Boundary.knots = bknots,
degree = degree, intercept = intercept)
} else if (type == "is") {
out <- splines2::iSpline(x, knots = iknots, Boundary.knots = bknots,
degree = degree, intercept = intercept)
} else if (type == "ms") {
out <- splines2::mSpline(x, knots = iknots, Boundary.knots = bknots,
degree = degree, intercept = intercept)
} else {
stop2("'type' is not yet accommodated.")
}
out
}
# Paste character vector collapsing with a comma
comma <- function(x) {
paste(x, collapse = ", ")
}
# Select rows of a matrix
#
# @param x A matrix.
# @param rows Logical or numeric vector stating which rows of 'x' to retain.
keep_rows <- function(x, rows = 1:nrow(x)) {
x[rows, , drop = FALSE]
}
# Drop rows of a matrix
#
# @param x A matrix.
# @param rows Logical or numeric vector stating which rows of 'x' to drop
drop_rows <- function(x, rows = 1:nrow(x)) {
x[!rows, , drop = FALSE]
}
# Replicate rows of a matrix or data frame
#
# @param x A matrix or data frame.
# @param ... Arguments passed to 'rep', namely 'each' or 'times'.
rep_rows <- function(x, ...) {
if (is.null(x) || !nrow(x)) {
return(x)
} else if (is.matrix(x) || is.data.frame(x)) {
x <- x[rep(1:nrow(x), ...), , drop = FALSE]
} else {
stop2("'x' must be a matrix or data frame.")
}
x
}
# Stop without printing call
stop2 <- function(...) stop(..., call. = FALSE)
# Immediate warning without printing call
warning2 <- function(...) warning(..., immediate. = TRUE, call. = FALSE)
# Shorthand for suppress warnings
SW <- function(expr) base::suppressWarnings(expr)
# Check if an object is NULL
is_null <- function(x) {
is.null(x) || ifelse(is.vector(x), all(sapply(x, is.null)), FALSE)
}
# Check if all objects are NULL
all_null <- function(...) {
dots <- list(...)
null_check <- uapply(dots, function(x) {
is.null(x) || ifelse(is.vector(x), all(sapply(x, is.null)), FALSE)
})
all(null_check)
}
# Check if any objects are NULL
any_null <- function(...) {
dots <- list(...)
null_check <- uapply(dots, function(x) {
is.null(x) || ifelse(is.vector(x), all(sapply(x, is.null)), FALSE)
})
any(null_check)
}
# Recursively removes NULL entries from an object
rm_null <- function(x, recursive = TRUE) {
x <- Filter(Negate(is_null), x)
if (recursive) {
x <- lapply(x, function(x) if (is.list(x)) rm_null(x) else x)
}
x
}
# Check if all elements are equal allowing NA and NULL
is_equal <- function(x, y, ...) {
isTRUE(all.equal(x, y, ...))
}
# Check if x behaves like a factor in design matrices
is_like_factor <- function(x) {
is.factor(x) || is.character(x) || is.logical(x)
}
# Check if 'x' is FALSE
isFALSE <- function(x) {
identical(FALSE, x)
}
sw <- function(f) suppressWarnings(f)
# Concatenate (i.e. 'c(...)') but don't demote factors to integers
ulist <- function(...) { unlist(list(...)) }
dlist <- function(x) unlist(x, recursive = FALSE)
# Return the names for the group specific coefficients
#
# @param cnms A named list with the names of the parameters nested within each
# grouping factor.
# @param flevels A named list with the (unique) factor levels nested within each
# grouping factor.
# @return A character vector.
get_ranef_name <- function(cnms, flevels) {
cnms_nms <- names(cnms)
b_nms <- uapply(seq_along(cnms), FUN = function(i) {
nm <- cnms_nms[i]
nms_i <- paste(cnms[[i]], nm)
flevels[[nm]] <- c(gsub(" ", "_", flevels[[nm]]),
paste0("_NEW_", nm))
if (length(nms_i) == 1) {
paste0(nms_i, ":", flevels[[nm]])
} else {
c(t(sapply(nms_i, paste0, ":", flevels[[nm]])))
}
})
c(paste0("b[", b_nms, "]"))
}
# Return the name for the mean_PPD
get_ppd_name <- function(x, ...) {
paste0(x$stub, "|mean_PPD")
}
# Return the name for the intercept parameter
get_int_name_basehaz <- function(x, is_jm = FALSE, ...) {
if (is_jm || has_intercept(x)) "(Intercept)" else NULL
}
get_int_name_ymod <- function(x, ...) {
if (x$intercept_type$number) paste0(x$stub, "|(Intercept)") else NULL
}
get_int_name_emod <- function(x, is_jm = FALSE, ...) {
nm <- get_int_name_basehaz(x$basehaz, is_jm = is_jm)
if (!is.null(nm)) paste0("Event|", nm) else NULL
}
# Return the names for the auxiliary parameters
get_aux_name_basehaz <- function(x, ...) {
switch(get_basehaz_name(x),
exp = NULL,
weibull = "weibull-shape",
gompertz = "gompertz-scale",
ms = paste0("m-splines-coef", seq(x$nvars)),
bs = paste0("b-splines-coef", seq(x$nvars)),
piecewise = paste0("piecewise-coef", seq(x$nvars)),
NA)
}
get_aux_name_ymod <- function(x, ...) {
switch(x$family$family,
gaussian = paste0(x$stub, "|sigma"),
Gamma = paste0(x$stub, "|shape"),
inverse.gaussian = paste0(x$stub, "|lambda"),
neg_binomial_2 = paste0(x$stub, "|reciprocal_dispersion"),
NULL)
}
get_aux_name_emod <- function(x, ...) {
nms <- get_aux_name_basehaz(x$basehaz)
if (!is.null(nms)) paste0("Event|", nms) else NULL
}
# Return the names for the coefficients
get_beta_name_ymod <- function(x) {
nms <- colnames(x$x$xtemp)
if (!is.null(nms)) paste0(x$stub, "|", nms) else NULL
}
get_beta_name_emod <- function(x, ...) {
nms <- colnames(x$x)
if (!is.null(nms)) paste0("Event|", nms) else NULL
}
# Return the names for the association parameters
get_assoc_name <- function(a_mod, assoc, ...) {
M <- length(a_mod)
a <- assoc
ev <- "etavalue"
es <- "etaslope"
ea <- "etaauc"
mv <- "muvalue"
ms <- "muslope"
ma <- "muauc"
evd <- "etavalue_data"
esd <- "etaslope_data"
mvd <- "muvalue_data"
msd <- "muslope_data"
evev <- "etavalue_etavalue"
evmv <- "etavalue_muvalue"
mvev <- "muvalue_etavalue"
mvmv <- "muvalue_muvalue"
p <- function(...) paste0(...)
indx <- function(x, m) paste0("Long", assoc["which_interactions",][[m]][[x]])
cnms <- function(x, m) colnames(a_mod[[m]][["X_data"]][[x]])
nms <- character()
for (m in 1:M) {
stub <- paste0("Assoc|Long", m, "|")
# order matters here! (needs to line up with the monitored stanpars)
if (a[ev, ][[m]]) nms <- c(nms, p(stub, ev ))
if (a[evd, ][[m]]) nms <- c(nms, p(stub, ev, ":", cnms(evd, m) ))
if (a[evev,][[m]]) nms <- c(nms, p(stub, ev, ":", indx(evev, m), "|", ev))
if (a[evmv,][[m]]) nms <- c(nms, p(stub, ev, ":", indx(evmv, m), "|", mv))
if (a[es, ][[m]]) nms <- c(nms, p(stub, es ))
if (a[esd, ][[m]]) nms <- c(nms, p(stub, es, ":", cnms(esd, m) ))
if (a[ea, ][[m]]) nms <- c(nms, p(stub, ea ))
if (a[mv, ][[m]]) nms <- c(nms, p(stub, mv ))
if (a[mvd, ][[m]]) nms <- c(nms, p(stub, mv, ":", cnms(mvd, m) ))
if (a[mvev,][[m]]) nms <- c(nms, p(stub, mv, ":", indx(mvev, m), "|", ev))
if (a[mvmv,][[m]]) nms <- c(nms, p(stub, mv, ":", indx(mvmv, m), "|", mv))
if (a[ms, ][[m]]) nms <- c(nms, p(stub, ms ))
if (a[msd, ][[m]]) nms <- c(nms, p(stub, ms, ":", cnms(msd, m) ))
if (a[ma, ][[m]]) nms <- c(nms, p(stub, ma ))
}
nms
}
# Return the list with summary information about the baseline hazard
#
# @return A named list.
get_basehaz <- function(x) {
if (is.stansurv(x))
return(x$basehaz)
if (is.stanjm(x))
return(x$survmod$basehaz)
stop("Bug found: could not find basehaz.")
}
# Return the name of the baseline hazard
#
# @return A character string.
get_basehaz_name <- function(x) {
if (is.character(x))
return(x)
if (is.stansurv(x))
return(x$basehaz$type_name)
if (is.stanjm(x))
return(x$survmod$basehaz$type_name)
if (is.character(x$type_name))
return(x$type_name)
stop("Bug found: could not resolve basehaz name.")
}
# Add the variables in ...'s to the RHS of a model formula
#
# @param x A model formula.
# @param ... Character strings, the variable names.
addto_formula <- function(x, ...) {
rhs_terms <- terms(reformulate_rhs(rhs(x)))
intercept <- attr(rhs_terms, "intercept")
term_labels <- attr(rhs_terms, "term.labels")
reformulate(c(term_labels, c(...)), response = lhs(x), intercept = intercept)
}
# Shorthand for as.integer, as.double, as.matrix, as.array
ai <- function(x, ...) as.integer(x, ...)
ad <- function(x, ...) as.double (x, ...)
am <- function(x, ...) as.matrix (x, ...)
aa <- function(x, ...) as.array (x, ...)
# Sample rows from a two-dimensional object
#
# @param x The two-dimensional object (e.g. matrix, data frame, array).
# @param size Integer specifying the number of rows to sample.
# @param replace Should the rows be sampled with replacement?
# @return A two-dimensional object with 'size' rows and 'ncol(x)' columns.
sample_rows <- function(x, size, replace = FALSE) {
samp <- sample(nrow(x), size, replace)
x[samp, , drop = FALSE]
}
# Sample rows from a stanmat object
#
# @param object A stanreg object.
# @param draws The number of draws/rows to sample from the stanmat.
# @param default_draws Integer or NA. If 'draws' is NULL then the number of
# rows sampled from the stanmat is equal to
# min(default_draws, posterior_sample_size, na.rm = TRUE)
# @return A matrix with 'draws' rows and 'ncol(stanmat)' columns.
sample_stanmat <- function(object, draws = NULL, default_draws = NA) {
S <- posterior_sample_size(object)
if (is.null(draws))
draws <- min(default_draws, S, na.rm = TRUE)
if (draws > S)
stop2("'draws' should be <= posterior sample size (", S, ").")
stanmat <- as.matrix(object$stanfit)
if (isTRUE(draws < S)) {
stanmat <- sample_rows(stanmat, draws)
}
stanmat
}
# Method to truncate a numeric vector at defined limits
#
# @param con A numeric vector.
# @param lower Scalar, the lower limit for the returned vector.
# @param upper Scalar, the upper limit for the returned vector.
# @return A numeric vector.
truncate.numeric <- function(con, lower = NULL, upper = NULL) {
if (!is.null(lower)) con[con < lower] <- lower
if (!is.null(upper)) con[con > upper] <- upper
con
}
# Transpose only if 'x' is a vector
transpose_vector <- function(x) {
if (is.vector(x)) return(t(x)) else return(x)
}
# Simplified conditional for 'if (is.null(...))'
if_null <- function(test, yes, no) {
if (is.null(test)) yes else no
}
# Replace entries of 'x' based on a (possibly) vectorised condition
#
# @param x The vector, matrix, or array.
# @param condition The logical condition, possibly a logical vector.
# @param replacement The value to replace with, where the condition is TRUE.
# @param margin The margin of 'x' on which to apply the condition.
# @return The same class as 'x' but possibly with some entries replaced.
replace_where <- function(x, condition, replacement, margin = 1L) {
switch(margin,
x[condition] <- replacement,
x[,condition] <- replacement,
stop("Cannot handle 'margin' > 2."))
x
}
# Calculate row means, but don't simplify to a vector
row_means <- function(x, na.rm = FALSE) {
mns <- rowMeans(x, na.rm = na.rm)
if (is.matrix(x)) {
return(matrix(mns, ncol = 1))
} else if (is.array(x)) {
return(array(mns, dim = c(nrow(x), 1)))
} else if (is.data.frame(x)) {
return(data.frame(mns))
} else {
stop2("Cannot handle objects of class: ", class(x))
}
}
# Calculate column means, but don't simplify to a vector
col_means <- function(x, na.rm = FALSE) {
mns <- colMeans(x, na.rm = na.rm)
if (is.matrix(x)) {
return(matrix(mns, nrow = 1))
} else if (is.array(x)) {
return(array(mns, dim = c(1, ncol(x))))
} else {
stop2("Cannot handle objects of class: ", class(x))
}
}
# Set row or column names on an object
set_rownames <- function(x, names) { rownames(x) <- names; x }
set_colnames <- function(x, names) { colnames(x) <- names; x }
# Select rows or columns by name or index
select_rows <- function(x, rows) { x[rows, , drop = FALSE] }
select_cols <- function(x, cols) { x[, cols, drop = FALSE] }
# Add attributes, but only if 'condition' is TRUE
structure2 <- function(.Data, condition, ...) {
if (condition) structure(.Data, ...) else .Data
}
# Split a vector in a specified number of (equally sized) segments
#
# @param x The vector to split.
# @param n_segments Integer specifying the desired number of segments.
# @return A list of vectors, see `?split`.
split_vector <- function(x, n_segments = 1) {
split(x, rep(1:n_segments, each = length(x) / n_segments))
}
# Replace an NA object, or NA entries in a vector
#
# @param x The vector with elements to potentially replace.
# @param replace_with The replacement value.
replace_na <- function(x, replace_with = "0") {
if (is.na(x)) {
x <- replace_with
} else {
x[is.na(x)] <- replace_with
}
x
}
# Replace an NULL object, or NULL entries in a vector
#
# @param x The vector with elements to potentially replace.
# @param replace_with The replacement value.
replace_null <- function(x, replace_with = "0") {
if (is.null(x)) {
x <- replace_with
} else {
x[is.null(x)] <- replace_with
}
x
}
# Add an intercept column onto a predictor matrix
add_intercept <- function(x) {
stopifnot(is.matrix(x))
cbind(rep(1, nrow(x)), x)
}
# Replace named elements of 'x' with 'y'
replace_named_elements <- function(x, y) { x[names(y)] <- y; x }
# Invert 'is.null'
not.null <- function(x) { !is.null(x) }
# Shorthand for as.integer, as.double, as.matrix, as.array
ai <- function(x, ...) as.integer(x, ...)
ad <- function(x, ...) as.double(x, ...)
am <- function(x, ...) as.matrix(x, ...)
aa <- function(x, ...) as.array(x, ...)
# Return a vector of 0's or 1's
zeros <- function(n) rep(0, times = n)
ones <- function(n) rep(1, times = n)
# Check if all elements of a vector are zeros
all_zero <- function(x) all(x == 0)
# Return the maximum integer or double
max_integer <- function() .Machine$integer.max
max_double <- function() .Machine$double.xmax
# Check for scalar or string
is.scalar <- function(x) { isTRUE(is.numeric(x) && (length(x) == 1)) }
is.string <- function(x) { isTRUE(is.character(x) && (length(x) == 1)) }
# Safe deparse
safe_deparse <- function(expr) deparse(expr, 500L)
# Evaluate a character string
eval_string <- function(x) eval(parse(text = x))
# Mutate, similar to dplyr (ie. append a new variable(s) to the data frame)
mutate <- function(x, ..., names_eval = FALSE, n = 4) {
dots <- list(...)
if (names_eval) { # evaluate names in parent frame
nms <- sapply(names(dots), function(x) eval.parent(as.name(x), n = n))
} else {
nms <- names(dots)
}
for (i in seq_along(dots))
x[[nms[[i]]]] <- dots[[i]]
x
}
mutate_ <- function(x, ...) mutate(x, ..., names_eval = TRUE, n = 5)
# Sort the rows of a data frame based on the variables specified in dots.
# (For convenience, any variables in ... that are not in the data frame
# are ignored, rather than throwing an error - dangerous but convenient)
#
# @param x A data frame.
# @param ... Character strings; names of the columns of 'x' on which to sort.
# @param skip Logical, if TRUE then any strings in the ...'s that are not
# present as variables in the data frame are ignored, rather than throwing
# an error - somewhat dangerous, but convenient.
# @return A data frame.
row_sort <- function(x, ...) {
stopifnot(is.data.frame(x))
vars <- lapply(list(...), as.name) # convert string to name
x[with(x, do.call(order, vars)), , drop = FALSE]
}
# Order the cols of a data frame in the order specified in the dots. Any
# remaining columns of 'x' are retained as is and included after the ... columns.
#
# @param x A data frame.
# @param ... Character strings; the desired order of the columns of 'x' by name.
# @param skip Logical, if TRUE then any strings in the ...'s that are not
# present as variables in the data frame are ignored, rather than throwing
# an error - somewhat dangerous, but convenient.
# @return A data frame.
col_sort <- function(x, ...) {
stopifnot(is.data.frame(x))
vars1 <- unlist(list(...))
vars2 <- setdiff(colnames(x), vars1) # select the leftover columns in x
x[, c(vars1, vars2), drop = FALSE]
}
# Calculate the specified quantiles for each column of an array
col_quantiles <- function(x, probs, na.rm = FALSE, return_matrix = FALSE) {
stopifnot(is.matrix(x) || is.array(x))
out <- lapply(probs, function(q) apply(x, 2, quantile, q, na.rm = na.rm))
if (return_matrix) do.call(cbind, out) else out
}
col_quantiles_ <- function(x, probs) {
col_quantiles(x, probs, na.rm = TRUE, return_matrix = TRUE)
}
# Append a string (prefix) to the column names of a matrix or array
append_prefix_to_colnames <- function(x, str) {
if (ncol(x)) set_colnames(x, paste0(str, colnames(x))) else x
}
# Return the name of the calling function as a string
get_calling_fun <- function(which = -2) {
fn <- tryCatch(sys.call(which = which)[[1L]], error = function(e) NULL)
if (!is.null(fn)) safe_deparse(fn) else NULL
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.