#'@title Try to find the MAP using coordinate descent
#'@description v4_li_mat_func and v4_prior_func use working model 4 described here: https://jean997.github.io/sh2Ash/mh_test.html
#'@param pars_start Starting parameters: eg rho, b, q. Should be appropriate for v4_li_mat_func
#'@param data data
#'@param grid Should be a data.frame with at least two columns, S1 and S2.
#'@param par_bounds should be data.frame or matrix number of pars by 2
#'@param type One of the models null, causal or shared. If null, pars should be length 1 (rho).
#'If causal, pars should be length 2 (rho, b), If shared pars should be length 3 (rho, b, q)
#'The prior on pi is hard coded as dirichlet((10, 1, 1, ...))
#'@param fix_pi if FALSE, only estimate pars, leave mixing parameters fixed as given in grid
#'@param fix_pars vector of indices of parameters to fix
#'@param ... Additional parameters (z_prior_func, b_prior_func, q_prior_func)
#'@param null_wt Specifies the prior weight on the first entry of grid
#'@export
get_map <- function(pars_start, data, grid, par_bounds,
tol=1e-7, n.iter=Inf, type=c("null", "causal", "shared"),
fix_pi=FALSE, fix_pars = c(), null_wt = 10, ...){
type <- match.arg(type)
#Only version 4 now
K <- nrow(grid)
stopifnot(ncol(grid) >= 2)
#Check inputs
p <- nrow(data)
J <- length(pars_start)
stopifnot(nrow(par_bounds) == J & ncol(par_bounds)==2)
#We can start with either params or pi. If pi is missing, we estimate it first.
if(any(is.na(pars_start))) stop("No missing parameter values please.\n")
pars <- pars_old <- pars_start
#If there is no initial grid estimate
if(all(grid$pi==0)){
matrix_llik1 <- v4_li_mat_func(pars, data, grid, type)
matrix_llik = matrix_llik1 - apply(matrix_llik1, 1, max)
matrix_lik = exp(matrix_llik)
w_res = ashr:::mixIP(matrix_lik =matrix_lik, prior=c(null_wt, rep(1, K-1)), weights=data$wts)
pi <- pi_old <- w_res$pihat
}else{
pi <- pi_old <- grid$pi
}
pi_prior <- ddirichlet1(pi, c(null_wt, rep(1, K-1)))
converged <- FALSE
PIS <- matrix(nrow=K, ncol=0)
PIS <- cbind(PIS, pi)
PARAMS <- matrix(nrow=J, ncol=0)
LLS <- c()
ct <-1
while(!converged & ct <= n.iter){
#maximize in each parameter
li_func <- function(pars, ...){
matrix_llik <- v4_li_mat_func(pars, data, grid, type=type)
loglik <- sum(sapply(1:p, FUN=function(i){
data$wts[i]*logSumExp(log(pi) +matrix_llik[i,])
})) +
v4_prior_func(pars, type=type, ...) + pi_prior
return(-loglik)
}
for(j in 1:J){
if(j %in% fix_pars){
LLS <- c(LLS, -li_func(pars, ...))
next
}
f1 <- function(x){
pars_new <- pars
pars_new[j] <- x
li_func(pars_new, ...)
}
upd_parj <- optimize(f = f1, lower=par_bounds[j,1], upper = par_bounds[j,2], maximum=FALSE)
pars[j] <- upd_parj$minimum
LLS <- c(LLS, -upd_parj$objective)
}
PARAMS <- cbind(PARAMS, pars)
matrix_llik1 <- v4_li_mat_func(pars, data, grid, type)
if(!fix_pi){
#Maximize in pis
matrix_llik = matrix_llik1 - apply(matrix_llik1, 1, max)
matrix_lik = exp(matrix_llik)
w_res = ashr:::mixIP(matrix_lik =matrix_lik, prior=c(null_wt, rep(1, K-1)), weights=data$wts)
pi <- w_res$pihat
#Likelihood
pi_prior <- ddirichlet1(pi, c(null_wt, rep(1, K-1)))
loglik <- sum(sapply(1:p, FUN=function(i){
data$wts[i]*logSumExp(log(pi) +matrix_llik1[i,])
})) +
v4_prior_func(pars, type = type, ...) + pi_prior
LLS <- c(LLS, loglik)
PIS <- cbind(PIS, pi)
}
#Test for convergence
test <- max(abs(c(pars, pi)-c(pars_old, pi_old)))
cat(ct, test, "\n")
if(test < tol) converged <- TRUE
pars_old <- pars
pi_old <- pi
ct <- ct + 1
}
if(!all(diff(LLS) > -1e-7)) cat("Warning: This may not be a local maximum ", min(diff(LLS)), "\n")
fit <- list("pars"=pars, "pi"=pi, "grid"=grid,
"loglik"=LLS[length(LLS)],
"PIS"=PIS, "PARAMS"= PARAMS, "LLS"=LLS,
"type"=type,
"converged" = converged)
fit$prior <- v4_prior_func(fit$pars, type = type, ...) + pi_prior
hes <- hessian(li_func, pars,... )
fit$var <- solve(hes)
fit$llmat <- matrix_llik1
return(fit)
}
ddirichlet1 <- function(x, alpha) {
logD <- sum(lgamma(alpha)) - lgamma(sum(alpha))
s <- sum((alpha - 1) * log(x))
return(sum(s) - logD)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.