
Defines functions normalize_mlr_names normalize_h2o_names yhat.model_stack yhat.workflow yhat.xgb.Booster yhat.GraphLearner yhat.LearnerClassif yhat.LearnerRegr yhat.keras yhat.scikitlearn_model yhat.h2o yhat.WrappedModel

Documented in yhat.GraphLearner yhat.keras yhat.LearnerClassif yhat.LearnerRegr yhat.model_stack yhat.scikitlearn_model yhat.workflow yhat.WrappedModel yhat.xgb.Booster

#' Wrapper over the predict function
#' These functions are default predict functions.
#' Each function returns a single numeric score for each new observation.
#' Those functions are very important since information from many models have to be extracted with various techniques.
#' Currently supported packages are:
#' \itemize{
#' \item \code{mlr} see more in \code{\link{explain_mlr}}
#' \item \code{h2o} see more in \code{\link{explain_h2o}}
#' \item \code{scikit-learn} see more in \code{\link{explain_scikitlearn}}
#' \item \code{keras} see more in \code{\link{explain_keras}}
#' \item \code{mlr3} see more in \code{\link{explain_mlr3}}
#' \item \code{xgboost} see more in \code{\link{explain_xgboost}}
#' \item \code{tidymodels} see more in \code{\link{explain_tidymodels}}
#' }
#' @inheritParams DALEX::yhat
#' @return An numeric vector of predictions

