R/workflow_airflow.R

Defines functions input_output_list_converter processing_config deploy_config_from_estimator deploy_config transform_config_from_estimator transform_config model_config_from_estimator model_config prepare_framework_container_def update_estimator_from_task update_submit_s3_uri .merge_s3_operations .extract_training_config_list_from_estimator_list .extract_training_config_from_estimator .extract_tuning_job_config tuning_config training_config training_base_config prepare_amazon_algorithm_estimator prepare_framework

Documented in deploy_config deploy_config_from_estimator input_output_list_converter model_config model_config_from_estimator prepare_amazon_algorithm_estimator prepare_framework prepare_framework_container_def processing_config training_base_config training_config transform_config transform_config_from_estimator tuning_config update_estimator_from_task update_submit_s3_uri

# NOTE: This code has been modified from AWS Sagemaker Python:
# https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/workflow/airflow.py

#' @include r_utils.R

#' @import fs
#' @import sagemaker.core
#' @import sagemaker.common
#' @import sagemaker.mlcore
#' @importFrom utils modifyList

#' @title Prepare S3 operations and environment variables related to framework.
#' @description S3 operations specify where to upload `source_dir`.
#' @param estimator (sagemaker.estimator.Estimator): The framework estimator to
#'              get information from and update.
#' @param s3_operations (list): The dict to specify s3 operations (upload
#'              `source_dir`).
#' @export
prepare_framework = function(estimator,
                             s3_operations){
  s3_split = list()
  if (!is.null(estimator$code_location)){
    s3_split = parse_s3_url(estimator$code_location)
    s3_split$key = as.character(fs::path(s3_split$key, estimator$.current_job_name, "source", "sourcedir.tar.gz"))
  } else if (!is.null(estimator$uploaded_code)){
    s3_split = parse_s3_url(estimator$uploaded_code$s3_prefix)
  } else {
    s3_split$bucket = estimator$sagemaker_session$.__enclos_env__$private$.default_bucket
    s3_split$key = file.path(estimator$.current_job_name, "source", "sourcedir.tar.gz")
  }

  script = basename(estimator$entry_point)

  if (!is.null(estimator$source_dir) && grepl("^s3://", tolower(estimator$source_dir))){
    code_dir = estimator$source_dir
    UploadedCode$s3_prefix=code_dir
    UploadedCode$script_name= script
    estimator$uploaded_code = UploadedCode
  } else {
    code_dir = sprintf("s3://%s/%s", s3_split$bucket, s3_split$key)
    UploadedCode$s3_prefix=code_dir
    UploadedCode$script_name= script
    estimator$uploaded_code = UploadedCode
    ll = deparse(substitute(s3_operations))
    s3_operations[["S3Upload"]] = list(
      list(
        "Path"=(estimator$source_dir %||% estimator$entry_point),
        "Bucket"=s3_split$bucket,
        "Key"=s3_split$key,
        "Tar"=TRUE)
    )
    assign(ll, s3_operations, envir = parent.frame())
  }
  estimator$.hyperparameters[[model_parameters$DIR_PARAM_NAME]] = code_dir
  estimator$.hyperparameters[[model_parameters$SCRIPT_PARAM_NAME]] = script
  estimator$.hyperparameters[[
    model_parameters$CONTAINER_LOG_LEVEL_PARAM_NAME
  ]] = estimator$container_log_level
  estimator$.hyperparameters[[model_parameters$JOB_NAME_PARAM_NAME]] = estimator$.current_job_name
  estimator$.hyperparameters[[
    model_parameters$SAGEMAKER_REGION_PARAM_NAME
  ]] = estimator$sagemaker_session$paws_region_name
}

#' @title Sets up amazon algorithm estimator.
#' @description This is done by adding the required `feature_dim` hyperparameter from training data.
#' @param estimator (sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase): An estimator
#'              for a built-in Amazon algorithm to get information from and update.
#' @param inputs : The training data.
#'              * (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of
#'              Amazon :class:~`Record` objects serialized and stored in S3. For
#'              use with an estimator for an Amazon algorithm.
#'              * (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
#'              :class:~`sagemaker.amazon.amazon_estimator.RecordSet` objects,
#'              where each instance is a different channel of training data.
#' @param mini_batch_size (numeric):
#' @export
prepare_amazon_algorithm_estimator = function(estimator,
                                              inputs,
                                              mini_batch_size=NULL){
  if (is.list(inputs)){
    for (record in inputs){
      if (inherits(record, "RecordSet") && record$channel == "train"){
          estimator$feature_dim = record$feature_dim
          break
        }
    }
  } else if (inherits(inputs, "RecordSet")){
      estimator$feature_dim = inputs$feature_dim
  } else {
    TypeError$new("Training data must be represented in RecordSet or list of RecordSets")
  }
  estimator$mini_batch_size = mini_batch_size
}

