R/mixing_test.R

Defines functions correlation_tile_sample summarise_sample plot_chain plot_acf plot_trace plot_gelman eff_sample_size median_step_size assess_chain

#This script contains functions for assessing the mixing of the chains:

#' @export

#function to assess burnin of chain
assess_chain <- function(output){
  
  #Calc median step size in each dimension
  median_step_size <- lapply(output$full_results,function(x){
    tmp <- mapply(x$sample_proposals[-1],x$sample_results[-length(x$sample_proposals)],FUN = function(x,y){abs(x-y)},SIMPLIFY = FALSE)
    tmp <- t(as.matrix(as.data.frame(tmp)))
    tmp[is.na(tmp)] <- 0
    return(apply(tmp,2,median))
  })
  
  #Calc median acceptance prob
  median_acceptance <- lapply(output$full_results,function(x){
    x$sample_probs[is.na(x$sample_probs)] <- 0 
    median(x$sample_probs)
  })
  
  #Calc the effective sample size of each chain in each dimension
  eff_sample_size <- lapply(output$full_results,function(x){
    tmp <- as.matrix(as.data.frame(x$sample_results))
    tmp <- apply(tmp,1,FUN = function(y){effectiveSize(as.mcmc(y))})
    return(tmp)}
  )
  
  print("Median step sizes are")
  print(median_step_size)
  print("Median acceptance probabilities are")
  print(median_acceptance)
  print("Effective Sample sizes are")
  print(eff_sample_size)
  print("Relative Speed of the method is:")
  print(mean(sapply(eff_sample_size,mean))/output$time)
  
  #plot the gelman diagnostics
  len <- length(output$sample)/8
  
  for(i in 1:length(output$sample[[1]])){
    
    par(mfrow = c(8,length(output$sample[[1]]))/2)
    MCMC_list <- lapply(0:7,function(x){
      tmp <- output$sample[(1+x*len):((x+1)*len)]
      tmp <- sapply(tmp,function(x){x[i]})
      acf(tmp,lag = length(tmp)/2)
      return(as.mcmc(tmp))
    })
    
    MCMC_list <- mcmc.list(MCMC_list)
    plot(MCMC_list)
    diagnostics <- gelman.diag(MCMC_list)
    print(diagnostics)
    gelman.plot(MCMC_list)
  }
}

#' @export
#function that takes in list of parameters as chain as evaluates the chain
median_step_size <- function(proposals){
  tmp <- mapply(proposals[-1],proposals[-length(proposals)],FUN = function(x,y){abs(x-y)},SIMPLIFY = FALSE)
  tmp <- t(as.matrix(as.data.frame(tmp)))
  tmp[is.na(tmp)] <- 0
  return(apply(tmp,2,median))
}

#' @export
#Calc the effective sample size of each chain in each dimension
eff_sample_size <- function(chain){
  tmp <- as.matrix(as.data.frame(chain))
  tmp <- apply(tmp,1,FUN = function(y){effectiveSize(as.mcmc(y))})
  return(tmp)
}

#' @export
#function to just plot the gelman diagnostics of a chain:
plot_gelman <- function(output,n_perms,names){
  
  #plot the gelman diagnostics
  len <- length(output$sample)/n_perms
  
  for(i in 1:length(output$sample[[1]])){
    
    MCMC_list <- lapply(0:(n_perms-1),function(x){
      tmp <- output$sample[(1+x*len):((x+1)*len)]
      tmp <- sapply(tmp,function(x){x[i]})
      return(as.mcmc(tmp))
    })
    
    MCMC_list <- mcmc.list(MCMC_list)
    diagnostics <- gelman.diag(MCMC_list)
    gelman.plot(MCMC_list,main = names[i])
    
    if(i==1){result <- diagnostics[[1]][1,]}else{result <- cbind(result,diagnostics[[1]][1,])}
  }
  result <- as.data.frame(result)
  names(result) <- names(output$sample[[1]])
  return(result)
}

#' @export
#function to just plot the trace diagnostics of a chain:
plot_trace <- function(output,n_perms){
  
  #plot the gelman diagnostics
  len <- length(output$sample)/n_perms
  par(mfrow = c(2,2))
  for(i in 1:length(output$sample[[1]])){
    
    MCMC_list <- lapply(0:(n_perms-1),function(x){
      tmp <- output$sample[(1+x*len):((x+1)*len)]
      tmp <- sapply(tmp,function(x){x[i]})
      return(as.mcmc(tmp))
    })
    
    MCMC_list <- mcmc.list(MCMC_list)
    plot(MCMC_list,main = names(output$sample[[1]])[i])
  }
}

