Skip to content

Flatten

Flatten is a Module wrapper that collapses spatial dimensions into a single feature vector, bridging convolutional layers and fully connected layers. It calls the functional flatten op under the hood and holds no learnable parameters. Use it inside a Sequential model to avoid writing a custom forward just for reshaping.

import simplegrad.nn as nn

model = nn.Sequential(
    nn.Conv2d(1, 16, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.Flatten(),                # (N, 16, 28, 28) -> (N, 12544)
    nn.Linear(12544, 10),
)

Flatten

Bases: Module

Flatten a range of tensor dimensions into a single dimension.

Parameters:

  • start_dim (int, default: 1 ) –

    First dimension to flatten (inclusive). Defaults to 1 (preserves the batch dimension).

  • end_dim (int, default: -1 ) –

    Last dimension to flatten (inclusive). Defaults to -1 (the last dimension).

Attributes

Attribute Type Description
.start_dim int First dimension to flatten. Defaults to 1 (preserves batch dim).
.end_dim int Last dimension to flatten. Defaults to -1 (all remaining).

Methods

Method Description
.forward() Flatten the input tensor from start_dim to end_dim.

Inherits all methods from Module: .parameters(), .submodules(), .to_device(), .summary(), .set_train_mode(), .set_eval_mode().