Nothing
#' Writes a list of tensors to the safetensors format
#'
#' @param tensors A named list of tensors. Currently only torch tensors are supported.
#' @param path The path to save the tensors to. It can also be a binary connection, as eg,
#' created with `file()`.
#' @param ... Currently unused.
#' @param metadata An optional string that is added to the file header. Possibly
#' adding additional description to the weights.
#'
#' @examples
#' if (rlang::is_installed("torch") && torch::torch_is_installed()) {
#' tensors <- list(x = torch::torch_randn(10, 10))
#' temp <- tempfile()
#' safe_save_file(tensors, temp)
#' safe_load_file(temp)
#'
#' ser <- safe_serialize(tensors)
#' }
#'
#' @returns The path invisibly or a raw vector.
#'
#' @export
safe_save_file <- function(tensors, path, ..., metadata = NULL) {
if (any(duplicated(names(tensors)))) {
cli::cli_abort("Duplicated names are not allowed in {.arg tensors}")
}
if (is.character(path)) {
con <- file(path, open = "wb")
on.exit({close(con)}, add = TRUE)
} else {
con <- path
}
write_safe(tensors, metadata, con)
invisible(path)
}
#' @describeIn safe_save_file Serializes the tensors and returns a raw vector.
#' @export
safe_serialize <- function(tensors, ..., metadata = NULL) {
con <- rawConnection(raw(), open = "wb")
on.exit({close(con)}, add = TRUE)
safe_save_file(tensors, con, metadata = metadata)
rawConnectionValue(con)
}
write_safe <- function(tensors, metadata, con) {
meta <- make_meta(tensors, metadata)
meta_raw <- charToRaw(jsonlite::toJSON(meta, auto_unbox = TRUE))
# write the metadatasize as a 64bit int
writeBin(length(meta_raw), con = con, size = 8L)
writeBin(meta_raw, con = con)
for (tensor in tensors) {
buf <- tensor_buffer(tensor)
writeBin(buf, con = con)
}
}
make_meta <- function(tensors, metadata) {
meta_ <- structure(
vector(mode = "list", length = length(tensors)),
names = names(tensors)
)
if (!is.null(metadata)) {
meta_[["__metadata__"]] <- validate_metadata(metadata)
}
pos <- 0L
for (nm in names(tensors)) {
meta <- tensor_meta(tensors[[nm]])
meta$data_offsets <- c(pos, pos + size_from_meta(meta))
pos <- meta$data_offsets[2]
meta_[[nm]] <- meta
}
meta_
}
tensor_buffer <- function(x) {
UseMethod("tensor_buffer")
}
tensor_buffer.torch_tensor <- function(x) {
torch::buffer_from_torch_tensor(x$cpu())
}
tensor_meta <- function(x) {
UseMethod("tensor_meta")
}
tensor_meta.torch_tensor <- function(x) {
list(
shape = as.list(x$shape), # we must store as a list to avoid simplification
dtype = torch_dtype_to_safe(x$dtype)
)
}
torch_dtype_to_safe <- function(x) {
if (x == torch::torch_float()) {
return("F32")
} else if (x == torch::torch_float16()) {
return("F16")
} else if (x == torch::torch_float64()) {
return("F64")
} else if (x == torch::torch_bool()) {
return("BOOL")
} else if (x == torch::torch_uint8()) {
return("U8")
} else if (x == torch::torch_int8()) {
return("I8")
} else if (x == torch::torch_int16()) {
return("I16")
} else if (x == torch::torch_int32()) {
return("I32")
} else if (x == torch::torch_int64()) {
return("I64")
} else {
cli::cli_abort("Unsupported data type {.val {x}}")
}
}
size_from_meta <- function(meta) {
numel <- prod(as.numeric(meta$shape))
el_size <- if (meta$dtype == "F32") {
4L
} else if (meta$dtype == "F16") {
2L
} else if (meta$dtype == "F64") {
8L
} else if (meta$dtype == "U8") {
1L
} else if (meta$dtype == "I8") {
1L
} else if (meta$dtype == "I16") {
2L
} else if (meta$dtype == "I32") {
4L
} else if (meta$dtype == "I64") {
8L
} else if (meta$dtype == "BOOL") {
1L
} else {
cli::cli_abort("Unsupported dtype {.val {meta$dtype}}")
}
as.integer(numel*el_size)
}
validate_metadata <- function(x) {
if (!rlang::is_list(x)) cli::cli_abort("{.arg metadata} must be a list.")
if (!rlang::is_named(x)) cli::cli_abort("{.arg metadata} must be a named list.")
lapply(x, function(item) {
if (!rlang::is_scalar_character(item))
cli::cli_abort("{.arg metadata} must be a named list of scalar characters.")
})
x
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.