Bases: Module
Flatten a range of tensor dimensions into a single dimension.
Parameters:
| Name |
Type |
Description |
Default |
start_dim
|
int
|
First dimension to flatten (inclusive). Defaults to 1
(preserves the batch dimension).
|
1
|
end_dim
|
int
|
Last dimension to flatten (inclusive). Defaults to -1
(the last dimension).
|
-1
|
Source code in simplegrad/nn/transform.py
| class Flatten(Module):
"""Flatten a range of tensor dimensions into a single dimension.
Args:
start_dim: First dimension to flatten (inclusive). Defaults to 1
(preserves the batch dimension).
end_dim: Last dimension to flatten (inclusive). Defaults to -1
(the last dimension).
"""
def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None:
super().__init__()
self.start_dim = start_dim
self.end_dim = end_dim
def forward(self, x: Tensor) -> Tensor:
"""Flatten the input tensor.
Args:
x: Input tensor.
Returns:
Tensor with dimensions ``[start_dim, end_dim]`` merged into one.
"""
return flatten(x, self.start_dim, self.end_dim)
def __str__(self):
return f"Flatten(start_dim={self.start_dim}, end_dim={self.end_dim})"
|
Flatten the input tensor.
Parameters:
| Name |
Type |
Description |
Default |
x
|
Tensor
|
|
required
|
Returns:
| Type |
Description |
Tensor
|
Tensor with dimensions [start_dim, end_dim] merged into one.
|
Source code in simplegrad/nn/transform.py
| def forward(self, x: Tensor) -> Tensor:
"""Flatten the input tensor.
Args:
x: Input tensor.
Returns:
Tensor with dimensions ``[start_dim, end_dim]`` merged into one.
"""
return flatten(x, self.start_dim, self.end_dim)
|