R/ast-classes.R

Defines functions find_arma_cpp_types get_arma_type arma_element_type auto_or_arma_type make_random_var_name new_node_by_chr multi_dispatch_fun unary_reduction_node make_generic_unary_function_class make_binary_operator_class error_msg_unkown_expr

utils::globalVariables(c("private", "self"))
ast_node <- R6::R6Class("ast_node",
  public = list(
    initialize = function(sexp, head, tail_elements = NULL) {
      private$sexp <- sexp
      self$set_head(head)
      self$set_tail_elements(tail_elements)
    },
    set_head = function(node) {
      private$head <- node
    },
    set_scope = function(node) {
      private$scope <- node
    },
    has_scope = function() {
      !is.null(self$get_scope())
    },
    set_parent = function(node) {
      private$parent <- node
    },
    get_parent = function() {
      private$parent
    },
    get_head = function() {
      private$head
    },
    get_tail_elements = function() {
      private$tail_elements
    },
    set_tail_elements = function(nodes) {
      private$tail_elements <- nodes
    },
    get_scope = function() {
      private$scope
    },
    get_sexp = function() {
      private$sexp
    },
    set_sexp = function(sexp) {
      private$sexp <- sexp
    },
    get_cpp_type = function() {
      private$cpp_type
    },
    is_armadillo_cpp_type = function() {
      grepl("^arma", self$get_cpp_type())
    },
    is_scalar_cpp_type = function() {
      !self$is_armadillo_type() && !self$is_auto_type()
    },
    is_auto_cpp_type = function() {
      self$has_auto_cpp_type()
    },
    # TODO: only one function
    has_auto_cpp_type = function() {
      self$get_cpp_type() == "auto"
    },
    set_cpp_type = function(type) {
      private$cpp_type <- type
    },
    emit = function(...) {
      do.call(paste0, list(...))
    },
    compile = function() {
      stop(error_msg_unkown_expr(self$get_sexp()), call. = FALSE)
    },
    find_all_returns = function() {
      private$find_nodes("ast_node_return", stop_at = "ast_node_function")
    },
    find_all_names = function() {
      private$find_nodes("ast_node_name")
    }
  ),
  private = list(
    sexp = NULL,
    head = NULL,
    tail_elements = list(),
    scope = NULL,
    parent = NULL,
    cpp_type = "auto",
    find_nodes = function(node_type, stop_at = NULL) {
      find_elements_rec <- function(node) {
        if (!is.null(stop_at) && any(stop_at %in% class(node))) {
          return(list())
        }
        found_node <- list()
        if (any(node_type %in% class(node))) {
          found_node <- list(node)
        }
        c(found_node, lapply(node$get_tail_elements(), find_elements_rec))
      }
      unlist(find_elements_rec(self))
    }
  )
)

error_msg_unkown_expr <- function(sexp) {
  paste0(
    "Sorry, but the expression:\n\n",
    paste0(deparse(sexp), collapse = "\n"),
    "\n\ncannot be translated into a C++ construct.\n",
    "Not all R functions are supported."
  )
}

ast_node_arma_pairlist <- R6::R6Class(
  classname = "ast_node_arma_pairlist",
  inherit = ast_node,
  public = list(
    compile = function(fun_name = NULL, overwrite_return_type = NULL) {
      fun_args <- self$get_parameter_types()
      input_params <- generate_cpp_input_parameters_code(
        fun_args,
        is_reassigned = function(param_name) {
          self$get_scope()$count_assignments(param_name) > 0L
        }
      )
      self$emit(
        "(", paste0(input_params, collapse = ", "), ")"
      )
    },
    get_parameter_types = function() {
      fun_args <- self$get_tail_elements()
      names <- vapply(fun_args, function(x) x$get_name(), character(1L))
      setNames(lapply(fun_args, function(x) {
        x$get_cpp_type()
      }), names)
    },
    # needs to be done ones
    update_parameter_types = function() {
      tail <- self$get_tail_elements()
      var_names <- self$get_parameter_names()
      for (i in seq_along(tail)) {
        param <- tail[[i]]
        var_name <- as.name(var_names[[i]])
        stopifnot(any(c("ast_node_function_call","ast_node_type_spec", "ast_node_terminal") %in% class(param)))
        x <- param$get_sexp()
        type <- if (!missing(x) && "ast_node_type_spec" %in% class(param)) {
          param$get_cpp_type()
        } else {
          ast_node_type_matrix$new(as.name("type_matrix"), as.name("type_matrix"))$get_cpp_type()
        }
        new_param <- ast_node_name$new(var_name, var_name)
        new_param$set_cpp_type(type)
        new_param$set_parent(self)
        new_param$set_scope(self$get_scope())
        tail[[i]] <- new_param
      }
      self$set_tail_elements(tail)
    },
    get_parameter_names = function() {
      names(self$get_sexp()[-1L])
    }
  )
)

ast_node_function <- R6::R6Class(
  classname = "ast_node_function",
  inherit = ast_node,
  public = list(
    register_input_variables_in_child_scope = function() {
      arguments <- self$get_function_parameters()
      # also set the function body as the scope of the function parameters
      body <- self$get_function_body()
      arguments$set_scope(body)
      parameters <- arguments$get_tail_elements()
      stopifnot("ast_node_block" %in% class(body))
      for (i in seq_along(parameters)) {
        param <- parameters[[i]]
        stopifnot("ast_node_name" %in% class(param))
        dummy_value <- ast_node_dummy$new()
        dummy_value$set_cpp_type(param$get_cpp_type())
        dummy_value$set_parent(self)
        dummy_value$set_scope(body)
        body$register_new_name(param$get_name(), dummy_value)
        # Also we need to change the scope of the input variables to the body
        # as we register the names there
        # TODO: revisit that
        param$set_scope(body)
      }
    },
    is_recursive = function(assigned_name) {
      # check if the assigned name is used as a function call in the function
      length(private$find_all_function_calls(of = assigned_name)) > 0L
    },
    get_cpp_function_type = function() {
      return_type <- self$get_cpp_type()
      if (return_type == "auto") {
        stop("Recursive functions must have an explicit return type.", call. = FALSE)
      }
      argument_types <- self$get_function_parameters()$get_parameter_types()
      # TODO: the following line should be handled by the pairlist, not here
      argument_types <- generate_cpp_input_types(argument_types)
      paste0(
        "std::function<", return_type,
        "(",
        paste0(argument_types, collapse = ", "),
        ")>"
      )
    },
    get_function_body = function() {
      self$get_tail_elements()[[2L]]
    },
    get_function_parameters = function() {
      self$get_tail_elements()[[1L]]
    },
    get_cpp_type = function() {
      body <- self$get_function_body()
      return_types <- body$find_all_returns()
      return_types <- lapply(return_types, function(x) x$get_cpp_type())
      return_type <- if (length(return_types) == 0L) {
        "void"
      } else {
        if (length(unique(return_types)) != 1L) {
          stop("There seem to be multiple possible return types in your function. ",
            "Consider adding type annotations to all `return` statements.",
            "\n\n",
            paste0(deparse(body$get_sexp()), collapse = "\n"),
            call. = FALSE
          )
        }
        return_types[[1L]]
      }
      return_type
    },
    compile = function(fun_name = NULL, overwrite_return_type = NULL, is_recursive = FALSE) {
      stopifnot(length(self$get_tail_elements()) == 2L)
      arguments <- self$get_function_parameters()
      body <- self$get_function_body()
      is_lambda <- is.null(fun_name)
      # first compile arguments and body
      # currently the compilation can alter the scope
      body_compiled <- body$compile()
      args_compiled <- arguments$compile()

      return_type <- if (!is.null(overwrite_return_type)) {
        overwrite_return_type
      } else {
        self$get_cpp_type()
      }
      if (is_lambda) {
        function_prefix_code <- "[&]"
        function_suffix_code <- if (grepl("^arma", return_type) || is_recursive) {
          paste0(" mutable -> ", return_type)
        } else {
          " mutable"
        }
      } else {
        function_prefix_code <- paste0(return_type, " ", fun_name)
        function_suffix_code <- ""
      }
      self$emit(
        function_prefix_code,
        args_compiled,
        function_suffix_code,
        "\n",
        body_compiled
      )
    }
  ),
  private = list(
    find_all_function_calls = function(of) {
      # TODO: implement a get_name or something for function_calls
      Filter(function(x) {
        paste0(deparse(x$get_head()), collapse = "") == of
      }, super$find_nodes("ast_node_function_call"))
    }
  )
)

