R/full_em_algo_functs.R

Defines functions run_full_glmeiv_given_pilot_params run_full_glmeiv_given_fitted_means compute_tolerance extract_glm_fitted_means run_m_step_full augment_inputs run_full_glmeiv_given_weights run_e_step

Documented in run_e_step run_full_glmeiv_given_pilot_params run_full_glmeiv_given_weights

#' Run E step
#'
#' Runs the E step of the GLM-EIV EM algorithm. Also, returns the updated value of pi (thereby running a portion of the subsequent M step).
#'
#' Ensures that the value of pi in the next M iteration will be less than 0.5. This core function is called in all variants of GLM-EIV (weights/estimates first, full/reduced, etc.).
#'
#' @param m_fam family object describing mRNA counts
#' @param g_fam family object describing gRNA counts
#' @param m mRNA counts
#' @param g gRNA counts
#' @param m_mus_pert0 fitted (or hypothesized) means mu^m(0)
#' @param m_mus_pert1 fitted (or hypothesized) means mu^m(1)
#' @param g_mus_pert0 fitted (or hypothesized) means mu^g(0)
#' @param g_mus_pert1 fitted (or hypothesized) means mu^g(1)
#' @param fit_pi fitted (or hypothesized) value for pi
#'
#' @return a list containing (i) the membership probabilities (Ti1s), (ii) the model log-likelihood, and (iii) the new value of pi (computed ahead of the subsequent M step for convenience).
run_e_step <- function(m_fam, g_fam, m, g, m_mus_pert0, m_mus_pert1, g_mus_pert0, g_mus_pert1, fit_pi, use_mrna_modality = TRUE) {
  # first, compute log-likelihood
  p0 <- exp(log(1 - fit_pi) + g_fam$log_py_given_mu(g, g_mus_pert0) + (if (use_mrna_modality) m_fam$log_py_given_mu(m, m_mus_pert0) else 0) )
  p1 <- exp(log(fit_pi) + g_fam$log_py_given_mu(g, g_mus_pert1) + (if (use_mrna_modality) m_fam$log_py_given_mu(m, m_mus_pert1) else 0) )

  s <- p0 + p1
  if (0 %in% s) {
    s_wo_0 <- s[s != 0]
    s[s == 0] <- min(s_wo_0)
  }
  log_lik <- sum(log(s))

  # second, compute membership probabilities
  quotient <- log(1 - fit_pi) - log(fit_pi) + g_fam$d_log_py(g, g_mus_pert0, g_mus_pert1) + (if (use_mrna_modality) m_fam$d_log_py(m, m_mus_pert0, m_mus_pert1) else 0)
  Ti1s <- 1/(exp(quotient) + 1)
  # estimate new_pi and ensure new_pi is less than 0.5
  new_pi <- sum(Ti1s)/length(Ti1s)
  if (new_pi > 0.5) {
    Ti1s <- 1 - Ti1s
    new_pi <- 1 - new_pi
  }
  # return quantities
  return(list(Ti1s = Ti1s, log_lik = log_lik, new_pi = new_pi))
}


