Nothing
#' Spark ML -- Generalized Linear Regression
#'
#' Perform regression using Generalized Linear Model (GLM).
#'
#' @template roxlate-ml-algo
#' @template roxlate-ml-formula-params
#' @template roxlate-ml-linear-regression-params
#' @template roxlate-ml-predictor-params
#' @param offset_col Offset column name. If this is not set, we treat all instance offsets as 0.0. The feature specified as offset has a constant coefficient of 1.0.
#' @param family Name of family which is a description of the error distribution to be used in the model. Supported options: "gaussian", "binomial", "poisson", "gamma" and "tweedie". Default is "gaussian".
#' @param link Name of link function which provides the relationship between the linear predictor and the mean of the distribution function. See for supported link functions.
#' @param link_power Index in the power link function. Only applicable to the Tweedie family. Note that link power 0, 1, -1 or 0.5 corresponds to the Log, Identity, Inverse or Sqrt link, respectively. When not set, this value defaults to 1 - variancePower, which matches the R "statmod" package.
#' @param link_prediction_col Link prediction (linear predictor) column name. Default is not set, which means we do not output link prediction.
#' @param solver Solver algorithm for optimization.
#' @param variance_power Power in the variance function of the Tweedie distribution which provides the relationship between the variance and mean of the distribution. Only applicable to the Tweedie family. (see \href{https://en.wikipedia.org/wiki/Tweedie_distribution}{Tweedie Distribution (Wikipedia)}) Supported values: 0 and [1, Inf). Note that variance power 0, 1, or 2 corresponds to the Gaussian, Poisson or Gamma family, respectively.
#'
#' @details Valid link functions for each family is listed below. The first link function of each family is the default one.
#' \itemize{
#' \item gaussian: "identity", "log", "inverse"
#' \item binomial: "logit", "probit", "loglog"
#' \item poisson: "log", "identity", "sqrt"
#' \item gamma: "inverse", "identity", "log"
#' \item tweedie: power link function specified through \code{link_power}. The default link power in the tweedie family is \code{1 - variance_power}.
#' }
#'
#' @examples
#' \dontrun{
#' library(sparklyr)
#'
#' sc <- spark_connect(master = "local")
#' mtcars_tbl <- sdf_copy_to(sc, mtcars, name = "mtcars_tbl", overwrite = TRUE)
#'
#' partitions <- mtcars_tbl %>%
#' sdf_random_split(training = 0.7, test = 0.3, seed = 1111)
#'
#' mtcars_training <- partitions$training
#' mtcars_test <- partitions$test
#'
#' # Specify the grid
#' family <- c("gaussian", "gamma", "poisson")
#' link <- c("identity", "log")
#' family_link <- expand.grid(family = family, link = link, stringsAsFactors = FALSE)
#' family_link <- data.frame(family_link, rmse = 0)
#'
#' # Train the models
#' for (i in seq_len(nrow(family_link))) {
#' glm_model <- mtcars_training %>%
#' ml_generalized_linear_regression(mpg ~ .,
#' family = family_link[i, 1],
#' link = family_link[i, 2]
#' )
#'
#' pred <- ml_predict(glm_model, mtcars_test)
#' family_link[i, 3] <- ml_regression_evaluator(pred, label_col = "mpg")
#' }
#'
#' family_link
#' }
#'
#' @export
ml_generalized_linear_regression <- function(x, formula = NULL, family = "gaussian",
link = NULL, fit_intercept = TRUE, offset_col = NULL,
link_power = NULL, link_prediction_col = NULL,
reg_param = 0, max_iter = 25, weight_col = NULL,
solver = "irls", tol = 1e-6, variance_power = 0,
features_col = "features", label_col = "label",
prediction_col = "prediction",
uid = random_string("generalized_linear_regression_"),
...) {
#check_dots_used()
UseMethod("ml_generalized_linear_regression")
}
ml_generalized_linear_regression_impl <- function(x, formula = NULL, family = "gaussian",
link = NULL, fit_intercept = TRUE, offset_col = NULL,
link_power = NULL, link_prediction_col = NULL,
reg_param = 0, max_iter = 25, weight_col = NULL,
solver = "irls", tol = 1e-6, variance_power = 0,
features_col = "features", label_col = "label",
prediction_col = "prediction",
uid = random_string("generalized_linear_regression_"),
response = NULL, features = NULL,
...) {
offset_col <- param_min_version(x, offset_col, "2.3.0")
fam <- family
if (is.function(fam)) {
warning("Specifying a function for `family` is deprecated; please specify strings for `family` and `link`.")
fam <- fam()
}
if (inherits(fam, "family") | is.function(fam)) {
link <- fam$link
family <- fam$family
}
ml_process_model(
x = x,
r_class = "ml_generalized_linear_regression",
ml_function = new_ml_model_generalized_linear_regression,
features = features,
response = response,
uid = uid,
formula = formula,
invoke_steps = list(
features_col = features_col,
label_col = label_col,
prediction_col = prediction_col,
family = family,
link = link,
fit_intercept = fit_intercept,
reg_param = reg_param,
max_iter = max_iter,
solver = solver,
tol = tol,
link_power = link_power,
variance_power = variance_power,
link_prediction_col = link_prediction_col,
weight_col = weight_col,
offset_col = offset_col
)
)
}
#' @export
params_validator.ml_generalized_linear_regression <- function(x) {
x <- params_base_validator(x)
x$solver <- function(x) cast_choice(x, "irls")
x
}
# ------------------------------- Methods --------------------------------------
#' @export
ml_generalized_linear_regression.spark_connection <- ml_generalized_linear_regression_impl
#' @export
ml_generalized_linear_regression.ml_pipeline <- ml_generalized_linear_regression_impl
#' @export
ml_generalized_linear_regression.tbl_spark <- ml_generalized_linear_regression_impl
# ------------------------------ Fitted models ---------------------------------
new_ml_generalized_linear_regression_model <- function(jobj) {
summary <- if (invoke(jobj, "hasSummary")) {
fit_intercept <- ml_get_param_map(jobj)$fit_intercept
new_ml_generalized_linear_regression_training_summary(
invoke(jobj, "summary"), fit_intercept
)
} else {
NULL
}
new_ml_prediction_model(
jobj,
coefficients = read_spark_vector(jobj, "coefficients"),
intercept = invoke(jobj, "intercept"),
link_prediction_col = if (
invoke(jobj, "isSet", invoke(jobj, "linkPredictionCol")))
invoke(jobj, "getLinkPredictionCol")
else NULL
,
summary = summary,
class = "ml_generalized_linear_regression_model"
)
}
new_ml_generalized_linear_regression_summary <- function(jobj, fit_intercept, ..., class = character()) {
version <- jobj %>%
spark_connection() %>%
spark_version()
arrange_stats <- make_stats_arranger(fit_intercept)
new_ml_summary(
jobj,
aic = function() invoke(jobj, "aic"), # lazy val
degrees_of_freedom = function() invoke(jobj, "degreesOfFreedom"), # lazy val
deviance = function() invoke(jobj, "deviance"), # lazy val
dispersion = function() invoke(jobj, "dispersion"), # lazy val
null_deviance = function() invoke(jobj, "nullDeviance"), # lazy val
num_instances = if (version > "2.2.0") function() invoke(jobj, "numInstances") else NULL, # lazy val
prediction_col = invoke(jobj, "predictionCol"),
predictions = invoke(jobj, "predictions") %>% sdf_register(),
rank = invoke(jobj, "rank"), # lazy val
residual_degree_of_freedom = function() invoke(jobj, "residualDegreeOfFreedom"), # lazy val
residual_degree_of_freedom_null = function() invoke(jobj, "residualDegreeOfFreedomNull"), # lazy val
residuals = function(type = "deviance") (invoke(jobj, "residuals", type) %>% sdf_register()),
...,
class = "ml_generalized_linear_regression_summary"
)
}
# ---------------------------- Constructors ------------------------------------
new_ml_generalized_linear_regression_training_summary <- function(jobj, fit_intercept) {
arrange_stats <- make_stats_arranger(fit_intercept)
s <- new_ml_generalized_linear_regression_summary(
jobj, fit_intercept,
coefficient_standard_errors = possibly_null(~ invoke(jobj, "coefficientStandardErrors") %>% arrange_stats()),
num_iterations = invoke(jobj, "numIterations"),
solver = invoke(jobj, "solver"),
p_values = possibly_null(~ invoke(jobj, "pValues") %>% arrange_stats()),
t_values = possibly_null(~ invoke(jobj, "tValues") %>% arrange_stats()),
class = "ml_generalized_linear_regression_training_summary"
)
s
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.