R/validations.R

Defines functions as_one_character find_stan find_jags validate_threads check_nmix check_obs_intercept check_gp_terms check_priorsim validate_trend_restrictions validate_trendmap validate_pos_real validate_even validate_predictors validate_pos_integers validate_pos_integer validate_equaldims validate_proportional validate_gr_subgr validate_trend_formula validate_obs_formula validate_trend_model validate_family_restrictions validate_family validate_silent collapse deparse0 as_one_integer as_one_logical deparse_variable validate_var_exists validate_series_groups validate_series_time

#'Argument validation functions
#'@param data Data to be validated (list or data.frame)
#'@noRd
validate_series_time = function(
  data,
  name = 'data',
  trend_model,
  check_levels = TRUE,
  check_times = TRUE
) {
  # First validation requires the full trend_model object
  if (!inherits(trend_model, 'mvgam_trend')) {
    trend_model <- list(
      trend_model = trend_model,
      unit = 'time',
      gr = 'NA',
      subgr = 'series'
    )
  }

  # Respect any original additions of implicit_vars
  implicit_series <- implicit_time <- FALSE
  if (!is.null(attr(data, 'implicit_vars'))) {
    implicit_series <- 'series' %in% attr(data, 'implicit_vars')
    implicit_time <- 'time' %in% attr(data, 'implicit_vars')
  }

  # Validate any grouping structure and update the data accordingly
  data <- validate_series_groups(
    data = data,
    trend_model = trend_model,
    name = name
  )
  if (!is.null(attr(data, 'implicit_vars'))) {
    implicit_series <- 'series' %in% attr(data, 'implicit_vars')
    implicit_time <- 'time' %in% attr(data, 'implicit_vars')
  }

  # Now we only need the character trend_model string
  trend_model <- trend_model$trend_model

  # Series label must be present as a factor and
  # must contain a time variable
  if (inherits(data, 'data.frame')) {
    data %>%
      dplyr::ungroup() -> data

    if (!'series' %in% colnames(data)) {
      data$series <- factor('series1')
      implicit_series <- TRUE
    }

    # Series factor must have all unique levels present
    if (!is.factor(data$series)) {
      stop('Variable "series" must be a factor type', call. = FALSE)
    }

    if (!'time' %in% colnames(data)) {
      if (trend_model == 'None') {
        # Add a time indicator if missing
        data %>%
          dplyr::group_by(series) %>%
          dplyr::mutate(time = dplyr::row_number()) %>%
          dplyr::ungroup() -> data
        implicit_time <- TRUE
      } else {
        stop(name, " does not contain a 'time' variable", call. = FALSE)
      }
    }
  }

  if (inherits(data, 'list')) {
    if (!'series' %in% names(data)) {
      data$series <- factor('series1')
      implicit_series <- TRUE
    }

    # Series factor must have all unique levels present
    if (!is.factor(data$series)) {
      stop('Variable "series" must be a factor type', call. = FALSE)
    }

    if (!'time' %in% names(data)) {
      if (trend_model == 'None') {
        # Add a time indicator if missing
        data.frame(series = data$series) %>%
          dplyr::group_by(series) %>%
          dplyr::mutate(time = dplyr::row_number()) %>%
          dplyr::pull(time) -> times
        implicit_time <- TRUE
        data$time <- times
      } else {
        stop(name, " does not contain a 'time' variable", call. = FALSE)
      }
    }
  }

  # Add an identifier so post-processing functions know
  # what the original supplied data ordering was; this is needed
  # for ensuring that functions such as fitted() and
  # residuals() return objects that match the original order that
  # the user supplied, if no newdata are given
  data$index..orig..order <- 1:length(data$time)

  # Add a new 'time' variable that will be useful for rearranging data for
  # modeling, in case 'time' is also supplied as a covariate or if this is
  # a continuous time model (CAR1)
  data$index..time..index <- data$time

  # Use the data ordering to set the index of time for CAR1
  if (trend_model == 'CAR1') {
    data.frame(series = data$series, time = data$time) %>%
      dplyr::mutate(orig_rows = dplyr::row_number()) %>%
      dplyr::group_by(series) %>%
      dplyr::mutate(idx = dplyr::row_number()) %>%
      dplyr::arrange(time) %>%
      dplyr::mutate(time. = dplyr::row_number()) %>%
      dplyr::ungroup() %>%
      dplyr::arrange(orig_rows) %>%
      dplyr::pull(time.) -> times
    data$index..time..index <- times
  }

  # Series factor must have all unique levels present if this is a
  # forecast check
  if (check_levels) {
    if (!all(levels(data$series) %in% unique(data$series))) {
      stop(
        paste0(
          'Mismatch between factor levels of "series" and unique values of "series"',
          '\n',
          'Use\n  `setdiff(levels(data$series), unique(data$series))` \nand',
          '\n',
          '  `intersect(levels(data$series), unique(data$series))`\nfor guidance'
        ),
        call. = FALSE
      )
    }
  }

  # Ensure each series has an observation, even if NA, for each
  # unique timepoint (only for trend models that require discrete time with
  # regularly spaced sampling intervals)
  if (check_times) {
    all_times_avail = function(time, min_time, max_time) {
      identical(
        as.numeric(sort(time)),
        as.numeric(seq.int(from = min_time, to = max_time))
      )
    }
    min_time <- as.numeric(min(data$index..time..index))
    max_time <- as.numeric(max(data$index..time..index))
    data.frame(series = data$series, time = data$index..time..index) %>%
      dplyr::group_by(series) %>%
      dplyr::summarise(
        all_there = all_times_avail(time, min_time, max_time)
      ) -> checked_times
    if (any(checked_times$all_there == FALSE)) {
      stop(
        "One or more series in ",
        name,
        " is missing observations for one or more timepoints",
        call. = FALSE
      )
    }
  }

  if (implicit_series & implicit_time) {
    attr(data, 'implicit_vars') <- c('series', 'time')
  }

  if (implicit_series & !implicit_time) {
    attr(data, 'implicit_vars') <- 'series'
  }

  if (implicit_time & !implicit_series) {
    attr(data, 'implicit_vars') <- 'time'
  }

  if (!implicit_time & !implicit_series) {
    attr(data, 'implicit_vars') <- NULL
  }

  return(data)
}

