Nothing
library(ranger)
library(survival)
context("ranger_treeInfo")
## Classification
rf.class.formula <- ranger(Species ~ ., iris, num.trees = 5)
rf.class.first <- ranger(dependent.variable.name = "Species", data = iris[, c(5, 1:4)], num.trees = 5)
rf.class.mid <- ranger(dependent.variable.name = "Species", data = iris[, c(1:2, 5, 3:4)], num.trees = 5)
rf.class.last <- ranger(dependent.variable.name = "Species", data = iris, num.trees = 5)
ti.class.formula <- treeInfo(rf.class.formula)
ti.class.first <- treeInfo(rf.class.first)
ti.class.mid <- treeInfo(rf.class.mid)
ti.class.last <- treeInfo(rf.class.last)
test_that("Terminal nodes have only prediction, non-terminal nodes all others, classification formula", {
expect_true(all(is.na(ti.class.formula[ti.class.formula$terminal, 2:6])))
expect_true(all(!is.na(ti.class.formula[ti.class.formula$terminal, c(1, 7:8)])))
expect_true(all(!is.na(ti.class.formula[!ti.class.formula$terminal, -8])))
expect_true(all(is.na(ti.class.formula[!ti.class.formula$terminal, 8])))
})
test_that("Terminal nodes have only prediction, non-terminal nodes all others, classification depvarname first", {
expect_true(all(is.na(ti.class.first[ti.class.first$terminal, 2:6])))
expect_true(all(!is.na(ti.class.first[ti.class.first$terminal, c(1, 7:8)])))
expect_true(all(!is.na(ti.class.first[!ti.class.first$terminal, -8])))
expect_true(all(is.na(ti.class.first[!ti.class.first$terminal, 8])))
})
test_that("Terminal nodes have only prediction, non-terminal nodes all others, classification depvarname mid", {
expect_true(all(is.na(ti.class.mid[ti.class.mid$terminal, 2:6])))
expect_true(all(!is.na(ti.class.mid[ti.class.mid$terminal, c(1, 7:8)])))
expect_true(all(!is.na(ti.class.mid[!ti.class.mid$terminal, -8])))
expect_true(all(is.na(ti.class.mid[!ti.class.mid$terminal, 8])))
})
test_that("Terminal nodes have only prediction, non-terminal nodes all others, classification depvarname last", {
expect_true(all(is.na(ti.class.last[ti.class.last$terminal, 2:6])))
expect_true(all(!is.na(ti.class.last[ti.class.last$terminal, c(1, 7:8)])))
expect_true(all(!is.na(ti.class.last[!ti.class.last$terminal, -8])))
expect_true(all(is.na(ti.class.last[!ti.class.last$terminal, 8])))
})
test_that("Names in treeInfo match, classification", {
varnames <- colnames(iris)[1:4]
expect_true(all(is.na(ti.class.formula$splitvarName) | ti.class.formula$splitvarName %in% varnames))
expect_true(all(is.na(ti.class.first$splitvarName) | ti.class.first$splitvarName %in% varnames))
expect_true(all(is.na(ti.class.mid$splitvarName) | ti.class.mid$splitvarName %in% varnames))
expect_true(all(is.na(ti.class.last$splitvarName) | ti.class.last$splitvarName %in% varnames))
})
test_that("Prediction for classification is factor with correct levels", {
expect_is(ti.class.formula$prediction, "factor")
expect_equal(levels(ti.class.formula$prediction), levels(iris$Species))
})
test_that("Prediction for classification is same as class prediction", {
dat <- iris[sample(nrow(iris)), ]
rf <- ranger(dependent.variable.name = "Species", data = dat, num.trees = 1,
replace = FALSE, sample.fraction = 1)
pred_class <- predict(rf, dat)$predictions
nodes <- predict(rf, dat, type = "terminalNodes")$predictions[, 1]
ti <- treeInfo(rf, 1)
pred_ti <- sapply(nodes, function(x) {
ti[ti$nodeID == x, "prediction"]
})
expect_equal(pred_ti, pred_class)
})
test_that("Prediction for classification is same as class prediction, new factor", {
dat <- iris[sample(nrow(iris)), ]
dat$Species <- factor(dat$Species, levels = sample(levels(dat$Species)))
rf <- ranger(dependent.variable.name = "Species", data = dat, num.trees = 1,
replace = FALSE, sample.fraction = 1)
pred_class <- predict(rf, dat)$predictions
nodes <- predict(rf, dat, type = "terminalNodes")$predictions[, 1]
ti <- treeInfo(rf, 1)
pred_ti <- sapply(nodes, function(x) {
ti[ti$nodeID == x, "prediction"]
})
expect_equal(pred_ti, pred_class)
})
test_that("Prediction for classification is same as class prediction, unused factor levels", {
dat <- iris[c(101:150, 51:100), ]
expect_warning(rf <- ranger(dependent.variable.name = "Species", data = dat, num.trees = 1,
replace = FALSE, sample.fraction = 1))
pred_class <- predict(rf, dat)$predictions
nodes <- predict(rf, dat, type = "terminalNodes")$predictions[, 1]
ti <- treeInfo(rf, 1)
pred_ti <- sapply(nodes, function(x) {
ti[ti$nodeID == x, "prediction"]
})
expect_equal(pred_ti, pred_class)
})
test_that("Prediction for probability is same as probability prediction", {
dat <- iris[sample(nrow(iris)), ]
rf <- ranger(dependent.variable.name = "Species", data = dat, num.trees = 1,
sample.fraction = 1, replace = FALSE, probability = TRUE)
ti <- treeInfo(rf)
pred_prob <- predict(rf, dat)$predictions
nodes <- predict(rf, dat, type = "terminalNodes")$predictions[, 1]
pred_ti <- t(sapply(nodes, function(x) {
as.matrix(ti[ti$nodeID == x, 8:10])
}))
colnames(pred_ti) <- gsub("pred\\.", "", colnames(ti)[8:10])
expect_equal(pred_prob, pred_ti)
})
test_that("Prediction for probability is same as probability prediction, new factor", {
dat <- iris[sample(nrow(iris)), ]
dat$Species <- factor(dat$Species, levels = sample(levels(dat$Species)))
rf <- ranger(dependent.variable.name = "Species", data = dat, num.trees = 1,
sample.fraction = 1, replace = FALSE, probability = TRUE)
ti <- treeInfo(rf)
pred_prob <- predict(rf, dat)$predictions
nodes <- predict(rf, dat, type = "terminalNodes")$predictions[, 1]
pred_ti <- t(sapply(nodes, function(x) {
as.matrix(ti[ti$nodeID == x, 8:10])
}))
colnames(pred_ti) <- gsub("pred\\.", "", colnames(ti)[8:10])
expect_equal(pred_prob, pred_ti)
})
test_that("Prediction for probability is same as probability prediction, unused factor levels", {
dat <- iris[c(101:150, 51:100), ]
dat$Species <- factor(dat$Species, levels = sample(levels(dat$Species)))
expect_warning(rf <- ranger(dependent.variable.name = "Species", data = dat, num.trees = 1,
sample.fraction = 1, replace = FALSE, probability = TRUE))
ti <- treeInfo(rf)
pred_prob <- predict(rf, dat)$predictions
nodes <- predict(rf, dat, type = "terminalNodes")$predictions[, 1]
pred_ti <- t(sapply(nodes, function(x) {
as.matrix(ti[ti$nodeID == x, 8:9])
}))
colnames(pred_ti) <- gsub("pred\\.", "", colnames(ti)[8:9])
expect_equal(pred_prob, pred_ti)
})
test_that("Prediction for matrix classification is integer with correct values", {
rf <- ranger(dependent.variable.name = "Species", data = data.matrix(iris),
num.trees = 5, classification = TRUE)
ti <- treeInfo(rf, 1)
expect_is(ti$prediction, "numeric")
expect_equal(sort(unique(ti$prediction)), 1:3)
})
## Regression
n <- 20
dat <- data.frame(y = rnorm(n),
replicate(2, runif(n)),
replicate(2, rbinom(n, size = 1, prob = .5)))
rf.regr.formula <- ranger(y ~ ., dat, num.trees = 5)
rf.regr.first <- ranger(dependent.variable.name = "y", data = dat, num.trees = 5)
rf.regr.mid <- ranger(dependent.variable.name = "y", data = dat[, c(2:3, 1, 4:5)], num.trees = 5)
rf.regr.last <- ranger(dependent.variable.name = "y", data = dat[, c(2:5, 1)], num.trees = 5)
ti.regr.formula <- treeInfo(rf.regr.formula)
ti.regr.first <- treeInfo(rf.regr.first)
ti.regr.mid <- treeInfo(rf.regr.mid)
ti.regr.last <- treeInfo(rf.regr.last)
test_that("Terminal nodes have only prediction, non-terminal nodes all others, regression formula", {
expect_true(all(is.na(ti.regr.formula[ti.regr.formula$terminal, 2:6])))
expect_true(all(!is.na(ti.regr.formula[ti.regr.formula$terminal, c(1, 7:8)])))
expect_true(all(!is.na(ti.regr.formula[!ti.regr.formula$terminal, -8])))
expect_true(all(is.na(ti.regr.formula[!ti.regr.formula$terminal, 8])))
})
test_that("Terminal nodes have only prediction, non-terminal nodes all others, regression depvarname first", {
expect_true(all(is.na(ti.regr.first[ti.regr.first$terminal, 2:6])))
expect_true(all(!is.na(ti.regr.first[ti.regr.first$terminal, c(1, 7:8)])))
expect_true(all(!is.na(ti.regr.first[!ti.regr.first$terminal, -8])))
expect_true(all(is.na(ti.regr.first[!ti.regr.first$terminal, 8])))
})
test_that("Terminal nodes have only prediction, non-terminal nodes all others, regression depvarname mid", {
expect_true(all(is.na(ti.regr.mid[ti.regr.mid$terminal, 2:6])))
expect_true(all(!is.na(ti.regr.mid[ti.regr.mid$terminal, c(1, 7:8)])))
expect_true(all(!is.na(ti.regr.mid[!ti.regr.mid$terminal, -8])))
expect_true(all(is.na(ti.regr.mid[!ti.regr.mid$terminal, 8])))
})
test_that("Terminal nodes have only prediction, non-terminal nodes all others, regression depvarname last", {
expect_true(all(is.na(ti.regr.last[ti.regr.last$terminal, 2:6])))
expect_true(all(!is.na(ti.regr.last[ti.regr.last$terminal, c(1, 7:8)])))
expect_true(all(!is.na(ti.regr.last[!ti.regr.last$terminal, -8])))
expect_true(all(is.na(ti.regr.last[!ti.regr.last$terminal, 8])))
})
test_that("Names in treeInfo match, regression", {
varnames <- c("X1", "X2", "X1.1", "X2.1")
expect_true(all(is.na(ti.regr.formula$splitvarName) | ti.regr.formula$splitvarName %in% varnames))
expect_true(all(is.na(ti.regr.first$splitvarName) | ti.regr.first$splitvarName %in% varnames))
expect_true(all(is.na(ti.regr.mid$splitvarName) | ti.regr.mid$splitvarName %in% varnames))
expect_true(all(is.na(ti.regr.last$splitvarName) | ti.regr.last$splitvarName %in% varnames))
})
test_that("Prediction for regression is numeric in correct range", {
expect_is(ti.regr.formula$prediction, "numeric")
expect_true(all(is.na(ti.regr.formula$prediction) | ti.regr.formula$prediction >= min(dat$y)))
expect_true(all(is.na(ti.regr.formula$prediction) | ti.regr.formula$prediction <= max(dat$y)))
})
## Probability estimation
rf.prob.formula <- ranger(Species ~ ., iris, num.trees = 5, probability = TRUE)
rf.prob.first <- ranger(dependent.variable.name = "Species", data = iris[, c(5, 1:4)], num.trees = 5, probability = TRUE)
rf.prob.mid <- ranger(dependent.variable.name = "Species", data = iris[, c(1:2, 5, 3:4)], num.trees = 5, probability = TRUE)
rf.prob.last <- ranger(dependent.variable.name = "Species", data = iris, num.trees = 5, probability = TRUE)
ti.prob.formula <- treeInfo(rf.prob.formula)
ti.prob.first <- treeInfo(rf.prob.first)
ti.prob.mid <- treeInfo(rf.prob.mid)
ti.prob.last <- treeInfo(rf.prob.last)
test_that("Terminal nodes have only prediction, non-terminal nodes all others, probability formula", {
expect_true(all(is.na(ti.prob.formula[ti.prob.formula$terminal, 2:6])))
expect_true(all(!is.na(ti.prob.formula[ti.prob.formula$terminal, c(1, 7:10)])))
expect_true(all(!is.na(ti.prob.formula[!ti.prob.formula$terminal, c(-8, -9, -10)])))
expect_true(all(is.na(ti.prob.formula[!ti.prob.formula$terminal, 8:10])))
})
test_that("Terminal nodes have only prediction, non-terminal nodes all others, probability depvarname first", {
expect_true(all(is.na(ti.prob.first[ti.prob.first$terminal, 2:6])))
expect_true(all(!is.na(ti.prob.first[ti.prob.first$terminal, c(1, 7:8)])))
expect_true(all(!is.na(ti.prob.first[!ti.prob.first$terminal, c(-8, -9, -10)])))
expect_true(all(is.na(ti.prob.first[!ti.prob.first$terminal, 8:10])))
})
test_that("Terminal nodes have only prediction, non-terminal nodes all others, probability depvarname mid", {
expect_true(all(is.na(ti.prob.mid[ti.prob.mid$terminal, 2:6])))
expect_true(all(!is.na(ti.prob.mid[ti.prob.mid$terminal, c(1, 7:8)])))
expect_true(all(!is.na(ti.prob.mid[!ti.prob.mid$terminal, c(-8, -9, -10)])))
expect_true(all(is.na(ti.prob.mid[!ti.prob.mid$terminal, 8:10])))
})
test_that("Terminal nodes have only prediction, non-terminal nodes all others, probability depvarname last", {
expect_true(all(is.na(ti.prob.last[ti.prob.last$terminal, 2:6])))
expect_true(all(!is.na(ti.prob.last[ti.prob.last$terminal, c(1, 7:8)])))
expect_true(all(!is.na(ti.prob.last[!ti.prob.last$terminal, c(-8, -9, -10)])))
expect_true(all(is.na(ti.prob.last[!ti.prob.last$terminal, 8:10])))
})
test_that("Names in treeInfo match, probability", {
varnames <- colnames(iris)[1:4]
expect_true(all(is.na(ti.prob.formula$splitvarName) | ti.prob.formula$splitvarName %in% varnames))
expect_true(all(is.na(ti.prob.first$splitvarName) | ti.prob.first$splitvarName %in% varnames))
expect_true(all(is.na(ti.prob.mid$splitvarName) | ti.prob.mid$splitvarName %in% varnames))
expect_true(all(is.na(ti.prob.last$splitvarName) | ti.prob.last$splitvarName %in% varnames))
})
test_that("Prediction for probability is one probability per class, sum to 1", {
expect_equal(ncol(ti.prob.formula), 10)
expect_is(ti.prob.formula$pred.setosa, "numeric")
expect_true(all(!ti.prob.formula$terminal | rowSums(ti.prob.formula[, 8:10]) == 1))
})
test_that("Prediction for probability has correct factor levels", {
dat <- iris[c(101:150, 1:100), ]
rf <- ranger(dependent.variable.name = "Species", data = dat, num.trees = 5, probability = TRUE)
# Predict
pred_rf <- predict(rf, dat, num.trees = 1)$predictions
# Predict with treeInfo
ti <- treeInfo(rf)
terminal_nodes <- predict(rf, dat, type = "terminalNodes")$predictions[, 1]
pred_ti <- as.matrix(ti[terminal_nodes + 1, grep("pred", colnames(ti))])
colnames(pred_ti) <- gsub("pred\\.", "", colnames(pred_ti))
rownames(pred_ti) <- NULL
expect_equal(pred_rf, pred_ti)
})
## Survival
rf.surv.formula <- ranger(Surv(time, status) ~ ., veteran, num.trees = 5)
rf.surv.first <- ranger(dependent.variable.name = "time", status.variable.name = "status", data = veteran[, c(3:4, 1:2, 5:8)], num.trees = 5)
rf.surv.mid <- ranger(dependent.variable.name = "time", status.variable.name = "status", data = veteran, num.trees = 5)
rf.surv.last <- ranger(dependent.variable.name = "time", status.variable.name = "status", data = veteran[, c(2, 1, 5:8, 3:4)], num.trees = 5)
ti.surv.formula <- treeInfo(rf.surv.formula)
ti.surv.first <- treeInfo(rf.surv.first)
ti.surv.mid <- treeInfo(rf.surv.mid)
ti.surv.last <- treeInfo(rf.surv.last)
test_that("Terminal nodes have only nodeID, non-terminal nodes all, survival formula", {
expect_true(all(is.na(ti.surv.formula[ti.surv.formula$terminal, 2:6])))
expect_true(all(!is.na(ti.surv.formula[ti.surv.formula$terminal, c(1, 7)])))
expect_true(all(!is.na(ti.surv.formula[!ti.surv.formula$terminal, ])))
})
test_that("Terminal nodes have only prediction, non-terminal nodes all others, survival depvarname first", {
expect_true(all(is.na(ti.surv.first[ti.surv.first$terminal, 2:6])))
expect_true(all(!is.na(ti.surv.first[ti.surv.first$terminal, c(1, 7)])))
expect_true(all(!is.na(ti.surv.first[!ti.surv.first$terminal, ])))
})
test_that("Terminal nodes have only prediction, non-terminal nodes all others, survival depvarname mid", {
expect_true(all(is.na(ti.surv.mid[ti.surv.mid$terminal, 2:6])))
expect_true(all(!is.na(ti.surv.mid[ti.surv.mid$terminal, c(1, 7)])))
expect_true(all(!is.na(ti.surv.mid[!ti.surv.mid$terminal, ])))
})
test_that("Terminal nodes have only prediction, non-terminal nodes all others, survival depvarname last", {
expect_true(all(is.na(ti.surv.last[ti.surv.last$terminal, 2:6])))
expect_true(all(!is.na(ti.surv.last[ti.surv.last$terminal, c(1, 7)])))
expect_true(all(!is.na(ti.surv.last[!ti.surv.last$terminal, ])))
})
test_that("Names in treeInfo match, survival", {
varnames <- colnames(veteran)[c(1:2, 5:8)]
expect_true(all(is.na(ti.surv.formula$splitvarName) | ti.surv.formula$splitvarName %in% varnames))
expect_true(all(is.na(ti.surv.first$splitvarName) | ti.surv.first$splitvarName %in% varnames))
expect_true(all(is.na(ti.surv.mid$splitvarName) | ti.surv.mid$splitvarName %in% varnames))
expect_true(all(is.na(ti.surv.last$splitvarName) | ti.surv.last$splitvarName %in% varnames))
})
test_that("No prediction for Survival", {
expect_equal(ncol(ti.surv.formula), 7)
})
## General
test_that("Error if no saved forest", {
expect_error(treeInfo(ranger(Species ~ ., iris, write.forest = FALSE)),
"Error\\: No saved forest in ranger object\\. Please set write.forest to TRUE when calling ranger\\.")
})
## Unordered splitting
test_that("Spitting value is comma separated list for partition splitting", {
n <- 50
dat <- data.frame(x = sample(c("A", "B", "C", "D", "E"), n, replace = TRUE),
y = rbinom(n, 1, 0.5),
stringsAsFactors = FALSE)
rf.partition <- ranger(y ~ ., dat, num.trees = 5, respect.unordered.factors = "partition")
ti.partition <- treeInfo(rf.partition)
expect_is(ti.partition$splitval, "character")
expect_true(all(is.na(ti.partition$splitval) | grepl("^\\d+(?:,\\d+)*$", ti.partition$splitval)))
})
test_that("Spitting value is numeric for order splitting", {
set.seed(100)
rf.order <- ranger(Sepal.Length ~ ., iris, num.trees = 5, respect.unordered.factors = "order")
ti.order <- treeInfo(rf.order)
expect_is(ti.order$splitval[!ti.order$terminal & ti.order$splitvarName == "Species"], "numeric")
})
test_that("treeInfo works for 31 unordered factor levels but not for 32", {
n <- 31
dt <- data.frame(x = factor(1:n, ordered = FALSE),
y = rbinom(n, 1, 0.5))
rf <- ranger(y ~ ., data = dt, num.trees = 10, splitrule = "extratrees")
expect_silent(treeInfo(rf))
n <- 32
dt <- data.frame(x = factor(1:n, ordered = FALSE),
y = rbinom(n, 1, 0.5))
rf <- ranger(y ~ ., data = dt, num.trees = 10, splitrule = "extratrees")
expect_warning(treeInfo(rf), "Unordered splitting levels can only be shown for up to 31 levels.")
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.