R/runs.R

Defines functions end_run get_run_context.default get_run_context start_run record_logged_model log_artifact.ggplot log_artifact.default log_artifact get_experiment_from_run set_terminated list_artifacts get_artifact_path load_artifact search_runs get_metric_history log_params param_value_to_rest delete_tag set_tag validate_batch_input log_batch get_run restore_run delete_run create_run log_metrics get_param get_metric increment_metric_step get_most_recent_step exists_metric get_key_value_df

Documented in create_run delete_run delete_tag end_run get_artifact_path get_metric get_metric_history get_param get_run list_artifacts load_artifact log_artifact log_artifact.default log_artifact.ggplot log_batch log_metrics log_params record_logged_model restore_run search_runs set_tag start_run

#' @include globals.R
NULL

#' @importFrom tibble tibble
#' @importFrom rlang names2
get_key_value_df <- function(..., .which = -1) {
  values <- list(...) %>% unlist()
  keys <- names2(values)
  args <- as.list(sys.call(which = .which))
  backup_keys <- args[2:length(args)] %>%
    as.vector() %>%
    as.character() %>%
    make.names()
  values <- unname(values)
  for(i in seq_along(values)) {
    keys[i] <- ifelse(keys[i] == "", backup_keys[i], keys[i])
  }
  tibble(
    key = keys,
    value = values,
  )
}

exists_metric <- function(metric_key, run_id, client) {
  tryCatch(
    {
      get_metric_history(
        metric_key = metric_key,
        run_id = run_id,
        client = client
      )

      return(TRUE)
    },
    error = function(e) {
      return(FALSE)
    }
  )
}

get_most_recent_step <- function(metric_history) {
  max(metric_history$step, na.rm = TRUE)
}

increment_metric_step <- function(metric_key, run_id, client) {
  get_metric_history(
    metric_key = metric_key,
    run_id = run_id,
    client = client
  ) %>%
    get_most_recent_step() %>%
    add(1) %>%
    as.integer()
}

#' Get the most recent value of a metric for a run
#'
#' @param metric_key The metric key to get
#' @param run_id The run ID to get the metric for. Defaults to the active run
#' @param client An MLFlow client. Autogenerated if missing.
#'
#' @return The value of the metric
#' @export
get_metric <- function(metric_key, run_id = get_active_run_id(), client = mlflow_client()) {

  check_required(metric_key)
  assert_string(metric_key)
  assert_string(run_id)
  assert_mlflow_client(client)

  metrics <- get_run(
    run_id = run_id,
    client = client
  )$metrics[[1]]

  if (is.null(metrics)) {
    warn(
      sprintf("No metric key %s found.", metric_key)
    )
    return(NULL)
  }

  metrics[
    metrics$key == metric_key,
  ]$value
}

#' Get a parameter for a run
#'
#' @param param The parameter to get
#' @param run_id The run ID to get the metric for. Defaults to the active run
#' @param client An MLFlow client. Autogenerated if missing.
#'
#' @return The value of the parameter
#' @export
get_param <- function(param, run_id = get_active_run_id(), client = mlflow_client()) {

  check_required(param)
  assert_string(param)
  assert_string(run_id)
  assert_mlflow_client(client)

  params <- get_run(
    run_id = run_id,
    client = client
  )$params[[1]]

  if (is.null(params)) {
    warn(
      sprintf("No param %s found.", params)
    )
    return(NULL)
  }

  params[[param]]
}

