Nothing
check_log_prob <- function(distribution, expect_fn) {
s <- distribution$sample()
log_probs <- distribution$log_prob(s)
log_probs_data_flat <- log_probs$view(-1)
s_data_flat <- s$view(c(length(log_probs_data_flat), -1))
for (i in seq_along(s_data_flat)) {
val <- s_data_flat[i]
log_prob <- log_probs_data_flat[i]
expect_fn(i, val$squeeze(), log_prob)
}
}
check_enumerate_support <- function(distr_cls, examples) {
for (i in seq_along(examples)) {
params <- examples[[i]][[1]]
params <- Map(torch_tensor, params)
expected <- examples[[i]][[2]]
# TODO: consider ignoring arg types in this expect_equal (suggested in PyTorch)
expected <- torch_tensor(expected)
d <- do.call(distr_cls, params)
actual <- d$enumerate_support(expand = FALSE)
expect_equal(actual, expected)
actual <- d$enumerate_support(expand = TRUE)
expected_with_expand <- expected$expand(c(-1, d$batch_shape, d$event_shape))
expect_equal(actual, expected_with_expand)
}
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.