R/predictLearner.R

Defines functions predictLearner predictLearner2 removeNALines insertLines checkPredictLearnerOutput

Documented in checkPredictLearnerOutput predictLearner

#' Predict new data with an R learner.
#'
#' Mainly for internal use. Predict new data with a fitted model.
#' You have to implement this method if you want to add another learner to this package.
#'
#' Your implementation must adhere to the following:
#' Predictions for the observations in \code{.newdata} must be made based on the fitted
#' model (\code{.model$learner.model}).
#' All parameters in \code{...} must be passed to the underlying predict function.
#'
#' @param .learner [\code{\link{RLearner}}]\cr
#'   Wrapped learner.
#' @param .model [\code{\link{WrappedModel}}]\cr
#'   Model produced by training.
#' @param .newdata [\code{data.frame}]\cr
#'   New data to predict. Does not include target column.
#' @param ... [any]\cr
#'   Additional parameters, which need to be passed to the underlying predict function.
#' @return
#' \itemize{
#'   \item For classification: Either a factor with class labels for type
#'     \dQuote{response} or, if the learner supports this, a matrix of class probabilities
#'     for type \dQuote{prob}. In the latter case the columns must be named with the class
#'     labels.
#'   \item For regression: Either a numeric vector for type \dQuote{response} or,
#'     if the learner supports this, a matrix with two columns for type \dQuote{se}.
#'     In the latter case the first column contains the estimated response (mean value)
#'     and the second column the estimated standard errors.
#'   \item For survival: Either a numeric vector with some sort of orderable risk
#'     for type \dQuote{response} or, if supported, a numeric vector with time dependent
#'     probabilities for type \dQuote{prob}.
#'   \item For clustering: Either an integer with cluster IDs for type \dQuote{response}
#'     or, if supported, a matrix of membership probabilities for type \dQuote{prob}.
#'   \item For multilabel: A logical matrix that indicates predicted class labels for type
#'     \dQuote{response} or, if supported, a matrix of class probabilities for type
#'     \dQuote{prob}. The columns must be named with the class labels.
#'  }
#' @export
predictLearner = function(.learner, .model, .newdata, ...) {
  lmod = getLearnerModel(.model)
  if (inherits(lmod, "NoFeaturesModel")) {
    predictNofeatures(.model, .newdata)
  } else {
    assertDataFrame(.newdata, min.rows = 1L, min.cols = 1L)
    UseMethod("predictLearner")
  }
}

predictLearner2 = function(.learner, .model, .newdata, ...) {
  # if we have that option enabled, set factor levels to complete levels from task
  if (.learner$fix.factors.prediction) {
    fls = .model$factor.levels
    ns = names(fls)
    # only take objects in .newdata
    ns = intersect(colnames(.newdata), ns)
    fls = fls[ns]
    if (length(ns) > 0L)
      .newdata[ns] = mapply(factor, x = .newdata[ns],
         levels = fls, SIMPLIFY = FALSE)
  }
  if ("missings" %nin% getLearnerProperties(.learner))
    no.na = removeNALines(.newdata)
  else
    no.na = list(newdata = .newdata, inserts = FALSE)
  if (!nrow(no.na$newdata))
    no.na = list(newdata = .newdata, inserts = FALSE)  # no choice if all lines contain NA
  p = predictLearner(.learner, .model, no.na$newdata, ...)
  p = checkPredictLearnerOutput(.learner, .model, p)
  return(insertLines(p, no.na$inserts))
}

removeNALines = function(newdata) {
  namat = is.na(newdata)
  narows = apply(namat, 1, any)
  return(list(newdata = newdata[!narows, , drop = FALSE], inserts = narows))
}

insertLines = function(prediction, inserts) {
#  if (!any(inserts))
#    return(prediction)
  if (is.matrix(prediction)) {
    ret = matrix(nrow = nrow(prediction) + sum(inserts), ncol = ncol(prediction))
    ret[!inserts, ] = prediction
    colnames(ret) = colnames(prediction)
  } else {
    ret = rep(NA, length(prediction) + sum(inserts))
    ret[!inserts] = prediction
    attributes(ret) = attributes(prediction)
    names(ret) = NULL
  }
  return(ret)
}

#' @title Check output returned by predictLearner.
#'
#' @description
#' Check the output coming from a Learner's internal
#' \code{predictLearner} function.
#'
#' This function is for internal use.
#'
#' @param learner [\code{\link{Learner}}]\cr
#'   The learner.
#' @param model [\code{\link{WrappedModel}}]]\cr
#'   Model produced by training.
#' @param p [any]\cr
#'   The prediction made by \code{learner}.
#' @return [any]. A sanitized version of \code{p}.
#' @keywords internal
#' @export
checkPredictLearnerOutput = function(learner, model, p) {
  cl = class(p)[1L]
  if (learner$type == "classif") {
    levs = model$task.desc$class.levels
    if (learner$predict.type == "response") {
      # the levels of the predicted classes might not be complete....
      # be sure to add the levels at the end, otherwise data gets changed!!!
      if (!is.factor(p))
        stopf("predictLearner for %s has returned a class %s instead of a factor!", learner$id, cl)
      levs2 = levels(p)
      if (length(levs2) != length(levs) || any(levs != levs2))
        p = factor(p, levels = levs)
    } else if (learner$predict.type == "prob") {
      if (!is.matrix(p))
        stopf("predictLearner for %s has returned a class %s instead of a matrix!", learner$id, cl)
      cns = colnames(p)
      if (is.null(cns) || length(cns) == 0L)
        stopf("predictLearner for %s has returned not the class levels as column names, but no column names at all!",
          learner$id)
      if (!setequal(cns, levs))
        stopf("predictLearner for %s has returned not the class levels as column names: %s",
          learner$id, collapse(colnames(p)))
    }
  } else if (learner$type == "regr") {
    if (learner$predict.type == "response") {
      if (cl != "numeric")
        stopf("predictLearner for %s has returned a class %s instead of a numeric!", learner$id, cl)
     } else if (learner$predict.type == "se") {
      if (!is.matrix(p))
        stopf("predictLearner for %s has returned a class %s instead of a matrix!", learner$id, cl)
      if (ncol(p) != 2L)
        stopf("predictLearner for %s has not returned a numeric matrix with 2 columns!", learner$id)
    }
  } else if (learner$type == "surv") {
    if (learner$predict.type == "prob")
      stop("Survival does not support prediction of probabilites yet.")
    if (!is.numeric(p))
      stopf("predictLearner for %s has returned a class %s instead of a numeric!", learner$id, cl)
  } else if (learner$type == "cluster")  {
    if (learner$predict.type == "response") {
      if (cl != "integer")
        stopf("predictLearner for %s has returned a class %s instead of an integer!", learner$id, cl)
    } else if (learner$predict.type == "prob") {
      if (!is.matrix(p))
        stopf("predictLearner for %s has returned a class %s instead of a matrix!", learner$id, cl)
    }
  } else if (learner$type == "multilabel")  {
    if (learner$predict.type == "response") {
      if (!(is.matrix(p) && typeof(p) == "logical"))
        stopf("predictLearner for %s has returned a class %s instead of a logical matrix!", learner$id, cl)
     } else if (learner$predict.type == "prob") {
      if (!(is.matrix(p) && typeof(p) == "double"))
        stopf("predictLearner for %s has returned a class %s instead of a numerical matrix!", learner$id, cl)
    }
  }
  return(p)
}
guillermozbta/s2 documentation built on May 17, 2019, 4:01 p.m.