# NOTE: This code has been modified from AWS Sagemaker Python:
# https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/model.py
#' @include r_utils.R
#' @import jsonlite
#' @import R6
#' @import sagemaker.core
#' @import sagemaker.common
#' @import lgr
#' @importFrom stats setNames
NEO_ALLOWED_FRAMEWORKS <- list("mxnet", "tensorflow", "keras", "pytorch", "onnx", "xgboost", "tflite")
#' @title An object that encapsulates a trained model.
#' @description Models can be deployed to compute services like a SageMaker ``Endpoint``
#' or Lambda. Deployed models can be used to perform real-time inference.
#' @keywords internal
#' @export
ModelBase = R6Class("ModelBase",
public = list(
#' @description Deploy this model to a compute service.
#' @param ... : not currently implemented
deploy = function(...){
NotImplementedError$new()
},
#' @description Destroy resources associated with this model.
#' @param ... : not currently implemented
delete_model = function(...){
NotImplementedError$new()
},
#' @description format class
format = function(){
format_class(self)
},
#' @description Return class documentation
help = function(){
cls_help(self)
}
)
)
#' @title Model Class
#' @description A SageMaker ``Model`` that can be deployed to an ``Endpoint``.
#' @export
Model = R6Class("Model",
inherit = ModelBase,
public = list(
#' @description Creates a new instance of this [R6][R6::R6Class] class.
#' @param image_uri (str): A Docker image URI.
#' @param model_data (str): The S3 location of a SageMaker model data
#' ``.tar.gz`` file.
#' @param role (str): An AWS IAM role (either name or full ARN). The Amazon
#' SageMaker training jobs and APIs that create Amazon SageMaker
#' endpoints use this role to access training data and model
#' artifacts. After the endpoint is created, the inference code
#' might use the IAM role if it needs to access some AWS resources.
#' It can be null if this is being used to create a Model to pass
#' to a ``PipelineModel`` which has its own Role field. (Default:
#' NULL)
#' @param predictor_cls (callable[string, :Session]): A
#' function to call to create a predictor (default: None). If not
#' None, ``deploy`` will return the result of invoking this
#' function on the created endpoint name.
#' @param env (dict[str, str]): Environment variables to run with ``image``
#' when hosted in SageMaker (Default: NULL).
#' @param name (str): The model name. If None, a default model name will be
#' selected on each ``deploy``.
#' @param vpc_config (dict[str, list[str]]): The VpcConfig set on the model
#' (Default: NULL)
#' \itemize{
#' \item{\strong{'Subnets' (list[str])} List of subnet ids.}
#' \item{\strong{'SecurityGroupIds' (list[str]):} List of security group ids.}}
#' @param sagemaker_session (:Session): A SageMaker Session
#' object, used for SageMaker interactions (default: None). If not
#' specified, one is created using the default AWS configuration
#' chain.
#' @param enable_network_isolation (Boolean): Default False. if True, enables
#' network isolation in the endpoint, isolating the model
#' container. No inbound or outbound network calls can be made to
#' or from the model container.
#' @param model_kms_key (str): KMS key ARN used to encrypt the repacked
#' model archive file if the model is repacked
#' @param image_config (dict[str, str]): Specifies whether the image of
#' model container is pulled from ECR, or private registry in your
#' VPC. By default it is set to pull model container image from
#' ECR. (default: None).
initialize = function(image_uri,
model_data=NULL,
role=NULL,
predictor_cls=NULL,
env=NULL,
name=NULL,
vpc_config=NULL,
sagemaker_session=NULL,
enable_network_isolation=FALSE,
model_kms_key=NULL,
image_config=NULL){
self$model_data = model_data
self$image_uri = image_uri
self$role = role
self$predictor_cls = predictor_cls
self$env = env %||% list()
self$name = name
self$.base_name = NULL
self$vpc_config = vpc_config
self$sagemaker_session = sagemaker_session %||% Session$new()
self$endpoint_name = NULL
private$.is_compiled_model = FALSE
private$.compilation_job_name = NULL
private$.is_edge_packaged_model = FALSE
self$.enable_network_isolation = enable_network_isolation
self$model_kms_key = model_kms_key
self$image_config = image_config
},
#' @description Creates a model package for creating SageMaker models or listing on Marketplace.
#' @param content_types (list): The supported MIME types for the input data (default: None).
#' @param response_types (list): The supported MIME types for the output data (default: None).
#' @param inference_instances (list): A list of the instance types that are used to
#' generate inferences in real-time (default: None).
#' @param transform_instances (list): A list of the instance types on which a transformation
#' job can be run or on which an endpoint can be deployed (default: None).
#' @param model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
#' using `model_package_name` makes the Model Package un-versioned (default: None).
#' @param model_package_group_name (str): Model Package Group name, exclusive to
#' `model_package_name`, using `model_package_group_name` makes the Model Package
#' versioned (default: None).
#' @param image_uri (str): Inference image uri for the container. Model class' self.image will
#' be used if it is None (default: None).
#' @param model_metrics (ModelMetrics): ModelMetrics object (default: None).
#' @param metadata_properties (MetadataProperties): MetadataProperties object (default: None).
#' @param marketplace_cert (bool): A boolean value indicating if the Model Package is certified
#' for AWS Marketplace (default: False).
#' @param approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
#' or "PendingManualApproval" (default: "PendingManualApproval").
#' @param description (str): Model Package description (default: None).
#' @param drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
#' @return str: A string of SageMaker Model Package ARN.
register = function(content_types,
response_types,
inference_instances,
transform_instances,
model_package_name=NULL,
model_package_group_name=NULL,
image_uri=NULL,
model_metrics=NULL,
metadata_properties=NULL,
marketplace_cert=FALSE,
approval_status=NULL,
description=NULL,
drift_check_baselines=NULL){
if (is.null(self$model_data))
ValueError$new("SageMaker Model Package cannot be created without model data.")
model_pkg_args = get_model_package_args(
content_types,
response_types,
inference_instances,
transform_instances,
model_package_name,
model_package_group_name,
self$model_data,
image_uri %||% self$image_uri,
model_metrics,
metadata_properties,
marketplace_cert,
approval_status,
description,
drift_check_baselines=drift_check_baselines
)
model_package = do.call(self$sagemaker_session$create_model_package_from_containers, model_pkg_args)
return(ModelPackage$new(
role=self$role,
model_data=self$model_data,
model_package_arn=model_package$ModelPackageArn)
)
},
#' @description Return a dict created by ``sagemaker.container_def()`` for deploying
#' this model to a specified instance type.
#' Subclasses can override this to provide custom container definitions
#' for deployment to a specific instance type. Called by ``deploy()``.
#' @param instance_type (str): The EC2 instance type to deploy this Model to.
#' For example, 'ml.p2.xlarge'.
#' @param accelerator_type (str): The Elastic Inference accelerator type to
#' deploy to the instance for loading and making inferences to the
#' model. For example, 'ml.eia1.medium'.
#' @return dict: A container definition object usable with the CreateModel API.
prepare_container_def = function(instance_type,
accelerator_type=NULL){
return(container_def(self$image_uri, self$model_data, self$env, image_config=self$image_config))
},
#' @description Whether to enable network isolation when creating this Model
#' @return bool: If network isolation should be enabled or not.
enable_network_isolation = function(){
return (self$.enable_network_isolation)
},
#' @description Check if this ``Model`` in the available region where neo support.
#' @param region (str): Specifies the region where want to execute compilation
#' @return bool: boolean value whether if neo is available in the specified
#' region
check_neo_region = function(region){
if(region %in% names(NEO_IMAGE_ACCOUNT)) return(TRUE)
return(FALSE)
},
#' @description Package this ``Model`` with SageMaker Edge.
#' Creates a new EdgePackagingJob and wait for it to finish.
#' model_data will now point to the packaged artifacts.
#' @param output_path (str): Specifies where to store the packaged model
#' @param role (str): Execution role
#' @param model_name (str): the name to attach to the model metadata
#' @param model_version (str): the version to attach to the model metadata
#' @param job_name (str): The name of the edge packaging job
#' @param resource_key (str): the kms key to encrypt the disk with
#' @param s3_kms_key (str): the kms key to encrypt the output with
#' @param tags (list[dict]): List of tags for labeling an edge packaging job. For
#' more, see
#' https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
#' @return sagemaker.model.Model: A SageMaker ``Model`` object. See
#' :func:`~sagemaker.model.Model` for full details.
package_for_edge = function(output_path,
model_name,
model_version,
role=NULL,
job_name=NULL,
resource_key=NULL,
s3_kms_key=NULL,
tags=NULL){
if (is.null(private$.compilation_job_name))
ValueError$new("You must first compile this model")
if (is.null(job_name))
job_name = sprintf("packaging%s", substr(private$.compilation_job_name, 12, nchar(private$.compilation_job_name)))
if (is.null(role))
role = self$sagemaker_session$expand_role(role)
private$.init_sagemaker_session_if_does_not_exist(NULL)
config = private$.edge_packaging_job_config(
output_path,
role,
model_name,
model_version,
job_name,
private$.compilation_job_name,
resource_key,
s3_kms_key,
tags
)
do.call(self$sagemaker_session.package_model_for_edge, config)
job_status = self$sagemaker_session$wait_for_edge_packaging_job(job_name)
self$model_data = job_status$ModelArtifact
private$.is_edge_packaged_model = TRUE
return(self)
},
#' @description Compile this ``Model`` with SageMaker Neo.
#' @param target_instance_family (str): Identifies the device that you want to
#' run your model after compilation, for example: ml_c5. For allowed
#' strings see
#' https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
#' @param input_shape (list): Specifies the name and shape of the expected
#' inputs for your trained model in json dictionary form, for
#' example: \code{list('data'= list(1,3,1024,1024)), or list('var1'= list(1,1,28,28),
#' 'var2'= list(1,1,28,28))}
#' @param output_path (str): Specifies where to store the compiled model
#' @param role (str): Execution role
#' @param tags (list[dict]): List of tags for labeling a compilation job. For
#' more, see
#' https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
#' @param job_name (str): The name of the compilation job
#' @param compile_max_run (int): Timeout in seconds for compilation (default:
#' 3 * 60). After this amount of time Amazon SageMaker Neo
#' terminates the compilation job regardless of its current status.
#' @param framework (str): The framework that is used to train the original
#' model. Allowed values: 'mxnet', 'tensorflow', 'keras', 'pytorch',
#' 'onnx', 'xgboost'
#' @param framework_version (str):
#' @param target_platform_os (str): Target Platform OS, for example: 'LINUX'.
#' For allowed strings see
#' https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
#' It can be used instead of target_instance_family.
#' @param target_platform_arch (str): Target Platform Architecture, for example: 'X86_64'.
#' For allowed strings see
#' https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
#' It can be used instead of target_instance_family.
#' @param target_platform_accelerator (str, optional): Target Platform Accelerator,
#' for example: 'NVIDIA'. For allowed strings see
#' https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
#' It can be used instead of target_instance_family.
#' @param compiler_options (dict, optional): Additional parameters for compiler.
#' Compiler Options are TargetPlatform / target_instance_family specific. See
#' https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html for details.
#' @return sagemaker.model.Model: A SageMaker ``Model`` object. See
#' :func:`~sagemaker.model.Model` for full details.
compile = function(target_instance_family,
input_shape,
output_path,
role,
tags=NULL,
job_name=NULL,
compile_max_run=5 * 60,
framework=NULL,
framework_version=NULL,
target_platform_os=NULL,
target_platform_arch=NULL,
target_platform_accelerator=NULL,
compiler_options=NULL){
framework = framework %||% private$.framework()
if (is.null(framework))
ValueError$new(
sprintf("You must specify framework, allowed values %s",
paste0(NEO_ALLOWED_FRAMEWORKS, collapse = ", ")))
if (!(framework %in% NEO_ALLOWED_FRAMEWORKS))
ValueError$new(
sprintf("You must provide valid framework, allowed values %s",
paste0(NEO_ALLOWED_FRAMEWORKS, collapse = ", ")))
if(is.null(job_name))
ValueError$new("You must provide a compilation job name")
if (is.null(self$model_data))
ValueError$new("You must provide an S3 path to the compressed model artifacts.")
framework_version = framework_version %||% private$.get_framework_version()
private$.init_sagemaker_session_if_does_not_exist(target_instance_family)
config = private$.compilation_job_config(
target_instance_family,
input_shape,
output_path,
role,
compile_max_run,
job_name,
framework,
tags,
target_platform_os,
target_platform_arch,
target_platform_accelerator,
compiler_options,
framework_version
)
do.call(self$sagemaker_session$compile_model, config)
job_status = self$sagemaker_session$wait_for_compilation_job(job_name)
self$model_data = job_status[["ModelArtifacts"]][["S3ModelArtifacts"]]
if (!is.null(target_instance_family)){
if(target_instance_family == "ml_eia2"){
invisible() # used for python pass
} else if (grepl("^ml_", target_instance_family)){
self$image_uri = private$.compilation_image_uri(
self$sagemaker_session$paws_region_name,
target_instance_family,
framework,
framework_version)
private$.is_compiled_model = TRUE
} else {
LOGGER$warn(paste(
"The instance type %s is not supported for deployment via SageMaker.",
"Please deploy the model manually."),
target_instance_family)
}
} else {
LOGGER$warn(paste(
"Devices described by Target Platform OS, Architecture and Accelerator are not",
"supported for deployment via SageMaker. Please deploy the model manually."))
}
private$.compilation_job_name = job_name
return(self)
},
#' @description Deploy this ``Model`` to an ``Endpoint`` and optionally return a
#' ``Predictor``.
#' Create a SageMaker ``Model`` and ``EndpointConfig``, and deploy an
#' ``Endpoint`` from this ``Model``. If ``self.predictor_cls`` is not None,
#' this method returns a the result of invoking ``self.predictor_cls`` on
#' the created endpoint name.
#' The name of the created model is accessible in the ``name`` field of
#' this ``Model`` after deploy returns
#' The name of the created endpoint is accessible in the
#' ``endpoint_name`` field of this ``Model`` after deploy returns.
#' @param initial_instance_count (int): The initial number of instances to run
#' in the ``Endpoint`` created from this ``Model``.
#' @param instance_type (str): The EC2 instance type to deploy this Model to.
#' For example, 'ml.p2.xlarge', or 'local' for local mode.
#' @param serializer (:class:`~sagemaker.serializers.BaseSerializer`): A
#' serializer object, used to encode data for an inference endpoint
#' (default: None). If ``serializer`` is not None, then
#' ``serializer`` will override the default serializer. The
#' default serializer is set by the ``predictor_cls``.
#' @param deserializer (:class:`~sagemaker.deserializers.BaseDeserializer`): A
#' deserializer object, used to decode data from an inference
#' endpoint (default: None). If ``deserializer`` is not None, then
#' ``deserializer`` will override the default deserializer. The
#' default deserializer is set by the ``predictor_cls``.
#' @param accelerator_type (str): Type of Elastic Inference accelerator to
#' deploy this model for model loading and inference, for example,
#' 'ml.eia1.medium'. If not specified, no Elastic Inference
#' accelerator will be attached to the endpoint. For more
#' information:
#' https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html
#' @param endpoint_name (str): The name of the endpoint to create (Default:
#' NULL). If not specified, a unique endpoint name will be created.
#' @param tags (List[dict[str, str]]): The list of tags to attach to this
#' specific endpoint.
#' @param kms_key (str): The ARN of the KMS key that is used to encrypt the
#' data on the storage volume attached to the instance hosting the
#' endpoint.
#' @param wait (bool): Whether the call should wait until the deployment of
#' this model completes (default: True).
#' @param data_capture_config (sagemaker.model_monitor.DataCaptureConfig): Specifies
#' configuration related to Endpoint data capture for use with
#' Amazon SageMaker Model Monitoring. Default: None.
#' @param serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
#' Specifies configuration related to serverless endpoint. Use this configuration
#' when trying to create serverless endpoint and make serverless inference. If
#' empty object passed through, we will use pre-defined values in
#' ``ServerlessInferenceConfig`` class to deploy serverless endpoint (default: None)
#' @param ... : pass deprecated parameters.
#' @return callable[string, sagemaker.session.Session] or None: Invocation of
#' ``self.predictor_cls`` on the created endpoint name, if ``self.predictor_cls``
#' is not None. Otherwise, return None.
deploy = function(initial_instance_count=NULL,
instance_type=NULL,
serializer=NULL,
deserializer=NULL,
accelerator_type=NULL,
endpoint_name=NULL,
tags=NULL,
kms_key=NULL,
wait=TRUE,
data_capture_config=NULL,
serverless_inference_config=NULL,
...){
kwargs = list(...)
removed_kwargs("update_endpoint", kwargs)
private$.init_sagemaker_session_if_does_not_exist(instance_type)
if(is.null(self$role))
ValueError$new("Role can not be null for deploying a model")
is_serverless = !is.null(serverless_inference_config)
if (!is_serverless && (is.null(instance_type) && is.null(initial_instance_count)))
ValueError$new(
"Must specify instance type and instance count unless using serverless inference"
)
if(is_serverless && !inherits(serverless_inference_config, "ServerlessInferenceConfig"))
ValueError$new(
"serverless_inference_config needs to be a ServerlessInferenceConfig object"
)
if (startsWith(instance_type,"ml.inf") && isFALSE(private$.is_compiled_model))
LOGGER$warn("Your model is not compiled. Please compile your model before using Inferentia.")
compiled_model_suffix = gsub("\\.", "-", instance_type)
if (isTRUE(private$.is_compiled_model)){
private$.ensure_base_name_if_needed(self$image_uri)
if(!is.null(self$.base_name))
self$.base_name = paste(self$.base_name, compiled_model_suffix, sep = "-", collapse = "-")
}
self$.create_sagemaker_model(instance_type, accelerator_type, tags)
serverless_inference_config_list = (
if (is_serverless) serverless_inference_config$to_request_list() else NULL
)
prod_variant = production_variant(
self$name,
instance_type,
initial_instance_count,
accelerator_type=accelerator_type,
serverless_inference_config=serverless_inference_config_list
)
if (!is.null(endpoint_name)) {
self$endpoint_name = endpoint_name
} else {
base_endpoint_name = self$.base_name %||% sagemaker.core::base_from_name(self$name)
if(!is.null(private$.is_compiled_model) && !is_serverless){
if (!endsWith(base_endpoint_name, compiled_model_suffix))
base_endpoint_name = paste(base_endpoint_name, compiled_model_suffix, sep = "-", collapse = "-")
}
self$endpoint_name = name_from_base(base_endpoint_name)
}
data_capture_config_list = NULL
if (!is.null(data_capture_config))
data_capture_config_list = data_capture_config$to_request_list()
self$sagemaker_session$endpoint_from_production_variants(
name=self$endpoint_name,
production_variants=list(prod_variant),
tags=tags,
kms_key=kms_key,
wait=wait,
data_capture_config_list=data_capture_config_list
)
if (!is.null(self$predictor_cls)){
predictor = self$predictor_cls$new(self$endpoint_name, self$sagemaker_session)
if (!is.null(serializer))
predictor$serializer = serializer
if (!is.null(deserializer))
predictor$deserializer = deserializer
return(predictor)
}
return(invisible(NULL))
},
#' @description Return a ``Transformer`` that uses this Model.
#' @param instance_count (int): Number of EC2 instances to use.
#' @param instance_type (str): Type of EC2 instance to use, for example,
#' 'ml.c4.xlarge'.
#' @param strategy (str): The strategy used to decide how to batch records in
#' a single request (default: None). Valid values: 'MultiRecord'
#' and 'SingleRecord'.
#' @param assemble_with (str): How the output is assembled (default: None).
#' Valid values: 'Line' or 'None'.
#' @param output_path (str): S3 location for saving the transform result. If
#' not specified, results are stored to a default bucket.
#' @param output_kms_key (str): Optional. KMS key ID for encrypting the
#' transform output (default: None).
#' @param accept (str): The accept header passed by the client to
#' the inference endpoint. If it is supported by the endpoint,
#' it will be the format of the batch transform output.
#' @param env (dict): Environment variables to be set for use during the
#' transform job (default: None).
#' @param max_concurrent_transforms (int): The maximum number of HTTP requests
#' to be made to each individual transform container at one time.
#' @param max_payload (int): Maximum size of the payload in a single HTTP
#' request to the container in MB.
#' @param tags (list[dict]): List of tags for labeling a transform job. If
#' none specified, then the tags used for the training job are used
#' for the transform job.
#' @param volume_kms_key (str): Optional. KMS key ID for encrypting the volume
#' attached to the ML compute instance (default: None).
transformer = function(instance_count,
instance_type,
strategy=NULL,
assemble_with=NULL,
output_path=NULL,
output_kms_key=NULL,
accept=NULL,
env=NULL,
max_concurrent_transforms=NULL,
max_payload=NULL,
tags=NULL,
volume_kms_key=NULL){
private$.init_sagemaker_session_if_does_not_exist(instance_type)
self$.create_sagemaker_model(instance_type, tags=tags)
if (self$enable_network_isolation())
env = NULL
return(Transformer$new(
self$name,
instance_count,
instance_type,
strategy=strategy,
assemble_with=assemble_with,
output_path=output_path,
output_kms_key=output_kms_key,
accept=accept,
max_concurrent_transforms=max_concurrent_transforms,
max_payload=max_payload,
env=env,
tags=tags,
base_transform_job_name=self$.base_name %||% self$name,
volume_kms_key=volume_kms_key,
sagemaker_session=self$sagemaker_session)
)
},
#' @description Delete an Amazon SageMaker Model.
delete_model = function(){
if(is.null(self$name))
ValueError$new("The SageMaker model must be created first before attempting to delete.")
self$sagemaker_session$delete_model(self$name)
},
#' @description Create a SageMaker Model Entity
#' @param instance_type (str): The EC2 instance type that this Model will be
#' used for, this is only used to determine if the image needs GPU
#' support or not.
#' @param accelerator_type (str): Type of Elastic Inference accelerator to
#' attach to an endpoint for model loading and inference, for
#' example, 'ml.eia1.medium'. If not specified, no Elastic
#' Inference accelerator will be attached to the endpoint.
#' @param tags (List[dict[str, str]]): Optional. The list of tags to add to
#' the model. Example: >>> tags = [{'Key': 'tagname', 'Value':
#' 'tagvalue'}] For more information about tags, see
#' \url{https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags}.
.create_sagemaker_model = function(instance_type,
accelerator_type=NULL,
tags=NULL){
container_def = self$prepare_container_def(instance_type, accelerator_type=accelerator_type)
private$.ensure_base_name_if_needed(container_def$Image)
private$.set_model_name_if_needed()
enable_network_isolation = self$enable_network_isolation()
private$.init_sagemaker_session_if_does_not_exist(instance_type)
self$sagemaker_session$create_model(
self$name,
self$role,
container_def,
vpc_config=self$vpc_config,
enable_network_isolation=enable_network_isolation,
tags=tags)
}
),
private = list(
.is_compiled_model = NULL,
.compilation_job_name = NULL,
.is_edge_packaged_model = NULL,
# Set ``self.sagemaker_session`` to be a ``LocalSession`` or
# ``Session`` if it is not already. The type of session object is
# determined by the instance type.
.init_sagemaker_session_if_does_not_exist = function(instance_type){
if (!is.null(self$sagemaker_session))
return(invisible(NULL))
if (instance_type %in% c("local", "local_gpu")){
self$sagemaker_session = LocalSession$new()
} else {
self$sagemaker_session = Session$new()}
},
# Create a base name from the image URI if there is no model name provided.
.ensure_base_name_if_needed = function(image_uri){
if (is.null(self$name))
self$.base_name = self$.base_name %||% base_name_from_image(image_uri)
},
# Generate a new model name if ``self._base_name`` is present.
.set_model_name_if_needed = function(){
if (!is.null(self$.base_name))
self$name = name_from_base(self$.base_name)
},
.framework = function(){
return(attr(self, "_framework_name"))
},
.get_framework_version = function(obj){
return(self$framework_version)
},
# Creates a request object for a packaging job.
# Args:
# output_path (str): where in S3 to store the output of the job
# role (str): what role to use when executing the job
# packaging_job_name (str): what to name the packaging job
# compilation_job_name (str): what compilation job to source the model from
# resource_key (str): the kms key to encrypt the disk with
# s3_kms_key (str): the kms key to encrypt the output with
# tags (list[dict]): List of tags for labeling an edge packaging job. For
# more, see
# https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
# Returns:
# dict: the request object to use when creating a packaging job
.edge_packaging_job_config = function(output_path,
role,
model_name,
model_version,
packaging_job_name,
compilation_job_name,
resource_key,
s3_kms_key,
tags){
output_model_config = list(
"S3OutputLocation"= output_path)
output_model_config$KmsKeyId = s3_kms_key
return(list(
"output_model_config"= output_model_config,
"role"= role,
"tags"= tags,
"model_name"= model_name,
"model_version"= model_version,
"job_name"= packaging_job_name,
"compilation_job_name"= compilation_job_name,
"resource_key"= resource_key)
)
},
.compilation_job_config = function(target_instance_type,
input_shape,
output_path,
role,
compile_max_run,
job_name,
framework,
tags,
target_platform_os=NULL,
target_platform_arch=NULL,
target_platform_accelerator=NULL,
compiler_options=NULL,
framework_version=NULL){
input_model_config = list(
"S3Uri" = self$model_data,
"DataInputConfig" = input_shape,
"Framework" = toupper(framework)
)
if(tolower(framework) == "pytorch"
&& grepl("(?=^ml_)(?!ml_inf)", target_instance_type, perl = TRUE)
&& !is.null(framework_version)){
input_model_config[["FrameworkVersion"]] = get_short_version(framework_version)
}
role = self$sagemaker_session$expand_role(role)
output_model_config = list(
"S3OutputLocation" = output_path)
if(!is.null(target_instance_type)){
output_model_config$TargetDevice = target_instance_type
} else {
if (is.null(target_platform_os) && is.null(target_platform_arch))
ValueError$new(
"target_instance_type or (target_platform_os and target_platform_arch) ",
"should be provided"
)
target_platform = list(
"Os"= target_platform_os,
"Arch"= target_platform_arch
)
if (!is.null(target_platform_accelerator))
target_platform$Accelerator = target_platform_accelerator
output_model_config$TargetPlatform = target_platform
}
if (!is.null(compiler_options)){
output_model_config$CompilerOptions = compiler_options
}
return(list(
"input_model_config"= input_model_config,
"output_model_config"= output_model_config,
"role"= role,
"stop_condition"= list("MaxRuntimeInSeconds"= compile_max_run),
"tags"= tags,
"job_name"= job_name))
},
# Constructs a dict of arguments for an Amazon SageMaker compilation job from estimator.
# Args:
# estimator (sagemaker.estimator.EstimatorBase): Estimator object
# created by the user.
# inputs (CompilationInput): class containing all the parameters that
# can be used when calling ``sagemaker.model.Model.compile_model()``
.get_compilation_args = function(estimator, inputs){
if (!inherits(inputs, "CompilationInput"))
TypeError$new("Your inputs must be provided as CompilationInput objects.")
target_instance_family = inputs$target_instance_type
input_shape = inputs$input_shape
output_path = inputs$output_path
role = estimator$role
compile_max_run = inputs.compile_max_run
job_name = estimator$.__enclos_env__$private$.compilation_job_name()
framework = inputs$framework %||% private$.framework()
if (is.null(framework))
ValueError$new(sprintf(
"You must specify framework, allowed values %s",
paste(NEO_ALLOWED_FRAMEWORKS, collapse = ", ")
))
if (!(framework %in% NEO_ALLOWED_FRAMEWORKS))
ValueError$new(sprintf(
"You must provide valid framework, allowed values %s",
paste(NEO_ALLOWED_FRAMEWORKS, collapse = ", ")
))
if (is.null(self$model_data))
ValueError$new("You must provide an S3 path to the compressed model artifacts.")
tags = inputs$tags
target_platform_os = inputs$target_platform_os
target_platform_arch = inputs$target_platform_arch
target_platform_accelerator = inputs$target_platform_accelerator
compiler_options = inputs$compiler_options
framework_version = inputs$framework_version %||% private$.get_framework_version()
return(private$.compilation_job_config(
target_instance_family,
input_shape,
output_path,
role,
compile_max_run,
job_name,
framework,
tags,
target_platform_os,
target_platform_arch,
target_platform_accelerator,
compiler_options,
framework_version)
)
},
# Retrieve the Neo or Inferentia image URI.
# Args:
# region (str): The AWS region.
# target_instance_type (str): Identifies the device on which you want to run
# your model after compilation, for example: ml_c5. For valid values, see
# https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
# framework (str): The framework name.
# framework_version (str): The framework version.
.compilation_image_uri = function(region,
target_instance_type,
framework,
framework_version){
framework_prefix = ""
framework_suffix = ""
if (framework == "xgboost")
framework_suffix = "-neo"
else if (grepl("^ml_inf", target_instance_type))
framework_prefix = "inferentia-"
else
framework_prefix = "neo-"
return(ImageUris$new()$retrieve(
sprintf("%s%s%s", framework_prefix, framework, framework_suffix),
region,
instance_type=target_instance_type,
version=framework_version)
)
}
),
lock_objects = F
)
#' @title An array of parameters for modelling methods
#' @name model_parameters
#' @keywords internal
#' @export
model_parameters <- sagemaker.core::Enum(
SCRIPT_PARAM_NAME = "sagemaker_program",
DIR_PARAM_NAME = "sagemaker_submit_directory",
CLOUDWATCH_METRICS_PARAM_NAME = "sagemaker_enable_cloudwatch_metrics",
CONTAINER_LOG_LEVEL_PARAM_NAME = "sagemaker_container_log_level",
JOB_NAME_PARAM_NAME = "sagemaker_job_name",
MODEL_SERVER_WORKERS_PARAM_NAME = "sagemaker_model_server_workers",
SAGEMAKER_REGION_PARAM_NAME = "sagemaker_region",
SAGEMAKER_OUTPUT_LOCATION = "sagemaker_s3_output"
)
#' @title A Model for working with an SageMaker ``Framework``.
#' @description This class hosts user-defined code in S3 and sets code location and
#' configuration in model environment variables.
#' @export
FrameworkModel = R6Class("FrameworkModel",
inherit = Model,
public = list(
#' @description Initialize a ``FrameworkModel``.
#' @param model_data (str): The S3 location of a SageMaker model data
#' ``.tar.gz`` file.
#' @param image_uri (str): A Docker image URI.
#' @param role (str): An IAM role name or ARN for SageMaker to access AWS
#' resources on your behalf.
#' @param entry_point (str): Path (absolute or relative) to the Python source
#' file which should be executed as the entry point to model
#' hosting. This should be compatible with either Python 2.7 or
#' Python 3.5. If 'git_config' is provided, 'entry_point' should be
#' a relative location to the Python source file in the Git repo.
#' Example
#' With the following GitHub repo directory structure:
#' >>> |----- README.md
#' >>> |----- src
#' >>> |----- inference.py
#' >>> |----- test.py
#' You can assign entry_point='src/inference.py'.
#' @param source_dir (str): Path (absolute, relative or an S3 URI) to a directory
#' with any other training source code dependencies aside from the entry
#' point file (default: None). If ``source_dir`` is an S3 URI, it must
#' point to a tar.gz file. Structure within this directory are preserved
#' when training on Amazon SageMaker. If 'git_config' is provided,
#' 'source_dir' should be a relative location to a directory in the Git repo.
#' If the directory points to S3, no code will be uploaded and the S3 location
#' will be used instead.
#' .. admonition:: Example
#' With the following GitHub repo directory structure:
#' >>> |----- README.md
#' >>> |----- src
#' >>> |----- inference.py
#' >>> |----- test.py
#' You can assign entry_point='inference.py', source_dir='src'.
#' @param predictor_cls (callable[string, sagemaker.session.Session]): A
#' function to call to create a predictor (default: None). If not
#' None, ``deploy`` will return the result of invoking this
#' function on the created endpoint name.
#' @param env (dict[str, str]): Environment variables to run with ``image``
#' when hosted in SageMaker (default: None).
#' @param name (str): The model name. If None, a default model name will be
#' selected on each ``deploy``.
#' @param container_log_level (str): Log level to use within the container
#' (default: "INFO").
#' @param code_location (str): Name of the S3 bucket where custom code is
#' uploaded (default: None). If not specified, default bucket
#' created by ``sagemaker.session.Session`` is used.
#' @param sagemaker_session (sagemaker.session.Session): A SageMaker Session
#' object, used for SageMaker interactions (default: None). If not
#' specified, one is created using the default AWS configuration
#' chain.
#' @param dependencies (list[str]): A list of paths to directories (absolute
#' or relative) with any additional libraries that will be exported
#' to the container (default: []). The library folders will be
#' copied to SageMaker in the same folder where the entrypoint is
#' copied. If 'git_config' is provided, 'dependencies' should be a
#' list of relative locations to directories with any additional
#' libraries needed in the Git repo. If the ```source_dir``` points
#' to S3, code will be uploaded and the S3 location will be used
#' instead. .. admonition:: Example
#' The following call >>> Estimator(entry_point='inference.py',
#' dependencies=['my/libs/common', 'virtual-env']) results in
#' the following inside the container:
#' >>> $ ls
#' >>> opt/ml/code
#' >>> |------ inference.py
#' >>> |------ common
#' >>> |------ virtual-env
#' @param git_config (dict[str, str]): Git configurations used for cloning
#' files, including ``repo``, ``branch``, ``commit``,
#' ``2FA_enabled``, ``username``, ``password`` and ``token``. The
#' ``repo`` field is required. All other fields are optional.
#' ``repo`` specifies the Git repository where your training script
#' is stored. If you don't provide ``branch``, the default value
#' 'master' is used. If you don't provide ``commit``, the latest
#' commit in the specified branch is used. .. admonition:: Example
#' The following config:
#' >>> git_config = {'repo': 'https://github.com/aws/sagemaker-python-sdk.git',
#' >>> 'branch': 'test-branch-git-config',
#' >>> 'commit': '329bfcf884482002c05ff7f44f62599ebc9f445a'}
#' results in cloning the repo specified in 'repo', then
#' checkout the 'master' branch, and checkout the specified
#' commit.
#' ``2FA_enabled``, ``username``, ``password`` and ``token`` are
#' used for authentication. For GitHub (or other Git) accounts, set
#' ``2FA_enabled`` to 'True' if two-factor authentication is
#' enabled for the account, otherwise set it to 'False'. If you do
#' not provide a value for ``2FA_enabled``, a default value of
#' 'False' is used. CodeCommit does not support two-factor
#' authentication, so do not provide "2FA_enabled" with CodeCommit
#' repositories.
#' For GitHub and other Git repos, when SSH URLs are provided, it
#' doesn't matter whether 2FA is enabled or disabled; you should
#' either have no passphrase for the SSH key pairs, or have the
#' ssh-agent configured so that you will not be prompted for SSH
#' passphrase when you do 'git clone' command with SSH URLs. When
#' HTTPS URLs are provided: if 2FA is disabled, then either token
#' or username+password will be used for authentication if provided
#' (token prioritized); if 2FA is enabled, only token will be used
#' for authentication if provided. If required authentication info
#' is not provided, python SDK will try to use local credentials
#' storage to authenticate. If that fails either, an error message
#' will be thrown.
#' For CodeCommit repos, 2FA is not supported, so '2FA_enabled'
#' should not be provided. There is no token in CodeCommit, so
#' 'token' should not be provided too. When 'repo' is an SSH URL,
#' the requirements are the same as GitHub-like repos. When 'repo'
#' is an HTTPS URL, username+password will be used for
#' authentication if they are provided; otherwise, python SDK will
#' try to use either CodeCommit credential helper or local
#' credential storage for authentication.
#' @param ... : Keyword arguments passed to the ``Model`` initializer.
initialize = function(model_data,
image_uri,
role,
entry_point,
source_dir=NULL,
predictor_cls=NULL,
env=NULL,
name=NULL,
container_log_level="INFO",
code_location=NULL,
sagemaker_session=NULL,
dependencies=NULL,
git_config=NULL,
...){
super$initialize(image_uri=image_uri,
model_data=model_data,
role=role,
predictor_cls=predictor_cls,
env=env,
name=name,
sagemaker_session=sagemaker_session,
...)
self$entry_point = entry_point
self$source_dir = source_dir
self$dependencies = dependencies %||% list()
self$git_config = git_config
# Align logging level with python logging
if(!is.numeric(container_log_level)){
container_log_level = switch(toupper(container_log_level),
"DEBUG" = 10,
"INFO" = 20,
"WARN" = 30,
"ERROR" = 40,
"FATAL" = 50,
"CRITICAL" = 50,
container_log_level
)
}
self$container_log_level = as.character(container_log_level)
if (!is.null(code_location)){
s3_parts = parse_s3_url(code_location)
self$bucket =s3_parts$bucket
self$key_prefix = s3_parts$key
} else {
self$bucket = NULL
self$key_prefix = NULL
}
if (!islistempty(self$git_config)){
updates = git_clone_repo(
self$git_config, self$entry_point, self$source_dir, self$dependencies
)
self$entry_point = updates$entry_point
self$source_dir = updates$source_dir
self$dependencies = updates$dependencies}
self$uploaded_code = NULL
self$repacked_model_data = NULL
},
#' @description Return a container definition with framework configuration set in
#' model environment variables.
#' This also uploads user-supplied code to S3.
#' @param instance_type (str): The EC2 instance type to deploy this Model to.
#' For example, 'ml.p2.xlarge'.
#' @param accelerator_type (str): The Elastic Inference accelerator type to
#' deploy to the instance for loading and making inferences to the
#' model. For example, 'ml.eia1.medium'.
#' @return dict[str, str]: A container definition object usable with the
#' CreateModel API.
prepare_container_def = function(instance_type=NULL,
accelerator_type=NULL){
deploy_key_prefix = model_code_key_prefix(
self$key_prefix, self$name, self$image_uri
)
private$.upload_code(deploy_key_prefix)
deploy_env = as.list(self$env)
deploy_env = modifyList(deploy_env, private$.framework_env_vars())
return(container_def(
self$image_uri,
self$repacked_model_data %||% self$model_data,
deploy_env,
image_config=self$image_config
)
)
}
),
private = list(
.upload_code = function(key_prefix, repack=FALSE){
local_code = get_config_value("local.local_code", self$sagemaker_session$config)
if ((isTRUE(self$sagemaker_session$local_mode) && !is.null(local_code)) || is.null(self$entry_point)){
self$uploaded_code = NULL
} else if (isFALSE(repack)){
bucket = self$bucket %||% self$sagemaker_session$default_bucket()
self$uploaded_code = tar_and_upload_dir(
sagemaker_session=self$sagemaker_session,
bucket=bucket,
s3_key_prefix=key_prefix,
script=self$entry_point,
directory=self$source_dir,
dependencies=self$dependencies)
}
if (repack && !is.null(self$model_data) && !is.null(self$entry_point)){
bucket = self$bucket %||% self$sagemaker_session$default_bucket()
repacked_model_data = s3_path_join("s3://", bucket, key_prefix, "model.tar.gz")
repack_model(
inference_script=self$entry_point,
source_directory=self$source_dir,
dependencies=self$dependencies,
model_uri=self$model_data,
repacked_model_uri=repacked_model_data,
sagemaker_session=self$sagemaker_session,
kms_key=self$model_kms_key)
self$repacked_model_data = repacked_model_data
UploadedCode$s3_prefix=self$repacked_model_data
UploadedCode$script_name=basename(self$entry_point)
self$uploaded_code = UploadedCode
}
},
.framework_env_vars = function(){
script_name = NULL
dir_name = NULL
if (!is.null(self$uploaded_code)){
script_name = self$uploaded_code$script_name
if (isTRUE(self$enable_network_isolation()))
dir_name = "/opt/ml/model/code"
else
dir_name = self$uploaded_code$s3_prefix
} else if (!islistempty(self$entry_point)){
script_name = self$entry_point
if(!is.null(self$source_dir))
dir_name = paste0("file://", self$source_dir)
}
output = setNames(list(
script_name,
dir_name,
self$container_log_level,
self$sagemaker_session$paws_region_name),
c(toupper(model_parameters$SCRIPT_PARAM_NAME),
toupper(model_parameters$DIR_PARAM_NAME),
toupper(model_parameters$CONTAINER_LOG_LEVEL_PARAM_NAME),
toupper(model_parameters$SAGEMAKER_REGION_PARAM_NAME))
)
return(output)
}
),
lock_objects = F
)
#' @title ModelPackage class
#' @description A SageMaker ``Model`` that can be deployed to an ``Endpoint``.
#' @export
ModelPackage = R6Class("ModelPackage",
inherit = Model,
public = list(
#' @description Initialize a SageMaker ModelPackage.
#' @param role (str): An AWS IAM role (either name or full ARN). The Amazon
#' SageMaker training jobs and APIs that create Amazon SageMaker
#' endpoints use this role to access training data and model
#' artifacts. After the endpoint is created, the inference code
#' might use the IAM role, if it needs to access an AWS resource.
#' @param model_data (str): The S3 location of a SageMaker model data
#' ``.tar.gz`` file. Must be provided if algorithm_arn is provided.
#' @param algorithm_arn (str): algorithm arn used to train the model, can be
#' just the name if your account owns the algorithm. Must also
#' provide ``model_data``.
#' @param model_package_arn (str): An existing SageMaker Model Package arn,
#' can be just the name if your account owns the Model Package.
#' ``model_data`` is not required.
#' @param ... : Additional kwargs passed to the Model constructor.
initialize = function(role,
model_data=NULL,
algorithm_arn=NULL,
model_package_arn=NULL,
...){
super$initialize(role = role, model_data = model_data, image_uri = NULL, ...)
if(!is.null(model_package_arn) && !is.null(algorithm_arn))
ValueError$new(
"model_package_arn and algorithm_arn are mutually exclusive.",
sprintf("Both were provided: model_package_arn: %s algorithm_arn: %s",
model_package_arn, algorithm_arn))
if (is.null(model_package_arn) && is.null(algorithm_arn))
ValueError$new(
"either model_package_arn or algorithm_arn is required. NULL was provided."
)
self$algorithm_arn = algorithm_arn
if (!is.null(self$algorithm_arn)){
if (is.null(model_data))
ValueError$new("model_data must be provided with algorithm_arn")
self$model_data = model_data
}
self$model_package_arn = model_package_arn
self$.created_model_package_name = NULL
},
#' @description Whether to enable network isolation when creating a model out of this
#' ModelPackage
#' @return bool: If network isolation should be enabled or not.
enable_network_isolation = function(){
return(private$.is_marketplace())
},
#' @description Create a SageMaker Model Entity
#' @param ... : Positional arguments coming from the caller. This class does not require
#' any so they are ignored.
.create_sagemaker_model = function(...){
if (!is.null(self$algorithm_arn)){
# When ModelPackage is created using an algorithm_arn we need to first
# create a ModelPackage. If we had already created one then its fine to re-use it.
if (is.null(self$.created_model_package_name)){
model_package_name = private$.create_sagemaker_model_package()
self$sagemaker_session$wait_for_model_package(model_package_name)
self$.created_model_package_name = model_package_name}
model_package_name = self$.created_model_package_name
} else {
# When a ModelPackageArn is provided we just create the Model
model_package_name = self$model_package_arn}
container_def = list("ModelPackageName"= model_package_name)
if (!identical(self$env, list()))
container_def$Environment = self$env
model_package_short_name = split_str(model_package_name, "/")[length(split_str(model_package_name, "/"))]
private$.ensure_base_name_if_needed(model_package_short_name)
private$.set_model_name_if_needed()
self$sagemaker_session$create_model(
self$name,
self$role,
container_def,
vpc_config=self$vpc_config,
enable_network_isolation=self$enable_network_isolation())
}
),
private = list(
.create_sagemaker_model_package = function(){
if (is.null(self$algorithm_arn))
ValueError$new("No algorithm_arn was provided to create a SageMaker Model Package")
alg_split = split_str(self$algorithm_arn, "/")
name = self$name %||% name_from_base(alg_split[length(alg_split)])
description = sprintf("Model Package created from training with %s", self$algorithm_arn)
self$sagemaker_session$create_model_package_from_algorithm(
name, description, self$algorithm_arn, self$model_data)
return(name)
},
.is_marketplace = function(){
model_package_name = self$model_package_arn %||% self$.created_model_package_name
if (is.null(model_package_name))
return(TRUE)
# Models can lazy-init sagemaker_session until deploy() is called to support
# LocalMode so we must make sure we have an actual session to describe the model package.
sagemaker_session = self$sagemaker_session %||% Session$new()
model_package_desc = sagemaker_session$sagemaker_client$describe_model_package(
ModelPackageName=model_package_name)
for (container in model_package_desc[["InferenceSpecification"]][["Containers"]]){
if ("ProductId" %in% names(container))
return(TRUE)}
return(FALSE)
},
# Set the base name if there is no model name provided.
.ensure_base_name_if_needed = function(base_name){
if (is.null(self$name))
self$.base_name = base_name
}
),
lock_objects = F
)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.