R/eval_cv.R

Defines functions fit_cv eval_cv

Documented in eval_cv

#' Create a cross validation evaluator
#'
#' @param nfolds integer. number of cv folds
#' @param ntrials integer. number of cv trials to run
#' @param conf_type string. How to calculate confidence interval of performance
#'   metrics across trials: 'norm' calcualtes std err using the 'sd' function,
#'   'perc' calculats lower and upper conf values using the 'quantile' function.
#' @param contrasts logical. Whether to compare test performance of fits within
#'   each group-outcome-stat combination (i.e., between predictors). This will
#'   result in a p-value for each model comparison as the proporiton of trials
#'   where one model had a lower performance than another model. Thus, a p-value
#'   of 0.05 indicates that one model performed worse than the other model 5%
#'   of the trials. If ntrials == 1, then this value can only be 0 or 1 to
#'   indicate which model is better.
#' @return aba model
#' @export
#'
#' @examples
#' data <- adnimerge %>% dplyr::filter(VISCODE == 'bl')
#' model <- aba_model() %>%
#'   set_data(data) %>%
#'   set_groups(everyone()) %>%
#'   set_outcomes(ConvertedToAlzheimers, CSF_ABETA_STATUS_bl) %>%
#'   set_predictors(
#'     PLASMA_ABETA_bl, PLASMA_PTAU181_bl, PLASMA_NFL_bl,
#'     c(PLASMA_ABETA_bl, PLASMA_PTAU181_bl, PLASMA_NFL_bl)
#'   ) %>%
#'   set_stats('glm') %>%
#'   set_evals('cv') %>%
#'   fit()
eval_cv <- function(nfolds = 5,
                    ntrials = 1,
                    conf_type = c('norm', 'perc'),
                    contrasts = FALSE) {
  conf_type <- match.arg(conf_type)

  struct <- list(
    nfolds = nfolds,
    ntrials = ntrials,
    conf_type = conf_type,
    contrasts = contrasts
  )
  struct$eval_type <- 'cv'
  class(struct) <- 'abaEval'
  struct
}

fit_cv <- function(model, nfolds = 5, ntrials = 1, verbose = FALSE) {
  # compile model
  fit_df <- model %>% aba_compile()

  # progress bar
  pb <- NULL
  if (verbose) pb <- progress::progress_bar$new(total = nrow(fit_df))
  fit_df <- 1:ntrials %>%
    purrr::map(
      function(index) {
        fit_df <- fit_df %>%
          group_by(group, outcome, stat) %>%
          nest() %>%
          rename(info=data) %>%
          rowwise() %>%
          mutate(
            data = process_dataset(
              data = model$data,
              group = .data$group,
              outcome = .data$outcome,
              stat = .data$stat,
              predictors = model$predictors,
              covariates = model$covariates
            )
          ) %>%
          ungroup()

        # fold data
        fit_df <- fit_df %>%
          mutate(
            data = purrr::map(
              .data$data,
              function(data) {
                cv_idx <- sample(cut(1:nrow(data), breaks=nfolds, labels=F))
                data$cv_idx <- cv_idx
                data_list <- 1:nfolds %>% purrr::map(
                  function(idx) {
                    data_train <- data %>% dplyr::filter(.data$cv_idx != idx)
                    data_test <- data %>% dplyr::filter(.data$cv_idx == idx)
                    list('data' = data_train, 'data_test' = data_test)
                  }
                )
                names(data_list) <- 1:nfolds
                data_list
              }
            )
          ) %>%
          unnest_longer(data, indices_to = 'fold') %>%
          unnest_wider(data) %>%
          unnest(info)

        # fit model
        fit_df <- fit_df %>%
          rowwise() %>%
          mutate(
            fit = fit_stat(
              data = .data$data,
              outcome = .data$outcome,
              stat = .data$stat,
              predictors = .data$predictor,
              covariates = .data$covariate,
              pb = pb
            )
          ) %>%
          ungroup()

        # select only factor labels and fit
        fit_df <- fit_df %>%
          select(gid, oid, sid, pid, fit, .data$fold, .data$data_test) %>%
          rename(
            group = gid,
            outcome = oid,
            stat = sid,
            predictor = pid
          )

        fit_df <- fit_df %>%
          mutate(trial = index) %>%
          select(-c(fold, trial), everything())

        # check that all models are not null
        if (sum(purrr::map_lgl(fit_df$fit, ~!is.null(.))) == 0) {
          stop('All models failed to be fit. Check your model setup.')
        }
        fit_df
      }
    ) %>%
    bind_rows()

  fit_df <- fit_df %>% mutate(fold = as.integer(fold))

  model$results <- fit_df
  model$is_fit <- TRUE
  model$fit_type <- 'cv'
  return(model)
}



