# @title Update a single flash factor/loading combination (and precision).
#
# @param data A flash data object.
#
# @param f A flash fit object.
#
# @param k Index of factor/loading pair to update.
#
# @param var_type Variance structure to assume for residuals.
#
# @param ebnm_fn_l Function to solve EBNM problem (loadings updates).
#
# @param ebnm_param_l Parameters to be passed to ebnm_fn_l.
#
# @param ebnm_fn_f Function to solve EBNM problem (factor updates).
#
# @param ebnm_param_f Parameters to be passed to ebnm_fn_f.
#
# @param Rk Optionally, a matrix of residuals (excluding factor k) may
# be passed in (for performance reasons).
#
# @param R2 A matrix of squared residuals may also be passed in.
#
# @return An updated flash object.
#
flash_update_single_fl = function(data,
f,
k,
var_type,
ebnm_fn_l,
ebnm_param_l,
ebnm_fn_f,
ebnm_param_f,
Rk = NULL,
R2 = NULL) {
# Update precision:
if (is.null(R2)) {
R2 = flash_get_R2(data, f)
}
f$tau = compute_precision(R2, data, var_type)
if (is.null(Rk)) {
Rk = flash_get_Rk(data, f, k)
}
f = flash_update_single_loading(data,
f,
k,
ebnm_fn_l,
ebnm_param_l,
Rk,
calc_obj = TRUE)
f = flash_update_single_factor(data,
f,
k,
ebnm_fn_f,
ebnm_param_f,
Rk,
calc_obj = TRUE)
return(f)
}
# @title Update a flash loading
#
# @inheritParams flash_update_single_fl
#
# @param calc_obj Specifies whether to calculate KL divergences.
#
# @return An updated flash object.
#
flash_update_single_loading = function(data,
f,
k,
ebnm_fn,
ebnm_param,
Rk,
calc_obj = TRUE) {
subset = which(!f$fixl[, k])
any_fixed = any(f$fixl[, k])
res = calc_update_vals(data,
f,
k,
subset,
any_fixed,
ebnm_fn,
ebnm_param,
loadings = TRUE,
Rk,
calc_obj)
if (!is.null(res)) {
f$EL[subset, k] = res$EX
f$EL2[subset, k] = res$EX2
f$gl[[k]] = res$g
f$ebnm_fn_l[[k]] = ebnm_fn
f$ebnm_param_l[[k]] = ebnm_param
if (calc_obj) {
f$KL_l[[k]] = res$KL
}
}
return(f)
}
# @title Update a flash factor
#
# @inherit flash_update_single_loading
#
flash_update_single_factor = function(data,
f,
k,
ebnm_fn,
ebnm_param,
Rk,
calc_obj = TRUE) {
subset = which(!f$fixf[, k])
any_fixed = any(f$fixf[, k])
res = calc_update_vals(data,
f,
k,
subset,
any_fixed,
ebnm_fn,
ebnm_param,
loadings = FALSE,
Rk,
calc_obj)
if (!is.null(res)) {
f$EF[subset, k] = res$EX
f$EF2[subset, k] = res$EX2
f$gf[[k]] = res$g
f$ebnm_fn_f[[k]] = ebnm_fn
f$ebnm_param_f[[k]] = ebnm_param
if (calc_obj) {
f$KL_f[[k]] = res$KL
}
}
return(f)
}
# @title Calculate updated values for factor/loading updates
#
# @inheritParams flash_update_single_loading
#
# @param subset The subset of factor or loadings entries that are not
# considered fixed (and can thus be updated).
#
# @param loadings Should be TRUE for loadings updates and FALSE for
# factor updates
#
# @return A list with elements EX, EX2, g, and KL (these are updated
# values of either EL, EL2, gl, and KL_l or EF, EF2, gf, and KL_f).
# If no update should be performed, returns NULL.
#
calc_update_vals = function(data,
f,
k,
subset,
any_fixed,
ebnm_fn,
ebnm_param,
loadings,
Rk,
calc_obj = TRUE) {
# Do not update if all elements are fixed:
if (length(subset) == 0) {
return(NULL)
}
if (loadings) {
ebnm_args = calc_ebnm_l_args(data, f, k, subset, any_fixed, Rk)
} else {
ebnm_args = calc_ebnm_f_args(data, f, k, subset, any_fixed, Rk)
}
if (is.null(ebnm_args)) {
return(NULL)
}
if (!is.null(ebnm_param$warmstart)) {
if (ebnm_param$warmstart) {
if (loadings && length(f$gl) >= k) {
ebnm_param$g = f$gl[[k]]
} else if (!loadings && length(f$gf) >= k) {
ebnm_param$g = f$gf[[k]]
}
}
ebnm_param$warmstart = NULL
}
a = do.call(ebnm_fn, list(ebnm_args$x, ebnm_args$s, ebnm_param))
res = list(EX = a$postmean,
EX2 = a$postmean2,
g = a$fitted_g)
if (calc_obj) {
KL = a$penloglik - NM_posterior_e_loglik(ebnm_args$x, ebnm_args$s,
a$postmean, a$postmean2)
res = c(res, KL = KL)
}
return(res)
}
# @title Calculate EBNM arguments for loadings updates
#
# @inheritParams calc_update_vals
#
# @return A list with elements x and s (vectors of observations and
# standard errors to be passed into ebnm_fn).
#
calc_ebnm_l_args = function(data, f, k, subset, any_fixed, Rk) {
# Subsetting can be expensive, so only do it when necessary:
if (any_fixed) {
Rk = Rk[subset, , drop = FALSE]
}
if (any_fixed && is.matrix(f$tau)) {
tau = f$tau[subset, , drop = FALSE]
} else {
tau = f$tau
}
if (any_fixed && data$anyNA) {
missing = data$missing[subset, , drop = FALSE]
} else {
missing = data$missing
}
if (data$anyNA) {
if (is.matrix(tau)) {
tau[missing] = 0
} else { # tau is a scalar
tau = tau * !missing
}
}
if (is.matrix(tau)) {
s2 = 1/(tau %*% f$EF2[, k])
} else { # tau is a scalar
if (data$anyNA) {
s2 = 1/(tau %*% f$EF2[, k])
} else {
s2 = 1/(tau * sum(f$EF2[, k]))
}
}
if (sum(is.finite(s2)) == 0) {
return(NULL)
}
x = ((Rk * tau) %*% f$EF[, k]) * s2
# Avoid NaNs when s2 is infinite (in which case the value of x
# doesn't matter).
x[is.infinite(s2)] = 0
# If a value of s2 becomes numerically negative, set it to a
# small positive number.
s = sqrt(pmax(s2, .Machine$double.eps))
return(list(x = x, s = s))
}
# @title Calculate EBNM arguments for factor updates
#
# @inherit calc_ebnm_l_args
#
calc_ebnm_f_args = function(data, f, k, subset, any_fixed, Rk) {
if (any_fixed) {
Rk = Rk[, subset, drop = FALSE]
}
if (any_fixed && is.matrix(f$tau)) {
tau = f$tau[, subset, drop = FALSE]
} else {
tau = f$tau
}
if (any_fixed && data$anyNA) {
missing = data$missing[, subset, drop = FALSE]
} else {
missing = data$missing
}
if (data$anyNA) {
if (is.matrix(f$tau)) {
tau[missing] = 0
} else {
tau = tau * !missing
}
}
if (is.matrix(f$tau)) {
s2 = 1/(t(f$EL2[, k]) %*% tau)
} else { # tau is a scalar
if (data$anyNA) {
s2 = 1/(f$EL2[, k] %*% tau)
} else {
s2 = 1/(tau * sum(f$EL2[, k]))
}
}
if (sum(is.finite(s2)) == 0) {
return(NULL)
}
x = (t(f$EL[, k]) %*% (Rk * tau)) * s2
x[is.infinite(s2)] = 0
s = sqrt(pmax(s2, .Machine$double.eps))
return(list(x = x, s = s))
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.