R/ctStanPlotPost.R

Defines functions ctStanPlotPost

Documented in ctStanPlotPost

#' ctStanPlotPost
#'
#' Plots prior and posterior distributions of model parameters in a ctStanModel or ctStanFit object.
#' 
#' @param obj fit or model object as generated by \code{\link{ctStanFit}},
#' \code{\link{ctModel}}, or \code{\link{ctStanModel}}.
#' @param rows vector of integers denoting which rows of obj$setup$popsetup to plot priors for. 
#' Character string 'all' plots all rows with parameters to be estimated. 
#' @param npp Integer number of parameters to show per page.
#' @param priorwidth if TRUE, plots will be scaled to show bulk of both the prior 
#' and posterior distributions. If FALSE, scale is based only on the posterior.
#' @param smoothness Positive numeric -- multiplier to modify smoothness of density plots, higher is smoother but
#' can cause plots to exceed natural boundaries, such as standard deviations below zero.
#' @param priorsamples number of samples from prior to use. More is slower.
#' @param wait If true, user is prompted to continue before plotting next graph.  
#' If false, graphs are plotted one after another without waiting.
#' @param plot Logical, if FALSE, ggplot objects are returned in a list instead of plotting.
#' @param ... Parameters to pass to ctStanFit. \code{cores = x} will speed things up,
#' where x is the number of cpu cores to use.
#' @examples
#' \donttest{
#' ctStanPlotPost(ctstantestfit, rows=3:4)
#' }
#' @export

ctStanPlotPost<-function(obj, rows='all', npp=6,priorwidth=TRUE, 
  smoothness=1,priorsamples=10000,
  plot=TRUE,wait=FALSE,...){
  
  # if(!priorwidth) message('priorwidth argument temporarily unavailable sorry...')
  if(!(class(obj) %in% c('ctStanFit','ctStanModel'))) stop('not a ctStanFit or ctStanModel object!')
  plots <- list()
  densiter <- 1e5
  ps <- cbind(obj$setup$popsetup, obj$setup$popvalues)
  ps <- ps[ps$when %in% c(0,-1) & ps$param > 0 & ps$copyrow < 1 & ps$matrix < 11,]
  ps <- ps[!duplicated(ps$param),]

  ps<-ps[order(ps$param),]

  
  e<-ctExtract(obj)
  priors <- ctStanGenerate(cts = obj,parsonly=TRUE,nsamples=priorsamples,...)
  priors <- priors$stanfit$transformedpars
  posteriors <- ctExtract(obj)
  
  
  if(rows[1]=='all') rows<-1:nrow(ps)
  nplots<-ceiling(length(rows) /4)
  if(1==99) Par.Value <- type <- quantity <- Density <- NULL

  quantity <- c('Posterior','Prior')
  for(ploti in 1:nplots){
    dat <- data.table(quantity='',Par.Value=0, Density=0,type='',param='')
    for(ri in if(length(rows) > 1) rows[as.integer(cut_number(rows,nplots))==ploti] else rows){
      # browser()
      pname <- ps$parname[ri]
      pari <- ps[ri,'param']
      meanpost <- posteriors$popmeans[,pari]
      meanprior <- priors$popmeans[,pari]
      if(priorwidth) xlimsindex <- 'all' else xlimsindex <- 1
      mdens <- ctDensityList(list(meanpost, meanprior),probs=c(.05,.95),plot=FALSE,
        xlimsindex=xlimsindex,cut=TRUE)
      quantity <- c('Posterior','Prior')
      for(i in 1:length(mdens$density)){
        dat <- rbind(dat,data.table(quantity=quantity[i],Par.Value=mdens$density[[i]]$x,
          Density=mdens$density[[i]]$y, type='Pop. Mean',param=pname))
      }
      
      
      if(ps[ri,'indvarying']>0){ #then also plot sd and subject level pars
       
        
        posteriorsd <- posteriors$popsd[,ps$indvarying[ri]]
        priorsd <- priors$popsd[,ps$indvarying[ri]]
        
        
        sddens <- ctDensityList(list(posteriorsd, priorsd),probs=c(.05,.95),plot=FALSE,
          xlimsindex=xlimsindex)
        for(i in 1:length(sddens$density)){
          dat <- rbind(dat,data.table(quantity=quantity[i],Par.Value=sddens$density[[i]]$x,
            Density=sddens$density[[i]]$y, type='Pop. SD',param=pname))
        }
        
      }
    }
    dat <- dat[-1,]
    
  
      plots<-c(plots,list(
      ggplot(dat,aes(x=Par.Value,fill=quantity,ymax=Density,y=Density) )+
        geom_line(alpha=.3) +
        geom_ribbon(alpha=.4,ymin=0) +
        scale_fill_manual(values=c('red','blue')) +
        theme_minimal()+
        theme(legend.title = element_blank(),
          panel.grid.minor = element_line(size = 0.1), panel.grid.major = element_line(size = .2),
          strip.text.x = element_text(margin = margin(.01, 0, .01, 0, "cm"))) +
          # coord_cartesian(xlim=c(min(dat$xmin),max(dat$xmax)))+
        facet_wrap(vars(type,param),scales='free')
    ))
  }
  
  if(plot) {
    firstplot=TRUE
    lapply(plots,function(x){
      if(wait && !firstplot) readline("Press [return] for next plot.")
      firstplot <<- FALSE
      suppressWarnings(print(x))
    })
    return(invisible(NULL))
  } else return(plots)
  # do.call(graphics::par,paroriginal)
  
}

Try the ctsem package in your browser

Any scripts or data that you put into this service are public.

ctsem documentation built on Nov. 2, 2023, 6:03 p.m.