# Function to ensure units of analysis, groups and subgroups are arranged
# and formatted properly for mvgam processing and modelling
#'@noRd
validate_series_groups = function(data, trend_model, name = 'data') {
  implicit_series <- implicit_time <- FALSE

  # Checks only needed if trend_model isn't 'None'
  if (trend_model$trend_model != 'None') {
    # Check that unit and subgr exist in data and are the correct type
    if (is.null(trend_model$gr)) trend_model$gr <- 'NA'
    if (is.null(trend_model$unit)) trend_model$unit <- 'time'
    if (is.null(trend_model$subgr)) trend_model$subgr <- 'series'

    if (
      trend_model$gr == 'NA' &
        trend_model$unit == 'time' &
        trend_model$subgr == 'series'
    ) {
      if (!'series' %in% names(data)) {
        data$series <- factor('series1')
        implicit_series <- TRUE
      }

      if (!'time' %in% names(data)) {
        if (trend_model$trend_model %in% c('ZMVNcor', 'ZMVNhiercor')) {
          # Add a time indicator if missing
          data.frame(series = data$series) %>%
            dplyr::group_by(series) %>%
            dplyr::mutate(time = dplyr::row_number()) %>%
            dplyr::pull(time) -> times
          implicit_time <- TRUE
          data$time <- times
        } else {
          stop(name, " does not contain a 'time' variable", call. = FALSE)
        }
      }
    }
    validate_var_exists(
      data = data,
      variable = trend_model$unit,
      type = 'num/int',
      name = name,
      trend_char = trend_model$trend_model
    )

    validate_var_exists(
      data = data,
      variable = trend_model$subgr,
      type = 'factor',
      name = name,
      trend_char = trend_model$trend_model
    )

    # If gr is supplied, check it exists and is the correct type
    if (trend_model$gr != 'NA') {
      implicit_series <- implicit_time <- TRUE
      validate_var_exists(
        data = data,
        variable = trend_model$gr,
        type = 'factor',
        name = name,
        trend_char = trend_model$trend_model
      )
      if (trend_model$subgr == 'series') {
        stop(
          'argument "subgr" cannot be set to "series" if "gr" is also supplied',
          call. = FALSE
        )
      }

      # Add necessary 'time' variable
      gr_dat <- data.frame(
        time = data[[trend_model$unit]],
        gr = data[[trend_model$gr]],
        subgr = data[[trend_model$subgr]]
      )

      # Check that each level of gr contains all possible levels of subgr
      gr_total_levels <- gr_dat %>%
        dplyr::group_by(gr) %>%
        dplyr::summarise(tot_subgrs = length(unique(subgr))) %>%
        dplyr::ungroup() %>%
        dplyr::pull(tot_subgrs)
      if (length(gr_total_levels) > 1) {
        if (stats::var(gr_total_levels) != 0) {
          stop(
            paste0(
              'Some levels of "',
              trend_model$gr,
              '" do not contain all\n',
              'unique levels of "',
              trend_model$subgr,
              '"',
              " in ",
              name
            ),
            call. = FALSE
          )
        }
      }

      gr_dat %>%
        dplyr::mutate(
          series = interaction(
            gr,
            subgr,
            drop = TRUE,
            sep = '_',
            lex.order = TRUE
          )
        ) -> gr_dat
    } else {
      if (trend_model$unit != 'time') {
        implicit_time <- TRUE
      }
      if (trend_model$subgr != 'series') {
        implicit_series <- TRUE
      }
      gr_dat <- data.frame(
        time = data[[trend_model$unit]],
        subgr = data[[trend_model$subgr]]
      ) %>%
        dplyr::mutate(series = as.factor(subgr))
    }

    # Add the possibly new 'series' and 'time' variables to
    # the data and return
    data$series <- gr_dat$series
    data$time <- gr_dat$time
  }

  if (implicit_series & implicit_time) {
    attr(data, 'implicit_vars') <- c('series', 'time')
  }

  if (implicit_series & !implicit_time) {
    attr(data, 'implicit_vars') <- 'series'
  }

  if (implicit_time & !implicit_series) {
    attr(data, 'implicit_vars') <- 'time'
  }

  if (!implicit_time & !implicit_series) {
    attr(data, 'implicit_vars') <- NULL
  }

  return(data)
}

