R/get_tree.R

Defines functions get_tree

Documented in get_tree

#' Extract some information about the split for a tree by user
#'
#' @inheritParams compute_vimp
#' @param tree Integer indicating the tree identifier
#'
#' @return A table sorted by the node/leaf identifier with each row representing a node/leaf. Each column provides information about the splits:\tabular{ll}{
#' \code{type} \tab The nature of the predictor (\code{Longitudinal} for longitudinal predictor, \code{Numeric} for continuous predictor or \code{Factor} for categorical predictor) if the node was split, \code{Leaf} otherwise \cr
#' \tab \cr
#' \code{var_split} \tab The predictor used for the split defined by its order in \code{timeData} and \code{fixedData} \cr
#' \tab \cr
#' \code{feature} \tab The feature used for the split defined by its position in random statistic \cr
#' \tab \cr
#' \code{threshold} \tab The threshold used for the split (only with \code{Longitudinal} and \code{Numeric}). No information is returned for \code{Factor} \cr
#' \tab \cr
#' \code{N} \tab The number of subjects in the node/leaf \cr
#' \tab \cr
#' \code{Nevent} \tab The number of events of interest in the node/leaf (only with survival outcome) \cr
#' \tab \cr
#' \code{depth} \tab the depth level of the node/leaf \cr
#' }
#'
#' @seealso [dynforest()]
#'
#' @examples
#' \donttest{
#' data(pbc2)
#'
#' # Get Gaussian distribution for longitudinal predictors
#' pbc2$serBilir <- log(pbc2$serBilir)
#' pbc2$SGOT <- log(pbc2$SGOT)
#' pbc2$albumin <- log(pbc2$albumin)
#' pbc2$alkaline <- log(pbc2$alkaline)
#'
#' # Sample 100 subjects
#' set.seed(1234)
#' id <- unique(pbc2$id)
#' id_sample <- sample(id, 100)
#' id_row <- which(pbc2$id%in%id_sample)
#'
#' pbc2_train <- pbc2[id_row,]
#'
#  Build longitudinal data
#' timeData_train <- pbc2_train[,c("id","time",
#'                                 "serBilir","SGOT",
#'                                 "albumin","alkaline")]
#'
#' # Create object with longitudinal association for each predictor
#' timeVarModel <- list(serBilir = list(fixed = serBilir ~ time,
#'                                      random = ~ time),
#'                      SGOT = list(fixed = SGOT ~ time + I(time^2),
#'                                  random = ~ time + I(time^2)),
#'                      albumin = list(fixed = albumin ~ time,
#'                                     random = ~ time),
#'                      alkaline = list(fixed = alkaline ~ time,
#'                                      random = ~ time))
#'
#' # Build fixed data
#' fixedData_train <- unique(pbc2_train[,c("id","age","drug","sex")])
#'
#' # Build outcome data
#' Y <- list(type = "surv",
#'           Y = unique(pbc2_train[,c("id","years","event")]))
#'
#' # Run dynforest function
#' res_dyn <- dynforest(timeData = timeData_train, fixedData = fixedData_train,
#'                      timeVar = "time", idVar = "id",
#'                      timeVarModel = timeVarModel, Y = Y,
#'                      ntree = 50, nodesize = 5, minsplit = 5,
#'                      cause = 2, ncores = 2, seed = 1234)
#'
#' # Extract split information from tree 4
#' res_tree4 <- get_tree(dynforest_obj = res_dyn, tree = 4)
#' }
#' @export
get_tree <- function(dynforest_obj, tree){

  if (!methods::is(dynforest_obj,"dynforest")){
    cli_abort(c(
      "{.var dynforest_obj} must be a dynforest object",
      "x" = "You've supplied a {.cls {class(dynforest_obj)}} object"
    ))
  }

  if (!inherits(tree, "numeric")){
    cli_abort(c(
      "{.var tree} must be a numeric object containing the tree identifier",
      "x" = "You've supplied a {.cls {class(tree)}} object"
    ))
  }

  if (!any(tree==seq(dynforest_obj$param$ntree))){
    cli_abort(c(
      "{.var tree} must be chosen between 1 and {dynforest_obj$param$ntree}",
      "x" = "You've chosen {tree}"
    ))
  }

  out <- dynforest_obj$rf[,tree]$V_split

  return(out)

}
anthonydevaux/DynForest documentation built on June 9, 2025, 11 p.m.