R/ctStanFit.R

Defines functions ctStanFit T0VARredundancies ctStanFitUpdate

Documented in ctStanFit ctStanFitUpdate

#' Update a ctStanFit object
#' 
#' Either to include different data, or because you have upgraded ctsem and the internal data structure has changed.
#'
#' @param oldfit fit object to be upgraded
#' @param data replacement long format data object
#' @param recompile whether to force a recompile -- safer but slower and usually unnecessary.
#' @param refit if TRUE, refits the model using the old estimates as a starting point. Only applicable for
#' optimized fits, not sampling.
#' @param ... extra arguments to pass to ctStanFit
#'
#' @return updated ctStanFit object.
#' @export
#'
#' @examples
#' newfit <- ctStanFitUpdate(ctstantestfit,refit=FALSE)

ctStanFitUpdate <- function(oldfit, data=NA, recompile=FALSE,refit=FALSE,...){
  
  if(!refit) message('Trying to do a quick update -- if there are problems, try with refit=TRUE for more robustness')
  
  dots <- list(...)
  args <- as.list(oldfit$args)
  for(n in names(dots)){
    args[[n]] <- dots[[n]]
  }
  if(length(oldfit$stanfit$stanfit@sim) > 0) refit=FALSE
  args$fit <- refit
  args$inits <- oldfit$stanfit$rawest
  args$ctstanmodel <- oldfit$ctstanmodelbase
  
  newargs <- as.list(args(ctStanFit))
  for(argi in names(args)){
    if(argi %in% names(args)) newargs[[argi]] <- args[[argi]] else message(argi, ' is no longer a valid argument, dropping...')
  }
  
  
  if(length(data==1)) args$datalong <- standatatolong(oldfit$standata,origstructure = TRUE,ctm=oldfit$ctstanmodelbase)
  if(length(data) > 1) args$datalong <- data
  newfit <- do.call(ctStanFit,args)
  
  if(!refit){
    oldfit$standata <- newfit$standata
    if(oldfit$ctstanmodel$recompile || recompile) oldfit$stanmodel <- rstan::stan_model(model_code = newfit$stanmodeltext) else
      oldfit$stanmodel <- stanmodels$ctsm
  } 
  if(refit) oldfit <- newfit
  return(oldfit)
}


T0VARredundancies <- function(ctm) {
  whichT0VAR_T0MEANSindvarying <- ctm$pars$matrix %in% 'T0VAR'  &  
    is.na(ctm$pars$value) &
    (ctm$pars$row %in% ctm$pars$row[ctm$pars$matrix %in% 'T0MEANS' & ctm$pars$indvarying] |
        ctm$pars$col %in% ctm$pars$row[ctm$pars$matrix %in% 'T0MEANS' & ctm$pars$indvarying])
  if(any(whichT0VAR_T0MEANSindvarying)){
    message('Free T0VAR parameters as well as indvarying T0MEANS -- fixing T0VAR pars to diag matrix of 1e-3')
    ctm$pars$value[whichT0VAR_T0MEANSindvarying & ctm$pars$col == ctm$pars$row ] <- 1e-3
    ctm$pars$value[whichT0VAR_T0MEANSindvarying & ctm$pars$col != ctm$pars$row ] <- 0
    ctm$pars$param[whichT0VAR_T0MEANSindvarying] <- NA
    ctm$pars$transform[whichT0VAR_T0MEANSindvarying] <- NA
    ctm$pars$indvarying[whichT0VAR_T0MEANSindvarying] <- FALSE
  }
  return(ctm)
}



# verbosify<-function(sf,verbose=2){
#   sm <- sf$stanmodel
#   sd <- sf$standata
#   sd$verbose=as.integer(verbose)
#   
#   sfr <- stan_reinitsf(sm,sd)
#   log_prob(sfr,sf$stanfit$rawest)
# }

