####################
# Author: James Hickey
#
# Series of tests to check prettifying of trees
#
####################
context("Testing input checking for prettifying")
test_that("error thrown when tree_index is not a positive integer", {
# Given a fitted gbm object
## test Gaussian distribution gbm model
set.seed(1)
# create some data
N <- 1000
X1 <- runif(N)
X2 <- 2*runif(N)
X3 <- factor(sample(letters[1:4],N,replace=T))
X4 <- ordered(sample(letters[1:6],N,replace=T))
X5 <- factor(sample(letters[1:3],N,replace=T))
X6 <- 3*runif(N)
mu <- c(-1,0,1,2)[as.numeric(X3)]
SNR <- 10 # signal-to-noise ratio
Y <- X1**1.5 + 2 * (X2**.5) + mu
sigma <- sqrt(var(Y)/SNR)
Y <- Y + rnorm(N,0,sigma)
# create a bunch of missing values
X1[sample(1:N,size=100)] <- NA
X3[sample(1:N,size=300)] <- NA
w <- rep(1,N)
offset <- rep(0, N)
data <- data.frame(Y=Y,X1=X1,X2=X2,X3=X3,X4=X4,X5=X5,X6=X6)
# Set up for new API
params <- training_params(num_trees=20, interaction_depth=3, min_num_obs_in_node=10,
shrinkage=0.005, bag_fraction=0.5, id=seq(nrow(data)), num_train=N/2, num_features=6)
dist <- gbm_dist("Gaussian")
fit <- gbmt(Y~X1+X2+X3+X4+X5+X6, data=data, distribution=dist, weights=w, offset=offset,
train_params=params, var_monotone=c(0, 0, 0, 0, 0, 0), keep_gbm_data=TRUE, cv_folds=10, is_verbose=FALSE)
# When calling pretty_gbm_tree on the object and the tree_index is not a positive integer
# Then an error is thrown
expect_error(pretty_gbm_tree(fit, tree_index=NA))
expect_error(pretty_gbm_tree(fit, tree_index=c(1, 2)))
expect_error(pretty_gbm_tree(fit, tree_index=0))
expect_error(pretty_gbm_tree(fit, tree_index=NA))
expect_error(pretty_gbm_tree(fit, tree_index=FALSE))
})
test_that("error thrown when tree_index exceeds the number of trees fitted", {
# Given a fitted gbm object
## test Gaussian distribution gbm model
set.seed(1)
# create some data
N <- 1000
X1 <- runif(N)
X2 <- 2*runif(N)
X3 <- factor(sample(letters[1:4],N,replace=T))
X4 <- ordered(sample(letters[1:6],N,replace=T))
X5 <- factor(sample(letters[1:3],N,replace=T))
X6 <- 3*runif(N)
mu <- c(-1,0,1,2)[as.numeric(X3)]
SNR <- 10 # signal-to-noise ratio
Y <- X1**1.5 + 2 * (X2**.5) + mu
sigma <- sqrt(var(Y)/SNR)
Y <- Y + rnorm(N,0,sigma)
# create a bunch of missing values
X1[sample(1:N,size=100)] <- NA
X3[sample(1:N,size=300)] <- NA
w <- rep(1,N)
offset <- rep(0, N)
data <- data.frame(Y=Y,X1=X1,X2=X2,X3=X3,X4=X4,X5=X5,X6=X6)
# Set up for new API
params <- training_params(num_trees=20, interaction_depth=3, min_num_obs_in_node=10,
shrinkage=0.005, bag_fraction=0.5, id=seq(nrow(data)), num_train=N/2, num_features=6)
dist <- gbm_dist("Gaussian")
fit <- gbmt(Y~X1+X2+X3+X4+X5+X6, data=data, distribution=dist, weights=w, offset=offset,
train_params=params, var_monotone=c(0, 0, 0, 0, 0, 0), keep_gbm_data=TRUE, cv_folds=10, is_verbose=FALSE)
# When calling pretty_gbm_tree on the object and the tree_index is larger than number of fitted trees
tree_index <- params$num_trees + 1
# Then an error is thrown
expect_error(pretty_gbm_tree(fit, tree_index))
})
context("Test output of pretty_gbm_tree")
test_that("Tree is prettified correctly", {
# Given a fitted gbm object and a sensible tree_index
n_trees <- 200
tree_index <- 100
## test Gaussian distribution gbm model
set.seed(1)
# create some data
N <- 1000
X1 <- runif(N)
X2 <- 2*runif(N)
X3 <- factor(sample(letters[1:4],N,replace=T))
X4 <- ordered(sample(letters[1:6],N,replace=T))
X5 <- factor(sample(letters[1:3],N,replace=T))
X6 <- 3*runif(N)
mu <- c(-1,0,1,2)[as.numeric(X3)]
SNR <- 10 # signal-to-noise ratio
Y <- X1**1.5 + 2 * (X2**.5) + mu
sigma <- sqrt(var(Y)/SNR)
Y <- Y + rnorm(N,0,sigma)
# create a bunch of missing values
X1[sample(1:N,size=100)] <- NA
X3[sample(1:N,size=300)] <- NA
w <- rep(1,N)
offset <- rep(0, N)
data <- data.frame(Y=Y,X1=X1,X2=X2,X3=X3,X4=X4,X5=X5,X6=X6)
# Set up for new API
params <- training_params(num_trees=n_trees, interaction_depth=3, min_num_obs_in_node=10,
shrinkage=0.005, bag_fraction=0.5, id=seq(nrow(data)), num_train=N/2, num_features=6)
dist <- gbm_dist("Gaussian")
fit <- gbmt(Y~X1+X2+X3+X4+X5+X6, data=data, distribution=dist, weights=w, offset=offset,
train_params=params, var_monotone=c(0, 0, 0, 0, 0, 0), keep_gbm_data=TRUE, cv_folds=10, is_verbose=FALSE)
# When calling pretty_gbm_tree on the object and the tree_index
pretty_gbm_tree_tree <- pretty_gbm_tree(fit, tree_index)
tree <- data.frame(fit$trees[[tree_index]])
names(tree) <- c("SplitVar","SplitCodePred","LeftNode",
"RightNode","MissingNode","ErrorReduction",
"Weight","Prediction")
row.names(tree) <- 0:(nrow(tree)-1)
# Then correctly prettifies tree
expect_equal(pretty_gbm_tree_tree, tree)
})
test_that("default prettified tree is the first one", {
# Given a fitted gbm object and the default tree_index
n_trees <- 200
tree_index <- 1
## test Gaussian distribution gbm model
set.seed(1)
# create some data
N <- 1000
X1 <- runif(N)
X2 <- 2*runif(N)
X3 <- factor(sample(letters[1:4],N,replace=T))
X4 <- ordered(sample(letters[1:6],N,replace=T))
X5 <- factor(sample(letters[1:3],N,replace=T))
X6 <- 3*runif(N)
mu <- c(-1,0,1,2)[as.numeric(X3)]
SNR <- 10 # signal-to-noise ratio
Y <- X1**1.5 + 2 * (X2**.5) + mu
sigma <- sqrt(var(Y)/SNR)
Y <- Y + rnorm(N,0,sigma)
# create a bunch of missing values
X1[sample(1:N,size=100)] <- NA
X3[sample(1:N,size=300)] <- NA
w <- rep(1,N)
offset <- rep(0, N)
data <- data.frame(Y=Y,X1=X1,X2=X2,X3=X3,X4=X4,X5=X5,X6=X6)
# Set up for new API
params <- training_params(num_trees=n_trees, interaction_depth=3, min_num_obs_in_node=10,
shrinkage=0.005, bag_fraction=0.5, id=seq(nrow(data)), num_train=N/2, num_features=6)
dist <- gbm_dist("Gaussian")
fit <- gbmt(Y~X1+X2+X3+X4+X5+X6, data=data, distribution=dist, weights=w, offset=offset,
train_params=params, var_monotone=c(0, 0, 0, 0, 0, 0), keep_gbm_data=TRUE, cv_folds=10, is_verbose=FALSE)
# When calling pretty_gbm_tree on the object and with
pretty_gbm_tree_tree <- pretty_gbm_tree(fit)
tree <- data.frame(fit$trees[[1]])
names(tree) <- c("SplitVar","SplitCodePred","LeftNode",
"RightNode","MissingNode","ErrorReduction",
"Weight","Prediction")
row.names(tree) <- 0:(nrow(tree)-1)
# Then correctly prettifies the first trees
expect_equal(pretty_gbm_tree_tree, tree)
})
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.