library(SingleCellExperiment)
library(reticulate)
library(tidyverse)
library(ggplot2)
library(scibet)
library(Seurat)
library(scmap)
library(fmsb)
pan_path <- '/home/pauling/projects/01_classifier/01_data/05_CrossPlatform/pan_fit_res.rds.gz'
pan_data <- readr::read_rds(pan_path)
pan_data <- pan_data[,3]
pan_data <- pan_data %>% dplyr::mutate(GEO = c("GSE84133","GSE85241","E-MTAB-5061","GSE81608"))

process the data

pan_data %>%
  dplyr::mutate(
    tpm = purrr::map(
      .x = tpm,
      .f = function(.x){
        label <- .x$label
        .x$label <- NULL
        .x <- lapply(.x, as.numeric) %>% do.call("data.frame", .)
        .x$label <- label
        .x
      }
    )
  ) -> pan_data
scmap_ck <- function(expr_train, expr_test){

  train_label <- expr_train$label
  test_label <- expr_test$label

  expr_train <- expr_train %>% dplyr::select(-label) %>% t()
  expr_test <- expr_test %>% dplyr::select(-label) %>% t()

  ann <- data.frame(cell_type1 = train_label)
  sce <- SingleCellExperiment(assays = list(normcounts = as.matrix(expr_train)), colData = ann)
  logcounts(sce) <- log2(normcounts(sce) + 1)
  rowData(sce)$feature_symbol <- rownames(sce)
  sce <- sce[!duplicated(rownames(sce)), ]

  tx_sce <- SingleCellExperiment(assays = list(normcounts = as.matrix(expr_test)))
  logcounts(tx_sce) <- log2(normcounts(tx_sce) + 1)
  rowData(tx_sce)$feature_symbol <- rownames(tx_sce)

  time <- system.time({
    sce <- selectFeatures(sce, n_features = 500,  suppress_plot = T)
    sce <- indexCluster(sce)
    scmapCluster_results <- scmapCluster(
      projection = tx_sce, 
      threshold = 0,
      index_list = list(
        yan = metadata(sce)$scmap_cluster_index
      )
    )
  })

  tibble(
    ori = as.character(test_label),
    prd = unlist(scmapCluster_results$combined_labs)) -> tmp

  num1 <- tmp %>% dplyr::filter(ori == prd) %>% nrow(.)
  ac <- num1/nrow(tmp)
  tibble(
    ac = ac,
    time = as.numeric(time[3]),
    method = 'scmap'
  )
}
scibet_fun <- function(expr_train, expr_test){
  time <- system.time({  
    prd <- scibet::SciBet(expr_train, expr_test[,-ncol(expr_test)], k = 500)
  })

  tibble(
    ori = as.character(expr_test$label),
    prd = prd
  ) -> tmp

  num1 <- tmp %>% dplyr::filter(ori == prd) %>% nrow(.)
  ac <- num1/nrow(tmp)
  tibble(
    ac = ac,
    time = as.numeric(time[3]),
    method = 'SciBet'
  )
}
Seurat3 <- function(expr_train, expr_test){

  data.frame(
    celltype = expr_train$label,
    tech = 'xx'
  ) -> metadata

  data.frame(
    celltype = expr_test$label,
    tech = 'yy'
  ) -> metadata1

  expr_train[,-ncol(expr_train)] <- log2(expr_train[,-ncol(expr_train)] +1)
  expr_test[,-ncol(expr_test)] <- log2(expr_test[,-ncol(expr_test)] +1)

  ori <- expr_test$label
  X_train <- as.matrix(t(expr_train[,-ncol(expr_train)]))
  X_test <- as.matrix(t(expr_test[,-ncol(expr_test)]))

  matr <- cbind(X_train, X_test)
  metadata <- rbind(metadata, metadata1)
  colnames(matr) <- as.character(1:ncol(matr))
  rownames(metadata) <- as.character(1:nrow(metadata))

  ttest <- CreateSeuratObject(counts = matr, meta.data = metadata)
  ttest.list <- SplitObject(object = ttest, split.by = "tech")

  time <- system.time({
  for (i in 1:length(x = ttest.list)) {
    ttest.list[[i]] <- NormalizeData(object = ttest.list[[i]], verbose = FALSE)
    ttest.list[[i]] <- FindVariableFeatures(object = ttest.list[[i]], 
                                            selection.method = "vst", nfeatures = 500, verbose = FALSE)
  }

  anchors <- FindTransferAnchors(reference = ttest.list[[1]], 
                                 query = ttest.list[[2]], 
                                 dims = 1:30)

  predictions <- TransferData(anchorset = anchors,
                              refdata = ttest.list[[1]]$celltype,
                              dims = 1:30)

  })

  tibble(
    ori = ori,
    prd = predictions$predicted.id
  ) -> tmp

  num1 <- tmp %>% dplyr::filter(ori == prd) %>% nrow(.)
  ac <- num1/nrow(tmp)
  tibble(
    ac = ac,
    time = as.numeric(time[3]),
    method = 'Seurat3'
  )
}
pipe_fun <- function(.x, .y){ 

  train <- pan_data$tpm[[.x]]
  test <- pan_data$tpm[[.y]]

  ck1 <- scmap_ck(train, test)
  print(1)
  ck2 <- scibet_fun(train, test)
  print(2)
  ck3 <- Seurat3(train, test)
  print(3)

  ck1 %>%
    dplyr::bind_rows(ck2) %>%
    dplyr::bind_rows(ck3) -> res

  return(res)
}
combn(1:4,2) %>% 
  t() %>% 
  as.tibble() %>%
  dplyr::rename(num1 = V1, num2 = V2) %>%
  dplyr::mutate(
    res = purrr::map2(
      .x = num1,
      .y = num2,
      .f = pipe_fun
    )
  ) -> res2

