tests/testthat/test-compiler-TransformRecipe.R

test_that("expand_ellipsis() works", {

    ast <- EinopsAst(
        input_axes = OneSidedAstNode(
            NamedAxisAstNode(
                name = "a",
                src = list(start = 1)
            ),
            NamedAxisAstNode(
                name = "b",
                src = list(start = 3)
            )
        ),
        output_axes = OneSidedAstNode(
            NamedAxisAstNode(
                name = "b",
                src = list(start = 8)
            )
        ),
        src = list(start = 1)
    )
    expect_identical(expand_ellipsis(ast, 2), ast)

    ast <- EinopsAst(
        input_axes = OneSidedAstNode(
            NamedAxisAstNode(
                name = "a",
                src = list(start = 1)
            ),
            NamedAxisAstNode(
                name = "b",
                src = list(start = 3)
            ),
            NamedAxisAstNode(
                name = "c",
                src = list(start = 5)
            )
        ),
        output_axes = OneSidedAstNode(
            NamedAxisAstNode(
                name = "b",
                src = list(start = 10)
            ),
            NamedAxisAstNode(
                name = "c",
                src = list(start = 12)
            )
        ),
        src = list(start = 1)
    )

    expect_identical(expand_ellipsis(ast, 3), ast)

    # check that simple flat left and right ellipses are expanded correctly

    ast <- EinopsAst(
        input_axes = OneSidedAstNode(
            NamedAxisAstNode(
                name = "a",
                src = list(start = 1)
            ),
            EllipsisAstNode(
                src = list(start = 3)
            )
        ),
        output_axes = OneSidedAstNode(
            EllipsisAstNode(
                src = list(start = 10)
            )
        ),
        src = list(start = 1)
    )

    expected_ast <- EinopsAst(
        input_axes = OneSidedAstNode(
            NamedAxisAstNode(
                name = "a",
                src = list(start = 1)
            ),
            NamedAxisAstNode("...1"),
            NamedAxisAstNode("...2")
        ),
        output_axes = OneSidedAstNode(
            NamedAxisAstNode("...1"),
            NamedAxisAstNode("...2")
        ),
        src = list(start = 1)
    )

    expect_identical(expand_ellipsis(ast, 3), expected_ast)

    # test ... -> (...)

    ast <- EinopsAst(
        input_axes = OneSidedAstNode(
            EllipsisAstNode(
                src = list(start = 1)
            )
        ),
        output_axes = OneSidedAstNode(
            GroupAstNode(
                children = list(
                    EllipsisAstNode(
                        src = list(start = 9)
                    )
                ),
                src =  list(start = 8)
            )
        ),
        src = list(start = 1)
    )

    expected_ast <- EinopsAst(
        input_axes = OneSidedAstNode(
            NamedAxisAstNode("...1"),
            NamedAxisAstNode("...2")
        ),
        output_axes = OneSidedAstNode(
            GroupAstNode(
                children = list(
                    NamedAxisAstNode("...1"),
                    NamedAxisAstNode("...2")
                ),
                src = list(start = 8)
            )
        ),
        src = list(start = 1)
    )

    expect_identical(expand_ellipsis(ast, 2L), expected_ast)

    # test " ... ->  "

    ast <- EinopsAst(
        input_axes = OneSidedAstNode(
            EllipsisAstNode(src = list(start = 1))
        ),
        output_axes = OneSidedAstNode(),
        src = list(start = 1)
    )

    expected_ast <- EinopsAst(
        input_axes = OneSidedAstNode(
            NamedAxisAstNode("...1"),
            NamedAxisAstNode("...2")
        ),
        output_axes = OneSidedAstNode(),
        src = list(start = 1)
    )

    expect_identical(expand_ellipsis(ast, 2L), expected_ast)

})

