Nothing
#' Function for plotting iterations of posterior distributions
#'
#' @param fits A list of brmsfit objects following the same data over time.
#' Currently checkpointing is not supported.
#' @param form A formula describing the growth model similar to \code{\link{growthSS}}
#' and \code{\link{brmPlot}} such as: outcome ~ predictor |individual/group
#' @param df data used to fit models (this is used to plot each subject's trend line).
#' @param priors a named list of samples from the prior distributions for each parameter in
#' \code{params}. This is only used if sample_prior=FALSE in the brmsfit object.
#' If left NULL then no prior is included.
#' @param params a vector of parameters to include distribution plots of.
#' Defaults to NULL which will use all parameters from the top level model.
#' @param maxTime Optional parameter to designate a max time not observed in the models so far
#' @param patch Logical, should a patchwork plot be returned or should lists of ggplots be returned?
#' @keywords Bayesian brms
#' @import ggplot2
#' @import patchwork
#' @importFrom methods is
#' @importFrom stats setNames
#' @import viridis
#' @return A ggplot or a list of ggplots (depending on patch).
#' @export
#' @examples
#' \donttest{
#' f <- "https://raw.githubusercontent.com/joshqsumner/pcvrTestData/main/brmsFits.rdata"
#' tryCatch(
#' {
#' print(load(url(f)))
#' library(brms)
#' library(ggplot2)
#' library(patchwork)
#' fits <- list(fit_3, fit_15)
#' form <- y~time | id / group
#' priors <- list(
#' "phi1" = rlnorm(2000, log(130), 0.25),
#' "phi2" = rlnorm(2000, log(12), 0.25),
#' "phi3" = rlnorm(2000, log(3), 0.25)
#' )
#' params <- c("A", "B", "C")
#' d <- simdf
#' maxTime <- NULL
#' patch <- TRUE
#' from3to25 <- list(
#' fit_3, fit_5, fit_7, fit_9, fit_11,
#' fit_13, fit_15, fit_17, fit_19, fit_21, fit_23, fit_25
#' )
#' distributionPlot(
#' fits = from3to25, form = y ~ time | id / group,
#' params = params, d = d, priors = priors, patch = FALSE
#' )
#' distributionPlot(
#' fits = from3to25, form = y ~ time | id / group,
#' params = params, d = d, patch = FALSE
#' )
#' },
#' error = function(e) {
#' message(e)
#' }
#' )
#' }
#' ## End(Not run)
distributionPlot <- function(fits, form, df, priors = NULL,
params = NULL, maxTime = NULL, patch = TRUE) {
#* ***** `Reused helper variables`
parsed_form <- .parsePcvrForm(form, df)
y <- parsed_form$y
x <- parsed_form$x
individual <- parsed_form$individual
group <- parsed_form$group
d <- parsed_form$data
fitData <- fits[[length(fits)]]$data
dSplit <- split(d, d[[group]])
startTime <- min(unlist(lapply(fits, function(ft) {
return(min(ft$data[[x]], na.rm = TRUE))
})))
if (is.null(maxTime)) {
endTime <- max(unlist(lapply(fits, function(ft) {
return(max(ft$data[[x]], na.rm = TRUE))
})))
}
byTime <- mean(diff(unlist(lapply(fits, function(ft) {
return(max(ft$data[[x]], na.rm = TRUE))
}))))
timeRange <- seq(startTime, endTime, byTime)
virOptions <- c("C", "G", "B", "D", "A", "H", "E", "F")
palettes <- lapply(
seq_along(unique(fitData[[group]])),
function(i) {
group_pal <- viridis::viridis(length(timeRange),
begin = 0.1,
end = 1, option = virOptions[i], direction = 1
)
return(group_pal)
}
)
names(palettes) <- unique(fitData[[group]])
#* ***** `if params is null then pull them from growth formula`
if (is.null(params)) {
fit <- fits[[1]]
growthForm <- as.character(fit$formula[[1]])[[3]]
test <- gsub(x, "", growthForm) # ;test
test2 <- gsub("exp\\(", "", test) # ; test2
test3 <- gsub("\\(1", "", test2) # ;test3
test4 <- gsub("[/]|[+]|[-]|[)]|[()]", "", test3)
params <- strsplit(test4, "\\s+")[[1]]
test3 <- gsub("[)]|[()]", "", test2)
test3
}
#* ***** `growth trendline plots`
growthTrendPlots <- lapply(seq_along(dSplit), function(i) {
dt <- dSplit[[i]]
p <- ggplot2::ggplot(dt, ggplot2::aes(
x = .data[[x]], y = .data[[y]], color = .data[[x]],
group = .data[[individual]]
)) +
ggplot2::geom_line(show.legend = FALSE) +
viridis::scale_color_viridis(begin = 0.1, end = 1, option = virOptions[i], direction = 1) +
ggplot2::scale_x_continuous(limits = c(startTime, endTime)) +
pcv_theme()
return(p)
})
#* ***** `posterior distribution extraction`
posts <- do.call(rbind, lapply(fits, function(fit) {
time <- max(fit$data[[x]], na.rm = TRUE)
fitDraws <- do.call(cbind, lapply(params, function(par) {
draws <- as.data.frame(fit)[grepl(par, colnames(as.data.frame(fit)))]
if (nrow(brms::prior_draws(fit)) > 1) {
draws <- draws[!grepl("^prior_", colnames(draws))]
}
splits <- strsplit(colnames(draws), split = "")
mx <- max(unlist(lapply(splits, length)))
ind <- which(unlist(lapply(1:mx, function(i) {
l_over_1 <- length(unique(rapply(splits, function(j) {
return(j[i])
}))) != 1
return(l_over_1)
})))
if (length(ind) > 0) {
colnames(draws) <- paste(par, unlist(lapply(colnames(draws), function(c) {
return(substr(c, min(ind), max(c(ind, nchar(c)))))
})), sep = "_")
}
return(draws)
}))
fitDraws$time <- time
return(fitDraws)
}))
#* ***** `prior distribution extraction`
distPlotPriorExtractionRes <- .distPlotPriorExtraction(fits, priors, d, group, params, x)
prior_df <- distPlotPriorExtractionRes[["prior_df"]]
USEPRIOR <- distPlotPriorExtractionRes[["UP"]]
#* ***** `posterior distribution plots`
#* need to assign ordering of factors
#* if USEPRIOR then join data, don't make separate geom
if (USEPRIOR) {
posts <- rbind(prior_df, posts)
}
posts[[x]] <- factor(posts[[x]], levels = sort(as.numeric(unique(posts[[x]]))), ordered = TRUE)
lapply(posts, summary)
xlims <- lapply(params, function(par) {
diff <- as.numeric(as.matrix(posts[, grepl(paste0("^", par, "_"), colnames(posts))]))
rng <- c(min(diff, na.rm = TRUE), max(diff, na.rm = TRUE))
return(rng)
})
names(xlims) <- params
postPlots <- lapply(unique(fitData[[group]]), function(groupVal) {
groupPlots <- lapply(params, function(par) {
p <- ggplot2::ggplot(posts) +
ggplot2::geom_density(ggplot2::aes(
x = .data[[paste(par, groupVal, sep = "_")]],
fill = .data[[x]], color = .data[[x]],
group = .data[[x]]
), alpha = 0.8) +
ggplot2::labs(x = paste(par, group, groupVal)) +
ggplot2::coord_cartesian(xlim = xlims[[par]]) +
pcv_theme() +
ggplot2::theme(
axis.text.x.bottom = ggplot2::element_text(angle = 0),
legend.position = "none", axis.title.y = ggplot2::element_blank()
)
if (USEPRIOR) {
p <- p + ggplot2::scale_fill_manual(values = c("black", palettes[[groupVal]])) +
ggplot2::scale_color_manual(values = c("black", palettes[[groupVal]]))
} else {
p <- p + ggplot2::scale_fill_manual(values = palettes[[groupVal]]) +
ggplot2::scale_color_manual(values = palettes[[groupVal]])
}
return(p)
})
return(groupPlots)
})
if (patch) {
ncol_patch <- 1 + length(params)
nrow_patch <- length(postPlots)
patchPlot <- growthTrendPlots[[1]] + postPlots[[1]]
if (length(unique(d[[group]])) > 1) {
for (i in 2:length(growthTrendPlots)) {
patchPlot <- patchPlot + growthTrendPlots[[i]] + postPlots[[i]]
}
}
out <- patchPlot + patchwork::plot_layout(ncol = ncol_patch, nrow = nrow_patch)
} else {
out <- list(growthTrendPlots, postPlots)
}
return(out)
}
#' Prior extraction in distPlot
#' @keywords internal
#' @noRd
.distPlotPriorExtraction <- function(fits, priors, d, group, params, x) {
if (is.null(priors)) {
return(list("prior_df" = NULL, "UP" = FALSE))
}
if (all(unlist(lapply(fits, function(fit) nrow(brms::prior_draws(fit)) < 1)))) {
# if no models were fit with sample_prior
if (!is.null(priors)) { # if prior is supplied as argument
USEPRIOR <- TRUE
if (!methods::is(priors[[1]], "list")) {
priors <- lapply(seq_along(unique(d[[group]])), function(i) priors)
names(priors) <- unique(d[[group]])
}
prior_df <- do.call(cbind, lapply(names(priors), function(nm) {
nmp <- priors[[nm]]
nm_res <- setNames(data.frame(do.call(cbind, lapply(names(nmp), function(nmpn) {
return(nmp[[nmpn]])
}))), paste0(names(nmp), "_", nm))
return(nm_res)
}))
prior_df[[x]] <- 0
} else {
USEPRIOR <- FALSE
}
} else { #* `need to fit some models with sample_prior and see how this works with them`
prior_df <- brms::prior_draws(fits[[1]])
prior_df <- prior_df[, grepl(paste0("b_", paste0(params, collapse = "|")), colnames(prior_df))]
colnames(prior_df) <- gsub(group, "", colnames(prior_df))
colnames(prior_df) <- gsub("^b_", "", colnames(prior_df))
prior_df[[x]] <- 0
USEPRIOR <- TRUE
}
return(list("prior_df" = prior_df, "UP" = USEPRIOR))
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.