Nothing
test_that("deeplabv3_resnet50 works with default aux_loss=NULL", {
model <- model_deeplabv3_resnet50(num_classes = 21)
input <- torch::torch_randn(1, 3, 32, 32)
out <- model(input)
expect_true(is.list(out))
expect_named(out, "out")
expect_tensor_shape(out$out, c(1, 21, 32, 32))
withr::with_options(list(timeout = 360), {
model <- model_deeplabv3_resnet50(pretrained = TRUE)
})
out <- model(input)
expect_named(out, "out")
expect_tensor_shape(out$out, c(1, 21, 32, 32))
})
test_that("deeplabv3_resnet101 works with default aux_loss=NULL", {
skip_if(Sys.getenv("TEST_LARGE_DATASETS", unset = "0") != "1",
"Skipping test: set TEST_LARGE_DATASETS=1 to enable tests requiring large downloads.")
model <- model_deeplabv3_resnet101(num_classes = 21)
input <- torch::torch_randn(1, 3, 32, 32)
out <- model(input)
expect_true(is.list(out))
expect_named(out, "out")
expect_tensor_shape(out$out, c(1, 21, 32, 32))
withr::with_options(list(timeout = 360), {
model <- model_deeplabv3_resnet101(pretrained = TRUE)
})
out <- model(input)
expect_named(out, "out")
expect_tensor_shape(out$out, c(1, 21, 32, 32))
})
test_that("deeplabv3_resnet50 works with aux_loss = TRUE", {
model <- model_deeplabv3_resnet50(aux_loss = TRUE, num_classes = 21)
input <- torch::torch_randn(1, 3, 32, 32)
out <- model(input)
expect_named(out, c("out", "aux"))
expect_tensor_shape(out$out, c(1, 21, 32, 32))
expect_tensor_shape(out$aux, c(1, 21, 32, 32))
})
test_that("deeplabv3_resnet50 works with aux_loss = FALSE", {
model <- model_deeplabv3_resnet50(aux_loss = FALSE, num_classes = 21)
input <- torch::torch_randn(1, 3, 32, 32)
out <- model(input)
expect_named(out, "out")
expect_tensor_shape(out$out, c(1, 21, 32, 32))
})
test_that("custom num_classes works with aux_loss = TRUE", {
model <- model_deeplabv3_resnet50(num_classes = 3, aux_loss = TRUE)
input <- torch::torch_randn(1, 3, 32, 32)
out <- model(input)
expect_named(out, c("out", "aux"))
expect_tensor_shape(out$out, c(1, 3, 32, 32))
expect_tensor_shape(out$aux, c(1, 3, 32, 32))
})
test_that("pretrained requires num_classes = 21", {
expect_error(
model_deeplabv3_resnet50(pretrained = TRUE, num_classes = 3),
"num_classes = 21"
)
})
test_that("model_deeplabv3_resnet50 detects cat in Wikipedia image", {
skip_if(Sys.getenv("TEST_LARGE_DATASETS", unset = "0") != "1",
"Skipping test: set TEST_LARGE_DATASETS=1 to enable tests requiring large downloads.")
voc_classes <- c(
"background", "aeroplane", "bicycle", "bird", "boat", "bottle",
"bus", "car", "cat", "chair", "cow", "dining table", "dog", "horse",
"motorbike", "person", "potted plant", "sheep", "sofa", "train", "tv/monitor"
)
img_url <- "https://upload.wikimedia.org/wikipedia/commons/3/36/United_Airlines_Boeing_777-200_Meulemans.jpg"
img <- magick::image_read(img_url)
norm_mean <- c(0.485, 0.456, 0.406)
norm_std <- c(0.229, 0.224, 0.225)
input <- transform_to_tensor(img)
input <- transform_resize(input, c(520, 520))
input <- transform_normalize(input, mean = norm_mean, std = norm_std)
input <- input$unsqueeze(1)
model <- model_deeplabv3_resnet50(pretrained = TRUE)
model$eval()
output <- model(input)
mask <- output$out$argmax(dim = 2) # shape (1, H, W)
label_array <- mask %>% torch::as_array() # convert to R array
label_table <- table(factor(label_array, levels = 0:20, labels = voc_classes))
expect_gt(label_table[["aeroplane"]], 0)
expect_gt(label_table[["aeroplane"]], label_table[["dog"]])
expect_gt(label_table[["aeroplane"]], label_table[["person"]])
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.