tests/testthat/test-estimator.R

# NOTE: This code has been modified from AWS Sagemaker Python:
# https://github.com/aws/sagemaker-python-sdk/blob/master/tests/unit/test_session.py

library(sagemaker.core)
library(sagemaker.common)

lg = lgr::get_logger("sagemaker")

MODEL_DATA = "s3://bucket/model.tar.gz"
MODEL_IMAGE = "mi"
ENTRY_POINT = "blah.py"

DATA_DIR = file.path(getwd(), "data")
SCRIPT_NAME = "dummy_script.py"
SCRIPT_PATH = file.path(DATA_DIR, SCRIPT_NAME)
TIMESTAMP = "2017-11-06-14:14:15.671"
TIME = 1510006209.073025
BUCKET_NAME = "mybucket"
INSTANCE_COUNT = 1
INSTANCE_TYPE = "c4.4xlarge"
ACCELERATOR_TYPE = "ml.eia.medium"
ROLE = "DummyRole"
IMAGE_URI = "fakeimage"
REGION = "us-west-2"
JOB_NAME = sprintf("%s-[0-9:.-]+", IMAGE_URI)
TAGS = list(list("Name"="some-tag", "Value"="value-for-tag"))
OUTPUT_PATH = "s3://bucket/prefix"
GIT_REPO = "https://github.com/aws/sagemaker-python-sdk.git"
BRANCH = "test-branch-git-config"
COMMIT = "ae15c9d7d5b97ea95ea451e4662ee43da3401d73"
PRIVATE_GIT_REPO_SSH = "git@github.com:testAccount/private-repo.git"
PRIVATE_GIT_REPO = "https://github.com/testAccount/private-repo.git"
PRIVATE_BRANCH = "test-branch"
PRIVATE_COMMIT = "329bfcf884482002c05ff7f44f62599ebc9f445a"
CODECOMMIT_REPO = "https://git-codecommit.us-west-2.amazonaws.com/v1/repos/test-repo/"
CODECOMMIT_REPO_SSH = "ssh://git-codecommit.us-west-2.amazonaws.com/v1/repos/test-repo/"
CODECOMMIT_BRANCH = "master"
REPO_DIR = "/tmp/repo_dir"
ENV_INPUT = list("env_key1"="env_val1", "env_key2"="env_val2", "env_key3"="env_val3")
Sys.setenv("AWS_REGION" = REGION)

DESCRIBE_TRAINING_JOB_RESULT = list("ModelArtifacts"=list("S3ModelArtifacts"=MODEL_DATA))

RETURNED_JOB_DESCRIPTION = list(
  "AlgorithmSpecification"=list(
    "TrainingInputMode"="File",
    "TrainingImage"="1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-other:1.0.4"
  ),
  "HyperParameters"=list(
    "sagemaker_submit_directory"='s3://some/sourcedir.tar.gz',
    "checkpoint_path"='s3://other/1508872349',
    "sagemaker_program"='iris-dnn-classifier.py',
    "sagemaker_container_log_level"='INFO',
    "sagemaker_job_name"='"neo"',
    "training_steps"="100"
  ),
  "RoleArn"="arn:aws:iam::366:role/SageMakerRole",
  "ResourceConfig"=list("VolumeSizeInGB"=30, "InstanceCount"=1, "InstanceType"="ml.c4.xlarge"),
  "EnableNetworkIsolation"=FALSE,
  "StoppingCondition"=list("MaxRuntimeInSeconds"=24 * 60 * 60),
  "TrainingJobName"="neo",
  "TrainingJobStatus"="Completed",
  "TrainingJobArn"="arn:aws:sagemaker:us-west-2:336:training-job/neo",
  "OutputDataConfig"=list("KmsKeyId"="", "S3OutputPath"="s3://place/output/neo"),
  "TrainingJobOutput"=list("S3TrainingJobOutput"="s3://here/output.tar.gz"),
  "EnableInterContainerTrafficEncryption"=FALSE
)

MODEL_CONTAINER_DEF = list(
  "Environment"=list(
    "SAGEMAKER_PROGRAM"=ENTRY_POINT,
    "SAGEMAKER_SUBMIT_DIRECTORY"="s3://mybucket/mi-2017-10-10-14-14-15/sourcedir.tar.gz",
    "SAGEMAKER_CONTAINER_LOG_LEVEL"="20",
    "SAGEMAKER_REGION"=REGION
  ),
  "Image"=MODEL_IMAGE,
  "ModelDataUrl"=MODEL_DATA
)

ENDPOINT_DESC = list("EndpointConfigName"="test-endpoint")

ENDPOINT_CONFIG_DESC = list("ProductionVariants"=list(list("ModelName"="model-1"), list("ModelName"="model-2")))

LIST_TAGS_RESULT = list("Tags"=list(list("Key"="TagtestKey", "Value"="TagtestValue")))

DISTRIBUTION_PS_ENABLED = list("parameter_server"=list("enabled"=TRUE))
DISTRIBUTION_MPI_ENABLED = list(
  "mpi"=list("enabled"=TRUE, "custom_mpi_options"="options", "processes_per_host"=2)
)
DISTRIBUTION_SM_DDP_ENABLED = list(
  "smdistributed"=list("dataparallel"=list("enabled"=TRUE, "custom_mpi_options"="options"))
)

DummyFramework = R6::R6Class("DummyFramework",
  inherit = Framework,
  public = list(
    initialize = function(...){
      super$initialize(...)
      attr(self, "_framework_name") = "dummy"
    },
    training_image_uri = function(){
      return(IMAGE_URI)
    },
    create_model = function(role=NULL,
                            model_server_workers=NULL,
                            entry_point=NULL,
                            vpc_config_override="VPC_CONFIG_DEFAULT",
                            enable_network_isolation=NULL,
                            model_dir=NULL,
                            ...){
      if (is.null(enable_network_isolation))
        enable_network_isolation = self$enable_network_isolation()
      return(DummyFrameworkModel$new(
        self$sagemaker_session,
        vpc_config=self$get_vpc_config(vpc_config_override),
        entry_point=entry_point,
        enable_network_isolation=enable_network_isolation,
        role=role,
        ...)
      )
    }
  ),
  private = list(
    .prepare_init_params_from_job_description = function(job_details,
                                                         model_channel_name=NULL){
      init_params = super$.prepare_init_params_from_job_description(
        job_details, model_channel_name
      )
      init_params[["image_uri"]] = NULL
      return(init_params)
    }
  ),
  lock_objects = F
)

DummyFrameworkModel = R6::R6Class("DummyFrameworkModel",
  inherit = FrameworkModel,
  public = list(
    initialize = function(sagemaker_session,
                          entry_point=NULL,
                          role=ROLE,
                          ...){
      super$initialize(
        MODEL_DATA,
        MODEL_IMAGE,
        role,
        if(is.null(entry_point)) ENTRY_POINT else entry_point,
        sagemaker_session=sagemaker_session,
        ...
      )
    },
    create_predictor = function(endpoint_name){
      return(NULL)
    },
    prepare_container_def = function(instance_type, accelerator_type=NULL){
      return(MODEL_CONTAINER_DEF)
    }
  ),
  lock_objects = F
)

sagemaker_session = function(region=REGION){
  paws_mock = Mock$new(
    name = "PawsSession",
    region_name = region
  )

  cloudwatchlogs = Mock$new(name="cloudwatchlogs")

  paws_mock$.call_args("client", side_effect = function(obj, ...){
    switch(obj,
           "cloudwatchlogs" = cloudwatchlogs)
  })

  sms = Mock$new(
    name="Session",
    paws_session=paws_mock,
    paws_region_name=region,
    config=NULL,
    local_mode=FALSE,
    s3=NULL
  )

  sagemaker = Mock$new()
  sagemaker$.call_args("describe_training_job", return_value=DESCRIBE_TRAINING_JOB_RESULT)
  sagemaker$.call_args("describe_endpoint", return_value=ENDPOINT_DESC)
  sagemaker$.call_args("describe_endpoint_config", return_value=ENDPOINT_CONFIG_DESC)
  sagemaker$.call_args("list_tags", return_value=LIST_TAGS_RESULT)
  sagemaker$.call_args("train")

  s3_client = Mock$new()
  s3_client$.call_args("put_object")

  sms$.call_args("default_bucket", return_value=BUCKET_NAME)
  sms$.call_args("upload_data", return_value=OUTPUT_PATH)
  sms$.call_args("expand_role")
  sms$.call_args("train")
  sms$.call_args("logs_for_job")
  sms$.call_args("wait_for_job")
  sms$.call_args("describe_training_job", return_value=DESCRIBE_TRAINING_JOB_RESULT)
  sms$.call_args("update_training_job")
  sms$.call_args("wait_for_model_package")
  sms$.call_args("create_model_package_from_algorithm")
  sms$.call_args("create_model")
  sms$.call_args("endpoint_from_production_variants")
  sms$.call_args("create_model_package_from_containers", return_value = list(ModelPackageArn="dummy"))
  sms$sagemaker = sagemaker
  sms$s3 = s3_client
  return(sms)
}

training_job_description = function(returned_job_description = RETURNED_JOB_DESCRIPTION, ll = list()){
  sms = sagemaker_session()
  sms$sagemaker$.call_args("describe_training_job", return_value=modifyList(
    returned_job_description, ll))
  sms$.call_args("describe_training_job", return_value=modifyList(
    returned_job_description, ll))
  return(sms)
}

test_that("test_framework_all_init_args", {
  sms = sagemaker_session()
  f = DummyFramework$new(
    "my_script.py",
    role="DummyRole",
    instance_count=3,
    instance_type="ml.m4.xlarge",
    sagemaker_session=sms,
    volume_size=123,
    volume_kms_key="volumekms",
    max_run=456,
    input_mode="inputmode",
    output_path="outputpath",
    output_kms_key="outputkms",
    base_job_name="basejobname",
    tags=list(list("foo"="bar")),
    subnets=c("123", "456"),
    security_group_ids=c("789", "012"),
    metric_definitions=list(list("Name"="validation-rmse", "Regex"="validation-rmse=(\\d+)")),
    encrypt_inter_container_traffic=TRUE,
    checkpoint_s3_uri="s3://bucket/checkpoint",
    checkpoint_local_path="file://local/checkpoint",
    enable_sagemaker_metrics=TRUE,
    enable_network_isolation=TRUE,
    environment=ENV_INPUT,
    max_retry_attempts=2
  )

  f$.__enclos_env__$private$.start_new("s3://mydata", NULL)

  expect_equal(sms$train(..return_value = T), list(
      "input_config"=list(
        list(
          "DataSource"=list(
            "S3DataSource"=list(
              "S3DataType"="S3Prefix",
              "S3Uri"="s3://mydata",
              "S3DataDistributionType"="FullyReplicated"
            )
          ),
          "ChannelName"="training"
        )
      ),
      "role"=sms$expand_role(),
      "output_config"=list("S3OutputPath"="outputpath", "KmsKeyId"="outputkms"),
      "resource_config"=list(
        "InstanceCount"=3,
        "InstanceType"="ml.m4.xlarge",
        "VolumeSizeInGB"=123,
        "VolumeKmsKeyId"="volumekms"
      ),
      "stop_condition"=list("MaxRuntimeInSeconds"=456),
      "vpc_config"=list("Subnets"=c("123", "456"), "SecurityGroupIds"=c("789", "012")),
      "input_mode"="inputmode",
      "hyperparameters"=list(),
      "tags"=list(list("foo"="bar")),
      "metric_definitions"=list(list("Name"="validation-rmse", "Regex"="validation-rmse=(\\d+)")),
      "environment"=list("env_key1"="env_val1", "env_key2"="env_val2", "env_key3"="env_val3"),
      "enable_network_isolation"=TRUE,
      "retry_strategy"=list("MaximumRetryAttempts"=2),
      "encrypt_inter_container_traffic"=TRUE,
      "image_uri"="fakeimage",
      "checkpoint_s3_uri"="s3://bucket/checkpoint",
      "checkpoint_local_path"="file://local/checkpoint",
      "enable_sagemaker_metrics"=TRUE
    )
  )
})

test_that("test_framework_with_debugger_and_built_in_rule", {
  debugger_built_in_rule_with_custom_args = Rule$new()$sagemaker(
    base_config=sagemaker.debugger::stalled_training_rule(),
    rule_parameters=list("threshold"="120", "stop_training_on_fire"="True"),
    collections_to_save=list(
      CollectionConfig$new(
        name="losses", parameters=list("train.save_interval"="50", "eval.save_interval"="10")
      )
    )
  )
  sms = sagemaker_session()
  f = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE,
    rules=list(debugger_built_in_rule_with_custom_args),
    debugger_hook_config=DebuggerHookConfig$new(s3_output_path="s3://output"),
  )
  f$fit("s3://mydata")
  args = sms$train(..return_value = T)
  expect_equal(args[["debugger_rule_configs"]][[1]][["RuleParameters"]], list(
    "rule_to_invoke"="StalledTrainingRule",
    "threshold"="120",
    "stop_training_on_fire"="True"
    )
  )
  expect_equal(args[["debugger_hook_config"]], list(
    "S3OutputPath"="s3://output",
    "CollectionConfigurations"=list(
      list(
        "CollectionName"="losses",
        "CollectionParameters"=list("train.save_interval"="50", "eval.save_interval"="10")
        )
      )
    )
  )
  expect_equal(args[["profiler_config"]], list(
    "S3OutputPath"=sprintf("s3://%s/", BUCKET_NAME)
    )
  )
})

