R/cause_grid_adapt2_v7.R

#'@export
cause_grid_adapt2_v7 <-  function(dat, mix_grid, rho,
                                  params = c("q", "gamma", "eta"),
                                  priors = list(function(q){dbeta(q, 1, 10)},
                                                function(b){dnorm(b, 0, 0.6)},
                                                function(b){dnorm(b, 0, 0.6)}),
                                  n_start = c(10, 50, 10),
                                  ranges = list(c(0, 1), c(-1, 1), c(-1, 1)),
                                  range_fixed = c(TRUE, FALSE, FALSE),
                                  max_post_per_bin = 0.001){


  if(missing(mix_grid)| missing(rho)){
    stop("For now please provide mix_grid and rho\n")
  }
  if(missing(dat)){
    stop("Please provide dat.\n")
  }
  param_grid <- adapt2_grid(params, ranges, priors, n_start,
                            range_fixed,
                            mix_grid, rho,
                            dat, max_post_per_bin)
  ranges <- param_grid$ranges
  param_grid <- param_grid$res
  post_marge <- marge_dists(param_grid, params, priors, ranges)

  return(list("post" = param_grid, "post_marge" = post_marge, "mix_grid" = mix_grid, params = params,
              "rho" = rho, "ranges"=ranges, params = params, n_start = n_start, max_post_per_bin = max_post_per_bin))

}

adapt2_grid <- function(params, ranges, priors, n_start,
                        range_fixed,
                        mix_grid, rho,
                        dat, max_post_per_bin,
                        end_bin_thresh=1e-8){

  k <- length(params)
  stopifnot(length(ranges) == k)
  stopifnot(length(priors) == k)
  stopifnot(length(n_start)==k)
  stopifnot(length(range_fixed)==k)

  vals <- get_vals2(ranges, n_start, priors)
  res <- post_from_vals(vals, params, dat, mix_grid, rho)
  post_marge <- sherlockAsh:::marge_dists(res, params, priors, ranges)
  range_set <- FALSE
  widths <- sapply(seq_along(params), function(i){(ranges[[i]][2]-ranges[[i]][1])/n_start[i]})
  n_add <- ceiling(n_start/2)
  #Set range first then refine the grid
  while(!range_set){
    range_set <- TRUE
    for(i in seq_along(params)){
      if(range_fixed[i]) next
      n_new <- n_start
      n_new[i] <- n_add[i]
      if(with(post_marge[[i]], post[which.min(mid)]) > end_bin_thresh){
        new_ranges <- ranges
        new_ranges[[i]][1] <- ranges[[i]][1]- n_add[i]*widths[i]
        new_ranges[[i]][2] <- ranges[[i]][1]
        ranges[[i]][1] <- new_ranges[[i]][1]

        vals <- get_vals2(new_ranges, n_new, priors)
        new_res <- post_from_vals(vals, params, dat, mix_grid, rho)
        res <- rbind(res, new_res)
        res <- res %>% mutate(log_post  =  log_lik + log_prior,
                              log_post = log_post - logSumExp(log_post))

        range_set <- FALSE
	      n_start[i] <- n_start[i] + n_add[i]
      }
      if(with(post_marge[[i]], post[which.max(mid)]) > end_bin_thresh ){
        new_ranges <- ranges
        new_ranges[[i]][2] <- ranges[[i]][2] +  n_add[i]*widths[i]
        new_ranges[[i]][1] <- ranges[[i]][2]
        ranges[[i]][2] <- new_ranges[[i]][2]
        vals <- get_vals2(new_ranges, n_new, priors)
        new_res <- post_from_vals(vals, params, dat, mix_grid, rho)
        res <- rbind(res, new_res)
        res <- res %>% mutate(log_post  =  log_lik + log_prior,
                              log_post = log_post - logSumExp(log_post))

        range_set <- FALSE
	      n_start[i] <- n_start[i] + n_add[i]
      }
    }
    post_marge <-marge_dists(res, params, priors, ranges)
  }
  #Refine grid
  thresh <- log(max_post_per_bin)
  n <- 3
  while(any(res$log_post > thresh)){
    ix <- which(res$log_post > thresh)
    nbrs <- unique(unlist(get_neighbors(ix, res, params)))
    ix <- unique(c(ix, nbrs))
    nr <- map_df(ix, function(i){
      new_ranges <- list()
      for(j in seq_along(params)){
        new_ranges[[j]] <- as.numeric(res[i, paste0(params[j], c("start", "stop"))])
      }
      vals <- sherlockAsh:::get_vals2(new_ranges, rep(n, length(params)), priors)
      new_res <-post_from_vals(vals, params, dat, mix_grid, rho)
      return(new_res)
    })
    res <- res[-ix,]
    res <- rbind(res, nr)
    res <- res %>% mutate(log_post  =  log_lik + log_prior,
                          log_post = log_post -logSumExp(log_post),
                          norm_log_post = log_post - log_volume)

  }
  R <- list(res = res, ranges = ranges)
  return(R)
}