#' Log Metrics
#'
#' Logs a metric for a run. Metrics key-value pair that records a single float measure.
#'   During a single execution of a run, a particular metric can be logged several times.
#'   The MLflow Backend keeps track of historical metric values along two axes: timestamp and step.
#'
#' @param ... variable names from which a data.frame with `key` and `value` columns will be created.
#' @param client An MLFlow client. Defaults to `NULL` and will be auto-generated.
#' @param run_id A run uuid. Automatically inferred if a run is currently active.
#'
#' @importFrom rlang maybe_missing
#' @importFrom purrr imap_int
#' @importFrom magrittr add
#'
#' @return No return value. Called for side effects.
#'
#' @export
log_metrics <- function(..., run_id = get_active_run_id(), client = mlflow_client()) {

  assert_string(run_id)
  assert_mlflow_client(client)

  metrics <- get_key_value_df(...)

  metrics$timestamp <- get_timestamp() %>%
    convert_timestamp_to_ms()

  metrics$step <- metrics$key %>%
    set_names() %>%
    map_lgl(
      ~ exists_metric(
        .x,
        run_id = run_id,
        client = client
      )
    ) %>%
    imap_int(
      function(.x, .y) {
        if (isTRUE(.x)) {
          increment_metric_step(
            metric_key = .y,
            run_id = run_id,
            client = client
          )
        } else {
          0L
        }
      }
    ) %>%
    unname()

  log_batch(
    metrics = metrics,
    run_id = run_id,
    client = client
  )
}



#' Create an MLFlow run
#'
#' @importFrom purrr imap
#'
#' @param tags Additional tags to supply for the run
#' @param experiment_id The ID of the experiment to register the run under.
#' @param client An MLFlow client. Defaults to `NULL` and will be auto-generated.
#'
#' @return Metadata on the newly-created run.
create_run <- function(tags = list(), experiment_id = get_active_experiment_id(), client = mlflow_client()) {

  assert_list(tags)
  assert_string(experiment_id)
  assert_mlflow_client(client)

  tags <- tags %>%
    imap(~ list(key = .y, value = .x)) %>%
    unname()

  data <- list(
    experiment_id = experiment_id,
    start_time = convert_timestamp_to_ms(get_timestamp()),
    tags = tags
  )

  response <- call_mlflow_api(
    "runs", "create",
    client = client,
    verb = "POST",
    data = data
  )

  run_id <- response$run$info$run_id

  get_run(run_id = run_id, client = client)
}

#' Delete a Run
#'
#' Deletes the run with the specified ID.
#'
#' @param run_id A run uuid. Automatically inferred if a run is currently active.
#' @param client An MLFlow client. Defaults to `NULL` and will be auto-generated.
#'
#' @return No return value. Called for side effects.
#'
#' @export
delete_run <- function(run_id = get_active_run_id(), client = mlflow_client()) {

  assert_string(run_id)
  assert_mlflow_client(client)

  if (exists_active_run() && identical(run_id, get_active_run_id())) {
    abort("Cannot delete an active run.")
  }

  data <- list(run_id = run_id)
  call_mlflow_api(
    "runs", "delete",
    client = client,
    verb = "POST",
    data = data
  )

  invisible()
}

#' Restore a Run
#'
#' Restores the run with the specified ID.
#' @param run_id A run id
#' @param client An MLFlow client. Defaults to `NULL` and will be auto-generated.
#'
#' @return Metadata on the newly-restored run.
#' @export
restore_run <- function(run_id = get_active_run_id(), client = mlflow_client()) {

  assert_string(run_id)
  assert_mlflow_client(client)

  data <- list(run_id = run_id)
  call_mlflow_api(
    "runs", "restore",
    client = client,
    verb = "POST",
    data = data
  )

  get_run(run_id, client = client)
}

#' Get Run
#'
#' Gets metadata, params, tags, and metrics for a run. Returns a single value for each metric
#' key: the most recently logged metric value at the largest step.
#'
#' @param run_id A run uuid. Automatically inferred if a run is currently active.
#' @param client An MLFlow client. Defaults to `NULL` and will be auto-generated.
#'
#' @return Metadata on the run, including the ID, the `experiment_id`, the `status`, etc.
#' @export
get_run <- function(run_id = get_active_run_id(), client = mlflow_client()) {

  assert_string(run_id)
  assert_mlflow_client(client)

  response <- call_mlflow_api(
    "runs", "get",
    client = client,
    verb = "GET",
    query = list(
      run_id = run_id
    )
  )

  parse_run(response$run)
}

