tests/testthat/test-global_validation.R

test_that("global_validation correctly handles missing predictions", {
  skip_if_not_installed("randomForest")
  data("iris")
  set.seed(123)
  ctrl <- caret::trainControl(method="cv")
  model <- caret::train(iris[,c("Sepal.Width", "Petal.Length", "Petal.Width")],
                        iris[,c("Sepal.Length")],
                        method="rf", trControl=ctrl, ntree=10)
  expect_error(global_validation(model))
})

test_that("global_validation works with caret regression", {
  skip_if_not_installed("randomForest")
  data("iris")
  set.seed(123)
  ctrl <- caret::trainControl(method="cv", savePredictions="final")
  model <- caret::train(iris[,c("Sepal.Width", "Petal.Length", "Petal.Width")],
                        iris[,c("Sepal.Length")],
                        method="rf", trControl=ctrl, ntree=10)
  expect_equal(global_validation(model),
               c("RMSE"=0.3307870, "Rsquared"=0.8400544, "MAE"=0.2621827),
               tolerance = 0.02)

})

test_that("global_validation works with caret classification", {
  skip_if_not_installed("randomForest")
  data("iris")
  set.seed(123)
  ctrl <- caret::trainControl(method="cv", savePredictions="final")
  model <- caret::train(iris[,c("Sepal.Width", "Petal.Length", "Petal.Width", "Sepal.Length")],
                        iris[,c("Species")],
                        method="rf", trControl=ctrl, ntree=10)
  expect_equal(global_validation(model)[1:2],
               c("Accuracy"=0.96, "Kappa"=0.94),
               tolerance = 0.02)

})

test_that("global_validation works with CreateSpacetimeFolds", {
  skip_if_not_installed("randomForest")
  data("iris")
  set.seed(123)
  iris$folds <- sample(rep(1:10, ceiling(nrow(iris)/10)), nrow(iris))
  indices <- CreateSpacetimeFolds(iris, "folds")
  ctrl <- caret::trainControl(method="cv", savePredictions="final", index = indices$index)
  model <- caret::train(iris[,c("Sepal.Width", "Petal.Length", "Petal.Width", "Sepal.Length")],
                        iris[,c("Species")],
                        method="rf", trControl=ctrl, ntree=10)
  expect_equal(global_validation(model)[1:2],
               c("Accuracy"=0.96, "Kappa"=0.94),
               tolerance = 0.02)
})

Try the CAST package in your browser

Any scripts or data that you put into this service are public.

CAST documentation built on June 22, 2024, 10:47 a.m.