R/extract_dcpo_results.R

Defines functions extract_dcpo_results

Documented in extract_dcpo_results

#' Extract DCPO Results
#'
#' \code{summarize_dcpo_results} is a convenience function that produces summary statistics of the main parameters of a DCPO object along with the relevant identifying information (country, year, question, and cutpoint).
#'
#' @param dcpo_input the data frame of survey items and marginals generated by \code{DCPOtools::dcpo_setup} previously passed to \code{dcpo} to generate the DCPO object passed as \code{dcpo_output}
#' @param dcpo_output a DCPO object output by \code{DCPO::dcpo} or \code{DCPOtools::dcpo}
#' @param pars a character vector of parameter names to be summarized from the \code{DCPO} model: theta (mean public opinion), sigma (polarization in public opinion), alpha (question dispersion), beta (question-cutpoint difficulty), and/or delta (country-specific question bias)
#' @param probs a numeric vector of quantiles of interest; the default is c(.1, .9)
#'
#' @examples
#' \donttest{
#' out1 <- dcpo(demsup_data,
#'              chime = FALSE,
#'              chains = 1,
#'              iter = 300) # 1 chain/300 iterations for example purposes only; use defaults
#'
#' theta_results <- summarize_dcpo_results(dcpo_input = demsup_data,
#'                                         dcpo_output = out1,
#'                                         pars = "theta")
#' }
#'
#' @return a tibble
#'
#' @importFrom rstan summary
#' @importFrom dplyr mutate group_by summarize select first arrange
#' @importFrom purrr map_df
#' @importFrom tibble as_tibble rownames_to_column
#' @importFrom posterior quantile2
#'
#' @export

