R/huggingface_estimator.R

# NOTE: This code has been modified from AWS Sagemaker Python:
# https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/huggingface/estimator.py

#' @import lgr
#' @import R6
#' @import sagemaker.core
#' @import sagemaker.common
#' @import sagemaker.mlcore

#' @title HuggingFace estimator class
#' @description Handle training of custom HuggingFace code.
#' @export
HuggingFace = R6Class("HuggingFace",
  inherit = sagemaker.mlcore::Framework,
  public = list(

    #' @field .module
    #' mimic python module
    .module = "sagemaker.huggingface.estimator",

    #' @description This ``Estimator`` executes a HuggingFace script in a managed execution environment.
    #'              The managed HuggingFace environment is an Amazon-built Docker container that executes
    #'              functions defined in the supplied ``entry_point`` Python script within a SageMaker
    #'              Training Job.
    #'              Training is started by calling
    #'              :meth:`~sagemaker.amazon.estimator.Framework.fit` on this Estimator.
    #' @param py_version (str): Python version you want to use for executing your model training
    #'              code. Defaults to ``None``. Required unless ``image_uri`` is provided.  List
    #'              of supported versions:
    #'              \url{https://github.com/aws/sagemaker-python-sdk#huggingface-sagemaker-estimators}
    #' @param entry_point (str): Path (absolute or relative) to the Python source
    #'              file which should be executed as the entry point to training.
    #'              If ``source_dir`` is specified, then ``entry_point``
    #'              must point to a file located at the root of ``source_dir``.
    #' @param transformers_version (str): Transformers version you want to use for
    #'              executing your model training code. Defaults to ``None``.  Required unless
    #'              ``image_uri`` is provided. The current supported version is ``4.6.1``.
    #' @param tensorflow_version (str): TensorFlow version you want to use for
    #'              executing your model training code. Defaults to ``None``. Required unless
    #'              ``pytorch_version`` is provided. List of supported versions:
    #'              \url{https://github.com/aws/sagemaker-python-sdk#huggingface-sagemaker-estimators}.
    #' @param pytorch_version (str): PyTorch version you want to use for
    #'              executing your model training code. Defaults to ``None``. Required unless
    #'              ``tensorflow_version`` is provided. List of supported versions:
    #'              \url{https://github.com/aws/sagemaker-python-sdk#huggingface-sagemaker-estimators}.
    #' @param source_dir (str): Path (absolute, relative or an S3 URI) to a directory
    #'              with any other training source code dependencies aside from the entry
    #'              point file (default: None). If ``source_dir`` is an S3 URI, it must
    #'              point to a tar.gz file. Structure within this directory are preserved
    #'              when training on Amazon SageMaker.
    #' @param hyperparameters (dict): Hyperparameters that will be used for
    #'              training (default: None). The hyperparameters are made
    #'              accessible as a dict[str, str] to the training code on
    #'              SageMaker. For convenience, this accepts other types for keys
    #'              and values, but ``str()`` will be called to convert them before
    #'              training.
    #' @param image_uri (str): If specified, the estimator will use this image
    #'              for training and hosting, instead of selecting the appropriate
    #'              SageMaker official image based on framework_version and
    #'              py_version. It can be an ECR url or dockerhub image and tag.
    #'              Examples:
    #'              * ``123412341234.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0``
    #'              * ``custom-image:latest``
    #'              If ``framework_version`` or ``py_version`` are ``None``, then
    #'              ``image_uri`` is required. If also ``None``, then a ``ValueError``
    #'              will be raised.
    #' @param distribution (dict): A dictionary with information on how to run distributed training
    #'              (default: None).  Currently, the following are supported:
    #'              distributed training with parameter servers, SageMaker Distributed (SMD) Data
    #'              and Model Parallelism, and MPI. SMD Model Parallelism can only be used with MPI.
    #'              To enable parameter server use the following setup:
    #'              .. code:: python
    #'              {
    #'              "parameter_server": {
    #'              "enabled": True
    #'              }
    #'              }
    #'              To enable MPI:
    #'              .. code:: python
    #'              {
    #'              "mpi": {
    #'              "enabled": True
    #'              }
    #'              }
    #'              To enable SMDistributed Data Parallel or Model Parallel:
    #'              .. code:: python
    #'              {
    #'              "smdistributed": {
    #'              "dataparallel": {
    #'              "enabled": True
    #'              },
    #'              "modelparallel": {
    #'              "enabled": True,
    #'              "parameters": {}
    #'              }
    #'              }
    #'              }
    #' @param compiler_config (:class:`sagemaker.mlcore::TrainingCompilerConfig`):
    #'              Configures SageMaker Training Compiler to accelerate training.
    #' @param ... : Additional kwargs passed to the :class:`~sagemaker.estimator.Framework`
    #'              constructor.
    initialize = function(py_version,
                          entry_point,
                          transformers_version=NULL,
                          tensorflow_version=NULL,
                          pytorch_version=NULL,
                          source_dir=NULL,
                          hyperparameters=NULL,
                          image_uri=NULL,
                          distribution=NULL,
                          compiler_config=NULL,
                          ...){
      kwargs = list(...)
      self$framework_version = transformers_version
      self$py_version = py_version
      self$tensorflow_version = tensorflow_version
      self$pytorch_version = pytorch_version

      private$.validate_args(image_uri=image_uri)
      instance_type = renamed_kwargs(
        "train_instance_type", "instance_type", kwargs[["instance_type"]], kwargs
      )

      base_framework_name = if (!is.null(tensorflow_version)) "tensorflow" else "pytorch"
      base_framework_version = (
        if (!is.null(tensorflow_version)) tensorflow_version else pytorch_version
      )

      if (!is.null(distribution)){

        validate_smdistributed(
          instance_type=instance_type,
          framework_name=base_framework_name,
          framework_version=base_framework_version,
          py_version=self$py_version,
          distribution=distribution,
          image_uri=image_ur)

        warn_if_parameter_server_with_multi_gpu(
          training_instance_type=instance_type, distribution=distribution
        )
      }
      kwargs[["py_version"]] = self$py_version

      if (!("enable_sagemaker_metrics" %in% names(kwargs)))
        kwargs[["enable_sagemaker_metrics"]] = TRUE

      kwargs = c(
        entry_point=entry_point,
        source_dir=source_dir,
        hyperparameters=list(hyperparameters),
        image_uri=image_uri,
        kwargs)
      do.call(super$initialize, kwargs)

      if (!is.null(compiler_config)){
        if (!inherits(compiler_config, "TrainingCompilerConfig")){
          error_string = paste(
            "Expected instance of type `sagemaker.mlcore::TrainingCompilerConfig`",
            "for argument compiler_config.",
            sprintf("Instead got `%s`", class(compiler_config)[1])
          )
          ValueError$new(error_string)
        }
        if (compiler_config$enabled){
          compiler_config$validate(
            image_uri=image_uri,
            instance_type=instance_type,
            distribution=distribution
          )
        }
      }
      self$distribution = distribution %||% list()
      self$compiler_config = compiler_config

      attr(self, "_framework_name") = "huggingface"
    },

    #' @description Return hyperparameters used by your custom PyTorch code during model training.
    hyperparameters = function(){
      hyperparameters = super$hyperparameters()
      additional_hyperparameters = private$.distribution_configuration(
        distribution=self$distribution
      )
      hyperparameters = modifyList(
        hyperparameters, private$.json_encode_hyperparameters(additional_hyperparameters)
      )

      if (!is.null(self$compiler_config)){
        training_compiler_hyperparameters = self$compiler_config$.to_hyperparameter_list()
        hyperparameters = modifyList(
          hyperparameters, private$.json_encode_hyperparameters(training_compiler_hyperparameters)
        )
      }
      return(hyperparameters)
    },

    #' @description Create a model to deploy.
    #'              The serializer, deserializer, content_type, and accept arguments are only used to define a
    #'              default Predictor. They are ignored if an explicit predictor class is passed in.
    #'              Other arguments are passed through to the Model class.
    #'              Creating model with HuggingFace training job is not supported.
    #' @param model_server_workers (int): Optional. The number of worker processes
    #'              used by the inference server. If None, server will use one
    #'              worker per vCPU.
    #' @param role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``,
    #'              which is also used during transform jobs. If not specified, the
    #'              role from the Estimator will be used.
    #' @param vpc_config_override (dict[str, list[str]]): Optional override for VpcConfig set on
    #'              the model.
    #'              Default: use subnets and security groups from this Estimator.
    #'              * 'Subnets' (list[str]): List of subnet ids.
    #'              * 'SecurityGroupIds' (list[str]): List of security group ids.
    #' @param entry_point (str): Path (absolute or relative) to the local Python
    #'              source file which should be executed as the entry point to
    #'              training. If ``source_dir`` is specified, then ``entry_point``
    #'              must point to a file located at the root of ``source_dir``.
    #'              If 'git_config' is provided, 'entry_point' should be
    #'              a relative location to the Python source file in the Git repo.
    #' @param source_dir (str): Path (absolute, relative or an S3 URI) to a directory
    #'              with any other training source code dependencies aside from the entry
    #'              point file (default: None). If ``source_dir`` is an S3 URI, it must
    #'              point to a tar.gz file. Structure within this directory are preserved
    #'              when training on Amazon SageMaker. If 'git_config' is provided,
    #'              'source_dir' should be a relative location to a directory in the Git
    #'              repo.
    #' @param dependencies (list[str]): A list of paths to directories (absolute
    #'              or relative) with any additional libraries that will be exported
    #'              to the container (default: []). The library folders will be
    #'              copied to SageMaker in the same folder where the entrypoint is
    #'              copied. If 'git_config' is provided, 'dependencies' should be a
    #'              list of relative locations to directories with any additional
    #'              libraries needed in the Git repo.
    #' @param ... : Additional parameters passed to :class:`~sagemaker.model.Model`
    #'              .. tip::
    #'              You can find additional parameters for using this method at
    #'              :class:`~sagemaker.model.Model`.
    #' @return (sagemaker.model.Model) a Model ready for deployment.
    create_model = function(model_server_workers=NULL,
                            role=NULL,
                            vpc_config_override="VPC_CONFIG_DEFAULT",
                            entry_point=NULL,
                            source_dir=NULL,
                            dependencies=NULL,
                            ...){
      kwargs = list(...)
      if (!("image_uri" %in% names(kwargs))){
          kwargs[["image_uri"]] = self$image_uri
      }
      kwargs[["name"]] = private$.get_or_create_name(kwargs[["name"]])
      kwargs = append(
        list(
          role = role %||% self$role,
          model_data=self$model_data,
          entry_point=entry_point,
          transformers_version=self$framework_version,
          tensorflow_version=self$tensorflow_version,
          pytorch_version=self$pytorch_version,
          py_version=self$py_version,
          source_dir=(source_dir %||% private$.model_source_dir()),
          container_log_level=self$container_log_level,
          code_location=self$code_location,
          model_server_workers=model_server_workers,
          sagemaker_session=self$sagemaker_session,
          vpc_config=self$get_vpc_config(vpc_config_override),
          dependencies=(dependencies %||% self$dependencies)
        ),
        kwargs
      )
      return(do.call(HuggingFaceModel$new, kwargs))
    }
  ),
  private = list(
    .validate_args = function(image_uri=NULL){
      if (!is.null(image_uri))
        return(NULL)

      if (is.null(self$framework_version) && is.null(image_uri))
        ValueError$new(
          "transformers_version, and image_uri are both NULL. ",
          "Specify either transformers_version or image_uri")
      if (!is.null(self$tensorflow_version) && !is.null(self$pytorch_version))
        ValueError$new(
          "tensorflow_version and pytorch_version are both not NULL. ",
          "Specify only tensorflow_version or pytorch_version.")
      if (is.null(self$tensorflow_version) && is.null(self$pytorch_version))
        ValueError$new(
          "tensorflow_version and pytorch_version are both NULL. ",
          "Specify either tensorflow_version or pytorch_version.")
      base_framework_version_len = (
        if (!is.null(self$tensorflow_version))
          length(split_str(self$tensorflow_version, "\\."))
        else length(split_str(self$pytorch_version, "\\."))
      )
      transformers_version_len = length(split_str(self$framework_version,"\\."))
      if (transformers_version_len != base_framework_version_len)
        ValueError$new(
          "Please use either full version or shortened version for both ",
          "transformers_version, tensorflow_version and pytorch_version.")
    },

    # Convert the job description to init params that can be handled by the class constructor.
    # Args:
    #   job_details: The returned job details from a describe_training_job
    # API call.
    # model_channel_name (str): Name of the channel where pre-trained
    # model data will be downloaded.
    # Returns:
    #   dictionary: The transformed init_params
    .prepare_init_params_from_job_description = function(job_details,
                                                         model_channel_name=NULL){
      init_params = super$.prepare_init_params_from_job_description(
        job_details, model_channel_name)

      image_uri = init_params$image_uri
      init_params$image_uri = NULL
      img_split = framework_name_from_image(image_uri)
      names(img_split) = c("framework", "py_version", "tag", "scriptmode")

      if (is.null(img_split$tag)){
        framework_version = NULL
      } else {
        fmwk = split_str(img_split$framework,"-")
        names(fmwk) = c("framework", "pt_or_tf")
        tag_pattern = "^(.*)-transformers(.*)-(cpu|gpu)-(py2|py3[67]?)$"
        m = regexec(tag_pattern, img_split$tag)
        tag_match = unlist(regmatches(img_split$tag, m))
        pt_or_tf_version = tag_match[2]
        framework_version = tag_match[3]
        if (fmwk[["pt_or_tf"]] == "pytorch"){
          init_params[["pytorch_version"]] = pt_or_tf_version
        } else {
          init_params[["tensorflow_version"]] = pt_or_tf_version
        }
      }
      init_params[["transformers_version"]] = framework_version
      init_params = append(init_params, list("py_version"=img_split$py_version))

      if (is.null(img_split$framework)){
        # If we were unable to parse the framework name from the image it is not one of our
        # officially supported images, in this case just add the image to the init params.
        init_params[["image_uri"]] = image_uri
        return(init_params)
      }
      if (fmwk[["framework"]] != attr(self, "_framework_name"))
        ValueError$new(
          sprintf("Training job: %s didn't use image for requested framework",
            job_details[["TrainingJobName"]])
        )

      return(init_params)
    }
  ),
  lock_objects=FALSE
)
DyfanJones/sagemaker-r-mlframework documentation built on March 18, 2022, 7:41 a.m.