R/get_map_old.R

#'@title Try to find the MAP using coordinate descent
#'@description v4_li_mat_func and v4_prior_func use working model 4 described here: https://jean997.github.io/sh2Ash/mh_test.html
#'@param pars.start Starting parameters: eg rho, b, q. Should be appropriate for li.mat.func
#'@param data data
#'@param grid Should be a data.frame with at least two columns, S1 and S2.
#'@param par.bounds should be data.frame or matrix number of pars by 2
#'@param split Optimization can be split over the range of a parameter if desired.
#'@param version Determines likelihood and prior functions
#'The prior on pi is hard coded as dirichlet((10, 1, 1, ...))
#'@export
get_map_old <- function(pars.start, data, grid, par.bounds, split=rep(NA, length(pars.start)),
                    tol=1e-7, n.iter=Inf, version=4, type=c("null", "causal", "shared"), ...){

  type <- match.arg(type)
  stopifnot(version %in% c(4, 6))
  if(version==4){
    li.mat.func=v4_li_mat_func
    prior.func=v4_prior_func
    K <- nrow(grid)
    stopifnot(ncol(grid) >= 2)

  }else if(version == 6){
    li.mat.func=v6_li_mat_func
    prior.func=v6_prior_func
    if(type=="shared") K <- 1+ 3*nrow(grid)
      else K <- 1 + 2*nrow(grid)
    stopifnot(ncol(grid) >= 1)

  }

  stopifnot(length(pars.start)==length(split))
  stopifnot(length(pars.start)==nrow(par.bounds))

  converged <- FALSE
  pars <- pars.old <- pars.start
  if(any(!is.na(split))){
    ix <- which(!is.na(split))
    stopifnot(all(par.bounds[ix,1] <= split[ix] & split[ix]<= par.bounds[ix,2]))
  }

  p <- nrow(data)
  J <- length(pars)
  stopifnot(nrow(par.bounds) == J & ncol(par.bounds)==2)
  pi.old <- rep(0, K)

  PIS <- matrix(nrow=K, ncol=0)
  PARAMS <- matrix(nrow=length(pars), ncol=0)
  LLS <- c()
  ct <-1
  while(!converged & ct < n.iter){

    matrix_llik1 <- li.mat.func(pars, data, grid, type)

    matrix_llik = matrix_llik1 - apply(matrix_llik1, 1, max)
    matrix_lik = exp(matrix_llik)
    w_res = ashr:::mixIP(matrix_lik =matrix_lik, prior=c(10, rep(1, K-1)), weights=data$wts)

    pi <- w_res$pihat
    pi.prior <- ddirichlet1(pi, c(10, rep(1, K-1)))
    loglik <- sum(sapply(1:p, FUN=function(i){
      data$wts[i]*logSumExp(log(pi) +matrix_llik1[i,])
    })) + prior.func(pars, type = type, ...) + pi.prior
    LLS <- c(LLS, loglik)
    PIS <- cbind(PIS, pi)
    #cat(ct, " pis, loglik ", loglik, "\n")
    #mygrid <- fill_grid(grid, pi)

    #Now maximize in each parameter
    li_func <- function(pars, ...){
      matrix_llik <- li.mat.func(pars, data, grid, type=type)
      loglik <- sum(sapply(1:p, FUN=function(i){
        data$wts[i]*logSumExp(log(pi) +matrix_llik[i,])
      })) + prior.func(pars, type=type, ...) + pi.prior
      return(-loglik)
    }

    for(j in 1:length(pars)){
      f1 <- function(x){
        pars.new <- pars
        pars.new[j] <- x
        li_func(pars.new, ...)
        #li_func(pars.new)
      }
      if(is.na(split[j])){
        upd.parj <-  optimize(f = f1, lower=par.bounds[j,1], upper = par.bounds[j,2], maximum=FALSE)
      }else{
        upd.parj1 <-  optimize(f = f1, lower=par.bounds[j,1], upper = split[j], maximum=FALSE)
        upd.parj2 <-  optimize(f = f1, lower=split[j], upper = par.bounds[j,2], maximum=FALSE)
        best <- which.min(c(upd.parj1$objective, upd.parj2$objective))
        upd.parj <- list(upd.parj1, upd.parj2)[[best]]
      }
      pars[j] <- upd.parj$minimum
      LLS <- c(LLS, -upd.parj$objective)
    }
    PARAMS <- cbind(PARAMS, pars)

    test <- max(abs(c(pars, pi)-c(pars.old, pi.old)))

    cat(ct, test, "\n")
    if(test < tol) converged <- TRUE
    pars.old <- pars
    pi.old <- pi
    ct <- ct + 1
  }
  if(!all(diff(LLS) > -1e-7)) cat("Warning: This may not be a local maximum ", min(diff(LLS)), "\n")

  fit <- list("pars"=pars, "pi"=pi, "grid"=grid,
              "loglik"=-1*upd.parj$objective,
              "PIS"=PIS, "PARAMS"= PARAMS, "LLS"=LLS, "version"=version,
              "type"=type,
              "converged" = converged)
  fit$prior <- prior.func(fit$pars, type = type, ...) + pi.prior
  hes <- hessian(li_func, pars,...  )
  fit$var <- solve(hes)
  fit$llmat <- matrix_llik1
  return(fit)

}

ddirichlet1 <- function(x, alpha) {
  logD <- sum(lgamma(alpha)) - lgamma(sum(alpha))
  s <- sum((alpha - 1) * log(x))
  return(sum(s) - logD)
}
jean997/sherlockAsh documentation built on May 18, 2019, 11:45 p.m.