inst/extdata/scripts/iss_parameter_investigation.R

d_file <- system.file("extdata", "sachs.discretised.txt", package = "bninfo")
obs_data <- read.delim(d_file, header = TRUE,sep = " ")
data(isachs)
node_names <- names(isachs)[1:11]
levels(obs_data$PIP2) <- c("1", "3", "2")
levels(obs_data$PKC) <- c("1", "3", "2")
for(node in setdiff(node_names, c("PIP2", "PKC"))){
  levels(obs_data[, node]) <- c("1", "2", "3")
}
for(node in node_names){
  obs_data[, node] <- factor(obs_data[, node], levels = sort(levels(obs_data[, node])))
}
int_targets <- c("PKC", "Akt", "PKA", "PIP2", "Mek")
############ Observation ############
cl <- makeCluster(4)
iss_vals <- seq(from = 100, to = 2500, by = 100)
base_no_prior <- rep(NA, length(iss_vals))
base_prior <- rep(NA, length(iss_vals))
for(i in 1:length(iss_vals)){
  print(iss_vals[i])
  algo_args <- list(score = "bde", prior = "uniform", tabu = 50, whitelist = NULL)
  boot_vanilla <- boot.strength(obs_data, cluster = cl, R = 50, m = nrow(int_data),
                                algorithm = "tabu",
                                algorithm.args = c(algo_args, list(iss = iss_vals[i])),
                                cpdag = TRUE)
  base_no_prior[i] <- mean(orientation_entropy(boot_vanilla))
  algo_args <- list(score = "bde", prior = "uniform", tabu = 50, whitelist = wl)
  boot_directed <- boot.strength(obs_data, cluster = cl, algorithm = "tabu", R = 50,
                                   m = nrow(int_data),
                                  algorithm.args = c(algo_args, list(iss = iss_vals[i])),
                                  cpdag = FALSE)
  base_prior[i] <- mean(orientation_entropy(boot_directed))
}
############ Full ############
cl <- makeCluster(4)
ents_no_prior <- rep(NA, length(iss_vals))
ents_prior <- rep(NA, length(iss_vals))
dir_no_prior <- rep(NA, length(iss_vals))
dir_prior <- rep(NA, length(iss_vals))
wl <- matrix(c("PKC", "Raf",
               "PKA", "Raf",
               "Raf", "Mek",
               "Mek", "Erk"),
             ncol = 2, byrow = T)
for(i in 1:length(iss_vals)){
  print(iss_vals[i])
  int_data <- isachs %>%
    rbind(mutate(sample_n(obs_data, nrow(obs_data), replace = TRUE), INT = NA)) # Add some of original data, resampled.
  INT <- lapply(1:11, function(x) {which(int_data$INT == x)})
  names(INT) <- node_names
#   boot_no_prior <- boot.strength(int_data[, 1:11], cl, algorithm = "tabu",
#                           algorithm.args = list(score = "mbde",
#                                                 iss = iss_vals[i],
#                                                 exp = INT,
#                                                 tabu = 50), cpdag = FALSE)
#   dir_no_prior[i] <- mean(boot_no_prior$direction)
#   ents_no_prior[i] <- mean(orientation_entropy(boot_no_prior))
  boot_prior <- boot.strength(int_data[, 1:11], cl, algorithm = "tabu",
                              algorithm.args = list(score = "mbde",
                                                    iss = iss_vals[i],
                                                    exp = INT,
                                                    whitelist = wl,
                                                    tabu = 50), cpdag = FALSE)
  dir_prior[i] <- mean(boot_prior$direction)
  ents_prior[i] <- mean(orientation_entropy(boot_prior))
}
##########################################################################################
##### Draw the picture
library(ggplot2)
library(stringr)
library(tidyr)
#boot <- boot_vanilla
data(tcell_examples)
good_arcs <- arcs2names(arcs(tcell_examples$net))
mean_null_strength <- function(boot){
  bad_arcs <- setdiff(arcs2names(boot), good_arcs)
  boot %>%
  mutate(arc_name = arcs2names(boot)) %>%
  filter(arc_name %in% bad_arcs) %$%
  strength %>%
  mean
}
wl <- matrix(c("PKC", "Raf",
               "PKA", "Raf",
               "Raf", "Mek",
               "Mek", "Erk"),
             ncol = 2, byrow = T)
