R/Assesser.R

#'
#' @imports dplyr
#'
NULL

#' @title Assesser Object
#'
#' @format [R6::R6Class] object.
#' @name Assesser
#'
#' @description
#' This is an object used to assess one or several models, on the whole data
#' or specific segments, for one or several target variables.
#'
#' @section Construction:
#' * `new` :: `{data.frame, character()} => Assesser}`
#'   Initialize a new `Assesser` with given test_data and id_names. If no id
#'   names are provided, then it defaults to the row number.
#'
#' @section Fields:
#'
#'
#' * `test_data` :: `data.frame()`\cr
#'  Data used for defining targets, segments, optionnaly predictions, features
#'  and others.
#'
#' * `id_names` :: `character()`\cr
#'   Names of one or several columns of the test data that uniquely define
#'   each row
#'
#' * `evaluation_funs` :: `[eval_function]` \cr
#'   EvaluationFunction that will be used to evaluated the models.
#'
#' @section S3 Methods:
#'  * For all three next methods, several inputs are given as lists, and recycled with a warning if lenghts
#'  are not equal.
#'
#' * `set_predictions()`, `set_segments()`, `set_targets()` :: `character() =>
#' invisible(self)`\cr
#'   defines the columns storing respectively predictions, segments and
#'   targets. `target_type` and `prediction_type` are derived from the column
#'   names.
#'
#' * `assess_model`\cr
#'    VOID -> any, ...\cr
#'    Assess the models thanks to the eval_function object. Additional
#'    ... arguments are transmitted to the Evaluation function.
#'
#' @export
Assesser <- R6::R6Class("Assesser",
  public = list(
    test_data = NULL,
    id_names = NULL,
    evaluation_funs = NULL,
    # -------- Methods--------------------------------------------------------
    initialize = function(test_data = NULL, id_names = NULL){
      assertthat::assert_that(is.data.frame(test_data))
      self$test_data <- test_data
      if (is.null(id_names)){
        self$test_data <- add_id(self$test_data)
        self$id_names  <- ".id"
      } else {
        self$id_names <- id_names
      }
    },

    assess_model = function(..., plot = TRUE){
      eval_frame <- private$get_eval_frame(
        additional_fields = self$evaluation_funs[["compulsory_fields"]]
      )
      return(assess_eval_frame(self$evaluation_funs, eval_frame, plot, ...))
    },

    set_predictions = function(prediction_names){
      private$predictions <- set_predictions(
        col_names = prediction_names,
        id_names = self$id_names,
        test_data = self$test_data
      )
      return(invisible(self))
    },
    set_segments = function(segment_names){
      private$segments <- set_segments(
        col_names = segment_names,
        id_names = self$id_names,
        test_data = self$test_data
      )
      return(invisible(self))
    },
    set_targets = function(target_names){
      private$targets <- set_targets(
        col_names = target_names,
        id_names = self$id_names,
        test_data = self$test_data
      )
      return(invisible(self))
    }
    ),
  private = list(
    # -------- Attributes ----------------------------------------------------
    predictions = NULL,
    targets = NULL, # default target = no target ??
    segments = NULL, # Default segment = no segment
    # -------- Methods-------------------------------------------------------
    # get_eval_frame returns an eval_frame with columns "model", "prediction",
    # "target", "target_type" and "segment". Additionally, additional_fields
    # can be fetched from self$test_data, with a list of column names
    get_eval_frame = function(additional_fields = NULL){
      assertthat::assert_that(
        is.null(additional_fields) || is.character(additional_fields),
        msg = "additional_fields should be equal to NULL or a vector of field
        names that need to be included in the eval frame"
      )

        eval_frame <- data.frame(
          .id = 1:dplyr::n_distinct(private$predictions[self$id_names])
        )
        eval_frame <- dplyr::left_join(
          x = eval_frame,
          y = private$predictions,
          by = self$id_names
        )
        if (!is.null(private$segments)){
          eval_frame <- dplyr::left_join(
            x = eval_frame,
            y = private$segments,
            by = self$id_names
          )
        } else {
          # default unique segment
          eval_frame  <- eval_frame %>%
            dplyr::mutate(segment = "segment")
        }
        eval_frame <- dplyr::left_join(
          x = eval_frame,
          y = private$targets,
          by = self$id_names
        )

        if (!is.null(additional_fields)) {
          assertthat::assert_that(
            all(additional_fields %in% names(self$test_data)),
            msg = paste("additional_fields are missing from the test_data:",
              paste(setdiff(additional_fields, names(self$test_data)),
                collapse = ", "))
          )
          additional_frame  <- self$test_data[
            c(self$id_names, additional_fields)
          ]
          eval_frame <- dplyr::left_join(
            x = eval_frame,
            y = additional_frame,
            by = self$id_names
          )
        }
        return(eval_frame)
    }
  )
)
signaux-faibles/MLsegmentr documentation built on Aug. 29, 2019, 2:22 p.m.