R/deeppam.R

Defines functions deeppam

Documented in deeppam

#' Construct a DeepPAM model with mgcv and keras.
#'
#' This function constructs a DeepPAM as described in Kopper et al. 2020.
#' It combines different deep models and a structured PAM(M) to a single
#' deep keras model.
#' @param networks_unstructured A list of all unstrctured deep models to be
#' considered in the DeepPAM. If empty, a PAM will be constructed.
#' Each entry of the list must be a list itself featuring two entries:
#' The input_layer and the net itself. For example this is a working input:
#' \code{image_input <- layer_input(shape = dim(unstructured_data$image)[2:3])
#'       image_output <- image_input %>%
#'         layer_conv_1d(1, 1, trainable = FALSE,
#'                       kernel_initializer = initializer_constant(0),
#'                       bias_initializer = initializer_constant(0)) %>%
#'      layer_global_max_pooling_1d() %>%
#'      layer_dense(1, "linear", trainable = FALSE,
#'                  kernel_initializer = initializer_constant(0))
#'      image_model <- list(input = image_input, output = image_output)
#'      networks_unstructured <- list(image = image_model)}
#' You may find the tutorial / vignette useful.
#'
#' @param structured_ped an object of class ped_tensor constructed using
#' from_ped_to_tensor(). This is the tensor that is used to fit the structured
#' part of the DeepPAM.
#' @param interactions Not implemented yet, thus must be FALSE.
#' However, this argument indicates whether the separate sub networks are
#' supposed to interact with one another. The argument should be a boolean or
#' a matrix featuring booleans. The matrix should indicate which sub nets should
#' interact.
#' @param orthogonal Not implemented yet, thus must be FALSE.
#' However, this argument indicates whether there should be an orthogonalisation
#' (as proposed in Ruegamer et al. (2020)) layer. Either a single boolean or a
#' vector of booleans.
#' @param warm_start logical. Should the PAM part be initialised with the
#' values from a fit of \code{pammtools::pamm()}?
#' @param trainable_structured logical. Should the PAM part be frozen during
#' training (FALSE) or trainable (TRUE)?
#' @param ... potential futher arguments.
#' @return a keras model of class "deeppam".
#' The object can use all regular keras methods, such as compile or fit in the
#' regular manner.
#' However, for fitting, make use of the weights that are stored in
#' attr(structured_ped, "weights"). More details are available in the tutorial.
#' @export
#' @author Philipp Kopper
#' @import tensorflow
#' @import keras
#' @import checkmate
deeppam <- function(networks_unstructured = list(),
                    structured_ped,
                    interactions = FALSE,
                    warm_start = FALSE,
                    trainable_structured = TRUE,
                    orthogonal = FALSE,
                    ...
) {
  if (interactions) {
    stop("Interactions between different sub models not implemented yet.")
  }
  if (orthogonal) {
    stop("Orthogonalisation not yet implemented.")
  }
  assert_class(structured_ped, "ped_tensor")
  X <- structured_ped[["X"]]
  y <- structured_ped[["y"]]
  timesteps <- dim(X)[2]
  P <- attr(structured_ped, "P")
  wghts <- attr(structured_ped, "weights")
  indicator <- attr(structured_ped, "indicator")
  coeffs <- attr(structured_ped, "coeffs")
  input_X <- layer_input(shape = dim(X)[2:3])
  my_batch_size <- tf$shape(input_X)[1]
  subnets <- vector(mode = "list", length = length(unique(indicator)))
  for (i in 0:max(indicator)) {
    rg <- which(indicator == i)
    from <- min(rg)
    to <- max(rg)
    subnets[[i + 1]] <- input_X %>%
      layer_lambda(f = function(x) x[, , from:to]) %>%
      layer_lambda(f = function(x) backend_reshape(x, timesteps, length(rg)))
    if (!(i %in% c(0, 1))) {
      P_c <- P[[i - 1]]
      if (warm_start) {
        subnets[[i + 1]] <- subnets[[i + 1]] %>%
          layer_dense(1, activation = "linear", use_bias = FALSE,
                      trainable = trainable_structured,
                      kernel_initializer = initializer_constant(coeffs[indicator == i]),
                      kernel_regularizer = function(x)
                        k_mean(k_batch_dot(x, k_dot(
                          tf$constant(P_c, dtype = "float32"), x),
                          axes = 2)))
      } else {
        subnets[[i + 1]] <- subnets[[i + 1]] %>%
          layer_dense(1, activation = "linear", use_bias = FALSE,
                      trainable = trainable_structured,
                      kernel_regularizer = function(x)
                        k_mean(k_batch_dot(x, k_dot(
                          tf$constant(P_c, dtype = "float32"), x),
                          axes = 2)))
      }
    } else if (i == 1) {
      if (warm_start) {
        subnets[[i + 1]] <- subnets[[i + 1]] %>%
          layer_dense(1, activation = "linear", use_bias = TRUE,
                      trainable = trainable_structured,
                      kernel_initializer = initializer_constant(coeffs[indicator == i]),
                      bias_initializer = initializer_constant(coeffs[1]))
      } else {
        subnets[[i + 1]] <- subnets[[i + 1]] %>%
          layer_dense(1, activation = "linear", use_bias = TRUE,
                      trainable = trainable_structured)
      }
    } else {
      subnets[[i + 1]] <- subnets[[i + 1]] %>%
        layer_dense(1, activation = "linear", use_bias = FALSE,
                    trainable = FALSE,
                    kernel_initializer = initializer_constant(1))
    }
  }
  structured_nn <- subnets %>% layer_add()
  nn_unstructured <- vector(mode = "list",
                            length = length(networks_unstructured))
  inputs_unstructured <- vector(mode = "list",
                                length = length(networks_unstructured))
  if (length(nn_unstructured) > 0) {
    for (i in 1:length(nn_unstructured)) {
      inputs_unstructured[[i]] <- networks_unstructured[[i]]$input
      nn_unstructured[[i]] <- networks_unstructured[[i]]$output %>%
        layer_repeat_vector(timesteps) %>%
        layer_lambda(f = function(x) backend_reshape(x, timesteps, 1L))
    }
  } else {
    message("No deep network given. Building a PAMM.")
  }
  if (length(networks_unstructured) > 1L) {
    nn_unstructured <- nn_unstructured %>% layer_add()
  } else if (length(networks_unstructured) == 1L) {
    nn_unstructured <- nn_unstructured[[1]]
  } else {
  }
  if (length(networks_unstructured) == 0) {
    merged <- structured_nn %>%
      layer_lambda(f = function(x) k_exp(x)) %>%
      layer_lambda(f = function(x) backend_reduce(x, as.integer(timesteps),
                                                  my_batch_size, 1L))
    inputs <- input_X
  } else {
    merged <- list(structured_nn, nn_unstructured) %>%
      layer_add() %>%
      layer_lambda(f = function(x) k_exp(x)) %>%
      layer_lambda(f = function(x) backend_reduce(x, as.integer(timesteps),
                                                  my_batch_size, 1L))
    inputs <- append(list(input_X), inputs_unstructured)
  }
  complete_nn <- keras_model(inputs = inputs, outputs = merged)
  class(complete_nn) <- c("deeppam", class(complete_nn))
  return(complete_nn)
}
pkopper/deeppam documentation built on Jan. 19, 2021, 12:39 a.m.