Nothing
#' @title Plot method for msm objects
#' @description Generate histogram of weights and weighted Kaplan-Meier (KM)
#' plot for censored outcomes of an msm object.
#'
#' @param x An object of class \code{msm}.
#' @param time_unit The time unit used in the input data.
#' Options are "day" (default), "week", "month", or "year".
#' @param show_hr Logical; whether to show hazard ratio on the KM plot.
#' Default is TRUE.
#' @param show_risk Logical; whether to show number at risk table
#' below the KM plot. Default is TRUE.
#' @param ... Ensures that all arguments starting from "..." are named.
#'
#' @return A list of two ggplot2 objects, one for histogram and the other
#' for KM plot.
#'
#' @keywords internal
#'
#' @author Kaifeng Lu, \email{kaifenglu@@gmail.com}
#'
#' @method plot msm
#' @export
plot.msm <- function(x, time_unit = "day",
show_hr = TRUE, show_risk = TRUE, ...) {
if (!inherits(x, "msm")) {
stop("x must be of class 'msm'")
}
alpha <- x$settings$alpha
conflev <- 100*(1-alpha)
treat_var <- x$settings$treat
if (!is.null(x$data_outcome) && nrow(x$data_outcome) > 0) {
# --- weight plot ---
df <- x$data_outcome
arm <- x$settings$data[[treat_var]]
df_arm <- data.frame(arm = c(df[[treat_var]][df$treated == 0][1],
df[[treat_var]][df$treated == 1][1]))
if (is.factor(arm)) {
df_arm$arm <- factor(df_arm$arm, labels = levels(arm))
} else if (is.numeric(arm) && all(arm %in% c(0, 1))) {
df_arm$arm <- factor(df_arm$arm, levels = c(1, 0),
labels = c("Treatment", "Control"))
} else {
df_arm$arm <- factor(df_arm$arm)
}
if (x$settings$swtrt_control_only) {
if (x$settings$stabilized_weights) {
p_w <- ggplot2::ggplot(df[df$treated == 0, ],
ggplot2::aes(x = .data$stabilized_weight)) +
ggplot2::geom_histogram(fill = "#77bd89", color = "#1f6e34",
bins = 30, alpha = 0.8) +
ggplot2::scale_x_continuous("Stabilized Weights") +
ggplot2::labs(title = df_arm$arm[[1]]) +
ggplot2::theme_bw()
} else {
p_w <- ggplot2::ggplot(df[df$treated == 0, ],
ggplot2::aes(x = .data$unstabilized_weight)) +
ggplot2::geom_histogram(fill = "#77bd89", color = "#1f6e34",
bins = 30, alpha = 0.8) +
ggplot2::scale_x_continuous("Unstabilized Weights") +
ggplot2::labs(title = df_arm$arm[[1]]) +
ggplot2::theme_bw()
}
} else {
if (x$settings$stabilized_weights) {
p_w <- ggplot2::ggplot(df, ggplot2::aes(x = .data$stabilized_weight)) +
ggplot2::geom_histogram(fill="#77bd89", color="#1f6e34",
bins = 30, alpha=0.8) +
ggplot2::scale_x_continuous("Stabilized Weights") +
ggplot2::facet_wrap(~ .data[[treat_var]]) +
ggplot2::theme_bw()
} else {
p_w <- ggplot2::ggplot(df, ggplot2::aes(x = .data$unstabilized_weight)) +
ggplot2::geom_histogram(fill="#77bd89", color="#1f6e34",
bins = 30, alpha=0.8) +
ggplot2::scale_x_continuous("Unstabilized Weights") +
ggplot2::facet_wrap(~ .data[[treat_var]]) +
ggplot2::theme_bw()
}
}
# --- Kaplan-Meier plot for counterfactual outcomes ---
df <- x$km_outcome
if (time_unit == "day") {
df$month <- df$time / 30.4375
} else if (time_unit == "week") {
df$month <- df$time / 4.3482
} else if (time_unit == "month") {
df$month <- df$time
} else if (time_unit == "year") {
df$month <- df$time * 12
} else {
stop("time_unit must be one of 'day', 'week', 'month', or 'year'")
}
if (!is.factor(df[[treat_var]])) {
if (is.numeric(df[[treat_var]]) && all(df[[treat_var]] %in% c(0, 1))) {
df[[treat_var]] <- factor(df[[treat_var]], levels = c(1, 0),
labels = c("Treatment", "Control"))
} else {
df[[treat_var]] <- factor(df[[treat_var]])
}
}
min_surv <- data.table::data.table(df)[, min(get("surv")),
by = treat_var][, get("V1")]
p_km <- ggplot2::ggplot(
df, ggplot2::aes(x = .data$month, y = .data$surv,
group = .data[[treat_var]], colour = .data[[treat_var]])) +
ggplot2::geom_step() +
ggplot2::scale_x_continuous(n.breaks = 11) +
ggplot2::scale_y_continuous(limits = c(0, 1)) +
ggplot2::labs(
x = "Months", y = "Survival Probability",
title = "Weighted Kaplan-Meier Curves for Censored Outcomes") +
ggplot2::theme_bw() +
ggplot2::theme(
plot.title = ggplot2::element_text(hjust = 0.5),
legend.title = ggplot2::element_blank(),
panel.grid.minor.x = ggplot2::element_blank(),
plot.margin = ggplot2::margin(t = 2, r = 5, b = 0, l = 20))
if (max(min_surv) < 0.5) {
p_km <- p_km + ggplot2::theme(legend.position = c(0.7, 0.85))
} else{
p_km <- p_km + ggplot2::theme(legend.position = c(0.15, 0.25))
}
# add hazard ratio to plot
if (show_hr) {
if (max(min_surv) < 0.5) {
p_km <- p_km +
ggplot2::annotate(
"text", x = 0.6*max(df$month), y = 0.7, hjust = 0,
label = sprintf("HR = %.3f (%.0f%% CI: %.3f, %.3f)",
x$hr, conflev, x$hr_CI[1], x$hr_CI[2]),
size = 3.5, color = "black"
)
} else {
p_km <- p_km +
ggplot2::annotate(
"text", x = 0, y = 0, hjust = 0,
label = sprintf("HR = %.3f (%.0f%% CI: %.3f, %.3f)",
x$hr, conflev, x$hr_CI[1], x$hr_CI[2]),
size = 3.5, color = "black"
)
}
}
# add number at risk table below plot
if (show_risk) {
xbreaks <- ggplot2::ggplot_build(p_km)$layout$panel_params[[1]]$x$breaks
xbreaks <- xbreaks[!is.na(xbreaks)]
limits <- c(min(xbreaks), max(max(xbreaks), max(df$month)))
# --- Compute number at risk at or before each tick ---
tablist <- lapply(0:1, function(h) {
t <- df$month[df$treated == h]
n <- df$nrisk[df$treated == h]
idx <- findInterval(xbreaks, t) + 1
atrisk <- ifelse(idx <= length(n), n[idx], 0)
df1 <- data.frame(time = xbreaks, atrisk = atrisk)
df1[[treat_var]] <- df[[treat_var]][df$treated == h][1]
df1
})
df_risk <- do.call(rbind, tablist)
df_risk[[treat_var]] <- factor(df_risk[[treat_var]],
levels = levels(df[[treat_var]]))
# --- Create number at risk plot ---
p_risk <- ggplot2::ggplot(
df_risk, ggplot2::aes(x = .data$time, y = .data[[treat_var]],
label = .data$atrisk, colour = .data[[treat_var]])) +
ggplot2::geom_text(size = 3.2, na.rm = TRUE) +
ggplot2::scale_x_continuous(breaks = xbreaks, limits = range(xbreaks)) +
ggplot2::scale_y_discrete(limits = rev(levels(df_risk[[treat_var]]))) +
ggplot2::coord_cartesian(clip = "off") +
ggplot2::theme_minimal() +
ggplot2::theme(
axis.title = ggplot2::element_blank(),
axis.text.x = ggplot2::element_blank(),
axis.ticks = ggplot2::element_blank(),
legend.position = "none",
panel.grid = ggplot2::element_blank(),
plot.margin = ggplot2::margin(t = 6, r = 5, b = 0, l = 20)) +
ggplot2::annotate(
"text", x = min(xbreaks), y = 3,
label = "No. of Subjects at Risk", size = 4, hjust = 0.5)
suppressMessages({
p_km <- p_km +
ggplot2::scale_x_continuous(
breaks = xbreaks, limits = limits, expand = c(0.05, 0))
p_risk <- p_risk +
ggplot2::scale_x_continuous(
breaks = xbreaks, limits = limits, expand = c(0.05, 0))
})
# 2. Align plots using align_plots()
aligned <- cowplot::align_plots(p_km, p_risk, align = "v", axis = "lr")
# 3. Combine with plot_grid()
p_km <- cowplot::plot_grid(aligned[[1]], aligned[[2]], ncol = 1,
rel_heights = c(4, 0.6))
}
list(p_w = p_w, p_km = p_km)
} else {
stop("No outcome data available for plotting.")
}
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.