R/partykit.R

Defines functions survival_prob_partykit

Documented in survival_prob_partykit

#' A wrapper for survival probabilities with partykit models
#' @param object A model object from `partykit::ctree()` or `partykit::cforest()`.
#' @param new_data A data frame to be predicted.
#' @param eval_time A vector of times to predict the survival probability.
#' @param time Deprecated in favor of `eval_time`. A vector of times to predict the survival probability.
#' @param output Type of output. Can be either `"surv"` or `"haz"`.
#' @return A tibble with a list column of nested tibbles.
#' @export
#' @keywords internal
#' @examples
#' library(partykit)
#' c_tree <- ctree(Surv(time, status) ~ age + ph.ecog, data = lung)
#' survival_prob_partykit(c_tree, lung[1:3, ], eval_time = 100)
#' c_forest <- cforest(Surv(time, status) ~ age + ph.ecog, data = lung, ntree = 10)
#' survival_prob_partykit(c_forest, lung[1:3, ], eval_time = 100)
survival_prob_partykit <- function(object,
                                   new_data,
                                   eval_time,
                                   time = deprecated(),
                                   output = "surv") {
  if (lifecycle::is_present(time)) {
    lifecycle::deprecate_warn(
      "0.2.0",
      "survival_prob_partykit(time)",
      "survival_prob_partykit(eval_time)"
    )
    eval_time <- time
  }

  # don't use the confidence intervals: the KM does not know about
  # how the sample for it got constructed (by the tree) and thus standard errors
  # don't take that into account
  output <- rlang::arg_match(output, c("surv", "haz"))

  n_obs <- nrow(new_data)
  # partykit handles missing values
  missings_in_new_data <- NULL

  y <- predict(object, newdata = new_data, type = "prob")

  survfit_summary_list <- purrr::map(y, summary, times = eval_time, extend = TRUE)
  survfit_summary_combined <- combine_list_of_survfit_summary(
    survfit_summary_list,
    eval_time = eval_time
  )

  res <- survfit_summary_patch(
    survfit_summary_combined,
    index_missing = missings_in_new_data,
    eval_time = eval_time,
    n_obs = n_obs
  ) %>%
    survfit_summary_to_tibble(eval_time = eval_time, n_obs = n_obs) %>%
    keep_cols(output) %>%
    tidyr::nest(.pred = c(-.row)) %>%
    dplyr::select(-.row)

  res
}

Try the censored package in your browser

Any scripts or data that you put into this service are public.

censored documentation built on April 14, 2023, 12:30 a.m.