R/cause_grid_approx.R

#'@title CAUSE with grid approximation and fixed rho
#'@param dat Data (data.frame)
#'@param seed Seed. Optional.
#'@param mix_grid Grid of mixture parameters. Will be estimated if not provided.
#'@param rho rho
#'@param z_prior_func optional. Unneccessary if you provide rho. Defaults to N(0, 0.5).
#'@param b_prior_func prior function for b. Defaults to N(0, 0.6)
#'@param q_prior_func prior function for q. Defaults to Beta(0.1, 1)
#'@param bvals Grid of values for b
#'@param qvals Grid of values for q
#'@param waic_samps Number of samples for calculating waic
#'@param models Which models to run. 2=causal, 3=shared. Some subset of c(2,3).
#'@export
cause_grid_approx <- function(dat, seed, mix_grid, rho,
                              brange = c(-1, 1), nb = 100,
                              nq = 100,
                             z_prior_func = function(z){ dnorm(z, 0, 0.5, log=TRUE)},
                             b_prior_func = function(b){dnorm(b, 0, 0.6, log=TRUE)},
                             q_prior_func = function(q){dbeta(q, 0.1, 1, log=TRUE)},
                             waic_samps  = 1000,
                             models=2:3){

  if(missing(seed)){
    seed <- ceiling(runif(n=1, 1, 1e9))
  }
  cat("seed: ", seed, "\n")
  set.seed(seed)

  if(missing(dat)){
    stop("Please provide dat.\n")
  }
  #Get bvals and qvals
  bstarts <- seq(brange[1], brange[2], length.out=nb + 1)
  b <- data.frame("begin"=bstarts[-(nb+1)], "end"=bstarts[-1]) %>%
    mutate( width = end-begin, mid = begin + (width/2))
  b$prior <- apply(b[,c("begin", "end")], 1, function(x){
    integrate(f=function(xx){exp(b_prior_func(xx))}, lower=x[1], upper=x[2])$val})
  #Normalize bprior
  b$prior %<>% divide_by(., sum(.))


  qstarts <- seq(0, 1, length.out=nq + 1)
  q <- data.frame("begin"=qstarts[-(nq+1)], "end"=qstarts[-1]) %>%
        mutate( width = end-begin, mid = begin + (width/2))

  q$prior <- apply(q[,c("begin", "end")], 1, function(x){
    integrate(f=function(xx){exp(q_prior_func(xx))}, lower=x[1], upper=x[2])$val})
  #Normalize qprior
  if(abs(sum(q$prior)-1) > 1e-6){
    cat("Warning: q prior does not integrate to 1.\n")
  }
  q$prior %<>% divide_by(., sum(.))

  #Get rho if necessary
  mg_miss <- FALSE
  if(missing(mix_grid)){
    cat("Getting grid.\n")
    mg_miss <- TRUE
    b1_fit <- ash(betahat = dat$beta_hat_1, sebetahat = dat$seb1,
                  mixcompdist = "normal", prior="nullbiased")
    b2_fit <- ash(betahat = dat$beta_hat_2, sebetahat = dat$seb2,
                  mixcompdist = "normal", prior="nullbiased")
    sigma1 <- b1_fit$fitted_g$sd[!zapsmall(b1_fit$fitted_g$pi)==0]
    sigma2 <- b2_fit$fitted_g$sd[!zapsmall(b2_fit$fitted_g$pi)==0]
    mix_grid <- data.frame("S1"=rep(sigma1, length(sigma2)),
                    "S2"=rep(sigma2, each=length(sigma1)),
                    "pi"=rep(0, length(sigma1)*length(sigma2)))
  }
  if(missing(rho) | mg_miss){
    cat("Finding MAP under q=b=0.\n")
    #null grid
    map_null <- get_map(pars.start=c(0),
                        data=dat, grid=mix_grid,
                        par.bounds=cbind(c(-1), c(1)),
                        type="null", z_prior_func= z_prior_func, n.iter=20)
    if(missing(rho)) rho <- map_null$pars
    mix_grid$pi <- map_null$pi
  }

  param_grid <- list()
  #causal model (model 2)
  param_grid[[1]] <- b %>% rename("bstart"="begin", "bstop"="end", "bwidth"="width", "b"="mid", "bprior"="prior")

  param_grid[[2]] <- data.frame("b" = rep(b$mid, nq), "q" = rep(q$mid, each=nb),
                                "bstart"=rep(b$begin, nq), "bstop" = rep(b$end, nq),
                                "bwidth" = rep(b$width, nq),
                                "bprior" = rep(b$prior, nq),
                                "qstart"=rep(q$begin, each=nb), "qstop" = rep(q$end, each=nb),
                                "qwidth" = rep(q$width, each=nb),
                                "qprior" = rep(q$prior, each=nb))
  param_grid[[2]]$prior <- with(param_grid[[2]], bprior*qprior)

  marge_post_b <- param_grid[[1]]
  marge_post_q <- q %>% rename("qstart"="begin", "qstop"="end", "qwidth"="width", "q"="mid", "qprior"="prior")
  #li.funcs <- list(li_func_null, li_func_causal, li_func_share_rawq)
  if(2 %in% models){
    cat("Model 2.\n")
    param_grid[[1]]$log_lik <- with(param_grid[[1]], sapply(b, FUN = function(bb){
      ll1 <- ll_v4(rho, bb, 1,
                   mix_grid$S1, mix_grid$S2, mix_grid$pi,
                   dat$beta_hat_1, dat$beta_hat_2,
                   dat$seb1, dat$seb2, dat$wts)
    }) + log(bwidth))
    param_grid[[1]]$log_lik #%<>% add(., -logSumExp(.))
    param_grid[[1]]$log_post <- with(param_grid[[1]], (log_lik + log(bprior)) %>%
                                                        add(., -logSumExp(.)) )
    marge_post_b$post_m2 <- exp(param_grid[[1]]$log_post)
  }
  if(3 %in% models){
    cat("Model 3.\n")
    param_grid[[2]]$log_lik <- apply(param_grid[[2]][, c("b", "q")], 1, FUN = function(x){
      ll1 <- ll_v4(rho, x[1], x[2],
                   mix_grid$S1, mix_grid$S2, mix_grid$pi,
                   dat$beta_hat_1, dat$beta_hat_2,
                   dat$seb1, dat$seb2, dat$wts)
    }) + log(param_grid[[2]]$bwidth) + log(param_grid[[2]]$qwidth)
    param_grid[[2]]$log_lik # %<>% add(., -logSumExp(.))
    param_grid[[2]]$log_post <- with(param_grid[[2]], (log_lik + log(prior)) %>%
                                       add(., -logSumExp(.)) )

    marge_post_b$post_m3 <- sapply(marge_post_b$b, FUN=function(bb){
                             filter(param_grid[[2]], b==bb) %>% select(log_post) %>%
                             exp(.) %>%  sum()})
    marge_post_q$post_m3 <- sapply(marge_post_q$q, FUN=function(qq){
      filter(param_grid[[2]], q==qq) %>% select(log_post) %>%
        exp(.) %>%  sum()})
  }

  #Calculate WAIC
  if(waic_samps > 0){
    llmats <- list()
    if(2 %in% models){
      llmats[[1]] <- samp_from_grid_ll(post_grid = param_grid[[1]], rho, mix_grid, dat,
                                       fix.q=1, waic_samps=waic_samps)
    }
    if(3 %in% models){
      llmats[[2]] <- samp_from_grid_ll(post_grid = param_grid[[2]], rho, mix_grid, dat,
                                       waic_samps=waic_samps)
    }
    if(length(models)==2) waic <-  my_waic(llmats)
      else waic <- my_waic[models-1]
  }else{
    waic = NULL
  }
  return(list("post" = param_grid, "mix_grid" = mix_grid, "marge_b" = marge_post_b, "marge_q" = marge_post_q, "waic" = waic,
              "b_prior_func" = deparse(b_prior_func), "q_prior_func" = deparse(q_prior_func),
              "z_prior_func" = deparse(z_prior_func), "rho" = rho, "seed" = seed))

}
jean997/sherlockAsh documentation built on May 18, 2019, 11:45 p.m.