#'@noRd
validate_var_exists = function(
  data,
  variable,
  type = 'factor',
  name = 'data',
  trend_char
) {
  if (trend_char != 'None') {
    if (!exists(variable, data)) {
      stop(
        paste0('Variable "', variable, '" not found in ', name),
        call. = FALSE
      )
    }

    if (type == 'num/int') {
      if (!is.numeric(data[[variable]])) {
        stop(
          paste0(
            'Variable "',
            variable,
            '" must be either numeric or integer type'
          ),
          call. = FALSE
        )
      }
    }

    if (type == 'factor') {
      if (!is.factor(data[[variable]])) {
        stop(
          paste0('Variable "', variable, '" must be a factor type'),
          call. = FALSE
        )
      }
    }
  }
}

#'@noRd
deparse_variable = function(...) {
  deparse0(substitute(...))
}

#'@noRd
as_one_logical = function(x, allow_na = FALSE) {
  s <- substitute(x)
  x <- as.logical(x)
  if (length(x) != 1L || anyNA(x) && !allow_na) {
    s <- deparse0(s, max_char = 100L)
    stop("Cannot coerce '", s, "' to a single logical value.", call. = FALSE)
  }
  x
}

#'@noRd
as_one_integer <- function(x, allow_na = FALSE) {
  s <- substitute(x)
  x <- suppressWarnings(as.integer(x))
  if (length(x) != 1L || anyNA(x) && !allow_na) {
    s <- deparse0(s, max_char = 100L)
    stop("Cannot coerce '", s, "' to a single integer value.", call. = FALSE)
  }
  x
}

#'@noRd
deparse0 <- function(x, max_char = NULL, ...) {
  out <- collapse(deparse(x, ...))
  if (isTRUE(max_char > 0)) {
    out <- substr(out, 1L, max_char)
  }
  out
}

#'@noRd
collapse <- function(..., sep = "") {
  paste(..., sep = sep, collapse = "")
}

#'@noRd
validate_silent <- function(silent) {
  silent <- as_one_integer(silent)
  if (silent < 0 || silent > 2) {
    stop("'silent' must be between 0 and 2.", call. = FALSE)
  }
  silent
}

#'@importFrom rlang warn
#'@noRd
validate_family = function(family, use_stan = TRUE) {
  if (is.character(family)) {
    if (family == 'beta') family <- betar()

    family <- try(eval(parse(text = family)), silent = TRUE)

    if (inherits(family, 'try-error'))
      stop("family not recognized", call. = FALSE)
  }

  if (is.function(family)) family <- family()

  if (is.null(family$family)) stop("family not recognized", call. = FALSE)

  if (!inherits(family, 'family')) stop('family not recognized', call. = FALSE)

  if (family$family == 'Beta regression') family$family <- 'beta'

  if (family$family == 'tweedie')
    insight::check_if_installed(
      "tweedie",
      reason = 'to simulate from Tweedie distributions'
    )

  if (
    !family$family %in% c('poisson', 'negative binomial', 'tweedie') & !use_stan
  )
    stop(
      'JAGS only supports poisson(), nb() or tweedie() families',
      call. = FALSE
    )

  # Stan cannot support Tweedie
  if (use_stan & family$family == 'tweedie')
    stop('Tweedie family not supported for stan', call. = FALSE)

  if (family$family %in% c('binomial', 'beta_binomial'))
    rlang::warn(
      paste0(
        "Binomial and Beta-binomial families require cbind(n_successes, n_trials)\n",
        "in the formula left-hand side. Do not use cbind(n_successes, n_failures)!"
      ),
      .frequency = "once",
      .frequency_id = 'cbind_binomials'
    )
  return(family)
}

