R/mleap.R

Defines functions ml_write_bundle mleap_load_bundle retrieve_model_schema mleap_model_schema print.mleap_transformer new_mleap_transformer

Documented in mleap_load_bundle mleap_model_schema ml_write_bundle

new_mleap_transformer <- function(jobj) {
  info <- jobj$info()
  structure(
    list(
      uid = info$uid()$toString(),
      name = info$name(),
      format = info$format()$toString(),
      mleap_version = info$version(),
      schema = retrieve_model_schema(jobj),
      .jobj = jobj
    ),
    class = "mleap_transformer"
  )
}

# no covered because uid different each run
#' @export
print.mleap_transformer <- function(x, ...) { # nocov start
   cat("MLeap Transformer\n")
   cat(paste0("<", x$uid, ">"), "\n")
   cat(paste0("  ", "Name: ", x$name), "\n")
   cat(paste0("  ", "Format: ", x$format), "\n")
   cat(paste0("  ", "MLeap Version: ", x$mleap_version))
} # nocov end

#' MLeap model schema
#' 
#' Returns the schema of an MLeap transformer.
#' 
#' @param x An MLeap model object.
#' @return A data frame of the model schema.
#' 
#' @export
mleap_model_schema <- function(x) {
  x$schema
}

retrieve_model_schema <- function(jobj) {
  input_schema <- jobj$root()$inputSchema()$fields()
  output_schema <- jobj$root()$outputSchema()$fields()
  ct <- rJava::.jnew("scala.reflect.ClassTag$")
  ct <- ct$`MODULE$`$apply(input_schema$head()$getClass())
  
  get_schema_tbl <- function(schema, ct, io) {
    df <- schema$toArray(ct) %>%
      as.list() %>%
      purrr::map(function(x) {
        data_type <- x$dataType()
        base_type <- data_type$base()$toString()
        dimensions <- tryCatch(
          data_type$dimensions()$get()$toIterable()$array() %>%
            paste0("(", ., ")", collapse = ", "),
          error = function(e) NA_character_
        )
        is_nullable <- data_type$isNullable()
        list(x$name(),
             base_type,
             is_nullable,
             dimensions)
      }) %>%
      purrr::transpose() %>%
      purrr::set_names(c("name", "type", "nullable", "dimension")) %>%
      purrr::map(unlist) %>%
      tibble::as_tibble()
    
    df$io <- io
    df
  }
  
  rbind(
    get_schema_tbl(input_schema, ct, "input"),
    get_schema_tbl(output_schema, ct, "output")
  )
}

#' Loads an MLeap bundle
#' 
#' @param path Path to the exported bundle zip file.
#' @return An MLeap model object.
#'
#' @export
mleap_load_bundle <- function(path) {
  # if mleap runtime jars aren't in class path (from package load),
  #   load jars
  if (!any(grepl("mleap-runtime", rJava::.jclassPath())))
    load_mleap_jars()
  
  ctx_builder <- rJava::.jnew("ml.combust.mleap.runtime.javadsl.ContextBuilder")
  ctx <- rJava::.jcall(
    ctx_builder, "Lml/combust/mleap/runtime/MleapContext;", 
    "createMleapContext"
    )
  path <- normalizePath(path)
  bundle_file <- rJava::.jnew("java.io.File", path)
  bundle_builder <- rJava::.jnew("ml.combust.mleap.runtime.javadsl.BundleBuilder")
  transformer <- rJava::.jcall(
    bundle_builder,
    "Lml/combust/bundle/dsl/Bundle;",
    "load", bundle_file, ctx
    )
  new_mleap_transformer(transformer)
}

#' Export a Spark pipeline for serving
#' 
#' This functions serializes a Spark pipeline model into an MLeap bundle.
#' 
#' @param x A Spark pipeline model object.
#' @param sample_input A sample input Spark DataFrame with the expected schema.
#' @param path Where to save the bundle.
#' @param overwrite Whether to overwrite an existing file, defaults to \code{FALSE}.
#' 
#' @examples
#' \dontrun{
#' library(sparklyr)
#' sc <- spark_connect(master = "local")
#' mtcars_tbl <- sdf_copy_to(sc, mtcars, overwrite = TRUE)
#' pipeline <- ml_pipeline(sc) %>%
#'   ft_binarizer("hp", "big_hp", threshold = 100) %>%
#'   ft_vector_assembler(c("big_hp", "wt", "qsec"), "features") %>%
#'   ml_gbt_regressor(label_col = "mpg")
#' pipeline_model <- ml_fit(pipeline, mtcars_tbl)
#' model_path <- file.path(tempdir(), "mtcars_model.zip")
#' ml_write_bundle(pipeline_model, 
#'                 mtcars_tbl,
#'                 model_path,
#'                 overwrite = TRUE)
#' }
#' 
#' @export
ml_write_bundle <- function(x, sample_input, path, overwrite = FALSE) {
  stages <- if (purrr::is_bare_list(x)) {
    purrr::map(x, sparklyr::spark_jobj)
  } else {
    list(sparklyr::spark_jobj(x))
  }
  
  sc <- sparklyr::spark_connection(stages[[1]])
  
  sdf <- x %>% 
    sparklyr::ml_transform(sample_input) %>% 
    sparklyr::spark_dataframe()
  
  path <- resolve_path(path)
  
  if (!identical(fs::path_ext(path), "zip"))
    stop("The bundle path must have a `.zip` extension.", call. = FALSE)

  if (fs::file_exists(path)) {
    if (!overwrite) {
      stop(paste0("Can't save bundle file: ", basename(path), " already exists."),
           call. = FALSE)
    } else
      fs::file_delete(path)
  }
  
  sparklyr::invoke_static(sc, "mleap.Main", "exportArrayToBundle",
                          sdf, uri(path), stages)
  message("Model successfully exported.")
  invisible(NULL)
}

Try the mleap package in your browser

Any scripts or data that you put into this service are public.

mleap documentation built on Jan. 28, 2021, 1:09 a.m.