#' Construct a DeepPAM model with mgcv and keras.
#'
#' This function constructs a DeepPAM as described in Kopper et al. 2020.
#' It combines different deep models and a structured PAM(M) to a single
#' deep keras model.
#' @param networks_unstructured A list of all unstrctured deep models to be
#' considered in the DeepPAM. If empty, a PAM will be constructed.
#' Each entry of the list must be a list itself featuring two entries:
#' The input_layer and the net itself. For example this is a working input:
#' \code{image_input <- layer_input(shape = dim(unstructured_data$image)[2:3])
#' image_output <- image_input %>%
#' layer_conv_1d(1, 1, trainable = FALSE,
#' kernel_initializer = initializer_constant(0),
#' bias_initializer = initializer_constant(0)) %>%
#' layer_global_max_pooling_1d() %>%
#' layer_dense(1, "linear", trainable = FALSE,
#' kernel_initializer = initializer_constant(0))
#' image_model <- list(input = image_input, output = image_output)
#' networks_unstructured <- list(image = image_model)}
#' You may find the tutorial / vignette useful.
#'
#' @param structured_ped an object of class ped_tensor constructed using
#' from_ped_to_tensor(). This is the tensor that is used to fit the structured
#' part of the DeepPAM.
#' @param interactions Not implemented yet, thus must be FALSE.
#' However, this argument indicates whether the separate sub networks are
#' supposed to interact with one another. The argument should be a boolean or
#' a matrix featuring booleans. The matrix should indicate which sub nets should
#' interact.
#' @param orthogonal Not implemented yet, thus must be FALSE.
#' However, this argument indicates whether there should be an orthogonalisation
#' (as proposed in Ruegamer et al. (2020)) layer. Either a single boolean or a
#' vector of booleans.
#' @param warm_start logical. Should the PAM part be initialised with the
#' values from a fit of \code{pammtools::pamm()}?
#' @param trainable_structured logical. Should the PAM part be frozen during
#' training (FALSE) or trainable (TRUE)?
#' @param ... potential futher arguments.
#' @return a keras model of class "deeppam".
#' The object can use all regular keras methods, such as compile or fit in the
#' regular manner.
#' However, for fitting, make use of the weights that are stored in
#' attr(structured_ped, "weights"). More details are available in the tutorial.
#' @export
#' @author Philipp Kopper
#' @import tensorflow
#' @import keras
#' @import checkmate
deeppam <- function(networks_unstructured = list(),
structured_ped,
interactions = FALSE,
warm_start = FALSE,
trainable_structured = TRUE,
orthogonal = FALSE,
...
) {
if (interactions) {
stop("Interactions between different sub models not implemented yet.")
}
if (orthogonal) {
stop("Orthogonalisation not yet implemented.")
}
assert_class(structured_ped, "ped_tensor")
X <- structured_ped[["X"]]
y <- structured_ped[["y"]]
timesteps <- dim(X)[2]
P <- attr(structured_ped, "P")
wghts <- attr(structured_ped, "weights")
indicator <- attr(structured_ped, "indicator")
coeffs <- attr(structured_ped, "coeffs")
input_X <- layer_input(shape = dim(X)[2:3])
my_batch_size <- tf$shape(input_X)[1]
subnets <- vector(mode = "list", length = length(unique(indicator)))
for (i in 0:max(indicator)) {
rg <- which(indicator == i)
from <- min(rg)
to <- max(rg)
subnets[[i + 1]] <- input_X %>%
layer_lambda(f = function(x) x[, , from:to]) %>%
layer_lambda(f = function(x) backend_reshape(x, timesteps, length(rg)))
if (!(i %in% c(0, 1))) {
P_c <- P[[i - 1]]
if (warm_start) {
subnets[[i + 1]] <- subnets[[i + 1]] %>%
layer_dense(1, activation = "linear", use_bias = FALSE,
trainable = trainable_structured,
kernel_initializer = initializer_constant(coeffs[indicator == i]),
kernel_regularizer = function(x)
k_mean(k_batch_dot(x, k_dot(
tf$constant(P_c, dtype = "float32"), x),
axes = 2)))
} else {
subnets[[i + 1]] <- subnets[[i + 1]] %>%
layer_dense(1, activation = "linear", use_bias = FALSE,
trainable = trainable_structured,
kernel_regularizer = function(x)
k_mean(k_batch_dot(x, k_dot(
tf$constant(P_c, dtype = "float32"), x),
axes = 2)))
}
} else if (i == 1) {
if (warm_start) {
subnets[[i + 1]] <- subnets[[i + 1]] %>%
layer_dense(1, activation = "linear", use_bias = TRUE,
trainable = trainable_structured,
kernel_initializer = initializer_constant(coeffs[indicator == i]),
bias_initializer = initializer_constant(coeffs[1]))
} else {
subnets[[i + 1]] <- subnets[[i + 1]] %>%
layer_dense(1, activation = "linear", use_bias = TRUE,
trainable = trainable_structured)
}
} else {
subnets[[i + 1]] <- subnets[[i + 1]] %>%
layer_dense(1, activation = "linear", use_bias = FALSE,
trainable = FALSE,
kernel_initializer = initializer_constant(1))
}
}
structured_nn <- subnets %>% layer_add()
nn_unstructured <- vector(mode = "list",
length = length(networks_unstructured))
inputs_unstructured <- vector(mode = "list",
length = length(networks_unstructured))
if (length(nn_unstructured) > 0) {
for (i in 1:length(nn_unstructured)) {
inputs_unstructured[[i]] <- networks_unstructured[[i]]$input
nn_unstructured[[i]] <- networks_unstructured[[i]]$output %>%
layer_repeat_vector(timesteps) %>%
layer_lambda(f = function(x) backend_reshape(x, timesteps, 1L))
}
} else {
message("No deep network given. Building a PAMM.")
}
if (length(networks_unstructured) > 1L) {
nn_unstructured <- nn_unstructured %>% layer_add()
} else if (length(networks_unstructured) == 1L) {
nn_unstructured <- nn_unstructured[[1]]
} else {
}
if (length(networks_unstructured) == 0) {
merged <- structured_nn %>%
layer_lambda(f = function(x) k_exp(x)) %>%
layer_lambda(f = function(x) backend_reduce(x, as.integer(timesteps),
my_batch_size, 1L))
inputs <- input_X
} else {
merged <- list(structured_nn, nn_unstructured) %>%
layer_add() %>%
layer_lambda(f = function(x) k_exp(x)) %>%
layer_lambda(f = function(x) backend_reduce(x, as.integer(timesteps),
my_batch_size, 1L))
inputs <- append(list(input_X), inputs_unstructured)
}
complete_nn <- keras_model(inputs = inputs, outputs = merged)
class(complete_nn) <- c("deeppam", class(complete_nn))
return(complete_nn)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.