rm(list = ls())
library(magrittr)
library(ggplot2)
dir_output <- "results/FARMM"
dir.create(dir_output, showWarnings = TRUE, recursive = TRUE)
load("data/FARMM/X_array.RData")

l_fit_microTensor <- list()
for(i in seq(1, 5)) {
  set.seed(i)
  l_fit_microTensor[[i]] <- 
    microTensor::microTensor(
      X = X_array, R = 3, 
      nn_t = TRUE, ortho_m = TRUE,
      weighted = TRUE,
      control = list(L_init = 2,
                     gamma = 2,
                     maxit = 1000,
                     verbose = TRUE,
                     debug_dir = paste0(dir_output, "/fit/", i)))
}
fit_microTensor <- l_fit_microTensor %>% 
  purrr::map_dbl(~ min(.x$fitting$obj)) %>% 
  {order(.)[1]} %>% 
  {l_fit_microTensor[[.]]}
save(fit_microTensor, file = paste0(dir_output, "/fit_microTensor.RData"))

l_fit_ctf <- list()
for(i in seq(1, 5)) {
  set.seed(i)
  l_fit_ctf[[i]] <- 
    microTensor::ctf(
      X = X_array, R = 3)
}
fit_ctf <- l_fit_ctf %>% 
  purrr::map_dbl(~ min(.x$fitting$obj)) %>% 
  {order(.)[1]} %>% 
  {l_fit_ctf[[.]]}
save(fit_ctf, file = paste0(dir_output, "/fit_ctf.RData"))

fit_pca <- microTensor::pca(X_array, R = 3)
save(fit_pca, file = paste0(dir_output, "/fit_pca.RData"))
load(paste0(dir_output, "/fit_microTensor.RData"))
load(paste0(dir_output, "/fit_ctf.RData"))
load(paste0(dir_output, "/fit_pca.RData"))
load("data/FARMM/X_array.RData")

Yhat_microTensor <- microTensor::get_Yhat(fit_microTensor)
expY <- exp(Yhat_microTensor)
expYsum_m <- apply(expY, c(2, 3), sum)
phat_microTensor <- 
  vapply(seq_len(dim(X_array)[1]),
         function(i) expY[i, , ] / expYsum_m,
         matrix(0.0, nrow = dim(X_array)[2], ncol = dim(X_array)[3]))
phat_microTensor <- 
  aperm(phat_microTensor, perm = c(3, 1, 2))

Yhat_ctf <- microTensor::get_Yhat(fit_ctf)
expY <- exp(Yhat_ctf)
expYsum_m <- apply(expY, c(2, 3), sum)
phat_ctf <- 
  vapply(seq_len(dim(X_array)[1]),
         function(i) expY[i, , ] / expYsum_m,
         matrix(0.0, nrow = dim(X_array)[2], ncol = dim(X_array)[3]))
phat_ctf <- 
  aperm(phat_ctf, perm = c(3, 1, 2))

Yhat_pca_mat <-  fit_pca$m %*% (t(fit_pca$s_full) * fit_pca$lambda)
Yhat_pca <- array(NA, dim(X_array))
for(i in seq_len(nrow(Yhat_pca_mat))) {
  Yhat_pca[i, , ] <-  Yhat_pca_mat[i, ]
}
expY <- exp(Yhat_pca)
expYsum_m <- apply(expY, c(2, 3), sum)
phat_pca <- 
  vapply(seq_len(dim(X_array)[1]),
         function(i) expY[i, , ] / expYsum_m,
         matrix(0.0, nrow = dim(X_array)[2], ncol = dim(X_array)[3]))
phat_pca <- 
  aperm(phat_pca, perm = c(3, 1, 2))


phat_data <- array(NA, dim(X_array))
Xsum_m <- apply(X_array, c(2, 3), sum)
phat_data <- 
  vapply(seq_len(dim(X_array)[1]),
         function(i) X_array[i, , ] / Xsum_m,
         matrix(0.0, nrow = dim(X_array)[2], ncol = dim(X_array)[3]))
phat_data <- 
  aperm(phat_data, perm = c(3, 1, 2))

