#' Feature Transformation -- PCA (Estimator)
#'
#' PCA trains a model to project vectors to a lower dimensional space of the top k principal components.
#'
#' @template roxlate-ml-feature-input-output-col
#' @template roxlate-ml-feature-transformer
#' @template roxlate-ml-feature-estimator-transformer
#'
#' @param k The number of principal components
#'
#' @export
ft_pca <- function(x, input_col = NULL, output_col = NULL, k = NULL,
uid = random_string("pca_"), ...) {
check_dots_used()
UseMethod("ft_pca")
}
#' @export
ft_pca.spark_connection <- function(x, input_col = NULL, output_col = NULL, k = NULL,
uid = random_string("pca_"), ...) {
.args <- list(
input_col = input_col,
output_col = output_col,
k = k,
uid = uid
) %>%
c(rlang::dots_list(...)) %>%
validator_ml_pca()
estimator <- spark_pipeline_stage(
x, "org.apache.spark.ml.feature.PCA",
input_col = .args[["input_col"]], output_col = .args[["output_col"]], uid = .args[["uid"]]
) %>%
jobj_set_param("setK", .args[["k"]]) %>%
new_ml_pca()
estimator
}
#' @export
ft_pca.ml_pipeline <- function(x, input_col = NULL, output_col = NULL, k = NULL,
uid = random_string("pca_"), ...) {
stage <- ft_pca.spark_connection(
x = spark_connection(x),
input_col = input_col,
output_col = output_col,
k = k,
uid = uid,
...
)
ml_add_stage(x, stage)
}
#' @export
ft_pca.tbl_spark <- function(x, input_col = NULL, output_col = NULL, k = NULL,
uid = random_string("pca_"), ...) {
stage <- ft_pca.spark_connection(
x = spark_connection(x),
input_col = input_col,
output_col = output_col,
k = k,
uid = uid,
...
)
if (is_ml_transformer(stage)) {
ml_transform(stage, x)
} else {
ml_fit_and_transform(stage, x)
}
}
new_ml_pca <- function(jobj) {
new_ml_estimator(jobj, class = "ml_pca")
}
new_ml_pca_model <- function(jobj) {
new_ml_transformer(
jobj,
explained_variance = possibly_null(~ read_spark_vector(jobj, "explainedVariance"))(),
pc = possibly_null(~ read_spark_matrix(jobj, "pc"))(),
class = "ml_pca_model"
)
}
validator_ml_pca <- function(.args) {
.args <- validate_args_transformer(.args)
.args[["k"]] <- cast_nullable_scalar_integer(.args[["k"]])
.args
}
#' @rdname ft_pca
#' @param features The columns to use in the principal components
#' analysis. Defaults to all columns in \code{x}.
#' @param pc_prefix Length-one character vector used to prepend names of components.
#'
#' @details \code{ml_pca()} is a wrapper around \code{ft_pca()} that returns a
#' \code{ml_model}.
#'
#' @examples
#' \dontrun{
#' library(dplyr)
#'
#' sc <- spark_connect(master = "local")
#' iris_tbl <- sdf_copy_to(sc, iris, name = "iris_tbl", overwrite = TRUE)
#'
#' iris_tbl %>%
#' select(-Species) %>%
#' ml_pca(k = 2)
#' }
#'
#' @export
#' @importFrom dplyr tbl_vars
ml_pca <- function(x,
features = tbl_vars(x),
k = length(features),
pc_prefix = "PC",
...) {
# If being used as a constructor alias for `ft_pca()`:
if (inherits(x, "spark_connection")) {
return(
rlang::exec("ft_pca.spark_connection", !!!rlang::dots_list(x = x, ...))
)
}
k <- cast_scalar_integer(k)
sc <- spark_connection(x)
assembled <- random_string("assembled")
out <- random_string("out")
features <- as.character(features)
pipeline <- ml_pipeline(sc) %>%
ft_vector_assembler(features, assembled) %>%
ft_pca(assembled, out, k = k)
pipeline_model <- pipeline %>%
ml_fit(x)
m <- new_ml_model(
pipeline_model = pipeline_model,
formula = paste0("~ ", paste0(features, collapse = " + ")),
dataset = x,
class = "ml_model_pca"
)
model <- m$model
pc <- model$pc
pc_names <- paste0(pc_prefix, seq_len(ncol(pc)))
rownames(pc) <- features
colnames(pc) <- pc_names
explained_variance <- model$explained_variance
if (!is.null(explained_variance)) {
names(explained_variance) <- pc_names
}
m$k <- k
m$pc <- pc
m$explained_variance <- explained_variance
m
}
#' @export
print.ml_model_pca <- function(x, ...) {
cat("Explained variance:", sep = "\n")
if (is.null(x$explained_variance)) {
cat("[not available in this version of Spark]", sep = "\n")
} else {
print_newline()
print(x$explained_variance)
}
print_newline()
cat("Rotation:", sep = "\n")
print(x$pc)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.