nn_flatten: Flattens a contiguous range of dims into a tensor.

nn_flattenR Documentation

Flattens a contiguous range of dims into a tensor.

Description

For use with nn_sequential.

Usage

nn_flatten(start_dim = 2, end_dim = -1)

Arguments

start_dim

first dim to flatten (default = 2).

end_dim

last dim to flatten (default = -1).

Shape

  • Input: ⁠(*, S_start,..., S_i, ..., S_end, *)⁠, where S_i is the size at dimension i and * means any number of dimensions including none.

  • Output: ⁠(*, S_start*...*S_i*...S_end, *)⁠.

See Also

nn_unflatten

Examples

if (torch_is_installed()) {
input <- torch_randn(32, 1, 5, 5)
m <- nn_flatten()
m(input)
}

torch documentation built on June 7, 2023, 6:19 p.m.