test_that("Tests that distillation with selected features is working", {
library(Rforestry)
set.seed(491)
data <- iris
test_ind <- sample(1:nrow(data), nrow(data)%/%5)
train_reg <- data[-test_ind,]
test_reg <- data[test_ind,]
# Train a random forest on the data set
forest <- forestry(x=train_reg[,-1],
y=train_reg[,1])
# Create a predictor wrapper for the forest
forest_predictor <- Predictor$new(model = forest,
data=train_reg,
y="Sepal.Length",
task = "regression")
# Initialize an interpreter
forest_interpret <- Interpreter$new(predictor = forest_predictor)
feat.index <- which(forest_interpret$features %in% c("Sepal.Width", "Petal.Width", "Petal.Length"))
# Standard distillation
context("Default distillation function")
surrogate.model <- distill(forest_interpret, features = feat.index)
surrogate.model$weights
expect_equal(all.equal(names(surrogate.model$weights),
c("Sepal.Width", "Petal.Length","Petal.Width")),
TRUE)
expect_equal(sum(surrogate.model$weights < 0), 0)
expect_equal(surrogate.model$grid$Petal.Length[,1],
forest_interpret$predictor$data[forest_interpret$data.points, "Petal.Length"])
expect_equal(surrogate.model$grid$Petal.Length[,2],
forest_interpret$pdp.1d$Petal.Length(forest_interpret$predictor$data[forest_interpret$data.points, "Petal.Length"]),
tolerance = 1e4)
# Equal Grid Spacing
context("Snap.train F for distill")
surrogate.model <- distill(forest_interpret, features = feat.index, snap.train = FALSE)
expect_equal(all.equal(names(surrogate.model$weights),
c("Sepal.Width", "Petal.Length","Petal.Width")),
TRUE)
expect_equal(sum(surrogate.model$weights < 0), 0)
expect_equal(surrogate.model$grid$Petal.Width[,1],
forest_interpret$grid.points$Petal.Width)
expect_equal(surrogate.model$grid$Petal.Width[,2],
forest_interpret$pdp.1d$Petal.Width(forest_interpret$grid.points$Petal.Width),
tolerance = 1e4)
# No snapping to grid
context("Snap.grid FALSE for distill")
surrogate.model <- distill(forest_interpret, features = feat.index, snap.grid = FALSE)
expect_equal(all.equal(names(surrogate.model$weights),
c("Sepal.Width", "Petal.Length","Petal.Width")),
TRUE)
expect_equal(sum(surrogate.model$weights < 0), 0)
expect_equal(surrogate.model$grid, NA)
# Specific feature test
context("distill with specified features")
surrogate.model <- distill(forest_interpret, features = feat.index)
expect_equal(all.equal(names(surrogate.model$weights),
c("Sepal.Width", "Petal.Length","Petal.Width")),
TRUE)
expect_equal(sum(surrogate.model$weights < 0), 0)
expect_equal(length(surrogate.model$grid), 3)
# Check that we can pass parameters to glmnet
context("distill with user-specified parameters for fitting")
surrogate.model <- distill(forest_interpret, features = feat.index,
params.glmnet = list(family = "gaussian",
alpha = 1,
lambda = 0,
intercept = FALSE))
expect_equal(all.equal(names(surrogate.model$weights),
c("Sepal.Width", "Petal.Length","Petal.Width")),
TRUE)
expect_equal(sum(surrogate.model$weights >0), 3)
# Check that we can do cross validation with cv.glmnet
context("distill with cross validation")
# surrogate.model <- distill(forest_interpret, cv = TRUE)
# expect_equal(all.equal(names(surrogate.model$weights),
# c("Species_Blue", "Species_Orange", "Sex_Male", "Sex_Female", "Index",
# "FrontalLobe","RearWidth","CarapaceLength","CarapaceWidth")),
# TRUE)
# expect_equal(sum(surrogate.model$weights < 0), 0)
#
# surrogate.model <- distill(forest_interpret, cv = TRUE,
# params.cv.glmnet = list(upper.limits = 0,
# intercept = FALSE,
# alpha = 1))
# expect_equal(all.equal(names(surrogate.model$weights),
# c("Species_Blue", "Species_Orange", "Sex_Male", "Sex_Female", "Index",
# "FrontalLobe","RearWidth","CarapaceLength","CarapaceWidth")),
# TRUE)
# expect_equal(sum(surrogate.model$weights > 0), 0)
rm(list=ls())
})
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.