R/extr_outs.R

Defines functions print.extr plot.extr extr_outs

Documented in extr_outs plot.extr print.extr

## Function extr_outs
##
##' @title
##' Extract outputs from `stanfit` objects obtained from [ProbBreed::bayes_met]
##'
##' @description
##' Extracts outputs of the Bayesian model fitted
##' using [ProbBreed::bayes_met()], and provides some diagnostics.
##'
##' @details
##' More details about the usage of `extr_outs` and other functions of
##' the `ProbBreed` package can be found at \url{https://saulo-chaves.github.io/ProbBreed_site/}.
##'
##'
##' @param model An object of class `stanfit`, obtained using [ProbBreed::bayes_met()]
##' @param probs A vector with two elements representing the probabilities
##' (in decimal scale) that will be considered for computing the quantiles.
##' @param verbose A logical value. If `TRUE`, the function will indicate the
##' completed steps. Defaults to `FALSE`
##' @return The function returns an object of class `extr`, which is a list with:
##' \itemize{
##' \item \code{variances} : a data frame containing the variance components of
##' the model effects, their standard deviation, naive standard error and highest
##' posterior density interval.
##' \item \code{post} : a list with the posterior of the effects, and the data
##' generated by the model.
##' \item \code{map} : a list with the maximum posterior values of each effect
##' \item \code{ppcheck} : a matrix containing the p-values of maximum, minimum,
##' median, mean and standard deviation; effective number of parameters, WAIC2
##' value, Rhat and effective sample size.
##' }
##'
##' @seealso [rstan::stan_diag], [ggplot2::ggplot], [rstan::check_hmc_diagnostics], [ProbBreed::plot.extr]
##'
##' @import rstan
##' @importFrom utils write.csv
##' @importFrom rlang .data
##'
##' @examples
##' \donttest{
##' mod = bayes_met(data = maize,
##'                 gen = "Hybrid",
##'                 loc = "Location",
##'                 repl = c("Rep","Block"),
##'                 trait = "GY",
##'                 reg = "Region",
##'                 year = NULL,
##'                 res.het = TRUE,
##'                 iter = 2000, cores = 2, chain = 4)
##'
##' outs = extr_outs(model = mod,
##'                  probs = c(0.05, 0.95),
##'                  verbose = TRUE)
##' }
##' @export

