R/add_pred_draws_car1.R

Defines functions add_pred_draws_car1

Documented in add_pred_draws_car1

#' Generate predictions for CAR(1) model
#'
#' @param input  A dataframe for which to generate model predictions.
#' @param object A `brms` model object.
#' @param type For `type = epred`, draws are from the expectation of the posterior predictive;
#' for`type = prediction`, draws are from the posterior predictive.
#' @param car1 Logical. Add CAR(1) errors?
#' @param draw_ids Draw IDs from model object. If NULL (the default), all draws are used.
#' @param ... Argument passed on to `tidybayes::add_epred_draws()`
#'
#' @return A dataframe of the type generated by `tidybayes::add_epred_draws()`.
#' @importFrom glue glue
#' @importFrom tidybayes add_epred_draws
#' @importFrom rlang .data
#' @importFrom data.table setorderv
#' @importFrom dplyr %>% group_by ungroup rename select left_join across mutate
#' @importFrom tidyselect matches
#' @importFrom stats na.omit
#' @export
#'
#' @examples
#' library("brms")
#' seed <- 1
#' data <- read.csv(paste0(system.file("extdata", package = "bgamcar1"), "/data.csv"))
#' fit <- fit_stan_model(
#'    paste0(system.file("extdata", package = "bgamcar1"), "/test"),
#'    seed,
#'    bf(y | cens(ycens, y2 = y2) ~ 1),
#'    data,
#'    prior(normal(0, 1), class = Intercept),
#'    car1 = FALSE,
#'    save_warmup = FALSE,
#'    chains = 3
#'  )
#' add_pred_draws_car1(data, fit, car1 = FALSE, draw_ids = 1234)
add_pred_draws_car1 <- function(input,
                                object,
                                type = "epred",
                                car1 = TRUE,
                                draw_ids = NULL,
                                ...) {
  # if (!type %in% c("epred", "prediction")) stop("'type' must be either 'prediction' or 'epred'")
  stopifnot("'type' must be either 'prediction' or 'epred'" = type %in% c("epred", "prediction"))

  .draw <- NULL
  .chain <- NULL
  .iteration <- NULL

  inputnames <- names(input)

  data_vars <- glue::glue("^{inputnames}$") %>% # group output by columns in "data"
    paste(collapse = "|")

  # extract variables from model object:

  varnames <- extract_resp(object) # responses, etc ...

  resp_present <- varnames$resp %in% inputnames

  if (!resp_present) {
    car1 <- FALSE
    message(glue::glue("{varnames$resp} not found in input. Setting car1 = FALSE."))
  }

  params <- extract_params(object, car1, draw_ids)

  if (is.null(draw_ids)) {
    draw_ids <- seq_len(nrow(params))
  }

  gr_vars <- c(".index", varnames$gr_ar) %>% # for CAR1 error
    na.omit()

  order_vars <- c(".index", varnames$gr_ar, varnames$time_ar) %>%
    na.omit() # for CAR1 error

  # generate predictions without AR term:
  preds <- tidybayes::add_epred_draws(
    input, object,
    incl_autocor = FALSE,
    draw_ids = draw_ids, dpar = TRUE,
    ...
  ) %>%
    ungroup() %>%
    rename(.index = .draw) %>%
    select(-c(.chain, .iteration)) %>%
    left_join(params, by = ".index")

  # add CAR(1) process to mean:
  if (type == "epred" & car1) {
    setorderv(preds, order_vars)
    preds <- add_car1(preds, varnames$resp, gr_vars)
  }

  # add CAR(1) residual error:
  if (type == "prediction" & car1) {
    setorderv(preds, order_vars)
    preds <- add_car1_err(preds, car1, gr_vars)
  }

  # add grouping vars:
  preds %>%
    group_by(across(matches(data_vars)))
}
bentrueman/bgamcar1 documentation built on July 6, 2024, 11:16 p.m.