R/fw_utils.R

Defines functions python_deprecation_warning validate_version_or_image_args .region_supports_profiler .region_supports_debugger .validate_smdataparallel_args validate_smdistributed warn_if_parameter_server_with_multi_gpu model_code_key_prefix framework_version_from_tag framework_name_from_image .list_files_to_compress tar_and_upload_dir validate_mp_config get_mp_parameters validate_source_dir

Documented in framework_name_from_image framework_version_from_tag get_mp_parameters model_code_key_prefix python_deprecation_warning .region_supports_debugger .region_supports_profiler tar_and_upload_dir validate_mp_config .validate_smdataparallel_args validate_smdistributed validate_source_dir validate_version_or_image_args warn_if_parameter_server_with_multi_gpu

# NOTE: This code has been modified from AWS Sagemaker Python:
# https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/fw_utils.py

#' @include utils.R
#' @include r_utils.R
#' @include image_uris.R
#' @include error.R

#' @import lgr
#' @importFrom fs is_file

.TAR_SOURCE_FILENAME <- ".tar.gz"

#' @title sagemaker.fw_utils.UserCode: An object containing the S3 prefix and script name.
#' @description This is for the source code used for the entry point with an ``Estimator``. It can be
#'              instantiated with positional or keyword arguments.
#' @keywords internal
#' @export
UploadedCode <- list("s3_prefix" = NULL, "script_name" = NULL)

PYTHON_2_DEPRECATION_WARNING <- paste(
  "%s is the latest version of %s that supports",
  "Python 2. Newer versions of %s will only be available for Python 3.",
  "Please set the argument \"py_version='py3'\" to use the Python 3 %s image.")

PARAMETER_SERVER_MULTI_GPU_WARNING <- paste(
  "If you have selected a multi-GPU training instance type",
  "and also enabled parameter server for distributed training,",
  "distributed training with the default parameter server configuration will not",
  "fully leverage all GPU cores; the parameter server will be configured to run",
  "only one worker per host regardless of the number of GPUs.")

DEBUGGER_UNSUPPORTED_REGIONS = c("us-iso-east-1")
PROFILER_UNSUPPORTED_REGIONS = c("us-iso-east-1")

SINGLE_GPU_INSTANCE_TYPES = c("ml.p2.xlarge", "ml.p3.2xlarge")
SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES = c(
  "ml.p3.16xlarge",
  "ml.p3dn.24xlarge",
  "ml.p4d.24xlarge",
  "local_gpu"
)
SM_DATAPARALLEL_SUPPORTED_FRAMEWORK_VERSIONS = list(
  "tensorflow"=list("2.3", "2.3.1", "2.3.2", "2.4", "2.4.1"),
  "pytorch"=list("1.6", "1.6.0", "1.7", "1.7.1", "1.8", "1.8.0", "1.8.1")
)
SMDISTRIBUTED_SUPPORTED_STRATEGIES = c("dataparallel", "modelparallel")

#' @title Validate that the source directory exists and it contains the user script
#' @param script (str): Script filename.
#' @param directory (str): Directory containing the source file.
#' @export
validate_source_dir <- function(script, directory){
  if (is.character(directory)){
    if (!fs::is_file(file.path(directory, script))){
      ValueError$new(sprintf('No file named "%s" was found in directory "%s".',script, directory))
    }
  }
  return(TRUE)
}

#' @title Get the model parallelism parameters provided by the user.
#' @param distribution : distribution dictionary defined by the user.
#' @return params: dictionary containing model parallelism parameters
#'              used for training.
#' @export
get_mp_parameters <- function(distribution){
  mp_dict = distribution$smdistributed$modelparallel %||% list()
  if (isTRUE(mp_dict$enabled %||% FALSE)) {
    params = mp_dict$parameters %||% list()
    validate_mp_config(params)
    return(params)
  }
  return(NULL)
}

