vignettes/internal_doc/stan_doc/one-country-vignette-cahill-data-model.R

library(tidyverse)
library(fpemdata)
library(fpemmodeling)
library(fpemreporting)
library(rstan)

### Filter the country-data from fpemdata for just "Kenya":
country_data = fpemdata:::get_divisions() %>%
  dplyr::filter(is_country == "Y") %>%
  dplyr::left_join(
    y = fpemdata:::get_division_classifications(),
    by = 'division_numeric_code') %>%
  dplyr::filter(!is.na(sub_region_numeric_code)) %>%
  # temp: add numeric codes for subregion and region BEFORE selecting countries
  dplyr::mutate(subreg_num = as.numeric(as.factor(sub_region_numeric_code))) %>%
  dplyr::mutate(reg_num = as.numeric(as.factor(region_numeric_code))) %>%
  dplyr::filter(name == 'Kenya') #%>%
#  dplyr::mutate(is_developing = factor(is_developed_region)) # wrong name!

## Make those 'units'...
unit_data = define_units(data = list(countries = country_data))

## Keep a map between the various indexing options for countries:
division_unit_map = unit_data %>%
  #dplyr::select(numeric_unit_code, division_numeric_code) %>% 
  dplyr::mutate(internal_unit_code = as.numeric(factor(division_numeric_code)))

is.dev.c <- ifelse(division_unit_map$is_developed_region == "N", 0, 1)

# keep link between the 3 proportions from same survey so easiest to go for wide format
obs_data_wide = fpemdata:::get_contraceptive_use() %>% 
  dplyr::filter(age_range == "15-49", is_in_union == "Y") %>%
  dplyr::filter(
    division_numeric_code %in% country_data[['division_numeric_code']]) %>%
  dplyr::left_join(division_unit_map, by = 'division_numeric_code') %>%
  # dplyr::mutate(
  #   has_se = !is.na(se), 
  #   se = dplyr::if_else(is.na(se), 0, se),
  #   zero_se = factor(se == 0),
  #   log_se = log(0.1)) %>%
  # select rows with non-na modern and trad (ignore total for now)
  dplyr::filter(!is.na(contraceptive_use_traditional) & !is.na(contraceptive_use_modern)) %>%
  dplyr::filter(name == 'Kenya') %>%
  dplyr::filter(start_date > 1975)

time_frame = fpemmodeling::seq_factory(obs_data_wide[['start_date']], c(1975, 2019)) 


####################  THIS COMES FROM THE GLOBAL RUN #################
## Add data from a global run, I'm making up numbers here:
unit_data = unit_data %>% dplyr::mutate(
  P_tilde_global = qlogis(0.8),
  P_tilde_global_sigma = 0.8,
  omega_global = log(0.05),
  omega_global_sigma = 1.2,
  Omega_global = qlogis(0.2),
  Omega_global_sigma = 0.7,
  R_tilde_global = qlogis(0.8),
  R_tilde_global_sigma = 0.5,
  phi_global = log(0.05),
  phi_global_sigma = 0.85,
  Phi_global = qlogis(0.1),
  Phi_global_sigma = 1.1,
  beta_1_global = 0.11,
  beta_2_global = 1,
  Z_global = -0.5,
  Z_global_sigma = 0.8
)
####################  END OF STUFF FROM THE GLOBAL RUN #################

############## This is the settings bit, don't really need all the features
############ for the one-country run but it works and lets us use the
############ same code

## Use the global model numbers and country-specific numbers:
process_settings = fpemmodeling:::process_settings(
  models = list(
    P_tilde ~ P_tilde_global_sigma + constant(P_tilde_global),
    omega ~ omega_global_sigma + constant(omega_global),
    Omega ~ Omega_global_sigma +  constant(Omega_global),
    R_tilde ~ R_tilde_global_sigma + constant(R_tilde_global),
    phi ~ phi_global_sigma + constant(phi_global),
    Phi ~ Phi_global_sigma + constant(Phi_global),
    Z_star ~ Z_global_sigma + constant(beta_1_global) + 
      constant(beta_2_global) + constant(Z_global)
  ), 
  data = unit_data
)

process_data = do.call(c, c(process_settings$inputs, use.names=FALSE))

auxilliary_data = list(
  N = nrow(obs_data_wide),
  C = nrow(unit_data),
  T = max(time_frame$index()),
  M = 4,
  t_star = which(time_frame$sequence() == 1990),
  get_c_i = factor(obs_data_wide$division_numeric_code),
  get_t_i = match(x = round(time_frame$.items), time_frame$sequence()),
  n_fixed_ar_hyper = 3,
  data_rho = rep(0.5,3),
  data_tau = rep(0.08, 3),
  model_type = array(data = c(1L, 0L))
)

simple_auxilliary_data = lapply(auxilliary_data, function(x) {
  if (is.factor(x))
    return(as.numeric(x))
  if (is.character(x))
    stop("Character data must be pre-converted to factors.")
  if (is.numeric(x))
    return(x)
  stop("Something is wrong.")
})

