Veridical Numeric Task

Before creating our plots we need to load in packages and prepare some convenience functions and themes for plotting.

Load Packages

I use the tidyverse a lot, so we'll load that all in now, we also need the BayesFactor package, and ggpubr for combining subplots into larger figures. We'll also tell knitr to create all figures at the full column width.

library(dplyr)
library(ggplot2)
library(ggh4x)
library(stringr)
library(patchwork)
library(watercolours)
library(tune)
library(mcce)   # devtools::install_github('gjcooper/gcphd-model_of_dce')
knitr::opts_chunk$set(out.width = "100%")

Load Data

# Load all the samples
accept_file <- here::here("data", "output", "NumericVDCE_1878182.rcgbcm_Estimation5Model.RData")
reject_file <- here::here("data", "output", "NumericVDCE_2117450.rcgbcm_RejectEstimation5Model.RData")

samples <- list(accept = get_samples(accept_file) %>% extract_parameters(filter = "sample"),
                reject = get_samples(reject_file) %>% extract_parameters(filter = "sample")) %>%
  bind_rows(.id = "condition")

#Map of old parameter names to new ones
par_label_short <- c("AORO", "AORA", "AARO", "AARA", "Int")
par_label_long <- c("Accept OR Reject OR",
                    "Accept OR Reject AND",
                    "Accept AND Reject OR",
                    "Accept AND Reject AND",
                    "Integrated")
par_label_long2 <- str_replace_all(par_label_long, " ", "\n")

names(par_label_short) <- c("FPP", "IST", "IEX", "MW", "CB")
names(par_label_long) <- c("FPP", "IST", "IEX", "MW", "CB")
names(par_label_long2) <- c("FPP", "IST", "IEX", "MW", "CB")
arch_order <- c("FPP", "IST", "IEX", "MW", "CB")

Plot theme

To keep things as consistent as possible we'll create a theme and update the default ggplot theme with it.

txt_base_sz <- 12
plot_theme <- theme(
  strip.background = element_blank(),
  strip.text = element_text(face = "bold", size = txt_base_sz + 2),
  strip.placement = "outside",
  legend.title =  element_blank(),
  legend.position = "bottom",
  legend.margin = margin(0, 0, 0, 0, "cm"),
  axis.text = element_text(face = "bold", size = txt_base_sz - 2),
  axis.title = element_text(face = "bold", size = txt_base_sz + 2),
  plot.margin = margin(l = 2, r = 2, t = 2),
  plot.title = element_text(face = "bold", size = txt_base_sz + 2, hjust = 0.5),
)

bigger_text <- theme(
  axis.title = element_text(size = txt_base_sz + 4),
  strip.text = element_text(face = "bold", size = txt_base_sz + 4),
  plot.title = element_text(face = "bold", size = txt_base_sz + 4))

five_col_palette <- setNames(watercolour$durer$discrete[c(1:4, 7)], NULL)


theme_set(theme_classic() + plot_theme)

Selective Influence test

model_medians <- samples %>%
  filter(startsWith(parameter, "alpha_"), stageid == "sample") %>%
  mutate(value = exp(value)) %>%
  group_by(subjectid, parameter) %>%
  summarise(value = median(value), condition = first(condition)) %>%
  arch_medians

Par_order <- c("IST", "CB", "FPP", "MW", "IEX")
leg_order <- c("FPP", "IST", "IEX", "MW", "CB")

acc_plot <- model_medians %>%
  filter(condition == "accept") %>%
  mutate(parameter = factor(parameter, levels = Par_order,
                            labels = par_label_long[Par_order])) %>%
  arch_plot +
  scale_fill_discrete(type = five_col_palette, breaks = par_label_long[leg_order]) +
  theme(legend.position = "none") +
  labs(title = "Accept condition")

rej_plot <- model_medians %>%
  filter(condition == "reject") %>%
  mutate(parameter = factor(parameter, levels = Par_order,
                            labels = par_label_long[Par_order])) %>%
  arch_plot +
  scale_fill_discrete(type = five_col_palette, breaks = par_label_long[leg_order]) +
  theme(legend.position = "right",
        legend.key.size = unit(2, "cm")) +
  labs(title = "Reject condition")

layout <- "
AAAAAAAAA
BBBB##CC#
"
acc_plot + rej_plot + guide_area() +
  plot_layout(design = layout, guides = "collect")

ggsave(
  filename = here::here(
    "results",
    "arch_mac_figures",
    "salience_manipulation.png"),
  dpi = 600,
  width = 45,
  height = 30,
  units = "cm"
)

