R/plot-functions.R

Defines functions .gg_multibatch_copynumber .gg_multibatch_pooled multibatch_figure sb_predictive sb_pred_data .predictiveDataTable mb_predictive mb_pred_data .gg_multibatch predictiveTibble

predictiveTibble <- function(object){
    value <- NULL
    pred <- predictive(object)
    pred2 <- pred %>%
        longFormatKB(K=k(object), B=numBatch(object)) %>%
        set_colnames(c("s", "oned", "batch", "component")) %>%
        mutate(model=modelName(object)) %>%
        mutate(batch=as.character(batch),
               batch=gsub("batch ", "", batch)) %>%
        select(-component) ## component labels are wrong
    zz <- zstar(object) %>%
        longFormatKB(K=k(object), B=numBatch(object)) %>%
        mutate(component=factor(value))
    pred2$component <- zz$component
    pred2
}

## .meltMultiBatchChains <- function(model){
##   ch <- chains(model)
##   th <- as.data.frame(theta(ch))
##   th$param <- "theta"
##   th$iter <- factor(1:nrow(th))
##   ##
##   ## to suppress annoying notes
##   ##
##   param <- iter <- NULL
##   th.m <- gather(th, key="variable", values=-c(param, iter)) %>%
##     as_tibble
##   ##
##   ##
##   ##
##   K <- k(model)
##   B <- nBatch(model)
##   btch <- matrix(uniqueBatch(model), B, K, byrow=FALSE)
##   btch <- rep(as.character(btch), each=iter(model))
##   comp <- matrix(1:K, B, K, byrow=TRUE)
##   comp <- rep(as.numeric(comp), each=iter(model))
##   th.m$batch <- factor(btch)
##   th.m$comp <- factor(paste0("k=", comp))
##   th.m$iter <- as.integer(th.m$iter)
## 
##   s <- as.data.frame(sigma(ch))
##   s$iter <- th$iter
##   s$param <- "sigma"
##   . <- value <- NULL
##   s.m <- gather(s, key="variable", values=-c(param, iter)) %>%
##     as_tibble
##   s.m$batch <- th.m$batch
##   s.m$comp <- th.m$comp
## 
##   nu0 <- as.data.frame(nu.0(ch))
##   nu0$iter <- th$iter
##   nu0$param <- "nu0"
##   nu0.m <- melt(nu0)
## 
##   s20 <- as.data.frame(sigma2.0(ch))
##   s20$iter <- th$iter
##   s20$param <- "s20"
##   s20.m <- melt(s20)
## 
##   mus <- as.data.frame(mu(ch))
##   mus$iter <- th$iter
##   mus$param <- "mu"
##   mus.m <- melt(mus)
##   mu.comp <- rep(seq_len(k(model)), each=iter(model))
##   mus.m$comp <- factor(paste0("k=", mu.comp))
## 
##   prob <- as.data.frame(p(ch))
##   prob$iter <- th$iter
##   prob$param <- "p"
##   prob.m <- melt(prob)
##   prob.m$comp <- mus.m$comp
## 
##   taus <- as.data.frame(tau(ch))
##   taus$iter <- th$iter
##   taus$param <- "tau"
##   taus.m <- melt(taus)
##   ##taus.m$comp <- factor(rep(seq_len(k(model)), each=iter(model)))
##   taus.m$comp <- mus.m$comp
##   prob.m$iter <- taus.m$iter <- mus.m$iter <- as.integer(mus.m$iter)
## 
##   dat.batch <- rbind(th.m, s.m)
##   dat.comp <- rbind(mus.m, taus.m, prob.m)
##   dat <- rbind(nu0.m, s20.m)
##   dat$iter <- as.integer(dat$iter)
##   list(batch=dat.batch, comp=dat.comp, single=dat)
## }
## 
## .meltMultiBatchPooledChains <- function(model){
##   ch <- chains(model)
##   th <- as.data.frame(theta(ch))
##   th$param <- "theta"
##   th$iter <- factor(1:nrow(th))
##   . <- param <- iter <- NULL
##   th.m <- gather(th, key="variable", values=-c(param, iter)) %>%
##     as_tibble
##   ##
##   ##
##   ##
##   K <- k(model)
##   B <- nBatch(model)
##   btch <- matrix(uniqueBatch(model), B, K, byrow=FALSE)
##   btch <- rep(as.character(btch), each=iter(model))
##   comp <- matrix(1:K, B, K, byrow=TRUE)
##   comp <- rep(as.numeric(comp), each=iter(model))
##   th.m$batch <- factor(btch)
##   th.m$comp <- factor(paste0("k=", comp))
##   th.m$iter <- as.integer(th.m$iter)
## 
##   value <- NULL
##   s <- as.data.frame(sigma(ch))
##   s.m <- s %>% as_tibble %>%
##     mutate(iter=seq_len(iter(model))) %>%
##     gather(param, value, -iter) %>%
##     mutate(param="sigma",
##            iter=factor(iter),
##            batch=factor(rep(uniqueBatch(model), each=iter(model))))
## 
##   nu0.m <- as_tibble(nu.0(ch)) %>%
##     mutate(iter=th$iter,
##            param="nu0")
##   s20.m <- as_tibble(sigma2.0(ch)) %>%
##     mutate(iter=th$iter,
##            param="s20")
## 
##   mus.m <- as_tibble(mu(ch)) %>% 
##     mutate(iter=th$iter) %>%
##     gather(param, value, -iter) %>%
##     mutate(param="mu",
##            comp=rep(seq_len(k(model)), each=iter(model))) %>%
##     mutate(comp=factor(paste0("k=", .$comp)))
## 
##   prob.m <- as_tibble(p(ch)) %>%
##     mutate(iter=th$iter) %>%
##     gather(param, value, -iter) %>%
##     mutate(param="p",
##            comp=mus.m$comp)
## 
##   taus.m <- as_tibble(tau(ch)) %>%
##     mutate(iter=th$iter) %>%
##     gather(param, value, -iter) %>%
##     mutate(param="tau",
##            comp=mus.m$comp)
## 
##   ##dat.batch <- bind_rows(th.m, s.m)
##   dat.comp <- rbind(mus.m, taus.m, prob.m)
##   dat <- bind_rows(nu0.m, s20.m)
##   dat$iter <- as.integer(dat$iter)
##   list(theta=th.m,
##        sigma=s.m,
##        comp=dat.comp,
##        single=dat)
## }
## 
## 
## meltSingleBatchChains <- function(model){
##   parameter <- NULL
##   ch <- chains(model)
##   th <- as_tibble(theta(ch)) %>%
##     set_colnames(paste0("theta", seq_len(k(model))))
##   th$iter <- 1:nrow(th)
##   th.m <- gather(th, key="parameter", value="value", -iter) %>%
##     mutate(comp=gsub("theta", "", parameter))
##   sig <- as_tibble(sigma(ch)) %>%
##     set_colnames(paste0("sigma", seq_len(k(model))))
##   if(class(model) == "SingleBatchPooled"){
##     sig$iter <- 1:nrow(sig)
##   } else sig$iter <- th$iter
##   s.m <- gather(sig, key="parameter", value="value", -iter) %>%
##     mutate(comp=gsub("sigma", "", parameter))
##   nu0 <- as_tibble(nu.0(ch)) %>%
##     set_colnames("nu0")
##   nu0$iter <- th$iter
##   nu0.m <- gather(nu0, key="parameter", value="value", -iter)
## 
##   s20 <- as_tibble(sigma2.0(ch)) %>%
##     set_colnames("s20")
##   s20$iter <- th$iter
##   s20.m <- gather(s20, key="parameter", value="value", -iter)
## 
##   mus <- as_tibble(mu(ch))%>%
##     set_colnames("mu")
##   mus$iter <- th$iter
##   mus.m <- gather(mus, key="parameter", value="value", -iter)
## 
##   prob <- as_tibble(p(ch)) %>%
##     set_param_names("p")
##   prob$iter <- th$iter
##   prob.m <- gather(prob, key="parameter", value="value", -iter) %>%
##     mutate(comp=gsub("p", "", parameter))
## 
##   taus <- as_tibble(tau(ch)) %>%
##     set_colnames("tau")
##   taus$iter <- th$iter
##   taus.m <- gather(taus, key="parameter", value="value", -iter)
## 
##   ll <- as_tibble(log_lik(ch)) %>%
##     set_colnames("log lik") %>%
##     mutate(iter=th$iter) %>%
##     gather(key="parameter", value="value", -iter)
##   dat.comp <- bind_rows(th.m,
##                         s.m,
##                         prob.m)
## 
##   dat <- bind_rows(nu0.m,
##                    s20.m,
##                    mu=mus.m,
##                    tau=taus.m,
##                    loglik=ll)
##   list(comp=dat.comp,
##        single=dat)
## }
## 
## meltSingleBatchPooledChains <- function(model){
##   ch <- chains(model)
##   th <- as_tibble(theta(ch)) %>%
##     set_colnames(paste0("theta", seq_len(k(model))))
##   th$iter <- 1:nrow(th)
##   parameter <- NULL
##   th.m <- gather(th, key="parameter", value="value", -iter) %>%
##     mutate(comp=gsub("theta", "", parameter))
##   sig <- as_tibble(sigma(ch)) %>%
##     set_colnames("value") %>%
##     mutate(iter=seq_len(iter(model))) %>%
##     mutate(parameter="sigma")
## ##  if(class(model) == "SingleBatchPooled"){
## ##    sig$iter <- 1:nrow(sig)
## ##  } else sig$iter <- th$iter
## ##  s.m <- gather(sig, key="parameter", value="value", -iter) %>%
## ##    mutate(comp=gsub("sigma", "", parameter))
##   nu0 <- as_tibble(nu.0(ch)) %>%
##     set_colnames("nu0")
##   nu0$iter <- th$iter
##   nu0.m <- gather(nu0, key="parameter", value="value", -iter)
##   nu0$iter <- th$iter
##   nu0.m <- gather(nu0, key="parameter", value="value", -iter)
## 
##   prob <- as_tibble(p(ch)) %>%
##     set_param_names("p")
##   prob$iter <- th$iter
##   prob.m <- gather(prob, key="parameter", value="value", -iter) %>%
##     mutate(comp=gsub("p", "", parameter))
## 
##   s20 <- as_tibble(sigma2.0(ch)) %>%
##     set_colnames("s20")
##   s20$iter <- th$iter
##   s20.m <- gather(s20, key="parameter", value="value", -iter)
## 
##   mus <- as_tibble(mu(ch))%>%
##     set_colnames("mu")
##   mus$iter <- th$iter
##   mus.m <- gather(mus, key="parameter", value="value", -iter)
## 
##   taus <- as_tibble(tau(ch)) %>%
##     set_colnames("tau")
##   taus$iter <- th$iter
##   taus.m <- gather(taus, key="parameter", value="value", -iter)
## 
##   ll <- as_tibble(log_lik(ch)) %>%
##     set_colnames("log lik") %>%
##     mutate(iter=th$iter) %>%
##     gather(key="parameter", value="value", -iter)
##   dat.comp <- bind_rows(th.m,
##                         ##sig,
##                         prob.m)
##   sig <- sig[ , colnames(nu0.m) ]
##   dat <- bind_rows(sig,
##                    nu0.m,
##                    s20.m,
##                    mu=mus.m,
##                    tau=taus.m,
##                    loglik=ll)
##   list(comp=dat.comp,
##        single=dat)
## }
## 
## setMethod("gatherChains", "MultiBatchModel", function(object){
##   .meltMultiBatchChains(object)
## })
## 
## setMethod("gatherChains", "MultiBatchPooled", function(object){
##   .meltMultiBatchPooledChains(object)
## })
## 
## .ggMultiBatchChains <- function(model){
##   melt.ch <- gatherChains(model)
##   dat.batch <- melt.ch$batch %>%
##     mutate(iter=as.integer(iter))
##   dat.comp <- melt.ch$comp %>%
##     as_tibble
##   dat.single <- melt.ch$single %>%
##     as_tibble
##   iter <- value <- batch <- param <- comp <- NULL
##   p.batch <- ggplot(dat.batch, aes(iter, value, group=batch)) +
##     geom_point(size=0.3, aes(color=batch)) +
##     geom_line(aes(color=batch)) +
##     facet_grid(param ~ comp, scales="free_y")
## 
##   p.comp <- ggplot(dat.comp, aes(iter, value, group=comp)) +
##     geom_point(size=0.3, aes(color=comp)) +
##     geom_line(aes(color=comp)) +
##     facet_wrap(~param, scales="free_y")
## 
##   p.single <- ggplot(dat.single, aes(iter, value, group="param")) +
##     geom_point(size=0.3, color="gray") +
##     geom_line(color="gray") +
##     facet_wrap(~param, scales="free_y")
##   list(batch=p.batch, comp=p.comp, single=p.single)
## }
## 
## .ggMultiBatchPooledChains <- function(model){
##   melt.ch <- gatherChains(model)
##   for(i in seq_along(melt.ch)){
##     melt.ch[[i]]$iter <- as.integer(melt.ch[[i]]$iter)
##   }
##   iter <- value <- batch <- param <- comp <- NULL
##   p.theta <- ggplot(melt.ch$theta, aes(iter, value, group=batch)) +
##     geom_point(size=0.3, aes(color=batch)) +
##     geom_line(aes(color=batch)) +
##     facet_grid(param ~ comp, scales="free_y")
## 
##   p.sigma <- ggplot(melt.ch$sigma, aes(iter, value, group=batch)) +
##     geom_point(size=0.3, aes(color=batch)) +
##     geom_line(aes(color=batch))
##     ##facet_grid(param ~ comp, scales="free_y")
## 
##   p.comp <- ggplot(melt.ch$comp, aes(iter, value, group=comp)) +
##     geom_point(size=0.3, aes(color=comp)) +
##     geom_line(aes(color=comp)) +
##     facet_wrap(~param, scales="free_y")
## 
##   p.single <- ggplot(melt.ch$single, aes(iter, value, group="param")) +
##     geom_point(size=0.3, color="gray") +
##     geom_line(color="gray") +
##     facet_wrap(~param, scales="free_y")
##   list(theta=p.theta, sigma=p.sigma, comp=p.comp, single=p.single)
## }

