R/predict.R

Defines functions add_attributes_predict get_post_pred_link check_newdata

#' @include occumb.R
NULL

#' @title Predict method for occumbFit class.
#' @description Obtain predictions of parameters related to species occupancy
#'  and detection from an \code{occumbFit} model object.
#' @param object An \code{occumbFit} object.
#' @param newdata An optional \code{occumbData} object with covariates to be
#'  used for prediction. If omitted, the fitted covariates are used.
#' @param parameter The parameter to be predicted.
#' @param scale The scale on which the prediction is made.
#'  \code{type = "response"} returns the prediction on the original scale of the
#'  parameter. \code{type = "link"} returns the prediction on the link scale of
#'  the parameter.
#' @param type The type of prediction.
#'  \code{type = "quantiles"} returns 50% quantile as the posterior median of
#'  the prediction in addition to 2.5 and 97.5% quantiles as the lower and
#'  upper limits of the 95% credible interval of the prediction.
#'  \code{type = "mean"} returns the posterior mean of the prediction.
#'  \code{type = "samples"} returns the posterior samples of the prediction.
#' @return
#'  Predictions are obtained as a matrix or array that can have dimensions
#'  corresponding to statistics (or samples), species, sites, and replicates.
#'  The \code{dimension} and \code{label} attributes are added to the output
#'  object to inform these dimensions.
#'  If the sequence read count data \code{y} has species, site, or replicate
#'  names appended as the \code{dimnames} attribute (see Details in
#'  \code{\link{occumbData}()}), they will be copied into the \code{label}
#'  attribute of the returned object.
#' @details
#'  Applying \code{predict()} to an \code{occumbFit} object generates predictions
#'  for the specified parameter (\code{phi}, \code{theta}, or \code{psi}) based
#'  on the estimated effects and the given covariates. It is important to
#'  recognize that the predictions are specific to the individual species being
#'  modeled since they depend on the estimated species-specific effects (i.e.,
#'  \code{alpha}, \code{beta}, and \code{gamma}; see
#'  \href{https://fukayak.github.io/occumb/articles/model_specification.html}{the package vignette} for details).
#'  When providing \code{newdata}, it must thus be assumed that the set of
#'  species contained in \code{newdata} is the same as that of the data being
#'  fitted.
#' @export
setMethod("predict", signature(object = "occumbFit"),
  function(object,
           newdata = NULL,
           parameter = c("phi", "theta", "psi"),
           scale = c("response", "link"),
           type = c("quantiles", "mean", "samples")) {

    parameter <- match.arg(parameter)
    scale     <- match.arg(scale)
    type      <- match.arg(type)

    if (missing(newdata)) {
      data <- object@data
    } else {
      check_newdata(newdata, object)
      data <- newdata
      if (!identical(dimnames(newdata@y)[[1]], dimnames(object@data@y)[[1]])) {
        dimnames(data@y)[[1]] <- dimnames(object@data@y)[[1]]
      }
    }

    inv_link <- switch(parameter,
                       phi   = exp,
                       theta = stats::plogis,
                       psi   = stats::plogis)

    post_pred_link <- get_post_pred_link(object, data, parameter)

    if (type == "quantiles") {

      out <- apply(post_pred_link,
                   2:length(dim(post_pred_link)),
                   stats::quantile, probs = c(0.5, 0.025, 0.975))

      if (scale == "response") {
        out <- inv_link(out)
      }

    } else if (type == "mean") {

      if (scale == "response") {
        out <- apply(inv_link(post_pred_link),
                     2:length(dim(post_pred_link)),
                     mean)
      } else if (scale == "link") {
        out <- apply(post_pred_link,
                     2:length(dim(post_pred_link)),
                     mean)
      }

      if (is.null(dim(out))) {
        out <- array(out, c(1, length(out)))
      } else {
        out <- array(out, c(1, dim(out)))
      }

    } else if (type == "samples") {

      out <- post_pred_link

      if (scale == "response") {
        out <- inv_link(out)
      }

    }

    return(add_attributes_predict(out, parameter, scale, type, data))
  }
)


