R/tempo_plotting.R

Defines functions tempo_plot_means tempo_plot_covs_pmf tempo_trace tempo_plot_posteriors tempo_plot_covariates tempo_plot_pmf

Documented in tempo_plot_covs_pmf tempo_plot_means tempo_plot_posteriors tempo_trace

#' Plot predicted mean transition times
#' 
#' This function plots the posteriors of mean transition times for observations
#' included in the output from \code{tempo_derive()} (specified using the 
#' \code{obs_idx} argument to tempo_derive). In order to use this function,
#' the \code{quantity} argument needs to include \code{"mean"} at the time
#' that \code{tempo_derive()} was called.
#' 
#' @param x output from \code{tempo_derive()}. \code{tempo_derive()} needs
#' to have been called with the \code{quantity} argument equal to or containing 
#' \code{"mean"}.
#' 
#' @param x_lim optional; A vector of length two of lower and upper bounds
#' for the x axis. Defaults to null, in which case the min and max of the data 
#' are used to generate the x axis limits.
#' 
#' @param y_axis_order optional; a vector of observation IDs  specifying the 
#' order in which they should be plotted. Defaults to \code{NULL}, in which case
#' ascending alphanumeric order is used.
#' 
#' @return a ggplot2 plot of posteriors for mean transition time predictions.
#' @importFrom tidyr gather separate
#' @importFrom tibble as_tibble rownames_to_column
#' @importFrom ggridges geom_density_ridges
#' @importFrom forcats fct_relevel
#' @importFrom scales rescale
#' @importFrom ggthemes theme_hc
#' @export
#' @name tempo_plot_means
#' @rdname tempo_plot_means
tempo_plot_means <- function(x, x_lim = NULL, y_axis_order = NULL) {
  d_pmf_mean <- as_tibble(aperm(x$mean, perm = c(2, 1)),
                          rownames = "obs_id") %>% 
    gather("iter", "value", -"obs_id") %>% 
    group_by(.data$obs_id) %>% 
    mutate(mean_value = mean(.data$value),
           dif_from_mean = .data$value - .data$mean_value) %>% 
    ungroup()
  
  if(is.null(y_axis_order)) {
    y_axis_order <- sort(unique(d_pmf_mean$obs_id))
  }
  
  if(is.null(x_lim)) {
    x_lim <- round(quantile(d_pmf_mean$value, c(0, 0.995)))
  }

  grand_mean <- mean(d_pmf_mean$mean_value)
  diffs <- unique(d_pmf_mean$mean_value - grand_mean)
  zero_scaled <- rescale(0, from = range(diffs), to = c(0,1))
  
  ggplot(d_pmf_mean) +
    geom_vline(linetype = "dashed",
               aes(colour = "mean",
                   xintercept = grand_mean),
               show.legend = TRUE) +
    geom_density_ridges(aes(x = .data$value,
                            y = fct_relevel(as.factor(.data$obs_id),
                                            y_axis_order),
                            group = .data$obs_id, 
                            fill = .data$mean_value - grand_mean),
                        alpha = 0.8,
                        size = 0.25) +
    coord_cartesian(xlim = x_lim) +
    ylab("Observation") +
    xlab("Time Step") +
    scale_y_discrete(expand = expand_scale(mult = c(0.01, 0.055))) +
    scale_fill_gradientn(name = "Average departure from\ngrand mean",
                         colors = c("#ffaa00",
                                    "#ffefd1",
                                    "#ffffff",
                                    "#488fb1",
                                    "#002733"),
                         values = c(0, 0.25, zero_scaled, 0.75, 1),
                         guide = guide_colorbar(order = 2)) +
    scale_color_manual(name = "",
                       labels = "Grand Mean",
                       values = c(mean = "black")) +
    ggtitle(expression(
      paste("Per-site posterior densities for mean transition times")
    )) +
    theme_hc(base_size = 16,
             base_family = "serif") +
    theme(legend.position = "right", legend.title=element_text(size=14),
          legend.text.align = 1)
}

