Nothing
#' @include distributions.R
#' @include distributions-exp-family.R
#' @include distributions-utils.R
#' @include distributions-constraints.R
#' @include utils.R
Normal <- R6::R6Class(
"torch_Normal",
lock_objects = FALSE,
inherit = ExponentialFamily,
public = list(
.arg_constraints = list(
loc = constraint_real,
scale = constraint_positive
),
.support = constraint_real,
has_rsample = TRUE,
._mean_carrier_measure = 0,
initialize = function(loc, scale, validate_args = NULL) {
# TODO
broadcasted <- broadcast_all(list(loc, scale))
self$loc <- broadcasted[[1]]
self$scale <- broadcasted[[2]]
# TODO: check this fragment
# It seems it's more suitbale for Python
# if (inherits(loc, "numeric") & inherits(scale, "numeric"))
# batch_shape <- NULL
# else
# batch_shape <- self$loc$size()
batch_shape <- self$loc$size()
super$initialize(batch_shape, validate_args = validate_args)
},
expand = function(batch_shape, .instance = NULL) {
.args <- list(
loc = self$loc$expand(batch_shape),
scale = self$scale$expand(batch_shape)
)
new <- private$.get_checked_instance(self, .instance, .args)
# new$loc <- self$loc$expand(batch_shape)
# new$scale <- self$scale$expand(batch_shape)
new$.__enclos_env__$super$initialize(
batch_shape,
validate_args = FALSE
)
new$.validate_args <- self$.validate_args
new
},
sample = function(sample_shape = NULL) {
shape <- self$.extended_shape(sample_shape)
with_no_grad({
torch_normal(
self$loc$expand(shape), self$scale$expand(shape)
)
})
},
rsample = function(sample_shape = NULL) {
shape <- self$.extended_shape(sample_shape)
eps <- .standard_normal(shape,
dtype = self$loc$dtype,
device = self$loc$device
)
self$loc + eps * self$scale
},
log_prob = function(value) {
if (self$.validate_args) {
self$.validate_sample(value)
}
# compute the variance
var <- self$scale**2
if (inherits(self$scale, "numeric")) {
log_scale <- log(self$scale)
} else {
log_scale <- self$scale$log()
}
-((value - self$loc)**2) / (2 * var) - log_scale - log(sqrt(2 * pi))
},
cdf = function(value) {
if (self$.validate_args) {
self$.validate_sample(value)
}
0.5 * (1 + torch_erf((value - self$loc) * self$scale$reciprocal() / sqrt(2)))
},
icdf = function(value) {
if (self$.validate_args) {
self$.validate_sample(value)
}
self$loc + self$scale * torch_erfinv(2 * value - 1) * sqrt(2)
},
entropy = function() {
0.5 + 0.5 * log(2 * pi) + torch_log(self$scale)
}
),
private = list(
.log_normalizer = function(x, y) {
-0.25 * x$pow(2) / y + 0.5 * torch_log(-pi / y)
}
),
active = list(
mean = function() {
self$loc
},
stddev = function() {
self$scale
},
variance = function() {
self$stddev$pow(2)
},
.natural_params = function() {
list(self$loc / self$scale$pow(2), -0.5 * self$scale$pow(2)$reciprocal())
},
.mean_carrier_measure = function() {
self$._mean_carrier_measure
},
support = function() {
private$.support
}
)
)
Normal <- add_class_definition(Normal)
#' Creates a normal (also called Gaussian) distribution parameterized by
#' `loc` and `scale`.
#'
#' @param loc (float or Tensor): mean of the distribution (often referred to as mu)
#' @param scale (float or Tensor): standard deviation of the distribution (often referred to as sigma)
#' @param validate_args Additional arguments
#'
#' @return Object of `torch_Normal` class
#'
#' @examples
#' m <- distr_normal(loc = 0, scale = 1)
#' m$sample() # normally distributed with loc=0 and scale=1
#' @seealso [Distribution] for details on the available methods.
#' @family distributions
#' @export
distr_normal <- function(loc, scale, validate_args = NULL) {
Normal$new(loc, scale, validate_args)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.