Nothing
#' Cross-validation testing for DCPO
#'
#' \code{dcpo_xvt} performs a single cross-validation test for DCPO
#'
#' @param dcpo_input a data frame of survey items and marginals generated by \code{DCPOtools::dcpo_setup}
#' @param fold_number an integer indicating the number of the fold to treated as test data in the current analysis
#' @param number_of_folds an integer indicating the total number of folds
#' @param fold_seed a seed for reproducibly randomly assigning observations to folds; when a complete set of k-fold cross-validations is to be performed, the same seed should be used for all
#' @param chime play chime when complete?
#' @param ... arguments to be passed to \code{rstan::stan}. See \code{dcpo}.
#
#' @details \code{dcpo_xvt} performs a single cross-validation test of a DCPO estimation. To perform
#' a complete k-fold cross-validation, call it repeatedly, changing only the fold_number argument.
#'
#' @examples
#' \donttest{
#' # Single cross-validation test with 25% test set
#' demsup_xvtest_25pct <- dcpo_xvt(demsup_data,
#' number_of_folds = 4,
#' iter = 300,
#' chains = 1) # 1 chain/300 iterations for example only; use defaults
#' }
#'
#' @return a stanfit object
#'
#' @import rstan
#' @importFrom dplyr '%>%' group_by mutate ungroup arrange summarize filter mutate_at select mutate_all n_distinct vars matches pull if_else
#' @importFrom forcats as_factor
#' @importFrom janitor clean_names
#' @importFrom tidyr spread
#' @importFrom beepr beep
#' @importFrom stats setNames
#'
#' @export
dcpo_xvt <- function(dcpo_input,
fold_number = 1,
number_of_folds = 10,
fold_seed = 324,
chime = TRUE,
...) {
country <- year <- question <- fold <- test <- kk <- tt <- qq <- rr <- n_r <- y_r <- n <- years <- countries <- `.` <- NULL
set.seed(fold_seed)
dat <- dcpo_input$data %>%
group_by(country, year, question) %>%
mutate(fold = runif(min = 1, max = number_of_folds, n = 1) %>% round()) %>%
ungroup() %>%
mutate(test = as.numeric(fold == fold_number)) %>%
arrange(test, kk, year, question) %>%
mutate(kk = as_factor(country),
tt = year - min(year) + 1,
qq = as_factor(question))
scale_q <- dcpo_input$data_args$scale_q
delta <- dcpo_input$data_args$delta
scale_cp <- dcpo_input$data_args$scale_cp
use_delta <- dat %>%
group_by(test, qq, kk) %>%
summarize(years = n_distinct(year)) %>%
ungroup() %>%
filter(!test) %>%
spread(key = kk, value = years, fill = 0) %>%
mutate(countries = rowSums(.[, c(-1, -2)] > 1)) %>%
mutate_at(vars(-test, -qq, -countries),
~ if_else(. > 1 & countries > 2 & qq != scale_q, 1, 0)) %>%
select(-test, -qq, -countries) %>%
{if (!delta) mutate_all(., ~ 0) else .}
scale_item_matrix <- dat %>%
group_by(qq, rr) %>%
summarize(n = sum(n_r)) %>%
ungroup() %>%
spread(key = rr, value = n, fill = 0) %>%
janitor::clean_names() %>%
mutate_at(vars(matches(paste0("x\\d+$"))), ~if_else(. > 0, 10, 0)) %>%
mutate_at(vars(matches(paste0("x", scale_cp, "$"))), ~if_else(qq == scale_q, . + 1, 0)) %>%
mutate_at(vars(matches(paste0("x\\d+$"))), ~if_else(. > 0, . - 10, 0)) %>%
select(-qq) %>%
as.matrix()
stopifnot(sum(scale_item_matrix) == 1)
dcpo_input_fold <- list( K = max(as.numeric(dat$kk)),
Time = max(dat$tt),
Q = max(as.numeric(dat$qq)),
R = max(dat$rr),
N = nrow(dat %>% filter(!test)),
kk = as.numeric(dat %>% filter(!test) %>% pull(kk)),
tt = as.numeric(dat %>% filter(!test) %>% pull(tt)),
qq = as.numeric(dat %>% filter(!test) %>% pull(qq)),
rr = dat %>% filter(!test) %>% pull(rr),
y_r = dat %>% filter(!test) %>% pull(y_r),
n_r = dat %>% filter(!test) %>% pull(n_r),
N_test = nrow(dat %>% filter(test==1)),
kk_test = as.numeric(dat %>% filter(test==1) %>% pull(kk)),
tt_test = as.numeric(dat %>% filter(test==1) %>% pull(tt)),
qq_test = as.numeric(dat %>% filter(test==1) %>% pull(qq)),
rr_test = dat %>% filter(test==1) %>% pull(rr),
n_r_test = dat %>% filter(test==1) %>% pull(n_r),
fixed_cutp = scale_item_matrix,
use_delta = use_delta,
data = dat,
data_args = dcpo_input$data_args)
dcpo_input_fold_names <- c("K", "T", "Q", "R", "N", "kk", "tt", "qq", "rr", "y_r", "n_r", "N_test", "kk_test", "tt_test", "qq_test", "rr_test", "n_r_test", "fixed_cutp", "use_delta", "data", "data_args")
dcpo_model <- stanmodels$dcpo_kfold
stan_args <- list(object = dcpo_model,
data = setNames(dcpo_input_fold, dcpo_input_fold_names),
...)
if (!length(stan_args$control)) {
stan_args$control <- list(adapt_delta = 0.99, stepsize = 0.005, max_treedepth = 14)
}
if (!length(stan_args$seed)) {
stan_args$seed <- 324
}
if (!length(stan_args$thin)) {
stan_args$thin <- 2
}
if (!length(stan_args$pars)) {
stan_args$pars <- "y_r_test"
}
if (!length(stan_args$cores)) {
stan_args$cores <- min(stan_args$chains, parallel::detectCores()/2)
}
dcpo_output_fold <- do.call(rstan::sampling, stan_args)
out1 <- list(dcpo_input_fold = setNames(dcpo_input_fold, dcpo_input_fold_names),
dcpo_output_fold = dcpo_output_fold,
xvt_args = list(fold_number = fold_number, number_of_folds = number_of_folds, fold_seed = fold_seed))
# Chime
if(chime) {
try(beep())
}
return(out1)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.