#' @title Visualize the risk
#'
#' @description
#' This function visualizes the risk during training. If validation data are given, then
#' the train risk is plotted against the validation risk.
#'
#' @return `ggplot` object containing the graphic.
#' @param cboost ([Compboost])\cr
#' A trained [Compboost] object.
#' @examples
#' cboost_no_valdat = boostSplines(data = iris, target = "Sepal.Length",
#' loss = LossQuadratic$new())
#' plotRisk(cboost_no_valdat)
#'
#' cboost_valdat = boostSplines(data = iris, target = "Sepal.Length",
#' loss = LossQuadratic$new(), oob_fraction = 0.3)
#' plotRisk(cboost_valdat)
#' @export
plotRisk = function(cboost) {
if (! requireNamespace("ggplot2", quietly = TRUE)) stop("Please install ggplot2 to create plots.")
checkmate::assertClass(cboost, "Compboost")
if (is.null(cboost$model))
stop("Model has not been trained!")
if (! cboost$model$isTrained())
stop("Model has not been trained!")
inbag_trace = cboost$getInbagRisk()
oob_log = cboost$getLoggerData()
.data = ggplot2::.data
if ("oob_risk" %in% names(oob_log)) {
oob_trace = oob_log[["oob_risk"]]
df_risk = data.frame(
risk = c(inbag_trace, oob_trace),
type = rep(c("inbag", "oob"), times = c(length(inbag_trace), length(oob_trace))),
iter = c(seq_along(inbag_trace), seq_along(oob_trace))
)
gg = ggplot2::ggplot(stats::na.omit(df_risk), ggplot2::aes(x = .data$iter, y = .data$risk, color = .data$type))
} else {
df_risk = data.frame(iter = seq_along(inbag_trace), risk = inbag_trace)
gg = ggplot2::ggplot(stats::na.omit(df_risk), ggplot2::aes(x = .data$iter, y = .data$risk))
}
gg = gg + ggplot2::geom_line(linewidth = 1.1) +
ggplot2::xlab("Iteration") +
ggplot2::ylab("Risk")
return(gg)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.