R/jsdm_stancode.R

Defines functions print.jsdmstan_model .modelcode jsdm_stancode

Documented in jsdm_stancode print.jsdmstan_model

#' Make stancode for the jsdm model
#'
#' This function returns the Stan code used to fit the model as specified by the data
#' list, family and method.
#'
#' @details
#'   Environmental covariate effects (\code{"betas"}) can be parameterised in two
#'   ways. With the \code{"cor"} parameterisation all covariate effects are assumed
#'   to be constrained by a correlation matrix between the covariates. With the
#'   \code{"unstruct"} parameterisation all covariate effects are assumed to draw
#'   from a simple distribution with no correlation structure. Both parameterisations
#'   can be modified using the prior object.
#'
#' @param method The method, one of \code{"gllvm"} or \code{"mglmm"}
#' @param family is the response family, must be one of \code{"gaussian"},
#'   \code{"neg_binomial"}, \code{"poisson"}, \code{"binomial"},
#'   \code{"bernoulli"}, \code{"zi_poisson"}, or
#'    \code{"zi_neg_binomial"}. Regular expression
#'   matching is supported.
#' @param prior The prior, given as the result of a call to [jsdm_prior()]
#' @param log_lik Whether the log likelihood should be calculated in the generated
#'   quantities (by default \code{TRUE}), required for loo
#' @param site_intercept Whether a site intercept should be included, potential
#'   values \code{"none"} (no site intercept), \code{"grouped"} (a site intercept
#'   with hierarchical grouping) or \code{"ungrouped"} (site intercept with no
#'   grouping)
#' @param beta_param The parameterisation of the environmental covariate effects, by
#'   default \code{"cor"}. See details for further information.
#' @param zi_param For the zero-inflated families, whether the zero-inflation parameter
#'   is a species-specific constant (default, \code{"constant"}), or varies by
#'   environmental covariates (\code{"covariate"}).
#'
#' @return A character vector of Stan code, class "jsdmstan_model"
#' @export
#'
#' @examples
#' jsdm_stancode(family = "gaussian", method = "gllvm")
#' jsdm_stancode(family = "poisson", method = "mglmm")
#'
jsdm_stancode <- function(method, family, prior = jsdm_prior(),
                          log_lik = TRUE, site_intercept = "none",
                          beta_param = "cor", zi_param = "constant") {
  # checks
  family <- match.arg(family, c("gaussian", "bernoulli", "poisson",
                                "neg_binomial","binomial","zi_poisson",
                                "zi_neg_binomial"))
  method <- match.arg(method, c("gllvm", "mglmm"))
  beta_param <- match.arg(beta_param, c("cor","unstruct"))
  site_intercept <- match.arg(site_intercept, c("none","grouped","ungrouped"))
  zi_param <- match.arg(zi_param, c("constant","covariate"))
  if (class(prior)[1] != "jsdmprior") {
    stop("Prior must be given as a jsdmprior object")
  }

  # data processing steps
  scode <- .modelcode(
    method = method, family = family,
    phylo = FALSE, prior = prior, log_lik = log_lik, site_intercept = site_intercept,
    beta_param = beta_param, zi_param = zi_param
  )
  class(scode) <- c("jsdmstan_model", "character")
  return(scode)
}


