R/cause_grid_adapt_v7.R

#'@export
cause_grid_adapt_v7 <-  function(dat, mix_grid, rho, gamma_range = c(-1, 1),
                                 gamma_prime_range = c(-1, 1),
                                 n_q_start = 10, n_gamma_start = 100, n_gamma_prime_start = 10,
                                 gamma_prior_func = function(b){dnorm(b, 0, 0.6)},
                                 gamma_prime_prior_func = function(b){dnorm(b, 0, 0.6)},
                                 q_prior_func = function(q){dbeta(q, 0.1, 1)},
                                 max_post_per_bin = 0.01, fix_gamma_0 = FALSE,
                                 fix_gamma_0_and_q_1=FALSE){


  if(missing(mix_grid)| missing(rho)){
    stop("For now please provide mix_grid and rho\n")
  }
  if(fix_gamma_0_and_q_1 & fix_gamma_0){
    stop("Please use only one of fix_gamma_0_and_q_1 or fix_gamma_0\n")
  }
  if(missing(dat)){
    stop("Please provide dat.\n")
  }

  if(fix_gamma_0_and_q_1){
    param_grid <- adapt_grid_1d(gamma_prime_range,
                                n_gamma_prime_start,
                                gamma_prime_prior_func,
                                mix_grid, rho,
                                dat, max_post_per_bin)
    param_grid <- param_grid %>% mutate(gpstart = round(gpstart, digits=6), gpstop=round(gpstop, digits=6))%>%
      mutate(g = 0, gstart = 0, gstop = 0, q = 1, qstart=1, qstop=1)
    params <- c("gp")
    priors <- list(gamma_prime_prior_func)
    ranges <- list(gamma_prime_range)
  }else if(fix_gamma_0){
    param_grid <- adapt_grid_2d(gamma_prime_range,
                                n_gamma_prime_start, n_q_start,
                                gamma_prime_prior_func, q_prior_func,
                                mix_grid, rho,
                                dat, max_post_per_bin)
    param_grid <- param_grid %>% mutate(qstart = round(qstart, digits=6), qstop=round(qstop, digits=6),
                                        gpstart = round(gpstart, digits=6), gpstop=round(gpstop, digits=6))%>%
                                mutate(g = 0, gstart = 0, gstop = 0)
    params <- c("q", "gp")
    priors <- list(q_prior_func,  gamma_prime_prior_func)
    ranges <- list(c(0, 1), gamma_prime_range)
  }else{
    param_grid <- adapt_grid_3d(gamma_range, gamma_prime_range,
                              n_gamma_start, n_gamma_prime_start, n_q_start,
                              gamma_prior_func,  gamma_prime_prior_func, q_prior_func,
                              mix_grid, rho,
                              dat, max_post_per_bin)
    param_grid <- param_grid %>%
                  mutate(qstart = round(qstart, digits=6), qstop=round(qstop, digits=6),
                         gstart = round(gstart, digits=6), gstop=round(gstop, digits=6),
				                 gpstart = round(gpstart, digits=6), gpstop=round(gpstop, digits=6))


    params <- c("q", "g", "gp")
    priors <- list(q_prior_func, gamma_prior_func, gamma_prime_prior_func)
    ranges <- list(c(0, 1), gamma_range, gamma_prime_range)
  }

  post_marge <- marge_dists(param_grid, params, priors, ranges)

  return(list("post" = param_grid, "post_marge" = post_marge, "mix_grid" = mix_grid, params = params,
              "gamma_prior_func" = deparse(gamma_prior_func), "q_prior_func" = deparse(q_prior_func),
              "gamma_prime_prior_func" = deparse(gamma_prime_prior_func),
              "rho" = rho, "ranges"=ranges))

}

