Skip to content

Pooling

simplegrad.functions.pooling.max_pool2d(x: Tensor, kernel_size: int | tuple[int, int], stride: int | tuple[int, int] = None, pad_width: int | tuple[int, int, int] = 0, pad_mode: str = 'constant', pad_value: int = 0) -> Tensor

Apply 2D max pooling over the input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch, channels, H, W) or (channels, H, W).

required
kernel_size int | tuple[int, int]

Pooling window size. Int or (kH, kW).

required
stride int | tuple[int, int]

Step between pooling windows. Int or (sH, sW). Defaults to kernel_size if not specified.

None
pad_width int | tuple[int, int, int]

Padding before pooling. Int (all sides) or (top, bottom, left, right).

0
pad_mode str

Padding mode. Defaults to "constant".

'constant'
pad_value int

Fill value for constant padding. Defaults to 0.

0

Returns:

Type Description
Tensor

Output tensor of shape (batch, channels, out_H, out_W).

Source code in simplegrad/functions/pooling.py
def max_pool2d(
    x: Tensor,
    kernel_size: int | tuple[int, int],
    stride: int | tuple[int, int] = None,
    pad_width: int | tuple[int, int, int] = 0,
    pad_mode: str = "constant",
    pad_value: int = 0,
) -> Tensor:
    """Apply 2D max pooling over the input tensor.

    Args:
        x: Input tensor of shape ``(batch, channels, H, W)`` or ``(channels, H, W)``.
        kernel_size: Pooling window size. Int or ``(kH, kW)``.
        stride: Step between pooling windows. Int or ``(sH, sW)``. Defaults to
            ``kernel_size`` if not specified.
        pad_width: Padding before pooling. Int (all sides) or ``(top, bottom, left, right)``.
        pad_mode: Padding mode. Defaults to ``"constant"``.
        pad_value: Fill value for constant padding. Defaults to 0.

    Returns:
        Output tensor of shape ``(batch, channels, out_H, out_W)``.
    """
    assert (
        len(x.shape) == 4 or len(x.shape) == 3
    ), "Input tensor must be 4D (batch, channels, H, W) or 3D (channels, H, W)"

    kh, kw = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
    if stride is None:
        sh, sw = kh, kw
    elif isinstance(stride, int):
        sh, sw = stride, stride
    else:
        sh, sw = stride

    if pad_width == 0 or pad_width == (0, 0, 0, 0):
        padded_input = x
    else:
        if isinstance(pad_width, int):
            pad_width_np = ((0, 0), (0, 0), (pad_width, pad_width), (pad_width, pad_width))
        elif isinstance(pad_width, tuple) and len(pad_width) == 4:
            pad_width_np = (
                (0, 0),
                (0, 0),
                (pad_width[0], pad_width[1]),
                (pad_width[2], pad_width[3]),
            )
        else:
            raise ValueError("pad_width must be an int or tuple of 4 ints")
        padded_input = pad(x=x, width=pad_width_np, mode=pad_mode, value=pad_value)

    return _MaxPool2d.apply(padded_input, kh, kw, sh, sw)