extr_outs = function(model, probs = c(0.025, 0.975), verbose = FALSE){

  requireNamespace('rstan')

  # Data
  data = attr(model, "data")

  # Extract stan results
  out <- rstan::extract(model, permuted = TRUE)

  effects = names(out)[which(names(out) %in% c('r','b','m', 'l', 't', 'g', 'gl','gt','gm'))]
  declared = attr(model, "declared_eff")
  trait = declared$trait
  nenv = ncol(out$l)

  # Posterior effects ------------------------
  post = list()

  if(attr(model, 'modmean')){
    for (i in c(effects)) {
      post[[i]] = out[[i]]
    }
    names(post)[which(names(post) == 'l')] = declared$loc
    names(post)[which(names(post) == 'g')] = declared$gen
    # names(post)[which(names(post) == 'sigma_vec')] = paste(declared$gen, declared$loc, sep = ':')

  }else{
    for (i in effects) {
      post[[i]] = out[[i]]
    }
    names(post)[which(names(post) == 'l')] = declared$loc
    names(post)[which(names(post) == 'g')] = declared$gen
    names(post)[which(names(post) == 'gl')] = paste(declared$gen, declared$loc, sep = ':')
    if("r" %in% effects){names(post)[which(names(post) == 'r')] = declared$repl}
    if("b" %in% effects){names(post)[which(names(post) == 'b')] = declared$blk}
    if("m" %in% effects){
      names(post)[which(names(post) == 'm')] = declared$reg
      names(post)[which(names(post) == 'gm')] = paste(declared$gen, declared$reg, sep = ':')
    }
    if("t" %in% effects){
      names(post)[which(names(post) == 't')] = declared$year
      names(post)[which(names(post) == 'gt')] = paste(declared$gen, declared$year, sep = ':')
    }
  }

  if(verbose)  message('-> Posterior effects extracted')

  # Variances --------------------
  if(!'sigma_vec' %in% names(out)){

    variances = NULL
    std.dev = NULL
    naive.se = NULL
    prob = matrix(NA, nrow = 2, ncol = length(c(effects,'sigma')),
                  dimnames = list(probs, c(effects,'sigma')))
    # var.plots = list()
    for (i in c(effects,'sigma')) {
      if(i == 'sigma') all_variances = out[['sigma']]^2 else all_variances = out[[paste("s",i,sep='_')]]^2
      variances[i] = mean(all_variances)
      std.dev[i] = sd(all_variances)
      naive.se[i] = sd(all_variances)/sqrt(length(all_variances))
      prob[,i] = as.matrix(quantile(all_variances, probs = probs))
    }
    variances = data.frame(
      'effect' = c(effects,'error'),
      'var' = round(variances,3),
      'sd' = round(std.dev,3),
      'naive.se' = round(naive.se,3),
      'HPD1' = round(prob[1,],3),
      'HPD2' = round(prob[2,],3),
      row.names = NULL
    )

    colnames(variances)[which(colnames(variances) %in% c("HPD1",'HPD2'))] = c(
      paste('HPD',probs[1], sep = '_'),
      paste('HPD',probs[2], sep = '_')
    )

    rm(all_variances)
    rm(i)

  } else {

    variances = NULL
    std.dev = NULL
    naive.se = NULL
    prob = matrix(NA, nrow = 2, ncol = length(effects),
                  dimnames = list(probs, effects))
    for (i in c(effects)) {
      all_variances = out[[paste("s",i,sep='_')]]^2
      variances[i] = mean(all_variances)
      std.dev[i] = sd(all_variances)
      naive.se[i] = sd(all_variances)/sqrt(length(all_variances))
      prob[,i] = as.matrix(quantile(all_variances, probs = probs))
    }

    variances = data.frame(
      'effect' = c(effects, paste0('error_env', 1:nenv)),
      'var' = c(round(variances, 3), round(apply(out[['sigma']]^2, 2, mean),3)),
      'sd' = c(round(std.dev,3), round(apply(out[['sigma']]^2, 2, sd),3)),
      'naive.se' = c(round(naive.se,3),
                     round(apply(out[['sigma']]^2, 2,
                                 function(x) sd(x)/sqrt(length(x))), 3)),
      'HPD1' = c(round(prob[1,],3), round(apply(out[['sigma']]^2, 2,
                                                function(x) quantile(x, probs = probs))[1,],3)),
      'HPD2' = c(round(prob[2,],3), round(apply(out[['sigma']]^2, 2,
                                                function(x) quantile(x, probs = probs))[2,],3)),
      row.names = NULL
    )

    rm(all_variances)
    rm(i)

    colnames(variances)[which(colnames(variances) %in% c("HPD1",'HPD2'))] = c(
      paste('HPD',probs[1], sep = '_'),
      paste('HPD',probs[2], sep = '_')
    )
  }

  variances$effect[which(variances$effect == 'l')] = declared$loc
  variances$effect[which(variances$effect == 'g')] = declared$gen
  if(!attr(model, 'modmean')) variances$effect[which(variances$effect == 'gl')] = paste(declared$gen, declared$loc, sep = ':')
  if("r" %in% effects){variances$effect[which(variances$effect == 'r')] = declared$repl}
  if("b" %in% effects){variances$effect[which(variances$effect == 'b')] = declared$blk}
  if("m" %in% effects){
    variances$effect[which(variances$effect == 'm')] = declared$reg
    variances$effect[which(variances$effect == 'gm')] = paste(declared$gen, declared$reg, sep = ':')
  }
  if("t" %in% effects){
    variances$effect[which(variances$effect == 't')] = declared$year
    variances$effect[which(variances$effect == 'gt')] = paste(declared$gen, declared$year, sep = ':')
  }

  if(verbose) message('-> Variances extracted')

  # Maximum posterior values (MAP) ------------------------
  get_map <- function(posterior) {
    posterior <- as.matrix(posterior)
    if (ncol(posterior) > 1) {
      den = apply(posterior, 2, density)
      map = unlist(lapply(den, function(den)
        den$x[which.max(den$y)]))
    }
    else {
      den = density(posterior)
      map = den$x[which.max(den$y)]
    }
    return(map)
  }
  map = lapply(post, get_map)

  if(verbose) message('-> Maximum posterior values extracted')

  # Data generated by the model -----------------
  post[['sampled.Y']] <- out[['y_gen']]

  # Diagnostics -------------------
  ns = length(out$mu)
  y = as.numeric(data[,trait])
  N = length(y)
  temp = apply(out$y_gen, 1, function(x) {
    c(
      max = max(x) > max(y),
      min = min(x) > min(y),
      median = quantile(x, 0.5) > quantile(y, 0.5),
      mean = mean(x) > mean(y),
      sd = sd(x) > sd(y)
    )
  })
  p.val_max = sum(temp["max",]) / ns
  p.val_min = sum(temp["min",]) / ns
  p.val_median = sum(temp["median.50%",]) / ns
  p.val_mean = sum(temp["mean",]) / ns
  p.val_sd = sum(temp["sd",]) / ns

  temp_v = apply(out$y_log_like, 2, function(x) {
    c(val = log((1 / ns) * sum(exp(x))),
      var = var(x))
  })
  lppd = sum(temp_v["val", ])
  p_WAIC2 = sum(temp_v["var", ]) # Effective number of parameters
  elppd_WAIC2 = lppd - p_WAIC2
  WAIC2 = -2 * elppd_WAIC2
  output_p_check = round(t(
    cbind(
      p.val_max = p.val_max,
      p.val_min = p.val_min,
      p.val_median = p.val_median,
      p.val_mean = p.val_mean,
      p.val_sd = p.val_sd,
      Eff_No_parameters = p_WAIC2,
      WAIC2 = WAIC2,
      mean_Rhat = mean(summary(model)$summar[,"Rhat"]),
      Eff_sample_size = mean(summary(model)$summar[,"n_eff"])/ns
    )
  ), 4)
  colnames(output_p_check) <- "Diagnostics"

  if(verbose) message('-> Posterior predictive checks computed')


  results = list(post, variances, map, output_p_check)
  names(results) = c('post', 'variances', 'map', 'ppcheck')

  rstan::check_hmc_diagnostics(model)

  class(results) = "extr"

  attr(results, "control") = data.frame(
    niter = model@stan_args[[1]]$iter,
    warmup = model@stan_args[[1]]$warmup,
    thin = model@stan_args[[1]]$thin,
    nchains = length(model@stan_args)
  )

  attr(results, "data") = data
  attr(results, 'modmean') = attr(model, "modmean")
  attr(results, "declared") = declared

  return(results)
}