make_data_model_data <- function(obs_data, include_bias = FALSE, include_misclassification = FALSE) {
  # Number of observations
  N <- nrow(obs_data)
  
  # Number of countries
  C <- length(unique(obs_data$division_numeric_code))
  
  # Number of timepoints
  time_frame = fpemmodeling::seq_factory(as.numeric(obs_data$start_date))
  T <- max(time_frame$.index)
  
  # Number of data source types
  S <- length(unique(obs_data$data_series_type))
  
  # Fill in a matrix of the observed proportions
  # This matches with the y variable in Alkema 2013
  obs_proportion <- matrix(nrow = N, ncol = 4, 0)
  
  obs_proportion[, 1] <- obs_data$contraceptive_use_modern
  obs_proportion[, 2] <- obs_data$contraceptive_use_traditional
  obs_proportion[, 3] <- obs_data$unmet_need_for_any_method
  obs_proportion[, 4] <- 1 - obs_proportion[, 1] - obs_proportion[, 2] - obs_proportion[, 3]
  
  obs_proportion <- ifelse(obs_proportion == 0, obs_proportion + 1e-4, obs_proportion)
  obs_proportion <- ifelse(obs_proportion == 1, obs_proportion - 1e-4, obs_proportion)
  
  # Indicator matrix for missing observations
  obs_proportion_missing <- ifelse(is.na(obs_proportion), 1, 0)
  obs_proportion <- ifelse(is.na(obs_proportion), 0, obs_proportion)
  
  # Precalculate the proportion who either have unmet need or do not use a method
  # this is equivalent to y_i3 + y_i4 in the paper
  obs_proportion_no_method_use <- 1 - obs_proportion[, 1] - obs_proportion[, 2]
  
  # Lookup vectors for finding the country, timepoint, and data source type
  # for an observation
  get_c_i = as.numeric(factor(obs_data$division_numeric_code))
  get_t_i = match(x = round(time_frame$.items), time_frame$sequence())
  get_s_i = as.numeric(factor(obs_data$data_series_type))
  
  # Multiplier index
  # proportions are perturbed by a perturbation vector V.
  # each proportion category in each country gets its own V
  # (the V's are identified with a hierarchical model)
  has_geographical_region_bias <- ifelse(obs_data$has_geographical_region_bias == "Y", 1, 0)
  geographic_multiplier_indices_table <- obs_data %>% 
    dplyr::filter(has_geographical_region_bias == "Y") %>%
    dplyr::distinct(division_numeric_code, data_series_type) %>%
    dplyr::mutate(multiplier_index = dplyr::row_number())
  
  geographic_multiplier_indices <- obs_data %>%
    dplyr::left_join(geographic_multiplier_indices_table) %>%
    dplyr::pull(multiplier_index) %>%
    ifelse(is.na(.), 0, .)
  
  has_traditional_includes_folk_methods_bias <- ifelse(!is.na(obs_data$traditional_method_bias_reason) & obs_data$traditional_method_bias_reason == "Including folk methods.", 1, 0)
  has_absence_of_probing_questions_bias      <- ifelse(obs_data$has_absence_of_probing_questions_bias == "Y", 1, 0)
  
  Sigma_prior = diag(0.1, 2)
  
  list(
    S = S,
    obs_proportion = obs_proportion,
    obs_proportion_missing = obs_proportion_missing,
    obs_proportion_no_method_use = obs_proportion_no_method_use,
    get_c_i = get_c_i,
    get_t_i = get_t_i,
    get_s_i = get_s_i,
    
    N_geographical_region_bias_parameters = length(unique(geographic_multiplier_indices)),
    has_geographical_region_bias = has_geographical_region_bias,
    get_geographical_region_bias_index_i = geographic_multiplier_indices,
    
    has_traditional_includes_folk_methods_bias = has_traditional_includes_folk_methods_bias,
    has_absence_of_probing_questions_bias = has_absence_of_probing_questions_bias,
    
    Sigma_prior = Sigma_prior,
    
    include_bias = include_bias,
    include_misclassification = include_misclassification
  )
}

data_model_data <- make_data_model_data(obs_data_wide, include_bias = TRUE, include_misclassification = FALSE)

stan_model_data = c(simple_auxilliary_data, process_data, data_model_data)

samples = fpemmodeling::run(
  data = stan_model_data, 
  model =  "models/fpem--rl--dp--proportions--cahill.stan",
  cache = '.my_cache',
  control = list(iter = 1200, warmup = 1000, 
		 init='random', init_r=1,
		 control = list(adapt_delta = 0.999, max_treedepth=16)),
  cores = 4)

PRZ = rstan::extract(samples, pars = c("P", "R", "Z"), 
		     inc_warmup=FALSE, permuted=TRUE) %>%
  lapply(function(A) {
    dimnames(A) = list(iteration = 1:dim(A)[1], 
		       country = "Kenya", 
		       year = time_frame$sequence())
    return(A)}
  )