test_that("test_framework_with_debugger_and_custom_rule", {
  hook_config = DebuggerHookConfig$new(
    s3_output_path="s3://output", collection_configs=list(CollectionConfig$new(name="weights"))
  )
  debugger_custom_rule = Rule$new()$custom(
    name="CustomRule",
    image_uri="RuleImageUri",
    instance_type=INSTANCE_TYPE,
    volume_size_in_gb=5,
    source="path/to/my_custom_rule.py",
    rule_to_invoke="CustomRule",
    other_trials_s3_input_paths=c("s3://path/trial1", "s3://path/trial2"),
    rule_parameters=list("threshold"="120")
  )

  sms = sagemaker_session()
  f = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE,
    rules=list(debugger_custom_rule),
    debugger_hook_config=hook_config
  )
  f$fit("s3://mydata")

  args = sms$train(..return_value = T)

  expect_equal(args[["debugger_rule_configs"]], list(
    list(
      "RuleConfigurationName"="CustomRule",
      "RuleEvaluatorImage"="RuleImageUri",
      "InstanceType"=INSTANCE_TYPE,
      "VolumeSizeInGB"=5,
      "RuleParameters"=list(
        "other_trial_0"="s3://path/trial1",
        "other_trial_1"="s3://path/trial2",
        "source_s3_uri"=sms$upload_data(),
        "rule_to_invoke"="CustomRule",
        "threshold"="120"
        )
      )
    )
  )
  expect_equal(args[["debugger_hook_config"]], list(
    "S3OutputPath"="s3://output",
    "CollectionConfigurations"=list(list("CollectionName"="weights"))
  ))
})

test_that("test_framework_with_only_debugger_rule", {
  sms = sagemaker_session()
  f = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE,
    rules=list(Rule$new()$sagemaker(sagemaker.debugger::stalled_training_rule()))
  )
  f$fit("s3://mydata")
  args = sms$train(..return_value = T)
  expect_equal(args[["debugger_rule_configs"]][[1]][["RuleParameters"]], list(
    "rule_to_invoke"="StalledTrainingRule"
  ))
  expect_equal(args[["debugger_hook_config"]], list(
    "S3OutputPath"=sprintf("s3://%s/",BUCKET_NAME),
    "CollectionConfigurations"=list()
  ))
})

test_that("test_framework_with_debugger_rule_and_single_action", {
  stop_training_action = sagemaker.debugger::StopTraining$new()
  sms = sagemaker_session()
  f = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE,
    rules=list(Rule$new()$sagemaker(sagemaker.debugger::stalled_training_rule(), actions=stop_training_action))
  )
  f$fit("s3://mydata")
  args = sms$train(..return_value = T)
  expect_equal(args[["debugger_rule_configs"]][[1]][["RuleParameters"]], list(
    "rule_to_invoke"="StalledTrainingRule",
    "action_json"=stop_training_action$serialize()
  ))
  expect_equal(stop_training_action$action_parameters[["training_job_prefix"]], f$.current_job_name)
  expect_equal(args[["debugger_hook_config"]], list(
    "S3OutputPath"=sprintf("s3://%s/",BUCKET_NAME),
    "CollectionConfigurations"=list()
  ))
})

test_that("test_framework_with_debugger_rule_and_multiple_actions", {
  action_list = sagemaker.debugger::ActionList$new(
    sagemaker.debugger::StopTraining$new(),
    sagemaker.debugger::Email$new("abc@abc.com"),
    sagemaker.debugger::SMS$new("+1234567890")
  )
  sms = sagemaker_session()
  f = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE,
    rules=list(Rule$new()$sagemaker(sagemaker.debugger::stalled_training_rule(), actions=action_list))
  )
  f$fit("s3://mydata")
  args = sms$train(..return_value = T)
  expect_equal(args[["debugger_rule_configs"]][[1]][["RuleParameters"]], list(
    "rule_to_invoke"="StalledTrainingRule",
    "action_json"=action_list$serialize()
  ))
  expect_equal(args[["debugger_hook_config"]], list(
    "S3OutputPath"=sprintf("s3://%s/",BUCKET_NAME),
    "CollectionConfigurations"=list()
  ))
})

test_that("test_framework_with_only_debugger_hook_config", {
  hook_config = DebuggerHookConfig$new(
    s3_output_path="s3://output", collection_configs=list(CollectionConfig$new(name="weights"))
  )
  sms = sagemaker_session()
  f = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE,
    debugger_hook_config=hook_config
  )
  f$fit("s3://mydata")
  args = sms$train(..return_value = T)
  expect_equal(args[["debugger_hook_config"]], list(
    "S3OutputPath"="s3://output",
    "CollectionConfigurations"=list(list("CollectionName"="weights"))
  ))
  expect_false("debugger_rule_configs" %in% names(args))
})

test_that("test_framework_without_debugger_and_profiler", {
  sms = sagemaker_session()
  f = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE
  )
  f$fit("s3://mydata")
  args = sms$train(..return_value = T)
  expect_equal(args[["debugger_hook_config"]], list(
    "S3OutputPath"=sprintf("s3://%s/", BUCKET_NAME),
    "CollectionConfigurations"=list()
  ))
  expect_false("debugger_rule_configs" %in% names(args))
  expect_equal(args[["profiler_config"]], list(
    "S3OutputPath"=sprintf("s3://%s/", BUCKET_NAME)
  ))
  expect_true(grepl("ProfilerReport-[0-9]+", args[["profiler_rule_configs"]][[1]][["RuleConfigurationName"]]))
  expect_equal(
    args[["profiler_rule_configs"]][[1]][["RuleEvaluatorImage"]],
    "895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest"
  )
  expect_equal(
    args[["profiler_rule_configs"]][[1]][["RuleParameters"]],
    list("rule_to_invoke"="ProfilerReport")
  )
})

test_that("test_framework_with_debugger_and_profiler_rules", {
  debugger_built_in_rule_with_custom_args = Rule$new()$sagemaker(
    base_config=sagemaker.debugger::stalled_training_rule(),
    rule_parameters=list("threshold"="120", "stop_training_on_fire"="True"),
    collections_to_save=list(
      CollectionConfig$new(
        name="losses", parameters=list("train.save_interval"="50", "eval.save_interval"="10")
      )
    )
  )
  profiler_built_in_rule_with_custom_args = ProfilerRule$new()$sagemaker(
    base_config=sagemaker.debugger::ProfilerReport$new(CPUBottleneck_threshold=90),
    name="CustomProfilerReportRule"
  )
  profiler_custom_rule = ProfilerRule$new()$custom(
    name="CustomProfilerRule",
    image_uri="RuleImageUri",
    instance_type=INSTANCE_TYPE,
    volume_size_in_gb=5,
    source="path/to/my_custom_rule.py",
    rule_to_invoke="CustomProfilerRule",
    rule_parameters=list("threshold"="10")
  )
  sms = sagemaker_session()
  f = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE,
    rules=list(
      debugger_built_in_rule_with_custom_args,
      profiler_built_in_rule_with_custom_args,
      profiler_custom_rule
    )
  )
  f$fit("s3://mydata")
  args = sms$train(..return_value = T)
  expect_equal(args[["debugger_rule_configs"]],list(
   list(
     "RuleConfigurationName"="StalledTrainingRule",
     "RuleEvaluatorImage"="895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest",
     "RuleParameters"=list(
       "rule_to_invoke"="StalledTrainingRule",
       "threshold"="120",
       "stop_training_on_fire"="True"
       )
     )
  ))
  expect_equal(args[["debugger_hook_config"]],list(
    "S3OutputPath"="s3://mybucket/",
    "CollectionConfigurations"=list(
      list(
        "CollectionName"="losses",
        "CollectionParameters"=list("train.save_interval"="50", "eval.save_interval"="10")
      )
    )
  ))
  expect_equal(args[["profiler_config"]],list(
    "S3OutputPath"=sprintf("s3://%s/", BUCKET_NAME)
  ))
  expect_equal(args[["profiler_rule_configs"]],list(
    list(
      "RuleConfigurationName"="CustomProfilerReportRule",
      "RuleEvaluatorImage"="895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest",
      "RuleParameters"=list("CPUBottleneck_threshold"="90", "rule_to_invoke"="ProfilerReport")
    ),
    list(
      "RuleConfigurationName"="CustomProfilerRule",
      "RuleEvaluatorImage"="RuleImageUri",
      "InstanceType"="c4.4xlarge",
      "VolumeSizeInGB"=5,
      "RuleParameters"=list(
        "source_s3_uri"=OUTPUT_PATH,
        "rule_to_invoke"="CustomProfilerRule",
        "threshold"="10"
      )
    )
  ))
})

test_that("test_framework_with_only_profiler_rule_specified", {
  sms = sagemaker_session()
  f = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE,
    rules=list(ProfilerRule$new()$sagemaker(sagemaker.debugger::CPUBottleneck$new(gpu_threshold=60)))
  )
  f$fit("s3://mydata")
  args = sms$train(..return_value = T)
  expect_equal(args[["profiler_config"]], list(
    "S3OutputPath"=sprintf("s3://%s/", BUCKET_NAME)
  ))
  expect_equal(args[["profiler_rule_configs"]], list(
    list(
      "RuleConfigurationName"="CPUBottleneck",
      "RuleEvaluatorImage"="895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest",
      "RuleParameters"=list(
        "threshold"="50",
        "gpu_threshold"="60",
        "cpu_threshold"="90",
        "patience"="1000",
        "scan_interval_us"="60000000",
        "rule_to_invoke"="CPUBottleneck"
      )
    )
  ))
})

test_that("test_framework_with_only_profiler_rule_specified", {
  sms = sagemaker_session()
  f = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE,
    rules=list(ProfilerRule$new()$sagemaker(sagemaker.debugger::CPUBottleneck$new(gpu_threshold=60)))
  )
  f$fit("s3://mydata")
  args = sms$train(..return_value = T)
  expect_equal(args[["profiler_config"]], list(
    "S3OutputPath"=sprintf("s3://%s/", BUCKET_NAME)
  ))
  expect_equal(args[["profiler_rule_configs"]], list(
    list(
      "RuleConfigurationName"="CPUBottleneck",
      "RuleEvaluatorImage"="895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest",
      "RuleParameters"=list(
        "threshold"="50",
        "gpu_threshold"="60",
        "cpu_threshold"="90",
        "patience"="1000",
        "scan_interval_us"="60000000",
        "rule_to_invoke"="CPUBottleneck"
      )
    )
  ))
})

test_that("test_framework_with_profiler_config_without_s3_output_path", {
  sms = sagemaker_session()
  f = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE,
    profiler_config=ProfilerConfig$new(system_monitor_interval_millis=1000)
  )
  f$fit("s3://mydata")
  args = sms$train(..return_value = T)
  expect_equal(args[["profiler_config"]], list(
    "S3OutputPath"=sprintf("s3://%s/", BUCKET_NAME),
    "ProfilingIntervalInMilliseconds"= 1000
  ))
  expect_true(grepl("ProfilerReport-[0-9]+",args[["profiler_rule_configs"]][[1]][["RuleConfigurationName"]]))
  expect_equal(args[["profiler_rule_configs"]][[1]][["RuleEvaluatorImage"]], "895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest")
  expect_equal(args[["profiler_rule_configs"]][[1]][["RuleParameters"]], list("rule_to_invoke"="ProfilerReport"))
})

test_that("test_framework_with_no_default_profiler_in_unsupported_region", {
  sms = sagemaker_session(sagemaker.core:::PROFILER_UNSUPPORTED_REGIONS)
  sms$.call_args("train", list(TrainingJobArn = NULL))
  f = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE
  )
  f$fit("s3://mydata")
  args = sms$train(..return_value = T)
  expect_null(args[["profiler_config"]])
  expect_null(args[["profiler_rule_configs"]])
})

test_that("test_framework_with_profiler_config_and_profiler_disabled", {
  sms = sagemaker_session()
  f = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE,
    profiler_config=ProfilerConfig$new(),
    disable_profiler=TRUE
  )
  expect_error(
    f$fit("s3://mydata"),
    "profiler_config cannot be set when disable_profiler is True.",
    class = "RuntimeError"
  )
})

test_that("test_framework_with_profiler_rule_and_profiler_disabled", {
  profiler_custom_rule = ProfilerRule$new()$custom(
    name="CustomProfilerRule",
    image_uri="RuleImageUri",
    instance_type=INSTANCE_TYPE,
    volume_size_in_gb=5
  )
  sms = sagemaker_session()
  f = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE,
    rules=list(profiler_custom_rule),
    disable_profiler=TRUE
  )
  expect_error(
    f$fit("s3://mydata"),
    "ProfilerRule cannot be set when disable_profiler is True.",
    class = "RuntimeError"
  )
})

test_that("test_framework_with_enabling_default_profiling_when_profiler_is_already_enabled", {
  sms = sagemaker_session()
  sms$.call_args(
    "describe_training_job",
    return_value = modifyList(DESCRIBE_TRAINING_JOB_RESULT, list("ProfilingStatus" = "Enabled"))
  )
  f = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE
  )
  f$fit("s3://mydata")
  expect_error(
    f$enable_default_profiling(),
    paste0("Debugger monitoring is already enabled. To update the profiler_config parameter ",
           "and the Debugger profiling rules, please use the update_profiler function."),
    class = "ValueError"
  )
})

test_that("test_framework_with_enabling_default_profiling", {
  sms = sagemaker_session()
  sms$.call_args(
    "describe_training_job",
    return_value = modifyList(DESCRIBE_TRAINING_JOB_RESULT, list("ProfilingStatus" = "Disabled"))
  )
  f = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE,
    disable_profiler=TRUE
  )
  f$fit("s3://mydata")
  f$enable_default_profiling()
  args = sms$update_training_job(..return_value = T)
  expect_equal(args[["profiler_config"]], list(
    "S3OutputPath"=sprintf("s3://%s/", BUCKET_NAME)
  ))
  expect_true(grepl("ProfilerReport-[0-9]+",args[["profiler_rule_configs"]][[1]][["RuleConfigurationName"]]))
  expect_equal(args[["profiler_rule_configs"]][[1]][["RuleEvaluatorImage"]],
    "895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest"
  )
  expect_equal(args[["profiler_rule_configs"]][[1]][["RuleParameters"]], list("rule_to_invoke"="ProfilerReport"))
})