.gg_multibatch <- function(model, bins=100, mixtheme, shift_homozygous){
  is_homozygous <- NULL
  colors <- c("#999999", "#56B4E9", "#E69F00", "#0072B2",
              "#D55E00", "#CC79A7",  "#009E73")
  pred <- predictiveTibble(model)
  predictive.summary <- pred %>%
    group_by(model, batch) %>%
    summarize(n=n())
  pred <- left_join(pred, predictive.summary,
                    by=c("model", "batch")) %>%
    mutate(batch=factor(paste("Batch", batch)))
  colors <- colors[seq_len(k(model))]
  ##df <- multiBatchDensities(model)
  full.data <- tibble(oned=oned(model),
                      batch=batch(model)) %>%
    mutate(batch=paste("Batch", batch)) %>%
    ##mutate(batch=factor(paste("Batch", batch))) %>%
    mutate(model=modelName(model))
  if(!missing(shift_homozygous)){
    full.data <- full.data %>%
      mutate(is_homozygous=oned <= shift_homozygous[1],
             oned=ifelse(is_homozygous,
                         oned + shift_homozygous[2], oned))
    pred <- pred %>%
      mutate(is_homozygous=oned <= shift_homozygous[1],
             oned=ifelse(is_homozygous,
                         oned + shift_homozygous[2], oned))
    xint <- max(full.data$oned[full.data$is_homozygous])
    vertical_break <- geom_vline(xintercept=xint,
                                 linetype="dashed",
                                 color="gray")

    ##scale_x <- scale_x_continuous(labels=c())
  } else {
    vertical_break <- geom_vline(xintercept=-Inf)
  }
  batch.data <- full.data %>%
    group_by(batch, model) %>%
    summarize(n=n()) %>%
    mutate(x=-Inf, y=Inf)
  ..count.. <- NULL
  n_facet <- NULL
  ..density.. <- NULL
  if(missing(mixtheme)){
    mixtheme <- theme(panel.background=element_rect(fill="white"),
                      axis.line=element_line(color="black"),
                      legend.position="bottom",
                      legend.direction="horizontal",
                      strip.text.y=element_text(angle=0),
                      strip.background=element_rect(fill="gray95",
                                                    color="gray90"))
  }
  geom_dens <- geom_density(adjust=1, alpha=0.4, size=0.75, color="gray30")
  if(all(is.na(pred$oned))) geom_dens <- geom_vline(xintercept=0, color="transparent")
  x <- NULL
  pred$component <- factor(as.numeric(pred$component))
  fig <- ggplot(pred, aes(x=oned, n_facet=n,
                          y=..count../n_facet,
                          fill=component)) +
    geom_histogram(data=full.data, aes(oned, ..density..),
                   bins=bins,
                   inherit.aes=FALSE,
                   color="gray70",
                   fill="gray70",
                   alpha=0.6) +
    vertical_break +
    facet_wrap(~batch, ncol=1, strip.position="right") +
    geom_dens +
    geom_text(data=batch.data, aes(x=x, y=y, label=paste0("  n=", n)),
              hjust="inward", vjust="inward",
              inherit.aes=FALSE,
              size=3) +
    mixtheme +
    scale_y_sqrt() +
    scale_color_manual(values=colors) +
    scale_fill_manual(values=colors) +
    xlab("Median log R ratio") +
    ylab("Density") +
    guides(color=FALSE, fill=guide_legend(title="Mixture\ncomponent"))
  return(fig)
}