#'@noRd
validate_family_restrictions = function(response, family) {
  response <- response[!is.na(response)]

  # 0s and 1s only for Bernoulli
  if (family$family == 'bernoulli') {
    y <- response
    nobs <- length(response)
    weights <- rep(1, length(response))
    eval(binomial()$initialize)
  }

  # 0s and 1s not allowed for Beta
  if (family$family == 'beta') {
    if (any(response <= 0)) {
      stop('Values <= 0 not allowed for beta responses', call. = FALSE)
    }
    if (any(response >= 1)) {
      stop('Values >= 1 not allowed for beta responses', call. = FALSE)
    }
  }

  # negatives not allowed for several families
  if (
    family$family %in%
      c('poisson', 'negative binomial', 'tweedie', 'binomial', 'beta_binomial')
  ) {
    if (any(response < 0)) {
      stop(
        paste0('Values < 0 not allowed for count family responses'),
        call. = FALSE
      )
    }
  }

  # negatives and/or zeros not allowed for several families
  if (family$family %in% c('lognormal', 'Gamma')) {
    if (any(response <= 0)) {
      stop(
        paste0('Values <= 0 not allowed for ', family$family, ' responses'),
        call. = FALSE
      )
    }
  }
}

#'@noRd
validate_trend_model = function(
  trend_model,
  drift = FALSE,
  noncentred = FALSE,
  warn = TRUE
) {
  if (inherits(trend_model, 'mvgam_trend')) {
    ma_term <- if (trend_model$ma) {
      'MA'
    } else {
      NULL
    }
    cor_term <- if (trend_model$cor) {
      'cor'
    } else {
      NULL
    }
    if (is.null(trend_model$gr)) trend_model$gr <- 'NA'
    if (trend_model$gr != 'NA') {
      gr_term <- 'hier'
    } else {
      gr_term <- NULL
    }
    trend_model <- paste0(
      trend_model$trend_model,
      trend_model$p,
      ma_term,
      gr_term,
      cor_term
    )
  } else {
    if (trend_model != 'None' & warn) {
      rlang::warn(
        paste0(
          "Supplying trend_model as a character string is deprecated\n",
          "Please use the dedicated functions (i.e. RW() or ZMVN()) instead"
        ),
        .frequency = "once",
        .frequency_id = 'trend_characters'
      )
    }
  }

  trend_model <- match.arg(arg = trend_model, choices = trend_model_choices())

  if (trend_model == 'VAR') {
    trend_model <- 'VAR1'
  }
  if (trend_model == 'VARcor') {
    trend_model <- 'VAR1cor'
  }
  if (trend_model == 'VARhiercor') {
    trend_model <- 'VAR1hiercor'
  }
  if (trend_model %in% c('VARMA', 'VARMAcor')) {
    trend_model <- 'VARMA1,1cor'
  }

  if (
    !trend_model %in% c('None', 'RW', 'AR1', 'AR2', 'AR3', 'CAR1') & noncentred
  ) {
    message('Non-centering of trends currently not available for this model')
  }

  if (trend_model %in% c('PWlinear', 'PWlogistic'))
    insight::check_if_installed(
      "extraDistr",
      reason = 'to simulate from piecewise trends'
    )
  return(trend_model)
}

#'@noRd
validate_obs_formula = function(formula, data, refit = FALSE) {
  if (attr(terms(formula), "response") == 0L) {
    stop('response variable is missing from formula', call. = FALSE)
  }

  # Check that response terms are in the data; account for possible
  # 'cbind' in there if this is a binomial model
  resp_terms <- as.character(terms(formula(formula))[[2]])
  if (length(resp_terms) == 1) {
    out_name <- as.character(terms(formula(formula))[[2]])
    if (!as.character(terms(formula(formula))[[2]]) %in% names(data)) {
      stop(
        paste0('variable ', terms(formula(formula))[[2]], ' not found in data'),
        call. = FALSE
      )
    }
  } else {
    if (any(grepl('cbind', resp_terms))) {
      resp_terms <- resp_terms[-grepl('cbind', resp_terms)]
      out_name <- resp_terms[1]
      for (i in 1:length(resp_terms)) {
        if (!resp_terms[i] %in% names(data)) {
          stop(
            paste0('variable ', resp_terms[i], ' not found in data'),
            call. = FALSE
          )
        }
      }
    } else {
      stop(
        'Not sure how to deal with this response variable specification',
        call. = FALSE
      )
    }
  }

  if (any(attr(terms(formula), 'term.labels') %in% 'y')) {
    stop(
      'due to internal data processing, "y" should not be used as the name of a predictor in mvgam',
      call. = FALSE
    )
  }

  # Add a y outcome for sending to the modelling backend
  data$y <- data[[out_name]]

  return(data)
}

