predict.orsf_fit: Compute predictions using ORSF

View source: R/orsf_predict.R

predict.orsf_fitR Documentation

Compute predictions using ORSF

Description

Predicted risk, survival, hazard, or mortality from an ORSF model.

Usage

## S3 method for class 'orsf_fit'
predict(
  object,
  new_data,
  pred_horizon = NULL,
  pred_type = "risk",
  na_action = "fail",
  boundary_checks = TRUE,
  n_thread = 1,
  verbose_progress = FALSE,
  pred_aggregate = TRUE,
  ...
)

Arguments

object

(orsf_fit) a trained oblique random survival forest (see orsf).

new_data

a data.frame, tibble, or data.table to compute predictions in.

pred_horizon

(double) a value or vector indicating the time(s) that predictions will be calibrated to. E.g., if you were predicting risk of incident heart failure within the next 10 years, then pred_horizon = 10. pred_horizon can be NULL if pred_type is 'mort', since mortality predictions are aggregated over all event times

pred_type

(character) the type of predictions to compute. Valid options are

  • 'risk' : probability of having an event at or before pred_horizon.

  • 'surv' : 1 - risk.

  • 'chf': cumulative hazard function

  • 'mort': mortality prediction

na_action

(character) what should happen when new_data contains missing values (i.e., NA values). Valid options are:

  • 'fail' : an error is thrown if new_data contains NA values

  • 'pass' : the output will have NA in all rows where new_data has 1 or more NA value for the predictors used by object

  • 'omit' : rows in new_data with incomplete data will be dropped

  • 'impute_meanmode' : missing values for continuous and categorical variables in new_data will be imputed using the mean and mode, respectively. To clarify, the mean and mode used to impute missing values are from the training data of object, not from new_data.

boundary_checks

(logical) if TRUE, pred_horizon will be checked to make sure the requested values are less than the maximum observed time in object's training data. If FALSE, these checks are skipped.

n_thread

(integer) number of threads to use while computing predictions. Default is one thread. To use the maximum number of threads that your system provides for concurrent execution, set n_thread = 0.

verbose_progress

(logical) if TRUE, progress messages are printed in the console. If FALSE (the default), nothing is printed.

pred_aggregate

(logical) If TRUE (the default), predictions will be aggregated over all trees by taking the mean. If FALSE, the returned output will contain one row per observation and one column for each tree. If the length of pred_horizon is two or more and pred_aggregate is FALSE, then the result will be a list of such matrices, with the i'th item in the list corresponding to the i'th value of pred_horizon.

...

Further arguments passed to or from other methods (not currently used).

Details

new_data must have the same columns with equivalent types as the data used to train object. Also, factors in new_data must not have levels that were not in the data used to train object.

pred_horizon values should not exceed the maximum follow-up time in object's training data, but if you truly want to do this, set boundary_checks = FALSE and you can use a pred_horizon as large as you want. Note that predictions beyond the maximum follow-up time in the object's training data are equal to predictions at the maximum follow-up time, because aorsf does not estimate survival beyond its maximum observed time.

If unspecified, pred_horizon may be automatically specified as the value used for oobag_pred_horizon when object was created (see orsf).

Value

a matrix of predictions. Column j of the matrix corresponds to value j in pred_horizon. Row i of the matrix corresponds to row i in new_data.

Examples

Begin by fitting an ORSF ensemble:

library(aorsf)

set.seed(329730)

index_train <- sample(nrow(pbc_orsf), 150) 

pbc_orsf_train <- pbc_orsf[index_train, ]
pbc_orsf_test <- pbc_orsf[-index_train, ]

fit <- orsf(data = pbc_orsf_train, 
            formula = Surv(time, status) ~ . - id,
            oobag_pred_horizon = 365.25 * 5)

Predict risk, survival, or cumulative hazard at one or several times:

# predicted risk, the default
predict(fit, 
        new_data = pbc_orsf_test[1:5, ], 
        pred_type = 'risk', 
        pred_horizon = c(500, 1000, 1500))
##             [,1]       [,2]       [,3]
## [1,] 0.458321864 0.72956562 0.89157462
## [2,] 0.032194868 0.08028381 0.15592011
## [3,] 0.115945485 0.24099853 0.38094684
## [4,] 0.008378033 0.02964250 0.06977315
## [5,] 0.009798295 0.01793586 0.04454374
# predicted survival, i.e., 1 - risk
predict(fit, 
        new_data = pbc_orsf_test[1:5, ], 
        pred_type = 'surv',
        pred_horizon = c(500, 1000, 1500))
##           [,1]      [,2]      [,3]
## [1,] 0.5416781 0.2704344 0.1084254
## [2,] 0.9678051 0.9197162 0.8440799
## [3,] 0.8840545 0.7590015 0.6190532
## [4,] 0.9916220 0.9703575 0.9302269
## [5,] 0.9902017 0.9820641 0.9554563
# predicted cumulative hazard function
# (expected number of events for person i at time j)
predict(fit, 
        new_data = pbc_orsf_test[1:5, ], 
        pred_type = 'chf',
        pred_horizon = c(500, 1000, 1500))
##            [,1]       [,2]       [,3]
## [1,] 0.63432745 1.26771489 1.74183802
## [2,] 0.03415809 0.09124550 0.20017014
## [3,] 0.14715014 0.34375274 0.62976148
## [4,] 0.00857621 0.03195771 0.08744159
## [5,] 0.01043219 0.01888677 0.05177019

Predict mortality, defined as the number of events in the forest’s population if all observations had characteristics like the current observation. This type of prediction does not require you to specify a prediction horizon

predict(fit, 
        new_data = pbc_orsf_test[1:5, ], 
        pred_type = 'mort')
##           [,1]
## [1,] 78.575909
## [2,] 20.872849
## [3,] 37.341745
## [4,] 13.616617
## [5,]  8.798328

aorsf documentation built on Oct. 26, 2023, 5:08 p.m.