Nothing
test_that("simnple tracing works", {
fn <- function(x) {
torch_relu(x)
}
input <- torch_tensor(c(-1, 0, 1))
tr_fn <- jit_trace(fn, input)
expect_equal_to_tensor(tr_fn(input), fn(input))
})
test_that("print the graph works", {
fn <- function(x) {
torch_relu(x)
}
input <- torch_tensor(c(-1, 0, 1))
tr_fn <- jit_trace(fn, input)
expect_output(print(tr_fn$graph), regexp = "graph")
})
test_that("modules are equivalent", {
Net <- nn_module(
"Net",
initialize = function() {
self$conv1 <- nn_conv2d(1, 32, 3, 1)
self$conv2 <- nn_conv2d(32, 64, 3, 1)
self$dropout1 <- nn_dropout(0.25)
self$dropout2 <- nn_dropout(0.5)
self$fc1 <- nn_linear(9216, 128)
self$fc2 <- nn_linear(128, 10)
},
forward = function(x) {
x <- self$conv1(x)
x <- nnf_relu(x)
x <- self$conv2(x)
x <- nnf_relu(x)
x <- nnf_max_pool2d(x, 2)
x <- self$dropout1(x)
x <- torch_flatten(x, start_dim = 2)
x <- self$fc1(x)
x <- nnf_relu(x)
x <- self$dropout2(x)
x <- self$fc2(x)
output <- nnf_log_softmax(x, dim = 1)
output
}
)
net <- Net()
net$eval()
# currently we need to detach all parameters in order to
# JIT compile. We need to support modules to avoid that.
for (p in net$parameters) {
p$detach_()
}
fn <- function(x) {
net(x)
}
input <- torch_randn(100, 1, 28, 28)
out <- fn(input)
tr_fn <- jit_trace(fn, input)
expect_true(torch_allclose(fn(input), tr_fn(input)))
})
test_that("can save and reload", {
fn <- function(x) {
torch_relu(x)
}
input <- torch_tensor(c(-1, 0, 1))
tr_fn <- jit_trace(fn, input)
tmp <- tempfile("tst", fileext = "pt")
jit_save(tr_fn, tmp)
f <- jit_load(tmp)
expect_equal_to_tensor(f(input), fn(input))
})
test_that("errors gracefully when passing unsupported inputs", {
fn <- function(x) {
torch_relu(x)
}
expect_error(
jit_trace(fn, "a")
)
})
test_that("can take lists of tensors as input", {
fn <- function(x) {
torch_stack(x)
}
x <- list(torch_tensor(1), torch_tensor(2))
tr_fn <- jit_trace(fn, x)
expect_equal_to_tensor(fn(x), tr_fn(x))
})
test_that("can output a list of tensors", {
fn <- function(x) {
list(x, x + 1)
}
x <- torch_tensor(1)
tr_fn <- jit_trace(fn, x)
expect_equal_to_tensor(fn(x)[[1]], tr_fn(x)[[1]])
expect_equal_to_tensor(fn(x)[[2]], tr_fn(x)[[2]])
})
test_that("fn can take more than 1 argument", {
fn <- function(x, y) {
list(x, x + y)
}
x <- torch_tensor(1)
y <- torch_tensor(2)
tr_fn <- jit_trace(fn, x, y)
expect_equal_to_tensor(fn(x, y)[[1]], tr_fn(x, y)[[1]])
expect_equal_to_tensor(fn(x, y)[[2]], tr_fn(x, y)[[2]])
expect_error(
tr_fn <- jit_trace(fn, x = x, y = y)
)
})
test_that("can have named inputs and outputs", {
fn <- function(x) {
list(x = x$t1, y = x$t2)
}
x <- list(
t1 = torch_tensor(1),
t2 = torch_tensor(2)
)
tr_fn <- jit_trace(fn, x, strict = FALSE)
expect_equal(
tr_fn(x),
fn(x)
)
})
test_that("tuple inputs are correctly handled", {
fn <- function(x) {
jit_tuple(list(x = x$t1, y = x$t2))
}
x <- jit_tuple(list(
t1 = torch_tensor(1),
t2 = torch_tensor(2)
))
tr_fn <- jit_trace(fn, x, strict = FALSE)
# returned named tuples will loose their names
expect_equal(tr_fn(x), list(torch_tensor(1), torch_tensor(2)))
# if the model has been traced with a named tuple we
# will expect a tuple back too.
expect_error(
tr_fn(list(t1 = torch_tensor(1), t2 = torch_tensor(2)))
)
})
test_that("tuple casting", {
fn <- function(y) {
jit_tuple(list(y[[1]], y[[2]]))
}
x <- jit_tuple(list(torch_tensor(1), torch_tensor(2)))
tr_fn <- jit_trace(fn, x)
expect_error(
tr_fn(list(torch_tensor(1), torch_tensor(2)))
)
expect_equal(
tr_fn(jit_tuple(list(torch_tensor(1), torch_tensor(2)))),
list(torch_tensor(1), torch_tensor(2))
)
})
test_that("trace a nn module", {
test_module <- nn_module(
initialize = function() {
self$linear <- nn_linear(10, 10)
self$norm <- nn_batch_norm1d(10)
self$par <- nn_parameter(torch_tensor(2))
self$buff <- nn_buffer(torch_randn(10, 5))
self$constant <- 1
self$hello <- list(torch_tensor(1), torch_tensor(2), "hello")
},
forward = function(x) {
self$par * x
},
testing = function(x) {
x %>%
self$linear()
},
test_constant = function(x) {
x + self$constant + self$hello[[2]]
}
)
mod <- test_module()
expect_error(
m <- jit_trace_module(
mod,
forward = torch_randn(1),
testing = list(torch_randn(10, 10)),
test_constant = list(torch_tensor(1))
),
regexp = NA
)
expect_length(m$parameters, 5)
expect_length(m$buffers, 4)
expect_length(m$modules, 3)
expect_equal_to_tensor(m(torch_tensor(2)), torch_tensor(4))
with_no_grad(m$par$zero_())
expect_equal_to_tensor(m(torch_tensor(2)), torch_tensor(0))
x <- torch_randn(10, 10)
expect_true(torch_allclose(m$testing(x), mod$testing(x)))
with_no_grad({
m$linear$weight$zero_()$add_(1)
mod$linear$weight$zero_()$add_(1)
})
expect_true(torch_allclose(m$testing(x), mod$testing(x)))
expect_true(torch_allclose(m$test_constant(torch_tensor(2)), torch_tensor(5)))
})
test_that("dont crash when gcing a method", {
mod <- jit_trace(nn_linear(10, 10), torch_randn(10, 10))
gc()
forward <- mod$forward
rm(forward)
gc()
gc()
expect_error(regexp = NA, mod$forward)
})
test_that("we can save traced modules", {
test_module <- nn_module(
initialize = function() {
self$linear <- nn_linear(10, 10)
self$norm <- nn_batch_norm1d(10)
self$par <- nn_parameter(torch_tensor(2))
self$buff <- nn_buffer(torch_randn(10, 5))
self$constant <- 1
self$hello <- list(torch_tensor(1), torch_tensor(2), "hello")
},
forward = function(x) {
self$par * x
},
testing = function(x) {
x %>%
self$linear()
},
test_constant = function(x) {
x + self$constant + self$hello[[2]]
}
)
mod <- test_module()
m <- jit_trace_module(
mod,
forward = torch_randn(1),
testing = list(torch_randn(10, 10)),
test_constant = list(torch_tensor(1))
)
jit_save(m, "tracedmodule.pt")
rm(m)
gc()
gc()
m <- jit_load("tracedmodule.pt")
expect_length(m$parameters, 5)
expect_length(m$buffers, 4)
expect_length(m$modules, 3)
expect_equal_to_tensor(m(torch_tensor(2)), torch_tensor(4))
with_no_grad(m$par$zero_())
expect_equal_to_tensor(m(torch_tensor(2)), torch_tensor(0))
x <- torch_randn(10, 10)
expect_true(torch_allclose(m$testing(x), mod$testing(x)))
with_no_grad({
m$linear$weight$zero_()$add_(1)
mod$linear$weight$zero_()$add_(1)
})
expect_true(torch_allclose(m$testing(x), mod$testing(x), atol=1e-5))
expect_true(torch_allclose(m$test_constant(torch_tensor(2)), torch_tensor(5)))
})
test_that("trace a module", {
Net <- nn_module(
"Net",
initialize = function() {
self$conv1 <- nn_conv2d(1, 32, 3, 1)
self$conv2 <- nn_conv2d(32, 64, 3, 1)
self$dropout1 <- nn_dropout(0.25)
self$dropout2 <- nn_dropout(0.5)
self$fc1 <- nn_linear(9216, 128)
self$fc2 <- nn_linear(128, 10)
},
forward = function(x) {
x <- self$conv1(x)
x <- nnf_relu(x)
x <- self$conv2(x)
x <- nnf_relu(x)
x <- nnf_max_pool2d(x, 2)
x <- self$dropout1(x)
x <- torch_flatten(x, start_dim = 2)
x <- self$fc1(x)
x <- nnf_relu(x)
x <- self$dropout2(x)
x <- self$fc2(x)
output <- nnf_log_softmax(x, dim = 1)
output
}
)
net <- Net()
net$eval()
input <- torch_randn(100, 1, 28, 28)
out <- net(input)
tr_fn <- jit_trace(net, input)
expect_equal_to_tensor(net(input), tr_fn(input), tolerance = 1e-6)
})
test_that("Can recover from errors in the traced method", {
module <- nn_module(
initialize = function() {},
forward = function(x) {
stop("The error abcde")
}
)
expect_error(
jit_trace(module(), torch_tensor(1)),
regexp = ".*abcde"
)
expect_error(
regexp = "You must initialize the nn_module before tracing",
jit_trace(module, torch_tensor(1))
)
expect_error(
regexp = "jit_trace needs a function or nn_module",
jit_trace(1, torch_tensor(1))
)
})
test_that("we get a good error message when trying to call a method from a submodule", {
module <- nn_module(
initialize = function() {
self$linear <- nn_linear(10, 10)
},
forward = function(x) {
self$linear(x)
}
)
m <- jit_trace(module(), torch_randn(100, 10))
expect_error(
m$linear(torch_randn(10, 10)),
regexp = "Methods from submodules of traced modules are not traced"
)
})
test_that("errors in the tracer are correctly captured", {
module <- nn_module(
initialize = function() {
self$linear <- nn_linear(10, 10)
},
forward = function(x) {
self$linear(x)
1
}
)
expect_error(
jit_trace(module(), torch_randn(10, 10)),
regexp = ".*Only tensors, lists, tuples of tensors"
)
})
test_that("we can include traced module as a submodule and trace", {
module <- nn_module(
initialize = function() {
self$linear <- jit_trace(nn_linear(10, 10), torch_randn(10, 10))
},
forward = function(x) {
self$linear(x)
}
)
mod <- module()
m <- jit_trace(mod, torch_randn(10, 10))
x <- torch_randn(10, 10)
expect_equal_to_tensor(m(x), mod(x))
expect_equal_to_tensor(m$linear(x), mod$linear(x))
})
test_that("can save module for mobile", {
Net <- nn_module(
"Net",
initialize = function() {
self$conv1 <- nn_conv2d(1, 32, 3, 1)
self$conv2 <- nn_conv2d(32, 64, 3, 1)
self$dropout1 <- nn_dropout(0.25)
self$dropout2 <- nn_dropout(0.5)
self$fc1 <- nn_linear(9216, 128)
self$fc2 <- nn_linear(128, 10)
},
forward = function(x) {
x <- self$conv1(x)
x <- nnf_relu(x)
x <- self$conv2(x)
x <- nnf_relu(x)
x <- nnf_max_pool2d(x, 2)
x <- self$dropout1(x)
x <- torch_flatten(x, start_dim = 2)
x <- self$fc1(x)
x <- nnf_relu(x)
x <- self$dropout2(x)
x <- self$fc2(x)
output <- nnf_log_softmax(x, dim = 1)
output
}
)
net <- Net()
net$eval()
input <- torch_randn(100, 1, 28, 28)
out <- net(input)
tr_fn <- jit_trace(net, input)
tmp <- tempfile("tst", fileext = ".pt")
jit_save_for_mobile(tr_fn, tmp)
f <- jit_load(tmp)
expect_equal_to_tensor(net(input), f(input), tolerance = 1e-6)
})
test_that("can save function for mobile", {
fn <- function(x) {
torch_relu(x)
}
input <- torch_tensor(c(-1, 0, 1))
tr_fn <- jit_trace(fn, input)
tmp <- tempfile("tst", fileext = ".pt")
jit_save_for_mobile(tr_fn, tmp)
f <- jit_load(tmp)
expect_equal_to_tensor(torch_relu(input), f(input))
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.