Parameter values for Veridical task

par_medians <- samples %>%
  filter(!startsWith(parameter, "alpha_"), stageid == "sample") %>%
  mutate(value = exp(value)) %>%
  group_by(subjectid, parameter) %>%
  summarise(value = median(value), condition = first(condition))

par_colours <- c("t0" = "grey",
                 "A" = "grey", "b_acc" = "#73842E", "b_rej" = "#D0781C",
                 "v_acc_p_H" = "#569F72", "v_acc_p_L" = "#407755", "v_acc_p_D" = "#2B5039",
                 "v_rej_p_H" = "#F29F40", "v_rej_p_L" = "#EF8C1A", "v_rej_p_D" = "#BF6C0D",
                 "v_acc_r_H" = "#85C0FF", "v_acc_r_L" = "#47A0FF", "v_acc_r_D" = "#0A81FF",
                 "v_rej_r_H" = "#BBA9A0", "v_rej_r_L" = "#A2887C", "v_rej_r_D" = "#83685D")

par_medians %>%
  mutate(colour = par_colours[parameter]) %>%
  ggplot(aes(x = parameter, y = value, fill = colour)) +
  scale_colour_watercolour(palette = "durer") +
  scale_fill_identity() +
  geom_boxplot() +
  coord_flip() +
  facet_wrap(~ condition)

ggsave(
  filename = here::here(
    "results",
    "arch_mac_figures",
    "vnumeric_pars.png"),
  dpi = 600,
  width = 45,
  height = 30,
  units = "cm"
)

Simulation

Prep for data reads and get original data

# Load all the samples
data_location <- here::here("data", "output", "VeridicalRecovery")
odata <- "NumericVDCE_1878182.rcgbcm_Estimation5Model.RData"

recovery_files <- list.files(data_location, pattern = ".RData")
recovery_data_files <- list.files(data_location, pattern = ".RDS")

original_samples <- get_samples(here::here("data", "output", odata))

Read in generated data

recovery_data <- lapply(recovery_data_files, function(x) {
  print(paste("Reading data", x))
  readRDS(here::here(data_location, x))
})
names(recovery_data) <- sapply(strsplit(recovery_data_files, "_"), "[[", 2)

recovery_data[["Original"]] <- original_samples$data %>%
  mutate(generator = "participant")
recovery_data <- bind_rows(recovery_data, .id = "source")

Data plots

Plot 1

recovery_data %>%
  filter(rt < 10) %>%
  mutate(accept = factor(accept, labels = c("Reject", "Accept"))) %>%
  ggplot(aes(x = source, y = rt, colour = generator)) +
  geom_boxplot() +
  facet_wrap(~ accept) +
  scale_colour_watercolour(palette = "durer")

Plot 2

recovery_data %>%
  filter(rt < 5) %>%
  mutate(simulated = ifelse(source == "Original", "No", "Yes")) %>%
  mutate(cell=paste0(price, rating)) %>%
  ggplot(aes(x = rt, colour = source, linetype=simulated)) +
  geom_density(linewidth=0.7) +
  scale_colour_watercolour(palette = "durer") +
  facet_wrap(~ cell)

Read in the samples

samples <- lapply(recovery_files, function(x) {
  print(paste("Extracting", x))
  get_samples(here::here(data_location, x)) %>%
    extract_parameters(filter = "sample")
})
names(samples) <- sapply(strsplit(recovery_files, "_"), "[[", 2)

samples[["Original"]] <- original_samples %>% extract_parameters(filter = "sample")

samples <- bind_rows(samples, .id = "source")
rm(original_samples)

Get Architecture Medians

model_medians <- samples %>%
  filter(startsWith(parameter, "alpha_"), stageid == "sample") %>%
  mutate(value = exp(value)) %>%
  group_by(source, subjectid, parameter) %>%
  summarise(value = median(value)) %>%
  arch_medians(grouping_vars = c("source", "subjectid"))

Get subject order for plot

subject_order <- model_medians %>%
  filter(subjectid != "Group") %>%
  group_by(subjectid, source) %>%
  summarise(maxRel = max(rel_val)) %>%
  group_by(subjectid) %>%
  summarise(meanMax = mean(maxRel)) %>%
  arrange(desc(meanMax)) %>%
  pull(subjectid)

Get parameter order

Par_order <- model_medians %>%
  filter(source == "Original") %>%
  group_by(parameter) %>%
  summarise(median_val = median(rel_val)) %>%
  arrange(median_val) %>%
  pull(parameter)