#' Log Batch
#'
#' Log a batch of metrics, params, and/or tags for a run. The server will respond with an error (non-200 status code)
#'   if any data failed to be persisted. In case of error (due to internal server error or an invalid request), partial
#'   data may be written.
#'
#' @importFrom checkmate assert_data_frame
#'
#' @param metrics A dataframe of metrics to log, containing the following columns: "key", "value",
#'  "step", "timestamp". This dataframe cannot contain any missing ('NA') entries.
#' @param params A dataframe of params to log, containing the following columns: "key", "value".
#'  This dataframe cannot contain any missing ('NA') entries.
#' @param tags A dataframe of tags to log, containing the following columns: "key", "value".
#'  This dataframe cannot contain any missing ('NA') entries.
#' @param run_id A run uuid. Automatically inferred if a run is currently active.
#' @param client An MLFlow client. Defaults to `NULL` and will be auto-generated.
#'
#' @return No return value. Called for side effects.
#'
#' @export
log_batch <- function(metrics = data.frame(), params = data.frame(), tags = data.frame(), run_id = get_active_run_id(), client = mlflow_client()) {

  assert_data_frame(metrics)
  assert_data_frame(params)
  assert_data_frame(tags)
  assert_string(run_id)
  assert_mlflow_client(client)

  validate_batch_input("metrics", metrics, c("key", "value", "step", "timestamp"))
  validate_batch_input("params", params, c("key", "value"))
  validate_batch_input("tags", tags, c("key", "value"))

  params$value <- unlist(lapply(params$value, param_value_to_rest))

  data <- list(
    run_id = run_id,
    metrics = metrics,
    params = params,
    tags = tags
  )

  call_mlflow_api(
    "runs", "log-batch",
    client = client,
    verb = "POST",
    data = data
  )

  invisible()
}

validate_batch_input <- function(input_type, input_dataframe, expected_column_names) {

  if (is.null(input_dataframe) || nrow(input_dataframe) == 0) {
    return()
  } else if (!setequal(names(input_dataframe), expected_column_names)) {
    msg <- paste(input_type,
                 " batch input dataframe must contain exactly the following columns: ",
                 paste(expected_column_names, collapse = ", "),
                 ". Found: ",
                 paste(names(input_dataframe), collapse = ", "),
                 sep = ""
    )
    abort(msg)
  }
}

#' Set Tag
#'
#' Sets a tag on a run. Tags are run metadata that can be updated during a run and
#'  after a run completes.
#'
#' @param key Name of the tag. Maximum size is 255 bytes. This field is required.
#' @param value String value of the tag being logged. Maximum size is 500 bytes. This field is required.
#' @param run_id A run uuid. Automatically inferred if a run is currently active.
#' @param client An MLFlow client. Defaults to `NULL` and will be auto-generated.
#'
#' @return No return value. Called for side effects.
#'
#' @export
set_tag <- function(key, value, run_id = get_active_run_id(), client = mlflow_client()) {

  check_required(key)
  check_required(value)

  assert_string(key)
  assert_string(value)
  assert_mlflow_client(client)

  data <- list(
    run_id = run_id,
    key = key,
    value = value
  )

  call_mlflow_api(
    "runs", "set-tag",
    client = client,
    verb = "POST",
    data = data
  )

  invisible()
}

#' Delete Tag
#'
#' Deletes a tag on a run. This is irreversible. Tags are run metadata that can be updated during a run and
#'  after a run completes.
#'
#' @param key Name of the tag. Maximum size is 255 bytes. This field is required.
#' @param run_id A run uuid. Automatically inferred if a run is currently active.
#' @param client An MLFlow client. Defaults to `NULL` and will be auto-generated.
#'
#' @return No return value. Called for side effects.
#'
#' @export
delete_tag <- function(key, run_id = get_active_run_id(), client = mlflow_client()) {

  check_required(key)
  assert_string(key)
  assert_string(run_id)
  assert_mlflow_client(client)

  data <- list(
    run_id = run_id,
    key = key
  )

  call_mlflow_api(
    "runs", "delete-tag",
    client = client,
    verb = "POST",
    data = data
  )

  invisible()
}

