Nothing
#' Get results of DCPO cross-validation testing
#'
#' \code{get_xvt_results} performs a single cross-validation test for dcpo's estimates of cross-national public opinion
#'
#' @param dcpo_xvt_output output from a single call to \code{DCPO::dcpo_xvt} or a k-fold test list of such output generated by purrr::map
#' @param ci an integer indicating the desired width of credible interval for coverage testing; 80 is the default.
#
#' @examples
#' \donttest{
#' # Single cross-validation test with 25% test set
#' demsup_xvtest_25pct <- dcpo_xvt(demsup_data,
#' chime = FALSE,
#' number_of_folds = 4,
#' iter = 300,
#' chains = 1) # 1 chain/300 iterations for example only; use defaults
#'
#' get_xvt_results(demsup_xvtest_25pct)
#' }
#'
#' @return a stanfit object
#'
#' @import rstan
#' @importFrom dplyr mutate group_by summarize_if ungroup select mutate_at bind_rows first nth filter vars contains left_join
#' @importFrom purrr map_df
#' @importFrom tibble tibble rownames_to_column
#' @importFrom stats quantile runif sd
#'
#' @export
get_xvt_results <- function(dcpo_xvt_output, ci = 80) {
model <- country_name <- mae <- country_means <- improv_over_cmmae <- NULL
kfc <- assess_kfold_convergence(dcpo_xvt_output)
if (nrow(kfc) > 0) {
warning("These estimates have not yet fully converged. Increase the number of iterations.\n")
warning(kfc)
}
if (length(names(dcpo_xvt_output)[3]) > 0) {
xvt_results <- xvt(dcpo_xvt_output, ci)
} else {
xvt_results <- purrr::map_df(dcpo_xvt_output, function(x) {
xvt(dcpo_xvt_output = x, ci)
})
mean_results <- xvt_results %>%
mutate(country_means = (model == "country means")) %>%
group_by(country_means) %>%
summarize_if(is.numeric, mean) %>%
mutate(model = c("k-fold mean", "mean country means")) %>%
ungroup() %>%
select(-country_means) %>%
mutate(mae = round(mae, 3),
improv_over_cmmae = round(improv_over_cmmae, 1)) %>%
mutate_at(vars(contains("coverage")), round, 1)
xvt_results <- xvt_results %>%
bind_rows(mean_results)
}
return(xvt_results)
}
xvt <- function(dcpo_xvt_output, ci) {
test <- country <- y_r <- n_r <- NULL
test_data <- dcpo_xvt_output %>%
dplyr::first() %>%
dplyr::nth(-2) %>%
dplyr::filter(test == 1)
y_r_test_all <- dcpo_xvt_output %>%
dplyr::nth(2) %>%
rstan::extract(pars = "y_r_test") %>%
dplyr::first()
model_mae <- mean(abs(test_data$y_r/test_data$n_r - (colMeans(y_r_test_all)/test_data$n_r))) %>%
round(3)
country_mean <- dcpo_xvt_output %>%
dplyr::first() %>%
dplyr::nth(-2) %>%
dplyr::filter(test == 0) %>%
dplyr::group_by(country) %>%
dplyr::summarize(country_mean = mean(y_r/n_r))
cmmae_test <- test_data %>%
dplyr::left_join(country_mean, by = "country")
country_mean_mae <- mean(abs((cmmae_test$y_r/cmmae_test$n_r - cmmae_test$country_mean))) %>%
round(3)
improv_vs_cmmae <- round((country_mean_mae - model_mae)/country_mean_mae * 100, 1)
coverage <- (mean(test_data$y_r >= apply(y_r_test_all, 2, quantile, (1-ci/100)/2) &
test_data$y_r <= apply(y_r_test_all, 2, quantile, 1-(1-ci/100)/2)) * 100) %>%
round(1)
xvt_results <- tibble::tibble(model = c(paste0("Fold ", dcpo_xvt_output$xvt_args$fold_number, " of ", dcpo_xvt_output$xvt_args$number_of_folds, " (", dcpo_xvt_output$xvt_args$fold_seed,")"), "country means"),
mae = c(model_mae, country_mean_mae),
improv_over_cmmae = c(improv_vs_cmmae, NA))
ci_name <- paste0("coverage", ci, "ci")
xvt_results[[ci_name]] <- c(coverage, NA)
return(xvt_results)
}
assess_kfold_convergence <- function(dcpo_xvt_output) {
parameter <- mean <- se_mean <- sd <- `2.5%` <- `50%` <- `97.5%` <- n_eff <- Rhat <- fold <- NULL
if (length(names(dcpo_xvt_output)[3]) > 0) {
kfc <- dcpo_xvt_output %>%
dplyr::nth(2) %>%
summary() %>%
`[[`("summary") %>%
as.data.frame() %>%
rownames_to_column(var = "parameter") %>%
mutate(fold = dcpo_xvt_output$xvt_args$fold_number) %>%
filter(Rhat > 1.1) %>%
select(parameter, mean, se_mean, sd, `2.5%`, `50%`, `97.5%`, n_eff, Rhat, fold)
} else {
kfc <- dcpo_xvt_output %>%
purrr::map_df(function(x) {
x %>%
dplyr::nth(2) %>%
summary() %>%
`[[`("summary") %>%
as.data.frame() %>%
rownames_to_column(var = "parameter") %>%
mutate(fold = x$xvt_args$fold_number) %>%
filter(Rhat > 1.1) %>%
select(parameter, mean, se_mean, sd, `2.5%`, `50%`, `97.5%`, n_eff, Rhat, fold)
})
}
return(kfc)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.