
Defines functions names_if set_names with_dir deparse_str deparse_fun replace dust_type generate_dust_support_sum is_call odin_dust_file cpp_namespace cpp_when cpp_block cpp_args cpp_function dust_plus_1 dust_plus_y dust_minus_1 dust_array_access squote dquote dust_fold_call dust_flatten_eqs collector vcapply viapply vlapply `%||%`

`%||%` <- function(a, b) { # nolint
  if (is.null(a)) b else a

vlapply <- function(x, fun, ...) {
  vapply(x, fun, logical(1), ...)

viapply <- function(x, fun, ...) {
  vapply(x, fun, integer(1), ...)

vcapply <- function(x, fun, ...) {
  vapply(x, fun, character(1), ...)

collector <- function(...) {

dust_flatten_eqs <- function(...) {

dust_fold_call <- function(...) {

dquote <- function(...) {

squote <- function(...) {

dust_array_access <- function(target, index, data, meta, supported, gpu) {
  mult <- data$elements[[target]]$dimnames$mult

  f <- function(i) {
    index_i <- dust_minus_1(index[[i]], i > 1, data, meta, supported, gpu)
    if (i == 1) {
    } else {
      mult_i <- generate_dust_sexp(mult[[i]], data, meta, supported, gpu)
      sprintf("%s * %s", mult_i, index_i)

  paste(vcapply(rev(seq_along(index)), f), collapse = " + ")

dust_minus_1 <- function(x, protect, data, meta, supported, gpu) {
  if (is.numeric(x)) {
    generate_dust_sexp(x - 1L, data, meta, supported, gpu)
  } else {
    x_expr <- generate_dust_sexp(x, data, meta, supported, gpu)
    sprintf(if (protect) "(%s - 1)" else "%s - 1", x_expr)

dust_plus_y <- function(x, y, rewrite) {
  if (is.numeric(x) && is.numeric(y)) {
    rewrite(x + y)
  } else {
    sprintf("%s + %s", rewrite(x), rewrite(y))

dust_plus_1 <- function(x, rewrite) {
  dust_plus_y(x, 1, rewrite)

cpp_function <- function(return_type, name, args, body, const = FALSE) {
  c(cpp_args(return_type, name, args, const), paste0("  ", body), "}")

cpp_args <- function(return_type, name, args, const) {
  args_str <- paste(sprintf("%s %s", names(args), unname(args)),
                    collapse = ", ")
  sprintf("%s %s(%s)%s {",
          return_type, name, args_str, if (const) " const" else "")

cpp_block <- function(body) {
  c("{", paste0("  ", body), "}")

cpp_when <- function(condition, body) {
  c(sprintf("if (%s) {", condition), paste0("  ", body), "}")

cpp_namespace <- function(name, code) {
  c(sprintf("namespace %s {", name), code, "}")

odin_dust_file <- function(path) {
  system.file(path, package = "odin.dust", mustWork = TRUE)

is_call <- function(expr, symbol) {
  is.recursive(expr) && identical(expr[[1L]], as.name(symbol))

generate_dust_support_sum <- function(rank) {
  if (rank == 1) {
    ret <- list(
      name = "odin_sum1",
      declaration = c(
        "template <typename real_type, typename container>",
        paste("__host__ __device__ real_type",
              "odin_sum1(const container x, size_t from, size_t to);")),
      definition = NULL)
  } else {
    ## There are a series of substitutions that need to be made here,
    ## all of which are literal
    tr <- c("double*" = "const container",
            "double" = "real_type")
    head <- "template <typename real_type, typename container>"
    ret <- lapply(odin:::generate_c_support_sum(rank), replace, tr)
    for (v in c("declaration", "definition")) {
      s <- ret[[v]]
      s[[1L]] <- paste("__host__ __device__", s[[1L]])
      ret[[v]] <- c(head, s)


dust_type <- function(type) {
         double = "real_type",
         int = "int",
         interpolate_data_constant =
         interpolate_data_linear =
         interpolate_data_spline =
         stop(sprintf("Unknown type '%s'", type)))

replace <- function(x, tr) {
  from <- names(tr)
  for (i in seq_along(tr)) {
    x <- gsub(from[[i]], tr[[i]], x, fixed = TRUE)

deparse_fun <- function(x) {
  str <- paste(sub("\\s+$", "", deparse(x)), collapse = "\n")
  ## Apply a few fixes:
  str <- gsub("function (", "function(", str, fixed = TRUE)
  str <- gsub("\\)\n\\{", ") {", str)
  str <- gsub("\\}\n\\s*else", "} else", str)

deparse_str <- function(x) {
  paste(deparse(x), collapse = "\n")

with_dir <- function(path, code) {
  owd <- setwd(path)

set_names <- function(x, nms) {
  names(x) <- nms

names_if <- function(x) {