##.gg_multibatch2 <- function(model, bins=100, mixtheme, shift_homozygous){
##    is_homozygous <- NULL
##  colors <- c("#999999", "#56B4E9", "#E69F00", "#0072B2",
##              "#D55E00", "#CC79A7",  "#009E73")
##  pred <- predictiveTibble(model)
##  predictive.summary <- pred %>%
##    group_by(model, batch) %>%
##    summarize(n=n())
##  pred <- left_join(pred, predictive.summary,
##                    by=c("model", "batch")) %>%
##    mutate(batch=factor(paste("Batch", batch)))
##  colors <- colors[seq_len(k(model))]
##  ##df <- multiBatchDensities(model)
##  full.data <- tibble(oned=oned(model),
##                      batch=batch(model)) %>%
##    mutate(batch=paste("Batch", batch)) %>%
##    ##mutate(batch=factor(paste("Batch", batch))) %>%
##    mutate(model=modelName(model))
##  if(length(unique(full.data$batch)) > 1){
##    full.data2 <- full.data
##    full.data2$batch <- "Overall"
##    full.data3 <- bind_rows(full.data2, full.data) %>%
##      mutate(batch=factor(batch, levels=c("Overall",
##                                          unique(full.data$batch))))
##    full.data <- full.data3
##  }
##  if(!missing(shift_homozygous)){
##    full.data <- full.data %>%
##      mutate(is_homozygous=oned <= shift_homozygous[1],
##             oned=ifelse(is_homozygous,
##                         oned + shift_homozygous[2], oned))
##    pred <- pred %>%
##      mutate(is_homozygous=oned <= shift_homozygous[1],
##             oned=ifelse(is_homozygous,
##                         oned + shift_homozygous[2], oned))
##    xint <- max(full.data$oned[full.data$is_homozygous])
##    vertical_break <- geom_vline(xintercept=xint,
##                                 linetype="dashed",
##                                 color="gray")
##
##    ##scale_x <- scale_x_continuous(labels=c())
##  } else {
##    vertical_break <- geom_vline(xintercept=-Inf)
##  }
##  batch.data <- full.data %>%
##    group_by(batch, model) %>%
##    summarize(n=n()) %>%
##    mutate(x=-Inf, y=Inf)
##  if("Overall" %in% batch.data$batch){
##    pred2 <- pred[1:3000, ]
##    ix <- sample(seq_len(nrow(pred)), 3000)
##    pred2$oned <- pred$oned[ix]
##    pred2$component <- pred$component[ix]
##    pred2$batch <- "Overall"
##    pred$batch <- as.character(pred$batch)
##    pred3 <- bind_rows(pred, pred2) %>%
##      mutate(batch=factor(batch, levels=levels(full.data$batch)))
##    pred <- pred3
##  }
##  ..count.. <- NULL
##  n_facet <- NULL
##  ..density.. <- NULL
##  if(missing(mixtheme)){
##    mixtheme <- theme(panel.background=element_rect(fill="white"),
##                      axis.line=element_line(color="black"),
##                      legend.position="bottom",
##                      legend.direction="horizontal",
##                      strip.text.y=element_text(angle=0),
##                      strip.background=element_rect(fill="gray95",
##                                                    color="gray90"))
##  }
##  geom_dens <- geom_density(adjust=1, alpha=0.4, size=0.75, color="gray30")
##  if(all(is.na(pred$oned))) geom_dens <- geom_vline(xintercept=0, color="transparent")
##  x <- NULL
##  pred$component <- factor(as.numeric(pred$component))
##  fig <- ggplot(pred, aes(x=oned, n_facet=n,
##                          y=..count../n_facet,
##                          fill=component)) +
##    geom_histogram(data=full.data, aes(oned, ..density..),
##                   bins=bins,
##                   inherit.aes=FALSE,
##                   color="gray70",
##                   fill="gray70",
##                   alpha=0.6) +
##    vertical_break +
##    facet_wrap(~batch, ncol=1, strip.position="right") +
##    geom_dens +
##    geom_text(data=batch.data, aes(x=x, y=y, label=paste0("  n=", n)),
##              hjust="inward", vjust="inward",
##              inherit.aes=FALSE,
##              size=3) +
##    mixtheme +
##    scale_y_sqrt() +
##    scale_color_manual(values=colors) +
##    scale_fill_manual(values=colors) +
##    xlab("Median log R ratio") +
##    ylab("Density") +
##    guides(color=FALSE, fill=guide_legend(title="Mixture\ncomponent"))
##  return(fig)
##}