#'@noRd
validate_trend_formula = function(formula) {
  if (!is.null(rlang::f_lhs(formula))) {
    stop(
      'Argument "trend_formula" should not have a left-hand side',
      call. = FALSE
    )
  }

  if (any(grepl('series', as.character(formula)))) {
    stop(
      'Argument "trend_formula" should not have the identifier "series" in it.\nUse "trend" instead for varying effects',
      call. = FALSE
    )
  }

  if (!is.null(attr(terms(formula(formula)), 'offset'))) {
    stop('Offsets not allowed in argument "trend_formula"', call. = FALSE)
  }
}

#'@noRd
validate_gr_subgr = function(gr, subgr, cor) {
  gr <- deparse0(substitute(gr))
  subgr <- deparse0(substitute(subgr))

  if (gr != 'NA') {
    if (subgr == 'NA') {
      stop(
        'argument "subgr" must be supplied if "gr" is also supplied',
        call. = FALSE
      )
    }
  }

  if (subgr != 'NA') {
    if (gr == 'NA') {
      stop(
        'argument "gr" must be supplied if "subgr" is also supplied',
        call. = FALSE
      )
    } else {
      cor <- TRUE
    }
  }

  list(.group = gr, .subgroup = subgr, .cor = cor)
}

#'@noRd
validate_proportional = function(x) {
  s <- substitute(x)
  x <- base::suppressWarnings(as.numeric(x))
  if (length(x) != 1L || anyNA(x)) {
    stop("Argument '", s, "' must be a single numeric value", call. = FALSE)
  }

  if (x < 0 || x > 1) {
    stop(
      "Argument '",
      s,
      "' must be a proportion ranging from 0 to 1, inclusive",
      call. = FALSE
    )
  }
}

#'@noRd
validate_equaldims = function(x, y) {
  s <- substitute(x)
  q <- substitute(y)

  if (NCOL(x) != NCOL(y)) {
    stop(
      "Argument '",
      s,
      "' and argument '",
      q,
      "' must have equal dimensions",
      call. = FALSE
    )
  }

  if (NROW(x) != NROW(y)) {
    stop(
      "Argument '",
      s,
      "' and argument '",
      q,
      "' must have equal dimensions",
      call. = FALSE
    )
  }
}

#'@noRd
validate_pos_integer = function(x) {
  s <- substitute(x)
  x <- base::suppressWarnings(as.numeric(x))
  if (length(x) != 1L || anyNA(x)) {
    stop("Argument '", s, "' must be a single numeric value", call. = FALSE)
  }

  if (sign(x) != 1) {
    stop("Argument '", s, "' must be a positive integer", call. = FALSE)
  } else {
    if (x %% 1 != 0) {
      stop("Argument '", s, "' must be a positive integer", call. = FALSE)
    }
  }
}

#'@noRd
validate_pos_integers = function(x) {
  s <- substitute(x)

  val_pos = function(y, s) {
    y <- base::suppressWarnings(as.numeric(y))
    if (sign(y) != 1) {
      stop("Negative values in ", s, " detected", call. = FALSE)
    } else {
      if (y %% 1 != 0) {
        stop("Non-integer values in ", s, " detected", call. = FALSE)
      }
    }
  }
  res <- lapply(seq_along(x), function(i) val_pos(x[i], s))
}

#'@noRd
validate_predictors = function(object, newdata) {
  # Check names of supplied variables against those required
  # for prediction
  required_vars <- insight::find_predictors(
    object,
    component = "all"
  )$conditional
  required_vars <- setdiff(
    required_vars,
    attr(object$obs_data, 'implicit_vars')
  )

  # If time and / or series were in original data but not used,
  # there is no need for them in the newdata for models with
  # no trend_model
  if (!inherits(object$trend_model, 'mvgam_trend')) {
    if (object$trend_model == 'None') {
      # Find all predictor terms that were used by the model
      obs_preds <- insight::find_predictors(
        object$mgcv_model,
        effects = 'all',
        component = 'all',
        flatten = TRUE
      )

      trend_preds <- vector()
      if (!is.null(object$trend_call)) {
        trend_preds <- insight::find_predictors(
          object$trend_mgcv_model,
          effects = 'all',
          component = 'all',
          flatten = TRUE
        )

        preds <- unique(c(obs_preds, trend_preds))
      } else {
        preds <- obs_preds
      }

      # If time and series not used as predictors, no need for them
      # to be in required_vars
      if ('time' %in% required_vars & !'time' %in% preds) {
        required_vars <- setdiff(required_vars, 'time')
      }

      if (
        'series' %in%
          required_vars &
          !'series' %in% preds &
          (!object$use_lv | is.null(object$trend_call))
      ) {
        required_vars <- setdiff(required_vars, 'series')
      }
    }
  }

  if (length(required_vars)) {
    if (any(required_vars %in% names(newdata) == FALSE)) {
      stop(
        paste0(
          'the following required variables are missing from newdata:\n ',
          paste(
            required_vars[which(!required_vars %in% names(newdata))],
            collapse = ', '
          )
        ),
        call. = FALSE
      )
    }
  }
}

