skip_connection("ml-regression-aft-survival-regression")
skip_on_livy()
skip_on_arrow_devel()
skip_databricks_connect()
test_that("ml_aft_survival_regression() default params", {
test_requires_version("3.0.0")
sc <- testthat_spark_connection()
test_default_args(sc, ml_aft_survival_regression)
})
test_that("ml_aft_survival_regression() param setting", {
test_requires_version("3.0.0")
sc <- testthat_spark_connection()
test_args <- list(
censor_col = "ccol",
quantile_probabilities = c(0.02, 0.5, 0.98),
fit_intercept = FALSE,
max_iter = 42,
tol = 1e-04,
aggregation_depth = 3,
quantiles_col = "qcol",
features_col = "fcol",
label_col = "lcol",
prediction_col = "pcol"
)
test_param_setting(sc, ml_aft_survival_regression, test_args)
})
sc <- testthat_spark_connection()
training <- data.frame(
label = c(1.218, 2.949, 3.627, 0.273, 4.199),
censor = c(1.0, 0.0, 0.0, 1.0, 0.0),
V1 = c(1.560, 0.346, 1.380, 0.520, 0.795),
V2 = c(-0.605, 2.158, 0.231, 1.151, -0.226)
)
training_tbl <- sdf_copy_to(sc, training, overwrite = TRUE)
test_that("ml_aft_survival_regression() works properly", {
training_va <- ft_vector_assembler(
training_tbl,
c("V1", "V2"), "features"
)
aft <- ml_aft_survival_regression(
training_va,
quantile_probabilities = list(0.3, 0.6),
quantiles_col = "quantiles"
)
expect_equal(aft$coefficients, c(-0.49631114666506776, 0.19844437699934067), tolerance = 1e-4, scale = 1)
expect_equal(aft$intercept, 2.6380946151040043, tolerance = 1e-4, scale = 1)
expect_equal(aft$scale, 1.5472345574364683, tolerance = 1e-4, scale = 1)
predicted_tbl <- ml_predict(aft, training_va)
expect_warning_on_arrow(
p_q <- predicted_tbl %>%
dplyr::pull(quantiles) %>%
dplyr::first()
)
expect_equal(p_q,
c(1.1603238947151593, 4.995456010274735),
tolerance = 1e-4, scale = 1
)
expect_equal(predicted_tbl %>%
dplyr::pull(prediction) %>%
dplyr::first(),
5.718979487634966,
tolerance = 1e-4, scale = 1
)
aft_model <- ml_aft_survival_regression(training_va, label ~ V1 + V2, features_col = "feat")
expect_equal(coef(aft_model),
structure(c(2.63808989630564, -0.496304411053117, 0.198452172529228), .Names = c("(Intercept)", "V1", "V2")),
tolerance = 1e-05, scale = 1
)
})
test_that("ML Pipeline works", {
aft_pipeline <- ml_pipeline(sc) %>%
ft_vector_assembler(c("V1", "V2"), "features") %>%
ml_aft_survival_regression(
quantile_probabilities = list(0.3, 0.6),
quantiles_col = "quantiles"
)
expect_is(aft_pipeline, "ml_pipeline")
aft_fitted <- ml_fit(aft_pipeline, training_tbl)
expect_is(aft_fitted, "ml_pipeline_model")
})
test_that("Deprecated function fails", {
sc <- testthat_spark_connection()
expect_warning(
ml_survival_regression(sc),
"'ml_survival_regression' is deprecated."
)
})
test_that("Tuning works with AFT", {
test_requires("survival")
sc <- testthat_spark_connection()
ovarian_tbl <- sdf_copy_to(
sc,
survival::ovarian,
name = "ovarian_tbl",
overwrite = TRUE
)
pipeline <- ml_pipeline(sc) %>%
ft_r_formula(futime ~ ecog_ps + rx + age + resid_ds) %>%
ml_aft_survival_regression(censor_col = "fustat")
cv <- ml_cross_validator(
sc,
estimator = pipeline,
estimator_param_maps = list(
aft_survival_regression = list(
max_iter = c(10, 20),
aggregation_depth = c(2, 4)
)
),
evaluator = ml_regression_evaluator(sc),
num_folds = 2,
seed = 1111
)
cv_model <- ml_fit(cv, ovarian_tbl)
expect_is(cv_model, "ml_cross_validator_model")
cv_metrics <- ml_validation_metrics(cv_model)
expect_equal(dim(cv_metrics), c(4, 3))
})
test_clear_cache()
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.