R/dcpo_xvt.R

Defines functions dcpo_xvt

Documented in dcpo_xvt

#' Cross-validation testing for DCPO
#'
#' \code{dcpo_xvt} performs a single cross-validation test for DCPO
#'
#' @param dcpo_input a data frame of survey items and marginals generated by \code{DCPOtools::dcpo_setup}
#' @param fold_number an integer indicating the number of the fold to treated as test data in the current analysis
#' @param number_of_folds an integer indicating the total number of folds
#' @param fold_seed a seed for reproducibly randomly assigning observations to folds; when a complete set of k-fold cross-validations is to be performed, the same seed should be used for all
#' @param chime play chime when complete?
#' @param ... arguments to be passed to \code{rstan::stan}. See \code{dcpo}.
#
#' @details \code{dcpo_xvt} performs a single cross-validation test of a DCPO estimation.  To perform
#' a complete k-fold cross-validation, call it repeatedly, changing only the fold_number argument.
#'
#' @examples
#' \donttest{
#' # Single cross-validation test with 25% test set
#' demsup_xvtest_25pct <- dcpo_xvt(demsup_data,
#'                            number_of_folds = 4,
#'                            iter = 300,
#'                            chains = 1) # 1 chain/300 iterations for example only; use defaults
#' }
#'
#' @return a stanfit object
#'
#' @import rstan
#' @importFrom dplyr '%>%' group_by mutate ungroup arrange summarize filter mutate_at select mutate_all n_distinct vars matches pull if_else
#' @importFrom forcats as_factor
#' @importFrom janitor clean_names
#' @importFrom tidyr spread
#' @importFrom beepr beep
#' @importFrom stats setNames
#'
#' @export

dcpo_xvt <- function(dcpo_input,
                       fold_number = 1,
                       number_of_folds = 10,
                       fold_seed = 324,
                       chime = TRUE,
                       ...) {

  country <- year <- question <- fold <- test <- kk <- tt <- qq <- rr <- n_r <- y_r <- n <- years <- countries <- `.` <- NULL

  set.seed(fold_seed)

  dat <- dcpo_input$data %>%
    group_by(country, year, question) %>%
    mutate(fold = runif(min = 1, max = number_of_folds, n = 1) %>% round()) %>%
    ungroup() %>%
    mutate(test = as.numeric(fold == fold_number)) %>%
    arrange(test, kk, year, question) %>%
    mutate(kk = as_factor(country),
           tt = year - min(year) + 1,
           qq = as_factor(question))

  scale_q <- dcpo_input$data_args$scale_q
  delta <- dcpo_input$data_args$delta
  scale_cp <- dcpo_input$data_args$scale_cp

  use_delta <- dat %>%
    group_by(test, qq, kk) %>%
    summarize(years = n_distinct(year)) %>%
    ungroup() %>%
    filter(!test) %>%
    spread(key = kk, value = years, fill = 0) %>%
    mutate(countries = rowSums(.[, c(-1, -2)] > 1)) %>%
    mutate_at(vars(-test, -qq, -countries),
              ~ if_else(. > 1 & countries > 2 & qq != scale_q, 1, 0)) %>%
    select(-test, -qq, -countries) %>%
    {if (!delta) mutate_all(., ~ 0) else .}

  scale_item_matrix <- dat %>%
    group_by(qq, rr) %>%
    summarize(n = sum(n_r)) %>%
    ungroup() %>%
    spread(key = rr, value = n, fill = 0) %>%
    janitor::clean_names() %>%
    mutate_at(vars(matches(paste0("x\\d+$"))), ~if_else(. > 0, 10, 0)) %>%
    mutate_at(vars(matches(paste0("x", scale_cp, "$"))), ~if_else(qq == scale_q, . + 1, 0)) %>%
    mutate_at(vars(matches(paste0("x\\d+$"))), ~if_else(. > 0, . - 10, 0)) %>%
    select(-qq) %>%
    as.matrix()
  stopifnot(sum(scale_item_matrix) == 1)

  dcpo_input_fold <- list( K          = max(as.numeric(dat$kk)),
                           Time       = max(dat$tt),
                           Q          = max(as.numeric(dat$qq)),
                           R          = max(dat$rr),
                           N          = nrow(dat %>% filter(!test)),
                           kk         = as.numeric(dat %>% filter(!test) %>% pull(kk)),
                           tt         = as.numeric(dat %>% filter(!test) %>% pull(tt)),
                           qq         = as.numeric(dat %>% filter(!test) %>% pull(qq)),
                           rr         = dat %>% filter(!test) %>% pull(rr),
                           y_r        = dat %>% filter(!test) %>% pull(y_r),
                           n_r        = dat %>% filter(!test) %>% pull(n_r),
                           N_test     = nrow(dat %>% filter(test==1)),
                           kk_test    = as.numeric(dat %>% filter(test==1) %>% pull(kk)),
                           tt_test    = as.numeric(dat %>% filter(test==1) %>% pull(tt)),
                           qq_test    = as.numeric(dat %>% filter(test==1) %>% pull(qq)),
                           rr_test    = dat %>% filter(test==1) %>% pull(rr),
                           n_r_test   = dat %>% filter(test==1) %>% pull(n_r),
                           fixed_cutp = scale_item_matrix,
                           use_delta  = use_delta,
                           data       = dat,
                           data_args  = dcpo_input$data_args)

  dcpo_input_fold_names <- c("K", "T", "Q", "R", "N", "kk", "tt", "qq", "rr", "y_r", "n_r", "N_test", "kk_test", "tt_test", "qq_test", "rr_test", "n_r_test", "fixed_cutp", "use_delta", "data", "data_args")

  dcpo_model <- stanmodels$dcpo_kfold
  stan_args <- list(object = dcpo_model,
                    data = setNames(dcpo_input_fold, dcpo_input_fold_names),
                    ...)
  if (!length(stan_args$control)) {
    stan_args$control <- list(adapt_delta = 0.99, stepsize = 0.005, max_treedepth = 14)
  }
  if (!length(stan_args$seed)) {
    stan_args$seed <- 324
  }
  if (!length(stan_args$thin)) {
    stan_args$thin <- 2
  }
  if (!length(stan_args$pars)) {
    stan_args$pars <- "y_r_test"
  }
  if (!length(stan_args$cores)) {
    stan_args$cores <- min(stan_args$chains, parallel::detectCores()/2)
  }
  dcpo_output_fold <- do.call(rstan::sampling, stan_args)

  out1 <- list(dcpo_input_fold = setNames(dcpo_input_fold, dcpo_input_fold_names),
               dcpo_output_fold = dcpo_output_fold,
               xvt_args = list(fold_number = fold_number, number_of_folds = number_of_folds, fold_seed = fold_seed))

  # Chime
  if(chime) {
    try(beep())
  }

  return(out1)
}

Try the DCPO package in your browser

Any scripts or data that you put into this service are public.

DCPO documentation built on July 8, 2020, 7:03 p.m.