#'Fit Joint Species Distribution Models in \pkg{mvgam}
#'
#'This function sets up a Joint Species Distribution Model whereby the residual associations among
#'species can be modelled in a reduced-rank format using a set of latent factors. The factor
#'specification is extremely flexible, allowing users to include spatial, temporal or any other type
#'of predictor effects to more efficiently capture unmodelled residual associations, while the
#'observation model can also be highly flexible (including all smooth, GP and other effects that
#'\pkg{mvgam} can handle)
#'
#'@inheritParams mvgam
#'@inheritParams ZMVN
#'@param formula A \code{formula} object specifying the GAM observation model formula. These are exactly like the formula
#'for a GLM except that smooth terms, `s()`, `te()`, `ti()`, `t2()`, as well as time-varying
#'`dynamic()` terms, nonparametric `gp()` terms and offsets using `offset()`, can be added to the right hand side
#'to specify that the linear predictor depends on smooth functions of predictors
#'(or linear functionals of these). Details of the formula syntax used by \pkg{mvgam}
#'can be found in \code{\link{mvgam_formulae}}
#'@param factor_formula A \code{formula} object specifying the linear predictor
#'effects for the latent factors. Use `by = trend` within calls to functional terms
#'(i.e. `s()`, `te()`, `ti()`, `t2()`, `dynamic()`, or `gp()`) to ensure that each factor
#'captures a different axis of variation. See the example below as an illustration
#'@param factor_knots An optional \code{list} containing user specified knot values to
#' be used for basis construction of any smooth terms in `factor_formula`.
#'For most bases the user simply supplies the knots to be used, which must match up with the `k` value supplied
#'(note that the number of knots is not always just `k`). Different terms can use different numbers of knots,
#'unless they share a covariate
#'@param data A \code{dataframe} or \code{list} containing the model response variable and covariates
#'required by the GAM \code{formula} and \code{factor_formula} objects
#'@param family \code{family} specifying the observation family for the outcomes. Currently supported
#'families are:
#'\itemize{
#' \item`gaussian()` for real-valued data
#' \item`betar()` for proportional data on `(0,1)`
#' \item`lognormal()` for non-negative real-valued data
#' \item`student_t()` for real-valued data
#' \item`Gamma()` for non-negative real-valued data
#' \item`bernoulli()` for binary data
#' \item`poisson()` for count data
#' \item`nb()` for overdispersed count data
#' \item`binomial()` for count data with imperfect detection when the number of trials is known;
#' note that the `cbind()` function must be used to bind the discrete observations and the discrete number
#' of trials
#' \item`beta_binomial()` as for `binomial()` but allows for overdispersion}
#'Default is `poisson()`. See \code{\link{mvgam_families}} for more details
#' @param species The unquoted name of the `factor` variable that indexes
#' the different response units in `data` (usually `'species'` in a JSDM).
#' Defaults to `series` to be consistent with other `mvgam` models
#'@param n_lv \code{integer} the number of latent factors to use for modelling
#'residual associations.
#'Cannot be `> n_species`. Defaults arbitrarily to `2`
#'@param threads \code{integer} Experimental option to use multithreading for within-chain
#'parallelisation in \code{Stan}. We recommend its use only if you are experienced with
#'\code{Stan}'s `reduce_sum` function and have a slow running model that cannot be sped
#'up by any other means. Currently works for all families when using \code{Cmdstan}
#'as the backend
#'@param priors An optional \code{data.frame} with prior
#'definitions (in Stan syntax) or, preferentially, a vector containing
#' objects of class `brmsprior` (see. \code{\link[brms]{prior}} for details).
#' See [get_mvgam_priors] and for more information on changing default prior distributions
#'@param ... Other arguments to pass to [mvgam]
#'@author Nicholas J Clark
#'@details Joint Species Distribution Models allow for responses of multiple species to be
#'learned hierarchically, whereby responses to environmental variables in `formula` can be partially
#'pooled and any latent, unmodelled residual associations can also be learned. In \pkg{mvgam}, both of
#'these effects can be modelled with the full power of latent factor Hierarchical GAMs, providing unmatched
#'flexibility to model full communities of species. When calling [jsdgam], an initial State-Space model using
#'`trend = 'None'` is set up and then modified to include the latent factors and their linear predictors.
#'Consequently, you can inspect priors for these models using [get_mvgam_priors] by supplying the relevant
#'`formula`, `factor_formula`, `data` and `family` arguments and keeping the default `trend = 'None'`.
#'
#' In a JSDGAM, the expectation of response \eqn{Y_{ij}} is modelled with
#'
#' \deqn{g(\mu_{ij}) = X_i\beta + u_i\theta_j,}
#'
#' where \eqn{g(.)} is a known link function,
#' \eqn{X} is a design matrix of linear predictors (with associated \eqn{\beta} coefficients),
#' \eqn{u} are \eqn{n_{lv}}-variate latent factors
#' (\eqn{n_{lv}}<<\eqn{n_{species}}) and
#' \eqn{\theta_j} are species-specific loadings on the latent factors, respectively. The design matrix
#' \eqn{X} and \eqn{\beta} coefficients are constructed and modelled using `formula` and can contain
#' any of `mvgam`'s predictor effects, including random intercepts and slopes, multidimensional penalized
#' smooths, GP effects etc... The factor loadings \eqn{\theta_j} are constrained for identifiability but can
#' be used to reconstruct an estimate of the species' residual variance-covariance matrix
#' using \eqn{\Theta \Theta'} (see the example below and [residual_cor()] for details).
#' The latent factors are further modelled using:
#'\deqn{
#'u_i \sim \text{Normal}(Q_i\beta_{factor}, 1) \quad
#'}
#'where the second design matrix \eqn{Q} and associated \eqn{\beta_{factor}} coefficients are
#'constructed and modelled using `factor_formula`. Again, the effects that make up this linear
#'predictor can contain any of `mvgam`'s allowed predictor effects, providing enormous flexibility for
#'modelling species' communities.
#'@seealso [mvgam()], [residual_cor()]
#'@references Nicholas J Clark & Konstans Wells (2023). Dynamic generalised additive models (DGAMs) for forecasting discrete ecological time series.
#'Methods in Ecology and Evolution. 14:3, 771-784.
#' \cr
#' \cr
#'David I Warton, F Guillaume Blanchet, Robert B O’Hara, Otso Ovaskainen, Sara Taskinen, Steven C
#'Walker & Francis KC Hui (2015). So many variables: joint modeling in community ecology.
#'Trends in Ecology & Evolution 30:12, 766-779.
#'@return A \code{list} object of class \code{mvgam} containing model output,
#'the text representation of the model file,
#'the mgcv model output (for easily generating simulations at
#'unsampled covariate values), Dunn-Smyth residuals for each species and key information needed
#'for other functions in the package. See \code{\link{mvgam-class}} for details.
#'Use `methods(class = "mvgam")` for an overview on available methods
#'@examples
#'\donttest{
#' # Fit a JSDGAM to the portal_data captures
#' mod <- jsdgam(
#' formula = captures ~
#' # Fixed effects of NDVI and mintemp, row effect as a GP of time
#' ndvi_ma12:series + mintemp:series + gp(time, k = 15),
#' factor_formula = ~ -1,
#' data = portal_data,
#' unit = time,
#' species = series,
#' family = poisson(),
#' n_lv = 2,
#' silent = 2,
#' chains = 2
#' )
#'
#' # Plot covariate effects
#' library(ggplot2); theme_set(theme_bw())
#' plot_predictions(
#' mod,
#' condition = c('ndvi_ma12','series', 'series')
#' )
#'
#' plot_predictions(
#' mod,
#' condition = c('mintemp','series', 'series')
#' )
#'
#' # A residual correlation plot
#' plot(
#' residual_cor(mod)
#' )
#'
#' # An ordination biplot can also be constructed
#' # from the factor scores and their loadings
#' if(requireNamespace('ggrepel', quietly = TRUE)){
#' ordinate(mod, alpha = 0.7)
#' }
#'
#'
#' # A more complicated example showing how to include predictors
#' # in the factor_formula
#'
#' # Simulate latent count data for 500 spatial locations and 10 species
#' set.seed(0)
#' N_points <- 500
#' N_species <- 10
#'
#' # Species-level intercepts (on the log scale)
#' alphas <- runif(N_species, 2, 2.25)
#'
#' # Simulate a covariate and species-level responses to it
#' temperature <- rnorm(N_points)
#' betas <- runif(N_species, -0.5, 0.5)
#'
#' # Simulate points uniformly over a space
#' lon <- runif(N_points, min = 150, max = 155)
#' lat <- runif(N_points, min = -20, max = -19)
#'
#' # Set up spatial basis functions as a tensor product of lat and lon
#' sm <- mgcv::smoothCon(mgcv::te(lon, lat, k = 5),
#' data = data.frame(lon, lat),
#' knots = NULL)[[1]]
#'
#' # The design matrix for this smooth is in the 'X' slot
#' des_mat <- sm$X
#' dim(des_mat)
#'
#' # Function to generate a random covariance matrix where all variables
#' # have unit variance (i.e. diagonals are all 1)
#' random_Sigma = function(N){
#' L_Omega <- matrix(0, N, N);
#' L_Omega[1, 1] <- 1;
#' for (i in 2 : N) {
#' bound <- 1;
#' for (j in 1 : (i - 1)) {
#' L_Omega[i, j] <- runif(1, -sqrt(bound), sqrt(bound));
#' bound <- bound - L_Omega[i, j] ^ 2;
#' }
#' L_Omega[i, i] <- sqrt(bound);
#' }
#' Sigma <- L_Omega %*% t(L_Omega);
#' return(Sigma)
#' }
#'
#' # Simulate a variance-covariance matrix for the correlations among
#' # basis coefficients
#' Sigma <- random_Sigma(N = NCOL(des_mat))
#'
#' # Now simulate the species-level basis coefficients hierarchically, where
#' # spatial basis function correlations are a convex sum of a base correlation
#' # matrix and a species-level correlation matrix
#' basis_coefs <- matrix(NA, nrow = N_species, ncol = NCOL(Sigma))
#' base_field <- mgcv::rmvn(1, mu = rep(0, NCOL(Sigma)), V = Sigma)
#' for(t in 1:N_species){
#' corOmega <- (cov2cor(Sigma) * 0.7) +
#' (0.3 * cov2cor(random_Sigma(N = NCOL(des_mat))))
#' basis_coefs[t, ] <- mgcv::rmvn(1, mu = rep(0, NCOL(Sigma)), V = corOmega)
#' }
#'
#' # Simulate the latent spatial processes
#' st_process <- do.call(rbind, lapply(seq_len(N_species), function(t){
#' data.frame(lat = lat,
#' lon = lon,
#' species = paste0('species_', t),
#' temperature = temperature,
#' process = alphas[t] +
#' betas[t] * temperature +
#' des_mat %*% basis_coefs[t,])
#' }))
#'
#' # Now take noisy observations at some of the points (60)
#' obs_points <- sample(1:N_points, size = 60, replace = FALSE)
#' obs_points <- data.frame(lat = lat[obs_points],
#' lon = lon[obs_points],
#' site = 1:60)
#'
#' # Keep only the process data at these points
#' st_process %>%
#' dplyr::inner_join(obs_points, by = c('lat', 'lon')) %>%
#' # now take noisy Poisson observations of the process
#' dplyr::mutate(count = rpois(NROW(.), lambda = exp(process))) %>%
#' dplyr::mutate(species = factor(species,
#' levels = paste0('species_', 1:N_species))) %>%
#' dplyr::group_by(lat, lon) -> dat
#'
#' # View the count distributions for each species
#' ggplot(dat, aes(x = count)) +
#' geom_histogram() +
#' facet_wrap(~ species, scales = 'free')
#'
#' ggplot(dat, aes(x = lon, y = lat, col = log(count + 1))) +
#' geom_point(size = 2.25) +
#' facet_wrap(~ species, scales = 'free') +
#' scale_color_viridis_c()
#'
#' # Inspect default priors for a joint species model with three spatial factors
#' priors <- get_mvgam_priors(formula = count ~
#' # Environmental model includes random slopes for
#' # a linear effect of temperature
#' s(species, bs = 're', by = temperature),
#'
#' # Each factor estimates a different nonlinear spatial process, using
#' # 'by = trend' as in other mvgam State-Space models
#' factor_formula = ~ gp(lon, lat, k = 6, by = trend) - 1,
#' n_lv = 3,
#'
#' # The data and grouping variables
#' data = dat,
#' unit = site,
#' species = species,
#'
#' # Poisson observations
#' family = poisson())
#' head(priors)
#'
#' # Fit a JSDM that estimates hierarchical temperature responses
#' # and that uses three latent spatial factors
#' mod <- jsdgam(formula = count ~
#' # Environmental model includes random slopes for a
#' # linear effect of temperature
#' s(species, bs = 're', by = temperature),
#'
#' # Each factor estimates a different nonlinear spatial process, using
#' # 'by = trend' as in other mvgam State-Space models
#' factor_formula = ~ gp(lon, lat, k = 6, by = trend) - 1,
#' n_lv = 3,
#'
#' # Change default priors for fixed random effect variances and
#' # factor GP marginal deviations to standard normal
#' priors = c(prior(std_normal(),
#' class = sigma_raw),
#' prior(std_normal(),
#' class = `alpha_gp_trend(lon, lat):trendtrend1`),
#' prior(std_normal(),
#' class = `alpha_gp_trend(lon, lat):trendtrend2`),
#' prior(std_normal(),
#' class = `alpha_gp_trend(lon, lat):trendtrend3`)),
#'
#' # The data and the grouping variables
#' data = dat,
#' unit = site,
#' species = species,
#'
#' # Poisson observations
#' family = poisson(),
#' chains = 2,
#' silent = 2)
#'
#' # Plot the implicit species-level intercept estimates
#' plot_predictions(mod, condition = 'species',
#' type = 'link')
#'
#' # Plot species' hierarchical responses to temperature
#' plot_predictions(mod, condition = c('temperature', 'species', 'species'),
#' type = 'link')
#'
#' # Plot posterior median estimates of the latent spatial factors
#' plot(mod, type = 'smooths', trend_effects = TRUE)
#'
#' # Or using gratia, if you have it installed
#' if(requireNamespace('gratia', quietly = TRUE)){
#' gratia::draw(mod, trend_effects = TRUE, dist = 0)
#' }
#'
#' # Plot species' randomized quantile residual distributions
#' # as a function of latitude
#' pp_check(mod,
#' type = 'resid_ribbon_grouped',
#' group = 'species',
#' x = 'lat',
#' ndraws = 200)
#'
#' # Calculate residual spatial correlations
#' post_cors <- residual_cor(mod)
#' names(post_cors)
#' # Look at lower and upper credible interval estimates for
#' # some of the estimated correlations
#' post_cors$cor[1:5, 1:5]
#' post_cors$cor_upper[1:5, 1:5]
#' post_cors$cor_lower[1:5, 1:5]
#'
#' # Plot of the posterior median correlations for those estimated
#' # to be non-zero
#' plot(post_cors, cluster = TRUE)
#'
#' # An ordination biplot can also be constructed
#' # from the factor scores and their loadings
#' if(requireNamespace('ggrepel', quietly = TRUE)){
#' ordinate(mod)
#' }
#'
#' # Posterior predictive checks and ELPD-LOO can ascertain model fit
#' pp_check(mod,
#' type = "pit_ecdf_grouped",
#' group = "species",
#' ndraws = 200)
#' loo(mod)
#'
#' # Forecast log(counts) for entire region (site value doesn't matter as long
#' # as each spatial location has a different and unique site identifier);
#' # note this calculation takes a few minutes because of the need to calculate
#' # draws from the stochastic latent factors
#' newdata <- st_process %>%
#' dplyr::mutate(species = factor(species,
#' levels = paste0('species_',
#' 1:N_species))) %>%
#' dplyr::group_by(lat, lon) %>%
#' dplyr::mutate(site = dplyr::cur_group_id()) %>%
#' dplyr::ungroup()
#' preds <- predict(mod, newdata = newdata)
#'
#' # Plot the median log(count) predictions on a grid
#' newdata$log_count <- preds[,1]
#' ggplot(newdata, aes(x = lon, y = lat, col = log_count)) +
#' geom_point(size = 1.5) +
#' facet_wrap(~ species, scales = 'free') +
#' scale_color_viridis_c() +
#' theme_classic()
#'
#' \dontshow{
#' # For R CMD check: make sure any open connections are closed afterward
#' closeAllConnections()
#' }
#'}
#'@export
jsdgam = function(
formula,
factor_formula = ~ -1,
knots,
factor_knots,
data,
newdata,
family = poisson(),
unit = time,
species = series,
share_obs_params = FALSE,
priors,
n_lv = 2,
backend = getOption("brms.backend", "cmdstanr"),
algorithm = getOption("brms.algorithm", "sampling"),
control = list(max_treedepth = 10, adapt_delta = 0.8),
chains = 4,
burnin = 500,
samples = 500,
thin = 1,
parallel = TRUE,
threads = 1,
silent = 1,
run_model = TRUE,
return_model_data = FALSE,
residuals = TRUE,
...
) {
#### Validate arguments and initialise the model skeleton ####
validate_pos_integer(n_lv)
# Prep the trend so that the data can be structured in the usual
# mvgam fashion (with 'time' and 'series' variables)
unit <- deparse0(substitute(unit))
subgr <- deparse0(substitute(species))
prepped_trend <- prep_jsdgam_trend(unit = unit, subgr = subgr, data = data)
data_train <- validate_series_time(data = data, trend_model = prepped_trend)
# Set up a simple trend_map to get the model dimensions correct;
# this requires that we only have n_lv trends and that each series
# only maps to one distinct trend, resulting in a loading matrix of
# the correct size (n_series x n_lv)
trend_map <- prep_jsdgam_trendmap(data_train, n_lv)
# Set up the model structure but leave autoformat off so that the
# model file can be easily modified
mod <- suppressWarnings(mvgam(
formula = formula,
trend_formula = factor_formula,
knots = knots,
trend_knots = factor_knots,
family = family,
share_obs_params = share_obs_params,
priors = priors,
trend_model = 'None',
trend_map = trend_map,
data = data_train,
noncentred = TRUE,
run_model = FALSE,
autoformat = FALSE,
backend = backend,
...
))
model_file <- mod$model_file
#### Modify model data and model file ####
# Remove Z from supplied data
model_file <- model_file[
-grep(
"matrix[n_series, n_lv] Z; // matrix mapping series to latent states",
model_file,
fixed = TRUE
)
]
# Add M to data block
model_file[grep(
'int<lower=0> n_lv; // number of dynamic factors',
model_file,
fixed = TRUE
)] <- paste0(
'int<lower=0> n_lv; // number of dynamic factors\n',
'int<lower=0> M; // number of nonzero lower-triangular factor loadings'
)
model_file <- readLines(textConnection(model_file), n = -1)
# Update parameters
model_file <- model_file[
-grep("// latent state SD terms", model_file, fixed = TRUE)
]
model_file <- model_file[
-grep("vector<lower=0>[n_lv] sigma;", model_file, fixed = TRUE)
]
model_file[grep(
"matrix[n, n_lv] LV_raw;",
model_file,
fixed = TRUE
)] <- paste0(
"matrix[n, n_lv] LV_raw;\n\n",
"// factor lower triangle loadings\n",
"vector[M] L_lower;\n",
"// factor diagonal loadings\n",
"vector<lower=0>[n_lv] L_diag;"
)
model_file <- readLines(textConnection(model_file), n = -1)
# Update transformed parameters
model_file <- model_file[-grep("// latent states", model_file, fixed = TRUE)]
model_file <- model_file[-grep("lv_coefs = Z;", model_file, fixed = TRUE)]
model_file <- model_file[
-grep("matrix[n, n_lv] LV;", model_file, fixed = TRUE)
]
model_file <- model_file[
-grep("trend_mus = X_trend * b_trend;", model_file, fixed = TRUE)
]
model_file[grep(
"matrix[n_series, n_lv] lv_coefs;",
model_file,
fixed = TRUE
)] <- paste0(
"matrix[n_series, n_lv] lv_coefs = rep_matrix(0, n_series, n_lv);\n",
'matrix[n, n_lv] LV;\n'
)
starts <- grep(
"LV = LV_raw .* rep_matrix(sigma', rows(LV_raw));",
model_file,
fixed = TRUE
)
ends <- starts + 5
model_file <- model_file[-(starts:ends)]
# Simplified latent variable creation if no terms in factor_formula
if (
is.null(rownames(attr(terms.formula(factor_formula), 'factors'))) &
is.null(colnames(attr(terms.formula(factor_formula), 'factors')))
) {
model_file[grep(
"// latent process linear predictors",
model_file,
fixed = TRUE
)] <- paste0(
"// latent process linear predictors\n",
"trend_mus = X_trend * b_trend;\n\n",
"// constraints allow identifiability of loadings\n",
"{\n",
"int idx;\n",
"idx = 0;\n",
"for(j in 1 : n_lv) lv_coefs[j, j] = L_diag[j];\n",
"for(j in 1 : n_lv) {\n",
"for(k in (j + 1) : n_series) {\n",
"idx = idx + 1;\n",
"lv_coefs[k, j] = L_lower[idx];\n",
"}\n",
"}\n",
"}\n\n",
"// raw latent factors\n",
"LV = LV_raw;\n"
)
} else {
model_file[grep(
"// latent process linear predictors",
model_file,
fixed = TRUE
)] <- paste0(
"// latent process linear predictors\n",
"trend_mus = X_trend * b_trend;\n\n",
"// constraints allow identifiability of loadings\n",
"{\n",
"int idx;\n",
"idx = 0;\n",
"for(j in 1 : n_lv) lv_coefs[j, j] = L_diag[j];\n",
"for(j in 1 : n_lv) {\n",
"for(k in (j + 1) : n_series) {\n",
"idx = idx + 1;\n",
"lv_coefs[k, j] = L_lower[idx];\n",
"}\n",
"}\n",
"}\n\n",
"// raw latent factors (with linear predictors)\n",
"for (j in 1 : n_lv) {\n",
"for (i in 1 : n) {\n",
"LV[i, j] = trend_mus[ytimes_trend[i, j]] + LV_raw[i, j];\n",
"}\n}\n"
)
}
model_file <- model_file[
-grep("// derived latent states", model_file, fixed = TRUE)
]
model_file <- readLines(textConnection(model_file), n = -1)
# Update model block
sigma_prior <- grep(
"// priors for latent state SD parameters",
model_file,
fixed = TRUE
) +
1
model_file <- model_file[-sigma_prior]
# Use standard normal for loadings in most models, apart from
# those using identify link
if (family_links(mod$family) != 'identity') {
model_file[grep(
"// priors for latent state SD parameters",
model_file,
fixed = TRUE
)] <- paste0(
"// priors for factors and loading coefficients\n",
"L_lower ~ std_normal();\n",
"L_diag ~ std_normal();"
)
model_file <- readLines(textConnection(model_file), n = -1)
} else {
model_file[grep(
"// priors for latent state SD parameters",
model_file,
fixed = TRUE
)] <- paste0(
"// priors for factors and loading coefficients\n",
"L_lower ~ student_t(3, 0, 1);\n",
"L_diag ~ student_t(3, 0, 1);"
)
model_file <- readLines(textConnection(model_file), n = -1)
}
# Update generated quantities
model_file[grep(
'matrix[n, n_series] mus;',
model_file,
fixed = TRUE
)] <- paste0(
'matrix[n, n_series] mus;\n',
'vector[n_lv] sigma;'
)
model_file[grep(
"penalty = 1.0 / (sigma .* sigma);",
model_file,
fixed = TRUE
)] <- paste0(
"penalty = rep_vector(1.0, n_lv);\n",
"sigma = rep_vector(1.0, n_lv);"
)
model_file <- readLines(textConnection(model_file), n = -1)
model_file <- sanitise_modelfile(model_file)
# Remove Z from model_data as it is no longer needed
model_data <- mod$model_data
model_data$Z <- NULL
# Add M to model_data
n_series <- NCOL(model_data$ytimes)
model_data$M <- n_lv * (n_series - n_lv) + n_lv * (n_lv - 1) / 2
#### Autoformat the Stan code ####
if (requireNamespace('cmdstanr', quietly = TRUE) & backend == 'cmdstanr') {
if (
requireNamespace('cmdstanr') &
cmdstanr::cmdstan_version() >= "2.29.0"
) {
model_file <- .autoformat(
model_file,
overwrite_file = FALSE,
backend = 'cmdstanr',
silent = silent >= 1L
)
}
model_file <- readLines(textConnection(model_file), n = -1)
} else {
model_file <- .autoformat(
model_file,
overwrite_file = FALSE,
backend = 'rstan',
silent = silent >= 1L
)
model_file <- readLines(textConnection(model_file), n = -1)
}
# Remove lp__ from monitor params if VB is to be used
param <- unique(c(mod$monitor_pars, 'Sigma', 'LV'))
if (algorithm %in% c('meanfield', 'fullrank', 'pathfinder', 'laplace')) {
param <- param[!param %in% 'lp__']
}
#### Determine what to return ####
if (!run_model) {
mod$model_file <- model_file
mod$monitor_pars <- param
attr(model_data, 'trend_model') <- 'None'
attr(model_data, 'prepped_trend_model') <- prepped_trend
attr(model_data, 'noncentred') <- NULL
attr(model_data, 'threads') <- threads
mod$model_data <- model_data
out <- mod
} else {
# Check if cmdstan is accessible; if not, use rstan
if (backend == 'cmdstanr') {
if (!requireNamespace('cmdstanr', quietly = TRUE)) {
warning('cmdstanr library not found. Defaulting to rstan')
use_cmdstan <- FALSE
} else {
use_cmdstan <- TRUE
if (is.null(cmdstanr::cmdstan_version(error_on_NA = FALSE))) {
warning(
'cmdstanr library found but Cmdstan not found. Defaulting to rstan'
)
use_cmdstan <- FALSE
}
}
}
#### Run the model ####
if (use_cmdstan) {
# Prepare threading and generate the model
cmd_mod <- .model_cmdstanr(model_file, threads = threads, silent = silent)
# Condition the model using Cmdstan
out_gam_mod <- .sample_model_cmdstanr(
model = cmd_mod,
algorithm = algorithm,
prior_simulation = FALSE,
data = model_data,
chains = chains,
parallel = parallel,
silent = silent,
max_treedepth = control$max_treedepth,
adapt_delta = control$adapt_delta,
threads = threads,
burnin = burnin,
samples = samples,
param = param,
save_all_pars = FALSE,
...
)
} else {
# Condition the model using rstan
requireNamespace('rstan', quietly = TRUE)
out_gam_mod <- .sample_model_rstan(
model = model_file,
algorithm = algorithm,
prior_simulation = FALSE,
data = model_data,
chains = chains,
parallel = parallel,
silent = silent,
max_treedepth = control$max_treedepth,
adapt_delta = control$adapt_delta,
threads = threads,
burnin = burnin,
samples = samples,
thin = thin,
...
)
}
# After modeling (add a new class to make predictions and other post-processing
# simpler)
out1 <- mod
out1$model_output <- out_gam_mod
class(out1) <- c('mvgam')
if (residuals) {
mod_residuals <- dsresids_vec(out1)
} else {
mod_residuals <- NULL
}
rm(out1)
# Add the posterior median coefficients to the mgcv objects
ss_gam <- mod$mgcv_model
V <- cov(mcmc_chains(out_gam_mod, 'b'))
ss_gam$Vp <- ss_gam$Vc <- V
p <- mcmc_summary(
out_gam_mod,
'b',
variational = algorithm %in%
c('meanfield', 'fullrank', 'pathfinder', 'laplace')
)[, c(4)]
names(p) <- names(ss_gam$coefficients)
ss_gam$coefficients <- p
trend_mgcv_model <- mod$trend_mgcv_model
V <- cov(mcmc_chains(out_gam_mod, 'b_trend'))
trend_mgcv_model$Vp <- trend_mgcv_model$Vc <- V
p <- mcmc_summary(
out_gam_mod,
'b_trend',
variational = algorithm %in%
c('meanfield', 'fullrank', 'pathfinder', 'laplace')
)[, c(4)]
names(p) <- names(trend_mgcv_model$coefficients)
trend_mgcv_model$coefficients <- p
#### Return the output as class mvgam ####
trim_data <- list()
attr(trim_data, 'threads') <- threads
attr(trim_data, 'noncentred') <- NULL
attr(trim_data, 'trend_model') <- 'None'
attr(trim_data, 'prepped_trend_model') <- prepped_trend
# Extract sampler arguments
dots <- list(...)
if ('adapt_delta' %in% names(dots)) {
message(
'argument "adapt_delta" should be supplied as an element in "control"'
)
adapt_delta <- dots$adapt_delta
dots$adapt_delta <- NULL
} else {
adapt_delta <- control$adapt_delta
if (is.null(adapt_delta)) adapt_delta <- 0.8
}
if ('max_treedepth' %in% names(dots)) {
message(
'argument "max_treedepth" should be supplied as an element in "control"'
)
max_treedepth <- dots$max_treedepth
dots$max_treedepth <- NULL
} else {
max_treedepth <- control$max_treedepth
if (is.null(max_treedepth)) max_treedepth <- 10
}
out <- structure(
list(
call = mod$call,
trend_call = factor_formula,
family = mod$family,
share_obs_params = mod$share_obs_params,
trend_model = 'None',
trend_map = trend_map,
drift = FALSE,
priors = mod$priors,
model_output = out_gam_mod,
model_file = model_file,
model_data = if (return_model_data) {
model_data
} else {
trim_data
},
inits = NULL,
monitor_pars = param,
sp_names = mod$sp_names,
trend_sp_names = mod$trend_sp_names,
mgcv_model = ss_gam,
trend_mgcv_model = trend_mgcv_model,
ytimes = mod$ytimes,
resids = mod_residuals,
use_lv = TRUE,
n_lv = n_lv,
upper_bounds = mod$upper_bounds,
obs_data = mod$obs_data,
test_data = mod$test_data,
fit_engine = 'stan',
backend = backend,
algorithm = algorithm,
max_treedepth = max_treedepth,
adapt_delta = adapt_delta
),
class = c('mvgam', 'jsdgam')
)
}
return(out)
}
#' Prep trend for jsdgam
#' @noRd
prep_jsdgam_trend = function(data, unit, subgr) {
unit <- as_one_character(unit)
subgr <- as_one_character(subgr)
validate_var_exists(
data = data,
variable = unit,
type = 'num/int',
name = 'data',
trend_char = 'ZMVN'
)
validate_var_exists(
data = data,
variable = subgr,
type = 'factor',
name = 'data',
trend_char = 'ZMVN'
)
out <- structure(
list(
trend_model = 'ZMVN',
ma = FALSE,
cor = TRUE,
unit = unit,
gr = "NA",
subgr = subgr,
label = NULL
),
class = 'mvgam_trend'
)
}
#' @noRd
prep_jsdgam_trendmap = function(data, n_lv) {
if (n_lv > nlevels(data$series)) {
stop(
'Number of factors must be <= number of levels in species',
call. = FALSE
)
}
data.frame(
trend = rep(1:n_lv, nlevels(data$series))[1:nlevels(data$series)],
series = factor(levels(data$series), levels = levels(data$series))
)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.