Transform Functions
Shape-transform functions rearrange tensor data without changing its values, enabling tensors to flow between layers that expect different shapes. flatten collapses one or more dimensions into a single dimension (the typical operation before a Linear layer in a CNN), while reshape gives full control over the output shape. Both are differentiable — gradients are simply reshaped back during backpropagation.
import simplegrad as sg
x = sg.normal((4, 16, 7, 7), requires_grad=True) # (N, C, H, W)
out = sg.flatten(x, start_dim=1) # (4, 784)
flatten(x: Tensor, start_dim: int = 0, end_dim: int = -1) -> Tensor
Flatten a range of dimensions into a single dimension.
Parameters:
-
x(Tensor) –Input tensor.
-
start_dim(int, default:0) –First dimension to flatten (inclusive). Supports negative indexing.
-
end_dim(int, default:-1) –Last dimension to flatten (inclusive). Supports negative indexing.
reshape(x: Tensor, new_shape: tuple[int, ...]) -> Tensor
Reshape a tensor to a new shape.
Parameters:
-
x(Tensor) –Input tensor.
-
new_shape(tuple[int, ...]) –Target shape. Total number of elements must match.