#' Plot an individual's mean posterior predictions compared to their raw affect
#' ratings
#'
#' \code{plot_affect} is capable of plotting either grouped or
#' individual-level posterior predictions (vs. raw observations) for a defined
#' list of posterior predictions and/or grouping.
#'
#' @param data Either a list of outputs from [get_affect_ppc], or parameters
#' from [make_par_df].
#' @param plt_type Possible types are "grouped" or "individual" (for
#' [get_affect_ppc] outputs) or "weights" (for [make_par_df] output).
#' @param adj_order Same as [fit_learning_model()].
#' @param nouns Formatted noun versions of the adjectives, in order.
#' @param id_no If \code{grouped == FALSE}, a participant number to plot. If
#' left as \code{NULL}, defaults to the individual with the median \eqn{R^2} for
#' each adjective.
#' @param r2_coords If \code{grouped == FALSE}, coordinates to print the pseudo-
#' \eqn{R^2} value.
#' @param cred Same as [plot_glm], ignored unless \code{plt_type == "weights"}.
#' @param legend_pos,pal,font,font_size Same as [plot_import].
#'
#' @return A single or list of \code{ggplot} object(s) depending on type.
#'
#' @importFrom stats quantile
#' @importFrom scales pseudo_log_trans
#'
#' @examples \dontrun{
#' fit_affect <- fit_learning_model(
#' example_data$nd,
#' model = "2a",
#' affect = TRUE,
#' exp_part = "training",
#' algorithm = "fullrank"
#' )
#'
#' fit_dfs <- list()
#' for (adj in c("happy", "confident", "engaged")) {
#' fits_dfs[[adj]] <- get_affect_ppc(
#' fit_affect$draws, fit_affect$raw_df, adj = adj
#' )
#' }
#'
#' # Grouped plot
#' plot_affect(fit_dfs, plt_type = "grouped")
#'
#' # Individual-level median posterior predictions
#' plot_affect(fit_dfs, plt_type = "individual", r2_coords = c(0.8, 0.97))
#'
#' # Weight plot
#' pars <- make_par_df(fit_affect$raw, fit_affect$summary)
#' plot_affect(pars, plt_type = "weights"))
#' }
#'
#' @export
plot_affect <- function(data,
plt_type = c("individual", "grouped", "weights"),
adj_order = c("happy", "confident", "engaged"),
nouns = c("Happiness", "Confidence", "Engagement"),
id_no = NULL,
r2_coords = c(0.9, 0.8),
cred = c(0.95, 0.99),
legend_pos = "right",
pal = NULL,
font = "",
font_size = 11) {
plt_type <- match.arg(plt_type)
type <- trial_no_q <- value <- mean_val <- se_val <- se_pred <- parameter <-
posterior_mean <- NULL
if (plt_type == "weights") {
cred <- sort(cred)
cred_l1 <- (1 - cred[2]) / 2
cred_l2 <- (1 - cred[1]) / 2
data <- data |> dplyr::filter(!is.na(adj))
p <- unique(data$parameter)
labs <- c(
expression(w[0]), expression(w[1]), expression(w[1]^b),
expression(w[1]^o), expression(w[2]), expression(w[3]), "\u03B3"
)
if (any(grepl("w1_o", p)) && !any(grepl("w1_b", p)))
labs <- labs[c(1, 2, 5, 6, 7)]
else if (!any(grepl("w1_o", p))) labs <- labs[c(1, 5, 6, 7)]
else labs <- labs[c(1, 3, 4, 5, 6, 7)] # for "full" model
weight_plot <- data |>
dplyr::mutate(
# move gamma to end of plot
parameter = ifelse(parameter == "gamma", "zgamma", parameter),
adj = paste0(toupper(substr(adj, 1, 1)), substr(adj, 2, nchar(adj)))
) |>
ggplot2::ggplot(
ggplot2::aes(
x = parameter, y = posterior_mean, color = factor(adj),
fill = factor(adj)
)
) +
geom_flat_violin(
position = ggplot2::position_nudge(x = .125, y = 0), adjust = 2,
trim = FALSE, alpha = 0.5
) +
ggplot2::geom_point(
ggplot2::aes(x = as.numeric(as.factor(parameter)) - 0.225),
position = ggplot2::position_jitter(width = .1, height = 0),
size = .25,
alpha = 0.15
) +
ggplot2::stat_summary(
geom = "boxplot",
fun.data = function(x) {
setNames(
quantile_hdi(x, c(cred_l1, cred_l2, 0.5, 1 - cred_l2, 1 - cred_l1)),
c("ymin", "lower", "middle", "upper", "ymax")
)
},
position = ggplot2::position_dodge2(), alpha = 0.6, width = 0.2
) +
ggplot2::geom_hline(
ggplot2::aes(yintercept = 0), linetype = "dashed", colour = "slategrey"
) +
ggplot2::geom_hline(
ggplot2::aes(yintercept = 1), linetype = "dashed", colour = "slategrey"
) +
ggplot2::scale_color_manual(name = NULL, values = pal) +
ggplot2::scale_fill_manual(name = NULL, values = pal) +
ggplot2::scale_x_discrete(name = "Parameter", labels = labs) +
ggplot2::scale_y_continuous(
name = "Posterior mean",
trans = scales::pseudo_log_trans(sigma = 0.1),
breaks = c(-5, -1, 0, 1, 5)
) +
cowplot::theme_half_open(font_family = font, font_size = font_size) +
ggplot2::theme(legend.position = legend_pos)
return(weight_plot)
} else if (plt_type == "grouped") {
if (is.null(pal)) pal <- c("#ffc9b5", "#95a7ce", "#987284")
ppc_list <- lapply(
seq_along(adj_order),
function(f) {
dplyr::mutate(
data.table::rbindlist(
data[[f]]$indiv_ppcs, idcol = "subjID"
),
adj = adj_order[f]
)
}
)
grouped_plot <-
data.table::rbindlist(ppc_list) |>
dplyr::group_by(adj, type, trial_no_q) |>
dplyr::mutate(mean_val = mean(value), se_val = std(value)) |>
dplyr::distinct(trial_no_q, adj, type, mean_val, se_val) |>
ggplot2::ggplot(
ggplot2::aes(
x = trial_no_q, y = mean_val, color = adj, fill = adj,
linetype = factor(type, levels = c("raw", "pred"))
)
) +
ggplot2::scale_x_continuous(
limits = c(0, 120), breaks = seq(0, 120, 20)
) +
# ggplot2::scale_y_continuous(limits = c(25, 72)) +
ggplot2::geom_line(size = 1.1, alpha = 0.5) +
ggplot2::geom_ribbon(
ggplot2::aes(ymin = mean_val - se_val, ymax = mean_val + se_val),
alpha = 0.3, colour = NA
) +
ggplot2::scale_color_manual(name = NULL, values = pal) +
ggplot2::scale_fill_manual(name = NULL, values = pal) +
ggplot2::xlab("Rating number") +
ggplot2::ylab(paste0("Mean (\u00B1 SE) affect rating")) +
ggplot2::guides(linetype = "none", color = "none", fill = "none") +
cowplot::theme_half_open(
font_size = font_size,
font = font
) +
ggplot2::theme(legend.position = c(0.85, 0.85))
return(grouped_plot)
} else if (plt_type == "individual") {
if (is.null(pal)) {
pal <- c("#ffc9b5", "#648767", "#b1ddf1", "#95a7ce", "#987284", "#3d5a80")
}
median_id <- function(df, kind, id = id_no) {
pseudo_R2 <- NULL # appease R CMD check
if (kind == "num") {
med <- subset(
df,
round(pseudo_R2, 2) == round(
quantile(df$pseudo_R2, 0.5, na.rm = TRUE, type = 3), 2
)
)$id_no
if (length(med) > 1) sample(med, 1)
# takes someone with approx. median pseudo-R2, so can show diff. ppts
else med
} else if (kind == "id" && !is.null(id)) {
subset(df, id_no == id)$subjID
}
}
len <- length(data)
id_vec <- vector(mode = "integer", length = len)
if (is.null(id_no)) {
id_vec <- sapply(1:len,
function(f) median_id(data[[f]]$fit_df, "num"))
} else {
id_vec <- rep(id_no, len)
}
indiv_ppc_plots <- list()
for (a in seq_along(adj_order)) {
adj <- adj_order[a]
r2 <- subset(data[[a]]$fit_df, id_no == id_vec[a])$pseudo_R2
indiv_ppc_plots[[adj]] <-
data[[a]]$indiv_ppcs[[
median_id(data[[a]]$fit_df, "id", id_vec[a])
]] |>
ggplot2::ggplot(
ggplot2::aes(x = trial_no_q, y = value, color = type, fill = type)
) +
ggplot2::geom_line() +
ggplot2::geom_ribbon(
ggplot2::aes(ymin = value - se_pred, ymax = value + se_pred),
alpha = 0.5
) +
ggplot2::scale_color_manual(
name = "Data type", labels = c("Predicted", "Real"),
values = c(pal[a * 2 - 1], pal[a * 2])
) +
ggplot2::scale_fill_manual(
name = "Data type", labels = c("Predicted", "Real"),
values = c(pal[a * 2 - 1], pal[a * 2])
) +
ggplot2::xlab("Rating number") +
ggplot2::ylab(paste0(nouns[a], " rating /100")) +
cowplot::theme_half_open(font_size = font_size, font = font) +
ggplot2::annotation_custom(
grid::textGrob(
bquote("Pseudo-" ~ R^2 ~ "=" ~ .(round(r2, 2))),
gp = grid::gpar(fontsize = font_size + 2, col = "steelblue4"),
x = r2_coords[1], y = r2_coords[2]
)
) +
ggplot2::theme(legend.position = legend_pos)
}
return(indiv_ppc_plots)
}
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.