tests/testthat/test-ALEPlot-gold-standard.R

# test-ALEPlot.R
# Tests to ensure that ale package gives exactly the same results
# as the gold standard reference ALEPlot package.

# test_file('tests/testthat/test-ALEPlot.R')

# To minimize test time, the reference output should be serialized with expect_snapshot_value.

# Do not run these on CRAN so that the required packages are not included as dependencies.
# https://community.rstudio.com/t/skip-an-entire-test-file-on-cran-only/162842
if (!identical(Sys.getenv("NOT_CRAN"), "true")) return()


# nnet -----------------

set.seed(0)
n = 1000  # smaller dataset for more rapid execution
x1 <- runif(n, min = 0, max = 1)
x2 <- runif(n, min = 0, max = 1)
x3 <- runif(n, min = 0, max = 1)
x4 <- runif(n, min = 0, max = 1)
y = 4*x1 + 3.87*x2^2 + 2.97*exp(-5+10*x3)/(1+exp(-5+10*x3))+
  13.86*(x1-0.5)*(x2-0.5)+ rnorm(n, 0, 1)

DAT <<- data.frame(y, x1, x2, x3, x4)

set.seed(0)
nnet.DAT <<- nnet::nnet(y ~ ., data = DAT, linout = T, skip = F, size = 6,
                        decay = 0.1, maxit = 1000, trace = F)

# Define the predict functions
nnet_pred_fun_ALEPlot <<- function(X.model, newdata) {
  as.numeric(predict(X.model, newdata,type = "raw"))
}
nnet_pred_fun_ale <<- function(object, newdata, type = pred_type) {
  as.numeric(predict(object, newdata, type = type))
}


# gbm ----------------

adult_data <<-
  census |>
  as.data.frame() |>   # ALEPlot is not compatible with the tibble format
  select(age:native_country, higher_income) |>  # Rearrange columns to match ALEPlot order
  stats::na.omit(data)

# Dump plots automatically generated by gbm into a temp PDF file so they don't print
# Don't print any plots
pdf(file = NULL)

set.seed(0)
gbm.data <<- gbm::gbm(
  higher_income ~ .,
  data = adult_data[,-c(3,4)] |>
    # gbm::gbm() requires binary response outcomes to be numeric 0 or 1
    mutate(higher_income = as.integer(higher_income)),
  distribution = "bernoulli",
  n.trees = 100,  # smaller model than ALEPlot example for rapid execution
  shrinkage = 0.02,
  interaction.depth = 3
)

# Return to regular printing of plots
dev.off() |> invisible()

gbm_pred_fun_ALEPlot <<- function(X.model, newdata) {
  as.numeric(gbm::predict.gbm(X.model, newdata, n.trees = 100, type="link"))
}
gbm_pred_fun_ale <<- function(object, newdata, type = pred_type) {
  as.numeric(gbm::predict.gbm(object, newdata, n.trees = 100, type = type))
}


# Tests --------------------

test_that('ale function matches output of ALEPlot with nnet', {
  # Dump plots into a temp PDF file so they don't print
  # Don't print any plots
  pdf(file = NULL)

  # Create list of ALEPlot data that can be readily compared for accuracy
  nnet_ALEPlot <-
    map(1:4, \(it.col_idx) {
      ALEPlot::ALEPlot(DAT[,2:5], nnet.DAT, pred.fun = nnet_pred_fun_ALEPlot, J = it.col_idx, K = 10) |>
        as_tibble() |>
        select(-K)
    }) |>
    set_names(names(DAT[,2:5]))

  # Return to regular printing of plots
  dev.off() |> invisible()

  # Create ale results with data only
  nnet_ale <- ALE(
    # basic arguments
    model = nnet.DAT,
    data = DAT,
    # make ale equivalent to ALEPlot
    parallel = 0,
    output_stats = FALSE,
    boot_it = 0,
    # specific options requested by ALEPlot example
    pred_type = "raw", pred_fun = nnet_pred_fun_ale,
    max_num_bins = 10 + 1,
    silent = TRUE
  )

  # Convert ale results to version that can be readily compared with ALEPlot
  nnet_ale_to_ALEPlot <-
    get(nnet_ale, ale_centre = 'zero') |>
    map(\(it.x) {
      tibble(
        x.values = it.x[[1]],
        f.values = it.x$.y,
      )
    })

  # Compare results of ALEPlot with ale
  expect_true(
    all.equal(nnet_ALEPlot, nnet_ale_to_ALEPlot, tolerance = 0.01)
  )
})