#'@noRd
validate_even <- function(x) {
  s <- substitute(x)
  x <- base::suppressWarnings(as.numeric(x))
  if (x %% 2 != 0) {
    stop("Argument '", s, "'  must be an even integer", call. = FALSE)
  }
}

#'@noRd
validate_pos_real = function(x) {
  s <- substitute(x)
  x <- base::suppressWarnings(as.numeric(x))
  if (length(x) != 1L || anyNA(x)) {
    stop("Argument '", s, "' must be a single numeric value", call. = FALSE)
  }

  if (sign(x) != 1) {
    stop("Argument '", s, "' must be a positive real value", call. = FALSE)
  }
}

#'@noRd
validate_trendmap = function(trend_map, data_train, trend_model, use_stan) {
  # Trend mapping not supported by JAGS
  if (!use_stan) {
    stop('trend mapping not available for JAGS', call. = FALSE)
  }

  # trend_map must have an entry for each unique time series
  if (!all(sort(trend_map$series) == sort(unique(data_train$series)))) {
    stop(
      'Argument "trend_map" must have an entry for every unique time series in "data"',
      call. = FALSE
    )
  }

  # trend_map must not specify a greater number of trends than there are series
  if (max(trend_map$trend) > length(unique(data_train$series))) {
    stop(
      'Argument "trend_map" specifies more latent trends than there are series in "data"',
      call. = FALSE
    )
  }

  # trend_map must not skip any trends, but can have zeros for some entries
  drop_zero = function(x) {
    x[x != 0]
  }

  if (
    !all(
      drop_zero(sort(unique(trend_map$trend))) == seq(1:max(trend_map$trend))
    )
  ) {
    stop(
      'Argument "trend_map" must link at least one series to each latent trend',
      call. = FALSE
    )
  }

  # series variable must be a factor with same levels as the series variable
  # in the data
  if (!is.factor(trend_map$series)) {
    stop(
      'trend_map$series must be a factor with levels matching levels of data$series',
      call. = FALSE
    )
  }

  if (!all(levels(trend_map$series) == levels(data_train$series))) {
    stop(
      'trend_map$series must be a factor with levels matching levels of data$series',
      call. = FALSE
    )
  }
}

#'@noRd
validate_trend_restrictions = function(
  trend_model,
  formula,
  trend_formula,
  trend_map,
  drift = FALSE,
  drop_obs_intercept = FALSE,
  use_lv = FALSE,
  n_lv,
  data_train,
  use_stan = TRUE,
  priors = FALSE
) {
  # Assess whether additional moving average or correlated errors are needed
  ma_cor_adds <- ma_cor_additions(trend_model)
  list2env(ma_cor_adds, envir = environment())

  if (length(unique(data_train$series)) == 1 & add_cor) {
    warning(
      'Correlated process errors not possible with only 1 series',
      call. = FALSE
    )
    add_cor <- FALSE
  }

  # Some checks on general trend setup restrictions
  if (!priors) {
    if (trend_model %in% c('PWlinear', 'PWlogistic')) {
      if (attr(terms(formula), 'intercept') == 1 & !drop_obs_intercept) {
        warning(
          paste0(
            'It is difficult / impossible to estimate intercepts\n',
            'and piecewise trend offset parameters. You may want to\n',
            'consider dropping the intercept from the formula'
          ),
          call. = FALSE
        )
      }

      if (use_lv)
        stop(
          'Cannot estimate piecewise trends using dynamic factors',
          call. = FALSE
        )
    }
  }

  if (use_lv & (add_ma | add_cor) & missing(trend_formula)) {
    stop(
      'Cannot estimate moving averages or correlated errors for dynamic factors',
      call. = FALSE
    )
  }

  if (use_lv & drift) {
    warning(
      'Cannot identify drift terms for this model\ninclude "time" as a fixed effect instead',
      call. = FALSE
    )
    drift <- FALSE
  }

  if (drift && trend_model == 'CAR1') {
    warning(
      'Cannot identify drift terms for CAR models; setting "drift = FALSE"',
      call. = FALSE
    )
    drift <- FALSE
  }

  if (
    trend_model %in% c('VAR', 'VAR1', 'VAR1cor', 'VARMA1,1cor', 'GP') & drift
  ) {
    warning(
      'Cannot identify drift terms for VAR or GP models; setting "drift = FALSE"',
      call. = FALSE
    )
    drift <- FALSE
  }

  if (use_lv & trend_model == 'VAR1' & missing(trend_formula)) {
    stop(
      'Cannot identify dynamic factor models that evolve as VAR processes',
      call. = FALSE
    )
  }

  if (!use_stan & trend_model %in% c('GP', 'VAR1', 'PWlinear', 'PWlogistic')) {
    stop(
      'Gaussian Process, VAR and piecewise trends not supported for JAGS',
      call. = FALSE
    )
  }

  # Check trend formula and create the trend_map if missing
  if (!missing(trend_formula)) {
    validate_trend_formula(trend_formula)
    if (missing(trend_map)) {
      trend_map <- data.frame(
        series = factor(
          levels(data_train$series),
          levels = levels(data_train$series)
        ),
        trend = 1:length(unique(data_train$series))
      )
    }

    if (
      !trend_model %in%
        c('None', 'RW', 'AR1', 'AR2', 'AR3', 'VAR1', 'CAR1', 'ZMVN')
    ) {
      stop(
        'only None, ZMVN, RW, AR1, AR2, AR3, CAR1 and VAR trends currently supported for trend predictor models',
        call. = FALSE
      )
    }
  }

  # Check trend_map is correctly specified
  if (!missing(trend_map)) {
    validate_trendmap(
      trend_map = trend_map,
      data_train = data_train,
      trend_model = trend_model,
      use_stan = use_stan
    )

    # If trend_map correctly specified, set use_lv to TRUE for
    # most models (but not yet for VAR models, which require additional
    # modifications)
    if (trend_model == 'VAR1') {
      use_lv <- FALSE
    } else {
      use_lv <- TRUE
    }
    n_lv <- max(trend_map$trend)
  }

  # Number of latent variables cannot be greater than number of series
  if (use_lv) {
    if (missing(n_lv)) {
      n_lv <- min(2, floor(length(unique(data_train$series)) / 2))
    }
    if (n_lv > length(unique(data_train$series))) {
      stop('number of latent variables cannot be greater than number of series')
    }
  }

  if (missing(trend_map)) {
    trend_map <- NULL
  }

  if (missing(n_lv)) {
    n_lv <- NULL
  }

  return(list(
    trend_model = trend_model,
    add_cor = add_cor,
    add_ma = add_ma,
    use_var1 = use_var1,
    use_var1cor = use_var1cor,
    use_lv = use_lv,
    n_lv = n_lv,
    trend_map = trend_map,
    drift = drift
  ))
}

