R/fit_cause.R

#'@title Fit CAUE
#'@param dat Data
#'@param mix_grid Mixing parameters
#'@param rho rho
#'@param sigma_g,qalpha,qbeta Specify priors for q and gamma/eta
#'@return A list with items conf, full, elpd, summary, and plot.
#'@export
fit_cause <- function(dat, mix_grid, rho,
                      sigma_g = 0.6, qalpha = 1, qbeta=10){
  #fit1 <- cause_grid_adapt2_v7(dat, mix_grid=mix_grid, rho=rho,
  #                             max_post_per_bin = 0.001,
  #                             params = c("gamma"),
  #                             priors = list(function(b){dnorm(b, 0, sigma_g)}),
  #                             n_start = c(20),
  #                             ranges = list(c(-1, 1)),
  #                             range_fixed = c(FALSE))

  fit2 <- cause_grid_adapt2_v7(dat, mix_grid=mix_grid, rho=rho,
                               max_post_per_bin = 0.001,
                               params = c("eta", "q"),
                               priors = list(function(b){dnorm(b, 0, sigma_g)},
                                             function(q){dbeta(q, qalpha, qbeta)}),
                               n_start = c(20, 10),
                               ranges = list(c(-1, 1), c(0, 1)),
                               range_fixed = c(FALSE, TRUE))

  fit3 <- cause_grid_adapt2_v7(dat, mix_grid=mix_grid, rho=rho,
                               max_post_per_bin = 0.001,
                               params = c("gamma", "eta", "q"),
                               priors = list(function(b){dnorm(b, 0, sigma_g)},
                                             function(b){dnorm(b, 0, sigma_g)},
                                             function(q){dbeta(q, qalpha, qbeta)}),
                               n_start = c(20, 20, 10),
                               ranges = list(c(-1, 1), c(-1, 1), c(0, 1)),
                               range_fixed = c(FALSE, FALSE, TRUE))
  fit0 <- list("post"=NULL, rho = rho, mix_grid=mix_grid)
  fits <- list("null"=fit0, "conf"=fit2, "full" = fit3)
  elpd <- in_sample_elpd_loo(dat, fits)
  res <- list("conf"=fit2, "full" = fit3, elpd=elpd)
  plts <- lapply(1:2, function(i){
    fit <- res[[i]]
    post_marge <- lapply(seq_along(fit$post_marge), function(i){
      p <- fit$post_marge[[i]]
      p$param <- fit$params[i]
      return(p)})
    post_marge <- do.call(rbind, post_marge)
    post_marge <- select(post_marge, mid, width, param, post, prior) %>%
      gather("dist", "pdf", -mid, -param, -width)
    post_marge$param <- factor(post_marge$param, levels =c("gamma", "eta", "q"))
    plt <- ggplot(post_marge) + geom_line(aes(x=mid, y=pdf/width, linetype=dist)) +
      xlab("parameter value") +
      theme_bw() + theme(legend.position="none",
                         axis.title.y=element_blank(),
                         axis.title.x=element_blank()) +
      facet_wrap(~param, scale="free")
    return(plt)})

  tab <- sapply(1:2, function(i){
    fit <- res[[i]]
    gamma <- eta <- q <- NA
    full_params <- c("gamma", "eta", "q")
    tt <- c()
    for(j in seq_along(full_params)){
      if(!full_params[j] %in% fit$params){
        tt[j] <- NA
      }else{
        ix <- which(fit$params == full_params[j])
        med <- with(fit$post_marge[[ix]], step_quantile(0.5, begin, end, post))
        x025 <- with(fit$post_marge[[ix]], step_quantile(0.025, begin, end, post))
        x975 <- with(fit$post_marge[[ix]], step_quantile(0.975, begin, end, post))
        tt[j]  <- paste0(round(med, digits=2),
                         " (", round(x025, digits=2), ", ",
                         round(x975, digits=2), ")")
      }
    }
    names(tt) <- full_params
    return(tt)
  })
  tab <- data.frame(t(tab))
  res$summary <- tab
  plts[[3]] <- tableGrob(tab)
  elpd <- res$elpd
  elpd <- elpd %>% mutate(delta_elpd = round(delta_elpd, digits=2),
                          se_delta_elpd = round(se_delta_elpd, digits=2),
                          z = round(z, digits=2)) %>%
          rename(se = se_delta_elpd)
  plts[[4]] <- tableGrob(elpd)
  #h <- arrangeGrob(grobs = plts,
  #                 layout_matrix = rbind(c(4, 4, 4, 3, 3, 3), c(NA, NA, 1, 1, 1, 1), c(2, 2, 2, 2,2,2)))
  res$plts <- plts
  return(res)
}
jean997/sherlockAsh documentation built on May 18, 2019, 11:45 p.m.