Architecture recovery

subject_recovery <- model_medians %>%
  filter(source != "Original") %>%
  filter(subjectid != "Group") %>%
  mutate(subjectid = factor(subjectid, subject_order)) %>%
  mutate(parameter = factor(parameter, Par_order,
                            labels = par_label_long2[Par_order])) %>%
  mutate(source = factor(source, leg_order,
                         labels = par_label_long2[leg_order])) %>%
  ggplot(aes(x = subjectid, y = rel_val, fill = parameter)) +
    geom_col() +
    labs(y = "Relative Evidence") +
    scale_fill_discrete(type = five_col_palette, breaks = par_label_long2[leg_order]) +
    theme(axis.title.x = element_blank(),
          axis.text.x = element_blank(),
          axis.ticks.x = element_blank(),
          strip.background = element_blank(),
          strip.text.y = element_blank()) +
    facet_grid(rows = vars(source)) +
    scale_y_continuous(labels = NULL, breaks = NULL) +
    ggtitle("Indidividual Subjects")

group_recovery <- model_medians %>%
  filter(source != "Original") %>%
  filter(subjectid == "Group") %>%
  mutate(subjectid = factor(subjectid, subject_order)) %>%
  mutate(parameter = factor(parameter, Par_order,
                            labels = par_label_long2[Par_order])) %>%
  mutate(source = factor(source, leg_order,
                         labels = par_label_long2[leg_order])) %>%
  ggplot(aes(x = subjectid, y = rel_val, fill = parameter)) +
    geom_col() +
    scale_x_discrete(expand = c(0.5, 0)) +
    scale_fill_discrete(type = five_col_palette, breaks = par_label_long2[leg_order]) +
    theme(axis.title = element_blank(),
          axis.text.x = element_blank(),
          axis.ticks.x = element_blank(),
          strip.text.y.right = element_text(angle = 0, hjust=0)

          ) +
    facet_grid(rows = vars(source)) +
    scale_y_continuous(labels = NULL, breaks = NULL) +
    ggtitle("Group Level")

subject_recovery + group_recovery +
  plot_layout(widths = c(9, 1), guides = "collect") +
  plot_annotation(
    title = "Relative Evidence for each of the 5 architectures",
    subtitle = "Each row corresponds to the generating architecture",
    caption = "The area of each stacked bar is the relative proportion of each of the five possible modelled architectures as their dirichlet process"
    )

ggsave(
  filename = here::here(
    "results",
    "arch_mac_figures",
    "simulation_recovery.png"),
  dpi = 600,
  width = 45,
  height = 30,
  units = "cm"
)

Parameter Recovery

Get data in right format

par_medians <- samples %>%
  filter(!startsWith(parameter, "alpha_")) %>%
  group_by(source, subjectid, parameter) %>%
  summarise(value = median(value))

recovery <- par_medians %>% filter(source != "Original")
original <- par_medians %>% filter(source == "Original")

combined <- recovery %>%
  left_join(original, by = c("parameter", "subjectid")) %>%
  select(-source.y) %>%
  rename(
    recovered_value = value.x,
    estimated_value = value.y,
    recovery_model = source.x
  ) %>%
  mutate(recovery_model = factor(recovery_model, levels = arch_order,
                                 labels = par_label_short[arch_order])) %>%
  mutate(subjectid = case_when(
    subjectid == "theta_mu" ~ "Group",
    TRUE ~ str_pad(subjectid, 2, pad = "0")
  ))

pt_col <- watercolour$durer$discrete["shadow"]

Plotting function

parplot <- function(subset) {
  subset %>%
  ggplot(aes(x = estimated_value, y = recovered_value)) +
  geom_point(size = 0.5, colour = pt_col) +
  geom_abline(intercept = 0, slope = 1) +
  facet_grid(rows = vars(parameter),
              cols = vars(recovery_model)) +
  labs(
    x = "Generating Value",
    y = "Recovered Value"
  ) +
  coord_obs_pred()
}

Non drift parameter

pars <- c("A", "t0", "b_acc", "b_rej")
plabs <- c("Start Point", "Non-decision Time",
           "Threshold to Accept", "Threshold to Reject")
combined %>%
  filter(parameter %in% pars) %>%
  mutate(parameter = factor(parameter, levels = pars, labels = plabs)) %>%
  parplot()

ggsave(
  filename = here::here(
    "results",
    "arch_mac_figures",
    "parameter_recovery_other.png"),
  dpi = 600,
  width = 45,
  height = 30,
  units = "cm"
)

Accept drift rates

