R/calc_waic.R

Defines functions calc_waic

Documented in calc_waic

#' @title Calculate a block version of Watanabe-Akaike Information Criterion (WAIC)
#' @description This function is used to get the WAIC value when blocking members
#' @param vote_info A logical vote matrix (or a rollcall object) in which rows represent members and columns represent issues.
#' The entries should be FALSE ("No"), TRUE ("Yes"), or NA (missing data).
#' @param years_v A vector representing the time period for each vote in the model. This is defultly set as `NULL` for a static model.
#' @param prob_array An array of probabilities with three dimensions.
#' @return The block WAIC value for a static PUM or a vector of WAIC by time for a dynamic PUM.
#' @examples
#' \donttest{
#' # Long-running example
#' data(h116)
#' h116.c = preprocess_rollcall(h116)
#' hyperparams <- list(beta_mean = 0, beta_var = 1, alpha_mean = c(0, 0),
#'                     alpha_scale = 5, delta_mean = c(-2, 10), delta_scale = sqrt(10))
#' control <- list(num_iter = 2, burn_in = 0, keep_iter = 1, flip_rate = 0.1)
#' h116.c.pum <- sample_pum_static(h116.c, hyperparams,
#'                                   control, pos_leg = grep("SCALISE", rownames(h116.c$votes)),
#'                                   verbose = FALSE, pre_run = NULL, appended = FALSE)
#' h116.c.pum.predprob = predict_pum(h116.c, years_v = NULL, h116.c.pum)
#' h116.c.pum.waic = calc_waic(h116.c, prob_array = h116.c.pum.predprob)
#' }
#' @export
calc_waic <- function(vote_info, years_v = NULL, prob_array) {
  # Check and process input vote object
  if (is.matrix(vote_info)) {

    if (all(is.na(vote_info) | (vote_info %in% c(0, 1) & is.numeric(vote_info)))) {
      vote_m <- vote_info
    } else if (all(is.logical(vote_info))) {
      vote_m <- vote_info
      vote_m[vote_m == TRUE] <- 1
      vote_m[vote_m == FALSE] <- 0
    } else if (all(vote_info %in% c("T", "F", "NA"))) {
      vote_m <- vote_info
      vote_m[vote_m == "T"] <- 1
      vote_m[vote_m == "F"] <- 0
      vote_m[vote_m == "NA"] <- NA
    } else {
      invalid_values <- vote_info[!(is.na(vote_info) | vote_info %in% c(0, 1, TRUE, FALSE, "T", "F", "NA"))]
      if (length(invalid_values) > 0) {
        stop(paste("Invalid value found in your vote matrix:", paste(invalid_values, collapse = ", ")))
      }
    }

  } else if (is.list(vote_info) && "votes" %in% names(vote_info)) {
    vote_m <- vote_info$votes
    rownames_old <- rownames(vote_m)
    colnames_old <- colnames(vote_m)
    yea_values <- vote_info$codes$yea
    nay_values <- vote_info$codes$nay
    vote_m <- matrix(ifelse(vote_m %in% yea_values, 1,
                            ifelse(vote_m %in% nay_values, 0, NA)),
                     nrow = nrow(vote_m), ncol = ncol(vote_m))
    rownames(vote_m) <- rownames_old
    colnames(vote_m) <- colnames_old

  } else {
    stop("vote_info is not in a valid type. It has to be a matrix or a rollcall object.")
  }

  n_rows <- dim(prob_array)[1]
  n_cols <- dim(prob_array)[2]
  n_samples <- dim(prob_array)[3]

  issue_cols <- colnames(prob_array[,,1])
  vote_cols <- colnames(vote_m)
  common_cols <- intersect(vote_cols, issue_cols)
  case_vote_m <- vote_m[, common_cols]
  leg_cols = rownames(prob_array[,,1])
  vote_rows <- rownames(case_vote_m)
  common_leg <- intersect(vote_rows, leg_cols)
  case_vote_m <- case_vote_m[common_leg,]

  block_leg = TRUE
  block_issue = FALSE

  if (is.null(years_v)){
    # static
    if (block_leg) {
      num_votes = nrow(case_vote_m)
    } else if (block_issue) {
      num_votes = ncol(case_vote_m)
    } else {
      num_votes = sum(!is.na(as.vector(case_vote_m)))
    }

    mean_prob <- rep(-Inf, num_votes)
    mean_log_prob <- rep(0, num_votes)
    log_prob_var <- rep(0, num_votes)
    num_iter = 0

    for (j in 1:n_samples){
      log_prob <- case_vote_m * log(prob_array[,,j]) +
        (1 - case_vote_m) * log(1 - prob_array[,,j])
      if (block_leg) {
        log_prob <- rowSums(log_prob, na.rm = T)
      } else if (block_issue) {
        log_prob <- colSums(log_prob, na.rm = T)
      } else {
        log_prob <- log_prob[!is.na(log_prob)]
      }
      mean_prob <- pmax(mean_prob, log_prob) + log(1 + exp(pmin(mean_prob, log_prob)
                                                           - pmax(mean_prob, log_prob)))
      next_mean_log_prob = (num_iter * mean_log_prob + log_prob) / (num_iter + 1)
      log_prob_var = log_prob_var +
        (log_prob - mean_log_prob) * (log_prob - next_mean_log_prob)
      mean_log_prob = next_mean_log_prob
      num_iter = num_iter + 1

      delta <- log_prob - mean_log_prob
      mean_log_prob <- mean_log_prob + delta / num_iter
      log_prob_var <- log_prob_var + delta * (log_prob - mean_log_prob)
    }
    return(-2*sum(mean_prob - log(num_iter) - (log_prob_var) / (num_iter - 1)))

  } else {

    # dynamic
    # Initialize WAIC components
    num_points <- sum(apply(case_vote_m, 1, function(row) {
      interested_inds <- which(!is.na(row))
      length(table(years_v[interested_inds]))
    }))
    # mean_prob <- rep(0, num_points)
    mean_prob = rep(-Inf, num_points)
    mean_log_prob <- rep(0, num_points)
    log_prob_var <- rep(0, num_points)
    num_iter <- 0

    for (j in 1:n_samples) {
      log_prob <- case_vote_m * log(prob_array[, , j]) +
        (1 - case_vote_m) * log(1 - prob_array[, , j])

      # Group log probabilities by year
      data_prob_m <- as.data.frame(cbind(years_v, t(log_prob)))
      split_data <- split(data_prob_m[-1], data_prob_m$years_v)
      log_prob <- lapply(split_data, function(df) {
        apply(df, 2, function(x) if (all(is.na(x))) NA else sum(x, na.rm = TRUE))
      })
      log_prob <- do.call(rbind, log_prob)

      # Filter non-NA rows
      non_na_rows <- which(!is.na(log_prob), arr.ind = TRUE)[, "row"]
      years_vector <- as.numeric(rownames(log_prob)[non_na_rows])
      log_prob <- log_prob[!is.na(log_prob)]

      # Update WAIC components
      if (num_iter == 0){
        mean_prob = log_prob
      }
      # mean_prob <- mean_prob + exp(log_prob)
      mean_prob = pmax(mean_prob, log_prob) +
        log(1 + exp(pmin(mean_prob, log_prob) - pmax(mean_prob, log_prob)))
      next_mean_log_prob <- (num_iter * mean_log_prob + log_prob) / (num_iter + 1)
      log_prob_var <- log_prob_var + (log_prob - mean_log_prob) * (log_prob - next_mean_log_prob)
      mean_log_prob <- next_mean_log_prob
      num_iter <- num_iter + 1
    }

    # Calculate WAIC
    waic_result <- -2 * (mean_prob -log(num_iter) - (log_prob_var) / (num_iter - 1))
    # Aggregate WAIC by year
    waic_by_year <- tapply(waic_result, years_vector, sum, na.rm = TRUE)

    return(waic_by_year)
  }
}

Try the pumBayes package in your browser

Any scripts or data that you put into this service are public.

pumBayes documentation built on June 8, 2025, 11:29 a.m.