adapt_grid_3d <- function(gamma_range, gamma_prime_range,
                          n_gamma_start, n_gamma_prime_start, n_q_start,
                          gamma_prior_func,  gamma_prime_prior_func, q_prior_func,
                          mix_grid, rho,
                          dat, max_post_per_bin){

  #Get gamma-vals and q-vals
  coords <- matrix(c(0, 1, gamma_range, gamma_prime_range), nrow=3, byrow=T)
  priors <- list(q_prior_func, gamma_prior_func, gamma_prime_prior_func)
  vals <- get_vals(coords, c(n_q_start, n_gamma_start, n_gamma_prime_start), priors)

  res <- expand.grid(q_ix =seq(n_q_start), g_ix=seq(n_gamma_start), gp_ix = seq(n_gamma_prime_start)) %>%
         mutate(q = vals[[1]]$mid[q_ix], g = vals[[2]]$mid[g_ix], gp = vals[[3]]$mid[gp_ix],
                qstart = vals[[1]]$begin[q_ix], qstop = vals[[1]]$end[q_ix],
                qwidth=vals[[1]]$width[q_ix], qprior = vals[[1]]$prior[q_ix],
                gstart = vals[[2]]$begin[g_ix], gstop = vals[[2]]$end[g_ix],
                gwidth = vals[[2]]$width[g_ix], gprior = vals[[2]]$prior[g_ix],
                gpstart = vals[[3]]$begin[gp_ix], gpstop = vals[[3]]$end[gp_ix],
                gpwidth = vals[[3]]$width[gp_ix], gpprior = vals[[3]]$prior[gp_ix],
                log_prior = log(qprior) + log(gprior) + log(gpprior) ) #- 2*gamma_norm)

  res$log_lik <- apply(res[, c("g","gp", "q")], 1, FUN = function(x){
    ll_v7(rho, x[1], x[2], x[3],
                 mix_grid$S1, mix_grid$S2, mix_grid$pi,
                 dat$beta_hat_1, dat$beta_hat_2,
                 dat$seb1, dat$seb2)
  })
  res <- res %>% mutate(log_post  =  log_lik + log_prior,
                        log_post = log_post - logSumExp(log_post))

  thresh <- log(max_post_per_bin)
  n <- 3
  while(any(res$log_post > thresh)){
    ix <- which(res$log_post > thresh)
    nr <- lapply(ix, FUN=function(i){
      coords <- matrix(as.numeric(res[i, c("qstart", "qstop", "gstart", "gstop", "gpstart", "gpstop")]), nrow=3, byrow=TRUE)

      vals <- get_vals(coords, rep(n, 3), priors)

      new_res <- expand.grid(q_ix = seq(n), g_ix=seq(n), gp_ix = seq(n)) %>%
        mutate(q = vals[[1]]$mid[q_ix], g = vals[[2]]$mid[g_ix], gp = vals[[3]]$mid[gp_ix],
               qstart = vals[[1]]$begin[q_ix], qstop = vals[[1]]$end[q_ix],
               qprior = vals[[1]]$prior[q_ix], qwidth=vals[[1]]$width[q_ix],
               gstart = vals[[2]]$begin[g_ix], gstop = vals[[2]]$end[g_ix],
               gwidth = vals[[2]]$width[g_ix], gprior = vals[[2]]$prior[g_ix],
               gpstart = vals[[3]]$begin[gp_ix], gpstop = vals[[3]]$end[gp_ix],
               gpwidth = vals[[3]]$width[gp_ix], gpprior = vals[[3]]$prior[gp_ix],
               log_prior = NA , log_post=NA)
      new_res$log_lik <- apply(new_res[, c("g","gp", "q")], 1, FUN = function(x){
        ll_v7(rho, x[1], x[2], x[3],
                     mix_grid$S1, mix_grid$S2, mix_grid$pi,
                     dat$beta_hat_1, dat$beta_hat_2,
                     dat$seb1, dat$seb2)
      })
      return(new_res)
    })
    nr <- do.call(rbind, nr)
    res <- res[-ix,]
    res <- rbind(res, nr)
    res <- res %>% mutate( log_prior = log(qprior) + log(gprior) + log(gpprior),
                      log_post  =  log_lik + log_prior,
                      log_post = log_post -logSumExp(log_post))

  }
  return(res)
}


