Nothing
# ---- general utils ----
is_backend <- function(name) {
identical(keras$config$backend(), name)
}
is_windows <- function() {
identical(.Platform$OS.type, "windows")
}
is_osx <- function() {
Sys.info()["sysname"] == "Darwin"
}
is_mac_arm64 <- function() {
sys_info <- Sys.info()
sys_info[["sysname"]] == "Darwin" &&
sys_info[["machine"]] == "arm64"
}
is_scalar <- function(x) identical(length(x), 1L)
# is_py_object <- function(x) is_py_object(x)
split_dots_named_unnamed <- function(dots) {
nms <- names(dots)
if (is.null(nms))
return(list(unnamed = dots, named = list()))
named <- nzchar(nms)
list(unnamed = dots[!named], named = dots[named])
}
drop_nulls <- function(x, i = NULL) {
if(is.null(i))
return(x[!vapply(x, is.null, FALSE, USE.NAMES = FALSE)])
drop <- logical(length(x))
names(drop) <- names(x)
drop[i] <- vapply(x[i], is.null, FALSE, USE.NAMES = FALSE)
x[!drop]
}
#' @importFrom rlang dots_list
# identical to rlang::list2(), except .named = TRUE
named_list <- function(...)
dots_list(...,
.named = TRUE,
# not the default
.ignore_empty = "trailing",
.preserve_empty = FALSE,
.homonyms = "error",
.check_assign = FALSE)
`append1<-` <- function(x, value) {
x[[length(x) + 1L]] <- value
x
}
`append<-` <- function(x, value) c(x, value)
`prepend<-` <- function(x, value) c(value, x) # c(x[integer()], value, x)
replace_val <- function(x, old, new) {
if (!is_scalar(new))
stop("Unexpected length of replacement value in replace_val().\n",
"`new` must be length 1, not ", length(new))
x[x %in% old] <- new
x
}
imap <- function(.x, .f, ...) {
out <- .mapply(.f, list(.x, names(.x) %||% seq_along(.x)), list(...))
names(out) <- names(.x)
out
}
map2 <- function(.x, .y, .f, ...) {
out <- .mapply(.f, list(.x, .y), list(...))
if(length(.x) == length(out))
names(out) <- names(.x)
out
}
map_chr <- function(.x, .f, ...) {
out <- vapply(X = .x, FUN = .f, FUN.VALUE = "", ..., USE.NAMES = FALSE)
names(out) <- names(.x)
out
}
map_lgl <- function(.x, .f, ...) {
out <- vapply(X = .x, FUN = .f, FUN.VALUE = TRUE, ..., USE.NAMES = FALSE)
names(out) <- names(.x)
out
}
map_int <- function(.x, .f, ...) {
out <- vapply(X = .x, FUN = .f, FUN.VALUE = 0L, ..., USE.NAMES = FALSE)
names(out) <- names(.x)
out
}
last <- function(x) x[[length(x)]]
second_to_last <- function(x)
if((lx <- length(x)) > 1) x[[lx-1L]]
rename <- function(x, ..., .skip_existing = TRUE) {
dots <- list(...)
nms <- names(x)
for(i in seq_along(dots)) {
newname <- names(dots)[[i]]
oldname <- dots[[i]]
if(.skip_existing && newname %in% nms)
next
nms[match(oldname, nms)] <- newname
}
names(x) <- nms
x
}
# arg_to_positional <- function(args, name, pos = 1L) {
# idx <- match(name, names(args))
# if(is.na(idx))
# return(args)
# arg <- args[idx]
# args <- args[-idx]
#
# head <- args[seq_len(pos-1)]
# tail <- args[seq.int(from = pos, to = length(args))]
# if((length(head)+1) != pos)
# stop("Not enough argument supplied")
# names(arg) <- NULL
# c(head, arg, tail)
# }
args_to_positional <- function(args, nms) {
idxs <- match(nms, names(args))
if(all(is.na(idxs)))
return(args)
if(any(is.na(idxs)))
stop("Required positional arguments not supplied: ",
paste0(nms, collapse = ", "))
positional_args <- args[idxs]
names(positional_args) <- NULL
named_args <- args[-idxs]
c(positional_args, named_args)
}
`%""%` <- function (x, y) {
if(!is.character(x))
stop("x must be character")
not_empty <- nzchar(x)
if(all(not_empty))
return(x)
if(!is.character(y))
stop("y must be character")
# don't force `y` unless needed
if (!identical(length(y), length(x))) {
stopifnot(identical(length(y), 1L))
y <- rep(y, length(x))
}
empty <- !not_empty
x[empty] <- y[empty]
x
}
# ---- arg checkers ----
check_bool <- function(x) {
if (identical(x, TRUE) || identical(x, FALSE))
x
else
stop(sprintf("`%s` arg must be `TRUE` or `FALSE`",
deparse1(substitute(x))))
}
# ---- arg transformers ----
as_array <- function(x)
if(is.null(x) || is_py_object(x) || is.array(x))
x else as.array(x)
as_py_array <- function(x)
if(is.null(x) || is_py_object(x))
x else np_array(x)
as_r_value <- function (x)
if (is_py_object(x))
py_to_r(x) else x
as_axis <- function(axis) {
if (is.null(axis))
return(NULL)
if (length(axis) > 1)
return(lapply(axis, as_axis))
axis <- as.integer(axis)
if (axis == 0L)
stop("`axis` argument is 1 based, received 0")
if (axis > 0L) axis - 1L
else axis
}
# Helper function to coerce shape arguments to tuple
# tf$reshape()/k_reshape() doesn't accept a tf.TensorShape object
normalize_shape <- function(shape) {
# reflect NULL back
if (is.null(shape))
return(shape)
# already fixed up
if (inherits(shape, "keras_shape"))
return(shape)
# if it's a list or a numeric vector then convert to integer
# NA's in are accepted as NULL
# also accept c(NA), as if it was a numeric
if (is.list(shape) || is.numeric(shape) ||
(is.logical(shape) && all(is.na(shape)))) {
shape <- lapply(shape, function(value) {
# Pass through python objects unmodified, only coerce R objects
# supplied shapes, e.g., to tf$random$normal, can be a list that's a mix
# of scalar integer tensors and regular integers
if (is_py_object(value))
return(value)
# accept NA,NA_integer_,NA_real_ as NULL
if ((is_scalar(value) && is.na(value)))
return(NULL)
if (!is.null(value))
as.integer(value)
else
NULL
})
}
if (inherits(shape, "tensorflow.python.framework.tensor_shape.TensorShape"))
shape <- as.list(shape$as_list()) # unpack for tuple()
# coerce to tuple so it's iterable
tuple(shape)
}
as_integer <- function(x) {
if (is.numeric(x))
as.integer(x)
else
x
}
as_integer_array <- function(x) {
if(is.atomic(x))
x <- as.array(x)
if(is.array(x) && storage.mode(x) != "integer")
storage.mode(x) <- "integer"
x
}
as_integer_tuple <- function(x, force_tuple = FALSE) {
if (is.null(x) || is_py_object(x))
x
else if (is.list(x) || force_tuple)
tuple(as.list(as.integer(x)))
else
as.integer(x)
}
as_scalar_or_tuple <- function(x) {
if(is_py_object(x))
return(x)
if(is.atomic(x) && is_scalar(x))
return(x)
tuple(as.list(x))
}
as_tuple <- function(x) {
if(is_py_object(x))
return(x)
tuple(as.list(x))
}
as_nullable_integer <- function(x) {
if (is.null(x))
x
else
as.integer(x)
}
as_layer_index <- function(x) {
if (is.null(x))
return(x)
x <- as.integer(x)
if (x == 0L)
stop("`index` for get_layer() is 1-based (0 was passed as the index)")
if (x > 0L)
x - 1L
else
x
}
as_node_index <- function(node_index) {
as.integer(node_index-1)
}
# Helper function to normalize paths
normalize_path <- function(path) {
if (is.null(path))
NULL
else
normalizePath(path.expand(path), mustWork = FALSE)
}
# unused
as_index <- function(x) {
if(storage.mode(x) == "double")
storage.mode(x) <- "integer"
# k_array() pass through here...
# TODO: implement an efficient way to check for negative slices
x - 1L
}
# Sketch for an alternative approach to offsetting indexes,
# so that they are 1 based in the R runtime, but convert into python
# as 0 based. Alternative implementaiton for Callback() epochs,
# LearningRateSchedule(), and similar.
#
# as_r_index <- function(x) {
# if(is.double(x))
# x <- as.integer(x)
# class(x) <- c("r_index", class(x))
# x
# }
#
# r_to_py.r_index <- function(x) {
# if (x > 0L) x - 1L else x
# }
#
# zero_to_one_index <- function(x) x + 1L
# ---- resolve_py_obj ----
resolve_wrapper_py_obj_expr <- function(x, prefer_class = TRUE) {
if (!identical(class(x), "function"))
return()
ns <- environment(sys.function()) # keras3 namespace
xe <- environment(x)
if (identical(xe, emptyenv()))
return()
# only inspect pkg functions, or pkg wrapped functions
## is a wrapper returned by new_wrapped_py_class(), like Layer()
if (identical(parent.env(xe), ns))
return(quote(`__class__`))
## is a pkg exported function
if (!(identical(xe, ns)))
return()
# standard builtin wrapper, like layer_dense, loss_*
# (or Layer(), though that's handled above)
last_cl <- last(body(x))
if (is.call(last_cl) &&
(identical(last_cl[[1L]], quote(do.call)) ||
identical(last_cl[[1L]], quote(create_layer)))) {
expr <- last_cl[[2L]]
if (identical(expr, quote(callable))) {
# loss_ or metric_ wrapper
if (prefer_class)
expr <- second_to_last(body(x))[[c(3, 3)]]
else
expr <- second_to_last(body(x))[[c(3, 4)]]
}
return(expr)
}
# application wrapper
if (is.call(last_cl) &&
identical(last_cl[[1L]], quote(set_preprocessing_attributes)) &&
is.call(last_cl2 <- as.list(body(x))[[length(body(x)) - 1L]]) &&
(identical(last_cl2[[c(3L, 1L)]], quote(do.call))))
return(last_cl2[[c(3L, 2L)]])
# bare builtin op_wrapper, like
# op_add <- function(x1, x2) keras$ops$add(x1, x2)
if (is.call(cl <- body(x)) &&
(is.call(cl0 <- cl1 <- cl[[1L]]) ||
(
identical(cl0, quote(`{`)) &&
length(cl1 <- as.list(cl[-1])) == 1 &&
is.call(cl <- cl1[[1L]]) &&
is.call(cl0 <- cl1 <- cl[[1L]])
)))
{
while (is.call(cl0) && identical(cl0[[1L]], quote(`$`)))
cl0 <- cl0[[2L]]
if (identical(cl0, quote(keras)))
return(cl1)
}
NULL
}
resolve_py_obj <- function(x, default_name = "anonymous_R_function",
env = asNamespace("keras3"),
prefer_class = TRUE,
convert = TRUE) {
# this function is used:
# - to resolve `inherit` args in the keras subclassing API
# (e.g., if `inherit` arg is a wrapper like `layer_dense`, or
# `layer_custom` returned by a Layer("Custom", ...))
# - to resolve args that can come in as callables to `compile()`
# (e.g., loss, metrics)
# - to resolve args that can come in as callables passed to layer_* constructors.
# (e.g., activations, initializers)
# - to resolve custom_objects supplied to the saving & serialization API,
# (e.g., with_custom_object_scope(), load_model(), ...)
# - `x` can come in as a language object, enabling lazy evaluation /
# delayed initialization python
# - If `x` is a package exported wrapper, like `layer_dense` or similar,
# this will return the py callable object, like `keras$layers$Dense`
# This should work with *all* exported wrappers
# (loss_, activation_, layer_, op_*, etc.)
# - Otherwise, If `x` is a bare R function, it will be coerced to
# a python function with `py_func2()`, which is similar to the default
# r_to_py(<func>) except:
# - the constructed python wrapper has an accurate signature that
# matches the R func (needed in some places where keras inspects the
# callable signature)
# - We work harder/better to resolve an appropriate __name__ (accepting
# R attributes "name", "__name__" and "py_function_name", and give an
# opportunity for us to provide a better default like "custom_metric"
# from methods like `compile()`)
# - Otherwise, we return `x` unmodified (assuming it will be coerced via
# r_to_py() downstream). If `convert = FALSE`, we eagerly call `r_to_py(x)`.
if (is.language(x))
x <- eval(x, env)
if (is.null(x) || is_py_object(x))
return(x)
if (is_bare_r_function(x)) {
py_obj_expr <- resolve_wrapper_py_obj_expr(x, prefer_class = prefer_class)
if (!is.null(py_obj_expr)) {
# eval in environment(x): wrapper env, where we might find `__class__`.
py_obj <- tryCatch(eval(py_obj_expr, environment(x)),
error = function(e) NULL)
if (is_py_object(py_obj))
return(py_obj)
}
return(as_py_function(x, default_name = default_name))
}
if (convert) x else r_to_py(x)
}
is_bare_r_function <- function(x) {
identical(class(x), "function")
}
as_py_name <- function(x) {
# sanitize a deparsed R expression into valid python symbol string
if(is.language(x))
x <- deparse(x, width.cutoff = 500L)[1]
x <- make.names(as.character(x))
x <- gsub(".", "_", x, fixed = TRUE)
x
}
as_py_function <- function(fn, default_name = "r_func") {
if(is_py_object(fn))
return(fn)
name <-
attr(fn, "py_function_name", TRUE) %||%
attr(fn, "__name__", TRUE) %||%
attr(fn, "name", TRUE) %||%
default_name
py_func2(fn, convert = TRUE, name = name)
}
get_function_name <- function(fn) {
if (is_py_object(fn))
return(py_to_r(py_get_attr(fn, "__name__")))
attr(fn, "py_function_name", TRUE) %||%
attr(fn, "__name__", TRUE) %||%
attr(fn, "name", TRUE)
}
# if(FALSE) {
# # TODO: use this to generate a static list for populating
# # a reverse lookup hashtable
# x <- lapply(asNamespace("keras3"), resolve_wrapper_py_obj_expr) |>
# purrr::map_chr(\(expr) if(is.null(expr)) "" else deparse1(expr))
# df <- tibble::enframe(x, value = "expr")
# df <- df[order(df$name),]
# success <- df$expr != ""
#
#
# df[success, ] |> print(n = Inf)
# df[!success, ] |> print(n = Inf)
#
# # prefer_class = FALSE
# x <- lapply(asNamespace("keras3"), resolve_wrapper_py_obj_expr,
# prefer_class = FALSE) |>
# purrr::map_chr(\(expr) if(is.null(expr)) "" else deparse1(expr))
# df <- tibble::enframe(x, value = "expr")
# df <- df[order(df$name),]
# success <- df$expr != ""
# df[success, ] |> print(n = Inf)
# df[!success, ] |> print(n = Inf)a
# }
# as_activation <- NULL
# on_load_make_as_activation <- function() {
# if (getRversion() < "4.2") {
# as_activation <<- .as_activation
# } else {
# as_activation <<- local({
# # make a hashtab to do reverse look ups, converting exported closures like
# # `activation_elu` to a builtin activation name string "elu". The
# # motivation is to avoid needlessly popping out to an R closure if we're
# # using a bultin. We have to do this at runtime since the hastab
# # needs the closure object address.
# delayedAssign("h", local({
# nms <- grep("^activation_", getNamespaceExports("keras3"), value = TRUE)
# h <- utils::hashtab("address", length(nms))
# ns <- asNamespace("keras3")
# for (name in nms)
# utils::sethash(h, getExportedValue(ns, name),
# substr(name, 12L, 999L))
# h
# }))
#
# function(x) utils::gethash(h, x) %||% .as_activation(x)
# })
# }
# }
#
# .as_activation <- function(x) {
# if (is.null(x) || is_py_object(x))
# return(x)
#
# name <- attr(x, "py_function_name", TRUE)
# if (is_string(name) && identical(x, get0(
# paste0("activation_", name),
# envir = environment(sys.function()),
# inherits = FALSE
# )))
# # it's a builtin; the name string will be resolved upstream via
# # keras.activations.get(name)
# return(name)
#
# if (is.function(x))
# return(as_py_function(x, default_name = "custom_activation"))
# x
# }
#
# ---- capture_args ----
# capture_args_v1 <-
function(cl, modifiers = NULL, ignore = NULL,
envir = parent.frame(), fn = sys.function(-1)) {
## bug: match.call() resolves incorrectly if dots are from not the default sys.parent()
## e.g, this fails if dots originate from the callers caller:
# cl <- eval(quote(match.call()), parent.frame())
## workaround: caller must call match.call() from the correct frame
## note: capture_args_v1() must always be called at the top level of the intended function body.
## sys.function(-1) resolves to the incorrect function if the capture_args()
## call is itself a promise in another call. E.g.,:
## do.call(foo, capture_args_v1(match.call())) fails because fn resolves to do.call()
fn_arg_nms <- names(formals(fn))
known_args <- intersect(names(cl), fn_arg_nms)
known_args <- setdiff(known_args, ignore)
names(known_args) <- known_args
cl2 <- c(quote(list), lapply(known_args, as.symbol))
if("..." %in% fn_arg_nms && !"..." %in% ignore) {
assert_all_dots_named(envir, cl)
# this might reorder args by assuming ... are last, but it doesn't matter
# since everything is supplied as a keyword arg to the Python side anyway
cl2 <- c(cl2, quote(...))
}
args <- eval(as.call(cl2), envir)
# check `ignore` again, since arg might have been in `...`
for(nm in intersect(names(args), ignore))
args[[nm]] <- NULL
nms_to_modify <- intersect(names(args), names(modifiers))
for (nm in nms_to_modify)
args[nm] <- list(modifiers[[nm]](args[[nm]]))
# list() so if modifier returns NULL, don't remove the arg
args
}
#' @importFrom rlang list2
capture_args <- function(modifiers = NULL, ignore = NULL, force = NULL,
enforce_all_dots_named = TRUE) {
call <- sys.call(-1L)
envir <- parent.frame(1L)
fn <- sys.function(-1L)
# if("capture_args" %in% all.names(call, unique = TRUE))
# stop("incorrect usage of capture_args(), must be evaluated as ",
# "a standard expression, not as not a promise (i.e., not as part ",
# "of a call of another function")
# match.call() automatically omits missing() args in the returned call. These
# user calls all standardize to the same thing:
# - layer_dense(, 10)
# - layer_dense(object = , 10)
# - layer_dense(object = , 10, )
# - layer_dense(, 10, )
# all standardize to:
# - layer_dense(units = 10)
call <- match.call(fn, call, expand.dots = TRUE, envir = parent.frame(2))
# message("call: ", deparse1(call))
fn_arg_nms <- names(formals(fn))
known_args <- intersect(names(call), fn_arg_nms)
if (length(ignore) && !is.character(ignore)) {
# e.g., ignore = c("object", \(nms) startsWith(nms, "."))
ignore <- as.character(unlist(lapply(ignore, function(ig) {
if (is.character(ig)) return(ig)
stopifnot(is.function(ig))
ig <- ig(known_args) # ignore fn can return either lgl or int for [
if (!is.character(ig))
ig <- known_args[ig]
ig
}), use.names = FALSE))
}
known_args <- setdiff(known_args, ignore)
known_args <- union(known_args, force)
names(known_args) <- known_args
if ("..." %in% fn_arg_nms && !"..." %in% ignore) {
if (enforce_all_dots_named)
assert_all_dots_named(envir, call)
# match.call already drops missing args that match to known args, but it
# doesn't protect from missing args that matched into ...
# use list2() to allow dropping a trailing missing arg in ... also
dots <- quote(...)
list_sym <- quote(list2)
} else {
dots <- NULL
list_sym <- quote(list)
}
# this might reorder args by assuming ... are last, but it doesn't matter
# since everything is supplied as a keyword arg to the Python side anyway
call <- as.call(c(list_sym, lapply(known_args, as.symbol), dots))
args <- eval(call, envir)
# filter out ignore again, in case any were in ...
# we could probably enhance the `call` constructed above to use, e.g.,
# ..1, ..2, ..4, to skip ignores, and avoid forcing them.
if (length(ignores_in_dots <- intersect(names(call), ignore)))
args[ignores_in_dots] <- NULL
# apply modifier functions. e.g., as_nullable_integer()
if (length(names_to_modify <-
intersect(names(args), names(modifiers))))
args[names_to_modify] <-
map2(modifiers[names_to_modify], args[names_to_modify],
function(modifier, arg) modifier(arg))
args
}
capture_args3 <-
function(modifiers = NULL, ignore = NULL) {
# currently unused
# like capture_args(), but will also unpack `!!!args`
# e.g.,
# constraints <- list(kernel_constraint = constraint_unitnorm(),
# bias_constraint = constraint_unitnorm())
# layer_dense(units = 2, !!!constraints)
cl0 <- cl <- sys.call(-1L)
envir <- parent.frame(2L)
fn <- sys.function(-1L)
# first defuse rlang !!! and := in calls
cl[[1L]] <- rlang::quos
cl_exprs <- eval(cl, envir)
# build up a call to base::list() using the exprs
cl <- as.call(c(list, cl_exprs))
# match.call()
cl <- match.call(fn, cl,
expand.dots = !"..." %in% ignore,
envir = envir)
# filter out args to ignore
for(ig in intersect(names(cl), ignore))
cl[[ig]] <- NULL
# eval and capture args
args <- rlang::eval_tidy(cl, env = envir)
# apply modifier functions. e.g., as_nullable_integer
nms_to_modify <- intersect(names(args), names(modifiers))
for (name in nms_to_modify)
# list() so if modifier returns NULL, don't remove the arg
args[name] <- list(modifiers[[name]](args[[name]]))
args
}
modify_intersection <- function(x, modifiers) {
for (name in intersect(names(x), names(modifiers))) {
x[[name]] <- modifiers[[name]](x[[name]])
}
x
}
assert_all_dots_named <- function(envir = parent.frame(), cl) {
x <- evalq(eval(substitute(alist(...))), envir)
if (!length(x)) return()
# ignore trailing missing arg
if (identical(x[length(x)], list(quote(expr =))))
x[[length(x)]] <- NULL
if (!length(x)) return()
x <- names(x)
if (is.character(x) && !anyNA(x) && all(x != ""))
return()
stop("All arguments provided to `...` must be named.\n",
"Call with unnamed arguments in dots:\n ",
paste(deparse(cl, 500L), collapse = "\n"))
}
# ---- py helpers ----
py_is <- function(x, y) {
is_py_object(x) &&
is_py_object(y) &&
identical(py_id(x), py_id(y))
}
have_module <- function(module) {
tryCatch({ import(module); TRUE; }, error = function(e) FALSE)
}
have_h5py <- function() {
have_module("h5py")
}
have_pyyaml <- function() {
have_module("yaml")
}
have_requests <- function() {
have_module("requests")
}
have_pillow <- function() {
have_module("PIL") # aka Pillow
}
# ---- unused / dead ----
relative_to <- function(dir, file) {
# normalize paths
dir <- normalizePath(dir, mustWork = FALSE, winslash = "/")
file <- normalizePath(file, mustWork = FALSE, winslash = "/")
# ensure directory ends with a /
if (!identical(substr(dir, nchar(dir), nchar(dir)), "/")) {
dir <- paste(dir, "/", sep="")
}
# if the file is prefixed with the directory, return a relative path
if (identical(substr(file, 1, nchar(dir)), dir))
file <- substr(file, nchar(dir) + 1, nchar(file))
# simplify ./
if (identical(substr(file, 1, 2), "./"))
file <- substr(file, 3, nchar(file))
file
}
# internal `[` method that ensures functions in this namespace use one-based
# indexing in case user has a global option set for zero-based indexing.
if (FALSE) {
# roxygen2 now wants this exported.
`[.tensorflow.tensor` <-
getS3method("[", "tensorflow.tensor", envir = asNamespace("tensorflow"))
formals(`[.tensorflow.tensor`)$style <- "R"
formals(`[.tensorflow.tensor`)$options <-
tensorflow::tf_extract_opts(
one_based = TRUE,
inclusive_stop = TRUE,
disallow_out_of_bounds = TRUE,
warn_tensors_passed_asis = FALSE,
warn_negatives_pythonic = FALSE
)
}
standard_layer_arg_modifiers <- list(
input_shape = normalize_shape,
batch_input_shape = normalize_shape,
batch_size = as_nullable_integer,
seed = as_nullable_integer
)
if (getRversion() < "4.0")
activeBindingFunction <- function(nm, env) {
as.list.environment(env, all.names = TRUE)[[nm]]
}
# don't dispatch to as.list(), just wrap in list()
as_list <- function(x) if (is.null(x) || is.list(x)) x else list(x)
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.