tests/testthat/test-models-mobilenetv2.R

test_that("mobilenetv2 works", {
  model <- model_mobilenet_v2()
  input <- torch::torch_randn(1, 3, 256, 256)
  out <- model(input)

  expect_tensor_shape(out, c(1, 1000))

  model <- model_mobilenet_v2(pretrained = TRUE)
  torch::torch_manual_seed(1)
  input <- torch::torch_randn(1, 3, 256, 256)
  out <- model(input)

  expect_tensor_shape(out, c(1, 1000))
  # expect_equal_to_r(out[1,1], -1.1959798336029053, tolerance = 1e-5) # value taken from pytorch
})

test_that("we can prune head of mobilenetv2 moels", {
  mobilenet <- model_mobilenet_v2(pretrained=TRUE)

  expect_no_error(prune <- nn_prune_head(mobilenet, 1))
  expect_true(inherits(prune, "nn_sequential"))
  expect_equal(length(prune), 1)
  expect_true(inherits(prune[[length(prune)]], "nn_sequential"))

  input <- torch::torch_randn(1, 3, 256, 256)
  out <- prune(input)
  expect_tensor_shape(out, c(1, 1280, 8, 8))
})
mlverse/torchvision documentation built on Sept. 18, 2024, 4:03 p.m.