tests/testthat/test-nest_cv.R

# WARNING - Generated by {fusen} from dev/flat_teaching.Rmd: do not edit by hand

test_that("nest_cv functions correctly", {
  # Setup test data
  base_dt <- data.table::data.table(
    id = 1:100,
    group = rep(c("A", "B"), each = 50),
    value = rnorm(100),
    category = factor(rep(c("X", "Y"), times = 50))
  )
  
  # Create nested test data
  test_dt <- data.table::data.table(
    name = c("group1", "group2"),
    data = list(
      base_dt[group == "A"],
      base_dt[group == "B"]
    )
  )
  
  # Test basic functionality
  test_that("basic functionality works", {
    result <- nest_cv(test_dt, v = 5, repeats = 2)
    
    # Check structure
    expect_true(data.table::is.data.table(result))
    expect_true(all(c("name", "id", "splits", "train", "validate") %in% names(result)))
    expect_equal(nrow(result), 5 * 2 * 2)  # v * repeats * groups
    
    # Check content types
    expect_true(all(sapply(result$train, data.table::is.data.table)))
    expect_true(all(sapply(result$validate, data.table::is.data.table)))
    
    # Check split sizes
    first_fold <- result$train[[1]]
    expect_equal(nrow(first_fold), 40)  # 80% of 50 for training
    expect_equal(nrow(result$validate[[1]]), 10)  # 20% of 50 for validation
  })
  
  # Test with stratification
  test_that("stratification works correctly", {
    result <- nest_cv(test_dt, v = 5, repeats = 1, strata = "category")
    
    # Check if strata is maintained in splits
    first_train <- result$train[[1]]
    first_validate <- result$validate[[1]]
    
    # Check proportions in training and validation sets
    train_prop <- table(first_train$category) / nrow(first_train)
    validate_prop <- table(first_validate$category) / nrow(first_validate)
    expect_equal(train_prop, train_prop, tolerance = 0.1)
  })
  
  # Test with different parameters
  test_that("parameter variations work", {
    # Test with different v
    result_v3 <- nest_cv(test_dt, v = 3, repeats = 1)
    expect_equal(nrow(result_v3), 3 * 2)  # 3 folds * 2 groups
    
    # Test with different repeats
    result_r3 <- nest_cv(test_dt, v = 5, repeats = 3)
    expect_equal(nrow(result_r3), 5 * 3 * 2)  # 5 folds * 3 repeats * 2 groups
    
    # Test with different breaks
    result_breaks <- nest_cv(test_dt, v = 5, breaks = 3)
    expect_true(!is.null(result_breaks))
    
    # Test with different pool
    result_pool <- nest_cv(test_dt, v = 5, pool = 0.2)
    expect_true(!is.null(result_pool))
  })
  
  # Test error handling
  test_that("error handling works correctly", {
    # Test empty input
    empty_dt <- test_dt[0]
    expect_error(
      nest_cv(empty_dt),
      "Input 'nest_dt' cannot be empty"
    )
    
    # Test input without nested columns
    bad_dt <- data.table::data.table(a = 1:3, b = 4:6)
    expect_error(
      nest_cv(bad_dt),
      "Input 'nest_dt' must contain at least one nested column"
    )
  })
  
  # Test with multiple nested columns
  test_that("multiple nested columns work", {
    # Create test data with multiple nested columns
    multi_nest_dt <- data.table::data.table(
      name = c("group1", "group2"),
      data1 = list(
        base_dt[group == "A"],
        base_dt[group == "B"]
      ),
      data = list(
        base_dt[group == "A"],
        base_dt[group == "B"]
      )
    )
    
    result <- nest_cv(multi_nest_dt, v = 2)
    expect_true(!is.null(result))
    expect_true(all(c("name", "splits", "train", "validate") %in% names(result)))
  })
  
  # Test data consistency
  test_that("data consistency is maintained", {
    result <- nest_cv(test_dt, v = 5, repeats = 1)
    
    # Check that all original columns are preserved in splits
    first_train <- result$train[[1]]
    expect_true(all(names(base_dt) %in% names(first_train)))
    
    # Check that no observations are lost or duplicated
    for (i in seq_len(nrow(result))) {
      train_set <- result$train[[i]]
      validate_set <- result$validate[[i]]
      
      # Total number of observations should equal original group size
      expect_equal(nrow(train_set) + nrow(validate_set), 50)
      
      # No duplicates between train and validate
      train_ids <- train_set$id
      validate_ids <- validate_set$id
      expect_equal(length(intersect(train_ids, validate_ids)), 0)
    }
  })
  
  # Test reproducibility
  test_that("results are reproducible with seed", {
    set.seed(123)
    result1 <- nest_cv(test_dt, v = 5)
    
    set.seed(123)
    result2 <- nest_cv(test_dt, v = 5)
    
    expect_equal(result1, result2)
  })
  
  # Test with different data types
  test_that("handles different data types correctly", {
    # Create test data with various data types
    complex_dt <- data.table::data.table(
      id = 1:50,
      num = rnorm(50),
      int = 1:50,
      fct = factor(rep(letters[1:5], 10)),
      date = seq(as.Date("2024-01-01"), by = "day", length.out = 50),
      char = letters[1:50]
    )
    
    nested_complex <- data.table::data.table(
      name = "group1",
      data = list(complex_dt)
    )
    
    result <- nest_cv(nested_complex, v = 5)
    
    # Check that data types are preserved
    first_train <- result$train[[1]]
    expect_type(first_train$num, "double")
    expect_type(first_train$int, "integer")
    expect_s3_class(first_train$fct, "factor")
    expect_s3_class(first_train$date, "Date")
    expect_type(first_train$char, "character")
  })
})

Try the mintyr package in your browser

Any scripts or data that you put into this service are public.

mintyr documentation built on April 4, 2025, 2:56 a.m.