Nothing
#' Fit causal model using 'stan'
#'
#' Takes a model and data and returns a model object with data
#' attached and a posterior model
#'
#' @inheritParams CausalQueries_internal_inherit_params
#'
#' @param data_type Either 'long' (as made by \code{\link{make_data}}) or
#' 'compact' (as made by \code{\link{collapse_data}}). Compact data must
#' have entries for each member of each strategy family to produce a
#' valid simplex. When long form data is provided with missingness, missing
#' data is assumed to be missing at random.
#' @param keep_type_distribution Logical. Whether to keep the (transformed) distribution
#' of the causal types. Defaults to `TRUE`
#' @param keep_event_probabilities Logical. Whether to keep the (transformed) distribution
#' of event probabilities. Defaults to `FALSE`
#' @param keep_fit Logical. Whether to keep the \code{stanfit} object produced
#' by \link[rstan]{sampling} for further inspection.
#' See \code{?stanfit} for more details. Defaults to `FALSE`. Note the \code{stanfit}
#' object has internal names for parameters (lambda), event probabilities (w), and the
#' type distribution (types)
#' @param censored_types vector of data types that are selected out of
#' the data, e.g. \code{c("X0Y0")}
#' @param ... Options passed onto \link[rstan]{sampling} call. For
#' details see \code{?rstan::sampling}
#'
#' @return An object of class \code{causal_model} with posterior distribution on
#' parameters and other elements generated by updating; all elements accessible
#' via \code{get} and \code{inspect}.
#'
#' @seealso \code{\link{make_model}} to create a new model,
#' \code{\link{summary.causal_model}} provides a summary method for
#' output objects of class \code{causal_model}
#'
#' @examples
#' model <- make_model('X->Y')
#' data_long <- make_data(model, n = 4)
#' data_short <- collapse_data(data_long, model)
#' model <- update_model(model, data_long)
#' model <- update_model(model, data_short)
#'
#' # It is possible to implement updating without data, in which
#' # case the posterior is a stan object that reflects the prior
#'
#' update_model(model)
#'
#' \dontrun{
#'
#' # Censored data types illustrations
#' # Here we update less than we might because we are aware of filtered data
#'
#' data <- data.frame(X=rep(0:1, 10), Y=rep(0:1,10))
#' uncensored <-
#' make_model("X->Y") |>
#' update_model(data) |>
#' query_model(te("X", "Y"), using = "posteriors")
#'
#' censored <-
#' make_model("X->Y") |>
#' update_model(
#' data,
#' censored_types = c("X1Y0")) |>
#' query_model(te("X", "Y"), using = "posteriors")
#'
#'
#' # Censored data: We learn nothing because the data
#' # we see is the only data we could ever see
#' make_model("X->Y") |>
#' update_model(
#' data,
#' censored_types = c("X1Y0", "X0Y0", "X0Y1")) |>
#' query_model(te("X", "Y"), using = "posteriors")
#' }
#'
#' @import methods
#' @import Rcpp
#' @import rstantools
#' @importFrom rstan stan
#' @importFrom rstan extract
#' @importFrom rstan sampling
#'
#' @export
update_model <- function(model,
data = NULL,
data_type = NULL,
keep_type_distribution = TRUE,
keep_event_probabilities = FALSE,
keep_fit = FALSE,
censored_types = NULL,
...) {
# Guess data_type
if (is.null(data_type)) {
data_type <- ifelse(all(c("event", "count") %in% names(data)), "compact", "long")
}
# Checks on data_types
if (data_type == "long") {
if (is.null(data)) {
message("No data provided")
data_events <- minimal_event_data(model)
} else {
if (nrow(data) == 0 | all(is.na(data))) {
message("No data provided")
data_events <- minimal_event_data(model)
} else {
if (!any(model$nodes %in% names(data)))
stop("Data should contain columns corresponding to model nodes")
data_events <- collapse_data(data, model)
}
}
}
if (data_type == "compact") {
if (!all(c("event", "count") %in% names(data))) {
stop(paste(
"Compact data should contain columns",
"`event` and `count`"
))
}
if (!is.integer(data$count)) {
data$count <- as.integer(data$count)
warning("count column should be integer valued; value has been forced to integer")
}
data_events <- data
}
stan_data <- prep_stan_data(
model = model,
data = data_events,
keep_type_distribution = keep_type_distribution,
censored_types = censored_types
)
# assign fit
stanfit <- stanmodels$simplexes
# parameters to drop
drop_pars <- c("parlam", "parlam2", "gamma", "sum_gammas", "w_full", "w_0")
if (!keep_event_probabilities) {
drop_pars <- c(drop_pars, "w")
}
if (!keep_type_distribution) {
drop_pars <- c(drop_pars, "types")
}
sampling_args <- set_sampling_args(
object = stanfit,
user_dots = list(...),
data = stan_data,
pars = drop_pars,
include = FALSE
)
newfit <- do.call(rstan::sampling, sampling_args) |> capture_warnings()
model$stan_objects <- list(data = stan_data$data, stan_warnings = newfit$warnings)
# Keep full fit object
if (keep_fit) {
model$stan_objects$stanfit <- newfit$fit
}
# Retain posterior distribution
model$posterior_distribution <-
extract(newfit$fit, pars = "lambdas")$lambdas |>
as.data.frame()
colnames(model$posterior_distribution) <- get_parameter_names(model)
# Retain type distribution
if (keep_type_distribution) {
model$stan_objects$type_posterior <- extract(newfit$fit, pars = "types")$types
colnames(model$stan_objects$type_posterior) <- colnames(stan_data$P)
}
# Retain event (pre-censoring) probabilities
if (keep_event_probabilities) {
model$stan_objects$event_probabilities <- extract(newfit$fit, pars = "w")$w
colnames(model$stan_objects$event_probabilities) <- colnames(stan_data$E)
}
# Retain stanfit summary with readable names
# Identify saved parameters
params <- colnames(model$posterior_distribution)
if (keep_event_probabilities) {
params <- c(params,
colnames(model$stan_objects$event_probabilities))
}
if (keep_type_distribution) {
params <- c(params, colnames(model$stan_objects$type_posterior))
}
params <- c(params, "lp__")
params_labels <- newfit$fit@sim$fnames_oi
raname_list <-
lapply(
X = list(params, params_labels),
FUN = function(x)
vapply(
X = x,
FUN = function(y) {
paste0(y, paste0(rep(" ", times =
max(
vapply(c(params, params_labels), nchar, numeric(1))
) -
nchar(y)), collapse = ""))
},
FUN.VALUE = character(1),
USE.NAMES = FALSE
)
)
model$stan_objects$stan_summary <- utils::capture.output(print(newfit$fit))
for (i in seq_along(params)) {
model$stan_objects$stan_summary <-
gsub(
pattern = raname_list[[2]][i],
replacement = raname_list[[1]][i],
x = model$stan_objects$stan_summary,
fixed = TRUE
)
}
return(model)
}
capture_warnings <- function(expr) {
# Initialize a variable to store captured warnings
warnings_captured <- character()
# Capture warnings and the result
fit <- withCallingHandlers(
expr,
warning = function(w) {
warnings_captured <<- c(warnings_captured, conditionMessage(w))
}
)
# List containing the result of the expression and the warnings
list(fit = fit, warnings = warnings_captured |> concise())
}
concise <- function(warnings) {
# Apply transformations based on the patterns
sapply(warnings, function(warning) {
if (grepl("not mixed", warning)) {
sub("not mixed.*", "not mixed", warning) # Truncate after "not mixed"
} else if (grepl("too low", warning)) {
sub("too low.*", "too low", warning) # Truncate after "too low"
} else {
warning # Leave the warning unchanged if no match
}
}, USE.NAMES = FALSE) |>
paste(collapse = "\n")# Ensure it returns a vector
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.