R/get-tables.R

Defines functions get_bayes_R2 find_model_names loo_compare_order table_criterion

Documented in get_bayes_R2 table_criterion

#' Get table of various criterion
#' 
#' Calculates various criterion for a list of brms models using the \code{add_criterion} function.
#' 
#' @param fits a list of model fits in the order that you want to compare them
#' @param criterion the criterion to use
#' @param sort to sort the table with the best models at the top
#' @param ... additional parameters passed to \code{add_criterion}
#' @import brms
#' @import dplyr
#' @importFrom rstan get_elapsed_time get_num_upars
#' @importFrom lubridate seconds_to_period hour minute second
#' @importFrom stats setNames
#' @export
#' 
table_criterion <- function(fits, 
                            criterion = c("loo", "loo_R2", "bayes_R2", "log_lik"), 
                            sort = TRUE, ...) {
  
  n <- length(fits)
  
  for (i in 1:n) {
    if (!is.brmsfit(fits[[i]])) stop("fit is not an object of class brmsfit.")
  }
  
  # Get the formula, distribution, link function, run time, and number of divergent transitions for each model
  df_names <- NULL
  for (i in 1:n) {
    fit <- fits[[i]]
    
    nuts <- nuts_params(fit)
    divergent <- nuts %>% filter(.data$Parameter %in% "divergent__")
    td <- seconds_to_period(max(rowSums(get_elapsed_time(fit$fit))))

    df <- data.frame(id = i,
                     # Model = gsub(".*\\~ ", "", as.character(fit$formula)[1]), 
                     Model = as.character(fit$formula)[1], 
                     Var = as.character(fit$formula)[2], 
                     Distribution = as.character(fit$family)[1], 
                     Link = as.character(fit$family)[2],
                     # n_pars = get_num_upars(fit$fit),
                     n_divergent = sum(divergent$Value),
                     max_time = sprintf('%02d:%02d:%02d', hour(td), minute(td), round(second(td))))
    df_names <- rbind(df_names, df)
  }

  # Get the criterion if required
  # for (i in 1:n) {
  #   fits[[i]] <- add_criterion(x = fits[[i]], criterion = criterion, model_name = df_names$id[i], overwrite = TRUE)
  # }
  
  if ("loo" %in% criterion) {
    list_loo <- list()
    for (i in 1:n) {
      list_loo[[i]] <- fits[[i]]$criteria$loo
    }
    df_loo <- data.frame(loo_compare(list_loo))
    df_loo$id <- loo_compare_order(list_loo)
    df_loo$model_name <- find_model_names(list_loo)[df_loo$id]
    df_loo <- df_loo %>%
      relocate(.data$p_loo, .after = .data$se_looic) %>%
      relocate(.data$se_p_loo, .after = .data$p_loo) %>%
      mutate(across(.data$se_diff:.data$se_looic, ~ format(round(.x, digits = 0), nsmall = 0))) %>%
      mutate(across(.data$p_loo:.data$se_p_loo, ~ format(round(.x, digits = 1), nsmall = 1)))
  }
  
  if ("loo_R2" %in% criterion) {
    df_loo_R2 <- NULL
    for (i in 1:n) {
      fit <- fits[[i]]
      df1 <- loo_R2(fit) %>%
        data.frame() %>%
        mutate(id = df_names$id[i]) %>%
        rename(loo_R2 = .data$Estimate, se_loo_R2 = .data$Est.Error, Q2.5_loo_R2 = .data$Q2.5, Q97.5_loo_R2 = .data$Q97.5) %>%
        mutate(across(.data$loo_R2:.data$Q97.5_loo_R2, ~ format(round(.x, digits = 3), nsmall = 3))) %>%
        select(-.data$Q2.5_loo_R2, -.data$Q97.5_loo_R2)
      df_loo_R2 <- rbind(df_loo_R2, df1)
    }
  }

  if ("bayes_R2" %in% criterion) {
    df_bayes_R2 <- NULL
    for (i in 1:n) {
      fit <- fits[[i]]
      df1 <- bayes_R2(fit) %>%
        data.frame() %>%
        mutate(id = df_names$id[i]) %>%
        rename(bayes_R2 = .data$Estimate, se_bayes_R2 = .data$Est.Error, Q2.5_bayes_R2 = .data$Q2.5, Q97.5_bayes_R2 = .data$Q97.5) %>%
        mutate(across(bayes_R2:Q97.5_bayes_R2, ~ format(round(.x, digits = 3), nsmall = 3))) %>%
        select(-.data$Q2.5_bayes_R2, -.data$Q97.5_bayes_R2)
      df_bayes_R2 <- rbind(df_bayes_R2, df1)
    }
  }
  
  # if ("log_lik" %in% criterion) {
  #   df_ll <- NULL
  #   for (i in 1:n) {
  #     fit <- fits[[i]]
  #     ll <- rowSums(log_lik(fit))
  #     df1 <- data.frame(id = df_names$id[i]) %>%
  #       rename(log_lik = mean(ll), se_log_lik = sd(ll))
  #     df_ll <- rbind(df_ll, df1)
  #   }
  # }
  
  # Combine the tables
  df_all <- df_names
  if ("loo" %in% criterion) df_all <- df_all %>% left_join(df_loo, by = "id")
  if ("loo_R2" %in% criterion)  df_all <- df_all %>% left_join(df_loo_R2, by = "id")
  if ("bayes_R2" %in% criterion)  df_all <- df_all %>% left_join(df_bayes_R2, by = "id")
  # if ("log_lik" %in% criterion)  df_all <- df_all %>% left_join(df_ll, by = "id")
  
  # Sort by elpd if wanted
  if (sort && "loo" %in% criterion) {
    df_all <- df_all %>% arrange(-as.numeric(.data$elpd_diff))
  } else {
    df_all <- df_all %>% arrange(.data$id)
  }
  
  df_all <- df_all %>%
    mutate(elpd_diff = format(round(elpd_diff, digits = 0), nsmall = 0)) %>%
    relocate(.data$model_name, .after = id) %>%
    relocate(.data$Distribution, .after = .data$model_name) %>%
    relocate(.data$Link, .after = .data$Distribution) %>%
    relocate(.data$n_divergent, .after = last_col()) %>%
    relocate(.data$max_time, .after = last_col())

  return(df_all)
}