mb_pred_data <- function(model, predict){
  batches <- factor(batch(model), labels=paste("batch", unique(batch(model))))
  dat <- tibble(y=y(model),
                batch=batches,
                predictive="empirical")
  predict$batch <- factor(predict$batch, labels=paste("batch", unique(batch(model))))
  colnames(predict)[3] <- "predictive"
  predict$predictive <- "posterior\npredictive"
  predictive <- NULL
  dat2 <- rbind(dat, predict) %>%
    mutate(predictive=factor(predictive,
                             levels=c("empirical", "posterior\npredictive")))
  dat2 <- dat2[, c("y", "predictive", "batch")]
  dat2
}

mb_predictive <- function(model, predict, adjust=1/3){
  dat <- mb_pred_data(model, predict)
  predictive <- NULL
  fig <- ggplot(dat, aes(y, fill=predictive)) +
    geom_density(alpha=0.4, adjust=adjust) +
    facet_wrap(~batch, ncol=1) +
    guides(fill=guide_legend(title="")) +
    theme(panel.background=element_rect(fill="white"))
  fig
}

.predictiveDataTable <- function(model, predict){
  model.name <- modelName(model)
  if(class(model) %in% c("SingleBatchModel", "SingleBatchPooled")){
    dat <- sb_pred_data(model, predict)
    dat$model <- model.name
  } else {
    dat <- mb_pred_data(model, predict)
    dat$model <- model.name
  }
  dat
}