#' Plots for the `extr` object
#'
#' Build plots using the outputs stored in the `extr` object.
#'
#'
#' @param x An object of class `extr`.
#' @param category A string indicating which plot to build. See options in the Details section.
#' @param ... Passed to [ggplot2::geom_histogram], when `category = histogram`. Useful to change the
#' number of bins.
#' @method plot extr
#'
#'
#' @details The available options are:
##'   \itemize{
##'     \item \code{ppdensity} : Density plots of the empirical and sampled data, useful to assess the
##'     model's convergence.
##'     \item \code{density} : Density plots of the model's effects.
##'     \item \code{histogram} : Histograms of the model's effects.
##'     \item \code{traceplot}: Trace plot showing the changes in the effects' values across iterations and chains.
##'   }
#'
#'
#' @seealso  [ProbBreed::extr_outs]
#'
##' @import ggplot2
##' @importFrom rlang .data
#'
#' @rdname plot.extr
#' @export
#'
#' @examples
#' \donttest{
##' mod = bayes_met(data = maize,
##'                 gen = "Hybrid",
##'                 loc = "Location",
##'                 repl = c("Rep","Block"),
##'                 trait = "GY",
##'                 reg = "Region",
##'                 year = NULL,
##'                 res.het = TRUE,
##'                 iter = 2000, cores = 2, chain = 4)
##'
##' outs = extr_outs(model = mod,
##'                  probs = c(0.05, 0.95),
##'                  verbose = TRUE)
##' plot(outs, category = "ppdensity")
##' plot(outs, category = "density")
##' plot(outs, category = "histogram")
##' plot(outs, category = "traceplot")
#' }
#'