## Translate param value to safe format for REST.
## Don't use case_when to avoid dplyr dep.
param_value_to_rest <- function(value) {
  ifelse(
    is.nan(value),
    "NaN",
    ifelse(
      is.infinite(value),
      ifelse(
        value < 0,
        "-Infinity",
        "Infinity"
      ),
      as.character(value)
    )
  )
}

#' Log Parameters
#'
#' Logs parameters for a run. Examples are params and hyperparams
#'   used for ML training, or constant dates and values used in an ETL pipeline.
#'   A param is a STRING key-value pair. For a run, a single parameter is allowed
#'   to be logged only once.
#'
#' @inheritParams log_metrics
#'
#' @return No return value. Called for side effects.
#'
#' @export
log_params <- function(..., run_id = get_active_run_id(), client = mlflow_client()) {

  assert_string(run_id)
  assert_mlflow_client(client)

  params <- get_key_value_df(...)
  params$value <- param_value_to_rest(params$value)

  log_batch(
    params = params,
    run_id = run_id,
    client = client
  )
}

#' Get Metric History
#'
#' Get a list of all values for the specified metric for a given run.
#'
#' @param metric_key Name of the metric.
#'
#' @importFrom tibble as_tibble
#' @importFrom purrr list_modify
#'
#' @param run_id A run uuid. Automatically inferred if a run is currently active.
#' @param client An MLFlow client. Defaults to `NULL` and will be auto-generated.
#'
#' @return A `data.frame` of the history of the metric provided.
#' @export
get_metric_history <- function(metric_key, run_id = get_active_run_id(), client = mlflow_client()) {

  check_required(metric_key)
  assert_string(metric_key)
  assert_string(run_id)
  assert_mlflow_client(client)

  response <- call_mlflow_api(
    "metrics", "get-history",
    client = client,
    verb = "GET",
    query = list(
      run_id = run_id,
      metric_key = metric_key
    )
  )

  if (is_empty(response$metrics)) {
    abort(
      sprintf(
        "Could not find a metric called %s in run %s.",
        metric_key,
        run_id
      )
    )
  } else {
    response$metrics %>%
      map(
        function(.x) {
          .x %>%
            list_modify(
              timestamp = milliseconds_to_datetime(.x$timestamp)
            )
        }
      ) %>%
      map(as_tibble) %>%
      reduce(bind_rows)
  }
}

#' Search Runs
#'
#' Search for runs that satisfy expressions. Search expressions can use Metric and Param keys.
#'
#' @param experiment_ids List of string experiment IDs (or a single string experiment ID) to search
#' over. Attempts to use active experiment if not specified.
#' @param filter A filter expression over params, metrics, and tags, allowing returning a subset of runs.
#'   The syntax is a subset of SQL which allows only ANDing together binary operations between a param/metric/tag and a constant.
#' @param run_view_type Run view type.
#' @param order_by List of properties to order by. Example: "metrics.acc DESC".
#' @param client An MLFlow client. Defaults to `NULL` and will be auto-generated.
#'
#' @return A data.frame of runs matching the search criteria.
#'
#' @export
search_runs <- function(experiment_ids, run_view_type = c("ACTIVE_ONLY", "DELETED_ONLY", "ALL"), order_by = list(), filter = "", client = mlflow_client()) {

  check_required(experiment_ids)

  # If we get back a single experiment ID, e.g. the active experiment ID, convert it to a list
  if (is.atomic(experiment_ids)) {
    experiment_ids <- list(experiment_ids)
  }

  assert_list(experiment_ids)
  run_view_type <- match.arg(run_view_type)
  assert_list(order_by)
  assert_string(filter)
  assert_mlflow_client(client)

  response <- call_mlflow_api(
    "runs", "search",
    client = client,
    verb = "POST",
    data = list(
      experiment_ids = experiment_ids,
      filter = "",
      run_view_type = run_view_type,
      order_by = list()
    )
  )

  runs_list <- response$run %>%
    map(parse_run)

  do.call("bind_rows", runs_list) %||% data.frame()
}