extract_dcpo_results <- function(dcpo_input,
                                 dcpo_output,
                                 par) {

    if (!par %in% c("theta", "sigma", "alpha", "beta", "delta")) {
        errorCondition('par must be specified as one of "theta", "sigma", "alpha", "beta", or "delta"')
    }

    question <- country <- year <- parameter <- variable <- kk <- tt <- qq <- rr <- NULL

    dat <- dcpo_input$data

    qcodes <- dat %>%
        dplyr::group_by(question) %>%
        filter(rr == max(rr)) %>%
        dplyr::summarize(qq = first(qq) %>%
                             as.numeric(),
                         n = n(),
                         rr_max = max(rr))

    kcodes <- dat %>%
        dplyr::group_by(country) %>%
        dplyr::summarize(kk = first(kk) %>%
                             as.numeric())

    tcodes <- dat %>%
        dplyr::group_by(year) %>%
        dplyr::summarize(tt = first(tt))


    ktcodes <- dat %>%
        dplyr::group_by(country) %>%
        dplyr::summarize(first_yr = min(year),
                         last_yr = max(year))

    if ("R6" %in% class(dcpo_output)) {
        fit <- dcpo_output

        suppressWarnings(
            res <- fit$draws(par,
                             format = "df") %>%
                pivot_longer(starts_with(par),
                             values_to = par) %>%
                {if (par == "theta") {
                    dplyr::mutate(., tt = as.numeric(gsub("theta\\[(\\d+),\\d+\\]",
                                                          "\\1",
                                                          name)),
                                  kk = as.numeric(gsub("theta\\[\\d+,(\\d+)\\]",
                                                       "\\1",
                                                       name))) %>%
                        dplyr::left_join(kcodes, by = "kk") %>%
                        dplyr::left_join(tcodes, by = "tt") %>%
                        dplyr::mutate(year = if_else(tt == 1,
                                                     as.integer(year),
                                                     as.integer(min(year, na.rm = TRUE) + tt - 1))) %>%
                        dplyr::left_join(ktcodes, by = "country") %>%
                        dplyr::filter(year >= first_yr & year <= last_yr) %>%
                        dplyr::arrange(kk, tt) %>%
                        dplyr::select(country, year, theta, draw = .draw)
                } else if (par == "sigma") {
                    fit$draws("sigma",
                              format = "df") %>%
                        dplyr::mutate(tt = as.numeric(gsub("sigma\\[(\\d+),\\d+\\]",
                                                           "\\1",
                                                           variable)),
                                      kk = as.numeric(gsub("sigma\\[\\d+,(\\d+)\\]",
                                                           "\\1",
                                                           variable))) %>%
                        dplyr::left_join(kcodes, by = "kk") %>%
                        dplyr::left_join(tcodes, by = "tt") %>%
                        dplyr::mutate(year = if_else(tt == 1,
                                                     as.integer(year),
                                                     as.integer(min(year, na.rm = TRUE) + tt - 1))) %>%
                        dplyr::left_join(ktcodes, by = "country") %>%
                        dplyr::filter(year >= first_yr & year <= last_yr) %>%
                        dplyr::arrange(kk, tt) %>%
                        dplyr::select(country, year,
                                      starts_with("me"),
                                      sd, mad, starts_with("q"),
                                      rhat, starts_with("ess"),
                                      variable, kk, tt)
                } else if (par == "alpha") {
                    fit$draws("alpha",
                              format = "df") %>%
                        dplyr::mutate(qq = as.numeric(gsub("alpha\\[(\\d+)]",
                                                           "\\1",
                                                           variable))) %>%
                        dplyr::left_join(qcodes, by = "qq") %>%
                        dplyr::arrange(qq) %>%
                        dplyr::select(question, n, everything())
                } else if (par == "beta") {
                    fit$draws("beta",
                              format = "df") %>%
                        dplyr::mutate(rr = as.numeric(gsub("beta\\[(\\d+),\\d+\\]",
                                                           "\\1",
                                                           variable)),
                                      qq = as.numeric(gsub("beta\\[\\d+,(\\d+)\\]",
                                                           "\\1",
                                                           variable)))%>%
                        dplyr::left_join(qcodes, by = "qq") %>%
                        dplyr::arrange(qq, rr) %>%
                        dplyr::filter(rr <= rr_max) %>%
                        dplyr::select(question, n, everything())
                } else if (par == "delta") {
                    fit$draws("delta",
                              format = "df") %>%
                        dplyr::mutate(qq = as.numeric(gsub("delta\\[(\\d+),\\d+\\]",
                                                           "\\1",
                                                           parameter)),
                                      kk = as.numeric(gsub("delta\\[\\d+,(\\d+)\\]",
                                                           "\\1",
                                                           parameter)))%>%
                        dplyr::left_join(qcodes, by = "qq") %>%
                        dplyr::left_join(kcodes, by = "kk") %>%
                        dplyr::arrange(qq) %>%
                        dplyr::select(question, n, country, everything())
                }}) %>%
            dplyr::group_split(draw) %>%
            purrr::map(~select(.x, -draw))

    } else {
        message("extract_dcpo_results() cannot currently handle rstan output")
        # res <- map_df(pars, function(par) {
        #     if (par == "theta") {
        #         rstan::summary(dcpo_output, pars = "theta", probs = probs) %>%
        #             dplyr::first() %>%
        #             as.data.frame() %>%
        #             tibble::rownames_to_column("parameter") %>%
        #             tibble::as_tibble() %>%
        #             dplyr::mutate(tt = as.numeric(gsub("theta\\[(\\d+),\\d+\\]",
        #                                                "\\1",
        #                                                parameter)),
        #                           kk = as.numeric(gsub("theta\\[\\d+,(\\d+)\\]",
        #                                                "\\1",
        #                                                parameter))) %>%
        #             dplyr::left_join(kcodes, by = "kk") %>%
        #             dplyr::left_join(tcodes, by = "tt") %>%
        #             dplyr::mutate(year = if_else(tt == 1,
        #                                          as.integer(year),
        #                                          as.integer(min(year, na.rm = TRUE) + tt - 1))) %>%
        #             dplyr::left_join(ktcodes, by = "country") %>%
        #             dplyr::filter(year >= first_yr & year <= last_yr) %>%
        #             dplyr::arrange(kk, tt) %>%
        #             dplyr::select(-first_yr, -last_yr)
        #     } else if (par == "sigma") {
        #         rstan::summary(dcpo_output, pars = "sigma", probs = probs) %>%
        #             dplyr::first() %>%
        #             as.data.frame() %>%
        #             tibble::rownames_to_column("parameter") %>%
        #             tibble::as_tibble() %>%
        #             dplyr::mutate(tt = as.numeric(gsub("sigma\\[(\\d+),\\d+\\]",
        #                                                "\\1",
        #                                                parameter)),
        #                           kk = as.numeric(gsub("sigma\\[\\d+,(\\d+)\\]",
        #                                                "\\1",
        #                                                parameter))) %>%
        #             dplyr::left_join(kcodes, by = "kk") %>%
        #             dplyr::left_join(tcodes, by = "tt") %>%
        #             dplyr::arrange(kk, tt)
        #     } else if (par == "alpha") {
        #         rstan::summary(dcpo_output, pars = "alpha", probs = probs) %>%
        #             dplyr::first() %>%
        #             as.data.frame() %>%
        #             tibble::rownames_to_column("parameter") %>%
        #             tibble::as_tibble() %>%
        #             dplyr::mutate(qq = as.numeric(gsub("alpha\\[(\\d+)]",
        #                                                "\\1",
        #                                                parameter))) %>%
        #             dplyr::left_join(qcodes, by = "qq") %>%
        #             dplyr::arrange(qq)
        #     } else if (par == "beta") {
        #         rstan::summary(dcpo_output, pars = "beta", probs = probs) %>%
        #             dplyr::first() %>%
        #             as.data.frame() %>%
        #             tibble::rownames_to_column("parameter") %>%
        #             tibble::as_tibble() %>%
        #             dplyr::mutate(rr = as.numeric(gsub("beta\\[(\\d+),\\d+\\]",
        #                                                "\\1",
        #                                                parameter)),
        #                           qq = as.numeric(gsub("beta\\[\\d+,(\\d+)\\]",
        #                                                "\\1",
        #                                                parameter)))%>%
        #             dplyr::left_join(qcodes, by = "qq") %>%
        #             dplyr::arrange(qq, rr)
        #     } else if (par == "delta") {
        #         rstan::summary(dcpo_output, pars = "delta", probs = probs) %>%
        #             dplyr::first() %>%
        #             as.data.frame() %>%
        #             tibble::rownames_to_column("parameter") %>%
        #             tibble::as_tibble() %>%
        #             dplyr::mutate(qq = as.numeric(gsub("delta\\[(\\d+),\\d+\\]",
        #                                                "\\1",
        #                                                parameter)),
        #                           kk = as.numeric(gsub("delta\\[\\d+,(\\d+)\\]",
        #                                                "\\1",
        #                                                parameter)))%>%
        #             dplyr::left_join(qcodes, by = "qq") %>%
        #             dplyr::left_join(kcodes, by = "kk") %>%
        #             dplyr::arrange(qq)
        #     }
        # })
    }

    return(res)
}
fsolt/DCPOtools documentation built on June 9, 2025, 4:10 p.m.