R/plot.mHMM.R

Defines functions plot.mHMM

Documented in plot.mHMM

#' Plotting the posterior densities for a fitted multilevel HMM
#'
#' \code{plot.mHMM} plots the posterior densities for a fitted multilevel hidden
#' Markov model for the group and subject level parameters simultaneously. The
#' plotted posterior densities are either for the transition probability matrix
#' gamma, or for the emission distribution probabilities (categorical data) or
#' means and standard deviation (continuous data).
#'
#' Note that the standard deviation of the (variable and) state specific Normal
#' emission distribution in case of continuous data is fixed over subjects.
#' Hence, for the standard deviation, only the posterior distribution at the
#' group level is plotted.
#'
#' @param x Object of class \code{mHMM}, generated by the function
#'   \code{\link{mHMM}}.
#' @param component String specifying if the displayed posterior densities
#'   should be for the transition probability matrix gamma (\code{component =
#'   "gamma"}), or for the emission distribution probabilities (\code{component
#'   = "emiss"}). In case of the latter and the model is based on multiple
#'   dependent variables, the user has to indicate for which dependent variable
#'   the posterior densities have to be plotted, see \code{dep}.
#' @param dep Integer specifying for which dependent variable the posterior
#'   densities should be plotted. Only required if one wishes to plot the
#'   emission distribution probabilities and the model is based on multiple
#'   dependent variables. Defaults to \code{dep = 1}.
#' @param col Vector of colors for the posterior density lines. If one is
#'   plotting the posterior densities for gamma, or the posterior densities of
#'   Normally distributed emission probabilities, the vector has length \code{m}
#'   (i.e., number of hidden states). If one is plotting the posterior densities
#'   for categorical emission probabilities, the vector has length
#'   \code{q_emiss[k]} (i.e., the number of outcome categories for the dependent
#'   variable \code{k}).
#' @param dep_lab Optional string when plotting the posterior
#'   densities of the emission probabilities with length 1, denoting the label
#'   for the dependent variable plotted. Automatically obtained from the input
#'   object \code{x} when not specified.
#' @param cat_lab Optional vector of strings when plotting the posterior
#'   densities of categorical emission probabilities, denoting the labels of the
#'   categorical outcome values. Automatically generated when not provided.
#' @param lwd1 Positive number indicating the line width of the posterior
#'   density at the group level.
#' @param lwd2 Positive number indicating the line width of the posterior
#'   density at the subject level.
#' @param lty1 Positive number indicating the line type of the posterior
#'   density at the group level.
#' @param lty2 Positive number indicating the line type of the posterior
#'   density at the subject level.
#' @param burn_in An integer which specifies the number of iterations to discard
#'   when obtaining the model parameter summary statistics. When left
#'   unspecified, the burn in period specified at creating the \code{mHMM}
#'   object with the function \code{\link{mHMM}} will be used.
#' @param legend_cex A numerical value giving the amount by which plotting text
#'   and symbols in the legend should be magnified relative to the default.
#' @param ... Arguments to be passed to methods (see \code{\link[graphics]{par}})
#'
#' @return \code{plot.mHMM} returns a plot of the posterior densities. Depending
#'   on whether (\code{component = "gamma"}) or (\code{component = "emiss"}),
#'   the plotted posterior densities are either for the transition probability
#'   matrix gamma or for the emission distribution probabilities, respectively.
#'
#' @seealso \code{\link{mHMM}} for fitting the multilevel hidden Markov
#'   model, creating the object \code{mHMM}.
#'
#' @examples
#' ###### Example on package example data, see ?nonverbal
#' # First run the function mHMM on example data
#' \donttest{
#' # specifying general model properties:
#' m <- 2
#' n_dep <- 4
#' q_emiss <- c(3, 2, 3, 2)
#'
#' # specifying starting values
#' start_TM <- diag(.8, m)
#' start_TM[lower.tri(start_TM) | upper.tri(start_TM)] <- .2
#' start_EM <- list(matrix(c(0.05, 0.90, 0.05, 0.90, 0.05, 0.05), byrow = TRUE,
#'                         nrow = m, ncol = q_emiss[1]), # vocalizing patient
#'                  matrix(c(0.1, 0.9, 0.1, 0.9), byrow = TRUE, nrow = m,
#'                         ncol = q_emiss[2]), # looking patient
#'                  matrix(c(0.90, 0.05, 0.05, 0.05, 0.90, 0.05), byrow = TRUE,
#'                         nrow = m, ncol = q_emiss[3]), # vocalizing therapist
#'                  matrix(c(0.1, 0.9, 0.1, 0.9), byrow = TRUE, nrow = m,
#'                         ncol = q_emiss[4])) # looking therapist
#'
#' # Run a model without covariate(s):
#' out_2st <- mHMM(s_data = nonverbal, gen = list(m = m, n_dep = n_dep,
#'                 q_emiss = q_emiss), start_val = c(list(start_TM), start_EM),
#'                 mcmc = list(J = 11, burn_in = 5))
#'
#' ## plot the posterior densities for gamma
#' plot(out_2st, component = "gamma")
#' }
#'
#' @export
#'
plot.mHMM <- function(x, component = "gamma", dep = 1, col, dep_lab, cat_lab,
                      lwd1 = 2, lwd2 = 1, lty1 = 1, lty2 = 3,
                      legend_cex, burn_in, ...){
  if (!is.mHMM(x)){
    stop("The input object x should be from the class mHMM, obtained with the function mHMM.")
  }
  if(sum(objects(x$PD_subj[[1]]) %in% "log_likl") != 1){
    stop("The input object is created using an earlier version of the mHMMbayes package. Please re-run the function mHMM with the current package version, or post-process the object using the earlier version of the package.")
  }
  if (component != "gamma" & component != "emiss"){
    stop("The input specified under component should be a string, restrectid to state either gamma or emiss.")
  }
  object <- x
  input   <- x$input
  n_subj  <- input$n_subj
  if (missing(burn_in)){
    burn_in <- input$burn_in
  }
  J       <- input$J
  if (burn_in >= (J-1)){
    stop(paste("The specified burn in period should be at least 2 points smaller
               compared to the number of iterations J, J =", J))
  }
  old_par <- graphics::par(no.readonly =TRUE)
  on.exit(graphics::par(old_par))
  m       <- input$m
  n_dep   <- input$n_dep
  data_distr <- input$data_distr

  if(component == "gamma"){
    if (missing(col)){
      state_col <- grDevices::rainbow(m)
    } else {
      state_col <- col
    }
    if(m > 3){
      graphics::par(mfrow = c(2,ceiling(m/2)), mar = c(4,2,3,1) + 0.1, mgp = c(2,1,0))
    } else {
      graphics::par(mfrow = c(1,m), mar = c(4,2,3,1) + 0.1, mgp = c(2,1,0))
    }
    for(i in 1:m){
      max <- 0
      for(j in 1:m){
        new <- max(stats::density(object$gamma_prob_bar[burn_in:J, m * (i-1) + j])$y)
        if(new > max){max <- new}
      }
      graphics::plot.default(x = 1, ylim = c(0, max), xlim = c(0,1), type = "n", cex = .8,  main =
             paste("From state", i, "to state ..."), yaxt = "n", ylab = "",
           xlab = "Transition probability", ...)
      graphics::title(ylab="Density", line=.5)
      for(j in 1:m){
        graphics::lines(stats::density(object$gamma_prob_bar[burn_in:J,m * (i-1) + j]),
              type = "l", col = state_col[j], lwd = lwd1, lty = lty1)
        for(s in 1:n_subj){
          graphics::lines(stats::density(object$PD_subj[[s]]$trans_prob[burn_in:J,(m * (i-1) + j)]),
                type = "l", col = state_col[j], lwd = lwd2, lty = lty2)
        }
      }
      graphics::legend("topright", col = state_col, legend = paste("To state", 1:m),
             bty = 'n', lty = 1, lwd = 2, cex = .8)
    }
  } else if (component == "emiss"){
    if (missing(dep_lab)){
      dep_lab <- input$dep_labels[dep]
    }
    if (data_distr == "categorical"){
      q_emiss <- input$q_emiss
      if (missing(cat_lab)){
        cat_lab <- paste("Category", 1:q_emiss[dep])
      }
      start <- c(0, q_emiss * m)
      start2 <- c(0, seq(from = (q_emiss[dep]-1) * 2, to = (q_emiss[dep]-1) * 2 * m, by = (q_emiss[dep]-1) * 2))
      if (missing(col)){
        cat_col <- grDevices::rainbow(q_emiss[dep])
      } else {
        cat_col <- col
      }
      if(m > 3){
        graphics::par(mfrow = c(2,ceiling(m/2)), mar = c(4,2,3,1) + 0.1, mgp = c(2,1,0))
      } else {
        graphics::par(mfrow = c(1,m), mar = c(4,2,3,1) + 0.1, mgp = c(2,1,0))
      }
      for(i in 1:m){
        # determining the scale of the y axis
        max <- 0
        for(q in 1:q_emiss[dep]){
          new <- max(stats::density(object$emiss_prob_bar[[dep]][burn_in:J,q_emiss[dep] * (i-1) + q])$y)
          if(new > max){max <- new}
        }
        # set plotting area
        graphics::plot.default(x = 1, ylim = c(0, max), xlim = c(0,1), type = "n",
                               main = paste(dep_lab, ", state", i),
                               yaxt = "n", ylab = "", xlab = "Conditional probability", ...)
        graphics::title(ylab="Density", line=.5)
        for(q in 1:q_emiss[dep]){
          # add density curve for population level posterior distribution
          graphics::lines(stats::density(object$emiss_prob_bar[[dep]][burn_in:J,q_emiss[dep] * (i-1) + q]),
                          type = "l", col = cat_col[q], lwd = lwd1, lty = lty1)
          # add density curves for subject posterior distributions
          for(s in 1:n_subj){
            graphics::lines(stats::density(object$PD_subj[[s]]$cat_emiss[burn_in:J,(sum(start[1:dep])
                                                                                    + (i-1)*q_emiss[dep] + q)]),
                            type = "l", col = cat_col[q], lwd = lwd2, lty = lty2)
          }
        }
        graphics::legend("topright", col = cat_col, legend = cat_lab, bty = 'n', lty = 1, lwd = 2, cex = .7)
      }
    } else if (data_distr == "continuous"){
      graphics::par(mfrow = c(1,2), mar = c(4,2,3,1) + 0.1, mgp = c(2,1,0))
      if (missing(col)){
        state_col <- grDevices::rainbow(m)
      } else {
        state_col <- col
      }
      # PLOTTING POSTERIOR MEAN
      max_y <- 0
      for(i in 1:m){
        new <- max(stats::density(object$emiss_mu_bar[[dep]][burn_in:J,i])$y)
        if(new > max_y){max_y <- new}
      }
      quantiles <- apply(object$emiss_mu_bar[[dep]][burn_in:J,], 2, stats::quantile, probs = c(0.025, 0.975))
      min_x <- min(quantiles)
      max_x <- max(quantiles)
      # set plotting area
      graphics::plot.default(x = 1, ylim = c(0, max_y), xlim = c(min_x, max_x), type = "n",
                             main = paste("Posterior density mean \n", dep_lab),
                             yaxt = "n", ylab = "", xlab = "Mean", ...)
      graphics::title(ylab="Density", line=.5)
      for(i in 1:m){
        # add density curve for population level posterior distribution
        graphics::lines(stats::density(object$emiss_mu_bar[[dep]][burn_in:J,i]),
                        type = "l", col = state_col[i], lwd = lwd1, lty = lty1)
        # add density curves for subject posterior distributions
        for(s in 1:n_subj){
          graphics::lines(stats::density(object$PD_subj[[s]]$cont_emiss[burn_in:J,((dep-1)*m + i)]),
                          type = "l", col = state_col[i], lwd = lwd2, lty = lty2)
        }
      }
      graphics::legend("topright", col = state_col, legend = paste("State", 1:m), bty = 'n', lty = 1, lwd = 2, cex = .7)

    # PLOTTING POSTERIOR SD
      max_y <- 0
      for(i in 1:m){
        new <- max(stats::density(object$emiss_sd_bar[[dep]][burn_in:J,i])$y)
        if(new > max_y){max_y <- new}
      }
      quantiles <- apply(object$emiss_sd_bar[[dep]][burn_in:J,], 2, stats::quantile, probs = c(0.025, 0.975))
      min_x <- min(quantiles)
      max_x <- max(quantiles)
      # set plotting area
      graphics::plot.default(x = 1, ylim = c(0, max_y), xlim = c(min_x, max_x), type = "n",
                             main = paste("Posterior density SD \n", dep_lab),
                             yaxt = "n", ylab = "", xlab = "Standard deviation", ...)
      graphics::title(ylab="Density", line=.5)

      for(i in 1:m){
        # add density curve for population level posterior distribution
        graphics::lines(stats::density(object$emiss_sd_bar[[dep]][burn_in:J,i]),
                        type = "l", col = state_col[i], lwd = lwd1, lty = lty1)
      }
      # fixed over subjects, so population level only
      graphics::legend("topright", col = state_col, legend = paste("State", 1:m), bty = 'n', lty = 1, lwd = 2, cex = .7)

    } else if (data_distr == 'count'){
      graphics::par(mfrow = c(1,1), mar = c(4,2,3,1) + 0.1, mgp = c(2,1,0))
      if (missing(col)){
        state_col <- grDevices::rainbow(m)
      } else {
        state_col <- col
      }
      # PLOTTING POSTERIOR MEAN IN THE NATURAL SCALE
      max_y <- 0
      for(i in 1:m){
        new <- max(stats::density(object$emiss_mu_bar[[dep]][burn_in:J,i])$y)
        if(new > max_y){max_y <- new}
      }
      quantiles <- apply(object$emiss_mu_bar[[dep]][burn_in:J,], 2, stats::quantile, probs = c(0.025, 0.975))
      min_x <- min(quantiles)
      max_x <- max(quantiles)
      # set plotting area
      graphics::plot.default(x = 1, ylim = c(0, max_y), xlim = c(min_x, max_x), type = "n",
                             main = paste("Posterior density mean \n", dep_lab),
                             yaxt = "n", ylab = "", xlab = "Mean", ...)
      graphics::title(ylab="Density", line=.5)
      for(i in 1:m){
        # add density curve for population level posterior distribution
        graphics::lines(stats::density(object$emiss_mu_bar[[dep]][burn_in:J,i]),
                        type = "l", col = state_col[i], lwd = lwd1, lty = lty1)
        # add density curves for subject posterior distributions
        for(s in 1:n_subj){
          graphics::lines(stats::density(object$PD_subj[[s]]$count_emiss[burn_in:J,((dep-1)*m + i)]),
                          type = "l", col = state_col[i], lwd = lwd2, lty = lty2)
        }
      }
      graphics::legend("topright", col = state_col, legend = paste("State", 1:m), bty = 'n', lty = 1, lwd = 2, cex = .7)
    }
  }
}

Try the mHMMbayes package in your browser

Any scripts or data that you put into this service are public.

mHMMbayes documentation built on May 29, 2024, 6:41 a.m.