set.seed(13254)

# Import libraries
library(tidyverse)
library(magrittr)
library(rstan)
library(tidybayes)
library(here)
library(foreach)
library(doParallel)
library(loo)
library(matrixStats)
library(DESeq2)

devtools::load_all()

registerDoParallel()
# library(future)
# plan(multiprocess)

rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())
# N = 50
# G = 10
# counts_source = "Acute_Myeloid_Leukemia_Primary_Blood_Derived_Cancer_-_Peripheral_Blood"
# 
# counts_tidy <- load_test_data(N, G, counts_source)
# data_for_stan = data_for_stan_from_tidy(N, G, counts_tidy)
N = 50
G = 10


data_problematic_raw <- readr::read_csv(here("local_temp_data","t_cells_with_error_qq_plot_means.csv"), col_types = cols(
  sample = col_character(),
  symbol = col_character(),
  `read count` = col_double(),
  `read count normalised` = col_double(),
  `gene error mean` = col_double()
))



N_samples_total <- length(unique(data_problematic_raw$sample))

problematic_genes <- data_problematic_raw %>% group_by(symbol, `gene error mean`) %>% 
  summarise(N_samples = n()) %>% ungroup() %>% filter(N_samples == N_samples_total) %>% distinct() %>% arrange(desc(`gene error mean`)) %>% head(G) %>% pull(symbol)

data_bad_genes <- 
      data_problematic_raw %>%
      filter(sample %in% ( (.) %>% distinct(sample) %>% head(N) %>% pull(sample)) ) %>%
      filter(symbol %in% problematic_genes) %>%
      dplyr::rename(ens_iso = symbol) %>%
      mutate_if(is.character, as.factor) %>%
      mutate(`read count` = as.integer(`read count`))

data_for_stan_bad_genes <- data_for_stan_from_tidy(N, G, data_bad_genes)
data_for_stan_bad_genes$generate_quantities <- 1
model_defs = data.frame(model_name = c(
  "negBinomial_deseq2", 
  "negBinomial"
  # "gamma_multinomial",
  # "logNormal_multinomial",
  # "multinomial" ,
  # "gamma_multinomial_2"
  #"negBinomial_deseq2_multinomial"
  ), adapt_delta = c(0.8, 0.8)) #, 0.95, 0.95, 0.8, 0.95))
deseq_normalization <- estimateSizeFactorsForMatrix(t(data_for_stan_bad_genes$counts))
K = 2

N_samples_gamma_log_lik <- 10000
#Copy the data for all runs
kfold_def <- prepare_kfold(K, model_defs, data_for_stan_bad_genes, seed = 1346842485)
for(i in 1:nrow(kfold_def$model_defs_kfold)) {
  if(kfold_def$model_defs_kfold$model_name[i] %in% c("negBinomial_deseq2","negBinomial_deseq2_multinomial")) {
    kfold_def$data_list[[i]]$normalization = deseq_normalization
    kfold_def$data_list[[i]]$my_prior = c(5, 3)
  }
  if(kfold_def$model_defs_kfold$model_name[i] %in% c("gamma_multinomial", "gamma_multinomial_2")) {
    kfold_def$data_list[[i]]$N_samples_log_lik = N_samples_gamma_log_lik
  }
}

kfold_res <- run_kfold(kfold_def, "bad_genes")
#compare(x = kfold_res$kfold)

#Test the "sampled" version of gamma_multinomial
id_gamma <- kfold_def$model_defs %>% filter(model_name == "gamma_multinomial") %>% pull(id)
indices_gamma = kfold_def$model_defs_kfold$id == id_gamma
fits_gamma <- kfold_res$fits[indices_gamma]
holdout_gamma <- lapply(kfold_def$data_list[indices_gamma], '[[', "holdout")
log_lik_gamma_sampled <- extract_log_lik_K(fits_gamma, holdout_gamma, "log_lik_sampled")

id_gamma2 <- kfold_def$model_defs %>% filter(model_name == "gamma_multinomial_2") %>% pull(id)
indices_gamma2 = kfold_def$model_defs_kfold$id == id_gamma2
fits_gamma2 <- kfold_res$fits[indices_gamma2]
holdout_gamma2 <- lapply(kfold_def$data_list[indices_gamma2], '[[', "holdout")
log_lik_gamma2_sampled <- extract_log_lik_K(fits_gamma2, holdout_gamma, "log_lik_sampled")


kfold_all <- c(kfold_res$kfold, list(gamma_multinomial_sampled = kfold(log_lik_gamma_sampled), gamma_multinomial_2_sampled = kfold(log_lik_gamma2_sampled)))
compare(x = kfold_all)
holdout_ranks_all <- kfold_def$model_defs$id %>% map(function(id) {
    indices = kfold_def$model_defs_kfold$id == id
    fits_for_model <- kfold_res$fits[indices]
    holdout_for_model <- lapply(kfold_def$data_list[indices], '[[', "holdout")
    extract_holdout_ranks(fits_for_model, holdout_for_model, kfold_def$base_data) %>% 
      as.tibble() %>%
      rownames_to_column("sample") %>%
      gather("gene","rank", - sample) %>%
      mutate(model =  kfold_def$model_defs$model_name[kfold_def$model_defs$id == id])
  }) %>% do.call(rbind, args = .)
plot_holdout_ranks(holdout_ranks_all, binwidth = 4)
holdout_ranks_all %>% filter(model %in% c("gamma_multinomial","negBinomial")) %>% ggplot(aes(x = rank)) + geom_histogram(bins = 10) + facet_grid(gene~model, scales = "free_y")

TODO:test K-fold diag on simulations from known model



stemangiola/RNAseq-noise-model documentation built on Oct. 18, 2019, 1:47 p.m.