R/moretrees_init_rand.R

Defines functions moretrees_init_rand

Documented in moretrees_init_rand

# --------------------------------------------------------------------------------- #
# -------------------- moretrees initial values function -------------------------- #
# --------------------------------------------------------------------------------- #

#' Here's a brief description.
#'   \code{moretrees_init_rand} Randomly generates starting values for moretrees
#'   models. Not recommended if the model is converging slowly!
#' 
#' @export
#' @useDynLib moretrees
#' 
#' @section Model Description:
#' Describe MOReTreeS model and all parameters here.
#' 
#' @param dsgn Design list generated by moretrees_design_tree()
#' @param xxT Computed from exposure matrix X
#' @param wwT Computed from covariate matrix W
#' @param update_hyper Update hyperparameters? Default = TRUE.
#' @param hyper_fixed Fixed values of hyperparameters to use if update_hyper = FALSE.
#' If family = "bernoulli", this should be a list including the following elements:
#' tau (prior variance for sparse node coefficients)
#' rho (prior node selection probability for sparse node coefficients)
#' omega (prior variance for non-sparse node coefficients)
#' If family = "gaussian", in addition to the above, the list should also include:
#' sigma2 (variance of residuals)
#' @param hyper_random_init If update_hyper = TRUE, this is a list containing the 
#' maximum values of the hyperparameters. Each hyperparameter will be initialised
#' uniformly at random between 0 and the maximum values given by the list elements
#' below. If multiple random restarts are being used, it is recommended
#' to use a large range for these initial values so that the parameter space
#' can be more effectively explored. The list contains the following elements:
#' tau_max (maxmimum of prior sparse node variance)
#' omega_max (maximum of prior non-sparse node variance)
#' sigma2_max (maximum of residual error variance--- for gaussian data only)
#' @param vi_random_init A list with parameters that determine the distributions from
#' which the initial VI parameters will be randomly chosen. All parameters will be randomly
#' selected from independent normal distributions with the standard deviations given by
#' the list elements below. If multiple random restarts are being used, it is recommended
#' to use large standard deviations for these initial values so that the parameter space
#' can be more effectively explored. The list contains the following elements:
#' mu_sd (standard deviation for posterior means of sparse node coefficients)
#' delta_sd (standard deviation for posterior means of non-sparse node coefficients)
#' xi_sd (standard deviation for auxilliary parameters xi--- for bernoulli data only)
#' @return A list containing starting values
#' @examples 
#' @family MOReTreeS functions

moretrees_init_rand <- function(X, W, y,
                      outcomes_units,
                      outcomes_nodes,
                      ancestors,
                      xxT, wwT,
                      initial_values,
                      update_hyper,
                      hyper_fixed,
                      vi_random_init,
                      hyper_random_init) {
  
  n <- length(y)
  m <- ncol(W)
  p <- length(unique(unlist(ancestors)))
  pL <- length(ancestors)
  K <- ncol(X)
  
  eta <- abs(rnorm(n, mean = 0, sd = vi_random_init$eta_sd))
  g_eta <- gfun(eta)
  if (update_hyper) {
    # If hyperparameters will be updated, randomly initialise them
    hyperparams <- list(omega = runif(1, 0, hyper_random_init$omega_max),
                        tau = runif(1, 0, hyper_random_init$tau_max))
  } else {
    # Otherwise, use fixed values
    hyperparams <- hyper_fixed
  }
  
  if (m == 0) {
    hyperparams$omega <- 1
  }
  hyperparams$eta <- eta
  hyperparams$g_eta <- g_eta
  # Variational parameter initial values
  xxT_g_eta <- lapply(X = outcomes_units, FUN = xxT_g_eta_fun,
                      xxT = xxT, g_eta = g_eta, K = K)
  Sigma_inv <- lapply(X = outcomes_nodes, 
                      FUN = function(outcomes, x, K, tau) 2 * Reduce(`+`, x[outcomes]) + 
                        diag(1 / tau, nrow = K),
                      x = xxT_g_eta,
                      K = K,
                      tau = hyperparams$tau)
  Sigma <- lapply(Sigma_inv, solve)
  Sigma_det <- sapply(Sigma, det)
  mu <- lapply(X = 1:p, FUN = function(i) matrix(rnorm(K), ncol = 1))
  prob <- runif(p, 0 , 1)
  a_rho <- 1 + sum(prob) # need to initialise a_rho and b_rho using VI updates
  b_rho <- 1 + p - sum(prob) # so that terms cancel in ELBO.
  # otherwise first ELBO will be wrong
  tau_t <- rep(hyperparams$tau, p)
  delta <- lapply(X = 1:p, FUN = function(i) matrix(rnorm(m), ncol = 1))
  if (m > 0) {
    wwT_g_eta <- lapply(X = outcomes_units, FUN = xxT_g_eta_fun,
                        xxT = wwT, g_eta = g_eta, K = m)
    Omega_inv <- lapply(X = outcomes_nodes, 
                        FUN = function(outcomes, w, m, omega) 2 * Reduce(`+`, w[outcomes]) + 
                          diag(1 / omega, nrow = m),
                        w = wwT_g_eta,
                        m = m,
                        omega = hyperparams$omega)
    Omega <- sapply(Omega_inv, solve, simplify = F)
    Omega_det <- sapply(Omega, det, simplify = T)
  } else {
    Omega <- rep(list(matrix(nrow = 0, ncol = 0)), p)
    Omega_inv <- rep(list(matrix(nrow = 0, ncol = 0)), p)
    Omega_det <- rep(1, p)
  }
  # Put VI parameters in list
  vi_params <- list(mu = mu, prob = prob, Sigma = Sigma,
                    Sigma_inv = Sigma_inv, Sigma_det = Sigma_det,
                    tau_t = tau_t, delta = delta,
                    Omega = Omega, Omega_inv = Omega_inv,
                    Omega_det = Omega_det,
                    a_rho = a_rho, b_rho = b_rho)
  
  # Compute initial ELBO
  hyperparams <-  update_hyperparams_logistic_moretrees(X = X, 
                                                        W = W,
                                                        y = y, 
                                                        outcomes_units = outcomes_units,
                                                        ancestors = ancestors,
                                                        n = n, K = K, p = p, m = m,
                                                        prob = prob, mu = mu,
                                                        Sigma = Sigma, Sigma_det = Sigma_det,
                                                        tau_t = tau_t, delta = delta,
                                                        Omega = Omega, Omega_det = Omega_det,
                                                        eta = hyperparams$eta, g_eta = hyperparams$g_eta,
                                                        omega = hyperparams$omega, tau = hyperparams$tau,
                                                        a_rho = a_rho, b_rho = b_rho,
                                                        update_hyper = F)
  
  return(list(vi_params = vi_params, hyperparams = hyperparams))
}
IQSS/moretrees documentation built on March 20, 2020, 8:44 p.m.