Skip to content

Reduction Operations

Reduction operations collapse one or more axes of a tensor into a scalar or lower-dimensional tensor. sum and mean are differentiable and commonly used to produce a scalar loss from a batch of per-sample values. argmax and argmin return indices rather than values and are therefore not differentiable — they are used for computing accuracy metrics during evaluation.

import simplegrad as sg

x = sg.normal((4, 10), requires_grad=True)
loss = sg.mean(x)
loss.backward()

preds = sg.argmax(x, axis=1)  # predicted class indices

sum(x: Tensor, dim: int | None = None) -> Tensor

Sum tensor elements along a dimension (keepdims=True).

Parameters:

  • x (Tensor) –

    Input tensor.

  • dim (int | None, default: None ) –

    Dimension to reduce. If None, sums all elements.

mean(x: Tensor, dim: int | None = None) -> Tensor

Compute the mean of tensor elements along a dimension.

Parameters:

  • x (Tensor) –

    Input tensor.

  • dim (int | None, default: None ) –

    Dimension to reduce. If None, averages all elements.

trace(x: Tensor) -> Tensor

Compute the trace (sum of diagonal elements) of a square matrix.

Parameters:

  • x (Tensor) –

    2D square tensor.

Raises:

  • ValueError

    If x is not a 2D square tensor (checked in eager mode).

argmax(x: Tensor, dim: int | None = None, dtype: str = 'int32') -> Tensor

Return indices of maximum values along a dimension.

Not differentiable — comp_grad is always False on the output.

Parameters:

  • x (Tensor) –

    Input tensor.

  • dim (int | None, default: None ) –

    Dimension to reduce. If None, returns the flat index.

  • dtype (str, default: 'int32' ) –

    Integer dtype for the output indices.

argmin(x: Tensor, dim: int | None = None, dtype: str = 'int32') -> Tensor

Return indices of minimum values along a dimension.

Not differentiable — comp_grad is always False on the output.

Parameters:

  • x (Tensor) –

    Input tensor.

  • dim (int | None, default: None ) –

    Dimension to reduce. If None, returns the flat index.

  • dtype (str, default: 'int32' ) –

    Integer dtype for the output indices.