R/get_map.R

#'@title Try to find the MAP using coordinate descent
#'@description v4_li_mat_func and v4_prior_func use working model 4 described here: https://jean997.github.io/sh2Ash/mh_test.html
#'@param pars_start Starting parameters: eg rho, b, q. Should be appropriate for v4_li_mat_func
#'@param data data
#'@param grid Should be a data.frame with at least two columns, S1 and S2.
#'@param par_bounds should be data.frame or matrix number of pars by 2
#'@param type One of the models null, causal or shared. If null, pars should be length 1 (rho).
#'If causal, pars should be length 2 (rho, b), If shared pars should be length 3 (rho, b, q)
#'The prior on pi is hard coded as dirichlet((10, 1, 1, ...))
#'@param fix_pi if FALSE, only estimate pars, leave mixing parameters fixed as given in grid
#'@param fix_pars vector of indices of parameters to fix
#'@param ... Additional parameters (z_prior_func, b_prior_func, q_prior_func)
#'@param null_wt Specifies the prior weight on the first entry of grid
#'@export
get_map <- function(pars_start, data, grid, par_bounds,
                    tol=1e-7, n.iter=Inf, type=c("null", "causal", "shared"),
                    fix_pi=FALSE, fix_pars = c(), null_wt = 10, ...){


  type <- match.arg(type)

  #Only version 4 now
  K <- nrow(grid)
  stopifnot(ncol(grid) >= 2)

  #Check inputs
  p <- nrow(data)
  J <- length(pars_start)
  stopifnot(nrow(par_bounds) == J & ncol(par_bounds)==2)

  #We can start with either params or pi. If pi is missing, we estimate it first.

  if(any(is.na(pars_start))) stop("No missing parameter values please.\n")

  pars <- pars_old <- pars_start
  #If there is no initial grid estimate
  if(all(grid$pi==0)){
    matrix_llik1 <- v4_li_mat_func(pars, data, grid, type)
    matrix_llik = matrix_llik1 - apply(matrix_llik1, 1, max)
    matrix_lik = exp(matrix_llik)
    w_res = ashr:::mixIP(matrix_lik =matrix_lik, prior=c(null_wt, rep(1, K-1)), weights=data$wts)
    pi <- pi_old <- w_res$pihat
  }else{
    pi <- pi_old <- grid$pi
  }
  pi_prior <- ddirichlet1(pi, c(null_wt, rep(1, K-1)))

  converged <- FALSE
  PIS <- matrix(nrow=K, ncol=0)
  PIS <- cbind(PIS, pi)
  PARAMS <- matrix(nrow=J, ncol=0)
  LLS <- c()
  ct <-1

  while(!converged & ct <= n.iter){
    #maximize in each parameter
    li_func <- function(pars, ...){
      matrix_llik <- v4_li_mat_func(pars, data, grid, type=type)
      loglik <- sum(sapply(1:p, FUN=function(i){
                                data$wts[i]*logSumExp(log(pi) +matrix_llik[i,])
                    })) +
                v4_prior_func(pars, type=type, ...) + pi_prior
      return(-loglik)
    }
    for(j in 1:J){
      if(j %in% fix_pars){
        LLS <- c(LLS, -li_func(pars, ...))
        next
      }
      f1 <- function(x){
        pars_new <- pars
        pars_new[j] <- x
        li_func(pars_new, ...)
      }
      upd_parj <-  optimize(f = f1, lower=par_bounds[j,1], upper = par_bounds[j,2], maximum=FALSE)
      pars[j] <- upd_parj$minimum
      LLS <- c(LLS, -upd_parj$objective)
    }
    PARAMS <- cbind(PARAMS, pars)
    matrix_llik1 <- v4_li_mat_func(pars, data, grid, type)
    if(!fix_pi){
      #Maximize in pis
      matrix_llik = matrix_llik1 - apply(matrix_llik1, 1, max)
      matrix_lik = exp(matrix_llik)
      w_res = ashr:::mixIP(matrix_lik =matrix_lik, prior=c(null_wt, rep(1, K-1)), weights=data$wts)
      pi <- w_res$pihat
      #Likelihood
      pi_prior <- ddirichlet1(pi, c(null_wt, rep(1, K-1)))
      loglik <- sum(sapply(1:p, FUN=function(i){
                                  data$wts[i]*logSumExp(log(pi) +matrix_llik1[i,])
                    })) +
                v4_prior_func(pars, type = type, ...) + pi_prior
      LLS <- c(LLS, loglik)
      PIS <- cbind(PIS, pi)
    }
    #Test for convergence
    test <- max(abs(c(pars, pi)-c(pars_old, pi_old)))
    cat(ct, test, "\n")
    if(test < tol) converged <- TRUE
    pars_old <- pars
    pi_old <- pi
    ct <- ct + 1
  }
  if(!all(diff(LLS) > -1e-7)) cat("Warning: This may not be a local maximum ", min(diff(LLS)), "\n")

  fit <- list("pars"=pars, "pi"=pi, "grid"=grid,
              "loglik"=LLS[length(LLS)],
              "PIS"=PIS, "PARAMS"= PARAMS, "LLS"=LLS,
              "type"=type,
              "converged" = converged)
  fit$prior <- v4_prior_func(fit$pars, type = type, ...) + pi_prior
  hes <- hessian(li_func, pars,...  )
  fit$var <- solve(hes)
  fit$llmat <- matrix_llik1
  return(fit)

}

ddirichlet1 <- function(x, alpha) {
  logD <- sum(lgamma(alpha)) - lgamma(sum(alpha))
  s <- sum((alpha - 1) * log(x))
  return(sum(s) - logD)
}
jean997/sherlockAsh documentation built on May 18, 2019, 11:45 p.m.