#' @rdname yhat
#' @export
yhat.WrappedModel <- function(X.model, newdata, ...) {
         "classif" = {
           pred <- predict(X.model, newdata = newdata)
           if (X.model$learner$predict.type != "prob") {
           if (!is.null(attr(X.model, "predict_function_target_column"))) {
             return(pred$data[,attr(X.model, "predict_function_target_column")])
           if ("truth" %in% colnames(pred$data)){
             if (ncol(pred$data) == 4) {
               response <- pred$data[, 3]
             } else {
               response <- pred$data[, -c(1, ncol(pred$data))]
               names(response) <- normalize_mlr_names(names(response))

           } else {
             if (ncol(pred$data) == 3) {
               response <- pred$data[, 2]
             } else {
               response <- pred$data[, -ncol(pred$data)]
               names(response) <- normalize_mlr_names(names(response))
         "regr" = {
           pred <- predict(X.model, newdata = newdata)
           response <- pred$data$response

         stop("Model is not explainable mlr object"))

yhat.h2o <- function(X.model, newdata, ...) {
    "H2ORegressionModel" = {
      if (!inherits(newdata, "H2OFrame")) {
        newdata <- h2o::as.h2o(newdata)
      as.vector(h2o::h2o.predict(X.model, newdata = newdata))

    "H2OBinomialModel" = {
      if (!inherits(newdata, "H2OFrame")) {
        newdata <- h2o::as.h2o(newdata)
      ret <- as.data.frame(h2o::h2o.predict(X.model, newdata = newdata))

      if (!is.null(attr(X.model, "predict_function_target_column"))) {
        return(ret[,attr(X.model, "predict_function_target_column")])

      if ("predict" %in% colnames(ret)) {
        ret <- ret [,3]
      } else {
        ret <- ret[,2]

    "H2OMultinomialModel" = {
      if (!inherits(newdata, "H2OFrame")) {
        newdata <- h2o::as.h2o(newdata)
      ret <- as.data.frame(h2o::h2o.predict(X.model, newdata = newdata))
      colnames(ret) <- normalize_h2o_names(colnames(ret))

      if (!is.null(attr(X.model, "predict_function_target_column"))) {
        return(ret[,attr(X.model, "predict_function_target_column")])

    stop("Model is not explainable h2o object")

#' @rdname yhat
#' @export
yhat.H2ORegressionModel <- yhat.h2o

#' @rdname yhat
#' @export
yhat.H2OBinomialModel <- yhat.h2o

#' @rdname yhat
#' @export
yhat.H2OMultinomialModel <- yhat.h2o

#' @rdname yhat
#' @export
yhat.scikitlearn_model <- function(X.model, newdata, ...) {
  if ("predict_proba" %in% names(X.model)) {
    pred <-  X.model$predict_proba(newdata)

    colnames(pred) <- 0:(ncol(pred)-1)

    if (!is.null(attr(X.model, "predict_function_target_column"))) {
      return(pred[,attr(X.model, "predict_function_target_column")])

    if (ncol(pred) == 2) {
      pred <- pred[,2]

  } else {
    pred <-  X.model$predict(newdata)

#' @rdname yhat
#' @export
yhat.keras <- function(X.model, newdata, ...) {
  if ("predict_proba" %in% names(X.model)) {
    pred <-  X.model$predict_proba(newdata)
    colnames(pred) <- 0:(ncol(pred)-1)

    if (!is.null(attr(X.model, "predict_function_target_column"))) {
      return(pred[,attr(X.model, "predict_function_target_column")])

    if (ncol(pred) == 1) {
      pred <- as.numeric(pred)
    } else if (ncol(pred) == 2) {
      pred <- as.numeric(pred[,2])
  } else {
    pred <-  X.model$predict(newdata)

#' yhat.mljar_model <- function(X.model, newdata, ...) {
#'   unlist(mljar::mljar_predict(model = X.model, x_pred = newdata, project_title = X.model$project), use.names = FALSE)
#' }

#' @rdname yhat
#' @export
yhat.LearnerRegr <- function(X.model, newdata, ...) {
  X.model$predict_newdata(newdata, ...)$response

#' @rdname yhat
#' @export
yhat.LearnerClassif <- function(X.model, newdata, ...) {
  pred <- X.model$predict_newdata(newdata)

  # return probabilities for class: 1
  response <- pred$prob

  if (!is.null(attr(X.model, "predict_function_target_column"))) {
    return(response[,attr(X.model, "predict_function_target_column")])

  if (ncol(response) == 2) {
    response <- response[,2]

#' @rdname yhat
#' @export
yhat.GraphLearner <- function(X.model, newdata, ...) {
  if ("prob" %in% X.model$predict_types) {
    pred <- X.model$predict_newdata(newdata)
    # return probabilities for class: 1
    response <- pred$prob

    if (!is.null(attr(X.model, "predict_function_target_column"))) {
      return(response[,attr(X.model, "predict_function_target_column")])

    if (ncol(response) == 2) {
      response <- response[,2]
  } else {
    X.model$predict_newdata(newdata, ...)$response

#' @rdname yhat
#' @export
yhat.xgb.Booster <- function(X.model, newdata, ...) {
  if (!is.null(attr(X.model, "encoder"))) {
    newdata <- attr(X.model, "encoder")(newdata)

  if (X.model$params$objective == "multi:softprob") {
    if (!is.null(attr(X.model, "true_labels"))) {
      col_names <- levels(as.factor(attr(X.model, "true_labels")))
    } else {
      col_names <- 0:(X.model$params$num_class-1)
    p <- predict(X.model, newdata, type="response")
    ret <- matrix(p, ncol = X.model$params$num_class, byrow = TRUE)
    if (!is.null(attr(X.model, "predict_function_target_column"))) {
      return(ret[,attr(X.model, "predict_function_target_column")])
    colnames(ret) <- col_names
  } else if (X.model$params$objective == "multi:softprob") {
    stop("Please use objective\"multi:softmax\" to get probability output")
  } else if (X.model$params$objective == "binary:logistic") {
    ret <- predict(X.model, newdata, type="response")
    if (!is.null(attr(X.model, "predict_function_target_column"))) {
      return(ret[,attr(X.model, "predict_function_target_column")])
  } else if (X.model$params$objective == "binary:logitraw" | X.model$params$objective == "binary:hinge") {
    stop("Please use objective\"binary:logistic\" to get probability output")
  } else {
    ret <- predict(X.model, newdata, type = "response")

#' @rdname yhat
#' @export
yhat.workflow <- function(X.model, newdata, ...) {
    if (inherits(newdata, "tbl")) {
      newdata <- as.data.frame(newdata)
    if (X.model$fit$fit$spec$mode == "classification") {
      response <- as.matrix(predict(X.model, newdata, type = "prob"))
      colnames(response) <- X.model$fit$fit$lvl

      if (!is.null(attr(X.model, "predict_function_target_column"))) {
        return(response[,attr(X.model, "predict_function_target_column")])

      if (ncol(response) == 2) {
        response <- response[,2]
    } else if (X.model$fit$fit$spec$mode == "regression") {
      pred <- predict(X.model, newdata)
      response <- pred$.pred
    } else {
      stop("Mode specification has to be either classification or regression")



#' @rdname yhat
#' @export
yhat.model_stack <- function(X.model, newdata, ...) {
  if (inherits(newdata, "tbl")) {
    newdata <- as.data.frame(newdata)
  if (X.model$mode == "classification") {
    response <- as.data.frame(predict(X.model, newdata, type = "prob"))
    colnames(response) <- vapply(colnames(response), function(x) {
      strsplit(x, ".pred_", fixed = TRUE)[[1]][2]
    }, FUN.VALUE = character(1))
    if (!is.null(attr(X.model, "predict_function_target_column"))) {
      return(response[, attr(X.model, "predict_function_target_column")])
    if (ncol(response) == 2) {
      response <- response[, 2]
  } else if (X.model$mode == "regression") {
    pred <- predict(X.model, newdata)
    response <- pred$.pred
  } else {
    stop("Mode specification has to be either classification or regression")

normalize_h2o_names <- function(names) {
  ret <- sapply(names, FUN = function(x) {
    tmp <- strsplit(x, "p")
    if (!is.na(tmp[[1]][2])) {
    } else {
  names(ret) <- NULL

normalize_mlr_names <- function(names) {
  ret <- sapply(names, FUN = function(x) {
    tmp <- strsplit(x, "prob.")
    if (!is.na(tmp[[1]][2])) {
    } else {
  names(ret) <- NULL

Try the DALEXtra package in your browser

Any scripts or data that you put into this service are public.

DALEXtra documentation built on May 31, 2023, 5:30 p.m.