R/rm_main_results.R

Defines functions get_augmented_prediction_results get_parameters.rm_main_results plot.rm_main_results aggregate_resample_run_results.rm_main_results aggregate_CV_split_results.rm_main_results new_rm_main_results rm_main_results

Documented in plot.rm_main_results rm_main_results

#' A result metric (RM) that calculates main decoding accuracy measures
#'
#' This result metric calculate the zero-one loss, the normalized rank, and the
#' mean of the decision values. This is also an S3 object which has an
#' associated plot function to display the results.
#'
#' @details
#' Like all result metrics, this result metric has functions to aggregregate
#' results after completing each set of cross-validation classifications, and
#' also after completing all the resample runs. The results should then be
#' available in the DECODING_RESULTS object returned by the cross-validator.
#'
#' @param aggregate_decision_values A string or boolean specifying how the
#'   decision values should be aggregated. If this is a boolean set to TRUE or
#'   to the string "full", then the decision values for the correct category
#'   will be calculated. If this is a boolean set to FALSE or to the string
#'   "none", then the decision values will not be calculated. If this is a
#'   string set to either "diag" or "only same train test time" then the
#'   decision values will only be calculated when for results when training and
#'   testing at the same time. Not returning the full results can speed up the
#'   runtime of the code and will use less memory so this can be useful for
#'   large data sets.
#'
#' @param aggregate_normalized_rank A string or boolean specifying how the
#'  normalized rank results should be aggregated. If this is a boolean
#'  set to TRUE or to the string "full", then the decision values for the correct
#'  category will be calculated. If this is a boolean set to FALSE or to the
#'  string "none", then the decision values will not be calculated. If this is a
#'  string set to either "diag" or "only same train test time" then the decision
#'  values will only be calculated when for results when training and testing
#'  at the same time. Not returning the full results can greatly speed up the runtime
#'  of the code and will use less memory so this can be useful for large data sets.
#'
#'
#' @examples
#' # If you only want to use the rm_main_results(), then you can put it in a
#' # list by itself and pass it to the cross-validator.
#' the_rms <- list(rm_main_results())
#' 
#' 
#' @family result_metrics
#'
#'
#'
#' @export
rm_main_results <- function(aggregate_decision_values = TRUE, aggregate_normalized_rank = TRUE) {
  
  options <- list(
    aggregate_decision_values = aggregate_decision_values,
    aggregate_normalized_rank = aggregate_normalized_rank)

  rm_obj <- new_rm_main_results(options = options)

  rm_obj
  
}




# the internal private constructor
new_rm_main_results <- function(the_data = data.frame(), state = "initial", options = NULL) {
  
  rm_obj <- the_data

  attr(rm_obj, "class") <- c("rm_main_results", "data.frame")
  attr(rm_obj, "state") <- state
  attr(rm_obj, "options") <- options

  rm_obj
  
}






