#' @include ml_clustering.R
NULL
#' @rdname ml_gradient_boosted_trees
#' @template roxlate-ml-probabilistic-classifier-params
#' @export
ml_gbt_classifier <- function(x, formula = NULL, max_iter = 20, max_depth = 5,
step_size = 0.1, subsampling_rate = 1,
feature_subset_strategy = "auto", min_instances_per_node = 1L,
max_bins = 32, min_info_gain = 0, loss_type = "logistic",
seed = NULL, thresholds = NULL, checkpoint_interval = 10,
cache_node_ids = FALSE, max_memory_in_mb = 256,
features_col = "features", label_col = "label",
prediction_col = "prediction", probability_col = "probability",
raw_prediction_col = "rawPrediction",
uid = random_string("gbt_classifier_"), ...) {
check_dots_used()
UseMethod("ml_gbt_classifier")
}
#' @export
ml_gbt_classifier.spark_connection <- function(x, formula = NULL, max_iter = 20, max_depth = 5,
step_size = 0.1, subsampling_rate = 1,
feature_subset_strategy = "auto", min_instances_per_node = 1L,
max_bins = 32, min_info_gain = 0, loss_type = "logistic",
seed = NULL, thresholds = NULL, checkpoint_interval = 10,
cache_node_ids = FALSE, max_memory_in_mb = 256,
features_col = "features", label_col = "label",
prediction_col = "prediction", probability_col = "probability",
raw_prediction_col = "rawPrediction",
uid = random_string("gbt_classifier_"), ...) {
.args <- list(
max_iter = max_iter,
max_depth = max_depth,
step_size = step_size,
subsampling_rate = subsampling_rate,
feature_subset_strategy = feature_subset_strategy,
min_instances_per_node = min_instances_per_node,
max_bins = max_bins,
min_info_gain = min_info_gain,
loss_type = loss_type,
seed = seed,
thresholds = thresholds,
checkpoint_interval = checkpoint_interval,
cache_node_ids = cache_node_ids,
max_memory_in_mb = max_memory_in_mb,
features_col = features_col,
label_col = label_col,
prediction_col = prediction_col,
probability_col = probability_col,
raw_prediction_col = raw_prediction_col
) %>%
c(rlang::dots_list(...)) %>%
validator_ml_gbt_classifier()
stage_class <- "org.apache.spark.ml.classification.GBTClassifier"
jobj <- (
if (spark_version(x) < "2.2.0") {
spark_pipeline_stage(
x, stage_class, uid,
features_col = .args[["features_col"]],
label_col = .args[["label_col"]], prediction_col = .args[["prediction_col"]]
)
} else {
spark_pipeline_stage(
x, stage_class, uid,
features_col = .args[["features_col"]],
label_col = .args[["label_col"]],
prediction_col = .args[["prediction_col"]],
probability_col = .args[["probability_col"]],
raw_prediction_col = .args[["raw_prediction_col"]]
)
}) %>% (
function(obj) {
do.call(
invoke,
c(obj, "%>%", Filter(
function(x) !is.null(x),
list(
list("setCheckpointInterval", .args[["checkpoint_interval"]]),
list("setMaxBins", .args[["max_bins"]]),
list("setMaxDepth", .args[["max_depth"]]),
list("setMinInfoGain", .args[["min_info_gain"]]),
list("setMinInstancesPerNode", .args[["min_instances_per_node"]]),
list("setCacheNodeIds", .args[["cache_node_ids"]]),
list("setMaxMemoryInMB", .args[["max_memory_in_mb"]]),
list("setLossType", .args[["loss_type"]]),
list("setMaxIter", .args[["max_iter"]]),
list("setStepSize", .args[["step_size"]]),
list("setSubsamplingRate", .args[["subsampling_rate"]]),
jobj_set_param_helper(obj, "setFeatureSubsetStrategy", .args[["feature_subset_strategy"]], "2.3.0", "auto"),
jobj_set_param_helper(obj, "setThresholds", .args[["thresholds"]]),
jobj_set_param_helper(obj, "setSeed", .args[["seed"]])
)
))
)
})
new_ml_gbt_classifier(jobj)
}
#' @export
ml_gbt_classifier.ml_pipeline <- function(x, formula = NULL, max_iter = 20, max_depth = 5,
step_size = 0.1, subsampling_rate = 1,
feature_subset_strategy = "auto", min_instances_per_node = 1L,
max_bins = 32, min_info_gain = 0, loss_type = "logistic",
seed = NULL, thresholds = NULL, checkpoint_interval = 10,
cache_node_ids = FALSE, max_memory_in_mb = 256,
features_col = "features", label_col = "label",
prediction_col = "prediction", probability_col = "probability",
raw_prediction_col = "rawPrediction",
uid = random_string("gbt_classifier_"), ...) {
stage <- ml_gbt_classifier.spark_connection(
x = spark_connection(x),
formula = formula,
max_iter = max_iter,
max_depth = max_depth,
step_size = step_size,
subsampling_rate = subsampling_rate,
feature_subset_strategy = feature_subset_strategy,
min_instances_per_node = min_instances_per_node,
max_bins = max_bins,
min_info_gain = min_info_gain,
loss_type = loss_type,
seed = seed,
thresholds = thresholds,
checkpoint_interval = checkpoint_interval,
cache_node_ids = cache_node_ids,
max_memory_in_mb = max_memory_in_mb,
features_col = features_col,
label_col = label_col,
prediction_col = prediction_col,
probability_col = probability_col,
raw_prediction_col = raw_prediction_col,
uid = uid,
...
)
ml_add_stage(x, stage)
}
#' @export
ml_gbt_classifier.tbl_spark <- function(x, formula = NULL, max_iter = 20, max_depth = 5,
step_size = 0.1, subsampling_rate = 1,
feature_subset_strategy = "auto", min_instances_per_node = 1L,
max_bins = 32, min_info_gain = 0, loss_type = "logistic",
seed = NULL, thresholds = NULL, checkpoint_interval = 10,
cache_node_ids = FALSE, max_memory_in_mb = 256,
features_col = "features", label_col = "label",
prediction_col = "prediction", probability_col = "probability",
raw_prediction_col = "rawPrediction",
uid = random_string("gbt_classifier_"),
response = NULL, features = NULL,
predicted_label_col = "predicted_label", ...) {
formula <- ml_standardize_formula(formula, response, features)
stage <- ml_gbt_classifier.spark_connection(
x = spark_connection(x),
formula = NULL,
max_iter = max_iter,
max_depth = max_depth,
step_size = step_size,
subsampling_rate = subsampling_rate,
feature_subset_strategy = feature_subset_strategy,
min_instances_per_node = min_instances_per_node,
max_bins = max_bins,
min_info_gain = min_info_gain,
loss_type = loss_type,
seed = seed,
thresholds = thresholds,
checkpoint_interval = checkpoint_interval,
cache_node_ids = cache_node_ids,
max_memory_in_mb = max_memory_in_mb,
features_col = features_col,
label_col = label_col,
prediction_col = prediction_col,
probability_col = probability_col,
raw_prediction_col = raw_prediction_col,
uid = uid,
...
)
if (is.null(formula)) {
stage %>%
ml_fit(x)
} else {
ml_construct_model_supervised(
new_ml_model_gbt_classification,
predictor = stage,
formula = formula,
dataset = x,
features_col = features_col,
label_col = label_col,
predicted_label_col = predicted_label_col
)
}
}
# Validator
validator_ml_gbt_classifier <- function(.args) {
.args <- ml_validate_decision_tree_args(.args)
.args[["thresholds"]] <- cast_double_list(.args[["thresholds"]], allow_null = TRUE)
.args[["max_iter"]] <- cast_scalar_integer(.args[["max_iter"]])
.args[["step_size"]] <- cast_scalar_double(.args[["step_size"]])
.args[["subsampling_rate"]] <- cast_scalar_double(.args[["subsampling_rate"]])
.args[["loss_type"]] <- cast_choice(.args[["loss_type"]], "logistic")
.args[["feature_subset_strategy"]] <- cast_string(.args[["feature_subset_strategy"]])
.args
}
new_ml_gbt_classifier <- function(jobj) {
v <- jobj %>%
spark_connection() %>%
spark_version()
if (v < "2.2.0") {
new_ml_predictor(jobj, class = "ml_gbt_classifier")
} else {
new_ml_probabilistic_classifier(jobj, class = "ml_gbt_classifier")
}
}
new_ml_gbt_classification_model <- function(jobj) {
v <- jobj %>%
spark_connection() %>%
spark_version()
if (v < "2.2.0") {
new_ml_prediction_model(
jobj,
# `lazy val featureImportances`
feature_importances = possibly_null(~ read_spark_vector(jobj, "featureImportances")),
num_classes = possibly_null(~ invoke(jobj, "numClasses"))(),
# `lazy val totalNumNodes`
total_num_nodes = function() invoke(jobj, "totalNumNodes"),
tree_weights = invoke(jobj, "treeWeights"),
# `def trees`
trees = function() {
invoke(jobj, "trees") %>%
purrr::map(new_ml_decision_tree_regression_model)
},
class = "ml_multilayer_perceptron_classification_model"
)
} else {
new_ml_probabilistic_classification_model(
jobj,
# `lazy val featureImportances`
feature_importances = possibly_null(~ read_spark_vector(jobj, "featureImportances")),
num_classes = possibly_null(~ invoke(jobj, "numClasses"))(),
# `lazy val totalNumNodes`
total_num_nodes = function() invoke(jobj, "totalNumNodes"),
tree_weights = invoke(jobj, "treeWeights"),
# `def trees`
trees = function() {
invoke(jobj, "trees") %>%
purrr::map(new_ml_decision_tree_regression_model)
},
class = "ml_gbt_classification_model"
)
}
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.