R/moretrees_init_W_logistic.R

Defines functions moretrees_init_W_logistic

Documented in moretrees_init_W_logistic

# --------------------------------------------------------------------------------- #
# -------------------- moretrees initial values function -------------------------- #
# --------------------------------------------------------------------------------- #

#' Here's a brief description.
#'   \code{moretrees_init_rand} Randomly generates starting values for moretrees
#'   models. Not recommended if the model is converging slowly!
#' 
#' @export
#' @useDynLib moretrees
#' 
#' @section Model Description:
#' Describe MOReTreeS model and all parameters here.
#' 
#' @param dsgn Design list generated by moretrees_design_tree()
#' @param xxT Computed from exposure matrix X
#' @param wwT Computed from covariate matrix W
#' @param update_hyper Update hyperparameters? Default = TRUE.
#' @param hyper_fixed Fixed values of hyperparameters to use if update_hyper = FALSE.
#' If family = "bernoulli", this should be a list including the following elements:
#' tau (prior variance for sparse node coefficients)
#' rho (prior node selection probability for sparse node coefficients)
#' omega (prior variance for non-sparse node coefficients)
#' If family = "gaussian", in addition to the above, the list should also include:
#' sigma2 (variance of residuals)
#' @param hyper_random_init If update_hyper = TRUE, this is a list containing the 
#' maximum values of the hyperparameters. Each hyperparameter will be initialised
#' uniformly at random between 0 and the maximum values given by the list elements
#' below. If multiple random restarts are being used, it is recommended
#' to use a large range for these initial values so that the parameter space
#' can be more effectively explored. The list contains the following elements:
#' tau_max (maxmimum of prior sparse node variance)
#' omega_max (maximum of prior non-sparse node variance)
#' sigma2_max (maximum of residual error variance--- for gaussian data only)
#' @param vi_random_init A list with parameters that determine the distributions from
#' which the initial VI parameters will be randomly chosen. All parameters will be randomly
#' selected from independent normal distributions with the standard deviations given by
#' the list elements below. If multiple random restarts are being used, it is recommended
#' to use large standard deviations for these initial values so that the parameter space
#' can be more effectively explored. The list contains the following elements:
#' mu_sd (standard deviation for posterior means of sparse node coefficients)
#' delta_sd (standard deviation for posterior means of non-sparse node coefficients)
#' xi_sd (standard deviation for auxilliary parameters xi--- for bernoulli data only)
#' @return A list containing starting values
#' @examples 
#' @family MOReTreeS functions

moretrees_init_W_logistic <- function(X, W, y, A,
                                initial_values,
                                outcomes_units,
                                outcomes_nodes,
                                ancestors,
                                xxT, wwT,
                                update_hyper,
                                hyper_fixed) {
  
  n <- length(y)
  m <- ncol(W)
  p <- length(unique(unlist(ancestors)))
  pL <- length(ancestors)
  K <- ncol(X)
  vi_params <- initial_values$vi_params
  hyperparams <- initial_values$hyperparams
  
  # Get starting value for omega -----------------------------------
  if (update_hyper) {
    hyperparams$omega <- hyperparams$tau
  } else {
    hyperparams$tau <- hyper_fixed$tau
    hyperparams$omega <- hyper_fixed$omega
  }
  
  # Get starting values for delta ----------------------------------
  vi_params$delta <-lapply(1:p, function(i) matrix(0, nrow = m, ncol = 1))
  xi <- mapply(`*`, vi_params$prob, vi_params$mu, SIMPLIFY = F)
  Xbeta <- numeric(n) + 0
  for (u in 1:pL) {
    beta_u <- Reduce(`+`, xi[ancestors[[u]]])
    Xbeta[outcomes_units[[u]]] <- X[outcomes_units[[u]], 
                                              ] %*% beta_u
  }
  wwT_g_eta <- lapply(X = outcomes_units, FUN = xxT_g_eta_fun, 
                      xxT = wwT, g_eta = hyperparams$g_eta, K = m)
  for (v in 1:p) {
    leaf_descendants <- outcomes_nodes[[v]]
    vi_params$Omega_inv[[v]] <- 2 * Reduce(`+`, wwT_g_eta[leaf_descendants]) + 
      diag(1 / hyperparams$omega, nrow = m)
    vi_params$Omega[[v]] <- solve(vi_params$Omega_inv[[v]])
    vi_params$Omega_det[v] <- det(vi_params$Omega[[v]])
    vi_params$delta[[v]] <- vi_params$delta[[v]] * 0
    for (u in leaf_descendants) {
      anc_u_mv <- setdiff(ancestors[[u]], v)
      units_u <- outcomes_units[[u]]
      theta_u_mv <- Reduce(`+`, vi_params$delta[anc_u_mv])
      vi_params$delta[[v]] <- vi_params$delta[[v]] + crossprod(W[units_u, , drop = FALSE], 
                (y[units_u]/2 - 2 * hyperparams$g_eta[units_u] *
                (W[units_u, , drop = FALSE] %*% theta_u_mv + Xbeta[units_u])))
    }
    vi_params$delta[[v]] <- vi_params$Omega[[v]] %*% vi_params$delta[[v]]
  }
  
  # Compute initial ELBO
  hyperparams <-  update_hyperparams_logistic_moretrees(X = X, 
                                                        W = W,
                                                        y = y, 
                                                        outcomes_units = outcomes_units,
                                                        ancestors = ancestors,
                                                        n = n, K = K, p = p, m = m,
                                                        prob = vi_params$prob, mu = vi_params$mu,
                                                        Sigma = vi_params$Sigma, Sigma_det = vi_params$Sigma_det,
                                                        tau_t = vi_params$tau_t, delta = vi_params$delta,
                                                        Omega = vi_params$Omega, Omega_det = vi_params$Omega_det,
                                                        eta = hyperparams$eta, g_eta = hyperparams$g_eta,
                                                        omega = hyperparams$omega, tau = hyperparams$tau,
                                                        a_rho = vi_params$a_rho, b_rho = vi_params$b_rho,
                                                        update_hyper = F)
  
  return(list(vi_params = vi_params, hyperparams = hyperparams))
}
IQSS/moretrees documentation built on March 20, 2020, 8:44 p.m.