# NOTE: This code has been modified from AWS Sagemaker Python:
# https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/analytics.py
#' @include r_utils.R
#' @import lgr
#' @import R6
#' @import sagemaker.core
#' @import data.table
METRICS_PERIOD_DEFAULT = 60 # seconds
#' @title AnalyticsMetricsBase Class
#' @description Base class for tuning job or training job analytics classes. Understands
#' common functionality like persistence and caching.
#' @export
AnalyticsMetricsBase = R6Class("AnalyticsMetricsBase",
public = list(
#' @description Initialize a ``AnalyticsMetricsBase`` instance.
initialize = function() {
self$.dataframe = NULL
},
#' @description ersists the analytics dataframe to a file.
#' @param filename (str): The name of the file to save to.
export_csv = function(filename) {
fwrite(self$dataframe(), filename)
},
#' @description A dataframe with lots of interesting results about this
#' object. Created by calling SageMaker List and Describe APIs and
#' converting them into a convenient tabular summary.
#' @param force_refresh (bool): Set to True to fetch the latest data from
#' SageMaker API.
dataframe = function(force_refresh = FALSE) {
if (force_refresh)
self$clear_cache()
if (is.null(self$.dataframe))
self$.dataframe = private$.fetch_dataframe()
return(self$.dataframe)
},
#' @description Clear the object of all local caches of API methods, so that the next
#' time any properties are accessed they will be refreshed from the
#' service.
clear_cache = function() {
self$.dataframe = NULL
},
#' @description format class
format = function(){
format_class(self)
}
),
private = list(
# Sub-class must calculate the dataframe and return it.
.fetch_dataframe = function() {
stop("I'm an abstract interface method", call. = F)
}
),
lock_objects = F
)
#' @title HyperparameterTuningJobAnalytics Class
#' @description Fetch results about a hyperparameter tuning job and make them accessible
#' for analytics.
#' @export
HyperparameterTuningJobAnalytics = R6Class("HyperparameterTuningJobAnalytics",
inherit = AnalyticsMetricsBase,
public = list(
#' @description Initialize a ``HyperparameterTuningJobAnalytics`` instance.
#' @param hyperparameter_tuning_job_name (str): name of the
#' HyperparameterTuningJob to analyze.
#' @param sagemaker_session (sagemaker.session.Session): Session object which
#' manages interactions with Amazon SageMaker APIs and any other
#' AWS services needed. If not specified, one is created using the
#' default AWS configuration chain.
initialize = function(hyperparameter_tuning_job_name,
sagemaker_session=NULL){
self$sagemaker_session = sagemaker_session %||% Session$new()
self$.tuning_job_name = hyperparameter_tuning_job_name
self$.tuning_job_describe_result = NULL
self$.training_job_summaries = NULL
super$initialize()
self$clear_cache()
},
#' @description Call ``DescribeHyperParameterTuningJob`` for the hyperparameter
#' tuning job.
#' @param force_refresh (bool): Set to True to fetch the latest data from
#' SageMaker API.
#' @return dict: The Amazon SageMaker response for
#' ``DescribeHyperParameterTuningJob``.
description = function(force_refresh=FALSE){
if (force_refresh)
self$clear_cache()
if(islistempty(self$.tuning_job_describe_result))
self$.tuning_job_describe_result = self$sagemaker_session$sagemaker$describe_hyper_parameter_tuning_job(
HyperParameterTuningJobName=self$name)
return (self$.tuning_job_describe_result)
},
#' @description A (paginated) list of everything from
#' ``ListTrainingJobsForTuningJob``.
#' @param force_refresh (bool): Set to True to fetch the latest data from
#' SageMaker API.
#' @return dict: The Amazon SageMaker response for
#' ``ListTrainingJobsForTuningJob``.
training_job_summaries = function(force_refresh=FALSE){
if (force_refresh)
self$clear_cache()
if (!is.null(self$.training_job_summaries))
return(self$.training_job_summaries)
output = list()
next_args = list(HyperParameterTuningJobName=self$name, MaxResults=100)
for (count in 1:100){
LOGGER$debug("Calling list_training_jobs_for_hyper_parameter_tuning_job %d", count)
raw_result = do.call(
self$sagemaker_session$sagemaker$list_training_jobs_for_hyper_parameter_tuning_job,
next_args
)
new_output = raw_result[["TrainingJobSummaries"]]
output = c(output, new_output)
LOGGER$debug("Got %d more TrainingJobs. Total so far: %d", length(new_output), length(output))
if (!length(raw_result[["NextToken"]]) == 0 && length(new_output) > 0)
next_args$NextToken = raw_result$NextToken
else
break
}
self$.training_job_summaries = output
return (output)
},
#' @description Clear the object of all local caches of API methods.
clear_cache = function(){
super$clear_cache()
self$.tuning_job_describe_result = NULL
self$.training_job_summaries = NULL
}
),
private = list(
.fetch_dataframe = function(){
# Run that helper over all the summaries.
reshape = function(training_summary){
out = lapply(training_summary[["TunedHyperParameters"]], as.numeric)
out[["TrainingJobName"]] = training_summary[["TrainingJobName"]]
out[["TrainingJobStatus"]] = training_summary[["TrainingJobStatus"]]
out[["FinalObjectiveValue"]] = training_summary[["FinalHyperParameterTuningJobObjectiveMetric"]][["Value"]]
start_time = training_summary[["TrainingStartTime"]]
end_time = training_summary[["TrainingEndTime"]]
out[["TrainingStartTime"]] = start_time
out[["TrainingEndTime"]] = end_time
if(!is.null(start_time) && !is.null(end_time)){
out[["TrainingElapsedTimeSeconds"]] = difftime(end_time, start_time, units = "secs")
}
out[["TrainingJobDefinitionName"]] = training_summary[["TrainingJobDefinitionName"]]
return(out)
}
return(rbindlist(lapply(self$training_job_summaries(), reshape), fill = T))
},
# Convert parameter ranges a dictionary using the parameter range names as the keys
.prepare_parameter_ranges = function(parameter_ranges){
out = list()
for (i in seq_along(parameter_ranges)){
ranges = parameter_ranges[[i]]
for(param in ranges)
out[[param$Name]] = param
}
return(out)
}
),
active = list(
#' @field name
#' Name of the HyperparameterTuningJob being analyzed
name = function(){
return(self$.tuning_job_name)
},
#' @field tuning_ranges
#' A dictionary describing the ranges of all tuned hyperparameters. The
#' keys are the names of the hyperparameter, and the values are the ranges.
#' The output can take one of two forms:
#' * If the 'TrainingJobDefinition' field is present in the job description, the output
#' is a dictionary constructed from 'ParameterRanges' in
#' 'HyperParameterTuningJobConfig' of the job description. The keys are the
#' parameter names, while the values are the parameter ranges.
#' Example:
#' >>> {
#' >>> "eta": {"MaxValue": "1", "MinValue": "0", "Name": "eta"},
#' >>> "gamma": {"MaxValue": "10", "MinValue": "0", "Name": "gamma"},
#' >>> "iterations": {"MaxValue": "100", "MinValue": "50", "Name": "iterations"},
#' >>> "num_layers": {"MaxValue": "30", "MinValue": "5", "Name": "num_layers"},
#' >>> }
#' * If the 'TrainingJobDefinitions' field (list) is present in the job description,
#' the output is a dictionary with keys as the 'DefinitionName' values from
#' all items in 'TrainingJobDefinitions', and each value would be a dictionary
#' constructed from 'HyperParameterRanges' in each item in 'TrainingJobDefinitions'
#' in the same format as above
#' Example:
#' >>> {
#' >>> "estimator_1": {
#' >>> "eta": {"MaxValue": "1", "MinValue": "0", "Name": "eta"},
#' >>> "gamma": {"MaxValue": "10", "MinValue": "0", "Name": "gamma"},
#' >>> },
#' >>> "estimator_2": {
#' >>> "framework": {"Values": ["TF", "MXNet"], "Name": "framework"},
#' >>> "gamma": {"MaxValue": "1.0", "MinValue": "0.2", "Name": "gamma"}
#' >>> }
#' >>> }
#' For more details about the 'TrainingJobDefinition' and 'TrainingJobDefinitions' fields
#' in job description, see
#' https://botocore.readthedocs.io/en/latest/reference/services/sagemaker.html#SageMaker.Client.create_hyper_parameter_tuning_job
tuning_ranges = function(){
description = self$description()
if(!islistempty(description[["TrainingJobDefinition"]])){
return(private$.prepare_parameter_ranges(
description[["HyperParameterTuningJobConfig"]][["ParameterRanges"]]))
}
output = lapply(description[["TrainingJobDefinitions"]], function(training_job_definition){
private$.prepare_parameter_ranges(training_job_definition[["HyperParameterRanges"]])
})
names(output) = sapply(description[["TrainingJobDefinitions"]], function(x) x[["DefinitionName"]])
return (output)
}
),
lock_objects = F
)
#' @title TrainingJobAnalytics Class
#' @description Fetch training curve data from CloudWatch Metrics for a specific training
#' job.
#' @export
TrainingJobAnalytics = R6Class("TrainingJobAnalytics",
inherit = AnalyticsMetricsBase,
public = list(
#' @field CLOUDWATCH_NAMESPACE
#' CloudWatch namespace to return Training Job Analytics data
CLOUDWATCH_NAMESPACE = "/aws/sagemaker/TrainingJobs",
#' @description Initialize a ``TrainingJobAnalytics`` instance.
#' @param training_job_name (str): name of the TrainingJob to analyze.
#' @param metric_names (list, optional): string names of all the metrics to
#' collect for this training job. If not specified, then it will
#' use all metric names configured for this job.
#' @param sagemaker_session (sagemaker.session.Session): Session object which
#' manages interactions with Amazon SageMaker APIs and any other
#' AWS services needed. If not specified, one is specified using
#' the default AWS configuration chain.
#' @param start_time :
#' @param end_time :
#' @param period :
initialize = function(training_job_name,
metric_names=NULL,
sagemaker_session=NULL,
start_time=NULL,
end_time=NULL,
period=NULL){
self$sagemaker_session = sagemaker_session %||% Session$new()
self$.cloudwatch = self$sagemaker_session$paws_session$client("cloudwatch")
self$.training_job_name = training_job_name
self$.start_time = start_time
self$.end_time = end_time
self$.period = period %||% METRICS_PERIOD_DEFAULT
if (!is.null(metric_names))
self$.metric_names = metric_names
else
self$.metric_names = private$.metric_names_for_training_job()
super$initialize()
self$clear_cache()
},
#' @description Clear the object of all local caches of API methods, so that the next
#' time any properties are accessed they will be refreshed from the
#' service.
clear_cache = function(){
super$clear_cache()
self$.data = data.table()
self$.time_interval = private$.determine_timeinterval()
}
),
private = list(
# Return a dictionary with two datetime objects, start_time and
# end_time, covering the interval of the training job
.determine_timeinterval = function(){
description = self$sagemaker_session$sagemaker$describe_training_job(TrainingJobName=self$name)
start_time = self$.start_time %||% description[["TrainingStartTime"]] # datetime object
# Incrementing end time by 1 min since CloudWatch drops seconds before finding the logs.
# This results in logs being searched in the time range in which the correct log line was
# not present.
# Example - Log time - 2018-10-22 08:25:55
# Here calculated end time would also be 2018-10-22 08:25:55 (without 1 min addition)
# CW will consider end time as 2018-10-22 08:25 and will not be able to search the
# correct log.
end_time = self$.end_time %||% (
description[["TrainingEndTime"]] %||% utcnow() + 60 # add 1 minute
)
return(list("start_time"= start_time, "end_time"= end_time))
},
.fetch_dataframe = function(){
self$.data = lapply(self$.metric_names, private$.fetch_metric)
return(rbindlist(self$.data, fill = T))
},
# Fetch all the values of a named metric, and add them to _data
# Args:
# metric_name:
.fetch_metric = function(metric_name){
request = list(
"Namespace"=self$CLOUDWATCH_NAMESPACE,
"MetricName"=metric_name,
"Dimensions"= list(list("Name"="TrainingJobName", "Value"=self$name)),
"StartTime"=self$.time_interval[["start_time"]],
"EndTime"=self$.time_interval[["end_time"]],
"Period"=self$.period,
"Statistics"=list("Average")
)
raw_cwm_data = do.call(self$.cloudwatch$get_metric_statistics, request)[["Datapoints"]]
if (islistempty(raw_cwm_data)){
LOGGER$warn("Warning: No metrics called %s found", metric_name)
return(NULL)
}
# Process data: normalize to starting time, and sort.
base_time = min(sapply(raw_cwm_data, function(x) x$Timestamp))
all_xy= rbindlist(lapply(raw_cwm_data, function(pt) list(
timestamp = (pt[["Timestamp"]]-base_time),
metric_name = metric_name,
value = pt[["Average"]])
),
fill = T
)
setorderv(all_xy, cols = "timestamp")
return(all_xy)
},
# Helper method to discover the metrics defined for a training job.
.metric_names_for_training_job = function(){
training_description = self$sagemaker_session$sagemaker$describe_training_job(
TrainingJobName=self$.training_job_name
)
metric_definitions = training_description$AlgorithmSpecification$MetricDefinitions
metric_names = lapply(metric_definitions, function(md) md[["Name"]])
return(metric_names)
}
),
active = list(
#' @field name
#' Name of the TrainingJob being analyzed
name = function(){
return(self$.training_job_name)
}
),
lock_objects = F
)
#' @title ExperimentAnalytics class
#' @description Fetch trial component data and make them accessible for analytics.
#' @export
ExperimentAnalytics = R6Class("ExperimentAnalytics",
inherit = AnalyticsMetricsBase,
public = list(
#' @field MAX_TRIAL_COMPONENTS
#' class metadata
MAX_TRIAL_COMPONENTS = 10000,
#' @description Initialize a ``ExperimentAnalytics`` instance.
#' @param experiment_name (str, optional): Name of the experiment if you want to constrain the
#' search to only trial components belonging to an experiment.
#' @param search_expression (dict, optional): The search query to find the set of trial components
#' to use to populate the data frame.
#' @param sort_by (str, optional): The name of the resource property used to sort
#' the set of trial components.
#' @param sort_order (str optional): How trial components are ordered, valid values are Ascending
#' and Descending. The default is Descending.
#' @param metric_names (list, optional): string names of all the metrics to be shown in the
#' data frame. If not specified, all metrics will be shown of all trials.
#' @param parameter_names (list, optional): string names of the parameters to be shown in the
#' data frame. If not specified, all parameters will be shown of all trials.
#' @param sagemaker_session (sagemaker.session.Session): Session object which manages interactions
#' with Amazon SageMaker APIs and any other AWS services needed. If not specified,
#' one is created using the default AWS configuration chain.
initialize = function(experiment_name=NULL,
search_expression=NULL,
sort_by=NULL,
sort_order=NULL,
metric_names=NULL,
parameter_names=NULL,
sagemaker_session=NULL){
self$sagemaker_session = sagemaker_session %||% Session$new()
if (!is.null(experiment_name) && !is.null(search_expression))
stop("Either experiment_name or search_expression must be supplied.", call. = F)
self$.experiment_name = experiment_name
self$.search_expression = search_expression
self$.sort_by = sort_by
self$.sort_order = sort_order
self$.metric_names = metric_names
self$.parameter_names = parameter_names
self$.trial_components = NULL
super$initialize()
self$clear_cache()
},
#' @description Clear the object of all local caches of API methods.
clear_cache = function() {
super$clear_cache()
self$.trial_components = NULL
}
),
private = list(
# Reshape trial component data to pandas columns
# Args:
# trial_component: dict representing a trial component
# Returns:
# dict: Key-Value pair representing the data in the pandas dataframe
.reshape = function(trial_component){
output = data.table(TrialComponentName = trial_component$TrialComponentName,
DisplayName = trial_component$DisplayName,
SourceArn = trial_component$Source$SourceArn)
# ----- bring .reshape_parameters into .reshape function -----
for(name in sort(names(trial_component$Parameters))){
if (!is.null(self$.parameter_names) && !(name %in% self$.parameter_names))
next
output[[name]] = (if(!islistempty(trial_component$Parameters[[name]]$NumberValue))
trial_component$Parameters[[name]]$NumberValue
else trial_component$Parameters[[name]]$StringValue)
}
# ----- bring .reshape_metrics into .reshape function -----
statistic_types = c("Min", "Max", "Avg", "StdDev", "Last", "Count")
for(metric_summary in trial_component$Metrics){
metric_name = trial_component$Metrics$MetricName
if (!is.null(self$.metric_name) && !(metric_name %in% self$.metric_names))
next
for(stat_type in statistic_types){
stat_value = metric_summary[[stat_type]]
if (!islistempty(stat_value))
output[[sprintf("%s - %s", metric_name, stat_type)]] = stat_value}
}
return(output)
},
# Return a pandas dataframe with all the trial_components,
# along with their parameters and metrics.
.fetch_dataframe = function(){
df = rbindlist(lapply(private$.get_trial_components(), private$.reshape), fill = T)
return(df)
},
# Get all trial components matching the given search query expression.
# Args:
# force_refresh (bool): Set to True to fetch the latest data from SageMaker API.
# Returns:
# list: List of dicts representing the trial components
.get_trial_components = function(force_refresh=FALSE){
if (force_refresh)
self$clear_cache()
if (!islistempty(self$.trial_components))
return(self$.trial_components)
if (islistempty(self$.search_expression))
self$.search_expression = list()
if (!is.null(self$.experiment_name)){
if (islistempty(self$.search_expression$Filters))
self$.search_expression$Filters = list()
self$.search_expression$Filters = c(
self$.search_expression$Filters,
list("Name"= "Parents.ExperimentName",
"Operator"= "Equals",
"Value"= self$.experiment_name))
}
return(self$.search(self$.search_expression, self$.sort_by, self$.sort_order))
},
# Perform a search query using SageMaker Search and return the matching trial components
# Args:
# search_expression: Search expression to filter trial components.
# sort_by: The name of the resource property used to sort the trial components.
# sort_order: How trial components are ordered, valid values are Ascending
# and Descending. The default is Descending.
# Returns:
# list: List of dict representing trial components.
.search = function(search_expression,
sort_by,
sort_order){
trial_components = list()
search_args = list(
"Resource"= "ExperimentTrialComponent",
"SearchExpression"= search_expression)
search_args$SortBy = sort_by
search_args$SortOrder = sort_order
while(length(trial_components) < self$MAX_TRIAL_COMPONENTS){
search_response = self$sagemaker_session$sagemaker$search(
Resource = search_args$Resource,
SearchExpression = search_args$SearchExpression,
SortBy = search_args$SortBy,
SortOrder = search_args$SortOrder,
NextToken = search_args$NextToken)
components = lapply(search_response$Results, function(result) result$TrialComponent)
trial_components = c(trial_components, components)
if (!islistempty(search_response$NextToken) && !islistempty(components))
search_args$NextToken = search_response$NextToken
else
break
}
return(trial_components)
}
),
active = list(
#' @field name
#' Name of the Experiment being analyzed
name = function(){
return(self$.experiment_name)
}
),
lock_objects = F
)
utcnow = function(){
now = Sys.time()
attr(now, "tzone") <- "UTC"
return(now)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.