tests/testthat/test-layer-attention.R

test_succeeds("multi_head_attention", {

  if (tensorflow::tf_version() < "2.4")
    skip("requires tf_version() >= 2.4")

  layer <- layer_multi_head_attention(num_heads=2, key_dim=2, name = "hello")
  target <- layer_input(shape=c(8, 16))
  source <- layer_input(shape=c(4, 16))

  expect_equal(layer$name, "hello")

  c(output_tensor, weights) %<-% layer(target, source,return_attention_scores=TRUE)

  expect_equal(output_tensor$shape$as_list(), list(NULL, 8, 16))
  expect_equal(weights$shape$as_list(), list(NULL, 2, 8, 4))
})

Try the keras package in your browser

Any scripts or data that you put into this service are public.

keras documentation built on Aug. 21, 2021, 9:07 a.m.