R_dev/summarize_roc_curve.R

#' Summarize the ROC curves from a cross validation of a two class classifier build by `caret::train`
#'
#' @param trainObj Object of class `train` with saved predictions from the cross validation process
#' @param pos String with the name of the `positive` class in the two class classification problem
#' @param fsum Function that is used to summarize the ROC curves from the cross validation
#' @param ... Further arguments for `fsum`
#'
#' @return A `tbl_df` with columns `tpr` and `fpr` that can be used to visualize the summarized ROC curve.
#'
#' @export

rocsy <- function(trainObj, pos, fsum = mean, ...){

  roc_cv <- trainObj$bestTune %>%
    dplyr::left_join(trainObj$pred, by=colnames(trainObj$bestTune)) %>%
    dplyr::mutate(resp = as.numeric(obs == pos), pred_prob = !!dplyr::sym(pos)) %>%
    dplyr::select(resp, pred_prob, Resample) %>%
    tibble::as_tibble() %>%
    tidyr::nest(-Resample) %>%
    dplyr::rename(pred = data)

  f_calcroc <- function(x){
    x %>%
      cutpointr::cutpointr(x = pred_prob, class = resp, silent = TRUE, direction = ">=") %>%
      dplyr::select(roc_curve) %>%
      tidyr::unnest() %>%
      dplyr::select(fpr, tpr)
  }

  roc_cv <- roc_cv %>%
    dplyr::mutate(roc = purrr::map(pred, f_calcroc))

  roc_fpr_grid <- roc_cv %>%
    dplyr::select(roc) %>%
    tidyr::unnest() %>%
    dplyr::select(fpr) %>%
    dplyr::distinct() %>%
    dplyr::arrange(fpr)

  f_rocexpand <- function(x){
    x %>%
      dplyr::full_join(roc_fpr_grid, ., by = "fpr") %>%
      zoo::na.locf() %>%
      dplyr::filter(!duplicated(fpr, fromLast = TRUE))
  }

  roc_cv <- roc_cv %>%
    dplyr::mutate(roc_expand = purrr::map(roc, f_rocexpand))

  res <- roc_cv %>%
    dplyr::select(roc_expand) %>%
    tidyr::unnest() %>%
    dplyr::group_by(fpr) %>%
    dplyr::summarise(
      "tpr" = fsum(tpr, ...)
    ) %>%
    dplyr::bind_rows(tibble("fpr" = c(0,1), "tpr" = c(0,1))) %>%
    dplyr::arrange(fpr, tpr) %>%
    dplyr::distinct()

  return(res)

}
seb09/rocsy documentation built on Nov. 5, 2019, 8:47 a.m.