R/update_model.R

Defines functions concise capture_warnings update_model

Documented in update_model

#' Fit causal model using 'stan'
#'
#' Takes a model and data and returns a model object with data
#' attached and a posterior model
#'
#' @inheritParams CausalQueries_internal_inherit_params
#'
#' @param data_type Either 'long' (as made by \code{\link{make_data}}) or
#'   'compact' (as made by \code{\link{collapse_data}}). Compact data must
#'   have entries for each member of each strategy family to produce a
#'   valid simplex. When long form data is provided with missingness, missing
#'   data is assumed to be missing at random.
#' @param keep_type_distribution Logical. Whether to keep the (transformed) distribution
#'   of the causal types.  Defaults to `TRUE`
#' @param keep_event_probabilities Logical. Whether to keep the (transformed) distribution
#'   of event probabilities. Defaults to `FALSE`
#' @param keep_fit Logical. Whether to keep the \code{stanfit} object produced
#'   by \link[rstan]{sampling} for further inspection.
#'   See \code{?stanfit} for more details. Defaults to `FALSE`. Note the  \code{stanfit}
#'   object has internal names for parameters (lambda), event probabilities (w), and the
#'   type distribution (types)
#' @param censored_types vector of data types that are selected out of
#'   the data, e.g. \code{c("X0Y0")}
#' @param ... Options passed onto \link[rstan]{sampling} call. For
#'   details see \code{?rstan::sampling}
#'
#' @return An object of class \code{causal_model} with posterior distribution on
#'  parameters and other elements generated by updating; all elements accessible
#'  via \code{get} and \code{inspect}.
#'
#' @seealso \code{\link{make_model}} to create a new model,
#'   \code{\link{summary.causal_model}} provides a summary method for
#'   output objects of class \code{causal_model}
#'
#' @examples
#'  model <- make_model('X->Y')
#'  data_long   <- make_data(model, n = 4)
#'  data_short  <- collapse_data(data_long, model)
#'  model <-  update_model(model, data_long)
#'  model <-  update_model(model, data_short)
#'
#'    # It is possible to implement updating without data, in which
#'    # case the posterior is a stan object that reflects the prior
#'
#'    update_model(model)
#'
#'  \dontrun{
#'
#'    # Censored data types illustrations
#'    # Here we update less than we might because we are aware of filtered data
#'
#'    data <- data.frame(X=rep(0:1, 10), Y=rep(0:1,10))
#'    uncensored <-
#'      make_model("X->Y") |>
#'      update_model(data) |>
#'      query_model(te("X", "Y"), using = "posteriors")
#'
#'    censored <-
#'      make_model("X->Y") |>
#'      update_model(
#'        data,
#'        censored_types = c("X1Y0")) |>
#'      query_model(te("X", "Y"), using = "posteriors")
#'
#'
#'    # Censored data: We learn nothing because the data
#'    # we see is the only data we could ever see
#'    make_model("X->Y") |>
#'      update_model(
#'        data,
#'        censored_types = c("X1Y0", "X0Y0", "X0Y1")) |>
#'      query_model(te("X", "Y"), using = "posteriors")
#'  }
#'
#' @import methods
#' @import Rcpp
#' @import rstantools
#' @importFrom rstan stan
#' @importFrom rstan extract
#' @importFrom rstan sampling
#'
#' @export

