Nothing
#' Base class for Distributions
#'
#' Represents a modifiable Distribution family
#'
#' @examples
#' # Example for param_bounds:
#'
#' # Create an Exponential Distribution with rate constrained to (0, 2)
#' # instead of (0, Inf)
#' my_exp <- dist_exponential()
#' my_exp$param_bounds$rate <- interval(c(0, 2))
#' my_exp$get_param_bounds()
#'
#' fit_dist(my_exp, rexp(100, rate = 3), start = list(rate = 1))$params$rate
#'
#' @family Distributions
Distribution <- R6Class(
classname = "Distribution",
public = list(
#' @details Construct a Distribution instance
#'
#' Used internally by the `dist_*` functions.
#'
#' @param type Type of distribution. This is a string constant for the
#' default implementation. Distributions with non-constant type must
#' override the `get_type()` function.
#' @param caps Character vector of capabilities to fuel the default
#' implementations of `has_capability()` and `require_capability()`.
#' Distributions with dynamic capabilities must override the
#' `has_capability()` function.
#' @param params Initial parameter bounds structure, backing the
#' `param_bounds` active binding (usually a list of intervals).
#' @param name Name of the Distribution class. Should be `CamelCase` and end
#' with `"Distribution"`.
#' @param default_params Initial fixed parameters backing the
#' `default_params` active binding (usually a list of numeric / NULLs).
initialize = function(type, caps, params, name, default_params) {
private$.type <- type
private$.caps <- caps
private$.params <- params
private$.default_params <- default_params
private$.name <- name
},
# nocov start (stubs with stop() as implementation)
#' @details Sample from a Distribution
#'
#' @param n number of samples to draw.
#' @param with_params Distribution parameters to use.
#' Each parameter value can also be a numeric vector of length `n`. In that
#' case the `i`-th sample will use the `i`-th parameters.
#'
#' @return A length `n` vector of i.i.d. random samples from the
#' Distribution with the specified parameters.
#'
#' @examples
#' dist_exponential(rate = 2.0)$sample(10)
sample = function(n, with_params = list()) {
stop("sample() is not implemented for ", class(self)[1L], ".")
},
#' @details Density of a Distribution
#'
#' @param x Vector of points to evaluate the density at.
#' @param log Flag. If `TRUE`, return the log-density instead.
#' @param with_params Distribution parameters to use.
#' Each parameter value can also be a numeric vector of length `length(x)`.
#' In that case, the `i`-th density point will use the `i`-th parameters.
#'
#' @return A numeric vector of (log-)densities
#'
#' @examples
#' dist_exponential()$density(c(1.0, 2.0), with_params = list(rate = 2.0))
density = function(x, log = FALSE, with_params = list()) {
stop("density() is not implemented for ", class(self)[1L], ".")
},
#' @details Compile a TensorFlow function for log-density evaluation
#'
#' @return A `tf_function` taking arguments `x` and `args` returning the
#' log-density of the Distribution evaluated at `x` with parameters `args`.
tf_logdensity = function() {
stop("tf_logdensity() is not implemented for ", class(self)[1L], ".")
},
#' @details Cumulative probability of a Distribution
#'
#' @param q Vector of points to evaluate the probability function at.
#' @param lower.tail If `TRUE`, return P(X <= q). Otherwise return P(X > q).
#' @param log.p If `TRUE`, probabilities are returned as `log(p)`.
#' @param with_params Distribution parameters to use.
#' Each parameter value can also be a numeric vector of length `length(q)`.
#' In that case, the `i`-th probability point will use the `i`-th
#' parameters.
#'
#' @return A numeric vector of (log-)probabilities
#'
#' @examples
#' dist_exponential()$probability(
#' c(1.0, 2.0),
#' with_params = list(rate = 2.0)
#' )
probability = function(q, lower.tail = TRUE, log.p = FALSE,
with_params = list()) {
stop("probability() is not implemented for ", class(self)[1L], ".")
},
#' @details Compile a TensorFlow function for log-probability evaluation
#'
#' @return A `tf_function` taking arguments `qmin`, `qmax` and `args`
#' returning the log-probability of the Distribution evaluated over the
#' closed interval \[`qmin`, `qmax`\] with parameters `args`.
tf_logprobability = function() {
stop("tf_logprobability() is not implemented for ", class(self)[1L], ".")
},
#' @details Quantile function of a Distribution
#'
#' @param p Vector of probabilities.
#' @param lower.tail If `TRUE`, return P(X <= q). Otherwise return P(X > q).
#' @param log.p If `TRUE`, probabilities are returned as `log(p)`.
#' @param with_params Distribution parameters to use.
#' Each parameter value can also be a numeric vector of length `length(p)`.
#' In that case, the `i`-th quantile will use the `i`-th parameters.
#'
#' @return A numeric vector of quantiles
#'
#' @examples
#' dist_exponential()$quantile(c(0.1, 0.5), with_params = list(rate = 2.0))
quantile = function(p, lower.tail = TRUE, log.p = FALSE,
with_params = list()) {
stop("quantile() is not implemented for ", class(self)[1L], ".")
},
#' @details Hazard function of a Distribution
#'
#' @param x Vector of points.
#' @param log Flag. If `TRUE`, return the log-hazard instead.
#' @param with_params Distribution parameters to use.
#' Each parameter value can also be a numeric vector of length `length(x)`.
#' In that case, the `i`-th hazard point will use the `i`-th parameters.
#'
#' @return A numeric vector of (log-)hazards
#'
#' @examples
#' dist_exponential(rate = 2.0)$hazard(c(1.0, 2.0))
hazard = function(x, log = FALSE, with_params = list()) {
stop("hazard() is not implemented for ", class(self)[1L], ".")
},
#' @details Gradients of the density of a Distribution
#'
#' @param x Vector of points.
#' @param log Flag. If `TRUE`, return the gradient of the log-density
#' instead.
#' @param with_params Distribution parameters to use.
#' Each parameter value can also be a numeric vector of length `length(x)`.
#' In that case, the `i`-th density point will use the `i`-th parameters.
#'
#' @return A list structure containing the (log-)density gradients of all
#' free parameters of the Distribution evaluated at `x`.
#'
#' @examples
#' dist_exponential()$diff_density(
#' c(1.0, 2.0),
#' with_params = list(rate = 2.0)
#' )
diff_density = function(x, log = FALSE, with_params = list()) {
stop("diff_density() is not implemented for ", class(self)[1L], ".")
},
#' @details Gradients of the cumulative probability of a Distribution
#'
#' @param q Vector of points to evaluate the probability function at.
#' @param lower.tail If `TRUE`, return P(X <= q). Otherwise return P(X > q).
#' @param log.p If `TRUE`, probabilities are returned as `log(p)`.
#' @param with_params Distribution parameters to use.
#' Each parameter value can also be a numeric vector of length `length(q)`.
#' In that case, the `i`-th probability point will use the `i`-th
#' parameters.
#'
#' @return A list structure containing the cumulative (log-)probability
#' gradients of all free parameters of the Distribution evaluated at `q`.
#'
#' @examples
#' dist_exponential()$diff_probability(
#' c(1.0, 2.0),
#' with_params = list(rate = 2.0)
#' )
diff_probability = function(q, lower.tail = TRUE, log.p = FALSE,
with_params = list()) {
stop("diff_probability() is not implemented for ", class(self)[1L], ".")
},
#' @details Determine if a value is in the support of a Distribution
#'
#' @param x Vector of points
#' @param with_params Distribution parameters to use.
#' Each parameter value can also be a numeric vector of length `length(x)`.
#' In that case, the `i`-th point will use the `i`-th parameters.
#'
#' @return A logical vector with the same length as `x` indicating whether
#' `x` is part of the support of the distribution given its parameters.
#'
#' @examples
#' dist_exponential(rate = 1.0)$is_in_support(c(-1.0, 0.0, 1.0))
is_in_support = function(x, with_params = list()) {
stop("is_in_support() is not implemented for ", class(self)[1L], ".")
},
# nocov end
#' @details Determine if a value has positive probability
#'
#' @param x Vector of points
#' @param with_params Distribution parameters to use.
#' Each parameter value can also be a numeric vector of length `length(x)`.
#' In that case, the `i`-th point will use the `i`-th parameters.
#'
#' @return A logical vector with the same length as `x` indicating whether
#' there is a positive probability mass at `x` given the Distribution
#' parameters.
#'
#' @examples
#' dist_dirac(point = 0.0)$is_discrete_at(c(0.0, 1.0))
is_discrete_at = function(x, with_params = list()) {
if (self$is_continuous()) {
rep_len(FALSE, length(x))
} else {
stop("is_discrete_at() is not implemented for ", class(self)[1L], ".") # nocov # nolint: line_length_linter.
}
},
#' @details Compile a TensorFlow function for discrete support checking
#'
#' @return A `tf_function` taking arguments `x` and `args` returning whether
#' the Distribution has a point mass at `x` given parameters `args`.
tf_is_discrete_at = function() {
check_installed("tensorflow")
if (self$is_continuous()) {
private$.tf_retrieve_or_call(
"c_false",
function() function(x, args) { # nolint: brace.
tensorflow::tf$broadcast_to(FALSE, shape = tensorflow::tf$shape(x))
}
)
} else {
# discrete and mixed case needs to check support as well
# nocov start
stop("tf_is_discrete_at() is not implemented for ",
class(self)[1L], ".")
# nocov end
}
},
#' @details
#' Check if a capability is present
#'
#' @param caps Character vector of capabilities
#' @return A logical vector the same length as `caps`.
#'
#' @examples
#' dist_exponential()$has_capability("density")
has_capability = function(caps) {
caps %in% private$.caps
},
#' @details
#' Get the type of a Distribution. Type can be one of `discrete`,
#' `continuous` or `mixed`.
#'
#' @return A string representing the type of the Distribution.
#'
#' @examples
#' dist_exponential()$get_type()
#' dist_dirac()$get_type()
#'
#' dist_mixture(list(dist_dirac(), dist_exponential()))$get_type()
#' dist_mixture(list(dist_dirac(), dist_binomial()))$get_type()
get_type = function() {
private$.type
},
#' @details Get the component Distributions of a transformed Distribution.
#'
#' @return A possibly empty list of Distributions
#'
#' @examples
#' dist_trunc(dist_exponential())$get_components()
#' dist_dirac()$get_components()
#' dist_mixture(list(dist_exponential(), dist_gamma()))$get_components()
get_components = function() {
list()
},
#' @details
#' Check if a Distribution is discrete, i.e. it has a density with respect
#' to the counting measure.
#'
#' @return `TRUE` if the Distribution is discrete, `FALSE` otherwise.
#' Note that mixed distributions are not discrete but can have point masses.
#'
#' @examples
#' dist_exponential()$is_discrete()
#' dist_dirac()$is_discrete()
is_discrete = function() {
identical(self$get_type(), "discrete")
},
#' @details
#' Check if a Distribution is continuous, i.e. it has a density with respect
#' to the Lebesgue measure.
#'
#' @return `TRUE` if the Distribution is continuous, `FALSE` otherwise.
#' Note that mixed distributions are not continuous.
#'
#' @examples
#' dist_exponential()$is_continuous()
#' dist_dirac()$is_continuous()
is_continuous = function() {
identical(self$get_type(), "continuous")
},
#' @details
#' Ensure that a Distribution has all required capabilities.
#' Will throw an error if any capability is missing.
#'
#' @param caps Character vector of Capabilities to require
#' @param fun_name Frienly text to use for generating the error message in
#' case of failure.
#'
#' @return Invisibly `TRUE`.
#'
#' @examples
#' dist_exponential()$require_capability("diff_density")
require_capability = function(caps,
fun_name = paste0(sys.call(-1)[[1]], "()")) {
if (!all(self$has_capability(caps))) {
missing_caps <- setdiff(caps, private$.caps)
if (length(missing_caps) > 1) {
missing_caps <- paste0(
paste(head(missing_caps, -1), collapse = ", "),
" and ",
tail(missing_caps, 1)
)
}
stop(
fun_name, " requires missing capabilites.\n",
private$.name, " doesn't provide ", missing_caps, "."
)
}
invisible(TRUE)
},
#' @details
#' Get the number of degrees of freedom of a Distribution family.
#' Only parameters without a fixed default are considered free.
#'
#' @return An integer representing the degrees of freedom suitable e.g. for
#' AIC calculations.
#'
#' @examples
#' dist_exponential()$get_dof()
#' dist_exponential(rate = 1.0)$get_dof()
get_dof = function() {
res <- private$.default_params
sum_dof <- function(proto) {
list_elems <- vapply(proto, is.list, logical(1))
distr_elems <- vapply(proto, is.Distribution, logical(1))
null_elems <- vapply(proto, is.null, logical(1))
sum(null_elems) +
sum(vapply(proto[list_elems], sum_dof, numeric(1))) +
sum(vapply(
proto[distr_elems], function(d) d$get_dof(), numeric(1)
))
}
sum_dof(res)
},
#' @details
#' Get Placeholders of a Distribution family.
#' Returns a list of free parameters of the family.
#' Their values will be `NULL`.
#'
#' If the Distribution has Distributions as parameters, placeholders will be
#' computed recursively.
#'
#' @return A named list containing any combination of (named or unnamed)
#' lists and `NULL`s.
#'
#' @examples
#' dist_exponential()$get_placeholders()
#' dist_mixture(list(dist_dirac(), dist_exponential()))$get_placeholders()
get_placeholders = function() {
res <- private$.default_params
prune <- function(proto) {
null_elems <- vapply(proto, is.null, logical(1))
list_elems <- vapply(proto, is.list, logical(1))
distr_elems <- vapply(proto, is.Distribution, logical(1))
keep <- null_elems | list_elems | distr_elems
proto[list_elems] <- lapply(proto[list_elems], prune)
proto[distr_elems] <- lapply(
proto[distr_elems],
function(d) d$get_placeholders()
)
proto[keep]
}
prune(res)
},
#' @details
#' Get a full list of parameters, possibly including placeholders.
#'
#' @param with_params Optional parameter overrides with the same structure
#' as `dist$get_params()`. Given Parameter values are expected to be length
#' 1.
#'
#' @return A list representing the (recursive) parameter structure of the
#' Distribution with values for specified parameters and `NULL` for free
#' parameters that are missing both in the Distributions parameters and in
#' `with_params`.
#'
#' @examples
#' dist_mixture(list(dist_dirac(), dist_exponential()))$get_params(
#' with_params = list(probs = list(0.5, 0.5))
#' )
get_params = function(with_params = list()) {
my_params <- private$.make_params(with_params, 1L)
get_distr_params <- function(elem) {
if (is.list(elem) &&
length(elem) == 2L &&
hasName(elem, "dist") &&
hasName(elem, "params") &&
is.Distribution(elem$dist)) {
elem$dist$get_params(elem$params)
} else if (is.list(elem)) {
lapply(elem, get_distr_params)
} else {
elem
}
}
lapply(my_params, get_distr_params)
},
#' @details Get a list of constant TensorFlow parameters
#'
#' @param with_params Optional parameter overrides with the same structure
#' as `dist$tf_make_constants()`. Given Parameter values are expected to be
#' length 1.
#'
#' @return A list representing the (recursive) constant parameters of the
#' Distribution with values sprecified by parameters. Each constant is a
#' TensorFlow Tensor of dtype `floatx`.
tf_make_constants = function(with_params = list()) {
check_installed(c("keras3", "tensorflow"))
my_params <- private$.make_params(with_params, 1L)
get_tf_consts <- function(elem) {
if (is.list(elem) &&
length(elem) == 2L &&
hasName(elem, "dist") &&
hasName(elem, "params") &&
is.Distribution(elem$dist)) {
elem$dist$tf_make_constants(elem$params)
} else if (is.list(elem)) {
are_null <- vapply(elem, is.null, logical(1L))
lapply(elem[!are_null], get_tf_consts)
} else if (is.numeric(elem)) {
keras3::as_tensor(elem, keras3::config_floatx())
} else {
# nocov start
stop(
"Unsupported parameter class ", class(params)[1L],
" for tf_make_constants of ", class(self)[1L]
)
# nocov end
}
}
get_tf_consts(my_params)
},
#' @details Compile distribution parameters into tensorflow outputs
#'
#' @param input A keras layer to bind all outputs to
#' @param name_prefix Prefix to use for layer names
#'
#' @return A list with two elements
#'
#' * `outputs` a flat list of keras output layers, one for each parameter.
#' * `output_inflater` a function taking keras output layers and
#' transforming them into a list structure suitable for passing to the
#' loss function returned by [tf_compile_model()]
tf_compile_params = function(input, name_prefix = "") {
bounds <- self$get_param_bounds()
out <- self$get_placeholders()
for (ph in names(out)) {
b <- bounds[[ph]]
if (!is.Interval(b)) {
stop("Unsupported parameter ", ph, " for ", class(self)[1L], ".") # nocov # nolint: line_length_linter.
}
layer_name <- paste0(name_prefix, ph)
out[[ph]] <- b$tf_make_layer(input, layer_name)
}
ph_names <- names(out)
list(
outputs = out,
output_inflater = eval(bquote(function(outputs) {
if (!is.list(outputs)) outputs <- list(outputs)
names(outputs) <- .(ph_names)
outputs
}))
)
},
#' @details Get Interval bounds on all Distribution parameters
#'
#' @return A list representing the free (recursive) parameter structure of
#' the Distribution with `Interval` objects as values representing the
#' bounds of the respective free parameters.
#'
#' @examples
#' dist_mixture(
#' list(dist_dirac(), dist_exponential()),
#' probs = list(0.5, 0.5)
#' )$get_param_bounds()
#'
#' dist_mixture(
#' list(dist_dirac(), dist_exponential())
#' )$get_param_bounds()
#'
#' dist_genpareto()$get_param_bounds()
#' dist_genpareto1()$get_param_bounds()
get_param_bounds = function() {
proto <- private$.make_params(list(), 1L)
get_bounds <- function(elem, ranges) {
if (is.null(elem)) {
ranges
} else if (is.list(elem) &&
length(elem) == 2L &&
hasName(elem, "dist") &&
hasName(elem, "params") &&
is.Distribution(elem$dist)) {
elem$dist$get_param_bounds()
} else if (is.list(elem) && rlang::is_named(elem)) {
mapply(get_bounds, elem,
ranges[names(elem)], SIMPLIFY = FALSE)
} else if (is.list(elem)) {
mapply(get_bounds, elem,
rep_len(ranges, length(elem)), SIMPLIFY = FALSE)
} else {
NULL
}
}
prune <- function(elem) {
if (is.list(elem)) {
non_null <- !vapply(elem, is.null, logical(1L))
lapply(elem[non_null], prune)
} else if (is.Interval(elem)) {
elem
} else {
stop("non-list and non-Interval encountered during pruning.") # nocov
}
}
res <- get_bounds(elem = proto, ranges = private$.params[names(proto)])
prune(res)
},
#' @details Get additional (non-linear) equality constraints on Distribution
#' parameters
#'
#' @return `NULL` if the box constraints specified by
#' `dist$get_param_bounds()` are sufficient, or a function taking full
#' Distribution parameters and returning either a numeric vector
#' (which must be 0 for valid parameter combinations) or a list with
#' elements
#'
#' * `constraints`: The numeric vector of constraints
#' * `jacobian`: The Jacobi matrix of the constraints with respect to the
#' parameters
#'
#' @examples
#' dist_mixture(
#' list(dist_dirac(), dist_exponential())
#' )$get_param_constraints()
get_param_constraints = function() {
NULL
},
#' @details Export sampling, density, probability and quantile functions
#' to plain R functions
#'
#' Creates new functions in `envir` named `{r,d,p,q}<name>` which implement
#' `dist$sample`, `dist$density`, `dist$probability` and `dist$quantile` as
#' plain functions with default arguments specified by `with_params` or the
#' fixed parameters.
#'
#' The resulting functions will have signatures taking all parameters as
#' separate arguments.
#'
#' @param name common suffix of the exported functions
#' @param envir Environment to export the functions to
#' @param with_params Optional list of parameters to use as default values
#' for the exported functions
#'
#' @return Invisibly `NULL`.
#'
#' @examples
#' tmp_env <- new.env(parent = globalenv())
#' dist_exponential()$export_functions(
#' name = "exp",
#' envir = tmp_env,
#' with_params = list(rate = 2.0)
#' )
#' evalq(
#' fitdistrplus::fitdist(rexp(100), "exp"),
#' envir = tmp_env
#' )
export_functions = function(name, envir = parent.frame(),
with_params = list()) {
params <- self$get_params(with_params = with_params)
params_flat <- tryCatch(
flatten_params(params),
# If flattening is impossible, allow no arguments
error = function(e) numeric()
)
make_fun <- function(prefix, first_param, func_name,
general_params = list()) {
if (length(params_flat)) {
param_list <- lapply(
names(params_flat),
function(nm) substitute(as.name(nm), list(nm = nm))
)
names(param_list) <- names(params_flat)
param_list <- do.call(call, c("c", param_list))
param_list <- as.call(list(
quote(reservr::inflate_params), param_list
))
} else {
param_list <- list()
}
ffmls <- c(alist(x = ), params_flat, general_params) # nolint lintr/#532
names(ffmls)[1L] <- first_param
general_params <- c(list(x = as.name(first_param)), general_params)
names(general_params)[1L] <- first_param
# Can't use bquote(..., splice = TRUE) for backward compatibility with R < 4.0
# Instead we construct the spliced call manually
spliced_call <- as.call(c(
substitute(self$func_name, list(func_name = func_name)),
general_params,
alist(with_params = params)
))
fbody <- bquote({
params <- .(param_list)
tryCatch(
.(spliced_call),
error = function(e) {
warning("Error during evaluation; returning NaNs.\n", e)
rep_len(NaN, length(first_param))
}
)
})
fun <- as.function(c(ffmls, fbody))
fun_name <- paste0(prefix, name)
message("Exported `", fun_name, "()`.")
assign(fun_name, fun, envir = envir)
}
if (self$has_capability("density")) {
make_fun("d", "x", "density", list(log = FALSE))
}
if (self$has_capability("sample")) {
make_fun("r", "n", "sample")
}
if (self$has_capability("probability")) {
make_fun("p", "q", "probability",
list(lower.tail = TRUE, log.p = FALSE))
}
if (self$has_capability("quantile")) {
make_fun("q", "p", "quantile",
list(lower.tail = TRUE, log.p = FALSE))
}
invisible(NULL)
}
),
private = list(
.make_params = function(with_params, n) {
# Make a list of objects
# 1. list(...) => list()
# 2. numeric => numeric
# 3. distribution => distribution + params
make_params_elem <- function(proto, value) {
if (is.list(proto)) {
if (is.null(value)) {
value <- vector("list", length(proto))
} else if (length(value) != length(proto)) {
value <- rep_len(value, length(proto))
}
mapply(make_params_elem, proto, value, SIMPLIFY = FALSE)
} else if (is.Distribution(proto)) {
if (is.list(value) && setequal(names(value), c("dist", "params")) && is.Distribution(value[["dist"]])) {
# Prevent re-wrapping when called twice e.g. via hazard()
value
} else {
list(
dist = proto,
params = value
)
}
} else {
if (!is.null(value)) {
if (is.vector(value)) {
rep_len(value, length.out = n)
} else {
value
}
} else {
if (is.vector(proto)) {
rep_len(proto, length.out = n)
} else {
proto
}
}
}
}
res <- lapply(names(private$.params), function(nm) {
make_params_elem(private$.default_params[[nm]], with_params[[nm]])
})
names(res) <- names(private$.params)
res
},
.caps = character(),
.type = character(),
.params = list(),
.default_params = list(),
.name = "",
.tf_functions = list(
nograph = list(),
float32 = list(),
float64 = list()
),
.tf_retrieve_or_call = function(name, impl) {
check_installed(c("keras3", "tensorflow"))
cache_key <- keras3::config_floatx()
res <- private$.tf_functions[[cache_key]][[name]]
use_cache <- getOption("reservr.cache_tf_function", default = TRUE)
if (is.null(res) ||
(cache_key != "nograph" && reticulate::py_is_null_xptr(res)) ||
!use_cache) {
fn <- impl()
assign("tf", tensorflow::tf, environment(fn))
res <- maybe_tf_function(fn)
private$.tf_functions[[cache_key]][[name]] <- res
}
res
}
),
active = list(
#' @field default_params Get or set (non-recursive) default parameters of a
#' Distribution
default_params = function(value) {
if (missing(value)) {
private$.default_params
} else {
assert_that(
is.list(value),
setequal(names(private$.default_params), names(value)),
msg = paste0(
"`default_params` must be a list with names ",
enumerate_strings(names(private$.default_params), quote = "'")
)
)
# Perform bounds checking
bounds <- private$.params
for (param in names(private$.default_params)) {
curr_bounds <- bounds[[param]]
if (is.Interval(curr_bounds)) {
assert_that(
is.null(value[[param]]) || is.numeric(value[[param]]) &&
all(curr_bounds$contains(value[[param]])),
msg = sprintf(
"`default_params$%s` must be numeric in %s, or NULL.",
param, curr_bounds
)
)
} else if (is.list(curr_bounds) &&
length(curr_bounds) &&
is.Interval(curr_bounds[[1L]])) {
assert_that(
is.list(value[[param]]),
all(vapply(
value[[param]],
function(val) is.null(val) || (is.numeric(val) && all(curr_bounds[[1L]]$contains(val))),
logical(1L)
)),
msg = sprintf(
paste(
"`default_params$%s` must be a list of",
"numeric values in %s, or NULLs"
),
param,
curr_bounds
)
)
}
}
# New defaults are assumed valid -> accept.
# Also ensure the order of elements stays the same just in case.
private$.default_params <- value[names(private$.default_params)]
}
},
#' @field param_bounds Get or set (non-recursive) parameter bounds
#' (box constraints) of a Distribution
param_bounds = function(value) {
if (missing(value)) {
private$.params
} else {
assert_that(
is.list(value),
setequal(names(private$.params), names(value)),
msg = paste0("`param_bounds` must be a list with names ",
enumerate_strings(names(private$.params), quote = "'"))
)
# Check type safety
bounds <- private$.params
check_types <- function(old, new, path = "") {
if (is.list(old) &&
length(old) == 1L &&
is.null(names(old)) &&
is.Interval(old[[1L]])) {
expected <- "a list containing an Interval"
} else if (is.list(old) && !length(old)) {
expected <- "an empty list"
} else if (is.list(old)) {
expected <- paste0("a list with ", length(old), "elements")
} else if (is.Interval(old)) {
expected <- "an Interval"
}
ok <- if (is.list(old)) {
if (length(old) == 1L &&
is.null(names(old)) &&
is.Interval(old[[1L]])) {
length(new) == 1L &&
is.null(names(new)) &&
is.Interval(new[[1L]])
} else {
is.list(new) && length(new) == length(old)
}
} else if (is.Interval(old)) {
is.Interval(new)
} else {
FALSE
}
if (!ok) {
stop(
"param_bounds", path, " must be ", expected, ".",
call. = FALSE
)
}
if (is.list(old) &&
length(old) &&
!is.Interval(old[[1L]])) {
sub_names <- if (!is.null(names(old))) {
paste(path, names(old), sep = "$")
} else {
paste0(path, "[[", seq_along(old), "]]")
}
if (!is.null(names(old))) {
if (!setequal(names(old), names(new))) {
stop(
"param_bounds", path, " must be a list with names ",
enumerate_strings(names(old), quote = "'"), ".",
call. = FALSE
)
}
new <- new[names(old)]
} else {
names(new) <- NULL
}
new <- mapply(
check_types,
old = old,
new = new,
path = sub_names,
SIMPLIFY = FALSE
)
}
# With fixed names / ordering
new
}
private$.params <- check_types(bounds, value)
}
}
)
)
# nocov start
distribution_class <- function(
name,
type = c("continuous", "discrete"),
params = list(),
sample = NULL,
density = NULL,
probability = NULL,
quantile = NULL,
diff_density = NULL,
diff_probability = NULL,
hazard = function(x, log, params) {
if (log) {
self$density(x, log = TRUE, with_params = params) -
self$probability(x, lower.tail = FALSE, log.p = TRUE,
with_params = params)
} else {
self$density(x, with_params = params) /
self$probability(x, lower.tail = FALSE, with_params = params)
}
},
support = I_REALS,
is_discrete = function(x, params) {
if (self$is_discrete()) {
self$is_in_support(x, params)
} else {
logical(length(x))
}
},
tf_logdensity = NULL,
tf_logprobability = NULL,
tf_is_discrete_at = NULL,
...,
active = list()
) {
clsname <- paste0(
toupper(substr(name, 1, 1)),
substr(name, 2, nchar(name)),
"Distribution"
)
type <- match.arg(type)
caps <- c(
character(),
if (!missing(sample)) "sample" else NULL,
if (!missing(density)) "density" else NULL,
if (!missing(probability)) "probability" else NULL,
if (!missing(quantile)) "quantile" else NULL,
if (!missing(diff_density)) "diff_density" else NULL,
if (!missing(diff_probability)) "diff_probability" else NULL,
if (!missing(tf_logdensity)) "tf_logdensity" else NULL,
if (!missing(tf_logprobability)) "tf_logprobability" else NULL
)
if (is.Interval(support)) {
support_interval <- support
support <- function(x, params) {
support_interval$contains(x)
}
}
R6Class(
classname = clsname,
inherit = Distribution,
public = list(
initialize = function(...) {
super$initialize(
type = type,
caps = caps,
params = params,
name = name,
default_params = list(...)
)
},
sample = function(n, with_params = list()) {
self$require_capability("sample")
params <- private$.make_params(with_params, n)
if (n == 0L) return(numeric())
private$.sample_impl(n = n, params = params)
},
density = function(x, log = FALSE, with_params = list()) {
self$require_capability("density")
params <- private$.make_params(with_params, length(x))
if (!length(x)) return(numeric())
private$.density_impl(x = x, log = log, params = params)
},
probability = function(q, lower.tail = TRUE, log.p = FALSE,
with_params = list()) {
self$require_capability("probability")
params <- private$.make_params(with_params, length(q))
if (!length(q)) return(numeric())
private$.probability_impl(q = q, lower.tail = lower.tail, log.p = log.p,
params = params)
},
quantile = function(p, lower.tail = TRUE, log.p = FALSE,
with_params = list()) {
self$require_capability("quantile")
params <- private$.make_params(with_params, length(p))
if (!length(p)) return(numeric())
private$.quantile_impl(p = p, lower.tail = lower.tail, log.p = log.p,
params = params)
},
hazard = function(x, log = FALSE, with_params = list()) {
self$require_capability(c("density", "probability"))
params <- private$.make_params(with_params, length(x))
if (!length(x)) return(numeric())
private$.hazard_impl(x = x, log = log, params = params)
},
diff_density = function(x, log = FALSE, with_params = list()) {
self$require_capability("diff_density")
params <- private$.make_params(with_params, length(x))
if (!length(x)) return(empty_derivative(self$get_placeholders()))
private$.diff_density_impl(x = x, vars = self$get_placeholders(),
log = log, params = params)
},
diff_probability = function(q, lower.tail = TRUE, log.p = FALSE,
with_params = list()) {
self$require_capability("diff_probability")
params <- private$.make_params(with_params, length(q))
if (!length(q)) return(empty_derivative(self$get_placeholders()))
private$.diff_probability_impl(q = q, vars = self$get_placeholders(),
lower.tail = lower.tail, log.p = log.p,
params = params)
},
is_in_support = function(x, with_params = list()) {
params <- private$.make_params(with_params, length(x))
if (!length(x)) return(logical())
private$.support_impl(x = x, params = params)
},
is_discrete_at = function(x, with_params = list()) {
params <- private$.make_params(with_params, length(x))
if (!length(x)) return(logical())
private$.is_discrete_impl(x = x, params = params)
},
tf_logdensity = function() {
if (is.null(private$.tf_logdensity_impl)) {
super$tf_logdensity()
} else {
private$.tf_retrieve_or_call(
"tf_logdensity",
private$.tf_logdensity_impl
)
}
},
tf_logprobability = function() {
if (is.null(private$.tf_logprobability_impl)) {
super$tf_logprobability()
} else {
private$.tf_retrieve_or_call(
"tf_logprobability",
private$.tf_logprobability_impl
)
}
},
tf_is_discrete_at = function() {
if (is.null(private$.tf_is_discrete_at_impl)) {
super$tf_is_discrete_at()
} else {
private$.tf_retrieve_or_call(
"tf_is_discrete_at",
private$.tf_is_discrete_at_impl
)
}
},
...
),
private = list(
.sample_impl = sample,
.density_impl = density,
.probability_impl = probability,
.quantile_impl = quantile,
.hazard_impl = hazard,
.diff_density_impl = diff_density,
.diff_probability_impl = diff_probability,
.support_impl = support,
.is_discrete_impl = is_discrete,
.tf_logdensity_impl = tf_logdensity,
.tf_logprobability_impl = tf_logprobability,
.tf_is_discrete_at_impl = tf_is_discrete_at
),
active = active
)
}
distribution_class_simple <- function(name,
fun_name,
type = c("continuous", "discrete"),
params = list(),
support = I_REALS,
envir = parent.frame(),
...) {
type <- match.arg(type)
sample_fun <- get(paste0("r", fun_name), mode = "function", envir = envir)
density_fun <- get(paste0("d", fun_name), mode = "function", envir = envir)
probability_fun <- get(
paste0("p", fun_name), mode = "function", envir = envir
)
quantile_fun <- get(paste0("q", fun_name), mode = "function", envir = envir)
eval(substitute(
distribution_class(
name = name,
type = type,
params = params,
sample = function(n, params) {
do.call(sample_fun, c(list(n = n), params))
},
density = function(x, log, params) {
do.call(density_fun, c(list(x = x, log = log), params))
},
probability = function(q, lower.tail, log.p, params) {
do.call(probability_fun,
c(list(q = q, lower.tail = lower.tail, log.p = log.p), params))
},
quantile = function(p, lower.tail, log.p, params) {
do.call(quantile_fun,
c(list(p = p, lower.tail = lower.tail, log.p = log.p), params))
},
support = support,
compile_sample = function() {
compile_simple_function(sample_fun, self)
},
compile_density = function() {
compile_simple_function(density_fun, self)
},
compile_probability = function() {
compile_simple_function(probability_fun, self)
},
compile_quantile = function() {
compile_simple_function(quantile_fun, self)
},
compile_probability_interval = function() {
if (self$is_continuous()) {
compile_simple_prob_continuous(probability_fun, self)
} else { # => self$is_discrete()
compile_simple_prob_discrete(probability_fun, density_fun, self)
}
},
...
),
env = list(
sample_fun = sample_fun,
density_fun = density_fun,
probability_fun = probability_fun,
quantile_fun = quantile_fun
)
))
}
# nocov end
#' Test if object is a Distribution
#'
#' @param object An R object.
#'
#' @examples
#' is.Distribution(dist_dirac())
#'
#' @export
#' @return `TRUE` if `object` is a Distribution, `FALSE` otherwise.
is.Distribution <- function(object) {
inherits(object, "Distribution")
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.