R/spmtree.R

Defines functions spmtree

Documented in spmtree

#' @title Simple Precision Medicine Tree
#' @description This function creates a classification tree
#'              designed to identify subgroups in which subjects
#'              perform especially well or especially poorly in a
#'              given treatment group.
#'
#' @param formula A description of the model to be fit with format
#'                \code{Y ~ treatment | X1 + X2} for data with a
#'                continuous outcome variable Y and 
#'                \code{Surv(Y, delta) ~ treatment | X1 + X2} for data with
#'                a right-censored survival outcome variable Y and
#'                a status indicator delta
#' @param data A matrix or data frame of the data
#' @param types A vector, data frame, or matrix of the types
#'              of each variable in the data; if left blank, the
#'              default is to assume all of the candidate split
#'              variables are ordinal; otherwise, all variables in 
#'              the data must be specified, and the possible variable 
#'              types are: "response", "treatment", "status", "binary", 
#'              "ordinal", and "nominal" for outcome variable Y, the 
#'              treatment variable, the status indicator (if 
#'              applicable), binary candidate split variables, ordinal
#'              candidate split variables, and nominal candidate split
#'              variables respectively
#' @param nmin An integer specifying the minimum node size of
#'             the overall classification tree
#' @param maxdepth An integer specifying the maximum depth of the
#'                 overall classification tree; this argument is 
#'                 optional but useful for shortening computation 
#'                 time; if left blank, the default is to grow the 
#'                 full tree until the minimum node size \code{nmin} 
#'                 is reached
#' @param print A boolean (TRUE/FALSE) value, where TRUE prints
#'              a more readable version of the final tree to the
#'              screen
#' @param dataframe A boolean (TRUE/FALSE) value, where TRUE returns
#'                  the final tree as a dataframe
#' @param prune A boolean (TRUE/FALSE) value, where TRUE prunes
#'              the final tree using \code{pmprune} function
#'
#' @details To identify the best split at each node of the 
#'          classification tree, all possible splits of all 
#'          candidate split variables are considered. The single
#'          split with the highest split criteria score is
#'          identified as the best split of the node. For data with 
#'          a continuous outcome variable, the split criteria is the 
#'          DIFF value that was first proposed for usage in the 
#'          relative-effectiveness based method (Zhang et al. (2010),
#'          Tsai et al. (2016)). For data with a survival outcome 
#'          variable, the split criteria is the squared test 
#'          statistic that tests the significance of the split by 
#'          treatment interaction term in a Cox proportional hazards 
#'          model.
#'
#'          When using \code{spmtree}, note the following
#'          requirements for the supplied data. First, the dataset
#'          must contain an outcome variable Y and a treatment
#'          variable. If Y is a right-censored survival time
#'          outcome, then there must also be a status indicator
#'          delta, where values of 1 denote the occurrence of the 
#'          (harmful) event of interest, and values of 0 denote
#'          censoring. If there are only two treatment groups, then
#'          the two possible values must be 0 or 1. If there are
#'          more than two treatment groups, then the possible values
#'          must be integers starting from 1 to the total number of
#'          treatment assignments. In regard to the candidate split
#'          variables, if a variable is binary, then the variable
#'          must take values of 0 or 1. If a variable is nominal,
#'          then the values must be integers starting from 1 to the
#'          total number of categories. There cannot be any missing
#'          values in the dataset. For candidate split variables
#'          with missing values, the missings together (MT) method
#'          proposed by Zhang et al. (1996) is helpful.
#'
#' @return \code{spmtree} returns the final classification tree as a 
#'         \code{party} object by default or a data frame. See 
#'         Hothorn and Zeileis (2015) for details. The data 
#'         frame contains the following columns of information:
#'         \item{node}{Unique integer values that identify each node
#'                     in the tree, where all of the nodes are
#'                     indexed starting from 1}
#'         \item{splitvar}{Integers that represent the candidate split
#'                         variable used to split each node, where
#'                         all of the variables are indexed starting
#'                         from 1; for terminal nodes, i.e., nodes
#'                         without child nodes, the value is set 
#'                         equal to NA}
#'         \item{splitvar_name}{The names of the candidate split 
#'                              variables used to split each node
#'                              obtained from the column names of the
#'                              supplied data; for terminal nodes,
#'                              the value is set equal to NA}
#'         \item{type}{Characters that denote the type of each 
#'                     candidate split variable; "bin" is for binary
#'                     variables, "ord" for ordinal, and "nom" for
#'                     nominal; for terminal nodes, the value is set
#'                     equal to NA}
#'         \item{splitval}{Values of the left child node of the 
#'                         current split/node; for binary variables,
#'                         a value of 0 is printed, and subjects with
#'                         values of 0 for the current \code{splitvar}
#'                         are in the left child node, while subjects
#'                         with values of 1 are in the right child
#'                         node; for ordinal variables,
#'                         \code{splitval} is numeric and implies
#'                         that subjects with values of the current
#'                         \code{splitvar} less than or equal to
#'                         \code{splitval} are in the left child 
#'                         node, while the remaining subjects with 
#'                         values greater than \code{splitval} are in 
#'                         the right child node; for nominal
#'                         variables, the \code{splitval} is a set of
#'                         integers separated by commas, and subjects
#'                         in that set of categories are in the left
#'                         child node, while the remaining subjects
#'                         are in the right child node; for terminal
#'                         nodes, the value is set equal to NA}
#'         \item{lchild}{Integers that represent the index (i.e.,
#'                       \code{node} value) of each node's left
#'                       child node; for terminal nodes, the value is
#'                       set equal to NA}
#'         \item{rchild}{Integers that represent the index (i.e.,
#'                       \code{node} value) of each node's right
#'                       child node; for terminal nodes, the value is
#'                       set equal to NA}
#'         \item{depth}{Integers that specify the depth of each
#'                      node; the root node has depth 1, its 
#'                      children have depth 2, etc.}
#'         \item{nsubj}{Integers that count the total number of
#'                      subjects within each node}
#'         \item{besttrt}{Integers that denote the identified best 
#'                        treatment assignment of each node}
#'
#' @references Chen, V., Li, C., and Zhang, H. (2022). dipm: an 
#'             R package implementing the Depth Importance in 
#'             Precision Medicine (DIPM) tree and Forest-based method.
#'             \emph{Bioinformatics Advances}, \strong{2}(1), vbac041.
#'             
#'             Chen, V. and Zhang, H. (2022). Depth importance in 
#'             precision medicine (DIPM): A tree-and forest-based 
#'             method for right-censored survival outcomes. 
#'             \emph{Biostatistics} \strong{23}(1), 157-172.
#'             
#'             Chen, V. and Zhang, H. (2020). Depth importance in 
#'             precision medicine (DIPM): a tree and forest based method. 
#'             In \emph{Contemporary Experimental Design, 
#'             Multivariate Analysis and Data Mining}, 243-259.
#' 
#'             Tsai, W.-M., Zhang, H., Buta, E., O'Malley, S., 
#'             Gueorguieva, R. (2016). A modified classification
#'             tree method for personalized medicine decisions.
#'             \emph{Statistics and its Interface} \strong{9}, 
#'             239-253.
#'
#'             Zhang, H., Holford, T., and Bracken, M.B. (1996).
#'             A tree-based method of analysis for prospective
#'             studies. \emph{Statistics in Medicine} \strong{15},
#'             37-49.
#'
#'             Zhang, H., Legro, R.S., Zhang, J., Zhang, L., Chen,
#'             X., et al. (2010). Decision trees for identifying
#'             predictors of treatment effectiveness in clinical
#'             trials and its application to ovulation in a study of
#'             women with polycystic ovary syndrome. \emph{Human
#'             Reproduction} \strong{25}, 2612-2621.
#'             
#'             Hothorn, T. and Zeileis, A. (2015). partykit: 
#'             a modular toolkit for recursive partytioning in R. 
#'             \emph{The Journal of Machine Learning Research} 
#'             \strong{16}(1), 3905-3909.
#'
#' @seealso \code{\link{dipm}}
#'
#' @examples
#' 
#' #
#' # ... an example with a continuous outcome variable
#' #     and two treatment groups
#' #
#'
#' N = 300
#' set.seed(123)
#'
#' # generate binary treatments
#' treatment = rbinom(N, 1, 0.5)
#'
#' # generate candidate split variables
#' X1 = rnorm(n = N, mean = 0, sd = 1)
#' X2 = rnorm(n = N, mean = 0, sd = 1)
#' X3 = rnorm(n = N, mean = 0, sd = 1)
#' X4 = rnorm(n = N, mean = 0, sd = 1)
#' X5 = rnorm(n = N, mean = 0, sd = 1)
#' X = cbind(X1, X2, X3, X4, X5)
#' colnames(X) = paste0("X", 1:5)
#'
#' # generate continuous outcome variable
#' calculateLink = function(X, treatment){
#'
#'     ((X[, 1] <= 0) & (X[, 2] <= 0)) *
#'         (25 * (1 - treatment) + 8 * treatment) + 
#'
#'     ((X[, 1] <= 0) & (X[, 2] > 0)) *
#'         (18 * (1 - treatment) + 20 * treatment) +
#'
#'     ((X[, 1] > 0) & (X[, 3] <= 0)) *
#'         (20 * (1 - treatment) + 18 * treatment) + 
#'
#'     ((X[, 1] > 0) & (X[, 3] > 0)) *
#'         (8 * (1 - treatment) + 25 * treatment)
#' }
#'
#' Link = calculateLink(X, treatment)
#' Y = rnorm(N, mean = Link, sd = 1)
#'
#' # combine variables in a data frame
#' data = data.frame(X, Y, treatment)
#' 
#' # fit a classification tree
#' tree1 = spmtree(Y ~ treatment | ., data, maxdepth = 3)
#' # predict optimal treatment for new subjects
#' predict(tree1, newdata = head(data), 
#' FUN = function(n)  as.numeric(n$info$opt_trt))
#'
#'\donttest{
#' #
#' # ... an example with a continuous outcome variable
#' #     and three treatment groups
#' #
#' 
#' N = 600
#' set.seed(123)
#' 
#' # generate treatments
#' treatment = sample(1:3, N, replace = TRUE)
#' 
#' # generate candidate split variables
#' X1 = round(rnorm(n = N, mean = 0, sd = 1), 4)
#' X2 = round(rnorm(n = N, mean = 0, sd = 1), 4)
#' X3 = sample(1:4, N, replace = TRUE)
#' X4 = sample(1:5, N, replace = TRUE)
#' X5 = rbinom(N, 1, 0.5)
#' X6 = rbinom(N, 1, 0.5)
#' X7 = rbinom(N, 1, 0.5)
#' X = cbind(X1, X2, X3, X4, X5, X6, X7)
#' colnames(X) = paste0("X", 1:7)
#' 
#' # generate continuous outcome variable
#' calculateLink = function(X, treatment){
#' 
#'     10.2 - 0.3 * (treatment == 1) - 0.1 * X[, 1] + 
#'     2.1 * (treatment == 1) * X[, 1] +
#'     1.2 * X[, 2]
#' }
#' 
#' Link = calculateLink(X, treatment)
#' Y = rnorm(N, mean = Link, sd = 1)
#' 
#' # combine variables in a data frame
#' data = data.frame(X, Y, treatment)
#' 
#' # create vector of variable types
#' types = c(rep("ordinal", 2), rep("nominal", 2), rep("binary", 3),
#'         "response", "treatment")
#' 
#' # fit a classification tree
#' tree2 = spmtree(Y ~ treatment | ., data, types = types)
#' 
#' #
#' # ... an example with a survival outcome variable
#' #     and two treatment groups
#' #
#'
#' N = 300
#' set.seed(321)
#'
#' # generate binary treatments
#' treatment = rbinom(N, 1, 0.5)
#'
#' # generate candidate split variables
#' X1 = rnorm(n = N, mean = 0, sd = 1)
#' X2 = rnorm(n = N, mean = 0, sd = 1)
#' X3 = rnorm(n = N, mean = 0, sd = 1)
#' X4 = rnorm(n = N, mean = 0, sd = 1)
#' X5 = rnorm(n = N, mean = 0, sd = 1)
#' X = cbind(X1, X2, X3, X4, X5)
#' colnames(X) = paste0("X", 1:5)
#'
#' # generate survival outcome variable
#' calculateLink = function(X, treatment){
#'
#'     X[, 1] + 0.5 * X[, 3] + (3 * treatment - 1.5) * (abs(X[, 5]) - 0.67)
#' }
#'
#' Link = calculateLink(X, treatment)
#' T = rexp(N, exp(-Link))
#' C0 = rexp(N, 0.1 * exp(X[, 5] + X[, 2]))
#' Y = pmin(T, C0)
#' delta = (T <= C0)
#'
#' # combine variables in a data frame
#' data = data.frame(X, Y, delta, treatment)
#' 
#' # fit a classification tree
#' tree3 = spmtree(Surv(Y, delta) ~ treatment | ., data, maxdepth = 2)
#' 
#' #
#' # ... an example with a survival outcome variable
#' #     and four treatment groups
#' #
#' 
#' N = 800
#' set.seed(321)
#' 
#' # generate treatments
#' treatment = sample(1:4, N, replace = TRUE)
#' 
#' # generate candidate split variables
#' X1 = round(rnorm(n = N, mean = 0, sd = 1), 4)
#' X2 = round(rnorm(n = N, mean = 0, sd = 1), 4)
#' X3 = sample(1:4, N, replace = TRUE)
#' X4 = sample(1:5, N, replace = TRUE)
#' X5 = rbinom(N, 1, 0.5)
#' X6 = rbinom(N, 1, 0.5)
#' X7 = rbinom(N, 1, 0.5)
#' X = cbind(X1, X2, X3, X4, X5, X6, X7)
#' colnames(X) = paste0("X", 1:7)
#' 
#' # generate survival outcome variable
#' calculateLink = function(X, treatment, noise){
#' 
#'     -0.2 * (treatment == 1) +
#'     -1.1 * X[, 1] + 
#'     1.2 * (treatment == 1) * X[, 1] +
#'     1.2 * X[, 2]
#' }
#' 
#' Link = calculateLink(X, treatment)
#' T = rweibull(N, shape = 2, scale = exp(Link))
#' Cnoise = runif(n = N) + runif(n = N)
#' C0 = rexp(N, exp(0.3 * -Cnoise))
#' Y = pmin(T, C0)
#' delta = (T <= C0)
#' 
#' # combine variables in a data frame
#' data = data.frame(X, Y, delta, treatment)
#' 
#' # create vector of variable types
#' types = c(rep("ordinal", 2), rep("nominal", 2), rep("binary", 3),
#'         "response", "status", "treatment")
#' 
#' # fit two classification trees
#' tree4 = spmtree(Surv(Y, delta) ~ treatment | ., data, types = types, maxdepth = 2)
#' tree5 = spmtree(Surv(Y, delta) ~ treatment | X3 + X4, data, types = types,
#'              maxdepth = 2)
#' }
#' @export
#' @import partykit
#' @import survival
#' @import stats