#' @title Export Airflow base training config from an estimator
#' @param estimator (sagemaker.estimator.EstimatorBase): The estimator to export
#'              training config from. Can be a BYO estimator, Framework estimator or
#'              Amazon algorithm estimator.
#' @param inputs : Information about the training data. Please refer to the ``fit()``
#'              method of
#'              the associated estimator, as this can take any of the following
#'              forms:
#'              * (str) - The S3 location where training data is saved.
#'              * (dict[str, str] or dict[str, sagemaker.inputs.TrainingInput]) - If using multiple
#'              channels for training data, you can specify a dict mapping channel names to
#'              strings or :func:`~sagemaker.inputs.TrainingInput` objects.
#'              * (sagemaker.inputs.TrainingInput) - Channel configuration for S3 data sources that can
#'              provide additional information about the training dataset. See
#'              :func:`sagemaker.inputs.TrainingInput` for full details.
#'              * (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of
#'              Amazon :class:~`Record` objects serialized and stored in S3.
#'              For use with an estimator for an Amazon algorithm.
#'              * (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
#'              :class:~`sagemaker.amazon.amazon_estimator.RecordSet` objects,
#'              where each instance is a different channel of training data.
#' @param job_name (str): Specify a training job name if needed.
#' @param mini_batch_size (int): Specify this argument only when estimator is a
#'              built-in estimator of an Amazon algorithm. For other estimators,
#'              batch size should be specified in the estimator.
#' @return dict: Training config that can be directly used by
#'              SageMakerTrainingOperator in Airflow.
#' @export
training_base_config = function(estimator,
                                inputs=NULL,
                                job_name=NULL,
                                mini_batch_size=NULL){
  if (inherits(estimator, "AmazonAlgorithmEstimatorBase")){
    estimator$prepare_workflow_for_training(
      records=inputs, mini_batch_size=mini_batch_size, job_name=job_name
    )
  } else {
    estimator$prepare_workflow_for_training(job_name=job_name)
  }
  s3_operations = list()

  if (!is.null(job_name)){
    estimator$.current_job_name = job_name
  } else {
    base_name = estimator$base_job_name %||% sagemaker.core::base_name_from_image(
      estimator$training_image_uri()
    )
    estimator$.current_job_name = sagemaker.core::name_from_base(base_name)
  }
  if (is.null(estimator$output_path)){
    default_bucket = estimator$sagemaker_session$default_bucket()
    estimator$output_path = sprintf("s3://%s/",default_bucket)
  }
  if (inherits(estimator, "Framework")){
    prepare_framework(estimator, s3_operations)
  } else if (inherits(estimator, "AmazonAlgorithmEstimatorBase")){
    prepare_amazon_algorithm_estimator(estimator, inputs, mini_batch_size)
  }
  job_config = .Job$new()$.__enclos_env__$private$.load_config(
    inputs, estimator, expand_role=FALSE, validate_uri=FALSE)

  train_config = list(
    "AlgorithmSpecification"=list(
      "TrainingImage"=estimator$training_image_uri(),
      "TrainingInputMode"=estimator$input_mode
      ),
    "OutputDataConfig"=job_config[["output_config"]],
    "StoppingCondition"=job_config[["stop_condition"]],
    "ResourceConfig"=job_config[["resource_config"]],
    "RoleArn"=job_config[["role"]]
  )

  train_config[["InputDataConfig"]] = job_config[["input_config"]]
  train_config[["VpcConfig"]] = job_config[["vpc_config"]]

  if (isTRUE(estimator$use_spot_instances))
    train_config[["EnableManagedSpotTraining"]] = TRUE

  if (!islistempty(estimator$hyperparameters())){
    train_config[["HyperParameters"]] = lapply(estimator$hyperparameters(), as.character)
  }
  if (!islistempty(s3_operations))
    train_config[["S3Operations"]] = s3_operations

  if (!is.null(estimator$checkpoint_local_path) && !is.null(estimator$checkpoint_s3_uri)){
    train_config[["CheckpointConfig"]] = list(
      "LocalPath"=estimator$checkpoint_local_path,
      "S3Uri"=estimator$checkpoint_s3_uri
    )
  }
  return(train_config)
}

#' @title Export Airflow training config from an estimator
#' @param estimator (sagemaker.estimator.EstimatorBase): The estimator to export
#'              training config from. Can be a BYO estimator, Framework estimator or
#'              Amazon algorithm estimator.
#' @param inputs : Information about the training data. Please refer to the ``fit()``
#'              method of the associated estimator, as this can take any of the following forms:
#'              * (str) - The S3 location where training data is saved.
#'              * (dict[str, str] or dict[str, sagemaker.inputs.TrainingInput]) - If using multiple
#'              channels for training data, you can specify a dict mapping channel names to
#'              strings or :func:`~sagemaker.inputs.TrainingInput` objects.
#'              * (sagemaker.inputs.TrainingInput) - Channel configuration for S3 data sources that can
#'              provide additional information about the training dataset. See
#'              :func:`sagemaker.inputs.TrainingInput` for full details.
#'              * (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of
#'              Amazon :class:~`Record` objects serialized and stored in S3.
#'              For use with an estimator for an Amazon algorithm.
#'              * (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
#'              :class:~`sagemaker.amazon.amazon_estimator.RecordSet` objects,
#'              where each instance is a different channel of training data.
#' @param job_name (str): Specify a training job name if needed.
#' @param mini_batch_size (int): Specify this argument only when estimator is a
#'              built-in estimator of an Amazon algorithm. For other estimators,
#'              batch size should be specified in the estimator.
#' @return list: Training config that can be directly used by
#'              SageMakerTrainingOperator in Airflow.
#' @export
training_config = function(estimator,
                           inputs=NULL,
                           job_name=NULL,
                           mini_batch_size=NULL){
  train_config = training_base_config(estimator, inputs, job_name, mini_batch_size)
  train_config[["TrainingJobName"]] = estimator$.current_job_name

  if (!is.null(estimator$tags))
    train_config[["Tags"]] = estimator$tags

  if (!is.null(estimator$metric_definitions))
    train_config[["AlgorithmSpecification"]][["MetricDefinitions"]] = estimator$metric_definitions

  return(train_config)
}