check_newdata <- function(newdata, object) {

  ## Stop if the number of species does not match
  if (!identical(dim(newdata@y)[[1]], dim(object@data@y)[[1]])) {
    stop(sprintf("The number of species in 'newdata' (%s) differs from that in the fitted data (%s).", dim(newdata@y)[[1]], dim(object@data@y)[[1]]))
  }

  ## Stop if covariate names and their order do not match
  name_cov <- name_newcov <- list()
  name_cov$spec_cov <-
    if (is.null(names(object@data@spec_cov))) {
      "(None)"
    } else {
      names(object@data@spec_cov)
    }
  name_newcov$spec_cov <-
    if (is.null(names(newdata@spec_cov))) {
      "(None)"
    } else {
      names(newdata@spec_cov)
    }
  name_cov$site_cov <-
    if (is.null(names(object@data@site_cov))) {
      "(None)"
    } else {
      names(object@data@site_cov)
    }
  name_newcov$site_cov <-
    if (is.null(names(newdata@site_cov))) {
      "(None)"
    } else {
      names(newdata@site_cov)
    }
  name_cov$repl_cov <-
    if (is.null(names(object@data@repl_cov))) {
      "(None)"
    } else {
      names(object@data@repl_cov)
    }
  name_newcov$repl_cov <-
    if (is.null(names(newdata@repl_cov))) {
      "(None)"
    } else {
      names(newdata@repl_cov)
    }
  name_mismatch <- c(!identical(name_cov$spec_cov, name_newcov$spec_cov),
                     !identical(name_cov$site_cov, name_newcov$site_cov),
                     !identical(name_cov$repl_cov, name_newcov$repl_cov))
  if (any(name_mismatch)) {
    stop(paste(c("The names of the covariates in 'newdata' and their order must match those in the fitted data.",
                 sprintf("\n  %s: %s (newdata); %s (fitted data)",
                         c("spec_cov", "site_cov", "repl_cov")[name_mismatch],
                         sapply(name_newcov, paste, collapse = ", ")[name_mismatch],
                         sapply(name_cov, paste, collapse = ", ")[name_mismatch])
               ), collapse = ""))
  }

  ## Stop if covariate classes do not match
  class_cov <- unlist(c(sapply(object@data@spec_cov, class),
                        sapply(object@data@site_cov, class),
                        sapply(object@data@repl_cov, function(x) class(c(x)))))
  class_newcov <- unlist(c(sapply(newdata@spec_cov, class),
                           sapply(newdata@site_cov, class),
                           sapply(newdata@repl_cov, function(x) class(c(x)))))
  class_mismatch <- (class_cov != class_newcov)
  if (any(class_mismatch)) {
    stop(paste(c("The covariate classes in 'newdata' must match those in the fitted data.",
                 sprintf("\n  %s: %s (newdata), %s (fitted data)",
                         names(class_cov[class_mismatch]),
                         class_newcov[class_mismatch],
                         class_cov[class_mismatch])), collapse = ""))
  }

  ## Stop if a discrete covariate in newdata contains a new level
  level_cov <- level_newcov <- list()
  for (i in seq_along(newdata@spec_cov)) {
    if (class(object@data@spec_cov[[i]]) %in% c("character", "factor")) {
      eval(parse(text = sprintf("level_cov$%s <- unique(object@data@spec_cov[[i]])",
                                names(object@data@spec_cov)[i])))
      eval(parse(text = sprintf("level_newcov$%s <- unique(newdata@spec_cov[[i]])",
                                names(newdata@spec_cov)[i])))
    }
  }
  for (i in seq_along(newdata@site_cov)) {
    if (inherits(object@data@site_cov[[i]], c("factor", "character"))) {
      eval(parse(text = sprintf("level_cov$%s <- unique(object@data@site_cov[[i]])",
                                names(object@data@site_cov)[i])))
      eval(parse(text = sprintf("level_newcov$%s <- unique(newdata@site_cov[[i]])",
                                names(newdata@site_cov)[i])))
    }
  }
  for (i in seq_along(newdata@repl_cov)) {
    if (inherits(c(object@data@repl_cov[[i]]), "character")) {
      eval(parse(text = sprintf("level_cov$%s <- unique(object@data@repl_cov[[i]])",
                                names(object@data@repl_cov)[i])))
      eval(parse(text = sprintf("level_newcov$%s <- unique(newdata@repl_cov[[i]])",
                                names(newdata@repl_cov)[i])))
    }
  }
  level_mismatch <- vector(length = length(level_cov))
  for (i in seq_along(level_cov)) {
    level_mismatch[i] <- !identical(level_cov[[i]], level_newcov[[i]])
  }
  if (any(level_mismatch)) {
    stop(paste(c("The levels of discrete covariates in 'newdata' must match those in the fitted data.",
                 sprintf("\n  %s: %s (newdata); %s (fitted data)",
                         names(level_cov)[level_mismatch],
                         sapply(level_newcov, paste, collapse = ", ")[level_mismatch],
                         sapply(level_cov, paste, collapse = ", ")[level_mismatch])
               ), collapse = ""))
  }

  ## Stop if an order of factor levels does not match
  level_cov <- level_newcov <- list()
  for (i in seq_along(newdata@spec_cov)) {
    if (inherits(object@data@spec_cov[[i]], "factor")) {
      eval(parse(text = sprintf("level_cov$%s <- levels(object@data@spec_cov[[i]])",
                                names(object@data@spec_cov)[i])))
      eval(parse(text = sprintf("level_newcov$%s <- levels(newdata@spec_cov[[i]])",
                                names(newdata@spec_cov)[i])))
    }
  }
  for (i in seq_along(newdata@site_cov)) {
    if (inherits(object@data@site_cov[[i]], "factor")) {
      eval(parse(text = sprintf("level_cov$%s <- levels(object@data@site_cov[[i]])",
                                names(object@data@site_cov)[i])))
      eval(parse(text = sprintf("level_newcov$%s <- levels(newdata@site_cov[[i]])",
                                names(newdata@site_cov)[i])))
    }
  }
  level_mismatch <- vector(length = length(level_cov))
  for (i in seq_along(level_cov)) {
    level_mismatch[i] <- !identical(level_cov[[i]], level_newcov[[i]])
  }
  if (any(level_mismatch)) {
    stop(paste(c("The levels of discrete covariates in 'newdata' must match those in the fitted data.",
                 sprintf("\n  %s: %s (newdata); %s (fitted data)",
                         names(level_cov)[level_mismatch],
                         sapply(level_newcov, paste, collapse = ", ")[level_mismatch],
                         sapply(level_cov, paste, collapse = ", ")[level_mismatch])
               ), collapse = ""))
  }

  # Warn if the list of species names does not match
  if (!identical(dimnames(newdata@y)[[1]], dimnames(object@data@y)[[1]])) {
    warning("The list of species names in 'newdata' does not match that in the fitted data; the list of species names in the fitted data will be added to the 'label' attribute of the returned object.")
  }
}


