#' Interpolation between baseline and prediction
#' @param baseline_type Baseline sequence, either "zero" for all zeros or "shuffle" for random permutation of input_seq.
#' @param m_steps Number of steps between baseline and original input.
#' @param input_seq Input tensor.
#' @noRd
interpolate_seq <- function(m_steps = 50,
baseline_type = "shuffle",
input_seq) {
stopifnot(baseline_type %in% c("zero", "shuffle", "unif"))
if (is.list(input_seq)) {
baseline <- list()
for (i in 1:length(input_seq)) {
input_dim <- dim(input_seq[[i]])
if (baseline_type == "zero") {
baseline[[i]] <- array(rep(0, prod(input_dim)), dim = input_dim)
if (baseline_type == "shuffle") {
input_dim <- dim(input_seq[[i]])
baseline[[i]] <- array(input_seq[[i]][ , sample(input_dim[2]), ], dim = input_dim)
if (baseline_type == "unif") {
baseline[[i]] <- array(stats::runif(prod(input_dim)), dim = input_dim)
} else {
if (baseline_type == "zero") {
baseline <- array(rep(0, prod(dim(input_seq))), dim = dim(input_seq))
if (baseline_type == "shuffle") {
baseline <- array(input_seq[ , sample(dim(input_seq)[2]), ], dim = dim(input_seq))
if (baseline_type == "unif") {
baseline <- array(stats::runif(prod(dim(input_seq))), dim = dim(input_seq))
m_steps <- as.integer(m_steps)
alphas <- tensorflow::tf$linspace(start = 0.0, stop = 1.0, num = m_steps + 1L) # Generate m_steps intervals for integral_approximation() below.
alphas_x <- alphas[ , tensorflow::tf$newaxis, tensorflow::tf$newaxis]
if (is.list(baseline)) {
delta <- list()
sequences <- list()
for (i in 1:length(baseline)) {
delta[[i]] <- input_seq[[i]] - baseline[[i]]
sequences[[i]] <- baseline[[i]] + alphas_x * delta[[i]]
} else {
delta <- input_seq - baseline
sequences <- baseline + alphas_x * delta
#' Compute gradients
#' @param input_idx Input layer to monitor for > 1 input.
#' @param target_class_idx Index of class to compute gradient for.
#' @param model Model to compute gradient for.
#' @param pred_stepwise Whether to do predictions with batch_size 1 rather than all at once. Can be used if
#' input is too big to handle at once.
#' @noRd
compute_gradients <- function(input_tensor, target_class_idx, model, input_idx = NULL, pred_stepwise = FALSE) {
# if (is.list(input_tensor)) {
# stop("Stepwise predictions only supported for single input layer yet")
# }
reticulate::py_run_string("import tensorflow as tf")
py$input_tensor <- input_tensor
py$input_idx <- as.integer(input_idx - 1)
py$target_class_idx <- as.integer(target_class_idx - 1)
py$model <- model
if (!is.null(input_idx)) {
"with tf.GradientTape() as tape:
probs = model(input_tensor)[:, target_class_idx]
} else {
"with tf.GradientTape() as tape:
probs = model(input_tensor)[:, target_class_idx]
grad <- py$tape$gradient(py$probs, py$input_tensor)
if (!is.null(input_idx)) {
} else {
integral_approximation <- function(gradients) {
reticulate::py_run_string("import tensorflow as tf")
py$gradients <- gradients
# riemann_trapezoidal
reticulate::py_run_string("grads = (gradients[:-1] + gradients[1:]) / tf.constant(2.0)")
reticulate::py_run_string("integrated_gradients = tf.math.reduce_mean(grads, axis=0)")
#' Compute integrated gradients
#' Computes integrated gradients scores for model and an input sequence.
#' This can be used to visualize what part of the input is import for the models decision.
#' Code is R implementation of python code from [here](https://www.tensorflow.org/tutorials/interpretability/integrated_gradients).
#' Tensorflow implementation is based on this [paper](https://arxiv.org/abs/1703.01365).
#' @param baseline_type Baseline sequence, either `"zero"` for all zeros or `"shuffle"` for random permutation of `input_seq`.
#' @param m_steps Number of steps between baseline and original input.
#' @param input_seq Input tensor.
#' @param target_class_idx Index of class to compute gradient for
#' @param model Model to compute gradient for.
#' @param pred_stepwise Whether to do predictions with batch size 1 rather than all at once. Can be used if
#' input is too big to handle at once. Only supported for single input layer.
#' @param num_baseline_repeats Number of different baseline estimations if baseline_type is `"shuffle"` (estimate integrated
#' gradient repeatedly for different shuffles). Final result is average of \code{num_baseline} single calculations.
#' @examplesIf reticulate::py_module_available("tensorflow")
#' library(reticulate)
#' model <- create_model_lstm_cnn(layer_lstm = 8, layer_dense = 3, maxlen = 20, verbose = FALSE)
#' random_seq <- sample(0:3, 20, replace = TRUE)
#' input_seq <- array(keras::to_categorical(random_seq), dim = c(1, 20, 4))
#' integrated_gradients(
#' input_seq = input_seq,
#' target_class_idx = 3,
#' model = model)
#' @returns A tensorflow tensor.
#' @export
integrated_gradients <- function(m_steps = 50,
baseline_type = "zero",
pred_stepwise = FALSE,
num_baseline_repeats = 1) {
reticulate::py_run_string("import tensorflow as tf")
input_idx <- NULL
if (num_baseline_repeats > 1 & baseline_type == "zero") {
warning('Ignoring num_baseline_repeats if baseline is of type "zero". Did you mean to use baseline_type = "shuffle"?')
if (num_baseline_repeats == 1 | baseline_type == "zero") {
baseline_seq <- interpolate_seq(m_steps = m_steps,
baseline_type = baseline_type,
input_seq = input_seq)
if (is.list(baseline_seq)) {
for (i in 1:length(baseline_seq)) {
baseline_seq[[i]] <- tensorflow::tf$cast(baseline_seq[[i]], dtype = "float32")
} else {
baseline_seq <- tensorflow::tf$cast(baseline_seq, dtype = "float32")
if (is.list(input_seq)) {
path_gradients <- list()
avg_grads <- list()
ig <- list()
if (pred_stepwise) {
path_gradients <- gradients_stepwise(
model = model,
baseline_seq = baseline_seq,
target_class_idx = target_class_idx)
} else {
path_gradients <- compute_gradients(
model = model,
input_tensor = baseline_seq,
target_class_idx = target_class_idx,
input_idx = NULL,
pred_stepwise = pred_stepwise)
for (i in 1:length(input_seq)) {
avg_grads[[i]] <- integral_approximation(gradients = path_gradients[[i]])
ig[[i]] <- ((input_seq[[i]] - baseline_seq[[i]][1, , ]) * avg_grads[[i]])[1, , ]
} else {
if (pred_stepwise) {
path_gradients <- gradients_stepwise(model = model,
baseline_seq = baseline_seq,
target_class_idx = target_class_idx,
input_idx = NULL)
} else {
path_gradients <- compute_gradients(
model = model,
input_tensor = baseline_seq,
target_class_idx = target_class_idx,
input_idx = NULL,
pred_stepwise = pred_stepwise)
avg_grads <- integral_approximation(gradients = path_gradients)
ig <- ((input_seq - baseline_seq[1, , ]) * avg_grads)[1, , ]
} else {
ig_list <- list()
for (i in 1:num_baseline_repeats) {
ig_list[[i]] <- integrated_gradients(m_steps = m_steps,
baseline_type = "shuffle",
input_seq = input_seq,
target_class_idx = target_class_idx,
model = model,
pred_stepwise = pred_stepwise,
num_baseline_repeats = 1)
ig_stacked <- tensorflow::tf$stack(ig_list, axis = 0L)
ig <- tensorflow::tf$reduce_mean(ig_stacked, axis = 0L)
#' Compute gradients stepwise (one batch at a time)
#' @noRd
gradients_stepwise <- function(model = model, baseline_seq, target_class_idx,
input_idx = NULL) {
if (is.list(baseline_seq)) {
first_dim <- dim(baseline_seq[[1]])[1]
num_input_layers <- length(baseline_seq)
l <- list()
for (j in 1:first_dim) {
input_list <- list()
for (k in 1:length(baseline_seq)) {
input <- as.array(baseline_seq[[k]][j, , ])
input <- array(input, dim = c(1, dim(baseline_seq[[k]])[-1]))
input <- tensorflow::tf$cast(input, baseline_seq[[k]]$dtype)
input_list[[k]] <- input
output <- compute_gradients(
model = model,
input_tensor = input_list,
target_class_idx = target_class_idx,
input_idx = NULL)
for (m in 1:length(output)) {
output[[m]] <- tensorflow::tf$squeeze(output[[m]])
l[[j]] <- output
path_gradients <- vector("list", num_input_layers)
for (n in 1:num_input_layers) {
temp_list <- vector("list", first_dim)
for (p in 1:first_dim){
temp_list[[p]] <- l[[p]][[n]]
path_gradients[[n]] <- tensorflow::tf$stack(temp_list)
} else {
l <- list()
for (j in 1:dim(baseline_seq)[1]) {
input <- as.array(baseline_seq[j, , ])
input <- array(input, dim = c(1, dim(baseline_seq)[-1]))
input <- tensorflow::tf$cast(input, baseline_seq$dtype)
output <- compute_gradients(
model = model,
input_tensor = input,
target_class_idx = target_class_idx,
input_idx = NULL)
output <- tensorflow::tf$squeeze(output)
l[[j]] <- output
path_gradients <- tensorflow::tf$stack(l)
#' Heatmap of integrated gradient scores
#' Creates a heatmap from output of \code{\link{integrated_gradients}} function. The first row contains
#' the column-wise absolute sums of IG scores and the second row the sums. Rows 3 to 6 contain the IG scores for each
#' position and each nucleotide. The last row contains nucleotide information.
#' @param integrated_grads Matrix of integrated gradient scores (output of \code{\link{integrated_gradients}} function).
#' @param input_seq Input sequence for model. Should be the same as \code{input_seq} input for corresponding
#' \code{\link{integrated_gradients}} call that computed input for \code{integrated_grads} argument.
#' @examplesIf reticulate::py_module_available("tensorflow") && requireNamespace("ComplexHeatmap", quietly = TRUE)
#' library(reticulate)
#' model <- create_model_lstm_cnn(layer_lstm = 8, layer_dense = 3, maxlen = 20, verbose = FALSE)
#' random_seq <- sample(0:3, 20, replace = TRUE)
#' input_seq <- array(keras::to_categorical(random_seq), dim = c(1, 20, 4))
#' ig <- integrated_gradients(
#' input_seq = input_seq,
#' target_class_idx = 3,
#' model = model)
#' heatmaps_integrated_grad(integrated_grads = ig,
#' input_seq = input_seq)
#' @returns A list of heatmaps.
#' @export
heatmaps_integrated_grad <- function(integrated_grads,
input_seq) {
if (is.list(input_seq)) {
for (i in 1:length(input_seq)) {
input_seq[[i]] <- tensorflow::tf$cast(input_seq[[i]], dtype = "float32")
for (i in 1:length(integrated_grads)) {
integrated_grads[[i]] <- tensorflow::tf$cast(integrated_grads[[i]], dtype = "float32")
} else {
input_seq <- tensorflow::tf$cast(input_seq, dtype = "float32")
integrated_grads <- tensorflow::tf$cast(integrated_grads, dtype = "float32")
if (is.list(input_seq)) {
num_input <- length(input_seq)
attribution_mask <- list()
nuc_matrix <- list()
nuc_seq <- list()
sum_nuc <- list()
for (i in 1:length(integrated_grads)) {
py$integrated_grads <- integrated_grads[[i]]
reticulate::py_run_string("attribution_mask = tf.reduce_sum(tf.math.abs(integrated_grads), axis=-1)")
reticulate::py_run_string("sum_nuc = tf.reduce_sum(integrated_grads, axis=-1)")
attribution_mask[[i]] <- py$attribution_mask
attribution_mask[[i]] <- as.matrix(attribution_mask[[i]], nrow = 1) %>% as.data.frame()
colnames(attribution_mask[[i]]) <- "abs_sum"
sum_nuc[[i]] <- py$sum_nuc
sum_nuc[[i]] <- as.matrix(sum_nuc[[i]], nrow = 1) %>% as.data.frame()
colnames(sum_nuc[[i]]) <- "sum"
if (length(dim(integrated_grads[[i]])) == 3) {
nuc_matrix[[i]] <- as.matrix(integrated_grads[[i]][1, , ])
if (length(dim(integrated_grads[[i]])) == 2) {
nuc_matrix[[i]] <- as.matrix(integrated_grads[[i]])
amb_nuc <- (apply(input_seq[[i]][1, ,], 1, max) %>% as.character()) != "1"
nuc_seq[[i]] <- apply(input_seq[[i]][1, ,], 1, which.max) %>% as.character()
nuc_seq[[i]] <- nuc_seq[[i]] %>% stringr::str_replace_all("1", "A") %>%
stringr::str_replace_all("2", "C") %>%
stringr::str_replace_all("3", "G") %>%
stringr::str_replace_all("4", "T")
nuc_seq[[i]][amb_nuc] <- "0"
rownames(nuc_matrix[[i]]) <- nuc_seq[[i]]
colnames(nuc_matrix[[i]]) <- c("A", "C", "G", "T")
} else {
num_input <- 1
py$integrated_grads <- integrated_grads
reticulate::py_run_string("attribution_mask = tf.reduce_sum(tf.math.abs(integrated_grads), axis=-1)")
reticulate::py_run_string("sum_nuc = tf.reduce_sum(integrated_grads, axis=-1)")
#py_run_string("mean_nuc = tf.reduce_mean(integrated_grads, axis=-1)")
attribution_mask <- py$attribution_mask
attribution_mask <- as.matrix(attribution_mask, nrow = 1) %>% as.data.frame()
colnames(attribution_mask) <- "abs_sum"
sum_nuc <- py$sum_nuc
sum_nuc <- as.matrix(sum_nuc, nrow = 1) %>% as.data.frame()
colnames(sum_nuc) <- "sum"
if (length(dim(integrated_grads)) == 3) {
nuc_matrix <- as.matrix(integrated_grads[1, , ])
if (length(dim(integrated_grads)) == 2) {
nuc_matrix <- as.matrix(integrated_grads)
amb_nuc <- (apply(input_seq[1, ,], 1, max) %>% as.character()) != "1"
nuc_seq <- apply(input_seq[1, ,], 1, which.max) %>% as.character()
nuc_seq <- nuc_seq %>% stringr::str_replace_all("1", "A") %>%
stringr::str_replace_all("2", "C") %>%
stringr::str_replace_all("3", "G") %>%
stringr::str_replace_all("4", "T")
nuc_seq[amb_nuc] <- "0"
rownames(nuc_matrix) <- nuc_seq
colnames(nuc_matrix) <- c("A", "C", "G", "T")
if (num_input == 1) {
ig_min <- keras::k_min(integrated_grads)$numpy()
ig_max <- keras::k_max(integrated_grads)$numpy()
col_fun <- circlize::colorRamp2(c(ig_min, mean(c(ig_max, ig_min)) , ig_max), c("blue", "white", "red"))
} else {
col_fun <- list()
for (i in 1:num_input) {
ig_min <- keras::k_min(integrated_grads[[i]])$numpy()
ig_max <- keras::k_max(integrated_grads[[i]])$numpy()
col_fun[[i]] <- circlize::colorRamp2(c(ig_min, mean(c(ig_max, ig_min)) , ig_max), c("blue", "white", "red"))
hm_list <- list()
if (num_input == 1) {
row_ha = ComplexHeatmap::columnAnnotation(abs_sum = attribution_mask[,1], sum = sum_nuc[,1]) # mean = mean_nuc[,1]
if (length(unique(row.names(nuc_matrix))) == 4) {
nuc_col <- c("A" = "red", "C" = "green", "G" = "blue", "T" = "yellow")
if (length(unique(row.names(nuc_matrix))) == 5) {
nuc_col <- c("A" = "red", "C" = "green", "G" = "blue", "T" = "yellow", "0" = "white")
ha <- ComplexHeatmap::HeatmapAnnotation(nuc = row.names(nuc_matrix), col = list(nuc = nuc_col))
hm_list[[1]] <- ComplexHeatmap::Heatmap(matrix = t(nuc_matrix),
name = "hm",
top_annotation = row_ha,
bottom_annotation = ha,
col = col_fun,
cluster_rows = FALSE,
cluster_columns = FALSE,
column_names_rot = 0
} else {
for (i in 1:num_input) {
row_ha <- ComplexHeatmap::columnAnnotation(abs_sum = attribution_mask[[i]][,1], sum = sum_nuc[[i]][,1])
if (length(unique(row.names(nuc_matrix[[i]]))) == 4) {
nuc_col <- c("A" = "red", "C" = "green", "G" = "blue", "T" = "yellow")
if (length(unique(row.names(nuc_matrix[[i]]))) == 5) {
nuc_col <- c("A" = "red", "C" = "green", "G" = "blue", "T" = "yellow", "0" = "white")
ha <- ComplexHeatmap::HeatmapAnnotation(nuc = row.names(nuc_matrix[[i]]), col = list(nuc = nuc_col))
hm_list[[i]] <- ComplexHeatmap::Heatmap(matrix = t(nuc_matrix[[i]]),
name = paste0("hm_", i),
top_annotation = row_ha,
bottom_annotation = ha,
col = col_fun[[i]],
cluster_rows = FALSE,
cluster_columns = FALSE,
column_names_rot = 0