#' ctStanFit
#'
#' Fits a ctsem model specified via \code{\link{ctModel}} with type either 'stanct' or 'standt'.
#' 
#' @param datalong long format data containing columns for subject id (numeric values, 1 to max subjects), manifest variables, 
#' any time dependent (i.e. varying within subject) predictors, 
#' and any time independent (not varying within subject) predictors.
#' @param ctstanmodel model object as generated by \code{\link{ctModel}} with type='stanct' or 'standt', for continuous or discrete time
#' models respectively.
#' @param stanmodeltext already specified Stan model character string, generally leave NA unless modifying Stan model directly.
#' (Possible after modification of output from fit=FALSE)
#' @param intoverstates logical indicating whether or not to integrate over latent states using a Kalman filter. 
#' Generally recommended to set TRUE unless using non-gaussian measurement model. 
#' @param binomial Deprecated. Logical indicating the use of binary rather than Gaussian data, as with IRT analyses.
#' This now sets \code{intoverstates = FALSE} and the \code{manifesttype} of every indicator to 1, for binary.
#' @param fit If TRUE, fit specified model using Stan, if FALSE, return stan model object without fitting.
#' @param intoverpop if 'auto', set to TRUE if optimizing and FALSE if using hmc. 
#' if TRUE, integrates over population distribution of parameters rather than full sampling.
#' Allows for optimization of non-linearities and random effects.
#' @param sameInitialTimes if TRUE, include an empty observation for every subject that has no observation 
#' at the earliest observation time of the dataset. This ensures that the T0MEANS occurs for every subject at the same time,
#' rather than just at the earliest observation for that subject. Important when modelling trends over time, age, etc. 
#' @param plot if TRUE, for sampling, a Shiny program is launched upon fitting to interactively plot samples. 
#' May struggle with many (e.g., > 5000) parameters. For optimizing, various optimization details are plotted -- in development.
#' @param derrind deprecated, latents involved in dynamic error calculations are determined automatically now.
#' @param optimize if TRUE, use \code{\link{stanoptimis}} function for maximum a posteriori / importance sampling estimates, 
#' otherwise use the HMC sampler from Stan, which is (much) slower, but generally more robust, accurate, and informative.
#' @param optimcontrol list of parameters sent to \code{\link{stanoptimis}} governing optimization / importance sampling.
#' @param nopriors deprecated, use priors argument. logical. If TRUE, any priors are disabled -- sometimes desirable for optimization. 
#' @param priors if TRUE, priors are included in computations, otherwise specified priors are ignored.
#' @param iter used when \code{optimize=FALSE}. number of iterations, half of which will be devoted to warmup by default when sampling.
#' When optimizing, this is the maximum number of iterations to allow -- convergence hopefully occurs before this!
#' @param inits either character string 'optimize, NULL, or vector of (unconstrained)
#' parameter start values, as returned by the rstan function \code{rstan::unconstrain_pars}, or the parameter values
#' found in a ctsem fit object \code{myfit$stanfit$rawest} (or \code{$rawposterior}) for instance. 
#' @param chains used when \code{optimize=FALSE}. Number of chains to sample, during HMC or post-optimization importance sampling. Unless the cores
#' argument is also set, the number of chains determines the number of cpu cores used, up to 
#' the maximum available minus one. Irrelevant when \code{optimize=TRUE}.
#' @param cores number of cpu cores to use. Either 'maxneeded' to use as many as available minus one,
#' up to the number of chains, or a positive integer. If \code{optimize=TRUE}, more cores are generally faster.
#' @param control Used when \code{optimize=FALSE}. List of arguments sent to \code{\link[rstan]{stan}} control argument, 
#' regarding warmup / sampling behaviour. Unless specified, values used are:
#' list(adapt_delta = .8, adapt_window=2, max_treedepth=10, adapt_init_buffer=2, stepsize = .001)
#' @param nlcontrol List of non-linear control parameters. 
#' \code{maxtimestep} must be a positive numeric,  specifying the largest time
#' span covered by the numerical integration. The large default ensures that for each observation time interval, 
#' only a single step of exponential integration is used. When \code{maxtimestep} is smaller than the observation time interval, 
#' the integration is nested within an Euler like loop. 
#' Smaller values may offer greater accuracy, but are slower and not always necessary. Given the exponential integration,
#' linear model elements are fit exactly with only a single step. 
#' @param verbose Integer from 0 to 2. Higher values print more information during model fit -- for debugging.
#' @param stationary Logical. If TRUE, T0VAR and T0MEANS input matrices are ignored, 
#' the parameters are instead fixed to long run expectations. More control over this can be achieved
#' by instead setting parameter names of T0MEANS and T0VAR matrices in the input model to 'stationary', for
#' elements that should be fixed to stationarity.
#' @param forcerecompile logical. For development purposes. 
#' If TRUE, stan model is recompiled, regardless of apparent need for compilation.
#' @param saveCompile if TRUE and compilation is needed / requested, writes the stan model to
#' the parent frame as ctsem.compiled (unless that object already exists and is not from ctsem), to avoid unnecessary recompilation.
#' @param savescores Logical. If TRUE, output from the Kalman filter is saved in output. For datasets with many variables
#' or time points, will increase file size substantially.
#' @param savesubjectmatrices Logical. If TRUE, subject specific matrices are saved -- 
#' only relevant when either time dependent predictors or individual differences are 
#' used. Can increase memory usage dramatically in large models, and can be computed after fitting using ctExtract
#' or ctStanSubjectPars .
#' @param saveComplexPars Logical. If TRUE, also save rowwise output of any complex parameters specified,
#' i.e. combinations of parameters, functions and states. 
#' @param gendata Logical -- If TRUE, uses provided data for only covariates and a time and missingness structure, and 
#' generates random data according to the specified model / priors. 
#' Generated data is in the $Ygen subobject after running \code{extract} on the fit object.
#' For datasets with many manifest variables or time points, file size may be large.
#' To generate data based on the posterior of a fitted model, see \code{\link{ctStanGenerateFromFit}}.
#' @param vb Logical. Use variational Bayes algorithm from stan? Only kind of working, not recommended.
#' @param ... additional arguments to pass to \code{\link[rstan]{stan}} function.
#' @export
#' @examples
#' \donttest{
#' 
#' #generate a ctStanModel relying heavily on defaults
#' model<-ctModel(type='stanct',
#'   latentNames=c('eta1','eta2'),
#'   manifestNames=c('Y1','Y2'),
#'   MANIFESTVAR=diag(.1,2),
#'   TDpredNames='TD1', 
#'   TIpredNames=c('TI1','TI2','TI3'),
#'   LAMBDA=diag(2)) 
#' 
#' fit<-ctStanFit(ctstantestdat, model,priors=TRUE)
#' 
#' summary(fit) 
#' 
#' plot(fit,wait=FALSE)
#' 
#' #### extended examples
#' 
#' library(ctsem)
#' set.seed(3)
#' 
#' #  Data generation (run this, but no need to understand!) -----------------
#' 
#' Tpoints <- 20
#' nmanifest <- 4
#' nlatent <- 2
#' nsubjects<-20
#' 
#' #random effects
#' age <- rnorm(nsubjects) #standardised
#' cint1<-rnorm(nsubjects,2,.3)+age*.5
#' cint2 <- cint1*.5+rnorm(nsubjects,1,.2)+age*.5
#' tdpredeffect <- rnorm(nsubjects,5,.3)+age*.5
#' 
#' for(i in 1:nsubjects){
#'   #generating model
#'   gm<-ctModel(Tpoints=Tpoints,n.manifest = nmanifest,n.latent = nlatent,n.TDpred = 1,
#'     LAMBDA = matrix(c(1,0,0,0, 0,1,.8,1.3),nrow=nmanifest,ncol=nlatent),
#'     DRIFT=matrix(c(-.3, .2, 0, -.5),nlatent,nlatent),
#'     TDPREDMEANS=matrix(c(rep(0,Tpoints-10),1,rep(0,9)),ncol=1),
#'     TDPREDEFFECT=matrix(c(tdpredeffect[i],0),nrow=nlatent),
#'     DIFFUSION = matrix(c(1, 0, 0, .5),2,2),
#'     CINT = matrix(c(cint1[i],cint2[i]),ncol=1),
#'     T0VAR=diag(2,nlatent,nlatent),
#'     MANIFESTVAR = diag(.5, nmanifest))
#' 
#'   #generate data
#'   newdat <- ctGenerate(ctmodelobj = gm,n.subjects = 1,burnin = 2,
#'     dtmat<-rbind(c(rep(.5,8),3,rep(.5,Tpoints-9))))
#'   newdat[,'id'] <- i #set id for each subject
#'   newdat <- cbind(newdat,age[i]) #include time independent predictor
#'   if(i ==1) {
#'     dat <- newdat[1:(Tpoints-10),] #pre intervention data
#'     dat2 <- newdat #including post intervention data
#'   }
#'   if(i > 1) {
#'     dat <- rbind(dat, newdat[1:(Tpoints-10),])
#'     dat2 <- rbind(dat2,newdat)
#'   }
#' }
#' colnames(dat)[ncol(dat)] <- 'age'
#' colnames(dat2)[ncol(dat)] <- 'age'
#' 
#' 
#' #plot generated data for sanity
#' plot(age)
#' matplot(dat[,gm$manifestNames],type='l',pch=1)
#' plotvar <- 'Y1'
#' plot(dat[dat[,'id']==1,'time'],dat[dat[,'id']==1,plotvar],type='l',
#'   ylim=range(dat[,plotvar],na.rm=TRUE))
#' for(i in 2:nsubjects){
#'   points(dat[dat[,'id']==i,'time'],dat[dat[,'id']==i,plotvar],type='l',col=i)
#' }
#' 
#' 
#' dat2[,gm$manifestNames][sample(1:length(dat2[,gm$manifestNames]),size = 100)] <- NA
#' 
#' 
#' #data structure
#' head(dat2)
#' 
#' 
#' # Model fitting -----------------------------------------------------------
#' 
#' ##simple univariate default model
#' 
#' m <- ctModel(type = 'stanct', manifestNames = c('Y1'), LAMBDA = diag(1))
#' ctModelLatex(m)
#' 
#' #Specify univariate linear growth curve
#' 
#' m1 <- ctModel(type = 'stanct',
#'   manifestNames = c('Y1'), latentNames=c('eta1'),
#'   DRIFT=matrix(-.0001,nrow=1,ncol=1),
#'   DIFFUSION=matrix(0,nrow=1,ncol=1),
#'   T0VAR=matrix(0,nrow=1,ncol=1),
#'   CINT=matrix(c('cint1'),ncol=1),
#'   T0MEANS=matrix(c('t0m1'),ncol=1),
#'   LAMBDA = diag(1),
#'   MANIFESTMEANS=matrix(0,ncol=1),
#'   MANIFESTVAR=matrix(c('merror'),nrow=1,ncol=1))
#' 
#' ctModelLatex(m1)
#' 
#' #fit
#' f1 <- ctStanFit(datalong = dat2, ctstanmodel = m1, optimize=TRUE, priors=FALSE)
#' 
#' summary(f1)
#' 
#' #plots of individual subject models v data
#' ctKalman(f1,plot=TRUE,subjects=1,kalmanvec=c('y','yprior'),timestep=.01)
#' ctKalman(f1,plot=TRUE,subjects=1:3,kalmanvec=c('y','ysmooth'),timestep=.01,errorvec=NA)
#' 
#' ctStanPostPredict(f1, wait=FALSE) #compare randomly generated data from posterior to observed data
#' 
#' cf<-ctCheckFit(f1) #compare mean and covariance of randomly generated data to observed cov
#' plot(cf,wait=FALSE)
#' 
#'  ### Further example models
#' 
#' #Include intervention
#' m2 <- ctModel(type = 'stanct',
#'   manifestNames = c('Y1'), latentNames=c('eta1'),
#'   n.TDpred=1,TDpredNames = 'TD1', #this line includes the intervention
#'   TDPREDEFFECT=matrix(c('tdpredeffect'),nrow=1,ncol=1), #intervention effect
#'   DRIFT=matrix(-1e-5,nrow=1,ncol=1),
#'   DIFFUSION=matrix(0,nrow=1,ncol=1),
#'   CINT=matrix(c('cint1'),ncol=1),
#'   T0MEANS=matrix(c('t0m1'),ncol=1),
#'   T0VAR=matrix(0,nrow=1,ncol=1),
#'   LAMBDA = diag(1),
#'   MANIFESTMEANS=matrix(0,ncol=1),
#'   MANIFESTVAR=matrix(c('merror'),nrow=1,ncol=1))
#' 
#' 
#' 
#' #Individual differences in intervention, Bayesian estimation, covariates
#' m2i <- ctModel(type = 'stanct',
#'   manifestNames = c('Y1'), latentNames=c('eta1'),
#'   TIpredNames = 'age',
#'   TDpredNames = 'TD1', #this line includes the intervention
#'   TDPREDEFFECT=matrix(c('tdpredeffect||TRUE'),nrow=1,ncol=1), #intervention effect
#'   DRIFT=matrix(-1e-5,nrow=1,ncol=1),
#'   DIFFUSION=matrix(0,nrow=1,ncol=1),
#'   CINT=matrix(c('cint1'),ncol=1),
#'   T0MEANS=matrix(c('t0m1'),ncol=1),
#'   T0VAR=matrix(0,nrow=1,ncol=1),
#'   LAMBDA = diag(1),
#'   MANIFESTMEANS=matrix(0,ncol=1),
#'   MANIFESTVAR=matrix(c('merror'),nrow=1,ncol=1))
#'   
#'   
#' #Including covariate effects
#' m2ic <- ctModel(type = 'stanct',
#'   manifestNames = c('Y1'), latentNames=c('eta1'),
#'   n.TIpred = 1, TIpredNames = 'age',
#'   n.TDpred=1,TDpredNames = 'TD1', #this line includes the intervention
#'   TDPREDEFFECT=matrix(c('tdpredeffect'),nrow=1,ncol=1), #intervention effect
#'   DRIFT=matrix(-1e-5,nrow=1,ncol=1),
#'   DIFFUSION=matrix(0,nrow=1,ncol=1),
#'   CINT=matrix(c('cint1'),ncol=1),
#'   T0MEANS=matrix(c('t0m1'),ncol=1),
#'   T0VAR=matrix(0,nrow=1,ncol=1),
#'   LAMBDA = diag(1),
#'   MANIFESTMEANS=matrix(0,ncol=1),
#'   MANIFESTVAR=matrix(c('merror'),nrow=1,ncol=1))
#' 
#' m2ic$pars$indvarying[m2ic$pars$matrix %in% 'TDPREDEFFECT'] <- TRUE
#' 
#' 
#' #Include deterministic dynamics
#' m3 <- ctModel(type = 'stanct',
#'   manifestNames = c('Y1'), latentNames=c('eta1'),
#'   n.TDpred=1,TDpredNames = 'TD1', #this line includes the intervention
#'   TDPREDEFFECT=matrix(c('tdpredeffect'),nrow=1,ncol=1), #intervention effect
#'   DRIFT=matrix('drift11',nrow=1,ncol=1),
#'   DIFFUSION=matrix(0,nrow=1,ncol=1),
#'   CINT=matrix(c('cint1'),ncol=1),
#'   T0MEANS=matrix(c('t0m1'),ncol=1),
#'   T0VAR=matrix('t0var11',nrow=1,ncol=1),
#'   LAMBDA = diag(1),
#'   MANIFESTMEANS=matrix(0,ncol=1),
#'   MANIFESTVAR=matrix(c('merror1'),nrow=1,ncol=1))
#' 
#' 
#' 
#' 
#' 
#' #Add system noise to allow for fluctuations that persist in time
#' m3n <- ctModel(type = 'stanct',
#'   manifestNames = c('Y1'), latentNames=c('eta1'),
#'   n.TDpred=1,TDpredNames = 'TD1', #this line includes the intervention
#'   TDPREDEFFECT=matrix(c('tdpredeffect'),nrow=1,ncol=1), #intervention effect
#'   DRIFT=matrix('drift11',nrow=1,ncol=1),
#'   DIFFUSION=matrix('diffusion',nrow=1,ncol=1),
#'   CINT=matrix(c('cint1'),ncol=1),
#'   T0MEANS=matrix(c('t0m1'),ncol=1),
#'   T0VAR=matrix('t0var11',nrow=1,ncol=1),
#'   LAMBDA = diag(1),
#'   MANIFESTMEANS=matrix(0,ncol=1),
#'   MANIFESTVAR=matrix(c(0),nrow=1,ncol=1))
#' 
#' 
#' 
#' #include 2nd latent process
#' 
#' m4 <- ctModel(n.manifest = 2,n.latent = 2, type = 'stanct',
#'   manifestNames = c('Y1','Y2'), latentNames=c('L1','L2'),
#'   n.TDpred=1,TDpredNames = 'TD1',
#'   TDPREDEFFECT=matrix(c('tdpredeffect1','tdpredeffect2'),nrow=2,ncol=1),
#'   DRIFT=matrix(c('drift11','drift21','drift12','drift22'),nrow=2,ncol=2),
#'   DIFFUSION=matrix(c('diffusion11','diffusion21',0,'diffusion22'),nrow=2,ncol=2),
#'   CINT=matrix(c('cint1','cint2'),nrow=2,ncol=1),
#'   T0MEANS=matrix(c('t0m1','t0m2'),nrow=2,ncol=1),
#'   T0VAR=matrix(c('t0var11','t0var21',0,'t0var22'),nrow=2,ncol=2),
#'   LAMBDA = matrix(c(1,0,0,1),nrow=2,ncol=2),
#'   MANIFESTMEANS=matrix(c(0,0),nrow=2,ncol=1),
#'   MANIFESTVAR=matrix(c('merror1',0,0,'merror2'),nrow=2,ncol=2))
#' 
#' #dynamic factor model -- fixing CINT to 0 and freeing indicator level intercepts
#' 
#' m3df <- ctModel(type = 'stanct',
#'   manifestNames = c('Y2','Y3'), latentNames=c('eta1'),
#'   n.TDpred=1,TDpredNames = 'TD1', #this line includes the intervention
#'   TDPREDEFFECT=matrix(c('tdpredeffect'),nrow=1,ncol=1), #intervention effect
#'   DRIFT=matrix('drift11',nrow=1,ncol=1),
#'   DIFFUSION=matrix('diffusion',nrow=1,ncol=1),
#'   CINT=matrix(c(0),ncol=1),
#'   T0MEANS=matrix(c('t0m1'),ncol=1),
#'   T0VAR=matrix('t0var11',nrow=1,ncol=1),
#'   LAMBDA = matrix(c(1,'Y3loading'),nrow=2,ncol=1),
#'   MANIFESTMEANS=matrix(c('Y2_int','Y3_int'),nrow=2,ncol=1),
#'   MANIFESTVAR=matrix(c('Y2residual',0,0,'Y3residual'),nrow=2,ncol=2))
#' 
#' }

