R/template_pipeline_common.R

# NOTE: This code has been modified from AWS Stepfunctions Python:
# https://github.com/aws/aws-step-functions-data-science-sdk-python/blob/main/src/stepfunctions/template/pipeline/common.py

#' @import R6

#' @include steps_states.R
#' @include template_utils.R
#' @include utils.R

StepId = Enum(
  Train             = 'Training',
  CreateModel       = 'Create Model',
  ConfigureEndpoint = 'Configure Endpoint',
  Deploy            = 'Deploy',

  TrainPreprocessor       = 'Train Preprocessor',
  CreatePreprocessorModel = 'Create Preprocessor Model',
  TransformInput          = 'Transform Input',
  CreatePipelineModel     = 'Create Pipeline Model'
)

#' @title WorkflowTemplate Class
#' @description Abstract class to create a template for Sagemaker workflows
WorkflowTemplate = R6Class("WorkflowTemplate",
  public = list(

    #' @param s3_bucket (str): S3 bucket under which the output artifacts from
    #'              the training job will be stored. The parent path used is built
    #'              using the format: ``s3://{s3_bucket}/{pipeline_name}/models/{job_name}/``.
    #'              In this format, `pipeline_name` refers to the keyword argument provided
    #'              for TrainingPipeline. If a `pipeline_name` argument was not provided,
    #'              one is auto-generated by the pipeline as `training-pipeline-<timestamp>`.
    #'              Also, in the format, `job_name` refers to the job name provided
    #'              when calling the :meth:`TrainingPipeline.run()` method.
    #' @param workflow : Workflow for AWS State Machine.
    #' @param role (str): An AWS IAM role (either name or full Amazon Resource Name (ARN)).
    #'              This role is used to create, manage, and execute the Step Functions workflows.
    #' @param client (SFN.Client, optional): \code{\link[paws]{sfn}} client to use for attaching the existing
    #'              workflow in Step Functions to the Workflow object. If not provided,
    #'              a default \code{\link[paws]{sfn}} client for Step Functions will be automatically
    #'              created and used. (default: None)
    #' @param ... : currently not implemented
    initialize = function(s3_bucket, workflow, role, client, ...){
      self$workflow = workflow
      self$role = role
      self$s3_bucket = s3_bucket
    },

    #' @description Renders a visualization of the workflow graph.
    #' @param portrait (bool, optional): Boolean flag set to `True` if the workflow
    #'              graph should be rendered in portrait orientation. Set to `False`,
    #'              if the graph should be rendered in landscape orientation. (default: False)
    render_graph = function(portrait=FALSE){
      return(self$workflow$render_graph(portrait=portrait))
    },

    #' @description Returns Workflow
    get_workflow = function(){
      return(self$workflow)
    },

    #' @description Build the workflow definition for the inference pipeline with
    #'              all the states involved.
    #' @return :class:`~stepfunctions.steps.states.Chain`: Workflow definition as
    #'              a chain of states involved in the the inference pipeline.
    build_workflow_definition = function(){
        stop("Not Implemented")
    },

    #' @description Creates the workflow on Step Functions.
    #' @return str: The Amazon Resource Name (ARN) of the workflow created. If the workflow
    #'              already existed, the ARN of the existing workflow is returned.
    create =function(){
      return(self$workflow$create())
    },

    #' @description Run the inference pipeline.
    #' @param ... : Not yet implemented
    #' @return :R:class:`~stepfunctions.workflow.Execution`: Running instance of the inference pipeline.
    execute = function(...){
      stop("Not Implemented")
    },

    #' @description format class
    format = function(){
      cls_fmt = "%s(s3_bucket='%s', workflow='%s', role='%s')"
      return(sprintf(cls_fmt, class(self)[1], self$bucket,
             self$workflow, self$role))
    }
  ),
  private = list(
    .generate_timestamp = function(){
      return(strftime(Sys.time(),'%Y-%m-%d-%H-%M-%S'))
    },

    .extract_input_template = function(definition){
      input_template = list()

      for (step in definition$steps){
        if (inherits(step, "Task")){
          input_template[[step$state_id]] = step$parameters
          step$update_parameters(replace_parameters_with_context_object(step))
        }
      }
      return(input_template)
    }
  ),
  lock_objects=F
)
DyfanJones/aws-step-functions-data-science-sdk-r documentation built on Dec. 17, 2021, 5:31 p.m.