#' Run full GLM-EIV given weights
#'
#' Runs the full GLM-EIV model given starting weights. Starting
#'
#' @param m mRNA counts
#' @param g gRNA counts
#' @param m_fam family describing m
#' @param g_fam family describing g
#' @param covariate_matrix the matrix of covariates; NULL if there are no covariates.
#' @param initial_Ti1s starting membership probabilities; these probabilities should be such that sum(initial_Ti1s)/length(initial_Ti1s) < 0.5.
#' @param prev_log_lik optional starting log-likelihood value (useful if an E step has been called ahead of running this function).
#' @param m_offset offsets for m
#' @param g_offset offsets for g
#' @param ep_tol (optional) EM convergence threshold
#' @param max_it  (optional) maximum number of EM iterations
#'
#' @return a fitted GLM-EIV object
#' @export
#'
#' @examples
#' m_fam <- g_fam <- augment_family_object(poisson())
#' n <- 5000
#' lib_size <- rpois(n = n, lambda = 5000)
#' m_offset <- g_offset <- log(lib_size)
#' pi <- 0.1
#' m_intercept <- log(0.05)
#' m_perturbation <- log(0.8)
#' g_intercept <- log(0.025)
#' g_perturbation <- log(1.2)
#' covariate_matrix <- data.frame(batch = rbinom(n = n, size = 1, prob = 0.5))
#' m_covariate_coefs <- log(0.9)
#' g_covariate_coefs <- log(1.1)
#' dat <- generate_full_data(m_fam = m_fam, m_intercept = m_intercept,
#' m_perturbation = m_perturbation, g_fam = g_fam, g_intercept = g_intercept,
#' g_perturbation = g_perturbation, pi = pi, n = n, B = 2,
#' covariate_matrix = covariate_matrix, m_covariate_coefs = m_covariate_coefs,
#' g_covariate_coefs = g_covariate_coefs, m_offset = m_offset, g_offset = g_offset)[[1]]
#' m <- dat$m
#' g <- dat$g
#' initial_Ti1s <- runif(n)
#' fit <- run_full_glmeiv_given_weights(m, g, m_fam, g_fam, covariate_matrix,
#' initial_Ti1s, m_offset, g_offset)
run_full_glmeiv_given_weights <- function(m, g, m_fam, g_fam, covariate_matrix, initial_Ti1s, m_offset, g_offset, prev_log_lik = -Inf, ep_tol = 1e-4, max_it = 15, use_mrna_modality = TRUE) {
  # verify column names ok
  check_col_names(covariate_matrix)

  # define some basic quantities
  n <- length(g)
  iteration <- 1L
  converged <- FALSE
  curr_Ti1s <- initial_Ti1s
  log_liks <- if (prev_log_lik == -Inf) numeric() else prev_log_lik
  augmented_inputs <- augment_inputs(covariate_matrix, m, g, m_offset, g_offset, n)

  # iterate through M and E steps (in that order) until convergence
  while (!converged) {
    # M step
    m_step <- run_m_step_full(curr_Ti1s = curr_Ti1s,
                              m_augmented = augmented_inputs$m_augmented,
                              m_fam = m_fam,
                              m_offset_augmented = augmented_inputs$m_offset_augmented,
                              g_augmented = augmented_inputs$g_augmented,
                              g_fam = g_fam,
                              g_offset_augmented = augmented_inputs$g_offset_augmented,
                              Xtilde_augmented = augmented_inputs$Xtilde_augmented, n = n,
                              use_mrna_modality = use_mrna_modality)
    # E step
    e_step <- run_e_step(m_fam = m_fam, g_fam = g_fam, m = m, g = g,
                         m_mus_pert0 = m_step$m_mus_pert0,
                         m_mus_pert1 = m_step$m_mus_pert1,
                         g_mus_pert0 = m_step$g_mus_pert0,
                         g_mus_pert1 = m_step$g_mus_pert1,
                         fit_pi = m_step$fit_pi,
                         use_mrna_modality = use_mrna_modality)
    # append current log-likelihood, check for convergence
    curr_Ti1s <- e_step$Ti1s
    curr_log_lik <- e_step$log_lik
    log_liks <- c(log_liks, curr_log_lik)
    tol <- compute_tolerance(curr_log_lik, prev_log_lik)
    if (tol < ep_tol) {
      converged <- TRUE
    } else {
      prev_log_lik <- curr_log_lik
      iteration <- iteration + 1L
      if (iteration >= max_it) break()
    }
  }
  out <- list(fit_m = m_step$fit_m, fit_g = m_step$fit_g, fit_pi = m_step$fit_pi,
              n_iterations = iteration, log_liks = log_liks, log_lik = curr_log_lik,
              converged = converged, n = n, posterior_perturbation_probs = curr_Ti1s)
  return(out)
}


augment_inputs <- function(covariate_matrix, m, g, m_offset, g_offset, n) {
  if (is.null(covariate_matrix)) {
    Xtilde_augmented <- data.frame(perturbation = c(rep(0, n), rep(1, n)))
  } else {
    Xtilde_0 <- dplyr::mutate(covariate_matrix, perturbation = 0)
    Xtilde_1 <- dplyr::mutate(covariate_matrix, perturbation = 1)
    Xtilde_augmented <- rbind(Xtilde_0, Xtilde_1) %>% dplyr::select(perturbation, everything())
  }
  m_augmented <- c(m, m)
  g_augmented <- c(g, g)
  m_offset_augmented <- if (!is.null(m_offset)) c(m_offset, m_offset) else NULL
  g_offset_augmented <- if (!is.null(g_offset)) c(g_offset, g_offset) else NULL
  out <- list(Xtilde_augmented = Xtilde_augmented,
              m_augmented = m_augmented,
              g_augmented = g_augmented,
              m_offset_augmented = m_offset_augmented,
              g_offset_augmented = g_offset_augmented)
  return(out)
}