ast_node_function_call <- R6::R6Class(
  classname = "ast_node_function_call",
  inherit = ast_node,
  public = list(
    get_name = function() {
      paste0(deparse(self$get_sexp()[[1L]]), collapse = "")
    },
    compile = function() {
      val <- self$get_name()
      if (self$has_scope()) {
        if (!self$get_scope()$is_name_defined(val)) {
          # this is an unknown function call
          stop(error_msg_unkown_expr(self$get_sexp()), call. = FALSE)
        }
      }
      arguments <- vapply(self$get_tail_elements(), function(x) {
        x$compile()
      }, character(1L))
      suffix <- if ("ast_node_block" %in% class(self$get_parent())) {
        ";\n"
      } else {
        ""
      }
      self$emit(
        as.character(val),
        "(",
        paste0(arguments, collapse = ", "),
        ")",
        suffix
      )
    },
    get_cpp_type = function() {
      if (self$has_scope() && is.call(self$get_sexp())) {
        scope <- self$get_scope()
        fun_name <- self$get_sexp()[[1L]]
        if (scope$is_name_defined(fun_name)) {
          fun <- scope$get_value_node_for_name(fun_name)
          return(fun$get_cpp_type())
        }
      }
      private$cpp_type
    }
  )
)

ast_node_terminal <- R6::R6Class(
  classname = "ast_node_terminal",
  inherit = ast_node,
  public = list(
    compile = function() {
      if (!self$has_value()) {
        return(self$emit(""))
      }
      val <- self$get_head()
      # val %% 1 == 0, TRUE if val is integer like otherwise FALSE
      # see https://stackoverflow.com/a/3477158/2798441
      # TODO: move to ast_node_scalar
      if (is.numeric(val) && !is.integer(val) && val %% 1 == 0) {
        self$emit(val, ".0")
      } else {
        self$emit(as.character(val))
      }
    },
    has_value = function() {
      TRUE
    }
  )
)

ast_node_scalar <- R6::R6Class(
  classname = "ast_node_scalar",
  inherit = ast_node_terminal
)

ast_node_scalar_double <- R6::R6Class(
  classname = "ast_node_scalar",
  inherit = ast_node_scalar,
  public = list(
    get_cpp_type = function() {
      "double"
    }
  )
)

ast_node_scalar_int <- R6::R6Class(
  classname = "ast_node_int",
  inherit = ast_node_scalar,
  public = list(
    get_cpp_type = function() {
      "int"
    }
  )
)

ast_node_scalar_bool <- R6::R6Class(
  classname = "ast_node_bool",
  inherit = ast_node_scalar,
  public = list(
    get_cpp_type = function() {
      "bool"
    }
  )
)

ast_node_pi <- R6::R6Class(
  classname = "ast_node_pi",
  inherit = ast_node_scalar_double,
  public = list(
    compile = function() {
      self$emit("arma::datum::pi")
    }
  )
)

ast_node_true <- R6::R6Class(
  classname = "ast_node_true",
  inherit = ast_node_scalar_bool,
  public = list(
    compile = function() {
      self$emit("true")
    }
  )
)

ast_node_false <- R6::R6Class(
  classname = "ast_node_false",
  inherit = ast_node_scalar_bool,
  public = list(
    compile = function() {
      self$emit("false")
    }
  )
)

ast_node_name <- R6::R6Class(
  classname = "ast_node_name",
  inherit = ast_node_terminal,
  public = list(
    get_cpp_type = function() {
      scope <- self$get_scope()
      if (self$has_scope() && scope$is_name_defined(self$get_name())) {
        value_node <- scope$get_value_node_for_name(self$get_name())
        # TODO: check for self reference
        return(value_node$get_cpp_type())
      }
      private$cpp_type
    },
    get_name = function() {
      paste0(deparse(self$get_head()), collapse = "")
    },
    has_value = function() {
      self$get_name() != ""
    }
  )
)

ast_node_empty <- R6::R6Class(
  classname = "ast_node_empty",
  inherit = ast_node_terminal
)

# a dummy node is just used to store a type
ast_node_dummy <- R6::R6Class(
  classname = "ast_node_dummy",
  inherit = ast_node_terminal,
  public = list(
    initialize = function() {
      super$initialize(quote(`dummy`), quote(`dummy`))
    }
  )
)

