tests/testthat/test_survival.R

## Tests for random forests for survival analysis

library(blockForest)
library(survival)
context("blockForest_surv")

## Initialize the random forest for survival analysis
rg.surv <- blockForest(Surv(time, status) ~ ., data = veteran, verbose = FALSE, 
                  write.forest = TRUE, num.trees = 10)

## Basic tests (for all random forests equal)
test_that("survival result is of class blockForest with 16 elements", {
  expect_is(rg.surv, "blockForest")
  expect_equal(length(rg.surv), 16)
})

test_that("results have right number of trees", {
  expect_equal(rg.surv$num.trees, 10)
})

test_that("results have right number of independent variables", {
  expect_equal(rg.surv$num.independent.variables, ncol(veteran) - 2)
})

test_that("Alternative interface works for survival", {
  rf <- blockForest(dependent.variable.name = "time", status.variable.name = "status", data = veteran, num.trees = 10)
  expect_equal(rf$treetype, "Survival")
})

test_that("Alternative interface prediction works for survival", {
  rf <- blockForest(dependent.variable.name = "time", status.variable.name = "status", data = veteran, num.trees = 10)
  expect_equal(predict(rf, veteran)$num.independent.variables, ncol(veteran) - 2) 
  expect_equal(predict(rf, veteran[, setdiff(names(veteran), c("time", "status"))])$num.independent.variables, ncol(veteran) - 2) 
})

test_that("Matrix interface works for survival", {
  rf <- blockForest(dependent.variable.name = "time", status.variable.name = "status", data = data.matrix(veteran), write.forest = TRUE, num.trees = 10)
  expect_equal(rf$treetype, "Survival")
  expect_equal(rf$forest$independent.variable.names, colnames(veteran)[c(1:2, 5:8)])
})

test_that("Matrix interface prediction works for survival", {
  dat <- data.matrix(veteran)
  rf <- blockForest(dependent.variable.name = "time", status.variable.name = "status", data = dat, write.forest = TRUE, num.trees = 10)
  expect_silent(predict(rf, dat))
})

test_that("growing works for single observations, survival", {
  rf <- blockForest(Surv(time, status) ~ ., veteran[1, ], write.forest = TRUE, num.trees = 10)
  expect_is(rf$survival, "matrix")
})

test_that("predict works for single observations, survival", {
  rf <- blockForest(Surv(time, status) ~ ., veteran, write.forest = TRUE, num.trees = 10)
  pred <- predict(rf, head(veteran, 1))
  expect_equal(length(pred$survival), length(rf$unique.death.times))
})

## Special tests for random forests for survival analysis
test_that("unique death times in survival result is right", {
  expect_equal(rg.surv$unique.death.times, sort(unique(veteran$time)))
})

test_that("C-index splitting works", {
  rf <- blockForest(Surv(time, status) ~ ., data = veteran, verbose = FALSE, 
               splitrule = "C", num.trees = 10)
  expect_equal(rf$treetype, "Survival")
})

test_that("C-index splitting not working on classification data", {
  expect_error(blockForest(Species ~ ., iris, splitrule = "C", num.trees = 10))
})

test_that("Logrank splitting not working on classification data", {
  expect_error(blockForest(Species ~ ., iris, splitrule = "logrank", num.trees = 10))
})

test_that("No error if survival tree without OOB observations", {
  dat <- data.frame(time = c(1,2), status = c(0,1), x = c(1,2))
  expect_silent(blockForest(Surv(time, status) ~ ., dat, num.trees = 1, num.threads = 1))
})

test_that("predict.all for survival returns 3d array of size samples x times x trees", {
  rf <- blockForest(Surv(time, status) ~ ., veteran, num.trees = 5)
  pred <- predict(rf, veteran, predict.all = TRUE)
  
  expect_is(pred$survival, "array")
  expect_equal(dim(pred$survival), 
               c(nrow(veteran), length(pred$unique.death.times), rf$num.trees))
  expect_is(pred$chf, "array")
  expect_equal(dim(pred$chf), 
               c(nrow(veteran), length(pred$unique.death.times), rf$num.trees))
})

test_that("Mean of predict.all for survival is equal to forest prediction", {
  rf <- blockForest(Surv(time, status) ~ ., veteran, num.trees = 5)
  pred_forest <- predict(rf, veteran, predict.all = FALSE)
  pred_trees <- predict(rf, veteran, predict.all = TRUE)
  expect_equal(apply(pred_trees$chf, 1:2, mean), pred_forest$chf)
})

test_that("timepoints() function returns timepoints", {
  rf <- blockForest(Surv(time, status) ~ ., veteran, num.trees = 5)
  expect_equal(timepoints(rf), rf$unique.death.times)
  
  pred <- predict(rf, veteran)
  expect_equal(timepoints(pred), rf$unique.death.times)
})

test_that("timepoints() working on survival forest only", {
  rf <- blockForest(Species ~ ., iris, num.trees = 5)
  expect_error(timepoints(rf), "No timepoints found. Object is no Survival forest.")
  
  pred <- predict(rf, iris)
  expect_error(timepoints(pred), "No timepoints found. Object is no Survival prediction object.")
})

test_that("Survival error without covariates", {
  expect_error(blockForest(Surv(time, status) ~ ., veteran[, c("time", "status")], num.trees = 5), 
               "Error: No covariates found.")
})

Try the blockForest package in your browser

Any scripts or data that you put into this service are public.

blockForest documentation built on April 3, 2023, 5:49 p.m.