R/get_xvt_results.R

Defines functions assess_kfold_convergence xvt get_xvt_results

Documented in get_xvt_results

#' Get results of DCPO cross-validation testing
#'
#' \code{get_xvt_results} performs a single cross-validation test for dcpo's estimates of cross-national public opinion
#'
#' @param dcpo_xvt_output output from a single call to \code{DCPO::dcpo_xvt} or a k-fold test list of such output generated by purrr::map
#' @param ci an integer indicating the desired width of credible interval for coverage testing; 80 is the default.
#
#' @examples
#' \donttest{
#' # Single cross-validation test with 25% test set
#' demsup_xvtest_25pct <- dcpo_xvt(demsup_data,
#'                            chime = FALSE,
#'                            number_of_folds = 4,
#'                            iter = 300,
#'                            chains = 1) # 1 chain/300 iterations for example only; use defaults
#'
#' get_xvt_results(demsup_xvtest_25pct)
#' }
#'
#' @return a stanfit object
#'
#' @import rstan
#' @importFrom dplyr mutate group_by summarize_if ungroup select mutate_at bind_rows first nth filter vars contains left_join
#' @importFrom purrr map_df
#' @importFrom tibble tibble rownames_to_column
#' @importFrom stats quantile runif sd
#'
#' @export
get_xvt_results <- function(dcpo_xvt_output, ci = 80) {

  model <- country_name <- mae <- country_means <- improv_over_cmmae <- NULL

  kfc <- assess_kfold_convergence(dcpo_xvt_output)
  if (nrow(kfc) > 0) {
    warning("These estimates have not yet fully converged. Increase the number of iterations.\n")
    warning(kfc)
  }

 if (length(names(dcpo_xvt_output)[3]) > 0) {
   xvt_results <- xvt(dcpo_xvt_output, ci)
 } else {
   xvt_results <- purrr::map_df(dcpo_xvt_output, function(x) {
     xvt(dcpo_xvt_output = x, ci)
     })

   mean_results <- xvt_results %>%
     mutate(country_means = (model == "country means")) %>%
     group_by(country_means) %>%
     summarize_if(is.numeric, mean) %>%
     mutate(model = c("k-fold mean", "mean country means")) %>%
     ungroup() %>%
     select(-country_means) %>%
     mutate(mae = round(mae, 3),
            improv_over_cmmae = round(improv_over_cmmae, 1)) %>%
     mutate_at(vars(contains("coverage")), round, 1)

   xvt_results <- xvt_results %>%
     bind_rows(mean_results)
 }

  return(xvt_results)
}

xvt <- function(dcpo_xvt_output, ci) {
  test <- country <- y_r <- n_r <- NULL

  test_data <- dcpo_xvt_output %>%
    dplyr::first() %>%
    dplyr::nth(-2) %>%
    dplyr::filter(test == 1)

  y_r_test_all <- dcpo_xvt_output %>%
    dplyr::nth(2) %>%
    rstan::extract(pars = "y_r_test") %>%
    dplyr::first()

  model_mae <- mean(abs(test_data$y_r/test_data$n_r - (colMeans(y_r_test_all)/test_data$n_r))) %>%
    round(3)

  country_mean <- dcpo_xvt_output %>%
    dplyr::first() %>%
    dplyr::nth(-2) %>%
    dplyr::filter(test == 0) %>%
    dplyr::group_by(country) %>%
    dplyr::summarize(country_mean = mean(y_r/n_r))

  cmmae_test <- test_data %>%
    dplyr::left_join(country_mean, by = "country")

  country_mean_mae <- mean(abs((cmmae_test$y_r/cmmae_test$n_r - cmmae_test$country_mean))) %>%
    round(3)

  improv_vs_cmmae <- round((country_mean_mae - model_mae)/country_mean_mae * 100, 1)

  coverage <- (mean(test_data$y_r >= apply(y_r_test_all, 2, quantile, (1-ci/100)/2) &
                      test_data$y_r <= apply(y_r_test_all, 2, quantile, 1-(1-ci/100)/2)) * 100) %>%
    round(1)

  xvt_results <- tibble::tibble(model = c(paste0("Fold ", dcpo_xvt_output$xvt_args$fold_number, " of ", dcpo_xvt_output$xvt_args$number_of_folds, " (", dcpo_xvt_output$xvt_args$fold_seed,")"), "country means"),
                                mae = c(model_mae, country_mean_mae),
                                improv_over_cmmae = c(improv_vs_cmmae, NA))
  ci_name <- paste0("coverage", ci, "ci")
  xvt_results[[ci_name]] <- c(coverage, NA)

  return(xvt_results)
}

assess_kfold_convergence <- function(dcpo_xvt_output) {
  parameter <- mean <- se_mean <- sd <- `2.5%` <- `50%` <- `97.5%` <- n_eff <- Rhat <- fold <- NULL
  if (length(names(dcpo_xvt_output)[3]) > 0) {
    kfc <- dcpo_xvt_output %>%
      dplyr::nth(2) %>%
      summary() %>%
      `[[`("summary") %>%
      as.data.frame() %>%
      rownames_to_column(var = "parameter") %>%
      mutate(fold = dcpo_xvt_output$xvt_args$fold_number) %>%
      filter(Rhat > 1.1) %>%
      select(parameter, mean, se_mean, sd, `2.5%`, `50%`, `97.5%`, n_eff, Rhat, fold)
  } else {
    kfc <- dcpo_xvt_output %>%
      purrr::map_df(function(x) {
      x %>%
        dplyr::nth(2) %>%
        summary() %>%
        `[[`("summary") %>%
        as.data.frame() %>%
        rownames_to_column(var = "parameter") %>%
        mutate(fold = x$xvt_args$fold_number) %>%
        filter(Rhat > 1.1) %>%
        select(parameter, mean, se_mean, sd, `2.5%`, `50%`, `97.5%`, n_eff, Rhat, fold)
    })
  }
  return(kfc)
}

Try the DCPO package in your browser

Any scripts or data that you put into this service are public.

DCPO documentation built on July 8, 2020, 7:03 p.m.