run_m_step_full <- function(curr_Ti1s, m_augmented, m_fam, m_offset_augmented, g_augmented, g_fam, g_offset_augmented, Xtilde_augmented, n, use_mrna_modality = TRUE) {
  fit_pi <- sum(curr_Ti1s)/n
  # fit models for m and g using weights
  weights <- c(1 - curr_Ti1s, curr_Ti1s)

  if (use_mrna_modality) {
    fit_m <- stats::glm(formula = stats::formula("m_augmented ~ ."), data = Xtilde_augmented,
                        family = m_fam, weights = weights, offset = m_offset_augmented)
    fitted_means_m <- extract_glm_fitted_means(fit_m, m_fam, n)
  } else {
    fit_m <- fitted_means_m <- NULL
  }

  fit_g <- stats::glm(formula = stats::formula("g_augmented ~ ."), data = Xtilde_augmented,
                      family = g_fam, weights = weights, offset = g_offset_augmented)
  fitted_means_g <- extract_glm_fitted_means(fit_g, g_fam, n)

  # return list fitted means, as well as the fitted GLM objects themselves
  out <- list(fit_m = fit_m, fit_g = fit_g, fit_pi = fit_pi,
              m_mus_pert0 = fitted_means_m$mus_pert0, m_mus_pert1 = fitted_means_m$mus_pert1,
              g_mus_pert0 = fitted_means_g$mus_pert0, g_mus_pert1 = fitted_means_g$mus_pert1)
  return(out)
}


# extract the fitted means
extract_glm_fitted_means <- function(fit, fam, n) {
  mus <- fam$linkinv(as.numeric(fit$linear.predictors))
  mus_pert0 <- mus[seq(1, n)]
  mus_pert1 <- mus[seq(n + 1, 2 * n)]
  out <- list(mus_pert0 = mus_pert0, mus_pert1 = mus_pert1)
  return(out)
}


compute_tolerance <- function(curr_log_lik, prev_log_lik) {
  if (curr_log_lik == -Inf) {
    tol <- Inf
  } else {
    tol <- abs(curr_log_lik - prev_log_lik)/min(abs(curr_log_lik), abs(prev_log_lik))
  }
  return(tol)
}


run_full_glmeiv_given_fitted_means <- function(m_fam, g_fam, m, g, m_mus_pert0, m_mus_pert1, g_mus_pert0, g_mus_pert1, fit_pi, covariate_matrix, m_offset, g_offset, ep_tol = 1e-4, max_it = 75, use_mrna_modality = TRUE) {
  e_step <- run_e_step(m_fam = m_fam, g_fam = g_fam, m = m, g = g,
                       m_mus_pert0 = m_mus_pert0, m_mus_pert1 = m_mus_pert1,
                       g_mus_pert0 = g_mus_pert0, g_mus_pert1 = g_mus_pert1,
                       fit_pi = fit_pi, use_mrna_modality = use_mrna_modality)
  out <- run_full_glmeiv_given_weights(m = m, g = g, m_fam = m_fam, g_fam = g_fam,
                                       covariate_matrix = covariate_matrix, initial_Ti1s = e_step$Ti1s,
                                       m_offset = m_offset, g_offset = g_offset, prev_log_lik = e_step$log_lik,
                                       ep_tol = ep_tol, max_it = max_it, use_mrna_modality = use_mrna_modality)
  return(out)
}