#' @title Export Airflow tuning config from a HyperparameterTuner
#' @param tuner (sagemaker.tuner.HyperparameterTuner): The tuner to export tuning
#'              config from.
#' @param inputs : Information about the training data. Please refer to the ``fit()``
#'              method of the associated estimator in the tuner, as this can take any of the
#'              following forms:
#'              * (str) - The S3 location where training data is saved.
#'              * (dict[str, str] or dict[str, sagemaker.inputs.TrainingInput]) - If using multiple
#'              channels for training data, you can specify a dict mapping channel names to
#'              strings or :func:`~sagemaker.inputs.TrainingInput` objects.
#'              * (sagemaker.inputs.TrainingInput) - Channel configuration for S3 data sources that can
#'              provide additional information about the training dataset. See
#'              :func:`sagemaker.inputs.TrainingInput` for full details.
#'              * (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of
#'              Amazon :class:~`Record` objects serialized and stored in S3.
#'              For use with an estimator for an Amazon algorithm.
#'              * (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
#'              :class:~`sagemaker.amazon.amazon_estimator.RecordSet` objects,
#'              where each instance is a different channel of training data.
#'              * (dict[str, one the forms above]): Required by only tuners created via
#'              the factory method ``HyperparameterTuner.create()``. The keys should be the
#'              same estimator names as keys for the ``estimator_list`` argument of the
#'              ``HyperparameterTuner.create()`` method.
#' @param job_name (str): Specify a tuning job name if needed.
#' @param include_cls_metadata : It can take one of the following two forms.
#'              * (bool) - Whether or not the hyperparameter tuning job should include information
#'              about the estimator class (default: False). This information is passed as a
#'              hyperparameter, so if the algorithm you are using cannot handle unknown
#'              hyperparameters (e.g. an Amazon SageMaker built-in algorithm that does not
#'              have a custom estimator in the Python SDK), then set ``include_cls_metadata``
#'              to ``False``.
#'              * (dict[str, bool]) - This version should be used for tuners created via the factory
#'              method ``HyperparameterTuner.create()``, to specify the flag for individual
#'              estimators provided in the ``estimator_list`` argument of the method. The keys
#'              would be the same estimator names as in ``estimator_list``. If one estimator
#'              doesn't need the flag set, then no need to include it in the dictionary. If none
#'              of the estimators need the flag set, then an empty dictionary ``{}`` must be used.
#' @param mini_batch_size : It can take one of the following two forms.
#'              * (int) - Specify this argument only when estimator is a built-in estimator of an
#'              Amazon algorithm. For other estimators, batch size should be specified in the
#'              estimator.
#'              * (dict[str, int]) - This version should be used for tuners created via the factory
#'              method ``HyperparameterTuner.create()``, to specify the value for individual
#'              estimators provided in the ``estimator_list`` argument of the method. The keys
#'              would be the same estimator names as in ``estimator_list``. If one estimator
#'              doesn't need the value set, then no need to include it in the dictionary. If
#'              none of the estimators need the value set, then an empty dictionary ``{}``
#'              must be used.
#' @return list: Tuning config that can be directly used by SageMakerTuningOperator in Airflow.
#' @export
tuning_config = function(tuner,
                         inputs,
                         job_name=NULL,
                         include_cls_metadata=FALSE,
                         mini_batch_size=NULL){
  tuner$.__enclos_env__$private$.prepare_job_name_for_tuning(job_name=job_name)

  tune_config = list(
    "HyperParameterTuningJobName"=tuner$.current_job_name,
    "HyperParameterTuningJobConfig"= .extract_tuning_job_config(tuner)
  )

  if (!is.null(tuner$estimator)){
    ll = .extract_training_config_from_estimator(
      tuner, inputs, include_cls_metadata, mini_batch_size)
    tune_config[["TrainingJobDefinition"]] = ll$training_job_def
    s3_operations = ll$s3_operations
  } else {
    ll = .extract_training_config_list_from_estimator_list(
      tuner, inputs, include_cls_metadata, mini_batch_size)
    tune_config[["TrainingJobDefinitions"]] = ll$training_job_def
    s3_operations = ll$s3_operations
  }

  if (!islistempty(s3_operations))
    tune_config[["S3Operations"]] = s3_operations

  if (!islistempty(tuner$tags))
    tune_config[["Tags"]] = tuner$tags

  if (!islistempty(tuner$warm_start_config))
    tune_config[["WarmStartConfig"]] = tuner$warm_start_config$to_input_req()

  return(tune_config)
}

# Extract tuning job config from a HyperparameterTuner
.extract_tuning_job_config = function(tuner){
  tuning_job_config = list(
    "Strategy"=tuner$strategy,
    "ResourceLimits"=list(
      "MaxNumberOfTrainingJobs"=tuner$max_jobs,
      "MaxParallelTrainingJobs"=tuner$max_parallel_jobs),
    "TrainingJobEarlyStoppingType"=tuner$early_stopping_type
  )

  if (!islistempty(tuner$objective_metric_name))
    tuning_job_config[["HyperParameterTuningJobObjective"]] = list(
      "Type"=tuner$objective_type,
      "MetricName"=tuner$objective_metric_name
    )

  parameter_ranges = tuner$hyperparameter_ranges()
  if (!islistempty(parameter_ranges))
    tuning_job_config[["ParameterRanges"]] = parameter_ranges

  return(tuning_job_config)
}

# Extract training job config from a HyperparameterTuner that uses the ``estimator`` field
.extract_training_config_from_estimator = function(tuner,
                                                   inputs,
                                                   include_cls_metadata,
                                                   mini_batch_size){
  train_config = training_base_config(tuner$estimator, inputs, mini_batch_size)
  train_config[["HyperParameters"]] = NULL

  tuner$.__enclos_env__$private$.prepare_static_hyperparameters_for_tuning(
    include_cls_metadata=include_cls_metadata)
  train_config[["StaticHyperParameters"]] = tuner$static_hyperparameters

  if (!islistempty(tuner$metric_definitions))
    train_config[["AlgorithmSpecification"]][["MetricDefinitions"]] = tuner$metric_definitions

  s3_operations = train_config[["S3Operations"]]
  train_config[["S3Operations"]] = NULL

  return(list(training_job_def = train_config, s3_operations = s3_operations))
}

# Extracts a list of training job configs from a Hyperparameter Tuner.
# It uses the ``estimator_list`` field.
.extract_training_config_list_from_estimator_list = function(tuner,
                                                             inputs,
                                                             include_cls_metadata,
                                                             mini_batch_size){
  estimator_names = sort(names(tuner$estimator_list))
  tuner$.__enclos_env__$private$.validate_list_argument(
    name="inputs", value=inputs, allowed_keys=estimator_names)
  tuner$.__enclos_env__$private$.validate_list_argument(
    name="include_cls_metadata", value=include_cls_metadata, allowed_keys=estimator_names
  )
  tuner$.__enclos_env__$private$.validate_list_argument(
    name="mini_batch_size", value=mini_batch_size, allowed_keys=estimator_names
  )

  train_config_dict = list()
  for (estimator_name in names(tuner$estimator_list)){
    estimator = tuner$estimator_list[[estimator_name]]
    train_config_dict[[estimator_name]] = training_base_config(
      estimator=estimator,
      inputs=(if(!islistempty(inputs)) inputs[[estimator_name]] else NULL),
      mini_batch_size=(if(!islistempty(mini_batch_size)) mini_batch_size[[estimator_name]] else NULL)
    )
  }

  tuner$.__enclos_env__$private$.prepare_static_hyperparameters_for_tuning(
    include_cls_metadata=include_cls_metadata)

  train_config_list = list()
  s3_operations_list = list()

  for (estimator_name in sort(names(train_config_dict))){
    train_config = train_config_dict[[estimator_name]]
    train_config[["HyperParameters"]]=NULL
    train_config[["StaticHyperParameters"]] = tuner$static_hyperparameters_list[[estimator_name]]

    train_config[["AlgorithmSpecification"]][[
      "MetricDefinitions"
    ]] = tuner$metric_definitions_list[[estimator_name]]

    train_config[["DefinitionName"]] = estimator_name
    train_config[["TuningObjective"]] = list(
      "Type"=tuner$objective_type,
      "MetricName"=tuner$objective_metric_name_list[[estimator_name]]
    )
    train_config[["HyperParameterRanges"]] = tuner$hyperparameter_ranges_list()[[estimator_name]]

    s3_operations_list = list.append(s3_operations_list, (train_config[["S3Operations"]] %||% list()))
    train_config[["S3Operations"]] = NULL

    train_config_list = list.append(train_config_list, train_config)
  }
  return(list(training_job_def = train_config_list, s3_operations = .merge_s3_operations(s3_operations_list)))
}

