devtools::load_all()
knitr::opts_chunk$set(echo = FALSE)

Script parameters

database <- "test_signauxfaibles"
collection <- "Features"
actual_period <- as.Date("2019-04-01")
last_batch <- "1904"
min_effectif <- 10
type <- "dataframe"
connect_to_h2o()

Fields to import / train on

  fields <- get_fields(training = FALSE)
  x_fields_model <- c(get_fields(training = TRUE), "debit_18m")
  x_fields_mini <- get_fields(training = TRUE, urssaf = 0)

Load data

# If not enough memory, run systemctl restart mongod
    my_data_frame <- connect_to_database(
      database = database,
      collection = collection,
      batch = last_batch,
      siren = NULL,
      date_inf = "2015-01-01",
      date_sup = "2017-01-01",
      min_effectif = min_effectif,
      fields = fields,
      code_ape = NULL,
      type = type,
      subsample = 50000
      )

Split data

res <- split_snapshot_rdm_month(my_data_frame)
train_sample <- my_data_frame %>% semi_join(res[["train"]])
val_sample <- my_data_frame %>% semi_join(res[["validation"]])
test_sample <- my_data_frame %>% semi_join(res[["test"]])
# Memory gain
rm(my_data_frame)
rm(res)
rm(train_sample)
rm(test_sample)

Load trained model

model <- load_h2o_object(
  "lgb",
  "model",
  last = TRUE
)
te_map <- load_h2o_object(
  "te_map",
  "temap",
  last = TRUE
)

out2 <- prepare_frame(
  data_to_prepare = val_sample,
  test_or_train = "test",
  te_map = te_map,
  save_or_load_map = FALSE
)

h2o_val_sample <- out2[["data"]]

aux_predict <- function(model, new_data){
  res <- predict_model(
    model = model,
    new_data = new_data
  )
  return(res$prob)
}

Model performance

result <- MLsegmentr::assess_training_sets(
  training_frames = h2o_train_sample,
  test_frame = h2o_val_sample,
  feature_sets = list(
    initial = x_fields_model#,
    #without_urssaf = x_fields_mini
    ),
  outcome_names = "outcome",
  training_fun = train_light_gradient_boosting,
  predict_fun = aux_predict,
  eval_funs = custom_eval_urssaf,
    #MLsegmentr::main_diff_eval
  more_training_args = list(
    h2o_validation_data = h2o_val_sample,
    save_results = FALSE
    ),
  segments = tidyr::replace_na(as.vector(h2o_val_sample["code_naf"]) == "C", FALSE),
  ... = h2o_val_sample[["siret"]] %>% as.vector()
)
cat(model@model$validation_metrics@metrics$pr_auc)
plotPR(model, h2o_train_sample, new_fig = FALSE)
plotPR(model, h2o_val_sample, new_fig = FALSE)

Model testing

cat("BEWARE YOU ARE ABOUT TO TEST ON TEST DATA")
cat("TAKING DECISIONS ON THESE RESULTS MAY LEAD TO OVERFITTING TEST DATA")
shapley_plot(siret, my_data_frame, model, batch = last_batch)


signaux-faibles/rsignauxfaibles documentation built on Dec. 2, 2020, 3:24 a.m.