Nothing
#' Writes a list of tensors to the safetensors format
#'
#' @param tensors A named list of tensors. Currently only torch tensors are supported.
#' @param path The path to save the tensors to. It can also be a binary connection, as eg,
#' created with `file()`.
#' @param ... Currently unused.
#' @param metadata An optional string that is added to the file header. Possibly
#' adding additional description to the weights.
#'
#' @examples
#' if (rlang::is_installed("torch") && torch::torch_is_installed()) {
#' tensors <- list(x = torch::torch_randn(10, 10))
#' temp <- tempfile()
#' safe_save_file(tensors, temp)
#' safe_load_file(temp, framework = "torch")
#'
#' ser <- safe_serialize(tensors)
#' }
#'
#' @returns The path invisibly or a raw vector.
#'
#' @export
safe_save_file <- function(tensors, path, ..., metadata = NULL) {
if (any(duplicated(names(tensors)))) {
cli::cli_abort("Duplicated names are not allowed in {.arg tensors}")
}
if (is.character(path)) {
con <- file(path, open = "wb")
on.exit(
{
close(con)
},
add = TRUE
)
} else {
con <- path
}
write_safe(tensors, metadata, con)
invisible(path)
}
#' @describeIn safe_save_file Serializes the tensors and returns a raw vector.
#' @export
safe_serialize <- function(tensors, ..., metadata = NULL) {
con <- rawConnection(raw(), open = "wb")
on.exit(
{
close(con)
},
add = TRUE
)
safe_save_file(tensors, con, metadata = metadata)
rawConnectionValue(con)
}
write_safe <- function(tensors, metadata, con) {
meta <- make_meta(tensors, metadata)
meta_raw <- charToRaw(jsonlite::toJSON(meta, auto_unbox = TRUE))
# write the metadatasize as a 64bit int
writeBin(length(meta_raw), con = con, size = 8L)
writeBin(meta_raw, con = con)
for (tensor in tensors) {
buf <- safe_tensor_buffer(tensor)
writeBin(buf, con = con)
}
}
make_meta <- function(tensors, metadata) {
meta_ <- structure(
vector(mode = "list", length = length(tensors)),
names = names(tensors)
)
if (!is.null(metadata)) {
meta_[["__metadata__"]] <- validate_metadata(metadata)
}
pos <- 0L
for (nm in names(tensors)) {
meta <- safe_tensor_meta(tensors[[nm]])
meta$data_offsets <- c(pos, pos + size_from_meta(meta))
pos <- meta$data_offsets[2]
meta_[[nm]] <- meta
}
meta_
}
#' @title Get raw buffer from a tensor
#' @description
#' Convert a tensor object to a raw buffer in the formated expected by safetensors.
#' @param x (any)\cr
#' Tensor object.
#' @returns (`raw`)
#' @export
safe_tensor_buffer <- function(x) {
UseMethod("safe_tensor_buffer")
}
#' @title Get metadata from a tensor
#' @description
#' Get the metadata from a tensor.
#' @param x (any)\cr
#' Tensor object.
#' @returns (`list`)
#' @export
safe_tensor_meta <- function(x) {
UseMethod("safe_tensor_meta")
}
size_from_meta <- function(meta) {
numel <- prod(as.numeric(meta$shape))
el_size <- if (meta$dtype == "F32") {
4L
} else if (meta$dtype == "F16") {
2L
} else if (meta$dtype == "F64") {
8L
} else if (meta$dtype == "U8") {
1L
} else if (meta$dtype == "I8") {
1L
} else if (meta$dtype == "I16") {
2L
} else if (meta$dtype == "I32") {
4L
} else if (meta$dtype == "I64") {
8L
} else if (meta$dtype == "BOOL") {
1L
} else if (meta$dtype == "U16") {
2L
} else if (meta$dtype == "U32") {
4L
} else if (meta$dtype == "U64") {
8L
} else {
cli::cli_abort("Unsupported dtype {.val {meta$dtype}}")
}
as.integer(numel * el_size)
}
validate_metadata <- function(x) {
if (!rlang::is_list(x)) {
cli::cli_abort("{.arg metadata} must be a list.")
}
if (!rlang::is_named(x)) {
cli::cli_abort("{.arg metadata} must be a named list.")
}
lapply(x, function(item) {
if (!rlang::is_scalar_character(item)) {
cli::cli_abort(
"{.arg metadata} must be a named list of scalar characters."
)
}
})
x
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.