R/setPredictType.R

Defines functions setPredictType.Learner setPredictType

Documented in setPredictType

#' @title Set the type of predictions the learner should return.
#'
#' @description
#' Possible prediction types are:
#' Classification: Labels or class probabilities (including labels).
#' Regression: Numeric or response or standard errors (including numeric response).
#' Survival: Linear predictor or survival probability.
#'
#' For complex wrappers the predict type is usually also passed down the
#' encapsulated learner in a recursive fashion.
#'
#' @template arg_learner
#' @param predict.type (`character(1)`)\cr
#'   Classification: \dQuote{response} or \dQuote{prob}.
#'   Regression: \dQuote{response} or \dQuote{se}.
#'   Survival: \dQuote{response} (linear predictor) or \dQuote{prob}.
#'   Clustering: \dQuote{response} or \dQuote{prob}.
#'   Default is \dQuote{response}.
#' @template ret_learner
#' @family predict
#' @family learner
#' @export
setPredictType = function(learner, predict.type) {
  assertClass(learner, classes = "Learner")
  UseMethod("setPredictType")
}

#' @export
setPredictType.Learner = function(learner, predict.type) {
  # checks should be done down here i guess, because of recursive calls in wrappers
  assertChoice(predict.type, choices = switch(learner$type,
    classif = c("response", "prob"),
    multilabel = c("response", "prob"),
    regr = c("response", "se"),
    surv = c("response", "prob"),
    costsens = "response",
    cluster = c("response", "prob")
  ))
  if (predict.type == "prob" && !hasLearnerProperties(learner, "prob")) {
    stopf("Trying to predict probs, but %s does not support that!", learner$id)
  }
  if (predict.type == "se" && !hasLearnerProperties(learner, "se")) {
    stopf("Trying to predict standard errors, but %s does not support that!", learner$id)
  }
  learner$predict.type = predict.type
  return(learner)
}

Try the mlr package in your browser

Any scripts or data that you put into this service are public.

mlr documentation built on Sept. 29, 2022, 5:05 p.m.