#' Load an artifact into an R object
#'
#' @importFrom checkmate assert_function
#' @importFrom aws.s3 s3read_using
#'
#' @param artifact_name The name of the artifact to load
#' @param run_id A run ID to find the URI for
#' @param client An MLFlow client
#' @param FUN a function to use to load the artifact
#' @param \dots Additional arguments to pass on to `s3read_using`
#' @param pause_base,max_times,pause_cap See \link[purrr]{insistently}
#'
#' @return An R object. The result of `s3read_using`
#' @export
load_artifact <- function(artifact_name, FUN = readRDS, run_id = get_active_run_id(), client = mlflow_client(), pause_base = .5, max_times = 5, pause_cap = 60, ...) {

  assert_function(FUN)
  assert_string(artifact_name)
  assert_string(run_id)
  assert_mlflow_client(client)

  artifact_location <- get_artifact_path(
    run_id = run_id,
    client = client
  )

  rate <- rate_backoff(
    pause_base = pause_base,
    max_times = max_times,
    pause_cap = pause_cap
  )

  insistently_read <- insistently(
    s3read_using,
    rate = rate,
    quiet = FALSE
  )

  object <- insistently_read(
    FUN = FUN,
    ...,
    object = paste(artifact_location, artifact_name, sep = "/")
  )

  object
}

#' Get the artifact path for a run
#'
#' @param run_id A run id. Automatically inferred if a run is currently active.
#' @param client An MLFlow client. Auto-generated if not provided.
#'
#' @return A path to the run's artifacts in S3
#' @export
get_artifact_path <- function(run_id = get_active_run_id(), client = mlflow_client()) {
  experiment_id <- get_experiment_from_run(run_id = run_id)

  experiment <- get_experiment(
    experiment_id = experiment_id,
    client = client
  )

  paste(
    experiment$artifact_location,
    run_id,
    "artifacts",
    sep = "/"
  )
}
#' List Artifacts
#'
#' Gets a list of artifacts.
#'
#' @param path The run's relative artifact path to list from. If not specified, it is
#'  set to the root artifact path
#' @param run_id A run id Automatically inferred if a run is currently active.
#' @param client An MLFlow client. Defaults to `NULL` and will be auto-generated.
#'
#' @importFrom purrr transpose
#' @importFrom rlang inform
#'
#' @return A `data.frame` of the artifacts at the path provided for the run provided.
#' @export
list_artifacts <- function(path = NULL, run_id = get_active_run_id(), client = mlflow_client()) {

  assert_string(path, null.ok = TRUE)
  assert_string(run_id)
  assert_mlflow_client(client)

  response <- call_mlflow_api(
    "artifacts", "list",
    client = client,
    verb = "GET",
    query = list(
      run_id = run_id,
      path = path
    )
  )

  files_list <- if (!is.null(response$files)) response$files else list()
  files_list <- map(files_list, function(file_info) {
    if (is.null(file_info$file_size)) {
      file_info$file_size <- NA
    }
    file_info
  })

  files_list %>%
    transpose() %>%
    map(unlist) %>%
    as.data.frame()
}

set_terminated <- function(status, end_time, run_id, client) {

  data <- list(
    run_id = run_id,
    status = status,
    end_time = end_time
  )

  response <- call_mlflow_api("runs", "update", verb = "POST", client = client, data = data)

  get_run(client = client, run_id = response$run_info$run_id)
}

get_experiment_from_run <- function(run_id) {
  get_run(
    run_id
  )$experiment_id %>%
    unique()
}

