Nothing
#' @title Simple Precision Medicine Tree
#' @description This function creates a classification tree
#' designed to identify subgroups in which subjects
#' perform especially well or especially poorly in a
#' given treatment group.
#'
#' @param formula A description of the model to be fit with format
#' \code{Y ~ treatment | X1 + X2} for data with a
#' continuous outcome variable Y and
#' \code{Surv(Y, delta) ~ treatment | X1 + X2} for data with
#' a right-censored survival outcome variable Y and
#' a status indicator delta
#' @param data A matrix or data frame of the data
#' @param types A vector, data frame, or matrix of the types
#' of each variable in the data; if left blank, the
#' default is to assume all of the candidate split
#' variables are ordinal; otherwise, all variables in
#' the data must be specified, and the possible variable
#' types are: "response", "treatment", "status", "binary",
#' "ordinal", and "nominal" for outcome variable Y, the
#' treatment variable, the status indicator (if
#' applicable), binary candidate split variables, ordinal
#' candidate split variables, and nominal candidate split
#' variables respectively
#' @param nmin An integer specifying the minimum node size of
#' the overall classification tree
#' @param maxdepth An integer specifying the maximum depth of the
#' overall classification tree; this argument is
#' optional but useful for shortening computation
#' time; if left blank, the default is to grow the
#' full tree until the minimum node size \code{nmin}
#' is reached
#' @param print A boolean (TRUE/FALSE) value, where TRUE prints
#' a more readable version of the final tree to the
#' screen
#' @param dataframe A boolean (TRUE/FALSE) value, where TRUE returns
#' the final tree as a dataframe
#' @param prune A boolean (TRUE/FALSE) value, where TRUE prunes
#' the final tree using \code{pmprune} function
#'
#' @details To identify the best split at each node of the
#' classification tree, all possible splits of all
#' candidate split variables are considered. The single
#' split with the highest split criteria score is
#' identified as the best split of the node. For data with
#' a continuous outcome variable, the split criteria is the
#' DIFF value that was first proposed for usage in the
#' relative-effectiveness based method (Zhang et al. (2010),
#' Tsai et al. (2016)). For data with a survival outcome
#' variable, the split criteria is the squared test
#' statistic that tests the significance of the split by
#' treatment interaction term in a Cox proportional hazards
#' model.
#'
#' When using \code{spmtree}, note the following
#' requirements for the supplied data. First, the dataset
#' must contain an outcome variable Y and a treatment
#' variable. If Y is a right-censored survival time
#' outcome, then there must also be a status indicator
#' delta, where values of 1 denote the occurrence of the
#' (harmful) event of interest, and values of 0 denote
#' censoring. If there are only two treatment groups, then
#' the two possible values must be 0 or 1. If there are
#' more than two treatment groups, then the possible values
#' must be integers starting from 1 to the total number of
#' treatment assignments. In regard to the candidate split
#' variables, if a variable is binary, then the variable
#' must take values of 0 or 1. If a variable is nominal,
#' then the values must be integers starting from 1 to the
#' total number of categories. There cannot be any missing
#' values in the dataset. For candidate split variables
#' with missing values, the missings together (MT) method
#' proposed by Zhang et al. (1996) is helpful.
#'
#' @return \code{spmtree} returns the final classification tree as a
#' \code{party} object by default or a data frame. See
#' Hothorn and Zeileis (2015) for details. The data
#' frame contains the following columns of information:
#' \item{node}{Unique integer values that identify each node
#' in the tree, where all of the nodes are
#' indexed starting from 1}
#' \item{splitvar}{Integers that represent the candidate split
#' variable used to split each node, where
#' all of the variables are indexed starting
#' from 1; for terminal nodes, i.e., nodes
#' without child nodes, the value is set
#' equal to NA}
#' \item{splitvar_name}{The names of the candidate split
#' variables used to split each node
#' obtained from the column names of the
#' supplied data; for terminal nodes,
#' the value is set equal to NA}
#' \item{type}{Characters that denote the type of each
#' candidate split variable; "bin" is for binary
#' variables, "ord" for ordinal, and "nom" for
#' nominal; for terminal nodes, the value is set
#' equal to NA}
#' \item{splitval}{Values of the left child node of the
#' current split/node; for binary variables,
#' a value of 0 is printed, and subjects with
#' values of 0 for the current \code{splitvar}
#' are in the left child node, while subjects
#' with values of 1 are in the right child
#' node; for ordinal variables,
#' \code{splitval} is numeric and implies
#' that subjects with values of the current
#' \code{splitvar} less than or equal to
#' \code{splitval} are in the left child
#' node, while the remaining subjects with
#' values greater than \code{splitval} are in
#' the right child node; for nominal
#' variables, the \code{splitval} is a set of
#' integers separated by commas, and subjects
#' in that set of categories are in the left
#' child node, while the remaining subjects
#' are in the right child node; for terminal
#' nodes, the value is set equal to NA}
#' \item{lchild}{Integers that represent the index (i.e.,
#' \code{node} value) of each node's left
#' child node; for terminal nodes, the value is
#' set equal to NA}
#' \item{rchild}{Integers that represent the index (i.e.,
#' \code{node} value) of each node's right
#' child node; for terminal nodes, the value is
#' set equal to NA}
#' \item{depth}{Integers that specify the depth of each
#' node; the root node has depth 1, its
#' children have depth 2, etc.}
#' \item{nsubj}{Integers that count the total number of
#' subjects within each node}
#' \item{besttrt}{Integers that denote the identified best
#' treatment assignment of each node}
#'
#' @references Chen, V., Li, C., and Zhang, H. (2022). dipm: an
#' R package implementing the Depth Importance in
#' Precision Medicine (DIPM) tree and Forest-based method.
#' \emph{Bioinformatics Advances}, \strong{2}(1), vbac041.
#'
#' Chen, V. and Zhang, H. (2022). Depth importance in
#' precision medicine (DIPM): A tree-and forest-based
#' method for right-censored survival outcomes.
#' \emph{Biostatistics} \strong{23}(1), 157-172.
#'
#' Chen, V. and Zhang, H. (2020). Depth importance in
#' precision medicine (DIPM): a tree and forest based method.
#' In \emph{Contemporary Experimental Design,
#' Multivariate Analysis and Data Mining}, 243-259.
#'
#' Tsai, W.-M., Zhang, H., Buta, E., O'Malley, S.,
#' Gueorguieva, R. (2016). A modified classification
#' tree method for personalized medicine decisions.
#' \emph{Statistics and its Interface} \strong{9},
#' 239-253.
#'
#' Zhang, H., Holford, T., and Bracken, M.B. (1996).
#' A tree-based method of analysis for prospective
#' studies. \emph{Statistics in Medicine} \strong{15},
#' 37-49.
#'
#' Zhang, H., Legro, R.S., Zhang, J., Zhang, L., Chen,
#' X., et al. (2010). Decision trees for identifying
#' predictors of treatment effectiveness in clinical
#' trials and its application to ovulation in a study of
#' women with polycystic ovary syndrome. \emph{Human
#' Reproduction} \strong{25}, 2612-2621.
#'
#' Hothorn, T. and Zeileis, A. (2015). partykit:
#' a modular toolkit for recursive partytioning in R.
#' \emph{The Journal of Machine Learning Research}
#' \strong{16}(1), 3905-3909.
#'
#' @seealso \code{\link{dipm}}
#'
#' @examples
#'
#' #
#' # ... an example with a continuous outcome variable
#' # and two treatment groups
#' #
#'
#' N = 300
#' set.seed(123)
#'
#' # generate binary treatments
#' treatment = rbinom(N, 1, 0.5)
#'
#' # generate candidate split variables
#' X1 = rnorm(n = N, mean = 0, sd = 1)
#' X2 = rnorm(n = N, mean = 0, sd = 1)
#' X3 = rnorm(n = N, mean = 0, sd = 1)
#' X4 = rnorm(n = N, mean = 0, sd = 1)
#' X5 = rnorm(n = N, mean = 0, sd = 1)
#' X = cbind(X1, X2, X3, X4, X5)
#' colnames(X) = paste0("X", 1:5)
#'
#' # generate continuous outcome variable
#' calculateLink = function(X, treatment){
#'
#' ((X[, 1] <= 0) & (X[, 2] <= 0)) *
#' (25 * (1 - treatment) + 8 * treatment) +
#'
#' ((X[, 1] <= 0) & (X[, 2] > 0)) *
#' (18 * (1 - treatment) + 20 * treatment) +
#'
#' ((X[, 1] > 0) & (X[, 3] <= 0)) *
#' (20 * (1 - treatment) + 18 * treatment) +
#'
#' ((X[, 1] > 0) & (X[, 3] > 0)) *
#' (8 * (1 - treatment) + 25 * treatment)
#' }
#'
#' Link = calculateLink(X, treatment)
#' Y = rnorm(N, mean = Link, sd = 1)
#'
#' # combine variables in a data frame
#' data = data.frame(X, Y, treatment)
#'
#' # fit a classification tree
#' tree1 = spmtree(Y ~ treatment | ., data, maxdepth = 3)
#' # predict optimal treatment for new subjects
#' predict(tree1, newdata = head(data),
#' FUN = function(n) as.numeric(n$info$opt_trt))
#'
#'\donttest{
#' #
#' # ... an example with a continuous outcome variable
#' # and three treatment groups
#' #
#'
#' N = 600
#' set.seed(123)
#'
#' # generate treatments
#' treatment = sample(1:3, N, replace = TRUE)
#'
#' # generate candidate split variables
#' X1 = round(rnorm(n = N, mean = 0, sd = 1), 4)
#' X2 = round(rnorm(n = N, mean = 0, sd = 1), 4)
#' X3 = sample(1:4, N, replace = TRUE)
#' X4 = sample(1:5, N, replace = TRUE)
#' X5 = rbinom(N, 1, 0.5)
#' X6 = rbinom(N, 1, 0.5)
#' X7 = rbinom(N, 1, 0.5)
#' X = cbind(X1, X2, X3, X4, X5, X6, X7)
#' colnames(X) = paste0("X", 1:7)
#'
#' # generate continuous outcome variable
#' calculateLink = function(X, treatment){
#'
#' 10.2 - 0.3 * (treatment == 1) - 0.1 * X[, 1] +
#' 2.1 * (treatment == 1) * X[, 1] +
#' 1.2 * X[, 2]
#' }
#'
#' Link = calculateLink(X, treatment)
#' Y = rnorm(N, mean = Link, sd = 1)
#'
#' # combine variables in a data frame
#' data = data.frame(X, Y, treatment)
#'
#' # create vector of variable types
#' types = c(rep("ordinal", 2), rep("nominal", 2), rep("binary", 3),
#' "response", "treatment")
#'
#' # fit a classification tree
#' tree2 = spmtree(Y ~ treatment | ., data, types = types)
#'
#' #
#' # ... an example with a survival outcome variable
#' # and two treatment groups
#' #
#'
#' N = 300
#' set.seed(321)
#'
#' # generate binary treatments
#' treatment = rbinom(N, 1, 0.5)
#'
#' # generate candidate split variables
#' X1 = rnorm(n = N, mean = 0, sd = 1)
#' X2 = rnorm(n = N, mean = 0, sd = 1)
#' X3 = rnorm(n = N, mean = 0, sd = 1)
#' X4 = rnorm(n = N, mean = 0, sd = 1)
#' X5 = rnorm(n = N, mean = 0, sd = 1)
#' X = cbind(X1, X2, X3, X4, X5)
#' colnames(X) = paste0("X", 1:5)
#'
#' # generate survival outcome variable
#' calculateLink = function(X, treatment){
#'
#' X[, 1] + 0.5 * X[, 3] + (3 * treatment - 1.5) * (abs(X[, 5]) - 0.67)
#' }
#'
#' Link = calculateLink(X, treatment)
#' T = rexp(N, exp(-Link))
#' C0 = rexp(N, 0.1 * exp(X[, 5] + X[, 2]))
#' Y = pmin(T, C0)
#' delta = (T <= C0)
#'
#' # combine variables in a data frame
#' data = data.frame(X, Y, delta, treatment)
#'
#' # fit a classification tree
#' tree3 = spmtree(Surv(Y, delta) ~ treatment | ., data, maxdepth = 2)
#'
#' #
#' # ... an example with a survival outcome variable
#' # and four treatment groups
#' #
#'
#' N = 800
#' set.seed(321)
#'
#' # generate treatments
#' treatment = sample(1:4, N, replace = TRUE)
#'
#' # generate candidate split variables
#' X1 = round(rnorm(n = N, mean = 0, sd = 1), 4)
#' X2 = round(rnorm(n = N, mean = 0, sd = 1), 4)
#' X3 = sample(1:4, N, replace = TRUE)
#' X4 = sample(1:5, N, replace = TRUE)
#' X5 = rbinom(N, 1, 0.5)
#' X6 = rbinom(N, 1, 0.5)
#' X7 = rbinom(N, 1, 0.5)
#' X = cbind(X1, X2, X3, X4, X5, X6, X7)
#' colnames(X) = paste0("X", 1:7)
#'
#' # generate survival outcome variable
#' calculateLink = function(X, treatment, noise){
#'
#' -0.2 * (treatment == 1) +
#' -1.1 * X[, 1] +
#' 1.2 * (treatment == 1) * X[, 1] +
#' 1.2 * X[, 2]
#' }
#'
#' Link = calculateLink(X, treatment)
#' T = rweibull(N, shape = 2, scale = exp(Link))
#' Cnoise = runif(n = N) + runif(n = N)
#' C0 = rexp(N, exp(0.3 * -Cnoise))
#' Y = pmin(T, C0)
#' delta = (T <= C0)
#'
#' # combine variables in a data frame
#' data = data.frame(X, Y, delta, treatment)
#'
#' # create vector of variable types
#' types = c(rep("ordinal", 2), rep("nominal", 2), rep("binary", 3),
#' "response", "status", "treatment")
#'
#' # fit two classification trees
#' tree4 = spmtree(Surv(Y, delta) ~ treatment | ., data, types = types, maxdepth = 2)
#' tree5 = spmtree(Surv(Y, delta) ~ treatment | X3 + X4, data, types = types,
#' maxdepth = 2)
#' }
#' @export
#' @import partykit
#' @import survival
#' @import stats
spmtree = function(formula,
data,
types = NULL,
nmin = 5,
maxdepth = Inf,
print = TRUE,
dataframe = FALSE,
prune = FALSE){
# check inputs
if(missing(formula)){
stop("The formula input is missing.")
}
if(missing(data)){
stop("The data input is missing.")
}
# coerce data input to R "data.frame" object
data = as.data.frame(data)
# coerce formula input to R "formula" object
form = as.formula(formula)
# if not missing, coerce types input to R "data.frame" object
if(missing(types) == FALSE){
if(!all(types %in% c("ordinal", "nominal", "binary",
"response", "status", "treatment"))){
stop("The type input is invalid.")
}
types = as.data.frame(types)
if(nrow(types) != 1){
types = t(types)
types = as.data.frame(types)
}
if(ncol(types) != ncol(data)){
stop("The number of variables in types does not equal the number of variables in the data.")
}
colnames(types) = colnames(data)
}
# get names of variables in the formula
form_vars = all.vars(form) # all variables
form_lhs = all.vars(form[[2]]) # variables to the left of ~
form_rhs = all.vars(form[[3]]) # variables to the right of ~
# response variable should always be (first) in lhs
Y = data[, form_lhs[1]]
if(!class(Y) %in% c("numeric", "integer")){
stop("Response Y must be numerical.")
}
# get status variable if applicable
if(length(form_lhs) == 1){
C = rep(0, nrow(data))
surv = 0
}
if(length(form_lhs) == 2){
C = data[, form_lhs[2]]
if(!class(C) %in% c("numeric", "integer", "logical")){
stop("delta must be integers.")
}
if(!all(unique(C) %in% c(0, 1))){
stop("delta must be 0 or 1.")
}
surv = 1
}
# treatment variable should always be first in rhs
treatment = data[, form_rhs[1]]
if(!class(treatment) %in% c("numeric", "integer")){
stop("Treatment must be integers.")
}
# determine appropriate method from data
ntrts = nlevels(as.factor(treatment))
if(maxdepth == Inf){
maxdepth = -7
}
if(ntrts <= 1){
stop("At least 2 treatment groups are required.")
}
if(ntrts == 2){
if(!all(unique(treatment) %in% c(0, 1))){
stop("Treatment must be 0 or 1 for two treatment groups.")
}
if(surv == 0){
method = -1
}else if(surv == 1){
method = 11
}
}else if(ntrts > 2){
if(!all(unique(treatment) %in% rep(1:ntrts))){
stop("Treatment must be 1 to ntrts for more than two treatment groups.")
}
if(surv == 0){
method = 24
}else if(surv == 1){
method = 25
}
}
# get matrix of candidate split variables X
if(form_rhs[2] == "."){ # account for Y ~ treatment | .
# formula
exclude = c(which(colnames(data) == form_lhs[1]), # Y variable
which(colnames(data) == form_rhs[1])) # treatment
if(length(form_lhs) == 2){ # exclude status indicator
exclude = c(exclude, which(colnames(data) == form_lhs[2]))
}
X = data.frame(data[, -exclude])
types = types[, -exclude]
}else{
include = which(colnames(data) %in% form_rhs[-1])
X = data.frame(data[, include])
types = types[, include]
}
# calculate number of observations n and variables nc
n = nrow(X)
nc = ncol(X)
if(nc == 1){
if( form_rhs[2] == "." ){
names(X) = names(data)[-exclude]
}else{
names(X) = names(data)[include]
}
}
# prepare types
if(is.null(types)){
types = rep(2, nc) # default is to assume all candidate
# split variables are ordinal
message("Note that all candidate split variables are assumed to be ordinal.")
}else{
if(nc == 1){
types = data.frame(types)
if( form_rhs[2] == "." ){
names(types) = names(data)[-exclude]
}else{
names(types) = names(data)[include]
}
rownames(types) = "types"
}
lll = ncol(types)
for (i in 1:lll){
if(types[i] == "binary") types[i] = 1
if(types[i] == "ordinal") types[i] = 2
if(types[i] == "nominal") types[i] = 3
}
}
ifbinary = any(types == 1)
if(ifbinary == TRUE){
ibin = which(types == 1)
if(length(ibin) == 1){
if(!class(data[, ibin]) %in% c("numeric", "integer")){
stop("Binary variables must be integers.")
}
if(!all(unique(data[, ibin]) %in% c(0, 1))){
stop("Binary variables must be 0 or 1.")
}
}else{
if(!all(apply(data[, ibin], 2, class) %in% c("numeric", "integer"))){
stop("Binary variables must be integers.")
}
if(!all(apply(data[, ibin], 2, unique) %in% c(0, 1))){
stop("Binary variables must be 0 or 1.")
}
}
}
ifordinal = any(types == 2)
if(ifordinal == TRUE){
iord = which(types == 2)
if(length(iord) == 1){
if(!class(data[, iord]) %in% c("numeric", "integer")){
stop("Ordinal variables must be numerical.")
}
}else{
if(!all(apply(data[, iord], 2, class) %in% c("numeric", "integer"))){
stop("Ordinal variables must be numerical.")
}
}
}
# create array of number of categories for nominal variables
ifnominal = any(types == 3)
if(ifnominal == TRUE){
inom = which(types == 3)
for(i in 1:length(inom)){
if(!class(X[, inom[i]]) %in% c("numeric", "integer")){
stop("Nominal variables must be integers.")
}
ncats = length(unique(X[, inom[i]]))
if(!all(unique(X[, inom[i]]) %in% rep(1:ncats))){
stop("Nominal must be 1 to ncats.")
}
X[, inom[i]] = factor(X[, inom[i]])
data[, colnames(X)[inom[i]]] = X[, inom[i]]
}
ncat = sapply(X, function(x)
if(is.null(levels(x))) -7
else max(as.numeric(levels(x)[x])))
}else{
ncat = rep(-7, nc)
}
# prepare covariate data
XC = t(X)
# set other unused parameter values to 0
ntree = 0
mtry = 0
nmin2 = 0
maxdepth2 = 0
# set types of R arguments to C
storage.mode(ntree) = "integer"
storage.mode(n) = "integer"
storage.mode(nc) = "integer"
storage.mode(Y) = "double"
storage.mode(XC) = "double"
storage.mode(types) = "integer"
storage.mode(ncat) = "integer"
storage.mode(treatment) = "integer"
storage.mode(C) = "integer"
storage.mode(nmin) = "integer"
storage.mode(nmin2) = "integer"
storage.mode(mtry) = "integer"
storage.mode(maxdepth) = "integer"
storage.mode(maxdepth2) = "integer"
storage.mode(method) = "integer"
tree = .Call("maketree",
ntree = ntree,
n = n,
nc = nc,
Y = Y,
X = XC,
types = types,
ncat = ncat,
treat = treatment,
censor = C,
nmin = nmin,
nmin2 = nmin2,
mtry = mtry,
maxdepth = maxdepth,
maxdepth2 = maxdepth2,
method = method,
environment(lm_R_to_C))
rm(XC)
# reformat tree
tree_txt = data.frame(as.vector(tree[[1]]),
as.vector(tree[[2]]),
as.vector(tree[[3]]),
as.vector(tree[[4]]),
as.vector(tree[[5]]),
as.vector(tree[[7]]),
as.vector(tree[[8]]),
as.vector(tree[[9]]),
as.vector(tree[[6]]),
as.vector(tree[[10]]),
as.vector(tree[[11]]),
as.vector(tree[[12]]),
as.vector(tree[[13]]),
as.vector(tree[[14]]),
as.vector(tree[[15]]),
as.vector(tree[[16]]),
as.vector(tree[[17]]))
colnames(tree_txt) = c("node",
"splitvar",
"type",
"sign",
"splitval",
"parent",
"lchild",
"rchild",
"depth",
"nsubj",
"ntrt0",
"ntrt1",
"r0",
"r1",
"p0",
"p1",
"besttrt")
# process tree output and/or print tree to screen
if(form_rhs[2] != "."){
splitvar_include = t(data.frame(include))
colnames(splitvar_include) = colnames(X)
}else{
splitvar_include = NULL
}
tree_txt = print.dipm(tree_txt, X, Y, C, treatment,
types, ncat, method, ntree, print,
splitvar_include)
if(prune){
tree_txt = pmprune(tree_txt)
}
if(dataframe){
return(tree_txt)
}else{
tree_pn = ini_node(1, tree_txt, data, form_rhs[1], surv)
if(surv){
tree_py = party(tree_pn, data,
fitted = data.frame(
"(fitted)" = fitted_node(tree_pn, data = data),
"(response)" = Surv(Y, C), check.names = F),
terms = terms(form))
}else{
tree_py = party(tree_pn, data,
fitted = data.frame(
"(fitted)" = fitted_node(tree_pn, data = data),
"(response)" = Y, check.names = F),
terms = terms(form))
}
return(tree_py)
}
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.