R/ml_pipeline.R

Defines functions spark_jobj.ml_pipeline_stage print.ml_pipeline_model print.ml_pipeline print_pipeline spark_connection.ml_pipeline_model spark_connection.ml_pipeline_stage spark_connection.ml_pipeline new_ml_pipeline_model new_ml_pipeline ml_pipeline.ml_pipeline_stage ml_pipeline.spark_connection ml_pipeline

Documented in ml_pipeline

#' Spark ML -- Pipelines
#'
#' Create Spark ML Pipelines
#'
#' @param x Either a \code{spark_connection} or \code{ml_pipeline_stage} objects
#' @template roxlate-ml-uid
#' @param ... \code{ml_pipeline_stage} objects.
#'
#' @return When \code{x} is a \code{spark_connection}, \code{ml_pipeline()} returns an empty pipeline object. When \code{x} is a \code{ml_pipeline_stage}, \code{ml_pipeline()} returns an \code{ml_pipeline} with the stages set to \code{x} and any transformers or estimators given in \code{...}.
#' @export
ml_pipeline <- function(x, ..., uid = random_string("pipeline_")) {
  UseMethod("ml_pipeline")
}

#' @export
ml_pipeline.spark_connection <-
  function(x, ..., uid = random_string("pipeline_")) {
    uid <- cast_string(uid)
    jobj <- invoke_new(x, "org.apache.spark.ml.Pipeline", uid)
    new_ml_pipeline(jobj)
  }

#' @export
ml_pipeline.ml_pipeline_stage <-
  function(x, ..., uid = random_string("pipeline_")) {
    uid <- cast_string(uid)
    sc <- spark_connection(x)
    dots <- list(...) %>%
      lapply(function(x) {
        spark_jobj(x)
      })
    stages <- c(spark_jobj(x), dots)
    jobj <- invoke_static(
      sc,
      "sparklyr.MLUtils",
      "createPipelineFromStages",
      uid,
      stages
    )
    new_ml_pipeline(jobj)
  }

# Constructors

new_ml_pipeline <- function(jobj, ..., class = character()) {
  stages <- tryCatch(
    {
      jobj %>%
        invoke("getStages") %>%
        lapply(ml_call_constructor)
    },
    error = function(e) {
      NULL
    }
  )
  new_ml_estimator(
    jobj,
    stages = stages,
    stage_uids = if (rlang::is_null(stages)) {
      NULL
    } else {
      sapply(stages, function(x) {
        x$uid
      })
    },
    ...,
    class = c(class, "ml_pipeline")
  )
}

new_ml_pipeline_model <- function(jobj, ..., class = character()) {
  stages <- tryCatch(
    {
      jobj %>%
        invoke("stages")
    },
    error = function(e) {
      NULL
    }
  )

  if (!rlang::is_na(stages)) {
    stages <- lapply(stages, ml_call_constructor)
  }

  new_ml_transformer(
    jobj,
    stages = stages,
    stage_uids = if (rlang::is_null(stages)) {
      NULL
    } else {
      sapply(stages, function(x) {
        x$uid
      })
    },
    ...,
    class = c(class, "ml_pipeline_model")
  )
}

#' @export
spark_connection.ml_pipeline <- function(x, ...) {
  spark_connection(spark_jobj(x))
}

#' @export
spark_connection.ml_pipeline_stage <- function(x, ...) {
  spark_connection(spark_jobj(x))
}

#' @export
spark_connection.ml_pipeline_model <- function(x, ...) {
  spark_connection(spark_jobj(x))
}

print_pipeline <- function(x, type = c("pipeline", "pipeline_model")) {
  type <- match.arg(type)
  if (identical(type, "pipeline")) {
    cat(paste0("Pipeline (Estimator) with "))
  } else {
    cat(paste0("PipelineModel (Transformer) with "))
  }
  num_stages <- length(ml_stages(x))
  if (num_stages == 0) {
    cat("no stages")
  } else if (num_stages == 1) {
    cat("1 stage")
  } else {
    cat(paste0(num_stages, " stages"))
  }
  cat("\n")
  cat(paste0("<", x$uid, ">"), "\n")
  if (num_stages > 0) {
    cat("  Stages", "\n")
    for (n in seq_len(num_stages)) {
      stage_output <- capture.output(print(ml_stage(x, n)))
      cat(paste0("  |--", n, " ", stage_output[1]), sep = "\n")
      cat(paste0("  |    ", stage_output[-1]), sep = "\n")
    }
  }
}

#' @export
print.ml_pipeline <- function(x, ...) {
  print_pipeline(x, "pipeline")
}

#' @export
print.ml_pipeline_model <- function(x, ...) {
  print_pipeline(x, "pipeline_model")
}

#' @export
spark_jobj.ml_pipeline_stage <- function(x, ...) x$.jobj

Try the sparklyr package in your browser

Any scripts or data that you put into this service are public.

sparklyr documentation built on Nov. 2, 2023, 5:09 p.m.