tests/testthat/test-parse_shape.R

test_that("preprocess_shape_ast expands ellipsis correctly", {

    shape <- c(1, 2, 3, 4)
    ast <- OneSidedAstNode(
        ConstantAstNode(
            count = 1,
            src = list(start = 1)
        ),
        NamedAxisAstNode(
            name = "c",
            src = list(start = 3)
        ),
        EllipsisAstNode(
            src = list(start = 5)
        )
    )
    expect_identical(
        preprocess_shape_ast(ast, shape),
        OneSidedAstNode(
            ConstantAstNode(
                count = 1,
                src = list(start = 1)
            ),
            NamedAxisAstNode(
                name = "c",
                src = list(start = 3)
            ),
            UnderscoreAstNode(),
            UnderscoreAstNode()
        )
    )

    shape <- 1:8
    ast <- OneSidedAstNode(
        NamedAxisAstNode(
            name = "a",
            src = list(start = 1)
        ),
        NamedAxisAstNode(
            name = "b",
            src = list(start = 3)
        ),
        NamedAxisAstNode(
            name = "c",
            src = list(start = 5)
        ),
        UnderscoreAstNode(
            src = list(start = 7)
        ),
        UnderscoreAstNode(
            src = list(start = 9)
        ),
        EllipsisAstNode(
            src = list(start = 11)
        )
    )
    expect_identical(
        preprocess_shape_ast(ast, shape),
        OneSidedAstNode(
            NamedAxisAstNode(
                name = "a",
                src = list(start = 1)
            ),
            NamedAxisAstNode(
                name = "b",
                src = list(start = 3)
            ),
            NamedAxisAstNode(
                name = "c",
                src = list(start = 5)
            ),
            UnderscoreAstNode(
                src = list(start = 7)
            ),
            UnderscoreAstNode(
                src = list(start = 9)
            ),
            UnderscoreAstNode(),
            UnderscoreAstNode(),
            UnderscoreAstNode()
        )
    )

    shape <- 3:6
    ast <- OneSidedAstNode(
        NamedAxisAstNode(
            name = "a",
            src = list(start = 1)
        ),
        EllipsisAstNode(
            src = list(start = 3)
        ),
        NamedAxisAstNode(
            name = "w",
            src = list(start = 7)
        )
    )
    expect_identical(
        preprocess_shape_ast(ast, shape),
        OneSidedAstNode(
            NamedAxisAstNode(
                name = "a",
                src = list(start = 1)
            ),
            UnderscoreAstNode(),
            UnderscoreAstNode(),
            NamedAxisAstNode(
                name = "w",
                src = list(start = 7)
            )
        )
    )
})

test_in_all_tensor_types_that("parse_shape works", {
    expect_identical(
        parse_shape(create_seq_tensor(c(4, 4)), "height width"),
        list(height = 4L, width = 4L)
    )

    expect_identical(
        parse_shape(
            create_seq_tensor(3:4), "height width"
        ),
        list(height = 3L, width = 4L)
    )

    x <- create_seq_tensor(3:6)

    expect_identical(
        parse_shape(x, "b c h w"),
        list(b = 3L, c = 4L, h = 5L, w = 6L)
    )

    expect_identical(
        parse_shape(x, "... c h w"),
        list(c = 4L, h = 5L, w = 6L)
    )

    expect_identical(
        parse_shape(x, "3 c h w"),
        list(c = 4L, h = 5L, w = 6L)
    )

    expect_identical(
        parse_shape(x, "3 c 5 w"),
        list(c = 4L, w = 6L)
    )

    expect_identical(
        parse_shape(x, "3 c ..."),
        list(c = 4L)
    )

    expect_identical(
        parse_shape(x, "c h w1 w2"),
        list(c = 3L, h = 4L, w1 = 5L, w2 = 6L)
    )

    expect_identical(
        parse_shape(x, "b _ h w"),
        list(b = 3L, h = 5L, w = 6L)
    )

    expect_identical(
        parse_shape(x, "b _ _ w"),
        list(b = 3L, w = 6L)
    )

    expect_identical(
        parse_shape(x, "b _ _ _"),
        list(b = 3L)
    )

    expect_identical(
        parse_shape(x, "b ..."),
        list(b = 3L)
    )

    expect_identical(
        parse_shape(x, "... h _ w"),
        list(h = 4L, w = 6L)
    )
})

test_in_all_tensor_types_that("parse_shape works on central ellipsis", {
    
    x <- create_seq_tensor(3:6)

    expect_identical(
        parse_shape(x, "a ... w"),
        list(a = 3L, w = 6L)
    )

    expect_identical(
        parse_shape(x, "a ... _ w"),
        list(a = 3L, w = 6L)
    )

    expect_identical(
        parse_shape(create_seq_tensor(2:6), "a ... _ w _"),
        list(a = 2L, w = 5L)
    )

    expect_identical(
        parse_shape(create_seq_tensor(2:6), "a ... _ 5 _"),
        list(a = 2L)
    )
})

Try the einops package in your browser

Any scripts or data that you put into this service are public.

einops documentation built on Sept. 9, 2025, 5:29 p.m.