R/plot.ctStanModel.R

Defines functions plot.ctStanModel

Documented in plot.ctStanModel

#' Prior plotting
#'
#' Plots priors for free model parameters in a ctStanModel.
#' 
#' @param x ctStanModel object as generated by \code{\link{ctModel}} with type='stanct' or 'standt'.
#' @param rows vector of integers denoting which rows of ctstanmodel$pars to plot priors for. 
#' Character string 'all' plots all rows with parameters to be estimated.
#' @param wait If true, user is prompted to continue before plotting next graph.
#' @param nsamples Numeric. Higher values increase fidelity (smoothness / accuracy) of density plots, at cost of speed.
#' @param rawpopsd Either 'marginalise' to sample from the specified (in the ctstanmodel) 
#' prior distribution for the raw population standard deviation, or a numeric value to use for the raw population standard deviation
#' for all subject level prior plots - the plots in dotted blue or red.
#' @param inddifdevs numeric vector of length 2, setting the means for the individual differences distributions.
#' @param inddifsd numeric, setting the standard deviation of the population means used to generate individual
#' difference distributions. 
#' @param plot If FALSE, ouputs list of GGplot objects that can be further modified.
#' @param ... not used.
#' @details Plotted in black is the prior for the population mean. In red and blue are the subject level priors that result
#' given that the population mean is estimated as 1 std deviation above the mean of the prior, or 1 std deviation below. 
#' The distributions around these two points are then obtained by marginalising over the prior for the raw population std deviation - 
#' so the red and blue distributions do not represent any specific subject level prior, but rather characterise the general amount
#' and shape of possible subject level priors at the specific points of the population mean prior.
#' @method plot ctStanModel
#' @export
#' @examples
#' model <- ctModel(type='stanct',
#' manifestNames='sunspots', 
#' latentNames=c('ss_level', 'ss_velocity'),
#' LAMBDA=matrix(c( 1, 'ma1' ), nrow=1, ncol=2),
#' DRIFT=matrix(c(0, 1,   'a21', 'a22'), nrow=2, ncol=2, byrow=TRUE),
#' MANIFESTMEANS=matrix(c('m1'), nrow=1, ncol=1),
#' # MANIFESTVAR=matrix(0, nrow=1, ncol=1),
#' CINT=matrix(c(0, 0), nrow=2, ncol=1),
#' DIFFUSION=matrix(c(
#'   0, 0,
#'   0, "diffusion"), ncol=2, nrow=2, byrow=TRUE))
#'   
#' plot(model,rows=8)