test_that("test_framework_with_enabling_default_profiling_with_existed_s3_output_path", {
  sms = sagemaker_session()
  sms$.call_args(
    "describe_training_job",
    return_value = modifyList(
      DESCRIBE_TRAINING_JOB_RESULT, list(
        "ProfilingStatus" = "Disabled",
        "ProfilerConfig" = list(
          "S3OutputPath"="s3://custom/",
          "ProfilingIntervalInMilliseconds"=1000)
        )
      )
    )
  f = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE,
    disable_profiler=TRUE
  )
  f$fit("s3://mydata")
  f$enable_default_profiling()
  args = sms$update_training_job(..return_value = T)
  expect_equal(args[["profiler_config"]], list(
    "S3OutputPath"="s3://custom/"
  ))
  expect_true(grepl("ProfilerReport-[0-9]+",args[["profiler_rule_configs"]][[1]][["RuleConfigurationName"]]))
  expect_equal(args[["profiler_rule_configs"]][[1]][["RuleEvaluatorImage"]],
               "895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest"
  )
  expect_equal(args[["profiler_rule_configs"]][[1]][["RuleParameters"]], list("rule_to_invoke"="ProfilerReport"))
})

test_that("test_framework_with_disabling_profiling_when_profiler_is_already_disabled", {
  sms = sagemaker_session()
  sms$.call_args(
    "describe_training_job",
    return_value = modifyList(DESCRIBE_TRAINING_JOB_RESULT, list("ProfilingStatus" = "Disabled"))
  )
  f = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE
  )
  f$fit("s3://mydata")
  expect_error(
    f$disable_profiling(),
    "Profiler is already disabled.",
    class = "ValueError"
  )
})

test_that("test_framework_with_disabling_profiling", {
  sms = sagemaker_session()
  sms$.call_args(
    "describe_training_job",
    return_value = modifyList(DESCRIBE_TRAINING_JOB_RESULT, list("ProfilingStatus" = "Enabled"))
  )
  f = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE
  )
  f$fit("s3://mydata")
  f$disable_profiling()
  args = sms$update_training_job(..return_value = T)
  expect_equal(args[["profiler_config"]], list("DisableProfiler"=TRUE))
})

test_that("test_framework_with_update_profiler_when_no_training_job", {
  sms = sagemaker_session()
  f = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE
  )
  expect_error(
    f$update_profiler(system_monitor_interval_millis=1000),
    "Estimator is not associated with a training job",
    class = "ValueError"
  )
})

test_that("test_framework_with_update_profiler_without_any_parameter", {
  sms = sagemaker_session()
  f = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE
  )
  f$fit("s3://mydata")
  expect_error(
    f$update_profiler(),
    "Please provide profiler config or profiler rule to be updated.",
    class = "ValueError"
  )
})

test_that("test_framework_with_update_profiler_with_debugger_rule", {
  sms = sagemaker_session()
  f = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE
  )
  f$fit("s3://mydata")
  expect_error(
    f$update_profiler(rules=list(Rule$new()$sagemaker(sagemaker.debugger::stalled_training_rule()))),
    "Please provide ProfilerRule to be updated.",
    class = "ValueError"
  )
})

test_that("test_framework_with_update_profiler_config", {
  sms = sagemaker_session()
  f = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE
  )
  f$fit("s3://mydata")
  f$update_profiler(system_monitor_interval_millis=1000)
  args = sms$update_training_job(..return_value = T)
  expect_equal(args[["profiler_config"]], list(
    "ProfilingIntervalInMilliseconds"=1000
  ))
  expect_false("profiler_rule_configs" %in% names(args))
})

test_that("test_framework_with_update_profiler_report_rule", {
  sms = sagemaker_session()
  f = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE
  )
  f$fit("s3://mydata")
  f$update_profiler(
    rules=list(
      ProfilerRule$new()$sagemaker(sagemaker.debugger::ProfilerReport$new(), name="CustomProfilerReportRule")
    )
  )
  args = sms$update_training_job(..return_value = T)
  expect_equal(args[["profiler_rule_configs"]], list(
    list(
      "RuleConfigurationName"="CustomProfilerReportRule",
      "RuleEvaluatorImage"="895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest",
      "RuleParameters"=list("rule_to_invoke"="ProfilerReport")
    )
  ))
  expect_false("profiler_config" %in% names(args))
})

test_that("test_framework_with_disable_framework_metrics", {
  sms = sagemaker_session()
  f = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE
  )
  f$fit("s3://mydata")
  f$update_profiler(disable_framework_metrics=TRUE)
  args = sms$update_training_job(..return_value = T)
  expect_equal(args[["profiler_config"]], list("ProfilingParameters"=list()))
  expect_false("profiler_rule_configs" %in% names(args))
})

test_that("test_framework_with_disable_framework_metrics_and_update_system_metrics", {
  sms = sagemaker_session()
  f = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE
  )
  f$fit("s3://mydata")
  f$update_profiler(system_monitor_interval_millis=1000, disable_framework_metrics=TRUE)
  args = sms$update_training_job(..return_value = T)
  expect_equal(args[["profiler_config"]], list(
    "ProfilingIntervalInMilliseconds"=1000,
    "ProfilingParameters"=list()
  ))
  expect_false("profiler_rule_configs" %in% names(args))
})

test_that("test_framework_with_disable_framework_metrics_and_update_framework_params", {
  sms = sagemaker_session()
  f = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE
  )
  f$fit("s3://mydata")
  expect_error(
    f$update_profiler(
      framework_profile_params=FrameworkProfile$new(), disable_framework_metrics=TRUE
    ),
    "framework_profile_params cannot be set when disable_framework_metrics is True",
    class = "ValueError"
  )
})

test_that("test_framework_with_update_profiler_config_and_profiler_rule", {
  profiler_custom_rule = ProfilerRule$new()$custom(
    name="CustomProfilerRule",
    image_uri="RuleImageUri",
    instance_type=INSTANCE_TYPE,
    volume_size_in_gb=5
  )
  sms = sagemaker_session()
  f = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE
  )
  f$fit("s3://mydata")
  f$update_profiler(rules=list(profiler_custom_rule), system_monitor_interval_millis=1000)
  args = sms$update_training_job(..return_value = T)
  expect_equal(args[["profiler_config"]], list("ProfilingIntervalInMilliseconds"=1000))
  expect_equal(args[["profiler_rule_configs"]], list(
    list(
      "RuleConfigurationName"="CustomProfilerRule",
      "RuleEvaluatorImage"="RuleImageUri",
      "InstanceType"="c4.4xlarge",
      "VolumeSizeInGB"=5
    )
  ))
})

test_that("test_training_job_with_rule_job_summary", {
  sms = sagemaker_session()
  sms$.call_args("describe_training_job", return_value=modifyList(DESCRIBE_TRAINING_JOB_RESULT, list(
      "DebugRuleEvaluationStatuses" = list(
        list(
          "RuleConfigurationName"="debugger_rule",
          "RuleEvaluationJobArn"="debugger_rule_job_arn",
          "RuleEvaluationStatus"="InProgress"
        )
      ),
      "ProfilerRuleEvaluationStatuses" = list(
        list(
          "RuleConfigurationName"="profiler_rule_1",
          "RuleEvaluationJobArn"="profiler_rule_job_arn_1",
          "RuleEvaluationStatus"="InProgress"
        ),
        list(
          "RuleConfigurationName"="profiler_rule_2",
          "RuleEvaluationJobArn"="profiler_rule_job_arn_2",
          "RuleEvaluationStatus"="ERROR"
        )
      )
    )
  ))
  f = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE
  )
  f$fit("s3://mydata")

  job_summary = f$rule_job_summary()
  expect_equal(job_summary, list(
    list(
      "RuleConfigurationName"="debugger_rule",
      "RuleEvaluationJobArn"="debugger_rule_job_arn",
      "RuleEvaluationStatus"="InProgress"
    ),
    list(
      "RuleConfigurationName"="profiler_rule_1",
      "RuleEvaluationJobArn"="profiler_rule_job_arn_1",
      "RuleEvaluationStatus"="InProgress"
    ),
    list(
      "RuleConfigurationName"="profiler_rule_2",
      "RuleEvaluationJobArn"="profiler_rule_job_arn_2",
      "RuleEvaluationStatus"="ERROR"
    )
  ))
})

test_that("test_framework_with_spot_and_checkpoints", {
  sms = sagemaker_session()
  f = DummyFramework$new(
    "my_script.py",
    role="DummyRole",
    instance_count=3,
    instance_type="ml.m4.xlarge",
    sagemaker_session=sms,
    volume_size=123,
    volume_kms_key="volumekms",
    max_run=456,
    input_mode="inputmode",
    output_path="outputpath",
    output_kms_key="outputkms",
    base_job_name="basejobname",
    tags=list(list("foo"="bar")),
    subnets=list("123", "456"),
    security_group_ids=list("789", "012"),
    metric_definitions=list(list("Name"="validation-rmse", "Regex"="validation-rmse=(\\d+)")),
    encrypt_inter_container_traffic=TRUE,
    use_spot_instances=TRUE,
    max_wait=500,
    checkpoint_s3_uri="s3://mybucket/checkpoints/",
    checkpoint_local_path="/tmp/checkpoints"
  )
  f$.__enclos_env__$private$.start_new("s3://mydata", NULL)
  args = sms$train(..return_value = T)
  expect_equal(args, list(
    "input_config"=list(
      list(
        "DataSource"=list(
          "S3DataSource"=list(
            "S3DataType"="S3Prefix",
            "S3Uri"="s3://mydata",
            "S3DataDistributionType"="FullyReplicated"
            )
          ),
        "ChannelName"="training"
      )
    ),
    "role"=sms$expand_role(),
    "output_config"=list("S3OutputPath"="outputpath", "KmsKeyId"="outputkms"),
    "resource_config"=list(
      "InstanceCount"=3,
      "InstanceType"="ml.m4.xlarge",
      "VolumeSizeInGB"=123,
      "VolumeKmsKeyId"="volumekms"
    ),
    "stop_condition"=list("MaxRuntimeInSeconds"=456, "MaxWaitTimeInSeconds"=500),
    "vpc_config"=list("Subnets"=list("123", "456"), "SecurityGroupIds"=list("789", "012")),
    "input_mode"="inputmode",
    "hyperparameters"=list(),
    "tags"=list(list("foo"="bar")),
    "metric_definitions"=list(list("Name"="validation-rmse", "Regex"="validation-rmse=(\\d+)")),
    "encrypt_inter_container_traffic"=TRUE,
    "image_uri"="fakeimage",
    "use_spot_instances"=TRUE,
    "checkpoint_s3_uri"="s3://mybucket/checkpoints/",
    "checkpoint_local_path"="/tmp/checkpoints"
  ))
})

test_that("test_framework_init_s3_entry_point_invalid", {
  sms = sagemaker_session()
  expect_error(
    DummyFramework$new(
      "s3://remote-script-because-im-mistaken",
      role=ROLE,
      sagemaker_session=sms,
      instance_count=INSTANCE_COUNT,
      instance_type=INSTANCE_TYPE
    ),
    "Must be a path to a local file",
    class = "ValueError"
  )
})

test_that("test_sagemaker_s3_uri_invalid", {
  sms = sagemaker_session()
  t = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE
  )
  expect_error(
    t$fit("thisdoesntstartwiths3"),
    "must be a valid S3 or FILE URI",
    class = "ValueError"
  )
})

test_that("test_sagemaker_model_s3_uri_invalid", {
  sms = sagemaker_session()
  t = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE,
    model_uri="thisdoesntstartwiths3either.tar.gz"
  )
  expect_error(
    t$fit("s3://mydata"),
    "must be a valid S3 or FILE URI",
    class="ValueError"
  )
})

test_that("test_sagemaker_model_file_uri_invalid", {
  sms = sagemaker_session()
  t = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE,
    model_uri="file://notins3.tar.gz"
  )
  expect_error(
    t$fit("s3://mydata"),
    "File URIs are supported in local mode only",
    class = "ValueError"
  )
})

test_that("test_sagemaker_model_default_channel_name", {
  sms = sagemaker_session()
  f = DummyFramework$new(
    entry_point="my_script.py",
    role="DummyRole",
    instance_count=3,
    instance_type="ml.m4.xlarge",
    sagemaker_session=sms,
    model_uri="s3://model-bucket/prefix/model.tar.gz"
  )
  f$.__enclos_env__$private$.start_new(list(), NULL)
  args = sms$train(..return_value = T)
  expect_equal(args[["input_config"]], list(
    list(
      "DataSource"=list(
        "S3DataSource"=list(
          "S3DataType"="S3Prefix",
          "S3Uri"="s3://model-bucket/prefix/model.tar.gz",
          "S3DataDistributionType"="FullyReplicated"
        )
      ),
      "ContentType"= "application/x-sagemaker-model",
      "InputMode"="File",
      "ChannelName"="model"
    )
  ))
})

test_that("test_sagemaker_model_custom_channel_name", {
  sms = sagemaker_session()
  f = DummyFramework$new(
    entry_point="my_script.py",
    role="DummyRole",
    instance_count=3,
    instance_type="ml.m4.xlarge",
    sagemaker_session=sms,
    model_uri="s3://model-bucket/prefix/model.tar.gz",
    model_channel_name="testModelChannel"
  )
  f$.__enclos_env__$private$.start_new(list(), NULL)
  args = sms$train(..return_value = T)
  expect_equal(args[["input_config"]], list(
    list(
      "DataSource"=list(
        "S3DataSource"=list(
          "S3DataType"="S3Prefix",
          "S3Uri"="s3://model-bucket/prefix/model.tar.gz",
          "S3DataDistributionType"="FullyReplicated"
        )
      ),
      "ContentType"="application/x-sagemaker-model",
      "InputMode"="File",
      "ChannelName"="testModelChannel"
    )
  ))
})

test_that("test_custom_code_bucket", {
  code_bucket = "codebucket"
  prefix = "someprefix"
  code_location = sprintf("s3://%s/%s", code_bucket, prefix)
  sms = sagemaker_session()
  t = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE,
    code_location=code_location
  )
  t$fit("s3://bucket/mydata")

  expected_key = sprintf("%s/%s/source/sourcedir.tar.gz", prefix, JOB_NAME)
  args = sms$s3$put_object(..return_value = T)
  expect_equal(args[["Bucket"]], code_bucket)
  expect_true(grepl(expected_key, args[["Key"]]))

  expected_submit_dir = sprintf("s3://%s/%s", code_bucket, expected_key)
  args = sms$train(..return_value = T)
  expect_true(grepl(expected_submit_dir, args[["hyperparameters"]][["sagemaker_submit_directory"]]))
})

