R/Tracker.R

#' Machine Learning experiments logging
#'
#' @section Attributes:
#' * `id_columns`  -- `character()` \cr
#'   Column name(s) from the test and train frames, which are used to
#'   identify uniquely each entry. This enables, thanks to md5 checksums, to validate that comparison
#'   between models are done on constant evaluation perimeter (test_frame and
#'   train_frame fields).
#' @inheritSection GenericTracker Attributes
#' @inheritSection GenericTracker Methods
#' @export
#'
Tracker <- R6::R6Class("Tracker",
  inherit = GenericTracker, #nolint

  # -------- Attributes -----------------------------------------------------
  public = list(
    id_columns = NULL,
    initialize = function(
      database = NULL,
      collection = NULL,
      control = list()
      ){
      super$initialize(
        database = database,
        collection = collection,
        fields = list(
          field_timestamp(),
          field("experiment_name", is_compulsary = TRUE),
          field("experiment_description", is_compulsary = TRUE),
          field(
            "model",
            transform = function(x) dput(x)),
          field("model_name", is_compulsary = TRUE),
          field("model_parameters", is_compulsary = TRUE),
          field("model_features", is_compulsary = TRUE),
          field("model_target", is_compulsary = TRUE),
          field("model_performance", is_compulsary = TRUE),
          field("resampling_strategy", is_compulsary = TRUE),
          field("preprocessing_strategy", is_compulsary = TRUE),
          field("train_val_test_shares", is_compulsary = TRUE),
          field("additional_comment"),
          field(
            "additional_R_obj",
            transform = function(x) dput(x)
            ),
          field("test_frame",
            validate = function(frame, control){
              id_columns  <- control[["id_columns"]]
              is.data.frame(frame) && all(id_columns %in% names(frame))
            },
            transform = function(frame, control){
              md5(frame[control[["id_columns"]]])
            }
            ),
          field("train_frame",
            validate = function(frame, control){
              id_columns  <- control[["id_columns"]]
              is.data.frame(frame) && all(id_columns %in% names(frame))
            },
            transform = function(frame, control){
              md5(frame[control[["id_columns"]]])
            }
            ),
          field_uuid()
          ),
        control = control
        )
    }
    )
  )


get_private <- function(x) {
  x[[".__enclos_env__"]]$private
}


md5 <- function(value){

  if (is.data.frame(value)){
    # L'ordre et les attributs ne doivent pas influencer l'emprunte du
    # data.frame
    value  <- do.call(
      function(...) dplyr::arrange_(.data = value, ...),
      args = as.list(names(value))
      )
    attributes(value)  <- NULL
  }

  if (require(digest)){
    md5_value <- digest::digest(value, algo = c("md5"))
  } else {
    warning(
      "Package digest is required to save the md5 of the test frame and
      check that evaluation is made on constant perimeter"
      )
  }
  return(md5_value)
}
signaux-faibles/MLlogr documentation built on June 27, 2019, 1:20 p.m.