R/bayesian_multivariate_regression.R

Defines functions multivariate_lbf multivariate_regression

Documented in multivariate_lbf multivariate_regression

#' @title Multiviate regression object
#' @importFrom R6 R6Class
#' @importFrom MASS ginv
#' @keywords internal
BayesianMultivariateRegression <- R6Class("BayesianMultivariateRegression",
  inherit = BayesianSimpleRegression,
  public = list(
    initialize = function(J, residual_variance, prior_variance) {
      private$J = J
      private$.prior_variance = prior_variance
      private$.residual_variance = residual_variance
      private$.posterior_b1 = matrix(0, J, nrow(prior_variance))
      private$prior_variance_scale = 1
      tryCatch({
        private$.residual_variance_inv = invert_via_chol(residual_variance)
      }, error = function(e) {
        warning(paste0('Cannot compute inverse for residual variance due to error:\n', e, '\nELBO computation will thus be skipped.'))
      })
    },
    fit = function(d, prior_weights = NULL, use_residual = FALSE, save_summary_stats = FALSE, save_var = FALSE, estimate_prior_variance_method = NULL) {
      # d: data object
      # use_residual: fit with residual instead of with Y,
      # a special feature for when used with SuSiE algorithm
      if (d$Y_has_missing) stop("Cannot work with missing data in Bayesian Multivariate Regression module.")
      if (use_residual) XtY = d$XtR
      else XtY = d$XtY
      # OLS estimates
      # bhat is J by R
      bhat = XtY / d$X2_sum
      bhat[which(is.nan(bhat))] = 0
      if (d$Y_has_missing) sbhat2 = lapply(1:nrow(d$X2_sum), function(j) private$.residual_variance / d$X2_sum[j,])
      else sbhat2 = lapply(1:length(d$X2_sum), function(j) private$.residual_variance / d$X2_sum[j])
      for (j in 1:length(sbhat2)) {
        sbhat2[[j]][which(is.nan(sbhat2[[j]]) | is.infinite(sbhat2[[j]]))] = 1E6
      }
      if (save_summary_stats) {
        private$.bhat = bhat
        private$.sbhat = sqrt(do.call(rbind, lapply(1:length(sbhat2), function(j) diag(sbhat2[[j]]))))
      }
      if (d$Y_has_missing) stop("Computation involving missing data in Y has not been implemented in BayesianMultivariateRegression method.")
      # deal with prior variance: can be "estimated" across effects
      if(!is.null(estimate_prior_variance_method)) {
        if (estimate_prior_variance_method == "EM") {
          private$cache = list(b=bhat, s=sbhat2, update_scale=T)
        } else {
          private$prior_variance_scale = private$estimate_prior_variance(bhat,sbhat2,prior_weights,method=estimate_prior_variance_method)
        }
      }
      # posterior
      post = multivariate_regression(bhat, sbhat2, private$.prior_variance * private$prior_variance_scale)
      private$.posterior_b1 = post$b1
      private$.posterior_b2 = post$b2
      if (save_var) private$.posterior_variance = post$cov
      private$.lbf = post$lbf
    }
  ),
  active = list(
    residual_variance_inv = function() private$.residual_variance_inv,
    residual_variance = function(v) {
      if (missing(v)) private$.residual_variance
      else {
        private$.residual_variance = v
        private$.residual_variance_inv = invert_via_chol(v)
      }
    },
    prior_variance = function() private$prior_variance_scale
  ),
  private = list(
    .residual_variance_inv = NULL,
    .prior_variance_inv = NULL,
    prior_variance_scale = NULL,
    loglik = function(scalar, bhat, S, prior_weights) {
      U = private$.prior_variance * scalar
      lbf = multivariate_lbf(bhat, S, U)
      return(compute_softmax(lbf, prior_weights)$log_sum)
    },
    estimate_prior_variance_optim = function(betahat, shat2, prior_weights, ...) {
      # log(1) = 0
      lV = optim(par=0, fn=private$neg_loglik_logscale, betahat=betahat, shat2=shat2, prior_weights = prior_weights, ...)$par
      return(exp(lV))
    },
    estimate_prior_variance_em = function(pip) {
      if (length(dim(private$.posterior_b2)) == 3) {
        # when R > 1
        mu2 = Reduce("+", lapply(1:length(pip), function(j) pip[j] * private$.posterior_b2[,,j]))
      } else {
        # when R = 1 each post_b2 is a scalar.
        # Now make it a matrix to be compatable with later computations.
        if (ncol(private$.posterior_b2) != 1) stop("Data dimension is incorrect for posterior_b2")
        mu2 = matrix(sum(pip * private$.posterior_b2[,1]), 1,1)
      }
      if (is.null(private$.prior_variance_inv)) private$.prior_variance_inv = ginv(private$.prior_variance)
      V = sum(diag(private$.prior_variance_inv %*% mu2)) / nrow(private$.prior_variance)
      return(V)
    },
    estimate_prior_variance_simple = function() 1
  )
)

#' @title Multiviate regression calculations
#' @importFrom abind abind
#' @keywords internal
multivariate_regression = function(bhat, S, U) {
  # FIXME: this can be pre-computed to save some computations
  S_inv = lapply(1:length(S), function(j) invert_via_chol(S[[j]]))
  post_cov = lapply(1:length(S), function(j) U %*% solve(diag(nrow(U)) + S_inv[[j]] %*% U))
  lbf = sapply(1:length(S), function(j) 0.5 * (log(det(S[[j]])) - log(det(S[[j]]+U))) + 0.5*t(bhat[j,])%*%S_inv[[j]]%*%post_cov[[j]]%*%S_inv[[j]]%*%bhat[j,])
  lbf[which(is.nan(lbf))] = 0
  # lbf = multivariate_lbf(bhat, S, U)
  # using rbind here will end up with dimension issues for degenerated case on J; have to use t(...(cbind, )) instead
  post_b1 = t(do.call(cbind, lapply(1:length(S), function(j) post_cov[[j]] %*% (S_inv[[j]] %*% bhat[j,]))))
  post_b2 = lapply(1:length(post_cov), function(j) tcrossprod(post_b1[j,]) + post_cov[[j]])
  # deal with degerated case with 1 condition
  if (ncol(post_b1) == 1) {
    post_b2 = matrix(unlist(post_b2), length(post_b2), 1)
  } else {
    post_b2 = aperm(abind(post_b2, along = 3), c(2,1,3))
  }
  return(list(b1 = post_b1, b2 = post_b2, lbf = lbf, cov = post_cov))
}

#' @title Multiviate logBF
#' @importFrom mvtnorm dmvnorm
#' @keywords internal
multivariate_lbf = function(bhat, S, U) {
  lbf = sapply(1:length(S), function(j) dmvnorm(x = bhat[j,],sigma = S[[j]] + U,log = T) - dmvnorm(x = bhat[j,],sigma = S[[j]],log = T))
  lbf[which(is.nan(lbf))] = 0
  return(lbf)
}
gaow/mmbr documentation built on March 25, 2020, 4:26 p.m.