R/get_task.R

Defines functions get_task.xgb.Booster get_task.train get_task.svm get_task.rpart get_task.regbagg get_task.ranger get_task.randomForest get_task.RandomForest get_task.qda get_task.ppr get_task.party get_task.nnet get_task.nls get_task.naiveBayes get_task.multinom get_task.mars get_task.lm get_task.lda get_task.ksvm get_task.glm get_task.gbm get_task.fda get_task.earth get_task.cubist get_task.classbagg get_task.cforest get_task.C5.0 get_task.boosting get_task.bagging get_task.BinaryTree get_task.default get_task

#' @keywords internal
get_task <- function(object) {
  UseMethod("get_task")
}


#' @keywords internal
get_task.default <- function(object) {
  warning('`type` could not be determined; assuming `type = "regression"`')
  "regression"
}


#' @keywords internal
get_task.BinaryTree <- function(object) {
  if (object@responses@is_nominal) {
    "classification"
  } else if (object@responses@is_ordinal || object@responses@is_censored) {
    "other"
  } else {
    "regression"
  }
}


#' @keywords internal
get_task.bagging <- function(object) {
  "classification"
}


#' @keywords internal
get_task.boosting <- function(object) {
  "classification"
}


#' @keywords internal
get_task.C5.0 <- function(object) {
  "classification"
}


get_task.cforest <- function(object) {
  if (is.factor(object$fitted[["(response)"]])) {
    "classification"
  } else if (is.numeric(object$fitted[["(response)"]])) {
    "regression"
  } else {
    "other"
  }
}


#' @keywords internal
get_task.classbagg <- function(object) {
  "classification"
}


#' @keywords internal
get_task.cubist <- function(object) {
  "regression"
}


#' @keywords internal
get_task.earth <- function(object) {
  if (!is.null(object$glm.list) &&
      object$glm.list[[1L]]$family$family == "binomial") {
    "classification"
  } else if (is.null(object$glm.list) ||
             object$glm.list[[1L]]$family$family %in%
             c("gaussian", "Gamma", "inverse.gaussian", "poisson")) {
    "regression"
  } else {
    "other"
  }
}


#' @keywords internal
get_task.fda <- function(object) {
  "classification"
}


#' @keywords internal
get_task.gbm <- function(object) {
  if (object$distribution %in%
      c("coxph", "gaussian", "laplace", "tdist", "gamma", "poisson", "tweedie")) {
    "regression"
  } else if (object$distribution %in%
             c("bernoulli", "huberized", "multinomial", "adaboost")) {
    "classification"
  } else {
    "other"
  }
}


#' @keywords internal
get_task.glm <- function(object) {
  if(object$family$family == "binomial") {
    "classification"
  } else if (object$family$family %in%
             c("gaussian", "Gamma", "inverse.gaussian", "poisson")) {
    "regression"
  } else {
    "other"
  }
}


#' @keywords internal
get_task.ksvm <- function(object) {
  if (grepl("svr$", object@type)) {
    "regression"
  } else if (grepl("svc$", object@type)) {
    "classification"
  } else {
    "other"
  }
}


#' @keywords internal
get_task.lda <- function(object) {
  "classification"
}


#' @keywords internal
get_task.lm <- function(object) {
  # FIXME: What about multivariate response models?
  "regression"
}


#' @keywords internal
get_task.mars <- function(object) {
  if (ncol(object$fitted.values) > 1) {
    stop("`partial` does not currently support multivariate response models.")
  }
  "regression"
}


#' @keywords internal
get_task.multinom <- function(object) {
  # FIXME: What about multivariate response models?
  "classification"
}


#' @keywords internal
get_task.naiveBayes <- function(object) {
  "classification"
}


#' @keywords internal
get_task.nls <- function(object) {
  "regression"
}


#' @keywords internal
get_task.nnet <- function(object) {
  if (is.null(object$lev)) {
    "regression"
  } else {
    "classification"
  }
}


get_task.party <- function(object) {
  if (is.factor(object$fitted[["(response)"]])) {
    "classification"
  } else if (is.numeric(object$fitted[["(response)"]])) {
    "regression"
  } else {
    "other"
  }
}


#' @keywords internal
get_task.ppr <- function(object) {
  if (object$q > 1) {
    stop("`partial` does not currently support multivariate response models.")
  }
  "regression"
}


#' @keywords internal
get_task.qda<- function(object) {
  "classification"
}


#' @keywords internal
get_task.RandomForest <- function(object) {
  if (object@responses@is_nominal) {
    "classification"
  } else if (object@responses@is_ordinal || object@responses@is_censored) {
    "other"
  } else {
    "regression"
  }
}


#' @keywords internal
get_task.randomForest <- function(object) {
  if (object$type == "regression") {
    "regression"
  } else if (object$type == "classification") {
    "classification"
  } else {
    "unsupervised"
  }
}


#' @keywords internal
get_task.ranger <- function(object) {
  if (object$treetype == "Regression") {
    "regression"
  } else if (object$treetype == "Probability estimation") {
    "classification"
  } else if (object$treetype == "Classification") {
    stop("Partial dependence for classification tasks with \"ranger\" objects ",
         "requires a probability forest. Try refitting the model with ",
         "`probability = TRUE`; see `?ranger::ranger` for details.",
         call. = FALSE)
  } else {
    "other"
  }
}


#' @keywords internal
get_task.regbagg <- function(object) {
  "regression"
}


#' @keywords internal
get_task.rpart <- function(object) {
  if (object$method == "anova") {
    "regression"
  } else if (object$method == "class") {
    "classification"
  } else {
    "other"
  }
}


#' @keywords internal
get_task.svm <- function(object) {
  if (object$type %in% c(3, 4)) {
    "regression"
  } else {
    if (is.null(object$call$probability)) {
      stop("Partial dependence for classification tasks with \"svm\" objects ",
           "requires estimating predicted class probabilities. Try refitting ",
           "the model with `probability = TRUE`; see `?e1071::svm` for ",
           "details.", call. = FALSE)
    }
    "classification"
  }
}


#' @keywords internal
get_task.train <- function(object) {
  if (object$modelType == "Classification") {
    "classification"
  } else if (object$modelType == "Regression") {
    "regression"
  } else {
    "other"
  }
}


#' @keywords internal
get_task.xgb.Booster <- function(object) {
  # FIXME: "reg:linear" was changed to "reg:squarederror" in v0.90.0, but the
  # following should suffice without having to check package version.
  if (object$params$objective %in%
      c("reg:gamma", "reg:linear", "reg:logistic", "reg:squarederror",
        "reg:squaredlogerror", "count:poisson")) {
    "regression"
  } else if (object$params$objective %in%
             c("binary:logistic", "multi:softprob")) {
    "classification"
  } else if (object$params$objective %in%
             c("binary:logitraw", "multi:softmax")) {
    stop(paste("For classification, switch to an objective function",
               "that returns the predicted probabilities."))
  } else {
    "other"
  }
}

Try the pdp package in your browser

Any scripts or data that you put into this service are public.

pdp documentation built on June 8, 2022, 1:07 a.m.