# NOTE: This code has been modified from AWS Sagemaker Python:
# https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/algorithm.py
#' @include r_utils.R
#' @include parameter.R
#' @include estimator.R
#' @include serializers.R
#' @include deserializers.R
#' @include predictor.R
#' @import R6
#' @import sagemaker.core
#' @import sagemaker.common
#' @title AlgorithmEstimator Class
#' @description A generic Estimator to train using any algorithm object (with an
#' ``algorithm_arn``). The Algorithm can be your own, or any Algorithm from AWS
#' Marketplace that you have a valid subscription for. This class will perform
#' client-side validation on all the inputs.
#' @export
AlgorithmEstimator = R6Class("AlgorithmEstimator",
inherit = EstimatorBase,
public = list(
#' @field .hyperpameters_with_range
#' These Hyperparameter Types have a range definition.
.hyperpameters_with_range = c("Integer", "Continuous", "Categorical"),
#' @description Initialize an ``AlgorithmEstimator`` instance.
#' @param algorithm_arn (str): algorithm arn used for training. Can be just the name if your
#' account owns the algorithm.
#' @param role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker
#' training jobs and APIs that create Amazon SageMaker endpoints use this role to
#' access training data and model artifacts. After the endpoint
#' is created, the inference code might use the IAM role, if it
#' needs to access an AWS resource.
#' @param instance_count (int): Number of Amazon EC2 instances to
#' use for training.
#' @param instance_type (str): Type of EC2
#' instance to use for training, for example, 'ml.c4.xlarge'.
#' @param volume_size (int): Size in GB of the EBS volume to use for
#' storing input data during training (default: 30). Must be large enough to store
#' training data if File Mode is used (which is the default).
#' @param volume_kms_key (str): Optional. KMS key ID for encrypting EBS volume attached
#' to the training instance (default: NULL).
#' @param max_run (int): Timeout in seconds for training (default: 24 * 60 * 60).
#' After this amount of time Amazon SageMaker terminates the
#' job regardless of its current status.
#' @param input_mode (str): The input mode that the algorithm supports
#' (default: 'File'). Valid modes:
#' * 'File' - Amazon SageMaker copies the training dataset from
#' the S3 location to a local directory.
#' * 'Pipe' - Amazon SageMaker streams data directly from S3 to
#' the container via a Unix-named pipe.
#' This argument can be overriden on a per-channel basis using
#' ``TrainingInput.input_mode``.
#' @param output_path (str): S3 location for saving the training result (model artifacts and
#' output files). If not specified, results are stored to a default bucket. If
#' the bucket with the specific name does not exist, the
#' estimator creates the bucket during the
#' :meth:`~sagemaker.estimator.EstimatorBase.fit` method
#' execution.
#' @param output_kms_key (str): Optional. KMS key ID for encrypting the
#' training output (default: NULL).
#' @param base_job_name (str): Prefix for
#' training job name when the
#' :meth:`~sagemaker.estimator.EstimatorBase.fit`
#' method launches. If not specified, the estimator generates a
#' default job name, based on the training image name and
#' current timestamp.
#' @param hyperparameters (dict): Dictionary containing the hyperparameters to
#' initialize this estimator with.
#' @param sagemaker_session (sagemaker.session.Session): Session object which manages
#' interactions with Amazon SageMaker APIs and any other AWS services needed. If
#' not specified, the estimator creates one using the default
#' AWS configuration chain.
#' @param tags (list[dict]): List of tags for labeling a training job. For more, see
#' https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
#' @param subnets (list[str]): List of subnet ids. If not specified
#' training job will be created without VPC config.
#' @param security_group_ids (list[str]): List of security group ids. If
#' not specified training job will be created without VPC config.
#' @param model_uri (str): URI where a pre-trained model is stored, either locally or in S3
#' (default: NULL). If specified, the estimator will create a channel pointing to
#' the model so the training job can download it. This model
#' can be a 'model.tar.gz' from a previous training job, or
#' other artifacts coming from a different source.
#' More information:
#' https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-training.html#td-deserialization
#' @param model_channel_name (str): Name of the channel where 'model_uri'
#' will be downloaded (default: 'model').
#' @param metric_definitions (list[dict]): A list of dictionaries that defines the metric(s)
#' used to evaluate the training jobs. Each dictionary contains two keys: 'Name' for
#' the name of the metric, and 'Regex' for the regular
#' expression used to extract the metric from the logs.
#' @param encrypt_inter_container_traffic (bool): Specifies whether traffic between training
#' containers is encrypted for the training job (default: ``False``).
#' @param ... : Additional kwargs. This is unused. It's only added for AlgorithmEstimator
#' to ignore the irrelevant arguments.
initialize = function(algorithm_arn,
role,
instance_count,
instance_type,
volume_size=30,
volume_kms_key=NULL,
max_run=24 * 60 * 60,
input_mode="File",
output_path=NULL,
output_kms_key=NULL,
base_job_name=NULL,
sagemaker_session=NULL,
hyperparameters=NULL,
tags=NULL,
subnets=NULL,
security_group_ids=NULL,
model_uri=NULL,
model_channel_name="model",
metric_definitions=NULL,
encrypt_inter_container_traffic=FALSE,
...){
self$algorithm_arn = algorithm_arn
super$initialize(
role,
instance_count,
instance_type,
volume_size,
volume_kms_key,
max_run,
input_mode,
output_path,
output_kms_key,
base_job_name,
sagemaker_session,
tags,
subnets,
security_group_ids,
model_uri=model_uri,
model_channel_name=model_channel_name,
metric_definitions=metric_definitions,
encrypt_inter_container_traffic=encrypt_inter_container_traffic)
self$algorithm_spec = self$sagemaker_session$sagemaker$describe_algorithm(
AlgorithmName=algorithm_arn)
self$validate_train_spec()
self$hyperparameter_definitions = private$.parse_hyperparameters()
self$hyperparam_list = list()
if (!is.null(hyperparameters))
do.call(self$set_hyperparameters, hyperparameters)
},
#' @description Placeholder docstring
validate_train_spec = function(){
train_spec = self$algorithm_spec[["TrainingSpecification"]]
algorithm_name = self$algorithm_spec[["AlgorithmName"]]
# Check that the input mode provided is compatible with the training input modes for the
# algorithm.
train_input_modes = private$.algorithm_training_input_modes(train_spec["TrainingChannels"])
if (!(self$input_mode %in% train_input_modes))
stop(sprintf("Invalid input mode: %s. %s only supports: %s", self$input_mode, algorithm_name, train_input_modes),
call. = F)
# Check that the training instance type is compatible with the algorithm.
supported_instances = train_spec[["SupportedTrainingInstanceTypes"]]
if (!(self$instance_type %in% supported_instances)){
stop(sprint("Invalid instance_type: %s. %s supports the following instance types: %s",
self$instance_type, algorithm_name, supported_instances),
call. = F)}
# Verify if distributed training is supported by the algorithm
if (self$instance_count > 1
&& "SupportsDistributedTraining" %in% train_spec
&& !is.null(train_spec[["SupportsDistributedTraining"]]))
stop(sprintf("Distributed training is not supported by %s. Please set instance_count=1", algorithm_name),
call. = F)
},
#' @description formats hyperparameters for model tunning
#' @param ... model hyperparameters
set_hyperparameter = function(...){
args = list(...)
for(x in names(args)){
value = private$.validate_and_cast_hyperparameter(x, args[[x]])
self$hyperparam_list[[x]] = value}
private$.validate_and_set_default_hyperparameters()
},
#' @description Returns the hyperparameters as a dictionary to use for training.
#' The fit() method, that does the model training, calls this method to
#' find the hyperparameters you specified.
hyperparameters = function(){
return(self$hyperparam_list)
},
#' @description Returns the docker image to use for training.
#' The fit() method, that does the model training, calls this method to
#' find the image to use for model training.
training_image_uri = function(){
stop("training_image_uri is never meant to be called on Algorithm Estimators", call. = F)
},
#' @description Return True if this Estimator will need network isolation to run.
#' On Algorithm Estimators this depends on the algorithm being used. If
#' this is algorithm owned by your account it will be False. If this is an
#' an algorithm consumed from Marketplace it will be True.
#' @return bool: Whether this Estimator needs network isolation or not.
enable_network_isolation = function(){
return(private$.is_marketplace())
},
#' @description Create a model to deploy.
#' The serializer, deserializer, content_type, and accept arguments are
#' only used to define a default Predictor They are ignored if an
#' explicit predictor class is passed in. Other arguments are passed
#' through to the Model class.
#' @param role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``,
#' which is also used during transform jobs. If not specified, the
#' role from the Estimator will be used.
#' @param predictor_cls (RealTimePredictor): The predictor class to use when
#' deploying the model.
#' @param serializer (callable): Should accept a single argument, the input
#' data, and return a sequence of bytes. May provide a content_type
#' attribute that defines the endpoint request content type
#' @param deserializer (callable): Should accept two arguments, the result
#' data and the response content type, and return a sequence of
#' bytes. May provide a content_type attribute that defines the
#' endpoint response Accept content type.
#' @param vpc_config_override (dict[str, list[str]]): Optional override for VpcConfig set on
#' the model. Default: use subnets and security groups from this Estimator.
#' * 'Subnets' (list[str]): List of subnet ids.
#' * 'SecurityGroupIds' (list[str]): List of security group ids.
#' @param ... : Additional arguments for creating a :class:`~sagemaker.model.ModelPackage`.
#' .. tip::
#' You can find additional parameters for using this method at
#' :class:`~sagemaker.model.ModelPackage` and
#' :class:`~sagemaker.model.Model`.
#' @return a Model ready for deployment.
create_model = function(role=NULL,
predictor_cls=NULL,
serializer=IdentitySerializer$new(),
deserializer=BytesDeserializer$new(),
vpc_config_override="VPC_CONFIG_DEFAULT",
...){
removed_kwargs("content_type", kwargs)
removed_kwargs("accept", kwargs)
if (is.null(predictor_cls)) {
predict_wrapper = function(endpoint, session){
return (sagemaker.mlcore::Predictor$new(
endpoint, session, serializer, deserializer, content_type, accept))
}
predictor_cls = predict_wrapper
}
role = role %||% self$role
param = list(role=role,
algorithm_arn=self$algorithm_arn,
model_data=self$model_data,
vpc_config=self$get_vpc_config(vpc_config_override),
sagemaker_session=self$sagemaker_session,
predictor_cls=predictor_cls,
...)
return(do.call(ModelPackage$new, param))
},
#' @description Return a ``Transformer`` that uses a SageMaker Model based on the
#' training job. It reuses the SageMaker Session and base job name used by
#' the Estimator.
#' @param instance_count (int): Number of EC2 instances to use.
#' @param instance_type (str): Type of EC2 instance to use, for example,
#' 'ml.c4.xlarge'.
#' @param strategy (str): The strategy used to decide how to batch records in
#' a single request (default: None). Valid values: 'MultiRecord'
#' and 'SingleRecord'.
#' @param assemble_with (str): How the output is assembled (default: None).
#' Valid values: 'Line' or 'None'.
#' @param output_path (str): S3 location for saving the transform result. If
#' not specified, results are stored to a default bucket.
#' @param output_kms_key (str): Optional. KMS key ID for encrypting the
#' transform output (default: None).
#' @param accept (str): The accept header passed by the client to
#' the inference endpoint. If it is supported by the endpoint,
#' it will be the format of the batch transform output.
#' @param env (dict): Environment variables to be set for use during the
#' transform job (default: None).
#' @param max_concurrent_transforms (int): The maximum number of HTTP requests
#' to be made to each individual transform container at one time.
#' @param max_payload (int): Maximum size of the payload in a single HTTP
#' request to the container in MB.
#' @param tags (list[dict]): List of tags for labeling a transform job. If
#' none specified, then the tags used for the training job are used
#' for the transform job.
#' @param role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``,
#' which is also used during transform jobs. If not specified, the
#' role from the Estimator will be used.
#' @param volume_kms_key (str): Optional. KMS key ID for encrypting the volume
#' attached to the ML compute instance (default: None).
transformer = function(instance_count,
instance_type,
strategy=NULL,
assemble_with=NULL,
output_path=NULL,
output_kms_key=NULL,
accept=NULL,
env=NULL,
max_concurrent_transforms=NULL,
max_payload=NULL,
tags=NULL,
role=NULL,
volume_kms_key=NULL){
role = role %||% self.role
if (!is.null(self$latest_training_job)){
model = self$create_model(role=role)
model$.create_sagemaker_model()
model_name = model$name
transform_env = list()
if(!islistempty(env)){
transform_env = model$env
transform_env = modifyList(transform_env, env, keep.null = T)}
if(private$.is_marketplace()){
transform_env = NULL}
tags = tags %||% self$tags
} else {
stop("No finished training job found associated with this estimator", call. = F)}
return(sagemaker.common::Transformer$new(
model_name,
instance_count,
instance_type,
strategy=strategy,
assemble_with=assemble_with,
output_path=output_path,
output_kms_key=output_kms_key,
accept=accept,
max_concurrent_transforms=max_concurrent_transforms,
max_payload=max_payload,
env=transform_env,
tags=tags,
base_transform_job_name=self$base_job_name,
volume_kms_key=volume_kms_key,
sagemaker_session=self$sagemaker_session))
},
#' @description Train a model using the input training dataset.
#' The API calls the Amazon SageMaker CreateTrainingJob API to start
#' model training. The API uses configuration you provided to create the
#' estimator and the specified input training data to send the
#' CreatingTrainingJob request to Amazon SageMaker.
#' This is a synchronous operation. After the model training
#' successfully completes, you can call the ``deploy()`` method to host the
#' model using the Amazon SageMaker hosting services.
#' @param inputs (str or dict or TrainingInput): Information
#' about the training data. This can be one of three types:
#' \itemize{
#' \item{\strong{(str)} the S3 location where training data is saved, or a file:// path in
#' local mode.}
#' \item{\strong{(dict[str, str]} or dict[str, TrainingInput]) If using multiple
#' channels for training data, you can specify a dict mapping channel names to
#' strings or :func:`~TrainingInput` objects.}
#' \item{\strong{(TrainingInput)} - channel configuration for S3 data sources that can
#' provide additional information as well as the path to the training dataset.
#' See :func:`TrainingInput` for full details.}
#' \item{\strong{(sagemaker.session.FileSystemInput)} - channel configuration for
#' a file system data source that can provide additional information as well as
#' the path to the training dataset.}}
#' @param wait (bool): Whether the call should wait until the job completes (default: True).
#' @param logs ([str]): A list of strings specifying which logs to print. Acceptable
#' strings are "All", "NULL", "Training", or "Rules". To maintain backwards
#' compatibility, boolean values are also accepted and converted to strings.
#' Only meaningful when wait is True.
#' @param job_name (str): Training job name. If not specified, the estimator generates
#' a default job name, based on the training image name and current timestamp.
#' @param experiment_config (dict[str, str]): Experiment management configuration.
#' Dictionary contains three optional keys,
#' 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
fit = function(inputs=NULL,
wait=TRUE,
logs=TRUE,
job_name=NULL){
if (!is.null(inputs))
private$.validate_input_channels(inputs)
super$fit(inputs, wait, logs, job_name)
},
#' @description
#' Printer.
#' @param ... (ignored).
print = function(...){
print_class(self)
}
),
private = list(
.is_marketplace= function(){
return ("ProductId" %in% self$algorithm_spec)
},
.prepare_for_training = function(job_name = NULL){
# Validate hyperparameters
# an explicit call to set_hyperparameters() will also validate the hyperparameters
# but it is possible that the user never called it.
private$.validate_and_set_default_hyperparameters()
super$.prepare_for_training(job_name)
},
.validate_input_channels = function(channels){
train_spec = self$algorithm_spec$TrainingSpecification
algorithm_name = self$algorithm_spec$AlgorithmName
training_channels = lapply(train_spec$TrainingChannels, function(c) c)
names(training_channels) = sapply(train_spec$TrainingChannels, function(c) c$Name)
# check for unknown channels that the algorithm does not support
for (c in channels){
if(!(c %in% training_channels))
stop(sprintf("Unknown input channel: %s is not supported by: %s", c, algorithm_name), call. = F)
}
# check for required channels that were not provided
for (i in seq_along(training_channels)){
name = names(training_channels)[i]
channel = training_channels[[i]]
if (!(name %in% names(channels)) && "IsRequired" %in% names(channel) && !islistempty(channel$IsRequired))
stop(sprintf("Required input channel: %s Was not provided.", name), call. = F)
}
},
.vatlidate_and_cast_hyperparameter = function(name, v){
algorithm_name = self$algorithm_spec$AlgorithmName
if (!(name %in% names(self$hyperparameter_definitions)))
stop(sprintf("Invalid hyperparameter: %s is not supported by %s", name, algorithm_name),
call. = F)
definition = self$hyperparameter_definitions[[name]]
if ("class" %in% names(definition))
value = definition$class$public_methods$cast_to_type(v)
else
value = v
if ("range" %in% names(definition) && !definition$range$is_valid(value)){
valid_range = definition$range$as_tuning_range(name)
stop(sprintf("Invalid value: %s Supported range: %s", value, valid_range), call. = F)}
return(value)
},
.validate_and_set_default_hyperparameters = function(){
# Check if all the required hyperparameters are set. If there is a default value
# for one, set it.
for (i in seq_along(self$hyperparameter_definitions)){
name = names(self$hyperparameter_definitions)[i]
definition = self$hyperparameter_definitions[[i]]
if (!(name %in% self$hyperparam_list)){
spec = definition$spec
if ("DefaultValue" %in% names(spec))
self$hyperparam_list[[name]] = spec$DefaultValue
else if ("IsRequired" %in% names(spec) && !islistempty(spec$IsRequired))
stop(sprintf("Required hyperparameter: %s is not set", name), call. = F)}
}
},
.parse_hyperparameters = function(){
definitions = list()
training_spec = self$algorithm_spec$TrainingSpecification
if ("SupportedHyperParameters" %in% names(training_spec)){
hyperparameters = training_spec$SupportedHyperParameters
for (h in hyperparameters){
parameter_type = h$Type
name = h$Name
param = private$.hyperparameter_range_and_class(
parameter_type, h)
definitions[[name]] = list("spec"= h)
if (!islistempty(param$parameter_range))
definitions[[name]]$range = parameter_range
if (!islistempty(param$parameter_class))
definitions[[name]]$class = parameter_class
}
}
return(definitions)
},
.hyperparameter_range_and_class = function(parameter_type, hyperparameter){
if (parameter_type %in% self$.hyperpameters_with_range)
range_name = paste0(parameter_type , "ParameterRangeSpecification")
parameter_class = NULL
parameter_range = NULL
if (parameter_type %in% c("Integer", "Continuous")){
# Integer and Continuous are handled the same way. We get the min and max values
# and just create an Instance of Parameter. Note that the range is optional for all
# the Parameter Types.
if (parameter_type == "Integer")
parameter_class = sagemaker.mlcore::IntegerParameter
else
parameter_class = sagemaker.mlcore::ContinuousParameter
if ("Range" %in% names(hyperparameter)){
min_value = parameter_class$public_methods$cast_to_type(
hyperparameter$Range[[range_name]][["MinValue"]])
max_value = parameter_class$public_methods$cast_to_type(
hyperparameter$Range[[range_name]][["MaxValue"]])
parameter_range = parameter_class$new(min_value, max_value)
}
} else if(parameter_type == "Categorical") {
parameter_class = sagemaker.mlcore::CategoricalParameter
if("Range" %in% names(hyperparameter)){
values = hyperparameter$Range[[range_name]][["Values"]]
parameter_range = sagemaker.mlcore::CategoricalParameter$new(values)}
} else if(parameter_type == "FreeText") {
NULL
} else
stop(sprintf("Invalid Hyperparameter type: %s. Valid ones are: (Integer, Continuous, Categorical, FreeText)", parameter_type),
call. = F)
return(list(parameter_class = parameter_class, parameter_range = parameter_range))
},
.algorithm_training_input_modes = function(training_channels){
current_input_modes = c("File", "Pipe")
for (channel in training_channels){
supported_input_modes = unique(channel$SupportedInputModes)
current_input_modes = c(current_input_modes, supported_input_modes)}
return(current_input_modes)
},
# Convert the job description to init params that can be handled by the
# class constructor
# Args:
# job_details (dict): the returned job details from a DescribeTrainingJob
# API call.
# model_channel_name (str): Name of the channel where pre-trained
# model data will be downloaded.
# Returns:
# dict: The transformed init_params
.prepare_init_params_from_job_description= function(job_details, model_channel_name=None){
init_params = super$.prepare_init_params_from_job_description(
job_details, model_channel_name)
# This hyperparameter is added by Amazon SageMaker Automatic Model Tuning.
# It cannot be set through instantiating an estimator.
if ("_tuning_objective_metric" %in% names(init_params$hyperparameters)){
init_params[["hyperparameters"]][["_tuning_objective_metric"]] = NULL}
return(init_params)
}
),
lock_objects = FALSE
)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.