knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)

Introduction

In this introduction, we'll provide a step-by-step guide to training models with AWS Sagemaker using the sagemaker R package.

We are going to train and tune an xgboost regression model on the sagemaker::abalone dataset, analyze the hyperparameters, and make new predictions.

Tuning

The tuning interface is similar to the caret package. We'll

  1. choose a model

  2. define a hyperparameter grid

  3. set the training and validation data

Dataset

I'll be building a regression model on the built-in abalone dataset, taken from UCI dataset database.

library(sagemaker)
library(rsample)
library(dplyr)
library(ggplot2)
library(tidyr)

sagemaker::abalone

The built-in hyperparameter tuning methods with AWS Sagemaker requires a train/validation split. Cross-validation is not supported out of the box.

We can quickly split the data with rsample:

abalone_split <- rsample::initial_split(sagemaker::abalone)

The training data needs to be uploaded to an S3 bucket that AWS Sagemaker has read/write permission to. For the typical AWS Sagemaker role, this could be any bucket with sagemaker included in the name.

We'll use the sagemaker::write_s3 helper to upload tibbles or data.frames to S3 as a csv.

write_s3(analysis(abalone_split), s3(s3_bucket(), "abalone-train.csv"))
write_s3(assessment(abalone_split), s3(s3_bucket(), "abalone-test.csv"))

You can also set a default bucket with options(sagemaker.default.bucket = "bucket_name") for sagemaker::s3_bucket.

Then we'll save the paths to use in tuning:

split <- s3_split(
  s3_train = s3(s3_bucket(), "abalone-train.csv"),
  s3_validation = s3(s3_bucket(), "abalone-test.csv")
)

Hyperparameters

Now we'll define ranges to tune over:

ranges <- list(
  max_depth = sagemaker_integer(3, 20),
  colsample_bytree = sagemaker_continuous(0, 1),
  subsample = sagemaker_continuous(0, 1)
)

Training

Then we kick off the training jobs.

tune <- sagemaker_hyperparameter_tuner(
  sagemaker_xgb_estimator(), split, ranges, max_jobs = 10
)
tune <- sagemaker_attach_tuner("xgboost-191114-1954")
tune

Analysis

Tuning

We can get more details about the tuning jobs by looking at the logs:

logs <- sagemaker_tuning_job_logs(tune)
logs %>%
  glimpse()

From here, we can investigate the training deeper:

logs %>%
  select(final_objective_value, colsample_bytree:subsample) %>%
  pivot_longer(colsample_bytree:subsample) %>%
  ggplot(aes(value, final_objective_value)) +
  geom_point() +
  facet_wrap(~name, scales = "free_x")

Training

We can also see the individual jobs logs, to track the difference between the train/validation set. This might be useful for advanced model tuning.

Note that tune$model_name is the name of the best model found during training.

job_logs <- sagemaker_training_job_logs(tune$model_name)
job_logs
job_logs %>%
  pivot_longer(`train:rmse`:`validation:rmse`) %>%
  ggplot(aes(iteration, value, color = name)) +
  geom_line()

Predictions

The AWS Sagemaker API supports two predictions modes: real-time endpoint and batch inference.

Real-time

Real-time opens a persistent web-endpoint for predictions. Deploying takes a few minutes.

sagemaker_deploy_endpoint(tune)

Then make new predictions on tibbles or data.frames, using the standard predict generic.

pred <- predict(tune, sagemaker::abalone[1:100, -1])
pred <- sagemaker::abalone_pred
glimpse(pred)

Once deployed, the endpoint has a subsecond latency.

Make sure to delete the endpoint when you are done to avoid charges.

sagemaker_delete_endpoint(tune)

Batch

You can also make batch predictions from data saved in S3. The batch method will write the predictions as a csv in an S3 folder.

s3_output_path <- batch_predict(
  tune, 
  s3_input = s3(s3_bucket(), "abalone-inference.csv"),
  s3_output = s3(s3_bucket(), "abalone_predictions")
)

We can use the sagemaker::read_s3 method to easily read csv data from S3.

s3_output_path <- s3(
  "sagemaker-us-east-2-495577990003/tests/batch-pred-test-output",
  "batch-pred-test.csv.out"
)
read_s3(s3_output_path) %>%
  glimpse()


tmastny/sagemaker documentation built on July 15, 2020, 4:17 p.m.