#' Generate predictions for CAR(1) model
#'
#' @param input A dataframe for which to generate model predictions.
#' @param object A `brms` model object.
#' @param type For `type = epred`, draws are from the expectation of the posterior predictive;
#' for`type = prediction`, draws are from the posterior predictive.
#' @param car1 Logical. Add CAR(1) errors?
#' @param draw_ids Draw IDs from model object. If NULL (the default), all draws are used.
#' @param ... Argument passed on to `tidybayes::add_epred_draws()`
#'
#' @return A dataframe of the type generated by `tidybayes::add_epred_draws()`.
#' @importFrom glue glue
#' @importFrom tidybayes add_epred_draws
#' @importFrom rlang .data
#' @importFrom data.table setorderv
#' @importFrom dplyr %>% group_by ungroup rename select left_join across mutate
#' @importFrom tidyselect matches
#' @importFrom stats na.omit
#' @export
#'
#' @examples
#' library("brms")
#' seed <- 1
#' data <- read.csv(paste0(system.file("extdata", package = "bgamcar1"), "/data.csv"))
#' fit <- fit_stan_model(
#' paste0(system.file("extdata", package = "bgamcar1"), "/test"),
#' seed,
#' bf(y | cens(ycens, y2 = y2) ~ 1),
#' data,
#' prior(normal(0, 1), class = Intercept),
#' car1 = FALSE,
#' save_warmup = FALSE,
#' chains = 3
#' )
#' add_pred_draws_car1(data, fit, car1 = FALSE, draw_ids = 1234)
add_pred_draws_car1 <- function(input,
object,
type = "epred",
car1 = TRUE,
draw_ids = NULL,
...) {
# if (!type %in% c("epred", "prediction")) stop("'type' must be either 'prediction' or 'epred'")
stopifnot("'type' must be either 'prediction' or 'epred'" = type %in% c("epred", "prediction"))
.draw <- NULL
.chain <- NULL
.iteration <- NULL
inputnames <- names(input)
data_vars <- glue::glue("^{inputnames}$") %>% # group output by columns in "data"
paste(collapse = "|")
# extract variables from model object:
varnames <- extract_resp(object) # responses, etc ...
resp_present <- varnames$resp %in% inputnames
if (!resp_present) {
car1 <- FALSE
message(glue::glue("{varnames$resp} not found in input. Setting car1 = FALSE."))
}
params <- extract_params(object, car1, draw_ids)
if (is.null(draw_ids)) {
draw_ids <- seq_len(nrow(params))
}
gr_vars <- c(".index", varnames$gr_ar) %>% # for CAR1 error
na.omit()
order_vars <- c(".index", varnames$gr_ar, varnames$time_ar) %>%
na.omit() # for CAR1 error
# generate predictions without AR term:
preds <- tidybayes::add_epred_draws(
input, object,
incl_autocor = FALSE,
draw_ids = draw_ids, dpar = TRUE,
...
) %>%
ungroup() %>%
rename(.index = .draw) %>%
select(-c(.chain, .iteration)) %>%
left_join(params, by = ".index")
# add CAR(1) process to mean:
if (type == "epred" & car1) {
setorderv(preds, order_vars)
preds <- add_car1(preds, varnames$resp, gr_vars)
}
# add CAR(1) residual error:
if (type == "prediction" & car1) {
setorderv(preds, order_vars)
preds <- add_car1_err(preds, car1, gr_vars)
}
# add grouping vars:
preds %>%
group_by(across(matches(data_vars)))
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.