acc_dr <- c("v_acc_p_D", "v_acc_r_D",
            "v_acc_p_L", "v_acc_r_L",
            "v_acc_p_H", "v_acc_r_H")
drlabs <- c("Price Distractor", "Rating Distractor",
            "Price Low", "Rating Low",
            "Price High", "Rating High")
combined %>%
  filter(startsWith(parameter, "v_acc")) %>%
  mutate(parameter = factor(parameter, levels = acc_dr, labels = drlabs)) %>%
  parplot

ggsave(
  filename = here::here(
    "results",
    "arch_mac_figures",
    "parameter_recovery_drift_acc.png"),
  dpi = 600,
  width = 45,
  height = 30,
  units = "cm"
)

Reject drift rates

rej_dr <- c("v_rej_p_D", "v_rej_r_D",
            "v_rej_p_L", "v_rej_r_L",
            "v_rej_p_H", "v_rej_r_H")
drlabs <- c("Price Distractor", "Rating Distractor",
            "Price Low", "Rating Low",
            "Price High", "Rating High")

combined %>%
  filter(startsWith(parameter, "v_rej")) %>%
  mutate(parameter = factor(parameter, levels = rej_dr, labels = drlabs)) %>%
  parplot

ggsave(
  filename = here::here(
    "results",
    "arch_mac_figures",
    "parameter_recovery_drift_rej.png"),
  dpi = 600,
  width = 45,
  height = 30,
  units = "cm"
)

Architecture Explainer

one_acc_plot <- function(racecol = "darkgreen", diagcol = "black", end = c(0.8, 0.8), width = 1) {
  ggplot(data.frame(x = seq(0, 1, 0.1), y = seq(0, 1, 0.1)),
         aes(x = x, y = y)) +
  geom_blank() +
  geom_hline(yintercept = 0.9, linetype = 2, colour = diagcol) +
  geom_segment(
    x = 0, y = 0,
    xend = end[1], yend = end[2],
    lineend = "round", # See available arrow types in example above
    linejoin = "mitre",
    linetype = 1,
    linewidth = width,
    arrow = arrow(length = unit(0.3, "inches")),
    colour = racecol # Also accepts "red", "blue' etc
  ) +
  theme_void() +
  theme(
    axis.line = element_line(linewidth=width, colour = diagcol),
    axis.title = element_blank(),
    axis.text = element_blank()
  ) +
  scale_x_continuous(expand = c(0, 0), limits = c(0, 1)) +
  scale_y_continuous(expand = c(0, 0), limits = c(0, 1))
}
plot_canvas <- function(xmax = 1, ymax = 1) {
  ggplot(data.frame(x = seq(0, xmax, length.out = 10), y = seq(0, ymax, length.out = 10)),
         aes(x = x, y = y)) +
    geom_blank() +
    scale_x_continuous(expand = c(0, 0), limits = c(0, xmax)) +
    scale_y_continuous(expand = c(0, 0), limits = c(0, ymax)) +
    theme_void() +
    theme(panel.background = element_rect(fill = "white", colour = "black"))

}

arch_schema_plot <- function(arch) {
  p <- plot_canvas() +
    theme(
      plot.title = element_text(size = 24, hjust = 0.5)
    ) +
    coord_fixed()

  arch_spec <- arch_schema[[arch]]
  for (component in arch_spec$components) {
    lba <- do.call(one_acc_plot, component$args)
    annote_args <- c(list(grob = ggplotGrob(lba)), component$pos)
    p <- p + do.call(annotation_custom, annote_args)
  }
  p + ggtitle(arch_spec$name)
}

greengrey <- "#C5D3C5"
redgrey <- "#D3C5C5"