test_that("test_custom_code_bucket_without_prefix", {
  code_bucket = "codebucket"
  code_location = sprintf("s3://%s", code_bucket)
  sms = sagemaker_session()
  t = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE,
    code_location=code_location
  )
  t$fit("s3://bucket/mydata")

  expected_key = sprintf("%s/source/sourcedir.tar.gz", JOB_NAME)
  args = sms$s3$put_object(..return_value = T)
  expect_equal(args[["Bucket"]], code_bucket)
  expect_true(grepl(expected_key, args[["Key"]]))

  expected_submit_dir = sprintf("s3://%s/%s", code_bucket, expected_key)
  args = sms$train(..return_value = T)
  expect_true(grepl(expected_submit_dir, args[["hyperparameters"]][["sagemaker_submit_directory"]]))
})

test_that("test_invalid_custom_code_bucket", {
  code_location = "thisllworkright?"
  sms = sagemaker_session()
  t = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE,
    code_location=code_location
  )
  expect_error(
    t$fit("s3://bucket/mydata")
  )
})

test_that("test_augmented_manifest", {
  sms = sagemaker_session()
  fw = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role="DummyRole",
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE
  )
  fw$fit(
    inputs=TrainingInput$new(
      "s3://mybucket/train_manifest",
      s3_data_type="AugmentedManifestFile",
      attribute_names=list("foo", "bar")
    )
  )
  train_kwargs = sms$train(..return_value = T)
  s3_data_source = train_kwargs[["input_config"]][[1]][["DataSource"]][["S3DataSource"]]
  expect_equal(s3_data_source[["S3Uri"]], "s3://mybucket/train_manifest")
  expect_equal(s3_data_source[["S3DataType"]], "AugmentedManifestFile")
  expect_equal(s3_data_source[["AttributeNames"]], list("foo", "bar"))
})

test_that("test_s3_input_mode", {
  expected_input_mode = "Pipe"
  sms = sagemaker_session()
  fw = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role="DummyRole",
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE
  )
  fw$fit(inputs=TrainingInput$new("s3://mybucket/train_manifest", input_mode=expected_input_mode))
  train_kwargs = sms$train(..return_value = T)
  expect_equal(train_kwargs[["input_config"]][[1]][["InputMode"]], "Pipe")
  expect_equal(train_kwargs[["input_mode"]], "Pipe")
})

test_that("test_shuffle_config", {
  sms = sagemaker_session()
  fw = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role="DummyRole",
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE
  )
  fw$fit(inputs=TrainingInput$new("s3://mybucket/train_manifest", shuffle_config=ShuffleConfig$new(100)))
  train_kwargs = sms$train(..return_value = T)
  channel = train_kwargs[["input_config"]][[1]]
  expect_equal(channel[["ShuffleConfig"]][["Seed"]], 100)
})

BASE_HP = list(
  "sagemaker_program"=SCRIPT_NAME,
  "sagemaker_submit_directory"=sprintf("s3://mybucket/%s/source/sourcedir.tar.gz", JOB_NAME),
  "sagemaker_job_name"=JOB_NAME
)

sagemaker_local_session = function(region=REGION, config = NULL){
  paws_mock = Mock$new(
    name = "PawsSession",
    region_name = region
  )
  sms = Mock$new(
    name="LocalSession",
    paws_session=paws_mock,
    paws_region_name=region,
    config=config,
    local_mode=TRUE,
    s3=NULL
  )
  sagemaker = Mock$new()
  sagemaker$.call_args("describe_training_job", return_value=DESCRIBE_TRAINING_JOB_RESULT)
  sagemaker$.call_args("describe_endpoint", return_value=ENDPOINT_DESC)
  sagemaker$.call_args("describe_endpoint_config", return_value=ENDPOINT_CONFIG_DESC)
  sagemaker$.call_args("list_tags", return_value=LIST_TAGS_RESULT)
  sagemaker$.call_args("train")

  s3_client = Mock$new()
  s3_client$.call_args("put_object")

  sms$.call_args("default_bucket", return_value=BUCKET_NAME)
  sms$.call_args("upload_data", return_value=OUTPUT_PATH)
  sms$.call_args("expand_role")
  sms$.call_args("train")
  sms$.call_args("logs_for_job")
  sms$.call_args("wait_for_job")
  sms$sagemaker = sagemaker
  sms$s3 = s3_client
  return(sms)
}

test_that("test_local_code_location", {
  config = list("local"=list("local_code"=TRUE, "region"="us-west-2"))
  sms = sagemaker_local_session(config = config)
  t = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=1,
    instance_type="local",
    base_job_name=IMAGE_URI,
    hyperparameters=list("123"=456, "learning_rate"=0.1)
  )
  t$fit("file:///data/file")
  expect_equal(t$source_dir, DATA_DIR)
  expect_equal(t$entry_point, "dummy_script.py")
})

test_that("test_start_new_convert_hyperparameters_to_str", {
  uri = "bucket/mydata"
  sms = sagemaker_session()
  t = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE,
    base_job_name=IMAGE_URI,
    hyperparameters=list("123"=list(456), "learning_rate"=0.1)
  )
  t$fit(sprintf("s3://%s",uri))
  expected_hyperparameters = BASE_HP
  expected_hyperparameters[["sagemaker_container_log_level"]] = "20"
  expected_hyperparameters[["learning_rate"]] = "0.1"
  expected_hyperparameters[["123"]] = as.character(jsonlite::toJSON(list(456), auto_unbox = T))
  expected_hyperparameters[["sagemaker_region"]] = 'us-west-2'

  actual_hyperparameter = sms$train(..return_value = T)$hyperparameters

  for (n in sort(names(expected_hyperparameters))){
    if (!(n  %in% c("sagemaker_job_name", "sagemaker_submit_directory")))
      expect_equal(actual_hyperparameter[[n]], expected_hyperparameters[[n]])
    else
      expect_true(grepl(expected_hyperparameters[[n]], actual_hyperparameter[[n]]))
  }
})

test_that("test_start_new_wait_called", {
  uri = "bucket/mydata"
  sms = sagemaker_session()
  t = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE
  )
  t$fit(sprintf("s3://%s",uri))
  expected_hyperparameters = BASE_HP
  expected_hyperparameters[["sagemaker_container_log_level"]] = "20"
  expected_hyperparameters[["sagemaker_region"]] = 'us-west-2'

  actual_hyperparameter = sms$train(..return_value = T)$hyperparameters
  for (n in sort(names(expected_hyperparameters))){
    if (!(n  %in% c("sagemaker_job_name", "sagemaker_submit_directory")))
      expect_equal(actual_hyperparameter[[n]], expected_hyperparameters[[n]])
    else
      expect_true(grepl(expected_hyperparameters[[n]], actual_hyperparameter[[n]]))
  }
})

test_that("test_attach_framework", {
  sms = training_job_description(ll=list(
    "VpcConfig" = list("Subnets"=list("foo"), "SecurityGroupIds"=list("bar")),
    "EnableNetworkIsolation" = TRUE)
  )
  f = DummyFramework$new("dummy", instance_count=10, instance_type="dummy", role = "dummy", sagemaker_session=sms)
  framework_estimator = f$attach(
    training_job_name="neo", sagemaker_session=sms
  )
  expect_equal(framework_estimator$.current_job_name, "neo")
  expect_equal(framework_estimator$latest_training_job, "neo")
  expect_equal(framework_estimator$role, "arn:aws:iam::366:role/SageMakerRole")
  expect_equal(framework_estimator$instance_count, 1)
  expect_equal(framework_estimator$max_run, 24 * 60 * 60)
  expect_equal(framework_estimator$input_mode, "File")
  expect_equal(framework_estimator$base_job_name, "neo")
  expect_equal(framework_estimator$output_path, "s3://place/output/neo")
  expect_equal(framework_estimator$output_kms_key, "")
  expect_equal(framework_estimator$hyperparameters()$training_steps, "100")
  expect_equal(framework_estimator$source_dir, "s3://some/sourcedir.tar.gz")
  expect_equal(framework_estimator$entry_point, "iris-dnn-classifier.py")
  expect_equal(framework_estimator$subnets, list("foo"))
  expect_equal(framework_estimator$security_group_ids, list("bar"))
  expect_false(framework_estimator$encrypt_inter_container_traffic)
  expect_equal(framework_estimator$tags, LIST_TAGS_RESULT[["Tags"]])
  expect_equal(framework_estimator$tags, LIST_TAGS_RESULT[["Tags"]])
  expect_true(framework_estimator$enable_network_isolation())
})

mod_list = list("VpcConfig" = list("Subnets"=list("foo"), "SecurityGroupIds"=list("bar")), "EnableNetworkIsolation" = TRUE)

test_that("test_attach_framework", {
  SagemakerSesion = training_job_description(ll=mod_list)
  f = DummyFramework$new("dummy", instance_count=10, instance_type="dummy", role = "dummy", sagemaker_session=SagemakerSesion)
  framework_estimator = f$attach(training_job_name="neo", sagemaker_session=SagemakerSesion)
  expect_equal(framework_estimator$.current_job_name, "neo")
  expect_equal(framework_estimator$latest_training_job, "neo")
  expect_equal(framework_estimator$role, "arn:aws:iam::366:role/SageMakerRole")
  expect_equal(framework_estimator$instance_count, 1)
  expect_equal(framework_estimator$max_run, 24 * 60 * 60)
  expect_equal(framework_estimator$input_mode, "File")
  expect_equal(framework_estimator$base_job_name, "neo")
  expect_equal(framework_estimator$output_path, "s3://place/output/neo")
  expect_equal(framework_estimator$output_kms_key, "")
  expect_equal(framework_estimator$hyperparameters()$training_steps, "100")
  expect_equal(framework_estimator$source_dir, "s3://some/sourcedir.tar.gz")
  expect_equal(framework_estimator$entry_point, "iris-dnn-classifier.py")
  expect_equal(framework_estimator$subnets, list("foo"))
  expect_equal(framework_estimator$security_group_ids, list("bar"))
  expect_false(framework_estimator$encrypt_inter_container_traffic)
  expect_equal(framework_estimator$tags, LIST_TAGS_RESULT[["Tags"]])
  expect_equal(framework_estimator$tags, LIST_TAGS_RESULT[["Tags"]])
  expect_true(framework_estimator$enable_network_isolation())
})

test_that("test_attach_no_logs", {
  SagemakerSesion = training_job_description(ll = mod_list)
  f = Estimator$new("dummy", instance_count=10, instance_type="dummy", role = "dummy", sagemaker_session=SagemakerSesion)
  f$attach(training_job_name="job", sagemaker_session=SagemakerSesion)
  expect_equal(SagemakerSesion$logs_for_job(..count = T), 0)
  expect_null(SagemakerSesion$logs_for_job(..return_value = T))
})

test_that("test_logs", {
  SagemakerSesion = training_job_description(ll = mod_list)
  f = Estimator$new("dummy", instance_count=10, instance_type="dummy", role = "dummy", sagemaker_session=SagemakerSesion)
  estimator = f$attach(training_job_name="job", sagemaker_session=SagemakerSesion)
  estimator$logs()
  expect_true(SagemakerSesion$logs_for_job(..return_value = T)$wait)
})

test_that("test_attach_without_hyperparameters", {
  RETURNED_JOB_NO_HYPER_DESC = RETURNED_JOB_DESCRIPTION
  RETURNED_JOB_NO_HYPER_DESC[["HyperParameters"]] = NULL
  SagemakerSesion = training_job_description(RETURNED_JOB_NO_HYPER_DESC, ll = mod_list)
  f = Estimator$new("dummy", instance_count=10, instance_type="dummy", role = "dummy", sagemaker_session=SagemakerSesion)
  estimator = f$attach(training_job_name="job", sagemaker_session=SagemakerSesion)

  expect_equal(estimator$hyperparameters(), list())
})

test_that("test_attach_framework_with_tuning", {
  sms = training_job_description(ll = list("HyperParameters"= list("_tuning_objective_metric"="Validation-accuracy")))
  f = DummyFramework$new("dummy", instance_count=10, instance_type="dummy", role = "dummy", sagemaker_session=sms)
  framework_estimator = f$attach(training_job_name="neo", sagemaker_session=sms)
  expect_equal(framework_estimator$latest_training_job, "neo")
  expect_equal(framework_estimator$role, "arn:aws:iam::366:role/SageMakerRole")
  expect_equal(framework_estimator$instance_count, 1)
  expect_equal(framework_estimator$max_run, 24 * 60 * 60)
  expect_equal(framework_estimator$input_mode, "File")
  expect_equal(framework_estimator$base_job_name, "neo")
  expect_equal(framework_estimator$output_path, "s3://place/output/neo")
  expect_equal(framework_estimator$output_kms_key, "")
  hyper_params = framework_estimator$hyperparameters()
  expect_equal(hyper_params[["training_steps"]], "100")
  expect_equal(hyper_params[["_tuning_objective_metric"]], "Validation-accuracy")
  expect_equal(framework_estimator$source_dir, "s3://some/sourcedir.tar.gz")
  expect_equal(framework_estimator$entry_point, "iris-dnn-classifier.py")
  expect_false(framework_estimator$encrypt_inter_container_traffic)
})

test_that("test_attach_framework_with_model_channel", {
  s3_uri = "s3://some/s3/path/model.tar.gz"
  sms = training_job_description(ll = list("InputDataConfig" = list(
    list(
      "ChannelName"="model",
      "InputMode"="File",
      "DataSource"=list("S3DataSource"=list("S3Uri"=s3_uri))
      )
    )
  ))
  f = DummyFramework$new("dummy", instance_count=10, instance_type="dummy", role = "dummy", sagemaker_session=sms)
  framework_estimator = f$attach(training_job_name="neo", sagemaker_session=sms)
  expect_equal(framework_estimator$model_uri, s3_uri)
  expect_false(framework_estimator$encrypt_inter_container_traffic)
})

