Nothing
# Copyright 2024 Google LLC
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#' @title Bayesian Logit Model Factory
#' @docType class
#' @export
#' @description A class for creating and managing Bayesian Logit Models
#' @field version im package version used to fit model
#' @field tau_draws Posterior draws for the treatment effect
#' @field mcmChecks MCMC diagnostics
#' @field credible_interval Credible interval for the treatment effect
#' @field prior_eta Prior distribution for eta
#' @field prior_tau Prior distribution for tau
#' @field prior_mean_y Prior distribution for mean y
#' @field eta_draws Posterior draws for eta
#' @field predict_list List of predictions
logit <- R6::R6Class(
classname = "logit",
private = list(
..mcmc_checks = NULL, # a mcmc_checks object
..version = NULL,
..stanfit = NULL,
..tau_draws = NULL,
..stan_data = NULL,
..true_tau = NULL,
..true_eta = NULL,
..credible_interval = NULL,
..tau_prior_mean = NULL,
..tau_prior_sd = NULL,
..eta_draws = NULL,
..alpha_draws = NULL,
..beta_draws = NULL,
..prior_eta = NULL,
..prior_tau = NULL,
..prior_mean_y = NULL,
..var_cols = NULL,
..treatment = NULL,
..predict_list = NULL,
..predictions = NULL
),
active = list(
#' @description Get the package version
version = function() {
return(private$..version)
},
#' @description Get the posterior draws for tau
tau_draws = function() {
return(private$..tau_draws)
},
#' @description Get the MCMC diagnostics
mcmChecks = function() {
return(private$..mcmc_checks)
},
#' @description Get the credible interval
credible_interval = function() {
return(private$..credible_interval)
},
#' @description Get the prior for eta
prior_eta = function() {
return(private$..prior_eta)
},
#' @description Get the prior for tau
prior_tau = function() {
return(private$..prior_tau)
},
#' @description Get the prior for mean y
prior_mean_y = function() {
return(private$..prior_mean_y)
},
#' @description Get the posterior draws for eta
eta_draws = function() {
return(private$..eta_draws)
},
#' @description Get the list of predictions
predict_list = function() {
return(private$..predict_list)
}
),
public = list(
#' @description
#' Create a new Bayesian Logit Model object.
#'
#' @param data Data frame to be used
#' @param y Name of the outcome variable in the data frame
#' @param x Vector of names of all covariates in the data frame
#' @param treatment Name of the treatment indicator variable in the data frame
#' @param mean_alpha Prior mean for alpha
#' @param sd_alpha Prior standard deviation for alpha
#' @param mean_beta Prior mean for beta
#' @param sd_beta Prior standard deviation for beta
#' @param tau_mean Prior mean for the treatment effect estimation
#' @param tau_sd Prior standard deviation for the treatment effect estimation
#' @param seed Seed for Stan fitting
#' @param fit Flag for fitting the data to the model or not
#' @param ... Additional arguments for Stan
#' @return invisible
initialize = function(data, y, x, treatment,
mean_alpha,
sd_alpha,
mean_beta,
sd_beta,
tau_mean,
tau_sd,
seed = 1982,
fit = TRUE,
...) {
private$..version <- packageVersion("imt")
private$..var_cols <- x
private$..treatment <- treatment
cleaned_data <- cleanData(
data = data,
y = y,
treatment = treatment,
x = x,
binary = TRUE
)
stan_data <- list(
N = cleaned_data$N,
y = cleaned_data$Y,
K = cleaned_data$K,
X = cleaned_data$X,
mean_alpha = mean_alpha,
sd_alpha = sd_alpha,
mean_beta = mean_beta,
sd_beta = sd_beta,
treat = cleaned_data$treat_vec,
tau_mean = tau_mean,
tau_sd = tau_sd,
run_estimation = 0
)
private$..tau_prior_sd <- tau_sd
private$..tau_prior_mean <- tau_mean
private$..stan_data <- stan_data
# Draw from the prior
sim_out <- rstan::sampling(stanmodels$logit,
data = private$..stan_data
)
stan_data$run_estimation <- 1
private$..prior_eta <- as.matrix(sim_out, pars = "eta")[, 1]
private$..prior_tau <- as.matrix(sim_out, pars = "tau")[, 1]
private$..prior_mean_y <- as.matrix(sim_out, pars = "mean_y_sim")
# Fit model
if (fit) {
message("Fitting model to the data")
private$..stanfit <- rstan::sampling(stanmodels$logit,
data = stan_data, ...
)
private$..mcmc_checks <- mcmcChecks$new(
fit = private$..stanfit,
pars = "tau"
)
private$..tau_draws <-
as.data.frame(private$..stanfit, pars = "tau") |>
dplyr::pull(tau)
private$..eta_draws <-
as.data.frame(private$..stanfit, pars = "eta") |>
dplyr::pull(eta)
private$..alpha_draws <-
as.data.frame(private$..stanfit, pars = "alpha") |>
dplyr::pull(alpha)
private$..beta_draws <-
as.data.frame(private$..stanfit, pars = "beta")
}
return(invisible())
},
#' @description
#' Plot MCMC trace for the eta and sigma parameters.
#' @param ... Additional arguments for Stan
#' @return A ggplot object.
tracePlot = function(...) {
return(
bayesplot::mcmc_trace(private$..stanfit,
pars = "tau", ...
)
)
},
#' @description
#' Calculates the posterior of an effect being greater than, less than,
#' or within a range defined by thresholds.
#'
#'
#' @param x A numeric vector containing draws from the posterior
# ' distribution of the effect size.
#' @param a Optional. Lower bound for the threshold.
#' @param b Optional. Upper bound for the threshold.
#' @param prior Logical. If TRUE, calculates probabilities based on
#' the prior distribution.
#' If FALSE (default), uses the posterior distribution.
#'
#' @return A character string summarizing the estimated probability
#'
#'
calcProb = function(a = 0, b = NULL, prior = FALSE) {
# Input validation (same as before)
if (is.null(a) && is.null(b)) {
stop("Either 'a' or 'b' must be provided.")
}
if (!is.null(a) && !is.null(b) && b <= a) {
stop("'b' must be greater than 'a'.")
}
if (prior) {
x <- private$..prior_eta
txt <- "Our prior is "
} else {
x <- private$..eta_draws
txt <- "Given the data, we estimate "
}
if (!is.null(a) && is.null(b)) {
p <- scales::percent(mean(x > a))
statement <- glue::glue(
"{txt} that the ",
"probability that the effect is more than {a}",
" percentage points is {p}."
)
} else if (is.null(a) && !is.null(b)) {
p <- mean(x < b)
statement <- glue::glue(
"{txt} that the probability that the",
" effect is less than {b} percentage points is {p}."
)
} else { # both 'a' and 'b' are present
p <- mean(x > a & x < b)
statement <- glue::glue(
"{txt} that the probability that the effect",
" is between {a} and {b} percentage points is {p}."
)
}
return(statement)
},
#' Calculate point estimate of the effect
#'
#' This R6 method calculates the point estimate of the effect size
#' based on the posterior draws of the eta parameter.
#'
#' @param median Logical value. If TRUE (default), the median of
#' the eta draws is returned. If FALSE, the mean is returned.
#'
#' @return A numeric value representing the point estimate.
#'
#' @details This method uses the private$..eta_draws internal variable
#' which contains MCMC draws of the eta parameter representing the
#' effect size. Based on the specified median argument, the method
#' calculates and returns either the median or the mean of the draws.
pointEstimate = function(median = TRUE) {
if (median) {
return(median(private$..eta_draws) * 100)
} else {
return(mean(private$..eta_draws) * 100)
}
},
#' Calculates credible interval for the effect of the intervention
#'
#' This R6 method calculates and returns a formatted statement summarizing
#' the credible interval of a specified width for the effect of the intervention.
#'
#' @param width Numeric value between 0 and 1 representing the desired
#' width of the credible interval (e.g., 0.95 for a 95% credible interval).
#' @param round Integer value indicating the number of decimal places to round
#' the lower and upper bounds of the credible interval.
#'
#' @return A character string with the following information:
#' - The probability associated with the specified width
#' - The lower and upper bounds of the credible interval, rounded to the
#' specified number of decimal places
#'
#' @details This method uses the private$..eta_draws internal variable
#' containing MCMC draws of the eta parameter representing the effect
#' size. It calculates the credible interval, stores it internally, and
#' returns a formatted statement summarizing the findings.
#'
credibleInterval = function(width = 0.75, round = 0) {
private$..credible_interval <- credibleInterval(
draws = private$..eta_draws, width
)
statement <- glue::glue(
"Given the data, we estimate that there is a ",
"{scales::percent(width)} probability that the effect is between ",
"{round(private$..credible_interval$lower_bound * 100, round)} and ",
"{round(private$..credible_interval$upper_bound * 100, round)} ",
"percentage points."
)
return(statement)
},
#' @description
#' Plots impact's prior and posterior distributions.
#'
#' @param tau Logical. If TRUE, plot tau instead of eta
#' @param ... other arguments passed to vizdraws.
#' @return An interactive plot of the prior and posterior distributions.
vizdraws = function(tau = FALSE, ...) {
if (tau) {
p <- vizdraws::vizdraws(
prior = private$..prior_tau,
posterior = private$..tau_draws, ...
)
} else {
p <- vizdraws::vizdraws(
prior = private$..prior_eta * 100,
posterior = private$..eta_draws * 100, ...
)
}
return(p)
},
#' @description
#' Plots lollipop chart for the prior and posterior of the impact being
#' greater or less than a threshold.
#'
#' For more details see [vizdraws::lollipops()].
#' @param threshold cutoff used to calculate the probability. Defaults to
#' zero percent points
#' @param ... other arguments passed to vizdraws.
#' @return A lollipop chart with the prior and posterior probability of
#' the impact being above or below a threshold.
lollipop = function(threshold = 0, ...) {
data <- data.frame(
Name = "Impact",
Prior = mean(private$..prior_eta * 100 > threshold),
Posterior = mean(private$..eta_draws * 100 > threshold)
)
p <- vizdraws::lollipops(data, ...)
return(p)
},
#' @description
#' Plots draws from the prior distribution of the outcome, tau, and impact
#' in percentage points.
plotPrior = function() {
mean_y <- ggplot2::ggplot(
data = tibble::tibble(draws = private$..prior_mean_y),
ggplot2::aes(x = draws)
) +
ggplot2::geom_histogram(bins = 30) +
ggplot2::scale_x_continuous(labels = scales::percent) +
ggplot2::annotate("text",
x = quantile(private$..prior_mean_y, prob = 0.9),
y = 200, label = paste0(
"mean: ",
round(mean(private$..prior_mean_y * 100), 1)
),
hjust = 0
) +
ggplot2::xlab(expression(bar(y))) +
ggplot2::ylab("N draws") +
ggplot2::theme_minimal()
tau <- ggplot2::ggplot(
data = tibble::tibble(draws = private$..prior_tau),
ggplot2::aes(x = draws)
) +
ggplot2::geom_histogram(bins = 30) +
ggplot2::xlab(expression(tau)) +
ggplot2::ylab("N draws") +
ggplot2::theme_minimal()
eta <- ggplot2::ggplot(
data = tibble::tibble(draws = private$..prior_eta * 100),
ggplot2::aes(x = draws)
) +
ggplot2::geom_histogram(bins = 30) +
ggplot2::annotate("text",
x = quantile(private$..prior_eta * 100, prob = 0.9),
y = 200, label = paste0(
"mean: ",
round(mean(private$..prior_eta) * 100, 1)
),
hjust = 0
) +
ggplot2::xlab("Impact in percentage points") +
ggplot2::ylab("N draws") +
ggplot2::theme_minimal()
plots <- ggpubr::ggarrange(eta,
ggpubr::ggarrange(tau, mean_y, labels = c("tau", "Outcome"), ncol = 2),
labels = "Impact", nrow = 2
)
plots <- ggpubr::annotate_figure(plots,
top = ggpubr::text_grob("Draws from prior distributions",
face = "bold", size = 14
)
)
return(plots)
},
#' @description
#' Predict new data
#'
#' @param new_data Data frame to be predicted
#' @param name Group name of the prediction
#' @param M Number of posterior draws to sample from
#' @param ... Additional arguments
#' @return invisible(self)
predict = function(new_data, name = NULL, M = NULL, ...) {
# Check new_data contains correct columns
cols <- c(private$..var_cols, private$..treatment)
if (!all(cols %in% colnames(new_data))) {
missing_cols <- cols[which(!(cols %in% colnames(new_data)))]
statement <- glue::glue(
"{missing_cols} is missing."
)
stop(statement)
}
# Check number of posterior draws is not out of range
if (is.null(M)) {
M <- length(private$..alpha_draws)
}
if (M > length(private$..alpha_draws)) {
M <- length(private$..alpha_draws)
warning(
"Number of posterior draws can not exceed ", M, ". ",
"Setting number of posterior draws to ", M, ". "
)
}
# Create prediction list if NULL
if (is.null(private$..predictions)) {
private$..predictions <- list()
private$..predict_list <- list()
}
# Create group name if NULL
if (is.null(name)) {
name <- paste0("pred", length(private$..predictions) + 1)
warning(
"No name was supplied, assigning predictions to ", name, "."
)
}
# Transform data
X_pred <- new_data[, private$..var_cols]
mean_X <- colMeans((private$..stan_data)$X)
sd_X <- apply((private$..stan_data)$X, 2, sd)
K <- length(private$..var_cols)
for (k in 1:K) {
X_pred[, k] <- (X_pred[, k] - mean_X[k]) / sd_X[k]
}
# Sample posterior predictives
N <- nrow(X_pred)
y_sim <- matrix(NA, nrow = N, ncol = M)
treat <- new_data |> dplyr::pull(private$..treatment)
y_sim <- purrr::pmap(
.l = list(
alpha = private$..alpha_draws,
tau = private$..tau_draws,
beta = purrr::array_branch(private$..beta_draws, 1)
),
.f = logitRng,
X = X_pred, treat = treat, N = N
)
y_sim <- t(matrix(unlist(y_sim), ncol = N, byrow = TRUE))
private$..predict_list[[length(private$..predict_list) + 1]] <- name
private$..predictions[[name]] <- y_sim
return(invisible(self))
},
#' @description
#' Get posterior predictive draws
#' @param name Group name of the prediction
#' @param ... Additional arguments (not used)
#' @return Matrix of posterior predictive draws
getPred = function(name = NULL, ...) {
# Check if predictions exists
if (is.null(private$..predictions)) {
stop("No predictions in the object.")
}
# If name is NULL, supply the last predicted item.
if (is.null(name)) {
name <- private$..predict_list[[length(private$..predict_list)]]
}
# If name is not in the list, throw an error
if (!(name %in% private$..predict_list)) {
stop(name, " is not in the prediction list.")
}
return(private$..predictions[[name]])
},
#' @description
#' Get point estimate, credible interval and prob summary of predictive draws
#'
#' @param name Optional. Group name of the prediction
#' @param subgroup Optional. A boolean vector to get summary on the conditional group average
#' @param median Optional. Logical value for using median or mean
#' @param width Optional. Numeric value for credible interval width
#' @param round Optional. Integer value for rounding
#' @param a Optional. Lower bound threshold
#' @param b Optional. Upper bound threshold
#' @param ... Additional arguments
#' @return A character string with summary information
predSummary = function(name = NULL,
subgroup = NULL,
median = TRUE,
width = 0.75,
round = 0,
a = NULL,
b = NULL,
...) {
# Check if predictions exists
if (is.null(private$..predictions)) {
stop("No predictions in the object.")
}
# If name is NULL, supply the last predicted item.
if (is.null(name)) {
name <- names(private$..predictions)[length(private$..predictions)]
}
# If name is not in the list, throw an error
if (!(name %in% names(private$..predictions))) {
stop(name, " is not in the prediction list.")
}
# Validate subgroup
if (!is.null(subgroup)) {
validate_logical_vector(subgroup, nrow(private$..predictions[[name]]))
} else {
subgroup <- rep(TRUE, nrow(private$..predictions[[name]]))
}
# Get posterior draws of group average
mean_y_draws <- colMeans((private$..predictions[[name]])[subgroup, ])
# Get point estimate of posterior draws
point_estimate_y <- pointEstimate(x = mean_y_draws, median = median)
pe_statement <- glue::glue(
"Given the data, we estimate that for group: {name}, ",
"the point estimate of the group average is ",
"{round(point_estimate_y, round)}%."
)
# Get credible interval
credible_interval <- credibleInterval(mean_y_draws, width)
ci_statement <- glue::glue(
"With {scales::percent(width)} probability, ",
"the point estimate is between ",
"{round(credible_interval$lower_bound * 100, round)} and ",
"{round(credible_interval$upper_bound * 100, round)} ",
"percentage points."
)
prob_statement <- calcProb(mean_y_draws, a, b,
group_name = "group average"
)
statement <- paste(
pe_statement,
ci_statement,
prob_statement
)
return(statement)
},
#' @description
#' Compare the average of the posterior draws of two groups
#'
#' @param name1 Group name of the first prediction to be compared
#' @param name2 Group name of the second prediction to be compared
#' @param subgroup1 Optional. A boolean vector for the first group
#' @param subgroup2 Optional. A boolean vector for the second group
#' @param median Optional. Logical value for using median or mean
#' @param width Optional. Numeric value for credible interval width
#' @param round Optional. Integer value for rounding
#' @param a Optional. Lower bound threshold
#' @param b Optional. Upper bound threshold
#' @param ... Additional arguments
#' @return A character string with comparison summary
predCompare = function(name1, name2,
subgroup1 = NULL,
subgroup2 = NULL,
median = TRUE,
width = 0.75,
round = 0,
a = NULL,
b = NULL,
...) {
# Check if predictions exists
if (is.null(private$..predictions)) {
stop("No predictions in the object.")
}
if (length(private$..predictions) < 2) {
stop("Less than two groups were predicted.")
}
# If name is NULL, supply the last predicted item.
if (is.null(name1) || is.null(name2)) {
stop("Group names need to be provided.")
}
# If name is not in the list, throw an error
if (!(name1 %in% names(private$..predictions))) {
stop(name1, " is not in the prediction list.")
}
if (!(name2 %in% names(private$..predictions))) {
stop(name2, " is not in the prediction list.")
}
# Validate subgroup1
if (!is.null(subgroup1)) {
validate_logical_vector(subgroup1, nrow(private$..predictions[[name1]]))
} else {
subgroup1 <- rep(TRUE, nrow(private$..predictions[[name1]]))
}
# Validate subgroup2
if (!is.null(subgroup2)) {
validate_logical_vector(subgroup2, nrow(private$..predictions[[name2]]))
} else {
subgroup2 <- rep(TRUE, nrow(private$..predictions[[name2]]))
}
# Get posterior draws of group average
eta_draws <- colMeans((private$..predictions[[name1]])[subgroup1, ]) -
colMeans((private$..predictions[[name2]])[subgroup2, ])
# Get point estimate of posterior draws
point_estimate_eta <- pointEstimate(x = eta_draws, median = median)
pe_statement <- glue::glue(
"Given the data, we estimate that ",
"the point estimate of the group difference is ",
"{round(point_estimate_eta, round)}%."
)
# Get credible interval
credible_interval <- credibleInterval(eta_draws, width)
ci_statement <- glue::glue(
"With {scales::percent(width)} probability, ",
"the point estimate is between ",
"{round(credible_interval$lower_bound * 100, round)} and ",
"{round(credible_interval$upper_bound * 100, round)} ",
"percentage points."
)
# Calculate prob
prob_statement <- calcProb(eta_draws, a, b,
group_name = "group difference"
)
statement <- paste(
pe_statement,
ci_statement,
prob_statement
)
return(statement)
}
)
)
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.