cl <- makeCluster(4)
iss_vals <- c(100, 200, 300, 400, 500, 600, 700, 1000, 1250, 1500, 1750, 2000, 2500)
raw_results <- data.frame(iss = iss_vals) %>%
  mutate(obs_no_prior = rep(NA, length(iss_vals)),
         obs_prior = obs_no_prior,
         #obs_prior_resampled = obs_prior,
         last_no_prior = obs_no_prior,
         last_prior = obs_no_prior,
         obs_no_prior_strength = obs_no_prior,
         #obs_no_prior_resampled_strength = obs_no_prior,
         obs_prior_strength = obs_no_prior,
         #obs_prior_resampled_strength = obs_prior,
         last_no_prior_strength = obs_no_prior,
         last_prior_strength = obs_no_prior)

for(i in 1:length(raw_results$iss)){
   print(raw_results$iss[i])
   int_data <- isachs %>%
     rbind(mutate(sample_n(obs_data, nrow(obs_data), replace = TRUE), INT = NA)) # Add some of original data, resampled.
   INT <- lapply(1:11, function(x) {which(int_data$INT == x)})
   names(INT) <- node_names
   ##### No prior not observation resampled
   algo_args <- list(score = "bde", prior = "uniform", tabu = 50, whitelist = NULL)
   boot_vanilla <- boot.strength(obs_data, cluster = cl, R = 200,
                                 algorithm = "tabu",
                                 algorithm.args = c(algo_args, list(iss = raw_results$iss[i])),
                                 cpdag = TRUE)
   raw_results$obs_no_prior[i] <- mean(orientation_entropy(boot_vanilla))
   raw_results$obs_no_prior_strength[i] <- mean_null_strength(boot_vanilla)
   ##### With prior observation
   algo_args <- list(score = "bde", prior = "uniform", tabu = 50, whitelist = wl)
   boot_directed <- boot.strength(obs_data, cluster = cl, algorithm = "tabu", R = 200,
                                  algorithm.args = c(algo_args, list(iss = raw_results$iss[i])),
                                  cpdag = FALSE)
   raw_results$obs_prior[i] <- mean(orientation_entropy(boot_directed))
   raw_results$obs_prior_strength[i] <- mean_null_strength(boot_directed)
   ##### With prior observation resampled
#    ##### Intervention no prior
   int_data <- isachs %>%
     rbind(mutate(sample_n(obs_data, nrow(obs_data), replace = TRUE), INT = NA)) # Add some of original data, resampled.
   INT <- lapply(1:11, function(x) {which(int_data$INT == x)})
   names(INT) <- node_names
   boot_int_no_prior <- boot.strength(int_data[, 1:11], cl, algorithm = "tabu", R = 200,
                                  algorithm.args = list(score = "mbde",
                                                        iss = raw_results$iss[i],
                                                        exp = INT,
                                                        tabu = 50), cpdag = FALSE)
   raw_results$last_no_prior[i] <- mean(orientation_entropy(boot_int_no_prior))
   raw_results$last_no_prior_strength[i] <- mean_null_strength(boot_int_no_prior)
   ##### Intervention with prior
   boot_int_prior <- boot.strength(int_data[, 1:11], cl, algorithm = "tabu",
                               algorithm.args = list(score = "mbde", R = 200,
                                                     iss = raw_results$iss[i],
                                                     exp = INT,
                                                     whitelist = wl,
                                                     tabu = 50), cpdag = FALSE)
   raw_results$last_prior[i] <- mean(orientation_entropy(boot_int_prior))
   raw_results$last_prior_strength[i] <- mean_null_strength(boot_int_prior)
}


full_plot_data <- raw_results %>%
  gather(experiment, causal_entropy, -iss) %>%
  mutate(group = ifelse(grepl("no_prior", experiment), "no_prior", "prior"),
         experiment = str_replace_all(experiment, "_prior", ""),
         experiment = str_replace_all(experiment, "_no", "")) %>%
  filter(!grepl("_strength", experiment)) %>%
  mutate(experiment = factor(experiment, levels = c("obs", "last")))


iss_vals <- unique(full_plot_data$iss)
plot_list <- as.list(iss_vals)
names(plot_list) <- paste(iss_vals)
for(iss_val in iss_vals){
  plot_data <- filter(full_plot_data, iss == iss_val)
  plot_list[[paste(iss_val)]] <- ggplot(plot_data, aes(experiment, causal_entropy, group = group)) +
    geom_line(aes(colour = group)) +
    labs(title = paste("Causal Entropy at ISS =", iss_val)) +
    ylim(c(.2, 1))
}

pdf("/Users/robertness/Downloads/ggplot.pdf", onefile = TRUE)
for (i in 1:length(plot_list)) {
  do.call("grid.arrange", plot_list[i])
}
dev.off()
robertness/bninfo documentation built on May 27, 2019, 10:32 a.m.