#' Log Artifact
#'
#' Logs a specific file or directory as an artifact for a run. Modeled after `aws.s3::s3write_using`
#'
#' @param x The object to log as an artifact
#' @param FUN the function to use to save the artifact
#' @param filename the name of the file to save
#' @param run_id A run uuid. Automatically inferred if a run is currently active.
#' @param client An MLFlow client. Auto-generated if not provided
#' @param pause_base,max_times,pause_cap See \link[purrr]{insistently}
#' @param ... Additional arguments to pass to `aws.s3::s3write_using`
#'
#' @details
#'
#' When logging to Amazon S3, ensure that you have the s3:PutObject, s3:GetObject,
#' s3:ListBucket, and s3:GetBucketLocation permissions on your bucket.
#'
#' Additionally, at least the \code{AWS_ACCESS_KEY_ID} and \code{AWS_SECRET_ACCESS_KEY}
#' environment variables must be set to the corresponding key and secrets provided
#' by Amazon IAM.
#'
#' @importFrom stringr str_remove str_split str_sub
#' @importFrom aws.s3 s3write_using
#' @importFrom purrr insistently rate_backoff
#'
#' @return The path to the file, invisibly
#' @export
log_artifact <- function(x, FUN, filename, run_id, client = mlflow_client(), pause_base = .5, max_times = 5, pause_cap = 60, ...) {
  UseMethod("log_artifact")
}

#' @rdname log_artifact
#' @export
log_artifact.default <- function(x, FUN = saveRDS, filename, run_id = get_active_run_id(), client = mlflow_client(), pause_base = .5, max_times = 5, pause_cap = 60, ...) {

  check_required(x)
  check_required(filename)

  artifact_dir = get_artifact_path(
    run_id = run_id,
    client = client
  )

  artifact_filepath <- paste(artifact_dir, filename, sep = "/")

  rate <- rate_backoff(
    pause_base = pause_base,
    max_times = max_times,
    pause_cap = pause_cap
  )

  insistently_write <- insistently(
    s3write_using,
    rate = rate,
    quiet = FALSE
  )

  insistently_write(
    x = x,
    FUN = FUN,
    ...,
    object = artifact_filepath
  )

  invisible(artifact_filepath)
}

#' @importFrom aws.s3 put_object
#' @importFrom tools file_ext
#' @rdname log_artifact
#' @export
log_artifact.ggplot <- function(x, FUN, filename, run_id = get_active_run_id(), client = mlflow_client(), pause_base = .5, max_times = 5, pause_cap = 60, ...) {

  check_required(x)
  check_required(FUN)
  check_required(filename)

  ## based on https://github.com/hrbrmstr/hrbrthemes/blob/master/R/aaa.r
  if (isFALSE(requireNamespace("ggplot2", quietly = TRUE))) {
    abort(
      "Package `ggplot2` required for `ggsave`.\n",
      "Please install and try again."
    )
  }

  artifact_dir = get_artifact_path(
    run_id = run_id,
    client = client
  )

  artifact_filepath <- paste(artifact_dir, filename, sep = "/")

  ext <- file_ext(artifact_filepath)
  temp_file <- tempfile(fileext = ext)
  on.exit(unlink(temp_file, recursive = TRUE))

  ggplot2::ggsave(filename = temp_file, plot = x, ...)

  rate <- rate_backoff(
    pause_base = pause_base,
    max_times = max_times,
    pause_cap = pause_cap
  )

  insistently_put <- insistently(
    put_object,
    rate = rate,
    quiet = FALSE
  )

  insistently_put(
    file = temp_file,
    object = artifact_filepath
  )

  invisible(artifact_filepath)
}

#' Record logged model metadata with the tracking server.
#'
#' @param model_spec A model specification.
#' @param run_id A run uuid. Automatically inferred if a run is currently active.
#' @param client An MLFlow client. Defaults to `NULL` and will be auto-generated.
#'
#' @importFrom jsonlite toJSON
record_logged_model <- function(model_spec, run_id = get_active_run_id(), client = mlflow_client()) {

  call_mlflow_api(
    "runs", "log-model",
    client = client,
    verb = "POST",
    data = list(
      run_id = run_id,
      model_json = toJSON(model_spec, auto_unbox = TRUE)
    )
  )
}

