R/cause_grid_approx_v7.R

#'@title CAUSE with grid approximation and fixed rho
#'@param dat Data (data.frame)
#'@param seed Seed. Optional.
#'@param mix_grid Grid of mixture parameters. Will be estimated if not provided.
#'@param rho rho
#'@param z_prior_func optional. Unneccessary if you provide rho. Defaults to N(0, 0.5).
#'@param gamma_prior_func prior function for b. Defaults to N(0, 0.6)
#'@param q_prior_func prior function for q. Defaults to Beta(0.1, 1)
#'@param gvals Grid of values for gamma and gamma_prime
#'@param qvals Grid of values for q
#'@export
cause_grid_approx_v7 <- function(dat, mix_grid, rho,
                              gamma_range = c(-1, 1), gamma_prime_range = c(-1, 1),
                              n_gamma= 100, n_gamma_prime = 100, n_q = 100,
                              gamma_prior_func = function(b){dnorm(b, 0, 0.6)},
                              gamma_prime_prior_func = function(b){dnorm(b, 0, 0.6)},
                              q_prior_func = function(q){dbeta(q, 0.1, 1)}){
  if(missing(mix_grid)| missing(rho)){
    stop("For now please provide mix_grid and rho\n")
  }

  if(missing(dat)){
    stop("Please provide dat.\n")
  }

  #Get gamma-vals and q-vals
  coords <- matrix(c(0, 1, gamma_range, gamma_prime_range), nrow=3, byrow=T)
  priors <- list(q_prior_func, gamma_prior_func, gamma_prime_prior_func)
  vals <- get_vals(coords, c(n_q, n_gamma, n_gamma_prime), priors)

  param_grid <-  expand.grid(q_ix =seq(n_q), g_ix=seq(n_gamma), gp_ix = seq(n_gamma_prime)) %>%
                mutate(q = vals[[1]]$mid[q_ix], g = vals[[2]]$mid[g_ix], gp = vals[[3]]$mid[gp_ix],
                       qstart = vals[[1]]$begin[q_ix], qstop = vals[[1]]$end[q_ix],
                       qwidth=vals[[1]]$width[q_ix], qprior = vals[[1]]$prior[q_ix],
                       gstart = vals[[2]]$begin[g_ix], gstop = vals[[2]]$end[g_ix],
                       gwidth = vals[[2]]$width[g_ix], gprior = vals[[2]]$prior[g_ix],
                       gpstart = vals[[3]]$begin[gp_ix], gpstop = vals[[3]]$end[gp_ix],
                       gpwidth = vals[[3]]$width[gp_ix], gpprior = vals[[3]]$prior[gp_ix],
                       log_prior = log(qprior) + log(gprior) + log(gpprior) )


  param_grid$log_lik <- apply(param_grid[, c("g","gp", "q")], 1, FUN = function(x){
    ll_v7(rho, x[1], x[2], x[3],
          mix_grid$S1, mix_grid$S2, mix_grid$pi,
          dat$beta_hat_1, dat$beta_hat_2,
          dat$seb1, dat$seb2)
  })
  param_grid <- param_grid %>% mutate(log_post  =  log_lik + log_prior,
                        log_post = log_post - logSumExp(log_post))

  params <- c("q", "g", "gp")
  ranges <- list(c(0, 1), gamma_range, gamma_prime_range)

  post_marge <- marge_dists(param_grid, params, priors, ranges)


  return(list("post" = param_grid, "post_marge" = post_marge, "mix_grid" = mix_grid, parmas = params,
              "gamma_prior_func" = deparse(gamma_prior_func), "q_prior_func" = deparse(q_prior_func),
              "rho" = rho))

}
jean997/sherlockAsh documentation built on May 18, 2019, 11:45 p.m.