tests/testthat/test-weight.R

library(testthat)
library(reticulate)

skip_if_no_keras <- function() {

  if (!tryCatch(
    reticulate::py_module_available("keras"),
    error = function(e) return(FALSE)
  )
  ) skip("keras not available for testing...")
}
# Test case 1: Check the weights for gaussian family
test_that("Weights for gaussian family should be equal to input weights", {
  skip_if_no_keras()

  family <- "gaussian"
  muhat <- c(0.1, 0.5, 0.9)
  w <- c(0.2, 0.6, 0.8)
  expected_output <- w
  actual_output <- weight(w, muhat, family)
  expect_equal(actual_output, expected_output)
})

# Test case 2: Check the weights for binomial family
test_that("Weights for binomial family should be correctly calculated", {
  skip_if_no_keras()

  family <- "binomial"
  muhat <- c(0.2, 0.7, 0.99)
  w <- c(0.2, 0.6, 0.8)

  # Calculate the expected weights
  diriv_values <- diriv(family, muhat)
  aux <- muhat * (1 - muhat) * (diriv_values^2)
  aux[aux <= 0.001] <- 0.001
  expected_output <- w / aux

  actual_output <- weight(w, muhat, family)
  expect_equal(actual_output, expected_output)
})

# Test case 3: Check for missing 'muhat' argument
test_that("Function should throw an error for missing 'muhat' argument", {
  skip_if_no_keras()

  family <- "gaussian"
  w <- c(0.2, 0.6, 0.8)
  expect_error(weight(w, family))
})

# Test case 4: Check for missing 'w' argument
test_that("Function should throw an error for missing 'w' argument", {
  skip_if_no_keras()

  family <- "gaussian"
  muhat <- c(0.1, 0.5, 0.9)
  expect_error(weight(muhat, family))
})

# Test case 5: Check for missing 'family' argument
test_that("Function should throw an error for missing 'family' argument", {
  skip_if_no_keras()

  muhat <- c(0.1, 0.5, 0.9)
  w <- c(0.2, 0.6, 0.8)
  expect_error(weight(w, muhat))
})

# Test case 6: Check for unsupported family
test_that("Function should throw an error for unsupported 'family'", {
  skip_if_no_keras()

  family <- "poisson"
  muhat <- c(0.1, 0.5, 0.9)
  w <- c(0.2, 0.6, 0.8)
  expect_error(weight(w, muhat, family))
})

# Test case 7: Check that weights calculation handles extreme muhat values for binomial family
test_that("Weights calculation should handle extreme muhat values for binomial family", {
  skip_if_no_keras()

  family <- "binomial"
  muhat <- c(0.0001, 0.9999)
  w <- c(0.2, 0.6)

  # Calculate the expected weights
  muhat[muhat <= 0.001] <- 0.001
  muhat[muhat >= 0.999] <- 0.999
  diriv_values <- neuralGAM:::diriv(family, muhat)
  aux <- muhat * (1 - muhat) * (diriv_values**2)
  aux[aux <= 0.001] <- 0.001
  expected_output <- w / aux

  actual_output <- neuralGAM:::weight(w, muhat, family)
  expect_equal(actual_output, expected_output)
})

Try the neuralGAM package in your browser

Any scripts or data that you put into this service are public.

neuralGAM documentation built on June 22, 2024, 6:55 p.m.