test_that("test_attach_framework_with_inter_container_traffic_encryption_flag", {
  sms = training_job_description(ll = list("EnableInterContainerTrafficEncryption" = TRUE))
  f = DummyFramework$new("dummy", instance_count=10, instance_type="dummy", role = "dummy", sagemaker_session=sms)
  framework_estimator = f$attach(training_job_name="neo", sagemaker_session=sms)
  expect_true(framework_estimator$encrypt_inter_container_traffic)
})

test_that("test_attach_framework_base_from_generated_name", {
  base_job_name = "neo"
  sms = training_job_description()
  f = DummyFramework$new("dummy", instance_count=10, instance_type="dummy", role = "dummy", sagemaker_session=sms)
  framework_estimator = f$attach(training_job_name=name_from_base(base_job_name), sagemaker_session=sms)
  expect_equal(framework_estimator$base_job_name, base_job_name)
})

test_that("est_fit_verify_job_name", {
  sms = sagemaker_session()
  fw = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role="DummyRole",
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE,
    tags=TAGS,
    encrypt_inter_container_traffic=TRUE
  )
  fw$fit(inputs=TrainingInput$new("s3://mybucket/train"))
  train_kwargs = sms$train(..return_value = T)
  expect_equal(train_kwargs$image_uri, IMAGE_URI)
  expect_equal(train_kwargs$input_mode, "File")
  expect_equal(train_kwargs$tags, TAGS)
  expect_true(grepl(JOB_NAME, train_kwargs$job_name))
  expect_true(train_kwargs$encrypt_inter_container_traffic)
})

test_that("test_prepare_for_training_unique_job_name_generation", {
  sms = sagemaker_session()
  fw = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE
  )
  fw$.prepare_for_training()
  first_job_name = fw$.current_job_name

  Sys.sleep(0.1)
  fw$.prepare_for_training()
  second_job_name = fw$.current_job_name

  expect_false(first_job_name == second_job_name)
})

test_that("test_prepare_for_training_force_name", {
  sms = sagemaker_session()
  fw = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE,
    base_job_name="some"
  )
  fw$.prepare_for_training(job_name="use_it")
  expect_equal(fw$.current_job_name, "use_it")
})

test_that("test_prepare_for_training_force_name_generation", {
  sms = sagemaker_session()
  fw = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE,
    base_job_name="some"
  )
  fw$base_job_name = NULL
  fw$.prepare_for_training()
  expect_true(grepl(JOB_NAME, fw$.current_job_name))
})

test_that("test_git_support_with_branch_and_commit_succeed", {
  mock_git_clone_repo = mock_fun(side_effect = function(...) list(
    "entry_point"="/tmp/repo_dir/entry_point",
    "source_dir"=NULL,
    "dependencies"=NULL)
  )
  mock_tar_and_upload_dir = mock_fun(side_effect = function(...) list())
  git_config = list("repo"=GIT_REPO, "branch"=BRANCH, "commit"=COMMIT)
  entry_point = "entry_point"
  sms = sagemaker_session()
  fw = DummyFramework$new(
    entry_point=entry_point,
    git_config=git_config,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE
  )

  with_mock(
    `sagemaker.core::git_clone_repo` = mock_git_clone_repo,
    `sagemaker.core::tar_and_upload_dir` = mock_tar_and_upload_dir, {
      fw$fit()
  })
  expect_equal(mock_git_clone_repo(..return_value = T), list(
      git_config,
      entry_point,
      NULL,
      list()
  ))
})

test_that("test_git_support_with_branch_succeed", {
  mock_git_clone_repo = mock_fun(side_effect = function(...) list(
    "entry_point"="/tmp/repo_dir/source_dir/entry_point",
    "source_dir"=NULL,
    "dependencies"=NULL)
  )
  mock_tar_and_upload_dir = mock_fun(side_effect = function(...) list())
  git_config = list("repo"=GIT_REPO, "branch"=BRANCH)
  entry_point = "entry_point"
  sms = sagemaker_session()
  fw = DummyFramework$new(
    entry_point=entry_point,
    git_config=git_config,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE
  )

  with_mock(
    `sagemaker.core::git_clone_repo` = mock_git_clone_repo,
    `sagemaker.core::tar_and_upload_dir` = mock_tar_and_upload_dir,{
      fw$fit()
  })
  expect_equal(mock_git_clone_repo(..return_value = T), list(
    git_config,
    entry_point,
    NULL,
    list()
  ))
})

test_that("test_git_support_with_dependencies_succeed", {
  mock_git_clone_repo = mock_fun(side_effect = function(...) list(
    "entry_point"="/tmp/repo_dir/source_dir/entry_point",
    "source_dir"=NULL,
    "dependencies"=list("/tmp/repo_dir/foo", "/tmp/repo_dir/foo/bar"))
  )
  mock_tar_and_upload_dir = mock_fun(side_effect = function(...) list())
  git_config = list("repo"=GIT_REPO, "branch"=BRANCH,"commit"=COMMIT)
  entry_point = "entry_point"
  sms = sagemaker_session()
  fw = DummyFramework$new(
    entry_point=entry_point,
    git_config=git_config,
    dependencies=list("foo", "foo/bar"),
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE
  )

  with_mock(
    `sagemaker.core::git_clone_repo` = mock_git_clone_repo,
    `sagemaker.core::tar_and_upload_dir` = mock_tar_and_upload_dir,{
      fw$fit()
  })
  expect_equal(mock_git_clone_repo(..return_value = T), list(
    git_config,
    entry_point,
    NULL,
    list("foo", "foo/bar")
  ))
})

test_that("test_git_support_without_branch_and_commit_succeed", {
  mock_git_clone_repo = mock_fun(side_effect = function(...) list(
    "entry_point"="/tmp/repo_dir/source_dir/entry_point",
    "source_dir"=NULL,
    "dependencies"=NULL)
  )
  mock_tar_and_upload_dir = mock_fun(side_effect = function(...) list())
  git_config = list("repo"=GIT_REPO)
  entry_point = "source_dir/entry_point"
  sms = sagemaker_session()
  fw = DummyFramework$new(
    entry_point=entry_point,
    git_config=git_config,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE
  )

  with_mock(
    `sagemaker.core::git_clone_repo` = mock_git_clone_repo,
    `sagemaker.core::tar_and_upload_dir` = mock_tar_and_upload_dir,{
      fw$fit()
  })
  expect_equal(mock_git_clone_repo(..return_value = T), list(
    git_config,
    entry_point,
    NULL,
    list()
  ))
})

test_that("test_git_support_repo_not_provided", {
  git_config = list("branch"=BRANCH, "commit"=COMMIT)
  sms = sagemaker_session()
  fw = DummyFramework$new(
    entry_point="entry_point",
    git_config=git_config,
    source_dir="source_dir",
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE
  )
  expect_error(
    fw$fit(),
    "Please provide a repo for git_config.",
    class = "ValueError"
  )
})

test_that("test_git_support_bad_repo_url_format", {
  git_config = list("repo"="hhttps://github.com/user/repo.git", "branch"=BRANCH)
  sms = sagemaker_session()
  fw = DummyFramework$new(
    entry_point="entry_point",
    git_config=git_config,
    source_dir="source_dir",
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE
  )
  expect_error(
    fw$fit(),
    "Invalid Git url provided.",
    class = "ValueError"
  )
})

test_that("test_git_support_entry_point_not_exist", {
  mock_git_clone_repo = mock_fun(side_effect = function(...) ValueError$new("Entry point does not exist in the repo."))
  git_config = list("repo"=GIT_REPO, "branch"=BRANCH, "commit"=COMMIT)
  sms = sagemaker_session()
  fw = DummyFramework$new(
    entry_point="entry_point_that_does_not_exist",
    git_config=git_config,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE
  )
  with_mock(
    `sagemaker.core::git_clone_repo` = mock_git_clone_repo, {
      expect_error(
        fw$fit(),
        "Entry point does not exist in the repo.",
        class = "ValueError"
      )
  })
})

test_that("test_git_support_source_dir_not_exist", {
  mock_git_clone_repo = mock_fun(side_effect = function(...) ValueError$new("Source directory does not exist in the repo."))
  git_config = list("repo"=GIT_REPO, "branch"=BRANCH, "commit"=COMMIT)
  sms = sagemaker_session()
  fw = DummyFramework$new(
    entry_point="entry_point",
    git_config=git_config,
    source_dir="source_dir_that_does_not_exist",
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE
  )
  with_mock(
    `sagemaker.core::git_clone_repo` = mock_git_clone_repo, {
      expect_error(
        fw$fit(),
        "Source directory does not exist in the repo.",
        class = "ValueError"
      )
  })
})

test_that("test_git_support_dependencies_not_exist", {
  mock_git_clone_repo = mock_fun(side_effect = function(...) ValueError$new("Dependency no-such-dir does not exist in the repo."))
  git_config = list("repo"=GIT_REPO, "branch"=BRANCH, "commit"=COMMIT)
  sms = sagemaker_session()
  fw = DummyFramework$new(
    entry_point="entry_point",
    git_config=git_config,
    source_dir="source_dir",
    dependencies=list("foo", "no-such-dir"),
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE
  )
  with_mock(
    `sagemaker.core::git_clone_repo` = mock_git_clone_repo, {
      expect_error(
        fw$fit(),
        "Dependency no-such-dir does not exist in the repo.",
        class = "ValueError"
      )
  })
})

test_that("test_git_support_with_username_password_no_2fa", {
  mock_git_clone_repo = mock_fun(side_effect = function(...) list(
    "entry_point"="/tmp/repo_dir/entry_point",
    "source_dir"=NULL,
    "dependencies"=NULL
  ))
  mock_tar_and_upload_dir = mock_fun(side_effect = function(...) list())
  git_config = list(
    "repo"=PRIVATE_GIT_REPO,
    "branch"=PRIVATE_BRANCH,
    "commit"=PRIVATE_COMMIT,
    "username"="username",
    "password"="passw0rd!"
  )
  entry_point="entry_point"
  sms = sagemaker_session()
  fw = DummyFramework$new(
    entry_point=entry_point,
    git_config=git_config,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE
  )

  with_mock(
    `sagemaker.core::git_clone_repo` = mock_git_clone_repo,
    `sagemaker.core::tar_and_upload_dir` = mock_tar_and_upload_dir, {
      fw$fit()
  })

  expect_equal(mock_git_clone_repo(..return_value = T), list(git_config, entry_point, NULL, list()))
  expect_equal(fw$entry_point, "/tmp/repo_dir/entry_point")
})

test_that("test_git_support_with_token_2fa", {
  mock_git_clone_repo = mock_fun(side_effect = function(...) list(
    "entry_point"="/tmp/repo_dir/entry_point",
    "source_dir"=NULL,
    "dependencies"=NULL
  ))
  mock_tar_and_upload_dir = mock_fun(side_effect = function(...) list())
  git_config = list(
    "repo"=PRIVATE_GIT_REPO,
    "branch"=PRIVATE_BRANCH,
    "commit"=PRIVATE_COMMIT,
    "token"="my-token",
    "2FA_enabled"=TRUE
  )
  entry_point = "entry_point"
  sms = sagemaker_session()
  fw = DummyFramework$new(
    entry_point=entry_point,
    git_config=git_config,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE
  )
  with_mock(
    `sagemaker.core::git_clone_repo` = mock_git_clone_repo,
    `sagemaker.core::tar_and_upload_dir` = mock_tar_and_upload_dir, {
      fw$fit()
  })
  expect_equal(mock_git_clone_repo(..return_value = T), list(git_config, entry_point, NULL, list()))
  expect_equal(fw$entry_point, "/tmp/repo_dir/entry_point")
})

test_that("test_git_support_ssh_no_passphrase_needed", {
  mock_git_clone_repo = mock_fun(side_effect = function(...) list(
    "entry_point"="/tmp/repo_dir/entry_point",
    "source_dir"=NULL,
    "dependencies"=NULL
  ))
  mock_tar_and_upload_dir = mock_fun(side_effect = function(...) list())
  git_config = list("repo"=PRIVATE_GIT_REPO_SSH, "branch"=PRIVATE_BRANCH, "commit"=PRIVATE_COMMIT)
  sms = sagemaker_session()
  entry_point="entry_point"
  fw = DummyFramework$new(
    entry_point=entry_point,
    git_config=git_config,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE,
  )

  with_mock(
    `sagemaker.core::git_clone_repo` = mock_git_clone_repo,
    `sagemaker.core::tar_and_upload_dir` = mock_tar_and_upload_dir,{
      fw$fit()
  })
  expect_equal(mock_git_clone_repo(..return_value = T), list(git_config, entry_point, NULL, list()))
  expect_equal(fw$entry_point, "/tmp/repo_dir/entry_point")
})

test_that("test_git_support_codecommit_with_username_and_password_succeed", {
  mock_git_clone_repo = mock_fun(side_effect = function(...) list(
    "entry_point"="/tmp/repo_dir/entry_point",
    "source_dir"=NULL,
    "dependencies"=NULL
  ))
  mock_tar_and_upload_dir = mock_fun(side_effect = function(...) list())
  git_config = list(
    "repo"=CODECOMMIT_REPO,
    "branch"=CODECOMMIT_BRANCH,
    "username"="username",
    "password"="passw0rd!"
  )
  entry_point = "entry_point"
  sms = sagemaker_session()
  fw = DummyFramework$new(
    entry_point=entry_point,
    git_config=git_config,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE
  )
  with_mock(
    `sagemaker.core::git_clone_repo` = mock_git_clone_repo,
    `sagemaker.core::tar_and_upload_dir` = mock_tar_and_upload_dir,{
      fw$fit()
  })
  expect_equal(mock_git_clone_repo(..return_value = T), list(git_config, entry_point, NULL, list()))
  expect_equal(fw$entry_point, "/tmp/repo_dir/entry_point")
})

