R/history.R

Defines functions compose_history_metadata plot.tf_estimator_history print.tf_estimator_history as.data.frame.tf_estimator_history new_tf_estimator_history

Documented in plot.tf_estimator_history

new_tf_estimator_history <- function(training_history = NULL, evaluation_history = NULL) {
  training_history <- training_history %||% .globals$history[[mode_keys()$TRAIN]]
  metrics_train <- names(training_history$losses)
  steps <- utils::tail(training_history$step, 1)
  
  evaluation_metrics <- if (!is.null(evaluation_history)) {
    evaluation_history$losses %>%
      lapply(utils::tail, 1) %>%
      c(list(step = utils::tail(evaluation_history$step, 1)))
  } else NULL
  
  structure(
    list(
      params = list(metrics = metrics_train,
                    steps = steps),
      losses = training_history$losses, 
      step = training_history$step,
      evaluation_metrics = evaluation_metrics
    ),
    class = "tf_estimator_history"
  )
}

#' @export
as.data.frame.tf_estimator_history <- function(x, ...) {
  df <- data.frame(x[["losses"]]) %>%
    cbind(data.frame(x["step"]))
  if (length(df))
    tidyr::gather(df, "metric", "value", -"step")
  else df
}

#' @export
print.tf_estimator_history <- function(x, ...) {
  
  steps <- x$params$steps
  params <- list(steps = steps)
  params <-  prettyNum(params, big.mark = ",")
  
  str <- paste0("Trained for ", params[["steps"]]," steps.")
  
  # last epoch metrics
  metrics <- lapply(x$losses, function(metric) {
    metric[[which(x$step == steps)]]
  })
  
  cat(str, "\n")
  print_metrics <- function(metrics, header) {
    labels <- names(metrics)
    max_label_len <- max(nchar(labels))
    labels <- sprintf(paste0("%", max_label_len, "s"), labels) 
    metrics <- prettyNum(metrics, big.mark = ",", digits = 4, scientific=FALSE)
    str <- paste0(paste0(header, "\n"),
                  paste0(labels, ": ", metrics, collapse = "\n"),
                  collapse = "\n")
    cat(str, "\n")
  }
  
  print_metrics(metrics, "Final step (plot to see history):")
  
  if (!is.null(x$evaluation_metrics))
    print_metrics(
      x$evaluation_metrics[names(x$evaluation_metrics) != "step"], 
      paste0("Evaluation metrics (step ", x$evaluation_metrics[["step"]], "):")
    )
  
}

#' Plot training history
#' 
#' Plots metrics recorded during training. 
#' 
#' @param x Training history object returned from `train()`.
#' @param y Unused.
#' @param metrics One or more metrics to plot (e.g. `c('total_losses', 'mean_losses')`).
#'   Defaults to plotting all captured metrics.
#' @param method Method to use for plotting. The default "auto" will use 
#'   \pkg{ggplot2} if available, and otherwise will use base graphics.
#' @param smooth Whether a loess smooth should be added to the plot, only 
#'   available for the `ggplot2` method. If the number of data points is smaller
#'   than ten, it is forced to false.
#' @param theme_bw Use `ggplot2::theme_bw()` to plot the history in 
#'   black and white.
#' @param ... Additional parameters to pass to the [plot()] method.
#' @importFrom graphics par plot
#' 
#' @export
plot.tf_estimator_history <- function(x, y, metrics = NULL, method = c("auto", "ggplot2", "base"),
                                      smooth = getOption("tf.estimator.plot.history.smooth", TRUE),
                                      theme_bw = getOption("tf.estimator.plot.history.theme_bw", FALSE),
                                      ...) {
  # check which method we should use
  method <- match.arg(method)
  if (method == "auto") {
    if (requireNamespace("ggplot2", quietly = TRUE))
      method <- "ggplot2"
    else
      method <- "base"
  }
  
  # convert to data frame
  df <- x %>%
    compose_history_metadata(rename_step_col = FALSE) %>%
    tidyr::gather("metric", "value", -"step")
  
  # if metrics is null we plot all of the metrics
  if (is.null(metrics)) metrics <- x$params$metrics
  
  # select the correct metrics
  df <- df[df$metric %in% metrics, ]
  
  if (method == "ggplot2") {
    # helper function for correct breaks (integers only)
    int_breaks <- function(x) pretty(x)[pretty(x) %% 1 == 0]
    
    p <- ggplot2::ggplot(df, ggplot2::aes_(~step, ~value))
    
    smooth_args <- list(se = FALSE, method = 'loess', na.rm = TRUE)
    
    if (theme_bw) {
      smooth_args$size <- 0.5
      smooth_args$color <- "gray47"
      p <- p +
        ggplot2::theme_bw() +
        ggplot2::geom_point(col = 1, na.rm = TRUE, size = 2) +
        ggplot2::scale_shape(solid = FALSE)
    } else {
      p <- p +
        ggplot2::geom_point(shape = 21, col = 1, na.rm = TRUE)
    }
    
    if (smooth && nrow(df) >= 10)
      p <- p + do.call(ggplot2::geom_smooth, smooth_args)
    
    p <- p +
      ggplot2::facet_grid(metric~., switch = 'y', scales = 'free_y') +
      ggplot2::scale_x_continuous(breaks = int_breaks) +
      ggplot2::theme(axis.title.y = ggplot2::element_blank(), strip.placement = 'outside',
                     strip.text = ggplot2::element_text(colour = 'black', size = 11),
                     strip.background = ggplot2::element_rect(fill = NA, color = NA))
    
    return(p)
  }
  
  if (method == 'base') {
    # par
    op <- par(mfrow = c(length(metrics), 1),
              mar = c(3, 3, 2, 2)) # (bottom, left, top, right)
    on.exit(par(op), add = TRUE)
    
    for (i in seq_along(metrics)) {
      
      # get metric
      metric <- metrics[[i]]
      
      # adjust margins
      top_plot <- i == 1
      bottom_plot <- i == length(metrics)
      if (top_plot)
        par(mar = c(1.5, 3, 1.5, 1.5))
      else if (bottom_plot)
        par(mar = c(2.5, 3, .5, 1.5))
      else
        par(mar = c(1.5, 3, .5, 1.5))
      
      # select data for current panel
      df2 <- df[df$metric == metric, ]
      
      # plot values
      plot(df2$step, df2$value, pch = c(1, 4)[df2$data],
           xaxt = ifelse(bottom_plot, 's', 'n'), xlab = "step", ylab = metric, ...)
      
      # add legend
      legend_location <- ifelse(
        df2[,'value'][1] > df2[,'value'][x$params$steps],
        "topright", "bottomright")
      graphics::legend(legend_location, legend = metric, pch = 1)
    }
  }
}

compose_history_metadata <- function(history, max_rows = 100, rename_step_col = TRUE) {
  training_history <- as.data.frame(history) %>%
    tidyr::spread("metric", "value")
  
  training_history <- if (nrow(training_history) > max_rows) {
    # cap number of points plotted
    nrow_history <- nrow(training_history)
    sampling_indices <- seq(1, nrow_history, by = nrow_history / max_rows) %>%
      as.integer()
    training_history[sampling_indices,]
  } else training_history
  
  if (rename_step_col)
    names(training_history)[names(training_history) == "step"] <- "epoch"
  training_history
}
rstudio/tfestimators documentation built on Nov. 24, 2021, 6:56 a.m.