update_model <- function(model,
                         data = NULL,
                         data_type = NULL,
                         keep_type_distribution = TRUE,
                         keep_event_probabilities = FALSE,
                         keep_fit = FALSE,
                         censored_types = NULL,
                         ...) {
  # Guess data_type
  if (is.null(data_type)) {
    data_type <- ifelse(all(c("event", "count") %in% names(data)), "compact", "long")
  }

  # Checks on data_types
  if (data_type == "long") {
    if (is.null(data)) {
      message("No data provided")
      data_events <- minimal_event_data(model)

    } else {
      if (nrow(data) == 0 | all(is.na(data))) {
        message("No data provided")
        data_events <- minimal_event_data(model)

      } else {
        if (!any(model$nodes %in% names(data)))
          stop("Data should contain columns corresponding to model nodes")

        data_events <- collapse_data(data, model)
      }
    }
  }

  if (data_type == "compact") {
    if (!all(c("event", "count") %in% names(data))) {
      stop(paste(
        "Compact data should contain columns",
        "`event` and `count`"
      ))
    }

    if (!is.integer(data$count)) {
      data$count <- as.integer(data$count)
      warning("count column should be integer valued; value has been forced to integer")
    }

    data_events <- data
  }

  stan_data <- prep_stan_data(
    model = model,
    data = data_events,
    keep_type_distribution = keep_type_distribution,
    censored_types = censored_types
  )
  # assign fit
  stanfit <- stanmodels$simplexes

  # parameters to drop
  drop_pars <- c("parlam", "parlam2", "gamma", "sum_gammas", "w_full", "w_0")

  if (!keep_event_probabilities) {
    drop_pars <- c(drop_pars, "w")
  }

  if (!keep_type_distribution) {
    drop_pars <- c(drop_pars, "types")
  }


  sampling_args <- set_sampling_args(
    object = stanfit,
    user_dots = list(...),
    data = stan_data,
    pars = drop_pars,
    include = FALSE
  )

  newfit <- do.call(rstan::sampling, sampling_args) |> capture_warnings()

  model$stan_objects  <- list(data = stan_data$data, stan_warnings = newfit$warnings)


  # Keep full fit object
  if (keep_fit) {
    model$stan_objects$stanfit <- newfit$fit
  }

  # Retain posterior distribution
  model$posterior_distribution <-
    extract(newfit$fit, pars = "lambdas")$lambdas |>
    as.data.frame()
  colnames(model$posterior_distribution) <- get_parameter_names(model)

  # Retain type distribution
  if (keep_type_distribution) {
    model$stan_objects$type_posterior <- extract(newfit$fit, pars = "types")$types

    colnames(model$stan_objects$type_posterior) <- colnames(stan_data$P)
  }

  # Retain event (pre-censoring) probabilities
  if (keep_event_probabilities) {
    model$stan_objects$event_probabilities <- extract(newfit$fit, pars = "w")$w
    colnames(model$stan_objects$event_probabilities) <- colnames(stan_data$E)
  }

  # Retain stanfit summary with readable names
  # Identify saved parameters
  params <- colnames(model$posterior_distribution)
  if (keep_event_probabilities) {
    params <- c(params,
                colnames(model$stan_objects$event_probabilities))
  }

  if (keep_type_distribution) {
    params <- c(params, colnames(model$stan_objects$type_posterior))
  }

  params <- c(params, "lp__")

  params_labels <- newfit$fit@sim$fnames_oi

  raname_list <-
    lapply(
      X = list(params, params_labels),
      FUN = function(x)
        vapply(
          X = x,
          FUN = function(y) {
            paste0(y, paste0(rep(" ", times =
                                   max(
                                     vapply(c(params, params_labels), nchar, numeric(1))
                                   ) -
                                   nchar(y)), collapse = ""))
          },
          FUN.VALUE = character(1),
          USE.NAMES = FALSE
        )
    )


  model$stan_objects$stan_summary <- utils::capture.output(print(newfit$fit))

  for (i in seq_along(params)) {
    model$stan_objects$stan_summary <-
      gsub(
        pattern = raname_list[[2]][i],
        replacement = raname_list[[1]][i],
        x = model$stan_objects$stan_summary,
        fixed = TRUE
      )
  }

  return(model)
}


capture_warnings <- function(expr) {
  # Initialize a variable to store captured warnings
  warnings_captured <- character()

  # Capture warnings and the result
  fit <- withCallingHandlers(
    expr,
    warning = function(w) {
      warnings_captured <<- c(warnings_captured, conditionMessage(w))
    }
  )

  # List containing the result of the expression and the warnings
  list(fit = fit, warnings = warnings_captured |> concise())
}


concise <- function(warnings) {
  # Apply transformations based on the patterns
  sapply(warnings, function(warning) {
    if (grepl("not mixed", warning)) {
      sub("not mixed.*", "not mixed", warning)  # Truncate after "not mixed"
    } else if (grepl("too low", warning)) {
      sub("too low.*", "too low", warning)  # Truncate after "too low"
    } else {
      warning  # Leave the warning unchanged if no match
    }
  }, USE.NAMES = FALSE) |>
    paste(collapse = "\n")# Ensure it returns a vector
}

Try the CausalQueries package in your browser

Any scripts or data that you put into this service are public.

CausalQueries documentation built on April 3, 2025, 7:46 p.m.