adapt_grid_2d <- function(gamma_prime_range,
                         n_gamma_prime_start, n_q_start,
                         gamma_prime_prior_func, q_prior_func,
                         mix_grid, rho,
                         dat, max_post_per_bin){

  #Get gamma-vals and q-vals
  coords <- matrix(c(0, 1, gamma_prime_range), nrow=2, byrow=T)
  priors <- list(q_prior_func, gamma_prime_prior_func)
  vals <- get_vals(coords, c(n_q_start, n_gamma_prime_start), priors)

  res <- expand.grid(q_ix =seq(n_q_start), gp_ix = seq(n_gamma_prime_start)) %>%
    mutate(q = vals[[1]]$mid[q_ix], gp = vals[[2]]$mid[gp_ix],
           qstart = vals[[1]]$begin[q_ix], qstop = vals[[1]]$end[q_ix],
           qwidth=vals[[1]]$width[q_ix], qprior = vals[[1]]$prior[q_ix],
           gpstart = vals[[2]]$begin[gp_ix], gpstop = vals[[2]]$end[gp_ix],
           gpwidth = vals[[2]]$width[gp_ix], gpprior = vals[[2]]$prior[gp_ix],
           log_prior = log(qprior) + log(gpprior) ) #- 2*gamma_norm)

  res$log_lik <- apply(res[, c("gp", "q")], 1, FUN = function(x){
    ll_v7(rho, 0, x[1], x[2],
          mix_grid$S1, mix_grid$S2, mix_grid$pi,
          dat$beta_hat_1, dat$beta_hat_2,
          dat$seb1, dat$seb2)
  })
  res <- res %>% mutate(log_post  =  log_lik + log_prior,
                        log_post = log_post - logSumExp(log_post))

  thresh <- log(max_post_per_bin)
  n <- 3
  while(any(res$log_post > thresh)){
    ix <- which(res$log_post > thresh)
    nr <- lapply(ix, FUN=function(i){
      coords <- matrix(as.numeric(res[i, c("qstart", "qstop", "gpstart", "gpstop")]), nrow=2, byrow=TRUE)

      vals <- get_vals(coords, rep(n, 2), priors)

      new_res <- expand.grid(q_ix = seq(n), gp_ix = seq(n)) %>%
        mutate(q = vals[[1]]$mid[q_ix], gp = vals[[2]]$mid[gp_ix],
               qstart = vals[[1]]$begin[q_ix], qstop = vals[[1]]$end[q_ix],
               qprior = vals[[1]]$prior[q_ix], qwidth=vals[[1]]$width[q_ix],
               gpstart = vals[[2]]$begin[gp_ix], gpstop = vals[[2]]$end[gp_ix],
               gpwidth = vals[[2]]$width[gp_ix], gpprior = vals[[2]]$prior[gp_ix],
               log_prior = NA , log_post=NA)
      new_res$log_lik <- apply(new_res[, c("gp", "q")], 1, FUN = function(x){
        ll_v7(rho, 0, x[1], x[2],
              mix_grid$S1, mix_grid$S2, mix_grid$pi,
              dat$beta_hat_1, dat$beta_hat_2,
              dat$seb1, dat$seb2)
      })
      return(new_res)
    })
    nr <- do.call(rbind, nr)
    res <- res[-ix,]
    res <- rbind(res, nr)
    res <- res %>% mutate( log_prior = log(qprior) + log(gpprior),
                           log_post  =  log_lik + log_prior,
                           log_post = log_post -logSumExp(log_post))

  }
  return(res)
}


