R/timeseries_histogram_plotting.R

Defines functions gam0_hist_viterbi norm_hist timeseries_plot

Documented in gam0_hist_viterbi norm_hist timeseries_plot

#' Plot timeseries of observations
#'
#' @param x The data to be fit with an HMM in the form of a 3D array. The
#'   first index (row) corresponds to time, the second (column) to the
#'   variable number, and the third (matrix number) to the subject number.
#' @param states A matrix with the columns containing the sequence of states
#' that generated the data in `x` for the given subject.
#' @param num_subjects The number of subjects/trials that generated the data.
#' @param num_variables The number of variables in the data.
#' @param variable_names A vector containing the names of the variables in the
#'   data `x`.
#' @param subject_names A vector containing the names of the subjects generating
#'   the data `x`.
#' @param xaxis A list containing a list for each subject containing vectors for
#'   each variable of the desired minimum and maximum x-axis value.
#' @param yaxis A list containing a list for each subject containing vectors for
#'   each variable of the desired minimum and maximum y-axis value.
#'
#' @return A grid of plots of the time series for each subject and variable.
#' @export
#' @import RColorBrewer
#' @importFrom ggplot2 ggplot aes ggtitle theme labs geom_point geom_line

timeseries_plot <- function(x, states, num_subjects, num_variables,
                            variable_names = c("Var 1", "Var 2", "Var 3"),
                            subject_names = c("Subject 1", "Subject 2",
                                              "Subject 3", "Subject 4"),
                            xaxis = list(list(c(1,900))), yaxis = NULL) {
  plots  <- list()
  for (i in 1:num_subjects) {
    data       <- data.frame('State' = as.factor(states[, i]))
    data$Time  <- 1:nrow(states)
    for (j in 1:num_variables) {
      data$Observation <- x[, j, i]
      p <- ggplot(data, ggplot2::aes(x = Time, y = Observation)) +
        ggplot2::theme_light() +
        ggplot2::scale_color_brewer(palette = "Set1") +
        ggtitle(subject_names[i]) +
        theme(plot.title = ggplot2::element_text(hjust = 0.5)) +
        labs(x = 'Time (s)', y = variable_names[j]) +
        geom_point(ggplot2::aes(color = State)) +
        geom_line(colour = 'grey', alpha = 0.8, lwd = 0.4) +
        ggplot2::coord_cartesian(xlim = xaxis[[i]][[j]], ylim = yaxis[[i]][[j]])
      plots <- c(plots, list(p))
    }
  }
  plots
  #egg::ggarrange(plotlist = plots, common.legend = TRUE, legend = "bottom")
}


#' Plot histograms
#'
#' This function plots the histograms for each subject and variable with the
#' true state dependent distributions overlayed.
#'
#' @param sample The sample generated by `norm_generate_sample()`.
#' @param num_states The number of states in the desired HMM.
#' @param num_variables The number of variables in the data.
#' @param num_subjects The number of subjects/trials that generated the data.
#' @param hmm A list of parameters that specify the normal HMM, including
#'   `num_states`, `num_variables`, `num_subjects`, `mu`, `sigma`, `gamma`,
#'   `delta`.
#' @param width The width of the histogram bins.
#' @param x_step A value indicating the step length for the range of
#'   observation values.
#'
#' @return Histograms of the data with the distributions overlayed.
#' @export
#' @importFrom ggplot2 ggplot aes ggtitle theme_bw labs geom_line
#' @importFrom stats dnorm
norm_hist <- function(sample, num_states, num_variables, num_subjects,
                      hmm, width = 1, x_step = 0.2) {
  states <- sample$state
  x      <- sample$observ
  n      <- nrow(x)
  Var    <- c("Variable 1", "Variable 2", "Variable 3", "Variable 4")
  Sub    <- c("Subject 1", "Subject 2", "Subject 3", "Subject 4")
  plots  <- list()
  for (i in 1:num_subjects) {
    subvar_data <- data.frame('State' = as.factor(states[, i]))
    for (j in 1:num_variables) {
      subvar_data$Observation <- x[, j, i]
      h <- ggplot() +
        geom_histogram(data = subvar_data,
                       aes(x = Observation),
                       binwidth = width,
                       colour = "grey",
                       fill = "white") +
        ggplot2::scale_color_brewer(palette = "Set1") +
        theme_bw() +
        ggtitle(Sub[i]) +
        theme(panel.grid.major = ggplot2::element_blank(),
              panel.grid.minor = ggplot2::element_blank(),
              plot.title = ggplot2::element_text(hjust = 0.5)) +
        labs(x = Var[j], y = '')

      xfit <- seq(min(subvar_data$Observation), max(subvar_data$Observation),
                  by = x_step)
      marginal <- numeric(length(xfit))
      for (k in 1:num_states) {
        yfit     <- dnorm(xfit, hmm$mu[[j]][i, k], hmm$sigma[[j]][i, k])
        yfit     <- yfit*sum(subvar_data$State == k)*width
        df       <- data.frame('xfit' = xfit, 'yfit' = yfit,
                               col = as.factor(rep(k, length(xfit))))
        h        <- h + geom_line(data = df, aes(xfit, yfit, colour = col),
                                  lwd = 0.7)
        marginal <- marginal + yfit
      }
      h     <- h + labs(color = "State")
      df    <- data.frame('xfit' = xfit, 'yfit' = marginal)
      h     <- h + geom_line(data = df, aes(xfit, yfit), col="black", lwd=0.7)
      plots <- c(plots, list(h))
    }
  }
  plots
  #ggarrange(plotlist = plots, common.legend = TRUE, legend = "bottom")
}