# Merge a list of S3 operation dictionaries into one
.merge_s3_operations = function(s3_operations_list){
  s3_operations_merged =list()
  for (s3_operations in s3_operations_list){
    for (key in names(s3_operations)){
      operations = s3_operations[[key]]
      if (is.null(names(s3_operations_merged)) || !(key %in% names(s3_operations_merged)))
        s3_operations_merged[[key]] = list()

      for (operation in operations){
        if (!list.exist.in(operation, s3_operations_merged[[key]]))
          s3_operations_merged[[key]] = list.append(s3_operations_merged[[key]], operation)
      }
    }
  }
  return(s3_operations_merged)
}

#' @title Updated the S3 URI of the framework source directory in given estimator.
#' @param estimator (sagemaker.estimator.Framework): The Framework estimator to
#'              update.
#' @param job_name (str): The new job name included in the submit S3 URI
#' @return str: The updated S3 URI of framework source directory
#' @export
update_submit_s3_uri=function(estimator, job_name){
  if (islistempty(estimator$uploaded_code))
    return(NULL)

  pattern = "(?<=/)[^/]+?(?=/source/sourcedir.tar.gz)"

  # update the S3 URI with the latest training job.
  # s3://path/old_job/source/sourcedir.tar.gz will become s3://path/new_job/source/sourcedir.tar.gz
  submit_uri = estimator$uploaded_code$s3_prefix
  submit_uri = gsub(pattern, job_name, submit_uri, perl = T)
  script_name = estimator$uploaded_code$script_name
  UploadedCode$s3_prefix=submit_uri
  UploadedCode$script_name=script_name
  estimator$uploaded_code = UploadedCode
}

#' @title Update training job of the estimator from a task in the DAG
#' @param estimator (sagemaker.estimator.EstimatorBase): The estimator to update
#' @param task_id (str): The task id of any
#'              airflow.contrib.operators.SageMakerTrainingOperator or
#'              airflow.contrib.operators.SageMakerTuningOperator that generates
#'              training jobs in the DAG.
#' @param task_type (str): Whether the task is from SageMakerTrainingOperator or
#'              SageMakerTuningOperator. Values can be 'training', 'tuning' or None
#'              (which means training job is not from any task).
#' @export
update_estimator_from_task = function(estimator,
                                      task_id,
                                      task_type){
  if (is.null(task_type))
    return(NULL)
  if (tolower(task_type) == "training"){
    training_job = sprintf(
      "{{ ti.xcom_pull(task_ids='%s')['Training']['TrainingJobName'] }}", task_id)
    job_name = training_job
  } else if (tolower(task_type) == "tuning"){
    training_job = sprintf(
      "{{ ti.xcom_pull(task_ids='%s')['Tuning']['BestTrainingJob']['TrainingJobName'] }}",
      task_id)
    # need to strip the double quotes in json to get the string
    job_name = sprintf(paste0(
      "{{ ti.xcom_pull(task_ids='%s')['Tuning']['TrainingJobDefinition']",
      "['StaticHyperParameters']['sagemaker_job_name'].strip('%s') }}"), task_id, '"')
  } else {
    ValueError$new("task_type must be either 'training', 'tuning' or None.")}
  estimator$.current_job_name = training_job
  if (inherits(estimator, "Framework"))
    update_submit_s3_uri(estimator, job_name)
}

#' @title This prepares the framework model container information and specifies related S3 operations.
#' @description Prepare the framework model container information. Specify related S3
#'              operations for Airflow to perform. (Upload `source_dir` )
#' @param model (sagemaker.model.FrameworkModel): The framework model
#' @param instance_type (str): The EC2 instance type to deploy this Model to. For
#'              example, 'ml.p2.xlarge'.
#' @param s3_operations (dict): The dict to specify S3 operations (upload
#'              `source_dir` ).
#' @return dict: The container information of this framework model.
#' @export
prepare_framework_container_def = function(model,
                                           instance_type,
                                           s3_operations){
  deploy_image = model$image_uri
  if (islistempty(deploy_image)){
    region_name = model$sagemaker_session$paws_region_name
    deploy_image = model$serving_image_uri(region_name, instance_type)
  }
  base_name = sagemaker.core::base_name_from_image(deploy_image)
  model$name = model$name %||% sagemaker.core::name_from_base(base_name)

  bucket = model$bucket %||% model$sagemaker_session$.__enclos_env__$private$.default_bucket
  if (!is.null(model$entry_point)){
    script = basename(model$entry_point)
    key = sprintf("%s/source/sourcedir.tar.gz", model$name)

    if (!islistempty(model$source_dir) && grepl("^s3://", tolower(model$source_dir))){
      code_dir = model$source_dir
      UploadedCode$s3_prefix=code_dir
      UploadedCode$script_name= script
      model$uploaded_code = UploadedCode
    } else {
      code_dir = sprintf("s3://%s/%s", bucket, key)
      UploadedCode$s3_prefix=code_dir
      UploadedCode$script_name= script
      model$uploaded_code = UploadedCode
      ll = deparse(substitute(s3_operations))
      s3_operations[["S3Upload"]] = list(
        list("Path"=(model$source_dir %||% script), "Bucket"=bucket, "Key"=key, "Tar"=TRUE)
      )
      assign(ll, s3_operations, envir = parent.frame())
    }
  }
  deploy_env = as.list(model$env)
  deploy_env = modifyList(deploy_env, model$.__enclos_env__$private$.framework_env_vars())

  tryCatch({
    if (!islistempty(model$model_server_workers))
      deploy_env[[toupper(model_parameters$MODEL_SERVER_WORKERS_PARAM_NAME)]] = as.character(
        model$model_server_workers)
  }, error = function(e) {
    # This applies to a FrameworkModel which is not SageMaker Deep Learning Framework Model
    NULL
  })

  return (container_def(deploy_image, model$model_data, deploy_env))
}

