R/analyze_roc.R

Defines functions analyze_roc

Documented in analyze_roc

#' @title Analyze ROC Curves
#'
#' @description This function analyzes ROC curves from the results of the \code{\link{assess_models}} function
#'
#' @param ... Output(s) of the \code{\link{language_model}}, \code{\link{comparison_model}}, or \code{\link{test_language_model}} functions
#' @param plot If TRUE, plots a matrix displaying the results of all model comparisons. Defaults to TRUE.
#' @param plot_diagonal if TRUE, the matrix plot will show repeated (inverted) values on the opposite diagonal. Defaults to FALSE.
#'
#' @return A dataframe with the results of statistical tests conducted on the ROCs for each model pairing
#'
#' @seealso \code{\link{language_model}}, \code{\link{comparison_model}}, \code{\link{test_language_model}}
#'
#' @import ggplot2
#' @importFrom rlang .data
#'
#' @export
#'
#' @examples
#'
#' \dontrun{
#' strong_movie_review_data$cleanText = clean_text(strong_movie_review_data$text)
#' mild_movie_review_data$cleanText = clean_text(mild_movie_review_data$text)
#'
#' # Using language to predict "Positive" vs. "Negative" reviews
#' # Only for strong reviews (ratings of 1 or 10)
#' movie_model_strong = language_model(strong_movie_review_data,
#'                                      outcome = "valence",
#'                                      outcomeType = "binary",
#'                                      text = "cleanText",
#'                                      progressBar = FALSE)
#'
#' # Using language to predict "Positive" vs. "Negative" reviews
#' # Only for mild reviews (ratings of 4 or 7)
#' movie_model_mild = language_model(mild_movie_review_data,
#'                                      outcome = "valence",
#'                                      outcomeType = "binary",
#'                                      text = "cleanText",
#'                                      progressBar = FALSE)
#'
#' # Analyze ROC curves
#' auc_tests = analyze_roc(movie_model_strong, movie_model_mild)
#' }

analyze_roc = function(..., plot=TRUE, plot_diagonal=FALSE) {

  model1=model2=model2_auc=model1_auc=size=width=height=sig=font=sig_TF=NULL

  dots = list(...)
  dots_names = match.call(expand.dots = FALSE)

  for (i in 1:length(dots)) {
    input = dots[[i]]
    if (!(class(input) %in% c("langModel", "compModel", "testAssessment"))) {
      stop(paste0("Your argument '", as.character(dots_names$...[[i]]), "'must be a model generated by either the `language_model` or 'comparison_model` functions."))
    }
    if (input@type != "binary") {
      stop(paste0("ROCs can only be analyzed for models with a binary outcome variable (`",as.character(dots_names$...[[i]]),"` does not have a binary outcome)."))
    }
    if (class(input) == "testAssessment") {
      if (!is.vector(as.character(dots_names$...))) {
        namelist = c(as.character(dots_names$...))
      }
      else {
        namelist = as.character(dots_names$...)
      }
      if (!input@trainedModel %in% namelist){
        result = askYesNo(paste0("`",namelist[i],"` is the outcome of testing a model on new data, but the original model (`",input@trainedModel,"`) has not been included. Are you sure you want to continue without including it?"))
        if (is.na(result)) {
          stop("Function aborted.")
        }
        if (!result) {
          stop("Function aborted.")
        }
      }
    }
  }


  model_labels = data.frame(matrix(ncol=2,nrow=0))
  colnames(model_labels) = c("name", "auc")

  roc_list = list()


  for (i in 1:length(dots)) {
    input = dots[[i]]

    roc_data = input@roc
    roc_list[[as.character(dots_names$...[[i]])]] = roc_data

    temp_frame = data.frame(name=as.character(dots_names$...[[i]]), auc = roc_data$auc)
    model_labels = rbind(model_labels, temp_frame)
  }



  auc_tests = data.frame(matrix(ncol=7,nrow=0))
  colnames(auc_tests) = c("model1", "model2", "model1_auc", "model2_auc", "statistic", "statistic_value", "p_value")


  for (i in 1:length(roc_list)) {
    for (j in (i+1):length(roc_list)) {
      if (j > length(roc_list)) {
        next
      }
      model1 = names(roc_list)[i]
      model2 = names(roc_list)[j]
      if (model1 == model2) {
        next
      }
      model1_auc = roc_list[[i]]$auc
      model2_auc = roc_list[[j]]$auc
      test_output = roc.test(roc_list[[i]], roc_list[[j]])
      temp_frame = data.frame(model1=model1, model2=model2, model1_auc=model1_auc, model2_auc=model2_auc, statistic = names(test_output$statistic)[1], statistic_value=test_output$statistic, p_value=test_output$p.value)
      auc_tests = rbind(auc_tests, temp_frame)
    }
  }

  auc_tests$sig_TF = ifelse(auc_tests$p_value < .05, 1, 0)
  auc_tests$sig_TF = factor(auc_tests$sig_TF, levels = c(0,1))
  auc_tests$font = ifelse(auc_tests$sig_TF == 0, "plain", "bold")
  auc_tests$size = ifelse(auc_tests$sig_TF == 0, 0, 1)
  auc_tests$size = as.factor(auc_tests$size)
  auc_tests$width = ifelse(auc_tests$sig_TF == 0, 1, .95)
  auc_tests$height = ifelse(auc_tests$sig_TF == 0, 1, .95)
  auc_tests$sig = ifelse(auc_tests$p_value >= .05, "NS", ifelse(auc_tests$p_value >= .01, "*", ifelse(auc_tests$p_value >= .001, "**", "***")))
  auc_tests$model1 = factor(auc_tests$model1, levels = model_labels$name)
  auc_tests$model2 = factor(auc_tests$model2, levels = rev(model_labels$name))

  if (plot_diagonal) {
    auc_tests2 = auc_tests
    colnames(auc_tests2) = c("model2", "model1", "model2_auc", "model1_auc", "statistic", "statistic_value", "p_value", "sig_TF", "font", "size", "width", "height", "sig")
    auc_tests = rbind(auc_tests, auc_tests2)
  }

  if(plot & nrow(model_labels) > 1) {
    r = suppressWarnings(ggplot(auc_tests) +
                           geom_tile(aes(x=model1, y=model2, fill=(model2_auc-model1_auc), size=size, width=width, height=height), color="black") +
                           geom_text(aes(x=model1, y=model2, label=sig, fontface=font)) +
                           scale_fill_gradient2(low="#ff4c4c", mid="#ffffff", high="#4c4cff", midpoint=0) +
                           scale_x_discrete(position = "top", drop=FALSE) +
                           scale_y_discrete(drop=FALSE) +
                           scale_size_discrete(range=c(1,2)) +
                           labs(fill="Difference\nbetween\nAUCs\n(Model2 - Model1)", title = "Testing differences between ROC curves\n(significance labeled)") +
                           guides(size=FALSE) +
                           coord_fixed() +
                           theme(panel.grid = element_blank(),
                                 panel.background = element_blank(),
                                 axis.text.x = element_text(angle=90)))
    print(r)
  }

  auc_tests = subset(auc_tests, select = -c(sig_TF, font, size, width, height))

  return(auc_tests)
}
nlanderson9/languagePredictR documentation built on June 10, 2021, 11 a.m.