plot.ctStanModel<-function(x,rows='all',wait=FALSE,nsamples=1e6, rawpopsd='marginalise',
  inddifdevs=c(-1,1),inddifsd=.1,plot=TRUE,...){
  if(!'ctStanModel' %in% class(x)) stop('not a ctStanModel object!')
  
  x <- ctModelTransformsToNum(x)
  x <- T0VARredundancies(x)
  m<-x$pars
  highmean=rnorm(nsamples,inddifdevs[2],inddifsd)
  lowmean= rnorm(nsamples,inddifdevs[1],inddifsd)
  if(rows[1]=='all') rows<-which(is.na(m$value) & 
      !grepl('[',m$param,fixed=TRUE) &
      !duplicated(m$param))#1:nrow(m)
  nplots<-ceiling(length(rows) /6)
  plots <- list()
  if(1==99) Par.Value <- type <- Density <- NULL
  for(ploti in 1:nplots){
    dat <- data.table(Par.Value=0, Density=0,type='',param='')
    for(rowi in if(length(rows) > 1) rows[as.integer(cut_number(rows,nplots))==ploti] else rows){
      
      #rawpopsd
      if(rawpopsd[1]=='marginalise'){
        rawpopsdbase<-  stats::rnorm(ceiling(nsamples/2))
        rawpopsdbase<-  c(rawpopsdbase,-rawpopsdbase) #symmetry
        if(!is.na(x$rawpopsdbaselowerbound)) rawpopsdbase <- rawpopsdbase[rawpopsdbase>x$rawpopsdbaselowerbound]
        sdscale <- as.numeric(m$sdscale[rowi])
        sdtform <- gsub('.*', '*',x$rawpopsdtransform,fixed=TRUE)
        rawpopsdprior<-eval(parse(text=sdtform)) * sdscale
        
      } else if(is.na(as.numeric(rawpopsd))) stop('rawpopsd argument is ill specified!') else {
        rawpopsdprior <- rep(rawpopsd,nsamples)
      }
      denslist<-list()
      #mean
      
      rawpopmeans=stats::rnorm(length(rawpopsdprior))
      # xmean=eval(parse(text=paste0(m$transform[rowi])))
      denslist[[1]]=tform(rawpopmeans,m$transform[rowi], m$multiplier[rowi], m$meanscale[rowi], m$offset[rowi],m$inneroffset[rowi])
      leg <- c('Pop. mean prior')
      colvec <- c(1)
      
      if(m$indvarying[rowi]){
        
        if(inddifdevs[1]=='marginalise'){
          param=stats::rnorm(length(rawpopsdprior),rawpopmeans,rawpopsdprior)
          denslist[[2]]=tform(param,m$transform[rowi], m$multiplier[rowi], m$meanscale[rowi], m$offset[rowi],m$inneroffset[rowi])
          leg <- c('Pop. mean prior', paste0('Subject prior',lowmean))
          colvec <- c(1,2)
        }
        if(inddifdevs[1]!='marginalise'){
          #high
          param=stats::rnorm(length(rawpopsdprior),highmean,rawpopsdprior)
          denslist[[2]]=tform(param,m$transform[rowi], m$multiplier[rowi], m$meanscale[rowi], m$offset[rowi],m$inneroffset[rowi])
          
          #low
          param=stats::rnorm(length(rawpopsdprior),lowmean,rawpopsdprior)
          denslist[[3]]=tform(param,m$transform[rowi], m$multiplier[rowi], m$meanscale[rowi], m$offset[rowi],m$inneroffset[rowi])
          
          leg <- c('Pop. mean prior', paste0('Subject prior\nrawmean ~ N(',inddifdevs[1],',',inddifsd,')'  ),
            paste0('Subject prior\nrawmean ~ N(',inddifdevs[2],',',inddifsd,')'))
          colvec <- c(1,2,4)
        }
      }
      
      dens <- ctDensityList(denslist,probs=c(.01,.99),plot=FALSE)
      for(i in 1:length(leg)){
        dat <- rbind(dat,data.table(Par.Value=dens$density[[i]]$x,
          Density=dens$density[[i]]$y, type=leg[i],param=m$param[rowi]))
      }
    }
    dat <- dat[-1,]
    # browser()
    # dat[,xlow := quantile(Par.Value,.3),by=list(type,param)]
    # dat[,xhigh := quantile(Par.Value,.7),by=list(type,param)]
    # dat[,xhigh := max(xhigh),by=list(param)]
    # dat[,xlow := min(xlow),by=list(param)]
    # dat[,xhigh := min(c(max(Par.Value),max(xhigh+(xhigh-xlow)*2))) ,by=list(param)]
    # dat[,xlow := max(c(min(Par.Value),min(xlow-(xhigh-xlow)*2))) ,by=list(param)]
    # 
    # dat[,Par.Value:=min(Par.Value,xhigh)]
    # 
    # dat$Par.Value[dat$Par.Value >= dat$xhigh] <- dat$xhigh[dat$Par.Value >= dat$xhigh]
    # dat$Par.Value[dat$Par.Value <= dat$xlow] <- dat$xlow[dat$Par.Value <= dat$xlow]
    
    plots<-c(plots,list(
      ggplot(dat,aes(x=Par.Value,fill=type,ymax=Density,y=Density) )+
        geom_line(alpha=.3) +
        geom_ribbon(alpha=.4,ymin=0) +
        scale_fill_manual(values=c('black','red','blue')) +
        # geom_violin(alpha=.4) +
        # coord_cartesian(xlim = c(vars(xlow)[1], vars(xhigh)[1]))+
        # coord_cartesian(xlim = mean(vars(x))-sd(vars(x)), mean(vars(x))+sd(vars(x)))+
        theme_minimal()+
        theme(legend.title = element_blank(),legend.position='top')+
        facet_wrap(vars(param),scales='free')
    ))
  }
  # ctDensityList(denslist,plot = TRUE, probs=c(.01,.99),main=m$param[rowi],
  #   # cex=.8,cex.main=.8,cex.axis=.8,cex.lab=.8,cex.sub=.8,
  #   # legend = leg,
  #   # legendargs=list(cex=.8),
  #   colvec = colvec)
  
  if(plot) {
    firstplot=TRUE
    lapply(plots,function(x){
      if(wait && !firstplot) readline("Press [return] for next plot.")
      firstplot <<- FALSE
      suppressWarnings(print(x))
    })
    return(invisible(NULL))
  } else return(plots)
  
}
cdriveraus/ctsem documentation built on May 3, 2024, 12:37 p.m.