R/tidy-posterior-bartMachine.R

Defines functions residual_draws.bartMachine predicted_draws.bartMachine fitted_draws.bartMachine

Documented in fitted_draws.bartMachine predicted_draws.bartMachine residual_draws.bartMachine

#' Get fitted draws from posterior of \code{bartMachine} model
#'
#' @param model A \code{bartMachine} model.
#' @param newdata Data frame to generate fitted values from. If omitted, defaults to the data used to fit the model.
#' @param value The name of the output column for \code{fitted_draws}; default \code{".value"}.
#' @param n Not currently implemented.
#' @param include_newdata Should the newdata be included in the tibble?
#' @param include_sigsqs Should the posterior sigma-squared draw be included?
#' @param ... Not currently in use.
#'
#' @return A tidy data frame (tibble) with fitted values.
#' @export
#'
fitted_draws.bartMachine <- function(model, newdata, value = ".value", ..., n = NULL, include_newdata = TRUE, include_sigsqs = FALSE) {
  if (missing(newdata)) newdata <- stats::model.matrix(model)

  stopifnot(
    is.data.frame(newdata),
    is.character(value),
    is.null(n) | (is.integer(n) & n > 0),
    is.logical(include_newdata),
    is.logical(include_sigsqs)
  )

  # order for columns in output
  col_order <- c(".row", ".chain", ".iteration", ".draw", value)

  posterior <- bartMachine::bart_machine_get_posterior(bart_machine = model, new_data = newdata)

  # bind newdata with fitted, wide format
  out <- dplyr::bind_cols(
    if (include_newdata) dplyr::as_tibble(newdata) else NULL,
    dplyr::as_tibble(posterior$y_hat_posterior_samples, .name_repair = function(names) {
      paste0(".col_iter", as.character(1:length(names)))
    }),
    .row = 1:nrow(newdata)
  )

  # convert to long format
  out <- tidyr::gather(out, key = ".draw", value = !!value, dplyr::starts_with(".col_iter"))

  # add variables to keep to generic standard, remove string in
  out <- dplyr::mutate(out, .chain = NA_integer_, .iteration = NA_integer_, .draw = as.integer(gsub(pattern = ".col_iter", replacement = "", x = .data$.draw)))

  # include sigma^2 if needed
  if (include_sigsqs) {
    sigsq <- dplyr::bind_cols(
      .draw = 1:model$num_iterations_after_burn_in,
      sigsq = bartMachine::get_sigsqs(model)
    )

    out <- dplyr::left_join(out, sigsq, by = ".draw")

    col_order <- c(col_order, "sigsq")
  }

  # rearrange
  out <- dplyr::select(out, -!!col_order, !!col_order)

  # group
  row_groups <- names(out)[!names(out) %in% col_order[col_order != ".row"]]

  out <- dplyr::group_by(out, dplyr::across(row_groups))

  return(out)
}


#' Get predict draws from posterior of \code{bartMachine} model
#'
#' @param object A \code{bartMachine} model.
#' @param newdata Data frame to generate predictions from. If omitted, most model types will generate predictions from the data used to fit the model.
#' @param value The name of the output column for \code{predicted_draws}; default \code{".prediction"}.
#' @param ndraws Not currently implemented.
#' @param include_newdata Should the newdata be included in the tibble?
#' @param include_fitted Should the posterior fitted values be included in the tibble?
#' @param include_sigsqs Should the posterior sigma-squared draw be included?
#' @param ... Not currently in use.
#'
#' @return A tidy data frame (tibble) with predicted values.
#' @export
#'
predicted_draws.bartMachine <- function(object, newdata, value = ".prediction", ..., ndraws = NULL, include_newdata = TRUE, include_fitted = FALSE, include_sigsqs = FALSE) {
  stopifnot(
    is.character(value),
    is.logical(include_fitted),
    is.logical(include_sigsqs)
  )

  # get fitted values (need sigsq to start with)
  out <- fitted_draws.bartMachine(object = object, newdata = newdata, value = ".fit", include_newdata = include_newdata, include_sigsqs = TRUE)

  # draw prediction from estimated variance
  out <- dplyr::mutate(out, !!value := stats::rnorm(n = dplyr::n(), mean = .data$.fit, sd = sqrt(.data$sigsq)))

  # remove sigma^2 value if necessary
  if (!include_sigsqs) out <- dplyr::select(out, -.data$sigsq)

  # remove fitted value if necessary
  if (!include_fitted) out <- dplyr::select(out, -.data$.fit)

  return(out)
}

#' Get residual draw for \code{bartMachine} model
#'
#' @param object \code{bartMachine} model.
#' @param newdata Data frame to generate predictions from. If omitted, original data used to fit the model.
#' @param value Name of the output column for residual_draws; default is \code{.residual}.
#' @param ... Additional arguments passed to the underlying prediction method for the type of model given.
#' @param include_newdata Should the newdata be included in the tibble?
#' @param include_sigsqs Should the posterior sigma-squared draw be included?
#' @param ndraws Not currently implemented.
#'
#' @return Tibble with residuals.
#' @export
#'
residual_draws.bartMachine <- function(object, newdata, value = ".residual", ..., ndraws = NULL, include_newdata = TRUE, include_sigsqs = FALSE) {
  obs <- dplyr::tibble(y = object$y, .row = 1:object$n)

  fitted <- fitted_draws(object, newdata,
    value = ".fitted", n = NULL,
    include_newdata = include_newdata,
    include_sigsqs = include_sigsqs
  )


  out <- dplyr::mutate(
    dplyr::left_join(fitted, obs, by = ".row"),
    !!value := .data$y - .data$.fitted
  )

  dplyr::group_by(out, .data$.row)
}

Try the tidytreatment package in your browser

Any scripts or data that you put into this service are public.

tidytreatment documentation built on March 18, 2022, 6:30 p.m.