tb_diff <- 
  list(phat_microTensor,
       phat_ctf,
       phat_pca) %>% 
  purrr::map2_dfr(
    c("microTensor", "CTF", "PCA"),
    function(i_phat, i_method) {
      tibble::tibble(
        diff = abs(i_phat - phat_data) %>% 
          apply(c(2, 3), mean) %>% 
          as.vector(),
        method = i_method
      )
    }
  ) %>% 
  dplyr::filter(!is.na(diff)) %>% 
  dplyr::mutate(method = factor(method, levels = c("PCA", "CTF", "microTensor")))
p_diff <- tb_diff %>% 
  ggplot(aes(x = method, y = diff)) +
  geom_boxplot() +
  theme_bw() +
  theme(axis.title.x = element_blank()) +
  ylab("Absolute difference in observed and \n method-based microbial profiles")
load("data/FARMM/df_samples.RData")
df_subjects <- df_samples %>% 
  dplyr::group_by(SubjectID) %>% 
  dplyr::slice(1) %>% 
  dplyr::ungroup()
feature_names <- dimnames(X_array)[[1]]
subject_names <- dimnames(X_array)[[2]]
time_names <- dimnames(X_array)[[3]]
mat_tax <- dimnames(X_array)[[1]] %>% 
  stringr::str_split_fixed(stringr::fixed("|"), n = 7)

percs <- fit_microTensor$lambda^2 / sum(fit_microTensor$lambda^2) * 100
ps <- c(1, 2) %>% 
  purrr::map(function(r) {
    # feature
    df_m <- microTensor::create_loading(
      fit_microTensor,
      feature_names = feature_names,
      subject_names = subject_names,
      time_names = time_names,
      class = "feature") %>% 
      dplyr::mutate(
        feature_plot = mat_tax[, 7] %>% 
          stringr::str_replace(stringr::fixed("s__"), "") %>% 
          stringr::str_replace(stringr::fixed("_"), " ")) %>% 
      dplyr::mutate(m = !!sym(paste0("Axis ", r)))

    df_plot <- rbind(
      df_m %>% 
        dplyr::arrange(-df_m[["m"]]) %>% 
        dplyr::slice(seq_len(5)),
      df_m %>% 
        dplyr::arrange(df_m[["m"]]) %>% 
        dplyr::slice(seq_len(5)) %>% 
        {.[seq(5, 1), ]}
    ) %>% 
      dplyr::mutate(feature_plot = factor(feature_plot, 
                                          levels = rev(feature_plot)))

    p_feature <- df_plot %>% 
      ggplot(aes(y = m, x = feature_plot)) +
      geom_bar(stat = "identity") +
      coord_flip() +
      theme_bw() +
      theme(axis.title.y = element_blank()) +
      ylab("Feature loading") +
      ggtitle(" ")

    # time trajectory
    tb_loading <- microTensor::create_loading(fit_decomp = fit_microTensor,
                                              feature_names = feature_names,
                                              subject_names = subject_names,
                                              time_names = time_names,
                                              class = "sample")

    tb_loading <- tb_loading %>% 
      dplyr::mutate(Time = as.integer(as.character(Time))) %>% 
      dplyr::left_join(df_samples %>% 
                         dplyr::group_by(SubjectID) %>% 
                         dplyr::slice(1) %>% 
                         dplyr::ungroup(), 
                       by = c("Subject" = "SubjectID")) 

    tb_loading_avg <- tb_loading %>% 
      dplyr::group_by(Time, study_group) %>% 
      dplyr::summarise(mean_loading = mean(!!sym(paste0("Axis ", r)), na.rm = TRUE),
                       sd_loading = sd(!!sym(paste0("Axis ", r)), na.rm = TRUE) /
                         sqrt(dplyr::n())) %>% 
      dplyr::ungroup()

    p_time <- tb_loading_avg %>% 
      dplyr::filter(!is.na(study_group)) %>% 
      ggplot(aes(x = Time, y = mean_loading,
                 color = study_group, shape = study_group)) +
      geom_line(position = position_dodge(width = 0.5)) +
      geom_point(position = position_dodge(width = 0.5), size = 2) +
      geom_errorbar(aes(ymax = mean_loading + sd_loading,
                        ymin = mean_loading - sd_loading),
                    position = position_dodge(width = 0.5),
                    width = 0.5) +
      theme_bw() +
      labs(colour="Diet group", shape = "Diet group") +
      xlab("Study day") +
      ylab("Group average sample loading") +
      ggtitle("")
    if(r == 1)
      p_time <- p_time + 
      theme(legend.position = c(0, 1),
            legend.justification = c(0, 1),
            legend.background = element_blank())
    else 
      p_time <- p_time + 
      theme(legend.position = "none")

    return(list(p_feature,
                p_time))
  })


