if (!reticulate::py_module_available("tensorflow")) { knitr::opts_chunk$set(eval = FALSE) } else { knitr::opts_chunk$set(eval = TRUE) }
library(deepG) library(keras) library(magrittr) library(ggplot2) library(reticulate)
options(rmarkdown.html_vignette.check_title = FALSE)
```{css, echo=FALSE} mark.in { background-color: CornflowerBlue; }
mark.out { background-color: IndianRed; }
## Introduction The <a href="https://arxiv.org/abs/1703.01365">Integrated Gradient</a> (IG) method can be used to determine what parts of an input sequence are important for the models decision. We start with training a model that can differentiate sequences based on the GC content (as described in the <a href="getting_started.html">Getting started tutorial</a>). ## Model Training We create two simple dummy training and validation data sets. Both consist of random <tt>ACGT</tt> sequences but the first category has a probability of 40% each for drawing <tt>G</tt> or <tt>C</tt> and the second has equal probability for each nucleotide (first category has around 80% <tt>GC</tt> content and second one around 50%). ```r set.seed(123) # Create data vocabulary <- c("A", "C", "G", "T") data_type <- c("train_1", "train_2", "val_1", "val_2") for (i in 1:length(data_type)) { temp_file <- tempfile() assign(paste0(data_type[i], "_dir"), temp_file) dir.create(temp_file) if (i %% 2 == 1) { header <- "label_1" prob <- c(0.1, 0.4, 0.4, 0.1) } else { header <- "label_2" prob <- rep(0.25, 4) } fasta_name_start <- paste0(header, "_", data_type[i], "file") create_dummy_data(file_path = temp_file, num_files = 1, seq_length = 20000, num_seq = 1, header = header, prob = prob, fasta_name_start = fasta_name_start, vocabulary = vocabulary) } # Create model maxlen <- 50 model <- create_model_lstm_cnn(maxlen = maxlen, filters = c(8, 16), kernel_size = c(8, 8), pool_size = c(3, 3), layer_lstm = 8, layer_dense = c(4, 2), model_seed = 3) # Train model hist <- train_model(model, train_type = "label_folder", run_name = "gc_model_1", path = c(train_1_dir, train_2_dir), path_val = c(val_1_dir, val_2_dir), epochs = 6, batch_size = 64, steps_per_epoch = 50, step = 50, vocabulary_label = c("high_gc", "equal_dist")) plot(hist)
We can try to visualize what parts of an input sequence is important for the models decision, using Integrated Gradient. Let's create a sequence with a high GC content. We use same number of Cs as Gs and of As as Ts.
set.seed(321) g_count <- 17 stopifnot(g_count < 25) a_count <- (50 - (2*g_count))/2 high_gc_seq <- c(rep("G", g_count), rep("C", g_count), rep("A", a_count), rep("T", a_count)) high_gc_seq <- high_gc_seq[sample(maxlen)] %>% paste(collapse = "") # shuffle nt order high_gc_seq
We need to one-hot encode the sequence before applying Integrated Gradient.
high_gc_seq_one_hot <- seq_encoding_label(char_sequence = high_gc_seq, maxlen = 50, start_ind = 1, vocabulary = vocabulary) head(high_gc_seq_one_hot[1,,])
Our model should be confident, this sequences belongs to the first class
pred <- predict(model, high_gc_seq_one_hot, verbose = 0) colnames(pred) <- c("high_gc", "equal_dist") pred
We can visualize what parts where important for the prediction.
ig <- integrated_gradients( input_seq = high_gc_seq_one_hot, target_class_idx = 1, model = model) if (requireNamespace("ComplexHeatmap", quietly = TRUE)) { heatmaps_integrated_grad(integrated_grads = ig, input_seq = high_gc_seq_one_hot) } else { message("Skipping ComplexHeatmap-related code because the package is not installed.") }
We may test how our models prediction changes if we exchange certain nucleotides in the input sequence. First, we look for the positions with the smallest IG score.
ig <- as.array(ig) smallest_index <- which(ig == min(ig), arr.ind = TRUE) smallest_index
We may change the nucleotide with the lowest score and observe the change in prediction confidence
# copy original sequence high_gc_seq_one_hot_changed <- high_gc_seq_one_hot # prediction for original sequence predict(model, high_gc_seq_one_hot, verbose = 0) # change nt smallest_index <- which(ig == min(ig), arr.ind = TRUE) smallest_index row_index <- smallest_index[ , "row"] col_index <- smallest_index[ , "col"] new_row <- rep(0, 4) nt_index_old <- col_index nt_index_new <- which.max(ig[row_index, ]) new_row[nt_index_new] <- 1 high_gc_seq_one_hot_changed[1, row_index, ] <- new_row cat("At position", row_index, "changing", vocabulary[nt_index_old], "to", vocabulary[nt_index_new], "\n") pred <- predict(model, high_gc_seq_one_hot_changed, verbose = 0) print(pred)
Let's repeatedly apply the previous step and change the sequence after each iteration.
# copy original sequence high_gc_seq_one_hot_changed <- high_gc_seq_one_hot pred_list <- list() pred_list[[1]] <- pred <- predict(model, high_gc_seq_one_hot, verbose = 0) # change nts for (i in 1:20) { # update ig scores for changed input ig <- integrated_gradients( input_seq = high_gc_seq_one_hot_changed, target_class_idx = 1, model = model) %>% as.array() smallest_index <- which(ig == min(ig), arr.ind = TRUE) smallest_index row_index <- smallest_index[ , "row"] col_index <- smallest_index[ , "col"] new_row <- rep(0, 4) nt_index_old <- col_index nt_index_new <- which.max(ig[row_index, ]) new_row[nt_index_new] <- 1 high_gc_seq_one_hot_changed[1, row_index, ] <- new_row cat("At position", row_index, "changing", vocabulary[nt_index_old], "to", vocabulary[nt_index_new], "\n") pred <- predict(model, high_gc_seq_one_hot_changed, verbose = 0) pred_list[[i + 1]] <- pred } pred_df <- do.call(rbind, pred_list) pred_df <- data.frame(pred_df, iteration = 0:(nrow(pred_df) - 1)) names(pred_df) <- c("high_gc", "equal_dist", "iteration") ggplot(pred_df, aes(x = iteration, y = high_gc)) + geom_line() + ylab("high GC confidence")
We can try the same in the opposite direction, i.e. replace big IG scores.
# copy original sequence high_gc_seq_one_hot_changed <- high_gc_seq_one_hot pred_list <- list() pred <- predict(model, high_gc_seq_one_hot, verbose = 0) pred_list[[1]] <- pred # change nts for (i in 1:20) { # update ig scores for changed input ig <- integrated_gradients( input_seq = high_gc_seq_one_hot_changed, target_class_idx = 1, model = model) %>% as.array() biggest_index <- which(ig == max(ig), arr.ind = TRUE) biggest_index row_index <- biggest_index[ , "row"] row_index <- row_index[1] col_index <- biggest_index[ , "col"] new_row <- rep(0, 4) nt_index_old <- col_index nt_index_new <- which.min(ig[row_index, ]) new_row[nt_index_new] <- 1 high_gc_seq_one_hot_changed[1, row_index, ] <- new_row cat("At position", row_index, "changing", vocabulary[nt_index_old], "to", vocabulary[nt_index_new], "\n") pred <- predict(model, high_gc_seq_one_hot_changed, verbose = 0) pred_list[[i + 1]] <- pred } pred_df <- do.call(rbind, pred_list) pred_df <- data.frame(pred_df, iteration = 0:(nrow(pred_df) - 1)) names(pred_df) <- c("high_gc", "equal_dist", "iteration") ggplot(pred_df, aes(x = iteration, y = high_gc)) + geom_line() + ylab("high GC confidence")
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.