tests/testthat/test_caret.R

library(testthat)
library(rsample)

###################################################################
## Test cases for caret -> rsample that mimic `trainControl`

dat <- data.frame(y = 1:15, x = 15:1)

lgo1 <- 
  structure(
    list(
      method = "LGOCV",
      index = structure(
        list(
          Resample1 = c(1L, 4L, 5L, 6L, 7L, 9L, 10L, 14L),
          Resample2 = c(2L, 4L, 5L, 6L, 9L, 10L, 14L, 15L),
          Resample3 = c(1L, 2L, 3L, 5L, 6L, 7L, 8L, 9L)
        ),
        .Names = c("Resample1", "Resample2", "Resample3")
      ),
      indexOut = structure(
        list(
          Resample1 = c(2L, 3L, 8L, 11L, 12L, 13L, 15L),
          Resample2 = c(1L, 3L, 7L, 8L, 11L, 12L, 13L),
          Resample3 = c(4L, 10L, 11L, 12L, 13L, 14L, 15L)
        ),
        .Names = c("Resample1", "Resample2", "Resample3")
      ),
      number = 3,
      p = 0.5
    ),
    .Names = c("method", "index", "indexOut", "number", "p")
  )

cv_1 <- structure(
  list(
    method = "cv",
    index = structure(
      list(
        Fold1 = c(2L, 3L, 4L, 5L, 6L, 7L, 8L, 10L, 11L, 13L),
        Fold2 = c(1L, 4L, 6L,
                  8L, 9L, 10L, 11L, 12L, 13L, 14L, 15L),
        Fold3 = c(1L, 2L, 3L, 5L, 7L, 9L, 12L, 14L, 15L)
      ),
      .Names = c("Fold1", "Fold2", "Fold3")
    ),
    indexOut = structure(
      list(
        Resample1 = c(1L, 9L, 12L, 14L, 15L),
        Resample2 = c(2L, 3L, 5L, 7L),
        Resample3 = c(4L, 6L, 8L, 10L, 11L, 13L)
      ),
      .Names = c("Resample1", "Resample2", "Resample3")
    ),
    number = 3,
    repeats = NA
  ),
  .Names = c("method", "index", "indexOut", "number", "repeats")
)

cv_2 <- 
  structure(
  list(
    method = "repeatedcv",
    index = structure(
      list(
        Fold1.Rep1 = c(1L, 3L, 4L, 6L, 9L, 10L, 12L, 13L, 14L, 15L),
        Fold2.Rep1 = c(2L, 5L, 7L, 8L, 10L, 11L, 13L, 14L, 15L),
        Fold3.Rep1 = c(1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, 11L, 12L),
        Fold1.Rep2 = c(1L, 2L, 3L, 5L, 6L, 7L, 10L, 11L, 12L, 14L),
        Fold2.Rep2 = c(2L, 4L, 6L, 8L, 9L, 11L, 13L, 14L, 15L),
        Fold3.Rep2 = c(1L, 3L, 4L, 5L, 7L, 8L, 9L, 10L, 12L, 13L, 15L)
      ),
      .Names = c(
        "Fold1.Rep1", "Fold2.Rep1", "Fold3.Rep1",
        "Fold1.Rep2", "Fold2.Rep2", "Fold3.Rep2"
      )
    ),
    indexOut = structure(
      list(
        Resample1 = c(2L, 5L, 7L, 8L, 11L),
        Resample2 = c(1L, 3L, 4L, 6L, 9L, 12L),
        Resample3 = c(10L, 13L, 14L, 15L),
        Resample4 = c(4L, 8L, 9L, 13L, 15L),
        Resample5 = c(1L, 3L, 5L, 7L, 10L, 12L),
        Resample6 = c(2L, 6L, 11L, 14L)
      ),
      .Names = c(
        "Resample1", "Resample2", "Resample3",
        "Resample4", "Resample5", "Resample6"
      )
    ),
    number = 3,
    repeats = 2
  ),
  .Names = c("method", "index", "indexOut", "number", "repeats")
)

cv_3 <- cv_2
cv_3$method <- "adaptive_cv"