# The aggregate_CV_split_results method needed to fulfill the results metric interface.
aggregate_CV_split_results.rm_main_results <- function(rm_obj, prediction_results) {


  # return a warning if the state is not intial
  if (attr(rm_obj, "state") != "initial") {
    
    warning(paste0(
      "The method aggregate_CV_split_results() should only be called on",
      "rm_main_results that are in the intial state.",
      "Any data that was already stored in this object will be overwritten"))
    
  }


  # get the options for how the normalized rank and decision values should be aggregated
  aggregate_decision_values <- attr(rm_obj, "options")$aggregate_decision_values
  aggregate_normalized_rank <- attr(rm_obj, "options")$aggregate_normalized_rank


  # get the zero-one loss loss results
  the_results <- prediction_results %>%
    dplyr::mutate(correct = .data$actual_labels == .data$predicted_labels) %>%
    dplyr::group_by(.data$CV, .data$train_time, .data$test_time) %>%
    summarize(zero_one_loss = mean(.data$correct, na.rm = TRUE))


  # get the normalized rank decision values
  if (aggregate_normalized_rank != FALSE && aggregate_normalized_rank != "none") {

    # data slightly augmented version of that has actual_labels with decision_vals. appended
    prediction_results_aug <- get_augmented_prediction_results(prediction_results, aggregate_normalized_rank)
    decision_vals_aug <- select(prediction_results_aug, starts_with("decision"))
    decision_vals_rest <- select(prediction_results_aug, -starts_with("decision"))

    # get the normalized rank decision values
    get_rank_one_row <- function(decision_vals_aug_row) {
      actual_label <- decision_vals_aug_row[1]
      decision_vals_row <- decision_vals_aug_row[2:length(decision_vals_aug_row)]
      the_names <- names(decision_vals_row)
      the_order <- order(as.numeric(decision_vals_row), decreasing = TRUE)
      which(the_names[the_order] == actual_label)
    }


    num_classes <- prediction_results %>%
      select(starts_with("decision_vals")) %>%
      ncol()

    normalized_rank_results <- 1 - ((apply(decision_vals_aug, 1, get_rank_one_row) - 1) / (num_classes - 1))

    summarized_normalized_rank_results <- decision_vals_rest %>%
      mutate(normalized_rank = normalized_rank_results) %>%
      dplyr::group_by(.data$CV, .data$train_time, .data$test_time) %>%
      summarize(normalized_rank = mean(.data$normalized_rank, na.rm = TRUE))

    the_results <- left_join(the_results, summarized_normalized_rank_results,
      by = c("CV", "train_time", "test_time"))
    
  }





  # get the decision values for the correct label
  if (aggregate_decision_values != FALSE && aggregate_decision_values != "none") {
    
    prediction_results_aug <- get_augmented_prediction_results(prediction_results, aggregate_decision_values)
    decision_vals_aug <- select(prediction_results_aug, starts_with("decision"))
    decision_vals_rest <- select(prediction_results_aug, -starts_with("decision"))

    get_decision_vals_one_row <- function(decision_vals_aug_row) {
      decision_vals_aug_row[which(as.character(as.matrix(decision_vals_aug_row[1])) == names(decision_vals_aug_row[2:length(decision_vals_aug_row)])) + 1]
    }

    correct_class_decision_val <- as.numeric(apply(decision_vals_aug, 1, get_decision_vals_one_row))

    summarized_correct_decision_val_results <- decision_vals_rest %>%
      mutate(decision_vals = correct_class_decision_val) %>%
      dplyr::group_by(.data$CV, .data$train_time, .data$test_time) %>%
      summarize(decision_vals = mean(.data$decision_vals, na.rm = TRUE))

    the_results <- left_join(the_results, summarized_correct_decision_val_results,
      by = c("CV", "train_time", "test_time"))
    
  }


  options <- attr(rm_obj, "options")
  options$zero_one_loss_chance_level <- 1/length(unique(prediction_results$actual_labels))
  
  new_rm_main_results(
    the_results,
    "results combined over one cross-validation split",
    options)
  
}




# The aggregate_resample_run_results method needed to fulfill the results metric interface
aggregate_resample_run_results.rm_main_results <- function(resample_run_results) {
  
  central_results <- resample_run_results %>%
    group_by(.data$train_time, .data$test_time) %>%
    summarize(zero_one_loss = mean(.data$zero_one_loss))


  if ("normalized_rank" %in% names(resample_run_results)) {
    
    normalized_rank_results <- resample_run_results %>%
      group_by(.data$train_time, .data$test_time) %>%
      summarize(normalized_rank = mean(.data$normalized_rank))

    central_results <- left_join(central_results, normalized_rank_results, by = c("train_time", "test_time"))
  }


  if ("decision_vals" %in% names(resample_run_results)) {
    
    decision_vals_results <- resample_run_results %>%
      group_by(.data$train_time, .data$test_time) %>%
      summarize(decision_vals = mean(.data$decision_vals))

    central_results <- left_join(central_results, decision_vals_results, by = c("train_time", "test_time"))
  }


  new_rm_main_results(
    central_results,
    "final results",
    attributes(resample_run_results)$options)
  
}





