R/nesmr_backselect.R

Defines functions nesmr_backselect

#' NESMR backselect algorithm
#'
#' @param mod_list List of initial models to search
#' @param beta_hat Estimated beta coefficients from GWAS
#' @param se_beta_hat Standard errors of beta coefficients from GWAS
#' @param mod_log_liks Optional log-likelihoods of the model. Otherwise, look at model$log_lik and then attempt to call log_py(model)
#' @param Z_true Optional true Z-scores to use for selection
#' @param method Currently only does AIC based select within `aic_cutoff` of the best current model
#' @param aic_cutoff AIC difference cutoff for selecting models
#' @param pvalue_cutoff P-value cutoff for selecting edges
#' @param alpha Significance level for selecting variants
#' @param ... Additional parameters to pass into esmr
#'
#' @return List of models selected by NESMR backselect
#' @export
#'
#' @examples
#' set.seed(13)
#' d <- 4
#' G <- matrix(0, nrow = d, ncol = d)
#' G_low <- lower.tri(G)
#' B_lower <- G_low + 0
#' G[d, 1] <- 0.1
#'
#' B_true <- (G != 0) + 0
#'
#' h2 <- 0.3
#' J <- 5000
#' N <- 30000
#' pi_J <- 0.1
#' alpha <- 5e-8
#' lambda <- qnorm(1 - alpha / 2)
#' dat <- GWASBrewer::sim_mv(
#'   G = G,
#'   N = N,
#'   J = J,
#'   h2 = h2,
#'   pi = pi_J,
#'   sporadic_pleiotropy = TRUE,
#'   est_s = TRUE
#' )
#'
#' Ztrue <- with(dat, beta_marg/se_beta_hat)
#' pval_true <- 2*pnorm(-abs(Ztrue))
#' minp <- apply(pval_true, 1, min)
#' ix <- which(minp < alpha)
#'
#' backselect_results <- nesmr_backselect(
#'   list(full_mod),
#'   beta_hat = dat$beta_hat,
#'   se_beta_hat = dat$s_estimate, Z_true = Ztrue,
#'   aic_cutoff = 10, pvalue_cutoff = 0.05 / sum(B_lower),
#'   alpha = 5e-8)
#'
#' # Print the AIC of the models and indicate the true generating model
#' back_aic <- sapply(backselect_results, function(x) x$aic)
#' min_aic <- which.min(back_aic)
#' true_ix <- which(sapply(backselect_results, function(x) all(x$B_template == B_true)))
#' plot(min(back_aic) - back_aic,
#'     pch = ifelse(1:length(back_aic) == true_ix, 8, 1)
#' )
#' legend("topleft", legend = c("True model"),
#'       pch = c(8, 1))
nesmr_backselect <- function(
    mod_list,
    beta_hat, se_beta_hat,
    variant_ix = NULL,
    method = c("aic", "elbo"),
    aic_cutoff = 2,
    pvalue_cutoff = 0.05,
    alpha = 5e-8,
    ...) {
  method <- match.arg(method)
  n_params <- sapply(mod_list, function(x) { sum(x$B_template) })
  d <- ncol(beta_hat)
  stopifnot(ncol(se_beta_hat) == d)

  if (method == "aic") {
    mod_list <- lapply(mod_list, function(x) {
      if (!"log_lik" %in% names(x)) x$log_lik <- log_py(x)
      if (! "aic" %in% names(x)) x$aic <- -2 * x$log_lik + 2 * sum(x$B_template)

      return(x)
    })
  }

  # This is called log_liks but can be either log_lik or elbo
  mod_obj <- sapply(mod_list, function(x) {
    if (method == "aic") x$log_lik
    else if (method == "elbo") x$elbo
    })

  # Psuedo AIC if we are using ELBO
  mod_aic <- -2 * mod_obj + 2 * n_params

  best_mod_aic <- min(mod_aic)

  if (is.null(variant_ix)) {
    Z <- beta_hat/se_beta_hat
    select_pval <- 2*pnorm(-abs(Z))
    minp <- apply(select_pval, 1, min)
    variant_ix <- which(minp < alpha)
  }

  # Use a queue to simulate breadth first search
  # First we remove all possible edges from the starting model and check stopping criterion
  # Then, we look at all possible edges from each of these subsets
  queue <- rstackdeque::rpqueue()
  for (mod in mod_list) {
    queue <- queue %>% rstackdeque::insert_back(mod)
  }

  visited <- list()
  return_mods <- mod_list

  while(! rstackdeque::empty(queue)) {
    curr_mod <- rstackdeque::peek_front(queue)
    queue <- rstackdeque::without_front(queue)

    B_template <- curr_mod$B_template

    # Do this to get only the maximum non-zero p-value
    edge_ix <- which(B_template != 0, arr.ind = TRUE)

    log_pvalues <- curr_mod$pvals_dm[edge_ix]
    # Filter out the ones < log(pvalue_cutoff)
    pvalue_order <- order(log_pvalues, decreasing = TRUE)
    log_pvalues <- log_pvalues[pvalue_order]
    edge_ix <- edge_ix[pvalue_order,, drop = FALSE]
    non_sig <- log_pvalues > log(pvalue_cutoff)
    n_non_sig <- sum(non_sig)

    # If all edges are significant, skip this model
    # The model will already be in the return_mods list from the previous iteration
    if (n_non_sig <= 0 || all(! non_sig)) {
      next
    }

    for (i in seq_len(n_non_sig)) {
      curr_edge <- edge_ix[i,, drop = FALSE]
      # Remove the edge with the highest p-value
      B_template[curr_edge] <- 0
      if (sum(B_template) == 0) {
        next
      }

      # Check if we have already visited this configuration
      # If so, skip all checks as it is either in the model list or the queue
      B_template_chr <- paste(B_template, collapse = "")
      if (B_template_chr %in% visited) {
        next
      } else {
        visited <- append(visited, B_template_chr)
      }

      # Fit the new model
      # Currently use the same variants across all models but this could change
      new_mod <- esmr(
        beta_hat_X = beta_hat,
        se_X = se_beta_hat,
        variant_ix = variant_ix,
        G = diag(d),
        direct_effect_template = B_template,
        direct_effect_init = B_template * mod$direct_effects,
        ...
      )

      new_mod$num_params <- sum(B_template)
      # TODO: Eventually want to do something faster than computing likelihood each time
      if (method == "aic") {
        new_mod$log_lik <- log_py(new_mod)
        new_mod$aic <- -2 * new_mod$log_lik + 2 * new_mod$num_params
        obj_value <- new_mod$aic
      } else {
        obj_value <- -2 * new_mod$elbo + 2 * new_mod$num_params
      }


      # TODO: Generalize this to "stopping criterion" function
      # Either aic difference, or posterior probability with different priors
      if (abs(best_mod_aic - obj_value) <= aic_cutoff) {
        best_mod_aic <- min(best_mod_aic, obj_value)
        all_B_strings <- sapply(return_mods, function(x) paste(x$B_template, collapse = ""))
        if (B_template_chr %in% all_B_strings) {
          warning("Duplicate element...")
          next
        } else {
          return_mods <- append(return_mods, list(new_mod))
        }
        queue <- queue %>% rstackdeque::insert_back(new_mod)
      } else {
        # TODO: Can we stop looking at edges with lower likelihood?
        # I think we can make the claim that if p-value for an edge is higher then
        # the likelihood will be lower but this is not guaranteed to be true
        # for all subsets of the graph
        # break
      }

      # Put the edge back
      B_template[curr_edge] <- 1
    }
  }

  # TODO: Figure out which condition causes the duplicate
#  remove_dups <- !duplicated(sapply(return_mods, function(x) x$B_template))
  return(return_mods)
}
jean997/mrScan documentation built on Dec. 20, 2024, 3:39 a.m.