Skip to content

Transform Functions

simplegrad.functions.tranform.flatten(x: Tensor, start_dim: int = 0, end_dim: int = -1) -> Tensor

Flatten a range of dimensions into a single dimension.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required
start_dim int

First dimension to flatten (inclusive). Supports negative indexing.

0
end_dim int

Last dimension to flatten (inclusive). Supports negative indexing.

-1

Returns:

Type Description
Tensor

Tensor with dimensions [start_dim, end_dim] merged into one.

Source code in simplegrad/functions/tranform.py
def flatten(x: Tensor, start_dim: int = 0, end_dim: int = -1) -> Tensor:
    """Flatten a range of dimensions into a single dimension.

    Args:
        x: Input tensor.
        start_dim: First dimension to flatten (inclusive). Supports negative indexing.
        end_dim: Last dimension to flatten (inclusive). Supports negative indexing.

    Returns:
        Tensor with dimensions ``[start_dim, end_dim]`` merged into one.
    """
    return _Flatten.apply(x, start_dim, end_dim)

simplegrad.functions.tranform.reshape(x: Tensor, new_shape: tuple[int, ...]) -> Tensor

Reshape a tensor to a new shape.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required
new_shape tuple[int, ...]

Target shape. Total number of elements must match.

required

Returns:

Type Description
Tensor

Tensor with values laid out in new_shape.

Source code in simplegrad/functions/tranform.py
def reshape(x: Tensor, new_shape: tuple[int, ...]) -> Tensor:
    """Reshape a tensor to a new shape.

    Args:
        x: Input tensor.
        new_shape: Target shape. Total number of elements must match.

    Returns:
        Tensor with values laid out in ``new_shape``.
    """
    return _Reshape.apply(x, new_shape, oper=f"reshape({new_shape})")