#'@noRd
check_priorsim = function(prior_simulation, data_train, orig_y, formula) {
  # Fill y with NAs if this is a simulation from the priors
  if (prior_simulation) {
    data_train$y <- rep(NA, length(data_train$y))
  } else {
    data_train$y <- orig_y
  }

  # Fill response variable with original supplied values
  resp_terms <- as.character(terms(formula(formula))[[2]])
  if (length(resp_terms) == 1) {
    out_name <- as.character(terms(formula(formula))[[2]])
  } else {
    if (any(grepl('cbind', resp_terms))) {
      resp_terms <- resp_terms[-grepl('cbind', resp_terms)]
      out_name <- resp_terms[1]
    }
  }
  data_train[[out_name]] <- orig_y

  return(data_train)
}

#'@noRd
check_gp_terms = function(formula, data_train, family) {
  # Check for proper binomial specification
  if (!missing(family)) {
    if (is.character(family)) {
      if (family == 'beta') family <- betar()

      family <- try(eval(parse(text = family)), silent = TRUE)

      if (inherits(family, 'try-error'))
        stop("family not recognized", call. = FALSE)
    }

    if (is.function(family)) family <- family()

    if (family$family %in% c('binomial', 'beta_binomial')) {
      # Check that response terms use the cbind() syntax
      resp_terms <- as.character(terms(formula(formula))[[2]])
      if (length(resp_terms) == 1) {
        stop(
          'Binomial family requires cbind() syntax in the formula left-hand side',
          call. = FALSE
        )
      } else {
        if (any(grepl('cbind', resp_terms))) {
        } else {
          stop(
            'Binomial family requires cbind() syntax in the formula left-hand side',
            call. = FALSE
          )
        }
      }
    }
  }

  # Check for gp terms in the validated formula
  orig_formula <- gp_terms <- gp_details <- NULL
  if (any(grepl('gp(', attr(terms(formula), 'term.labels'), fixed = TRUE))) {
    formula <- interpret_mvgam(
      formula,
      N = max(data_train$time),
      family = family
    )
    orig_formula <- formula

    # Keep intercept?
    keep_intercept <- attr(terms(formula), 'intercept') == 1

    # Indices of gp() terms in formula
    gp_terms <- which_are_gp(formula)

    # Extract attributes
    gp_details <- get_gp_attributes(formula, data_train, family)

    # Replace with s() terms so the correct terms are included
    # in the model.frame
    formula <- gp_to_s(formula, data_train, family)
    if (!keep_intercept) formula <- update(formula, . ~ . - 1)
  }

  return(list(
    orig_formula = orig_formula,
    gp_terms = gp_terms,
    formula = formula,
    gp_details = gp_details
  ))
}