sb_pred_data <- function(model, predict){
  dat <- tibble(y=y(model), predictive="empirical")
  ##browser()
  ##predict.bak=predict
  colnames(predict)[2] <- "predictive"
  predict$predictive <- "posterior\npredictive"
  predict$batch <- factor("batch 1")
  dat$batch <- factor("batch 1")
  predictive <- NULL
  dat2 <- rbind(dat, predict) %>%
    mutate(predictive=factor(predictive,
                             levels=c("empirical", "posterior\npredictive")))
  dat2
}

sb_predictive <- function(model, predict, adjust=1/3){
  dat2 <- sb_pred_data(model, predict)
  predictive <- NULL
  fig <- ggplot(dat2, aes(y, fill=predictive, color=predictive)) +
    geom_density(alpha=0.4, adjust=adjust) +
    guides(fill=guide_legend(title="")) +
    theme(panel.background=element_rect(fill="white"))
  fig
}


multibatch_figure <- function(theoretical, empirical, model){
  nb <- nBatch(model)
  colors <- c("#999999", "#56B4E9", "#E69F00", "#0072B2",
              "#D55E00", "#CC79A7",  "#009E73")
  scale_col <-  scale_color_manual(values=colors)
  scale_fill <- scale_fill_manual(values=colors)
  scale_y <- scale_y_sqrt()
  lrr <- NULL
  ..count.. <- NULL
  ghist <- geom_histogram(data=empirical, aes(lrr, ..count..),
                          binwidth=0.01, inherit.aes=FALSE)
  gobj <- ggplot() + ghist + facet_wrap(~batch)
  gb <- ggplot_build(gobj)
  ylimit <- gb$layout$panel_ranges[[1]][["y.range"]]
  theoretical.sum <- group_by(theoretical, batch, component) %>%
    summarize(maxy=max(y))
  theoretical$y <- rescale(theoretical$y, c(0, ylimit[2]))
  component <- x <- y <- NULL
  gpolygon <-  geom_polygon(aes(x, y, fill=component, color=component), alpha=0.4)
  ggplot(theoretical) +
    ghist +
    gpolygon +
    scale_col +
    scale_fill +
    xlab("quantiles") + ylab("count") +
    guides(fill=guide_legend(""), color=guide_legend("")) +
    facet_wrap(~batch, nrow=nb, as.table=TRUE, scales="free_y")
}

