


ml_train_template_id <- 234
ml_predict_template_id <- 456

test_that("jsonlite works", {
  # The tests below fail when run via R CMD check due with a
  # "invalid encoding argument" error. jsonlite::toJSON is the last thing in
  # the traceback. Adding this test here seems to make the issue go away.
  # TODO: find out why.
  s <- jsonlite::toJSON(list(sample_weight = "survey_weights"), auto_unbox = TRUE)
  expect_equal(unclass(s), "{\"sample_weight\":\"survey_weights\"}")

# Build

test_that("calls scripts_post_custom", {
  fake_get_database_id <- mock(456, cycle = TRUE)
  fake_scripts_post_custom <- mock(list(id = 999))
  fake_scripts_post_custom_runs <- mock(list(id = 888))
  fake_scripts_get_custom_runs <- mock(list(state = "running"), list(state = "succeeded"))
  fake_civis_ml_fetch_existing <- mock(NULL)

    `civis::get_database_id` = fake_get_database_id,
    `civis::scripts_post_custom` = fake_scripts_post_custom,
    `civis::scripts_post_custom_runs` = fake_scripts_post_custom_runs,
    `civis::scripts_get_custom_runs` = fake_scripts_get_custom_runs,
    `civis::civis_ml_fetch_existing` = fake_civis_ml_fetch_existing,
    `civis::get_train_template_id` = function(...) ml_train_template_id,

    tbl <- civis_table(table_name = "schema.table",
                       database_name = "a_database",
                       sql_where = "1 = 2",
                       sql_limit = 10),

    civis_ml(x = tbl,
             model_type = "sparse_logistic",
             dependent_variable = "target",
             excluded_columns = c("col_1", "col_2", "col_3"),
             primary_key = "row_number",
             parameters = list(n_estimators = 10),
             cross_validation_parameters = list(n_estimators = c(10, 20, 30)),
             model_name = "awesome civisml",
             calibration = "sigmoid",
             oos_scores_table = "score.table",
             oos_scores_db = "another_database",
             oos_scores_if_exists = "drop",
             fit_params = list(sample_weight = "survey_weights"),
             cpu_requested = 1111,
             memory_requested = 9096,
             disk_requested = 9,
             notifications = list(successEmailSubject = "A success",
                                  successEmailAddresses = c("user@example.com")),
             polling_interval = .01,
             validation_data = "skip",
             n_jobs = 9,
             verbose = FALSE)

  script_args <- mock_args(fake_scripts_post_custom)[[1]]
  expect_equal(script_args$from_template_id, ml_train_template_id)
  expect_equal(script_args$name, "awesome civisml Train")
  expect_equal(script_args$notifications, list(successEmailSubject = "A success",
                                               successEmailAddresses = c("user@example.com")))

  # These are template args/params:
  ml_args <- script_args$arguments
  expect_is(ml_args, "AsIs")  # We don't want jsonlite doing anything unexpected.
  expect_equal(ml_args$MODEL, "sparse_logistic")
  expect_equal(ml_args$TARGET_COLUMN, "target")
  expect_equal(ml_args$PRIMARY_KEY, "row_number")
  expect_equal(unclass(ml_args$PARAMS), '{"n_estimators":10}')
  expect_equal(unclass(ml_args$CVPARAMS), '{"n_estimators":[10,20,30]}')
  expect_equal(ml_args$CALIBRATION, "sigmoid")
  expect_equal(ml_args$IF_EXISTS, "drop")
  expect_equal(ml_args$TABLE_NAME, "schema.table")
  expect_equal(ml_args$CIVIS_FILE_ID, NULL)
  expect_equal(ml_args$OOSTABLE, "score.table")
  expect_equal(ml_args$OOSDB, list(database = 456))
  expect_equal(ml_args$WHERESQL, "1 = 2")
  expect_equal(ml_args$LIMITSQL, 10)
  expect_equal(ml_args$EXCLUDE_COLS, "col_1 col_2 col_3")
  expect_equal(unclass(ml_args$FIT_PARAMS), '{"sample_weight":"survey_weights"}')
  expect_equal(ml_args$DB, list(database = 456))
  expect_equal(ml_args$REQUIRED_CPU, 1111)
  expect_equal(ml_args$REQUIRED_MEMORY, 9096)
  expect_equal(ml_args$REQUIRED_DISK_SPACE, 9)
  expect_equal(ml_args$VALIDATION_DATA, "skip")
  expect_equal(ml_args$N_JOBS, 9)

  # Make sure we started the job.
  expect_args(fake_scripts_post_custom_runs, 1, 999)

  # And checked it's status
  expect_args(fake_scripts_get_custom_runs, 1, 999, 888)
  expect_called(fake_scripts_get_custom_runs, 2)

test_that("calls civis_ml.data.frame for local df", {
  fake_write_civis_file <- mock(1234)
  fake_get_database_id <- mock(456)
  fake_create_and_run_model <- mock(NULL)

    `civis::write_civis_file` = fake_write_civis_file,
    `civis::get_database_id` = fake_get_database_id,
    `civis::create_and_run_model` = fake_create_and_run_model,
    `civis::get_train_template_id` = function(...) ml_train_template_id,

             model_type = "sparse_logistic",
             dependent_variable = "the_target_column",
             primary_key = "the_pk_column"),

    expect_args(fake_create_and_run_model, 1,
                file_id = 1234,
                dependent_variable = "the_target_column",
                excluded_columns = NULL,
                primary_key = "the_pk_column",
                model_type = "sparse_logistic",
                parameters = NULL,
                cross_validation_parameters = NULL,
                fit_params = NULL,
                calibration = NULL,
                oos_scores_table = NULL,
                oos_scores_db_id = NULL,
                oos_scores_if_exists = 'fail',
                model_name = NULL,
                cpu_requested = NULL,
                memory_requested = NULL,
                disk_requested = NULL,
                validation_data = 'train',
                n_jobs = NULL,
                notifications = NULL,
                verbose = FALSE,
                civisml_version = "prod")

test_that("calls civis_ml.civis_table for table_name", {
  fake_get_database_id <- mock(456)
  fake_create_and_run_model <- mock(NULL)

    `civis::get_database_id` = fake_get_database_id,
    `civis::create_and_run_model` = fake_create_and_run_model,

    x <- civis_table(table_name = "a_schema.table",
                     database_name = "a_database",
                     sql_where = "a = b",
                     sql_limit = 6),

    civis_ml(x = x,
             model_type = "sparse_logistic",
             dependent_variable = "the_target_column",
             primary_key = "the_pk_column")

  expect_args(fake_get_database_id, 1, "a_database")

  expect_args(fake_create_and_run_model, 1,
              table_name = "a_schema.table",
              database_id = 456,
              sql_where = "a = b",
              sql_limit = 6,
              dependent_variable = "the_target_column",
              excluded_columns = NULL,
              primary_key = "the_pk_column",
              model_type = "sparse_logistic",
              parameters = NULL,
              cross_validation_parameters = NULL,
              fit_params = NULL,
              calibration = NULL,
              oos_scores_table = NULL,
              oos_scores_db_id = NULL,
              oos_scores_if_exists = 'fail',
              model_name = NULL,
              cpu_requested = NULL,
              memory_requested = NULL,
              disk_requested = NULL,
              validation_data = 'train',
              n_jobs = NULL,
              notifications = NULL,
              verbose = FALSE,
              civisml_version = "prod")

test_that("calls civis_ml.civis_file for file_id", {
  fake_get_database_id <- mock(456)
  fake_create_and_run_model <- mock(NULL)

    `civis::get_database_id` = fake_get_database_id,
    `civis::create_and_run_model` = fake_create_and_run_model,

    civis_ml(x = civis_file(file_id = 123),
             model_type = "sparse_logistic",
             dependent_variable = "the_target_column",
             primary_key = "the_pk_column")

  expect_args(fake_create_and_run_model, 1,
              file_id = civis_file(123),
              dependent_variable = "the_target_column",
              excluded_columns = NULL,
              primary_key = "the_pk_column",
              model_type = "sparse_logistic",
              parameters = NULL,
              cross_validation_parameters = NULL,
              fit_params = NULL,
              calibration = NULL,
              oos_scores_table = NULL,
              oos_scores_db_id = NULL,
              oos_scores_if_exists = 'fail',
              model_name = NULL,
              cpu_requested = NULL,
              memory_requested = NULL,
              disk_requested = NULL,
              validation_data = 'train',
              n_jobs = NULL,
              notifications = NULL,
              verbose = FALSE,
              civisml_version = "prod")

test_that("calls civis_ml.character for local csv", {
  fake_get_database_id <- mock(456)
  fake_write_civis_file <- mock(123)
  fake_create_and_run_model <- mock(NULL)

    `civis::get_database_id` = fake_get_database_id,
    `civis::write_civis_file` = fake_write_civis_file,
    `civis::create_and_run_model` = fake_create_and_run_model,

    civis_ml(x =  "fake_temp_path",
             model_type = "sparse_logistic",
             dependent_variable = "the_target_column",
             primary_key = "the_pk_column")

  expect_args(fake_write_civis_file, 1,
              path = "fake_temp_path",
              name = "modelpipeline_data.csv")

  expect_args(fake_create_and_run_model, 1,
              file_id = 123,
              dependent_variable = "the_target_column",
              excluded_columns = NULL,
              primary_key = "the_pk_column",
              model_type = "sparse_logistic",
              parameters = NULL,
              cross_validation_parameters = NULL,
              fit_params = NULL,
              calibration = NULL,
              oos_scores_table = NULL,
              oos_scores_db_id = NULL,
              oos_scores_if_exists = 'fail',
              model_name = NULL,
              cpu_requested = NULL,
              memory_requested = NULL,
              disk_requested = NULL,
              validation_data = 'train',
              n_jobs = NULL,
              notifications = NULL,
              verbose = FALSE,
              civisml_version = "prod")

test_that("raises error on invalid calibration", {
  fake_get_database_id <- mock(456)
  fake_write_civis_file <- mock(123)

    `civis::get_database_id` = fake_get_database_id,
    `civis::write_civis_file` = fake_write_civis_file,

    expect_error(civis_ml(x = "fake_temp_path",
                          model_type = "sparse_logistic",
                          dependent_variable = "target",
                          primary_key = "pk",
                          calibration = "fake"),
                 "calibration must be 'sigmoid', 'isotonic', or NULL\\.")

test_that("raises error if multioutput not supported", {
  fake_get_database_id <- mock(456, cycle = TRUE)
  fake_write_civis_file <- mock(123, cycle = TRUE)
  mo_not_supported <- c("sparse_linear_regressor", "sparse_ridge_regressor", "gradient_boosting_regressor",
    "sparse_logistic", "gradient_boosting_classifier")

  for (mtype in mo_not_supported) {
      `civis::get_database_id` = fake_get_database_id,
      `civis::write_civis_file` = fake_write_civis_file,

      expect_error(civis_ml(x = "fake_temp_path",
                            model_type = mtype,
                            dependent_variable = c("target", "target_2"),
                            primary_key = "pk",
                            calibration = "fake"),
                   paste0("Multioutput is not supported for ", mtype))

# Predict

fake_model <- structure(
    job = list(
      id = 123,
      name = "model_task",
      fromTemplateId = ml_train_template_id,
      arguments = list(
        PRIMARY_KEY = "training_primary_key"
    run = list(id = 456)
  class = "civis_ml"

test_that("calls scripts_post_custom", {
  fake_get_database_id <- mock(456, cycle = TRUE)
  fake_scripts_post_custom <- mock(list(id = 999))
  fake_scripts_post_custom_runs <- mock(list(id = 888))
  fake_scripts_get_custom_runs <- mock(list(state = "running"), list(state = "succeeded"))
  fake_scripts_get_custom <- mock(list(state = "succeeded"), cycle = TRUE)
  fake_fetch_predict_results <- mock(NULL)

    `civis::get_database_id` = fake_get_database_id,
    `civis::scripts_post_custom` = fake_scripts_post_custom,
    `civis::scripts_post_custom_runs` = fake_scripts_post_custom_runs,
    `civis::scripts_get_custom` = fake_scripts_get_custom,
    `civis::scripts_get_custom_runs` = fake_scripts_get_custom_runs,
    `civis::fetch_predict_results` = fake_fetch_predict_results,
    `civis::get_predict_template_id` = function(...) ml_predict_template_id,

    tbl <- civis_table(table_name = "schema.table",
                       database_name = "a_database",
                       sql_where = "6 = 7",
                       sql_limit = 7),
            newdata = tbl,
            primary_key = "row_number",
            output_table = "score.table",
            output_db = "score_database",
            if_output_exists = "append",
            n_jobs = 10,
            cpu_requested = 2000,
            memory_requested = 10,
            disk_requested = 15,
            polling_interval = .01,
            verbose = TRUE)

  script_args <- mock_args(fake_scripts_post_custom)[[1]]
  expect_equal(script_args$from_template_id, ml_predict_template_id)
  expect_equal(script_args$name, "model_task Predict")

  # These are template args/params:
  pred_args <- script_args$arguments
  expect_is(pred_args, "AsIs")  # We don't want jsonlite doing anything unexpected.
  expect_equal(pred_args$TRAIN_JOB, 123)
  expect_equal(pred_args$TRAIN_RUN, 456)
  expect_equal(pred_args$PRIMARY_KEY, "row_number")
  expect_equal(pred_args$IF_EXISTS, "append")
  expect_equal(pred_args$N_JOBS, 10)
  expect_equal(pred_args$CPU, 2000)
  expect_equal(pred_args$MEMORY, 10)
  expect_equal(pred_args$DISK_SPACE, 15)
  expect_equal(pred_args$DEBUG, TRUE)
  expect_equal(pred_args$CIVIS_FILE_ID, NULL)
  expect_equal(pred_args$TABLE_NAME, "schema.table")
  expect_equal(pred_args$DB, list(database = 456))
  expect_equal(pred_args$WHERESQL, "6 = 7")
  expect_equal(pred_args$LIMITSQL, 7)
  expect_equal(pred_args$OUTPUT_TABLE, "score.table")
  expect_equal(pred_args$OUTPUT_DB, list(database = 456))

  # Make sure we started the job.
  expect_args(fake_scripts_post_custom_runs, 1, 999)

  # And checked it's status
  expect_args(fake_scripts_get_custom_runs, 1, 999, 888)
  expect_called(fake_scripts_get_custom_runs, 2)

test_that("uses training primary_key by default", {
  fake_get_database_id <- mock(123)
  fake_create_and_run_pred <- mock(NULL)

    `civis::get_database_id` = fake_get_database_id,
    `civis::create_and_run_pred` = fake_create_and_run_pred,
    `civis::get_predict_template_id` = function(...) ml_predict_template_id,

    tbl <- civis_table(table_name = "schema.table", database_name = "the_db"),
    predict(fake_model, newdata = tbl)

  run_args <- mock_args(fake_create_and_run_pred)[[1]]
  expect_equal(run_args$primary_key, "training_primary_key")

test_that("uploads local df and passes a file_id", {
  fake_write_civis_file <- mock(1234)
  fake_create_and_run_pred <- mock(NULL)

    `civis::write_civis_file` = fake_write_civis_file,
    `civis::create_and_run_pred` = fake_create_and_run_pred,
    `civis::get_predict_template_id` = function(...) ml_predict_template_id,

    predict(fake_model, iris, primary_key = NULL)

  expect_args(fake_create_and_run_pred, 1,
              train_job_id = fake_model$job$id,
              train_run_id = fake_model$run$id,
              template_id = ml_predict_template_id,
              primary_key = NULL,
              output_table = NULL,
              output_db_id = NULL,
              if_output_exists = 'fail',
              model_name = "model_task",
              n_jobs = NULL,
              cpu_requested = NULL,
              memory_requested = NULL,
              disk_requested = NULL,
              polling_interval = NULL,
              verbose = FALSE,
              file_id = 1234)

test_that("uploads a local file and passes a file_id", {
  fake_write_civis_file <- mock(561)
  fake_create_and_run_pred <- mock(NULL)

    `civis::write_civis_file` = fake_write_civis_file,
    `civis::create_and_run_pred` = fake_create_and_run_pred,
    `civis::get_predict_template_id` = function(...) ml_predict_template_id,

    predict(fake_model, "fake_temp_path", primary_key = NULL)

  expect_args(fake_write_civis_file, 1,

  expect_args(fake_create_and_run_pred, 1,
              train_job_id = fake_model$job$id,
              train_run_id = fake_model$run$id,
              template_id = ml_predict_template_id,
              primary_key = NULL,
              output_table = NULL,
              output_db_id = NULL,
              if_output_exists = 'fail',
              model_name = "model_task",
              n_jobs = NULL,
              cpu_requested = NULL,
              memory_requested = NULL,
              disk_requested = NULL,
              polling_interval = NULL,
              verbose = FALSE,
              file_id = 561)

test_that("passes a file_id directly", {
  fake_create_and_run_pred <- mock(NULL)

    `civis::create_and_run_pred` = fake_create_and_run_pred,
    `civis::get_predict_template_id` = function(...) ml_predict_template_id,

    predict(fake_model, civis_file(1234))

  expect_args(fake_create_and_run_pred, 1,
              train_job_id = fake_model$job$id,
              train_run_id = fake_model$run$id,
              template_id = ml_predict_template_id,
              primary_key = "training_primary_key",
              output_table = NULL,
              output_db_id = NULL,
              if_output_exists = 'fail',
              model_name = "model_task",
              n_jobs = NULL,
              cpu_requested = NULL,
              memory_requested = NULL,
              disk_requested = NULL,
              polling_interval = NULL,
              verbose = FALSE,
              file_id = 1234)

test_that("passes a manifest file_id", {
  fake_create_and_run_pred <- mock(NULL)

    `civis::create_and_run_pred` = fake_create_and_run_pred,
    `civis::get_predict_template_id` = function(...) ml_predict_template_id,

    predict(fake_model, civis_file_manifest(123), primary_key = NULL)

  expect_args(fake_create_and_run_pred, 1,
              train_job_id = fake_model$job$id,
              train_run_id = fake_model$run$id,
              template_id = ml_predict_template_id,
              primary_key = NULL,
              output_table = NULL,
              output_db_id = NULL,
              if_output_exists = 'fail',
              model_name = "model_task",
              n_jobs = NULL,
              cpu_requested = NULL,
              memory_requested = NULL,
              disk_requested = NULL,
              polling_interval = NULL,
              verbose = FALSE,
              manifest = 123)

test_that("passes table info", {
  fake_get_database_id <- mock(999)
  fake_create_and_run_pred <- mock(NULL)

    `civis::get_database_id` = fake_get_database_id,
    `civis::create_and_run_pred` = fake_create_and_run_pred,
    `civis::get_predict_template_id` = function(...) ml_predict_template_id,

    table_to_score <- civis_table(
      table_name = "a_schema.table",
      database_name = "a_database",
      sql_where = "row_number in (1, 2, 4)",
      sql_limit = 11
    predict(fake_model, table_to_score, primary_key = NULL)

  expect_args(fake_get_database_id, 1, "a_database")
  expect_args(fake_create_and_run_pred, 1,
              train_job_id = fake_model$job$id,
              train_run_id = fake_model$run$id,
              template_id = ml_predict_template_id,
              primary_key = NULL,
              output_table = NULL,
              output_db_id = NULL,
              if_output_exists = 'fail',
              model_name = "model_task",
              n_jobs = NULL,
              cpu_requested = NULL,
              memory_requested = NULL,
              disk_requested = NULL,
              polling_interval = NULL,
              verbose = FALSE,
              table_name = "a_schema.table",
              database_id = 999,
              sql_where = "row_number in (1, 2, 4)",
              sql_limit = 11)


test_that("newer CivisML versions use feather", {
  # enforce newer CivisML version
  temp_id <- 11219
  # factor should not cause errors when using feather
  x <- data.frame(a = 1:3, b = letters[1:3])
  fake_file <- mock(1)
    `civis::write_civis_file` = fake_file,
      stash_local_dataframe(x, temp_id)
      args <- mock_args(fake_file)
      expect_equal(args[[1]]$name, "modelpipeline_data.feather")


test_that("older CivisML versions use csv", {
  # enforce older CivisML version
  temp_id <- 9969
  # factor type should not matter for older version
  x <- data.frame(a = 1:3, b = letters[1:3], stringsAsFactors = FALSE)
  fake_file <- mock(1)
    `civis::write_civis_file` = fake_file,
      stash_local_dataframe(x, temp_id)
      args <- mock_args(fake_file)
      expect_equal(args[[1]]$name, "modelpipeline_data.csv")

# run build model

test_that("uses the correct template_id", {
  fake_run_model <- mock(list(job_id = 133, run_id = 244))
  fake_civis_ml_fetch_existing <- mock(NULL)

    `civis::run_model` = fake_run_model,
    `civis::civis_ml_fetch_existing` = fake_civis_ml_fetch_existing,
    `civis::get_train_template_id` = function(...) ml_train_template_id,

    create_and_run_model(file_id = 123)

  run_args <- mock_args(fake_run_model)[[1]]
  expect_equal(run_args$template_id, ml_train_template_id)

test_that("converts parameters arg to JSON string", {
  fake_run_model <- mock(list(job_id = 133, run_id = 244))
  fake_civis_ml_fetch_existing <- mock(NULL)

    `civis::run_model` = fake_run_model,
    `civis::civis_ml_fetch_existing` = fake_civis_ml_fetch_existing,
    `civis::get_train_template_id` = function(...) ml_train_template_id,

    create_and_run_model(file_id = 123, parameters = list(n_trees = 500, c = -1))

  run_args <- mock_args(fake_run_model)[[1]]
  expect_equal(unclass(run_args$arguments$PARAMS), '{"n_trees":500,"c":-1}')

test_that("converts cross_validation_parameters to JSON string", {
  fake_run_model <- mock(list(job_id = 133, run_id = 244))
  fake_civis_ml_fetch_existing <- mock(NULL)

    `civis::run_model` = fake_run_model,
    `civis::civis_ml_fetch_existing` = fake_civis_ml_fetch_existing,
    `civis::get_train_template_id` = function(...) ml_train_template_id,

    create_and_run_model(file_id = 123,
                         model_type = "sparse_logistic",
                         cross_validation_parameters = list(n_trees = c(500, 250), c = -1))

  run_args <- mock_args(fake_run_model)[[1]]

test_that("converts fit_params to JSON string", {
  fake_run_model <- mock(list(job_id = 133, run_id = 244))
  fake_civis_ml_fetch_existing <- mock(NULL)

    `civis::run_model` = fake_run_model,
    `civis::civis_ml_fetch_existing` = fake_civis_ml_fetch_existing,
    `civis::get_train_template_id` = function(...) ml_train_template_id,

    create_and_run_model(file_id = 123,
                         fit_params = list(weights = "weight_col"))

  run_args <- mock_args(fake_run_model)[[1]]
  expect_equal(unclass(run_args$arguments$FIT_PARAMS), '{"weights":"weight_col"}')

test_that("space separates excluded_columns", {
  fake_run_model <- mock(list(job_id = 133, run_id = 244))
  fake_civis_ml_fetch_existing <- mock(NULL)

    `civis::run_model` = fake_run_model,
    `civis::civis_ml_fetch_existing` = fake_civis_ml_fetch_existing,
    `civis::get_train_template_id` = function(...) ml_train_template_id,

    create_and_run_model(file_id = 132, excluded_columns = c("c1", "c2", "c3"))

  run_args <- mock_args(fake_run_model)[[1]]
  expect_equal(run_args$arguments$EXCLUDE_COLS, "c1 c2 c3")

test_that("space separates target_column", {
  fake_run_model <- mock(list(job_id = 133, run_id = 244))
  fake_civis_ml_fetch_existing <- mock(NULL)

    `civis::run_model` = fake_run_model,
    `civis::civis_ml_fetch_existing` = fake_civis_ml_fetch_existing,
    `civis::get_train_template_id` = function(...) ml_train_template_id,

    create_and_run_model(file_id = 132, dependent_variable = c("c1", "c2"),
                         model_type = "random_forest_regressor")

  run_args <- mock_args(fake_run_model)[[1]]
  expect_equal(run_args$arguments$TARGET_COLUMN, "c1 c2")

test_that("file_id is always numeric", {
  fake_run_model <- mock(list(job_id = 133, run_id = 244))
  fake_civis_ml_fetch_existing <- mock(NULL)

    `civis::run_model` = fake_run_model,
    `civis::civis_ml_fetch_existing` = fake_civis_ml_fetch_existing,
    `civis::get_train_template_id` = function(...) ml_train_template_id,

    create_and_run_model(file_id = civis_file(132))

  run_args <- mock_args(fake_run_model)[[1]]
  expect_equal(run_args$arguments$CIVIS_FILE_ID, 132)

test_that("exceptions with hyperband correct", {
  fake_run_model <- mock(list(job_id = 133, run_id = 244))
  fake_civis_ml_fetch_existing <- mock(NULL)

    `civis::run_model` = fake_run_model,
    `civis::civis_ml_fetch_existing` = fake_civis_ml_fetch_existing,
    err <-  "cross_validation_parameters = \"hyperband\" not supported for sparse_logistic",
    expect_error(create_and_run_model(file_id = civis_file(132),
                                      model_type = "sparse_logistic",
                                      cross_validation_parameters = "hyperband"), err)

test_that("robust if metrics.json not present", {
  fn <- tempfile()
  cat(jsonlite::toJSON(list(a=1, b=letters[1:3])), file = fn)

  fake_outputs <- mock(list(list(objectType = "File", objectId = 1, name = "model_info.json")))
  fake_download <- mock(fn)
  fake_fetch_job <- mock(1)
  fake_fetch_run <- mock(list(state = "succeeded"))
  fake_model_type <- mock("regressor")
  res <- with_mock(
    `civis::must_fetch_civis_ml_job` = fake_fetch_job,
    `civis::must_fetch_civis_ml_run` = fake_fetch_run,
    `civis::scripts_list_custom_runs_outputs` = fake_outputs,
    `civis::download_civis` = fake_download,
    `civis::model_type` = fake_model_type,
    civis_ml_fetch_existing(123, 1)
  expect_equal(res$model_info, list(a=1, b=letters[1:3]))

# run predictions

test_that("uses the correct template_id", {
  fake_run_model <- mock(list(job_id = 133, run_id = 244))
  fake_fetch_predict_results <- mock(NULL)

    `civis::run_model` = fake_run_model,
    `civis::fetch_predict_results` = fake_fetch_predict_results,

    create_and_run_pred(train_job_id = 111, train_run_id = 222, template_id = 555),
    run_args <- mock_args(fake_run_model)[[1]],
    expect_equal(run_args$template_id, 555)

# fetch existing model

test_that("raises an error on not found", {
  fake_scripts_get_custom <- function(id) stop(httr::http_condition(404L, "error"))

    `civis::scripts_get_custom` = fake_scripts_get_custom,

    expect_error(civis_ml_fetch_existing(123), "Error: model 123 not found\\.")

test_that("raises an error on invalid model", {
  fake_must_fetch_civis_ml_job <- mock(
      lastRun = list(
        id = NULL
    `civis::must_fetch_civis_ml_job` = fake_must_fetch_civis_ml_job,

    expect_error(civis_ml_fetch_existing(123), "Error: invalid model task\\.")

test_that("issues message for still running", {
  fake_must_fetch_civis_ml_job <- mock(
      lastRun = list(
        id = 456
      arguments = list(
          MODEL = "regressor"
  fake_scripts_get_custom_runs <- mock(list(state = "running"))

    `civis::must_fetch_civis_ml_job` = fake_must_fetch_civis_ml_job,
    `civis::scripts_get_custom_runs` = fake_scripts_get_custom_runs,

                   "The model task is still running\\.")

test_that("raises an error if job failed", {
  fake_must_fetch_civis_ml_job <- mock(
      lastRun = list(
        id = 456
      arguments = list(
          MODEL = "regressor"
  fake_scripts_get_custom_runs <- mock(list(state = "failed"))

    `civis::must_fetch_civis_ml_job` = fake_must_fetch_civis_ml_job,
    `civis::scripts_get_custom_runs` = fake_scripts_get_custom_runs,

                   "The model task failed, use fetch_logs to retreive any error messages.")

test_that("fetch_logs.civis_ml_error works", {
    `civis::scripts_post_custom` = function(...) NULL,
    `civis::scripts_post_custom_runs` = function(...) NULL,
    `civis::scripts_get_custom_runs` = function(...) list(state = "failed", id = 1, run_id = 2, error = "msg"),
    `civis::fetch_logs.civis_ml_error` = function(...) list("A log message"),
    e <- tryCatch(civis:::run_model(1234, name = "sparse_logistic", list(), list(),
                                    verbose = TRUE, polling_interval = NULL),
             error = function(e) e),
    log <- fetch_logs(e)[[1]],
    expect_equal(log, "A log message"))


test_that("it removes notifications when NULL", {
  fake_scripts_post_custom <- mock(list(id = 123))
  fake_scripts_post_custom_runs <- mock(list(id = 456))
  fake_scripts_get_custom_runs <- mock(list(state = "succeeded"))

    `civis::scripts_post_custom` = fake_scripts_post_custom,
    `civis::scripts_post_custom_runs` = fake_scripts_post_custom_runs,
    `civis::scripts_get_custom_runs` = fake_scripts_get_custom_runs,

    run_model(template_id = 123, name = "a name", arguments = list(a = "b"),
              notifications = NULL, verbose = TRUE)
  script_args <- mock_args(fake_scripts_post_custom)[[1]]
  expect_false("notifications" %in% names(script_args))

test_that("it removes name when NULL", {
  fake_scripts_post_custom <- mock(list(id = 123))
  fake_scripts_post_custom_runs <- mock(list(id = 456))
  fake_scripts_get_custom_runs <- mock(list(state = "succeeded", NULL))

    `civis::scripts_post_custom` = fake_scripts_post_custom,
    `civis::scripts_post_custom_runs` = fake_scripts_post_custom_runs,
    `civis::scripts_get_custom_runs` = fake_scripts_get_custom_runs,

    run_model(template_id = 123, name = NULL, arguments = list(a = "b"),
              notifications = NULL, verbose = TRUE)
  script_args <- mock_args(fake_scripts_post_custom)[[1]]
  expect_false("name" %in% names(script_args))

test_that("civis_ml_error is caught from run_model", {
  e <- with_mock(
    `civis::scripts_post_custom` = function(...) NULL,
    `civis::scripts_post_custom_runs` = function(...) NULL,
    `civis::scripts_get_custom_runs` = function(...) list(state = "failed", id = 1, run_id = 2, error = "msg"),
    `civis::fetch_logs.civis_ml_error` = function(...) list("A log message"),
    tryCatch(civis:::run_model(1234, name = "sparse_logistic", list(), list(), verbose = TRUE),
             error = function(e) e))
  msg <- "scripts_get_custom_runs(... = NULL, run_id = NULL): msg\nA log message"
  expect_equal(e$message, msg)
  expect_is(e, c("civis_ml_error", "civis_error"))
  expect_true(any(grepl("A log message", capture.output(print(e)))))

  err_data <- get_error(e)
  expect_equal(err_data$f, "scripts_get_custom_runs")
  expect_equal(err_data$log[[1]], "A log message")


test_that("it checks input type", {
  fake_must_fetch_output_file <- mock(NULL)

    `civis::must_fetch_output_file` = fake_must_fetch_output_file,
    expect_error(fetch_oos_scores("not a model"), "is_civis_ml(model) is not TRUE", fixed = TRUE)

test_that("it looks for predictions.csv.gz", {
  fake_must_fetch_output_file <- mock(textConnection(c("a, b, c")))

    `civis::must_fetch_output_file` = fake_must_fetch_output_file,
    fetch_oos_scores(structure(list(), class = "civis_ml"))

  fetch_args <- mock_args(fake_must_fetch_output_file)[[1]]
  expect_equal(fetch_args[[2]], "predictions.csv.gz")

test_that("it calls read.csv with extra args", {
  fake_must_fetch_output_file <- mock(textConnection(c("a,b,c")))

  df <- with_mock(
    `civis::must_fetch_output_file` = fake_must_fetch_output_file,
    fetch_oos_scores(structure(list(), class = "civis_ml"),
                           stringsAsFactors = FALSE, header = FALSE)
  ans <- data.frame(V1 = "a", V2 = "b", V3 = "c", stringsAsFactors = FALSE)
  expect_equal(df, ans)


test_that("it checks input type", {
  expect_error(fetch_predictions("not a model"), "is(x, \"civis_ml_prediction\") is not TRUE", fixed = TRUE)

test_that("it calls read.csv with extra args, and dowload_civis with correct id", {
  fake_read_csv <- mock(NULL)
  fake_download_civis <- mock(textConnection(c("a,b,c")))

  df <- with_mock(
    `civis::fetch_predict_results` = function(...) list(model_info = list(output_file_ids = 1)),
    `civis::download_civis` = fake_download_civis,
    fetch_predictions(structure(list(), class = "civis_ml_prediction"),
                      header = FALSE, stringsAsFactors = FALSE)
  ans <- data.frame(V1 = "a", V2 = "b", V3 = "c", stringsAsFactors = FALSE)
  expect_equal(df, ans)

  dl_args <- mock_args(fake_download_civis)[[1]]
  expect_equal(dl_args[[1]], 1)

Try the civis package in your browser

Any scripts or data that you put into this service are public.

civis documentation built on April 1, 2023, 12:01 a.m.