arch_order <- c("FPP", "IST", "IEX", "MW", "CB")
arch_schema <- list(
  AORO = list(name = "Accept OR Reject OR",
              components = list(list(pos = list(xmin = 0.05, xmax = 0.45, ymin = 0.55, ymax = 0.95),
                                     args = list(racecol = "darkgreen", end = c(0.6, 0.9), width = 2)),
                                list(pos = list(xmin = 0.55, xmax = 0.95, ymin = 0.55, ymax = 0.95),
                                     args = list(racecol = redgrey, end = c(0.95, 0.8))),
                                list(pos = list(xmin = 0.05, xmax = 0.45, ymin = 0.05, ymax = 0.45),
                                     args = list(racecol = greengrey, end = c(0.7, 0.8))),
                                list(pos = list(xmin = 0.55, xmax = 0.95, ymin = 0.05, ymax = 0.45),
                                     args = list(racecol = redgrey, end = c(0.99, 0.75))))),
  AORA = list(name = "Accept OR Reject AND",
              components = list(list(pos = list(xmin = 0.05, xmax = 0.45, ymin = 0.55, ymax = 0.95),
                                     args = list(racecol = "lightgrey", end = c(0.6, 0.75))),
                                list(pos = list(xmin = 0.55, xmax = 0.95, ymin = 0.55, ymax = 0.95),
                                     args = list(racecol = "darkred", end = c(0.55, 0.9), width = 2)),
                                list(pos = list(xmin = 0.05, xmax = 0.45, ymin = 0.05, ymax = 0.45),
                                     args = list(racecol = "darkgreen", end = c(0.6, 0.9), width = 2)),
                                list(pos = list(xmin = 0.55, xmax = 0.95, ymin = 0.05, ymax = 0.45),
                                     args = list(racecol = "lightgrey", end = c(0.65, 0.85))))),
  AARO = list(name = "Accept AND Reject OR",
              components = list(list(pos = list(xmin = 0.05, xmax = 0.45, ymin = 0.55, ymax = 0.95),
                                     args = list(racecol = "darkgreen", end = c(0.6, 0.9), width = 2)),
                                list(pos = list(xmin = 0.55, xmax = 0.95, ymin = 0.55, ymax = 0.95),
                                     args = list(racecol = "lightgrey", end = c(0.8, 0.8))),
                                list(pos = list(xmin = 0.05, xmax = 0.45, ymin = 0.05, ymax = 0.45),
                                     args = list(racecol = "darkgreen", end = c(0.8, 0.9), width = 2)),
                                list(pos = list(xmin = 0.55, xmax = 0.95, ymin = 0.05, ymax = 0.45),
                                     args = list(racecol = "lightgrey", end = c(0.9, 0.85))))),
  AARA = list(name = "Accept AND Reject AND",
              components = list(list(pos = list(xmin = 0.05, xmax = 0.45, ymin = 0.55, ymax = 0.95),
                                     args = list(racecol = "darkgreen", end = c(0.7, 0.95), width = 2)),
                                list(pos = list(xmin = 0.55, xmax = 0.95, ymin = 0.55, ymax = 0.95),
                                     args = list(racecol = "lightgrey", end = c(0.8, 0.8))),
                                list(pos = list(xmin = 0.05, xmax = 0.45, ymin = 0.05, ymax = 0.45),
                                     args = list(racecol = "darkgreen", end = c(0.7, 0.9), width = 2)),
                                list(pos = list(xmin = 0.55, xmax = 0.95, ymin = 0.05, ymax = 0.45),
                                     args = list(racecol = "darkred", end = c(0.55, 0.9), width = 2)))),
  Int = list(name = "Integrated",
              components = list(list(pos = list(xmin = 0.05, xmax = 0.45, ymin = 0.3, ymax = 0.7),
                                     args = list(racecol = "darkgreen", end = c(0.6, 0.9), width = 4)),
                                list(pos = list(xmin = 0.55, xmax = 0.95, ymin = 0.3, ymax = 0.7),
                                     args = list(racecol = "lightgrey", end = c(0.7, 0.8), width = 2))))
)

All together

p <- plot_canvas(xmax = 2, ymax = 3) +
  coord_fixed() +
  annotation_custom(ggplotGrob(arch_schema_plot('AORO')),
                    xmin = 0.05, xmax = 0.95, ymin = 2.05, ymax = 2.95) +
  annotation_custom(ggplotGrob(arch_schema_plot('AORA')),
                    xmin = 1.05, xmax = 1.95, ymin = 2.05, ymax = 2.95) +
  annotation_custom(ggplotGrob(arch_schema_plot('AARO')),
                    xmin = 0.05, xmax = 0.95, ymin = 1.05, ymax = 1.95) +
  annotation_custom(ggplotGrob(arch_schema_plot('AARA')),
                    xmin = 1.05, xmax = 1.95, ymin = 1.05, ymax = 1.95) +
  annotation_custom(ggplotGrob(arch_schema_plot('Int')),
                    xmin = 0.55, xmax = 1.45, ymin = 0.05, ymax = .95)


print(p)

ggsave(
  filename = here::here(
    "results",
    "arch_mac_figures",
    "accumulators_diagram.png"),
  dpi = 600,
  width = 40,
  height = 60,
  units = "cm"
)


gjcooper/gcphd-model_of_dce documentation built on March 25, 2024, 8:57 a.m.