tests/testthat/test_calibration.R

test_that("Test that kernel bandwidth calibration works", {
  
  N <- 50
  X <- cbind(matrix(rnorm(N * 2), ncol = 2), rbinom(N, 1, 0.5))
  y <- X %*% rnorm(ncol(X))
  # Check that it runs with calibration
  fit_one <- gam(y ~ s(X1, X2, X3,
                       bs = "gKRLS", xt = gKRLS(bandwidth = "calibrate")), data = data.frame(X, y))
  # Extract the calibration information
  calib_info <- get_calibration_information(fit_one)
  
  fit_two <- gam(y ~ 
                   s(X2) +
      s(X1, X3, bs = "gKRLS", xt = gKRLS(bandwidth = "calibrate")) +
      s(X1, X2, X3, bs = "gKRLS", xt = gKRLS(bandwidth = "calibrate")), 
    data = data.frame(X, y), method = "GCV.Cp")
  calib_info <- get_calibration_information(fit_two)
  # Check that calibration works correctly
  expect_equal(calib_info$bandwidth, sapply(calib_info$smooth, FUN=function(i){
    fit_two$smooth[[i]]$bandwidth
  }))
  
  # Check the manually reconstructed kernel uses the correct, calibrated,
  # bandwidths
  X_test <- data.frame(cbind(matrix(rnorm(N * 2 * 2), ncol = 2), rbinom(N * 2, 1, 0.5)))
  pred_test <- predict(fit_two, newdata = X_test, type = 'terms')
  kern_two <- Predict.matrix(fit_two$smooth[[2]], data = X_test)
  kern_three <- Predict.matrix(fit_two$smooth[[3]], data = X_test)
 
  input_X <- sweep(as.matrix(X_test[,c('X1','X3')]), MARGIN = 2, 
        FUN='-', STATS = fit_two$smooth[[2]]$std_train$mean)
  input_X <- input_X %*% fit_two$smooth[[2]]$std_train$whiten
  test_K <- gKRLS:::create_sketched_kernel(
    X_test = as.matrix(input_X),
    X_train = fit_two$smooth[[2]]$X_train, 
    bandwidth = fit_two$smooth[[2]]$bandwidth,
    S = fit_two$smooth[[2]]$sketch_matrix)
  mgcv_K <- Predict.matrix(fit_two$smooth[[2]], data = X_test)  
  expect_equal(test_K, mgcv_K)

  input_X <- sweep(as.matrix(X_test[,c('X1','X2', 'X3')]), MARGIN = 2, 
                   FUN='-', STATS = fit_two$smooth[[3]]$std_train$mean)
  input_X <- input_X %*% fit_two$smooth[[3]]$std_train$whiten
  test_K <- gKRLS:::create_sketched_kernel(
    X_test = as.matrix(input_X),
    X_train = fit_two$smooth[[3]]$X_train, 
    bandwidth = fit_two$smooth[[3]]$bandwidth,
    S = fit_two$smooth[[3]]$sketch_matrix)
  mgcv_K <- Predict.matrix(fit_two$smooth[[3]], data = X_test)  
  expect_equal(test_K, mgcv_K)
  
})

Try the gKRLS package in your browser

Any scripts or data that you put into this service are public.

gKRLS documentation built on Sept. 11, 2024, 8:24 p.m.