ast_node_assignment <- R6::R6Class(
  classname = "ast_node_assignment",
  inherit = ast_node,
  public = list(
    compile = function() {
      operands <- self$get_tail_elements()
      stopifnot(length(operands) == 2L)
      lhs_compiled <- operands[[1L]]$compile()

      lhs_is_name <- "ast_node_name" %in% class(operands[[1L]])
      is_initial_assignment <- self$is_initial_assignment()
      if (is_initial_assignment && lhs_is_name) {
        private$register_assignment()
      }

      # check if assignment is a QR decomposition
      rhs <- operands[[2L]]
      if (any(c("ast_node_qr_init", "ast_node_svd_init") %in% class(rhs))) {
        return(rhs$compile(lhs_compiled))
      }


      lhs_is_call <- "ast_node_function_call" %in% class(operands[[1L]]) ||
        "ast_node_element_access" %in% class(operands[[1L]])
      type <- if (is_initial_assignment && !lhs_is_call) {
        type <- operands[[2L]]$get_cpp_type()
        # we use auto for anything non armadillo
        type <- auto_or_arma_type(type)
        # also let the name point to this inital expression
        # operands[[1L]]$set_value_node(operands[[2L]])
        paste0(type, " ")
      } else {
        ""
      }

      # lambda
      is_lambda <- "ast_node_function" %in% class(rhs)
      is_lambda_recursive <- FALSE
      if (is_lambda) {
        if (rhs$is_recursive(lhs_compiled)) {
          is_lambda_recursive <- TRUE
          type <- paste0(rhs$get_cpp_function_type(), " ")
        } else {
          type <- paste0("auto", " ")
        }
      }

      # something wrong here
      if (is_lambda && is_lambda_recursive) {
        self$emit(
          type, lhs_compiled, ";\n",
          lhs_compiled, " = ",
          operands[[2L]]$compile(is_recursive = TRUE),
          ";"
        )
      } else {
        self$emit(
          type,
          lhs_compiled,
          " = ",
          operands[[2L]]$compile(),
          ";"
        )
      }
    },
    check_and_register_assignment = function() {
      tail_elements <- self$get_tail_elements()
      stopifnot(length(tail_elements) == 2L)
      tail_elements[[1L]]$set_cpp_type(tail_elements[[2L]]$get_cpp_type())
      self$set_cpp_type(tail_elements[[2L]]$get_cpp_type())
    },
    is_initial_assignment = function() {
      var_name <- paste0(deparse(self$get_tail_elements()[[1L]]$get_sexp()), collapse = "")
      if (self$has_scope()) {
        if (self$get_scope()$is_name_defined(var_name)) {
          return(FALSE)
        }
        if (self$get_scope()$has_scope()) {
          # maybe it was previously already defined
          already_defined <- self$get_scope()$get_scope()$is_name_defined(var_name)
          if (already_defined) {
            return(FALSE)
          }
        }
      }
      TRUE
    }
  ),
  private = list(
    register_assignment = function() {
      if (self$has_scope()) {
        var_name <- self$get_tail_elements()[[1L]]$get_name()
        self$get_scope()$register_new_name(var_name, self$get_tail_elements()[[2L]])
      }
    }
  )
)

ast_node_return <- R6::R6Class(
  classname = "ast_node_return",
  inherit = ast_node,
  public = list(
    compile = function() {
      operands <- self$get_tail_elements()
      stopifnot(length(operands) %in% c(1L, 2L))
      self$emit(
        "return ",
        operands[[1L]]$compile(),
        ";"
      )
    },
    get_cpp_type = function() {
      operands <- self$get_tail_elements()
      if (length(operands) == 2L) {
        return(operands[[2]]$get_cpp_type())
      }
      operands[[1L]]$get_cpp_type()
    }
  )
)

make_binary_operator_class <- function(name, op) {
  R6::R6Class(
    classname = name,
    inherit = ast_node,
    public = list(
      compile = function() {
        operands <- self$get_tail_elements()
        stopifnot(length(operands) %in% c(1:2))
        if (length(operands) == 1L) {
          self$emit(op, operands[[1L]]$compile())
        } else {
          self$emit(
            operands[[1L]]$compile(),
            " ", op, " ",
            operands[[2L]]$compile()
          )
        }
      },
      get_cpp_type = function() {
        operands <- self$get_tail_elements()
        arma_type <- if (operands[[1L]]$is_armadillo_cpp_type()) {
          operands[[1L]]
        } else if (length(operands) == 2L && operands[[2L]]$is_armadillo_cpp_type()) {
          operands[[2L]]
        }
        if (!is.null(arma_type)) {
          if (op %in% c("*", "/", "+", "-")) {
            # TODO: HEURISTIC: we return the type of the first armadillo type
            return(arma_type$get_cpp_type())
          }
          if (op %in% c("<=", ">=", "<", ">", "==")) {
            # it is an arma::umat
            # TODO: can we figure out if rowvec/colvec?
            return("arma::umat")
          }
        }
        private$cpp_type
      }
    )
  )
}
ast_node_plus <- make_binary_operator_class("ast_node_plus", "+")
ast_node_minus <- make_binary_operator_class("ast_node_minus", "-")
ast_node_div <- make_binary_operator_class("ast_node_div", "/")
ast_node_greater <- make_binary_operator_class("ast_node_greater", ">")
ast_node_less <- make_binary_operator_class("ast_node_less", "<")
ast_node_leq <- make_binary_operator_class("ast_node_leq", "<=")
ast_node_geq <- make_binary_operator_class("ast_node_geq", ">=")
ast_node_equal <- make_binary_operator_class("ast_node_equal", "==")
ast_node_nequal <- make_binary_operator_class("ast_node_nequal", "!=")
ast_node_logical_and <- make_binary_operator_class("ast_node_logical_and", "&&")
ast_node_logical_or <- make_binary_operator_class("ast_node_logical_or", "||")

ast_node_block <- R6::R6Class(
  classname = "ast_node_block",
  inherit = ast_node,
  public = list(
    initialize = function(...) {
      super$initialize(...)
      private$value_map <- new.env(parent = emptyenv())
    },
    compile = function() {
      self$emit(
        "{\n",
        paste0(lapply(self$get_tail_elements(), function(x) x$compile()), collapse = "\n"),
        "\n}\n"
      )
    },
    register_new_name = function(name_chr, value_node) {
      # no reassignments, first one counts
      name_still_free <- !self$has_scope() || !self$get_scope()$is_name_defined(name_chr)
      if (is.null(private$value_map[[name_chr]]) && name_still_free) {
        private$value_map[[name_chr]] <- value_node
      }
    },
    new_random_variable_id = function() {
      prefix <- if (self$has_scope()) {
        self$get_scope()$new_random_variable_id()
      } else {
        ""
      }
      var_id <- paste0(prefix, private$rand_variable_name_count)
      private$rand_variable_name_count <- private$rand_variable_name_count + 1L
      var_id
    },
    is_name_defined = function(name) {
      if (nchar(name) == 0L) {
        return(FALSE)
      }
      !is.null(private$value_map[[as.character(name)]]) ||
        (!is.null(self$get_scope()) && self$get_scope()$is_name_defined(name))
    },
    get_value_node_for_name = function(name) {
      el <- private$value_map[[as.character(name)]]
      if (is.null(el) && self$has_scope()) {
        self$get_scope()$get_value_node_for_name(name)
      } else {
        el
      }
    },
    find_all_parent_names = function(name) {
      unique(c(
        names(private$name_store),
        if (self$has_scope()) {
          self$get_scope()$find_all_parent_names()
        }
      ))
    },
    count_assignments = function(of_name) {
      assignments <- private$find_nodes("ast_node_assignment")
      assignments <- Filter(function(assignment) {
        elements <- assignment$get_tail_elements()
        lhs <- elements[[1L]]
        "ast_node_name" %in% class(lhs) && lhs$get_name() == of_name ||
          (
            "ast_node_element_access" %in% class(lhs) &&
              "ast_node_name" %in% class(lhs$get_tail_elements()[[1L]]) &&
              lhs$get_tail_elements()[[1L]]$get_name() == of_name
          )
      }, assignments)
      length(assignments)
    }
  ),
  private = list(
    value_map = NULL,
    rand_variable_name_count = 0L
  )
)

