R/survTreeLaplaceHazards.R

Defines functions survTreeLaplaceHazard

Documented in survTreeLaplaceHazard

#' Laplace Hazards for a Competing Risk Survival Tree Object
#'
#' Predicts the laplace-smoothed hazards of discrete survival tree. 
#' Can be used for single-risk or competing risk discrete survival data.
#' 
#' @param treeModel Fitted tree object as generated by "rpart" ("class rpart").
#' @param newdata Data in long format for which hazards are to be computed. Must 
#' contain the same columns that were used for tree fitting("class data.frame").
#' @param lambda Smoothing parameter for laplace-smoothing. Must be a non-negative 
#' number. A value of 0 corresponds to no smoothing ("numeric vector").
#' @return A m by k matrix with m being the length of newdata and k being the 
#' number of classes in treeModel. Each row corresponds to the smoothed hazard 
#' of the respective observation.
#'
#' @keywords survival
#' @examples
#' library(pec)
#' library(caret)
#' # Example data
#' data(cost)
#' # Convert time to years and select training and testing subsample
#' cost$time <- ceiling(cost$time/365)
#' costTrain <- cost[1:100, ]
#' costTest  <- cost[101:120, ]
#' # Convert to long format
#' timeColumn <- "time"
#' eventColumn <- "status"
#' costTrainLong <- dataLong(dataShort=costTrain, timeColumn = "time", 
#'                           eventColumn = "status")
#' costTestLong  <- dataLong(dataShort=costTest, timeColumn = "time", 
#'                           eventColumn = "status")
#' head(costTrainLong)
#' # Fit a survival tree
#' costTree <- rpart(formula = y ~ timeInt + prevStroke + age + sex, data = costTrainLong, 
#'                   method = "class")
#' # Compute smoothed hazards for test data
#' predictedhazards <- survTreeLaplaceHazard(costTree, costTestLong, 1)
#' predictedhazards
#' @export survTreeLaplaceHazard
survTreeLaplaceHazard <- function(treeModel, newdata, lambda)
{
  # Input Checks
  if(is.null(treeModel$frame))
  {
    stop("Incorrect model. Please provide an object of type rpart.")
  }
  if(lambda < 0 | !is.numeric(lambda) | length(lambda) != 1)
  {
    stop("Lambda must be a non-negative number.")
  }
  if(!all(unique(treeModel$frame$var)[-which(unique(treeModel$frame$var) == "<leaf>")]
          %in% colnames(newdata)))
  {
    stop("Newdata does not contain the same covariates as the tree model.")
  }
  #derive number of risks
  n_events <- length(unique(treeModel$y))
  #derive index of terminal nodes
  leaf_index <- factor(rownames(treeModel$frame[which(treeModel$frame$var == "<leaf>"), ]))
  #predict node for new data
  predicted_values <- factor(predict_leaves(treeModel, newdata), levels = leaf_index)
  #compute laplace-smoothed hazards for new data
  y_table <- treeModel$frame$yval2
  y_table <- y_table[leaf_index,2:(1 + n_events)]
  if(is.null(nrow(y_table))) y_table = matrix(y_table, nrow = 1)
  hazards <- t(apply(y_table, 1, function(x) (x + lambda)/(sum(x) + lambda * n_events)))
  hazards_fitted <- hazards[predicted_values, ]
  return(hazards_fitted)
}
#predict nodes from ranger tree object
predict_leaves <-
  function (object, newdata, na.action = na.pass) {
    where <-
      if (missing(newdata)) 
        object$where
    else {
      if (is.null(attr(newdata, "terms"))) {
        Terms <- delete.response(object$terms)
        newdata <- model.frame(Terms, newdata, na.action = na.action, 
                               xlev = attr(object, "xlevels"))
        if (!is.null(cl <- attr(Terms, "dataClasses"))) 
          .checkMFClasses(cl, newdata, TRUE)
      }
      pred.rpart(object, rpart.matrix(newdata))
    }
    as.integer(row.names(object$frame))[where]
  }

pred.rpart <- getFromNamespace("pred.rpart", "rpart")
rpart.matrix <- getFromNamespace("rpart.matrix", "rpart")

Try the discSurv package in your browser

Any scripts or data that you put into this service are public.

discSurv documentation built on March 18, 2022, 7:12 p.m.