R/PredictorMLR3.R

#' @title PredictorMLR3
#'
#' @include Predictor.R
#'
#' @description
#' This task specializes [Predictor] for `mlr3` models.
#' The `model` is assumed to be a `LearnerRegr` or `LearnerClassif`.
#'
#' It is recommended to use [makePredictor()] for construction of Predictor objects.
#' @export
PredictorMLR3 = R6::R6Class("PredictorMLR3",

  inherit = Predictor,

  public = list(

    #' @description
    #' Create a new PredictorMLR3 object.
    #' @param model `LearnerRegr` or `LearnerClassif` object.
    #' @param data The data used for computing FMEs, must be data.frame or data.table.
    initialize = function(model, data) {
      private$initializeSubclass(model, data)
    },

    #' @description
    #' Predicts on an observation `"newdata"`.
    #' @param newdata The feature vector for which the target should be predicted.
    predict = function(newdata) {
      if ("LearnerRegr" %in% class(self$model)) {
        prediction = as.data.table(self$model$predict_newdata(newdata))[,3]
      }
      if ("LearnerClassif" %in% class(self$model)) {
        if (!"prob" %in% self$model$predict_type) {
          stop(paste(class(self)[1], "Your learner needs predict_type = `prob`"))
        }
        # the target class for the probability is the last item of the levels of the target
        prediction = as.data.table(self$model$predict_newdata(newdata))[,4]
      }
      names(prediction) = "prediction"
      return(prediction)
   }

  ),
  private = list(

    getTarget = function(model) {
      return(model$state$train_task$target_names)
    }

  )
)

Try the fmeffects package in your browser

Any scripts or data that you put into this service are public.

fmeffects documentation built on June 22, 2024, 9:32 a.m.