R/add_binomial.R

Defines functions add_binomial

#' @noRd
add_binomial = function(
  formula,
  model_file,
  model_data,
  data_train,
  data_test,
  family_char
) {
  # Add trial information if necessary
  if (family_char %in% c('binomial', 'beta_binomial')) {
    # Identify which variable in data represents the number of trials
    resp_terms <- as.character(terms(formula(formula))[[2]])
    resp_terms <- resp_terms[-grepl('cbind', resp_terms)]
    trial_name <- resp_terms[2]

    # Pull the trials variable from the data and validate
    train_trials <- data_train[[trial_name]]

    if (any(is.na(train_trials))) {
      stop(
        paste0('variable ', trial_name, ' contains missing values'),
        call. = FALSE
      )
    }

    if (any(is.infinite(train_trials))) {
      stop(
        paste0('variable ', trial_name, ' contains infinite values'),
        call. = FALSE
      )
    }

    # Matrix of trials per series
    all_trials <- data.frame(
      series = as.numeric(data_train$series),
      time = data_train$time,
      trials = data_train[[trial_name]]
    ) %>%
      dplyr::arrange(time, series)

    # Same for data_test
    if (!is.null(data_test)) {
      if (!(exists(trial_name, where = data_test))) {
        stop(
          'Number of trials must also be supplied in "newdata" for Binomial models',
          call. = FALSE
        )
      }

      all_trials <- rbind(
        all_trials,
        data.frame(
          series = as.numeric(data_test$series),
          time = data_test$time,
          trials = data_test[[trial_name]]
        )
      ) %>%
        dplyr::arrange(time, series)

      if (any(is.na(all_trials$trial)) | any(is.infinite(all_trials$trial))) {
        stop(
          paste0(
            'Missing or infinite values found in ',
            trial_name,
            ' variable'
          ),
          call. = FALSE
        )
      }
    }

    # Construct matrix of N-trials in the correct format so it can be
    # flattened into one long vector
    trials <- matrix(
      NA,
      nrow = length(unique(all_trials$time)),
      ncol = length(unique(all_trials$series))
    )
    for (i in 1:length(unique(all_trials$series))) {
      trials[, i] <- all_trials$trials[which(all_trials$series == i)]
    }

    # Add trial info to the model data
    model_data$flat_trials <- as.vector(trials)
    model_data$flat_trials_train <- as.vector(trials)[which(
      as.vector(model_data$y_observed) == 1
    )]

    # Add trial vectors to model block
    model_file[grep(
      "int<lower=0> flat_ys[n_nonmissing]; // flattened nonmissing observations",
      model_file,
      fixed = TRUE
    )] <-
      paste0(
        "array[n_nonmissing] int<lower=0> flat_ys; // flattened nonmissing observations\n",
        "array[total_obs] int<lower=0> flat_trials; // flattened trial vector\n",
        "array[n_nonmissing] int<lower=0> flat_trials_train; // flattened nonmissing trial vector\n"
      )
    model_file <- readLines(textConnection(model_file), n = -1)
  } else {
    trials <- NULL
  }

  # Update parameters block
  if (family_char == 'beta_binomial') {
    model_file[grep("vector[num_basis] b_raw;", model_file, fixed = TRUE)] <-
      paste0("vector[num_basis] b_raw;\n", "vector<lower=0>[n_series] phi;")
    model_file <- readLines(textConnection(model_file), n = -1)
  }

  # Update functions block
  if (family_char == 'beta_binomial') {
    if (any(grepl('functions {', model_file, fixed = TRUE))) {
      model_file[grep('functions {', model_file, fixed = TRUE)] <-
        paste0(
          "functions {\n",
          "vector rep_each(vector x, int K) {\n",
          "int N = rows(x);\n",
          "vector[N * K] y;\n",
          "int pos = 1;\n",
          "for (n in 1 : N) {\n",
          "for (k in 1 : K) {\n",
          "y[pos] = x[n];\n",
          "pos += 1;\n",
          "}\n",
          "}\n",
          "return y;\n",
          "}"
        )
    } else {
      model_file[grep('Stan model code', model_file)] <-
        paste0(
          '// Stan model code generated by package mvgam\n',
          'functions {\n',
          "vector rep_each(vector x, int K) {\n",
          "int N = rows(x);\n",
          "vector[N * K] y;\n",
          "int pos = 1;\n",
          "for (n in 1 : N) {\n",
          "for (k in 1 : K) {\n",
          "y[pos] = x[n];\n",
          "pos += 1;\n",
          "}\n",
          "}\n",
          "return y;\n",
          "}\n}"
        )
    }
  }

  # Update model block
  if (family_char == 'binomial') {
    if (
      any(grepl("flat_ys ~ poisson_log_glm(flat_xs,", model_file, fixed = TRUE))
    ) {
      linenum <- grep(
        "flat_ys ~ poisson_log_glm(flat_xs,",
        model_file,
        fixed = TRUE
      )
      model_file[linenum] <-
        paste0("flat_ys ~ binomial(flat_trials_train, inv_logit(flat_xs * b));")

      model_file <- model_file[-(linenum + 1)]
      model_file <- readLines(textConnection(model_file), n = -1)
    }

    if (
      any(grepl(
        "flat_ys ~ poisson_log_glm(append_col(flat_xs, flat_trends),",
        model_file,
        fixed = TRUE
      ))
    ) {
      linenum <- grep(
        "flat_ys ~ poisson_log_glm(append_col(flat_xs, flat_trends),",
        model_file,
        fixed = TRUE
      )
      model_file[linenum] <-
        paste0(
          "flat_ys ~ binomial(flat_trials_train, inv_logit(append_col(flat_xs, flat_trends) * append_row(b, 1.0)));"
        )

      model_file <- model_file[-(linenum + 1)]
      model_file <- readLines(textConnection(model_file), n = -1)
    }
  }

  if (family_char == 'beta_binomial') {
    model_file[grep("model {", model_file, fixed = TRUE)] <-
      paste0(
        "model {\n",
        "// priors for Beta dispersion parameters\n",
        "phi ~ gamma(0.01, 0.01);"
      )
    if (
      any(grepl("flat_ys ~ poisson_log_glm(flat_xs,", model_file, fixed = TRUE))
    ) {
      linenum <- grep(
        "flat_ys ~ poisson_log_glm(flat_xs,",
        model_file,
        fixed = TRUE
      )
      model_file[linenum] <-
        paste0(
          "vector[n_nonmissing] flat_phis;\n",
          "flat_phis = rep_each(phi, n)[obs_ind];\n",
          "flat_ys ~ beta_binomial(flat_trials_train, inv_logit(flat_xs * b) .* flat_phis, (1 - inv_logit(flat_xs * b)) .* flat_phis);"
        )

      model_file <- model_file[-(linenum + 1)]
      model_file <- readLines(textConnection(model_file), n = -1)
    }

    if (
      any(grepl(
        "flat_ys ~ poisson_log_glm(append_col(flat_xs, flat_trends),",
        model_file,
        fixed = TRUE
      ))
    ) {
      linenum <- grep(
        "flat_ys ~ poisson_log_glm(append_col(flat_xs, flat_trends),",
        model_file,
        fixed = TRUE
      )
      model_file[linenum] <-
        paste0(
          "vector[n_nonmissing] flat_phis;\n",
          "flat_phis = rep_each(phi, n)[obs_ind];\n",
          "flat_ys ~ beta_binomial(flat_trials_train, inv_logit(append_col(flat_xs, flat_trends) * append_row(b, 1.0)) .* flat_phis, (1 - inv_logit(append_col(flat_xs, flat_trends) * append_row(b, 1.0))) .* flat_phis);"
        )

      model_file <- model_file[-(linenum + 1)]
      model_file <- readLines(textConnection(model_file), n = -1)
    }
  }

  if (family_char == 'bernoulli') {
    if (
      any(grepl("flat_ys ~ poisson_log_glm(flat_xs,", model_file, fixed = TRUE))
    ) {
      model_file[grep(
        "flat_ys ~ poisson_log_glm(flat_xs,",
        model_file,
        fixed = TRUE
      )] <-
        "flat_ys ~ bernoulli_logit_glm(flat_xs,"
    }

    if (
      any(grepl(
        "flat_ys ~ poisson_log_glm(append_col(flat_xs, flat_trends),",
        model_file,
        fixed = TRUE
      ))
    ) {
      model_file[grep(
        "flat_ys ~ poisson_log_glm(append_col(flat_xs, flat_trends),",
        model_file,
        fixed = TRUE
      )] <-
        "flat_ys ~ bernoulli_logit_glm(append_col(flat_xs, flat_trends),"
    }
  }

  # Update the generated quantities block
  if (family_char == 'binomial') {
    model_file[grep(
      "ypred[1:n, s] = poisson_log_rng(mus[1:n, s]);",
      model_file,
      fixed = TRUE
    )] <-
      "ypred[1:n, s] = binomial_rng(flat_trials[ytimes[1:n, s]], inv_logit(mus[1:n, s]));"
  }

  if (family_char == 'beta_binomial') {
    model_file[grep("vector[total_obs] eta;", model_file, fixed = TRUE)] <-
      paste0("vector[total_obs] eta;\n", "matrix[n, n_series] phi_vec;")

    model_file[grep("eta = X * b;", model_file, fixed = TRUE)] <-
      paste0(
        "eta = X * b;;\n",
        "for (s in 1 : n_series) {\n",
        "phi_vec[1 : n, s] = rep_vector(phi[s], n);\n",
        "}"
      )

    model_file[grep(
      "ypred[1:n, s] = poisson_log_rng(mus[1:n, s]);",
      model_file,
      fixed = TRUE
    )] <-
      "ypred[1:n, s] = beta_binomial_rng(flat_trials[ytimes[1:n, s]], inv_logit(mus[1:n, s]) .* phi_vec[1:n, s], (1 - inv_logit(mus[1:n, s])) .* phi_vec[1:n, s]);"
  }

  if (family_char == 'bernoulli') {
    model_file[grep(
      "ypred[1:n, s] = poisson_log_rng(mus[1:n, s]);",
      model_file,
      fixed = TRUE
    )] <-
      "ypred[1:n, s] = bernoulli_logit_rng(mus[1:n, s]);"
  }

  #### Return ####
  return(list(
    model_file = model_file,
    model_data = model_data,
    trials = trials
  ))
}
nicholasjclark/mvgam documentation built on April 17, 2025, 9:39 p.m.