R/estimates.R

#' @include flatten.R
NULL

#' @title Function \code{estimates}
#' @description Extracts estimates of posterior means, standard deviations credible intervals
#' from a \code{Chain} object or list of \code{Chain} objects generated by \code{fbseq()}. 
#' @export
#' @return a data frame of marginal posterior parameter estimates: posterior means, poseterior standard deviations, and approximate credible intervals.
#' Parameters not updeted in the MCMC are excluded.
#' @param obj a \code{Chain} object or list of \code{Chain} objects generated by \code{fbseq()}.
#' @param level level of the credible intervals from 0 to 1
estimates = function(obj, level = 0.95){
  if(class(obj) == "Chain") obj = list(obj)
  n = sum(sapply(obj, function(ch) ch@iterations * ch@thin))
  Mean = rowSums(sapply(obj, function(ch) flatten_post(ch) * ch@iterations * ch@thin))/n
  MeanSq = rowSums(sapply(obj, function(ch) flatten_post(ch, square = T) * ch@iterations * ch@thin))/n
  Sd =  sqrt(n*(MeanSq - Mean^2)/(n - 1))

  d = data.frame(mean = Mean, sd = Sd, lower = NA, upper = NA)
  p = 1 - (1 - level)/2

  normals = c("beta", "epsilon", "theta")
  igammas = c("gamma", "nu", "sigmaSquared", "tau", "xi")

  for(v in normals){
    n = grep(v, rownames(d))
    s = d[n,]
    s$lower = qnorm(1 - p, mean = s$mean, sd = s$sd)
    s$upper = qnorm(p, mean = s$mean, sd = s$sd)
    d[n,] = s
  }

  for(v in igammas){
    n = grep(v, rownames(d))
    s = d[n,]
    shape = s$mean^2/s$sd^2 + 2
    scale = s$mean*(shape - 1)
    s$lower = s$upper = s$mean
    i = is.finite(shape)
    if(any(i)){
      s$lower[i] = qinvgamma(1 - p, shape = shape[i], scale = scale[i])
      s$upper[i] = qinvgamma(p, shape = shape[i], scale = scale[i])
    }
    d[n,] = s
  }

  if(any(grepl("nu", rownames(d)))) d["nu", "upper"] = pmin(d["nu", "upper"], obj[[1]]@d)
  if(any(grepl("sigmaSquared", rownames(d))))
    d[grep("sigmaSquared", rownames(d)), "upper"] = pmin(d[grep("sigmaSquared", rownames(d)), "upper"], obj[[1]]@s^2)

  out = data.frame(mean = d$mean, sd = d$sd, lower = d$lower, upper = d$upper)
  colnames(out) = c("mean", "sd", paste0(c("lower", "upper"), "_ci_", level))
  rownames(out) = names(Mean)
  out
}
wlandau/fbseq documentation built on May 4, 2019, 8:43 a.m.