#' @title Validate the configuration dictionary for model parallelism.
#' @param config (list): Dictionary holding configuration keys and values.
#' @export
validate_mp_config <- function(config){
  if (!("partitions" %in% names(config)))
    ValueError$new("'partitions' is a required parameter.")

  validate_positive <- function(key){
    if (!inherits(config[[key]], c("integer", "numeric")) || config[key] < 1)
      ValueError$new(sprintf("The number of %s must be a positive integer.",key))
  }

  validate_in <- function(key, vals){
    if (!(config[[key]] %in% vals))
      ValueError$new(sprintf("%s must be a value in: [%s].",
                   key, paste(vals, collapse = ", ")))
  }

  validate_bool <- function(keys){
    validate_in(keys, c(TRUE, FALSE))
  }

  validate_in("pipeline", c("simple", "interleaved", "_only_forward"))
  validate_in("placement_strategy", c("spread", "cluster"))
  validate_in("optimize", c("speed", "memory"))

  for (key in c("microbatches", "partitions"))
    validate_positive(key)

  for (key in c("auto_partition", "contiguous", "load_partition", "horovod", "ddp"))
    validate_bool(key)

  if ("partition_file" %in% names(config) &&
      !inherits(config$partition_file, "character"))
    ValueError$new("'partition_file' must be a character.")

  if (!isTRUE(config$auto_partition) && !("default_partition" %in% names(config)))
    ValueError$new("default_partition must be supplied if auto_partition is set to `FALSE`!")

  if ("default_partition" %in% names(config) && config$default_partition >= config$partitions)
    ValueError$new("default_partition must be less than the number of partitions!")

  if ("memory_weight" %in% names(config) && (
    config$memory_weight > 1 || config$memory_weight < 0))
    ValueError$new("memory_weight must be between 0.0 and 1.0!")

  if ("ddp_port" %in% names(config) && "ddp" %in% names(config))
    ValueError$new("`ddp_port` needs `ddp` to be set as well")

  if ("ddp_dist_backend" %in% names(config) && !("ddp" %in% names(config)))
    ValueError$new("`ddp_dist_backend` needs `ddp` to be set as well")

  if ("ddp_port" %in% names(config)){
    if (!inherits(config$ddp_port, "integer") || config$ddp_port < 0){
      value = config$ddp_port
      ValueError$new(sprintf("Invalid port number %s.", value))
    }
  }

  if ((config$horovod %||% FALSE) && (config$ddp %||% FALSE))
    ValueError$new("'ddp' and 'horovod' cannot be simultaneously enabled.")
}

#' @title Package source files and uploads a compress tar file to S3.
#' @description Package source files and upload a compress tar file to S3. The S3
#'              location will be ``s3://<bucket>/s3_key_prefix/sourcedir.tar.gz``.
#'              If directory is an S3 URI, an UploadedCode object will be returned, but
#'              nothing will be uploaded to S3 (this allow reuse of code already in S3).
#'              If directory is None, the script will be added to the archive at
#'              ``./<basename of script>``.
#'              If directory is not None, the (recursive) contents of the directory will
#'              be added to the archive. directory is treated as the base path of the
#'              archive, and the script name is assumed to be a filename or relative path
#'              inside the directory.
#' @param sagemaker_session (sagemaker.Session): sagemaker_session session used to access S3.
#' @param bucket (str): S3 bucket to which the compressed file is uploaded.
#' @param s3_key_prefix (str): Prefix for the S3 key.
#' @param script (str): Script filename or path.
#' @param directory (str): Optional. Directory containing the source file. If it
#'              starts with "s3://", no action is taken.
#' @param dependencies (List[str]): Optional. A list of paths to directories
#'              (absolute or relative) containing additional libraries that will be
#'              copied into /opt/ml/lib
#' @param kms_key (str): Optional. KMS key ID used to upload objects to the bucket
#'              (default: None).
#' @return sagemaker.fw_utils.UserCode: An object with the S3 bucket and key (S3 prefix) and
#'               script name.
#' @export
tar_and_upload_dir <- function(sagemaker_session,
                               bucket,
                               s3_key_prefix,
                               script,
                               directory=NULL,
                               dependencies=NULL,
                               kms_key=NULL){
  if (!is.null(directory) && startsWith(tolower(directory),"s3://")){
    UploadedCode$s3_prefix=directory
    UploadedCode$script_name= basename(script)
    return(UploadedCode)}

  script_name =  if(!is.null(directory)) script else basename(script)
  dependencies = dependencies %||% list()
  key = sprintf("%s/sourcedir.tar.gz",s3_key_prefix)
  tmp = tempfile(fileext = .TAR_SOURCE_FILENAME)

  tryCatch({
    source_files = unlist(c(.list_files_to_compress(script, directory), dependencies))
    tar_file = create_tar_file(source_files, tmp)
  })
  if (!is.null(kms_key)) {
    ServerSideEncryption = "aws:kms"
    SSEKMSKeyId =  kms_key
  } else {
    ServerSideEncryption = NULL
    SSEKMSKeyId =  NULL
  }

  obj <- readBin(tar_file, "raw", n = file.size(tar_file))
  kwargs = list(
    Body = obj,
    Bucket = bucket,
    Key = key
  )
  kwargs[["ServerSideEncryption"]] = ServerSideEncryption
  kwargs[["SSEKMSKeyId"]] = SSEKMSKeyId

  do.call(sagemaker_session$s3$put_object, kwargs)
  on.exit(unlink(tmp, recursive = T))

  UploadedCode$s3_prefix=sprintf("s3://%s/%s",bucket, key)
  UploadedCode$script_name=script_name

  return(UploadedCode)
}

