Nothing
#' @export
print.keras_training_history <- function(x, ...) {
# compute epochs actuually trained for
epochs <- min(x$params$epochs, length(x$metrics[[1]]))
# training params
params <- x$params
params <- list(samples = params$samples,
validation_samples = params$validation_samples,
batch_size = params$batch_size,
epochs = epochs)
params <- prettyNum(params, big.mark = ",")
if (!identical(params[["validation_samples"]], "NULL"))
validate <- paste0(", validated on ", params[["validation_samples"]], " samples")
else
validate <- ""
str <- ""
if (!params[["samples"]] == "NULL") {
str <- paste0(str, "Trained on ", params[["samples"]]," samples", validate, " (batch_size=",
params[["batch_size"]], ", epochs=", params[["epochs"]], ")")
}
# last epoch metrics
metrics <- lapply(x$metrics, function(metric) {
metric[[epochs]]
})
labels <- names(metrics)
max_label_len <- max(nchar(labels))
labels <- sprintf(paste0("%", max_label_len, "s"), labels)
metrics <- prettyNum(metrics, big.mark = ",", digits = 4, scientific=FALSE)
str <- paste0(str, "\n",
"Final epoch (plot to see history):\n",
paste0(labels, ": ", metrics, collapse = "\n"),
collapse = "\n")
cat(str, "\n")
}
#' Plot training history
#'
#' Plots metrics recorded during training.
#'
#' @param x Training history object returned from
#' [`fit.keras.src.models.model.Model()`].
#' @param y Unused.
#' @param metrics One or more metrics to plot (e.g. `c('loss', 'accuracy')`).
#' Defaults to plotting all captured metrics.
#' @param method Method to use for plotting. The default "auto" will use
#' \pkg{ggplot2} if available, and otherwise will use base graphics.
#' @param smooth Whether a loess smooth should be added to the plot, only
#' available for the `ggplot2` method. If the number of epochs is smaller
#' than ten, it is forced to false.
#' @param theme_bw Use `ggplot2::theme_bw()` to plot the history in
#' black and white.
#' @param ... Additional parameters to pass to the [plot()] method.
#'
#' @importFrom rlang .data
#'
#' @returns if `method == "ggplot2"`, the ggplot object is returned. If
#' `method == "base"`, then this function will draw to the graphics device and
#' return `NULL`, invisibly.
#'
#' @export
plot.keras_training_history <- function(x, y, metrics = NULL, method = c("auto", "ggplot2", "base"),
smooth = getOption("keras.plot.history.smooth", TRUE),
theme_bw = getOption("keras.plot.history.theme_bw", FALSE),
...) {
# check which method we should use
method <- match.arg(method)
if (method == "auto") {
if (requireNamespace("ggplot2", quietly = TRUE))
method <- "ggplot2"
else
method <- "base"
}
# convert to data frame
df <- as.data.frame(x)
# if metrics is null we plot all of the metrics
if (is.null(metrics))
metrics <- Filter(function(name) !grepl("^val_", name), names(x$metrics))
# select the correct metrics
df <- df[df$metric %in% metrics, ]
if (tensorflow::tf_version() < "2.2")
do_validation <- x$params$do_validation
else
do_validation <- any(grepl("^val_", names(x$metrics)))
if (method == "ggplot2") {
# helper function for correct breaks (integers only)
int_breaks <- function(x) pretty(x)[pretty(x) %% 1 == 0]
if (do_validation) {
if (theme_bw)
p <- ggplot2::ggplot(df, ggplot2::aes(.data$epoch, .data$value, color = .data$data, fill = .data$data, linetype = .data$data, shape = .data$data))
else
p <- ggplot2::ggplot(df, ggplot2::aes(.data$epoch, .data$value, color = .data$data, fill = .data$data))
} else {
p <- ggplot2::ggplot(df, ggplot2::aes(.data$epoch, .data$value))
}
smooth_args <- list(se = FALSE, method = 'loess', na.rm = TRUE,
formula = y ~ x)
environment(smooth_args$formula) <- baseenv()
if (theme_bw) {
smooth_args$size <- 0.5
smooth_args$color <- "gray47"
p <- p +
ggplot2::theme_bw() +
ggplot2::geom_point(col = 1, na.rm = TRUE, size = 2) +
ggplot2::scale_shape(solid = FALSE)
} else {
p <- p +
ggplot2::geom_point(shape = 21, col = 1, na.rm = TRUE)
}
if (smooth && x$params$epochs >= 10)
p <- p + do.call(ggplot2::geom_smooth, smooth_args)
p <- p +
ggplot2::facet_grid(metric~., switch = 'y', scales = 'free_y') +
ggplot2::scale_x_continuous(breaks = int_breaks) +
ggplot2::theme(axis.title.y = ggplot2::element_blank(), strip.placement = 'outside',
strip.text = ggplot2::element_text(colour = 'black', size = 11),
strip.background = ggplot2::element_rect(fill = NA, color = NA))
return(p)
}
if (method == 'base') {
# par
op <- par(mfrow = c(length(metrics), 1),
mar = c(3, 3, 2, 2)) # (bottom, left, top, right)
on.exit(par(op), add = TRUE)
for (i in seq_along(metrics)) {
# get metric
metric <- metrics[[i]]
# adjust margins
top_plot <- i == 1
bottom_plot <- i == length(metrics)
mar <- c(1.5, 5, 0.5, 1.5)
if (top_plot)
mar[3] %<>% `+`(3.5)
if (bottom_plot)
mar[1] %<>% `+`(3.5)
par(mar = mar)
# select data for current panel
df2 <- df[df$metric == metric, ]
# plot values
plot(df2$epoch, df2$value, pch = c(1, 4)[df2$data],
xaxt = ifelse(bottom_plot, 's', 'n'), xlab = "epoch", ylab = metric, ...)
# add legend
legend_location <- ifelse(
df2[df2$data == 'training', 'value'][1] > df2[df2$data == 'training', 'value'][x$params$epochs],
"topright", "bottomright")
if (do_validation)
graphics::legend(legend_location, legend = c(metric, paste0("val_", metric)), pch = c(1, 4))
else
graphics::legend(legend_location, legend = metric, pch = 1)
}
invisible(NULL)
}
}
#' @export
as.data.frame.keras_training_history <- function(x, ...) {
metric_names <- names(x$metrics)
# pad to epochs if necessary
values <- x$metrics
pad <- x$params$epochs - length(values$loss)
pad_data <- list()
for (metric in metric_names)
pad_data[[metric]] <- rep_len(NA, pad)
values <- rbind(values, pad_data)
# prepare data to plot as a data.frame
df <- data.frame(
epoch = seq_len(x$params$epochs),
value = unlist(values),
metric = rep(sub("^val_", "", names(x$metrics)), each = x$params$epochs),
data = rep(grepl("^val_", names(x$metrics)), each = x$params$epochs)
)
rownames(df) <- NULL
# order factor levels appropriately
df$data <- factor(df$data, c(FALSE, TRUE), c('training', 'validation'))
df$metric <- factor(df$metric, unique(sub("^val_", "", names(x$metrics))))
# return
df
}
to_keras_training_history <- function(history) {
# turn history into an R object so it can be persisted and
# and give it a class so we can write print/plot methods
params <- history$params
# normalize metrics
metrics <- history$history
# metrics <- lapply(metrics, function(metric) {
# as.numeric(lapply(metric, mean))
# })
# create history
keras_training_history(
params = params,
metrics = metrics
)
}
keras_training_history <- function(params, metrics) {
# pad missing metrics with NA
# rows <- max(as.integer(lapply(metrics, length)))
# for (metric in names(metrics)) {
# metric_data <- metrics[[metric]]
# pad <- rows - length(metric_data)
# pad_data <- rep_len(NA, pad)
# metric_data <- c(metric_data, pad_data)
# metrics[[metric]] <- metric_data
# }
# return history
structure(class = "keras_training_history", list(
params = params,
metrics = metrics
))
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.