devtools::load_all()
knitr::opts_chunk$set(echo = FALSE)

Script parameters

database <- "test_signauxfaibles"
collection <- "Features"
actual_period <- as.Date("2019-04-01")
last_batch <- "1904"
min_effectif <- 10
type <- "dataframe"
connect_to_h2o()

Fields to import / train on

  fields <- c(get_fields(training = FALSE), "debit_18m")
  x_fields_model <- get_fields(training = TRUE)
  x_fields_mini <- get_fields(training = TRUE, urssaf = 0)

Load data

# If not enough memory, run systemctl restart mongod
    my_data_frame <- connect_to_database(
      database = database,
      collection = collection,
      batch = last_batch,
      siren = NULL,
      date_inf = "2015-01-01",
      date_sup = "2017-01-01",
      min_effectif = min_effectif,
      fields = fields,
      code_ape = NULL,
      type = type,
      subsample = 50000
      )

Split data

res <- split_snapshot_rdm_month(my_data_frame)
train_sample <- my_data_frame %>% semi_join(res[["train"]])
val_sample <- my_data_frame %>% semi_join(res[["validation"]])
test_sample <- my_data_frame %>% semi_join(res[["test"]])

Train model

out <- prepare_frame(
  data_to_prepare = train_sample,
  save_or_load_map = FALSE,
  test_or_train = "train"
)

te_map <- out[["te_map"]]
out2 <- prepare_frame(
  data_to_prepare = val_sample,
  test_or_train = "test",
  te_map = te_map,
  save_or_load_map = FALSE
)

h2o_train_sample <- out[["data"]]
h2o_val_sample <- out2[["data"]]

model <- train_light_gradient_boosting(
  h2o_train_data = h2o_train_sample,
  h2o_validation_data = h2o_val_sample,
  x_fields_model = x_fields_model,
  outcome = "outcome",
  save_results = "FALSE"
)
aux_predict <- function(model, new_data){
  new_data$periode <- as.Date(new_data$periode)
  h2o_new_data <- h2o::as.h2o(new_data)
  res <- predict_model(
    model = model,
    new_data = h2o_new_data
  )
  return(res$prob)
}

aux_train <- function(train_data, features, target){
  return(
    train_light_gradient_boosting(
      h2o_train_data = train_data,
      h2o_validation_data = NULL,
      x_fields_model = features,
      outcome = target,
      save_results = FALSE
    )
  )
}

val_df <- as.data.frame(h2o_val_sample)
val_df$outcome <- as.logical(val_df$outcome)
prediction <- aux_predict(model, val_df)

Model performance

model_assesser <- MLsegmentr::Assesser$new(val_df)
model_assesser$set_predictions(prediction)
model_assesser$set_targets("outcome")

model_assesser$evaluation_funs <- MLsegmentr::eval_precision_recall()

model_assesser$assess_model()
model_assesser$set_predictions_with_training(
  train_fun = aux_train,
  train_data = h2o_train_sample,
  test_data = val_df,
  target_names = "outcome",
  feature_sets = list(initial = x_fields_model,
    without_urssaf = x_fields_mini
    ),
  predict_fun = aux_predict
  )
model_assesser$evaluation_funs <- eval_urssaf()
model_assesser$assess_model()
cat(model@model$validation_metrics@metrics$pr_auc)
plotPR(model, h2o_train_sample, new_fig = FALSE)
plotPR(model, h2o_val_sample, new_fig = FALSE)

Model testing

cat("BEWARE YOU ARE ABOUT TO TEST ON TEST DATA")
cat("TAKING DECISIONS ON THESE RESULTS MAY LEAD TO OVERFITTING TEST DATA")
shapley_plot(siret, my_data_frame, model, batch = last_batch)


signaux-faibles/rsignauxfaibles documentation built on Dec. 2, 2020, 3:24 a.m.