.list_files_to_compress <- function(script, directory){
  if(is.null(directory))
    return(list(script))

  basedir = directory %||% dirname(script)

  return(list.files(basedir, full.names = T))
}

#' @title Extract the framework and Python version from the image name.
#' @param image_uri (str): Image URI, which should be one of the following forms:
#' \itemize{
#'    \item{legacy: '<account>.dkr.ecr.<region>.amazonaws.com/sagemaker-<fw>-<py_ver>-<device>:<container_version>'}
#'    \item{legacy: '<account>.dkr.ecr.<region>.amazonaws.com/sagemaker-<fw>-<py_ver>-<device>:<fw_version>-<device>-<py_ver>'}
#'    \item{current: '<account>.dkr.ecr.<region>.amazonaws.com/sagemaker-<fw>:<fw_version>-<device>-<py_ver>'}
#'    \item{current: '<account>.dkr.ecr.<region>.amazonaws.com/sagemaker-rl-<fw>:<rl_toolkit><rl_version>-<device>-<py_ver>'}
#'    \item{current: '<account>.dkr.ecr.<region>.amazonaws.com/<fw>-<image_scope>:<fw_version>-<device>-<py_ver>'}
#' }
#' @return tuple: A tuple containing:
#' \itemize{
#'    \item{str: The framework name}
#'    \item{str: The Python version}
#'    \item{str: The image tag}
#'    \item{str: If the TensorFlow image is script mode}
#' }
#' @export
framework_name_from_image <- function(image_uri){
  sagemaker_pattern = ECR_URI_PATTERN
  sagemaker_match = unlist(regmatches(image_uri,regexec(ECR_URI_PATTERN,image_uri)))
  sagemaker_match = sagemaker_match[length(sagemaker_match)]
  if (is.na(sagemaker_match) || length(sagemaker_match) == 0)
    return(list(NULL, NULL, NULL, NULL))

  # extract framework, python version and image tag
  # We must support both the legacy and current image name format.
  name_pattern = paste0(
    "^(?:sagemaker(?:-rl)?-)?",
    "(tensorflow|mxnet|chainer|pytorch|scikit-learn|xgboost",
    "|huggingface-tensorflow|huggingface-pytorch)(?:-)?",
    "(scriptmode|training)?",
    ":(.*)-(.*?)-(py2|py3[67]?)(?:.*)$")
  name_match = unlist(regmatches(sagemaker_match,regexec(name_pattern,sagemaker_match)))
  if (length(name_match) > 0){
    fw_pts = as.list(name_match[-1])
    fw_pts = lapply(fw_pts, function(x) if(x =="") NULL else x)
    names(fw_pts) = c("fw", "scriptmode", "ver", "device", "py")
    return(list(fw_pts$fw, fw_pts$py, sprintf("%s-%s-%s", fw_pts$ver, fw_pts$device, fw_pts$py), fw_pts$scriptmode))
  }

  legacy_name_pattern = "^sagemaker-(tensorflow|mxnet)-(py2|py3)-(cpu|gpu):(.*)$"
  legacy_match = unlist(regmatches(sagemaker_match,regexec(legacy_name_pattern,sagemaker_match)))
  if (length(legacy_match) > 0){
    lg_pts = legacy_match[-1]
    return (list(lg_pts[1], lg_pts[2], lg_pts[4], NULL))
  }
  return(list(NULL, NULL, NULL, NULL))
}

#' @title Extract the framework version from the image tag.
#' @param image_tag (str): Image tag, which should take the form
#'           '<framework_version>-<device>-<py_version>'
#' @return str: The framework version.
#' @export
framework_version_from_tag <- function(image_tag){
  tag_pattern = "^(.*)-(cpu|gpu)-(py2|py3[67]?)$"
  tag_match = regmatches(image_tag,regexec(tag_pattern,image_tag))[[1]]
  return(if (length(tag_match) == 0) NULL else tag_match[2])
}

