Nothing
#' Spark ML -- Gradient Boosted Trees
#'
#' Perform binary classification and regression using gradient boosted trees. Multiclass classification is not supported yet.
#'
#' @template roxlate-ml-algo
#' @template roxlate-ml-decision-trees-base-params
#' @template roxlate-ml-formula-params
#' @template roxlate-ml-feature-subset-strategy
#' @param max_iter Maxmimum number of iterations.
#' @param step_size Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each estimator. (default = 0.1)
#' @param subsampling_rate Fraction of the training data used for learning each decision tree, in range (0, 1]. (default = 1.0)
#' @param loss_type Loss function which GBT tries to minimize. Supported: \code{"squared"} (L2) and \code{"absolute"} (L1) (default = squared) for regression and \code{"logistic"} (default) for classification. For \code{ml_gradient_boosted_trees}, setting \code{"auto"}
#' will default to the appropriate loss type based on model type.
#' @name ml_gradient_boosted_trees
NULL
#' @rdname ml_gradient_boosted_trees
#' @template roxlate-ml-decision-trees-type
#' @details \code{ml_gradient_boosted_trees} is a wrapper around \code{ml_gbt_regressor.tbl_spark} and \code{ml_gbt_classifier.tbl_spark} and calls the appropriate method based on model type.
#' @template roxlate-ml-old-feature-response
#'
#' @examples
#' \dontrun{
#' sc <- spark_connect(master = "local")
#' iris_tbl <- sdf_copy_to(sc, iris, name = "iris_tbl", overwrite = TRUE)
#'
#' partitions <- iris_tbl %>%
#' sdf_random_split(training = 0.7, test = 0.3, seed = 1111)
#'
#' iris_training <- partitions$training
#' iris_test <- partitions$test
#'
#' gbt_model <- iris_training %>%
#' ml_gradient_boosted_trees(Sepal_Length ~ Petal_Length + Petal_Width)
#'
#' pred <- ml_predict(gbt_model, iris_test)
#'
#' ml_regression_evaluator(pred, label_col = "Sepal_Length")
#' }
#'
#' @export
ml_gradient_boosted_trees <- function(x, formula = NULL,
type = c("auto", "regression", "classification"),
features_col = "features", label_col = "label",
prediction_col = "prediction", probability_col = "probability",
raw_prediction_col = "rawPrediction", checkpoint_interval = 10,
loss_type = c("auto", "logistic", "squared", "absolute"),
max_bins = 32, max_depth = 5, max_iter = 20L,
min_info_gain = 0, min_instances_per_node = 1,
step_size = 0.1, subsampling_rate = 1, feature_subset_strategy = "auto",
seed = NULL, thresholds = NULL, cache_node_ids = FALSE,
max_memory_in_mb = 256, uid = random_string("gradient_boosted_trees_"),
response = NULL, features = NULL, ...) {
formula <- ml_standardize_formula(formula, response, features)
response_col <- gsub("~.+$", "", formula) %>% trimws()
sdf <- spark_dataframe(x)
# choose classification vs. regression model based on column type
schema <- sdf_schema(sdf)
response_type <- schema[[response_col]]$type
type <- rlang::arg_match(type)
model_type <- if (!identical(type, "auto")) {
type
} else {
if (response_type %in% c("DoubleType", "IntegerType")) {
"regression"
} else {
"classification"
}
}
loss_type <- rlang::arg_match(loss_type)
loss_type <- if (identical(loss_type, "auto")) {
if (identical(model_type, "classification")) "logistic" else "squared"
} else if (identical(model_type, "regression")) {
if (!loss_type %in% c("squared", "absolute")) {
stop("`loss_type` must be \"squared\" or \"absolute\" for regression.")
}
loss_type
} else {
if (!identical(loss_type, "logistic")) {
stop("`loss_type` must be \"logistic\" for classification.")
}
loss_type
}
if (spark_version(spark_connection(x)) < "2.2.0" && !is.null(thresholds)) {
stop("`thresholds` is only supported for GBT in Spark 2.2.0+.")
}
switch(
model_type,
regression = ml_gbt_regressor(
x = x,
formula = formula,
max_iter = max_iter,
max_depth = max_depth,
step_size = step_size,
subsampling_rate = subsampling_rate,
feature_subset_strategy = feature_subset_strategy,
min_instances_per_node = min_instances_per_node,
max_bins = max_bins,
min_info_gain = min_info_gain,
loss_type = loss_type,
seed = seed,
checkpoint_interval = checkpoint_interval,
cache_node_ids = cache_node_ids,
max_memory_in_mb = max_memory_in_mb,
features_col = features_col,
label_col = label_col,
prediction_col = prediction_col,
uid = uid,
...
),
classification = ml_gbt_classifier(
x = x,
formula = formula,
max_iter = max_iter,
max_depth = max_depth,
step_size = step_size,
subsampling_rate = subsampling_rate,
feature_subset_strategy = feature_subset_strategy,
min_instances_per_node = min_instances_per_node,
max_bins = max_bins,
min_info_gain = min_info_gain,
loss_type = loss_type,
seed = seed,
thresholds = thresholds,
checkpoint_interval = checkpoint_interval,
cache_node_ids = cache_node_ids,
max_memory_in_mb = max_memory_in_mb,
features_col = features_col,
label_col = label_col,
prediction_col = prediction_col,
probability_col = probability_col,
raw_prediction_col = raw_prediction_col,
uid = uid,
...
)
)
}
new_ml_model_gbt_classification <- function(pipeline_model, formula, dataset, label_col,
features_col, predicted_label_col) {
new_ml_model_classification(
pipeline_model, formula,
dataset = dataset,
label_col = label_col, features_col = features_col,
predicted_label_col = predicted_label_col,
class = "ml_model_gbt_classification"
)
}
new_ml_model_gbt_regression <- function(pipeline_model, formula, dataset, label_col,
features_col) {
new_ml_model_regression(
pipeline_model, formula,
dataset = dataset,
label_col = label_col, features_col = features_col,
class = "ml_model_gbt_regression"
)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.