get_post_pred_link <- function(object, data, parameter) {

  # Input:
  #   (a matrix, a vector) ... intercept + covariates, or
  #   (a vector, a numeric) ... intercept only or single covariate (for shared_effect)
  # Output:
  #   a vector
  get_post_linpred <- function(post_effect, cov) {

    if (is.null(dim(post_effect)) && (length(cov) == 1)) {
      result <- post_effect * cov
    } else {
      result <- c(post_effect %*% cov)
    }

    return(result)
  }

  formula <-
    switch(parameter,
           phi = formula(object@occumb_args$formula_phi),
           theta = formula(object@occumb_args$formula_theta),
           psi = formula(object@occumb_args$formula_psi))
  formula_shared <-
    switch(parameter,
           phi = formula(object@occumb_args$formula_phi_shared),
           theta = formula(object@occumb_args$formula_theta_shared),
           psi = formula(object@occumb_args$formula_psi_shared))
  effect <-
    switch(parameter,
           phi = "alpha",
           theta = "beta",
           psi = "gamma")
  effect_shared <-
    switch(parameter,
           phi = "alpha_shared",
           theta = "beta_shared",
           psi = "gamma_shared")

  list_cov <- set_covariates(data, formula, formula_shared, parameter)
  has_shared_effect <- !is.null(list_cov$cov_shared)

  I <- dim(data@y)[1]
  J <- dim(data@y)[2]
  K <- dim(data@y)[3]
  N <- object@fit$mcmc.info$n.samples

  post_effect <- get_post_samples(object, effect)
  if (has_shared_effect) {
    post_effect_shared <- get_post_samples(object, effect_shared)
    if (list_cov$type == "i") {
      pred_link <- matrix(nrow = N, ncol = I)
      for (i in seq_len(I)) {
        pred_link[, i] <-
          get_post_linpred(post_effect[, i, ], list_cov$cov) +
          get_post_linpred(post_effect_shared, list_cov$cov_shared[i, ])
      }
    } else if (list_cov$type == "ij") {
      pred_link <- array(dim = c(N, I, J))
      for (i in seq_len(I)) {
        for (j in seq_len(J)) {
          pred_link[, i, j] <-
            get_post_linpred(post_effect[, i, ], list_cov$cov[j, ]) +
            get_post_linpred(post_effect_shared, list_cov$cov_shared[i, j, ])
        }
      }
    } else if (list_cov$type == "ijk") {
      pred_link <- array(dim = c(N, I, J, K))
      for (i in seq_len(I)) {
        for (j in seq_len(J)) {
          for (k in seq_len(K)) {
            pred_link[, i, j, k] <-
              get_post_linpred(post_effect[, i, ], list_cov$cov[j, k, ]) +
              get_post_linpred(post_effect_shared, list_cov$cov_shared[i, j, k, ])
          }
        }
      }
    }
  } else {
    if (list_cov$type == "i") {
      pred_link <- matrix(nrow = N, ncol = I)
      for (i in seq_len(I)) {
        pred_link[, i] <- get_post_linpred(post_effect[, i, ], list_cov$cov)
      }
    } else if (list_cov$type == "ij") {
      pred_link <- array(dim = c(N, I, J))
      for (i in seq_len(I)) {
        for (j in seq_len(J)) {
          pred_link[, i, j] <-
            get_post_linpred(post_effect[, i, ], list_cov$cov[j, ])
        }
      }
    } else if (list_cov$type == "ijk") {
      pred_link <- array(dim = c(N, I, J, K))
      for (i in seq_len(I)) {
        for (j in seq_len(J)) {
          for (k in seq_len(K)) {
            pred_link[, i, j, k] <-
              get_post_linpred(post_effect[, i, ], list_cov$cov[j, k, ])
          }
        }
      }
    }
  }

  return(pred_link)
}


