tests/testthat/test_predict_equal.R

get_test = function(cfg) {
  data = cfg$data
  #data$xtest[sample(NROW(data$xtest), size = 10^5, replace = FALSE), ]
  data$xtest
}

get_trafos = function(cfg) {
  data = cfg$data
  data$trafos
}

predict_like_fitting = function(test, model_path) {
  # uses keras model
  # as done during fit_surrogate
  model = keras::load_model_hdf5(model_path, compile = FALSE)
  rs2 = mlr3keras::reshape_data_embedding(test)
  ptest = as.data.table(predict(model, rs2$data))
}

predict_objective = function(xdt, objective, trafos) {
  # uses ONNX model
  # as done in objective fun
  if (length(trafos)) {
    xdt[, names(trafos) := pmap(list(.SD, trafos), function(x, t) {t$retrafo(x)}), .SDcols = names(trafos)]  # raw data because the objective applies the trafos
  }
  objective$eval_dt(xdt)
}

predictions_equal = function(cfg) {
  set.seed(123L)
  test = get_test(cfg)
  
  p1 = predict_like_fitting(test, model_path = paste0(cfg$subdir, cfg$keras_model_file))
  names(p1) = cfg$target_variables

  objective = cfg$get_objective(retrafo = FALSE)
  trafos = get_trafos(cfg)
  trafos = trafos[names(trafos) %in% names(test)]  # applicable trafos

  xdt = copy(test)
  xdt = xdt[, mlr3misc::shuffle(names(xdt)), with = FALSE]  # shuffling has no effect

  p2 = predict_objective(xdt, objective = objective, trafos = trafos)
  names(p2) = cfg$target_variables

  truth = cfg$data$ytest
  p1 = as.matrix(p1)
  p2 = as.matrix(p2)
  colnames(truth) = colnames(p1) = colnames(p2) = cfg$target_variables

  metrics1 = compute_metrics(truth, p1)
  metrics2 = compute_metrics(truth, p2)

  metrics_ref = fread(paste0(cfg$subdir, "surrogate_test_metrics.csv"))

  all(c(all(abs(p1 - p2 ) <= 1e-4), all(abs(metrics1[, -c(1, 2, 5)] - metrics2[, -c(1, 2, 5)]) <= 5e-2),  all(abs(metrics1[, -c(1, 2, 5)] - metrics_ref[, -c(1, 2, 5)]) <= 5e-2)))
}

test_that("predict equal", {
  skip_if_not(check_directory_exists(workdir))
  cfgs = c(grep("rbv2", benchmark_configs$keys(), value = TRUE), "lcbench", "nb301")
  for (cfg in cfgs) {
    config = benchmark_configs$get(cfg, workdir = workdir)
    expect_true(predictions_equal(config))
  }
})
slds-lmu/paper_2021_multi_fidelity_surrogates documentation built on Feb. 20, 2022, 11:53 a.m.