.gg_multibatch_pooled <- function(model, bins){
  fig <- .gg_multibatch(model, bins)
  fig
}



.gg_multibatch_copynumber <- function(model, bins=400){
  colors <- c("#999999", "#56B4E9", "#E69F00", "#0072B2",
              "#D55E00", "#CC79A7",  "#009E73")
  pred <- predictiveTibble(model)
##  pred <- predictive(model) %>%
##    longFormatKB(K=k(model), B=numBatch(model)) %>%
##    set_colnames(c("s", "y", "batch", "component")) %>%
##    mutate(model=modelName(model),
##           batch=as.character(batch),
##           batch=gsub("batch ", "", batch),
##           component=factor(component))
##  ##  predictive <- posteriorPredictive(model) %>%
##  ##    mutate(component=factor(component))
  predictive.summary <- pred %>%
    group_by(model, batch) %>%
    summarize(n=n())
  pred <- left_join(pred, predictive.summary,
                          by=c("model", "batch")) %>%
    mutate(batch=paste("Batch", batch))
  zz <- map_z(model)
  comp_labels <- mapping(model)
  pred$copynumber <- comp_labels[pred$component]
  colors <- colors[seq_along(comp_labels)]
  ##df <- multiBatchDensities(model)
  full.data <- tibble(oned=oned(model),
                      batch=batch(model)) %>%
    mutate(batch=paste("Batch", batch)) %>%
    mutate(model=modelName(model))
  ##xlimit <- c(-5, 1)
  batch.data <- full.data %>%
    group_by(batch, model) %>%
    summarize(n=n()) %>%
    mutate(x=-Inf, y=Inf)
  ..count.. <- NULL
  n_facet <- NULL
  ..density.. <- NULL
  x <- NULL
  copynumber <- NULL
  fig <- ggplot(pred, aes(x=oned, n_facet=n,
                          y=..count../n_facet,
                          fill=copynumber)) +
    geom_histogram(data=full.data, aes(oned, ..density..),
                   bins=bins,
                   inherit.aes=FALSE,
                   color="gray70",
                   fill="gray70",
                   alpha=0.1) +
    facet_wrap(~batch, ncol=1, strip.position="right") +
    geom_density(adjust=1, alpha=0.4, size=0.75, color="gray30") +
    geom_text(data=batch.data,
              aes(x=x, y=y, label=paste0("  n=", n)),
              hjust="inward", vjust="inward",
              inherit.aes=FALSE,
              size=3) +
    ## show marginal density
    theme(panel.background=element_rect(fill="white"),
          axis.line=element_line(color="black"),
          legend.position="bottom",
          legend.direction="horizontal") +
    ##facet_grid(batch~model,
    ##labeller=labeller(model=ml)) +
    scale_y_sqrt() +
    scale_color_manual(values=colors) +
    scale_fill_manual(values=colors) +
    xlab("average copy number") +
    ylab("density") +
    ##coord_cartesian(xlim=xlimit) +
    guides(color=FALSE) ##+
  return(fig)
}

