R/LearnerMultioutputCforest.R

#' @title Conditional Random Forest Multioutput Learner
#'
#' @name mlr_learners_multioutput.cforest
#' @include LearnerMultioutput.R
#'
#' @description
#' A [LearnerMultioutput] implementation of the random forest and bagging ensemble
#' algorithms utilizing conditional inference trees as base learners.
#' Supports multilabel classification.
#'
#' @templateVar id Multioutputput.cforest
#' @template section_dictionary_learner
#'
#' @export
LearnerMultioutputCForest = R6Class("LearnerMultioutputCForest",
  inherit = LearnerMultioutput,
  public = list(
    #' @description
    #' Creates a new instance of this [R6][R6::R6Class] class.
    initialize = function() {

      ps = ParamSet$new(list(
        # missing: weights (see bottom), subset, strata
        # (see FIXME: below), na.action, ytrafo
        ParamInt$new("ntree", default = 500L, lower = 1L, tags = "train"),
        # replace and fraction go in perturb (named list)
        ParamLgl$new("replace", default = FALSE, tags = "train"),
        ParamDbl$new("fraction", default = 0.632, lower = 0, upper = 1,
          tags = "train"),
        ParamInt$new("mtry", lower = 0L, special_vals = list(Inf),
          tags = "train"), # default actually "ceiling(sqrt(nvar))"
        ParamUty$new("applyfun", tags = c("train", "importance")),
        ParamInt$new("cores", default = NULL, special_vals = list(NULL),
          tags = c("train", "importance")),
        ParamLgl$new("trace", default = FALSE, tags = "train"),
        ParamUty$new("offset", tags = "train"),
        ParamUty$new("cluster", tags = "train"),
        ParamUty$new("scores", tags = "train"),

        # all in ctree_control(); missing: mtry, applyfun, cores
        # (see above, passed directly)
        ParamFct$new("teststat", default = "quadratic",
          levels = c("quadratic", "maximum"), tags = "train"),
        ParamFct$new("splitstat", default = "quadratic",
          levels = c("quadratic", "maximum"), tags = "train"),
        ParamLgl$new("splittest", default = FALSE, tags = "train"),
        ParamFct$new("testtype", default = "Univariate",
          levels = c("Bonferroni", "MonteCarlo", "Univariate", "Teststatistic"),
          tags = "train"),
        ParamUty$new("nmax", tags = "train"),

        # pargs arguments as a list for GenzBretz() within ctree_control:
        # maxpts, abseps, releps
        ParamUty$new("pargs", tags = "train"),

        ParamDbl$new("alpha", default = 0.05, lower = 0, upper = 1,
          tags = "train"),
        ParamDbl$new("mincriterion", default = 0, lower = 0, upper = 1,
          tags = "train"),
        ParamDbl$new("logmincriterion", default = 0, tags = "train"),
        ParamInt$new("minsplit", lower = 1L, default = 20L, tags = "train"),
        ParamInt$new("minbucket", lower = 1L, default = 7L, tags = "train"),
        ParamDbl$new("minprob", default = 0.01, lower = 0, upper = 1,
          tags = "train"),
        ParamLgl$new("stump", default = FALSE, tags = "train"),
        ParamLgl$new("lookahead", default = FALSE, tags = "train"),
        ParamLgl$new("MIA", default = FALSE, tags = "train"),
        ParamInt$new("nresample", default = 9999L, lower = 1L, tags = "train"),
        ParamDbl$new("tol", default = sqrt(.Machine$double.eps), lower = 0,
          tags = "train"),
        ParamInt$new("maxsurrogate", default = 0L, lower = 0L, tags = "train"),
        ParamLgl$new("numsurrogate", default = FALSE, tags = "train"),
        ParamInt$new("maxdepth", default = Inf, lower = 0L,
          special_vals = list(Inf), tags = "train"),
        ParamLgl$new("multiway", default = FALSE, tags = "train"),
        ParamInt$new("splittry", default = 2L, lower = 0L, tags = "train"),
        ParamLgl$new("intersplit", default = FALSE, tags = "train"),
        ParamLgl$new("majority", default = FALSE, tags = "train"),
        ParamLgl$new("caseweights", default = TRUE, tags = "train"),
        ParamLgl$new("saveinfo", default = FALSE, tags = "train"),
        ParamLgl$new("update", default = FALSE, tags = "train"),
        ParamFct$new("splitflavour", default = "ctree",
          levels = c("ctree", "exhaustive"), tags = "train"),

        # predict; missing FUN and simplify (not needed here)
        ParamLgl$new("OOB", default = FALSE, tags = c("predict", "importance")),
        ParamLgl$new("simplify", default = TRUE, tags = "predict"),
        ParamLgl$new("scale", default = TRUE, tags = "predict"),

        # importance; OOB see predict, applyfun, cores see train
        ParamInt$new("nperm", default = 1L, lower = 0L,
          tags = c("train", "importance")),
        ParamFct$new("risk", default = "loglik",
          levels = c("loglik", "misclassification"),
          tags = c("train", "importance")),
        ParamLgl$new("conditional", default = FALSE,
          tags = c("train", "importance")),
        ParamDbl$new("threshold", default = 0.2,
          tags = c("train", "importance"))
      ))

      ps$add_dep("nresample", on = "testtype",
        cond = CondEqual$new("MonteCarlo"))
      ps$add_dep("nperm", on = "conditional", cond = CondEqual$new(TRUE))
      ps$add_dep("threshold", on = "conditional", cond = CondEqual$new(TRUE))

      # set the cforest specific ctree_control parameters
      ps$values$teststat = "quadratic"
      ps$values$testtype = "Univariate"
      ps$values$mincriterion = 0
      ps$values$saveinfo = FALSE

      super$initialize(
        id = "multioutput.cforest",
        feature_types = mlr_reflections$task_feature_types,
        predict_types = c("response", "prob"),
        param_set = ParamSet$new(),
        properties = c("missings", "twoclass", "multiclass", "multioutput", "multilabel"),
        packages = c("partykit", "sandwich", "coin")
      )
    }
  ),

  private = list(
    .train = function(task) {
      pars = self$param_set$get_values(tags = "train")
      pars_control = pars[which(names(pars) %in%
        setdiff(methods::formalArgs(partykit::ctree_control),
          c("mtry", "applyfun", "cores")
        ))] # see ctree_control
      pars = pars[names(pars) %nin%
        c("replace", "fraction", names(pars_control))]
      control = mlr3misc::invoke(partykit::ctree_control, .args = pars_control)
      # perturb parameters need special handling; FIXME: easier way?
      perturb = list(replace = FALSE, fraction = 0.632)
      if (!is.null(self$param_set$values$replace)) {
        perturb$replace = self$param_set$values$replace
      }
      if (!is.null(self$param_set$values$fraction)) {
        perturb$fraction = self$param_set$values$fraction
      }

      mlr3misc::invoke(partykit::cforest,
        formula = task$formula(),
        data = task$data(),
        weights = task$weights$weight, # weights are handled here
        # FIXME: strata handling
        control = control,
        perturb = perturb,
        .args = pars
      )
    },

    .predict = function(task) {
      pars = self$param_set$get_values(tags = "predict")
      newdata = task$data(cols = task$feature_names)
      preds = mlr3misc::invoke(predict, object = self$model, newdata = newdata,
        type = self$predict_type, .args = pars)

      if (self$predict_type == "response") {
        p = map(task$target_names, function(t) {
          list(
            row_ids = task$row_ids,
            truth = task$truth()[[t]],
            response = preds[[t]]
          )
        })
      } else {
        p = map(task$target_names, function(t) {
          list(
            row_ids = task$row_ids,
            truth = task$truth[[t]],
            prob = preds[[t]] # FIXME
          )
        })
      }

      names(p) = task$target_names
      preds = map(as.PredictionDataMultioutput(p, task$task_types), as_prediction)
      PredictionMultioutput$new(task, task$row_ids, preds)
    }
  )
)
mlr-org/mlr3multioutput documentation built on Nov. 22, 2020, 1:17 p.m.