ast_node_global_scope <- R6::R6Class(
  classname = "ast_node_global_scope",
  inherit = ast_node_block,
  public = list(
    initialize = function(sexp) {
      super$initialize(sexp, quote(`{`))
    }
  )
)

ast_node_length <- R6::R6Class(
  classname = "ast_node_length",
  inherit = ast_node,
  public = list(
    compile = function() {
      stopifnot(length(self$get_tail_elements()) == 1L)
      self$emit(
        self$get_tail_elements()[[1L]]$compile(), ".n_elem"
      )
    },
    get_cpp_type = function() {
      "int"
    }
  )
)

ast_node_ncol <- R6::R6Class(
  classname = "ast_node_ncol",
  inherit = ast_node,
  public = list(
    compile = function() {
      stopifnot(length(self$get_tail_elements()) == 1L)
      self$emit(
        self$get_tail_elements()[[1L]]$compile(), ".n_cols"
      )
    },
    get_cpp_type = function() {
      "int"
    }
  )
)

ast_node_nrow <- R6::R6Class(
  classname = "ast_node_nrow",
  inherit = ast_node,
  public = list(
    compile = function() {
      stopifnot(length(self$get_tail_elements()) == 1L)
      self$emit(
        self$get_tail_elements()[[1L]]$compile(), ".n_rows"
      )
    },
    get_cpp_type = function() {
      "int"
    }
  )
)

ast_node_if <- R6::R6Class(
  classname = "ast_node_if",
  inherit = ast_node,
  public = list(
    compile = function() {
      stopifnot(length(self$get_tail_elements()) %in% c(2L, 3L))
      elements <- self$get_tail_elements()
      self$emit(
        "\nif ( ", elements[[1L]]$compile(), " ) \n",
        elements[[2L]]$compile(),
        if (length(elements) > 2L && !is.null(elements[[3L]])) {
          paste0(" else ", elements[[3L]]$compile())
        }
      )
    }
  )
)

ast_node_mult <- R6::R6Class(
  classname = "ast_node_mult",
  inherit = ast_node,
  public = list(
    compile = function() {
      stopifnot(length(self$get_tail_elements()) == 2L)
      elements <- self$get_tail_elements()

      op <- private$get_mult_operator()
      self$emit(
        elements[[1L]]$compile(),
        " ",
        op,
        " ",
        elements[[2L]]$compile()
      )
    },
    get_cpp_type = function() {
      elements <- self$get_tail_elements()
      if (private$get_mult_operator() == "%") {
        # TODO: need to figure out types for imat, umat, mat combinations
        return(elements[[1L]]$get_cpp_type())
      } else {
        arma_element <- if (elements[[1L]]$is_armadillo_cpp_type()) {
          elements[[1L]]
        } else if (elements[[2L]]$is_armadillo_cpp_type()) {
          elements[[2L]]
        }
        if (!is.null(arma_element)) {
          return(arma_element$get_cpp_type())
        } else {
          # TODO: figure more combinations
          # # "Otherwise, if either operand is double, the other shall be converted to double."
          #
          # if ("double" %in% c(elements[[2L]]$get_cpp_type(), elements[[2L]]$get_cpp_type())) {
          #   return("double")
          # }
        }
      }
      private$cpp_type
    }
  ),
  private = list(
    get_mult_operator = function() {
      elements <- self$get_tail_elements()
      # only use armadillo element-wise multiplication
      # if both operands are armadillo types
      if (!elements[[1L]]$is_armadillo_cpp_type() ||
        !elements[[2L]]$is_armadillo_cpp_type()) {
        "*"
      } else {
        "%"
      }
    }
  )
)

ast_node_matmul <- R6::R6Class(
  classname = "ast_node_matmul",
  inherit = ast_node,
  public = list(
    compile = function() {
      stopifnot(length(self$get_tail_elements()) == 2L)
      elements <- self$get_tail_elements()
      self$emit(
        elements[[1L]]$compile(),
        " * ",
        elements[[2L]]$compile()
      )
    },
    get_cpp_type = function() {
      arma_type <- find_arma_cpp_types(self$get_tail_elements())
      arma_cpp_types <- vapply(arma_type, function(x) x$get_cpp_type(), character(1L))
      if (length(arma_type) == 2L) {
        is_integer_arma_type <- function(type) {
          grepl("^arma::i", type)
        }
        is_unsigned_integer_arma_type <- function(type) {
          grepl("^arma::u", type)
        }
        is_double_arma_type <- function(type) {
          type %in% paste0("arma::", c("vec", "colvec", "rowvec", "mat"))
        }
        # TODO: revisit these rules and contemplate
        if (any(sapply(arma_cpp_types, is_double_arma_type))) {
          return("arma::mat")
        }
        if (any(sapply(arma_cpp_types, is_integer_arma_type))) {
          return("arma::imat")
        }
        if (all(sapply(arma_cpp_types, is_unsigned_integer_arma_type))) {
          return("arma::umat")
        }
      }
      if (length(arma_type) == 1L) {
        return(arma_type[[1L]]$get_cpp_type())
      }
      "arma::mat"
    }
  )
)

ast_node_pow <- R6::R6Class(
  classname = "ast_node_pow",
  inherit = ast_node,
  public = list(
    compile = function() {
      stopifnot(length(self$get_tail_elements()) == 2L)
      elements <- self$get_tail_elements()
      fun <- if (elements[[1L]]$is_armadillo_cpp_type()) {
        # we assume the second argument is compatible
        "arma::pow"
      } else {
        "std::pow"
      }
      self$emit(
        fun, "(",
        elements[[1L]]$compile(),
        ", ",
        elements[[2L]]$compile(),
        ")"
      )
    },
    get_cpp_type = function() {
      self$get_tail_elements()[[1L]]$get_cpp_type()
    }
  )
)

ast_node_bracket <- R6::R6Class(
  classname = "ast_node_bracket",
  inherit = ast_node,
  public = list(
    compile = function() {
      stopifnot(length(self$get_tail_elements()) == 1L)
      self$emit(
        "(",
        self$get_tail_elements()[[1L]]$compile(),
        ")"
      )
    },
    get_cpp_type = function() {
      self$get_tail_elements()[[1L]]$get_cpp_type()
    }
  )
)

ast_node_colsums <- R6::R6Class(
  classname = "ast_node_colsums",
  inherit = ast_node,
  public = list(
    compile = function() {
      stopifnot(length(self$get_tail_elements()) == 1L)
      self$emit("arma::sum( ", self$get_tail_elements()[[1L]]$compile(), ", 0 )")
    },
    get_cpp_type = function() {
      input_type <- self$get_tail_elements()[[1L]]$get_cpp_type()
      arma_type <- arma_element_type(input_type)
      if (!is.null(arma_type)) {
        get_arma_type("rowvec", elements = arma_type)
      } else {
        private$cpp_type
      }
    }
  )
)