setMethod("ggMixture", "MultiBatchCopyNumber",
          function(model, bins=100, mixtheme, shift_homozygous){
              .gg_multibatch_copynumber(model, bins)
})

setMethod("ggMixture", "MultiBatchCopyNumberPooled",
          function(model, bins=100, mixtheme, shift_homozygous){
              .gg_multibatch_copynumber(model, bins)
})

setMethod("ggMixture", "MultiBatchModel",
          function(model, bins=100, mixtheme, shift_homozygous){
              .gg_multibatch(model, bins=bins, mixtheme, shift_homozygous)
})

#' @export
#' @rdname ggplot-functions
#' @aliases ggMixture,MultiBatchModel-method
setMethod("ggMixture", "MultiBatch",
          function(model, bins=100, mixtheme, shift_homozygous){
              .gg_multibatch(model, bins=bins, mixtheme, shift_homozygous)
})

setMethod("ggMixture", "MultiBatchPooled",
          function(model, bins=100, mixtheme, shift_homozygous){
              .gg_multibatch_pooled(model, bins)
})

##setMethod("ggChains", "MultiBatchModel", function(model){
##  .ggMultiBatchChains(model)
##})

###' @export
###' @rdname ggplot-functions
###' @aliases ggChains,MultiBatch-method
##setMethod("ggChains", "MultiBatch", function(model){
##  ch <- chains(model)
##  ch.list <- as(ch, "list")
##  melt.ch <- gatherChains(model)
##  dat.batch <- melt.ch$batch %>%
##    mutate(iter=as.integer(iter))
##  dat.comp <- melt.ch$comp %>%
##    as_tibble
##  dat.single <- melt.ch$single %>%
##    as_tibble
##  iter <- value <- batch <- param <- comp <- NULL
##  p.batch <- ggplot(dat.batch, aes(iter, value, group=batch)) +
##    geom_point(size=0.3, aes(color=batch)) +
##    geom_line(aes(color=batch)) +
##    facet_grid(param ~ comp, scales="free_y")
##
##  p.comp <- ggplot(dat.comp, aes(iter, value, group=comp)) +
##    geom_point(size=0.3, aes(color=comp)) +
##    geom_line(aes(color=comp)) +
##    facet_wrap(~param, scales="free_y")
##
##  p.single <- ggplot(dat.single, aes(iter, value, group="param")) +
##    geom_point(size=0.3, color="gray") +
##    geom_line(color="gray") +
##    facet_wrap(~param, scales="free_y")
##  list(batch=p.batch, comp=p.comp, single=p.single)  
##})
##
##setMethod("ggChains", "MultiBatchPooled", function(model){
##  .ggMultiBatchPooledChains(model)
##})
scristia/CNPBayes documentation built on Aug. 9, 2020, 7:31 p.m.