#' @export
multi_predict._H2OMultinomialModel <-
function(object, new_data, type = "class", ...) {
args <- list(...)
if (any(names(rlang::enquos(...)) == "newdata")) {
rlang::abort("Did you mean to use `new_data` instead of `newdata`?")
}
res <- switch(object$fit@algorithm,
gbm = gbm_multi_predict(object, new_data, type, args$trees),
deeplearning = mlp_multi_predict(object, new_data, type, args$epochs)
)
res
}
#' @export
multi_predict._H2ORegressionModel <-
function(object, new_data, type = "numeric", ...) {
args <- list(...)
if (any(names(rlang::enquos(...)) == "newdata")) {
rlang::abort("Did you mean to use `new_data` instead of `newdata`?")
}
res <- switch(object$fit@algorithm,
gbm = gbm_multi_predict(object, new_data, type, args$trees),
deeplearning = mlp_multi_predict(object, new_data, type, args$epochs)
)
res
}
gbm_multi_predict <- function(object, new_data, type, trees) {
trees <- sort(trees)
preds <- h2o::staged_predict_proba.H2OModel(
object = object$fit,
newdata = h2o::as.h2o(new_data)
)
preds <- as.data.frame(preds)
if (type %in% c("class", "prob")) {
n_classes <- length(object$lvl)
res <- purrr::map(trees, function(tree) {
pred_class_names <- paste0("C", seq_len(n_classes))
tree_names <- paste0("T", tree)
x <- preds[, paste(tree_names, pred_class_names, sep = ".")]
colnames(x) <- object$lvl
x <- as.data.frame(x)
x$predict <- object$lvl[apply(x, 1, which.max)]
x$trees <- tree
x
})
} else if (type == "numeric") {
res <- purrr::map(trees, function(tree) {
tibble(trees = tree, .pred = preds[, tree])
})
}
# prepare predictions in parsnip format
if (type == "class") {
res <- purrr::map(
res,
~ dplyr::select(.x, trees, predict) %>%
dplyr::rename(.pred_class = predict) %>%
dplyr::mutate(
.pred_class = factor(.pred_class, levels = object$lvl),
.row = 1:max(dplyr::row_number())
)
)
} else if (type == "numeric") {
res <- purrr::map(
res,
~ dplyr::mutate(.x, .row = 1:max(dplyr::row_number()))
)
} else if (type == "prob") {
new_names <- object$lvl
names(new_names) <- paste(".pred", object$lvl, sep = "_")
res <- purrr::map(
res,
~ dplyr::select(.x, trees, !!!object$lvl) %>%
dplyr::rename(!!new_names)
)
res <- res %>%
dplyr::mutate(.row = 1:max(dplyr::row_number()))
}
res <- dplyr::bind_rows(res)
res <- dplyr::arrange(res, .row, trees)
res <- split(res[, -ncol(res)], res$.row)
tibble(.pred = res)
}
mlp_multi_predict <- function(object, new_data, type, epochs) {
epochs <- sort(epochs)
new_data <- h2o::as.h2o(new_data)
preds <- purrr::map(epochs, function(epoch) {
model <- h2o::h2o.deeplearning(
x = object$fit@parameters$x,
y = object$fit@parameters$y,
training_frame = object$fit@parameters$training_frame,
checkpoint = object$fit@model_id,
epochs = epoch,
hidden = object$fit@parameters$hidden,
l2 = object$fit@parameters$l2,
hidden_dropout_ratios = object$fit@parameters$hidden_dropout_ratios,
activation = object$fit@parameters$activation
)
predict(model, new_data) %>%
as.data.frame()
})
# prepare predictions in parsnip format
if (type == "class") {
res <- purrr::map2(
preds,
epochs,
~ dplyr::select(.x, predict) %>%
dplyr::rename(.pred_class = predict) %>%
dplyr::mutate(
epochs = .y,
.pred_class = factor(.pred_class, levels = object$lvl),
.row = 1:max(dplyr::row_number())
)
)
} else if (type == "prob") {
new_names <- object$lvl
names(new_names) <- paste(".pred", object$lvl, sep = "_")
res <- purrr::map2(
preds,
epochs,
~ dplyr::select(.x, !!!object$lvl) %>%
dplyr::rename(!!!new_names) %>%
dplyr::mutate(epochs = .y, .row = 1:max(dplyr::row_number()))
)
} else if (type == "numeric") {
res <- purrr::map2(preds, epochs, function(x, epoch) {
tibble(epochs = epoch, .pred = x$predict)
})
res <- purrr::map(res, ~ dplyr::mutate(.x, .row = 1:max(dplyr::row_number())))
}
res <- dplyr::bind_rows(res)
res <- dplyr::arrange(res, .row, epochs)
res <- split(res[, -ncol(res)], res$.row)
tibble(.pred = res)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.