R/TMBam.R

Defines functions TMBam

## adapted from jagam which is (c) Simon Wood 2014. Released under GPL2. 


TMBam <- function(formula, family=gaussian, data=list(), file, weights=NULL,
                  na.action, offset=NULL, knots=NULL, sp=NULL,
                  drop.unused.levels=TRUE, control=gam.control(), centred=TRUE,
                  diagonalize=FALSE) {
## rho contains log smoothing params and b the model coefficients, in JAGS
## diagonalize==TRUE actually seems to be faster for high dimensional terms
## in the Gaussian setting (Conjugate updates better than MH), otherwise 
## diagonalize==FALSE faster as block MH is highly advantageous
## WARNING: centred=FALSE is usually a very bad idea!!
  if (is.null(file)) stop("TMBam requires a base filename model specification")

  # generate C++ filename
  cppfile <- paste0(file,".cpp")
  # generate R filename
  Rfile <- paste0(file,".R")

  cppcat <- function(...) cat(..., file=cppfile, append=TRUE)

  # start the C++ file
  cat("// code auto-generated by TMBam
#include <TMB.hpp>

template<class Type>
Type objective_function<Type>::operator() ()
{
",file=cppfile)



  ## takes GAM formula and data and produces JAGS model and corresponding 
  ## data list...
  if (is.character(family))
            family <- eval(parse(text = family))
  if (is.function(family))
            family <- family()
  if (is.null(family$family))
            stop("family not recognized")

  #resp <- all.vars(update(formula, . ~ 1))
  resp <- "y"
  lp_stuff <- tmbam.lp(resp, family, use.weights=FALSE, offset=FALSE)

  gp <- interpret.gam(formula) # interpret the formula 
  cl <- match.call() # call needed in gam object for update to work
  mf <- match.call(expand.dots=FALSE)
  mf$formula <- gp$fake.formula
  mf$family <- mf$knots <- mf$sp <- mf$file <- mf$control <- 
  mf$centred <- mf$sp.prior <- mf$diagonalize <- NULL
  mf$drop.unused.levels <- drop.unused.levels
  mf[[1]] <- quote(stats::model.frame) ##as.name("model.frame")
  pmf <- mf

  pmf$formula <- gp$pf
  pmf <- eval(pmf, parent.frame()) # pmf contains all data for parametric part
  pterms <- attr(pmf,"terms") ## pmf only used for this
  rm(pmf)

  mf <- eval(mf, parent.frame()) # the model frame now contains all the data 
  if (nrow(mf)<2) stop("Not enough (non-NA) data to do anything meaningful")
  terms <- attr(mf,"terms")

  ## summarize the *raw* input variables
  ## note can't use get_all_vars here -- buggy with matrices
  vars <- all.vars(gp$fake.formula[-2]) ## drop response here
  inp <- parse(text = paste("list(", paste(vars, collapse = ","),")"))

  ## allow a bit of extra flexibility in what `data' is allowed to be (as model.frame actually does)
  if (!is.list(data)&&!is.data.frame(data)) data <- as.data.frame(data) 

  dl <- eval(inp, data, parent.frame())
  if (!control$keepData) { rm(data)} ## save space
  names(dl) <- vars ## list of all variables needed
  var.summary <- mgcv:::variable.summary(gp$pf,dl,nrow(mf)) ## summarize the input data
  rm(dl)

  G <- mgcv:::gam.setup(gp,pterms=pterms,
                 data=mf,knots=knots,sp=sp,
                 H=NULL,absorb.cons=centred,sparse.cons=FALSE,#select=TRUE,
                 idLinksBases=TRUE,scale.penalty=control$scalePenalty,
                 diagonal.penalty=diagonalize)
  G$model <- mf;G$terms <- terms;G$family <- family;G$call <- cl
  G$var.summary <- var.summary
  ## write JAGS code producing linear predictor and linking linear predictor to 
  ## response....

  use.weights <- if (is.null(weights)) FALSE else TRUE
#use.weights <- mgcv:::write.jagslp("y",family,file,use.weights,!is.null(G$offset))
  if (is.null(weights)&&use.weights) weights <- rep(1,nrow(G$X))

  ## start the JAGS data list...

  #jags.stuff <- list(y=G$y, n=length(G$y), X=G$X)
  tmb.stuff <- list(y=G$y, X=G$X)
  if (!is.null(G$offset)) tmb.stuff$offset <- G$offset
  if (use.weights) tmb.stuff$w <- weights

## FIXME: is this right?
  if (family$family == "binomial") tmb.stuff$y <- G$y*weights ## JAGS not expecting observed prob!!

  ## get initial values, for use by JAGS, and to guess suitable values for
  ## uninformative priors...
  ## initial sp values
  lambda <- mgcv:::initial.spg(G$X, G$y, G$w, family, G$S, G$rank,
                               G$off,offset=G$offset,L=G$L) 
  tmb.ini <- list()
  lam <- if (is.null(G$L)) lambda else G$L%*%lambda
  jin <- mgcv:::jini(G,lam)
  tmb.ini$b <- jin$beta

  # fiddle
  tmb.ini$beta <- tmb.ini$b[-(1:G$nsdf)]
  tmb.ini$mu <- tmb.ini$b[1:G$nsdf]
  tmb.ini$b <- NULL
  tmb.stuff$X <- tmb.stuff$X[, -G$nsdf]

  prior.tau <- signif(0.01/(abs(jin$beta) + jin$se)^2,2)

  ## set the fixed effect priors...
  if (G$nsdf>0) {
    if (G$nsdf==1) {
      pars_def <- "  PARAMETER(mu); // intercept\n"
    }else{
      pars_def <- "  PARAMETER_VECTOR(mu); // fixed effects\n"
    }
  }

  ## Work through smooths.
  ## In JAGS terms the penalties should simply define priors.
  ## Any unpenalized term should be given a diffuse prior.  
  ## For diagonalized terms these should be written directly into the code
  ## and there is nothing to pass to JAGS.
  ## For overlapping multi term penalties, a null space penalty needs to
  ## be added and the components of the penalty have to be passed into 
  ## JAGS in the argument list: cbinding the components into one matrix seems sensible.
  ## Smoothing parameters should be in a single vector in the code indexed by 
  ## number.  
  n.sp <- 0 ## count the smoothing parameters....
pen_counter <- 1
pen_mats <- "  // penalty matrices\n"
beta_def <- "  // split up the random effects coefs\n"
c_sp_counter <- 0
K_defs <- "  // define the real penalties\n"
K_count <- 1
  for (i in 1:length(G$smooth)) {
    ## Are penalties seperable...
    seperable <- FALSE
    M <- length(G$smooth[[i]]$S)
    p <- G$smooth[[i]]$last.para - G$smooth[[i]]$first.para + 1 ## number of params
    if (M<=1) seperable <- TRUE else {
      overlap <- rowSums(G$smooth[[i]]$S[[1]])
      for (j in 2:M) overlap <- overlap & rowSums(G$smooth[[i]]$S[[j]])
      if (!sum(overlap)) seperable <- TRUE 
    }
    if (seperable) { ## double check that they are diagonal
      if (M>0) for (j in 1:M) {
        if (max(abs(G$smooth[[i]]$S[[j]] - diag(diag(G$smooth[[i]]$S[[j]]),nrow=p)))>0) seperable <- FALSE
      } 
    }
    # diagonalized bits
    if (seperable) {
stop("blarg Dave didn't sort this yet")
      b0 <- G$smooth[[i]]$first.para
      if (M==0) {
        b1 <- G$smooth[[i]]$last.para
        ptau <- min(prior.tau[b0:b1])
##        cat("  for (i in ",b0,":",b1,") { b[i] ~ dnorm(0,",ptau,") }\n",file=file,append=TRUE,sep="")
      } else for (j in 1:M) {
        D <- diag(G$smooth[[i]]$S[[j]]) > 0
        b1 <- sum(as.numeric(D)) + b0 - 1
        n.sp <- n.sp + 1
##        cat("  for (i in ",b0,":",b1,") { b[i] ~ dnorm(0, lambda[",n.sp,"]) }\n",file=file,append=TRUE,sep="")
        b0 <- b1 + 1
      }
    # non-diagonalized
    } else { ## inseperable - requires the penalty matrices to be supplied to JAGS... 
      b0 <- G$smooth[[i]]$first.para; b1 <- G$smooth[[i]]$last.para
      Kname <- paste("K",i,sep="") ## total penalty matrix in JAGS
      Sname <- paste("S", c_sp_counter+1 ,sep="") ## components of total penalty in R & JAGS

      pen_mats <- paste0(pen_mats, "  DATA_SPARSE_MATRIX(", Sname, ");\n")
      beta_def <- paste0(beta_def, "  vector<Type> beta", K_count,
                                   " = beta.segment(", b0-G$nsdf-1, ",", p, ");\n")

      K_defs <- paste0(K_defs,
                       "  Eigen::SparseMatrix<Type> ", Kname, " = lambda(", c_sp_counter,
                       ")*", Sname, sep="")
      c_sp_counter <- c_sp_counter + 1
      tmb.stuff[[Sname]] <- as(G$smooth[[i]]$S[[1]], "sparseMatrix")

      if (M>1) { ## code to form total precision matrix...
        for (j in 2:M){
          # components of total penalty in R & JAGS
          Sname <- paste("S", c_sp_counter+1 ,sep="") 
          pen_mats <- paste0(pen_mats, "  DATA_MATRIX(", Sname, ");\n")
          K_defs <- paste0(K_defs,
                           "+\n                    lambda(", c_sp_counter, ")*",
                           Sname, sep="")
          c_sp_counter <- c_sp_counter + 1
          tmb.stuff[[Sname]] <- as(G$smooth[[i]]$S[[j]], "sparseMatrix")
        }
      }
      K_defs <- paste0(K_defs, ";\n", sep="")
      K_count <- K_count + 1
      n.sp <- n.sp + M
#      Sc <- G$smooth[[i]]$S[[1]]
#      if (M>1) for (j in 2:M) Sc <- cbind(Sc, G$smooth[[i]]$S[[j]])
#      tmb.stuff[[Sname]] <- Sc
    }
  } ## smoothing penalties finished

# FIXME: does this go elsewhere?
pars_def <- paste0(pars_def, "  PARAMETER_VECTOR(beta);\n")

  sp_scale_defs <- "  // transform smoopars and scale\n"
  pars_def <- paste0(pars_def, "  PARAMETER_VECTOR(log_lambda); // smoopar\n")
  sp_scale_defs <- paste0(sp_scale_defs,
                          "  vector<Type> lambda = exp(log_lambda);\n")
  # scale
  #sp_scale_defs <- paste0(sp_scale_defs, "  Type phi = exp(log_phi);\n")
  pars_def <- paste0(pars_def, lp_stuff$hyperpars_pars)
  sp_scale_defs <- paste0(sp_scale_defs, lp_stuff$hyperpars, "\n")

tmb.ini$log_sigma <- 1


# FIXME: does this need to be sorted?
# FIXME: need starting values code below!
##!##  ## Write the smoothing parameter prior code, using L if it exists.
##!##  cat("  ## smoothing parameter priors CHECK...\n",file=file,append=TRUE,sep="")
##!##  if (is.null(G$L)) {
##!##    if (sp.prior=="log.uniform") {
##!##      cat("  for (i in 1:",n.sp,") {\n",file=file,append=TRUE,sep="")
##!##      cat("    rho[i] ~ dunif(-12,12)\n",file=file,append=TRUE,sep="") 
##!##      cat("    lambda[i] <- exp(rho[i])\n",file=file,append=TRUE,sep="")
##!##      cat("  }\n",file=file,append=TRUE,sep="")
##!##      tmb.ini$rho <- log(lambda)
tmb.ini$log_lambda <- log(lambda)
##!##    } else { ## gamma priors
##!##      cat("  for (i in 1:",n.sp,") {\n",file=file,append=TRUE,sep="")
##!##      cat("    lambda[i] ~ dgamma(.05,.005)\n",file=file,append=TRUE,sep="") 
##!##      cat("    rho[i] <- log(lambda[i])\n",file=file,append=TRUE,sep="")
##!##      cat("  }\n",file=file,append=TRUE,sep="")
##!##      tmb.ini$lambda <- lambda
##!##    }
##!##  } else { 
##!##    tmb.stuff$L <- G$L
##!##    rho.lo <- FALSE
##!##    if (any(G$lsp0!=0)) {
##!##      tmb.stuff$rho.lo <- G$lsp0
##!##      rho.lo <- TRUE
##!##    }
##!##    nr <- ncol(G$L)
##!##    if (sp.prior=="log.uniform") {
##!##      cat("  for (i in 1:",nr,") { rho0[i] ~ dunif(-12,12) }\n",file=file,append=TRUE,sep="")
##!##      if (rho.lo) cat("  rho <- rho.lo + L %*% rho0\n",file=file,append=TRUE,sep="")
##!##      else cat("  rho <- L %*% rho0\n",file=file,append=TRUE,sep="")
##!##      cat("  for (i in 1:",n.sp,") { lambda[i] <- exp(rho[i]) }\n",file=file,append=TRUE,sep="")
##!##      tmb.ini$rho0 <- log(lambda)
##!##    } else { ## gamma prior
##!##      cat("  for (i in 1:",nr,") {\n",file=file,append=TRUE,sep="")
##!##      cat("    lambda0[i] ~ dgamma(.05,.005)\n",file=file,append=TRUE,sep="") 
##!##      cat("    rho0[i] <- log(lambda0[i])\n",file=file,append=TRUE,sep="")
##!##      cat("  }\n",file=file,append=TRUE,sep="")
##!##      if (rho.lo) cat("  rho <- rho.lo + L %*% rho0\n",file=file,append=TRUE,sep="")
##!##      else cat("  rho <- L %*% rho0\n",file=file,append=TRUE,sep="")
##!##      cat("  for (i in 1:",n.sp,") { lambda[i] <- exp(rho[i]) }\n",file=file,append=TRUE,sep="")
##!##      tmb.ini$lambda0 <- lambda
##!##    }
##!##  } 
##!##  cat("}",file=file,append=TRUE)

  cppcat("  using namespace density;\n")
  cppcat("  using namespace Eigen;\n")
  cppcat("  using namespace R_inla;\n")
  cppcat("\n")
  cppcat("  DATA_MATRIX(X); // design matrix\n")
  cppcat(paste0("  DATA_VECTOR(", resp ,"); // response\n"))
  cppcat(pen_mats)
  cppcat("\n")
  cppcat(pars_def)
  cppcat("\n")
  cppcat(sp_scale_defs)
  cppcat("\n")
  cppcat(beta_def)
  cppcat("\n")
  cppcat(K_defs)
  cppcat("\n")
  cppcat("  // initialize log-likelihood\n  Type nll=0;\n\n")

  # REML
  cppcat("  using namespace atomic; // for logdet\n\n")
  cppcat("  // calculate REML penalty\n")

  K_count <- K_count - 1

  # construct the penalty here
  # lnl -0.5( logdets + penaltymatrixstuff)
  cppcat("  nll = -0.5*(")
  # all the log determinants that you want
  cppcat(paste0("logdet(matrix<Type>(K", 1:K_count, "))", collapse=" + "))
  cppcat(") +\n")
  # now the penalty matrix bits
  cppcat("    0.5*(")
  cppcat(paste0("GMRF(K", 1:K_count, ").Quadform(",
                "beta", 1:K_count, ")",
                collapse=" + "))
  cppcat(");\n\n")

  # calculate linear predictor
  cppcat("  // linear predictor\n")

#TODO: fix weights for the binomial case
  cppcat(lp_stuff$lp)
  cppcat("\n")
  # calculate likelihood
  cppcat(lp_stuff$ll)


  cppcat("\n")
  cppcat("  return nll;\n\n")

  cppcat("}\n")

  G$formula=formula
  G$rank=ncol(G$X) ## to Gibbs sample we force full rank!
  list(pregam=G, tmb.data=tmb.stuff, tmb.ini=tmb.ini)
} ## jagam
dill/mgcvminusminus documentation built on Nov. 14, 2021, 4:13 p.m.