R/template_pipeline_train.R

# NOTE: This code has been modified from AWS Stepfunctions Python:
# https://github.com/aws/aws-step-functions-data-science-sdk-python/blob/main/src/stepfunctions/template/pipeline/train.py

#' @import R6

#' @include steps_sagemaker.R
#' @include steps_states.R
#' @include workflow_stepfunctions.R
#' @include template_pipeline_common.R

#' @title TrainingPipeline for Sagemaker
#' @description Creates a standard training pipeline with the following steps in order:
#' \itemize{
#'   \item{Train estimator}
#'   \item{Create estimator model}
#'   \item{Endpoint configuration}
#'   \item{Deploy model}
#' }
#' @export
TrainingPipeline = R6Class("TrainingPipeline",
  inherit = WorkflowTemplate,
  public = list(

    #' @description Initialize TrainingPipeline class
    #' @param estimator (sagemaker.estimator.EstimatorBase): The estimator to use
    #'              for training. Can be a BYO estimator, Framework estimator or Amazon
    #'              algorithm estimator.
    #' @param role (str): An AWS IAM role (either name or full Amazon Resource Name (ARN)).
    #'              This role is used to create, manage, and execute the Step Functions workflows.
    #' @param inputs : Information about the training data. Please refer to the `fit()`
    #'              method of the associated estimator, as this can take any of the following forms:
    #' \itemize{
    #'     \item{(str) - The S3 location where training data is saved.}
    #'     \item{(list[str, str] or list[str, `sagemaker.inputs.TrainingInput`]) - If
    #'           using multiple channels for training data, you can specify a list mapping
    #'           channel names to strings or `sagemaker.inputs.TrainingInput` objects.}
    #'     \item{(`sagemaker.inputs.TrainingInput`) - Channel configuration for S3 data
    #'           sources that can provide additional information about the training dataset.
    #'          See `sagemaker.inputs.TrainingInput` for full details.}
    #'     \item{(`sagemaker.amazon.amazon_estimator.RecordSet`) - A collection of Amazon
    #'          `Record` objects serialized and stored in S3. For use with an estimator
    #'          for an Amazon algorithm.}
    #'     \item{(list[`sagemaker.amazon.amazon_estimator.RecordSet`]) - A list of
    #'          `sagemaker.amazon.amazon_estimator.RecordSet` objects, where each instance
    #'          is a different channel of training data.}
    #' }
    #' @param s3_bucket (str): S3 bucket under which the output artifacts from
    #'              the training job will be stored. The parent path used is built
    #'              using the format: ``s3://{s3_bucket}/{pipeline_name}/models/{job_name}/``.
    #'              In this format, `pipeline_name` refers to the keyword argument provided
    #'              for TrainingPipeline. If a `pipeline_name` argument was not provided,
    #'              one is auto-generated by the pipeline as `training-pipeline-<timestamp>`.
    #'              Also, in the format, `job_name` refers to the job name provided
    #'              when calling the :meth:`TrainingPipeline.run()` method.
    #' @param client (SFN.Client, optional): \code{\link[paws]{sfn}} client to use for creating and
    #'              interacting with the training pipeline in Step Functions. (default: None)
    #' @param pipeline_name (str, optional): Name of the pipeline. This name will
    #'              be used to name jobs (if not provided when calling execute()),
    #'              models, endpoints, and S3 objects created by the pipeline. If
    #'              a `pipeline_name` argument was not provided, one is auto-generated
    #'              by the pipeline as `training-pipeline-<timestamp>`. (default:None)
    initialize = function(estimator,
                          role,
                          inputs,
                          s3_bucket,
                          client=NULL,
                          pipeline_name=NULL){
      self$estimator = estimator
      self$inputs = inputs
      self$pipeline_name = pipeline_name

      if (is.null(self$pipeline_name)){
        private$.pipeline_name_unique = TRUE
        self$pipeline_name = sprintf('training-pipeline-%s', private$.generate_timestamp())
      }
      self$definition = self$build_workflow_definition()
      self$input_template = private$.extract_input_template(self$definition)

      workflow = Workflow$new(
        name=self$pipeline_name,
        definition=self$definition,
        role=role,
        format_json=TRUE,
        client=client)

      super$initialize(s3_bucket=s3_bucket, workflow=workflow, role=role, client=client)
    },

    #' @description Build the workflow definition for the training pipeline with
    #'              all the states involved.
    #' @return :class:`~stepfunctions.steps.states.Chain`: Workflow definition as
    #'              a chain of states involved in the the training pipeline.
    build_workflow_definition = function(){
      default_name = self$pipeline_name

      instance_type = self$estimator$instance_type
      instance_count = self$estimator$instance_count

      training_step = TrainingStep$new(
        StepId$Train,
        estimator=self$estimator,
        job_name=paste0(default_name,'/estimator-source'),
        data=self$inputs
      )

      model = self$estimator$create_model()
      model_step = ModelStep$new(
        StepId$CreateModel,
        instance_type=instance_type,
        model=model,
        model_name=default_name
      )

      endpoint_config_step = EndpointConfigStep$new(
        StepId$ConfigureEndpoint,
        endpoint_config_name=default_name,
        model_name=default_name,
        initial_instance_count=instance_count,
        instance_type=instance_type
      )
      deploy_step = EndpointStep$new(
        StepId$Deploy,
        endpoint_name=default_name,
        endpoint_config_name=default_name,
      )
      return(Chain$new(list(training_step, model_step, endpoint_config_step, deploy_step)))
    },

    #' @description Run the training pipeline.
    #' @param job_name (str, optional): Name for the training job. If one is not
    #'              provided, a job name will be auto-generated. (default: None)
    #' @param hyperparameters (list, optional): Hyperparameters for the estimator
    #'              training. (default: None)
    #' @return :R:class:`~stepfunctions.workflow.Execution`: Running instance of
    #'              the training pipeline.
    execute = function(job_name=NULL,
                       hyperparameters=NULL){
      inputs = self$input_template

      if (!is.null(hyperparameters))
        inputs[[StepId$Train]][['HyperParameters']] = lapply(
          hyperparameters, function(v) as.character(v))

      if (is.null(job_name))
        job_name = sprintf('%s-%s','training-pipeline', private$.generate_timestamp())

      # Configure training and model
      inputs[[StepId$Train]][['TrainingJobName']] = paste0('estimator-', job_name)
      inputs[[StepId$Train]][['OutputDataConfig']][['S3OutputPath']] = (
        sprintf('s3://%s/%s/models',
        self$s3_bucket,
        self$workflow.name)
      )
      inputs[[StepId$CreateModel]][['ModelName']] = job_name

      # Configure endpoint
      inputs[[StepId$ConfigureEndpoint]][['EndpointConfigName']] = job_name
      for (variant in inputs[[StepId$ConfigureEndpoint]][['ProductionVariants']]){
        variant[['ModelName']] = job_name}
      inputs[[StepId$Deploy]][['EndpointConfigName']] = job_name
      inputs[[StepId$Deploy]][['EndpointName']] = job_name

      # Configure the path to model artifact
      inputs[[StepId$CreateModel]][['PrimaryContainer']][['ModelDataUrl']] = sprintf('%s/%s/output/model.tar.gz',
        inputs[[StepId$Train]][['OutputDataConfig']][['S3OutputPath']],
        inputs[[StepId$Train]][['TrainingJobName']]
      )
      return(self$workflow$execute(inputs=inputs, name=job_name))
    }

  ),
  private = list(
    .allowed_kwargs = 'pipeline_name'
  ),
  lock_objects=F
)
DyfanJones/aws-step-functions-data-science-sdk-r documentation built on Dec. 17, 2021, 5:31 p.m.