R/job.R

# NOTE: This code has been modified from AWS Sagemaker Python:
# https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/job.py

#' @include r_utils.R

#' @import R6
#' @import sagemaker.core

#' @title .Job Class
#' @description Handle creating, starting and waiting for Amazon SageMaker jobs to
#'              finish.
#' @note This class shouldn't be directly instantiated.
#'              Subclasses must define a way to create, start and wait for an Amazon
#'              SageMaker job.
#' @keywords internal
#' @export
.Job = R6Class(".Job",
  public = list(

    #' @field sagemaker_session
    #' Sagemaker Session Class
    sagemaker_session = NULL,

    #' @field job_name
    #' name of job
    job_name = NULL,

    #' @description Base class initializer. Subclasses which override ``__init__`` should
    #'              invoke ``super()``
    #' @param sagemaker_session (sagemaker.session.Session): Session object which
    #'              manages interactions with Amazon SageMaker APIs and any other
    #'              AWS services needed. If not specified, the estimator creates one
    #'              using the default AWS configuration chain.
    #' @param job_name (str): Prefix for training job name
    initialize = function(sagemaker_session = NULL,
                          job_name = NULL){
      self$sagemaker_session = sagemaker_session
      self$job_name = job_name
    },

    #' @description Create a new Amazon SageMaker job from the estimator.
    #' @param estimator (sagemaker.estimator.EstimatorBase): Estimator object
    #'              created by the user.
    #' @param inputs (str): Parameters used when called
    #'              :meth:`~sagemaker.estimator.EstimatorBase.fit`.
    #' @return sagemaker.job: Constructed object that captures all information
    #'              about the started job.
    start_new = function(estimator,
                         inputs){
      NotImplementedError$new("I'm an abstract interface method")
    },

    #' @description Wait for the Amazon SageMaker job to finish.
    wait = function(){NotImplementedError$new("I'm an abstract interface method")},

    #' @description Describe the job.
    describe = function(){NotImplementedError$new("I'm an abstract interface method")},

    #' @description Stop the job.
    stop = function(){NotImplementedError$new("I'm an abstract interface method")},

    #' @description format class
    format = function(){
      return(format_class(self))
    }
  ),
  private = list(
    .load_config = function(inputs,
                           estimator,
                           expand_role=TRUE,
                           validate_uri=TRUE){
      input_config = private$.format_inputs_to_input_config(inputs, validate_uri)
      role = if (expand_role) estimator$sagemaker_session$expand_role(estimator$role) else estimator$role


      output_config = private$.prepare_output_config(estimator$output_path, estimator$output_kms_key)

      resource_config = private$.prepare_resource_config(
        estimator$instance_count,
        estimator$instance_type,
        estimator$volume_size,
        estimator$volume_kms_key)

      stop_condition = private$.prepare_stop_condition(
        estimator$max_run,
        estimator$max_wait)

      vpc_config = estimator$get_vpc_config()

      model_channel = private$.prepare_channel(
        input_config,
        estimator$model_uri,
        estimator$model_channel_name,
        validate_uri,
        content_type="application/x-sagemaker-model",
        input_mode="File")

      if (!islistempty(model_channel)){
        input_config = if (is.null(input_config)) list() else input_config
        input_config = list.append(input_config, model_channel)
      }
      if (estimator$enable_network_isolation()){
        code_channel = private$.prepare_channel(input_config, estimator$code_uri, estimator$code_channel_name, validate_uri)
        if (!islistempty(code_channel)) {
          input_config = if (is.null(input_config)) list() else input_config
          input_config = list.append(input_config, code_channel)}
      }
      return (list(
        "input_config"= input_config,
        "role"= role,
        "output_config"= output_config,
        "resource_config"= resource_config,
        "stop_condition"= stop_condition,
        "vpc_config"= vpc_config)
      )
    },

    .format_inputs_to_input_config = function(inputs = NULL,
                                              validate_uri=TRUE){
      if (is.null(inputs))
        return(NULL)

      if (inherits(inputs, c("RecordSet", "FileSystemRecordSet")))
        inputs = inputs$data_channel()

      input_dict = list()
      if (inherits(inputs, "character")){
        input_dict[["training"]] = private$.format_string_uri_input(inputs, validate_uri)
      } else if (inherits(inputs, "TrainingInput")){
        input_dict[["training"]] = inputs
      } else if (inherits(inputs, "file_input")){
        input_dict[["training"]] = inputs
      } else if (is_list_named(inputs)){
        for (v in 1:length(inputs)){
          input_dict[[names(inputs)[v]]] = private$.format_string_uri_input(inputs[[v]], validate_uri)
        }
      } else if (is.list(inputs)){
        input_dict = private$.format_record_set_list_input(inputs)
      } else if (inherits(inputs, "FileSystemInput")) {
        input_dict[["training"]] = inputs
      } else {
        msg = "Cannot format input %s. Expecting one of str, dict, TrainingInput or FileSystemInput"
        ValueError$new(sprintf(msg, inputs))
      }
      channels = lapply(names(input_dict), function(x) private$.convert_input_to_channel(x, input_dict[[x]]))
      return(channels)
    },

    .convert_input_to_channel = function(channel_name,
                                         channel_s3_input){
      channel_config = channel_s3_input$config
      channel_config[["ChannelName"]] = channel_name
      return(channel_config)
    },

    .prepare_output_config = function(s3_path, kms_key_id){
      config = list("S3OutputPath"= s3_path)

      if (!is.null(kms_key_id) || length(kms_key_id) == 0)
        config[["KmsKeyId"]] = kms_key_id
      return(config)
    },

    .format_string_uri_input = function(
      uri_input,
      validate_uri=TRUE,
      content_type=NULL,
      input_mode=NULL,
      compression=NULL,
      target_attribute_name=NULL){
      if (is.character(uri_input) && validate_uri && startsWith(uri_input,"s3://")){
        s3_input_result = TrainingInput$new(
          uri_input,
          content_type=content_type,
          input_mode=input_mode,
          compression=compression,
          target_attribute_name=target_attribute_name)
        return(s3_input_result)
      }
      if (is.character(uri_input) && validate_uri && startsWith(uri_input,"file://")){
        file_input = pkg_method("file_input", "sagemaker.core")
        return (file_input$new(uri_input))
      }
      if (inherits(uri_input, "character") && validate_uri)
        ValueError$new(sprintf(
          'URI input %s must be a valid S3 or FILE URI: must start with "s3://" or "file://"',
          uri_input)
        )
      if (inherits(uri_input, "character")){
          s3_input_result = TrainingInput$new(
            uri_input,
            content_type=content_type,
            input_mode=input_mode,
            compression=compression,
            target_attribute_name=target_attribute_name)
        return(s3_input_result)}
      if (inherits(uri_input, c("TrainingInput", "file_input", "FileSystemInput")))
          return(uri_input)

      ValueError$new(sprintf(
        "Cannot format input %s. Expecting one of str, TrainingInput, file_input or FileSystemInput",
        uri_input)
      )
    },

    .prepare_channel = function(input_config,
                                channel_uri=NULL,
                                channel_name=NULL,
                                validate_uri=TRUE,
                                content_type=NULL,
                                input_mode=NULL){
      if (is.null(channel_uri))
        return(NULL)
      if (is.null(channel_name))
        ValueError$new(sprintf(
          "Expected a channel name if a channel URI %s is specified",channel_uri)
        )

      if (!is.null(input_config) || length(input_config) > 0){
        for (existing_channel in input_config){
          if (existing_channel[["ChannelName"]] == channel_name)
            ValueError$new(sprintf("Duplicate channel %s not allowed.",channel_name))
        }
      }

      channel_input = private$.format_string_uri_input(channel_uri, validate_uri, content_type, input_mode)
      channel = private$.convert_input_to_channel(channel_name, channel_input)

      return(channel)
    },

    .format_record_set_list_input = function(inputs){
      # Deferred import due to circular dependency
      #from sagemaker.amazon.amazon_estimator import FileSystemRecordSet, RecordSet

      input_dict = list()
      for (record in inputs){
        if (!inherits(record, c("RecordSet", "FileSystemRecordSet")))
          ValueError$new("List compatible only with RecordSets or FileSystemRecordSets.")

        if (record$channel %in% names(input_dict))
          ValueError$new("Duplicate channels not allowed.")
        if (inherits(record, "RecordSet"))
          input_dict[[record$channel]] = record$records_s3_input()
        if (inherits(record, "FileSystemRecordSet"))
          input_dict[[record$channel]] = record$file_system_input
      }
      return (input_dict)
    },

    .prepare_resource_config = function(instance_count,
                                        instance_type,
                                        volume_size,
                                        volume_kms_key = NULL){
      resource_config = list("InstanceCount"= instance_count,
                             "InstanceType"= instance_type,
                             "VolumeSizeInGB"= volume_size)

      if (!is.null(volume_kms_key) || length(volume_kms_key) == 0)
        resource_config[["VolumeKmsKeyId"]] = volume_kms_key

        return(resource_config)
    },

    .prepare_stop_condition = function(max_run,
                                       max_wait){
      if (!is.null(max_wait))
        return(list("MaxRuntimeInSeconds"= max_run, "MaxWaitTimeInSeconds"= max_wait))
      return(list("MaxRuntimeInSeconds"= max_run))
    }
  ),
  active = list(
    #' @field name
    #' Returns job name
    name = function(){
      return(self$job_name)
    }
  ),
  lock_objects = F
)
DyfanJones/sagemaker-r-common documentation built on June 14, 2022, 10:31 p.m.