ast_node_rowsums <- R6::R6Class(
  classname = "ast_node_rowsums",
  inherit = ast_node,
  public = list(
    compile = function() {
      stopifnot(length(self$get_tail_elements()) == 1L)
      self$emit("arma::sum( ", self$get_tail_elements()[[1L]]$compile(), ", 1 )")
    },
    get_cpp_type = function() {
      input_type <- self$get_tail_elements()[[1L]]$get_cpp_type()
      arma_type <- arma_element_type(input_type)
      if (!is.null(arma_type)) {
        get_arma_type("colvec", elements = arma_type)
      } else {
        private$cpp_type
      }
    }
  )
)

ast_node_colmeans <- R6::R6Class(
  classname = "ast_node_colmeans",
  inherit = ast_node,
  public = list(
    compile = function() {
      stopifnot(length(self$get_tail_elements()) == 1L)
      self$emit("arma::mean( ", self$get_tail_elements()[[1L]]$compile(), ", 0 )")
    },
    get_cpp_type = function() {
      "arma::rowvec"
    }
  )
)

ast_node_rowmeans <- R6::R6Class(
  classname = "ast_node_rowmeans",
  inherit = ast_node,
  public = list(
    compile = function() {
      stopifnot(length(self$get_tail_elements()) == 1L)
      self$emit("arma::mean( ", self$get_tail_elements()[[1L]]$compile(), ", 1 )")
    },
    get_cpp_type = function() {
      "arma::colvec"
    }
  )
)

ast_node_forwardsolve <- R6::R6Class(
  classname = "ast_node_forwardsolve",
  inherit = ast_node,
  public = list(
    compile = function() {
      stopifnot(length(self$get_tail_elements()) == 2L)
      elements <- self$get_tail_elements()
      self$emit(
        "arma::solve(arma::trimatl(",
        elements[[1L]]$compile(), "), ",
        elements[[2L]]$compile(),
        " )"
      )
    },
    get_cpp_type = function() {
      "arma::colvec"
    }
  )
)

ast_node_backsolve <- R6::R6Class(
  classname = "ast_node_backsolve",
  inherit = ast_node,
  public = list(
    compile = function() {
      stopifnot(length(self$get_tail_elements()) == 2L)
      elements <- self$get_tail_elements()
      self$emit(
        "arma::solve(arma::trimatu(",
        elements[[1L]]$compile(), "), ",
        elements[[2L]]$compile(),
        " )"
      )
    },
    get_cpp_type = function() {
      "arma::colvec"
    }
  )
)

ast_node_solve <- R6::R6Class(
  classname = "ast_node_solve",
  inherit = ast_node,
  public = list(
    compile = function() {
      stopifnot(length(self$get_tail_elements()) %in% 1:2)
      elements <- self$get_tail_elements()
      if (length(elements) == 1L) {
        self$emit("arma::inv(", elements[[1L]]$compile(), ")")
      } else {
        self$emit(
          "arma::solve(",
          elements[[1L]]$compile(),
          ", ",
          elements[[2L]]$compile(),
          ")"
        )
      }
    },
    get_cpp_type = function() {
      if (length(self$get_tail_elements()) == 1L) {
        "arma::mat"
      } else {
        "arma::colvec"
      }
    }
  )
)

ast_node_crossprod <- R6::R6Class(
  classname = "ast_node_crossprod",
  inherit = ast_node,
  public = list(
    compile = function() {
      stopifnot(length(self$get_tail_elements()) %in% 1:2)
      elements <- self$get_tail_elements()
      if (length(elements) == 1L) {
        expr1 <- elements[[1L]]$compile()
        expr2 <- expr1
      } else {
        expr1 <- elements[[1L]]$compile()
        expr2 <- elements[[2L]]$compile()
      }
      self$emit("arma::trans(", expr1, ") * (", expr2, ")")
    },
    get_cpp_type = function() {
      "arma::mat"
    }
  )
)

ast_node_tcrossprod <- R6::R6Class(
  classname = "ast_node_tcrossprod",
  inherit = ast_node,
  public = list(
    compile = function() {
      stopifnot(length(self$get_tail_elements()) %in% 1:2)
      elements <- self$get_tail_elements()
      if (length(elements) == 1L) {
        expr1 <- elements[[1L]]$compile()
        expr2 <- expr1
      } else {
        expr1 <- elements[[1L]]$compile()
        expr2 <- elements[[2L]]$compile()
      }
      self$emit("(", expr1, ") * arma::trans(", expr2, ")")
    },
    get_cpp_type = function() {
      "arma::mat"
    }
  )
)

ast_node_while <- R6::R6Class(
  classname = "ast_node_while",
  inherit = ast_node,
  public = list(
    compile = function() {
      elements <- self$get_tail_elements()
      stopifnot(length(elements) == 2L)
      self$emit(
        "while (", elements[[1L]]$compile(), ")\n",
        elements[[2L]]$compile()
      )
    }
  )
)

ast_node_for <- R6::R6Class(
  classname = "ast_node_for",
  inherit = ast_node,
  public = list(
    compile = function() {
      elements <- self$get_tail_elements()
      stopifnot(length(elements) == 3L)
      # we currently allow just one pattern
      # for (i in X) { ... }
      # will be translated to
      # for (const auto& i : X) {
      #   ...
      # }
      stopifnot("ast_node_name" %in% class(elements[[1L]]))
      iter_var_name <- elements[[1L]]$compile()
      container <- elements[[2L]]
      body <- elements[[3L]]
      is_iter_var_used <- private$is_loop_var_used(iter_var_name, body)
      if (is_iter_var_used) {
        # TODO: this modifies the input tree during compilation mmmmh
        # This means we now have to run the type deduction again for the body
        # Think about creating the structure during construction of the ast
        stopifnot("ast_node_block" %in% class(body))
        body_scope <- self$get_scope()
        dummy_iter_var <- ast_node_dummy$new()
        dummy_iter_var$set_cpp_type("auto")
        body_scope$register_new_name(iter_var_name, dummy_iter_var)
        self$emit(
          "for (const auto& ", iter_var_name,
          " : ", container$compile(), ")\n",
          body$compile(),
          "\n"
        )
      } else {

        container_var_name <- make_random_var_name(self$get_scope()$new_random_variable_id())
        iter_var_name <- make_random_var_name(self$get_scope()$new_random_variable_id())

        self$emit(
          "\nauto&& ", container_var_name, " = ", container$compile(), ";\n",
          "for (auto ",
          iter_var_name, " = ", container_var_name, ".begin();",
          iter_var_name, " != ", container_var_name, ".end();",
          "++", iter_var_name, ")\n",
          body$compile(),
          "\n"
        )
      }
    }
  ),
  private = list(
    is_loop_var_used = function(var_name, body) {
      all_names <- body$find_all_names()
      for (name_node in all_names) {
        if (var_name == name_node$get_head()) {
          return(TRUE)
        }
      }
      FALSE
    }
  )
)