test_that('ale function matches output of ALEPlot with gbm', {
  # Dump plots into a temp PDF file so they don't print
  # Don't print any plots
  pdf(file = NULL)

  # Create list of ALEPlot data that can be readily compared for accuracy
  # For this test, get only four variables: c('age', 'workclass', 'education_num', 'sex')
  # These are column indexes c(1, 2, 3, 8)
  gbm_ALEPlot <-
    map(c(1, 2, 3, 8), \(it.col_idx) {
      ALEPlot::ALEPlot(
        adult_data[,-c(3,4,15)], gbm.data, pred.fun = gbm_pred_fun_ALEPlot,
        J = it.col_idx,
        K = 10, NA.plot = TRUE
      ) |>
        as_tibble() |>
        select(-K)
    }) |>
    set_names(names(adult_data[,-c(3,4,15)])[c(1, 2, 3, 8)])

  # Return to regular printing of plots
  dev.off() |> invisible()

  # Create ale results with data only
  gbm_ale <- ALE(
    model = gbm.data,
    x_cols = c('age', 'workclass', 'education_num', 'sex'),
    data = adult_data[,-c(3,4)],  # unlike ALEPlot, include the y column (15)
    # make ale equivalent to ALEPlot
    parallel = 0,
    output_stats = FALSE,
    boot_it = 0,
    # specific options requested by ALEPlot example
    pred_fun = gbm_pred_fun_ale, pred_type = 'link',
    max_num_bins = 10 + 1,
    silent = TRUE
  ) |>
    suppressMessages()

  # Convert ale results to version that can be readily compared with ALEPlot
  gbm_ale_to_ALEPlot <-
    get(gbm_ale, ale_centre = 'zero') |>
    map(\(it.x) {
      tibble(
        x.values = it.x[[1]],
        f.values = unname(it.x$.y),
      ) |>
        mutate(across(where(is.factor), as.character))
    })

  # Compare results of ALEPlot with ale
  expect_true(
    all.equal(gbm_ALEPlot, gbm_ale_to_ALEPlot)
  )
})


test_that('2D ALE matches output of ALEPlot interactions with nnet', {
  # Dump plots into a temp PDF file so they don't print
  # Don't print any plots
  pdf(file = NULL)

  # Create list of ALEPlot data that can be readily compared for accuracy
  nnet_ALEPlot_ixn <- list()
  for (it.x1 in 1:4) {
    for (it.x2 in 1:4) {
      if (it.x1 < it.x2) {
        ap_data <- ALEPlot::ALEPlot(
          DAT[,2:5],
          nnet.DAT,
          pred.fun = nnet_pred_fun_ALEPlot,
          J = c(it.x1, it.x2),
          K = 10
        )
        .x1 <- ap_data$x.values[[1]]
        .x2 <- ap_data$x.values[[2]]
        .y  <- ap_data$f.values

        ixn_tbl <-
          expand.grid(
            row = 1:length(.x1),
            col = 1:length(.x2)
          ) |>
          as_tibble() |>
          mutate(
            .x1 = .x1[row],
            .x2 = as.numeric(.x2[col]),
            .y  = as.numeric(.y[cbind(row, col)])
          ) |>
          select(-row, -col) |>
          arrange(.x1, .x2, .y)

        # Remove extraneous attributes, otherwise comparison will not match
        attributes(ixn_tbl)$out.attrs <- NULL

        nnet_ALEPlot_ixn[[str_glue('x{it.x1}:x{it.x2}')]] <- ixn_tbl
      }
    }
  }

  # Return to regular printing of plots
  dev.off() |> invisible()

  nnet_2D <- ALE(
    # basic arguments
    model = nnet.DAT,
    data = DAT,
    x_cols = list(d2 = TRUE),
    parallel = 0,
    output_stats = FALSE,
    pred_fun = nnet_pred_fun_ale,
    pred_type = "raw", max_num_bins = 10 + 1,  # specific options requested
    silent = TRUE
  )

  # Convert ale results to version that can be readily compared with ALEPlot
  nnet_2D_to_ALEPlot <-
    get(nnet_2D, ale_centre = 'zero') |>
    map(\(it.ale) {
      it.ale <- it.ale |>
        select(1, 2, .y) |>
        set_names(c('.x1', '.x2', '.y')) |>
        arrange(.x1, .x2, .y)

      # Strip incomparable attributes
      attr(it.ale, 'x') <- NULL

      it.ale
    })

  # Compare results of ALEPlot with ale
  expect_true(
    all.equal(nnet_ALEPlot_ixn, nnet_2D_to_ALEPlot, tolerance = 0.01)
  )
})


