inst/templates/Lrnr_template.R

##' Template of a \code{sl3} Learner.
##'
##' This is a template for defining a new learner.
##' This can be copied to a new file using \code{\link{write_learner_template}}.
##' The remainder of this documentation is an example of how you might write 
##' documentation for your new learner.
##' 
##' This learner uses \code{\link[my_package]{my_ml_fun}} from \code{my_package} 
##' to fit my favorite machine learning algorithm.
##'
##' @docType class
##'
##' @importFrom R6 R6Class
##'
##' @export
##'
##' @keywords data
##'
##' @return A learner object inheriting from \code{\link{Lrnr_base}} with
##'  methods for training and prediction. For a full list of learner
##'  functionality, see the complete documentation of \code{\link{Lrnr_base}}.
##'
##' @format An \code{\link[R6]{R6Class}} object inheriting from
##'  \code{\link{Lrnr_base}}.
##'
##' @family Learners
##'
##' @section Parameters:
##'   - \code{param_1="default_1"}: This parameter does something and is not 
##'      already specified in the task. 
##'   - \code{param_2="default_2"}: This parameter does something else and is 
##'      not already specified in the task. 
##'   - \code{...}: Other parameters passed directly to 
##'      \code{\link[my_package]{my_ml_fun}}. See its documentation for details. 
##'      Also, any additional parameters that can be considered by 
##'      \code{\link{Lrnr_base}}.
##'
##' @examples 
##' include an example here, e.g. see \code{\link{Lrnr_ranger}}'s example.
##' 
Lrnr_template <- R6Class(
  classname = "Lrnr_template", 
  inherit = Lrnr_base,
  portable = TRUE, 
  class = TRUE,
  # Above, you should change Lrnr_template (in both the object name and the classname argument)
  # to a name that indicates what your learner does
  public = list(
    # you can define default parameter values here
    # if possible, your learner should define defaults for all required parameters
    initialize = function(param_1 = "default_1", param_2 = "default_2", ...) {
      # this captures all parameters to initialize and saves them as self$params
      params <- args_to_list()
      super$initialize(params = params, ...)
    },

    # you can define public functions that allow your learner to do special things here
    # for instance glm learner might return prediction standard errors
    special_function = function(arg_1) {
    }
  ),
  private = list(
    # list properties your learner supports here.
    # Use sl3_list_properties() for a list of options
    .properties = c(""),

    # .train takes task data and returns a fit object that can be used to generate predictions
    .train = function(task) {
      # generate an argument list from the parameters that were
      # captured when your learner was initialized.
      # this allows users to pass arguments directly to your ml function
      args <- self$params

      # get outcome variable type
      # preferring learner$params$outcome_type first, then task$outcome_type
      outcome_type <- self$get_outcome_type(task)
      # should pass something on to your learner indicating outcome_type
      # e.g. family or objective

      # add task data to the argument list
      # what these arguments are called depends on the learner you are wrapping
      args$x <- as.matrix(task$X_intercept)
      args$y <- outcome_type$format(task$Y)

      # only add arguments on weights and offset
      # if those were specified when the task was generated
      if (task$has_node("weights")) {
        args$weights <- task$weights
      }

      if (task$has_node("offset")) {
        args$offset <- task$offset
      }
      
      if (task$has_node("id")) {
        args$id <- task$id
      }

      # call a function that fits your algorithm
      # with the argument list you constructed
      fit_object <- call_with_args(my_ml_fun, args)

      # return the fit object, which will be stored
      # in a learner object and returned from the call
      # to learner$predict
      return(fit_object)
    },

    # .predict takes a task and returns predictions from that task
    .predict = function(task = NULL) {
      self$training_task
      self$training_outcome_type
      self$fit_object

      predictions <- predict(self$fit_object, task$X)
      return(predictions)
    },
    # list any packages required for your learner here.
    .required_packages = c("my_package")
  )
)
tlverse/sl3 documentation built on Nov. 18, 2024, 12:46 a.m.