Nothing
#' @title
#' Parse einops pattern into AST
#'
#' @description
#' In the python implementation, they preprocess all 1's
#' into () and then parse it into the equivalent of a
#' GroupAstNode. Here, we won't do that
#'
#' @param tokens EinopsTokenSequence object from the lexer
#' @return EinopsAst object
#' @keywords internal
parse_einops_ast <- function(tokens) {
if (length(tokens) == 0) {
stop("Empty token sequence")
}
# Find the top-level arrow
arrow_pos <- find_top_level_arrow_index(tokens)
# Split into input and output slices
if (arrow_pos == 1) {
input_tokens <- list()
} else {
input_tokens <- tokens[1:(arrow_pos - 1)]
}
if (arrow_pos == length(tokens)) {
output_tokens <- list()
} else {
output_tokens <- tokens[(arrow_pos + 1):length(tokens)]
}
# Parse each side
if (length(input_tokens) == 0) {
stop("Empty input pattern before arrow")
}
if (length(output_tokens) == 0) {
output_axes <- list()
} else {
output_axes <- parse_axes_iter(output_tokens)
}
input_axes <- parse_axes_iter(input_tokens)
# Create the root AST node
first_token <- tokens[[1]]
full_src <- list(start = first_token$start)
EinopsAst(input_axes, output_axes, full_src)
}
#' @title Find the arrow in a token sequence
#' @param tokens EinopsTokenSequence object
#' @return one-indexed Integer position of the arrow token
#' @keywords internal
find_top_level_arrow_index <- function(tokens) {
arrow_positions <- which(sapply(tokens, function(x) x$type == "ARROW"))
if (length(arrow_positions) == 0) stop("No '->' found in expression")
if (length(arrow_positions) > 1) stop("Multiple '->' found in expression")
arrow_positions[1]
}
#' @title Parse a one-sided einops pattern into OneSidedAstNode
#' @param tokens EinopsTokenSequence object from the lexer (no arrow expected)
#' @return OneSidedAstNode object
#' @keywords internal
parse_onesided_ast <- function(tokens) {
if (length(tokens) == 0) {
stop("Empty token sequence")
}
axes <- parse_axes_iter(tokens)
OneSidedAstNode(axes)
}
#' @title Parse a sequence of axis tokens
#' @param tokens List of tokens representing one side of the pattern
#' @return List of AST nodes
#' @keywords internal
parse_axes_iter <- function(tokens) {
if (length(tokens) == 0) {
stop("Empty axis pattern")
}
result <- list()
has_ellipsis <- FALSE
i <- 1
while (i <= length(tokens)) {
token <- tokens[[i]]
switch(token$type,
NAME = {
src <- list(start = token$start)
node <- NamedAxisAstNode(token$value, src)
result <- append(result, list(node))
},
INT = {
src <- list(start = token$start)
node <- ConstantAstNode(token$value, src)
result <- append(result, list(node))
},
ELLIPSIS = {
if (has_ellipsis) {
stop(
"Multiple ellipses found in axis pattern at position ",
token$start
)
}
has_ellipsis <- TRUE
src <- list(start = token$start)
node <- EllipsisAstNode(src)
result <- append(result, list(node))
},
UNDERSCORE = {
src <- list(start = token$start)
node <- UnderscoreAstNode(src)
result <- append(result, list(node))
},
LPAREN = {
# Find matching closing paren
paren_depth <- 1
group_start <- i + 1
group_end <- i + 1
while (group_end <= length(tokens) && paren_depth > 0) {
if (tokens[[group_end]]$type == "LPAREN") {
paren_depth <- paren_depth + 1
# Check for nesting - not allowed in einops
if (paren_depth > 1) {
stop(
"Groups cannot be nested at position ",
tokens[[group_end]]$start
)
}
} else if (tokens[[group_end]]$type == "RPAREN") {
paren_depth <- paren_depth - 1
}
if (paren_depth > 0) {
group_end <- group_end + 1
}
}
if (paren_depth > 0) {
stop(
"Unmatched opening parenthesis '(' at position ",
token$start
)
}
# Parse group contents
if (group_start > group_end - 1) {
# Empty group
group_children <- list()
} else {
group_tokens <- tokens[group_start:(group_end - 1)]
group_children <- parse_axes_iter(group_tokens)
}
# Use opening paren position as group source
src <- list(start = token$start)
group_node <- GroupAstNode(group_children, src)
result <- append(result, list(group_node))
# Skip to after the closing paren
i <- group_end
},
RPAREN = {
stop(
"Unmatched closing parenthesis ')' at position ",
token$start
)
},
stop(
"Unexpected token type '",
token$type, "' at position ", token$start
)
)
i <- i + 1
}
result
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.