R/rl_estimator.R

# NOTE: This code has been modified from AWS Sagemaker Python:
# https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/rl/estimator.py

#' @include tensorflow_estimator.R
#' @include r_utils.R

#' @import R6
#' @import sagemaker.core
#' @import sagemaker.common
#' @import sagemaker.mlcore
#' @import lgr
#' @importFrom stats setNames

SAGEMAKER_ESTIMATOR <- "sagemaker_estimator"
SAGEMAKER_ESTIMATOR_VALUE <- "RLEstimator"
RL_PYTHON_VERSION <- "py3"
TOOLKIT_FRAMEWORK_VERSION_MAP <- list(
  "coach"=list(
    "0.10.1"=list("tensorflow"="1.11"),
    "0.10"=list("tensorflow"="1.11"),
    "0.11.0"=list("tensorflow"="1.11", "mxnet"="1.3"),
    "0.11.1"=list("tensorflow"="1.12"),
    "0.11"=list("tensorflow"="1.12", "mxnet"="1.3"),
    "1.0.0"=list("tensorflow"="1.12")
  ),
  "ray"=list(
    "0.5.3"=list("tensorflow"="1.11"),
    "0.5"=list("tensorflow"="1.11"),
    "0.6.5"=list("tensorflow"="1.12"),
    "0.6"=list("tensorflow"="1.12"),
    "0.8.2"=list("tensorflow"="2.1"),
    "0.8.5"=list("tensorflow"="2.1", "pytorch"="1.5"),
    "1.6.0"=list("tensorflow"="2.5.0", "pytorch"="1.8.1")
  )
)

#' @title RLToolkit enum environment list
#' @description RL toolkit you want to use for
#'              executing your model training code.
#' @return environment containing [COACH, RAY]
#' @export
RLToolkit = sagemaker.core::Enum(COACH = "coach", RAY = "ray")

#' @title RLFramework enum environment list
#' @description Framework (MXNet, TensorFlow or PyTorch) you want to be used
#'              as a toolkit backed for
#'              reinforcement learning training.
#' @return environment containing [TENSORFLOW, MXNET, PYTORCH]
#' @export
RLFramework = sagemaker.core::Enum(
  TENSORFLOW = "tensorflow",
  MXNET = "mxnet",
  PYTORCH = "pytorch"
)