comb_res2 <- Reduce(rbind, res2$res)

combn(1:4,2) %>% 
  t() %>% 
  as.tibble() %>%
  dplyr::rename(num1 = V1, num2 = V2) %>%
  dplyr::mutate(
    res = purrr::map2_chr(
      .x = num1,
      .y = num2,
      .f = function(.x, .y){
        g1 <- pan_data$GEO[.x]
        g2 <- pan_data$GEO[.y]
        g <- paste(g1, g2, sep = " - ")
        return(g)
      }
    )
  ) -> map.da
combn(1:4,2) %>% 
  t() %>% 
  as.tibble() %>%
  dplyr::rename(num1 = V1, num2 = V2) %>%
  dplyr::mutate(
    num = purrr::map_dbl(
      .x = num2,
      .f = function(.x){
        pan_data$tpm[[.x]] %>% nrow()
      }
    )
  ) %>%
  dplyr::pull(num) -> cell_num

comb_res %>%
  #dplyr::mutate(cell_num = rep(cell_num,3)) %>%
  #dplyr::mutate(cell_per_second = cell_num/time) %>%
  dplyr::mutate(dataset = c(1,1,1,2,2,2,3,3,3,4,4,4,5,5,5,6,6,6)) %>%
  ggplot(aes(ac, cell_per_second)) +
  geom_point(aes(colour = method), size = 2) +
  scale_colour_manual(values = c("#FF34B3", "#FF6A6A", "#00CED1")) +
  scale_y_log10() +
  theme_classic() +
  theme(
    legend.position = 'none',
    axis.title = element_text(size = 13),
    axis.text = element_text(size = 13),
    legend.title = element_text(size = 13),
    legend.text = element_text(size = 13),
    axis.text.y = element_text(color="black"),
    axis.text.x = element_text(color="black"),
    panel.background = element_rect(colour = "black", fill = "white"),
    panel.grid = element_line(colour = "grey", linetype = "dashed"),
    panel.grid.major = element_line(
      colour = "grey",
      linetype = "dashed",
      size = 0.2
    )
  ) +
  labs(
    x = "Accuracy",
    y = "Cell number/Second"
  )
