tests/testthat/test-models-alexnet.R

test_that("alexnet", {

  m <- model_alexnet()
  input <- torch::torch_randn(1, 3, 256, 256)

  out <- m(input)

  expect_tensor_shape(out, c(1, 1000))

  m <- model_alexnet(pretrained = TRUE)
  input <- torch::torch_randn(1, 3, 256, 256)

  out <- m(input)

  expect_tensor_shape(out, c(1, 1000))

})

Try the torchvision package in your browser

Any scripts or data that you put into this service are public.

torchvision documentation built on April 14, 2023, 5:08 p.m.