test_that('2D ALE matches output of ALEPlot interactions with gbm', {
  # Dump plots into a temp PDF file so they don't print
  # Don't print any plots
  pdf(file = NULL)

  # Create list of ALEPlot data that can be readily compared for accuracy
  gbm_ALEPlot_ixn <- list()
  adult_data_subset <- adult_data[,-c(3,4,15)]
  for (it.x1 in c(1, 2, 3, 8)) {
    for (it.x2 in c(1, 3, 11)) {
      if (it.x1 < it.x2) {
        ap_data <- ALEPlot::ALEPlot(
          adult_data_subset,
          gbm.data,
          pred.fun = gbm_pred_fun_ALEPlot,
          J = c(it.x1, it.x2),
          K = 10,
          NA.plot = TRUE
        )
        .x1 <- ap_data$x.values[[1]]
        .x2 <- ap_data$x.values[[2]]
        .y  <- ap_data$f.values

        ixn_tbl <-
          expand.grid(
            row = 1:length(.x1),
            col = 1:length(.x2)
          ) |>
          as_tibble() |>
          mutate(
            .x1 = .x1[row],
            .x2 = as.numeric(.x2[col]),
            .y  = as.numeric(.y[cbind(row, col)])
          ) |>
          select(-row, -col) |>
          arrange(.x1, .x2, .y)

        # Remove extraneous attributes, otherwise comparison will not match
        attributes(ixn_tbl)$out.attrs <- NULL

        gbm_ALEPlot_ixn[[str_glue(
          '{names(adult_data_subset)[it.x1]}:{names(adult_data_subset)[it.x2]}'
        )]] <- ixn_tbl
      }
    }
  }

  gbm_2D <- ALE(
    model = gbm.data,
    data = adult_data,
    x_cols = c(
      'age:education_num',
      'age:hours_per_week',
      'workclass:education_num',
      'workclass:hours_per_week',
      'education_num:hours_per_week',
      'sex:hours_per_week'
    ),
    parallel = 0,
    output_stats = FALSE,
    pred_fun = gbm_pred_fun_ale,
    pred_type = 'link', max_num_bins = 10 + 1,  # specific options requested
    silent = TRUE
  )

  # Return to regular printing of plots.
  # For some reason, calling ALE() on gbm.data also prints some plots.
  dev.off() |> invisible()

  # Convert ale results to version that can be readily compared with ALEPlot
  gbm_2D_to_ALEPlot <-
    get(gbm_2D, ale_centre = 'zero') |>
    map(\(it.ale) {
      it.ale <- it.ale |>
        select(1, 2, .y) |>
        set_names(c('.x1', '.x2', '.y')) |>
        # Convert [ordered] factor columns to character for comparability with ALEPlot
        mutate(across(
            '.x1',
            \(it.col) if (is.factor(it.col)) as.character(it.col) else it.col
        )) |>
        arrange(.x1, .x2, .y)

      # Strip incomparable attributes
      attr(it.ale, 'x') <- NULL

      it.ale
    })

  # Compare results of ALEPlot with ale
  expect_true(
    all.equal(gbm_ALEPlot_ixn, gbm_2D_to_ALEPlot)
  )
})

Try the ale package in your browser

Any scripts or data that you put into this service are public.

ale documentation built on April 11, 2025, 6:09 p.m.