#' Plot covariates and PMF
#' 
#' Plot the covariate time series (with information on predicted coefficient 
#' magnitudes and signs) for each covariate, with a plot of the predicted 
#' transition time PMF aligned below.
#' 
#' @param x output from \code{tempo_derive()}. \code{tempo_derive()} needs
#' to have been called with the \code{quantity} argument equal to or containing 
#' \code{"pmf"}
#' 
#' @param draws the output from \code{tempo_mcmc()}
#' 
#' @param obs_id The observation id (corresponding to the obs_id column in 
#' the \code{y} argument to \code{tempo_mcmc()}. \code{tempo_derive()} needs to
#' have calculated the pmf for this observation (controllable via the obs_idx 
#' argument to \code{tempo_derive()}.
#' 
#' @param confidence_intervals Boolean; Should confidence intervals be shown on the
#' pmf plot? Defaults to \code{TRUE}.
#' 
#' @param title character string; optional title to be placed at the top of the 
#' plot. Defaults to NULL in which case a default title is generated: "PMF and 
#' covariates for <obs_id>"
#' 
#' @export
#' @importFrom cowplot plot_grid draw_label ggdraw
tempo_plot_covs_pmf <- function(x,
                                draws,
                                obs_id,
                                confidence_intervals = TRUE,
                                title = NULL) {
  covs_plot <- tempo_plot_covariates(draws, obs_id)
  pmf_plot <- tempo_plot_pmf(x, obs_id, confidence_intervals)
  
  if (is.null(title)) {
    title <- sprintf("PMF and covariates for %s", obs_id) 
  }
  
  # Make title
  title_plot <- ggdraw() + 
    draw_label(
      title,
      fontface = 'bold',
      x = 0,
      hjust = 0
    ) +
    theme(
      # add margin on the left of the drawing canvas,
      # so title is aligned with left edge of first plot
      plot.margin = margin(0, 0, 0, 7)
    )
  
  main_plot <- plot_grid(plotlist = list(covs_plot, pmf_plot),
                         ncol = 1,
                         align = "v")
  
  plot_grid(title_plot, main_plot, ncol = 1, rel_heights = c(0.05, 1))
}

#' Traceplots
#' 
#' Generate traceplots of model parameters. Monitored group-level effects are
#' not plotted from this function. Only supports plotting up to 5 chains.
#' 
#' @param draws Output from \code{tempo_mcmc()}. If \code{draws} has more than 5
#' chains, only the first 5 will be plotted
#' @rdname tempo_trace
#' @name tempo_trace
#' @export
tempo_trace <- function(draws) {
  n_chains <- attr(draws, "n_chains")
  if (n_chains > 5) {
    n_chains <- 5
    warning("draws has more than 5 chains, plotting only the first 5.")
  }
  
  palette_cols <- c("#006281",
                    "#ffb11c",
                    "#a2e36c",
                    "#7dc5c9",
                    "#ff9880")
  
  draws_long <- do.call(rbind, draws) %>% 
    rownames_to_column("rowname") %>% 
    as_tibble() %>% 
    separate(.data$rowname, sep = "[.]", into = c("chain","iteration")) %>% 
    filter(.data$chain %in% paste0("chain_", 1:5)) %>% 
    mutate(chain = as.factor(str_replace(.data$chain, "_", " ")),
           iteration = as.numeric(.data$iteration)) %>% 
    gather("parameter", "value", -.data$chain, -.data$iteration) %>% 
    filter(!str_detect(.data$parameter, "^eps")) %>%  #filter out group_effects
    mutate(chain = fct_relevel(.data$chain,
                               paste("chain", 1:n_chains)))
  
  ggplot(draws_long, aes(x = .data$iteration,
                         y = .data$value,
                         group = .data$parameter)) +
    geom_line(aes(group = .data$chain, color = .data$chain), size = 0.25) +
    facet_wrap(facets = vars(.data$parameter), scales = "free", ncol = 2) +
    theme_hc(base_size = 16,
             base_family = "serif") +
    xlab("Iteration") +
    ylab("Value") +
    scale_color_manual(name = "",
                       values = palette_cols[1:n_chains]) +
    theme(legend.position = "right", legend.title=element_text(size=14))
}

#' Plot posterior distributions
#' 
#' Generate histograms of the posterior distributions for model paramters. 
#' Monitored group-level effects are not plotted from this function.
#' @return A ggplot2 object with histograms of posteriors
#' 
#' @param draws output from \code{tempo_mcmc()}.
#' @importFrom dplyr summarize
#' @export
tempo_plot_posteriors <- function(draws) {
  
  draws_long <- do.call(rbind, draws) %>% 
    as_tibble() %>% 
    gather("parameter", "value") %>% 
    filter(!str_detect(.data$parameter, "^eps")) # filter out group_effects
  
  means <- draws_long %>% 
    group_by(.data$parameter) %>% 
    summarize(mean = mean(.data$value))
  
  ggplot(draws_long, aes(group = .data$parameter)) +
    geom_histogram(aes(x = .data$value, y = .data$`..density..`),
                   fill = "#488fb1", color = "grey20", size = 0.25, bins = 30) +
    geom_vline(means, mapping = aes(xintercept = .data$mean, color = "Mean"), size = 1) +
    facet_wrap(facets = vars(.data$parameter), scales = "free", ncol = 3) +
    theme_hc(base_size = 16,
             base_family = "serif") +
    xlab("Value") +
    ylab("Density") +
    scale_color_manual("", labels = "Mean", breaks = c("Mean"), values = "#ffaa00") +
    theme(legend.position = "right", legend.title=element_text(size=14))
}