#' A plot function for the rm_main_results object
#'
#' This function can create a line plot of the results or temporal
#' cross-decoding results for the the zero-one loss, normalized rank and/or
#' decision values after the decoding analysis has been run (and all results
#' have been aggregated).
#'
#' @param x A rm_main_result object that has aggregated runs from a
#'   decoding analysis, e.g., if DECODING_RESULTS are the out from the
#'   run_decoding(cv) then this argument should be
#'   `DECODING_RESULTS$rm_main_results`.
#'
#' @param ... This is needed to conform to the plot generic interface.
#'
#' @param result_type A string specifying the types of results to plot. Options
#'   are: 'zero_one_loss', 'normalized_rank', 'decision_values', or 'all'.
#'
#' @param plot_type A string specifying the type of results to plot. Options are
#'   'TCD' to plot a temporal cross decoding matrix or 'line' to create a line
#'   plot of the decoding results as a function of time
#'
#' @family result_metrics
#'
#'
#'
#' @export
plot.rm_main_results <- function(x, ..., result_type = "zero_one_loss", plot_type = "TCD") {
  
  main_results <- x

  if (attributes(main_results)$state != "final results") {
    stop("The results can only be plotted *after* the decoding analysis has been run")
  }


  # convert the zero-one loss results to percentages
  main_results <- dplyr::mutate(main_results, zero_one_loss = .data$zero_one_loss * 100)


  # parse which type of results should be plotted
  if (result_type == "all") {
    
    # do nothing
    
  } else if (result_type == "zero_one_loss") {
    
    main_results <- dplyr::select(main_results, .data$train_time, .data$test_time, .data$zero_one_loss)
    
  } else if (result_type == "normalized_rank") {
    
    if (!(result_type %in% names(main_results))) {
      stop(paste("Can't plot", result_type, "results because this type of result was not saved."))
    }

    main_results <- dplyr::select(main_results, .data$train_time, .data$test_time, .data$normalized_rank)
    
  } else if (result_type == "decision_vals") {
    
    if (!(result_type %in% names(main_results))) {
      stop(paste("Can't plot", result_type, "results because this type of result was not saved."))
    }

    main_results <- dplyr::select(main_results, .data$train_time, .data$test_time, .data$decision_vals)
    
  } else {
    
    warning(paste0(
      "result_type must be set to either 'all', 'zero_one_loss', 'normalized_rank', or 'decision_vals'.",
      "Using the default value of all"))
  }



  if (!(plot_type == "TCD" || plot_type == "line")) {
    warning("plot_type must be set to 'TCD' or 'line'. Using the default value of 'TCD'")
  }



  main_results$train_time <- round(get_center_bin_time(main_results$train_time))
  main_results$test_time <- round(get_center_bin_time(main_results$test_time))

  # an alternative way to display the labels (not used)
  # main_results$train_time <- get_time_range_strings(main_results$train_time)
  # main_results$test_time <- get_time_range_strings(main_results$test_time)

  
  main_results <- main_results %>%
    tidyr::gather("result_type", "accuracy", -.data$train_time, -.data$test_time) %>%
    dplyr::mutate(
      result_type = replace(result_type, result_type == "zero_one_loss", "Zero-one loss"),
      result_type = replace(result_type, result_type == "normalized_rank", "Normalized rank"),
      result_type = replace(result_type, result_type == "decision_vals", "Decision values"))

  
  # data frame for plotting a horizontal line at chance decoding accuracy levels
  #  (chance level from the input x which is the rm_main_results object passed to plot)
  zero_one_loss_chance <- 100 * attributes(x)$options$zero_one_loss_chance_level
  chance_accuracy_df <- data.frame(
    result_type = c("Zero-one loss", "Normalized rank", "Decision values"),
    chance_level = c(zero_one_loss_chance, .5, NA)) %>%
    dplyr::filter(.data$result_type %in% unique(main_results$result_type))
  
  
  # if only a single time, just plot a bar for the decoding accuracy
  if (length(unique(main_results$train_time)) == 1) {
    
    main_results %>%
      ggplot(aes(.data$test_time, .data$accuracy)) +
      geom_col() +
      facet_wrap(~result_type, scales = "free") +
      xlab("Time") +
      ylab("Accuracy")
    
  } else if ((sum(main_results$train_time == main_results$test_time) == dim(main_results)[1]) ||
    plot_type == "line") {

    # if only trained and tested at the same time, create line plot
    main_results %>%
      dplyr::filter(.data$train_time == .data$test_time) %>%
      ggplot(aes(.data$test_time, .data$accuracy)) +
      facet_wrap(~result_type, scales = "free") +
      xlab("Time") +
      ylab("Accuracy") +
      geom_hline(data = chance_accuracy_df, 
                 aes(yintercept = .data$chance_level),
                 color = "maroon", na.rm=TRUE) + 
      geom_line() + 
      theme_classic() + 
      theme(
        panel.grid.major = element_blank(),
        panel.grid.minor = element_blank(),
        strip.background = element_rect(colour = "white", fill = "white"))
      
    
  } else {

    # if trained and testing at all times, create a TCD plot

    if (result_type != "all") {
      
      g <- main_results %>%
        ggplot(aes(.data$test_time, .data$train_time, fill = .data$accuracy)) +
        geom_tile() +
        facet_wrap(~result_type) +
        scale_fill_continuous(type = "viridis", name = "Prediction \n accuracy") +
        ylab("Train time") +
        xlab("Test time") +
        theme(
          panel.grid.major = element_blank(),
          panel.grid.minor = element_blank(),
          strip.background = element_rect(colour = "white", fill = "white"))

      g
      
    } else if (result_type == "all") {

      # plotting multiple TCD subplots on the same figure

      all_TCD_plots <- lapply(unique(main_results$result_type), function(curr_result_name) {
        curr_results <- filter(main_results, .data$result_type == curr_result_name)

        g <- curr_results %>%
          ggplot(aes(.data$test_time, .data$train_time, fill = .data$accuracy)) +
          geom_tile() +
          facet_wrap(~result_type, scales = "free") +
          # scale_fill_continuous(type = "viridis", name = curr_results$result_type[1]) +
          scale_fill_continuous(type = "viridis", name = "") +
          ylab("Train time") +
          xlab("Test time") +
          theme(legend.position = "bottom") +
          theme(
            panel.grid.major = element_blank(),
            panel.grid.minor = element_blank(),
            strip.background = element_rect(colour = "white", fill = "white"))
        
        g
        
      })

      all_TCD_plots[["ncol"]] <- 3
      do.call(gridExtra::grid.arrange, all_TCD_plots)
    }
    
  }
}




