tests/testthat/test-distill_feats.R

test_that("Tests that distillation with selected features is working", {

  library(Rforestry)
  set.seed(491)

  data <- iris

  test_ind <- sample(1:nrow(data), nrow(data)%/%5)
  train_reg <- data[-test_ind,]
  test_reg <- data[test_ind,]

  # Train a random forest on the data set
  forest <- forestry(x=train_reg[,-1],
                     y=train_reg[,1])

  # Create a predictor wrapper for the forest
  forest_predictor <- Predictor$new(model = forest,
                                    data=train_reg,
                                    y="Sepal.Length",
                                    task = "regression")

  # Initialize an interpreter
  forest_interpret <- Interpreter$new(predictor = forest_predictor)

  feat.index <- which(forest_interpret$features %in% c("Sepal.Width", "Petal.Width", "Petal.Length"))

  # Standard distillation
  context("Default distillation function")
  surrogate.model <- distill(forest_interpret, features = feat.index)
  surrogate.model$weights
  expect_equal(all.equal(names(surrogate.model$weights),
                         c("Sepal.Width", "Petal.Length","Petal.Width")),
               TRUE)
  expect_equal(sum(surrogate.model$weights < 0), 0)

  expect_equal(surrogate.model$grid$Petal.Length[,1],
               forest_interpret$predictor$data[forest_interpret$data.points, "Petal.Length"])

  expect_equal(surrogate.model$grid$Petal.Length[,2],
               forest_interpret$pdp.1d$Petal.Length(forest_interpret$predictor$data[forest_interpret$data.points, "Petal.Length"]),
               tolerance = 1e4)

  # Equal Grid Spacing
  context("Snap.train F for distill")
  surrogate.model <- distill(forest_interpret, features = feat.index, snap.train = FALSE)
  expect_equal(all.equal(names(surrogate.model$weights),
                         c("Sepal.Width", "Petal.Length","Petal.Width")),
               TRUE)
  expect_equal(sum(surrogate.model$weights < 0), 0)
  expect_equal(surrogate.model$grid$Petal.Width[,1],
               forest_interpret$grid.points$Petal.Width)

  expect_equal(surrogate.model$grid$Petal.Width[,2],
               forest_interpret$pdp.1d$Petal.Width(forest_interpret$grid.points$Petal.Width),
               tolerance = 1e4)

  # No snapping to grid
  context("Snap.grid FALSE for distill")
  surrogate.model <- distill(forest_interpret, features = feat.index, snap.grid = FALSE)
  expect_equal(all.equal(names(surrogate.model$weights),
                         c("Sepal.Width", "Petal.Length","Petal.Width")),
               TRUE)
  expect_equal(sum(surrogate.model$weights < 0), 0)
  expect_equal(surrogate.model$grid, NA)

  # Specific feature test
  context("distill with specified features")
  surrogate.model <- distill(forest_interpret, features = feat.index)
  expect_equal(all.equal(names(surrogate.model$weights),
                         c("Sepal.Width", "Petal.Length","Petal.Width")),
               TRUE)
  expect_equal(sum(surrogate.model$weights < 0), 0)
  expect_equal(length(surrogate.model$grid), 3)

  # Check that we can pass parameters to glmnet
  context("distill with user-specified parameters for fitting")
  surrogate.model <- distill(forest_interpret, features = feat.index,
                             params.glmnet = list(family = "gaussian",
                                                  alpha = 1,
                                                  lambda = 0,
                                                  intercept = FALSE))
  expect_equal(all.equal(names(surrogate.model$weights),
                         c("Sepal.Width", "Petal.Length","Petal.Width")),
               TRUE)
  expect_equal(sum(surrogate.model$weights >0), 3)


  # Check that we can do cross validation with cv.glmnet
  context("distill with cross validation")
  # surrogate.model <- distill(forest_interpret, cv = TRUE)
  # expect_equal(all.equal(names(surrogate.model$weights),
  #                        c("Species_Blue", "Species_Orange", "Sex_Male", "Sex_Female", "Index",
  #                          "FrontalLobe","RearWidth","CarapaceLength","CarapaceWidth")),
  #              TRUE)
  # expect_equal(sum(surrogate.model$weights < 0), 0)
  #
  # surrogate.model <- distill(forest_interpret, cv = TRUE,
  #                            params.cv.glmnet = list(upper.limits = 0,
  #                                                    intercept = FALSE,
  #                                                    alpha = 1))
  # expect_equal(all.equal(names(surrogate.model$weights),
  #                        c("Species_Blue", "Species_Orange", "Sex_Male", "Sex_Female", "Index",
  #                          "FrontalLobe","RearWidth","CarapaceLength","CarapaceWidth")),
  #              TRUE)
  # expect_equal(sum(surrogate.model$weights > 0), 0)

  rm(list=ls())
})
forestry-labs/interpretability_sandbox documentation built on April 26, 2023, 4:14 p.m.