#!source envir::attach_source(.file)
# ---- setup attach ----
options(dplyr.summarise.inform = FALSE)
library_stable <-
function(package,
lifecycle_exclude = c(
"superseded",
"deprecated",
"questioning",
"defunct",
"experimental",
"soft-deprecated",
"retired"
), ...,
exclude = NULL) {
package <- deparse1(substitute(package))
exclude <- lifecycle::pkg_lifecycle_statuses(
package, lifecycle_exclude)$fun |> c(exclude)
# message(package, ", excludeing: ", paste0(exclude, collapse = ", "))
library(package, character.only = TRUE, ...,
exclude = exclude)
}
library_stable(fs)
library_stable(readr)
library_stable(stringr)
library_stable(tibble)
library_stable(tidyr)
library_stable(dplyr, warn.conflicts = FALSE)
library_stable(rlang)
library_stable(purrr)
library_stable(stringr)
library_stable(glue)
library_stable(cli)
library(envir)
library(commafree)
library(magrittr, include.only = c("%>%", "%<>%"))
library(reticulate)
library(assertthat, include.only = c("assert_that"))
import_from(roxygen2, block_has_tags)
import_from(memoise, memoise)
import_from(reticulate, system2t)
# ---- utils utils ----
is_scalar <- function(x) identical(length(x), 1L)
replace_val <- function (x, old, new) {
if (!is_scalar(new))
stop("Unexpected length of replacement value in replace_val().\n",
"`new` must be length 1, not ", length(new))
x[x %in% old] <- new
x
}
`%error%` <- rlang::zap_srcref(function(x, y) tryCatch(x, error = function(e) y))
py_is <- function(x, y) identical(py_id(x), py_id(y))
str_split1_on_first <- function(x, pattern, ...) {
stopifnot(length(x) == 1, is.character(x))
regmatches(x, regexpr(pattern, x, ...), invert = TRUE)[[1L]]
}
str_flatten_lines <- function(..., na.rm = FALSE) {
stringr::str_flatten(unlist(c(...)), na.rm = na.rm, collapse = "\n")
}
str_split_lines <- function(...) {
x <- c(...)
if(length(x) > 1)
x <- str_flatten_lines(x)
stringr::str_split_1(x, "\n")
}
str_flatten_and_compact_lines <- function(..., roxygen = FALSE) {
out <- str_split_lines(...) |>
str_trim("right") |>
str_flatten_lines() |>
str_replace_all("\n{4,}", "\n\n\n")
if(roxygen)
out <- out |> str_replace_all("\n(#'\n){3,}", "\n#'\n#'\n")
out
}
# str_compact_newlines <- function(x, max_consecutive_new_lines = 2) {
# x <- x |> str_flatten_lines()
# while (nchar(x) != nchar(x <- gsub(
# strrep("\n", max_consecutive_new_lines + 1),
# strrep("\n", max_consecutive_new_lines),
# x, fixed = TRUE))) {}
# x
# }
str_detect_non_ascii <- function(chr) {
map_lgl(chr, \(x) {
if (Encoding(x) == "unknown") {
Encoding(x) <- "UTF-8"
}
!isTRUE(Encoding(x) == "unknown")
})
}
str_remove_non_ascii <- function(chr) {
chr <- str_split_lines(chr)
i <- str_detect_non_ascii(chr)
if (any(i)) {
chr[i] <- chr[i] |> map_chr(\(x) {
iconv(x, to = "ASCII", sub = "")
})
}
chr
}
trimws_file <- function(...) {
for(f in Sys.glob(c(...))) {
txt <- f |> readLines() |> trimws("right")
length(txt) <- Position(nzchar, txt, right = TRUE)
writeLines(txt, f, useBytes = TRUE)
}
}
`append<-` <- function(x, after = length(x), value) {
append(x, value, after)
}
unwhich <- function(x, len) {
lgl <- logical(len)
lgl[x] <- TRUE
lgl
}
in_recursive_call <- function() {
fn <- sys.call(-1)[[1]]
for(cl in head(sys.calls(), -2))
if(identical(cl[[1]], fn))
return(TRUE)
return(FALSE)
}
# use like:
# chat_messages(system = "optional",
# user = "question",
# assistant = "answer",
# user = "question")
chat_messages <- function(...) {
x <- rlang::dots_list(..., .named = TRUE, .ignore_empty = "all")
stopifnot(all(names(x) %in% c("system", "user", "assistant")))
unname(imap(x, \(content, role)
list(role = role,
content = trim(str_flatten_lines(content)))))
}
dput_cb <- function (x, echo = TRUE) {
obj_nm <- deparse(substitute(x))
out <- clipr::write_clip(capture.output(dput(x)))
if (echo)
cat("Object", sQuote(obj_nm), "is now in clipboard:\n",
glue::glue_collapse(out, sep = "\n", width = 160), "\n")
invisible(x)
}
cat_cb <- function(x, echo = TRUE, trunc_echo = TRUE, obj_nm = deparse(substitute(x))) {
force(obj_nm)
catted <- capture.output(cat(x))
tryCatch({
out <- clipr::write_clip(catted)
if (echo) {
width <- if (trunc_echo) 160 else Inf
message("output of 'cat(", obj_nm, ")' is now in clipboard:\n",
glue::glue_collapse(out, sep = "\n", width = width), "\n")
}
}, error = function(e) {
warning("Could not write to clipboard")
print(e)
cat("This is what would have been written to clipboard:\n\n",
catted, "\n", sep = "")
})
invisible(x)
}
split_by <- function (.data, ..., .drop = FALSE, .sep = ".") {
if (!missing(...))
.data <- group_by(.data, ..., .add = FALSE)
if (is_grouped_df(.data)) {
idx <- group_indices(.data)
vars_chr <- group_vars(.data)
.data <- ungroup(.data)
}
else {
idx <- seq_len(nrow(.data))
vars_chr <- NULL
}
out <- split.data.frame(.data, idx)
if (!is.null(vars_chr)) {
names(out) <- vapply(out, function(df)
do.call(paste, c(unclass(as.data.frame(df[1L, vars_chr, drop = FALSE])),
sep = .sep)), "")
}
if (isTRUE(.drop))
for (i in seq_along(out))
for (v in vars_chr)
out[[i]][[v]] <- NULL
out
}
# attach_eval({
# }, mask.ok = "c")
.conflicts.OK <- "c"
# ignore empty trailing arg
c <- function(...)
base::do.call(base::c, rlang::list2(...), quote = TRUE)
.sprintf_transformer <- function(text, envir) {
m <- regexpr(":.+$", text)
if (m != -1L) {
format <- substring(regmatches(text, m), 2)
regmatches(text, m) <- ""
} else {
format <- "s"
}
res <- eval(str2lang(text), envir)
sprintf(paste0("%", format), res)
}
glue_fmt <- function(..., .envir = parent.frame()) {
glue(..., .transformer = .sprintf_transformer, .envir = .envir)
}
# ---- venv ----
# reticulate::virtualenv_remove("r-keras")
# if(!virtualenv_exists("r-keras")) install_keras()
use_virtualenv("r-keras", required = TRUE)
inspect <- import("inspect")
keras <- import("keras")
tf <- import("tensorflow")
local({
`__main__` <- reticulate::import_main()
`__main__`$keras <- keras
`__main__`$tf <- tf
`__main__`$tensorflow <- tf
})
py_eval("tf.experimental.numpy.experimental_enable_numpy_behavior()")
source_python("tools/common.py") # keras_class_type()
# ---- collect endpoints ----
list_endpoints <- function(
module = "keras", max_depth = 4,
skip = "keras.src",
skip_regex = sprintf("\\.%s$", c("experimental", "deserialize", "serialize", "get")),
include_modules = FALSE) {
.list_endpoints <- function(.module, depth) {
if (depth > max_depth) return()
module_py_obj <- py_eval(.module)
lapply(names(module_py_obj), \(nm) {
endpoint <- paste0(.module, ".", nm)
if(length(endpoint) > 1) browser()
if (endpoint %in% skip) return()
if (any(str_detect(endpoint, skip_regex))) return()
# message(endpoint)
endpoint_py_obj <- module_py_obj[[nm]]
if (inherits(endpoint_py_obj, "python.builtin.module")) {
out <- .list_endpoints(endpoint, depth = depth + 1L)
if(include_modules)
append(out) <- endpoint
return(out)
}
if (inherits(endpoint_py_obj, c("python.builtin.type",
"python.builtin.function")))
return(endpoint)
NULL
})
}
unlist(lapply(module, .list_endpoints, depth = 1L))
}
filter_out_endpoint_aliases <- function(endpoints) {
stopifnot(is.character(endpoints))
# some endpoints are aliases. resolve unique endoints.
endpoints <-
endpoints %>%
tibble(endpoint = .) %>%
mutate(py_obj = map(endpoint, py_eval)) %>%
# filter non py objects, e.g., version strings
filter(map_lgl(py_obj, inherits, "python.builtin.object")) %>%
mutate(
id = map_chr(py_obj, py_id),
py_type = map_chr(py_obj, \(o) class(o) %>%
grep("^python\\.builtin\\.", ., value = TRUE) %>%
sub("python.builtin.", "", ., fixed = TRUE) %>%
.[1])) %>%
# filter out literal aliases, i.e., identical py ids.
dplyr::group_split(id) %>%
map(\(df) {
if(nrow(df) == 1) return(df)
# for the pooling layer aliases, pick the longer/clearer name
if(all(grepl(r"(keras\.layers\.(Global)?(Avg|Average|Max)Pool(ing)?[1234]D)",
df$endpoint)))
return(df |> slice_max(nchar(endpoint)))
# keep aliases of losses under metrics, for complete autocomplete with 'metric_'
if(all(df$endpoint |> str_detect("^keras\\.(metrics|losses)\\."))) {
# message("keeping aliases: ", str_flatten_comma(df$endpoint))
return(df)
}
# `_api_export_path`
# `_api_export_symbol_id`
# if(any(df$endpoint %>% str_detect("keras.ops.average_pool"))) browser()
out <- try(df %>% filter(py_obj[[1]]$`_api_export_path`[[1]] == endpoint))
# if(inherits(out, "try-error")) browser()
return(out)
# otherwise, default to picking the shortest name, but most precise
# sort keras.preprocessing before keras.utils
df |>
mutate(
name = str_split(endpoint, fixed(".")) %>% map_chr(., ~.x[length(.x)]),
submodule = str_split(endpoint, fixed(".")) %>% map_chr(., ~.x[length(.x)-1])
) |>
arrange(nchar(name), submodule) |>
slice(1)
# slice_min(nchar(endpoint), n = 1, with_ties = FALSE)
}) %>%
list_rbind() %>%
# filter duplicate endpoint names, where all that differs is capitalization.
# this mostly affects endpoints offered as both function and class handles:
# metrics, losses, and merging layers.
# if we have both functional and type api interfaces, we
# just want the type right now (we dynamically fetch the matching
# functional counterpart later as needed, e.g., metrics)
split(., snakecase::to_upper_camel_case(.$endpoint)) %>%
map(\(df) {
if(nrow(df) == 1) return(df)
# prefer names w/ more capital letters (i.e., the class handles)
i <- which.max(nchar(gsub("[^A-Z]*", "", df$endpoint)))
# message(sprintf("keep: %s drop: %s",
# str_pad(df$endpoint[i], 50, "right"),
# str_flatten_comma(df$endpoint[-i])))
df[i,]
}) %>%
list_rbind() %>%
.$endpoint %>% unlist()
endpoints
}
if(FALSE) {
list_endpoints(c("keras.activations", "keras.regularizers"))
all_endpoints <- list_endpoints()
all_endpoints %>%
grep("ImageDataGenerator", ., value = T)
all_endpoints %>%
grep("Image", ., value = T, ignore.case = TRUE)
}
# ---- docstrings ----
known_section_headings <- c %(% {
"Args:"
"Arguments:"
"Attributes:"
"Returns:"
"Raises:"
"Input shape:"
"Inputs:"
# "Input:",
"Output shape:"
# "Outputs:",
"Output:"
"Call arguments:"
"Call Args:"
"Returns:"
"Example:"
"References:"
"Reference:"
"Examples:"
"Notes:"
"Note:"
"Usage:"
# "Standalone usage:"
}
# PROMPT: keyword arguments -> named list
# PROMPT: list/tuple -> list
# PROMPT: None -> NULL
# PROMPT: True -> TRUE
# PROMPT: False -> FALSE
# TODO: @references tag for "References:" section
tidy_section_headings <- known_section_headings |>
str_sub(end = -2) |>
replace_val("Args", "Arguments") |>
replace_val("Example", "Examples") |>
replace_val("Call Args", "Call arguments") |>
# replace_val("Inputs", "Input shape") |>
# replace_val("Outputs", "Output shape") |>
snakecase::to_snake_case()
split_docstring_into_sections <- function(docstring) {
if(is.null(docstring))
return(NULL)
assert_that(docstring |> is_string()) %error% browser()
docstring <- docstring |>
str_split_lines() |> str_trim("right") |> str_flatten_lines()
docstring <- docstring |> fence_in_examples_w_prompt_prefix()
docstring <- docstring |> backtick_not_links()
for(h in known_section_headings) {
# "Input shape:" not always on own line :*-(
docstring <- gsub(paste0(h, " "), paste0("\n", h, "\n"), docstring, fixed = TRUE)
}
# docstring <- docstring %>%
# str_split_lines() %>%
# # Input shape section in keras.layers.ConvLSTM{12}D has badly formatted item list
# # (indentation)(wtf) - (If) -> (indentation)(wtf)\n(indentation) - (If)
# sub("^([ ]*)(.*[^ ].*)+[ ]+- (If|Else)", "\\1\\2\n\\1- \\3", .) %>%
# str_flatten_lines()
x <- str_split_1(docstring, "\n")
if(length(x) <= 3)
return(list(title = x[1],
description = str_flatten(x[-1], "\n")))
m <- match(str_trim(x), known_section_headings)
m <- zoo::na.locf0(m)
m <- tidy_section_headings[m]
m[1:2] <- m[1:2] %|% c("title", "description")
m <- zoo::na.locf0(m)
# if(grepl("FeatureSpace"))
# If sections follow Args
# if(grepl("FeatureSpace", docstring)) browser()
# if(grepl("Leaky relu activation function.", docstring)) browser()
# TODO: in BatchNormalization, there is more prose after "Reference:" section
ind_lvl <- str_width(str_extract(x, "^[ ]*"))
maybe_new_sec <- which(m == "reference" &
ind_lvl == 0 &
str_width(str_trim(x)) > 0 &
!startsWith(str_trim(x), "- "))[-1]
maybe_new_sec %<>% .[. > which.max(m == "arguments")]
if(length(maybe_new_sec)) {
m[maybe_new_sec[1]:length(m)] %<>% replace_val("reference", "description")
append(m, after = maybe_new_sec[1]) <- "description"
x[maybe_new_sec[1]] %<>% paste0("\n", .)
x %<>% str_split_lines()
}
# keras.utils.FeatureSpace
if("arguments" %in% m) {
# check if args designation overflowed into another section
# args are always termianted by an empty line
is_arg <- m == "arguments"
w_is_arg <- which(is_arg)
new_lines_in_section <- which(is_arg & x == "")
new_lines_in_section %<>% .[. < (length(x)-1)]
for (i in new_lines_in_section) {
if (x[i + 1] == "" || ind_lvl[i + 1] == ind_lvl[w_is_arg[1]]) {
w_is_arg2 <- w_is_arg %>% .[. < i]
not_a <- setdiff(w_is_arg, w_is_arg2)
if (length(not_a))
m[not_a] <- "description"
}
}
}
sections <- split(x, m) |>
imap(\(s, nm) {
if (!nm %in% c("description", "title", "details")) # "note"
s <- s[-1] # drop heading
s |> str_trim("right") |> str_flatten("\n") |> trim()
})
sections <- sections[unique(c(m))] # split() reorders
sections
}
vararg_paramater_names <- function(py_obj) {
params <- inspect$signature(py_obj)$parameters$values()
params %>%
as_iterator() %>% iterate(\(p) {
if (p$kind %in% c(inspect$Parameter$VAR_POSITIONAL,
inspect$Parameter$VAR_KEYWORD))
p$name
else
NULL
}) %>% unlist()
}
parse_params_section <- function(docstring_args_section, treat_as_dots = NULL) {
for (chk_empty_fn in c(is.null, is.na, partial(`==`, "")))
if(chk_empty_fn(docstring_args_section))
return(list())
if(!is_scalar(docstring_args_section))
docstring_args_section %<>% str_flatten("\n")
# if(str_detect(docstring_args_section, fixed("recall: A scalar value in range `[0,"))) browser()
x <- docstring_args_section |>
c("") |> # append final new line to ensure glue::trim() still works when length(x) == 1
str_flatten("\n") |> glue::trim() |> str_split_1("\n")
## Replace with this smaller regex once https://github.com/keras-team/keras/pull/18790 is merged
# m <- str_match(x, "^(?<name>[*]{0,2}[A-Za-z0-9_]+): *(?<desc>.*)$")
m <- str_match(x, "^\\s{0,3}(?<name>[*]{0,2}[A-Za-z0-9_]+) ?: *(?<desc>.*)$")
description <- m[,"desc"] %|% x
param_name <- m[,"name"]
not_param <- which(param_name == "Default")
if(length(not_param)) {
description[not_param] <- x[not_param]
param_name[not_param] <- NA
}
param_name[startsWith(param_name, "*")] <- "..."
# param_name[param_name == "kwargs"] <- "..."
param_name[param_name %in% treat_as_dots] <- "..."
param_name %<>% zoo::na.locf0()
# {if(any(startsWith(param_name, "*"))) browser()} %error% browser()
# if both *args and **kwargs are documented, collapse them into the "..." param (todo)
out <- split(description, param_name) |>
map(\(s) s|> str_trim("right") |> str_flatten("\n"))
out <- out[unique(param_name)] # preserve original order
out
}
fence_in_examples_w_prompt_prefix <- function(docstring) {
# fence in code examples that are missing fences
# but start with >>>
docstring <- str_split_1(docstring, "\n") |> str_trim("right")
docstring0 <- str_trim(docstring)
in_code_block <- FALSE
i <- 0L
while(i < length(docstring)) {
i <- i + 1L
if(startsWith(docstring0[[i]], "```")) {
in_code_block <- TRUE
while(i < length(docstring0)) {
if (startsWith(docstring0[[i <- i + 1L]], "```")) {
in_code_block <- FALSE
break
}
}
next
}
if (!in_code_block) {
# not currently in a code block, look for the start of a code block
if (startsWith(docstring0[[i]], ">>>")) {
in_code_block <- TRUE
docstring[[i]] <- str_c("```python\n",
str_replace(docstring[[i]], ">>> ", ""))
}
} else {
# code block context open,
# look for the end of the code block
if (docstring0[[i]] == "") {
docstring[[i]] <- "```\n"
in_code_block <- FALSE
# tidy up example commands
} else if(any(startsWith(docstring0[[i]], c(">>> ", "... ")))) {
docstring[[i]] <- str_sub(docstring[[i]], 5L)
# } else if(docstring0[[i]] == "```") {
# # code block closed in docstring
# in_code_block <- FALSE
} else
# tidy up example output
docstring[[i]] <- str_c("# ", docstring[[i]])
}
}
if(in_code_block)
docstring <- c(docstring, "```\n")
docstring <- str_flatten(docstring, collapse = "\n")
}
str_split_1_inclusive <- function(string, pattern) {
unlist(stringi::stri_split_regex(
string, pattern =
paste0("(?<=(", pattern, "))", # match before pattern
"|", # or
"(?=(", pattern, "))"))) # match after pattern
}
backtick_not_links <- function(d) {
d %>%
str_split_1_inclusive("\n```") %>%
as.list() %>%
{
# mark prose and code sections.
type <- "prose"
for(i in seq_along(.)) {
if(.[[i]] == "\n```") {
names(.)[i] <- "delim"
type <- switch(type, prose = "code", code = "prose")
next
}
if(type == "prose") {
s <- str_split_1_inclusive(.[[i]], "`")
stype <- "prose"
for(si in seq_along(s)) {
if(s[[si]] == "`") {
names(s)[si] <- "delim"
stype <- switch(stype, prose = "code", code = "prose")
next
}
names(s)[si] <- stype
}
.[i] <- list(s)
next
}
if(type == "code")
names(.)[i] <- "code"
}
. <- unlist(.)
names(.) <- gsub("NA.", "", names(.), fixed = TRUE)
.
} %>% {
i <- which(names(.) == "prose")
names(.) <- NULL
.[i] <- .backtick_not_links(.[i])
.
} %>%
str_flatten()
}
.backtick_not_links <- function(x) {
# if(any(str_detect(x, fixed("["))))
# browser()
# [0-9a-zA-Z,\[\]\ -><.]+ # Capture numbers, brackets, and specific symbols
re <- regex(comments = TRUE, pattern = r"--(
( # Open capture group for main pattern
[^\ \n]* # anything not a space or newline
\[ # Match an opening square bracket
.+ # capture anything greedily
\] # Match a closing square bracket
[^\ \n\(]* # anything not a space or bounary or open parens
) # Close capture group
(?!\() # Lookahead to ensure no opening parenthesis follows
)--")
rep <- r"--(`\1`)--"
str_replace_all(x, re, rep)
}
# ---- roxygen -----
import_from("tools/family-maps.R", r_name_to_family_map)
get_family <- function(endpoint, r_name = make_r_name(endpoint)) {
r_name_to_family_map[[r_name]]
}
make_roxygen_tags <- function(endpoint, py_obj = py_eval(endpoint), type) {
out <- list()
out$export <- ""
out$family <- get_family(endpoint)
out$tether <- endpoint
links <- c(get_keras_doc_link(endpoint),
get_tf_doc_link(endpoint))
out$seealso <- str_flatten_lines("", glue("+ <{links}>"))
out
}
remap_families <- function(x) map_chr(x, \(autogend_fam) {
switch(autogend_fam,
"schedules optimizers" =,
"learning schedules optimizers" =,
"rate learning schedules optimizers" =,
"schedule rate learning schedules optimizers" =,
"rate learning schedules optimizers" = "optimizer learing rate schedules",
autogend_fam)
})
.keeper_families <- #NULL
c("io utils", "attention layers", "constant initializers", "constraints",
"dataset utils", "regularizers", "saving", "accuracy metrics",
"image ops", "iou metrics", "normalization layers", "probabilistic metrics",
"activation layers", "global pooling layers", "optimizer learing rate schedules",
"config backend", "core layers", "random preprocessing layers",
"regularization layers", "backend", "merging layers", "random",
"regression metrics", "convolutional layers", "callbacks", "confusion metrics",
"image utils", "random initializers", "optimizers", "pooling layers",
"reshaping layers", "rnn layers", "core ops", "initializers",
"math ops", "activations", "preprocessing layers", "losses",
"nn ops", "utils", "metrics", "layers", "numpy ops", "ops") |> remap_families() |> unique()
get_tf_doc_link <- function(endpoint) {
url_tail <- str_replace_all(endpoint, fixed('.'), '/')
glue("https://www.tensorflow.org/api_docs/python/tf/{url_tail}")
}
import_from("tools/endpoints-to-keras-io-urls.R", upstream_keras_io_urls_map)
get_keras_doc_link <- function(endpoint) {
upstream_keras_io_urls_map[[endpoint]]
}
get_layer_family <- function(layer) {
family <- layer$`__module__` |>
str_extract(".*\\.layers\\.(.*)s?\\.", group = 1) |>
replace_val("rnn", "recurrent") |>
str_replace_all("_", " ")
if(is.na(family) && py_is(layer, keras$layers$Input))
family <- "core"
str_c(family, " layers")
}
# ---- r wrapper --------
make_r_name <- function(endpoint, module = py_eval(endpoint)$`__module__`) {
type <- keras_class_type(py_eval(endpoint))
# manual renames
if(!is.null(r_name <- switch %(% { endpoint
"keras.utils.array_to_img" = "image_from_array"
"keras.utils.img_to_array" = "image_to_array"
"keras.utils.load_img" = "image_load"
"keras.utils.save_img" = "image_array_save"
"keras.utils.pad_sequences" = "pad_sequences"
"keras.layers.LayerNormalization" = "layer_layer_normalization"
"keras.utils.FeatureSpace" = "layer_feature_space"
"keras.ops.in_top_k" = "op_in_top_k"
"keras.ops.top_k" = "op_top_k"
"keras.ops.image.pad_images" = "op_image_pad"
"keras.layers.StackedRNNCells" = "rnn_cells_stack"
"keras.layers.SimpleRNNCell" = "rnn_cell_simple"
"keras.random.randint" = "random_integer"
"keras.layers.Bidirectional" = "bidirectional"
NULL
})) return(r_name)
x <- endpoint |> str_split_1(fixed("."))
x <- lapply(x, function(.x) switch(
.x,
"keras" = character(),
"preprocessing" = character(),
"utils" = character(),
"distribution" = "distributed",
.x
))
name <- x[length(x)]
prefix <- x[-length(x)] |> unlist() |>
str_replace("e?s$", "") |> str_flatten("_")
if(type == "learning_rate_schedule")
prefix <- "learning_rate_schedule"
name <- name |>
str_replace("NaN", "Nan") |>
str_replace("RMSprop$", "Rmsprop") |>
str_replace("ReLU$", "Relu") |>
str_replace("ResNet", "Resnet") |>
str_replace("ConvNeXt", "Convnext") |>
str_replace("XLarge", "Xlarge") |>
str_replace("IoU", "Iou") |>
str_replace("FBeta", "Fbeta")
if(str_detect(name, "[[:upper:]]"))
name %<>% snakecase::to_snake_case()
name <- name |>
str_replace("_([0-9])_d(_|$)", "_\\1d\\2") |> # conv_1_d -> conv_1d
str_replace_all("(^|_)l_([12])", "\\1l\\2") |> # l_1_l_2 -> l1_l2
# applications
str_replace_all("_v_([1234])($|_)", "_v\\1\\2") |>
str_replace_all("^resnet_([0-9]+)", "resnet\\1") |>
str_replace_all("^vgg_([0-9]+)", "vgg\\1") |>
str_replace_all("^dense_net_", "densenet") |>
str_replace_all("^mobile_net?", "mobilenet") |>
str_replace("f_1", "f1") |>
str_replace("r_2", "r2") |>
str_replace_all("max_norm", "maxnorm") |>
str_replace_all("non_neg", "nonneg") |>
str_replace_all("min_max", "minmax") |>
str_replace_all("unit_norm$", "unitnorm") |>
str_replace("tensor_board", "tensorboard") |>
identity()
if (str_detect(name, "efficient_net_"))
name <- name |>
str_split_1("efficient_net_") |> str_replace_all("_", "") |>
str_flatten("efficientnet_")
if(str_detect(name, "nas_net_"))
name %<>% str_replace_all("_", "")
name <- str_flatten(c(prefix, name), collapse = "_")
name %<>%
str_split_1("_") %>%
unique() %>%
.[nzchar(.)] %>%
str_flatten("_")
## fixes for specific wrappers
# no _ in up_sampling
name %<>% str_replace("^layer_up_sampling_", "layer_upsampling_")
# activation layers get a prefix (except for layer_activation())
if(startsWith(name, "layer_") &&
str_detect(module, "\\.activations?\\.") &&
name != "layer_activation")
name %<>% str_replace("^layer_", "layer_activation_")
name <- name |>
replace_val("layer_activation_p_relu", "layer_activation_parametric_relu") |>
replace_val("regularizer_orthogonal_regularizer", "regularizer_orthogonal") |>
identity()
if(grepl("^layer_.*_cell$", name))
name <- sub("layer_(.*)_cell$", "rnn_cell_\\1", name)
name
}
transformers_registry <-
yaml::read_yaml("tools/arg-transformers.yml") %>%
imap(\(args, endpoint) {
if(!str_detect(endpoint, "[*?]")) # not a glob
names(args) <- str_c("+", names(args))
# if(endpoint == "keras.random.randint") browser()
lapply(args, function(fn) {
# if(grepl("normalize_padding", fn)) browser()
# if(!is_string(fn)) browser()
if(is.null(fn)) return(fn)
fn <- str2lang(fn)
if (is.call(fn) && !identical(fn[[1]], quote(`function`)))
fn <- as.function.default(c(alist(x =), fn))
fn
})
})
get_arg_transformers <- function(endpoint, py_obj = py_eval(endpoint), params = NULL) {
if(is.null(params)) {
params <- get_fixed_docstring(endpoint) |> trim() |>
split_docstring_into_sections() |>
_$arguments |> parse_params_section()
}
transformers <- list()
frmls <- formals(py_obj)
pre_registered <- transformers_registry %>%
.[str_detect(endpoint, glob2rx(names(.)))] %>%
unname() %>% unlist(recursive = FALSE, use.names = TRUE)
add_if_present <- pre_registered %>% .[!startsWith(names(.), "+")]
add_always <- pre_registered %>% .[startsWith(names(.), "+")]
names(add_always) %<>% str_sub(2, -1)
rm(pre_registered)
# build up a default list
# - integer default values get a "as_integer"
# - "int" in param description gets a "as_integer"
for (key in c(names(frmls), names(params))) {
if (!is.null(tr <- add_if_present[[key]])) {
transformers[[key]] <- tr
} else if (typeof(frmls[[key]]) == "integer" ||
str_detect(params[[key]] %||% "",
regex("\\bint(eger)?\\b", ignore_case = TRUE))) {
transformers[[key]] <- quote(as_integer)
}
}
transformers %<>% modifyList(add_always)
if (!length(transformers))
transformers <- NULL
transformers
}
make_r_fn <- function(endpoint,
py_obj = py_eval(endpoint),
type = keras_class_type(py_obj),
transformers = get_arg_transformers(endpoint, py_obj, params = NULL)) {
# TODO: rename keras_class_type to keras_endpoint
if(endpoint |> startsWith("keras.ops.")) {
fn <- make_r_fn.op(endpoint, py_obj, transformers)
} else if(endpoint |> startsWith("keras.activation")) {
fn <- make_r_fn.activation(endpoint, py_obj, transformers)
}
# if(endpoint == "keras.random.randint") {
# frmls <- formals(keras$random$randint)
# frmls$minval <- 0L
# frmls$maxval <- 1L
# return(as.function.default(, quote({
# args <- capture_args(list(shape = normalize_shape, seed = as_integer))
# do.call(keras$random$randint, args)
# }))
# }
else {
fn <- switch(keras_class_type(py_obj),
layer = make_r_fn.layer(endpoint, py_obj, transformers),
metric = , loss = make_r_fn.loss(endpoint, py_obj, transformers),
make_r_fn.default(endpoint, py_obj, transformers))
}
if (length(transformers)) {
frmls <- formals(fn)
for (name in intersect(names(frmls), names(transformers))) {
default <- frmls[[name]]
if(rlang::is_missing(default)) next
# if(deparse(t))
transformer_expr <- transformers[[name]]
if(deparse1(transformer_expr) %in% c("as_axis", "as_index")) {
default2 <- rapply(list(default),
\(x) if(x >= 0L) x + 1L else x,
classes = "integer", how = "replace")[[1L]]
frmls[name] <- list2("{name}" := default2)
}
# is.integer(default) && default >= 0L) {
#
# }
# transformer <- eval(transformers[[name]], asNamespace('keras'))
# frmls[name] <- list2("{name}" := transformer(default)) # } %error% browser()
}
ofn <- fn
formals(fn) <- frmls
if(length(attrs <- attributes(ofn)))
attributes(fn) <- attrs
}
fn
}
make_r_fn.default <- function(endpoint, py_obj, transformers) {
# transformers <- get_arg_transformers(py_obj)
# if(py_is(py_obj, keras$losses$BinaryCrossentropy)) browser()
py_obj_expr <- endpoint_to_expr(endpoint)
if (!length(transformers))
transformers <- NULL
frmls <- formals(py_obj)
body <- bquote({
args <- capture_args(.(transformers))
do.call(.(py_obj_expr), args)
})
if(endpoint == "keras.utils.set_random_seed") {
body <- bquote({
args <- capture_args(.(transformers))
set.seed(args$seed)
# set Python/NumPy random seed
reticulate::py_set_seed(args$seed)
do.call(.(py_obj_expr), args)
})
}
# if(endpoint == "keras.preprocessing.image.save_img")
if(endpoint == "keras.utils.save_img")
frmls <- frmls[unique(c("x", "path", names(frmls)))] # swap so img is first arg, better for pipe
# frmls <- frmls[c(2, 1, 3:length(frmls))] # swap so img is first arg, better for pipe
as.function.default(c(frmls, body))
}
make_r_fn.op <- function(endpoint, py_obj, transformers) {
frmls <- formals(py_obj)
py_obj_expr <- endpoint_to_expr(endpoint)
syms <- lapply(names(frmls), as.symbol)
# names(syms) <- names(frmls) # TODO: args passed should be named
body <- cl <- as.call(c(py_obj_expr, syms))
if(endpoint == "keras.ops.meshgrid") {
# browser()
body <- rlang::zap_srcref(quote({
args <- lapply(list(...), function(x) {
if(storage.mode(x) == "double")
np_array(x, 'int64')
else
x
})
keras$ops$meshgrid(!!!args, indexing = indexing)
}))
transformers <- NULL
}
if(endpoint == "keras.ops.shape") {
stopifnot(length(frmls) == 1)
return(rlang::zap_srcref(function(x) {
out <- keras$ops$shape(x)
class(out) <- "keras_shape"
out
}))
}
if(endpoint == "keras.ops.array") {
if(!identical(frmls, as.pairlist(alist(x = , dtype = NULL)))) { browser(); stop()}
# k_array(c(1,2,3), "int32") fails because implicitly casting eager
# float tensors to int throws a python error. So we explicitly cast first.
return(function(x, dtype = NULL) {
if(!is.null(dtype) && !is_py_object(x))
x <- np_array(x, dtype)
keras$ops$array(x, dtype)
})
}
if(endpoint == "keras.ops.reshape") {
if(!identical(names(frmls), c("x", "newshape")))
browser()
return(function(x, ..., newshape = list(...)) {
keras$ops$reshape(x, tuple(lapply(shape(newshape), function(d) d %||% -1L)))
})
}
if(endpoint == "keras.ops.vectorized_map") {
names(frmls) %<>% replace_val("function", "f")
# swap the tensor is the first arg, for better pipe-ability
stopifnot(length(frmls) == 2) # new arg, update manually
return(function(elements, f) keras$ops$vectorized_map(f, elements))
# frmls <- frmls[unique(c(2, 1, seq_along(frmls)))]
}
if(!length(transformers)) {
fn <- as.function.default(c(frmls, body))
} else
fn <- make_r_fn.default(endpoint, py_obj, transformers)
fn
}
make_r_fn.activation <- function(endpoint,py_obj, transformers) {
fn <- make_r_fn.default(endpoint, py_obj, transformers)
attr(fn, "py_function_name") <- py_obj$`__name__`
fn
}
make_r_fn.loss <- function(endpoint, py_obj, transformers) {
# see if there is a function counterpart handle to the class handle
name <- py_obj$`__name__`
endpoint_fn <- str_replace(endpoint, name, snakecase::to_snake_case(name))
py_obj_fn <- tryCatch(py_eval(endpoint_fn),
python.builtin.AttributeError = function(e) NULL)
if(is.null(py_obj_fn)) {
fn <- make_r_fn.default(endpoint, py_obj, transformers)
formals(fn) <- formals(py_obj) %>%
c(alist(... =), .) %>% .[unique(names(.))]
return(fn)
}
py_obj_expr <- endpoint_to_expr(endpoint)
frmls <- formals(py_obj)
frmls <- c(alist(... = ), formals(py_obj))
frmls <- frmls[unique(names(frmls))]
#%error% browser()
### browser() if we need a simple fallback in case there is a loss with only a
### class handle and not a matching function handle
py_obj_expr_fn <- endpoint_to_expr(endpoint_fn)
transformers <- modifyList(transformers %||% list(),
get_arg_transformers(endpoint_fn) %||% list()) %error% browser()
if(!length(transformers))
transformers <- NULL
frmls <- c(formals(py_obj_fn), frmls)
frmls <- frmls[unique(names(frmls))]
stopifnot(names(frmls)[1:2] == c("y_true", "y_pred"))
body <- bquote({
args <- capture_args(.(transformers))
callable <- if (missing(y_true) && missing(y_pred))
.(py_obj_expr) else .(py_obj_expr_fn)
do.call(callable, args)
})
fn <- as.function.default(c(frmls, body))
attr(fn, "py_function_name") <- py_obj_fn$`__name__`
# TODO: compile should check for the bare metric/loss R wrapper being passed,
# and maybe unwrap it to avoid the R overhead.
# TODO: dump r docstring for the r function wrapper too
fn
}
make_r_fn.layer <- function(endpoint, py_obj, transformers) {
## build fn
frmls <- formals(py_obj)
frmls$self <- NULL
py_obj_expr <- str2lang(str_replace_all(endpoint, fixed("."), "$"))
## first deal w/ special cases / exceptions
if (grepl("layers.merging.", py_obj$`__module__`, fixed = TRUE)) {
# merging layers Add, Subtract, Dot, etc:
# - accept unnamed args as tensors in ...
# - first arg is `inputs`, not `object`
frmls <- c(alist(inputs = , ... =), frmls)
frmls <- frmls[unique(names(frmls))] # remove dup `...` if present
fn_body <- bquote({
args <- capture_args(.(transformers), ignore = c("...", "inputs"))
dots <- split_dots_named_unnamed(list(...))
if (missing(inputs))
inputs <- NULL
else if (!is.null(inputs) && !is.list(inputs))
inputs <- list(inputs)
inputs <- c(inputs, dots$unnamed)
args <- c(args, dots$named)
layer <- do.call(.(py_obj_expr), args)
if(length(inputs))
layer(inputs)
else
layer
})
} else if (grepl(".rnn.", py_obj$`__module__`, fixed = TRUE) &&
grepl("Cells?$", py_obj$`__name__`)) {
# rnn_cell_gru() and friends don't compose w/ `object`
# TODO: consider renaming these in keras 3, maybe something like
# rnn_cell_{gru,simple,stacked,lstm}()
fn_body <- bquote({
args <- capture_args(.(transformers))
do.call(.(py_obj_expr), args)
})
} else if (endpoint == "keras.layers.MultiHeadAttention") {
# first arg is `inputs`, a list
# TODO: in keras 3, change signature to:
# alist(query, key = query, value = key, ..., call_args = list())
frmls <- c(alist(inputs = ), frmls)
fn_body <- bquote({
args <- capture_args(.(transformers), ignore = "inputs")
layer <- do.call(.(py_obj_expr), args)
if (missing(inputs) || is.null(inputs))
return(layer)
if (!is.list(inputs))
inputs <- list(inputs)
do.call(layer, inputs)
})
} else if (endpoint == "keras.layers.Lambda") {
# rename `function` -> `func`. because roxygen gets very confused
# if you have an R parser reserved word as an arg name
names(frmls) %<>% replace_val("function", "f")
frmls <- c(alist(object = ), frmls)
fn_body <- bquote({
args <- capture_args(.(transformers), ignore = "object")
names(args)[match("f", names(args))] <- "function"
create_layer(.(py_obj_expr), object, args)
})
}
# else if (endpoint == "keras.ops.vectorized_map") {
# # rename `function` -> `func`. because roxygen gets very confused
# # if you have an R parser reserved word as an arg name
# names(frmls) %<>% replace_val("function", "f")
# fn_body <- bquote({
# args <- capture_args(.(transformers), ignore = "object")
# names(args)[match("f", names(args))] <- "function"
# create_layer(.(py_obj_expr), object, args)
# })
#
# }
else {
# default path for all other layers
frmls <- c(alist(object = ), frmls)
fn_body <- bquote({
args <- capture_args(.(transformers), ignore = "object")
create_layer(.(py_obj_expr), object, args)
})
}
fn <- as.function.default(c(frmls, fn_body), envir = emptyenv())
fn <- rlang::zap_srcref(fn)
fn
}
param_augment_registry <- yaml::read_yaml("tools/param-descripations.yml") %>%
map_depth(2, ~.x |> str_flatten_lines() |> str_squish())
# ---- make, format, dump ----
known_metrics_without_function_handles <- c %(% {
"keras.metrics.AUC"
"keras.metrics.BinaryIoU"
"keras.metrics.CosineSimilarity"
"keras.metrics.F1Score"
"keras.metrics.FalseNegatives"
"keras.metrics.FalsePositives"
"keras.metrics.FBetaScore"
"keras.metrics.IoU"
"keras.metrics.LogCoshError"
"keras.metrics.Mean"
"keras.metrics.MeanIoU"
"keras.metrics.MeanMetricWrapper"
"keras.metrics.OneHotIoU"
"keras.metrics.OneHotMeanIoU"
"keras.metrics.Precision"
"keras.metrics.PrecisionAtRecall"
"keras.metrics.R2Score"
"keras.metrics.Recall"
"keras.metrics.RecallAtPrecision"
"keras.metrics.RootMeanSquaredError"
"keras.metrics.SensitivityAtSpecificity"
"keras.metrics.SpecificityAtSensitivity"
"keras.metrics.Sum"
"keras.metrics.TrueNegatives"
"keras.metrics.TruePositives"
}
mk_export <- memoise(.mk_export <- function(endpoint, quiet = FALSE) {
# message(glue(".mk_export('{endpoint}', {quiet})"))
# py parts
py_obj <- py_eval(endpoint)
py_obj <<- py_obj
name <- py_obj$`__name__` %error% py_obj$fget$`__name__`
module <- py_obj$`__module__` %error% py_obj$fget$`__module__`
docstring <- get_fixed_docstring(endpoint)
type <- keras_class_type(py_obj)
doc <- split_docstring_into_sections(docstring)
# roxygen parts
r_name <- make_r_name(endpoint, module)
params <- parse_params_section(doc$arguments, treat_as_dots = vararg_paramater_names(py_obj))
if(endpoint %in% c("keras.layers.Lambda", "keras.ops.vectorized_map")) {
names(params)[match("function", names(params))] <- "f"
}
if ((endpoint |> startsWith("keras.losses.") ||
endpoint |> startsWith("keras.metrics.")) &&
inherits(py_obj, "python.builtin.type"))
local({
# each metric and loss is provided as two handles: a function and a class
# here, we fetch the matching function handle for the type handle
# we make an export obj for this function handle, and then recursively
# combine the docs and params from both.
endpoint2 <- str_replace(endpoint, name, snakecase::to_snake_case(name))
if(endpoint == endpoint2) {
message("Function handle not found for: ", endpoint)
# if `name` not same as endpoint tail, e.g., keras.losses.Reduction / ReductionV2
return()
}
py_obj2 <- NULL
tryCatch(
py_obj2 <- py_eval(endpoint2),
python.builtin.AttributeError = function(e) {
if(!endpoint %in% known_metrics_without_function_handles)
message(endpoint2, " not found")
}
)
if(is.null(py_obj2)) return()
ep2 <- mk_export(endpoint2, quiet = TRUE)
# params <<- modifyList(params, ep2$params)
for(nm in names(ep2$params)) {
pdoc1 <- params[[nm]]; pdoc2 <- ep2$params[[nm]]
if(str_squish(pdoc1 %||% "") == str_squish(pdoc2 %||% "")) next
params[[nm]] <- c(pdoc1, pdoc2) %>% .[which.max(str_length(str_squish(.)))] #str_split_lines() |> unique() |> str_flatten_lines()
}
# we don't really want to concatenate the sessions for metrics, only the doc
# parameters.
for(section in setdiff(names(ep2$doc), "title")) {
if (endpoint |> startsWith("keras.metrics.")) {
if (section == "examples") next
if (section == "returns") next
}
doc[[section]] %<>% str_flatten_lines(ep2$doc[[section]], .)
}
doc <<- doc
params <<- params
})
tags <- make_roxygen_tags(endpoint, py_obj, type)
if(!is.null(doc$call_arguments) &&
endpoint != "keras.layers.Bidirectional") {
# convert `Call Arguments:` section to a md list.
doc$call_arguments %<>%
parse_params_section() %>%
{glue("- `{names(.)}`: {unname(.)}") }
str_flatten_lines()
## This kinda works, but fails if there is a nested list within,
## like w/ layer_additive_attention
# doc$call_arguments %<>%
# parse_params_section() %>%
# {glue(" \\item{<names(.)>}{<unname(.)>}",
# .open = "<", .close = ">")} %>%
# str_flatten_lines() %>%
# sprintf("\\describe{\n%s\n}", .)
# browser()
}
# r wrapper parts
arg_transformers <- get_arg_transformers(endpoint, py_obj, params)
r_fn <- make_r_fn(endpoint, py_obj, type, arg_transformers)
local({
if (length(undocumented_params <-
setdiff(names(formals(r_fn)),
unlist(strsplit(names(params) %||% character(), ","))))) {
# if(endpoint == "keras.layers.Add") browser()
maybe_add_params <- param_augment_registry %>%
.[str_detect(endpoint, glob2rx(names(.))) |
str_detect(paste0(module, ".", name), glob2rx(names(.)))] %>%
unname() %>% unlist(recursive = FALSE)
for(new_param_name in intersect(names(maybe_add_params), undocumented_params))
params[[new_param_name]] <<- maybe_add_params[[new_param_name]]
}
})
if(!quiet)
local({
if (length(undocumented_params <-
setdiff(names(formals(r_fn)),
unlist(strsplit(names(params) %||% character(), ","))))) {
x <- list2("{endpoint}" := map(set_names(undocumented_params), ~"see description"))
}
})
local({
frmls <- formals(r_fn)
params_documented_but_missing <- names(params) |> setdiff(names(frmls))
if(length(params_documented_but_missing)) {
frmls <- c(frmls, lapply(purrr::set_names(params_documented_but_missing), \(n) NULL))
formals(r_fn) <<- frmls
}
})
# finish
roxygen <- dump_roxygen(doc, params, tags)
dump <- dump_keras_export(roxygen, r_name, r_fn)
rm(quiet)
as.list(environment())
})
attributes(mk_export) <- attributes(.mk_export)
dump_keras_export <- function(roxygen, r_name, r_fn) {
roxygen %<>%
str_split_lines() %>%
str_c("#' ", .)
fn_def <- str_flatten_lines(str_c(r_name, " <-"),
deparse(r_fn))
str_flatten_lines(roxygen, fn_def) |>
str_split_lines() |> str_trim("right") |>
str_flatten_lines()
}
dump_roxygen <- function(doc, params, tags) {
params <- params %>%
lapply(str_trim) %>%
lapply(glue::trim) %>%
{ glue("@param {names(.)}\n{list_simplify(., ptype = '')}") } %>%
str_flatten("\n\n")
doc$arguments <- NULL
# for(rd_sec in c("description", "details", "note")) {
# if(grepl("allback that terminates traini", doc$title)) browser()
for(rd_sec_name in c("description")) {
rd_sec <- doc[[rd_sec_name]] %||% ""
# if(nzchar(rd_sec))
doc[[rd_sec_name]] <- str_flatten(c(glue("@{rd_sec_name}"), str_trim(rd_sec)), "\n")
# else
# doc[[rd_sec_name]] <- NULL
}
# [^(] # is not followed by an opening parens (indicating a md link)
# r"---((?<!`)\[([0-9,\-\[\]]+)\](?!\())---"
# m <- "(?<!`)\\[([0-9,\\- >.\\[\\]]+)\\](?!\\()"
# r <- "`\\[\\1\\]`"
# doc$description %<>% gsub(m, r, ., perl = TRUE)
if(!is.null(returns <- doc$returns)) {
returns %<>% str_flatten_lines() %>%
glue::trim() %>%
str_flatten_lines("@returns", .)
doc$returns <- NULL
}
main <- lapply(names(doc), \(nm) {
md_heading <- if (nm %in% c("description", "title", "details")) # "note"
character() else
snakecase::to_title_case(nm, prefix = "# ")
# if(nm == "description" && doc$description == "@")
str_flatten(c(md_heading, doc[[nm]]), "\n")
})
tags <- imap(tags, function(x, name) {
glue("@{name} {x}") |> str_trim("right")
}) |> str_flatten_lines()
# family <- tags$family %>%
# {glue("@family {.}")} %>%
# str_flatten_lines()
#
# tags$family <- NULL
# tags <- tags %>%
# {glue("@{names(.)} {list_simplify(.)}")} %>%
# str_flatten("\n")
# tags <- c(family, tags) %>% str_flatten_lines()
roxygen <- c(main, returns, params, tags) %>%
str_flatten("\n\n") %>%
str_split_lines() %>%
str_trim("right") %>%
str_flatten_lines()
# compact roxygen
roxygen %<>% {
while (nchar(.) != nchar(. <- gsub(strrep("#'\n", 2), "#'\n", ., fixed = TRUE))) {}
while (nchar(.) != nchar(. <- gsub(strrep("\n", 3), "\n\n", ., fixed = TRUE))) {}
.
}
roxygen %<>% str_replace(fixed("@description\n\n@"), "@")
str_flatten_lines(roxygen) |>
str_split_lines() |>
str_trim("right") |>
str_flatten_lines()
}
mk_layer_activation_selu <- function() {
selu <- mk_export("keras.activations.selu")
selu$name <- "selu"
selu$module <- keras$layers$ReLU$`__module__` %>%
gsub("relu", "selu", ., ignore.case = TRUE)
selu$endpoint <- "keras.layers.~selu" # pretend
selu$r_name <- "layer_activation_selu"
selu$r_fn <- zap_srcref(function(object, ...) {
layer_activation(object, activation = "selu", ...)
})
selu$tags$inheritDotParams <- "layer_activation"
selu$tags$family <- "activation layers"
selu$params$object <- "tensor or model"
selu$dump <- with(selu, dump_keras_export(doc, params,
tags, r_name, r_fn))
selu$type <- "layer"
selu
}
get_docstring <-
get_fixed_docstring <- function(endpoint) {
d <- inspect$getdoc(py_eval(endpoint)) %||% ""
if(d == "") {
return(switch(endpoint,
"keras.random.dropout" = r"{random dropout
Randomly set some portion of values in the tensor to 0.}",
""))
}
d %<>% str_remove_non_ascii()
# if(str_detect_non_ascii(d))
# d <- d %>% iconv(to = "ASCII") %>% gsub("??", "", ., fixed = TRUE)
replace <- function(old_txt, new_txt) {
d <<- stringi::stri_replace_first_fixed(d, old_txt, new_txt)
}
switch %(% { endpoint
keras.ops.average_pool = `,`
keras.ops.depthwise_conv = `,`
keras.ops.separable_conv = `,`
keras.ops.conv_transpose = replace(
"`(batch_size,)` + inputs_spatial_shape ",
"`(batch_size,) + inputs_spatial_shape ")
keras.ops.scatter = replace(
" updates: A tensor, the values to be set at `indices`.",
" values: A tensor, the values to be set at `indices`."
)
keras.metrics.RecallAtPrecision = replace(
" num_thresholds: (Optional) Defaults to 200. The number of thresholds",
" num_thresholds: (Optional) Defaults to 200. The number of thresholds"
)
# keras.ops.einsum = replace(
# "operands: The operands to compute the Einstein sum of.",
# "*args: The operands to compute the Einstein sum of.")
NULL
}
d <- trim(d)
d %<>% str_split_lines()
if("Standalone usage:" %in% d &&
!any(c("Usage:", "Examples:") %in% d)) {
d %<>% replace_val("Standalone usage:", "Usage:\n\nStandalone usage:")
}
d %<>% str_flatten_lines()
# d %>% str_split_lines() %>%
# d %>% str_split_lines() %>%
# d <- str_replace("^(\\s*)Standalone usage:$", "\\1Examples:\n\\1Standalone usage:")
# replace("Standalone usage:\n", "Examples:\n\nStandalone usage:\n")
d
}
# ---- man-src ----
format_py_signature <- function(x, name = NULL) {
if (is_string(x)) # endpoint
x <- py_eval(x)
if (!inherits(x, "inspect.Signature"))
x <- inspect$signature(x)
x <- py_str(x)
if (length(xc <- str_split_1(x, ",")) >= 4) {
x <- xc |> str_trim() |> str_flatten(",\n ")
str_sub(x, 1, 1) <- "(\n "
str_sub(x, -1, -1) <- "\n)"
}
x <- str_flatten(c(name, x))
as_glue(x)
}
# format_py_signature(keras$layers$Dense)
# format_py_signature(keras$Model)
format_man_src_0 <- function(endpoint) {
as_glue(str_flatten_lines(
endpoint,
"__signature__",
format_py_signature(endpoint),
"__doc__",
get_docstring(endpoint)
))
}
get_endpoint_tether_doc <- function(endpoint) {
as_glue(str_flatten_lines(
endpoint,
"__doc__",
get_docstring(endpoint),
"__signature__",
format_py_signature(endpoint)
))
}
# format_man_src_0("keras.Model")
#
# df %>%
# rowwise() %>%
# mutate(dump_man_src_0 = {
# dir_create(path("man-src", r_name))
# endpoint |>
# format_man_src_0() |>
# write_lines(path("man-src", r_name, "0-upstream.md"))
# })
# unlink("man-src/k_absolute", recursive = T)
git <- function(..., retries = 4L, valid_exit_codes = 0L) {
for(i in seq(retries)) {
res <- suppressWarnings(system2t("git", c(...)))
if(identical(res, 128L)) {
# probably .git/index.lock contention with vscode
Sys.sleep(.1)
next
} else if (any(map_lgl(valid_exit_codes, identical, res))) {
break
} else {
cat("res <- "); dput(res)
stop("non-0 exit from git ", str_split1_on_first(..1, " ")[[1]])
}
}
res
}
git_restore <- function() {
git("restore R")
git("restore man")
git("restore man-src")
git("restore --staged man-src")
}
view_translation_diff <- function(r_name) {
system2("code", c("--diff",
sprintf("man-src/%s/1-formatted.md", r_name),
sprintf("man-src/%s/2-translated.Rmd", r_name)
))
}
man_src_pull_upstream_updates <- function(directories = dir_ls("man-src/", type = "directory"),
write_out_only_formatted = FALSE) {
message(deparse1(sys.call()))
vscode_settings_file <- ".vscode/settings.json"
if(file.exists(vscode_settings_file)) {
vscode_settings <- og_vscode_settings <-
jsonlite::read_json(vscode_settings_file)
vscode_settings %<>% modifyList(list("git.autorefresh" = FALSE,
"git.autofetch" = FALSE))
jsonlite::write_json(vscode_settings, vscode_settings_file)
withr::defer(jsonlite::write_json(og_vscode_settings, vscode_settings_file,
pretty = TRUE))
system("code -s", intern = TRUE) # force rereading of settings.json?
}
conflicts <- FALSE
directories |>
set_names(basename) |>
lapply(\(dir) {
dir %<>% as_fs_path()
old_upstream <- read_lines(path(dir, "0-upstream.md"))
endpoint <- old_upstream[1]
py_obj <- NULL
tryCatch({ py_obj <- py_eval(endpoint) },
error = function(e) NULL,
python.builtin.AttributeError = function(e) NULL)
if(is.null(py_obj)) {
unlink(dir, recursive = TRUE)
return()
}
old_upstream <- str_flatten_lines(old_upstream)
new_upstream <- format_man_src_0(endpoint)
# if(new_upstream == old_upstream) return() # nothing to update
export <- mk_export(endpoint)
old_formatted <- read_file(dir/ "1-formatted.md") |> str_trim()
new_formatted <- export$roxygen
if(new_formatted == old_formatted) return() # nothing to update
# write_lines(new_formatted); return()
if(write_out_only_formatted) {
write_lines(export$roxygen, dir/"1-formatted.md")
return()
}
if (file.exists(dir / "2-translated.Rmd"))
git(
"diff --no-index",
"-U1",
# "-U0",
# "--diff-algorithm=minimal",
paste0("--output=", dir / "translate.patch"),
dir / "1-formatted.md",
dir / "2-translated.Rmd",
valid_exit_codes = c(0L, 1L)
)
write_lines(new_upstream, dir/"0-upstream.md")
write_lines(export$roxygen, dir/"1-formatted.md")
write_lines(export$roxygen, dir/"2-translated.Rmd")
if (!file.exists(dir / "translate.patch") ||
!length(patch <- read_lines(dir / "translate.patch")))
return()
# return()
patch[c(1L, 3L)] %<>% str_replace(fixed("/1-formatted.md"), "/2-translated.Rmd")
# patch <- patch[-2] # drop index <hash>..<hash> line
write_lines(patch, dir / "translate.patch")
git("add", dir/"2-translated.Rmd")
exit_code <-
git("apply --3way --recount --allow-empty",
"--unidiff-zero",
"--ignore-whitespace",
# "--whitespace=fix",
dir/"translate.patch",
valid_exit_codes = c(0L, 1L))
if(exit_code)
conflicts <<- TRUE
}) |>
invisible()
if (conflicts) {
continue <- FALSE
message("Translation conflicts encountered, resolve before continuing.")
if (interactive()) {
continue <- askYesNo("Type 'yes' when ready to continue.")
}
# if(length(git("diff --cached --name-only --diff-filter=U")))
if (continue)
stop("Merge Conflicts! Resolve before continuting.")
}
}
man_src_render_translated <- function(directories = dir_ls("man-src/", type = "directory")) {
message(deparse1(sys.call()))
include("tools/knit.R")
directories |>
as_fs_path() |>
# set_names(basename) %>%
purrr::walk(\(dir) {
knit_man_src(dir / "2-translated.Rmd")
# withr::local_dir(dir)
# message("rendering: ", dir)
# keras$utils$clear_session()
# # Set knitr options to halt on errors
# knitr::opts_chunk$set(error = FALSE)
# knitr::knit("2-translated.Rmd", "3-rendered.md",
# quiet = TRUE, envir = new.env(parent = globalenv()))
# x <- read_lines("3-rendered.md")
# # TODO: these filters should be confined to chunk outputs only,
# # probably as a knitr hook
# x <- x |> str_replace_all(" at 0x[0-9A-F]{9}>$", ">")
# x <- x[!str_detect(x, r"{## .*rstudio:run:reticulate::py_last_error\(\).*}")]
# x <- x[!str_detect(x, r"{## .*reticulate::py_last_error\(\).*}")]
# x |> write_lines("3-rendered.md")
})
}
# ---- vignettes ----
tutobook_to_rmd <- function(path_to_tutobook = NULL, outfile = NA, tutobook_text = NULL) {
if(is.null(tutobook_text))
tutobook_text <- readLines(path_to_tutobook)
stopifnot(isTRUE(is.na(outfile)) || is_string(outfile) || is.null(outfile) || isFALSE(outfile))
# NA == infer file path from tutobook path
# string == outfile path
# NULL/FALSE == return munged text, no outfile
if (isTRUE(is.na(outfile))) {
if (!missing(path_to_tutobook)) {
outfile <- path_ext_set(path_to_tutobook, "Rmd")
} else
outfile <- FALSE
}
rmd_text <- try({ .tutobook_to_rmd(tutobook_text, tether = path_to_tutobook) }, silent = TRUE)
if(inherits(rmd_text, "try-error")) {
message("converting failed: ", path_to_tutobook)
warning(rmd_text)
browser()
return()
}
# new_filename <- basename(path_to_tutobook) %>% fs::path_ext_set(".Rmd")
if(!is_string(outfile))
return(rmd_text)
fs::dir_create(dirname(outfile))
writeLines(rmd_text, outfile)
invisible(outfile)
}
.tutobook_to_rmd <- function(tutobook, tether = NULL) {
stopifnot(is.character(tutobook))
df <- tibble(
line = str_split_lines(tutobook)
)
df %>%
mutate(
is_delim = startsWith(line, '"""'),
section_id = cumsum(is_delim),
is_code = !(section_id %% 2) & !is_delim,
delim_header = if_else(is_delim, str_replace(line, '^"""', ""), NA)) %>%
group_by(section_id) %>%
mutate(section_type = zoo::na.locf0(delim_header)) %>%
ungroup() %>%
filter(!is_delim) %>%
group_by(section_id, is_code, section_type) %>%
dplyr::group_map(\(.x, .grp) {
if(.grp$section_id == 1) {
x <- str_split_fixed(.x$line, ": ", 2)
x[,1] %<>% snakecase::to_snake_case() %<>% str_replace_all("_", "-")
x <- rlang::set_names(nm = x[,1], as.list(x[,2]))
x$output <- "rmarkdown::html_vignette"
x$knit <- '({source(here::here("tools/knit.R")); knit_vignette})'
x$tether <- x$tether %||% tether
# # x$repo <- https://github.com/rstudio/keras
frontmatter <- yaml::as.yaml(x) |> str_trim("right")
out <- str_flatten_lines("---", frontmatter, "---")
return(out)
}
out <- .x$line |>
str_trim("right") |>
str_flatten_lines() |>
str_trim()
if(out == "")
return("")
if(.grp$is_code) {
type <- .grp$section_type
if(is.na(type) || type == "")
type <- "python"
out <- str_flatten_lines(sprintf("```%s", type), out, "```")
} else {
out <- str_flatten_and_compact_lines(out)
}
out
}) %>%
keep(., . != "") %>%
str_flatten("\n\n")
}
# ---- misc utils ----
endpoint_to_expr <- function(endpoint) {
py_obj_expr <- endpoint |> str_split_1(fixed(".")) |>
glue::backtick() |> str_flatten("$") |> str2lang()
py_obj_expr
}
walk_roxy_blocks <- function(fn) {
# function that take blocklines as argument
withr::with_options(
list(keep.source.pkgs = TRUE, keep.parse.data.pkgs = TRUE), {
env <- roxygen2::env_package(".")
})
blocks <- roxygen2::parse_package(env = env)
walk(blocks, fn, .progress = TRUE)
}
get_block_line_range <- function(block, section = c("roxygen", "code")) {
## Returns the line number pair c(<start_of_roxygen_block>, <end_of_fn_body>)
# this only contains the start line for each tag
roxy_line_range <- code_line_range <- NULL
if("roxygen" %in% section) {
roxy_line_range <- range(unlist(lapply(block$tags, \(tag) {
if (is.null(tag$raw)) NULL # skip '<generated>' tags that point to sourcecode
else tag$line
})))
}
if ("code" %in% section) {
call_start_line <- getSrcLocation(block$object$value, "parse") %||% -1L
call_end_line <- getSrcLocation(block$object$value, "parse", FALSE) %||% -1L
call_body_start_line <- tryCatch(
getSrcLocation(body(block$object$value), "line")[1],
warning = function(w) NULL,
error = function(e) NULL
) %||% call_end_line
code_line_range <- range(c(call_start_line, call_end_line))
# signature_line_range <- c(call_start_line, call_body_start_line)
}
line_range <- range(c(roxy_line_range, code_line_range))
if(any(line_range == -1L)) NULL
else as.integer(line_range)
}
modify_roxy_block_lines <- function(fn) {
withr::with_options(
list(keep.source.pkgs = TRUE, keep.parse.data.pkgs = TRUE), {
env <- roxygen2::env_package(".")
})
blocks <- roxygen2::parse_package(env = env)
files <- file_lines_manager()
for (block in blocks) {
line_range <- get_block_line_range(block, section = "roxygen")
block_lines <- files$get_lines(block$file, line_range)
new_block_lines <- fn(block_lines, block = block)
if(is.character(new_block_lines))
files$set_lines(block$file, line_range, new_block_lines)
}
files$write_out()
}
combine_files <- function(out, ..., sep = "\n") {
con <- file(out, open = "at")
on.exit(close(con))
for (path in list(...))
if(file.exists(path))
writeLines(c(sep, readLines(path)), con = con)
invisible(out)
}
delete_empty_files <- function(...) {
for(f in list.files(..., full.names = TRUE)) {
x <- trimws(readLines(f))
if(!any(nzchar(x))) {
message("deleting: ", f)
unlink(f)
}
}
}
delete_empty_directories <- function(path = ".", ...) {
for(p in c(path, ...)) {
for (d in rev(list.dirs(p, recursive = TRUE)))
if(!length(list.files(d, all.files = TRUE, no.. = TRUE))) # empty directory
{message("Unlinking: ", d); unlink(d)}
}
}
file_lines_manager <- function() {
.files <- new.env(parent = emptyenv())
get_lines <- function(file, line_range) {
file_lines <- .files[[file]] %||% { .files[[file]] <- readLines(file) }
# check_line_range(line_range, max = length(file_lines))
file_lines[line_range[1]:line_range[2]]
}
set_lines <- function(file, line_range, new_lines) {
file_lines <- .files[[file]] %||% { .files[[file]] <- readLines(file) }
# check_line_range(line_range, max = length(file_lines))
file_lines[line_range[1]:line_range[2]] <- NA
file_lines[line_range[1]] <- str_flatten_and_compact_lines(new_lines)
attr(file_lines, "modified") <- TRUE
.files[[file]] <- file_lines
}
write_out <- function() {
for(file in names(.files)) {
# message("Writing out: ", file)
file_lines <- .files[[file]]
if(!isTRUE(attr(file_lines, "modified", TRUE))) next
# if (str_flatten_and_compact_lines(readLines(file)) ==
# str_flatten_and_compact_lines(file_lines)) next
cli_alert_warning("Updating source file: {.file {file}}")
writeLines(file_lines[!is.na(file_lines)], file)
rm(list = file, envir = .files)
}
}
environment()
}
move_roxy_blocks_to_mansrc <- function() {
.files <- new.env(parent = emptyenv())
get_file_lines <- function(file, line_range) {
file_lines <- .files[[file]] %||% { .files[[file]] <- readLines(file) }
# check_line_range(line_range, max = length(file_lines))
file_lines[line_range[1]:line_range[2]]
}
set_file_lines <- function(file, line_range, new_lines) {
file_lines <- .files[[file]] %||% { .files[[file]] <- readLines(file) }
# check_line_range(line_range, max = length(file_lines))
file_lines[line_range[1]:line_range[2]] <- NA
file_lines[line_range[1]] <- str_flatten_and_compact_lines(new_lines)
.files[[file]] <- file_lines
}
write_out_updated_files <- function() {
for(file in names(.files)) {
cli_alert_warning("Updating source file: {.file {file}}")
# message("Writing out: ", file)
file_lines <- .files[[file]]
writeLines(file_lines[!is.na(file_lines)], file)
}
}
get_block_line_range <- function(block, section = c("roxygen", "code")) {
## Returns the line number pair c(<start_of_roxygen_block>, <end_of_fn_body>)
# this only contains the start line for each tag
roxy_line_range <- code_line_range <- NULL
if("roxygen" %in% section) {
roxy_line_range <- range(unlist(lapply(block$tags, \(tag) {
if (is.null(tag$raw)) NULL # skip '<generated>' tags that point to sourcecode
else tag$line
})))
}
if ("code" %in% section) {
call_start_line <- getSrcLocation(block$object$value, "parse") %||% -1L
call_end_line <- getSrcLocation(block$object$value, "parse", FALSE) %||% -1L
call_body_start_line <- tryCatch(
getSrcLocation(body(block$object$value), "line")[1],
warning = function(w) NULL,
error = function(e) NULL
) %||% call_end_line
code_line_range <- range(c(call_start_line, call_end_line))
# signature_line_range <- c(call_start_line, call_body_start_line)
}
line_range <- range(c(roxy_line_range, code_line_range))
if(any(line_range == -1L)) NULL
else as.integer(line_range)
}
walk_roxy_blocks(function(block) {
roxy_line_range <- get_block_line_range(block, "roxygen") %||% return()
lines <- get_file_lines(block$file, line_range = roxy_line_range)
if(length(lines) <= 4) return()
# or @noRd or @describeIn or ...
name <- doctether:::get_block_name(block)
new_file <- fs::path("man-src", name, name, ext = "Rmd")
new_rendered_file <- new_file |> fs::path_ext_set("md")
if(file.exists(new_file))
return(warning("File seemingly already exists, skipping, ", new_file))
set_file_lines(block$file, roxy_line_range, c(
glue('# "{new_file}"'),
glue("#' @eval readLines('{new_rendered_file}')")
))
fs::dir_create(dirname(new_file))
writeLines(c(
"---",
'knit: ({source(here::here("tools/knit.R")); knit_man_src})',
"---",
"",
sub("^#' ?", "", lines)
), new_file)
})
write_out_updated_files()
}
move_mansrc_to_roxy_blocks <- function() {
}
str_prefix <- function(x, prefix, ...) str_c(prefix, x, ...)
py_help_roxified <- function(object) {
obj_nm <- deparse1(call("roxify", call("py_help", substitute(object))))
if(is_string(object))
object <- py_eval(object)
help <- py_capture_output(import_builtins()$help(object), type = "stdout")
help <- help |>
str_split_lines() |>
str_replace("^ \\| ", " ") |>
str_prefix("#' ") |>
str_trim("right") |>
str_flatten_lines()
cat_cb(help, obj_nm = obj_nm)
invisible(help)
}
if(FALSE) {
list.files("R", full.names = TRUE) %>%
walk(\(f) {
x <- readLines(f)
i <- startsWith(x, "#' + <https://www.tensorflow.org/api_docs/python/tf/keras/")
x[i] %<>% sub("#'", "# ", .)
# x <- sub(
# "^#' (+ <https://www.tensorflow.org/api_docs/python/tf/keras/.+>)$",
# "# \\1",
# xx
# )
writeLines(x, f)
})
}
# rename2(list(a = "b", a = z))
# rx <- regex(comments = TRUE, pattern = r"--(
# `? # Optional backtick at the start
# ( # Open capture group for main pattern
# \[ # Match an opening square bracket
# [0-9,\[\]\ -><.]+ # Capture numbers, brackets, and specific symbols
# \] # Match a closing square bracket
# ) # Close capture group
# `? # Optional backtick at the end
# (?!\() # Lookahead to ensure no opening parenthesis follows
# )--")
# rep <- r"--(`\1`)--"
# doc$description %<>% str_replace_all(rx, rep)
#
# x <- "e.g. `[[4], [20]] -> [[0.25, 0.1], [0.6, -0.2]]`"
# rx <- regex(comments = TRUE, r"--(
# `?( # open capture group, without backtick
# \[ # opening bracket, maybe starts with a backtick
# [0-9,\[\]\ -><.]+ # anything that looks like a number range
# \]
# ) # close capture group
# `? # closing bracket maybe w backtick
# (?!\() # is not followed by an opening parens indicating a md link
# )--")
# rep <- "`\\1`"
# str_replace_all(x, rx, rep)
if(FALSE) {
# for(dir in list.dirs("tools/raw"))
list.dirs("tools/raw") %>%
lapply(function(d) {
withr::local_dir(d)
if(any(startsWith(dir(), "gpt")))
system(paste("code -d r_wrapper.R", dir(pattern = "^gpt-4")[1]))
}) %>%
invisible()
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.