p_main <- 
  list(
    cowplot::plot_grid(ps[[1]][[1]], ps[[2]][[1]],
                       ncol = 1,
                       align = "v",
                       labels = paste0(c("B", "C"), ") Axis ", c(1, 2),
                                       " (", round(percs[c(1, 2)], 2),
                                       "%)")),
    cowplot::plot_grid(ps[[1]][[2]], ps[[2]][[2]],
                       ncol = 1,
                       align = "v")
  ) %>% 
  cowplot::plot_grid(plotlist = .,
                     nrow = 1,
                     align = "h", 
                     rel_widths = c(1, 0.8)) %>% 
  cowplot::plot_grid(p_diff,  .,
                     nrow = 1,
                     align = "h", 
                     rel_widths = c(0.5, 1),
                     labels = c("A)", ""))
ggsave(p_main, file = "figures/figure6/figure_FARMM.pdf",
       width = 11, height = 6)
p_supps <- list(fit_pca,
                fit_ctf) %>% 
  purrr::map2(
    c("PCA", "CTF"),
    function(fit_decomp, method) {
      percs <- fit_decomp$lambda^2 / sum(fit_decomp$lambda^2) * 100
      ps <- c(1, 2) %>% 
        purrr::map(function(r) {
          # feature
          df_m <- microTensor::create_loading(
            fit_decomp,
            feature_names = feature_names,
            subject_names = subject_names,
            time_names = time_names,
            class = "feature") %>% 
            dplyr::mutate(
              feature_plot = ifelse(mat_tax[, 7] == "",
                                    mat_tax[, 6] %>% 
                                      stringr::str_replace(stringr::fixed("g__"), "") %>% 
                                      stringr::str_replace(stringr::fixed("_"), " "),
                                    mat_tax[, 7] %>% 
                                      stringr::str_replace(stringr::fixed("s__"), "") %>% 
                                      stringr::str_replace(stringr::fixed("_"), " "))) %>% 
            dplyr::mutate(m = !!sym(paste0("Axis ", r)))

          df_plot <- rbind(
            df_m %>% 
              dplyr::arrange(-df_m[["m"]]) %>% 
              dplyr::slice(seq_len(5)),
            df_m %>% 
              dplyr::arrange(df_m[["m"]]) %>% 
              dplyr::slice(seq_len(5)) %>% 
              {.[seq(5, 1), ]}
          ) %>% 
            dplyr::mutate(feature_plot = factor(feature_plot, 
                                                levels = rev(feature_plot)))

          p_feature <- df_plot %>% 
            ggplot(aes(y = m, x = feature_plot)) +
            geom_bar(stat = "identity") +
            coord_flip() +
            theme_bw() +
            theme(axis.title.y = element_blank()) +
            ylab("Feature loading") +
            # ggtitle(paste0("Axis ", r, " (", 
            #                round(percs[r], 2), "%)"))
            ggtitle(" ")

          # time trajectory
          tb_loading <- microTensor::create_loading(
            fit_decomp = fit_decomp,
            feature_names = feature_names,
            subject_names = subject_names,
            time_names = time_names,
            class = "sample")

          tb_loading <- tb_loading %>% 
            dplyr::mutate(Time = as.integer(Time)) %>% 
            dplyr::left_join(df_samples %>% 
                               dplyr::group_by(SubjectID) %>% 
                               dplyr::slice(1) %>% 
                               dplyr::ungroup(), 
                             by = c("Subject" = "SubjectID")) 

          tb_loading_avg <- tb_loading %>% 
            dplyr::group_by(Time, study_group) %>% 
            dplyr::summarise(mean_loading = mean(!!sym(paste0("Axis ", r)), na.rm = TRUE),
                             sd_loading = sd(!!sym(paste0("Axis ", r)), na.rm = TRUE) /
                               sqrt(dplyr::n())) %>% 
            dplyr::ungroup()

          p_time <- tb_loading_avg %>% 
            dplyr::filter(!is.na(study_group)) %>% 
            ggplot(aes(x = Time, y = mean_loading,
                       color = study_group, shape = study_group)) +
            geom_line(position = position_dodge(width = 0.5)) +
            geom_point(position = position_dodge(width = 0.5), size = 2) +
            geom_errorbar(aes(ymax = mean_loading + sd_loading,
                              ymin = mean_loading - sd_loading),
                          position = position_dodge(width = 0.5),
                          width = 0.5) +
            theme_bw() +
            labs(colour="Diet group", shape = "Diet group") +
            xlab("Study day") +
            ylab("Group average sample loading") +
            ggtitle("")
          if(r == 1)
            p_time <- p_time + 
            theme(legend.position = c(0, 1),
                  legend.justification = c(0, 1),
                  legend.background = element_blank())
          else 
            p_time <- p_time + 
            theme(legend.position = "none")

          return(list(p_feature,
                      p_time))
        })

      list(
        cowplot::plot_grid(ps[[1]][[1]], ps[[2]][[1]],
                           ncol = 1,
                           align = "v",
                           labels = paste0("     Axis ", c(1, 2),
                                           " (", round(percs[c(1, 2)], 2),
                                           "%)")),
        cowplot::plot_grid(ps[[1]][[2]], ps[[2]][[2]],
                           ncol = 1,
                           align = "v")
      ) %>% 
        cowplot::plot_grid(plotlist = .,
                           nrow = 1,
                           align = "h",
                           rel_widths = c(1, 0.8))
    })

