Introduction and Basics: PRISM"

knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>",
  fig.width=11.5, fig.height=8.5
)

Introduction

Welcome to the StratifiedMedicine R package. The overall goal of this package is to develop analytic and visualization tools to aid in stratified and personalized medicine. Stratified medicine aims to find subsets or subgroups of patients with similar treatment effects, for example responders vs non-responders, while personalized medicine aims to understand treatment effects at the individual level (does a specific individual respond to treatment A?). Development of this package is ongoing.

Currently, the main algorithm in this package is "PRISM" (Patient Response Identifiers for Stratified Medicine; Jemielita and Mehrotra 2019, https://arxiv.org/abs/1912.03337). Given a data-structure of $(Y, A, X)$ (outcome(s), treatments, covariates), PRISM is a five step procedure:

  1. Estimand: Determine the question(s) or estimand(s) of interest. For example, $\theta_0 = E(Y|A=1)-E(Y|A=0)$, where A is a binary treatment variable. While this isn't an explicit step in the PRISM function, the question of interest guides how to set up PRISM.

  2. Filter (filter): Reduce covariate space by removing variables unrelated to outcome/treatment. Formally: $$ filter(Y, A, X) \longrightarrow (Y, A, X^{\star}) $$ where $X^{\star}$ has potentially lower dimension than $X$.

  3. Patient-level estimate (ple): Estimate counterfactual patient-level quantities, for example the individual treatment effect, $\theta(x) = E(Y|X=x,A=1)-E(Y|X=x,A=0)$. Formally: $$ ple(Y, A, X^{\star}) \longrightarrow \hat{\mathbf{\Theta}}(X^{\star}) $$ where $\hat{\mathbf{\theta}}(X^{\star})$ is the matrix of patient-level estimates.For example, these could refer to counterfactual estimates of $[E(Y|A=1,X=x), E(Y|A=0, X=X), E(Y|A=1,X=x)-E(Y|A=0,X=x)]$, or $[RMST_{\tau}(Y|A=1,X=x)-RMST_{\tau}(Y|A=0,X=x)]$ where RMST refers to the restricted mean survival time with truncation time $\tau$.

  4. Subgroup model (submod): Partition the data into subsets of patients (likely with similar treatment effects). Formally: $$ submod(Y, A, X^{\star}, \hat{\mathbf{\Theta}}(X^{\star})) \longrightarrow \mathbf{S}(X^{\star}) $$ where $\mathbf{S}(X^{\star})$ is a distinct set of rules that define the $k=0,...,K$ discovered subgroups, for example $\mathbf{S}(X^{\star}) = {X_1=0, X_2=0}$. Note that subgroups could be formed using the observed outcomes, PLEs, or both. By default, $k=0$ corresponds to the overall population.

  5. Parameter estimation and inference (param): For the overall population and discovered subgroups, output point estimates and variability metrics. Formally: $$ param(Y, A, X^{\star}, \hat{\mathbf{\theta}}(X^{\star}), \mathbf{S}(X^{\star}) ) \longrightarrow { \hat{\theta}{k}, SE(\hat{\theta}_k), CI{\alpha}(\hat{\theta}{k}), P(\hat{\theta}{k} > c) } \text{ for } k=0,...K $$ where $\hat{\theta}{k}$ is the point-estimate, $SE(\hat{\theta}_k)$ is the standard error, $CI{\alpha}(\hat{\theta}{k})$ is a two (or one) sided confidence interval with nominal coverage $1-\alpha$, and $P(\hat{\theta}{k} > c)$ is a probability statement for some constant $c$ (ex: $c=0$). These outputs are crucial for decision making and can also correspond to multiple estimates for each subgroup and overall. For binary/continuous outcomes, the default is to output point-estimates, SEs, CIs, and p-values for corresponding estimands $[E(Y|A=1,X=x), E(Y|A=0,X=X), E(Y|A=1,X=x)-E(Y|A=0,X=x)]$ in each discovered subgroup and overall.

Ultimately, PRISM provides information at the patient-level, the subgroup-level (if any), and the overall population. While there are defaults in place, the user can also input their own functions/model wrappers into the PRISM algorithm. We will demonstrate this later. PRISM can also be run without treatment assignment (A=NULL); in this setting, the focus is on finding subgroups based on prognostic effects. The below table describes default PRISM configurations for different family (gaussian, biomial, survival) and treatment (no treatment vs treatment) settings, including the associated estimands. Note that OLS refers to ordinary least squares (linear regression), GLM refers to generalized linear model, and MOB refers to model based partitioning (Zeileis, Hothorn, Hornik 2008; Seibold, Zeileis, Hothorn 2016). To summarise, default models include elastic net (Zou and Hastie 2005) for filtering, random forest ("ranger" R package) for patient-level /counterfactual estimation, and MOB (through "partykit" R package; lmtree, glmtree, and ctree (Hothorn, Hornik, Zeileis 2005)). When treatment assignment is provided, parameter estimation for continuous and binary outcomes involves averaging the patient-level estimates within the overall population and discovered subgroups (more details later). For survival outcomes, the cox regression hazard ratio (HR) or RMST (from the survR2 package) is used.

library(knitr)
summ.table = data.frame( `Step` = c("estimand(s)", "filter", "ple", "submod", "param"),
                        `gaussian` = c("E(Y|A=0)<br>E(Y|A=1)<br>E(Y|A=1)-E(Y|A=0)",
                                       "Elastic Net<br>(filter_glmnet)", 
                                       "Random Forest<br>(ple_ranger)",
                                       "MOB(OLS)<br>(submod_lmtree)", 
                                       "Average of PLEs<br>(param_ple)"),
                        `binomial` = c("E(Y|A=0)<br>E(Y|A=1)<br>E(Y|A=1)-E(Y|A=0)",
                                       "Elastic Net<br>(filter_glmnet)", 
                                       "Random Forest<br>(ple_ranger)",
                                       "MOB(GLM)<br>(submod_glmtree)", 
                                       "Average of PLEs<br>(param_ple)"),    
                        `survival` = c("HR(A=1 vs A=0)",
                                       "Elastic Net<br>(filter_glmnet)", 
                                       "Random Forest<br>(ple_ranger)",
                                       "MOB(weibull)<br>(submod_weibull)", 
                                       "Hazard Ratios<br>(param_HR)") )                        

kable( summ.table, caption = "Default PRISM Configurations (With Treatment)", full_width=T)

summ.table = data.frame( `Step` = c("estimand(s)", "filter", "ple", "submod", "param"),
                        `gaussian` = c("E(Y)",
                                       "Elastic Net<br>(filter_glmnet)", 
                                       "Random Forest<br>(ple_ranger)",
                                       "Conditional Inference Trees<br>submod_ctree",
                                       "OLS<br>(param_lm)"),
                        `binomial` = c("Prob(Y)",
                                       "Elastic Net<br>(filter_glmnet)", 
                                       "Random Forest<br>(ple_ranger)",
                                       "Conditional Inference Trees<br>submod_ctree", 
                                       "OLS<br>(param_lm)"),    
                        `survival` = c("RMST", "Elastic Net<br>(filter_glmnet)", 
                                       "Random Forest<br>(ple_ranger)",
                                       "Conditional Inference Trees<br>submod_ctree",
                                       "RMST<br>(param_rmst)") )                        

kable( summ.table, caption = "Default PRISM Configurations (Without Treatment, A=NULL)", full_width=T)

Example: Continuous Outcome with Binary Treatment

Consider a continuous outcome (ex: % change in tumor size) with a binary treatment (study drug vs standard of care). The estimand of interest is the average treatment effect, $\theta_0 = E(Y|A=1)-E(Y|A=0)$. First, we simulate continuous data where roughly 30\% of the patients receive no treatment-benefit for using $A=1$ vs $A=0$. Responders vs non-responders are defined by the continuous predictive covariates $X_1$ and $X_2$ for a total of four subgroups. Subgroup treatment effects are: $\theta_{1} = 0$ ($X_1 \leq 0, X_2 \leq 0$), $\theta_{2} = 0.25 (X_1 > 0, X_2 \leq 0)$, $\theta_{3} = 0.45 (X_1 \leq 0, X2 > 0$), $\theta_{4} = 0.65 (X_1>0, X_2>0)$.

library(ggplot2)
library(dplyr)
library(partykit)
library(StratifiedMedicine)
library(survival)
dat_ctns = generate_subgrp_data(family="gaussian")
Y = dat_ctns$Y
X = dat_ctns$X # 50 covariates, 46 are noise variables, X1 and X2 are truly predictive
A = dat_ctns$A # binary treatment, 1:1 randomized 
length(Y)
table(A)
dim(X)

For continuous outcome data (family="gaussian"), the default PRISM configuration is: (1) filter_glmnet (elastic net), (2) ple_ranger (treatment-specific random forest models), (3) submod_lmtree (model-based partitioning with OLS loss), and (4) param_ple (parameter estimation/inference through the PLEs). (Jemielita and Mehrotra 2019) show that this configuration performs quite well in terms of bias, efficiency, coverage, and selecting the right predictive covariates. To run PRISM, at a minimum, the outcome (Y), treatment (A), and covariates (X) must be provided. See below.

# PRISM Default: filter_glmnet, ple_ranger, submod_lmtree, param_ple #
res0 = PRISM(Y=Y, A=A, X=X)
summary(res0)
plot(res0) # same as plot(res0, type="tree")
## This is the same as running ##
# res1 = PRISM(Y=Y, A=A, X=X, family="gaussian", filter="filter_glmnet", 
#              ple = "ple_ranger", submod = "submod_lmtree", param="param_ple")

The summary gives a high-level overview of the findings (number of subgroups, parameter estimates, variables that survived the filter). The default plot() function currently combines tree plots with parameter estimates using the "ggparty" package. We can als directly look for prognostic effects by specifying omitting A (treatment) from PRISM:

# PRISM Default: filter_glmnet, ple_ranger, submod_ctree, param_lm #
res_prog = PRISM(Y=Y, X=X)
# res_prog = PRISM(Y=Y, A=NULL, X=X) #also works
summary(res_prog)
plot(res_prog)

Next, circling back to the first PRISM model with treatment included, let's review other core PRISM outputs and plotting functionality. Results relating to the filter include "filter.mod" (model output) and "filter.vars" (variables that pass the filter).

# elastic net model: loss by lambda #
plot(res0$filter.mod)
## Variables that remain after filtering ##
res0$filter.vars
# All predictive variables (X1,X2) and prognostic variables (X3,X5, X7) remains.

Results relating to the PLE model include "ple.mod" (model output), "mu.train" (training predictions), and "mu.test" (test predictions) where, for continuous or binary data, predictions are of E(Y|X,A=a) and E(Y|X,A=1)-E(Y|X,A=0). The PLEs, or individual treatment effects, are informative of the overall treatment heterogeneity and can be visualized through built-in waterfall plots. In this case, roughly 73% receive no benefit from treatment A=1 vs A=0. PRISM plots are built using "ggplot2", making it easy to enhance plot visualizations. For example,

prob.PLE = mean(I(res0$mu_train$PLE>0))
# Waterfall Plot #
plot(res0, type="PLE:waterfall")+geom_vline(xintercept = 0) + 
  geom_text(x=200, y=1, label=paste("Prob(PLE>0)=", prob.PLE, sep=""))

Next, the subgroup model (lmtree), identifies 4-subgroups based on varying treatment effects. By plotting the subgroup model object ("submod.fit\$mod")", we see that partitions are made through X1 (predictive) and X2 (predictive). At each node, parameter estimates for node (subgroup) specific OLS models, $Y\sim \beta_0+\beta_1*A$. For example, patients in nodes 4 and 6 have estimated treatment effects of 0.47 and 0.06 respectively. Subgroup predictions for the train/test set can be found in the "out.train" and "out.test" data-sets.

plot(res0$submod.fit$mod, terminal_panel = NULL)
table(res0$out.train$Subgrps)
table(res0$out.test$Subgrps)

These estimates tend to be overly positive or negative, as the same data that trains the subgroup model is used to estimate the treatment effects. Resampling, such as bootstrapping, can generally be used for "de-biased" treatment effect estimates and obtain valid inference (more details later).

For continuous and binary data, an alternative approach without resampling is to directly use the PLEs for parameter estimation and inference (param="param_ple"). Let $E(Y|X=x,A=a) = \mu(x, a)$ correspond to the outcome regression model(s) with estimates $\hat{\mu}(x, a)$. These estimates come directly from the fitted PLE model(s), in this case, treatment-specific random forest models. For the overall population and each discovered subgroup ($s=0,...,S$), the treatment effect (or risk difference) can be estimated by averaging the patient-specific treatment effect estimates (PLEs): $$ \hat{\theta}k = \frac{1}{n_k} \sum{i \in S_k} {\hat{\theta}}(x_i) $$ where $\hat{\theta}(x_i)=\hat{\mu}(a=1,x)-\hat{\mu}(a=0,x)$. For SEs / CIs, we utilize "pseudo-outcomes": $$ Y^{\star}i = \frac{AY - (A-\hat{\pi}(x))\hat{\mu}(a=1,x)}{\hat{\pi}(x)} - \frac{(1-A)Y - (A-\hat{\pi}(x))\hat{\mu}(a=0,x)}{1-\hat{\pi}(x)}$$ where $\pi(x)=P(A=1|X)$, or the treatment assignment probability for an individual. In a randomized controlled trial, this can be replaced by the marginal probability, $P(A=1|X)$. Note that $E(Y^{\star}_i)=E(Y|A=1,X)-E(Y|A=0,X)$ and $E(n_k^{-1}\sum{i \in S_k} Y^{\star}i)= E(Y|A=1, X \in S_k)-E(Y|A=0, X \in S_k)$. Next: $$SE(\hat{\theta}_k) = \sqrt{ n_k ^ {-2} \sum{i \in S_k} \left( Y^{\star}i-\hat{\theta}(x_i) \right)^2 } $$ CIs can then be formed using t- or Z-intervals. For example, a two-sided 95\% Z-interval, $CI{\alpha}(\hat{\theta}{k}) = \left[\hat{\theta}{k} \pm 1.96*SE(\hat{\theta}_k) \right]$

Moving back to the PRISM outputs, for any of the provided "param" options, a key output is the object "param.dat". By default, "param.dat" contain point-estimates, standard errors, lower/upper confidence intervals (depends on alpha_s and alpha_ovrl) and p-values. This output feeds directly into previously shown default ("tree") and "forest" plot.

## Overall/subgroup specific parameter estimates/inference
res0$param.dat
## Forest plot: Overall/subgroup specific parameter estimates (CIs)
plot(res0, type="tree")
plot(res0, type="forest")

PLE depence plots or heatmaps can also be generated from PRISM outputs. By default, if no grid data is supplied, if a variable is categorical/factor then all values are used, otherwise if continuous, we take 20 equally spaced bins. Regardless, based on a grid of values (with up to three variables), PLEs are estimated for each patient by fixing the grid variables to specific values. We then average the PLEs to obtain a point-estimate for each specific set of grid values, and can likewise calculate probabilities. See below; note that the heatmap is also consistent with the truth (treatment benefit for $X_1>0, X_2>0$ patients).

plot_dependence(res0, vars="X1")
plot_dependence(res0, vars="X2")
plot_dependence(res0, vars=c("X1", "X2"))

The hyper-parameters for the individual steps of PRISM can also be easily modified. For example, "filter_glmnet" by default selects covariates based on "lambda.min", "ple_ranger" requires nodes to contain at least 10% of the total observations, and "submod_lmtree" requires nodes to contain at least 10% of the total observations. To modify this:

# PRISM Default: filter_glmnet, ple_ranger, submod_lmtree, param_ple #
# Change hyper-parameters #
res_new_hyper = PRISM(Y=Y, A=A, X=X, filter.hyper = list(lambda="lambda.1se"),
                      ple.hyper = list(min.node.pct=0.05), 
                      submod.hyper = list(minsize=200), verbose=FALSE)
plot(res_new_hyper)

Example: Binary Outcome with Binary Treatment

Consider a binary outcome (ex: % overall response rate) with a binary treatment (study drug vs standard of care). The estimand of interest is the risk difference, $\theta_0 = E(Y|A=1)-E(Y|A=0)$. Similar to the continous example, we simulate binomial data where roughly 30\% of the patients receive no treatment-benefit for using $A=1$ vs $A=0$. Responders vs non-responders are defined by the continuous predictive covariates $X_1$ and $X_2$ for a total of four subgroups. Subgroup treatment effects are: $\theta_{1} = 0$ ($X_1 \leq 0, X_2 \leq 0$), $\theta_{2} = 0.11 (X_1 > 0, X_2 \leq 0)$, $\theta_{3} = 0.21 (X_1 \leq 0, X2 > 0$), $\theta_{4} = 0.31 (X_1>0, X_2>0)$.

For binary outcomes (Y=0,1), the default settings use glmnet to filter ("filter_glmnet"), randomforest patient-level estimates ("ple_ranger"; for binary, the output is the risk difference), "submod_glmtree"" (GLM MOB with binomial(link=identity)) for subgroup identification, and param_ple (average counterfactual risk differences within each subgroup; same formulas as continuous setting).

dat_bin = generate_subgrp_data(family="binomial", seed = 5558)
Y = dat_bin$Y
X = dat_bin$X # 50 covariates, 46 are noise variables, X1 and X2 are truly predictive
A = dat_bin$A # binary treatment, 1:1 randomized 

res0 = PRISM(Y=Y, A=A, X=X)
summary(res0)
plot(res0)

Example: Survival Outcome with Binary Treatment

Survival outcomes are also allowed in PRISM. The default settings use glmnet to filter ("filter_glmnet"), ranger patient-level estimates ("ple_ranger"; for survival, the output is the restricted mean survival time treatment difference), "submod_weibull"" (MOB with weibull loss function) for subgroup identification, and param_cox (subgroup-specific cox regression models). Another subgroup option is to use "submod_ctree"", which uses the conditional inference tree (ctree) algorithm to find subgroups; this looks for partitions irrespective of treatment assignment and thus corresponds to finding prognostic effects.

# Load TH.data (no treatment; generate treatment randomly to simulate null effect) ##
data("GBSG2", package = "TH.data")
surv.dat = GBSG2
# Design Matrices ###
Y = with(surv.dat, Surv(time, cens))
X = surv.dat[,!(colnames(surv.dat) %in% c("time", "cens")) ]
set.seed(6345)
A = rbinom( n = dim(X)[1], size=1, prob=0.5  )

# Default: filter_glmnet ==> ple_ranger (estimates patient-level RMST(1 vs 0) ==> submod_weibull (MOB with Weibull) ==> param_cox (Cox regression)
res_weibull1 = PRISM(Y=Y, A=A, X=X)
plot(res_weibull1, type="PLE:waterfall")
plot(res_weibull1)

# PRISM: filter_glmnet ==> submod_ctree ==> param_cox (Cox regression) #
res_ctree1 = PRISM(Y=Y, A=A, X=X, submod = "submod_ctree")
plot(res_ctree1)

Resampling

Resampling methods are also a feature in PRISM. Bootstrap (resample="Bootstrap"), permutation (resample="Permutation"), and cross-validation (resample="CV") based-resampling are included. Resampling can be used for obtaining de-biased or "honest" subgroup estimates, inference, and/or probability statements. For each resampling method, the sampling mechanism can be stratified by the discovered subgroups (default: stratify=TRUE). To summarize:

Bootstrap Resampling

Given observed data $(Y, A, X)$, fit $PRISM(Y,A,X)$. Based on the identified $k=1,..,K$ subgroups, output subgroup assignment for each patient. For the overall population and each subgroup ($k=0,...,K$), store the associated parameter estimates ($\hat{\theta}{k}$). For $r=1,..,R$ resamples with replacement ($(Y_r, A_r, X_r)$), fit $PRISM(Y_r, A_r, X_r)$ and obtain new subgroup assignments $k_r=1,..,K_r$ with associated parameter estimates $\hat{\theta}{k_r}$. For resample $r$, the bootstrap estimates and SEs for the original identified subgroups ($k=0,...,K$) are calculated respectively as: $$ \hat{\theta}{rk} = \sum{k_r} w_{k_r} \hat{\theta_{k_r}} $$ $$ SE(\hat{\theta}){rk} = \sqrt{ \sum{k_r} w_{k_r}^2 SE(\hat{\theta}{k_r})^2 } $$ where $w{k_r} = n(k \cap k_r)/ \sum_{k_r} n(k \cap k_r)$ and $n(k \cap k_r)$ is the number of subjects in the original subgroup $k$ who are also in the bootstrap subgroup $k_r$. The bootstrap smoothed estimate and standard error, as well as probability statements, are calculated as: $$ \tilde{\theta}{k} = \frac{1}{R} \sum_r \hat{\theta}{rk} $$ $$ SE(\hat{\theta}{k})_B = \sqrt{ \frac{1}{R} \sum_r (\hat{\theta}{rk}-\tilde{\theta}{k})^2 } $$ $$ \hat{P}(\hat{\theta}{k}>c) = \frac{1}{R} \sum_r I(\hat{\theta}_{rk}>c) $$

Bootstrap confidence intervals can then be formed, $\left[\hat{\theta}{k} \pm 1.96*SE(\hat{\theta}_k)_B \right]$. Bootstrap calibration, which uses the observed point-estimates and SEs, but adjusts the alpha level such that we obtain on average $1-\alpha$ coverage across all identified subgroups, is also implemented for PRISM (default: calibrate=FALSE). See Loh et al 2016 (GUIDE) for more details. Importantly, for calibration to be effective,$\hat{\theta}{k}$ should be relatively unbiased. Another approach is to use the bootstrap smoothed estimates, $\tilde{\theta}_{k}$, along with percentile-based CIs (i.e. 2.5,97.5 quantiles of bootstrap distribution). Other metrics are also automatically calculated, such as bootstrap bias (can be used to adjust initial subgroup estimates).

Returning to the survival example, we now re-run PRISM with 50 bootstrap resamples (for increased accuracy, use >1000). The smoothed bootstrap estimates, bootstrap standard errors, bootstrap bias, percentile CI, and calibrated CI correspond to "est_resamp", "SE_resamp", "bias.boot", "LCL.pct"/"UCL.pct", and "LCL.calib"/"UCL.calib" respectively. We can also plot a density plot of the bootstrap distributions through the plot(...,type="resample") option.

res_boot = PRISM(Y=Y, A=A, X=X, resample = "Bootstrap", R=50, ple = "None")
# Plot of distributions #
plot(res_boot, type="resample", estimand = "HR(A=1 vs A=0)")+geom_vline(xintercept = 1)

Permutation Resampling

Permutation resampling (resample="Permutation") follows the same general procedure as bootstrap resampling. The main difference is that we only randomly shuffle the treatment assignment $A$ without replacement. This simulates the null hypothesis of no treatment. A key output is the permutation p-values (pval_perm in param.dat) and the permutation resampling distributions.

Cross-Validation

Cross-validation resampling (resample="CV") also follows the same general procedure as bootstrap resampling. Given observed data $(Y, A, X)$, fit $PRISM(Y,A,X)$. Based on the identified $k=1,..,K$ subgroups, output subgroup assignment for each patient. Next, split the data into $R$ folds (ex: 5). For fold $r$ with sample size $n_r$, fit PRISM on $(Y[-r],A[-r], X[-r])$ and predict the patient-level estimates and subgroup assignments ($k_r=1,...,K_r$) for patients in fold $r$. The data in fold $r$ is then used to obtain parameter estimates for each subgroup, $\hat{\theta}{k_r}$. For fold $r$, estimates and SEs for the original subgroups ($k=1,...,K$) are then obtained using the same formula as with bootstrap resampling, again, denoted as ($\hat{\theta}{rk}$, $SE(\hat{\theta}_{rk})$). This is repeated for each fold and "CV" estimates and SEs are calculated for each identified subgroup. Let $w_r = n_r / \sum_r n_r$, then:

$$ \hat{\theta}{k,CV} = \sum w_r * \hat{\theta}{rk} $$ $$ SE(\hat{\theta}k){CV} = \sqrt{ \sum_{r} w_{r}^2 SE(\hat{\theta}{rk})^2 }$$ CV-based confidence intervals can then be formed, $\left[\hat{\theta}{k,CV} \pm 1.96*SE(\hat{\theta}k){CV} \right]$.

Conclusion

Overall, PRISM is a flexible algorithm that can aid in subgroup detection and exploration of heterogeneous treatment effects. Each step of PRISM is customizable, allowing for fast experimentation and improvement of individual steps. More details on creating user-specific models can be found in the "User_Specific_Models_PRISM" vignette User_Specific_Models. The StratifiedMedicine R package and PRISM will be continually updated and improved. User-feedback will further faciliate improvements.



Try the StratifiedMedicine package in your browser

Any scripts or data that you put into this service are public.

StratifiedMedicine documentation built on March 1, 2020, 9:07 a.m.