## Utility functions
#' Convert a list of probabilities into an m x m transition probability matrix
#'
#' @param x an mHMM_cont object
#'
#' @return m x m transition probability matrix, where m is the number of hidden states
#' @export
get_subject_tpm <- function(x, ...) {
UseMethod("get_subject_tpm", x)
}
#' @export
get_subject_tpm.mHMM_cont <- function(x) {
# Select probs
p <- apply(x$gamma_prob_bar,2, mean)
# Select states
m <- x$input$m
# Create m x m matrix and return
matrix(
p,
ncol = m,
nrow = m,
byrow = TRUE
)
}
#' Burn function for model output
#'
#' @param x an mHMM_cont object
#'
#' @return mHMM_cont object for which the burn-in samples have been removed
#' for each parameter.
#' @export
burn <- function(x, ...) {
UseMethod("burn", x)
}
#' @export
burn.mHMM_cont <- function(x) {
# Number of burn_in samples
burn_in <- x$input$burn_in
J <- x$input$J
# For each element (and nested elements), remove the burn-in samples
inp <- x$input
x$input <- NULL
# Get data types for each
dtypes <- vapply(x, function(x) mode(x), "string")
for(idx in seq_along(x)) {
# If character, pass
if(mode(x[[idx]]) == "character") {
next
# If label switching or state orders, pass
} else if(names(x)[idx] == "label_switch") {
next
} else if(names(x)[idx] == "state_orders") {
next
} else if(names(x)[idx] == "gamma_naccept") {
next
} else if (mode(x[[idx]]) == "list") {
for(subj_idx in seq_along(x[[idx]])) {
x[[idx]][[subj_idx]] <- x[[idx]][[subj_idx]][(burn_in+1):J,]
}
} else if (mode(x[[idx]]) == "numeric") {
x[[idx]] <- x[[idx]][(burn_in+1):J,]
}
}
# Create new object and return
x$input <- inp
class(x) <- "mHMM_cont"
# Return
return(x)
}
#' Retrieve MAP estimates for parameters
#'
#' @param x an mHMM_cont object
#'
#' @return List. Maximum a Posteriori (MAP) estimates for each parameter.
#' names of the elements are identical of the names of the input
#' parameters
#' @export
MAP <- function(x, ...) {
UseMethod("MAP", x)
}
#' @export
MAP.mHMM_cont <- function(x) {
# Remove burn-in samples
feelthebern <- burn(x)
# Remove input
feelthebern$input <- NULL
# Get data types for each
dtypes <- vapply(feelthebern, function(x) mode(x), "string")
# Remove character types
feelthebern <- feelthebern[!dtypes == "character"]
# For each, collect MAP
map_out <- vector("list", length(feelthebern))
# Names
names(map_out) <- names(feelthebern)
for(param_idx in seq_along(feelthebern)) {
if(names(feelthebern)[param_idx] %in% c("label_switch", "input", "state_orders")) {
next
}
# if numeric, compute MAP
if(mode(feelthebern[[param_idx]]) == "numeric") {
map_out[[param_idx]][["mean"]] <- unname(apply(feelthebern[[param_idx]], 2, mean))
map_out[[param_idx]][["median"]] <- unname(apply(feelthebern[[param_idx]], 2, median))
map_out[[param_idx]][["SE"]] <- unname(apply(feelthebern[[param_idx]], 2, sd))
} else {
map_out[[param_idx]] <- lapply(feelthebern[[param_idx]], function(x) {
list(
"mean" = unname(apply(x, 2, mean)),
"median" = unname(apply(x, 2, median)),
"SE" = unname(apply(x, 2, sd))
)
})
}
}
# Remove label switch
map_out$label_switch <- NULL
# Return
return(map_out)
}
#' Plot histograms of between-subject means
#'
#' @param x an mHMM_cont object
#' @param param string. Which of the parameters should be plotted? Must be one of 'emiss_mu_bar', 'gamma_int_bar', 'emiss_var_bar' or 'emiss_varmu_bar'
#' @param var int. Which emission distribution should be plotted? Must be an integer equal to or less than the total number of emission distributions. Index of variables is the same as the order in which they appear in the input data given to the mHMM.
#' @param ground_truth numeric. Vector of ground-truth values for the parameters equal to the number of emission distributions. Will be added as a red dotted line.
#'
#' @importFrom magrittr '%>%'
#'
#' @return plots a histogram of the parameter of interest for each of the hypothesized latent states
#' @export
plot_posterior <- function(x, ...) {
UseMethod("plot_posterior", x)
}
#' @export
plot_posterior.mHMM_cont <- function(x, param = c("emiss_mu_bar",
"gamma_int_bar",
"emiss_var_bar",
"emiss_varmu_bar"),
var = 1, ground_truth = NULL) {
if (!requireNamespace("ggplot2", quietly = TRUE)) {
stop("Package \"ggplot2\" needed for this function to work. Please install it.",
call. = FALSE)
}
if (!requireNamespace("assertthat", quietly = TRUE)) {
stop("Package \"ggplot2\" needed for this function to work. Please install it.",
call. = FALSE)
}
# Check length of ground_truth means. Should be equal to the number of emission distributions
if(!is.null(ground_truth)) {
assertthat::assert_that(length(ground_truth) == x$input$n_dep,
msg="Supplied more ground truth means than that there are emission distributions.")
}
# Match arg
param <- match.arg(param)
# Remove burn-in samples
# Get parameter posterior
betmu <- x %>%
burn() %>%
.[[param]]
# Name of the variable (unless gamma)
if(param != "gamma_int_bar") {
var_name <- x$input$dep_labels[var]
# Cannot select a variable using an index that doesn't exist
assertthat::assert_that(var <= length(betmu), msg = paste0(
"You selected ", var,
". But there are only ", length(betmu), " independent variables."
))
betmu <- as.data.frame(betmu[[var]])
} else {
var_name <- "Between-subject TPM (intercepts)"
betmu <- as.data.frame(betmu)
}
# To long format, plot etc.
betmu_long <- vector("list", ncol(betmu))
for(cn in seq_along(colnames(betmu))) {
n <- nrow(betmu)
betmu_long[[cn]] <- data.frame(
"var" = rep(colnames(betmu)[cn], n),
"val" = betmu[,colnames(betmu)[cn]],
"mval" = mean(betmu[,colnames(betmu)[cn]]),
"gtv" = ifelse(!is.null(ground_truth),
ground_truth[cn],
NA)
)
}
# Bind
outd <- do.call(rbind.data.frame, betmu_long)
# Plot
outp <- ggplot2::ggplot(outd, ggplot2::aes(x=val)) +
ggplot2::geom_histogram() +
ggplot2::theme_bw() +
ggplot2::facet_wrap(". ~ var", ncol=1) +
ggplot2::geom_vline(data=outd, ggplot2::aes(xintercept=mval),
linetype = "dashed", color = "#2b8cbe",
size = 1.1) +
ggplot2::ggtitle("Between-subject means for ", var_name)
# If ground truth supplied
if(!is.null(ground_truth)) {
outp +
ggplot2::geom_vline(data = outd, ggplot2::aes(xintercept=gtv),
linetype = "dotdash", color = "#e34a33",
size = 1.1)
} else {
outp
}
}
#' Make a trace plot of the data
#'
#' @param x an mHMM_cont object
#' @param param string. Which of the parameters should be plotted? Must be one of 'emiss_mu_bar', 'gamma_int_bar', 'emiss_var_bar' or 'emiss_varmu_bar'
#' @param var int. Which emission distribution should be plotted? Must be an integer equal to or less than the total number of emission distributions. Index of variables is the same as the order in which they appear in the input data given to the mHMM.
#'
#' @importFrom magrittr '%>%'
#'
#' @return XX
#'
#' @export
trace_plot <- function(x, ...) {
UseMethod("trace_plot", x)
}
#' @export
trace_plot.mHMM_cont <- function(x, param = c("emiss_mu_bar",
"gamma_int_bar",
"emiss_var_bar",
"emiss_varmu_bar"),
var = 1) {
if (!requireNamespace("ggplot2", quietly = TRUE)) {
stop("Package \"ggplot2\" needed for this function to work. Please install it.",
call. = FALSE)
}
if (!requireNamespace("assertthat", quietly = TRUE)) {
stop("Package \"assertthat\" needed for this function to work. Please install it.",
call. = FALSE)
}
# Match arg
param <- match.arg(param)
# Remove burn-in samples
# Get parameter posterior
betmu <- x %>%
burn() %>%
.[[param]]
# Name of the variable (unless gamma)
if(param != "gamma_int_bar") {
var_name <- x$input$dep_labels[var]
# Cannot select a variable using an index that doesn't exist
assertthat::assert_that(var <= length(betmu), msg = paste0(
"You selected ", var,
". But there are only ", length(betmu), " independent variables."
))
betmu <- as.data.frame(betmu[[var]])
} else {
var_name <- "Between-subject TPM (intercepts)"
betmu <- as.data.frame(betmu)
}
# To long format, plot etc.
betmu_long <- vector("list", ncol(betmu))
for(cn in seq_along(colnames(betmu))) {
n <- nrow(betmu)
betmu_long[[cn]] <- data.frame(
"iteration" = 1:n,
"var" = rep(colnames(betmu)[cn], n),
"val" = betmu[,colnames(betmu)[cn]]
)
}
# Bind
outd <- do.call(rbind.data.frame, betmu_long)
# Plot
ggplot2::ggplot(outd, ggplot2::aes(x=iteration, y=val)) +
ggplot2::geom_line() +
ggplot2::facet_wrap(". ~ var") +
ggplot2::ggtitle(var_name)
}
#' Compute upper and lower values of the 95\% credible interval
#'
#' @param x posterior values for a parameter
#' @param type either one of '0.95' for 95\% CI or '0.99' for 99\% CI
#'
#' @return numeric vector. Element 1 is the lower 95\% CI, element 2 is the upper 95\% CI.
#'
#' @export
credible_interval <- function(x, type=c("0.95", "0.99")) {
type <- match.arg(type)
if(type == "0.95") {
return(unname(apply(x, 2, function(y) quantile(y, c(0.025, 0.975)))))
} else {
return(unname(apply(x, 2, function(y) quantile(y, c(0.005, 0.995)))))
}
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.