p_supp <- cowplot::plot_grid(plotlist = p_supps,
                             labels = c("A) PCA", "B) CTF"),
                             ncol = 1)
ggsave(p_supp, file = "figures/suppFigure/figure_FARMM_supp.pdf",
       width = 8, height = 10)
load(paste0(dir_output, "/fit_microTensor.RData"))
load(paste0(dir_output, "/fit_ctf.RData"))
load(paste0(dir_output, "/fit_pca.RData"))
load("data/FARMM/X_array.RData")
load("data/FARMM/df_samples.RData")

tb_test_diet <- list(
  fit_ctf,
  fit_microTensor
) %>% 
  purrr::map2_dfr(
    c("CTF", "microTensor"),
    function(fit_decomp, method) {
      tb_subject <- microTensor::create_loading(fit_decomp,
                                                feature_names = dimnames(X_array)[[1]], 
                                                subject_names = dimnames(X_array)[[2]],
                                                time_names = dimnames(X_array)[[3]],
                                                class = "subject")
      tb_subject <- df_samples %>% 
        dplyr::group_by(SubjectID) %>% 
        dplyr::slice(1) %>% 
        dplyr::ungroup() %>% 
        dplyr::left_join(tb_subject, by = c("SubjectID" = "Subject"))

      tb_test <- c("Vegan", "EEN") %>% 
        purrr::map_dfr(function(group) {
          tb_subset <- tb_subject %>% 
            dplyr::filter(study_group %in% c("Omnivore", group)) %>% 
            dplyr::mutate(test_var = (study_group == group) * 1)
          tb_test_subset <- c("Axis 1", "Axis 2") %>% 
            purrr::map_dfr(function(axis) {
              tb_subset <- tb_subset %>% 
                dplyr::mutate(outcome = !!sym(axis))
              fit_lm <- lm(outcome ~ test_var, data = tb_subset)
              results <- summary(fit_lm)$coef
              tibble::tibble(
                Diet = group,
                Axis = axis,
                Coefficient = results[2, 1],
                `Standard error` = results[2, 2],
                `T statistic` = results[2, 3],
                p = results[2, 4]
              )
            })
        }) %>% 
        dplyr::mutate(Method = method)
    })
