# NOTE: This code has been modified from AWS Sagemaker Python:
# https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/transformer.py
#' @include r_utils.R
#' @import R6
#' @import sagemaker.core
#' @title Transformer class
#' @description A class for handling creating and interacting with Amazon SageMaker
#' transform jobs
#' @export
Transformer = R6Class("Transformer",
public = list(
#' @description Initialize a ``Transformer``.
#' @param model_name (str): Name of the SageMaker model being used for the
#' transform job.
#' @param instance_count (int): Number of EC2 instances to use.
#' @param instance_type (str): Type of EC2 instance to use, for example,
#' 'ml.c4.xlarge'.
#' @param strategy (str): The strategy used to decide how to batch records in
#' a single request (default: None). Valid values: 'MultiRecord'
#' and 'SingleRecord'.
#' @param assemble_with (str): How the output is assembled (default: None).
#' Valid values: 'Line' or 'None'.
#' @param output_path (str): S3 location for saving the transform result. If
#' not specified, results are stored to a default bucket.
#' @param output_kms_key (str): Optional. KMS key ID for encrypting the
#' transform output (default: None).
#' @param accept (str): The accept header passed by the client to
#' the inference endpoint. If it is supported by the endpoint,
#' it will be the format of the batch transform output.
#' @param max_concurrent_transforms (int): The maximum number of HTTP requests
#' to be made to each individual transform container at one time.
#' @param max_payload (int): Maximum size of the payload in a single HTTP
#' request to the container in MB.
#' @param tags (list[dict]): List of tags for labeling a transform job
#' (default: None). For more, see the SageMaker API documentation for
#' `Tag <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_.
#' @param env (dict): Environment variables to be set for use during the
#' transform job (default: None).
#' @param base_transform_job_name (str): Prefix for the transform job when the
#' :meth:`~sagemaker.transformer.Transformer.transform` method
#' launches. If not specified, a default prefix will be generated
#' based on the training image name that was used to train the
#' model associated with the transform job.
#' @param sagemaker_session (sagemaker.session.Session): Session object which
#' manages interactions with Amazon SageMaker APIs and any other
#' AWS services needed. If not specified, the estimator creates one
#' using the default AWS configuration chain.
#' @param volume_kms_key (str): Optional. KMS key ID for encrypting the volume
#' attached to the ML compute instance (default: None).
initialize = function(model_name,
instance_count,
instance_type,
strategy=NULL,
assemble_with=NULL,
output_path=NULL,
output_kms_key=NULL,
accept=NULL,
max_concurrent_transforms=NULL,
max_payload=NULL,
tags=NULL,
env=NULL,
base_transform_job_name=NULL,
sagemaker_session=NULL,
volume_kms_key=NULL){
self$model_name = model_name
self$strategy = strategy
self$env = env
self$output_path = output_path
self$output_kms_key = output_kms_key
self$accept = accept
self$assemble_with = assemble_with
self$instance_count = instance_count
self$instance_type = instance_type
self$volume_kms_key = volume_kms_key
self$max_concurrent_transforms = max_concurrent_transforms
self$max_payload = max_payload
self$tags = tags
self$base_transform_job_name = base_transform_job_name
self$.current_job_name = NULL
self$latest_transform_job = NULL
self$.reset_output_path = FALSE
self$sagemaker_session = sagemaker_session %||% Session$new()
},
#' @description Start a new transform job.
#' @param data (str): Input data location in S3.
#' @param data_type (str): What the S3 location defines (default: 'S3Prefix').
#' Valid values:
#' \itemize{
#' \item{\strong{'S3Prefix'} - the S3 URI defines a key name prefix. All objects with this prefix
#' will be used as inputs for the transform job.}
#' \item{\strong{'ManifestFile'} - the S3 URI points to a single manifest file listing each S3
#' object to use as an input for the transform job.}}
#' @param content_type (str): MIME type of the input data (default: None).
#' @param compression_type (str): Compression type of the input data, if
#' compressed (default: None). Valid values: 'Gzip', None.
#' @param split_type (str): The record delimiter for the input object
#' (default: 'None'). Valid values: 'None', 'Line', 'RecordIO', and
#' 'TFRecord'.
#' @param job_name (str): job name (default: None). If not specified, one will
#' be generated.
#' @param input_filter (str): A JSONPath to select a portion of the input to
#' pass to the algorithm container for inference. If you omit the
#' field, it gets the value '$', representing the entire input.
#' For CSV data, each row is taken as a JSON array,
#' so only index-based JSONPaths can be applied, e.g. $[0], $[1:].
#' CSV data should follow the `RFC format <https://tools.ietf.org/html/rfc4180>`_.
#' See `Supported JSONPath Operators
#' <https://docs.aws.amazon.com/sagemaker/latest/dg/batch-transform-data-processing.html#data-processing-operators>`_
#' for a table of supported JSONPath operators.
#' For more information, see the SageMaker API documentation for
#' `CreateTransformJob
#' <https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTransformJob.html>`_.
#' Some examples: "$[1:]", "$.features" (default: None).
#' @param output_filter (str): A JSONPath to select a portion of the
#' joined/original output to return as the output.
#' For more information, see the SageMaker API documentation for
#' `CreateTransformJob
#' <https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTransformJob.html>`_.
#' Some examples: "$[1:]", "$.prediction" (default: None).
#' @param join_source (str): The source of data to be joined to the transform
#' output. It can be set to 'Input' meaning the entire input record
#' will be joined to the inference result. You can use OutputFilter
#' to select the useful portion before uploading to S3. (default:
#' None). Valid values: Input, None.
#' @param experiment_config (dict[str, str]): Experiment management configuration.
#' Dictionary contains three optional keys,
#' 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
#' (default: ``None``).
#' @param model_client_config (dict[str, str]): Model configuration.
#' Dictionary contains two optional keys,
#' 'InvocationsTimeoutInSeconds', and 'InvocationsMaxRetries'.
#' (default: ``None``).
#' @param wait (bool): Whether the call should wait until the job completes
#' (default: TRUE).
#' @param logs (bool): Whether to show the logs produced by the job.
#' Only meaningful when wait is True (default: TRUE).
#' @param ... Other parameters (currently not used)
#' @return NULL invisible
transform = function(data,
data_type="S3Prefix",
content_type=NULL,
compression_type=NULL,
split_type=NULL,
job_name=NULL,
input_filter=NULL,
output_filter=NULL,
join_source=NULL,
experiment_config=NULL,
model_client_config=NULL,
wait=TRUE,
logs=TRUE,
...){
local_mode = self$sagemaker_session$local_mode
if(!local_mode && !grepl("^s3://", data)){
ValueError$new(sprintf("Invalid S3 URI: %s", data))
}
if (!is.null(job_name)) {
self$.current_job_name = job_name
} else {
base_name = self$base_transform_job_name
if (is.null(base_name)){
base_name = private$.retrieve_base_name()
}
self$.current_job_name = sagemaker.core::name_from_base(base_name)
}
if (is.null(self$output_path) || self$.reset_output_path){
self$output_path = sprintf("s3://%s/%s",
self$sagemaker_session$default_bucket(), self$.current_job_name
)
self$.reset_output_path = TRUE
}
self$latest_transform_job = private$.start_new(
data,
data_type,
content_type,
compression_type,
split_type,
input_filter,
output_filter,
join_source,
experiment_config,
model_client_config
)
if(wait) self$wait(logs=logs)
return(invisible(NULL))
},
#' @description Delete the corresponding SageMaker model for this Transformer.
delete_model = function(){
self$sagemaker_session$delete_model(self$model_name)
},
#' @description Wait for latest running batch transform job
#' @param logs return logs
wait = function(logs=TRUE){
private$.ensure_last_transform_job()
if (logs) {
self$sagemaker_session$logs_for_transform_job(self$latest_transform_job, wait=TRUE)
} else {
self$sagemaker_session$wait_for_transform_job(self$latest_transform_job)
}
},
#' @description Stop latest running batch transform job.
#' @param wait wait for transform job
stop_transform_job = function(wait=TRUE){
private$.ensure_last_transform_job()
self$sagemaker_session$stop_transform_job(name=self$latest_transform_job)
if(wait) self$wait()
},
#' @description Attach an existing transform job to a new Transformer instance
#' @param transform_job_name (str): Name for the transform job to be attached.
#' @param sagemaker_session (sagemaker.session.Session): Session object which
#' manages interactions with Amazon SageMaker APIs and any other
#' AWS services needed. If not specified, one will be created using
#' the default AWS configuration chain.
#' @return Transformer (class): The Transformer instance with the
#' specified transform job attached.
attach = function(transform_job_name,
sagemaker_session= NULL){
sagemaker_session = sagemaker_session %||% Session$new()
job_details = sagemaker_session$sagemaker$describe_transform_job(
TransformJobName=transform_job_name
)
init_params = private$.prepare_init_params_from_job_description(job_details)
# clone current class
transformer = self$clone()
init_params$sagemaker_session = sagemaker_session
do.call(transformer$initialize, init_params)
transformer$latest_transform_job = init_params$base_transform_job_name
return(transformer)
},
#' @description format class
format = function(){
return(format_class(self))
}
),
private = list(
.retrieve_base_name = function(){
image_name = private$.retrieve_image_uri()
if (!is.null(image_name)){
return(base_name_from_image(image_name))
}
return(self$model_name)
},
.retrieve_image_uri = function(){
tryCatch({
model_desc = self$sagemaker_session$sagemaker$describe_model(
ModelName=self$model_name)
}, error = function(e){
ValueError$new(sprintf(
"Failed to fetch model information for %s. ", self$model_name),
"Please ensure that the model exists. ",
"Local instance types require locally created models.")
})
primary_container = model_desc$PrimaryContainer
if (!is.null(primary_container) || length(primary_container) > 0)
return(primary_container$Image)
containers = model_desc$Containers
if (!is.null(containers) || length(containers) > 0)
return(containers[[1]]$Image)
return(NULL)
},
.ensure_last_transform_job = function(){
if (is.null(self$latest_transform_job))
ValueError$new("No transform job available")
},
# Convert the transform job description to init params that can be
# handled by the class constructor
# Args:
# job_details (dict): the returned job details from a
# describe_transform_job API call.
# Returns:
# dict: The transformed init_params
.prepare_init_params_from_job_description = function(job_details){
init_params = list()
init_params[["model_name"]] = job_details$ModelName
init_params[["instance_count"]] = job_details$TransformResources$InstanceCount
init_params[["instance_type"]] = job_details$TransformResources$InstanceType
init_params[["volume_kms_key"]] = job_details$TransformResources$VolumeKmsKeyId
init_params[["strategy"]] = job_details$BatchStrategy
init_params[["assemble_with"]] = job_details$TransformOutput$AssembleWith
init_params[["output_path"]] = job_details$TransformOutput$S3OutputPath
init_params[["output_kms_key"]] = job_details$TransformOutput$KmsKeyId
init_params[["accept"]] = job_details$TransformOutput$Accept
init_params[["max_concurrent_transforms"]] = job_details$MaxConcurrentTransforms
init_params[["max_payload"]] = job_details$MaxPayloadInMB
init_params[["base_transform_job_name"]] = job_details$TransformJobName
return(init_params)
},
.start_new = function(data,
data_type,
content_type,
compression_type,
split_type,
input_filter,
output_filter,
join_source,
experiment_config,
model_client_config){
transform_args = private$.get_transform_args(
data,
data_type,
content_type,
compression_type,
split_type,
input_filter,
output_filter,
join_source,
experiment_config,
model_client_config
)
do.call(self$sagemaker_session$transform, transform_args)
return(self$.current_job_name)
},
.get_transform_args = function(data,
data_type,
content_type,
compression_type,
split_type,
input_filter,
output_filter,
join_source,
experiment_config,
model_client_config){
config = private$.load_config(
data, data_type, content_type, compression_type, split_type
)
data_processing = private$.prepare_data_processing(
input_filter, output_filter, join_source
)
transform_args = modifyList(
config,
list(
"job_name"=self$.current_job_name,
"model_name"=self$model_name,
"strategy"=self$strategy,
"max_concurrent_transforms"=self$max_concurrent_transforms,
"max_payload"=self$max_payload,
"env"=self$env,
"experiment_config"=experiment_config,
"model_client_config"=model_client_config,
"tags"=self$tags,
"data_processing"=data_processing
)
)
return(transform_args)
},
.load_config = function(data,
data_type,
content_type,
compression_type,
split_type){
input_config = private$.format_inputs_to_input_config(
data, data_type, content_type, compression_type, split_type)
output_config = private$.prepare_output_config(
self$output_path,
self$output_kms_key,
self$assemble_with,
self$accept
)
resource_config = private$.prepare_resource_config(
self$instance_count, self$instance_type, self$volume_kms_key
)
return (list(
"input_config"= input_config,
"output_config"= output_config,
"resource_config"= resource_config)
)
},
.format_inputs_to_input_config = function(data,
data_type,
content_type = NULL,
compression_type = NULL,
split_type = NULL){
config = list("DataSource"=list("S3DataSource"=list("S3DataType"= data_type, "S3Uri"= data)))
if (!is.null(content_type))
config[["ContentType"]] = content_type
if (!is.null(compression_type))
config[["CompressionType"]] = compression_type
if (!is.null(split_type))
config[["SplitType"]] = split_type
return(config)
},
.prepare_output_config = function(s3_path,
kms_key_id,
assemble_with = NULL,
accept = NULL){
config = list("S3OutputPath"= s3_path)
if (!is.null(kms_key_id))
config[["KmsKeyId"]] = kms_key_id
if (!is.null(assemble_with))
config[["AssembleWith"]] = assemble_with
if (!is.null(accept))
config[["Accept"]] = accept
return(config)
},
.prepare_resource_config = function(instance_count,
instance_type,
volume_kms_key = NULL){
config = list("InstanceCount"= instance_count, "InstanceType"= instance_type)
if (!is.null(volume_kms_key))
config[["VolumeKmsKeyId"]] = volume_kms_key
return(config)
},
.prepare_data_processing = function(input_filter = NULL,
output_filter = NULL,
join_source = NULL){
config = list()
if (!is.null(input_filter))
config[["InputFilter"]] = input_filter
if (!is.null(output_filter))
config[["OutputFilter"]] = output_filter
if (!is.null(join_source))
config[["JoinSource"]] = join_source
if (length(config) == 0)
return(NULL)
return(config)
}
),
lock_objects = FALSE
)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.