#' @title RLEstimator Class
#' @description Handle end-to-end training and deployment of custom RLEstimator code.
#' @export
RLEstimator = R6Class("RLEstimator",
  inherit = sagemaker.mlcore::Framework,
  public = list(

    #' @field COACH_LATEST_VERSION_TF
    #' latest version of toolkit coach for tensorflow
    COACH_LATEST_VERSION_TF = "0.11.1",

    #' @field COACH_LATEST_VERSION_MXNET
    #' latest version of toolkit coach for mxnet
    COACH_LATEST_VERSION_MXNET = "0.11.0",

    #' @field RAY_LATEST_VERSION
    #' latest version of toolkit ray
    RAY_LATEST_VERSION = "1.6.0",

    #' @field .module
    #' mimic python module
    .module = "sagemaker.rl.estimator",

    #' @description Creates an RLEstimator for managed Reinforcement Learning (RL).
    #'              It will execute an RLEstimator script within a SageMaker Training Job. The managed RL
    #'              environment is an Amazon-built Docker container that executes functions defined in the
    #'              supplied ``entry_point`` Python script.
    #'              Training is started by calling
    #'              :meth:`~sagemaker.amazon.estimator.Framework.fit` on this Estimator.
    #'              After training is complete, calling
    #'              :meth:`~sagemaker.amazon.estimator.Framework.deploy` creates a hosted
    #'              SageMaker endpoint and based on the specified framework returns an
    #'              :class:`~sagemaker.amazon.mxnet.model.MXNetPredictor` or
    #'              :class:`~sagemaker.amazon.tensorflow.model.TensorFlowPredictor` instance that
    #'              can be used to perform inference against the hosted model.
    #'              Technical documentation on preparing RLEstimator scripts for
    #'              SageMaker training and using the RLEstimator is available on the project
    #'              homepage: https://github.com/aws/sagemaker-python-sdk
    #' @param entry_point (str): Path (absolute or relative) to the Python source
    #'              file which should be executed as the entry point to training.
    #'              If ``source_dir`` is specified, then ``entry_point``
    #'              must point to a file located at the root of ``source_dir``.
    #' @param toolkit (sagemaker.rl.RLToolkit): RL toolkit you want to use for
    #'              executing your model training code.
    #' @param toolkit_version (str): RL toolkit version you want to be use for
    #'              executing your model training code.
    #' @param framework (sagemaker.rl.RLFramework): Framework (MXNet or
    #'              TensorFlow) you want to be used as a toolkit backed for
    #'              reinforcement learning training.
    #' @param source_dir (str): Path (absolute, relative or an S3 URI) to a directory
    #'              with any other training source code dependencies aside from the entry
    #'              point file (default: NULL). If ``source_dir`` is an S3 URI, it must
    #'              point to a tar.gz file. Structure within this directory are preserved
    #'              when training on Amazon SageMaker.
    #' @param hyperparameters (dict): Hyperparameters that will be used for
    #'              training (default: NULL). The hyperparameters are made
    #'              accessible as a dict[str, str] to the training code on
    #'              SageMaker. For convenience, this accepts other types for keys
    #'              and values.
    #' @param image_uri (str): An ECR url. If specified, the estimator will use
    #'              this image for training and hosting, instead of selecting the
    #'              appropriate SageMaker official image based on framework_version
    #'              and py_version. Example:
    #'              123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0
    #' @param metric_definitions (list[dict]): A list of dictionaries that defines
    #'              the metric(s) used to evaluate the training jobs. Each
    #'              dictionary contains two keys: 'Name' for the name of the metric,
    #'              and 'Regex' for the regular expression used to extract the
    #'              metric from the logs. This should be defined only for jobs that
    #'              don't use an Amazon algorithm.
    #' @param ... : Additional kwargs passed to the
    #'              :class:`~sagemaker.estimator.Framework` constructor.
    #'              .. tip::
    #'              You can find additional parameters for initializing this class at
    #'              :class:`~sagemaker.estimator.Framework` and
    #'              :class:`~sagemaker.estimator.EstimatorBase`.
    initialize = function(entry_point,
                          toolkit=NULL,
                          toolkit_version=NULL,
                          framework=NULL,
                          source_dir=NULL,
                          hyperparameters=NULL,
                          image_uri=NULL,
                          metric_definitions=NULL,
                          ...){
      private$.validate_images_args(toolkit, toolkit_version, framework, image_uri)

      if (is.null(image_uri)){
        private$.validate_toolkit_support(toolkit, toolkit_version, framework)
        self$toolkit = toolkit
        self$toolkit_version = toolkit_version
        self$framework = framework
        self$framework_version = TOOLKIT_FRAMEWORK_VERSION_MAP[[self$toolkit]][[
          self$toolkit_version
        ]][[self$framework]]

        # set default metric_definitions based on the toolkit
        if (is.null(metric_definitions))
          metric_definitions = self$default_metric_definitions(toolkit)
      }
      super$initialize(
        entry_point,
        source_dir,
        hyperparameters,
        image_uri=image_uri,
        metric_definitions=metric_definitions,
        ...
      )
    },

    #' @description Create a SageMaker ``RLEstimatorModel`` object that can be deployed to an Endpoint.
    #' @param role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``,
    #'              which is also used during transform jobs. If not specified, the
    #'              role from the Estimator will be used.
    #' @param vpc_config_override (dict[str, list[str]]): Optional override for VpcConfig set on
    #'              the model. Default: use subnets and security groups from this Estimator.
    #'              * 'Subnets' (list[str]): List of subnet ids.
    #'              * 'SecurityGroupIds' (list[str]): List of security group ids.
    #' @param entry_point (str): Path (absolute or relative) to the Python source
    #'              file which should be executed as the entry point for MXNet
    #'              hosting (default: self.entry_point). If ``source_dir`` is specified,
    #'              then ``entry_point`` must point to a file located at the root of ``source_dir``.
    #' @param source_dir (str): Path (absolute or relative) to a directory with
    #'              any other training source code dependencies aside from the entry
    #'              point file (default: self.source_dir). Structure within this
    #'              directory are preserved when hosting on Amazon SageMaker.
    #' @param dependencies (list[str]): A list of paths to directories (absolute
    #'              or relative) with any additional libraries that will be exported
    #'              to the container (default: self.dependencies). The library
    #'              folders will be copied to SageMaker in the same folder where the
    #'              entry_point is copied. If the ```source_dir``` points to S3,
    #'              code will be uploaded and the S3 location will be used instead.
    #'              This is not supported with "local code" in Local Mode.
    #' @param ... : Additional kwargs passed to the :class:`~sagemaker.model.FrameworkModel`
    #'              constructor.
    #' @return sagemaker.model.FrameworkModel: Depending on input parameters returns
    #'              one of the following:
    #'              * :class:`~sagemaker.model.FrameworkModel` - if ``image_uri`` is specified
    #'              on the estimator;
    #'              * :class:`~sagemaker.mxnet.MXNetModel` - if ``image_uri`` isn't specified and
    #'                                  MXNet is used as the RL backend;
    #'              * :class:`~sagemaker.tensorflow.model.TensorFlowModel` - if ``image_uri`` isn't
    #'              specified and TensorFlow is used as the RL backend.
    create_model = function(role=NULL,
                            vpc_config_override="VPC_CONFIG_DEFAULT",
                            entry_point=NULL,
                            source_dir=NULL,
                            dependencies=NULL,
                            ...){
      kwargs = list(...)
      base_args = list(
        model_data=self$model_data,
        role=role %||% self$role,
        image_uri=kwargs$image_uri %||% self$image_uri,
        container_log_level=self$container_log_level,
        sagemaker_session=self$sagemaker_session,
        vpc_config=self$get_vpc_config(vpc_config_override))

      base_args[["name"]] = private$.get_or_create_name(kwargs[["name"]])

      if (is.null(entry_point) && (!is.null(source_dir) || !is.null(dependencies)))
        AttributeError$new("Please provide an `entry_point`.")

      entry_point = entry_point %||% private$.model_entry_point()
      source_dir = source_dir %||% private$.model_source_dir()
      dependencies = dependencies %||% self$dependencies

      extended_args = list(
        entry_point=entry_point,
        source_dir=source_dir,
        code_location=self$code_location,
        dependencies=dependencies
      )
      extended_args = modifyList(base_args, extended_args)

      if (!is.null(self$image_uri))
        return(do.call(FrameworkModel$new, extended_args))

      if (self$toolkit == RLToolkit$RAY)
        NotImplementedError$new(
          "Automatic deployment of Ray models is not currently available.",
          " Train policy parameters are available in model checkpoints",
          " in the TrainingJob output.")

      if (self$framework == RLFramework$TENSORFLOW){
        extended_args = c(framework_version=self$framework_version, extended_args)
        return(do.call(TensorFlowModel$new, extended_args))}
      if (self$framework == RLFramework$MXNET){
        extended_args = c(framework_version=self$framework_version, py_version=RL_PYTHON_VERSION, extended_args)
        return(do.call(MXNetModel$new, extended_args))
      }
      ValueError$new(sprintf(
        "An unknown RLFramework enum was passed in. framework: %s", self$framework)
      )
    },

    #' @description Return the Docker image to use for training.
    #'              The :meth:`~sagemaker.estimator.EstimatorBase.fit` method, which does
    #'              the model training, calls this method to find the image to use for model
    #'              training.
    #' @return str: The URI of the Docker image.
    training_image_uri = function(){
      if (!is.null(self$image_uri))
        return(self$image_uri)
      return(sagemaker.core::ImageUris$new()$retrieve(
        private$.image_framework(),
        self$sagemaker_session$paws_region_name,
        version=self$toolkit_version,
        instance_type=self$instance_type)
      )
    },

    #' @description Return hyperparameters used by your custom TensorFlow code during model training.
    hyperparameters = function(){
      hyperparameters = super$hyperparameters()

      additional_hyperparameters = setNames(list(
        self$output_path,
        # TODO: can be applied to all other estimators
        SAGEMAKER_ESTIMATOR_VALUE),
        c(model_parameters$SAGEMAKER_OUTPUT_LOCATION, SAGEMAKER_ESTIMATOR)
      )

      hyperparameters = modifyList(hyperparameters, private$.json_encode_hyperparameters(additional_hyperparameters))
      return(hyperparameters)
    },

    #' @description Provides default metric definitions based on provided toolkit.
    #' @param toolkit (sagemaker.rl.RLToolkit): RL Toolkit to be used for
    #'              training.
    #' @return list: metric definitions
    default_metric_definitions = function(toolkit){
      if (toolkit == RLToolkit$COACH){
        return(list(
          list("Name"="reward-training", "Regex"="^Training>.*Total reward=(.*?),"),
          list("Name"="reward-testing", "Regex"="^Testing>.*Total reward=(.*?),"))
        )
      }
      if (toolkit == RLToolkit$RAY){
        float_regex = "[-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?"  # noqa: W605, E501
        return(list(
          list("Name"="episode_reward_mean", "Regex"=sprintf("episode_reward_mean: (%s)" , float_regex)),
          list("Name"="episode_reward_max", "Regex"=sprintf("episode_reward_max: (%s)",float_regex)))
        )
      }
      ValueError$new(sprintf("An unknown RLToolkit enum was passed in. toolkit: %s", toolkit))
    }
  ),

  private = list(

    # Convert the job description to init params.
    # This is done so that the init params can be handled by the class constructor.
    # Args:
    #   job_details: the returned job details from a describe_training_job
    # API call.
    # model_channel_name (str): Name of the channel where pre-trained
    # model data will be downloaded.
    # Returns:
    #   dictionary: The transformed init_params
    .prepare_init_params_from_job_description = function(job_details,
                                                       model_channel_name=NULL){
      init_params = super$.prepare_init_params_from_job_description(
        job_details, model_channel_name)

      image_uri = init_params$image_uri
      init_params$image_uri = NULL
      img_split = framework_name_from_image(image_uri)
      names(img_split) = c("framework", "py_version", "tag", "scriptmode")

      if (is.null(img_split$framework)) {
        # If we were unable to parse the framework name from the image it is not one of our
        # officially supported images, in this case just add the image to the init params.
        init_params[["image_uri"]] = image_uri
        return(init_params)
      }
      ll_tag = private$.toolkit_and_version_from_tag(img_split$tag)
      names(ll_tag) <- c("toolkit", "toolkit_version")

      if (!private$.is_combination_supported(ll_tag$toolkit, ll_tag$toolkit_version, img_split$framework))
        ValueError$new(sprintf(
          "Training job: %s didn't use image for requested framework",
            job_details[["TrainingJobName"]])
          )

      init_params[["toolkit"]] = ll_tag$toolkit
      init_params[["toolkit_version"]] = ll_tag$toolkit_version
      init_params[["framework"]] = img_split$framework

      return(init_params)
    },

    .toolkit_and_version_from_tag = function(image_tag){
      tag_pattern = "^([A-Z]*|[a-z]*)(\\d.*)-(cpu|gpu)-(py2|py3)$"
      m = regexec(tag_pattern, image_tag)
      tag_match = unlist(regmatches(image_tag, m))
      if (length(tag_match) > 0)
        return(list(tag_match[[2]], tag_match[[3]]))
      return(list(NULL, NULL))
    },

    .validate_framework_format = function(framework){
      rl_framework = unname(as.list(RLFramework))
      if (!is.null(framework) && !(framework %in% rl_framework))
        ValueError$new(sprintf(
          "Invalid type: %s, valid RL frameworks types are: %s",
          framework, paste(rl_framework, collapse = ", "))
        )
    },

    .validate_toolkit_format = function(toolkit){
      rl_toolkit = unname(as.list(RLToolkit))
      if (!is.null(toolkit) && !(toolkit %in% rl_toolkit))
        ValueError$new(sprintf(
          "Invalid type: %s, valid RL toolkits types are: %s",
          toolkit, paste(rl_toolkit, collapse = ", "))
        )
    },

    .validate_images_args = function(toolkit=NULL,
                                     toolkit_version=NULL,
                                     framework=NULL,
                                     image_uri=NULL){
      private$.validate_toolkit_format(toolkit)
      private$.validate_framework_format(framework)

      if (is.null(image_uri)){
        not_found_args = list()
        if (is.null(toolkit))
          not_found_args = c(not_found_args, "toolkit")
        if (is.null(toolkit_version))
          not_found_args = c(not_found_args, "toolkit_version")
        if (is.null(framework))
            not_found_args = c(not_found_args, "framework")
        if (!islistempty(not_found_args))
          AttributeError$new(sprintf(
            "Please provide `%s` or `image_uri` parameter.",
              paste(not_found_args, collapse = "`, `"))
            )
      } else {
        found_args = list()
      if (!is.null(toolkit))
        found_args = c(found_args, "toolkit")
      if (!is.null(toolkit_version))
        found_args = c(found_args, "toolkit_version")
      if (!is.null(framework))
        found_args = c(found_args, "framework")
      if (!islistempty(found_args))
        LOGGER$warn(paste(
          "Parameter `image_uri` is specified,",
          "`%s` are going to be ignored when choosing the image."),
          paste(found_args, collapse = "`, `"))
      }
    },

    .is_combination_supported = function(toolkit,
                                         toolkit_version,
                                         framework){
      supported_versions = if(is.null(toolkit)) NULL else TOOLKIT_FRAMEWORK_VERSION_MAP[[toolkit]]
      if (!is.null(supported_versions)){
        supported_frameworks = supported_versions[[toolkit_version]]
        if (!is.null(supported_frameworks) && !is.null(supported_frameworks[[framework]]))
          return(TRUE)
      }
      return(FALSE)
    },

    .validate_toolkit_support = function(toolkit,
                                         toolkit_version,
                                         framework){
      if (!private$.is_combination_supported(toolkit, toolkit_version, framework))
        AttributeError$new(sprintf(
          "Provided `%s-%s` and `%s` combination is not supported.",
            toolkit, toolkit_version, framework)
        )
    },

    # Toolkit name and framework name for retrieving Docker image URI config.
    .image_framework = function(){
      return(paste(self$toolkit, self$framework, sep = "-", collapse = "-"))
    }
  ),
  lock_objects = F
)
DyfanJones/sagemaker-r-mlframework documentation built on March 18, 2022, 7:41 a.m.