tests/testthat/test-update-num-train-groups.r

####################
# Author: James Hickey
#
# Series of tests to check if how training params are updated if groups are present
#
####################

context("Input checking update_num_train_groups")
test_that("update_num_train_groups throws an error if not given a GBMTrainParams object", {
  # Given a default train params and dist object
  params <- training_params()
  dist <- gbm_dist()
  
  # When training params stripped of class and update_num_train_groups called
  class(params) <- "wrong"
  
  # Then error thrown
  expect_error(update_num_train_groups(params, dist))
})
test_that("update_num_train_groups throws an error if not given a GBMDist object", {
  # Given a default train params and dist object
  params <- training_params()
  dist <- gbm_dist()
  
  # When dist is stripped of class and update_num_train_groups called
  class(params) <- "wrong"
  
  # Then error thrown
  expect_error(update_num_train_groups(params, dist))
})

context("Testing update_num_train_groups behaviour")
test_that("update_num_train_groups returns original params object when given an AdaBoostGBMDist obj", {
  # Given some training parameters and a default dist
  params <- training_params()
  dist <- gbm_dist("AdaBoost")
  
  # When training parameters are updated using update_num_train_groups
  # Then no update happens
  expect_equal(update_num_train_groups(params, dist), params)
})
test_that("update_num_train_groups returns original params object when given an BernoulliGBMDist obj", {
  # Given some training parameters and a default dist
  params <- training_params()
  dist <- gbm_dist("Bernoulli")
  
  # When training parameters are updated using update_num_train_groups
  # Then no update happens
  expect_equal(update_num_train_groups(params, dist), params)
})
test_that("update_num_train_groups returns original params object when given an CoxPHGBMDist obj", {
  # Given some training parameters and a default dist
  params <- training_params()
  dist <- gbm_dist("CoxPH")
  
  # When training parameters are updated using update_num_train_groups
  # Then no update happens
  expect_equal(update_num_train_groups(params, dist), params)
})
test_that("update_num_train_groups returns original params object when given an GammaGBMDist obj", {
  # Given some training parameters and a default dist
  params <- training_params()
  dist <- gbm_dist("Gamma")
  
  # When training parameters are updated using update_num_train_groups
  # Then no update happens
  expect_equal(update_num_train_groups(params, dist), params)
})
test_that("update_num_train_groups returns original params object when given an GaussianGBMDist obj", {
  # Given some training parameters and a default dist
  params <- training_params()
  dist <- gbm_dist("Gaussian")
  
  # When training parameters are updated using update_num_train_groups
  # Then no update happens
  expect_equal(update_num_train_groups(params, dist), params)
})
test_that("update_num_train_groups returns original params object when given an HuberizedGBMDist obj", {
  # Given some training parameters and a default dist
  params <- training_params()
  dist <- gbm_dist("Huberized")
  
  # When training parameters are updated using update_num_train_groups
  # Then no update happens
  expect_equal(update_num_train_groups(params, dist), params)
})
test_that("update_num_train_groups returns original params object when given an LaplaceGBMDist obj", {
  # Given some training parameters and a default dist
  params <- training_params()
  dist <- gbm_dist("Laplace")
  
  # When training parameters are updated using update_num_train_groups
  # Then no update happens
  expect_equal(update_num_train_groups(params, dist), params)
})
test_that("update_num_train_groups returns original params object when given an PoissonGBMDist obj", {
  # Given some training parameters and a default dist
  params <- training_params()
  dist <- gbm_dist("Poisson")
  
  # When training parameters are updated using update_num_train_groups
  # Then no update happens
  expect_equal(update_num_train_groups(params, dist), params)
})
test_that("update_num_train_groups returns original params object when given an QuantileGBMDist obj", {
  # Given some training parameters and a default dist
  params <- training_params()
  dist <- gbm_dist("Quantile")
  
  # When training parameters are updated using update_num_train_groups
  # Then no update happens
  expect_equal(update_num_train_groups(params, dist), params)
})
test_that("update_num_train_groups returns original params object when given an TDistGBMDist obj", {
  # Given some training parameters and a default dist
  params <- training_params()
  dist <- gbm_dist("TDist")
  
  # When training parameters are updated using update_num_train_groups
  # Then no update happens
  expect_equal(update_num_train_groups(params, dist), params)
})
test_that("update_num_train_groups returns original params object when given an TweedieGBMDist obj", {
  # Given some training parameters and a default dist
  params <- training_params()
  dist <- gbm_dist("AdaBoost")
  
  # When training parameters are updated using update_num_train_groups
  # Then no update happens
  expect_equal(update_num_train_groups(params, dist), params)
})
test_that("update_num_train_groups throws error if training_params id indicates multiple rows per observation - Pairwise", {
  # Given a training params object and an appropriately set-up pairwise object
  # such that observations have multiple rows
  N <- 100
  dist <- gbm_dist("Pairwise")
  params <- training_params(num_train=N, id=sample(seq_len(N/5), N, replace=TRUE))
  
  # When attempt to update_num_train_groups
  # Then error is thrown
  expect_error(update_num_train_groups(params, dist))
})
test_that("update_num_train_groups returns correctly updated training_params - Pairwise", {
  # Given some training parameters object and a Pairwise object with groups
  N <- 100
  params <- training_params(num_train=N/2, id=seq_len(N))
  dist <- gbm_dist("Pairwise")
  dist$groups <- sample(seq_len(5), N, replace=TRUE)
  
  # When update_num_train_groups
  params_update <- update_num_train_groups(params, dist)
  
  # Then correctly calculates update
  num_train_groups <- max(1, round(params$train_fraction * nlevels(dist$group)))
  params$num_train_rows <- max(which(dist$group == levels(dist$group)[num_train_groups]))
  params$num_train <- params$num_train_rows
  expect_equal(params_update, params)
})
gbm-developers/gbm3 documentation built on April 28, 2024, 10:04 p.m.