#' @title Export Airflow model config from a SageMaker model
#' @param model (sagemaker.model.Model): The Model object from which to export the Airflow config
#' @param instance_type (str): The EC2 instance type to deploy this Model to. For
#'              example, 'ml.p2.xlarge'
#' @param role (str): The ``ExecutionRoleArn`` IAM Role ARN for the model
#' @param image_uri (str): An Docker image URI to use for deploying the model
#' @return dict: Model config that can be directly used by SageMakerModelOperator
#'              in Airflow. It can also be part of the config used by
#'              SageMakerEndpointOperator and SageMakerTransformOperator in Airflow.
#' @export
model_config = function(model,
                        instance_type=NULL,
                        role=NULL,
                        image_uri=NULL){
  s3_operations = list()
  model$image_uri = image_uri %||% model$image_uri

  if (inherits(model, "FrameworkModel")){
    container_def = prepare_framework_container_def(model, instance_type, s3_operations)
  } else {
    container_def = model$prepare_container_def()
    base_name = sagemaker.core::base_name_from_image(container_def[["Image"]])
    model$name = model$name %||% sagemaker.core::name_from_base(base_name)
  }
  primary_container = sagemaker.core::Session$private_methods$.expand_container_def(container_def)

  config = list(
    "ModelName"=model$name,
    "PrimaryContainer"=primary_container,
    "ExecutionRoleArn"=(role %||% model$role)
  )
  if (!islistempty(model$vpc_config))
    config[["VpcConfig"]] = model$vpc_config

  if (!islistempty(s3_operations))
    config[["S3Operations"]] = s3_operations

  return(config)
}

#' @title Export Airflow model config from a SageMaker estimator
#' @param estimator (sagemaker.model.EstimatorBase): The SageMaker estimator to
#'              export Airflow config from. It has to be an estimator associated
#'              with a training job.
#' @param task_id (str): The task id of any
#'              airflow.contrib.operators.SageMakerTrainingOperator or
#'              airflow.contrib.operators.SageMakerTuningOperator that generates
#'              training jobs in the DAG. The model config is built based on the
#'              training job generated in this operator.
#' @param task_type (str): Whether the task is from SageMakerTrainingOperator or
#'              SageMakerTuningOperator. Values can be 'training', 'tuning' or None
#'              (which means training job is not from any task).
#' @param instance_type (str): The EC2 instance type to deploy this Model to. For
#'              example, 'ml.p2.xlarge'
#' @param role (str): The ``ExecutionRoleArn`` IAM Role ARN for the model
#' @param image_uri (str): A Docker image URI to use for deploying the model
#' @param name (str): Name of the model
#' @param model_server_workers (int): The number of worker processes used by the
#'              inference server. If None, server will use one worker per vCPU. Only
#'              effective when estimator is a SageMaker framework.
#' @param vpc_config_override (dict[str, list[str]]): Override for VpcConfig set on
#'              the model. Default: use subnets and security groups from this Estimator.
#'              * 'Subnets' (list[str]): List of subnet ids.
#'              * 'SecurityGroupIds' (list[str]): List of security group ids.
#' @return dict: Model config that can be directly used by SageMakerModelOperator in Airflow. It can
#'              also be part of the config used by SageMakerEndpointOperator in Airflow.
#' @export
model_config_from_estimator = function(estimator,
                                       task_id,
                                       task_type,
                                       instance_type=NULL,
                                       role=NULL,
                                       image_uri=NULL,
                                       name=NULL,
                                       model_server_workers=NULL,
                                       vpc_config_override="VPC_CONFIG_DEFAULT"){
  update_estimator_from_task(estimator, task_id, task_type)
  if (inherits(estimator, "Estimator")){
    model = estimator$create_model(
      role=role, image_uri=image_uri, vpc_config_override=vpc_config_override
    )
  } else if (inherits(estimator, "AmazonAlgorithmEstimatorBase")){
    model = estimator$create_model(vpc_config_override=vpc_config_override)
  } else if (inherits(estimator, "TensorFlow")){
    model = estimator$create_model(
      role=role, vpc_config_override=vpc_config_override, entry_point=estimator$entry_point
    )
  } else if (inherits(estimator, "Framework")){
    model = estimator$create_model(
      model_server_workers=model_server_workers,
      role=role,
      vpc_config_override=vpc_config_override,
      entry_point=estimator$entry_point
    )
  } else {
    TypeError$new(
      "Estimator must be one of Estimator, Framework ",
      "or AmazonAlgorithmEstimatorBase."
    )
  }
  model$name = name

  return (model_config(model, instance_type, role, image_uri))
}