test_that("test_git_support_codecommit_with_ssh_no_passphrase_needed", {
  mock_git_clone_repo = mock_fun(side_effect = function(...) list(
    "entry_point"="/tmp/repo_dir/entry_point",
    "source_dir"=NULL,
    "dependencies"=NULL
  ))
  mock_tar_and_upload_dir = mock_fun(side_effect = function(...) list())
  git_config = list("repo"=CODECOMMIT_REPO_SSH, "branch"=CODECOMMIT_BRANCH)
  entry_point = "entry_point"
  sms = sagemaker_session()
  fw = DummyFramework$new(
    entry_point=entry_point,
    git_config=git_config,
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE
  )
  with_mock(
    `sagemaker.core::git_clone_repo` = mock_git_clone_repo,
    `sagemaker.core::tar_and_upload_dir` = mock_tar_and_upload_dir,{
      fw$fit()
  })
  expect_equal(mock_git_clone_repo(..return_value = T), list(git_config, entry_point, NULL, list()))
  expect_equal(fw$entry_point, "/tmp/repo_dir/entry_point")
})

test_that("test_init_with_source_dir_s3", {
  sms = sagemaker_session()
  fw = DummyFramework$new(
    entry_point=SCRIPT_NAME,
    source_dir="s3://location",
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE
  )
  fw$.prepare_for_training()
  actual = fw$hyperparameters()
  expect_equal(actual[["sagemaker_submit_directory"]], "s3://location")
  expect_equal(actual[["sagemaker_program"]], SCRIPT_NAME)
  expect_equal(actual[["sagemaker_container_log_level"]], "20")
  expect_true(grepl(JOB_NAME, actual[["sagemaker_job_name"]]))
  expect_equal(actual[["sagemaker_region"]], "us-west-2")
})

test_that("test_framework_transformer_creation", {
  vpc_config = list("Subnets"=list("foo"), "SecurityGroupIds"=list("bar"))
  sms = sagemaker_session()
  fw = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE,
    sagemaker_session=sms,
    subnets=vpc_config[["Subnets"]],
    security_group_ids=vpc_config[["SecurityGroupIds"]]
  )
  fw$latest_training_job = JOB_NAME
  assign("name_from_base", mock_fun(MODEL_IMAGE), envir = environment(fw$.__enclos_env__$private$.get_or_create_name))

  transformer = fw$transformer(INSTANCE_COUNT, INSTANCE_TYPE)
  expect_equal(sms$create_model(..return_value = T), list(
    MODEL_IMAGE,
    ROLE,
    MODEL_CONTAINER_DEF,
    vpc_config=vpc_config,
    enable_network_isolation=FALSE,
    tags=NULL
  ))
  expect_true(inherits(transformer, "Transformer"))
  expect_equal(transformer$sagemaker_session, sms)
  expect_equal(transformer$instance_count, INSTANCE_COUNT)
  expect_equal(transformer$instance_type, INSTANCE_TYPE)
  expect_equal(transformer$model_name, MODEL_IMAGE)
  expect_null(transformer$tags)
  expect_equal(transformer$env, list())
})

test_that("test_framework_transformer_creation_with_optional_params", {
  base_name = "foo"
  vpc_config = list("Subnets"=list("foo"), "SecurityGroupIds"=list("bar"))
  sms = sagemaker_session()
  fw = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE,
    sagemaker_session=sms,
    base_job_name=base_name,
    subnets=vpc_config[["Subnets"]],
    security_group_ids=vpc_config[["SecurityGroupIds"]],
    enable_network_isolation=FALSE
  )
  fw$latest_training_job = JOB_NAME
  assign("name_from_base", mock_fun(MODEL_IMAGE), envir = environment(fw$.__enclos_env__$private$.get_or_create_name))

  strategy = "MultiRecord"
  assemble_with = "Line"
  kms_key = "key"
  accept = "text/csv"
  max_concurrent_transforms = 1
  max_payload = 6
  env = list("FOO"="BAR")
  new_role = "dummy-model-role"
  new_vpc_config = list("Subnets"=list("x"), "SecurityGroupIds"=list("y"))
  model_name = "model-name"

  transformer = fw$transformer(
    INSTANCE_COUNT,
    INSTANCE_TYPE,
    strategy=strategy,
    assemble_with=assemble_with,
    output_path=OUTPUT_PATH,
    output_kms_key=kms_key,
    accept=accept,
    tags=TAGS,
    max_concurrent_transforms=max_concurrent_transforms,
    max_payload=max_payload,
    volume_kms_key=kms_key,
    env=env,
    role=new_role,
    model_server_workers=1,
    vpc_config_override=new_vpc_config,
    enable_network_isolation=TRUE,
    model_name=model_name
  )
  expect_equal(sms$create_model(..return_value = T), list(
    model_name,
    new_role,
    MODEL_CONTAINER_DEF,
    vpc_config=new_vpc_config,
    enable_network_isolation=TRUE,
    tags=TAGS
  ))
  expect_equal(transformer$strategy, strategy)
  expect_equal(transformer$assemble_with, assemble_with)
  expect_equal(transformer$output_path, OUTPUT_PATH)
  expect_equal(transformer$output_kms_key, kms_key)
  expect_equal(transformer$accept, accept)
  expect_equal(transformer$max_concurrent_transforms, max_concurrent_transforms)
  expect_equal(transformer$max_payload, max_payload)
  expect_equal(transformer$env, env)
  expect_equal(transformer$base_transform_job_name, base_name)
  expect_equal(transformer$volume_kms_key, kms_key)
  expect_equal(transformer$model_name, model_name)
})

test_that("test_ensure_latest_training_job", {
  sms = sagemaker_session()
  fw = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE,
    sagemaker_session=sms
  )
  fw$latest_training_job = "training_job"
  expect_null(fw$.__enclos_env__$private$.ensure_latest_training_job())
})

test_that("test_ensure_latest_training_job_failure", {
  sms = sagemaker_session()
  fw = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role=ROLE,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE,
    sagemaker_session=sms
  )
  expect_error(
    fw$.__enclos_env__$private$.ensure_latest_training_job(),
    "Estimator is not associated with a training job",
    class="ValueError"
  )
})

test_that("test_estimator_transformer_creation", {
  sms = sagemaker_session()
  estimator = Estimator$new(
    image_uri=IMAGE_URI,
    role=ROLE,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE,
    sagemaker_session=sms
  )
  estimator$latest_training_job = JOB_NAME
  model_name = "model_name"
  assign(
    "name_from_base",
    mock_fun(model_name),
    envir = environment(estimator$.__enclos_env__$private$.get_or_create_name)
  )
  transformer = estimator$transformer(INSTANCE_COUNT, INSTANCE_TYPE)

  expect_true(inherits(transformer, "Transformer"))
  expect_equal(transformer$sagemaker_session, sms)
  expect_equal(transformer$instance_count, INSTANCE_COUNT)
  expect_equal(transformer$instance_type, INSTANCE_TYPE)
  expect_equal(transformer$model_name, model_name)
  expect_null(transformer$tags)
})

test_that("test_estimator_transformer_creation_with_optional_params", {
  base_name = "foo"
  kms_key = "key"
  sms = sagemaker_session()
  estimator = Estimator$new(
    image_uri=IMAGE_URI,
    role=ROLE,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE,
    sagemaker_session=sms,
    base_job_name=base_name,
    output_kms_key=kms_key
  )
  estimator$latest_training_job = JOB_NAME
  strategy = "MultiRecord"
  assemble_with = "Line"
  accept = "text/csv"
  max_concurrent_transforms = 1
  max_payload = 6
  env = list("FOO"="BAR")
  new_vpc_config = list("Subnets"=list("x"), "SecurityGroupIds"=list("y"))
  model_name = "model-name"

  assign(
    "name_from_base",
    mock_fun(model_name),
    envir = environment(estimator$.__enclos_env__$private$.get_or_create_name)
  )
  transformer = estimator$transformer(
    INSTANCE_COUNT,
    INSTANCE_TYPE,
    strategy=strategy,
    assemble_with=assemble_with,
    output_path=OUTPUT_PATH,
    output_kms_key=kms_key,
    accept=accept,
    tags=TAGS,
    max_concurrent_transforms=max_concurrent_transforms,
    max_payload=max_payload,
    env=env,
    role=ROLE,
    vpc_config_override=new_vpc_config,
    enable_network_isolation=TRUE,
    model_name=model_name
  )
  expect_equal(transformer$strategy, strategy)
  expect_equal(transformer$assemble_with, assemble_with)
  expect_equal(transformer$output_path, OUTPUT_PATH)
  expect_equal(transformer$output_kms_key, kms_key)
  expect_equal(transformer$accept, accept)
  expect_equal(transformer$max_concurrent_transforms, max_concurrent_transforms)
  expect_equal(transformer$max_payload, max_payload)
  expect_equal(transformer$env, env)
  expect_equal(transformer$base_transform_job_name, base_name)
  expect_equal(transformer$tags, TAGS)
  expect_equal(transformer$model_name, model_name)
})

test_that("test_start_new", {
  training_job = JOB_NAME
  hyperparameters = list("mock"="hyperparameters")
  inputs = "s3://mybucket/train"
  sms = sagemaker_session()
  estimator = Estimator$new(
    IMAGE_URI,
    ROLE,
    INSTANCE_COUNT,
    INSTANCE_TYPE,
    output_path=OUTPUT_PATH,
    sagemaker_session=sms,
    hyperparameters=hyperparameters
  )

  exp_config = list("ExperimentName"="exp", "TrialName"="t", "TrialComponentDisplayName"="tc")
  started_training_job = estimator$.__enclos_env__$private$.start_new(inputs, exp_config)
  called_args = sms$train(..return_value = T)
  expect_equal(called_args[["hyperparameters"]], hyperparameters)
  expect_equal(called_args[["experiment_config"]], exp_config)
})

test_that("test_start_new", {
  training_job = JOB_NAME
  inputs = "file://mybucket/train"
  sms = sagemaker_session()
  estimator = Estimator$new(
    IMAGE_URI,
    ROLE,
    INSTANCE_COUNT,
    INSTANCE_TYPE,
    output_path=OUTPUT_PATH,
    sagemaker_session=sms
  )
  expect_error(
    estimator$.__enclos_env__$private$.start_new(inputs, exp_config),
    "File URIs are supported in local mode only. Please use a S3 URI instead.",
    class="ValueError"
  )
})

test_that("test_container_log_level", {
  sms = sagemaker_session()
  fw = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role="DummyRole",
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE,
    container_log_level="DEBUG"
  )
  fw$fit(inputs=TrainingInput$new("s3://mybucket/train"))
  train_kwargs = sms$train(..return_value = T)
  expect_equal(train_kwargs[["hyperparameters"]][["sagemaker_container_log_level"]], "10")
})

test_that("test_wait_without_logs", {
  sms = sagemaker_session()
  fw = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role="DummyRole",
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE
  )
  fw$latest_training_job = "JOB_NAME"
  fw$wait(FALSE)

  kwargs = sms$wait_for_job(..return_value = T)
  expect_equal(kwargs[["job"]], "JOB_NAME")
  expect_null(sms$logs_for_job(..return_value = T))
  expect_equal(sms$logs_for_job(..count = T), 0)
})

test_that("test_wait_with_logs", {
  sms = sagemaker_session()
  fw = DummyFramework$new(
    entry_point=SCRIPT_PATH,
    role="DummyRole",
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE
  )
  fw$latest_training_job = "JOB_NAME"
  fw$wait()

  expect_equal(sms$logs_for_job(..return_value = T), list(
    job_name="JOB_NAME",
    wait=TRUE,
    log_type="All"
  ))
  expect_null(sms$wait_for_job(..return_value = T))
  expect_equal(sms$wait_for_job(..count = T), 0)
})

#################################################################################
# Tests for the generic Estimator class

NO_INPUT_TRAIN_CALL = list(
  "input_config"=NULL,
  "output_config"=list("S3OutputPath"=OUTPUT_PATH),
  "resource_config"=list(
    "InstanceCount"=INSTANCE_COUNT,
    "InstanceType"=INSTANCE_TYPE,
    "VolumeSizeInGB"=30
  ),
  "stop_condition"=list("MaxRuntimeInSeconds"=86400),
  "vpc_config"=NULL,
  "input_mode"="File",
  "hyperparameters"=list(),
  "image_uri"=IMAGE_URI,
  "profiler_rule_configs"=list(
    list(
      "RuleConfigurationName"="ProfilerReport-[0-9]+",
      "RuleEvaluatorImage"="895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest",
      "RuleParameters"=list("rule_to_invoke"="ProfilerReport")
    )
  ),
  "profiler_config"=list("S3OutputPath"=OUTPUT_PATH)
)

INPUT_CONFIG = list(
  list(
    "DataSource"=list(
      "S3DataSource"=list(
        "S3DataType"="S3Prefix",
        "S3Uri"="s3://bucket/training-prefix",
        "S3DataDistributionType"="FullyReplicated"
      )
    ),
    "ChannelName"="train"
  )
)

BASE_TRAIN_CALL = NO_INPUT_TRAIN_CALL
BASE_TRAIN_CALL = modifyList(BASE_TRAIN_CALL, list("input_config"=INPUT_CONFIG))

HYPERPARAMS = list("x"=1, "y"="hello")
STRINGIFIED_HYPERPARAMS = lapply(HYPERPARAMS, function(x) as.character(x))
HP_TRAIN_CALL = BASE_TRAIN_CALL
HP_TRAIN_CALL = modifyList(HP_TRAIN_CALL, list("hyperparameters"=STRINGIFIED_HYPERPARAMS))

EXP_TRAIN_CALL = BASE_TRAIN_CALL
EXP_TRAIN_CALL = modifyList(EXP_TRAIN_CALL, list(
    "experiment_config"=list(
      "ExperimentName"="exp",
      "TrialName"="trial",
      "TrialComponentDisplayName"="tc"
    )
  )
)

