###############################################################################
## Coordinate ascent for MOUTHWASH with uniform mixtures ----------------------
###############################################################################
#' Coordinate ascent for optimizing t likelihood with uniform mixtures.
#'
#' Set \code{degrees_freedom = Inf} if you want to use a normal likelihood.
#'
#' @inheritParams mouthwash_second_step
#' @param tol The tolerance for the stopping condition.
#' @param maxit The maximum number of iterations to run in the
#' optimization scheme.
#' @param pi_init The initial values of the mixing proportions. A
#' numeric vector of positive values that sum to one.
#' @param z_init The initial values of z2. A numeric vector.
#' @param xi_init The initial (or known) value of the variance
#' inflation term.
#' @param plot_update A logical. Should I plot the the path of the
#' log-likelihood (\code{TRUE}) or not (\code{FALSE})?
#'
#' @author David Gerard
mouthwash_coordinate <- function(pi_init, z_init, xi_init, betahat_ols, S_diag,
alpha_tilde, a_seq, b_seq, lambda_seq,
degrees_freedom, scale_var = TRUE,
tol = 10 ^ -6, maxit = 100, plot_update = FALSE,
var_inflate_pen = 0) {
## Make sure input is correct ---------------------------------------------
M <- length(a_seq)
k <- length(z_init)
p <- length(betahat_ols)
assertthat::are_equal(length(b_seq), M)
assertthat::are_equal(length(lambda_seq), M)
assertthat::are_equal(ncol(alpha_tilde), k)
assertthat::are_equal(length(S_diag), p)
assertthat::are_equal(nrow(alpha_tilde), p)
assertthat::are_equal(sum(pi_init), 1)
assertthat::assert_that(all(b_seq - a_seq > -10 ^ -14))
assertthat::assert_that(all(lambda_seq - 1 > -10 ^ -14))
assertthat::assert_that(xi_init > 0)
zero_spot <- which(abs(a_seq) < 10 ^ -14 & abs(b_seq) < 10 ^ -14)
assertthat::are_equal(length(zero_spot), 1)
pi_new <- pi_init
z_new <- z_init
xi_new <- xi_init
llike_new <- uniform_mix_llike(pi_vals = pi_new, z2 = z_new, xi = xi_new,
betahat_ols = betahat_ols, S_diag = S_diag,
alpha_tilde = alpha_tilde, a_seq = a_seq,
b_seq = b_seq, lambda_seq = lambda_seq,
degrees_freedom = degrees_freedom,
var_inflate_pen = var_inflate_pen)
llike_vec <- llike_new
err <- tol + 1
iter_index <- 1
while (err > tol & iter_index < maxit) {
llike_old <- llike_new
pi_old <- pi_new
z_old <- z_new
xi_old <- xi_new
## Update z with quasi-Newton ----------------------------------------------------
optim_out <- stats::optim(par = z_new, fn = uniform_mix_llike,
gr = mouthwash_z_grad, method = "BFGS",
pi_vals = pi_new, xi = xi_new, betahat_ols = betahat_ols,
S_diag = S_diag, alpha_tilde = alpha_tilde, a_seq = a_seq,
b_seq = b_seq, lambda_seq = lambda_seq,
degrees_freedom = degrees_freedom,
control = list(fnscale = -1),
var_inflate_pen = var_inflate_pen)
z_new <- optim_out$par
## update xi with Brent if scale_var = TRUE----------------------------------------
if (scale_var) {
optim_out <- stats::optim(par = xi_new, fn = uniform_mix_llike, method = "Brent",
lower = 10 ^ -14, upper = 10,
pi_vals = pi_new, z2 = z_new, betahat_ols = betahat_ols,
S_diag = S_diag, alpha_tilde = alpha_tilde, a_seq = a_seq,
b_seq = b_seq, lambda_seq = lambda_seq,
degrees_freedom = degrees_freedom,
control = list(fnscale = -1),
var_inflate_pen = var_inflate_pen)
xi_new <- optim_out$par
}
## update pi_vals with ashr ------------------------------------------------------
## use ash directly --------------------------------------------------------------
## ashr needs null component to be in first location
pi_temp <- c(pi_new[zero_spot], pi_new[-zero_spot])
a_temp <- c(a_seq[zero_spot], a_seq[-zero_spot])
b_temp <- c(b_seq[zero_spot], b_seq[-zero_spot])
ash_g <- ashr::unimix(pi = pi_temp, a = a_temp, b = b_temp)
ashout <- ashr::ash.workhorse(betahat = c(betahat_ols - alpha_tilde %*% z_new),
sebetahat = c(sqrt(xi_new * S_diag)),
df = degrees_freedom,
g = ash_g,
prior = "nullbiased",
nullweight = lambda_seq[zero_spot],
outputlevel = 0,
alpha = 0)
pi_temp <- ashout$fitted_g$pi
pi_new <- append(pi_temp[-1], pi_temp[1], zero_spot - 1)
## ## old way -----------------------------------------------------------------------
## ash_data <- ashr::set_data(betahat = c(betahat_ols - alpha_tilde %*% z_new),
## sebetahat = c(sqrt(xi_new * S_diag)),
## lik = ashr::t_lik(degrees_freedom),
## alpha = 0)
## ash_g <- ashr::unimix(pi = pi_new, a = a_seq, b = b_seq)
## ## optmethod <- "mixEM"
## if (requireNamespace(package = "Rmosek", quietly = TRUE)) {
## optmethod <- "mixIP"
## } else {
## optmethod <- "mixEM"
## }
## control_default <- list(K = 1, method = 3, square = TRUE,
## step.min0 = 1, step.max0 = 1, mstep = 4, kr = 1, objfn.inc = 1,
## tol = 1e-07, maxiter = 500, trace = FALSE)
## ash_out <- ashr::estimate_mixprop(data = ash_data, g = ash_g, prior = lambda_seq,
## optmethod = optmethod, control = list())
## pi_new2 <- ash_out$optreturn$pihat
## stoping criteria ---------------------------------------------------------------
llike_new <- uniform_mix_llike(pi_vals = pi_new, z2 = z_new, xi = xi_new,
betahat_ols = betahat_ols, S_diag = S_diag,
alpha_tilde = alpha_tilde, a_seq = a_seq,
b_seq = b_seq, lambda_seq = lambda_seq,
degrees_freedom = degrees_freedom,
var_inflate_pen = var_inflate_pen)
llike_vec <- c(llike_vec, llike_new)
assertthat::assert_that((llike_new - llike_old) > - 10 ^ -12)
err <- abs(llike_new / llike_old - 1)
iter_index <- iter_index + 1
if (plot_update) {
graphics::plot(llike_vec, type = "l", xlab = "Iteration", ylab = "log-likelihood")
graphics::mtext(paste0("Diff: ", format(err, digits = 2)))
}
}
return(list(pi_vals = pi_new, z2 = z_new, xi = xi_new))
}
#' Gradient wrt z of \code{\link{uniform_mix_llike}}.
#'
#' @inheritParams uniform_mix_fix
#' @param var_inflate_pen Not used here. Just here for compatibility
#' with \code{\link[stats]{optim}} in \code{\link{mouthwash_coordinate}}.
#'
#' @author David Gerard
#'
#' @seealso \code{\link{uniform_mix_llike}} for the objective function
#' and \code{\link{mouthwash_coordinate}} for where this function
#' is used.
mouthwash_z_grad <- function(pi_vals, z2, xi, betahat_ols, S_diag, alpha_tilde, a_seq, b_seq,
lambda_seq, degrees_freedom, var_inflate_pen = 0){
## Make sure input is OK --------------------------------------------------
M <- length(a_seq)
k <- length(z2)
p <- length(betahat_ols)
assertthat::are_equal(length(b_seq), M)
assertthat::are_equal(length(lambda_seq), M)
assertthat::are_equal(length(S_diag), p)
assertthat::are_equal(nrow(alpha_tilde), p)
assertthat::are_equal(ncol(alpha_tilde), k)
assertthat::are_equal(length(pi_vals), M)
zero_spot <- which(abs(a_seq) < 10 ^ -14 & abs(b_seq) < 10 ^ -14)
assertthat::are_equal(length(zero_spot), 1)
az <- alpha_tilde %*% z2
resid_vec <- c(betahat_ols - az)
## calculate ftildes (likelihoods) --------------------------------------
sd_mat <- matrix(rep(sqrt(xi * S_diag), M), ncol = M)
left_means <- outer(resid_vec, a_seq, FUN = "-") / sd_mat
right_means <- outer(resid_vec, b_seq, FUN = "-") / sd_mat
denom_mat <- matrix(rep(b_seq - a_seq, times = p), byrow = TRUE, ncol = M)
ftilde_mat <- ptdiff_mat(left_means, right_means, degrees_freedom = degrees_freedom) /
denom_mat
dense_vec <- dt_wrap(x = betahat_ols, df = degrees_freedom,
mean = az, sd = sqrt(xi * S_diag))
ftilde_mat[, zero_spot] <- dense_vec
## calculate fbars (derivatives) ---------------------------------------
tdiff_mat <- stats::dt(right_means, df = degrees_freedom) / sd_mat -
stats::dt(left_means, df = degrees_freedom) / sd_mat
fbar_mat <- tdiff_mat / denom_mat
if (degrees_freedom == Inf) {
fbar_mat[, zero_spot] <- dense_vec * resid_vec / (xi * S_diag)
} else {
fbar_mat[, zero_spot] <- dense_vec * resid_vec * (degrees_freedom + 1) /
(degrees_freedom * xi * S_diag + resid_vec ^ 2)
}
## calcualte weights for rows of alpha_tilde ---------------------------
weight_vec <- rowSums(sweep(fbar_mat, MARGIN = 2, STATS = pi_vals, FUN = `*`)) /
rowSums(sweep(ftilde_mat, MARGIN = 2, STATS = pi_vals, FUN = `*`))
## now gradient --------------------------------------------------------
grad_final <- crossprod(alpha_tilde, weight_vec)
return(grad_final)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.