ast_node_qr_init <- R6::R6Class(
  classname = "ast_node_qr_init",
  inherit = ast_node,
  public = list(
    compile = function(var_name) {
      stopifnot(is.character(var_name))
      stopifnot(length(self$get_tail_elements()) == 1L)
      rhs <- self$get_tail_elements()[[1L]]$compile()
      var_name_q <- paste0(var_name, "__Q")
      var_name_r <- paste0(var_name, "__R")
      self$emit(
        "arma::mat ", var_name_q, ", ", var_name_r, ";\n",
        "arma::qr_econ( ", var_name_q, ", ", var_name_r, ", ", rhs, " );\n"
      )
    }
  )
)

ast_node_dollar <- R6::R6Class(
  classname = "ast_node_dollar",
  inherit = ast_node,
  public = list(
    compile = function() {
      lhs_val <- private$get_lhs_value()
      if (!is.null(lhs_val)) {
        rhs <- self$get_tail_elements()[[2L]]
        self$emit(
          lhs_val$dollar(rhs)$compile()
        )
      }
    },
    get_cpp_type = function() {
      lhs_val <- private$get_lhs_value()
      if (!is.null(lhs_val)) {
        rhs <- self$get_tail_elements()[[2L]]
        return(lhs_val$dollar(rhs)$get_cpp_type())
      }
      private$cpp_type
    }
  ),
  private = list(
    get_lhs_value = function() {
      elements <- self$get_tail_elements()
      lhs <- elements[[1L]]
      rhs <- elements[[2L]]
      stopifnot("ast_node_name" %in% class(lhs))
      stopifnot("ast_node_name" %in% class(rhs))
      #stopifnot(self$get_scope()$is_name_defined(lhs$get_name()))
      self$get_scope()$get_value_node_for_name(lhs$get_name())
    }
  )
)

ast_node_svd_init <- R6::R6Class(
  classname = "ast_node_svd_init",
  inherit = ast_node,
  public = list(
    compile = function(assigned_var_name) {
      stopifnot(length(self$get_tail_elements()) == 1L)
      stopifnot("ast_node_assignment" %in% class(self$get_parent()))
      rhs <- self$get_tail_elements()[[1L]]$compile()
      var_name <- make_random_var_name(self$get_scope()$new_random_variable_id())
      private$set_base_varname(var_name)
      var_name_u <- private$get_varname("u")
      var_name_v <- private$get_varname("v")
      var_name_d <- private$get_varname("d")
      self$emit(
        "arma::mat ", var_name_u, ", ", var_name_v, ";\n",
        "arma::colvec ", var_name_d, ";\n",
        "arma::svd_econ(", var_name_u, ", ", var_name_d, ", ", var_name_v, ", ", rhs, ");\n"
      )
    },
    dollar = function(node) {
      stopifnot("ast_node_name" %in% class(node))
      type <- node$get_name()
      var_name <- private$get_varname(type)
      var_name_node <- ast_node_name$new(var_name, var_name)
      cpp_type <- if (type %in% c("u", "v")) {
        "arma::mat"
      } else {
        "arma::colvec"
      }
      var_name_node$set_cpp_type(cpp_type)
      var_name_node$set_parent(self)
      var_name_node$set_scope(self$get_scope())
      var_name_node
    }
  ),
  private = list(
    base_var_name = NULL,
    set_base_varname = function(name) {
      private$base_var_name <- name
    },
    get_varname = function(type) {
      stopifnot(type %in% c("u", "v", "d"))
      paste0(private$base_var_name, "__", type)
    }
  )
)

ast_node_qr_q <- R6::R6Class(
  classname = "ast_node_qr_q",
  inherit = ast_node,
  public = list(
    compile = function() {
      stopifnot(length(self$get_tail_elements()) == 1L)
      self$emit(self$get_tail_elements()[[1L]]$compile(), "__Q")
    },
    get_cpp_type = function() {
      "arma::mat"
    }
  )
)

ast_node_qr_r <- R6::R6Class(
  classname = "ast_node_qr_q",
  inherit = ast_node,
  public = list(
    compile = function() {
      stopifnot(length(self$get_tail_elements()) == 1L)
      self$emit(self$get_tail_elements()[[1L]]$compile(), "__R")
    },
    get_cpp_type = function() {
      "arma::mat"
    }
  )
)

ast_node_seq_len <- R6::R6Class(
  classname = "ast_node_seq_len",
  inherit = ast_node,
  public = list(
    compile = function() {
      stopifnot(length(self$get_tail_elements()) == 1L)
      until <- self$get_tail_elements()[[1L]]$compile()
      self$emit(
        "arma::linspace<arma::colvec>(1, ", until, ", ", until, ")"
      )
    },
    get_cpp_type = function() {
      "arma::colvec"
    }
  )
)


ast_node_seq <- R6::R6Class(
  classname = "ast_node_seq",
  inherit = ast_node,
  public = list(
    compile = function() {
      stopifnot(length(self$get_tail_elements()) == 3L)
      elements <- self$get_tail_elements()
      from <- elements[[1L]]$compile()
      to <- elements[[2L]]$compile()
      len_out <- elements[[3L]]$compile()
      self$emit(
        "arma::linspace<arma::colvec>(", from, ", ", to, ", ", len_out, ")"
      )
    },
    get_cpp_type = function() {
      "arma::colvec"
    }
  )
)

ast_node_colon <- R6::R6Class(
  classname = "ast_node_colon",
  inherit = ast_node,
  public = list(
    compile = function() {
      stopifnot(length(self$get_tail_elements()) == 2L)
      elements <- self$get_tail_elements()
      from <- elements[[1L]]$compile()
      to <- elements[[2L]]$compile()
      self$emit(
        "arma::linspace<arma::colvec>(", from, ", ", to, ", ", to, " - ", from, " + 1)"
      )
    },
    get_cpp_type = function() {
      "arma::colvec"
    }
  )
)

ast_node_break <- R6::R6Class(
  classname = "ast_node_break",
  inherit = ast_node,
  public = list(
    compile = function() {
      stopifnot(length(self$get_tail_elements()) == 0L)
      self$emit("break;\n")
    }
  )
)

ast_node_next <- R6::R6Class(
  classname = "ast_node_next",
  inherit = ast_node,
  public = list(
    compile = function() {
      stopifnot(length(self$get_tail_elements()) == 0L)
      self$emit("continue;\n")
    }
  )
)