#' @title Export Airflow transform config from a SageMaker transformer
#' @param transformer (sagemaker.transformer.Transformer): The SageMaker
#'              transformer to export Airflow config from.
#' @param data (str): Input data location in S3.
#' @param data_type (str): What the S3 location defines (default: 'S3Prefix').
#'              Valid values:
#'              * 'S3Prefix' - the S3 URI defines a key name prefix. All objects with this prefix will
#'              be used as inputs for the transform job.
#'              * 'ManifestFile' - the S3 URI points to a single manifest file listing each S3 object
#'              to use as an input for the transform job.
#' @param content_type (str): MIME type of the input data (default: None).
#' @param compression_type (str): Compression type of the input data, if
#'              compressed (default: None). Valid values: 'Gzip', None.
#' @param split_type (str): The record delimiter for the input object (default:
#'              'None'). Valid values: 'None', 'Line', 'RecordIO', and 'TFRecord'.
#' @param job_name (str): job name (default: None). If not specified, one will be
#'              generated.
#' @param input_filter (str): A JSONPath to select a portion of the input to
#'              pass to the algorithm container for inference. If you omit the
#'              field, it gets the value '$', representing the entire input.
#'              For CSV data, each row is taken as a JSON array,
#'              so only index-based JSONPaths can be applied, e.g. $[0], $[1:].
#'              CSV data should follow the `RFC format <https://tools.ietf.org/html/rfc4180>`_.
#'              See `Supported JSONPath Operators
#'              <https://docs.aws.amazon.com/sagemaker/latest/dg/batch-transform-data-processing.html#data-processing-operators>`_
#'              for a table of supported JSONPath operators.
#'              For more information, see the SageMaker API documentation for
#'              `CreateTransformJob
#'              <https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTransformJob.html>`_.
#'              Some examples: "$[1:]", "$.features" (default: None).
#' @param output_filter (str): A JSONPath to select a portion of the
#'              joined/original output to return as the output.
#'              For more information, see the SageMaker API documentation for
#'              `CreateTransformJob
#'              <https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTransformJob.html>`_.
#'              Some examples: "$[1:]", "$.prediction" (default: None).
#' @param join_source (str): The source of data to be joined to the transform
#'              output. It can be set to 'Input' meaning the entire input record
#'              will be joined to the inference result. You can use OutputFilter
#'              to select the useful portion before uploading to S3. (default:
#'              None). Valid values: Input, None.
#' @return dict: Transform config that can be directly used by
#'              SageMakerTransformOperator in Airflow.
#' @export
transform_config = function(transformer,
                            data,
                            data_type="S3Prefix",
                            content_type=NULL,
                            compression_type=NULL,
                            split_type=NULL,
                            job_name=NULL,
                            input_filter=NULL,
                            output_filter=NULL,
                            join_source=NULL){
  if (!is.null(job_name)) {
    transformer$.current_job_name = job_name
  } else {
    base_name = transformer$base_transform_job_name
    transformer$.current_job_name = (if (!is.null(base_name))
      sagemaker.core::name_from_base(base_name) else transformer$model_name)
  }
  if (is.null(transformer$output_path)){
    transformer$output_path = sprintf("s3://%s/%s",
      transformer$sagemaker_session$default_bucket(), transformer$.current_job_name
    )
  }
  job_config = transformer$.__enclos_env__$private$.load_config(
    data, data_type, content_type, compression_type, split_type
  )

  config = list(
    "TransformJobName"=transformer$.current_job_name,
    "ModelName"=transformer$model_name,
    "TransformInput"=job_config[["input_config"]],
    "TransformOutput"=job_config[["output_config"]],
    "TransformResources"=job_config[["resource_config"]])

  data_processing = transformer$.__enclos_env__$private$.prepare_data_processing(
    input_filter, output_filter, join_source
  )
  if (!is.null(data_processing))
    config[["DataProcessing"]] = data_processing

  if (!is.null(transformer$strategy))
    config[["BatchStrategy"]] = transformer$strategy

  if (!is.null(transformer$max_concurrent_transforms))
    config[["MaxConcurrentTransforms"]] = transformer$max_concurrent_transforms

  if (!is.null(transformer$max_payload))
    config[["MaxPayloadInMB"]] = transformer$max_payload

  if (!is.null(transformer$env))
    config[["Environment"]] = transformer$env

  if (!is.null(transformer$tags))
    config[["Tags"]] = transformer$tags

  return(config)
}