#' @title Returns the s3 key prefix for uploading code during model deployment
#' @description The location returned is a potential concatenation of 2 parts
#'              1. code_location_key_prefix if it exists
#'              2. model_name or a name derived from the image
#' @param code_location_key_prefix (str): the s3 key prefix from code_location
#' @param model_name (str): the name of the model
#' @param image (str): the image from which a default name can be extracted
#' @return str: the key prefix to be used in uploading code
#' @export
model_code_key_prefix <- function(code_location_key_prefix, model_name, image){
  training_job_name = name_from_image(image)
  return(paste0(
    Filter(Negate(is.null),
    list(code_location_key_prefix, model_name %||% training_job_name)), collapse = "/"))
}

#' @title Warn the user that training will not fully leverage all the GPU cores
#' @description Warn the user that training will not fully leverage all the GPU
#'              cores if parameter server is enabled and a multi-GPU instance is selected.
#'              Distributed training with the default parameter server setup doesn't
#'              support multi-GPU instances.
#' @param training_instance_type (str): A string representing the type of training instance selected.
#' @param distribution (dict): A dictionary with information to enable distributed training.
#'              (Defaults to None if distributed training is not enabled.).
#' @export
warn_if_parameter_server_with_multi_gpu <- function(training_instance_type, distribution){
  if (training_instance_type == "local" || is.null(distribution))
    return(invisible(NULL))

  is_multi_gpu_instance = (
    (training_instance_type == "local_gpu" ||
       startsWith(split_str(training_instance_type,  "\\.")[[2]],"p")) &&
      !(training_instance_type %in% SINGLE_GPU_INSTANCE_TYPES)
  )

  ps_enabled = (
    ("parameter_server" %in% names(distribution)) &&
      distribution$parameter_server$enabled %||% FALSE
  )

  if (is_multi_gpu_instance && ps_enabled)
    LOGGER$warn(PARAMETER_SERVER_MULTI_GPU_WARNING)
}

#' @title Check if smdistributed strategy is correctly invoked by the user.
#' @description Currently, two strategies are supported: `dataparallel` or `modelparallel`.
#'              Validate if the user requested strategy is supported.
#'              Currently, only one strategy can be specified at a time. Validate if the user has requested
#'              more than one strategy simultaneously.
#'              Validate if the smdistributed dict arg is syntactically correct.
#'              Additionally, perform strategy-specific validations.
#' @param instance_type (str): A string representing the type of training instance selected.
#' @param framework_name (str): A string representing the name of framework selected.
#' @param framework_version (str): A string representing the framework version selected.
#' @param py_version (str): A string representing the python version selected.
#' @param distribution (dict): A dictionary with information to enable distributed training.
#'           (Defaults to None if distributed training is not enabled.)
#' @param image_uri (str): A string representing a Docker image URI.
#' @export
validate_smdistributed <- function(instance_type,
                                   framework_name,
                                   framework_version,
                                   py_version,
                                   distribution,
                                   image_uri=NULL){
  if (!("smdistributed" %in% names(distribution))){
    # Distribution strategy other than smdistributed is selected
    return(NULL)
  }

  # distribution contains smdistributed
  smdistributed = distribution$smdistributed
  if (!is_list_named(smdistributed))
    ValueError$new("smdistributed strategy requires to be a named list")

  if (length(smdistributed) > 1){
    # more than 1 smdistributed strategy requested by the user
    err_msg = paste(
      "Cannot use more than 1 smdistributed strategy.\n",
      "Choose one of the following supported strategies:",
      paste(SMDISTRIBUTED_SUPPORTED_STRATEGIES, collapse = ", "))
    ValueError$new(err_msg)
  }
  # validate if smdistributed strategy is supported
  # currently this for loop essentially checks for only 1 key
  for (strategy in names(smdistributed)){
    if (!(strategy %in% SMDISTRIBUTED_SUPPORTED_STRATEGIES)){
      err_msg = paste(
        sprintf("Invalid smdistributed strategy provided: %s\n", strategy),
        sprintf("Supported strategies: %s", paste(SMDISTRIBUTED_SUPPORTED_STRATEGIES, collapse = ", "))
      )
      ValueError$new(err_msg)
    }
  }

  # smdataparallel-specific input validation
  if ("dataparallel" %in% names(smdistributed)){
    .validate_smdataparallel_args(
      instance_type, framework_name, framework_version, py_version, distribution, image_uri
    )
  }
}

