R/plot_precision_recall.R

Defines functions plot_precision_recall

Documented in plot_precision_recall

#' Precision recall plot from a cutpointr object
#'
#' Given a \code{cutpointr} object this function plots the precision recall curve(s)
#' per subgroup, if given.
#' @param x A cutpointr object.
#' @param display_cutpoint (logical) Whether or not to display the optimal
#' cutpoint as a dot on the precision recall curve.
#' @param ... Additional arguments (unused).
#' @examples
#' library(cutpointr)
#'
#' ## Optimal cutpoint for dsi
#' data(suicide)
#' opt_cut <- cutpointr(suicide, dsi, suicide)
#' plot_precision_recall(opt_cut)
#' @family cutpointr plotting functions
#' @export
plot_precision_recall <- function(x, display_cutpoint = TRUE, ...) {

    stopifnot("cutpointr" %in% class(x))
    args <- list(...)

    if (!(has_column(x, "subgroup"))) {
        dts_pr <- "roc_curve"
        transparency <- 1
    } else {
        dts_pr <- c("roc_curve", "subgroup")
        transparency <- 0.6
    }

    if (!(has_column(x, "subgroup"))) {
        plot_title <- ggplot2::ggtitle("Precision recall plot")
    } else {
        plot_title <- ggplot2::ggtitle("Precision recall plot", "by class")
    }
    for (r in 1:nrow(x)) {
        x$roc_curve[[r]] <- x$roc_curve[[r]] %>%
            dplyr::mutate(Precision = tp / (tp + fp),
                           Recall = tp / (tp + fn))
    }
    if (display_cutpoint) {
        optcut_coords <- purrr::pmap_df(x, function(...) {
            args <- list(...)
            opt_ind <- get_opt_ind(roc_curve = args$roc_curve,
                                   oc = args$optimal_cutpoint,
                                   direction = args$direction)
            data.frame(Precision = args$roc_curve$Precision[opt_ind],
                       Recall = args$roc_curve$Recall[opt_ind])

        })
    }
    res_unnested <- x %>%
        dplyr::select(dts_pr) %>%
        tidyr::unnest(.data$roc_curve)
    res_unnested <- res_unnested[is.finite(res_unnested$x.sorted), ]
    if (!(has_column(x, "subgroup"))) {
        pr <- ggplot2::ggplot(res_unnested,
                              ggplot2::aes(x = Recall, y = Precision))
    } else {
        pr <- ggplot2::ggplot(res_unnested,
                              ggplot2::aes(x = Recall, y = Precision,
                                           color = subgroup))
    }
    pr <- pr +
        ggplot2::geom_line() +
        plot_title +
        ggplot2::theme(aspect.ratio = 1) +
        ggplot2::coord_cartesian(xlim = c(0, 1), ylim = c(0, 1))
    if (display_cutpoint) {
        pr <- pr + ggplot2::geom_point(data = optcut_coords, color = "black")
    }

    return(pr)
}
Thie1e/cutpointr documentation built on March 7, 2020, 3:25 a.m.