R/generate_cell_draws_and_summarize.R

Defines functions generate_cell_draws_and_summarize

Documented in generate_cell_draws_and_summarize

#' Generate cell draws and summary rasters from INLA model
#'
#' @description Use INLA posteriors to predict out across a grid
#'
#' @details
#' Based on a fitted INLA model, the survey area defined in an ID raster, and a set of
#' covariates, generate predictive grid cell draws and summary rasters across a study
#' area.
#'
#' @param inla_model Output from [fit_inla_model()]
#' @param inla_mesh An SPDE mesh used to define the spatial integration points of the INLA
#'   geostatistical model. Typically created using [INLA::inla.mesh.2d()] or a similar
#'   function.
#' @param n_samples (numeric) Number of posterior predictive samples to draw.
#' @param id_raster ([terra::SpatRaster]) raster showing all cell locations where
#'   predictions should be taken.
#' @param covariates (list) Named list of all covariate effects included in the model,
#'   typically generated by [load_covariates()].
#' @param inverse_link_function (character) If a link function was used in the INLA model,
#'   name of the R function to transform the predictive draws from link space to natural
#'   space. For example, in a logit-linked binomial model, pass 'plogis' (as a string is
#'   fine) to invert-logit the predictive draws.
#' @param nugget_in_predict (`logical(1)`, default TRUE) Should the nugget term be used as
#'   an IID noise term applied to each pixel-draw?
#' @param admin_boundaries ([sf][sf::sf] object, default NULL) The same admin boundaries
#'   used to create the admin-level effect, if one was defined in the model. Only used if
#'   an admin-level effect was defined in the model.
#' @param ui_width (numeric, default 0.95) Size of the uncertainty interval width when
#'   calculating the upper and lower summary rasters
#' @param verbose (`logical(1)`, default TRUE) Log progress for draw generation?
#'
#' @return Named list containing at least the following items:
#'   - "parameter_draws": posterior samples generated from [INLA::inla.posterior.sample()]
#'   - "cell_draws": A matrix of grid cell draws. Each row represents a non-NA pixel in
#'     the `id_raster`, in the same order that would be pulled by [terra::values()], and
#'     each column represents a different posterior draw.
#'   - "cell_pred_mean": Mean predictive estimate by grid cell, formatted as a terra
#'     SpatRaster
#'   - "cell_pred_lower": Lower bound of (X%) uncertainty interval, formatted as a terra
#'     SpatRaster
#'   - "cell_pred_upper": Upper bound of (X%) uncertainty interval, formatted as a terra
#'     SpatRaster
#'
#' @concept prediction
#'
#' @import data.table
#' @importFrom assertthat assert_that
#' @importFrom Matrix rowMeans
#' @importFrom matrixStats rowQuantiles
#' @importFrom purrr map map_dbl
#' @importFrom stats na.omit rnorm
#' @importFrom terra extract values
#' @export
generate_cell_draws_and_summarize <- function(
  inla_model, inla_mesh, n_samples, id_raster, covariates, inverse_link_function,
  nugget_in_predict = TRUE, admin_boundaries = NULL, ui_width = 0.95, verbose = TRUE
){
  if(verbose) logging_start_timer("Generating model predictions")
  # Get original link function and inverse link function
  inverse_link <- get(inverse_link_function)
  links <- list(plogis = stats::qlogis, identity = identity, logit = stats::plogis)
  link_fun <- links[[inverse_link_function]]

  # Generate INLA posterior samples
  if(verbose) logging_start_timer("Parameter posterior samples")
  posterior_samples <- INLA::inla.posterior.sample(
    n = n_samples, result = inla_model, add.names = FALSE
  )
  if(verbose) logging_stop_timer()

  # Reorder as a matrix with rows named after the coresponding model terms
  if(verbose) logging_start_timer("Cell draws")
  latent_matrix <- purrr::map(posterior_samples, 'latent') |> do.call(what = cbind)
  param_names <- rownames(posterior_samples[[1]]$latent) |>
    strsplit(split = ':') |>
    vapply(`[`, 1, FUN.VALUE = character(1))
  rownames(latent_matrix) <- param_names

  # Create the template for predictions, dimensions = N grid cells (by) N samples
  xy_fields <- c('x','y')
  id_raster_table <- data.table::as.data.table(id_raster, xy = TRUE) |>
    na.omit()
  transformed_cell_draws <- matrix(0, nrow = nrow(id_raster_table), ncol = n_samples)

  ## Optionally add covariate effects
  if('covariates' %in% param_names){
    cov_names <- names(covariates)
    for(cov_name in cov_names){
      id_raster_table[[cov_name]] <- terra::extract(
        x = covariates[[cov_name]],
        y = as.matrix(id_raster_table[, xy_fields, with = F])
      )[, 1]
    }
    # If stacking was used, transform to logit space
    sum_to_one_constraint <- inla_model$.args$formula |> as.character() |>
      grepl(pattern='extraconstr') |> any()
    if(sum_to_one_constraint){
      for(cov_name in cov_names) id_raster_table[, (cov_name) := link_fun(get(cov_name)) ]
    }
    fe_coefficients <- latent_matrix[param_names == 'covariates', ]
    fe_draws <- as.matrix(id_raster_table[, cov_names, with = F]) %*% fe_coefficients
  } else {
    fe_draws <- latent_matrix[param_names == '(Intercept)', ] |>
      matrix(ncol = n_samples, nrow = nrow(transformed_cell_draws), byrow = TRUE)
  }
  transformed_cell_draws <- transformed_cell_draws + fe_draws

  ## Optionally add spatial GP effect
  if('space' %in% param_names){
    A_proj_predictions <- INLA::inla.spde.make.A(
      mesh = inla_mesh,
      loc = as.matrix(id_raster_table[, xy_fields, with = F])
    )
    spatial_mesh_effects <- latent_matrix[param_names == 'space', ]
    assertthat::assert_that(nrow(spatial_mesh_effects) == ncol(A_proj_predictions))
    space_draws <- as.matrix(A_proj_predictions %*% spatial_mesh_effects)
    assertthat::assert_that(all.equal(dim(transformed_cell_draws), dim(space_draws)))
    transformed_cell_draws <- transformed_cell_draws + space_draws
  }

  # Optionally add nugget effect
  if(('nugget' %in% param_names) & nugget_in_predict){
    # Get draws of nugget precision -> draws of nugget standard deviation
    nugget_precision <- purrr::map(posterior_samples, 'hyperpar') |>
      purrr::map_dbl("Precision for nugget")
    nugget_sigma <- 1 / sqrt(nugget_precision)
    # Generate IID noise for each draw
    nugget_draws <- rnorm(length(transformed_cell_draws), mean = 0, sd = nugget_sigma) |>
      matrix(ncol = n_samples, byrow = TRUE)
    # Add to the draws
    transformed_cell_draws <- transformed_cell_draws + nugget_draws
  }

  # Add an admin-level effect, if one was defined
  if('adm_effect' %in% param_names){
    if(is.null(admin_boundaries)){
      stop("Admin effects were defined but admin boundaries were not passed for prediction")
    }
    admin_effects <- latent_matrix[param_names == 'adm_effect', ]
    if(nrow(admin_boundaries) != nrow(admin_effects)) stop("Admin effects dimension mismatch")
    # Translate from grid cells to admin units
    admin_boundaries$admin_id <- seq_len(nrow(admin_boundaries))
    grid_cells_to_admin_ids <- id_raster |>
      as.data.frame(xy = TRUE) |>
      sf::st_as_sf(coords = xy_fields, crs = 'EPSG:4326') |>
      sf::st_join(y = admin_boundaries[, c('admin_id')], join = sf::st_nearest_feature)
    admin_effects_by_cell <- admin_effects[grid_cells_to_admin_ids$admin_id, ]
    # Add admin-level effect to the draws
    transformed_cell_draws <- transformed_cell_draws + admin_effects_by_cell
  }

  # Apply the inverse link function to get predictive draws by grid cell
  predictive_draws <- inverse_link(transformed_cell_draws)
  if(verbose) logging_stop_timer()

  ## Summarize as rasters
  if(verbose) logging_start_timer("Summarize draws")
  to_fill <- which(!is.na(terra::values(id_raster)))
  r_mean <- r_lower <- r_upper <- id_raster
  terra::values(r_mean)[to_fill] <- Matrix::rowMeans(predictive_draws)
  terra::values(r_lower)[to_fill] <- (
    matrixStats::rowQuantiles(predictive_draws, probs = (1 - ui_width)/2)
  )
  terra::values(r_upper)[to_fill] <- (
    matrixStats::rowQuantiles(predictive_draws, probs = 1 - (1 - ui_width)/2)
  )
  if(verbose) logging_stop_timer() # End summarization

  # Return list of predictions
  predictions_list <- list(
    parameter_draws = posterior_samples,
    cell_draws = predictive_draws,
    cell_pred_mean = r_mean,
    cell_pred_lower = r_lower,
    cell_pred_upper = r_upper
  )
  if(verbose) logging_stop_timer() # End prediction
  return(predictions_list)
}

Try the mbg package in your browser

Any scripts or data that you put into this service are public.

mbg documentation built on April 4, 2025, 2:06 a.m.