R/predictive_model_R6.R

# R6 Predictive Model class
#' @export
PredictiveModel <- R6::R6Class("PredictiveModel",
  public = list(
    initialize = function(name, model_opts, fit_f, predict_f, assess_f){
      if (check_model_name(name)) private$name <- name
      if (check_opts(model_opts, 'model', 'PredictiveModel')){
        private$model_options <- model_opts
      }else{
        private$model_options <- list()
      }
      required_args <- c('dataset', 'model_opts')
      if (
        check_method(fit_f, required_args, 'fit_f', 'fit', 'PredictiveModel')
      ){
        private$fit_fn <- fit_f
      }else{
        private$fit_fn <- function(dataset, model_opts, ...) return(list())
      }
      required_args <- c('dataset', 'parameters', 'model_opts')
      if (
        check_method(
          predict_f, required_args, 'predict_f', 'predict', 'PredictiveModel'
        )
      ){
        private$predict_fn <- predict_f
      }else{
        private$predict_fn <- function(dataset, parameters, model_opts,  ...){
          return(list())
        }
      }
      required_args[2] <- 'predictions'
      if (
        check_method(
          assess_f, required_args, 'assess_f', 'assess', 'PredictiveModel'
        )
      ){
        private$assess_fn <- assess_f
      }else{
        private$assess_fn <- function(dataset, predictions, model_opts,  ...){
          return(list())
        }
      }
      notify_success('PredictiveModel')
      invisible(self)
    },
    print = function(){
      print(
        list(
          "name" = private$model_name,
          "model_opts" = private$model_options,
          "fit" = private$fit_fn,
          "predict" = private$predict_fn,
          "assess" = private$assess_fn
        )
      )
    },
    fit = function(dataset, ...){
      parameters <- private$fit_fn(dataset, private$model_options, ...)
      return(parameters)
    },
    train = function(dataset, ...) self$fit(dataset, ...),
    predict = function(dataset, parameters, ...){
      predictions <- private$predict_fn(
        dataset, parameters, private$model_options, ...
      )
      return(predictions)
    },
    assess = function(dataset, predictions, ...){
      metrics <- private$assess_fn(
        dataset, predictions, private$model_options, ...
      )
      return(metrics)
    },
    test = function(dataset, predictions, ...){
      self$assess(dataset, predictions, ...)
    }
  ),
  active = list(
    name = function(model_name){
      if (missing(model_name)){
        return(private$model_name)
      }else{
        if (check_model_name(model_name, init=FALSE)){
          private$model_name <- model_name
        }
      }
    },
    model_opts = function(model_options){
      if (missing(model_options)){
        return(private$model_options)
      }else{
        if (check_opts(model_options, 'model', 'PredictiveModel', init=FALSE)){
          private$model_options <- model_options
        }
      }
    },
    fit_f = function(fit_fn){
      if (missing(fit_fn)){
        return(private$fit_fn)
      }else{
        required_args <- c('dataset', 'model_opts')
        if (
          check_method(
            fit_fn, required_args, 'fit_f', 'fit', 'PredictiveModel', FALSE
          )
        ){
          private$fit_fn <- fit_fn
        }
      }
    },
    predict_f = function(predict_fn){
      if (missing(predict_fn)){
        return(private$predict_fn)
      }else{
        required_args <- c('dataset', 'model_opts', 'parameters')
        if (
          check_method(
            predict_fn, required_args, 'predict_f', 'predict',
            'PredictiveModel', FALSE
          )
        ){
          private$predict_fn <- predict_fn
        }
      }
    },
    assess_f = function(assess_fn){
      if (missing(assess_fn)){
        return(private$assess_fn)
      }else{
        required_args <- c('dataset', 'model_opts', 'predictions')
        if (
          check_method(
            assess_fn, required_args, 'assess_f', 'assess', 'PredictiveModel',
            FALSE
          )
        ){
          private$assess_fn <- assess_fn
        }
      }
    }
  ),
  private = list  (
    model_name = NA,
    model_options = NA,
    fit_fn = NA,
    predict_fn = NA,
    assess_fn = NA
  )

)
EntirelyDS/modelr documentation built on May 6, 2019, 3:48 p.m.