ml_metrics_regression: Extracts metrics from a fitted table

View source: R/ml_metrics.R

ml_metrics_regressionR Documentation

Extracts metrics from a fitted table

Description

The function works best when passed a 'tbl_spark' created by 'ml_predict()'. The output 'tbl_spark' will contain the correct variable types and format that the given Spark model "evaluator" expects.

Usage

ml_metrics_regression(
  x,
  truth,
  estimate = prediction,
  metrics = c("rmse", "rsq", "mae"),
  ...
)

Arguments

x

A 'tbl_spark' containing the estimate (prediction) and the truth (value of what actually happened)

truth

The name of the column from 'x' that contains the value of what actually happened

estimate

The name of the column from 'x' that contains the prediction. Defaults to 'prediction', since it is the default that 'ml_predict()' uses.

metrics

A character vector with the metrics to calculate. For regression models the possible values are: 'rmse' (Root mean squared error), 'mse' (Mean squared error),'rsq' (R squared), 'mae' (Mean absolute error), and 'var' (Explained variance). Defaults to: 'rmse', 'rsq', 'mae'

...

Optional arguments; currently unused.

Details

The ‘ml_metrics' family of functions implement Spark’s 'evaluate' closer to how the 'yardstick' package works. The functions expect a table containing the truth and estimate, and return a 'tibble' with the results. The 'tibble' has the same format and variable names as the output of the 'yardstick' functions.

Examples

## Not run: 
sc <- spark_connect("local")
tbl_iris <- copy_to(sc, iris)
iris_split <- sdf_random_split(tbl_iris, training = 0.5, test = 0.5)
training <- iris_split$training
reg_formula <- "Sepal_Length ~ Sepal_Width + Petal_Length + Petal_Width"
model <- ml_generalized_linear_regression(training, reg_formula)
tbl_predictions <- ml_predict(model, iris_split$test)
tbl_predictions %>%
  ml_metrics_regression(Sepal_Length)

## End(Not run)

rstudio/sparklyr documentation built on Sept. 18, 2024, 6:10 a.m.