ctStanFit<-function(datalong, ctstanmodel, stanmodeltext=NA, iter=1000, intoverstates=TRUE, binomial=FALSE,
  fit=TRUE, intoverpop='auto', sameInitialTimes=FALSE, stationary=FALSE,plot=FALSE,  derrind=NA,
  optimize=TRUE,  optimcontrol=list(),
  nlcontrol = list(), nopriors=NA, priors=FALSE, chains=2,
  cores=ifelse(optimize,getOption("mc.cores", 2L),'maxneeded'),
  inits=NULL,
  forcerecompile=FALSE,saveCompile=TRUE,savescores=FALSE,
  savesubjectmatrices=FALSE, saveComplexPars=FALSE,
  gendata=FALSE,
  control=list(),verbose=0,vb=FALSE,...){
  
  if(!is.na(nopriors)){
    warning('nopriors argument is deprecated, use priors argument in future')
    priors <- !nopriors
  }
  
  if(any(!is.na(derrind))) warning('derrind argment is deprecated, computed automatically now')
  
  datalong <- data.frame(datalong)

  if(!ctstanmodel$timeName %in% colnames(datalong) && !ctstanmodel$continuoustime) {
    dtable <- data.table(datalong)
    dtable[,.ObsCount:=1:.N,by=ctstanmodel$id]
    datalong[[ctstanmodel$timeName]] <- dtable[['.ObsCount']]
    rm(dtable)
  }
  
  datalong <- datalong[order(datalong[[ctstanmodel$subjectIDname]],datalong[[ctstanmodel$timeName]]),] #sort by subject, time.
  
  datavars <- c(ctstanmodel$timeName,ctstanmodel$subjectIDname, ctstanmodel$manifestNames,ctstanmodel$TDpredNames,ctstanmodel$TIpredNames)
  sapply(datavars,function(x){
    if(!x %in% colnames(datalong)) stop(paste0(x,' column not found in data!'))
    if(!x %in% ctstanmodel$subjectIDname){ #if not an id column
      if(any(!is.numeric(as.numeric(datalong[!is.na(datalong[,x]),x])))) stop(x ,' column contains non-numeric data!')
    }
  })
  
  if(!'ctStanModel' %in% class(ctstanmodel)) stop('not a ctStanModel object')
  
  #set nlcontrol defaults
  if(is.null(nlcontrol$maxtimestep)) nlcontrol$maxtimestep = 999999
  if(is.null(nlcontrol$Jstep)) nlcontrol$Jstep = 1e-6
  
  args=c(as.list(environment()), list(...)) #as.list((match.call(expand.dots=FALSE)))
  args$datalong <- NULL
  args$ctstanmodel <- NULL
  
  ctm <- ctstanmodel
  
  if(!is.null(ctm$TIpredAuto) && ctm$TIpredAuto %in% c(1L,TRUE)){ #if auto tipred, set all effects to true
    for(tip in ctm$TIpredNames){
      ctm$pars[[paste0(tip,'_effect')]] <- TRUE
    }
  }
  
  if(optimize && !priors) message("Maximum likelihood estimation requested")
  if(optimize && priors && (is.null(optimcontrol$is)  || optimcontrol$is %in% FALSE)) message("Maximum a posteriori estimation requested")
  if(optimize && priors && (!is.null(optimcontrol$is)  && optimcontrol$is %in% TRUE)) message("Bayesian estimation via optimization and importance sampling requested")
  if(!optimize) message("Bayesian estimation via Stan's NUTS sampler requested")
  
  
  ###stationarity
  if(stationary) {
    stop('Stationary option temporarily unavailable -- reductions needed to pass all CRAN checks')
    ctm$pars$param[ctm$pars$matrix %in% c('T0VAR','T0MEANS')] <- 'stationary'
    ctm$pars$value[ctm$pars$matrix %in% c('T0VAR','T0MEANS')] <- NA
    ctm$pars$indvarying[ctm$pars$matrix %in% c('T0VAR','T0MEANS')] <- FALSE
  }
  
  #collect individual stationary elements and update ctm$pars
  if(any(ctm$pars$param %in% 'stationary')) stop('Stationary option temporarily unavailable -- reductions needed to pass all CRAN checks')
  ctm$t0varstationary <- as.matrix(rbind(ctm$pars[which(ctm$pars$param %in% 'stationary' & ctm$pars$matrix %in% 'T0VAR'),c('row','col')]))
  if(nrow(ctm$t0varstationary) > 0){ #ensure upper tri is consistent with lower
    for(i in 1:nrow(ctm$t0varstationary)){
      if(ctm$t0varstationary[i,1] != ctm$t0varstationary[i,2]) ctm$t0varstationary <- rbind(ctm$t0varstationary,ctm$t0varstationary[i,c(2,1)])
    }}
  ctm$t0varstationary = unique(ctm$t0varstationary) #remove any duplicated rows
  ctm$t0meansstationary <- as.matrix(rbind(ctm$pars[which(ctm$pars$param[ctm$pars$matrix %in% 'T0MEANS'] %in% 'stationary'),c('row','col')]))
  ctm$pars$value[ctm$pars$param %in% 'stationary'] <- -99 #does this get inserted?
  ctm$pars$indvarying[ctm$pars$param %in% 'stationary'] <- FALSE
  ctm$pars$transform[ctm$pars$param %in% 'stationary'] <- NA
  ctm$pars$param[ctm$pars$param %in% 'stationary'] <- NA
  
  
  if(length(unique(datalong[,ctm$subjectIDname]))==1 && any(ctm$pars$indvarying[is.na(ctm$pars$value)]==TRUE)){
    # is.null(ctm$fixedrawpopmeans) && is.null(ctm$fixedsubpars) & is.null(ctm$forcemultisubject)) {
    ctm$pars$indvarying <- FALSE
    message('Individual variation not possible as only 1 subject! indvarying set to FALSE on all parameters')
  }
  
  if(length(unique(datalong[,ctm$subjectIDname]))==1 & any(is.na(ctm$pars$value[ctm$pars$matrix %in% 'T0VAR']))){ 
    # is.null(ctm$fixedrawpopmeans) & is.null(ctm$fixedsubpars) & is.null(ctm$forcemultisubject)) {
    for(ri in 1:nrow(ctm$pars)){
      if(is.na(ctm$pars$value[ri]) && ctm$pars$matrix[ri] %in% 'T0VAR'){
        ctm$pars$value[ri] <- ifelse(ctm$pars$row[ri] == ctm$pars$col[ri], 1, 0)
      }
    }
    message('Free T0VAR parameters fixed to diagonal matrix of 1 as only 1 subject - consider appropriateness!')
  }
  
  if(binomial){
    message('Binomial argument deprecated -- in future set manifesttype in the model object to 1 for binary indicators')
    intoverstates <- FALSE
    ctm$manifesttype[] <- 1
  }
  
  recompile <- FALSE
  if(!optimize && !priors){
    message('HMC sampling requested, but priors disabled -- are you sure? consider setting priors=TRUE')
    # !priors <- FALSE
  }
  if(optimize && !intoverstates) warning('intoverstates=TRUE required for sensible optimization! Proceed onwards to weird output at own risk!')
  
  if(intoverpop == 'auto')  intoverpop <- 
    ifelse(optimize && any(ctm$pars$indvarying[is.na(ctm$pars$value)]),TRUE,FALSE)
  
  # if(optimize && !intoverpop && any(ctm$pars$indvarying[is.na(ctm$pars$value)]) && 
  #     is.null(ctm$fixedrawpopchol) && is.null(ctm$fixedsubpars)){
  #   intoverpop <- TRUE
  #   message('Setting intoverpop=TRUE to enable optimization of random effects...')
  # }
  
  # if(intoverpop==TRUE && !any(ctm$pars$indvarying[is.na(ctm$pars$value)])) {
  #   # message('No individual variation -- disabling intoverpop switch'); 
  #   intoverpop <- FALSE
  # }
  
  
  ctm <- ctModel0DRIFT(ctm, ctm$continuoustime) #offset 0 drift
  ctm$pars <- ctModelStatesAndPARS(ctm$pars,statenames = ctm$latentNames,tdprednames=ctm$TDpredNames) #replace latent states and PARS with state and PAR[] refs, need this early because we rely on [] detection
  if(intoverpop)   ctm <- ctStanModelIntOverPop(ctm) #extend system matrices for individual differences
  
  #jacobian addition
  ctm$jacobian <- try(ctJacobian(ctm))
  if('try-error' %in% class(ctm$jacobian)) ctm$jacobian <- ctJacobian(ctm,simplify=FALSE)
  # ctm$jacobian <- unfoldmats( #replaces matrix references with base parameter
  #   c(listOfMatrices(ctm$pars),ctm$jacobian))
  ctm$jacobian <- ctm$jacobian[names(ctStanMatricesList()$jacobian)]
  jl <- ctModelUnlist(ctm$jacobian,names(ctm$jacobian))
  jl <- jl[apply(jl,1,function(x) any(!is.na(x))),] #clean up messy leftovers of NA's
  jl2 <- as.data.frame(rbind(data.table(ctm$pars[1,]),data.table(jl),fill=TRUE))[-1,]
  
  
  for(i in 1:nrow(jl2)){ #copy base parameter transforms etc to jacobian when needed
    if(!is.na(jl2$param[i]) && jl2$param[i] %in% ctm$pars$param){
      jl2[i, !colnames(jl2) %in% list('matrix','row','col')] <-
        ctm$pars[which(ctm$pars$param %in% jl2$param[i])[1], !colnames(ctm$pars) %in% list('matrix','row','col')]
    }
  }
  
  ctm$pars <- rbind(ctm$pars,jl2)
  
  ctm$pars <- ctModelStatesAndPARS(ctm$pars,statenames = ctm$latentNames,tdprednames=ctm$TDpredNames) #replace any new state and par refs with square bracket refs
  
  ctm <- ctModelTransformsToNum(ctm)
  
  ctm$pars <- ctStanModelCleanctspec(ctm$pars)
  
  ctm <- T0VARredundancies(ctm)
  
  if(!all(ctm$pars$transform[!is.na(suppressWarnings(as.integer(ctm$pars$transform)))] %in% c(0,1,2,3,4))) stop('Unknown transform specified -- integers should be 0 to 4')
  
  #fix binary manifestvariance
  
  if(any(ctm$manifesttype > 0)){ #if any non continuous variables, (with free parameters)...
    errfix <- which(ctm$pars$matrix %in% 'MANIFESTVAR' & 
        (ctm$pars$row %in% which(ctm$manifesttype==1) | 
            ctm$pars$col %in% which(ctm$manifesttype==1)) &
        is.na(suppressWarnings(as.numeric(
          ctm$pars$value))))
    
    if(length(errfix) > 0){
      message('Fixing any free MANIFESTVAR parameters for binary indicators to deterministic calculation')
      ctm$pars$value[errfix] <- 1e-5
      ctm$pars[errfix,c('param','transform','multiplier','offset','meanscale','inneroffset','sdscale')] <- NA
      ctm$pars$indvarying[errfix] <- FALSE
    }}
  
  ctm$modelmats <- ctStanModelMatrices(ctm)
  ctm <- ctStanCalcsList(ctm,save=saveComplexPars) #get extra calculations and adjust model spec as needed???
  
  #store values in ctm
  ctm$intoverpop <- as.integer(intoverpop)
  ctm$nlatentpop <- as.integer(ifelse(ctm$intoverpop ==1, max(ctm$pars$row[ctm$pars$matrix %in% 'T0MEANS']),  ctm$n.latent))
  ctm$intoverstates <- as.integer(intoverstates)
  ctm$priors <- as.integer(priors)
  ctm$stationary <- as.integer(stationary)
  ctm$nlcontrol <- nlcontrol
  
  
  
  #recompile checks
  if(forcerecompile) recompile <- TRUE
  if(naf(!is.na(ctm$rawpopsdbaselowerbound))) recompile <- TRUE
  if(ctm$rawpopsdbase != 'normal(0,1)') recompile <- TRUE
  if(ctm$rawpopsdtransform != 'log1p_exp(2*rawpopsdbase-1) .* sdscale') ctm$recompile <- TRUE
  if(any(ctm$modelmats$matsetup[,'transform'] < -10)) recompile <- TRUE #if custom transforms needed
  
  ncalcsNoJ<- length(unlist(ctm$modelmats$calcs)[!grepl('JAx[',unlist(ctm$modelmats$calcs),fixed=TRUE)])
  if(ncalcsNoJ > 0) recompile <- TRUE
  
  if(!recompile && 
      length(unlist(ctm$modelmats$calcs)[!grepl('JAx[',unlist(ctm$modelmats$calcs),fixed=TRUE)])>0)  
    message('Finite difference jacobian used to avoid recompiling -- use forcerecompile=TRUE for analytic jacobians')
  
  
  #further model adjustments conditional on recompile
  if(!recompile){ #then use finite diffs for some elements
    
    #collect row and column of complicated jacobian elements into vector
    ctm$JAxfinite <- array(as.integer(unique(
      unlist(ctm$modelmats$matsetup[ctm$modelmats$matsetup$matrix %in% 52 & 
          ctm$modelmats$matsetup$when == -999, 'col']))))# & 
    # ctm$modelmats$matsetup$copyrow < 1,c('row','col')]))))
    
    ctm$Jyfinite <- array(as.integer(unique(
      unlist(ctm$modelmats$matsetup[ctm$modelmats$matsetup$matrix %in% 54 & 
          ctm$modelmats$matsetup$when == -999, 'col']))))
    
  }
  if(recompile) ctm$JAxfinite <- ctm$Jyfinite <- array(as.integer(c()))
  ctm$recompile <- recompile
  
  
  standata <- ctStanData(ctm,datalong,optimize=optimize, sameInitialTimes=sameInitialTimes) 
  standata$verbose=as.integer(verbose)
  standata$savesubjectmatrices=as.integer(savesubjectmatrices)
  standata$gendata=as.integer(gendata)
  
  if(standata$savesubjectmatrices==1L) savescores = TRUE
  standata$savescores=as.integer(savescores)
  
  # print(standata$savesubjectmatrices)
  
  #####post model / data checks
  if(cores=='maxneeded') cores=max(1,min(c(chains,parallel::detectCores()-1)))
  
  if(is.logical(stanmodeltext)) {
    stanmodeltext<- ctStanModelWriter(ctm, gendata, ctm$modelmats$extratforms,ctm$modelmats$matsetup)
  }
  
  
  
  
  if(fit){
    # if(gendata && stanmodels$ctsmgen@model_code != stanmodeltext) recompile <- TRUE
    # if(!gendata && paste0(stanmodels$ctsm@model_code) != paste0(stanmodeltext)) recompile <- TRUE
    
    # STAN_NUM_THREADS <- Sys.getenv('STAN_NUM_THREADS',unset=NA)
    # Sys.setenv(STAN_NUM_THREADS=cores)
    
    if(recompile || forcerecompile) {
      message('Compiling model...') 
      
      #r4.2 / windows / old rstan check
      if(.Platform$OS.type=="windows" && R.version$major %in% 4 && as.numeric(R.version$minor) >= 2 &&
          unlist(utils::packageVersion('rstan'))[2] < 25){
        stop('
*****
Compiling not possible with R version 4.2+ on Windows with Rstan version < 2.26
To upgrade rstan, close and restart all R sessions then run:
              
install.packages("StanHeaders", repos = c("https://mc-stan.org/r-packages/", getOption("repos")))
install.packages("rstan", repos = c("https://mc-stan.org/r-packages/", getOption("repos")))
*****')
        # a=readline('Attempt to upgrade Rstan from Stan repository? y/n')
        # if(a %in% c('yes','y','Y','Yes','YES')){
        #   message(
        #   install.packages("StanHeaders", repos = c("https://mc-stan.org/r-packages/", getOption("repos")))
        #   install.packages("rstan", repos = c("https://mc-stan.org/r-packages/", getOption("repos")))
        # }
      }
      
      sm <- stan_model(model_name='ctsem', model_code = c(stanmodeltext),auto_write=TRUE)
      # ,allow_undefined = TRUE,verbose=TRUE,
      # includes = paste0(
      #   '\n#include "', file.path(getwd(), 'syl2.hpp'),'"',
      #   '\n')
      if(saveCompile){
        if(exists(x = 'ctsem.compiled',envir= parent.frame()) 
          && !'stanmodel' %in% class(get('ctsem.compiled',envir = parent.frame()))){
          warning('ctsem.compiled object already exists, not saving compile')
        } else  assign(x = 'ctsem.compiled',sm,envir = parent.frame())
      }
    }
    if(!recompile && !forcerecompile) {
      if(!gendata) sm <- stanmodels$ctsm else sm <- stanmodels$ctsmgen
    }
    

# configure inits ---------------------------------------------------------
    initOptim <- (length(inits)==1 && (inits=='optimize' || inits=='Optimize' || inits=='optimise' || inits=='Optimise'))
    if(optimize || initOptim){ #then set optimcontrol list
      optimcontrol$cores <- cores
      optimcontrol$verbose <- verbose
      optimcontrol$priors <- as.logical(priors)
      optimcontrol$standata=standata
      optimcontrol$sm=sm
      optimcontrol$init=inits
      optimcontrol$plot=plot
      optimcontrol$matsetup <- data.frame(ctm$modelmats$matsetup)
    }

    if(!optimize && !is.null(inits)){
      if('list' %in% class(inits)){
        staninits=inits
      } else {
        if(initOptim){ #then first optimize to get inits
          if(!intoverpop && length(unique(datalong[[ctm$subjectIDname]]) > 1) && any(ctm$pars$indvarying)) stop('Cannot optimize to get inits unless intoverpop=TRUE')
          optimcontrol$init <- NULL
          optimcontrol$tol=1e-7
          if(!intoverpop & ! intoverstates) stop('Cannot initialize with optimization unless intoverpop and intoverstates are set to TRUE')
          inits <- do.call(stanoptimis,optimcontrol)$rawest
        }
      sf <- stan_reinitsf(sm,standata)
      staninits <- list() 
      if(chains > 1){ #set all chains to same inits
        for(i in 1:chains){
          staninits[[i]]<-constrain_pars(sf,inits+rnorm(length(inits),0,.01))
        }
      }
      }
  }

    
    if(!optimize){
      
      #control arguments for rstan
      # if(is.null(control$adapt_term_buffer)) control$adapt_term_buffer <- min(c(iter/10,max(iter-20,75)))
      if(is.null(control$adapt_delta)) control$adapt_delta <- .8
      if(is.null(control$adapt_window)) control$adapt_window <- 5
      if(is.null(control$max_treedepth)) control$max_treedepth <- 10
      if(is.null(control$adapt_init_buffer)) control$adapt_init_buffer=2
      if(is.null(control$stepsize)) control$stepsize=.001
      if(is.null(control$metric)) control$metric='diag_e'
      
      
      message('Sampling...')
      # 
      stanargs <- list(object = sm, 
        init_r=.03,
        save_warmup=as.logical(plot),
        refresh=20,
        iter=iter,
        data = standata, chains = chains, control=control,
        cores=cores,
        ...) 
      if(!is.null(inits)) stanargs$init=staninits
      
      if(vb){
        if(!intoverpop && standata$nindvarying > 0) warning('Poor results are expected with variational inference and sampling individual differences! Suggest disabling vb or enabling intoverpop.')
        stanfit <- list(stanfit=rstan::vb(object = sm,data=standata, importance_resampling=TRUE,tol_rel_obj=1e-3))
      } else{ #if not vi
        if(plot==TRUE) stanfit <- suppressWarnings(list(stanfit=do.call(stanWplot,stanargs))) else stanfit <- suppressWarnings(list(stanfit=do.call(sampling,stanargs)))
        
        #find the median sample and compute kalman scores etc for this
        # browser()
        e=rstan::extract(stanfit$stanfit)
        # middle <- which(abs(e$ll-quantile(e$ll,.5)) == min(abs(e$ll-quantile(e$ll,.5) )))
        # middle <- which(e$ll==max(e$ll)) 
        stanfit$rawposterior <- t(stan_unconstrainsamples(fit = stanfit$stanfit,standata = standata))
        stanfit$rawest <- apply(stanfit$rawposterior,2,median)
        
      }
    }
    
    if(optimize==TRUE) {
      stanfit <- do.call(stanoptimis,optimcontrol) #eval(parse(text=opcall))
      
      #update data that may have changed during optimization
      for(ni in names(stanfit$standata)){
        standata[[ni]] <- stanfit$standata[[ni]]
      }
      if(ctm$n.TIpred>0){
        ctm$modelmats$TIPREDEFFECTsetup <- stanfit$standata$TIPREDEFFECTsetup
        ms <- ctm$modelmats$matsetup
        ms$tipred <- 0L
        parswithtipreds <- sort(unique(ms$param[ms$param >0 & ms$when %in% c(0,-1) & ms$copyrow < 1]))
        parswithtipreds<-parswithtipreds[apply(stanfit$standata$TIPREDEFFECTsetup,1,sum)>0]
        ms$tipred[ms$param >0 & ms$when %in% c(0,-1) & ms$copyrow < 1 & ms$param %in% parswithtipreds] <- 1L
        ctm$modelmats$matsetup <- ms
      }
    }
    
    # if(is.na(STAN_NUM_THREADS)) Sys.unsetenv('STAN_NUM_THREADS') else Sys.setenv(STAN_NUM_THREADS = STAN_NUM_THREADS) #reset sys env
  } # end if fit==TRUE
  #convert missings back to NA's for data output
  standataout<-standata
  standataout$Y[standataout$Y==99999] <- NA
  standataout$tipreds[standataout$tipreds==99999] <- NA
  # standataout <- utils::relist((standataout),skeleton=standata)
  
  setup=list(recompile=recompile,idmap=standata$idmap,matsetup=ctm$modelmats$matsetup,matvalues=ctm$modelmats$matvalues,
    popsetup=ctm$modelmats$matsetup[ctm$modelmats$matsetup$when %in% c(0,-1) & ctm$modelmats$matsetup$param > 0,],
    popvalues=ctm$modelmats$matvalues[ctm$modelmats$matsetup$when %in% c(0,-1) & ctm$modelmats$matsetup$param > 0,],
    extratforms=ctm$modelmats$extratforms)
  if(fit) {
    stanfit$transformedparsfull <- suppressMessages(stan_constrainsamples(sm = sm,standata = standata,
      savesubjectmatrices = TRUE, samples = matrix(stanfit$rawest,1),cores=1,savescores=TRUE,pcovn=5000))
    
    out <- list(args=args,
      setup=setup,
      stanmodeltext=stanmodeltext, data=standataout, ctdatastruct=datalong[c(1,nrow(datalong)),],standata=standata, 
      ctstanmodelbase=ctstanmodel, ctstanmodel=ctm,stanmodel=sm, stanfit=stanfit)
    class(out) <- 'ctStanFit'
    out$stanfit$kalman<-suppressMessages(ctStanKalman(out,pointest = TRUE))
  }
  
  if(!fit) out=list(args=args,setup=setup,
    stanmodeltext=stanmodeltext,data=standataout,  ctdatastruct=datalong[c(1,nrow(datalong)),],standata=standata, 
    ctstanmodelbase=ctstanmodel,  ctstanmodel=ctm)
  
  
  return(out)
}
cdriveraus/ctsem documentation built on April 18, 2024, 5:24 a.m.