#' @title Export Airflow transform config from a SageMaker estimator
#' @param estimator (sagemaker.model.EstimatorBase): The SageMaker estimator to
#'              export Airflow config from. It has to be an estimator associated
#'              with a training job.
#' @param task_id (str): The task id of any
#'              airflow.contrib.operators.SageMakerTrainingOperator or
#'              airflow.contrib.operators.SageMakerTuningOperator that generates
#'              training jobs in the DAG. The transform config is built based on the
#'              training job generated in this operator.
#' @param task_type (str): Whether the task is from SageMakerTrainingOperator or
#'              SageMakerTuningOperator. Values can be 'training', 'tuning' or None
#'              (which means training job is not from any task).
#' @param instance_count (int): Number of EC2 instances to use.
#' @param instance_type (str): Type of EC2 instance to use, for example,
#'              'ml.c4.xlarge'.
#' @param data (str): Input data location in S3.
#' @param data_type (str): What the S3 location defines (default: 'S3Prefix').
#'              Valid values:
#'              * 'S3Prefix' - the S3 URI defines a key name prefix. All objects with this prefix will
#'              be used as inputs for the transform job.
#'              * 'ManifestFile' - the S3 URI points to a single manifest file listing each S3 object
#'              to use as an input for the transform job.
#' @param content_type (str): MIME type of the input data (default: None).
#' @param compression_type (str): Compression type of the input data, if
#'              compressed (default: None). Valid values: 'Gzip', None.
#' @param split_type (str): The record delimiter for the input object (default:
#'              'None'). Valid values: 'None', 'Line', 'RecordIO', and 'TFRecord'.
#' @param job_name (str): transform job name (default: None). If not specified,
#'              one will be generated.
#' @param model_name (str): model name (default: None). If not specified, one will
#'              be generated.
#' @param strategy (str): The strategy used to decide how to batch records in a
#'              single request (default: None). Valid values: 'MultiRecord' and
#'              'SingleRecord'.
#' @param assemble_with (str): How the output is assembled (default: None). Valid
#'              values: 'Line' or 'None'.
#' @param output_path (str): S3 location for saving the transform result. If not
#'              specified, results are stored to a default bucket.
#' @param output_kms_key (str): Optional. KMS key ID for encrypting the transform
#'              output (default: None).
#' @param accept (str): The accept header passed by the client to
#'              the inference endpoint. If it is supported by the endpoint,
#'              it will be the format of the batch transform output.
#' @param env (dict): Environment variables to be set for use during the transform
#'              job (default: None).
#' @param max_concurrent_transforms (int): The maximum number of HTTP requests to
#'              be made to each individual transform container at one time.
#' @param max_payload (int): Maximum size of the payload in a single HTTP request
#'              to the container in MB.
#' @param tags (list[dict]): List of tags for labeling a transform job. If none
#'              specified, then the tags used for the training job are used for the
#'              transform job.
#' @param role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``,
#'              which is also used during transform jobs. If not specified, the role
#'              from the Estimator will be used.
#' @param volume_kms_key (str): Optional. KMS key ID for encrypting the volume
#'              attached to the ML compute instance (default: None).
#' @param model_server_workers (int): Optional. The number of worker processes
#'              used by the inference server. If None, server will use one worker
#'              per vCPU.
#' @param image_uri (str): A Docker image URI to use for deploying the model
#' @param vpc_config_override (dict[str, list[str]]): Override for VpcConfig set on
#'              the model. Default: use subnets and security groups from this Estimator.
#'              * 'Subnets' (list[str]): List of subnet ids.
#'              * 'SecurityGroupIds' (list[str]): List of security group ids.
#' @param input_filter (str): A JSONPath to select a portion of the input to
#'              pass to the algorithm container for inference. If you omit the
#'              field, it gets the value '$', representing the entire input.
#'              For CSV data, each row is taken as a JSON array,
#'              so only index-based JSONPaths can be applied, e.g. $[0], $[1:].
#'              CSV data should follow the `RFC format <https://tools.ietf.org/html/rfc4180>`_.
#'              See `Supported JSONPath Operators
#'              <https://docs.aws.amazon.com/sagemaker/latest/dg/batch-transform-data-processing.html#data-processing-operators>`_
#'              for a table of supported JSONPath operators.
#'              For more information, see the SageMaker API documentation for
#'              `CreateTransformJob
#'              <https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTransformJob.html>`_.
#'              Some examples: "$[1:]", "$.features" (default: None).
#' @param output_filter (str): A JSONPath to select a portion of the
#'              joined/original output to return as the output.
#'              For more information, see the SageMaker API documentation for
#'              `CreateTransformJob
#'              <https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTransformJob.html>`_.
#'              Some examples: "$[1:]", "$.prediction" (default: None).
#' @param join_source (str): The source of data to be joined to the transform
#'              output. It can be set to 'Input' meaning the entire input record
#'              will be joined to the inference result. You can use OutputFilter
#'              to select the useful portion before uploading to S3. (default:
#'              None). Valid values: Input, None.
#' @return dict: Transform config that can be directly used by
#'              SageMakerTransformOperator in Airflow.
#' @export
transform_config_from_estimator = function(estimator,
                                           task_id,
                                           task_type,
                                           instance_count,
                                           instance_type,
                                           data,
                                           data_type="S3Prefix",
                                           content_type=NULL,
                                           compression_type=NULL,
                                           split_type=NULL,
                                           job_name=NULL,
                                           model_name=NULL,
                                           strategy=NULL,
                                           assemble_with=NULL,
                                           output_path=NULL,
                                           output_kms_key=NULL,
                                           accept=NULL,
                                           env=NULL,
                                           max_concurrent_transforms=NULL,
                                           max_payload=NULL,
                                           tags=NULL,
                                           role=NULL,
                                           volume_kms_key=NULL,
                                           model_server_workers=NULL,
                                           image_uri=NULL,
                                           vpc_config_override=NULL,
                                           input_filter=NULL,
                                           output_filter=NULL,
                                           join_source=NULL){
  model_base_config = model_config_from_estimator(
    estimator=estimator,
    task_id=task_id,
    task_type=task_type,
    instance_type=instance_type,
    role=role,
    image_uri=image_uri,
    name=model_name,
    model_server_workers=model_server_workers,
    vpc_config_override=vpc_config_override)

  if (inherits(estimator, "Framework")){
    transformer = estimator$transformer(
      instance_count,
      instance_type,
      strategy,
      assemble_with,
      output_path,
      output_kms_key,
      accept,
      env,
      max_concurrent_transforms,
      max_payload,
      tags,
      role,
      model_server_workers,
      volume_kms_key)
  } else {
    transformer = estimator$transformer(
      instance_count,
      instance_type,
      strategy,
      assemble_with,
      output_path,
      output_kms_key,
      accept,
      env,
      max_concurrent_transforms,
      max_payload,
      tags,
      role,
      volume_kms_key)
  }
  transformer$model_name = model_base_config[["ModelName"]]

  transform_base_config = transform_config(
    transformer,
    data,
    data_type,
    content_type,
    compression_type,
    split_type,
    job_name,
    input_filter,
    output_filter,
    join_source
  )

  config = list("Model"=model_base_config, "Transform"=transform_base_config)

  return(config)
}

#' @title Export Airflow deploy config from a SageMaker model
#' @param model (sagemaker.model.Model): The SageMaker model to export the Airflow
#'              config from.
#' @param initial_instance_count (int): The initial number of instances to run in
#'              the ``Endpoint`` created from this ``Model``.
#' @param instance_type (str): The EC2 instance type to deploy this Model to. For
#'              example, 'ml.p2.xlarge'.
#' @param endpoint_name (str): The name of the endpoint to create (default: None).
#'              If not specified, a unique endpoint name will be created.
#' @param tags (list[dict]): List of tags for labeling a training job. For more,
#'              see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
#' @return dict: Deploy config that can be directly used by
#'              SageMakerEndpointOperator in Airflow.
#' @export
deploy_config = function(model,
                         initial_instance_count,
                         instance_type,
                         endpoint_name=NULL,
                         tags=NULL){
  model_base_config = model_config(model, instance_type)

  production_variant = production_variant(
    model$name, instance_type, initial_instance_count
  )
  name = model$name
  config_options = list("EndpointConfigName"=name, "ProductionVariants"=list(production_variant))
  config_options[["Tags"]] = tags

  endpoint_name = endpoint_name %||% name
  endpoint_base_config = list("EndpointName"=endpoint_name, "EndpointConfigName"=name)

  # if there is s3 operations needed for model, move it to root level of config
  s3_operations = model_base_config[["S3Operations"]]
  model_base_config[["S3Operations"]] = NULL

  config = list(
    "Model"=model_base_config,
    "EndpointConfig"=config_options,
    "Endpoint"=endpoint_base_config
  )

  if (!is.null(s3_operations))
    config[["S3Operations"]] = s3_operations

  return(config)
}