ast_node_element_access <- R6::R6Class(
  classname = "ast_node_element_access",
  inherit = ast_node,
  public = list(
    compile = function() {
      stopifnot(length(self$get_tail_elements()) %in% 2:3)
      # TODO: refactor and reflect "missing" names
      # in the OO structure properly. Currently that is indicated
      # by the method $has_value() of ast_node_name objects
      # possible allowed and assumed values are
      # x[i, j] => list(x, i, j)
      # x[i, ] => list(x, i, )
      # x[, j] => list(x, , j)
      # x[i] => list(x, i)
      head_code <- self$get_tail_elements()[[1L]]$compile()
      if (!private$is_subview()) {
        args <- self$get_tail_elements()[[2L]]$compile()
        if (length(self$get_tail_elements()) == 3L) {
          args <- c(args, self$get_tail_elements()[[3L]]$compile())
        }
        # counting starts at 0
        args <- paste0(args, " - 1")
        return(self$emit(
          head_code,
          "(", paste0(args, collapse = ", "), ")"
        ))
      }
      elements <- self$get_tail_elements()
      is_selected <- function(idx) {
      inherits(elements[[idx]], "ast_node_name") &&
        elements[[idx]]$has_value() ||
        (inherits(elements[[idx]], "ast_node_terminal") &&
          !inherits(elements[[idx]], "ast_node_name"))
      }
      row_selection <- is_selected(2L)
      col_selection <- is_selected(3L)
      stopifnot(xor(row_selection, col_selection))
      if (row_selection) {
        return(self$emit(
          head_code, ".row(", elements[[2L]]$compile(), " - 1)"
        ))
      }
      if (col_selection) {
        return(self$emit(
          head_code, ".col(", elements[[3L]]$compile(), " - 1)"
        ))
      }
      stop("unreachable")
    },
    get_cpp_type = function() {
      if (private$is_subview()) {
        return("arma::mat")
      }
      # TODO: depends on matrix type
      "auto"
    }
  ),
  private = list(
    is_subview = function() {
      !all(
        vapply(self$get_tail_elements(), function(x) {
          !inherits(x, "ast_node_name") || x$has_value()
        }, logical(1L))
      )
    }
  )
)

ast_node_norm <- R6::R6Class(
  classname = "ast_node_norm",
  inherit = ast_node,
  public = list(
    compile = function() {
      stopifnot(length(self$get_tail_elements()) == 1L)
      self$emit(
        "arma::norm(", self$get_tail_elements()[[1L]]$compile(), ")"
      )
    },
    get_cpp_type = function() {
      "double"
    }
  )
)

ast_node_diag <- R6::R6Class(
  classname = "ast_node_diag",
  inherit = ast_node,
  public = list(
    compile = function() {
      stopifnot(length(self$get_tail_elements()) == 1L)
      self$emit(
        "arma::diagmat(", self$get_tail_elements()[[1L]]$compile(), ")"
      )
    },
    get_cpp_type = function() {
      "arma::mat"
    }
  )
)

ast_node_list <- R6::R6Class(
  classname = "ast_node_list",
  inherit = ast_node,
  public = list(
    compile = function() {
      stopifnot(length(self$get_tail_elements()) >= 1L)
      args <- vapply(
        self$get_tail_elements(),
        function (x) x$compile(),
        character(1L)
      )
      self$emit(
        "Rcpp::List::create(",
        paste0(args, collapse = ", "),
        ")"
      )
    },
    get_cpp_type = function() {
      "Rcpp::List"
    }
  )
)

ast_node_rep_int <- R6::R6Class(
  classname = "ast_node_rep_int",
  inherit = ast_node,
  public = list(
    compile = function() {
      stopifnot(length(self$get_tail_elements()) == 2L)
      self$emit(
        "arma::colvec(std::vector<double>(",
        self$get_tail_elements()[[2L]]$compile(),
        ", ",
        self$get_tail_elements()[[1L]]$compile(),
        "))"
      )
    },
    get_cpp_type = function() {
      "arma::colvec"
    }
  )
)

make_generic_unary_function_class <- function(class_name, fun) {
  R6::R6Class(
    classname = paste0("ast_node_generic_unary_function_", class_name),
    inherit = ast_node_function_call,
    public = list(
      compile = function() {
        stopifnot(length(self$get_tail_elements()) == 1L)
        fun_name <- if (is.character(fun)) {
          fun
        } else {
          arma_type <- self$get_tail_elements()[[1L]]$is_armadillo_cpp_type()
          if (arma_type) {
            fun$arma
          } else {
            fun$std
          }
        }
        self$emit(fun_name, "(", self$get_tail_elements()[[1L]]$compile(), ")")
      },
      get_cpp_type = function() {
        # the assumption is that the input type is the same as the return type
        # for these kind of functions
        self$get_tail_elements()[[1L]]$get_cpp_type()
      }
    )
  )
}

unary_reduction_node <- function(arma_fun,
                                 classname = paste0("ast_node_", arma_fun)) {
  R6::R6Class(
    classname = classname,
    inherit = ast_node,
    public = list(
      compile = function() {
        stopifnot(length(self$get_tail_elements()) == 1L)
        self$emit(
          "arma::", arma_fun, "(",
          self$get_tail_elements()[[1L]]$compile(),
          ")"
        )
      },
      get_cpp_type = function() {
        # TODO: depends on data type
        "auto"
      }
    )
  )
}

ast_node_sum <- unary_reduction_node("accu", "ast_node_sum")
ast_node_min <- unary_reduction_node("min")
ast_node_max <- unary_reduction_node("max")

ast_node_type_spec <- R6::R6Class(
  classname = "ast_node_type_spec",
  inherit = ast_node
)

ast_node_type_matrix <- R6::R6Class(
  classname = "ast_node_type_matrix",
  inherit = ast_node_type_spec,
  public = list(
    get_cpp_type = function() {
      "arma::mat"
    }
  )
)

ast_node_type_colvec <- R6::R6Class(
  classname = "ast_node_type_colvec",
  inherit = ast_node_type_spec,
  public = list(
    get_cpp_type = function() {
      "arma::colvec"
    }
  )
)

ast_node_type_rowvec <- R6::R6Class(
  classname = "ast_node_type_rowvec",
  inherit = ast_node_type_spec,
  public = list(
    get_cpp_type = function() {
      "arma::rowvec"
    }
  )
)

ast_node_type_scalar_integer <- R6::R6Class(
  classname = "ast_node_type_scalar_integer",
  inherit = ast_node_type_spec,
  public = list(
    get_cpp_type = function() {
      "int"
    }
  )
)

ast_node_type_scalar_logical <- R6::R6Class(
  classname = "ast_node_type_scalar_logical",
  inherit = ast_node_type_spec,
  public = list(
    get_cpp_type = function() {
      "bool"
    }
  )
)

ast_node_type_scalar_numeric <- R6::R6Class(
  classname = "ast_node_type_scalar_numeric",
  inherit = ast_node_type_spec,
  public = list(
    get_cpp_type = function() {
      "double"
    }
  )
)

