R/analyses.R

Defines functions plot_single_chain gelman_diagnostics ess_diagnostics estimate_mode median.quantile summarise_chain load_mcmc_chains calculate_AIC calc_DIC calculate_WAIC calculate_BIC

Documented in calc_DIC calculate_BIC calculate_WAIC ess_diagnostics estimate_mode gelman_diagnostics load_mcmc_chains median.quantile plot_single_chain summarise_chain

#' BIC calculation
#' 
#' Given an MCMC chain, calculates the BIC
#' @param chain the MCMC chain to be tested
#' @param parTab the parameter table used for this chain
#' @param dat the data
#' @return a single BIC value
#' @export
calculate_BIC <- function(chain, parTab, dat){
    n <- nrow(dat)*ncol(dat)
    maxLik <- -2*max(chain$lnlike)
    B <- length(parTab[parTab$fixed==0,"values"])*log(n)
    return(maxLik + B)
}

#' WAIC calculation
#' 
#' Given an MCMC chain, calculates the WAIC
#' @param chain the MCMC chain to be tested
#' @param parTab the parameter table used for this chain
#' @param dat the data
#' @param f the posterior calculating function
#' @return a single WAIC value
#' @export
calculate_WAIC <- function(chain, parTab, dat, f){
    expectation_posterior <- 0
    tmp <- matrix(nrow=chain,ncol=nrow(dat)*ncol(dat))
    for(i in 1:nrow(chain)){
        pars <- zikaProj::get_index_pars(chain,i)
        y <- f(pars)
        index <- 1
        for(j in 1:nrow(dat)){
            for(x in 1:ncol(dat)){
                wow <-norm_error(y[j,x],dat[j,x],pars["S"],pars["MAX_TITRE"])
                expectation_posterior <- expectation_posterior + log(wow)
                tmp[i, index] <- wow
                index <- index + 1                  
            }
        }
    }
    lppd <- sum(log(colMeans(tmp)))
    pwaic1 <- 2*(lppd - expectation_posterior)
}

#' DIC calculation
#' 
#' Given an MCMC chain, calculates the DIC
#' @param lik.fun the posterior calculating function
#' @param chain the MCMC chain to be tested
#' @return a single DIC value
#' @export
calc_DIC <- function(lik.fun,chain){
    D.bar <- -2*mean(chain$lnlike)
    theta.bar <- as.numeric(summary(as.mcmc(chain[,2:(ncol(chain)-1)]))$statistics[,"Mean"])
   # print(theta.bar)
    D.hat <- -2*lik.fun(theta.bar)
    pD <- D.bar - D.hat
    pV <- var(-2*chain$lnlike)/2
    list(DIC=2*pD+D.bar,IC=2*pD+D.bar,pD=pD,pV=pV,Dbar=D.bar,Dhat=D.hat)
}

#' @export
calculate_AIC <- function(chain, parTab){
    k <- nrow(parTab[parTab$fixed == 0,])
    AIC <- 2*k - 2*(max(chain$lnlike))
    return(AIC)
}


#' Read in MCMC chains
#'
#' Loads all available MCMC chains from the chosen working directory, allowing the user to specify properties of the return chain
#' @param location Either the full file path to the MCMC chain containing directory, or a vector of file paths with the MCMC chainsa
#' @param parTab the parameter table that was used to solve the model (mainly used to find which were free parameters)
#' @param unfixed Boolean, if TRUE, only returns free parameters
#' @param thin thin chain by this much
#' @param burnin number of iterations to discard
#' @param multi if TRUE, looks for chains generated using the multivariate sampler. Otherwise, looks for univariate sampled chains.
#' @param chainNo if TRUE, adds the chain number to the MCMC chain as a column
#' @param PTchain DEV - if TRUE, looks for chains generated by the parallel tempering algorithm
#' @return a list containing both the MCMC chains appended to each other, and an MCMC list.
#' @export
load_mcmc_chains <- function(location="",parTab,unfixed=TRUE, thin=1,
                             burnin=100000, multi=TRUE, chainNo=FALSE,
                             PTchain=FALSE){
  if(length(location) == 1 && dir.exists(location)){
    print("Reading in chains from directory")
    if(multi){
      chains <- Sys.glob(file.path(location,"*multivariate_chain.csv"))
    } else {
      chains <- Sys.glob(file.path(location,"*univariate_chain.csv"))
    }
    if(PTchain){
      chains <- Sys.glob(file.path(location,"*_chain.csv"))
    }
  } else {
    print("Reading in chains from filepaths")
    if(length(location) == 1){
      chains <- list(location)
    } else {
      chains <- as.list(location)
    }
  }
  
  print(chains)
  if(length(chains) < 1){
    message("Error - no chains found")
    return(NULL)
  }
  
  ## Read in the MCMC chains with fread for speed
  read_chains <- lapply(chains,data.table::fread,data.table=FALSE)
  
  ## Thin and remove burn in
  read_chains <- lapply(read_chains, function(x) x[seq(1,nrow(x),by=thin),])
  read_chains <- lapply(read_chains,function(x) x[x$sampno > burnin,])
  print(lapply(read_chains, nrow))
  
  if(chainNo){
    for(i in 1:length(read_chains)) read_chains[[i]]$chain <- i
  }
  
  ## Get the estimated parameters only
  if(unfixed){
    fixed <- parTab$fixed
    read_chains <- lapply(read_chains, function(x) x[,c(1,which(fixed==0)+1,ncol(x))])
  }
  
  ## Try to create an MCMC list. This might not work, which is why we have a try catch
  list_chains <- tryCatch({
    tmp_list <- lapply(read_chains,coda::as.mcmc)
    tmp_list <- coda::as.mcmc.list(tmp_list)
  }, warning = function(w){
    print(w)
    NULL
  }, error = function(e){
    print(e)
    NULL
  },
  finally = {
    tmp_list
  })
  
  chain <- coda::as.mcmc(do.call("rbind",read_chains))
  return(list("list"=list_chains,"chain"=chain))
}


