R/plot_chain.R

Defines functions plot_chain

Documented in plot_chain

#' @title Chain's plot
#'
#' @description
#' This function provides three options of plots for the chain generated by the MCMC
#' algorithm in hp() and dlm() functions.
#'
#' @usage
#' plot_chain(fit, param, type = c("trace", "acf", "density"))
#'
#' @param fit Object of the classes `HP` or `DLM`.
#' @param param Character vector specifying the parameters to be plotted. It is used only when the class of fit object is `DLM`.
#' @param type Character string specifying the type of plot to be returned. There are three options: "trace" return a plot for the sample of the parameters; "acf" return a plot for the autocorrelation of the parameters; "density" return a plot for the posterior density of the parameters based on the samples generated by the MCMC method.
#'
#' @return A plot of the chosen type of the selected parameter(s).
#'
#' @examples
#' ## Importing mortality data from the USA available on the Human Mortality Database (HMD):
#' \donttest{data(USA)
#'
#' ## Selecting the log mortality rate of the 2010 total population ranging from 0 to 90 years old
#' USA2010 = USA[USA$Year == 2010,]
#' x = 0:90
#' Ex = USA2010$Ex.Total[x+1]
#' Dx = USA2010$Dx.Total[x+1]
#' y = log(Dx/Ex)
#'
#' ## Fitting HP model
#' fit = hp(x = x, Ex = Ex, Dx = Dx, model = "lognormal",
#'          m = c(NA, 0.08, rep(NA, 6)),
#'          v = c(NA, 1e-4, rep(NA, 6)))
#'
#' ## Plotting all the available types of plot:
#' plot_chain(fit, type = "trace")
#' plot_chain(fit, type = "acf")
#' plot_chain(fit, type = "density")
#'
#'
#' ## Fitting DLM
#' fit = dlm(y, M = 100)
#'
#' plot_chain(fit, param = "sigma2", type = "trace")
#' plot_chain(fit, param = "mu[10]", type = "acf")
#'
#' ## Selecting all theta1 indexed with 1 in first digit
#' plot_chain(fit, param = "theta1[1", type = "density")
#'
#' ## Plotting all parameters indexed by age 10 and age 11
#' plot_chain(fit, param = c("[10]", "[11]"))
#' }
#'
#' @import ggplot2
#' @importFrom dplyr select
#' @importFrom dplyr starts_with
#' @importFrom tidyr gather
#'
#' @export
plot_chain <- function(fit, param, type = c("trace", "acf", "density")){

  if(inherits(fit,"HP")){

    type = match.arg(type)

    if(type == "trace"){

      if(fit$info$model %in% c("binomial", "poisson")){
        df = data.frame(
          samples = c(fit$post.samples$mcmc_theta[,1], fit$post.samples$mcmc_theta[,2],
                      fit$post.samples$mcmc_theta[,3], fit$post.samples$mcmc_theta[,4],
                      fit$post.samples$mcmc_theta[,5], fit$post.samples$mcmc_theta[,6],
                      fit$post.samples$mcmc_theta[,7], fit$post.samples$mcmc_theta[,8]),
          param = rep(LETTERS[1:8], each = nrow(fit$post.samples$mcmc_theta)),
          iteration = rep(1:nrow(fit$post.samples$mcmc_theta), times = 8)
        )
      }else{
        df = data.frame(
          samples = c(fit$post.samples$mcmc_theta[,1], fit$post.samples$mcmc_theta[,2],
                      fit$post.samples$mcmc_theta[,3], fit$post.samples$mcmc_theta[,4],
                      fit$post.samples$mcmc_theta[,5], fit$post.samples$mcmc_theta[,6],
                      fit$post.samples$mcmc_theta[,7], fit$post.samples$mcmc_theta[,8],
                      fit$post.samples$sigma2),
          param = rep(c(LETTERS[1:8], "sigma2"), each = nrow(fit$post.samples$mcmc_theta)),
          iteration = rep(1:nrow(fit$post.samples$mcmc_theta), times = 9)
        )
      }

      ggplot2::ggplot(df) +
        ggplot2::xlab("") + ggplot2::ylab("") + ggplot2::theme_bw() +
        ggplot2::theme(plot.title = ggplot2::element_text(lineheight = 1.2),
                       axis.title.x = ggplot2::element_text(color = 'black', size = 13),
                       axis.title.y = ggplot2::element_text(color = 'black', size = 13),
                       axis.text = ggplot2::element_text(color = 'black', size = 13),
                       legend.text = ggplot2::element_text(size = 13),
                       strip.background = ggplot2::element_rect(fill = "deepskyblue4"),
                       strip.text = ggplot2::element_text(color = "white", size = 12)) +
        ggplot2::geom_line(ggplot2::aes(x = iteration, y = samples), col = "deepskyblue4") +
        ggplot2::facet_wrap(~param, scales = "free")

    }else if(type == "acf"){

      if(fit$info$model %in% c("binomial", "poisson")){
        df = data.frame(
          autocor = c(
            acf(fit$post.samples$mcmc_theta[,1], lag.max = 50, plot = F)$acf,
            acf(fit$post.samples$mcmc_theta[,2], lag.max = 50, plot = F)$acf,
            acf(fit$post.samples$mcmc_theta[,3], lag.max = 50, plot = F)$acf,
            acf(fit$post.samples$mcmc_theta[,4], lag.max = 50, plot = F)$acf,
            acf(fit$post.samples$mcmc_theta[,5], lag.max = 50, plot = F)$acf,
            acf(fit$post.samples$mcmc_theta[,6], lag.max = 50, plot = F)$acf,
            acf(fit$post.samples$mcmc_theta[,7], lag.max = 50, plot = F)$acf,
            acf(fit$post.samples$mcmc_theta[,8], lag.max = 50, plot = F)$acf),
          param = rep(LETTERS[1:8], each = 51),
          lag = rep(0:50, times = 8)
        )
      }else{
        df = data.frame(
          autocor = c(
            acf(fit$post.samples$mcmc_theta[,1], lag.max = 50, plot = F)$acf,
            acf(fit$post.samples$mcmc_theta[,2], lag.max = 50, plot = F)$acf,
            acf(fit$post.samples$mcmc_theta[,3], lag.max = 50, plot = F)$acf,
            acf(fit$post.samples$mcmc_theta[,4], lag.max = 50, plot = F)$acf,
            acf(fit$post.samples$mcmc_theta[,5], lag.max = 50, plot = F)$acf,
            acf(fit$post.samples$mcmc_theta[,6], lag.max = 50, plot = F)$acf,
            acf(fit$post.samples$mcmc_theta[,7], lag.max = 50, plot = F)$acf,
            acf(fit$post.samples$mcmc_theta[,8], lag.max = 50, plot = F)$acf,
            acf(fit$post.samples$sigma2, lag.max = 50, plot = F)$acf),
          param = rep(c(LETTERS[1:8], "sigma2"), each = 51),
          lag = rep(0:50, times = 9)
        )
      }

      ggplot2::ggplot(df) +
        ggplot2::scale_y_continuous(breaks = seq(0, 1, by = 0.2)) +
        ggplot2::scale_x_continuous(breaks = seq(0, 50, by = 10), limits = c(-0.5, 50.5)) +
        ggplot2::xlab("Lag") + ggplot2::ylab("Autocorrelation") + ggplot2::theme_bw() +
        ggplot2::theme(plot.title = ggplot2::element_text(lineheight = 1.2),
                       axis.title.x = ggplot2::element_text(color = 'black', size = 13),
                       axis.title.y = ggplot2::element_text(color = 'black', size = 13),
                       axis.text = ggplot2::element_text(color = 'black', size = 13),
                       legend.text = ggplot2::element_text(size = 13),
                       strip.background = ggplot2::element_rect(fill = "deepskyblue4"),
                       strip.text = ggplot2::element_text(color = "white", size = 12)) +
        ggplot2::geom_bar(ggplot2::aes(x = lag, y = autocor), stat = "identity", col = "black", fill = "orangered4") +
        ggplot2::facet_wrap(~param)

    }else if(type == "density"){


      if(fit$info$model %in% c("binomial", "poisson")){
        df = data.frame(
          samples = c(fit$post.samples$mcmc_theta[,1], fit$post.samples$mcmc_theta[,2],
                      fit$post.samples$mcmc_theta[,3], fit$post.samples$mcmc_theta[,4],
                      fit$post.samples$mcmc_theta[,5], fit$post.samples$mcmc_theta[,6],
                      fit$post.samples$mcmc_theta[,7], fit$post.samples$mcmc_theta[,8]),
          param = rep(LETTERS[1:8], each = nrow(fit$post.samples$mcmc_theta)),
          iteration = rep(1:nrow(fit$post.samples$mcmc_theta), times = 8)
        )
      }else{
        df = data.frame(
          samples = c(fit$post.samples$mcmc_theta[,1], fit$post.samples$mcmc_theta[,2],
                      fit$post.samples$mcmc_theta[,3], fit$post.samples$mcmc_theta[,4],
                      fit$post.samples$mcmc_theta[,5], fit$post.samples$mcmc_theta[,6],
                      fit$post.samples$mcmc_theta[,7], fit$post.samples$mcmc_theta[,8],
                      fit$post.samples$sigma2),
          param = rep(c(LETTERS[1:8], "sigma2"), each = nrow(fit$post.samples$mcmc_theta)),
          iteration = rep(1:nrow(fit$post.samples$mcmc_theta), times = 9)
        )
      }

      ggplot2::ggplot(df) +
        ggplot2::xlab("") + ggplot2::ylab("") + ggplot2::theme_bw() +
        ggplot2::theme(plot.title = ggplot2::element_text(lineheight = 1.2),
                       axis.title.x = ggplot2::element_text(color = 'black', size = 13),
                       axis.title.y = ggplot2::element_text(color = 'black', size = 13),
                       axis.text = ggplot2::element_text(color = 'black', size = 13),
                       legend.text = ggplot2::element_text(size = 13),
                       strip.background = ggplot2::element_rect(fill = "deepskyblue4"),
                       strip.text = ggplot2::element_text(color = "white", size = 12)) +
        ggplot2::geom_density(ggplot2::aes(x = samples), col = "black", fill = "deepskyblue4", alpha = .7) +
        ggplot2::facet_wrap(~param, scales = "free")

    }else{
      stop("Invalid type.")
    }

  }else if(inherits(fit,"DLM")){

    type = match.arg(type)
    ages = fit$info$ages
    m = length(fit$info$Ft)

    theta_name = paste0("theta", rep(1:m, each = length(fit$info$y)), "[", fit$info$ages, "]")
    col_names = c(paste0("mu[", fit$info$ages, "]"), theta_name, "sigma2")

    if(type == "trace"){

      if(m>1){
        aux = fit$theta[,,1]
        for(i in 2:m) aux = cbind(aux, fit$theta[,,i])
      }else{
        aux = fit$theta
      }

      chains = cbind(fit$mu, aux, fit$sig2)
      chains = as.data.frame(chains)
      colnames(chains) <- col_names

      ## Selecting the parameters to plot
      chains = dplyr::select(chains, dplyr::contains(param))
      p = ncol(chains)

      ## Checking if param is valid
      if(p == 0) { stop("param argument is not valid.") }

      chains$iteration = rep(1:nrow(chains))
      chains = tidyr::gather(chains, key = "param", value = "samples", -iteration)
      chains$param = factor(chains$param, levels = unique(chains$param))

      ggplot2::ggplot(chains) +
        ggplot2::xlab("") + ggplot2::ylab("") + ggplot2::theme_bw() +
        ggplot2::theme(plot.title = ggplot2::element_text(lineheight = 1.2),
                       axis.title.x = ggplot2::element_text(color = 'black', size = 10),
                       axis.title.y = ggplot2::element_text(color = 'black', size = 10),
                       axis.text = ggplot2::element_text(color = 'black', size = 10),
                       legend.text = ggplot2::element_text(size = 10),
                       strip.background = ggplot2::element_rect(fill = "deepskyblue4"),
                       strip.text = ggplot2::element_text(color = "white", size = 12)) +
        ggplot2::geom_line(ggplot2::aes(x = iteration, y = samples), col = "deepskyblue4") +
        ggplot2::facet_wrap(~param, scales = "free")

    }else if(type == "acf"){

      if(m>1){
        aux = fit$theta[,,1]
        for(i in 2:m) aux = cbind(aux, fit$theta[,,i])
      }else{
        aux = fit$theta
      }

      chains = cbind(fit$mu, aux, fit$sig2)
      chains = as.data.frame(chains)
      colnames(chains) <- col_names

      ## Selecting the parameters to plot
      chains = dplyr::select(chains, dplyr::contains(param))
      p = ncol(chains)

      ## Checking if param is valid
      if(p == 0) { stop("param argument is not valid.") }

      aux  = matrix(NA_real_, nrow = 51, ncol = p); aux = as.data.frame(aux); colnames(aux) = colnames(chains)
      for(i in 1:p){ aux[,i] <- acf(chains[,i], lag.max = 50, plot = F)$acf }

      aux$lag = lag = rep(0:50)
      aux = tidyr::gather(aux, key = "param", value = "autocor", -lag)
      aux$param = factor(aux$param, levels = unique(aux$param))

      ggplot2::ggplot(aux) +
        ggplot2::scale_y_continuous(breaks = seq(0, 1, by = 0.2)) +
        ggplot2::scale_x_continuous(breaks = seq(0, 50, by = 10), limits = c(-0.5, 50.5)) +
        ggplot2::xlab("Lag") + ggplot2::ylab("Autocorrelation") + ggplot2::theme_bw() +
        ggplot2::theme(plot.title = ggplot2::element_text(lineheight = 1.2),
                       axis.title.x = ggplot2::element_text(color = 'black', size = 10),
                       axis.title.y = ggplot2::element_text(color = 'black', size = 10),
                       axis.text = ggplot2::element_text(color = 'black', size = 10),
                       legend.text = ggplot2::element_text(size = 10),
                       strip.background = ggplot2::element_rect(fill = "deepskyblue4"),
                       strip.text = ggplot2::element_text(color = "white", size = 12)) +
        ggplot2::geom_bar(ggplot2::aes(x = lag, y = autocor), stat = "identity", col = "black", fill = "orangered4") +
        ggplot2::facet_wrap(~param)

    }else if(type == "density"){

      if(m>1){
        aux = fit$theta[,,1]
        for(i in 2:m) aux = cbind(aux, fit$theta[,,i])
      }else{
        aux = fit$theta
      }

      chains = cbind(fit$mu, aux, fit$sig2)
      chains = as.data.frame(chains)
      colnames(chains) <- col_names

      ## Selecting the parameters to plot
      chains = dplyr::select(chains, dplyr::contains(param))
      p = ncol(chains)

      ## Checking if param is valid
      if(p == 0) { stop("param argument is not valid.") }

      chains = tidyr::gather(chains, key = "param", value = "samples")
      chains$param = factor(chains$param, levels = unique(chains$param))

      ggplot2::ggplot(chains) +
        ggplot2::xlab("") + ggplot2::ylab("") + ggplot2::theme_bw() +
        ggplot2::theme(plot.title = ggplot2::element_text(lineheight = 1.2),
                       axis.title.x = ggplot2::element_text(color = 'black', size = 10),
                       axis.title.y = ggplot2::element_text(color = 'black', size = 10),
                       axis.text = ggplot2::element_text(color = 'black', size = 10),
                       legend.text = ggplot2::element_text(size = 10),
                       strip.background = ggplot2::element_rect(fill = "deepskyblue4"),
                       strip.text = ggplot2::element_text(color = "white", size = 12)) +
        ggplot2::geom_density(ggplot2::aes(x = samples), col = "black", fill = "deepskyblue4", alpha = .7) +
        ggplot2::facet_wrap(~param, scales = "free")

    }else{
      stop("Invalid type.")
    }
  }
}

Try the BayesMortalityPlus package in your browser

Any scripts or data that you put into this service are public.

BayesMortalityPlus documentation built on June 22, 2024, 7 p.m.