#' Run full GLM-EIV given pilot estiamtes
#'
#' @param pi_guess pilot guess for pi
#' @param m_intercept_guess pilot guess for m_intercept
#' @param m_perturbation_guess pilot guess for m_perturbation
#' @param m_covariate_coefs_guess pilot guess for m_covariate_coefficients
#' @param g_intercept_guess pilot guess for g_intercept
#' @param g_perturbation_guess pilot guess for g_perturbation
#' @param g_covariate_coefs_guess pilot guess for g_covariate_coefs
#' @return the fitted GLM-EIV object.
#' @export
#'
#' @inheritParams run_full_glmeiv_given_weights
#' @examples
#' set.seed(4)
#' m_fam <- g_fam <- augment_family_object(poisson())
#' n <- 5000
#' lib_size <- rpois(n = n, lambda = 5000)
#' m_offsets <- g_offsets <- log(lib_size)
#' pi <- 0.3
#' m_intercept <- log(0.05)
#' m_perturbation <- log(0.75)
#' g_intercept <- log(0.025)
#' g_perturbation <- log(1.4)
#' covariate_matrix <- data.frame(batch = rbinom(n = n, size = 1, prob = 0.5))
#' m_covariate_coefs <- log(0.9)
#' g_covariate_coefs <- log(1.1)
#' dat <- generate_full_data(m_fam = m_fam, m_intercept = m_intercept,
#' m_perturbation = m_perturbation, g_fam = g_fam, g_intercept = g_intercept,
#' g_perturbation = g_perturbation, pi = pi, n = n, B = 2,
#' covariate_matrix = covariate_matrix, m_covariate_coefs = m_covariate_coefs,
#' g_covariate_coefs = g_covariate_coefs, m_offset = m_offsets, g_offset = g_offsets)[[1]]
#' m <- dat$m
#' g <- dat$g
#' fit <- run_full_glmeiv_given_pilot_params(m = m, g = g, m_fam = m_fam, g_fam = g_fam,
#' pi_guess = 0.15, m_intercept_guess = log(0.1), m_perturbation_guess = log(1),
#' m_covariate_coefs_guess = log(1.4), g_intercept_guess = log(0.05),
#' g_perturbation_guess = log(1.4), g_covariate_coefs_guess = log(1.2),
#' covariate_matrix = covariate_matrix, m_offset = m_offsets, g_offset = g_offsets)
run_full_glmeiv_given_pilot_params <- function(m, g, m_fam, g_fam, pi_guess, m_intercept_guess, m_perturbation_guess, m_covariate_coefs_guess, g_intercept_guess, g_perturbation_guess, g_covariate_coefs_guess, covariate_matrix, m_offset, g_offset, ep_tol = 1e-5, max_it = 75, use_mrna_modality = TRUE) {
  # compute the conditional means
  if (use_mrna_modality) {
    m_conditional_means <- compute_theoretical_conditional_means(intercept = m_intercept_guess,
                                                                 perturbation_coef = m_perturbation_guess,
                                                                 fam = m_fam,
                                                                 covariate_matrix = covariate_matrix,
                                                                 covariate_coefs = m_covariate_coefs_guess,
                                                                 offset = m_offset)
    m_mus_pert0 <- m_conditional_means$mu0; m_mus_pert1 <- m_conditional_means$mu1
  } else {
    m_mus_pert0 <- m_mus_pert1 <- NULL
  }

  g_conditional_means <- compute_theoretical_conditional_means(intercept = g_intercept_guess,
                                                               perturbation_coef = g_perturbation_guess,
                                                               fam = g_fam,
                                                               covariate_matrix = covariate_matrix,
                                                               covariate_coefs = g_covariate_coefs_guess,
                                                               offset = g_offset)
  g_mus_pert0 <- g_conditional_means$mu0; g_mus_pert1 <- g_conditional_means$mu1

  # run glmeiv given fitted means
  out <- run_full_glmeiv_given_fitted_means(m_fam = m_fam, g_fam = g_fam, m = m, g = g,
                                            m_mus_pert0 = m_mus_pert0, m_mus_pert1 = m_mus_pert1,
                                            g_mus_pert0 = g_mus_pert0, g_mus_pert1 = g_mus_pert1,
                                            fit_pi = pi_guess, covariate_matrix = covariate_matrix,
                                            m_offset = m_offset, g_offset = g_offset, ep_tol = ep_tol,
                                            max_it = max_it, use_mrna_modality = use_mrna_modality)
  return(out)
}
timothy-barry/glmeiv documentation built on Jan. 30, 2024, 3:46 p.m.