# NOTE: This code has been modified from AWS Sagemaker Python:
# https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/xgboost/estimator.py
#' @include xgboost_default.R
#' @include xgboost_model.R
#' @include r_utils.R
#' @import R6
#' @import sagemaker.core
#' @import sagemaker.common
#' @import sagemaker.mlcore
#' @import lgr
#' @title XGBoost Class
#' @description Handle end-to-end training and deployment of XGBoost booster training or training using
#' customer provided XGBoost entry point script.
#' @export
XGBoost = R6Class("XGBoost",
inherit = sagemaker.mlcore::Framework,
public = list(
#' @field .module
#' mimic python module
.module = "sagemaker.xgboost.estimator",
#' @description This ``Estimator`` executes an XGBoost based SageMaker Training Job.
#' The managed XGBoost environment is an Amazon-built Docker container thatexecutes functions
#' defined in the supplied ``entry_point`` Python script.
#' Training is started by calling :meth:`~sagemaker.amazon.estimator.Framework.fit` on this
#' Estimator. After training is complete, calling
#' :meth:`~sagemaker.amazon.estimator.Framework.deploy` creates a hosted SageMaker endpoint
#' and returns an :class:`~sagemaker.amazon.xgboost.model.XGBoostPredictor` instance that
#' can be used to perform inference against the hosted model.
#' Technical documentation on preparing XGBoost scripts for SageMaker training and using the
#' XGBoost Estimator is available on the project home-page:
#' https://github.com/aws/sagemaker-python-sdk
#' @param entry_point (str): Path (absolute or relative) to the Python source file which should
#' be executed as the entry point to training. If ``source_dir`` is specified,
#' then ``entry_point`` must point to a file located at the root of ``source_dir``.
#' @param framework_version (str): XGBoost version you want to use for executing your model
#' training code.
#' @param source_dir (str): Path (absolute, relative or an S3 URI) to a directory
#' with any other training source code dependencies aside from the entry
#' point file (default: None). If ``source_dir`` is an S3 URI, it must
#' point to a tar.gz file. Structure within this directory are preserved
#' when training on Amazon SageMaker.
#' @param hyperparameters (dict): Hyperparameters that will be used for training (default: None).
#' The hyperparameters are made accessible as a dict[str, str] to the training code
#' on SageMaker. For convenience, this accepts other types for keys and values, but
#' ``str()`` will be called to convert them before training.
#' @param py_version (str): Python version you want to use for executing your model
#' training code (default: 'py3').
#' @param image_uri (str): If specified, the estimator will use this image for training and
#' hosting, instead of selecting the appropriate SageMaker official image
#' based on framework_version and py_version. It can be an ECR url or
#' dockerhub image and tag.
#' Examples:
#' 123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0
#' custom-image:latest.
#' @param ... : Additional kwargs passed to the
#' :class:`~sagemaker.estimator.Framework` constructor.
initialize = function(entry_point,
framework_version,
source_dir=NULL,
hyperparameters=NULL,
py_version="py3",
image_uri=NULL,
...){
kwargs = list(...)
instance_type = renamed_kwargs(
"train_instance_type", "instance_type", kwargs[["instance_type"]], kwargs
)
kwargs = c(
entry_point=entry_point, source_dir=source_dir, hyperparameters=hyperparameters, image_uri=image_uri, kwargs
)
do.call(super$initialize, kwargs)
self$py_version = py_version
self$framework_version = framework_version
attr(self, "_framework_name") = XGBOOST_NAME
if (is.null(image_uri)){
self$image_uri = sagemaker.core::ImageUris$new()$retrieve(
attr(self, "_framework_name"),
self$sagemaker_session$paws_region_name,
version=framework_version,
py_version=self$py_version,
instance_type=kwargs$instance_type,
image_scope="training"
)
}
},
#' @description Create a SageMaker ``XGBoostModel`` object that can be deployed to an ``Endpoint``.
#' @param role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used
#' during transform jobs. If not specified, the role from the Estimator will be used.
#' @param model_server_workers (int): Optional. The number of worker processes used by the
#' inference server. If None, server will use one worker per vCPU.
#' @param vpc_config_override (dict[str, list[str]]): Optional override for VpcConfig set on the
#' model.
#' Default: use subnets and security groups from this Estimator.
#' * 'Subnets' (list[str]): List of subnet ids.
#' * 'SecurityGroupIds' (list[str]): List of security group ids.
#' @param entry_point (str): Path (absolute or relative) to the local Python source file which
#' should be executed as the entry point to training. If ``source_dir`` is specified,
#' then ``entry_point`` must point to a file located at the root of ``source_dir``.
#' If not specified, the training entry point is used.
#' @param source_dir (str): Path (absolute or relative) to a directory with any other serving
#' source code dependencies aside from the entry point file.
#' If not specified, the model source directory from training is used.
#' @param dependencies (list[str]): A list of paths to directories (absolute or relative) with
#' any additional libraries that will be exported to the container.
#' If not specified, the dependencies from training are used.
#' This is not supported with "local code" in Local Mode.
#' @param ... : Additional kwargs passed to the :class:`~sagemaker.xgboost.model.XGBoostModel`
#' constructor.
#' @return sagemaker.xgboost.model.XGBoostModel: A SageMaker ``XGBoostModel`` object.
#' See :func:`~sagemaker.xgboost.model.XGBoostModel` for full details.
create_model = function(model_server_workers=NULL,
role=NULL,
vpc_config_override="VPC_CONFIG_DEFAULT",
entry_point=NULL,
source_dir=NULL,
dependencies=NULL,
...){
role = role %||% self$role
kwargs = list(...)
kwargs$name = private$.get_or_create_name(kwargs$name)
if (!("image_uri" %in% names(kwargs)))
kwargs$image_uri = self$image_uri
kwargs$model_data = self$model_data
kwargs$role = role
kwargs$entry_point = entry_point %||% private$.model_entry_point()
kwargs$framework_version = self$framework_version
kwargs$source_dir = source_dir %||% private$.model_source_dir()
kwargs$container_log_level = self$container_log_level
kwargs$code_location = self$code_location
kwargs$py_version = self$py_version
kwargs$model_server_workers = model_server_workers
kwargs$sagemaker_session = self$sagemaker_session
kwargs$vpc_config = self$get_vpc_config(vpc_config_override)
kwargs$dependencies = dependencies %||% self$dependencies
return (do.call(XGBoostModel$new, kwargs))
},
#' @description Attach to an existing training job.
#' Create an Estimator bound to an existing training job, each subclass
#' is responsible to implement
#' ``_prepare_init_params_from_job_description()`` as this method delegates
#' the actual conversion of a training job description to the arguments
#' that the class constructor expects. After attaching, if the training job
#' has a Complete status, it can be ``deploy()`` ed to create a SageMaker
#' Endpoint and return a ``Predictor``.
#' If the training job is in progress, attach will block and display log
#' messages from the training job, until the training job completes.
#' Examples:
#' >>> my_estimator.fit(wait=False)
#' >>> training_job_name = my_estimator.latest_training_job.name
#' Later on:
#' >>> attached_estimator = Estimator.attach(training_job_name)
#' >>> attached_estimator.deploy()
#' @param training_job_name (str): The name of the training job to attach to.
#' @param sagemaker_session (sagemaker.session.Session): Session object which
#' manages interactions with Amazon SageMaker APIs and any other
#' AWS services needed. If not specified, the estimator creates one
#' using the default AWS configuration chain.
#' @param model_channel_name (str): Name of the channel where pre-trained
#' model data will be downloaded (default: 'model'). If no channel
#' with the same name exists in the training job, this option will
#' be ignored.
#' @return Instance of the calling ``Estimator`` Class with the attached
#' training job.
attach = function(training_job_name,
sagemaker_session=NULL,
model_channel_name="model"){
sagemaker_session = sagemaker_session %||% Session$new()
job_details = sagemaker_session$sagemaker_client$describe_training_job(
TrainingJobName=training_job_name
)
init_params = private$.prepare_init_params_from_job_description(job_details, model_channel_name)
tags = sagemaker_session$sagemaker_client$list_tags(
ResourceArn=job_details$TrainingJobArn
)$Tags
init_params[["tags"]] = tags
init_params$sagemaker_session = sagemaker_session
# clone current class
estimator = self$clone()
do.call(estimator$initialize, init_params)
# update estimator class variables
estimator$latest_training_job = init_params$base_job_name
estimator$.current_job_name = estimator$latest_training_job
estimator$wait()
UploadedCode$s3_prefix=estimator$source_dir
UploadedCode$script_name= estimator$entry_point
estimator$uploaded_code = UploadedCode
return(estimator)
}
),
private = list(
# Convert the job description to init params that can be handled by the class constructor
# Args:
# job_details: the returned job details from a describe_training_job API call.
# Returns:
# dictionary: The transformed init_params
.prepare_init_params_from_job_description = function(job_details, model_channel_name=NULL){
init_params = super$.prepare_init_params_from_job_description(job_details)
image_uri = init_params$image_uri
init_params$image_uri = NULL
fw_name= framework_name_from_image(image_uri)
init_params$py_version = fw_name[[2]]
if (!is.null(fw_name[[1]]) && fw_name[[1]] != attr(self, "_framework_name"))
stop(sprintf("Training job: %s didn't use image for requested framework",
job_details$TrainingJobName
), call. = F
)
init_params$framework_version = framework_version_from_tag(fw_name[[3]])
if (islistempty(fw_name[[1]])){
# If we were unable to parse the framework name from the image it is not one of our
# officially supported images, in this case just add the image to the init params.
init_params$image_uri = image_uri}
return(init_params)
}
),
lock_objects = F
)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.