.modelcode <- function(method, family, phylo, prior, log_lik, site_intercept,
                       beta_param, zi_param) {
  model_functions <- "
"
  data <- paste(
    " int<lower=1> N; // Number of sites
  int<lower=1> S; // Number of species
",
ifelse(method == "gllvm",
      " int<lower=1> D; // Number of latent dimensions", ""
    ),
    "
  int<lower=0> K; // Number of predictor variables
  matrix[N, K] X; // Predictor matrix
",
ifelse(site_intercept == "grouped",
            "
  int<lower=1> ngrp; // Number of groups in site intercept
  int<lower=0, upper = ngrp> grps[N]; // Vector matching sites to groups
 ",""),
  switch(family,
      "gaussian" = "real",
      "bernoulli" = "int<lower=0,upper=1>",
      "neg_binomial" = "int<lower=0>",
      "poisson" = "int<lower=0>",
      "zi_poisson" = "int<lower=0>",
      "zi_neg_binomial" = "int<lower=0>",
      "binomial" = "int<lower=0>"
    ), "Y[N,S]; //Species matrix",
 ifelse(family == "binomial",
        "
  int<lower=0> Ntrials[N]; // Number of trials",""),
  ifelse(grepl("zi_", family),"
  int<lower=0> N_zero[S]; // number of zeros per species
  int<lower=0> N_nonzero[S]; //number of nonzeros per species
  int<lower=0> Sum_nonzero; //Total number of nonzeros across all species
  int<lower=0> Sum_zero; //Total number of zeros across all species
  int<lower=0> Y_nz[Sum_nonzero]; //Y values for nonzeros
  int<lower=0> ss[Sum_nonzero]; //species index for Y_nz
  int<lower=0> nn[Sum_nonzero]; //site index for Y_nz
  int<lower=0> sz[Sum_zero]; //species index for Y_z
  int<lower=0> nz[Sum_zero]; //site index for Y_z",""),
ifelse(grepl("zi_", family) & zi_param == "covariate","
  int<lower=1> zi_k; //number of covariates for env effects on zi
  matrix[N, zi_k] zi_X; //environmental covariate matrix for zi","")
)


  transformed_data <- ifelse(method == "gllvm", "
  // Ensures identifiability of the model - no rotation of factors
  int<lower=1> M;
  M = D * (S - D) + choose(D, 2) + D;
", "")



  site_inter_par <- switch(site_intercept,
                           "ungrouped" = "
  // Site intercepts
  real a_bar;
  real<lower=0> sigma_a;
  vector[N] a;",
  "none" = "",
  "grouped" = "
  // Site intercepts
  real a_bar;
  real<lower=0> sigma_a;
  vector[ngrp] a;")
  species_pars <- switch(beta_param, "cor" = "
  //betas are hierarchical with covariance model
  vector<lower=0>[K] sigmas_preds;
  matrix[K, S] z_preds;
  // covariance matrix on betas by predictors
  corr_matrix[K] cor_preds;", "unstruct" = "
  matrix[K, S] betas;")
  mglmm_spcov_pars <- "
  // species covariances
  vector<lower=0>[S] sigmas_species;
  matrix[S, N] z_species;
  corr_matrix[S] cor_species;"
  gllvm_pars <- "
  // Factor parameters
  vector[M] L; // Non-zero factor loadings
  real<lower=0> sigma_L; // variance of species loadings
  // Latent variables
  matrix[D, N] LV_uncor; // Per-site latent variable"
  var_pars <- switch(family,
    "gaussian" = "
  real<lower=0> sigma[S]; // Gaussian parameters",
    "bernoulli" = "",
    "neg_binomial" = "
  real<lower=0> kappa[S]; // neg_binomial parameters",
    "poisson" = "",
    "zi_poisson" = switch(zi_param,
    "constant" = "
  real<lower=0,upper=1> zi[S]; // zero-inflation parameter",
    "covariate" = "
  matrix[zi_k,S] zi_betas; //environmental effects for zi"),
  "zi_neg_binomial" = switch(zi_param,
                             "constant" = "
  real<lower=0> kappa[S]; // neg_binomial parameters
  real<lower=0,upper=1> zi[S]; // zero-inflation parameter",
                             "covariate" = "
  real<lower=0> kappa[S]; // neg_binomial parameters
  matrix[zi_k,S] zi_betas; //environmental effects for zi")
  )

  pars <- paste(
    site_inter_par, species_pars,
    switch(method,
      "gllvm" = gllvm_pars,
      "mglmm" = mglmm_spcov_pars
    ),
    var_pars
  )
  transformed_pars <- paste(if(beta_param == "cor"){ "
  // covariance matrix on betas by preds
  matrix[K, S] betas;
  "} else {""}, switch(method,
    "gllvm" = "
  // Construct factor loading matrix
  matrix[S, D] Lambda_uncor;
  // Constraints to allow identifiability of loadings
  for (i in 1:(D-1)) {
    for (j in (i+1):(D)){
      Lambda_uncor[i,j] = 0;
    }
  }
  {
    int index;
    index = 0;
    for (j in 1:D) {
      for (i in j:S) {
        index = index + 1;
        Lambda_uncor[i, j] = L[index];
      }
    }
  }
  ",
    "mglmm" = "
  matrix[N, S] u;
  u = (diag_pre_multiply(sigmas_species, cor_species) * z_species)';
  "
  ), if(beta_param == "cor") {"
  betas = diag_pre_multiply(sigmas_preds, cor_preds) * z_preds;
"} else {""})



  gllvm_model <- switch(site_intercept, "ungrouped" = "
  // model
  matrix[N, S] LV_sum = ((Lambda_uncor * sigma_L) * LV_uncor)';
  matrix[N, S] alpha = rep_matrix(a_bar + a * sigma_a, S);
  mu = alpha + (X * betas) + LV_sum;
  ", "none" = "
  // model
  matrix[N, S] LV_sum = ((Lambda_uncor * sigma_L) * LV_uncor)';
  mu = (X * betas) + LV_sum;
  ", "grouped" = "
  matrix[N, S] alpha;
  matrix[N, S] LV_sum = ((Lambda_uncor * sigma_L) * LV_uncor)';
  {
    vector[ngrp] theta = a_bar + a * sigma_a;
    for (n in 1:N){
      alpha[n,] = rep_row_vector(theta[grps[n]],S);
    }
  }
  mu = alpha + (X * betas) + LV_sum;
  ")
  mglmm_model <- switch(site_intercept, "ungrouped" = "
  // model
  matrix[N, S] alpha = rep_matrix(a_bar + a * sigma_a, S);
  mu = alpha + (X * betas) + u;
  ", "none" = "
  // model
  mu = (X * betas) + u;
  ", "grouped" = "
  matrix[N, S] alpha;
  {
    vector[ngrp] theta = a_bar + a * sigma_a;
    for (n in 1:N){
      alpha[n,] = rep_row_vector(theta[grps[n]],S);
    }
  }
  mu = alpha + (X * betas) + u;
  ")
  model <- paste("
  matrix[N,S] mu;
  ", ifelse(grepl("zi_",family),paste0("
  real mu_nz[Sum_nonzero];
  real mu_z[Sum_zero];
  int pos;
  int neg;",switch(zi_param,"constant" = "",
                   "covariate" = "
  real zi_nz[Sum_nonzero];
  real zi_z[Sum_zero];")),""),
  switch(method,
    "gllvm" = gllvm_model,
    "mglmm" = mglmm_model
  ),ifelse(grepl("zi_",family),paste0(ifelse(zi_param == "covariate", "
  matrix[N,S] zi = zi_X * zi_betas;",""),"
  for(i in 1:Sum_nonzero){
    mu_nz[i] = mu[nn[i],ss[i]];",
    switch(zi_param, "constant" = "",
           "covariate" = "
    zi_nz[i] = zi[nn[i],ss[i]];"),"
  }
  for(i in 1:Sum_zero){
    mu_z[i] = mu[nz[i],sz[i]];",
    switch(zi_param, "constant" = "",
           "covariate" = "
    zi_z[i] = zi[nz[i],sz[i]];"),"
  }
  "),""))
  model_priors <- paste(
    ifelse(site_intercept %in% c("ungrouped","grouped"), paste("
  // Site-level intercept priors
  a ~ ", prior[["a"]], ";
  a_bar ~ ", prior[["a_bar"]], ";
  sigma_a ~ ", prior[["sigma_a"]], ";
  "), "
  "), switch(beta_param, "cor" = paste(
  "
  // Species parameter priors
  sigmas_preds ~ ", prior[["sigmas_preds"]], ";
  to_vector(z_preds) ~ ", prior[["z_preds"]], ";
  // covariance matrix priors
  cor_preds ~ ", prior[["cor_preds"]], ";
"), "unstruct" = paste(
  "
  // Species parameter priors
  to_vector(betas) ~ ", prior[["betas"]],";
  ")),
    switch(method,
      "gllvm" = paste("
  // Factor priors
  to_vector(LV_uncor) ~ ", prior[["LV"]], ";
  L ~ ", prior[["L"]], ";
  sigma_L ~ ", prior[["sigma_L"]], "; // Variance of factor loadings
"),
      "mglmm" = paste("
  // Species parameter priors
  sigmas_species ~ ", prior[["sigmas_species"]], ";
  to_vector(z_species) ~ ", prior[["z_species"]], ";
  cor_species ~ ", prior[["cor_species"]], ";
")
    ),
    switch(family,
      "gaussian" = paste("
  //Standard deviation parameters
  sigma ~ ", prior[["sigma"]], ";
"),
      "neg_binomial" = paste("
  //Scale parameter
  kappa ~ ", prior[["kappa"]], ";
"),
      "bern" = "",
      "poisson" = "",
      "binomial" = "",
      "zi_poisson" = switch(zi_param,"constant" = paste("
  //zero-inflation parameter
  zi ~ ", prior[["zi"]], ";
"), "covariate" = paste("
  //zero-inflation parameter
  to_vector(zi_betas) ~ ", prior[["zi_betas"]], ";
")),
"zi_neg_binomial" = switch(zi_param, "constant" = paste("
  //zero-inflation parameter
  zi ~ ", prior[["zi"]], ";
  kappa ~ ", prior[["kappa"]], ";
"), "covariate" = paste("
  //zero-inflation parameter
  to_vector(zi_betas) ~ ", prior[["zi_betas"]], ";
  kappa ~ ", prior[["kappa"]], ";
")
    )
  ))
  model_pt2 <- if(!grepl("zi_", family)){ paste(
    "
  for(i in 1:N) Y[i,] ~ ",
    switch(family,
      "gaussian" = "normal(mu[i,], sigma);",
      "bernoulli" = "bernoulli_logit(mu[i,]);",
      "neg_binomial" = "neg_binomial_2_log(mu[i,], kappa);",
      "poisson" = "poisson_log(mu[i,]);",
      "binomial" = "binomial_logit(Ntrials[i], mu[i,]);"
    )
  )} else{paste("
  pos = 1;
  neg = 1;
  for(s in 1:S){
    target
      += N_zero[s]
           * log_sum_exp(",
           switch(zi_param,"constant" = "log(zi[s]),
                         log1m(zi[s])
                           +",
                  "covariate" = "bernoulli_logit_lpmf(1 | segment(zi_z, neg, N_zero[s])),
                         bernoulli_logit_lpmf(0 | segment(zi_z, neg, N_zero[s]))
                           +"),
  switch(family,
         "zi_poisson" = "poisson_log_lpmf(0 | segment(mu_z, neg, N_zero[s])));",
         "zi_neg_binomial" = "neg_binomial_2_log_lpmf(0 | segment(mu_z, neg, N_zero[s]), kappa[s]));"),"
    target += N_nonzero[s] * ",switch(zi_param,
    "constant" = "log1m(zi[s]);",
    "covariate" = "bernoulli_logit_lpmf(0 | segment(zi_nz, pos, N_nonzero[s]));"),"
    target +=",
    switch(family,
           "zi_poisson" = "poisson_log_lpmf(segment(Y_nz,pos,N_nonzero[s]) |
                                 segment(mu_nz, pos, N_nonzero[s]));",
           "zi_neg_binomial" = "neg_binomial_2_log_lpmf(segment(Y_nz,pos,N_nonzero[s]) |
                                 segment(mu_nz, pos, N_nonzero[s]), kappa[s]);"),"
    pos = pos + N_nonzero[s];
    neg = neg + N_zero[s];
  }
")
  }

  generated_quantities <- paste(
    ifelse(isTRUE(log_lik), "
  // Calculate linear predictor, y_rep, log likelihoods for LOO
  matrix[N, S] log_lik;
  ", ""),
    ifelse(method == "gllvm", "
  // Sign correct factor loadings and factors
  matrix[D, N] LV;
  matrix[S, D] Lambda;
  for(d in 1:D){
    if(Lambda_uncor[d,d] < 0){
      Lambda[,d] = -1 * Lambda_uncor[,d];
      LV[d,] = -1 * LV_uncor[d,];
    } else {
      Lambda[,d] = Lambda_uncor[,d];
      LV[d,] = LV_uncor[d,];
    }
  }", ""), ifelse(isTRUE(log_lik), paste(
      "
  {
    matrix[N, S] linpred;",ifelse(grepl("zi", family) & zi_param == "covariate","
    matrix[N,S] zi = zi_X * zi_betas;",""),
    switch(site_intercept, "ungrouped" = paste("
    linpred = rep_matrix(a_bar + a * sigma_a, S) + (X * betas) +",
    switch(method,
           "gllvm" = "((Lambda_uncor * sigma_L) * LV_uncor)'",
           "mglmm" = "u"
    ), ";
    "), "none" = paste("
    linpred = (X * betas) +",
    switch(method,
           "gllvm" = "((Lambda_uncor * sigma_L) * LV_uncor)'",
           "mglmm" = "u"
      ), ";
    "), "grouped" = paste("
  matrix[N, S] alpha;
  for (n in 1:N){
    alpha[n,] = rep_row_vector(a[grps[n]],S);
  }
  linpred = alpha + (X * betas) +",
  switch(method,
         "gllvm" = "((Lambda_uncor * sigma_L) * LV_uncor)'",
         "mglmm" = "u"
  ), ";
    ")),"
      for(i in 1:N) {
      for(j in 1:S) {",
      switch(family,
        "gaussian" = "log_lik[i, j] = normal_lpdf(Y[i, j] | linpred[i, j], sigma[j]);",
        "bernoulli" = "log_lik[i, j] = bernoulli_logit_lpmf(Y[i, j] | linpred[i, j]);",
        "neg_binomial" = "log_lik[i, j] = neg_binomial_2_log_lpmf(Y[i, j] | linpred[i, j], kappa[j]);",
        "poisson" = "log_lik[i, j] = poisson_log_lpmf(Y[i, j] | linpred[i, j]);",
        "binomial" = "log_lik[i, j] = binomial_logit_lpmf(Y[i, j] | Ntrials[i], linpred[i, j]);",
        "zi_poisson" = switch(zi_param,"constant" = "if (Y[i,j] == 0){
          log_lik[i, j] = log_sum_exp(bernoulli_lpmf(1 | zi[j]),
                          bernoulli_lpmf(0 |zi[j])
                          + poisson_log_lpmf(Y[i,j] | linpred[i,j]));
            } else {
              log_lik[i, j] = bernoulli_lpmf(0 | zi[j])
              + poisson_log_lpmf(Y[i,j] | linpred[i,j]);
            }",
                              "covariate" = "if (Y[i,j] == 0){
          log_lik[i, j] = log_sum_exp(bernoulli_logit_lpmf(1 | zi[i,j]),
                          bernoulli_logit_lpmf(0 |zi[i,j])
                          + poisson_log_lpmf(Y[i,j] | linpred[i,j]));
            } else {
              log_lik[i, j] = bernoulli_logit_lpmf(0 | zi[i,j])
              + poisson_log_lpmf(Y[i,j] | linpred[i,j]);
            }"),
        "zi_neg_binomial" = switch(zi_param,"constant" = "if (Y[i,j] == 0){
          log_lik[i, j] = log_sum_exp(bernoulli_lpmf(1 | zi[j]),
                          bernoulli_lpmf(0 |zi[j])
                          + neg_binomial_2_log_lpmf(Y[i,j] | linpred[i,j], kappa[j]));
            } else {
              log_lik[i, j] = bernoulli_lpmf(0 | zi[j])
              + neg_binomial_2_log_lpmf(Y[i,j] | linpred[i,j], kappa[j]);
            }",
                                   "covariate" = "if (Y[i,j] == 0){
          log_lik[i, j] = log_sum_exp(bernoulli_logit_lpmf(1 | zi[i,j]),
                          bernoulli_logit_lpmf(0 |zi[i,j])
                          + poisson_log_lpmf(Y[i,j] | linpred[i,j]));
            } else {
              log_lik[i, j] = bernoulli_logit_lpmf(0 | zi[i,j])
              + poisson_log_lpmf(Y[i,j] | linpred[i,j]);
            }")
      ),"
      }
    }
  }
  "
    ), "")
  )

  res <- paste(
    "//Generated by jsdmstan\n",
    "functions{\n",
    model_functions,
    "\n}\ndata{\n",
    data,
    "\n}\ntransformed data{\n",
    transformed_data,
    "\n}\nparameters{\n",
    pars,
    "\n}\ntransformed parameters{\n",
    transformed_pars,
    "\n}\nmodel{\n",
    model, "\n", model_priors, "\n", model_pt2, "\n",
    "\n}\ngenerated quantities{\n",
    generated_quantities,
    "\n}\n\n"
  )

  return(res)
}

#' @export
#' @describeIn jsdm_stancode A printing function for jsdmstan_model objects
#' @param x The jsdm_stancode object
#' @param ... Currently unused
print.jsdmstan_model <- function(x, ...) {
  cat(x)
}
fseaton/jsdmstan documentation built on Sept. 29, 2024, 6:40 p.m.