bt_1 <- 
  structure(
    list(
      method = "boot",
      index = structure(
        list(
          Resample1 = c(1L, 1L, 4L, 4L, 5L, 7L, 8L, 10L, 11L, 11L, 12L, 13L, 15L, 15L, 15L),
          Resample2 = c(1L, 2L, 3L, 5L, 5L, 5L, 6L, 7L, 8L, 9L, 9L, 9L, 10L, 10L, 12L)
        ),
        .Names = c("Resample1", "Resample2")
      ),
      indexOut = structure(
        list(
          Resample1 = c(2L, 3L, 6L, 9L, 14L),
          Resample2 = c(4L, 11L, 13L, 14L, 15L)
        ),
        .Names = c("Resample1", "Resample2")
      ),
      number = 2
    ),
    .Names = c("method",
               "index", "indexOut", "number")
  )
bt_2 <- bt_1
bt_2$method <- "boot632"

bt_3 <- bt_1
bt_3$method <- "optimism_boot"

bt_4 <- bt_1
bt_4$method <- "boot_all"

bt_5 <- bt_1
bt_5$method <- "adaptive_boot"

loo_1 <- 
  structure(
    list(
      method = "LOOCV",
      index = structure(
        list(
          Fold01 = 2:15,
          Fold02 = c(1L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, 10L, 11L, 12L, 13L, 14L, 15L),
          Fold03 = c(1L, 2L, 4L, 5L, 6L, 7L, 8L, 9L, 10L, 11L, 12L, 13L, 14L, 15L),
          Fold04 = c(1L, 2L, 3L, 5L, 6L, 7L, 8L, 9L, 10L, 11L, 12L, 13L, 14L, 15L),
          Fold05 = c(1L, 2L, 3L, 4L, 6L, 7L, 8L, 9L, 10L, 11L, 12L, 13L, 14L, 15L),
          Fold06 = c(1L, 2L, 3L, 4L, 5L, 7L, 8L, 9L, 10L, 11L, 12L, 13L, 14L, 15L),
          Fold07 = c(1L, 2L, 3L, 4L, 5L, 6L, 8L, 9L, 10L, 11L, 12L, 13L, 14L, 15L),
          Fold08 = c(1L, 2L, 3L, 4L, 5L, 6L, 7L, 9L, 10L, 11L, 12L, 13L, 14L, 15L),
          Fold09 = c(1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 10L, 11L, 12L, 13L, 14L, 15L),
          Fold10 = c(1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, 11L, 12L, 13L, 14L, 15L),
          Fold11 = c(1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, 10L, 12L, 13L, 14L, 15L),
          Fold12 = c(1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, 10L, 11L, 13L, 14L, 15L),
          Fold13 = c(1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, 10L, 11L, 12L, 14L, 15L),
          Fold14 = c(1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, 10L, 11L, 12L, 13L, 15L),
          Fold15 = 1:14
        ),
        .Names = c(
          "Fold01", "Fold02", "Fold03",  "Fold04", "Fold05", "Fold06",
          "Fold07", "Fold08", "Fold09", "Fold10", "Fold11", "Fold12",
          "Fold13", "Fold14", "Fold15"
        )
      ),
      indexOut = structure(
        list(
          Resample01 = 1L, Resample02 = 2L,
          Resample03 = 3L, Resample04 = 4L,
          Resample05 = 5L, Resample06 = 6L,
          Resample07 = 7L, Resample08 = 8L,
          Resample09 = 9L, Resample10 = 10L,
          Resample11 = 11L, Resample12 = 12L,
          Resample13 = 13L, Resample14 = 14L,
          Resample15 = 15L
        ),
        .Names = c(
          "Resample01", "Resample02", "Resample03", "Resample04",
          "Resample05", "Resample06", "Resample07", "Resample08",
          "Resample09", "Resample10", "Resample11", "Resample12",
          "Resample13", "Resample14", "Resample15"
        )
      )
    ),
    .Names = c("method", "index", "indexOut")
  )

