tests/testthat/test-pred_predict.R

test_that("output of pred_predict for logistic model is correct", {
  coefs_table <- data.frame("Intercept" = -3.4,
                            "Sex_M" = 0.306,
                            "Smoking_Status" = 0.628)
  existing_Logistic_Model <- pred_input_info(model_type = "logistic",
                                             model_info = coefs_table)
  new_df <- data.frame("Sex" = as.factor(c("M", "F", "M", "M", "F", "F", "M")),
                       "Smoking_Status" = c(1, 0, 0, 1, 1, 0, 1))

  expect_error(pred_predict(x = existing_Logistic_Model,
                            new_data = new_df))

  #new_df has a factor variable, so needs indicator variables creating before pred_predict:
  new_df_indicators <- dummy_vars(new_df)
  predout <- pred_predict(x = existing_Logistic_Model,
                          new_data = new_df_indicators)

  expect_type(predout, type = "list")
  expect_equal(length(predout), 3)
  expect_equal(names(predout), c("LinearPredictor", "PredictedRisk", "Outcomes"))
})


test_that("output of pred_predict for survival model is correct", {
  model2 <- pred_input_info(model_type = "survival",
                            model_info = SYNPM$Existing_TTE_models[1,],
                            cum_hazard = SYNPM$TTE_mod1_baseline)
  predout <- pred_predict(x = model2,
                          new_data = SYNPM$ValidationData,
                          survival_time = "ETime",
                          event_indicator = "Status",
                          time_horizon = 5)

  expect_type(predout, type = "list")
  expect_equal(length(predout), 4)
  expect_equal(names(predout), c("LinearPredictor", "PredictedRisk", "TimeHorizon", "Outcomes"))
})


test_that("un-identified baseline hazard time is handled correctly", {
  model2 <- pred_input_info(model_type = "survival",
                            model_info = SYNPM$Existing_TTE_models[1,],
                            cum_hazard = SYNPM$TTE_mod1_baseline)
  expect_error(pred_predict(x = model2,
                            new_data = SYNPM$ValidationData,
                            survival_time = "ETime",
                            event_indicator = "Status",
                            time_horizon = 50))

  model2 <- pred_input_info(model_type = "survival",
                            model_info = SYNPM$Existing_TTE_models[1:2,],
                            cum_hazard = list(SYNPM$TTE_mod1_baseline,
                                              SYNPM$TTE_mod2_baseline))
  expect_error(pred_predict(x = model2,
                            new_data = SYNPM$ValidationData,
                            survival_time = "ETime",
                            event_indicator = "Status",
                            time_horizon = 50))

  expect_error(pred_predict(x = model2,
                            new_data = SYNPM$ValidationData,
                            survival_time = "ETime",
                            event_indicator = "Status",
                            time_horizon = c(5, 50)))
})


test_that("output of pred_predict for multiple logistic model passing", {
  model2 <- pred_input_info(model_type = "logistic",
                            model_info = SYNPM$Existing_logistic_models)
  predout <- pred_predict(x = model2,
                          new_data = SYNPM$ValidationData,
                          binary_outcome = "Y")

  expect_type(predout, type = "list")
  expect_equal(length(predout), model2$M)
  expect_equal(names(predout[[1]]), c("LinearPredictor", "PredictedRisk", "Outcomes"))
  expect_equal(names(predout[[2]]), c("LinearPredictor", "PredictedRisk", "Outcomes"))
})


test_that("output of pred_predict for multiple survival model passing", {
  model2 <- pred_input_info(model_type = "survival",
                            model_info = SYNPM$Existing_TTE_models[1:2,],
                            cum_hazard = list(SYNPM$TTE_mod1_baseline,
                                              SYNPM$TTE_mod2_baseline))
  predout <- pred_predict(x = model2,
                          new_data = SYNPM$ValidationData,
                          survival_time = "ETime",
                          event_indicator = "Status",
                          time_horizon = 5)

  expect_type(predout, type = "list")
  expect_equal(length(predout), model2$M)
  expect_equal(names(predout[[1]]), c("LinearPredictor", "PredictedRisk", "TimeHorizon", "Outcomes"))
  expect_equal(names(predout[[2]]), c("LinearPredictor", "PredictedRisk", "TimeHorizon", "Outcomes"))
})


test_that("output of pred_predict for survival models with no cum_hazard", {
  model2 <- pred_input_info(model_type = "survival",
                            model_info = SYNPM$Existing_TTE_models,
                            cum_hazard = list(NULL,
                                              NULL,
                                              NULL))
  predout <- pred_predict(x = model2,
                          new_data = SYNPM$ValidationData,
                          survival_time = "ETime",
                          event_indicator = "Status",
                          time_horizon = 5)

  expect_type(predout, type = "list")
  expect_equal(length(predout), model2$M)
  expect_equal(names(predout[[1]]$PredictedRisk), NULL)


  model2 <- pred_input_info(model_type = "survival",
                            model_info = SYNPM$Existing_TTE_models[2,],
                            cum_hazard = SYNPM$TTE_mod2_baseline)
  predout <- pred_predict(x = model2,
                          new_data = SYNPM$ValidationData,
                          survival_time = "ETime",
                          event_indicator = "Status",
                          time_horizon = 5)

  expect_type(predout, type = "list")
  expect_equal(length(predout), 4)
  expect_equal(names(predout$PredictedRisk), NULL)
})



test_that("error messages of pred_predict are as expected", {
  model2 <- pred_input_info(model_type = "survival",
                            model_info = SYNPM$Existing_TTE_models,
                            cum_hazard = list(SYNPM$TTE_mod1_baseline,
                                              SYNPM$TTE_mod2_baseline,
                                              SYNPM$TTE_mod3_baseline))
  expect_error(pred_predict(x = model2,
                            new_data = SYNPM$ValidationData,
                            survival_time = "ETime",
                            event_indicator = "Status",
                            time_horizon = c(3, 4, 5))) #not allowed multiple time_horizons

  expect_error(pred_predict(x = model2,
                            new_data = SYNPM$ValidationData,
                            survival_time = "Time", #outcome name wrong
                            event_indicator = "Status",
                            time_horizon = 5))

  expect_error(pred_predict(x = data.frame("test"), #not predinfo object
                            new_data = SYNPM$ValidationData,
                            survival_time = "ETime",
                            event_indicator = "Status",
                            time_horizon = 5))
})

Try the predRupdate package in your browser

Any scripts or data that you put into this service are public.

predRupdate documentation built on Sept. 11, 2024, 5:34 p.m.