tb_samples <- microTensor::create_loading(
  fit_pca, 
  feature_names = dimnames(X_array)[[1]], 
  subject_names = dimnames(X_array)[[2]],
  time_names = dimnames(X_array)[[3]],
  class = "sample") %>% 
  dplyr::mutate(Time = as.numeric(Time))
tb_samples <- df_samples %>% 
  dplyr::left_join(tb_samples, by = c("SubjectID" = "Subject",
                                      "study_day" = "Time")) %>% 
  dplyr::filter(!is.na(`Axis 1`))
tb_test_diet_pca <- c("Vegan", "EEN") %>% 
  purrr::map_dfr(function(group) {
    tb_subset <- tb_samples %>% 
      dplyr::filter(study_group %in% c("Omnivore", group)) %>% 
      dplyr::mutate(test_var = (study_group == group) * 1)
    tb_test_subset <- c("Axis 1", "Axis 2") %>% 
      purrr::map_dfr(function(axis) {
        tb_subset <- tb_subset %>% 
          dplyr::mutate(outcome = !!sym(axis))
        fit_lme4 <- lme4::lmer(
          outcome ~ test_var + (1|SubjectID), 
          data = tb_subset)
        results <- lmerTest:::summary.lmerModLmerTest(fit_lme4)$coef
        tibble::tibble(
          Diet = group,
          Axis = axis,
          Coefficient = results[2, 1],
          `Standard error` = results[2, 2],
          `T statistic` = results[2, 4],
          p = results[2, 5]
        )
      })
  }) %>% 
  dplyr::mutate(Method = "PCA")

tb_test_diet_write <- rbind(tb_test_diet_pca,
                            tb_test_diet) %>% 
  dplyr::ungroup() %>% 
  dplyr::mutate(Test = paste0(Diet, " vs. Omnivore")) %>% 
  dplyr::group_by(Method) %>% 
  dplyr::mutate(Method_write = ifelse(seq(1, dplyr::n()) == 1,
                                      Method,
                                      "")) %>% 
  dplyr::group_by(Method, Test) %>% 
  dplyr::mutate(Test_write = ifelse(seq(1, dplyr::n()) == 1,
                                    Test,
                                    "")) %>%
  dplyr::ungroup() %>% 
  dplyr::select(Method_write, Test_write, Axis,
                Coefficient, `Standard error`, `T statistic`, p) %>% 
  dplyr::rename(Method = Method_write,
                Test = Test_write) %>% 
  dplyr::mutate(
    Coefficient = 
      sapply(Coefficient,
             format,
             digits = 3, 
             scientific = -2, 
             drop0trailing = TRUE),
    `Standard error` = 
      sapply(`Standard error`,
             format,
             digits = 3, 
             scientific = -2, 
             drop0trailing = TRUE),
    `T statistic` = 
      sapply(`T statistic`,
             format,
             digits = 3, 
             scientific = -2, 
             drop0trailing = TRUE),
    p = 
      sapply(p,
             format,
             digits = 3, 
             scientific = -2, 
             drop0trailing = TRUE))
readr::write_csv(tb_test_diet_write, path = "tables/farmm_test_diet.csv")

