R/ClassificationViaRegressionWrapper.R

Defines functions isFailureModel.ClassificationViaRegressionModel setPredictType.ClassificationViaRegressionWrapper getLearnerProperties.ClassificationViaRegressionWrapper predictLearner.ClassificationViaRegressionWrapper trainLearner.ClassificationViaRegressionWrapper makeClassificationViaRegressionWrapper

Documented in makeClassificationViaRegressionWrapper

#' @title Classification via regression wrapper.
#'
#' @description
#' Builds regression models that predict for the positive class whether a particular example belongs to it (1) or not (-1).
#'
#' Probabilities are generated by transforming the predictions with a softmax.
#'
#' Inspired by WEKA's ClassificationViaRegression (http://weka.sourceforge.net/doc.dev/weka/classifiers/meta/ClassificationViaRegression.html).
#'
#' @template arg_learner
#' @param predict.type (`character(1)`)\cr
#'   \dQuote{response} (= labels) or \dQuote{prob} (= probabilities and labels by selecting the one with maximal probability).
#' @template ret_learner
#' @export
#' @family wrapper
#' @examples
#' lrn = makeLearner("regr.rpart")
#' lrn = makeClassificationViaRegressionWrapper(lrn)
#' mod = train(lrn, sonar.task, subset = 1:140)
#' predictions = predict(mod, newdata = getTaskData(sonar.task)[141:208, 1:60])
makeClassificationViaRegressionWrapper = function(learner, predict.type = "response") {
  learner = checkLearner(learner, "regr")
  lrn = makeBaseWrapper(
    id = stri_paste(learner$id, "classify", sep = "."),
    type = "classif",
    next.learner = learner,
    package = "mlr",
    par.set = makeParamSet(),
    par.vals = list(),
    learner.subclass = "ClassificationViaRegressionWrapper",
    model.subclass = "ClassificationViaRegressionModel"
  )
  lrn$predict.type = predict.type
  return(lrn)
}

#' @export
trainLearner.ClassificationViaRegressionWrapper = function(.learner, .task, .subset = NULL, .weights = NULL, ...) {
  pos = getTaskDesc(.task)$positive
  td = getTaskData(.task, target.extra = TRUE, subset = .subset)
  target.name = stri_paste(pos, "prob", sep = ".")
  data = td$data
  data[[target.name]] = ifelse(td$target == pos, 1, -1)
  regr.task = makeRegrTask(
    id = stri_paste(getTaskId(.task), pos, sep = "."),
    data = data,
    target = target.name,
    weights = getTaskWeights(.task),
    blocking = .task$blocking)
  model = train(.learner$next.learner, regr.task, weights = .weights)
  makeChainModel(next.model = model, cl = "ClassificationViaRegressionModel")
}

#' @export
predictLearner.ClassificationViaRegressionWrapper = function(.learner, .model, .newdata, .subset = NULL, ...) {
  model = getLearnerModel(.model, more.unwrap = FALSE)
  p = predict(model, newdata = .newdata, subset = .subset, ...)$data$response

  if (.learner$predict.type == "response") {
    factor(ifelse(p > 0, getTaskDesc(.model)$positive, getTaskDesc(.model)$negative))
  } else {
    td = getTaskDesc(.model)
    levs = c(td$positive, td$negative)
    propVectorToMatrix(vnapply(p, function(x) exp(x) / sum(exp(x))), levs)
  }
}

#' @export
getLearnerProperties.ClassificationViaRegressionWrapper = function(learner) {
  props = getLearnerProperties(learner$next.learner)
  props = union(props, c("twoclass", "prob"))
  intersect(props, mlr$learner.properties$classif)
}

#' @export
setPredictType.ClassificationViaRegressionWrapper = function(learner, predict.type) {
  assertChoice(predict.type, c("response", "prob"))
  learner$predict.type = predict.type
}

#' @export
isFailureModel.ClassificationViaRegressionModel = function(model) {
  isFailureModel(getLearnerModel(model, more.unwrap = FALSE))
}
mlr-org/mlr documentation built on Jan. 12, 2023, 5:16 a.m.