#' Plot histograms
#' @return Histograms of the data with the distributions overlayed.
#' @export
#' @importFrom ggplot2 ggplot aes ggtitle theme_bw labs geom_line
#' @importFrom stats dnorm

gam0_hist_viterbi <- function(x, viterbi, num_states, num_subjects,
                              num_variables, hmm, state_dep_dist_pooled = FALSE,
                              width = 1, n = 100, level = 0.975, x_step = 0.2) {
  n     <- nrow(x)
  Var   <- c("Variable 1", "Variable 2", "Variable 3", "Variable 4")
  Sub   <- c("Subject 1", "Subject 2", "Subject 3", "Subject 4")
  plots <- list()
  for (i in 1:num_subjects) {

    subvar_data <- data.frame('State' = as.factor(viterbi[, i]))

    s_ind   <- i

    if (state_dep_dist_pooled) {

      s_ind <- 1

    }
    for (j in 1:num_variables) {
      subvar_data$Observation <- x[, j, i]
      h <- ggplot2::ggplot() +

      ggplot2::geom_histogram(data = subvar_data,
                              aes(x = Observation),
                              binwidth = width,
                              colour = "cornsilk4",
                              fill = "white") +


      ggplot2::theme_bw() +
      ggplot2::xlim(0.0000001, 4) +

      ggplot2::ggtitle(Sub[i]) +

      ggplot2::theme(panel.grid.major = ggplot2::element_blank(),
                     panel.grid.minor = ggplot2::element_blank(),
                     plot.title = ggplot2::element_text(hjust = 0.5)) +

      ggplot2::labs(x = Var[j], y = '')
      xfit <- seq(min(subvar_data$Observation, na.rm = TRUE),
                  max(subvar_data$Observation, na.rm = TRUE),
                  by = x_step)
      yfit <- c(hmm$zweight[[j]][i], dgamma(xfit[2:length(xfit)],
                                            shape = hmm$alpha[[j]][s_ind, 1],
                                            scale = hmm$theta[[j]][s_ind, 1])*

                (1 - hmm$zweight[[j]][i]))
      yfit <- yfit * sum(subvar_data$State == 1) * width
      df   <- data.frame('xfit' = xfit, 'yfit' = yfit,
                         col = as.factor(rep(1, length(xfit))))
      h    <- h + ggplot2::geom_line(data = df,
                                     aes(xfit, yfit, colour = col),
                                     lwd = 0.7)
      marginal <- yfit
      for (k in 2:num_states) {

        yfit     <- dgamma(xfit, shape = hmm$alpha[[j]][s_ind, k],

                           scale = hmm$theta[[j]][s_ind, k])

        yfit     <- yfit * sum(subvar_data$State == k) * width

        df       <- data.frame('xfit' = xfit, 'yfit' = yfit,

                               col = as.factor(rep(k, length(xfit))))

        h        <- h + ggplot2::geom_line(data = df,
                                           aes(xfit, yfit, colour = col),

                                           lwd = 0.7)


        marginal <- marginal + yfit


      }


      h  <- h + labs(color = "State")


      df <- data.frame('xfit' = xfit, 'yfit' = marginal)


      h  <- h + geom_line(data = df, aes(xfit, yfit), col="black", lwd=0.7)

      plots <- c(plots, list(h))


    }


  }

  list(plots = plots, xfit = xfit, marginal = marginal)

}
simonecollier/lizardHMM documentation built on Dec. 23, 2021, 2:24 a.m.