#'@noRd
check_obs_intercept = function(formula, orig_formula) {
  # Check for missing rhs in formula
  # If there are no terms in the observation formula (i.e. y ~ -1),
  # we will use an intercept-only observation formula and fix
  # the intercept coefficient at zero
  drop_obs_intercept <- FALSE
  if (
    length(attr(terms(formula), 'term.labels')) == 0 &
      !attr(terms(formula), 'intercept') == 1
  ) {
    formula_envir <- attr(formula, '.Environment')

    if (length(attr(terms(formula), 'factors')) == 0) {
      resp <- as.character(attr(terms(formula), 'variables'))[2]
    } else {
      resp <- dimnames(attr(terms(formula), 'factors'))[[1]][1]
    }

    if (!is.null(attr(terms(formula(formula)), 'offset'))) {
      formula <- formula(paste0(
        resp,
        ' ~ ',
        paste(gsub(' - 1', ' + 1', rlang::f_text(formula)))
      ))
    } else {
      formula <- formula(paste(resp, '~ 1'))
    }
    attr(formula, '.Environment') <- formula_envir
    drop_obs_intercept <- TRUE
  }

  if (is.null(orig_formula)) orig_formula <- formula

  return(list(
    orig_formula = orig_formula,
    formula = formula,
    drop_obs_intercept = drop_obs_intercept
  ))
}

#'@noRd
check_nmix = function(
  family,
  family_char,
  trend_formula,
  trend_model,
  trend_map,
  data_train,
  priors = FALSE
) {
  # Check for N-mixture modifications
  add_nmix <- FALSE
  nmix_trendmap <- TRUE
  if (family_char == 'nmix') {
    if (!(exists('cap', where = data_train))) {
      stop(
        'Max abundances must be supplied as a variable named "cap" for N-mixture models',
        call. = FALSE
      )
    }

    add_nmix <- TRUE
    if (!priors) {
      family <- poisson()
      family_char <- 'poisson'
      if (missing(trend_formula)) {
        stop('Argument "trend_formula" required for nmix models', call. = FALSE)
      }
    }

    if (!missing(trend_map)) {
      nmix_trendmap <- TRUE
    }
    use_lv <- TRUE
    if (trend_model == 'None') {
      trend_model <- 'RW'
    }
  }

  return(list(
    trend_model = trend_model,
    add_nmix = add_nmix,
    nmix_trendmap = nmix_trendmap,
    family = family,
    family_char = family_char
  ))
}

#'@noRd
validate_threads = function(family_char, threads) {
  if (
    threads > 1 &
      !family_char %in%
        c(
          'poisson',
          'negative binomial',
          'gaussian',
          'lognormal',
          'beta',
          'student',
          'Gamma'
        )
  ) {
    warning(
      'multithreading not yet supported for this family; setting threads = 1'
    )
    threads <- 1
  }
  return(threads)
}

#'@noRd
find_jags = function(jags_path) {
  if (!requireNamespace('runjags', quietly = TRUE)) {
    stop('runjags library is required but not found', call. = FALSE)
  }

  if (missing(jags_path)) {
    requireNamespace('runjags', quietly = TRUE)
    jags_path <- runjags::findjags()
  }

  # Code borrowed from the runjags package
  jags_status <- runjags::testjags(jags_path, silent = TRUE)
  if (!jags_status$JAGS.available) {
    if (jags_status$os == "windows") {
      Sys.sleep(0.2)
      jags_status <- runjags::testjags(jags_path, silent = TRUE)
    }

    if (!jags_status$JAGS.available) {
      cat(
        "Unable to call JAGS using '",
        jags_path,
        "'\nTry specifying the path to the JAGS binary as jags_path argument, or re-installing the rjags package.\nUse the runjags::testjags() function for more detailed diagnostics.\n",
        sep = ""
      )
      stop(
        "Unable to call JAGS.\nEither use the Stan backend or follow examples in ?mvgam to generate data / model files and run outside of mvgam",
        call. = FALSE
      )
    }
  }
}

#'@noRd
find_stan = function() {
  if (!requireNamespace('rstan', quietly = TRUE)) {
    warning('rstan library not found; checking for cmdstanr library')

    if (!requireNamespace('cmdstanr', quietly = TRUE)) {
      stop('cmdstanr library not found', call. = FALSE)
    }
  }
}

#'@noRd
as_one_character <- function(x, allow_na = FALSE) {
  s <- substitute(x)
  x <- as.character(x)
  if (length(x) != 1L || anyNA(x) && !allow_na) {
    s <- deparse0(s, max_char = 100L)
    stop("Cannot coerce '", s, "' to a single character value.", call. = FALSE)
  }
  x
}
nicholasjclark/mvgam documentation built on April 17, 2025, 9:39 p.m.