R/caret.R

# Functions to ease the caret interface.

#' Extract data for an ROC figure; eg TPR vs FPR
#'
#' @param model
#'
#' @return data.frame with one observation per radiograph, and 3 variables:
#'       # x: num [0, 1] corresponding to the False Positive Rate
#'       # y: num [0, 1] corresponding to the True Positive Rate
#'       # alpha: num [1, 0] corresponding to the classification threshold
#' @export
extract_model_roc_data <- function(model) {
    Y = model$trainingData$.outcome
    pY = extract_model_probas(model)
    pred <- prediction(pY, Y)
    perf <- performance(pred, measure="tpr", x.measure="fpr")
    roc_data <- data.frame(
        x = perf@x.values %>% unlist,
        y = perf@y.values %>% unlist,
        alpha = perf@alpha.values %>% unlist
    )
}

#' Extract a dataframe of pY columns for each classifier in a `Comparison` object
#'
#' Reconciles differences between glm, ranger, and glmnet model schemas.
#'
#' @param model
#'
#' @return Named numeric of pYs, named by `case_seq`
#' @export
#'
#' @examples
#' map(comparison@models, compose(enframe, extract_model_probas)) %>%
#'     reduce(.f=full_join, by="name")
extract_model_probas <- function(model) {
    assert(class(model) == "train")

    pY <- predict(model, model$finalModel$data, type="prob") %>% `[`('true')
    # glm and glmnet use rownames (case_id) by default, but ranger does not...
    if (!has_rownames(pY)) {  # add the case rownames from the trainingdata back
        rownames(pY) <- model$trainingData %>% rownames()
    }

    rownames(pY) %<>% str_replace("^X", "")
    rownames(pY) %<>% str_replace_all("\\.", "-")
    pY %<>%
        rownames_to_column() %>%
        deframe()
    return(pY)
}
mbadge/AnalysisToolkitR documentation built on May 27, 2019, 1:08 p.m.