tests/testthat/test-old-api-dist-creation.r

####################
# Author: James Hickey
#
# Series of tests to check creation of GBMDist object
# for old gbm API.
#
####################
context("Test the creation of GBMDist objects for old API")
test_that("Error thrown if distribution not recognised", {
  # Given a made up distribution
  dist <- list(name="woooooo! Error!")
  
  # When create_dist_obj_for_gbmt_fit is called
  # Then an error is thrown
  expect_error(create_dist_obj_for_gbmt_fit(dist), 
               paste("The distribution",dist$name,"is not available in the gbm package."))
})
test_that("Can create default distributions", {
  # Given default distributions
  # When creating similar defaults from create_dist_obj_for_gbmt_fit
  # Then successfully create those objects
  expect_equal(gbm_dist("AdaBoost"), create_dist_obj_for_gbmt_fit(list(name="AdaBoost")))
  expect_equal(gbm_dist("Bernoulli"), create_dist_obj_for_gbmt_fit(list(name="Bernoulli")))
  expect_equal(gbm_dist("CoxPH"), create_dist_obj_for_gbmt_fit(list(name="CoxPH")))
  expect_equal(gbm_dist("Gamma"), create_dist_obj_for_gbmt_fit(list(name="Gamma")))
  expect_equal(gbm_dist("Gaussian"), create_dist_obj_for_gbmt_fit(list(name="Gaussian")))
  expect_equal(gbm_dist("Huberized"), create_dist_obj_for_gbmt_fit(list(name="Huberized")))
  expect_equal(gbm_dist("Laplace"), create_dist_obj_for_gbmt_fit(list(name="Laplace")))
  expect_equal(gbm_dist("Pairwise"), create_dist_obj_for_gbmt_fit(list(name="Pairwise")))
  expect_equal(gbm_dist("Poisson"), create_dist_obj_for_gbmt_fit(list(name="Poisson")))
  expect_equal(gbm_dist("Quantile"), create_dist_obj_for_gbmt_fit(list(name="Quantile")))
  expect_equal(gbm_dist("TDist"), create_dist_obj_for_gbmt_fit(list(name="TDist")))
  expect_equal(gbm_dist("Tweedie"), create_dist_obj_for_gbmt_fit(list(name="Tweedie")))
})
test_that("Can specify model dependent data - CoxPH", {
  # Given a distribution list with strata, tied.times.method
  # and prior.node.coeff.var
  dist <- list(name="CoxPH")
  strata <- c(1, 2, 1)
  tied.times.method <- "breslow"
  prior.node.coeff.var <- 150.4
  
  # When distribution object is created
  dist_obj <- create_dist_obj_for_gbmt_fit(dist, tied.times.method, 
                                           strata, prior.node.coeff.var)
  
  # Then it has the correct model dependent data
  expect_equal(dist_obj$original_strata_id, strata)
  expect_equal(dist_obj$ties, tied.times.method)
  expect_equal(dist_obj$prior_node_coeff_var, prior.node.coeff.var)
})
test_that("Can specify model dependent data - Pairwise", {
  # Given a distribution list with metric, max.rank and group
  dist <- list(name="Pairwise", metric="mrr", max.rank=1, group="query")
  
  # When distribution object is created
  dist_obj <- create_dist_obj_for_gbmt_fit(dist)
  
  # Then it has the correct model dependent data
  expect_equal(dist_obj$metric, dist$metric)
  expect_equal(dist_obj$group, dist$group)
  expect_equal(dist_obj$max_rank, dist$max.rank)
})
test_that("Can specify model dependent data - Quantile", {
  # Given a distribution list with alpha specified
  dist <- list(name="Quantile", alpha=0.5)
  
  # When distribution object is created
  dist_obj <- create_dist_obj_for_gbmt_fit(dist)
  
  # Then it has the correct alpha
  expect_equal(dist_obj$alpha, dist$alpha)
})
test_that("Can specify model dependent data - TDist", {
  # Given a distribution list with df specified
  dist <- list(name="TDist", df=10)
  
  # When distribution object is created
  dist_obj <- create_dist_obj_for_gbmt_fit(dist)
  
  # Then it has the correct df
  expect_equal(dist_obj$df, dist$df)
})
test_that("Can specify model dependent data - Tweedie", {
  # Given a distribution list with power specified
  dist <- list(name="Tweedie", power=100)
  
  # When distribution object is created
  dist_obj <- create_dist_obj_for_gbmt_fit(dist)
  
  # Then it has the correct power
  expect_equal(dist_obj$power, dist$power)
})
test_that("Case of distribution name characters is irrelevant", {
  # Creating distributions - then case of letters in name doesn't 
  # matter
  expect_equal(create_dist_obj_for_gbmt_fit(list(name="adaboost")), 
               create_dist_obj_for_gbmt_fit(list(name="AdaBoost")))
  expect_equal(create_dist_obj_for_gbmt_fit(list(name="bernoulli")),
               create_dist_obj_for_gbmt_fit(list(name="Bernoulli")))
  expect_equal(create_dist_obj_for_gbmt_fit(list(name="coxph")),
               create_dist_obj_for_gbmt_fit(list(name="CoxPH")))
  expect_equal(create_dist_obj_for_gbmt_fit(list(name="gamma")),
               create_dist_obj_for_gbmt_fit(list(name="Gamma")))
  expect_equal(create_dist_obj_for_gbmt_fit(list(name="gaussian")),
               create_dist_obj_for_gbmt_fit(list(name="Gaussian")))
  expect_equal(create_dist_obj_for_gbmt_fit(list(name="huberized")),
               create_dist_obj_for_gbmt_fit(list(name="Huberized")))
  expect_equal(create_dist_obj_for_gbmt_fit(list(name="laplace")),
               create_dist_obj_for_gbmt_fit(list(name="Laplace")))
  expect_equal(create_dist_obj_for_gbmt_fit(list(name="pairwise")),
               create_dist_obj_for_gbmt_fit(list(name="Pairwise")))
  expect_equal(create_dist_obj_for_gbmt_fit(list(name="poisson")),
               create_dist_obj_for_gbmt_fit(list(name="Poisson")))
  expect_equal(create_dist_obj_for_gbmt_fit(list(name="quantile")),
               create_dist_obj_for_gbmt_fit(list(name="Quantile")))
  expect_equal(create_dist_obj_for_gbmt_fit(list(name="tdist")), 
               create_dist_obj_for_gbmt_fit(list(name="TDist")))
  expect_equal(create_dist_obj_for_gbmt_fit(list(name="tweedie")),
               create_dist_obj_for_gbmt_fit(list(name="Tweedie")))
})
gbm-developers/gbm3 documentation built on April 28, 2024, 10:04 p.m.