comb_res %>%
  dplyr::mutate(cell_num = rep(cell_num,3)) %>%
  dplyr::mutate(cell_per_second = cell_num/time) %>%
  dplyr::mutate(dataset = c(1,1,1,2,2,2,3,3,3,4,4,4,5,5,5,6,6,6)) %>%
  dplyr::rename(PPV = ac) %>%
  #dplyr::filter(dataset == 1) %>%
  ggplot(aes(method, cell_per_second)) +
  geom_col(aes(fill = PPV)) +
  scale_y_log10() +
  theme_bw() +
  theme(
    axis.title = element_text(size = 15),
    axis.text = element_text(size = 10),
    legend.title = element_text(size = 15),
    legend.text = element_text(size = 15),
    axis.text.y = element_text(color="black", size = 15),
    axis.text.x = element_text(color="black"),
    strip.background = element_rect(colour = "black", fill = "white")
  ) +
  labs(
    y = "Cells per second",
    x = ""
  ) +
  coord_flip() +
  facet_wrap(vars(dataset), nrow = 2) +
  scale_fill_distiller(palette = "Spectral") 
comb_res %>%
  dplyr::mutate(cell_num = rep(cell_num,3)) %>%
  dplyr::mutate(cell_per_second = cell_num/time) %>%
  readr::write_rds('/home/pauling/projects/01_classifier/01_data/05_CrossPlatform/ac_time/plot_da.rds')

combn(1:4,2) %>% 
  t() %>% 
  as.tibble() %>%
  dplyr::rename(num1 = V1, num2 = V2) %>%
  readr::write_rds('/home/pauling/projects/01_classifier/01_data/05_CrossPlatform/ac_time/data_use.rds')
a %>% 
  dplyr::mutate(method = ifelse(method == "Seurat3", "Seurat v3", method)) %>%
  ggplot(aes(factor(method, levels = c("SciBet","Seurat v3", "scmap")), ac)) +
  geom_boxplot(aes(colour = method)) +
  geom_point(aes(colour = method)) +
  theme_classic() +
  theme(
    legend.position = 'none',
    axis.title = element_text(size = 15),
    axis.text = element_text(size = 15),
    legend.title = element_text(size = 0),
    legend.text = element_text(size = 13),
    axis.text.y = element_text(color="black"),
    axis.text.x = element_text(color="black")
  ) +
  scale_colour_nejm() +
  labs(
    y = "Classification accuracy",
    x = " "
  )
comb_res2 %>%
  dplyr::mutate(Datasets = c(1,1,1,2,2,2,3,3,3,4,4,4,5,5,5,6,6,6)) %>%
  dplyr::select(ac, method, Datasets) %>%
  tidyr::spread(key = method, value = ac) %>%
  dplyr::mutate(Datasets = map.da$res) %>%
  ggplot(aes(Seurat3, SciBet)) +
  geom_segment(aes(x = 0.7, y = 0.7,xend = 1, yend = 1), linetype = "dashed", color = "grey50") +
  geom_point(aes(colour = Datasets), size = 4) +
  xlim(0.7,1) +
  ylim(0.7,1) +
  theme_bw() +
  scale_color_manual(values = my.co[c(8,4,7,2,9,11)]) +
  theme(
      axis.title = element_text(size = 15),
      axis.text = element_text(size = 12),
      legend.title = element_text(size = 12),
      legend.text = element_text(size = 12),
      axis.text.y = element_text(color="black"),
      axis.text.x = element_text(color="black")
  ) + 
  labs(
    x = "Seurat v3"
  )


PaulingLiu/scibet documentation built on April 17, 2021, 7:33 a.m.