# This function not exported and thus not documented
# see docs for tempo_plot_ccovs_pmf for descriptions of arguments
#' @import ggplot2
#' @importFrom scales rescale
tempo_plot_covariates <- function(draws, obs_id) {
  covariate_array <- attr(draws, "covariates")[, , -1]
  
  draws_all <- do.call(rbind, draws)
  
  if (attr(draws, "rand_eff_bool")) {
    all_betas <- tempo_calc_all_params_re(draws)
    group_id <- 
      attr(draws, "group_ids")[attr(draws, "y")$obs_id %in% obs_id]
    param_means <- data.frame(mean = colMeans(all_betas[, , group_id])[-1])
    param_means$covariate <- as.factor(rownames(param_means))
  } else {
    param_means <- 
      data.frame(mean = colMeans(draws_all[, grepl(colnames(draws$chain_1), 
                                                   pattern = "beta")])[-1])
    param_means$covariate <- 
      as.factor(str_sub(rownames(param_means), start = 6))
  }
  
  param_means$sign <- ifelse(param_means$mean > 0, "Positive", "Negative")
  param_means$linetype <- 
    as.factor(ifelse(param_means$sign == "Positive", "solid", "twodash"))
  
  covariates_long <- covariate_array %>% melt() %>% 
    `colnames<-`(c("obs_id_internal", "time", "covariate", "value")) %>% 
    filter(.data$obs_id_internal == obs_id) %>%  
    mutate(covariate = fct_relevel(as.factor(.data$covariate),
                                   levels(param_means$covariate))) %>% 
    left_join(param_means, by = "covariate") %>% 
    select(-.data$obs_id_internal)
  
  abs_means_in_order <- abs(param_means$mean)[order(abs(param_means$mean))]
  line_widths <- rescale(abs_means_in_order, to = c(0.25, 1.5))
  line_labels <- round(abs(param_means$mean), 2)[order(abs(param_means$mean))]
  
  covs_plot <- ggplot(data = covariates_long,
                      aes(x = .data$time,
                          y = .data$value,
                          colour = .data$covariate,
                          size = as.factor(abs(.data$mean)))) +
    geom_line(aes(linetype = .data$linetype)) +
    scale_color_discrete(guide = guide_legend(order = 1)) +
    scale_size_manual(name = expression(paste("abs(",beta,")")), # change to \beta
                      values = line_widths,
                      labels = line_labels,
                      guide = guide_legend(order = 2)) +
    scale_linetype_identity(name = expression(paste(beta, " Sign")),
                            breaks = c("solid", "twodash"),
                            labels = c("Positive", "Negative"),
                            guide = guide_legend(order = 3)) +
    theme_hc(base_size = 16,
             base_family = "serif") +
    theme(legend.position = "right", legend.title=element_text(size=14)) +
    ylab("Value") +
    xlab("Time") +
    labs(colour = "Covariate")
  
  covs_plot
}


# This function not exported and thus not documented
# see docs for tempo_plot_ccovs_pmf for descriptions of arguments
#' @importFrom stats quantile
#' @importFrom dplyr filter ungroup group_by
#' @importFrom tidyr spread
#' @importFrom reshape2 melt
tempo_plot_pmf <- function(x, obs_id, confidence_intervals = TRUE) {
  pmf_CI <- abind(
    apply(x$pmf, MARGIN = c(2, 3), FUN = quantile, probs = c(0.025, 0.975)) %>% 
      aperm(c(3, 2, 1)),
    t(apply(x$pmf, MARGIN = c(2, 3), FUN = mean)),
    along = 3
  )
  dimnames(pmf_CI)[[3]] <- c("lower", "upper", "mean")
  samples_long <- pmf_CI %>% 
    melt() %>% 
    `colnames<-`(c("obs_id_internal", "time_step", "stat", "value")) %>% 
    spread(.data$stat, .data$value) %>% 
    filter(.data$obs_id_internal == obs_id) %>% 
    mutate(time_step = as.numeric(str_sub(.data$time_step, start = 11)))
  
  pmf_plot1 <- ggplot(samples_long,
                      aes(.data$time_step, .data$mean)) +
    geom_col(alpha = 0.5, width = 1) +
    ylab("Probability") +
    xlab("Time") +
    theme_hc(base_size = 16,
             base_family = "serif") +
    theme(legend.position = "right", legend.title=element_text(size=14))
  
  if (confidence_intervals == TRUE) {
    pmf_plot <- pmf_plot1 + geom_rect(aes(xmin = .data$time_step - 0.5,
                                          xmax = .data$time_step + 0.5,
                                          ymin = .data$lower,
                                          ymax = .data$upper,
                                          color = "95% CI",
                                          size = "95% CI"),
                                      fill = NA,
                                      alpha = 0.25) +
      scale_color_manual(name = NULL, values = "deepskyblue3") +
      scale_size_manual(name = NULL, values = 0.15)
  } else {
    pmf_plot <- pmf_plot1
  }
  
  pmf_plot
}
vlandau/tempo documentation built on March 18, 2020, 12:04 a.m.