Nothing
test_that("sage_batch_predict works for regression without batching", {
skip_if_not_installed("rpart")
set.seed(123)
task = tgen("friedman1")$generate(n = 50)
learner = lrn("regr.rpart")
learner$train(task)
test_data = task$data()
predictions = sage_batch_predict(learner, test_data, task, batch_size = NULL, task_type = "regr")
expect_type(predictions, "double")
expect_length(predictions, 50)
expect_true(all(is.finite(predictions)))
})
test_that("sage_batch_predict works for regression with batching", {
skip_if_not_installed("rpart")
set.seed(123)
task = tgen("friedman1")$generate(n = 50)
learner = lrn("regr.rpart")
learner$train(task)
test_data = task$data()
# Predict without batching
pred_no_batch = sage_batch_predict(
learner,
test_data,
task,
batch_size = NULL,
task_type = "regr"
)
# Predict with batching
pred_with_batch = sage_batch_predict(
learner,
test_data,
task,
batch_size = 10,
task_type = "regr"
)
# Should be identical
expect_equal(pred_no_batch, pred_with_batch)
})
test_that("sage_batch_predict works for classification without batching", {
skip_if_not_installed("rpart")
set.seed(123)
task = tgen("2dnormals")$generate(n = 50)
learner = lrn("classif.rpart", predict_type = "prob")
learner$train(task)
test_data = task$data()
predictions = sage_batch_predict(
learner,
test_data,
task,
batch_size = NULL,
task_type = "classif"
)
expect_true(is.matrix(predictions))
expect_equal(nrow(predictions), 50)
expect_equal(ncol(predictions), 2) # Binary classification
expect_true(all(is.finite(predictions)))
# Probabilities should sum to 1
expect_true(all(abs(rowSums(predictions) - 1) < 1e-10))
})
test_that("sage_batch_predict works for classification with batching", {
skip_if_not_installed("rpart")
set.seed(123)
task = tgen("2dnormals")$generate(n = 50)
learner = lrn("classif.rpart", predict_type = "prob")
learner$train(task)
test_data = task$data()
# Predict without batching
pred_no_batch = sage_batch_predict(
learner,
test_data,
task,
batch_size = NULL,
task_type = "classif"
)
# Predict with batching
pred_with_batch = sage_batch_predict(
learner,
test_data,
task,
batch_size = 10,
task_type = "classif"
)
# Should be identical
expect_equal(pred_no_batch, pred_with_batch)
})
test_that("sage_aggregate_predictions works for regression", {
# Create sample data with multiple samples per coalition/instance
combined_data = data.table::data.table(
.coalition_id = c(1, 1, 1, 2, 2, 2),
.test_instance_id = c(1, 1, 1, 1, 1, 1),
feature1 = c(1, 2, 3, 4, 5, 6)
)
predictions = c(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)
result = sage_aggregate_predictions(
combined_data,
predictions,
task_type = "regr",
class_names = NULL
)
expect_equal(nrow(result), 2) # 2 coalitions
expect_true(all(c(".coalition_id", ".test_instance_id", "avg_pred") %in% names(result)))
expect_equal(result[.coalition_id == 1]$avg_pred, 2.0) # mean(1,2,3)
expect_equal(result[.coalition_id == 2]$avg_pred, 5.0) # mean(4,5,6)
})
test_that("sage_aggregate_predictions works for classification", {
# Create sample data with multiple samples per coalition/instance
combined_data = data.table::data.table(
.coalition_id = c(1, 1, 1, 2, 2, 2),
.test_instance_id = c(1, 1, 1, 1, 1, 1),
feature1 = c(1, 2, 3, 4, 5, 6)
)
# Binary classification probabilities
predictions = matrix(
c(
0.9,
0.1,
0.8,
0.2,
0.7,
0.3,
0.6,
0.4,
0.5,
0.5,
0.4,
0.6
),
ncol = 2,
byrow = TRUE
)
class_names = c("A", "B")
result = sage_aggregate_predictions(
combined_data,
predictions,
task_type = "classif",
class_names = class_names
)
expect_equal(nrow(result), 2) # 2 coalitions
expect_true(all(c(".coalition_id", ".test_instance_id", "A", "B") %in% names(result)))
# Check averaged probabilities for coalition 1
expect_equal(result[.coalition_id == 1]$A, mean(c(0.9, 0.8, 0.7)))
expect_equal(result[.coalition_id == 1]$B, mean(c(0.1, 0.2, 0.3)))
# Check averaged probabilities for coalition 2
expect_equal(result[.coalition_id == 2]$A, mean(c(0.6, 0.5, 0.4)))
expect_equal(result[.coalition_id == 2]$B, mean(c(0.4, 0.5, 0.6)))
})
test_that("sage_aggregate_predictions handles multiple test instances", {
# Create data with 2 coalitions and 2 test instances
combined_data = data.table::data.table(
.coalition_id = c(1, 1, 1, 1, 2, 2, 2, 2),
.test_instance_id = c(1, 1, 2, 2, 1, 1, 2, 2),
feature1 = 1:8
)
predictions = 1:8
result = sage_aggregate_predictions(
combined_data,
predictions,
task_type = "regr",
class_names = NULL
)
expect_equal(nrow(result), 4) # 2 coalitions × 2 test instances
expect_equal(result[.coalition_id == 1 & .test_instance_id == 1]$avg_pred, mean(c(1, 2)))
expect_equal(result[.coalition_id == 1 & .test_instance_id == 2]$avg_pred, mean(c(3, 4)))
expect_equal(result[.coalition_id == 2 & .test_instance_id == 1]$avg_pred, mean(c(5, 6)))
expect_equal(result[.coalition_id == 2 & .test_instance_id == 2]$avg_pred, mean(c(7, 8)))
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.