#' Spark ML -- Transform, fit, and predict methods (ml_ interface)
#'
#' Methods for transformation, fit, and prediction. These are mirrors of the corresponding \link{sdf-transform-methods}.
#'
#' @param x A \code{ml_estimator}, \code{ml_transformer} (or a list thereof), or \code{ml_model} object.
#' @param dataset A \code{tbl_spark}.
#' @template roxlate-ml-dots
#'
#' @details These methods are
#'
#' @return When \code{x} is an estimator, \code{ml_fit()} returns a transformer whereas \code{ml_fit_and_transform()} returns a transformed dataset. When \code{x} is a transformer, \code{ml_transform()} and \code{ml_predict()} return a transformed dataset. When \code{ml_predict()} is called on a \code{ml_model} object, additional columns (e.g. probabilities in case of classification models) are appended to the transformed output for the user's convenience.
#'
#' @name ml-transform-methods
NULL
#' @rdname ml-transform-methods
#' @export
is_ml_transformer <- function(x) inherits(x, "ml_transformer")
#' @rdname ml-transform-methods
#' @export
is_ml_estimator <- function(x) inherits(x, "ml_estimator")
#' @rdname ml-transform-methods
#' @export
ml_fit <- function(x, dataset, ...) {
UseMethod("ml_fit")
}
#' @rdname ml-transform-methods
#' @export
ml_fit.default <- function(x, dataset, ...) {
if (!is_ml_estimator(x)) {
stop("'ml_fit()' is only applicable to 'ml_estimator' objects")
}
spark_jobj(x) %>%
invoke("fit", spark_dataframe(dataset)) %>%
ml_call_constructor()
}
#' @rdname ml-transform-methods
#' @export
ml_transform <- function(x, dataset, ...) {
UseMethod("ml_transform")
}
#' @export
ml_transform.default <- function(x, dataset, ...) {
stop("Transformers must be 'ml_transformer' objects.")
}
#' @export
ml_transform.list <- function(x, dataset, ...) {
if (!all(sapply(x, is_ml_transformer))) {
stop("Transformers must be 'ml_transformer' objects.")
}
sdf <- spark_dataframe(dataset)
transforms <- x %>%
lapply(spark_jobj)
result_sdf <- Reduce(
function(dataset, transformer) invoke(transformer, "transform", dataset),
transforms,
init = sdf
)
sdf_register(result_sdf)
}
#' @export
ml_transform.ml_transformer <- function(x, dataset, ...) {
sdf <- spark_dataframe(dataset)
spark_jobj(x) %>%
invoke("transform", sdf) %>%
sdf_register()
}
#' @rdname ml-transform-methods
#' @export
ml_fit_and_transform <- function(x, dataset, ...) {
if (!is_ml_estimator(x)) {
stop("'ml_fit_and_transform()' is only applicable to 'ml_estimator' objects")
}
sdf <- spark_dataframe(dataset)
spark_jobj(x) %>%
invoke("fit", sdf) %>%
invoke("transform", sdf) %>%
sdf_register()
}
#' @rdname ml-transform-methods
#' @export
ml_predict <- function(x, dataset, ...) {
UseMethod("ml_predict")
}
#' @export
ml_predict.default <- function(x, dataset, ...) {
ml_transform(x, dataset)
}
#' @export
ml_predict.ml_model_regression <- function(x, dataset, ...) {
# when dataset is not supplied, attempt to use original dataset
if (missing(dataset) || rlang::is_null(dataset)) {
dataset <- x$dataset
}
cols <- x$model %>%
ml_params(c("prediction_col", "variance_col"),
allow_null = TRUE
) %>%
Filter(length, .) %>%
unlist(use.names = FALSE)
x$pipeline_model %>%
ml_transform(dataset) %>%
select(!!!rlang::syms(c(tbl_vars(dataset), cols)))
}
#' @rdname ml-transform-methods
#' @param probability_prefix String used to prepend the class probability output columns.
#' @export
ml_predict.ml_model_classification <- function(
x, dataset,
probability_prefix = "probability_", ...) {
sc <- spark_connection(x$model)
probability_prefix <- cast_string(probability_prefix)
if (missing(dataset) || rlang::is_null(dataset)) {
dataset <- x$dataset
}
predictions <- x$pipeline_model %>%
ml_transform(dataset)
probability_col <- ml_param(x$model, "probability_col", allow_null = TRUE)
if (rlang::is_null(probability_col)) {
predictions
} else {
index_labels <- spark_sanitize_names(
x$index_labels %||% (seq_len(x$model$num_classes) - 1L),
sc$config
)
sdf_separate_column(
predictions, probability_col,
paste0(probability_prefix, index_labels)
)
}
}
#' @export
ml_predict.ml_model_clustering <- function(x, dataset, ...) {
# when dataset is not supplied, attempt to use original dataset
if (missing(dataset) || rlang::is_null(dataset)) {
dataset <- x$dataset
}
x$pipeline_model %>%
ml_transform(dataset)
}
#' @export
ml_predict.ml_model_recommendation <- function(x, dataset, ...) {
# when dataset is not supplied, attempt to use original dataset
if (missing(dataset) || rlang::is_null(dataset)) {
dataset <- x$dataset
}
x$pipeline_model %>%
ml_transform(dataset)
}
#' Spark ML -- Transform, fit, and predict methods (sdf_ interface)
#'
#' Deprecated methods for transformation, fit, and prediction. These are mirrors of the corresponding \link{ml-transform-methods}.
#'
#' @param x A \code{tbl_spark}.
#' @param model A \code{ml_transformer} or a \code{ml_model} object.
#' @param transformer A \code{ml_transformer} object.
#' @param estimator A \code{ml_estimator} object.
#' @param ... Optional arguments passed to the corresponding \code{ml_} methods.
#'
#' @return \code{sdf_predict()}, \code{sdf_transform()}, and \code{sdf_fit_and_transform()} return a transformed dataframe whereas \code{sdf_fit()} returns a \code{ml_transformer}.
#'
#' @name sdf-transform-methods
NULL
#' @rdname sdf-transform-methods
#' @export
sdf_predict <- function(x, model, ...) {
.Deprecated("ml_predict")
UseMethod("sdf_predict")
}
#' @export
sdf_predict.default <- function(x, model, ...) {
ml_predict(model, sdf_register(x), ...)
}
#' @rdname sdf-transform-methods
#' @export
sdf_transform <- function(x, transformer, ...) {
.Deprecated("ml_transform")
ml_transform(transformer, sdf_register(x))
}
#' @rdname sdf-transform-methods
#' @export
sdf_fit <- function(x, estimator, ...) {
.Deprecated("ml_fit")
ml_fit(estimator, sdf_register(x))
}
#' @rdname sdf-transform-methods
#' @export
sdf_fit_and_transform <- function(x, estimator, ...) {
.Deprecated("ml_fit_and_transform")
ml_fit_and_transform(estimator, sdf_register(x))
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.