test_that("resnet18", {
model <- model_resnet18()
input <- torch::torch_randn(1, 3, 256, 256)
out <- model(input)
expect_tensor_shape(out, c(1, 1000))
model <- model_resnet18(pretrained = TRUE)
input <- torch::torch_randn(1, 3, 256, 256)
out <- model(input)
expect_tensor_shape(out, c(1, 1000))
})
test_that("resnet34", {
skip_on_os(c("windows", "mac"))
model <- model_resnet34()
input <- torch::torch_randn(1, 3, 256, 256)
out <- model(input)
expect_tensor_shape(out, c(1, 1000))
model <- model_resnet34(pretrained = TRUE)
input <- torch::torch_randn(1, 3, 256, 256)
out <- model(input)
expect_tensor_shape(out, c(1, 1000))
})
test_that("resnet50", {
skip_on_os(c("windows", "mac"))
model <- model_resnet50()
input <- torch::torch_randn(1, 3, 256, 256)
out <- model(input)
expect_tensor_shape(out, c(1, 1000))
model <- model_resnet50(pretrained = TRUE)
input <- torch::torch_randn(1, 3, 256, 256)
out <- model(input)
expect_tensor_shape(out, c(1, 1000))
})
test_that("resnet101", {
skip_on_os(c("windows", "mac"))
model <- model_resnet101()
input <- torch::torch_randn(1, 3, 256, 256)
out <- model(input)
expect_tensor_shape(out, c(1, 1000))
withr::with_options(list(timeout = 360),
model <- model_resnet101(pretrained = TRUE))
input <- torch::torch_randn(1, 3, 256, 256)
out <- model(input)
expect_tensor_shape(out, c(1, 1000))
})
test_that("resnet152", {
skip_on_os(c("windows", "mac"))
model <- model_resnet152()
input <- torch::torch_randn(1, 3, 256, 256)
out <- model(input)
expect_tensor_shape(out, c(1, 1000))
withr::with_options(list(timeout = 360),
model <- model_resnet152(pretrained = TRUE))
input <- torch::torch_randn(1, 3, 256, 256)
out <- model(input)
expect_tensor_shape(out, c(1, 1000))
})
test_that("resnext50_32x4d", {
skip_on_os(c("windows", "mac"))
model <- model_resnext50_32x4d()
input <- torch::torch_randn(1, 3, 256, 256)
out <- model(input)
expect_tensor_shape(out, c(1, 1000))
withr::with_options(list(timeout = 360),
model <- model_resnext50_32x4d(pretrained = TRUE))
input <- torch::torch_randn(1, 3, 256, 256)
out <- model(input)
expect_tensor_shape(out, c(1, 1000))
})
test_that("resnext50_32x4d", {
skip_on_os(c("windows", "mac"))
model <- model_resnext50_32x4d()
input <- torch::torch_randn(1, 3, 256, 256)
out <- model(input)
expect_tensor_shape(out, c(1, 1000))
withr::with_options(list(timeout = 360),
model <- model_resnext50_32x4d(pretrained = TRUE))
input <- torch::torch_randn(1, 3, 256, 256)
out <- model(input)
expect_tensor_shape(out, c(1, 1000))
})
test_that("resnext101_32x8d", {
skip_on_os(c("windows", "mac"))
model <- model_resnext101_32x8d()
input <- torch::torch_randn(1, 3, 256, 256)
out <- model(input)
expect_tensor_shape(out, c(1, 1000))
withr::with_options(list(timeout = 360),
model <- model_resnext101_32x8d(pretrained = TRUE))
input <- torch::torch_randn(1, 3, 256, 256)
out <- model(input)
expect_tensor_shape(out, c(1, 1000))
})
test_that("wide_resnet50_2", {
skip_on_os(c("windows", "mac"))
model <- model_wide_resnet50_2()
input <- torch::torch_randn(1, 3, 256, 256)
out <- model(input)
expect_tensor_shape(out, c(1, 1000))
withr::with_options(list(timeout = 360),
model <- model_wide_resnet50_2(pretrained = TRUE))
input <- torch::torch_randn(1, 3, 256, 256)
out <- model(input)
expect_tensor_shape(out, c(1, 1000))
})
test_that("wide_resnet101_2", {
skip_on_os(c("windows", "mac"))
model <- model_wide_resnet101_2()
input <- torch::torch_randn(1, 3, 256, 256)
out <- model(input)
expect_tensor_shape(out, c(1, 1000))
withr::with_options(list(timeout = 360),
model <- model_wide_resnet101_2(pretrained = TRUE))
input <- torch::torch_randn(1, 3, 256, 256)
out <- model(input)
expect_tensor_shape(out, c(1, 1000))
})
test_that("we can prune head of resnet34 moels", {
resnet34 <- model_resnet34(pretrained=TRUE)
expect_error(prune <- nn_prune_head(resnet34, 1), NA)
# expect_true(inherits(prune, "nn_sequential"))
expect_equal(length(prune), 9)
expect_true(inherits(prune[[length(prune)]], "nn_adaptive_avg_pool2d"))
input <- torch::torch_randn(1, 3, 256, 256)
out <- prune(input)
expect_tensor_shape(out, c(1, 512, 1, 1))
})
test_that("we can prune head of resnet50 moels", {
resnet50 <- model_resnet50(pretrained=TRUE)
expect_error(prune <- nn_prune_head(resnet50, 1), NA)
expect_true(inherits(prune, "nn_sequential"))
expect_equal(length(prune), 9)
expect_true(inherits(prune[[length(prune)]], "nn_adaptive_avg_pool2d"))
input <- torch::torch_randn(1, 3, 256, 256)
out <- prune(input)
expect_tensor_shape(out, c(1, 2048, 1, 1))
})
test_that("we can prune head of resnext101 moels", {
resnext101 <- model_resnext101_32x8d(pretrained=TRUE)
expect_error(prune <- torch:::nn_prune_head(resnext101, 1), NA)
expect_true(inherits(prune, "nn_sequential"))
expect_equal(length(prune), 9)
expect_true(inherits(prune[[length(prune)]], "nn_adaptive_avg_pool2d"))
input <- torch::torch_randn(1, 3, 256, 256)
out <- prune(input)
expect_tensor_shape(out, c(1, 2048, 1, 1))
})
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.