#' @title Export Airflow deploy config from a SageMaker estimator
#' @param estimator (sagemaker.model.EstimatorBase): The SageMaker estimator to
#'              export Airflow config from. It has to be an estimator associated
#'              with a training job.
#' @param task_id (str): The task id of any
#'              airflow.contrib.operators.SageMakerTrainingOperator or
#'              airflow.contrib.operators.SageMakerTuningOperator that generates
#'              training jobs in the DAG. The endpoint config is built based on the
#'              training job generated in this operator.
#' @param task_type (str): Whether the task is from SageMakerTrainingOperator or
#'              SageMakerTuningOperator. Values can be 'training', 'tuning' or None
#'              (which means training job is not from any task).
#' @param initial_instance_count (int): Minimum number of EC2 instances to deploy
#'              to an endpoint for prediction.
#' @param instance_type (str): Type of EC2 instance to deploy to an endpoint for
#'              prediction, for example, 'ml.c4.xlarge'.
#' @param model_name (str): Name to use for creating an Amazon SageMaker model. If
#'              not specified, one will be generated.
#' @param endpoint_name (str): Name to use for creating an Amazon SageMaker
#'              endpoint. If not specified, the name of the SageMaker model is used.
#' @param tags (list[dict]): List of tags for labeling a training job. For more,
#'              see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
#' @param ... : Passed to invocation of ``create_model()``. Implementations
#'              may customize ``create_model()`` to accept ``**kwargs`` to customize
#'              model creation during deploy. For more, see the implementation docs.
#' @return dict: Deploy config that can be directly used by
#'              SageMakerEndpointOperator in Airflow.
#' @export
deploy_config_from_estimator = function(estimator,
                                        task_id,
                                        task_type,
                                        initial_instance_count,
                                        instance_type,
                                        model_name=NULL,
                                        endpoint_name=NULL,
                                        tags=NULL,
                                        ...){
  update_estimator_from_task(estimator, task_id, task_type)
  model = estimator$create_model(...)
  model$name = model_name
  config = deploy_config(model, initial_instance_count, instance_type, endpoint_name, tags)
  return(config)
}

#' @title Export Airflow processing config from a SageMaker processor
#' @param processor (sagemaker.processor.Processor): The SageMaker
#'              processor to export Airflow config from.
#' @param inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
#'              the processing job. These must be provided as
#'              :class:`~sagemaker.processing.ProcessingInput` objects (default: None).
#' @param outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
#'              the processing job. These can be specified as either path strings or
#'              :class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
#' @param job_name (str): Processing job name. If not specified, the processor generates
#'              a default job name, based on the base job name and current timestamp.
#' @param experiment_config (dict[str, str]): Experiment management configuration.
#'              Dictionary contains three optional keys:
#'              'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
#' @param container_arguments ([str]): The arguments for a container used to run a processing job.
#' @param container_entrypoint ([str]): The entrypoint for a container used to run a processing job.
#' @param kms_key_id (str): The AWS Key Management Service (AWS KMS) key that Amazon SageMaker
#'              uses to encrypt the processing job output. KmsKeyId can be an ID of a KMS key,
#'              ARN of a KMS key, alias of a KMS key, or alias of a KMS key.
#'              The KmsKeyId is applied to all outputs.
#' @return dict: Processing config that can be directly used by
#'            SageMakerProcessingOperator in Airflow.
#' @export
processing_config = function(processor,
                             inputs=NULL,
                             outputs=NULL,
                             job_name=NULL,
                             experiment_config=NULL,
                             container_arguments=NULL,
                             container_entrypoint=NULL,
                             kms_key_id=NULL){
  if (!is.null(job_name)){
    processor$.current_job_name = job_name
  } else {
    base_name = processor$base_job_name
    processor$.current_job_name = (
      if (!is.null(base_name)) {
        sagemaker.core::name_from_base(base_name)
      } else {
        sagemaker.core::base_name_from_image(processor$image_uri)
      }
    )
  }
  config = list(
    "ProcessingJobName"=processor$.current_job_name,
    "ProcessingInputs"=input_output_list_converter(inputs))

  processing_output_config = sagemaker.common::ProcessingJob$public_methods$prepare_output_config(
    kms_key_id, input_output_list_converter(outputs)
  )
  config[["ProcessingOutputConfig"]] = processing_output_config

  config[["ExperimentConfig"]] = experiment_config

  app_specification = sagemaker.common::ProcessingJob$public_methods$prepare_app_specification(
    container_arguments, container_entrypoint, processor$image_uri
  )
  config[["AppSpecification"]] = app_specification

  config[["RoleArn"]] = processor$role

  config[["Environment"]]= processor$env

  if (!is.null(processor$network_config))
    config[["NetworkConfig"]] = processor$network_config$to_request_list()

  processing_resources = sagemaker.common::ProcessingJob$public_methods$prepare_processing_resources(
    instance_count=processor$instance_count,
    instance_type=processor$instance_type,
    volume_kms_key_id=processor$volume_kms_key,
    volume_size_in_gb=processor$volume_size_in_gb
  )
  config[["ProcessingResources"]] = processing_resources

  if (!is.null(processor$max_runtime_in_seconds)) {
    stopping_condition = sagemaker.common::ProcessingJob$public_methods$prepare_stopping_condition(
      processor$max_runtime_in_seconds
    )
    config[["StoppingCondition"]] = stopping_condition
  }
  if(!is.null(processor$tags))
    config[["Tags"]] = processor$tags

  return(config)
}

#' @title Converts a list of ProcessingInput or ProcessingOutput objects to a list of dicts
#' @param object_list (list[ProcessingInput or ProcessingOutput]
#' @return List of dicts
#' @export
input_output_list_converter = function(object_list){
  if (!islistempty(object_list))
    return(lapply(object_list, function(obj) obj$to_request_list()))
  return(object_list)
}
DyfanJones/sagemaker-r-workflow documentation built on April 3, 2022, 11:28 p.m.