#' @export
#function to plot acf for a sample
plot_acf <- function(output,n_perms,n_plot){
  
  #plot the acfs for the chains
  len <- length(output$sample)/n_perms
  par(mfrow = c(2,2))
  for(j in 1:n_perms){
    tmp <- output$sample[(1+(j-1)*len):(j*len)]
    if(j <= n_plot){
      for(i in 1:length(output$sample[[1]])){
        chain <- sapply(tmp,function(x){x[i]})
        acf(chain,lag.max = length(chain)/2, main = names(output$sample[[1]])[i],na.action =na.pass,ylim = c(-1,1))
      }
    }
  }
  return()
}

#' @export
#function to plot a sample
#currently only plots edges,triangles, 2.star, 3.star networks.
plot_chain <- function(sample,
                       terms = list(c("edges","triangles"),c("star.2","star.3"),c("edges","star.2"),c("star.2","edges","triangles")), #list of pairs or triples of names to plot against
                       names,
                       verbose = T){

  tmp <- do.call(rbind,sample)
  names(tmp) <- NULL
  if(verbose){
    sapply(1:length(names),function(x){hist(tmp[,x],main  = names[x],xlab = NULL,ylab = NULL,freq = F)})
    print("The means are:")
    print(apply(tmp,2,mean))
    print("The standard errors are:")
    print(apply(tmp,2,sd))
    cov(tmp)
  }
  data <- as.data.frame(tmp)
  names(data) <- names
  
  plots <- list()
  length(plots) <- length(terms)
  for(i in 1:length(terms)){
    vars = terms[[i]]
    tmp = data[vars]
    
    if(length(vars)==2){
      names(tmp) <- c("x","y")
      title = paste("x = ", vars[1]," and y = ", vars[2],sep="")
      plot <- ggplot(tmp,aes(x=x,y=y))+
        ggtitle(title)+
        geom_hex()
      plots[[i]] = plot
    }
    if(length(vars)==3){
      names(tmp) <- c("x","y","colour")
      title = paste("x = ", vars[1],", y = ", vars[2], " colour = ", vars[3],sep="")
      plot <- ggplot(tmp,aes(x=x,y=y,colour = colour))+
        ggtitle(title)+
        geom_point(size = 2,alpha = 0.5)+
        scale_colour_gradient(low= "yellow",high = "red")
      plots[[i]] = plot
    }
  }
  
  # 
  # #visualise 2 dimensional marginal
  # plot_1 <- ggplot(tmp,aes(x=edges,y=triangles))+
  #   geom_hex()
  # 
  # plot_2 <- ggplot(tmp,aes(x=star.2,y=star.3))+
  #   geom_hex()
  # 
  # plot_3 <- ggplot(tmp,aes(x=edges,y=star.2))+
  #   geom_hex()
  # 
  # plot_4 <- ggplot(tmp,aes(x=star.2,y=edges,colour = triangles))+
  #   geom_point(size = 3,alpha = 0.5)+
  #   scale_colour_gradient(low= "yellow",high = "red")
  
  
  
  
  

  if(verbose){grid::grid.draw(gridExtra::arrangeGrob(grobs = plots))}
  
  return(gridExtra::arrangeGrob(grobs = plots))
}

#' @export
#function summarise a sample

summarise_sample <- function(sample,names){
  tmp <- do.call(rbind,sample)
  names(tmp) <- NULL

  summary = data.frame(post_mean = apply(tmp,2,mean),
                       post_sd = apply(tmp,2,sd) ,
                       post_pvalue = apply(tmp,2,function(x){mean(sign(mean(x))*x <= 0 )}))
  rownames(summary) = names
  summary <- round(summary,4)
  
  return(summary)
}

#' @export
# function to plot the tiles plot for correlation for the parameters
correlation_tile_sample <- function(sample,names){
  
  tmp <- do.call(rbind,sample)
  tmp <- as.data.frame(tmp)
  names(tmp) <- names

  data <- round(cor(tmp),4)
  
  # plot <- ggplot(data,aes(x=Var1,y=Var2,fill = value))+
  #           geom_tile()+
  #           scale_fill_continuous(low = "red",high = "blue")
  plot <- ggcorrplot::ggcorrplot(data)
  return(plot)
}
duncan-clark/Blolog documentation built on June 22, 2022, 7:57 a.m.