element_type_map <- new.env(parent = emptyenv())
element_type_map[["<-"]] <- ast_node_assignment
element_type_map[["="]] <- ast_node_assignment
element_type_map[["{"]] <- ast_node_block
element_type_map[["function"]] <- ast_node_function
element_type_map[["arma_pairlist"]] <- ast_node_arma_pairlist
element_type_map[["["]] <- ast_node_element_access
element_type_map[["$"]] <- ast_node_dollar
element_type_map[["+"]] <- ast_node_plus
element_type_map[["-"]] <- ast_node_minus
element_type_map[["^"]] <- ast_node_pow
element_type_map[["*"]] <- ast_node_mult
element_type_map[["/"]] <- ast_node_div
element_type_map[["<"]] <- ast_node_less
element_type_map[[">"]] <- ast_node_greater
element_type_map[["<="]] <- ast_node_leq
element_type_map[[">="]] <- ast_node_geq
element_type_map[["=="]] <- ast_node_equal
element_type_map[["!="]] <- ast_node_nequal
element_type_map[["&&"]] <- ast_node_logical_and
element_type_map[["||"]] <- ast_node_logical_or
element_type_map[["%*%"]] <- ast_node_matmul
element_type_map[["("]] <- ast_node_bracket
element_type_map[["svd"]] <- ast_node_svd_init
element_type_map[["qr"]] <- ast_node_qr_init
element_type_map[["qr.Q"]] <- ast_node_qr_q
element_type_map[["qr.R"]] <- ast_node_qr_r
element_type_map[["solve"]] <- ast_node_solve
element_type_map[["sum"]] <- ast_node_sum
element_type_map[["min"]] <- ast_node_min
element_type_map[["max"]] <- ast_node_max
element_type_map[["nrow"]] <- ast_node_nrow
element_type_map[["ncol"]] <- ast_node_ncol
element_type_map[["length"]] <- ast_node_length
element_type_map[["seq_len"]] <- ast_node_seq_len
element_type_map[["seq"]] <- ast_node_seq
element_type_map[[":"]] <- ast_node_colon
element_type_map[["rep.int"]] <- ast_node_rep_int
element_type_map[["norm"]] <- ast_node_norm
element_type_map[["diag"]] <- ast_node_diag
element_type_map[["list"]] <- ast_node_list
element_type_map[["crossprod"]] <- ast_node_crossprod
element_type_map[["tcrossprod"]] <- ast_node_tcrossprod
element_type_map[["colSums"]] <- ast_node_colsums
element_type_map[["rowSums"]] <- ast_node_rowsums
element_type_map[["colMeans"]] <- ast_node_colmeans
element_type_map[["rowMeans"]] <- ast_node_rowmeans
element_type_map[["return"]] <- ast_node_return
element_type_map[["backsolve"]] <- ast_node_backsolve
element_type_map[["forwardsolve"]] <- ast_node_forwardsolve
element_type_map[["if"]] <- ast_node_if
element_type_map[["for"]] <- ast_node_for
element_type_map[["while"]] <- ast_node_while
element_type_map[["type_matrix"]] <- ast_node_type_matrix
element_type_map[["type_colvec"]] <- ast_node_type_colvec
element_type_map[["type_rowvec"]] <- ast_node_type_rowvec
element_type_map[["type_scalar_integer"]] <- ast_node_type_scalar_integer
element_type_map[["type_scalar_numeric"]] <- ast_node_type_scalar_numeric
element_type_map[["type_scalar_logical"]] <- ast_node_type_scalar_logical

# TODO: not pretty
multi_dispatch_fun <- function(arma, std) {
  structure(
    list(arma = arma, std = std),
    class = "multi_dispatch_fun"
  )
}
unary_function_mapping <- new.env(hash = TRUE, parent = emptyenv())
unary_function_mapping$t <- "arma::trans"
unary_function_mapping$`!` <- "!"
unary_function_mapping$exp <- multi_dispatch_fun(arma = "arma::exp", std = "std::exp")
unary_function_mapping$abs <- multi_dispatch_fun(arma = "arma::abs", std = "std::abs")
unary_function_mapping$log <- multi_dispatch_fun(arma = "arma::log", std = "std::log")
unary_function_mapping$chol <- "arma::chol"
unary_function_mapping$cumsum <- "arma::cumsum"
unary_function_mapping$sqrt <- multi_dispatch_fun(arma = "arma::sqrt", std = "std::sqrt")
unary_function_mapping$floor <- multi_dispatch_fun(arma = "arma::floor", std = "std::floor")
unary_function_mapping$ceiling <- multi_dispatch_fun(arma = "arma::ceil", std = "std::ceil")
unary_function_mapping$sort <- "arma::sort"
unary_function_mapping$unique <- "arma::unique"
unary_function_mapping$pnorm <- "arma::normcdf"
unary_function_mapping$det <- "arma::det"
unary_function_mapping$cos <- multi_dispatch_fun(arma = "arma::cos", std = "std::cos")
unary_function_mapping$acos <- "arma::acos"
unary_function_mapping$cosh <- "arma::cosh"
unary_function_mapping$acosh <- "arma::acosh"
unary_function_mapping$sin <- multi_dispatch_fun(arma = "arma::sin", std = "std::sin")
unary_function_mapping$asin <- "arma::asin"
unary_function_mapping$sinh <- "arma::sinh"
unary_function_mapping$asinh <- "arma::asinh"
unary_function_mapping$tan <- multi_dispatch_fun(arma = "arma::tan", std = "std::tan")
unary_function_mapping$atan <- "arma::atan"
unary_function_mapping$tanh <- "arma::tanh"
unary_function_mapping$atanh <- "arma::atanh"
unary_function_mapping$cumsum <- "arma::cumsum"
unary_function_mapping$cumprod <- "arma::cumprod"

new_node_by_chr <- function(head_chr, is_call = FALSE) {
  if (!is.null(element_type_map[[head_chr]])) {
    return(element_type_map[[head_chr]])
  }
  if (!is.null(unary_function_mapping[[head_chr]])) {
    return(make_generic_unary_function_class(head_chr, unary_function_mapping[[head_chr]]))
  }
  if (is_call) {
    return(ast_node_function_call)
  }
  ast_node
}

make_random_var_name <- function(id) {
  paste0("arma___tmp_", id)
}

# our type system
auto_or_arma_type <- function(type) {
  if (grepl("^arma", type)) {
    type
  } else {
    "auto"
  }
}

arma_element_type <- function(type) {
  dbl_types <- c("arma::mat", "arma::colvec", "arma::rowvec")
  int_types <- c("arma::imat", "arma::icolvec", "arma::irowvec")
  uint_types <- c("arma::umat", "arma::ucolvec", "arma::urowvec")
  if (type %in% dbl_types) {
    "double"
  } else if (type %in% int_types) {
    "int"
  } else if (type %in% uint_types) {
    "uint"
  }
}

get_arma_type <- function(base_type, elements) {
  prefix <- ""
  if (elements == "int") prefix <- "i"
  if (elements == "uint") prefix <- "u"
  paste0("arma::", prefix, base_type)
}

find_arma_cpp_types <- function(nodes) {
  Filter(function(x) {
    x$is_armadillo_cpp_type()
  }, nodes)
}
dirkschumacher/armacmp documentation built on Oct. 22, 2021, 7:10 p.m.