rof_1 <- 
  structure(
    list(
      method = "timeSlice",
      index = structure(
        list(
          Training04 = 1:4, Training05 = 2:5, Training06 = 3:6,
          Training07 = 4:7, Training08 = 5:8, Training09 = 6:9,
          Training10 = 7:10
        ),
        .Names = c(
          "Training04", "Training05", "Training06",  "Training07",
          "Training08", "Training09", "Training10"
        )
      ),
      indexOut = structure(
        list(
          Testing04 = 5:9, Testing05 = 6:10, Testing06 = 7:11,
          Testing07 = 8:12, Testing08 = 9:13, Testing09 = 10:14,
          Testing10 = 11:15
        ),
        .Names = c(
          "Testing04", "Testing05", "Testing06", "Testing07",
          "Testing08", "Testing09", "Testing10"
        )
      ),
      initialWindow = 4, horizon = 5, fixedWindow = TRUE, skip = 0
    ),
    .Names = c(
      "method", "index", "indexOut", "initialWindow",
      "horizon", "fixedWindow", "skip"
    )
  )

###################################################################
##

check_indices <- function(newer, orig) {
  for (i in seq_along(newer$splits)) {
    expect_equal(as.integer(newer$splits[[i]]),
                 orig$index[[i]])
    expect_equal(as.integer(newer$splits[[i]], "assessment"),
                 orig$indexOut[[i]])
  }
  invisible(NULL)
}

###################################################################
## Tests

test_that('basic v-fold', {
  vfold_obj_1 <- caret2rsample(cv_1, data = dat)
  check_indices(vfold_obj_1, cv_1)
  for (i in seq_along(vfold_obj_1$splits)) 
    expect_equal(vfold_obj_1$id[[i]], names(cv_1$index)[i]) 
})

test_that('repeated v-fold', {
  vfold_obj_2 <- caret2rsample(cv_2, data = dat)
  check_indices(vfold_obj_2, cv_2)
  for (i in seq_along(vfold_obj_2$splits)) 
    expect_equal(paste(vfold_obj_2$id2[[i]], vfold_obj_2$id[[i]], 
                       sep = "."),
                 names(cv_2$index)[i]) 
  
})

test_that('basic boot', {
  bt_obj_1 <- caret2rsample(bt_1, data = dat)
  check_indices(bt_obj_1, bt_1)
  for (i in seq_along(bt_obj_1$splits)) 
    expect_equal(bt_obj_1$id[[i]], names(bt_1$index)[i]) 
})

test_that('boot 632', {
  bt_obj_2 <- caret2rsample(bt_2, data = dat)
  check_indices(bt_obj_2, bt_2)
  for (i in seq_along(bt_obj_2$splits)) 
    expect_equal(bt_obj_2$id[[i]], names(bt_2$index)[i]) 
})

test_that('boot optim', {
  bt_obj_3 <- caret2rsample(bt_3, data = dat)
  check_indices(bt_obj_3, bt_3)
  for (i in seq_along(bt_obj_3$splits)) 
    expect_equal(bt_obj_3$id[[i]], names(bt_3$index)[i]) 
})

test_that('boot all', {
  bt_obj_4 <- caret2rsample(bt_4, data = dat)
  check_indices(bt_obj_4, bt_4)
  for (i in seq_along(bt_obj_4$splits)) 
    expect_equal(bt_obj_4$id[[i]], names(bt_4$index)[i]) 
})

test_that('adaptive boot', {
  bt_obj_5 <- caret2rsample(bt_5, data = dat)
  check_indices(bt_obj_5, bt_5)
  for (i in seq_along(bt_obj_5$splits)) 
    expect_equal(bt_obj_5$id[[i]], names(bt_5$index)[i]) 
})


test_that('loo', {
  loo_obj <- caret2rsample(loo_1, data = dat)
  check_indices(loo_obj, loo_1)
  for (i in seq_along(loo_obj$splits)) 
    expect_equal(loo_obj$id[[i]], names(loo_1$index)[i]) 
})

test_that('mcv', {
  mcv_obj <- caret2rsample(lgo1, data = dat)
  check_indices(mcv_obj, lgo1)
  for (i in seq_along(mcv_obj$splits)) 
    expect_equal(mcv_obj$id[[i]], names(lgo1$index)[i]) 
})

test_that('rolling origin', {
  rof_obj <- caret2rsample(rof_1, data = dat)
  check_indices(rof_obj, rof_1)
  for (i in seq_along(rof_obj$splits)) 
    expect_equal(rof_obj$id[[i]], names(rof_1$index)[i]) 
})
topepo/rsample documentation built on May 4, 2019, 4:25 p.m.