tests/testthat/test-predictor.R

test_that("Tests that the predictor wrapper is working", {

  library(Rforestry)
  set.seed(491)

  data <-iris



  test_ind <- sample(1:nrow(data), nrow(data)%/%5)
  train_reg <- data[-test_ind,]
  test_reg <- data[test_ind,]

  # Train a random forest on the data set
  forest <- forestry(x=train_reg[,-1],
                     y=train_reg[,1])

  # Create a predictor wrapper for the forest
  forest_predictor <- Predictor$new(model = forest,
                                    data=train_reg,
                                    y="Sepal.Length",
                                    task = "regression")


  context("Check class attributes")
  expect_equal(forest_predictor$task,"regression")

  expect_equal(forest_predictor$model@ntree, 500)

  model_data <- forest_predictor$model@processed_dta$processed_x
  wrapper_data <- forest_predictor$data

  expect_equal(all.equal(names(model_data),
                         names(wrapper_data)[-1]), TRUE)

  context("Check the model in predictor")
  expect_equal(all.equal(forest_predictor$model@processed_dta$y[1:10],
                         forest@processed_dta$y[1:10],
                         tolerance = 1e-4),
               TRUE)


  context("Check predictions are equal ")
  wrapped_preds <- predict(forest_predictor, test_reg[,-1])
  standard_preds <- predict(forest, test_reg[,-1])

  expect_equal(all.equal(wrapped_preds[,1],
                         standard_preds,
                         tolerance = 1e-4),
               TRUE)

  context("Try passing parameters to the predict function")
  # Try passing parameters to the predict function
  wrapped_preds <- predict(forest_predictor, train_reg[,-1], aggregation = "oob")
  standard_preds <- predict(forest, train_reg[,-1], aggregation = "oob")

  expect_equal(all.equal(wrapped_preds[,1],
                         standard_preds,
                         tolerance = 1e-4),
               TRUE)


  context("Try a non forestry model")
  mod2 <- lm(Sepal.Length ~., data = iris[,-5])

  linear_predictor <- Predictor$new(model = mod2,
                                    data=mod2$model,
                                    y="Sepal.Length",
                                    task = "regression")

  preds_wrapped <- predict(linear_predictor, iris[,-c(1,5)])
  preds_std <- predict(mod2, iris[,-c(1,5)])

  expect_equal(all.equal(preds_wrapped[,1],
                         unname(preds_std)), TRUE)

  context("Try a classification model")
  mod <- suppressWarnings(glm(Species ~ ., data=iris, family = "binomial"))

  preds <- predict(mod, iris, type = "response")


  predictor <- Predictor$new(model = mod,
                             data = iris,
                             y = "Species",
                             task = "classification",
                             type = "response")

  expect_equal(all.equal(predict(predictor, iris)[,1],
              unname(predict(mod, iris, type = "response")),
              tolerance = 1e-4),
              TRUE)

  rm(list=ls())
})
forestry-labs/interpretability_sandbox documentation built on April 26, 2023, 4:14 p.m.