add_attributes_predict <- function(x, parameter, scale, type, data) {

  label_pred <- function(x, data) {

    get_dimnames <- function(dn) {
      if (!is.null(dn)) {
        return(dn)
      } else {
        return(NULL)
      }
    }

    if (length(dim(x)) == 2 || is.null(dim(x))) {
      out <- list(Species = get_dimnames(dimnames(data@y)[[1]]))
    } else if (length(dim(x)) == 3) {
      out <- list(Species = get_dimnames(dimnames(data@y)[[1]]),
                  Sites = get_dimnames(dimnames(data@y)[[2]]))
    } else if (length(dim(x)) == 4) {
      out <- list(Species = get_dimnames(dimnames(data@y)[[1]]),
                  Sites = get_dimnames(dimnames(data@y)[[2]]),
                  Replicates = get_dimnames(dimnames(data@y)[[3]]))
    }

    return(out)
  }

  attr(x, "parameter") <- parameter
  attr(x, "scale")     <- scale

  dim_pred <-
    if (length(dim(x)) == 2) {
      c("Species")
    } else if (length(dim(x)) == 3) {
      c("Species", "Sites")
    } else if (length(dim(x)) == 4) {
      c("Species", "Sites", "Replicates")
    }

  if (type == "quantiles") {
    attr(x, "dimension") <- c("Statistics", dim_pred)
    attr(x, "label") <- c(list(Statistics = c("50%", "2.5%", "97.5%")),
                          label_pred(x, data))
  } else if (type == "mean") {
    attr(x, "dimension") <- c("Statistics", dim_pred)
    attr(x, "label") <- c(list(Statistics = c("mean")),
                          label_pred(x, data))
  } else if (type == "samples") {
    attr(x, "dimension") <- c("Samples", dim_pred)
    attr(x, "label") <- c(list(Samples = NULL),
                          label_pred(x, data))
  }

  return(x)
}
fukayak/occumb documentation built on April 17, 2025, 11:50 a.m.