R/utils.R

Defines functions as_integer_list as_tensors as_float_tensor as_tf_float as_axis as_nullable_integer normalize_shape

# from keras
# Helper function to coerce shape arguments to tuple
normalize_shape <- function(shape) {
  # reflect NULL back
  if (is.null(shape))
    return(shape)

  # if it's a list or a numeric vector then convert to integer
  if (is.list(shape) || is.numeric(shape)) {
    shape <- lapply(shape, function(value) {
      if (!is.null(value))
        as.integer(value)
      else
        NULL
    })
  }

  # coerce to tuple so it's iterable
  reticulate::tuple(shape)
}

# from keras
as_nullable_integer <- function(x) {
  if (is.null(x))
    x
  else
    as.integer(x)
}

# from keras
as_axis <- function(axis) {
  if (length(axis) > 1) {
    sapply(axis, as_axis)
  } else {
    axis <- as_nullable_integer(axis)
    if (is.null(axis))
      axis
    else if (axis == -1L)
      axis
    else
      axis - 1L
  }
}

as_tf_float <- function(x) {
  tf$cast(x, tf$float32)
}

as_float_tensor <- function(x) {
  if (is.list(x)) {
    Map(as_tf_float, x)
  } else {
    as_tf_float(x)
  }
}

as_tensors <- function(x) {
  if (is.list(x)) {
    Map(tf$convert_to_tensor, x)
  } else {
    tf$convert_to_tensor(x)
  }
}

as_integer_list <- function(x) {
  lapply(x, as.integer)
}

capture_args <- get("capture_args", asNamespace("keras"))

Try the tfprobability package in your browser

Any scripts or data that you put into this service are public.

tfprobability documentation built on Sept. 1, 2022, 5:07 p.m.