#' @title Check if request is using unsupported arguments.
#' @description Validate if user specifies a supported instance type, framework version, and python
#'              version.
#' @param instance_type (str): A string representing the type of training instance selected. Ex: `ml.p3.16xlarge`
#' @param framework_name (str): A string representing the name of framework selected. Ex: `tensorflow`
#' @param framework_version (str): A string representing the framework version selected. Ex: `2.3.1`
#' @param py_version (str): A string representing the python version selected. Ex: `py3`
#' @param distribution (dict): A dictionary with information to enable distributed training.
#'              (Defaults to None if distributed training is not enabled.)
#' @keywords internal
#' @export
.validate_smdataparallel_args <- function(instance_type,
                                          framework_name,
                                          framework_version,
                                          py_version,
                                          distribution,
                                          image_uri=NULL){
  smdataparallel_enabled = distribution$smdistributed$dataparallel$enabled %||% FALSE

  if (!smdataparallel_enabled)
    return(NULL)

  is_instance_type_supported = instance_type %in% SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES

  err_msg = ""

  if (!is_instance_type_supported){
    # instance_type is required
    err_msg = paste0(err_msg,
                     sprintf("Provided instance_type %s is not supported by smdataparallel.\n",instance_type),
                     sprintf("Please specify one of the supported instance types: %s\n",
                             paste(SMDISTRIBUTED_SUPPORTED_STRATEGIES, collapse = ", ")))
  }

  if (is.null(image_uri)){
    # ignore framework_version & py_version if image_uri is set
    # in case image_uri is not set, then both are mandatory
    supported = SM_DATAPARALLEL_SUPPORTED_FRAMEWORK_VERSIONS[[framework_name]]
    if (!(framework_version %in% supported)){
      err_msg = paste0(err_msg,
                       sprintf("Provided framework_version %s is not supported by", framework_version),
                       " smdataparallel.\n",
                       sprintf("Please specify one of the supported framework versions: %s \n", paste(supported, collapse = ", ")))
    }
    if (!("py3" %in% py_version)){
      err_msg = paste0(err_msg,
                       sprintf("Provided py_version %s is not supported by smdataparallel.\n", py_version),
                       "Please specify py_version=py3")
    }
  }
  if (nchar(err_msg) > 0)
    ValueError$new(err_msg)
}

#' @title Returns boolean indicating whether the region supports Amazon SageMaker Debugger.
#' @param region_name (str): Name of the region to check against.
#' @return bool: Whether or not the region supports Amazon SageMaker Debugger.
#' @keywords internal
#' @export
.region_supports_debugger <- function(region_name){
  return (!(tolower(region_name) %in% DEBUGGER_UNSUPPORTED_REGIONS))
}

#' @title Returns bool indicating whether region supports Amazon SageMaker Debugger profiling feature.
#' @param region_name (str): Name of the region to check against.
#' @return bool: Whether or not the region supports Amazon SageMaker Debugger profiling feature.
#' @keywords internal
#' @export
.region_supports_profiler <- function(region_name){
  return(!(tolower(region_name) %in% PROFILER_UNSUPPORTED_REGIONS))
}

#' @title Checks if version or image arguments are specified.
#' @description Validates framework and model arguments to enforce version or image specification.
#' @param framework_version (str): The version of the framework.
#' @param py_version (str): The version of Python.
#' @param image_uri (str): The URI of the image.
#' @export
validate_version_or_image_args <- function(framework_version, py_version, image_uri){
  if ((is.null(framework_version) || is.null(py_version)) && is.null(image_uri))
    ValueError$new(
      "`framework_version` or `py_version` was NULL, yet `image_uri` was also NULL.",
      " Either specify both `framework_version` and `py_version`, or specify `image_uri`."
    )
}


#' @title Raise warning for deprecated python versions
#' @param framework (str): model framework
#' @param latest_supported_version (str): latest supported version
#' @export
python_deprecation_warning <- function(framework, latest_supported_version){
  return(sprintf(PYTHON_2_DEPRECATION_WARNING,
                 latest_supported_version, framework, framework, framework))
}
DyfanJones/sagemaker-r-local documentation built on June 14, 2022, 10:32 p.m.