spmtree = function(formula,
                   data,
                   types = NULL,
                   nmin = 5,
                   maxdepth = Inf,
                   print = TRUE,
                   dataframe = FALSE,
                   prune = FALSE){

#    check inputs
    if(missing(formula)){
        stop("The formula input is missing.")
    }

    if(missing(data)){
        stop("The data input is missing.")
    }

#    coerce data input to R "data.frame" object
    data = as.data.frame(data)

#    coerce formula input to R "formula" object
    form = as.formula(formula)

#    if not missing, coerce types input to R "data.frame" object
    if(missing(types) == FALSE){
        
        if(!all(types %in% c("ordinal", "nominal", "binary",
                             "response", "status", "treatment"))){
            stop("The type input is invalid.")
        }

        types = as.data.frame(types)

        if(nrow(types) != 1){
            types = t(types)
            types = as.data.frame(types)
        }

        if(ncol(types) != ncol(data)){
            stop("The number of variables in types does not equal the number of variables in the data.")
        }

        colnames(types) = colnames(data)
    }

#    get names of variables in the formula
    form_vars = all.vars(form)      # all variables
    form_lhs = all.vars(form[[2]])  # variables to the left of ~
    form_rhs = all.vars(form[[3]])  # variables to the right of ~

#    response variable should always be (first) in lhs
    Y = data[, form_lhs[1]]
    if(!class(Y) %in% c("numeric", "integer")){
        stop("Response Y must be numerical.")
    }

#    get status variable if applicable
    if(length(form_lhs) == 1){
        C = rep(0, nrow(data))
        surv = 0
    }

    if(length(form_lhs) == 2){
        C = data[, form_lhs[2]]
        if(!class(C) %in% c("numeric", "integer", "logical")){
            stop("delta must be integers.")
        }
        if(!all(unique(C) %in% c(0, 1))){
            stop("delta must be 0 or 1.")
        }
        surv = 1
    }

#    treatment variable should always be first in rhs
    treatment = data[, form_rhs[1]]
    if(!class(treatment) %in% c("numeric", "integer")){
        stop("Treatment must be integers.")
    }

#    determine appropriate method from data
    ntrts = nlevels(as.factor(treatment))
    
    if(maxdepth == Inf){
        maxdepth = -7
    }

    if(ntrts <= 1){
        stop("At least 2 treatment groups are required.")
    }

    if(ntrts == 2){
        
        if(!all(unique(treatment) %in% c(0, 1))){
            stop("Treatment must be 0 or 1 for two treatment groups.")
        }

        if(surv == 0){

            method = -1

        }else if(surv == 1){

            method = 11
        }

    }else if(ntrts > 2){
        
        if(!all(unique(treatment) %in% rep(1:ntrts))){
            stop("Treatment must be 1 to ntrts for more than two treatment groups.")
        }

        if(surv == 0){

            method = 24

        }else if(surv == 1){

            method = 25
        }
    }

#    get matrix of candidate split variables X
    if(form_rhs[2] == "."){  # account for Y ~ treatment | . 
                                 # formula

        exclude = c(which(colnames(data) == form_lhs[1]), # Y variable
                  which(colnames(data) == form_rhs[1])) # treatment

        if(length(form_lhs) == 2){ # exclude status indicator

            exclude = c(exclude, which(colnames(data) == form_lhs[2]))
        }

        X = data.frame(data[, -exclude])
        types = types[, -exclude]
    }else{

        include = which(colnames(data) %in% form_rhs[-1])

        X = data.frame(data[, include])
        types = types[, include]
    }

#    calculate number of observations n and variables nc
    n = nrow(X)
    nc = ncol(X)
    if(nc == 1){
        if( form_rhs[2] == "." ){
            names(X) = names(data)[-exclude]
        }else{
            names(X) = names(data)[include]
        }
    }

#    prepare types
    if(is.null(types)){  

        types = rep(2, nc) # default is to assume all candidate
                        # split variables are ordinal

        message("Note that all candidate split variables are assumed to be ordinal.")

    }else{
        if(nc == 1){
            types = data.frame(types)
            if( form_rhs[2] == "." ){
                names(types) = names(data)[-exclude]
            }else{
                names(types) = names(data)[include]
            }
            rownames(types) = "types"
        }
        lll = ncol(types)
        for (i in 1:lll){
            if(types[i] == "binary") types[i] = 1
            if(types[i] == "ordinal") types[i] = 2
            if(types[i] == "nominal") types[i] = 3
        }
    }

    ifbinary = any(types == 1)
    if(ifbinary == TRUE){
        ibin = which(types == 1)
        if(length(ibin) == 1){
            if(!class(data[, ibin]) %in% c("numeric", "integer")){
                stop("Binary variables must be integers.")
            }
            if(!all(unique(data[, ibin]) %in% c(0, 1))){
                stop("Binary variables must be 0 or 1.")
            }
        }else{
            if(!all(apply(data[, ibin], 2, class) %in% c("numeric", "integer"))){
                stop("Binary variables must be integers.")
            }
            if(!all(apply(data[, ibin], 2, unique) %in% c(0, 1))){
                stop("Binary variables must be 0 or 1.")
            }
        }
    }
    
    ifordinal = any(types == 2)
    if(ifordinal == TRUE){
        iord = which(types == 2)
        if(length(iord) == 1){
            if(!class(data[, iord]) %in% c("numeric", "integer")){
                stop("Ordinal variables must be numerical.")
            }
        }else{
            if(!all(apply(data[, iord], 2, class) %in% c("numeric", "integer"))){
                stop("Ordinal variables must be numerical.")
            }
        }
    }
    
#    create array of number of categories for nominal variables
    ifnominal = any(types == 3)
    if(ifnominal == TRUE){

        inom = which(types == 3)
        for(i in 1:length(inom)){
            if(!class(X[, inom[i]]) %in% c("numeric", "integer")){
                stop("Nominal variables must be integers.")
            }
            ncats = length(unique(X[, inom[i]]))
            if(!all(unique(X[, inom[i]]) %in% rep(1:ncats))){
                stop("Nominal must be 1 to ncats.")
            }
            X[, inom[i]] = factor(X[, inom[i]])
            data[, colnames(X)[inom[i]]] = X[, inom[i]]
        }

        ncat = sapply(X, function(x) 
                    if(is.null(levels(x))) -7 
                    else max(as.numeric(levels(x)[x])))

    }else{
        ncat = rep(-7, nc)
    }

#    prepare covariate data
    XC = t(X)

#    set other unused parameter values to 0
    ntree = 0
    mtry = 0
    nmin2 = 0
    maxdepth2 = 0

#    set types of R arguments to C
    storage.mode(ntree) = "integer"
    storage.mode(n) = "integer"
    storage.mode(nc) = "integer"
    storage.mode(Y) = "double"
    storage.mode(XC) = "double"
    storage.mode(types) = "integer"
    storage.mode(ncat) = "integer"
    storage.mode(treatment) = "integer"
    storage.mode(C) = "integer"
    storage.mode(nmin) = "integer"
    storage.mode(nmin2) = "integer"
    storage.mode(mtry) = "integer"
    storage.mode(maxdepth) = "integer"
    storage.mode(maxdepth2) = "integer"
    storage.mode(method) = "integer"

    tree = .Call("maketree",
               ntree = ntree,
               n = n,
               nc = nc,
               Y = Y,
               X = XC,
               types = types,
               ncat = ncat,
               treat = treatment,
               censor = C,
               nmin = nmin,
               nmin2 = nmin2,
               mtry = mtry,
               maxdepth = maxdepth,
               maxdepth2 = maxdepth2,
               method = method,
               environment(lm_R_to_C))

    rm(XC)

#    reformat tree
    tree_txt = data.frame(as.vector(tree[[1]]),
                        as.vector(tree[[2]]),
                        as.vector(tree[[3]]),
                        as.vector(tree[[4]]),
                        as.vector(tree[[5]]),
                        as.vector(tree[[7]]),
                        as.vector(tree[[8]]),
                        as.vector(tree[[9]]),
                        as.vector(tree[[6]]),
                        as.vector(tree[[10]]),
                        as.vector(tree[[11]]),
                        as.vector(tree[[12]]),
                        as.vector(tree[[13]]),
                        as.vector(tree[[14]]),
                        as.vector(tree[[15]]),
                        as.vector(tree[[16]]),
                        as.vector(tree[[17]]))

    colnames(tree_txt) = c("node",
                         "splitvar",
                         "type",
                         "sign",
                         "splitval",
                         "parent",
                         "lchild",
                         "rchild",
                         "depth",
                         "nsubj",
                         "ntrt0",
                         "ntrt1",
                         "r0",
                         "r1",
                         "p0",
                         "p1",
                         "besttrt")

#    process tree output and/or print tree to screen
    if(form_rhs[2] != "."){
        splitvar_include = t(data.frame(include))
        colnames(splitvar_include) = colnames(X)
    }else{
        splitvar_include = NULL
    }
    tree_txt = print.dipm(tree_txt, X, Y, C, treatment,
                        types, ncat, method, ntree, print,
                        splitvar_include)
    if(prune){
        tree_txt = pmprune(tree_txt)
    }
    
    if(dataframe){
       return(tree_txt) 
    }else{
        tree_pn = ini_node(1, tree_txt, data, form_rhs[1], surv)
        if(surv){
            tree_py = party(tree_pn, data,
                            fitted = data.frame(
                                 "(fitted)" = fitted_node(tree_pn, data = data),
                                 "(response)" = Surv(Y, C), check.names = F),
                             terms = terms(form))
        }else{
            tree_py = party(tree_pn, data,
                            fitted = data.frame(
                                 "(fitted)" = fitted_node(tree_pn, data = data),
                                 "(response)" = Y, check.names = F),
                             terms = terms(form))
        }
        
        return(tree_py) 
    }

}

Try the dipm package in your browser

Any scripts or data that you put into this service are public.

dipm documentation built on Oct. 29, 2022, 1:09 a.m.