test_that("test_fit_deploy_tags_in_estimator", {
  tags = list(list("Key"="TagtestKey", "Value"="TagtestValue"))
  sms = sagemaker_session()
  estimator = Estimator$new(
    IMAGE_URI,
    ROLE,
    INSTANCE_COUNT,
    INSTANCE_TYPE,
    tags=tags,
    sagemaker_session=sms
  )
  estimator$fit()

  model_name = "model_name"
  assign("name_from_base", mock_fun(model_name), envir = environment(estimator$.__enclos_env__$private$.get_or_create_name))
  estimator$deploy(INSTANCE_COUNT, INSTANCE_TYPE)

  variant = list(
    list(
      "ModelName"=model_name,
      "VariantName"="AllTraffic",
      "InitialVariantWeight"=1,
      "InitialInstanceCount"=1,
      "InstanceType"="c4.4xlarge"
    )
  )
  expect_equal(sms$endpoint_from_production_variants(..return_value = T), list(
    name=model_name,
    production_variants=variant,
    tags=tags,
    kms_key=NULL,
    wait=TRUE,
    data_capture_config_list=NULL
  ))
  expect_equal(sms$create_model(..return_value = TRUE), list(
    model_name,
    "DummyRole",
    list("Image"="fakeimage", "Environment"=list(), "ModelDataUrl"="s3://bucket/model.tar.gz"),
    vpc_config=NULL,
    enable_network_isolation=FALSE,
    tags=tags
  ))
})

test_that("test_fit_deploy_tags", {
  sms = sagemaker_session()
  estimator = Estimator$new(
    IMAGE_URI, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, sagemaker_session=sms
  )
  estimator$fit()

  model_name = "model_name"
  mock_name_from_base = mock_fun(model_name)
  assign("name_from_base", mock_name_from_base, envir = environment(estimator$.__enclos_env__$private$.get_or_create_name))

  tags = list(list("Key"="TagtestKey", "Value"="TagtestValue"))
  estimator$deploy(INSTANCE_COUNT, INSTANCE_TYPE, tags=tags)

  variant = list(
    list(
      "ModelName"=model_name,
      "VariantName"="AllTraffic",
      "InitialVariantWeight"=1,
      "InitialInstanceCount"=1,
      "InstanceType"="c4.4xlarge"
    )
  )
  expect_equal(mock_name_from_base(..return_value = T), list(IMAGE_URI))
  expect_equal(sms$endpoint_from_production_variants(..return_value = T), list(
    name=model_name,
    production_variants=variant,
    tags=tags,
    kms_key=NULL,
    wait=TRUE,
    data_capture_config_list=NULL
  ))
  expect_equal(sms$create_model(..return_value = TRUE), list(
    model_name,
    "DummyRole",
    list("Image"="fakeimage", "Environment"=list(), "ModelDataUrl"="s3://bucket/model.tar.gz"),
    vpc_config=NULL,
    enable_network_isolation=FALSE,
    tags=tags
  ))
})

test_that("test_generic_to_fit_no_input", {
  sms = sagemaker_session()
  e = Estimator$new(
    IMAGE_URI,
    ROLE,
    INSTANCE_COUNT,
    INSTANCE_TYPE,
    output_path=OUTPUT_PATH,
    sagemaker_session=sms
  )
  e$fit()

  args = sms$train(..return_value = TRUE)
  expect_true(startsWith(args[["job_name"]], IMAGE_URI))

  args[["job_name"]] = NULL
  args[["role"]] = NULL

  exp_args = NO_INPUT_TRAIN_CALL
  actual_rule_config = args$profiler_rule_configs[[1]]$RuleConfigurationName
  exp_rule_config = exp_args$profiler_rule_configs[[1]]$RuleConfigurationName
  args$profiler_rule_configs[[1]]$RuleConfigurationName = NULL
  exp_args$profiler_rule_configs[[1]]$RuleConfigurationName = NULL
  expect_true(grepl(exp_rule_config, actual_rule_config))
  expect_equal(args, exp_args)
})

test_that("test_generic_to_fit_no_hps", {
  sms = sagemaker_session()
  e = Estimator$new(
    IMAGE_URI,
    ROLE,
    INSTANCE_COUNT,
    INSTANCE_TYPE,
    output_path=OUTPUT_PATH,
    sagemaker_session=sms
  )
  e$fit(list("train"="s3://bucket/training-prefix"))

  args = sms$train(..return_value = TRUE)
  expect_true(startsWith(args[["job_name"]], IMAGE_URI))

  args[["job_name"]] = NULL
  args[["role"]] = NULL

  exp_args = BASE_TRAIN_CALL
  actual_rule_config = args$profiler_rule_configs[[1]]$RuleConfigurationName
  exp_rule_config = exp_args$profiler_rule_configs[[1]]$RuleConfigurationName
  args$profiler_rule_configs[[1]]$RuleConfigurationName = NULL
  exp_args$profiler_rule_configs[[1]]$RuleConfigurationName = NULL
  expect_true(grepl(exp_rule_config, actual_rule_config))
  expect_equal(args, exp_args)
})

test_that("test_generic_to_fit_with_hps", {
  sms = sagemaker_session()
  e = Estimator$new(
    IMAGE_URI,
    ROLE,
    INSTANCE_COUNT,
    INSTANCE_TYPE,
    output_path=OUTPUT_PATH,
    sagemaker_session=sms
  )
  do.call(e$set_hyperparameters, HYPERPARAMS)
  e$fit(list("train"="s3://bucket/training-prefix"))

  args = sms$train(..return_value = TRUE)
  expect_true(startsWith(args[["job_name"]], IMAGE_URI))

  args[["job_name"]] = NULL
  args[["role"]] = NULL

  exp_args = HP_TRAIN_CALL
  actual_rule_config = args$profiler_rule_configs[[1]]$RuleConfigurationName
  exp_rule_config = exp_args$profiler_rule_configs[[1]]$RuleConfigurationName
  args$profiler_rule_configs[[1]]$RuleConfigurationName = NULL
  exp_args$profiler_rule_configs[[1]]$RuleConfigurationName = NULL
  expect_true(grepl(exp_rule_config, actual_rule_config))
  expect_equal(args, exp_args)
})

test_that("test_generic_to_fit_with_experiment_config", {
  sms = sagemaker_session()
  e = Estimator$new(
    IMAGE_URI,
    ROLE,
    INSTANCE_COUNT,
    INSTANCE_TYPE,
    output_path=OUTPUT_PATH,
    sagemaker_session=sms
  )
  e$fit(inputs=list("train"="s3://bucket/training-prefix"),
        experiment_config=list(
          "ExperimentName"="exp",
          "TrialName"="trial",
          "TrialComponentDisplayName"="tc"
        )
  )

  args = sms$train(..return_value = TRUE)
  expect_true(startsWith(args[["job_name"]], IMAGE_URI))

  args[["job_name"]] = NULL
  args[["role"]] = NULL

  exp_args = EXP_TRAIN_CALL
  actual_rule_config = args$profiler_rule_configs[[1]]$RuleConfigurationName
  exp_rule_config = exp_args$profiler_rule_configs[[1]]$RuleConfigurationName
  args$profiler_rule_configs[[1]]$RuleConfigurationName = NULL
  exp_args$profiler_rule_configs[[1]]$RuleConfigurationName = NULL
  expect_true(grepl(exp_rule_config, actual_rule_config))
  expect_equal(args[sort(names(args))], exp_args[sort(names(exp_args))])
})

test_that("test_generic_to_fit_with_encrypt_inter_container_traffic_flag", {
  sms = sagemaker_session()
  e = Estimator$new(
    IMAGE_URI,
    ROLE,
    INSTANCE_COUNT,
    INSTANCE_TYPE,
    output_path=OUTPUT_PATH,
    sagemaker_session=sms,
    encrypt_inter_container_traffic=TRUE
  )
  e$fit()

  args = sms$train(..return_value = TRUE)

  expect_true(args[["encrypt_inter_container_traffic"]])
})

test_that("test_generic_to_fit_with_network_isolation", {
  sms = sagemaker_session()
  e = Estimator$new(
    IMAGE_URI,
    ROLE,
    INSTANCE_COUNT,
    INSTANCE_TYPE,
    output_path=OUTPUT_PATH,
    sagemaker_session=sms,
    enable_network_isolation=TRUE
  )
  e$fit()

  args = sms$train(..return_value = TRUE)

  expect_true(args[["enable_network_isolation"]])
})

test_that("test_generic_to_fit_with_sagemaker_metrics_missing", {
  sms = sagemaker_session()
  e = Estimator$new(
    IMAGE_URI,
    ROLE,
    INSTANCE_COUNT,
    INSTANCE_TYPE,
    output_path=OUTPUT_PATH,
    sagemaker_session=sms
  )
  e$fit()

  args = sms$train(..return_value = TRUE)
  expect_false("enable_sagemaker_metrics" %in% names(args))
})

test_that("test_add_environment_variables_to_train_args", {
  sms = sagemaker_session()
  e = Estimator$new(
    IMAGE_URI,
    ROLE,
    INSTANCE_COUNT,
    INSTANCE_TYPE,
    output_path=OUTPUT_PATH,
    sagemaker_session=sms,
    environment=ENV_INPUT
  )
  e$fit()

  args = sms$train(..return_value = TRUE)
  expect_equal(args[["environment"]], ENV_INPUT)
})

test_that("test_add_retry_strategy_to_train_args", {
  sms = sagemaker_session()
  e = Estimator$new(
    IMAGE_URI,
    ROLE,
    INSTANCE_COUNT,
    INSTANCE_TYPE,
    output_path=OUTPUT_PATH,
    sagemaker_session=sms,
    max_retry_attempts=2
  )
  e$fit()

  args = sms$train(..return_value = TRUE)
  expect_equal(args[["retry_strategy"]], list("MaximumRetryAttempts"=2))
})

test_that("test_generic_to_fit_with_sagemaker_metrics_enabled", {
  sms = sagemaker_session()
  e = Estimator$new(
    IMAGE_URI,
    ROLE,
    INSTANCE_COUNT,
    INSTANCE_TYPE,
    output_path=OUTPUT_PATH,
    sagemaker_session=sms,
    enable_sagemaker_metrics=TRUE
  )
  e$fit()

  args = sms$train(..return_value = TRUE)
  expect_true(args[["enable_sagemaker_metrics"]])
})

test_that("test_generic_to_fit_with_sagemaker_metrics_disabled", {
  sms = sagemaker_session()
  e = Estimator$new(
    IMAGE_URI,
    ROLE,
    INSTANCE_COUNT,
    INSTANCE_TYPE,
    output_path=OUTPUT_PATH,
    sagemaker_session=sms,
    enable_sagemaker_metrics=FALSE
  )
  e$fit()

  args = sms$train(..return_value = TRUE)
  expect_false(args[["enable_sagemaker_metrics"]])
})

test_that("test_generic_to_deploy", {
  sms = sagemaker_session()
  e = Estimator$new(
    IMAGE_URI,
    ROLE,
    INSTANCE_COUNT,
    INSTANCE_TYPE,
    output_path=OUTPUT_PATH,
    sagemaker_session=sms
  )
  do.call(e$set_hyperparameters, HYPERPARAMS)
  e$fit(list("train"="s3://bucket/training-prefix"))
  predictor = e$deploy(INSTANCE_COUNT, INSTANCE_TYPE)

  args = sms$train(..return_value = TRUE)

  expect_true(startsWith(args[["job_name"]], IMAGE_URI))
  args[["job_name"]] = NULL
  args[["role"]] = NULL

  args$profiler_rule_configs[[1]]$RuleConfigurationName
  exp_args = HP_TRAIN_CALL
  actual_rule_config = args$profiler_rule_configs[[1]]$RuleConfigurationName
  exp_rule_config = exp_args$profiler_rule_configs[[1]]$RuleConfigurationName
  args$profiler_rule_configs[[1]]$RuleConfigurationName = NULL
  exp_args$profiler_rule_configs[[1]]$RuleConfigurationName = NULL
  expect_true(grepl(exp_rule_config, actual_rule_config))

  args = sms$create_model(..return_value = T)
  expect_true(startsWith(args[[1]], IMAGE_URI))
  expect_equal(args[[2]], ROLE)
  expect_equal(args[[3]][["Image"]], IMAGE_URI)
  expect_equal(args[[3]][["ModelDataUrl"]], MODEL_DATA)
  expect_null(args[["vpc_config"]])

  expect_true(inherits(predictor, "Predictor"))
  expect_true(startsWith(predictor$endpoint_name, IMAGE_URI))
  expect_equal(predictor$sagemaker_session, sms)
})

test_that("test_generic_to_deploy_network_isolation", {
  sms = sagemaker_session()
  e = Estimator$new(
    IMAGE_URI,
    ROLE,
    INSTANCE_COUNT,
    INSTANCE_TYPE,
    output_path=OUTPUT_PATH,
    enable_network_isolation=TRUE,
    sagemaker_session=sms
  )
  e$fit()
  e$deploy(INSTANCE_COUNT, INSTANCE_TYPE)

  args = sms$create_model(..return_value = T)
  expect_true(args[["enable_network_isolation"]])
})

test_that("test_generic_training_job_analytics", {
  sms = sagemaker_session()
  sms$sagemaker$.call_args("describe_training_job", return_value = list(
    "TuningJobArn"="arn:aws:sagemaker:us-west-2:968277160000:hyper-parameter-tuning-job/mock-tuner",
    "TrainingStartTime"=1530562991.299,
    "AlgorithmSpecification"=list(
      "TrainingImage"="some-image-url",
      "TrainingInputMode"="File",
      "MetricDefinitions"=list(
        list("Name"="train:loss", "Regex"="train_loss=([0-9]+\\.[0-9]+)"),
        list("Name"="validation:loss", "Regex"="valid_loss=([0-9]+\\.[0-9]+)")
      )
    )
  ))
  e = Estimator$new(
    IMAGE_URI,
    ROLE,
    INSTANCE_COUNT,
    INSTANCE_TYPE,
    output_path=OUTPUT_PATH,
    sagemaker_session=sms
  )
  expect_error(
    e$training_job_analytics,
    class = "ValueError"
  )
  do.call(e$set_hyperparameters, HYPERPARAMS)
  e$fit(list("train"="s3://bucket/training-prefix"))
  a = e$training_job_analytics
  expect_true(!is.null(a))
})