# add model comparisons
summary_cv <- function(model,
                       label,
                       control = aba_control(),
                       adjust = aba_adjust(),
                       verbose = FALSE) {
  if (length(model$evals) > 1) model$results <- model$results[[label]]
  results <- model$results
  ntrials <- max(results$trial)
  nfolds <- max(results$fold)
  eval_obj <- model$evals[[label]]
  conf_type <- eval_obj$conf_type
  contrasts <- eval_obj$contrasts

  # grab stat object
  results <- results %>%
    mutate(
      stat_obj = purrr::map(stat, ~model$stats[[.]])
    )

  # use evaluate function from stat object on fitted model and test data
  results <- results %>%
    mutate(
      results_test = purrr::pmap(
        list(stat_obj, fit, data_test),
        function(stat_obj, fit, data_test) {
          # if an error happens here, then the stat has no evaluate function
          x <- stat_obj$fns$evaluate(fit, data_test)
          x
        }
      )
    )

  results <- results %>%
    select(-c(fit, data_test, stat_obj)) %>%
    unnest(results_test)


  metrics <- results %>% select(-c(.data$group:.data$form)) %>% colnames()

  # summarise across folds
  results <- results %>%
    #pivot_longer(.data$rmse:.data$mae) %>%
    pivot_longer(all_of(metrics)) %>%
    group_by(group, outcome, stat, predictor, form, name, trial) %>%
    summarise(
      estimate_trial = mean(value),
      .groups='keep'
    ) %>%
    ungroup()

  results_raw <- results %>%
    pivot_wider(names_from=name, values_from=estimate_trial)

  # now summarise across trials
  results <- results %>%
    group_by(group, outcome, stat, predictor, form, name) %>%
    summarise(
      estimate = mean(estimate_trial),
      std_err = sd(estimate_trial),
      conf_low = quantile(estimate_trial, 0.025, na.rm=T),
      conf_high = quantile(estimate_trial, 0.975, na.rm=T),
      .groups='keep'
    ) %>%
    ungroup()

  results_train <- results %>%
    filter(form == 'train') %>%
    select(group:predictor, name, estimate) %>%
    rename(estimate_train = estimate)

  results <- results %>%
    filter(form == 'test') %>%
    select(-form) %>%
    left_join(
      results_train,
      by = c("group", "outcome", "stat", "predictor", "name")
    ) %>%
    rename(term = name)

  if (conf_type == 'norm') {
    results <- results %>%
      mutate(
        conf_low = estimate - 1.96 * std_err,
        conf_high = estimate + 1.96 * std_err
      )
  }

  if (ntrials == 1) results <- results %>% mutate(conf_low = NA, conf_high = NA)

  results <- results %>%
    select(group:term, estimate, conf_low, conf_high, estimate_train)

  results_list <- list(
    test_metrics = results
  )

  if (contrasts) {
    metric <- results_raw %>% select(-c(group:trial)) %>% names() %>% head(1)
    contrasts_df <- results_raw %>%
      filter(form == 'test') %>%
      rename(estimate = {{ metric }}) %>%
      select(group:trial, estimate) %>%
      pivot_wider(names_from=predictor, values_from=estimate)

    xdf <- contrasts_df %>% select(all_of(unique(results_raw$predictor)))

    cdf <- utils::combn(data.frame(xdf), 2, FUN = function(x) x[,1] - x[,2]) %>%
      data.frame() %>% tibble() %>%
      set_names(
        utils::combn(unique(results_raw$predictor), 2,
              FUN = function(o) paste0(o[[1]],'_',o[[2]]))
      )

    contrasts_df <- contrasts_df %>%
      select(-all_of(unique(results_raw$predictor))) %>%
      bind_cols(cdf)

    contrasts_df <- contrasts_df %>%
      group_by(group, outcome, stat) %>%
      summarise(
        across(colnames(cdf),
               list(
                 'estimate' = ~ mean(.x, na.rm=T),
                 'stderr' = ~ sd(.x, na.rm=T),
                 'conflow' = ~ quantile(.x, 0.025, na.rm=T),
                 'confhigh' = ~ quantile(.x, 0.975, na.rm=T),
                 'pval' = ~ mean(.x < 0, na.rm=T) # direction should be inferred
               )),
        .groups = 'keep'
      ) %>%
      ungroup()

    contrasts_df <- contrasts_df %>%
      pivot_longer(
        cols = -c(group, outcome, stat),
        names_to=c('predictor', 'predictor2', 'form'),
        names_sep = '_'
      ) %>%
      pivot_wider(names_from = form, values_from = value) %>%
      rename(conf_low = conflow, conf_high = confhigh, std_err = stderr)

    if (conf_type == 'norm') {
      contrasts_df <- contrasts_df %>%
        mutate(
          conf_low = estimate - 1.96 * std_err,
          conf_high = estimate + 1.96 * std_err
        )
    }

    contrasts_df <- contrasts_df %>%
      select(-c(std_err))

    results_list$contrasts <- contrasts_df
  }

  results_list
}


as_table_cv <- function(results, control) {
  as_table_traintest(results, control)
}
ncullen93/abaR documentation built on Feb. 8, 2024, 12:01 p.m.