#' Summarise MCMC chain statistics
#'
#' Returns code statistics and quantiles
#' @param chain the MCMC chain
#' @return a data frame of MCMC statistics
#' @export
summarise_chain <- function(chain){
  tmp <- summary(coda::as.mcmc(chain))
  return(cbind(tmp$statistics,tmp$quantiles))
}

#' Find median
#'
#' Finds the median and 95% quantiles for a vector
#' @param x the vector to be analysed
#' @return the quantiles
#' @export
median.quantile <- function(x){
  out <- quantile(x, probs = c(0.025,0.5,0.975))
  names(out) <- c("ymin","y","ymax")
  return(out)
}

#' Calculate mode
#'
#' Calculates the posterior mode based on density
#' @param x the vector to be tested
#' @return the point estimate for the mode
#' @export
estimate_mode <- function(x) {
  d <- density(x)
  d$x[which.max(d$y)]
}


#' Check ESS
#'
#' Checks the MCMC chain for effective sample sizes used the coda package, and highlights where these are less than a certain threshold
#' @param chain the MCMC chain to be tested
#' @param threshold the minimum allowable ESS
#' @return a list containing the ESS sizes and flagging which parameters have ESS below the threshold
#' @export
ess_diagnostics <- function(chain, threshold=200){
    ess <- coda::effectiveSize(chain)
    poorESS <- ess[ess < threshold]
    return(list("ESS"=ess,"poorESS"=poorESS))
}

#' Check gelman diagnostics
#'
#' Checks the gelman diagnostics for the given MCMC chain list. Also finds the parameter with the highest PSRF.
#' @param chain the list of MCMC chains
#' @param threshold the threshold for the gelman diagnostic above which chains should be rerun
#' @return a list of gelman diagnostics and highlighting worst parameters
#' @export
gelman_diagnostics <- function(chain, threshold=1.15){
  tmp <- NULL
  gelman <- tryCatch({
      gelman <- coda::gelman.diag(chain)
      psrf <- max(gelman$psrf[,2])
      psrf_names <- names(which.max(gelman$psrf[,2]))
      mpsrf <- gelman$mpsrf
      worst <- c("Worst_PSRF"=psrf,"Which_worst"=psrf_names,"MPSRF"=mpsrf)
      rerun <- FALSE
      if(psrf > threshold | mpsrf > threshold) rerun <- TRUE
      tmp <- list("GelmanDiag"=gelman,"WorstGelman"=worst, "Rerun"=rerun)
  }, warning = function(w){
      tmp <- w
  }, error = function(e){
      tmp <- e
  })
  return(gelman)
}


#' Density and iteration plot single chain
#'
#' Plots the posterior density and MCMC trace plots for the given chain. Requires ggplot2,cowplot and reshape2
#' @param chain the single MCMC chain
#' @param realPars OPTIONAL - vector of real parameter values to add to plot
#' @return a ggplot2 object
#' @export
plot_single_chain <- function(chain, realPars = NULL){
    meltedChain <- melt(chain,id.vars="sampno")
    
}
jameshay218/lazymcmc documentation built on Sept. 16, 2021, 12:14 a.m.