tb_test_time <- list(
  fit_ctf,
  fit_microTensor
) %>% 
  purrr::map2_dfr(
    c("CTF", "microTensor"),
    function(fit_decomp, method) {
      tb_time <- microTensor::create_loading(fit_decomp,
                                             feature_names = dimnames(X_array)[[1]], 
                                             subject_names = dimnames(X_array)[[2]],
                                             time_names = dimnames(X_array)[[3]],
                                             class = "time")
      tb_time <- tb_time %>% 
        dplyr::mutate(Time = as.numeric(as.character(Time))) %>% 
        dplyr::mutate(Time1 = ifelse(Time <= 5,
                                     Time,
                                     5),
                      Time2 = ifelse(Time <= 5,
                                     0,
                                     ifelse(Time <= 9,
                                            Time - 5,
                                            4)),
                      Time3 = ifelse(Time <= 9,
                                     0, 
                                     Time - 9))
      tb_test <- c("Axis 1", "Axis 2") %>% 
        purrr::map_dfr(function(axis) {
          tb_subset <- tb_time %>% 
            dplyr::mutate(outcome = !!sym(axis))
          fit_lm <- lm(outcome ~ Time1 + Time2 + Time3, 
                       data = tb_subset)
          results <- summary(fit_lm)$coef
          tibble::tibble(
            Test = c("Daily change (day 0-5)",
                     "Daily change (day 6-9)",
                     "Daily change (day 10-15)"),
            Axis = axis,
            Coefficient = results[seq(2, 4), 1],
            `Standard error` = results[seq(2, 4), 2],
            `T statistic` = results[seq(2, 4), 3],
            p = results[seq(2, 4), 4]
          )
        }) %>% 
        dplyr::mutate(Method = method)
      return(tb_test)
    })
tb_samples <- tb_samples %>% 
  dplyr::mutate(Time = study_day) %>% 
  dplyr::mutate(Time1 = ifelse(Time <= 5,
                               Time,
                               5),
                Time2 = ifelse(Time <= 5,
                               0,
                               ifelse(Time <= 9,
                                      Time - 5,
                                      4)),
                Time3 = ifelse(Time <= 9,
                               0, 
                               Time - 9))
tb_test_time_pca <- c("Axis 1", "Axis 2") %>% 
  purrr::map_dfr(function(axis) {
    tb_subset <- tb_samples %>% 
      dplyr::mutate(outcome = !!sym(axis))
    fit_lme4 <- lme4::lmer(
      outcome ~ Time1 + Time2 + Time3 + (1|SubjectID), 
      data = tb_subset)
    results <- lmerTest:::summary.lmerModLmerTest(fit_lme4)$coef
    tibble::tibble(
      Test = c("Daily change (day 0-5)",
               "Daily change (day 6-9)",
               "Daily change (day 10-15)"),
      Axis = axis,
      Coefficient = results[seq(2, 4), 1],
      `Standard error` = results[seq(2, 4), 2],
      `T statistic` = results[seq(2, 4), 4],
      p = results[seq(2, 4), 5]
    )
  }) %>% 
  dplyr::mutate(Method = "PCA") 
tb_test_time_write <- rbind(tb_test_time_pca,
                            tb_test_time) %>% 
  dplyr::ungroup() %>% 
  dplyr::mutate(Method = factor(Method, 
                                levels = c("PCA",
                                           "CTF",
                                           "microTensor")),
                Test = factor(Test,
                              levels = c("Daily change (day 0-5)",
                                         "Daily change (day 6-9)",
                                         "Daily change (day 10-15)")),
                Axis = factor(Axis, 
                              levels = c("Axis 1",
                                         "Axis 2"))) %>% 
  dplyr::arrange(Method, Test, Axis) %>% 
  dplyr::filter(Test == "Daily change (day 6-9)") %>% 
  dplyr::group_by(Method) %>% 
  dplyr::mutate(Method_write = ifelse(seq(1, dplyr::n()) == 1,
                                      as.character(Method),
                                      "")) %>% 
  dplyr::group_by(Method, Test) %>% 
  dplyr::mutate(Test_write = ifelse(seq(1, dplyr::n()) == 1,
                                    as.character(Test),
                                    "")) %>%
  dplyr::ungroup() %>% 
  dplyr::select(Method_write, Test_write, Axis,
                Coefficient, `Standard error`, `T statistic`, p) %>% 
  dplyr::rename(Method = Method_write,
                Test = Test_write) %>% 
  dplyr::mutate(
    Coefficient = 
      sapply(Coefficient,
             format,
             digits = 3, 
             scientific = -2, 
             drop0trailing = TRUE),
    `Standard error` = 
      sapply(`Standard error`,
             format,
             digits = 3, 
             scientific = -2, 
             drop0trailing = TRUE),
    `T statistic` = 
      sapply(`T statistic`,
             format,
             digits = 3, 
             scientific = -2, 
             drop0trailing = TRUE),
    p = 
      sapply(p,
             format,
             digits = 3, 
             scientific = -2, 
             drop0trailing = TRUE))
