Nothing
## Tests for predictions
library(ranger)
library(survival)
context("ranger_pred")
## Tests
test_that("predict returns good prediction", {
rf <- ranger(Species ~ ., iris, write.forest = TRUE)
pred <- predict(rf, iris)
expect_gt(mean(iris$Species == predictions(pred)), 0.9)
})
test_that("case weights work", {
expect_silent(ranger(Species ~ ., iris, num.trees = 5, case.weights = rep(1, nrow(iris))))
## Should only predict setosa now
weights <- c(rep(1, 50), rep(0, 100))
rf <- ranger(Species ~ ., iris, num.trees = 5, case.weights = weights, write.forest = TRUE)
pred <- predict(rf, iris)$predictions
expect_true(all(pred == "setosa"))
})
test_that("Prediction works correctly if dependent variable is not first or last", {
dat <- iris[, c(1:2, 5, 3:4)]
rf <- ranger(Species ~ ., dat, num.trees = 5, write.forest = TRUE)
expect_gte(mean(predictions(predict(rf, dat)) == dat$Species), 0.9)
## No response column
expect_gte(mean(predictions(predict(rf, dat[, -3])) == dat$Species), 0.9)
})
test_that("Prediction works correctly if dependent variable is not first or last, alternative interface", {
dat <- iris[, c(1:2, 5, 3:4)]
rf <- ranger(dependent.variable.name = "Species", data = dat, num.trees = 5, write.forest = TRUE)
expect_gte(mean(predictions(predict(rf, dat)) == dat$Species), 0.9)
## No response column
expect_gte(mean(predictions(predict(rf, dat[, -3])) == dat$Species), 0.9)
})
test_that("Missing value columns detected in predict", {
rf <- ranger(Species ~ ., iris, num.trees = 5, write.forest = TRUE)
dat <- iris
dat[4, 4] <- NA
dat[25, 1] <- NA
expect_error(predict(rf, dat), "Missing data in columns: Sepal.Length, Petal.Width.")
})
test_that("If num.trees set, these number is used for predictions", {
rf <- ranger(Species ~ ., iris, num.trees = 5, write.forest = TRUE)
pred <- predict(rf, iris, predict.all = TRUE, num.trees = 3)
expect_equal(pred$num.trees, 3)
expect_equal(dim(pred$predictions), c(nrow(iris), 3))
})
test_that("If num.trees not set, all trees are used for prediction", {
rf <- ranger(Species ~ ., iris, num.trees = 5, write.forest = TRUE)
pred <- predict(rf, iris, predict.all = TRUE)
expect_equal(pred$num.trees, 5)
expect_equal(dim(pred$predictions), c(nrow(iris), 5))
})
test_that("Error if unknown value for type", {
rf <- ranger(Species ~ ., iris, num.trees = 5, write.forest = TRUE)
expect_error(predict(rf, iris, type = "class"))
})
test_that("Terminal nodes returned by predict are node ids, classification", {
rf <- ranger(Species ~ ., iris, num.trees = 5, write.forest = TRUE)
pred <- predict(rf, iris, type = "terminalNodes")
expect_equal(dim(pred$predictions), c(nrow(iris), rf$num.trees))
expect_true(all(pred$predictions > 0))
expect_true(all(pred$predictions < max(sapply(rf$forest$split.varIDs, length))))
})
test_that("Terminal nodes returned by predict are node ids, probability", {
rf <- ranger(Species ~ ., iris, num.trees = 5, write.forest = TRUE, probability = TRUE)
pred <- predict(rf, iris, type = "terminalNodes")
expect_equal(dim(pred$predictions), c(nrow(iris), rf$num.trees))
expect_true(all(pred$predictions > 0))
expect_true(all(pred$predictions < max(sapply(rf$forest$split.varIDs, length))))
})
test_that("Terminal nodes returned by predict are node ids, regression", {
rf <- ranger(Sepal.Length ~ ., iris, num.trees = 5, write.forest = TRUE)
pred <- predict(rf, iris, type = "terminalNodes")
expect_equal(dim(pred$predictions), c(nrow(iris), rf$num.trees))
expect_true(all(pred$predictions > 0))
expect_true(all(pred$predictions < max(sapply(rf$forest$split.varIDs, length))))
})
test_that("Terminal nodes returned by predict are node ids, survival", {
rf <- ranger(Surv(time, status) ~ ., veteran, num.trees = 5, write.forest = TRUE)
pred <- predict(rf, veteran, type = "terminalNodes")
expect_equal(dim(pred$predictions), c(nrow(veteran), rf$num.trees))
expect_true(all(pred$predictions > 0))
expect_true(all(pred$predictions < max(sapply(rf$forest$split.varIDs, length))))
})
test_that("Same result with warning if getTerminalNodeIDs() used", {
rf <- ranger(Species ~ ., iris, num.trees = 5)
pred <- predict(rf, iris, type = "terminalNodes")
expect_warning(expect_equal(getTerminalNodeIDs(rf, iris), pred$predictions))
})
test_that("predict.all works for single observation", {
rf <- ranger(Species ~ ., iris, num.trees = 5, write.forest = TRUE)
pred <- predict(rf, iris[1, ], predict.all = TRUE)
expect_equal(dim(pred$predictions), c(1, rf$num.trees))
})
test_that("predict.all factor probabilities in correct order", {
rf <- ranger(Species ~ ., iris[c(51:100, 101:150, 1:50), ], probability = TRUE, num.trees = 5)
pred_all <- rowMeans(predict(rf, iris, predict.all = TRUE)$predictions[139,,])
pred_mean <- predict(rf, iris, predict.all = FALSE)$predictions[139,]
expect_equal(pred_all, pred_mean, tolerance = .001)
})
test_that("Warning if predicting with corrected impurity importance", {
rf <- ranger(Species ~ ., iris, num.trees = 5, importance = "impurity_corrected")
expect_warning(predict(rf, iris))
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.