test_that("test_generic_create_model_vpc_config_override", {
  vpc_config_a = list("Subnets"=list("foo"), "SecurityGroupIds"=list("bar"))
  vpc_config_b = list("Subnets"=list("foo", "bar"), "SecurityGroupIds"=list("baz"))

  sms = sagemaker_session()
  e = Estimator$new(
    IMAGE_URI, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, sagemaker_session=sms
  )
  e$fit(list("train"="s3://bucket/training-prefix"))
  expect_null(e$get_vpc_config())
  expect_null(e$create_model()$vpc_config)
  expect_equal(e$create_model(vpc_config_override=vpc_config_a)$vpc_config,vpc_config_a)
  expect_null(e$create_model(vpc_config_override=NULL)$vpc_config)

  e$subnets = vpc_config_a[["Subnets"]]
  e$security_group_ids = vpc_config_a[["SecurityGroupIds"]]
  expect_equal(e$get_vpc_config(), vpc_config_a)
  expect_equal(e$create_model()$vpc_config, vpc_config_a)
  expect_equal(e$create_model(vpc_config_override=vpc_config_b)$vpc_config, vpc_config_b)
  expect_null(e$create_model(vpc_config_override=NULL)$vpc_config)

  expect_error(
    e$get_vpc_config(vpc_config_override=list("invalid")), class = "ValueError"
  )
  expect_error(
    e$create_model(vpc_config_override=list("invalid")), class = "ValueError"
  )
})

test_that("test_generic_deploy_vpc_config_override", {
  vpc_config_a = list("Subnets"=list("foo"), "SecurityGroupIds"=list("bar"))
  vpc_config_b = list("Subnets"=list("foo", "bar"), "SecurityGroupIds"=list("baz"))

  sms = sagemaker_session()
  e = Estimator$new(
    IMAGE_URI, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, sagemaker_session=sms
  )
  e$fit(list("train"="s3://bucket/training-prefix"))
  e$deploy(INSTANCE_COUNT, INSTANCE_TYPE)

  expect_null(sms$create_model(..return_value = T)[["vpc_config"]])

  e$subnets = vpc_config_a[["Subnets"]]
  e$security_group_ids = vpc_config_a[["SecurityGroupIds"]]
  e$deploy(INSTANCE_COUNT, INSTANCE_TYPE)
  expect_equal(sms$create_model(..return_value = T)[["vpc_config"]], vpc_config_a)

  e$deploy(INSTANCE_COUNT, INSTANCE_TYPE, vpc_config_override=vpc_config_b)
  expect_equal(sms$create_model(..return_value = T)[["vpc_config"]], vpc_config_b)

  e$deploy(INSTANCE_COUNT, INSTANCE_TYPE, vpc_config_override=NULL)
  expect_null(sms$create_model(..return_value = T)[["vpc_config"]])
})

test_that("test_generic_deploy_accelerator_type", {
  sms = sagemaker_session()
  e = Estimator$new(
    IMAGE_URI, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, sagemaker_session=sms
  )
  e$fit(list("train"="s3://bucket/training-prefix"))
  e$deploy(INSTANCE_COUNT, INSTANCE_TYPE, accelerator_type=ACCELERATOR_TYPE)

  args = e$sagemaker_session$endpoint_from_production_variants(..return_value = T)
  expect_true(startsWith(args[["name"]], IMAGE_URI))
  expect_equal(args[["production_variants"]][[1]][["AcceleratorType"]], ACCELERATOR_TYPE)
  expect_equal(args[["production_variants"]][[1]][["InitialInstanceCount"]], INSTANCE_COUNT)
  expect_equal(args[["production_variants"]][[1]][["InstanceType"]], INSTANCE_TYPE)
})

test_that("test_deploy_with_model_name", {
  sms = sagemaker_session()
  estimator = Estimator$new(
    IMAGE_URI,
    ROLE,
    INSTANCE_COUNT,
    INSTANCE_TYPE,
    output_path=OUTPUT_PATH,
    sagemaker_session=sms
  )

  do.call(estimator$set_hyperparameters, HYPERPARAMS)
  estimator$fit(list("train"="s3://bucket/training-prefix"))
  model_name = "model-name"
  estimator$deploy(INSTANCE_COUNT, INSTANCE_TYPE, model_name=model_name)

  args = sms$create_model(..return_value = T)
  expect_equal(args[[1]], model_name)
})

test_that("test_deploy_with_no_model_name", {
  sms = sagemaker_session()
  estimator = Estimator$new(
    IMAGE_URI,
    ROLE,
    INSTANCE_COUNT,
    INSTANCE_TYPE,
    output_path=OUTPUT_PATH,
    sagemaker_session=sms
  )

  do.call(estimator$set_hyperparameters, HYPERPARAMS)
  estimator$fit(list("train"="s3://bucket/training-prefix"))
  estimator$deploy(INSTANCE_COUNT, INSTANCE_TYPE)

  args = sms$create_model(..return_value = T)
  expect_true(startsWith(args[[1]], IMAGE_URI))
})

test_that("test_register_default_image", {
  sms = sagemaker_session()
  estimator = Estimator$new(
    IMAGE_URI,
    ROLE,
    INSTANCE_COUNT,
    INSTANCE_TYPE,
    output_path=OUTPUT_PATH,
    sagemaker_session=sms
  )
  do.call(estimator$set_hyperparameters, HYPERPARAMS)
  estimator$fit(list("train"="s3://bucket/training-prefix"))

  model_package_name = "test-estimator-register-model"
  content_types = "application/json"
  response_types = "application/json"
  inference_instances = "ml.m4.xlarge"
  transform_instances = "ml.m4.xlarget"

  estimator$register(
    content_types=content_types,
    response_types=response_types,
    inference_instances=inference_instances,
    transform_instances=transform_instances,
    model_package_name=model_package_name
  )

  expected_create_model_package_request = list(
    "containers"=list(
      list(
        "Image"=estimator$image_uri,
        "ModelDataUrl"=estimator$model_data
      )
    ),
    "content_types"=content_types,
    "response_types"=response_types,
    "inference_instances"=inference_instances,
    "transform_instances"=transform_instances,
    "marketplace_cert"=FALSE,
    "model_package_name"=model_package_name
  )
  expect_equal(
    sms$create_model_package_from_containers(..return_value = T),
    expected_create_model_package_request
  )
})

test_that("test_register_inference_image", {
  sms = sagemaker_session()
  estimator = Estimator$new(
    IMAGE_URI,
    ROLE,
    INSTANCE_COUNT,
    INSTANCE_TYPE,
    output_path=OUTPUT_PATH,
    sagemaker_session=sms
  )
  do.call(estimator$set_hyperparameters, HYPERPARAMS)
  estimator$fit(list("train"="s3://bucket/training-prefix"))

  model_package_name = "test-estimator-register-model"
  content_types = "application/json"
  response_types = "application/json"
  inference_instances = "ml.m4.xlarge"
  transform_instances = "ml.m4.xlarget"
  inference_image = "fake-inference-image"

  estimator$register(
    content_types=content_types,
    response_types=response_types,
    inference_instances=inference_instances,
    transform_instances=transform_instances,
    model_package_name=model_package_name,
    image_uri=inference_image
  )

  expected_create_model_package_request = list(
    "containers"=list(
      list(
        "Image"=inference_image,
        "ModelDataUrl"=estimator$model_data
      )
    ),
    "content_types"=content_types,
    "response_types"=response_types,
    "inference_instances"=inference_instances,
    "transform_instances"=transform_instances,
    "marketplace_cert"=FALSE,
    "model_package_name"=model_package_name
  )
  expect_equal(sms$create_model_package_from_containers(..return_value = T), expected_create_model_package_request)
})

test_that("test_prepare_init_params_from_job_description_with_image_training_job", {
  init_params = EstimatorBase$private_methods$.prepare_init_params_from_job_description(
    job_details=RETURNED_JOB_DESCRIPTION
  )
  expect_equal(init_params[["role"]], "arn:aws:iam::366:role/SageMakerRole")
  expect_equal(init_params[["instance_count"]], 1)
  expect_equal(init_params[["image_uri"]], "1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-other:1.0.4")
})

test_that("test_prepare_init_params_from_job_description_with_algorithm_training_job", {
  algorithm_job_description = RETURNED_JOB_DESCRIPTION
  algorithm_job_description[["AlgorithmSpecification"]] = list(
    "TrainingInputMode"="File",
    "AlgorithmName"="arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees",
    "TrainingImage"=""
  )
  init_params = EstimatorBase$private_methods$.prepare_init_params_from_job_description(
    job_details=algorithm_job_description
  )
  expect_equal(init_params[["role"]], "arn:aws:iam::366:role/SageMakerRole")
  expect_equal(init_params[["instance_count"]], 1)
  expect_equal(init_params[["algorithm_arn"]], "arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees")
})

test_that("test_prepare_init_params_from_job_description_with_spot_training", {
  job_description = RETURNED_JOB_DESCRIPTION
  job_description[["EnableManagedSpotTraining"]] = TRUE
  job_description[["StoppingCondition"]] = list(
    "MaxRuntimeInSeconds"=86400,
    "MaxWaitTimeInSeconds"=87000
  )
  init_params = EstimatorBase$private_methods$.prepare_init_params_from_job_description(
    job_details=job_description
  )
  expect_equal(init_params[["role"]], "arn:aws:iam::366:role/SageMakerRole")
  expect_equal(init_params[["instance_count"]], 1)
  expect_true(init_params[["use_spot_instances"]])
  expect_equal(init_params[["max_run"]], 86400)
  expect_equal(init_params[["max_wait"]], 87000)
})

test_that("test_prepare_init_params_from_job_description_with_retry_strategy", {
  job_description = RETURNED_JOB_DESCRIPTION
  job_description[["RetryStrategy"]] = list("MaximumRetryAttempts"=2)
  job_description[["StoppingCondition"]] = list(
    "MaxRuntimeInSeconds"=86400,
    "MaxWaitTimeInSeconds"=87000
  )
  init_params = EstimatorBase$private_methods$.prepare_init_params_from_job_description(
    job_details=job_description
  )
  expect_equal(init_params[["role"]], "arn:aws:iam::366:role/SageMakerRole")
  expect_equal(init_params[["instance_count"]], 1)
  expect_equal(init_params[["max_run"]], 86400)
  expect_equal(init_params[["max_wait"]], 87000)
  expect_equal(init_params[["max_retry_attempts"]], 2)
})

test_that("test_prepare_init_params_from_job_description_with_retry_strategy", {
  invalid_job_description = RETURNED_JOB_DESCRIPTION
  invalid_job_description[["AlgorithmSpecification"]] = list("TrainingInputMode"="File")

  expect_error(
    EstimatorBase$private_methods$.prepare_init_params_from_job_description(
      job_details=invalid_job_description
    ),
    "Invalid AlgorithmSpecification",
    class = "RuntimeError"
  )
})

test_that("test_prepare_for_training_with_base_name", {
  sms = sagemaker_session()
  estimator = Estimator$new(
    image_uri="some-image",
    role="some_image",
    instance_count=1,
    instance_type="ml.m4.xlarge",
    sagemaker_session=sms,
    base_job_name="base_job_name"
  )
  estimator$.prepare_for_training()
  expect_true(grepl("base_job_name",estimator$.current_job_name))
})

test_that("test_prepare_for_training_with_name_based_on_image", {
  sms = sagemaker_session()
  estimator = Estimator$new(
    image_uri="some-image",
    role="some_image",
    instance_count=1,
    instance_type="ml.m4.xlarge",
    sagemaker_session=sms
  )
  estimator$.prepare_for_training()
  expect_true(grepl("some-image",estimator$.current_job_name))
})

test_that("test_estimator_local_mode_error", {
  # When using instance local with a session which is not LocalSession we should error out
  sms = sagemaker_session()
  expect_error(
    Estimator$new(
      image_uri="some-image",
      role="some_image",
      instance_count=1,
      instance_type="local",
      sagemaker_session=sms,
      base_job_name="base_job_name"
    ),
    class = "RuntimeError"
  )
})

test_that("test_estimator_local_mode_error", {
  # When using instance local with a session which is not LocalSession we should error out
  expect_true(inherits(
    Estimator$new(
      image_uri="some-image",
      role="some_image",
      instance_count=1,
      instance_type="local",
      sagemaker_session=sagemaker_local_session(),
      base_job_name="base_job_name"),
    "Estimator"
  ))
})

test_that("test_framework_distribution_configuration", {
  sms = sagemaker_session()
  framework = DummyFramework$new(
    entry_point="script",
    role=ROLE,
    sagemaker_session=sms,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE
  )
  actual_ps = framework$.__enclos_env__$private$.distribution_configuration(
    distribution=DISTRIBUTION_PS_ENABLED
  )
  expected_ps = list("sagemaker_parameter_server_enabled"=TRUE)

  actual_mpi = framework$.__enclos_env__$private$.distribution_configuration(
    distribution=DISTRIBUTION_MPI_ENABLED
  )
  expected_mpi = list(
    "sagemaker_mpi_enabled"=TRUE,
    "sagemaker_mpi_num_of_processes_per_host"=2,
    "sagemaker_mpi_custom_mpi_options"="options"
  )
  expect_equal(actual_mpi, expected_mpi)

  actual_ddp = framework$.__enclos_env__$private$.distribution_configuration(
    distribution=DISTRIBUTION_SM_DDP_ENABLED
  )
  expected_ddp = list(
    "sagemaker_distributed_dataparallel_enabled"=TRUE,
    "sagemaker_instance_type"=INSTANCE_TYPE,
    "sagemaker_distributed_dataparallel_custom_mpi_options"="options"
  )
  expect_equal(actual_ddp, expected_ddp)
})

test_that("test_image_name_map", {
  sms = sagemaker_session()
  expect_warning({
    e = DummyFramework$new(
      "my_script.py",
      image_name=IMAGE_URI,
      role=ROLE,
      sagemaker_session=sms,
      instance_count=INSTANCE_COUNT,
      instance_type=INSTANCE_TYPE
    )
  })
  expect_equal(e$image_uri, IMAGE_URI)
})
DyfanJones/sagemaker-r-mlcore documentation built on May 3, 2022, 10:08 a.m.