#' Start Run
#'
#' Starts a new run. If `client` is not provided, this function infers contextual information such as
#'   source name and version, and also registers the created run as the active run. If `client` is provided,
#'   no inference is done.
#'
#' @param run_id If specified, get the run with the specified UUID and log metrics
#'   and params under that run. The run's end time is unset and its status is set to
#'   running, but the run's other attributes remain unchanged.
#' @param experiment_id Used only when `run_id` is unspecified. ID of the experiment under
#'   which to create the current run. If unspecified, the run is created under
#'   a new experiment with a randomly generated name.
#' @param client An MLFlow client. Defaults to `NULL` and will be auto-generated.
#' @param nested Controls whether the run to be started is nested in a parent run. `TRUE` creates a nest run.
#'
#' @return Metadata on the newly-started run.
#'
#' @export
start_run <- function(run_id = Sys.getenv("MLFLOW_RUN_ID"), experiment_id = get_active_experiment_id(), client = mlflow_client(), nested = FALSE) {

  assert_logical(nested)
  assert_string(run_id)
  assert_string(experiment_id, null.ok = (!is.null(run_id) & run_id != ""))
  assert_mlflow_client(client)

  if (exists_active_run() && !nested) {
    abort(
      paste(
        "Run with id",
        get_active_run_id(),
        "is already active. To start a nested run, Call `start_run()` with `nested = TRUE`."
      )
    )
  }

  run <- if (run_id != "") {
    # This is meant to pick up existing run when we're inside `mlflow_source()` called via `mlflow run`.
    get_run(client = client, run_id = run_id)
  } else {
    experiment_id <- ifelse(
      is.null(experiment_id),
      infer_experiment_id(),
      experiment_id
    )

    args <- get_run_context(
      client,
      experiment_id = experiment_id
    )

    do.call(create_run, args)
  }

  push_active_run_id(mlflow_id(run))
  set_active_experiment_id(experiment_id = run$experiment_id)

  run
}

get_run_context <- function(client, ...) {
  UseMethod("get_run_context")
}

get_run_context.default <- function(client, experiment_id, ...) {
  tags <- list()
  tags[[MLFLOW_TAGS$MLFLOW_USER]] <- mlflow_user()
  tags[[MLFLOW_TAGS$MLFLOW_SOURCE_NAME]] <- get_source_name()
  tags[[MLFLOW_TAGS$MLFLOW_SOURCE_VERSION]] <- get_source_version()
  tags[[MLFLOW_TAGS$MLFLOW_SOURCE_TYPE]] <- MLFLOW_SOURCE_TYPE$LOCAL

  if (exists_active_run()) {
    # create a tag containing the parent run ID so that MLflow UI can display
    # nested runs properly
    tags[[MLFLOW_TAGS$MLFLOW_PARENT_run_id]] <- get_active_run_id()
  }
  list(
    client = client,
    tags = tags,
    experiment_id = experiment_id %||% 0,
    ...
  )
}

#' End a Run
#'
#' Terminates a run. Attempts to end the current active run if `run_id` is not specified.
#'
#' @importFrom checkmate assert_number
#'
#' @param status Updated status of the run. Defaults to `FINISHED`. Can also be set to
#' "FAILED" or "KILLED".
#' @param run_id A run uuid. Automatically inferred if a run is currently active.
#' @param client An MLFlow client. Defaults to `NULL` and will be auto-generated.
#'
#' @return Metadata on the newly-ended run.
#'
#' @export
end_run <- function(status = c("FINISHED", "FAILED", "KILLED"), run_id = get_active_run_id(), client = mlflow_client()) {

  status <- match.arg(status)
  assert_string(run_id)
  assert_mlflow_client(client)

  run <- set_terminated(
    client = client,
    run_id = run_id,
    status = status,
    end_time = convert_timestamp_to_ms(get_timestamp())
  )

  if (exists_active_run() && identical(run_id, get_active_run_id())) pop_active_run_id()

  run
}

MLFLOW_TAGS <- list(
  MLFLOW_USER = "mlflow.user",
  MLFLOW_SOURCE_NAME = "mlflow.source.name",
  MLFLOW_SOURCE_VERSION = "mlflow.source.version",
  MLFLOW_SOURCE_TYPE = "mlflow.source.type",
  MLFLOW_PARENT_run_id = "mlflow.parentRunId"
)
collegevine/lightMLFlow documentation built on Jan. 16, 2024, 5:52 a.m.