#Takes in vals as produced by get_vals2 and outputs a data frame with all combos of starts/stops for each parameter
#head(res, 3)
#Var1 Var2 Var3          q     qstart      qstop     qwidth     qprior        gamma  gammastart  gammastop gammawidth  gammaprior
#5129    3    1    1 0.08333333 0.06666667 0.10000000 0.03333333 0.15293339  0.006666667  0.00000000 0.01333333 0.01333333 0.008864654
#5118    1    3    2 0.01666667 0.00000000 0.03333333 0.03333333 0.28752861 -0.006666667 -0.01333333 0.00000000 0.01333333 0.008864654
#5199    1    2    2 0.21666667 0.20000000 0.23333333 0.03333333 0.03721802  0.060000000  0.05333333 0.06666667 0.01333333 0.008820988
#eta   etastart     etastop   etawidth   etaprior   log_lik  log_prior  log_post log_volume norm_log_post
#5129 -0.1666667 -0.2000000 -0.13333333 0.06666667 0.04262911 -133.2121  -9.758654 -6.930810  -10.42674      3.495925
#5118 -0.1000000 -0.1333333 -0.06666667 0.06666667 0.04369367 -133.8830  -9.102668 -6.945751  -10.42674      3.480985
#5199 -0.1000000 -0.1333333 -0.06666667 0.06666667 0.04369367 -131.8405 -11.152136 -6.952689  -10.42674      3.474047
post_from_vals <- function(vals, params, dat, mix_grid, rho){
  ix <- lapply(vals, function(x){seq(nrow(x))})
  res <- expand.grid(ix)
  for(i in seq_along(params)){
    res[[params[i]]] <- vals[[i]]$mid[ res[[paste0("Var", i)]]  ]
    res[[paste0(params[i], "start")]] <- vals[[i]]$begin[ res[[paste0("Var", i)]]  ]
    res[[paste0(params[i], "stop")]] <- vals[[i]]$end[ res[[paste0("Var", i)]]  ]
    res[[paste0(params[i], "width")]] <- vals[[i]]$width[ res[[paste0("Var", i)]]  ]
    res[[paste0(params[i], "prior")]] <- vals[[i]]$prior[ res[[paste0("Var", i)]]  ]
  }
  full_params <- c("gamma", "q", "eta")
  for(p in full_params){
    if(! p %in% params){
      res[[p]] <- 0
      res[[paste0(p, "start")]] <- 0
      res[[paste0(p, "stop")]] <- 0
    }
  }
  res[["log_lik"]] <- apply(res[, c("gamma","eta", "q")], 1, FUN = function(x){
    ll_v7(rho, x[1], x[2], x[3],
        mix_grid$S1, mix_grid$S2, mix_grid$pi,
        dat$beta_hat_1, dat$beta_hat_2,
        dat$seb1, dat$seb2)
  })
  res$log_prior <- Reduce("+", lapply(params, function(p){log(res[[paste0(p, "prior")]])}))
  res$log_volume <- Reduce("+", lapply(params, function(p){log(res[[paste0(p, "width")]])}))
  res <- res %>% mutate(log_post  =  log_lik + log_prior,
                        log_post = log_post - logSumExp(log_post))
  res$norm_log_post <- with(res, log_post-log_volume)
  return(res)
}

#Vals is a list of data frames with starts, stops, priors in each dimension
#length(vals) == length(params)
#lapply(vals, function(x){x[1:2,]})
#[[1]]
#begin end width  mid     prior
#1   0.0 0.1   0.1 0.05 0.6513216
#2   0.1 0.2   0.1 0.15 0.2413043
#
#[[2]]
#begin   end width   mid       prior
#1 -1.00 -0.96  0.04 -0.98 0.007008939
#2 -0.96 -0.92  0.04 -0.94 0.007797581
#
#[[3]]
#begin  end width  mid      prior
#1  -1.0 -0.8   0.2 -0.9 0.04342087
#2  -0.8 -0.6   0.2 -0.7 0.06744403
get_vals2 <- function(ranges, n, priors){
  k <- length(ranges)
  stopifnot(length(n)==k)
  stopifnot(length(priors)==k)
  vals <- list()
  for(i in seq_along(n)){
    s <- seq(ranges[[i]][1], ranges[[i]][2], length.out=n[i]+1)
    vals[[i]] <- data.frame("begin"=s[-(n[i]+1)], "end"=s[-1]) %>%
      mutate( width = end-begin, mid = begin + (width/2))
    vals[[i]]$prior <- apply(vals[[i]][,c("begin", "end")], 1, function(x){
      integrate(f=function(xx){priors[[i]](xx)}, lower=x[1], upper=x[2])$val})
  }
  return(vals)
}

get_neighbors <- function(ix, post, params){
  ivls <- map(params, function(p){
    Intervals(post[, paste0(p, c("start", "stop"))])
  })
  nbrs <- map(ivls, function(x){
    interval_overlap(x[ix,], x)
  }) %>% purrr::transpose() %>%
    map(., function(x){Reduce(dplyr::intersect, x)})
  return(nbrs)
}


neighbor_difference <- function(post, params, ix, thresh){
  ivls <- lapply(params, function(p){
    Intervals(post[, paste0(p, c("start", "stop"))])
  })
  ivl_overlap <- lapply(ivls, function(x){
    unlist(interval_overlap(x[ix,], x))
  })
  nbrs <- Reduce(dplyr::intersect, ivl_overlap)
  d <- post$norm_log_post[ix]-post$norm_log_post[nbrs]
  return(data.frame(ix=nbrs, dist=d))
  if(any(abs(d) > thresh)){
    return(c(ix, nbrs[abs(d) > thresh]))
  }else{
    return(NULL)
  }
}
jean997/sherlockAsh documentation built on May 18, 2019, 11:45 p.m.