R/residual_conv_1d.R

Defines functions layer_wavenet_dilated_causal_convolution_1d

Documented in layer_wavenet_dilated_causal_convolution_1d

WaveNetDilatedCausalConvolution1D <- R6::R6Class(
  "WaveNetDilatedCausalConvolution",
  inherit = keras::KerasLayer,
  public = list(

    filters = NULL,
    kernel_size = NULL,
    dilation_rate = NULL,

    initialize = function(filters, kernel_size, dilation_rate) {

      self$filters <- filters
      self$kernel_size <- kernel_size
      self$dilation_rate <- dilation_rate

    },

    conv_sigmoid = NULL,
    conv_tanh = NULL,
    conv_1x1 = NULL,
    conv_1x1_filters = NULL,


    build = function(input_shape) {

      # https://github.com/ibab/tensorflow-wavenet/blob/master/wavenet/model.py#L245

      self$conv_sigmoid <- keras::layer_conv_1d(
        filters = self$filters,
        kernel_size = self$kernel_size,
        dilation_rate = self$dilation_rate,
        activation = "sigmoid",
        padding = "causal",
        use_bias = FALSE
      )

      self$conv_tanh <- keras::layer_conv_1d(
        filters = self$filters,
        kernel_size = self$kernel_size,
        dilation_rate = self$dilation_rate,
        activation = "tanh",
        padding = "causal",
        use_bias = FALSE
      )

      self$conv_1x1_filters <- input_shape[[3]]

      self$conv_1x1 <- keras::layer_conv_1d(
        filters = self$conv_1x1_filters,
        kernel_size = 1,
        use_bias = FALSE
      )

    },

    call = function(x, mask = NULL) {

      out <- keras::layer_multiply(
        list(
          self$conv_sigmoid(x),
          self$conv_tanh(x)
        )
      )

      out <- self$conv_1x1(out)

      residual <- keras::layer_add(
        list(
          x,
          out
        )
      )

      list(
        residual,
        out
      )
    },

    compute_output_shape = function(input_shape) {
      list(
        input_shape,
        input_shape
      )
    }

  )
)

#' Wavenet residual connections
#'
#' Residual connection as described in section 2.3 of
#' \href{https://arxiv.org/abs/1609.03499}{van den Oord et al., \cite{WaveNet: A Generative Model for Raw Audio}}.
#'
#' @inheritParams keras::layer_conv_1d
#'
#' @export
layer_wavenet_dilated_causal_convolution_1d <- function(object, filters, kernel_size,
                                                        dilation_rate, name = NULL,
                                                        trainable = TRUE) {
  keras::create_layer(WaveNetDilatedCausalConvolution1D, object, list(
    filters = filters,
    kernel_size = kernel_size,
    dilation_rate = dilation_rate,
    name = name,
    trainable = trainable
  ))
}
r-tensorflow/wavenet documentation built on Nov. 5, 2019, 2:06 a.m.