# Get the parameters for the rm_main_results object
#' @export
get_parameters.rm_main_results <- function(ndtr_obj) {

  # get the options for now the normalized rank and decision values should be aggregated
  aggregate_decision_values <- attr(ndtr_obj, "options")$aggregate_decision_values
  aggregate_normalized_rank <- attr(ndtr_obj, "options")$aggregate_normalized_rank

  data.frame(
    rm_main_results.aggregate_decision_values = aggregate_decision_values,
    rm_main_results.aggregate_normalized_rank = aggregate_normalized_rank)
  
}





# A private helper function to get data needed to create the normalized rank
#  and decision value results
get_augmented_prediction_results <- function(prediction_results, aggregate_options) {

  # if only getting the decision values when training and testing at the same time
  if (aggregate_options == "diag" || aggregate_options == "only same train test time") {
    
    prediction_results <- prediction_results %>%
      dplyr::filter(.data$train_time == .data$test_time)

    # if getting the decision values for all time points
  } else if (aggregate_options == TRUE || aggregate_options == "full") {

    # if getting data from all times do nothing
  } else {
    
    argument_name <- deparse(substitute(aggregate_options))
    stop(paste0(
      argument_name, " was set to ", aggregate_options, ". ", argument_name,
      " must be set to one of the following: ",
      "TRUE, 'all', FALSE, 'none', 'diag', or 'only same train test time'"))
    
  }


  # add decision_vals. to the actual label names to allow a comparion
  #  of the actual labels to column names
  prediction_results <- prediction_results %>%
    mutate(decision_actual_labels = paste0("decision_vals.", .data$actual_labels)) %>%
    select(.data$decision_actual_labels, starts_with("decision"), everything())

  prediction_results
  
}
emeyers/NDTr documentation built on Aug. 8, 2020, 3:41 p.m.