R/plot_surv_model_performance_rocs.R

Defines functions plot.surv_model_performance_rocs

Documented in plot.surv_model_performance_rocs

#' Plot ROC Curves for Survival Models
#'
#' This function plots objects of class `"surv_model_performance_rocs"` - ROC curves for
#' specific time points for survival models created using the `model_performance(..., type="roc")`.
#'
#' @param x an object of class `"surv_model_performance_rocs"` to be plotted
#' @param ... additional objects of class `"surv_model_performance_rocs"` to be plotted together
#' @param title character, title of the plot
#' @param subtitle character, subtitle of the plot, `'default'` automatically generates "created for XXX, YYY models", where XXX and YYY are the explainer labels
#' @param auc boolean, whether the AUC values should be plotted
#' @param colors character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue")
#' @param facet_ncol number of columns for arranging subplots
#'
#' @return An object of the class `ggplot`.
#'
#' @family functions for plotting 'model_performance_survival' objects
#'
#' @examples
#' \donttest{
#' library(survival)
#' library(survex)
#'
#' model <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran)
#' exp <- explain(model)
#'
#' m_perf_roc <- model_performance(exp, type = "roc", times = c(100, 300))
#' plot(m_perf_roc)
#' }
#' @export
plot.surv_model_performance_rocs <- function(x,
                                             ...,
                                             title = "ROC curves for selected time points",
                                             subtitle = "default",
                                             auc = TRUE,
                                             colors = NULL,
                                             facet_ncol = NULL) {
    dfl <- c(list(x), list(...))

    alldfs <- lapply(dfl, function(x) {
        label <- attr(x, "label")

        cbind(x, label = label)
    })

    df <- do.call(rbind, alldfs)

    if (!is.null(subtitle) && subtitle == "default") {
        labels <- unique(df$label)
        endword <- ifelse(length(labels) > 1, " models", " model")
        subtitle <- paste0("created for the ", paste0(labels, collapse = ", "), endword)
    }

    num_colors <- length(unique(df$label))

    base_plot <- with(df, {
        ggplot(data = df, aes(x = FPR, y = TPR, group = label, color = label)) +
            geom_line(linewidth = 0.8) +
            theme_default_survex() +
            xlab("1 - specificity (FPR)") +
            ylab("sensitivity (TPR)") +
            coord_fixed() +
            theme(
                panel.grid.major.x = element_line(color = "grey90", linewidth = 0.5, linetype = 1),
                panel.grid.minor.x = element_line(color = "grey90", linewidth = 0.5, linetype = 1)
            ) +
            labs(title = title, subtitle = subtitle) +
            scale_color_manual("", values = generate_discrete_color_scale(num_colors, colors)) +
            facet_wrap(~time, ncol = facet_ncol, labeller = function(x) lapply(x, function(x) paste0("t=", x)))
    })

    if (auc) {
        auc_df <- unique(df[, c("label", "time", "AUC")])
        auc_df$AUC <- round(auc_df$AUC, 3)
        auc_df$y <- rep((0:(num_colors - 1)) * 0.1, each = length(unique(auc_df$time)))
        return_plot <- base_plot +
            with(auc_df, {
                geom_text(auc_df,
                    mapping = aes(
                        x = 0.75,
                        y = y,
                        label = paste("AUC =", AUC),
                        color = label
                    ),
                    show.legend = FALSE
                )
            })
    } else {
        return_plot <- base_plot
    }
    return_plot
}

Try the survex package in your browser

Any scripts or data that you put into this service are public.

survex documentation built on Oct. 25, 2023, 1:06 a.m.