# The format of one-country-run output is different from all-country, will make code below into a function eventually
calculate_props <- function(P, R, Z) {
  trad = (1 - R) * P
  modern = R * P
  unmet = (1 - P) * Z
  c(trad, modern, unmet)
}

all_vec <- calculate_props(PRZ$P %>% unlist() %>% as.vector(),
                                                     PRZ$R %>% unlist() %>% as.vector(),
                                                     PRZ$Z %>% unlist() %>% as.vector())

props <- c("P", "R", "Z")
posterior_samples <- array(data = all_vec, 
                       dim = c(1, dim(PRZ$P)[1], dim(PRZ$P)[3], length(props)),
                       dimnames = list(chain = 1, iteration = 1:dim(PRZ$P)[1], unit_time = 1:dim(PRZ$P)[3], props = props))


code <- country_data$division_numeric_code
name <- country_data$name
nyears <- time_frame$`.->.index` %>% length()
years <- time_frame$.sequence

data_set_path <- fpemdata:::default_data_set_path()
population_counts <- fpemdata::get_population_counts(data_set_path = data_set_path)
divisions <- fpemdata:::get_divisions(data_set_path = data_set_path)

divisions_country <- divisions %>%
  subset(is_country =="Y")
country_div_codes <- divisions_country %>%
  dplyr::select(division_numeric_code) %>%
  unlist()
country_names <- divisions_country %>%
  dplyr::select(name) %>%
  unlist()


results <- fpemreporting::fpem_calculate_results(
  posterior_samples = posterior_samples,
  country_population_counts = population_counts,
  first_year = min(years))

indicator = "contraceptive_use_all"
fpemreporting::fpem_plot_country_results(
  results$total_cpr_proportions,
  contraceptive_use = obs_data_wide,
  indicator = indicator,
  title = paste(name, indicator),
  y_label = "Proportion",
  breaks = seq(min(years), 2020, by = 5)
) %>% print()


indicator = "contraceptive_use_modern"
fpemreporting::fpem_plot_country_results(
  results$modern_cpr_proportions,
  contraceptive_use = obs_data_wide,
  indicator = indicator,
  title = paste(name, indicator),
  y_label = "Proportion",
  breaks = seq(min(years), 2020, by = 5)
) + geom_line(aes(x = year, y = `0.5`), data = modern_jags, linetype = 'dashed') +
  geom_line(aes(x = year, y = `0.025`), data = modern_jags, linetype = 'dashed') +
  geom_line(aes(x = year, y = `0.975`), data = modern_jags, linetype = 'dashed')


indicator = "contraceptive_use_traditional"
fpemreporting::fpem_plot_country_results(
  results$traditional_cpr_proportions,
  contraceptive_use = obs_data_wide,
  indicator = indicator,
  title = paste(name, indicator),
  y_label = "Proportion",
  breaks = seq(min(years), 2020, by = 5)
) + geom_line(aes(x = year, y = `0.5`), data = trad_jags, linetype = 'dashed') +
  geom_line(aes(x = year, y = `0.025`), data = trad_jags, linetype = 'dashed') +
  geom_line(aes(x = year, y = `0.975`), data = trad_jags, linetype = 'dashed')


indicator = "unmet_need_for_any_method"
fpemreporting::fpem_plot_country_results(
  results$unmet_proportions,
  contraceptive_use = obs_data_wide,
  indicator = indicator,
  title = paste(name, indicator),
  y_label = "Proportion",
  breaks = seq(min(years), 2020, by = 5)
) + geom_line(aes(x = year, y = `0.5`), data = unmet_jags, linetype = 'dashed') +
  geom_line(aes(x = year, y = `0.025`), data = unmet_jags, linetype = 'dashed') +
  geom_line(aes(x = year, y = `0.975`), data = unmet_jags, linetype = 'dashed')

process_jags <- function(csv, country) {
  read_csv(csv) %>%
    filter(Name == country) %>%
    select(Name, Iso, Percentile, `1975.5`:`2020.5`) %>%
    gather(year, value, `1975.5`:`2020.5`) %>%
    mutate(year = parse_number(year)) %>%
    spread(Percentile, value)
}

unmet_jags <- process_jags("~/../Downloads/190128_100329_global_rate_married_15-49_Country_perc_Unmet.csv", "Kenya")
modern_jags <- process_jags("~/../Downloads/190128_100329_global_rate_married_15-49_Country_perc_Modern.csv", "Kenya")
trad_jags <- process_jags("~/../Downloads/190128_100329_global_rate_married_15-49_Country_perc_Traditional.csv", "Kenya")

ggplot(unmet_jags, aes(x = year, y = `0.5`)) +
  geom_ribbon(aes(ymin = `0.025`, ymax = `0.975`), alpha = 0.5) +
  geom_line()
FPRgroup/FPEM documentation built on March 3, 2020, 8:19 a.m.