tests/testthat/test-split_cv.R

# WARNING - Generated by {fusen} from dev/flat_teaching.Rmd: do not edit by hand

# Create test data
create_test_data <- function() {
  dt1 <- data.table(
    x = 1:100,
    y = rnorm(100),
    group = rep(letters[1:4], 25)
  )
  dt2 <- data.table(
    x = 1:50,
    y = rnorm(50),
    group = rep(letters[1:2], 25)
  )
  return(list(data1 = dt1, data2 = dt2))
}

# Basic functionality test
test_that("split_cv returns correct structure", {
  test_data <- create_test_data()
  result <- split_cv(test_data, v = 5, repeats = 1)
  
  # Check basic structure of return value
  expect_type(result, "list")
  expect_equal(length(result), length(test_data))
  expect_named(result, names(test_data))
  
  # Check structure of each result set
  for (res in result) {
    expect_true(is.data.table(res))
    expect_true(all(c("splits", "id", "train", "validate") %in% names(res)))
    expect_equal(nrow(res), 5)
  }
})

# Error handling test
test_that("split_cv handles invalid inputs correctly", {
  expect_error(split_cv(NULL))
  expect_error(split_cv(list()))
  expect_error(split_cv(list(a = 1, b = 2)))
  
  test_data <- create_test_data()
  expect_warning(
    split_cv(test_data, v = 5, strata = "non_existent_column"),
    "Strata variable 'non_existent_column' not found"
  )
})

# Repeated cross-validation test
test_that("split_cv handles repeats correctly", {
  test_data <- create_test_data()
  
  # Single cross-validation
  result_single <- split_cv(test_data, v = 5, repeats = 1)
  for (res in result_single) {
    expect_true("id" %in% names(res))
    expect_false("id2" %in% names(res))
    expect_equal(nrow(res), 5)
    expect_true(all(res$id %in% paste0("Fold", 1:5)))
  }
  
  # Multiple repeats
  result_multiple <- split_cv(test_data, v = 5, repeats = 3)
  for (res in result_multiple) {
    expect_true(all(c("id", "id2") %in% names(res)))
    expect_equal(nrow(res), 15)
    expect_true(all(grepl("^Repeat\\d+$", res$id)))
    expect_true(all(grepl("^Fold\\d+$", res$id2)))
    expect_equal(length(unique(res$id)), 3)
    expect_equal(length(unique(res$id2)), 5)
    
    # Check number of folds in each repeat
    for (repeat_id in unique(res$id)) {
      expect_equal(res[id == repeat_id, .N], 5)
    }
  }
})

# Train and validation sets test
test_that("split_cv generates correct train and validate sets", {
  test_data <- create_test_data()
  result <- split_cv(test_data, v = 5, repeats = 2)
  
  for (i in seq_along(result)) {
    res <- result[[i]]
    original_data <- test_data[[i]]
    
    # Randomly check one split
    sample_split_idx <- sample(1:nrow(res), 1)
    train_set <- res$train[[sample_split_idx]]
    validate_set <- res$validate[[sample_split_idx]]
    
    # Check set properties
    train_rows <- train_set$x
    validate_rows <- validate_set$x
    original_rows <- nrow(original_data)
    
    # Check mutual exclusivity
    expect_equal(length(intersect(train_rows, validate_rows)), 0)
    
    # Check sizes
    expect_equal(length(validate_rows), original_rows/5, tolerance = 1)
    expect_equal(length(train_rows), original_rows * 4/5, tolerance = 1)
    
    # Check completeness
    all_rows <- sort(unique(c(train_rows, validate_rows)))
    expect_equal(all_rows, 1:original_rows)
  }
})

# Stratification test
test_that("split_cv handles stratification correctly", {
  test_data <- create_test_data()
  result <- split_cv(test_data, v = 5, strata = "group")
  
  for (res in result) {
    first_split <- res$train[[1]]
    expect_true("group" %in% names(first_split))
    unique_groups <- unique(first_split$group)
    expect_true(length(unique_groups) > 1)
  }
})

# Data type handling test
test_that("split_cv handles different input data types", {
  # Test with data.frame
  df_data <- list(
    data1 = as.data.frame(create_test_data()[[1]]),
    data2 = as.data.frame(create_test_data()[[2]])
  )
  result_df <- split_cv(df_data, v = 5)
  expect_true(all(sapply(result_df, is.data.table)))
  
  # Test with mixed types
  mixed_data <- list(
    data1 = as.data.frame(create_test_data()[[1]]),
    data2 = create_test_data()[[2]]
  )
  result_mixed <- split_cv(mixed_data, v = 5)
  expect_true(all(sapply(result_mixed, is.data.table)))
})

Try the mintyr package in your browser

Any scripts or data that you put into this service are public.

mintyr documentation built on April 4, 2025, 2:56 a.m.