#'@export
cause_grid_adapt_v7 <- function(dat, mix_grid, rho, gamma_range = c(-1, 1),
gamma_prime_range = c(-1, 1),
n_q_start = 10, n_gamma_start = 100, n_gamma_prime_start = 10,
gamma_prior_func = function(b){dnorm(b, 0, 0.6)},
gamma_prime_prior_func = function(b){dnorm(b, 0, 0.6)},
q_prior_func = function(q){dbeta(q, 0.1, 1)},
max_post_per_bin = 0.01, fix_gamma_0 = FALSE,
fix_gamma_0_and_q_1=FALSE){
if(missing(mix_grid)| missing(rho)){
stop("For now please provide mix_grid and rho\n")
}
if(fix_gamma_0_and_q_1 & fix_gamma_0){
stop("Please use only one of fix_gamma_0_and_q_1 or fix_gamma_0\n")
}
if(missing(dat)){
stop("Please provide dat.\n")
}
if(fix_gamma_0_and_q_1){
param_grid <- adapt_grid_1d(gamma_prime_range,
n_gamma_prime_start,
gamma_prime_prior_func,
mix_grid, rho,
dat, max_post_per_bin)
param_grid <- param_grid %>% mutate(gpstart = round(gpstart, digits=6), gpstop=round(gpstop, digits=6))%>%
mutate(g = 0, gstart = 0, gstop = 0, q = 1, qstart=1, qstop=1)
params <- c("gp")
priors <- list(gamma_prime_prior_func)
ranges <- list(gamma_prime_range)
}else if(fix_gamma_0){
param_grid <- adapt_grid_2d(gamma_prime_range,
n_gamma_prime_start, n_q_start,
gamma_prime_prior_func, q_prior_func,
mix_grid, rho,
dat, max_post_per_bin)
param_grid <- param_grid %>% mutate(qstart = round(qstart, digits=6), qstop=round(qstop, digits=6),
gpstart = round(gpstart, digits=6), gpstop=round(gpstop, digits=6))%>%
mutate(g = 0, gstart = 0, gstop = 0)
params <- c("q", "gp")
priors <- list(q_prior_func, gamma_prime_prior_func)
ranges <- list(c(0, 1), gamma_prime_range)
}else{
param_grid <- adapt_grid_3d(gamma_range, gamma_prime_range,
n_gamma_start, n_gamma_prime_start, n_q_start,
gamma_prior_func, gamma_prime_prior_func, q_prior_func,
mix_grid, rho,
dat, max_post_per_bin)
param_grid <- param_grid %>%
mutate(qstart = round(qstart, digits=6), qstop=round(qstop, digits=6),
gstart = round(gstart, digits=6), gstop=round(gstop, digits=6),
gpstart = round(gpstart, digits=6), gpstop=round(gpstop, digits=6))
params <- c("q", "g", "gp")
priors <- list(q_prior_func, gamma_prior_func, gamma_prime_prior_func)
ranges <- list(c(0, 1), gamma_range, gamma_prime_range)
}
post_marge <- marge_dists(param_grid, params, priors, ranges)
return(list("post" = param_grid, "post_marge" = post_marge, "mix_grid" = mix_grid, params = params,
"gamma_prior_func" = deparse(gamma_prior_func), "q_prior_func" = deparse(q_prior_func),
"gamma_prime_prior_func" = deparse(gamma_prime_prior_func),
"rho" = rho, "ranges"=ranges))
}
adapt_grid_3d <- function(gamma_range, gamma_prime_range,
n_gamma_start, n_gamma_prime_start, n_q_start,
gamma_prior_func, gamma_prime_prior_func, q_prior_func,
mix_grid, rho,
dat, max_post_per_bin){
#Get gamma-vals and q-vals
coords <- matrix(c(0, 1, gamma_range, gamma_prime_range), nrow=3, byrow=T)
priors <- list(q_prior_func, gamma_prior_func, gamma_prime_prior_func)
vals <- get_vals(coords, c(n_q_start, n_gamma_start, n_gamma_prime_start), priors)
res <- expand.grid(q_ix =seq(n_q_start), g_ix=seq(n_gamma_start), gp_ix = seq(n_gamma_prime_start)) %>%
mutate(q = vals[[1]]$mid[q_ix], g = vals[[2]]$mid[g_ix], gp = vals[[3]]$mid[gp_ix],
qstart = vals[[1]]$begin[q_ix], qstop = vals[[1]]$end[q_ix],
qwidth=vals[[1]]$width[q_ix], qprior = vals[[1]]$prior[q_ix],
gstart = vals[[2]]$begin[g_ix], gstop = vals[[2]]$end[g_ix],
gwidth = vals[[2]]$width[g_ix], gprior = vals[[2]]$prior[g_ix],
gpstart = vals[[3]]$begin[gp_ix], gpstop = vals[[3]]$end[gp_ix],
gpwidth = vals[[3]]$width[gp_ix], gpprior = vals[[3]]$prior[gp_ix],
log_prior = log(qprior) + log(gprior) + log(gpprior) ) #- 2*gamma_norm)
res$log_lik <- apply(res[, c("g","gp", "q")], 1, FUN = function(x){
ll_v7(rho, x[1], x[2], x[3],
mix_grid$S1, mix_grid$S2, mix_grid$pi,
dat$beta_hat_1, dat$beta_hat_2,
dat$seb1, dat$seb2)
})
res <- res %>% mutate(log_post = log_lik + log_prior,
log_post = log_post - logSumExp(log_post))
thresh <- log(max_post_per_bin)
n <- 3
while(any(res$log_post > thresh)){
ix <- which(res$log_post > thresh)
nr <- lapply(ix, FUN=function(i){
coords <- matrix(as.numeric(res[i, c("qstart", "qstop", "gstart", "gstop", "gpstart", "gpstop")]), nrow=3, byrow=TRUE)
vals <- get_vals(coords, rep(n, 3), priors)
new_res <- expand.grid(q_ix = seq(n), g_ix=seq(n), gp_ix = seq(n)) %>%
mutate(q = vals[[1]]$mid[q_ix], g = vals[[2]]$mid[g_ix], gp = vals[[3]]$mid[gp_ix],
qstart = vals[[1]]$begin[q_ix], qstop = vals[[1]]$end[q_ix],
qprior = vals[[1]]$prior[q_ix], qwidth=vals[[1]]$width[q_ix],
gstart = vals[[2]]$begin[g_ix], gstop = vals[[2]]$end[g_ix],
gwidth = vals[[2]]$width[g_ix], gprior = vals[[2]]$prior[g_ix],
gpstart = vals[[3]]$begin[gp_ix], gpstop = vals[[3]]$end[gp_ix],
gpwidth = vals[[3]]$width[gp_ix], gpprior = vals[[3]]$prior[gp_ix],
log_prior = NA , log_post=NA)
new_res$log_lik <- apply(new_res[, c("g","gp", "q")], 1, FUN = function(x){
ll_v7(rho, x[1], x[2], x[3],
mix_grid$S1, mix_grid$S2, mix_grid$pi,
dat$beta_hat_1, dat$beta_hat_2,
dat$seb1, dat$seb2)
})
return(new_res)
})
nr <- do.call(rbind, nr)
res <- res[-ix,]
res <- rbind(res, nr)
res <- res %>% mutate( log_prior = log(qprior) + log(gprior) + log(gpprior),
log_post = log_lik + log_prior,
log_post = log_post -logSumExp(log_post))
}
return(res)
}
adapt_grid_2d <- function(gamma_prime_range,
n_gamma_prime_start, n_q_start,
gamma_prime_prior_func, q_prior_func,
mix_grid, rho,
dat, max_post_per_bin){
#Get gamma-vals and q-vals
coords <- matrix(c(0, 1, gamma_prime_range), nrow=2, byrow=T)
priors <- list(q_prior_func, gamma_prime_prior_func)
vals <- get_vals(coords, c(n_q_start, n_gamma_prime_start), priors)
res <- expand.grid(q_ix =seq(n_q_start), gp_ix = seq(n_gamma_prime_start)) %>%
mutate(q = vals[[1]]$mid[q_ix], gp = vals[[2]]$mid[gp_ix],
qstart = vals[[1]]$begin[q_ix], qstop = vals[[1]]$end[q_ix],
qwidth=vals[[1]]$width[q_ix], qprior = vals[[1]]$prior[q_ix],
gpstart = vals[[2]]$begin[gp_ix], gpstop = vals[[2]]$end[gp_ix],
gpwidth = vals[[2]]$width[gp_ix], gpprior = vals[[2]]$prior[gp_ix],
log_prior = log(qprior) + log(gpprior) ) #- 2*gamma_norm)
res$log_lik <- apply(res[, c("gp", "q")], 1, FUN = function(x){
ll_v7(rho, 0, x[1], x[2],
mix_grid$S1, mix_grid$S2, mix_grid$pi,
dat$beta_hat_1, dat$beta_hat_2,
dat$seb1, dat$seb2)
})
res <- res %>% mutate(log_post = log_lik + log_prior,
log_post = log_post - logSumExp(log_post))
thresh <- log(max_post_per_bin)
n <- 3
while(any(res$log_post > thresh)){
ix <- which(res$log_post > thresh)
nr <- lapply(ix, FUN=function(i){
coords <- matrix(as.numeric(res[i, c("qstart", "qstop", "gpstart", "gpstop")]), nrow=2, byrow=TRUE)
vals <- get_vals(coords, rep(n, 2), priors)
new_res <- expand.grid(q_ix = seq(n), gp_ix = seq(n)) %>%
mutate(q = vals[[1]]$mid[q_ix], gp = vals[[2]]$mid[gp_ix],
qstart = vals[[1]]$begin[q_ix], qstop = vals[[1]]$end[q_ix],
qprior = vals[[1]]$prior[q_ix], qwidth=vals[[1]]$width[q_ix],
gpstart = vals[[2]]$begin[gp_ix], gpstop = vals[[2]]$end[gp_ix],
gpwidth = vals[[2]]$width[gp_ix], gpprior = vals[[2]]$prior[gp_ix],
log_prior = NA , log_post=NA)
new_res$log_lik <- apply(new_res[, c("gp", "q")], 1, FUN = function(x){
ll_v7(rho, 0, x[1], x[2],
mix_grid$S1, mix_grid$S2, mix_grid$pi,
dat$beta_hat_1, dat$beta_hat_2,
dat$seb1, dat$seb2)
})
return(new_res)
})
nr <- do.call(rbind, nr)
res <- res[-ix,]
res <- rbind(res, nr)
res <- res %>% mutate( log_prior = log(qprior) + log(gpprior),
log_post = log_lik + log_prior,
log_post = log_post -logSumExp(log_post))
}
return(res)
}
adapt_grid_1d <- function(gamma_prime_range,
n_gamma_prime_start,
gamma_prime_prior_func,
mix_grid, rho,
dat, max_post_per_bin){
#Get gamma-vals and q-vals
coords <- matrix(c(gamma_prime_range), nrow=1, byrow=T)
priors <- list(gamma_prime_prior_func)
vals <- get_vals(coords, c(n_gamma_prime_start), priors)
res <- data.frame( gp_ix = seq(n_gamma_prime_start)) %>%
mutate(gp = vals[[1]]$mid[gp_ix],
gpstart = vals[[1]]$begin[gp_ix], gpstop = vals[[1]]$end[gp_ix],
gpwidth = vals[[1]]$width[gp_ix], gpprior = vals[[1]]$prior[gp_ix],
log_prior = log(gpprior) ) #- 2*gamma_norm)
res$log_lik <- apply(res[, c("gp"), drop=FALSE], 1, FUN = function(x){
ll_v7(rho, 0, x[1], 1,
mix_grid$S1, mix_grid$S2, mix_grid$pi,
dat$beta_hat_1, dat$beta_hat_2,
dat$seb1, dat$seb2)
})
res <- res %>% mutate(log_post = log_lik + log_prior,
log_post = log_post - logSumExp(log_post))
thresh <- log(max_post_per_bin)
n <- 3
while(any(res$log_post > thresh)){
ix <- which(res$log_post > thresh)
nr <- lapply(ix, FUN=function(i){
coords <- matrix(as.numeric(res[i, c("gpstart", "gpstop")]), nrow=1, byrow=TRUE)
vals <- get_vals(coords, rep(n, 1), priors)
new_res <- data.frame(gp_ix = seq(n)) %>%
mutate(gp = vals[[1]]$mid[gp_ix],
gpstart = vals[[1]]$begin[gp_ix], gpstop = vals[[1]]$end[gp_ix],
gpwidth = vals[[1]]$width[gp_ix], gpprior = vals[[1]]$prior[gp_ix],
log_prior = NA , log_post=NA)
new_res$log_lik <- apply(new_res[, c("gp"), drop=FALSE], 1, FUN = function(x){
ll_v7(rho, 0, x[1], 1,
mix_grid$S1, mix_grid$S2, mix_grid$pi,
dat$beta_hat_1, dat$beta_hat_2,
dat$seb1, dat$seb2)
})
return(new_res)
})
nr <- do.call(rbind, nr)
res <- res[-ix,]
res <- rbind(res, nr)
res <- res %>% mutate( log_prior = log(gpprior),
log_post = log_lik + log_prior,
log_post = log_post -logSumExp(log_post))
}
return(res)
}
get_vals <- function(coords, n, priors){
k <- nrow(coords)
stopifnot(length(n)==k)
stopifnot(length(priors)==k)
vals <- list()
for(i in seq_along(n)){
s <- seq(coords[i,1], coords[i,2], length.out=n[i]+1)
vals[[i]] <- data.frame("begin"=s[-(n[i]+1)], "end"=s[-1]) %>%
mutate( width = end-begin, mid = begin + (width/2))
vals[[i]]$prior <- apply(vals[[i]][,c("begin", "end")], 1, function(x){
integrate(f=function(xx){priors[[i]](xx)}, lower=x[1], upper=x[2])$val})
}
return(vals)
}
marge_dists <- function(param_grid, params, priors, ranges){
post_marge <- list()
for(i in seq_along(params)){
param <- params[i]
r1 <- ranges[[i]][1]
r2 <- ranges[[i]][2]
start_nm <- paste0(param, "start")
stop_nm <- paste0(param, "stop")
param_grid[[start_nm]] <- round(param_grid[[start_nm]], digits=6)
param_grid[[stop_nm]] <- round(param_grid[[stop_nm]], digits=6)
width_nm <- paste0(param, "width")
min_width <- min(param_grid[,width_nm])
starts <- seq(r1, r2, by=min_width) %>% round(., digits=6)
stopifnot(all(unique(param_grid[[start_nm]]) %in% starts))
post_marge[[i]] <- data.frame("begin"=starts[-length(starts)], "end"=starts[-1]) %>%
mutate( width = end-begin, mid = begin + (width/2))
post_marge[[i]]$prior <- apply(post_marge[[i]][,c("begin", "end")], 1, function(x){
integrate(f=function(xx){priors[[i]](xx)}, lower=x[1], upper=x[2])$val})
post_marge[[i]]$post <- apply(post_marge[[i]], 1, FUN=function(x){
strt <- x[1]
ix <- which( param_grid[,start_nm] <= strt & param_grid[,stop_nm] > strt)
with(param_grid[ix,], exp(logSumExp(log_post - log(get(width_nm)) + log(min_width))))
})
}
return(post_marge)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.