# NOTE: This code has been modified from AWS Stepfunctions Python:
# https://github.com/aws/aws-step-functions-data-science-sdk-python/blob/main/src/stepfunctions/template/pipeline/train.py
#' @import R6
#' @include steps_sagemaker.R
#' @include steps_states.R
#' @include workflow_stepfunctions.R
#' @include template_pipeline_common.R
#' @title TrainingPipeline for Sagemaker
#' @description Creates a standard training pipeline with the following steps in order:
#' \itemize{
#' \item{Train estimator}
#' \item{Create estimator model}
#' \item{Endpoint configuration}
#' \item{Deploy model}
#' }
#' @export
TrainingPipeline = R6Class("TrainingPipeline",
inherit = WorkflowTemplate,
public = list(
#' @description Initialize TrainingPipeline class
#' @param estimator (sagemaker.estimator.EstimatorBase): The estimator to use
#' for training. Can be a BYO estimator, Framework estimator or Amazon
#' algorithm estimator.
#' @param role (str): An AWS IAM role (either name or full Amazon Resource Name (ARN)).
#' This role is used to create, manage, and execute the Step Functions workflows.
#' @param inputs : Information about the training data. Please refer to the `fit()`
#' method of the associated estimator, as this can take any of the following forms:
#' \itemize{
#' \item{(str) - The S3 location where training data is saved.}
#' \item{(list[str, str] or list[str, `sagemaker.inputs.TrainingInput`]) - If
#' using multiple channels for training data, you can specify a list mapping
#' channel names to strings or `sagemaker.inputs.TrainingInput` objects.}
#' \item{(`sagemaker.inputs.TrainingInput`) - Channel configuration for S3 data
#' sources that can provide additional information about the training dataset.
#' See `sagemaker.inputs.TrainingInput` for full details.}
#' \item{(`sagemaker.amazon.amazon_estimator.RecordSet`) - A collection of Amazon
#' `Record` objects serialized and stored in S3. For use with an estimator
#' for an Amazon algorithm.}
#' \item{(list[`sagemaker.amazon.amazon_estimator.RecordSet`]) - A list of
#' `sagemaker.amazon.amazon_estimator.RecordSet` objects, where each instance
#' is a different channel of training data.}
#' }
#' @param s3_bucket (str): S3 bucket under which the output artifacts from
#' the training job will be stored. The parent path used is built
#' using the format: ``s3://{s3_bucket}/{pipeline_name}/models/{job_name}/``.
#' In this format, `pipeline_name` refers to the keyword argument provided
#' for TrainingPipeline. If a `pipeline_name` argument was not provided,
#' one is auto-generated by the pipeline as `training-pipeline-<timestamp>`.
#' Also, in the format, `job_name` refers to the job name provided
#' when calling the :meth:`TrainingPipeline.run()` method.
#' @param client (SFN.Client, optional): \code{\link[paws]{sfn}} client to use for creating and
#' interacting with the training pipeline in Step Functions. (default: None)
#' @param pipeline_name (str, optional): Name of the pipeline. This name will
#' be used to name jobs (if not provided when calling execute()),
#' models, endpoints, and S3 objects created by the pipeline. If
#' a `pipeline_name` argument was not provided, one is auto-generated
#' by the pipeline as `training-pipeline-<timestamp>`. (default:None)
initialize = function(estimator,
role,
inputs,
s3_bucket,
client=NULL,
pipeline_name=NULL){
self$estimator = estimator
self$inputs = inputs
self$pipeline_name = pipeline_name
if (is.null(self$pipeline_name)){
private$.pipeline_name_unique = TRUE
self$pipeline_name = sprintf('training-pipeline-%s', private$.generate_timestamp())
}
self$definition = self$build_workflow_definition()
self$input_template = private$.extract_input_template(self$definition)
workflow = Workflow$new(
name=self$pipeline_name,
definition=self$definition,
role=role,
format_json=TRUE,
client=client)
super$initialize(s3_bucket=s3_bucket, workflow=workflow, role=role, client=client)
},
#' @description Build the workflow definition for the training pipeline with
#' all the states involved.
#' @return :class:`~stepfunctions.steps.states.Chain`: Workflow definition as
#' a chain of states involved in the the training pipeline.
build_workflow_definition = function(){
default_name = self$pipeline_name
instance_type = self$estimator$instance_type
instance_count = self$estimator$instance_count
training_step = TrainingStep$new(
StepId$Train,
estimator=self$estimator,
job_name=paste0(default_name,'/estimator-source'),
data=self$inputs
)
model = self$estimator$create_model()
model_step = ModelStep$new(
StepId$CreateModel,
instance_type=instance_type,
model=model,
model_name=default_name
)
endpoint_config_step = EndpointConfigStep$new(
StepId$ConfigureEndpoint,
endpoint_config_name=default_name,
model_name=default_name,
initial_instance_count=instance_count,
instance_type=instance_type
)
deploy_step = EndpointStep$new(
StepId$Deploy,
endpoint_name=default_name,
endpoint_config_name=default_name,
)
return(Chain$new(list(training_step, model_step, endpoint_config_step, deploy_step)))
},
#' @description Run the training pipeline.
#' @param job_name (str, optional): Name for the training job. If one is not
#' provided, a job name will be auto-generated. (default: None)
#' @param hyperparameters (list, optional): Hyperparameters for the estimator
#' training. (default: None)
#' @return :R:class:`~stepfunctions.workflow.Execution`: Running instance of
#' the training pipeline.
execute = function(job_name=NULL,
hyperparameters=NULL){
inputs = self$input_template
if (!is.null(hyperparameters))
inputs[[StepId$Train]][['HyperParameters']] = lapply(
hyperparameters, function(v) as.character(v))
if (is.null(job_name))
job_name = sprintf('%s-%s','training-pipeline', private$.generate_timestamp())
# Configure training and model
inputs[[StepId$Train]][['TrainingJobName']] = paste0('estimator-', job_name)
inputs[[StepId$Train]][['OutputDataConfig']][['S3OutputPath']] = (
sprintf('s3://%s/%s/models',
self$s3_bucket,
self$workflow.name)
)
inputs[[StepId$CreateModel]][['ModelName']] = job_name
# Configure endpoint
inputs[[StepId$ConfigureEndpoint]][['EndpointConfigName']] = job_name
for (variant in inputs[[StepId$ConfigureEndpoint]][['ProductionVariants']]){
variant[['ModelName']] = job_name}
inputs[[StepId$Deploy]][['EndpointConfigName']] = job_name
inputs[[StepId$Deploy]][['EndpointName']] = job_name
# Configure the path to model artifact
inputs[[StepId$CreateModel]][['PrimaryContainer']][['ModelDataUrl']] = sprintf('%s/%s/output/model.tar.gz',
inputs[[StepId$Train]][['OutputDataConfig']][['S3OutputPath']],
inputs[[StepId$Train]][['TrainingJobName']]
)
return(self$workflow$execute(inputs=inputs, name=job_name))
}
),
private = list(
.allowed_kwargs = 'pipeline_name'
),
lock_objects=F
)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.