adapt_grid_1d <- function(gamma_prime_range,
                          n_gamma_prime_start,
                          gamma_prime_prior_func,
                          mix_grid, rho,
                          dat, max_post_per_bin){

  #Get gamma-vals and q-vals
  coords <- matrix(c(gamma_prime_range), nrow=1, byrow=T)
  priors <- list(gamma_prime_prior_func)
  vals <- get_vals(coords, c(n_gamma_prime_start), priors)

  res <- data.frame( gp_ix = seq(n_gamma_prime_start)) %>%
    mutate(gp = vals[[1]]$mid[gp_ix],
           gpstart = vals[[1]]$begin[gp_ix], gpstop = vals[[1]]$end[gp_ix],
           gpwidth = vals[[1]]$width[gp_ix], gpprior = vals[[1]]$prior[gp_ix],
           log_prior = log(gpprior) ) #- 2*gamma_norm)

  res$log_lik <- apply(res[, c("gp"), drop=FALSE], 1, FUN = function(x){
    ll_v7(rho, 0, x[1], 1,
          mix_grid$S1, mix_grid$S2, mix_grid$pi,
          dat$beta_hat_1, dat$beta_hat_2,
          dat$seb1, dat$seb2)
  })
  res <- res %>% mutate(log_post  =  log_lik + log_prior,
                        log_post = log_post - logSumExp(log_post))

  thresh <- log(max_post_per_bin)
  n <- 3
  while(any(res$log_post > thresh)){
    ix <- which(res$log_post > thresh)
    nr <- lapply(ix, FUN=function(i){
      coords <- matrix(as.numeric(res[i, c("gpstart", "gpstop")]), nrow=1, byrow=TRUE)

      vals <- get_vals(coords, rep(n, 1), priors)

      new_res <- data.frame(gp_ix = seq(n)) %>%
        mutate(gp = vals[[1]]$mid[gp_ix],
               gpstart = vals[[1]]$begin[gp_ix], gpstop = vals[[1]]$end[gp_ix],
               gpwidth = vals[[1]]$width[gp_ix], gpprior = vals[[1]]$prior[gp_ix],
               log_prior = NA , log_post=NA)
      new_res$log_lik <- apply(new_res[, c("gp"), drop=FALSE], 1, FUN = function(x){
        ll_v7(rho, 0, x[1], 1,
              mix_grid$S1, mix_grid$S2, mix_grid$pi,
              dat$beta_hat_1, dat$beta_hat_2,
              dat$seb1, dat$seb2)
      })
      return(new_res)
    })
    nr <- do.call(rbind, nr)
    res <- res[-ix,]
    res <- rbind(res, nr)
    res <- res %>% mutate( log_prior = log(gpprior),
                           log_post  =  log_lik + log_prior,
                           log_post = log_post -logSumExp(log_post))

  }
  return(res)
}




get_vals <- function(coords, n, priors){
  k <- nrow(coords)
  stopifnot(length(n)==k)
  stopifnot(length(priors)==k)
  vals <- list()
  for(i in seq_along(n)){
    s <- seq(coords[i,1], coords[i,2], length.out=n[i]+1)
    vals[[i]] <- data.frame("begin"=s[-(n[i]+1)], "end"=s[-1]) %>%
      mutate( width = end-begin, mid = begin + (width/2))
    vals[[i]]$prior <- apply(vals[[i]][,c("begin", "end")], 1, function(x){
      integrate(f=function(xx){priors[[i]](xx)}, lower=x[1], upper=x[2])$val})
  }
  return(vals)
}

marge_dists <- function(param_grid, params, priors, ranges){
  post_marge <- list()
  for(i in seq_along(params)){
    param <- params[i]
    r1 <- ranges[[i]][1]
    r2 <- ranges[[i]][2]

    start_nm <- paste0(param, "start")
    stop_nm <- paste0(param, "stop")
    param_grid[[start_nm]] <- round(param_grid[[start_nm]], digits=6)
    param_grid[[stop_nm]] <- round(param_grid[[stop_nm]], digits=6)
    width_nm <- paste0(param, "width")

    min_width <- min(param_grid[,width_nm])

    starts <- seq(r1, r2, by=min_width) %>% round(., digits=6)
    stopifnot(all(unique(param_grid[[start_nm]]) %in% starts))
    post_marge[[i]] <- data.frame("begin"=starts[-length(starts)], "end"=starts[-1]) %>%
    mutate( width = end-begin, mid = begin + (width/2))
    post_marge[[i]]$prior <- apply(post_marge[[i]][,c("begin", "end")], 1, function(x){
                                  integrate(f=function(xx){priors[[i]](xx)}, lower=x[1], upper=x[2])$val})

    post_marge[[i]]$post <- apply(post_marge[[i]], 1, FUN=function(x){
      strt <- x[1]
      ix <- which( param_grid[,start_nm] <= strt & param_grid[,stop_nm] > strt)
      with(param_grid[ix,], exp(logSumExp(log_post - log(get(width_nm)) + log(min_width))))
    })
  }
  return(post_marge)
}
jean997/sherlockAsh documentation built on May 18, 2019, 11:45 p.m.