Skip to content

Flatten

simplegrad.nn.transform.Flatten

Bases: Module

Flatten a range of tensor dimensions into a single dimension.

Parameters:

Name Type Description Default
start_dim int

First dimension to flatten (inclusive). Defaults to 1 (preserves the batch dimension).

1
end_dim int

Last dimension to flatten (inclusive). Defaults to -1 (the last dimension).

-1
Source code in simplegrad/nn/transform.py
class Flatten(Module):
    """Flatten a range of tensor dimensions into a single dimension.

    Args:
        start_dim: First dimension to flatten (inclusive). Defaults to 1
            (preserves the batch dimension).
        end_dim: Last dimension to flatten (inclusive). Defaults to -1
            (the last dimension).
    """

    def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None:
        super().__init__()
        self.start_dim = start_dim
        self.end_dim = end_dim

    def forward(self, x: Tensor) -> Tensor:
        """Flatten the input tensor.

        Args:
            x: Input tensor.

        Returns:
            Tensor with dimensions ``[start_dim, end_dim]`` merged into one.
        """
        return flatten(x, self.start_dim, self.end_dim)

    def __str__(self):
        return f"Flatten(start_dim={self.start_dim}, end_dim={self.end_dim})"

forward(x: Tensor) -> Tensor

Flatten the input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required

Returns:

Type Description
Tensor

Tensor with dimensions [start_dim, end_dim] merged into one.

Source code in simplegrad/nn/transform.py
def forward(self, x: Tensor) -> Tensor:
    """Flatten the input tensor.

    Args:
        x: Input tensor.

    Returns:
        Tensor with dimensions ``[start_dim, end_dim]`` merged into one.
    """
    return flatten(x, self.start_dim, self.end_dim)