readr::write_csv(tb_test_time_write, path = "tables/farmm_test_time.csv")
Yhat <- microTensor::get_Yhat(fit_microTensor$unweighted)
Xsum_m <- apply(X_array, c(2, 3), sum)
expY <- exp(Yhat)
expYsum_m <- apply(expY, c(2, 3), sum)
phat <- vapply(seq_len(dim(X_array)[1]),
               function(i) expY[i, , ] / expYsum_m,
               matrix(0.0, nrow = dim(X_array)[2], ncol = dim(X_array)[3]))
phat <- aperm(phat, perm = c(3, 1, 2))
Xhat <- vapply(seq_len(dim(X_array)[1]),
               function(i) phat[i, , ] * Xsum_m,
               matrix(0.0, nrow = dim(X_array)[2], ncol = dim(X_array)[3]))
Xhat <- aperm(Xhat, perm = c(3, 1, 2))
var_obs <- (X_array - Xhat)^2
var_exp <- Xhat * (1 - phat)

# left panel
OF <- apply(var_obs, c(2, 3), sum, na.rm = TRUE) / 
  apply(var_exp, c(2, 3), sum, na.rm = TRUE)

tb_plot <- tibble::tibble(
  depth = as.vector(Xsum_m),
  OF = as.vector(OF)
) %>% 
  dplyr::filter(!is.na(depth)) 

p1 <- tb_plot %>% 
  ggplot(aes(x = log10(depth),
             y = log10(OF))) + 
  geom_point() +
  theme_bw() +
  xlab(expression(paste(log[10], " Read Depth"))) +
  ylab(expression(paste(log[10], " Overdispersion Factor"))) +
  ggtitle("Multinomial")

Yhat <- microTensor::get_Yhat(fit_microTensor)
wt_estimate <- microTensor::estimate_wt(X = X_array, Yhat = Yhat)

Xsum_m <- apply(X_array, c(2, 3), sum)
expY <- exp(Yhat)
expYsum_m <- apply(expY, c(2, 3), sum)
phat <- vapply(seq_len(dim(X_array)[1]),
               function(i) expY[i, , ] / expYsum_m,
               matrix(0.0, nrow = dim(X_array)[2], ncol = dim(X_array)[3]))
phat <- aperm(phat, perm = c(3, 1, 2))
Xhat <- vapply(seq_len(dim(X_array)[1]),
               function(i) phat[i, , ] * Xsum_m,
               matrix(0.0, nrow = dim(X_array)[2], ncol = dim(X_array)[3]))
Xhat <- aperm(Xhat, perm = c(3, 1, 2))
var_obs <- (X_array - Xhat)^2
var_exp <- Xhat * (1 - phat)


# right panel
OF <- apply(var_obs, c(2, 3), sum, na.rm = TRUE) / 
  apply(var_exp, c(2, 3), sum, na.rm = TRUE) /
  (1 + (Xsum_m - 1) * wt_estimate$phi)

tb_plot <- tibble::tibble(
  depth = as.vector(Xsum_m),
  OF = as.vector(OF)
) %>% 
  dplyr::filter(!is.na(depth)) %>% 
  dplyr::filter(depth > 10000)

p2 <- tb_plot %>% 
  ggplot(aes(x = log10(depth),
             y = log10(OF))) + 
  geom_point() +
  theme_bw() +
  xlab(expression(paste(log[10], " Read Depth"))) +
  ylab(expression(paste(log[10], " Overdispersion Factor"))) +
  ggtitle("Quasi-likelihood")


p <- cowplot::plot_grid(p1, p2, labels = c("A)", "B)"), nrow = 1)
ggsave(p, filename = "figures/figure2/figure2.pdf",
       width = 6.5, height = 3.5)


syma-research/microTensor documentation built on July 1, 2022, 6:30 a.m.