loo_compare_order <- function(loos) {
  tmp <- sapply(loos, function(x) {
    est <- x$estimates
    setNames(c(est), nm = c(rownames(est), paste0("se_", rownames(est))))
  })
  colnames(tmp) <- find_model_names(loos)
  rnms <- rownames(tmp)
  ord <- order(tmp[grep("^elpd", rnms), ], decreasing = TRUE)
  ord
}

find_model_names <- function(x) {
  stopifnot(is.list(x))
  out_names <- character(length(x))
  
  names1 <- names(x)
  names2 <- lapply(x, "attr", "model_name", exact = TRUE)
  names3 <- lapply(x, "[[", "model_name")
  names4 <- paste0("model", seq_along(x))
  
  for (j in seq_along(x)) {
    if (isTRUE(nzchar(names1[j]))) {
      out_names[j] <- names1[j]
    } else if (length(names2[[j]])) {
      out_names[j] <- names2[[j]]
    } else if (length(names3[[j]])) {
      out_names[j] <- names3[[j]]
    } else {
      out_names[j] <- names4[j]
    }
  }
  
  return(out_names)
}


#' Get Bayesian R2 for a list of models
#' 
#' Calculates the Bayesian R2 for a list of brms models using the \code{add_criterion} function.
#' 
#' @param fits a list of model fits in the order that you want to compare them
#' @param ... additional parameters passed to \code{add_criterion}
#' 
#' @author Darcy Webber \email{darcy@quantifish.co.nz}
#' 
#' @importFrom brms add_criterion
#' @import dplyr
#' @export
#' 
get_bayes_R2 <- function(fits, ...) {
  
  df <- NULL
  
  for (i in 1:length(fits)) {
    fit <- fits[[i]]
    if (!is.brmsfit(fit)) stop("fit is not an object of class brmsfit.")
    fit <- add_criterion(x = fit, criterion = "bayes_R2", ...)
    rdf <- data.frame(R2 = fit$criteria$bayes_R2) %>%
      mutate(Model = as.character(fit$formula)[1], Distribution = as.character(fit$family)[1], Link = as.character(fit$family)[2])
    df <- rbind(df, rdf)
  }
  
  df <- df %>% 
    group_by(.data$Model, .data$Distribution, .data$Link) %>% 
    summarise(R2 = mean(.data$R2)) %>%
    arrange(-.data$R2)
  
  df$diff <- c(0, diff(df$R2))
  
  return(df)
}
quantifish/influ2 documentation built on Dec. 14, 2024, 5:10 a.m.