plot.extr = function(x, ..., category = "ppdensity"){

  obj = x
  # Namespaces
  requireNamespace('ggplot2')

  stopifnot("Object is not of class 'extr'" = class(obj) == "extr")

  control = attr(obj, "control")
  declared = attr(obj, "declared")
  post = obj$post
  df.post.list = lapply(post, function(x){
    data.frame(
      value = c(x),
      iter = rep(seq(1, control$niter - control$warmup),
                 times = ncol(x) * control$nchains),
      chain = rep(seq(1, control$nchains),
                  each = control$niter - control$warmup)
    )
  })
  for (i in names(df.post.list)) df.post.list[[i]]$eff = i

  if(category == "ppdensity"){
    temp = df.post.list$sampled.Y
    temp$eff = "Sampled"
    temp = rbind(temp[,c(1,4)], data.frame(value = attr(obj,"data")[,declared$trait],
                                           eff = 'Empirical'))
    ggplot(data = temp, aes(x = .data$value, color = .data$eff, fill = .data$eff)) +
      geom_density(linewidth = 1.3, alpha = .25) +
      labs(x = declared$trait, y = 'Frequency', fill = '', color = '',
           title = "Density: Empirical vs Sampled phenotypic values") +
      theme_bw() +
      theme(legend.position = 'top') +
      scale_fill_manual(values = c('Empirical' = '#1f78b4', 'Sampled' = "#33a02c")) +
      scale_color_manual(values = c('Empirical' = '#1f78b4', 'Sampled' = "#33a02c"))

  }
  else if(category == "density"){
    df.post = do.call(rbind, df.post.list)
    df.post[which(df.post$eff == "sampled.Y"),"eff"] = paste("Sampled:", declared$trait)

    ggplot(df.post, aes(x = .data$value)) +
      facet_wrap(.~.data$eff, scales = "free")+
      theme_bw() + theme(legend.position = "none", axis.title = element_blank()) +
      geom_density(aes(fill = .data$eff), colour = "black", alpha = .5) +
      scale_fill_viridis_d(option = "turbo") +
      labs(title = "Density of the posterior values")
  }
  else if(category == "histogram"){
    df.post = do.call(rbind, df.post.list)
    df.post[which(df.post$eff == "sampled.Y"),"eff"] = paste("Sampled:", declared$trait)

    ggplot(df.post, aes(x = .data$value)) +
      facet_wrap(.~.data$eff, scales = "free")+
      theme_bw() + theme(legend.position = "none", axis.title = element_blank()) +
      geom_histogram(aes(fill = .data$eff), colour = "black", alpha = .5, ...) +
      scale_fill_viridis_d(option = "turbo") +
      labs(title = "Histogram of the posterior values")
  }
  else if(category == 'traceplot'){
    df.post = do.call(rbind, df.post.list)
    df.post[which(df.post$eff == "sampled.Y"),"eff"] = paste("Sampled:", declared$trait)

    ggplot(df.post, aes(y = .data$value, x = .data$iter, colour = factor(.data$chain))) +
      facet_wrap(.~.data$eff, scales = "free")+
      theme_bw() + theme(legend.position = "top", axis.title.y = element_blank()) +
      geom_line(aes(group = factor(.data$chain), linetype = factor(.data$chain))) +
      scale_colour_viridis_d(option = "turbo")  +
      labs(x = 'Iterations', colour = 'Chain', linetype = 'Chain', title = "Traceplot")
  }

}




#' Print an object of class `extr`
#'
#' Print a `extr` object in R console
#'
#' @param x An object of class `extr`
#' @param ... currently not used
#' @method print extr
#'
#' @seealso [ProbBreed::extr_outs]
#'
#' @export
#'

print.extr = function(x, ...){

  obj = x
  message("======> Variances")
  print(obj$variances)
  message("======> Posterior predictive checks")
  print(obj$ppcheck)

}

Try the ProbBreed package in your browser

Any scripts or data that you put into this service are public.

ProbBreed documentation built on April 4, 2025, 5:07 a.m.