test_that("prepare_transformation_recipe() works", {

    expect_no_error(prepare_transformation_recipe(
        "a b -> b", "mean", character(), 2L
    ))

    expect_identical(
        prepare_transformation_recipe(
            "a b -> b", "mean", character(), 2L
        ),
        TransformRecipe(
            elementary_axes_lengths = rep(unknown_axis_length(), 2L),
            axis_name2elementary_axis = r2r::hashmap(),
            input_composition_known_unknown = make_unknown_composition(1:2),
            axes_permutation = c(2L, 1L),
            first_reduced_axis = 2L,
            added_axes = r2r::hashmap(),
            output_composite_axes = list(2L)
        )
    )

    expect_identical(
        prepare_transformation_recipe(
            "... c h w -> ... h w", "mean", character(), 4L
        ),
        TransformRecipe(
            elementary_axes_lengths = rep(unknown_axis_length(), 4L),
            axis_name2elementary_axis = r2r::hashmap(),
            input_composition_known_unknown = make_unknown_composition(1:4),
            axes_permutation = c(1L, 3L, 4L, 2L),
            first_reduced_axis = 4L,
            added_axes = r2r::hashmap(),
            output_composite_axes = list(1L, 3L, 4L)
        )
    )

    expect_identical(
        prepare_transformation_recipe(
            "... c h w -> ... h w", "mean", character(), 5L
        ),
        TransformRecipe(
            elementary_axes_lengths = rep(unknown_axis_length(), 5L),
            axis_name2elementary_axis = r2r::hashmap(),
            input_composition_known_unknown = make_unknown_composition(1:5),
            axes_permutation = c(1L, 2L, 4L, 5L, 3L),
            first_reduced_axis = 5L,
            added_axes = r2r::hashmap(),
            output_composite_axes = list(1L, 2L, 4L, 5L)
        )
    )

    expect_identical(
        prepare_transformation_recipe(
            "... 4 h w -> ... h w", "mean", character(), 4L
        ),
        TransformRecipe(
            elementary_axes_lengths = c(
                unknown_axis_length(), 4L, rep(unknown_axis_length(), 2L)
            ),
            axis_name2elementary_axis = r2r::hashmap(),
            input_composition_known_unknown = list(
                list(known = integer(), unknown = 1L),
                list(known = 2L, unknown = integer()),
                list(known = integer(), unknown = 3L),
                list(known = integer(), unknown = 4L)
            ),
            axes_permutation = c(1L, 3L, 4L, 2L),
            first_reduced_axis = 4L,
            added_axes = r2r::hashmap(),
            output_composite_axes = list(1L, 3L, 4L)
        )
    )

    expect_identical(
        prepare_transformation_recipe(
            "b c h w -> 1 c 1 1", "mean", character(), 4L
        ),
        TransformRecipe(
            elementary_axes_lengths = rep(unknown_axis_length(), 4L),
            axis_name2elementary_axis = r2r::hashmap(),
            input_composition_known_unknown = make_unknown_composition(1:4),
            axes_permutation = c(2L, 1L, 3L, 4L),
            first_reduced_axis = 2L,
            added_axes = r2r::hashmap(),
            output_composite_axes = list(integer(), 2L, integer(), integer())
        )
    )

    expect_identical(
        prepare_transformation_recipe(
            "a b c d e f -> a (b) (c d e) f", "rearrange", character(), 6L
        ),
        TransformRecipe(
            elementary_axes_lengths = rep(unknown_axis_length(), 6L),
            axis_name2elementary_axis = r2r::hashmap(),
            input_composition_known_unknown = make_unknown_composition(1:6),
            axes_permutation = 1:6,
            first_reduced_axis = 7L,
            added_axes = r2r::hashmap(),
            output_composite_axes = list(1L, 2L, 3:5, 6L)
        )
    )

    # TODO some test with 1's, brackets, anonymous axes, ellipses,
    # bracketted axes, etc.

    expect_identical(
        prepare_transformation_recipe(
            "b (h1 h2 h3) (w1 w2 w3) c -> (h1 w2 h3) (b w1 h2 w3) c",
            "rearrange",
            c("h2", "w2", "w3", "h3"),
            4L
        ),
        TransformRecipe(
            elementary_axes_lengths = c(
                rep(unknown_axis_length(), 2L),
                rep(expected_axis_length(), 2L),
                unknown_axis_length(),
                rep(expected_axis_length(), 2L),
                unknown_axis_length()
            ),
            axis_name2elementary_axis = r2r::hashmap(
                list("h2", 3L), list("w2", 6L), list("w3", 7L), list("h3", 4L)
            ),
            input_composition_known_unknown = list(
                list(known = integer(), unknown = 1L),
                list(known = c(3L, 4L), unknown = 2L),
                list(known = c(6L, 7L), unknown = 5L),
                list(known = integer(), unknown = 8L)
            ),
            axes_permutation = c(2L, 6L, 4L, 1L, 5L, 3L, 7L, 8L),
            first_reduced_axis = 9L,
            added_axes = r2r::hashmap(),
            output_composite_axes = list(
                c(2L, 6L, 4L), c(1L, 5L, 3L, 7L), 8L
            )
        )
    )

    expect_identical(
        prepare_transformation_recipe(
            "(b h w) c2 -> b c2 h w", "rearrange", c("b", "h", "w"), 2L
        ),
        TransformRecipe(
            elementary_axes_lengths = c(
                rep(expected_axis_length(), 3L), unknown_axis_length()
            ),
            axis_name2elementary_axis = r2r::hashmap(
                list("b", 1L), list("h", 2L), list("w", 3L)
            ),
            input_composition_known_unknown = list(
                list(known = 1:3, unknown = integer()),
                list(known = integer(), unknown = 4L)
            ),
            axes_permutation = c(1L, 4L, 2L, 3L),
            first_reduced_axis = 5L,
            added_axes = r2r::hashmap(),
            output_composite_axes = list(1L, 4L, 2L, 3L)
        )
    )

    expect_identical(
        prepare_transformation_recipe(
            "h w c -> h 5 w c", "repeat", character(), 3L
        ),
        TransformRecipe(
            elementary_axes_lengths = c(
                rep(unknown_axis_length(), 3L), 5L
            ),
            axis_name2elementary_axis = r2r::hashmap(),
            input_composition_known_unknown = make_unknown_composition(1:3),
            axes_permutation = 1:3,
            first_reduced_axis = 4L,
            added_axes = r2r::hashmap(list(2L, 4L)),
            output_composite_axes = list(1L, 4L, 2L, 3L)
        )
    )

    expect_identical(
        prepare_transformation_recipe(
            "a b c d e ->", "max", character(), 5L
        ),
        TransformRecipe(
            elementary_axes_lengths = rep(unknown_axis_length(), 5L),
            axis_name2elementary_axis = r2r::hashmap(),
            input_composition_known_unknown = make_unknown_composition(1:5),
            axes_permutation = 1:5,
            first_reduced_axis = 1L,
            added_axes = r2r::hashmap(),
            output_composite_axes = list()
        )
    )

    # TODO complicated test with 1's, anon axes, brackets, ellipses, and axes

})

Try the einops package in your browser

Any scripts or data that you put into this service are public.

einops documentation built on Sept. 9, 2025, 5:29 p.m.