Reduction Operations
simplegrad.functions.reduction.sum(x: Tensor, dim: int | None = None) -> Tensor
simplegrad.functions.reduction.mean(x: Tensor, dim: int | None = None) -> Tensor
Compute the mean of tensor elements along a dimension.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor. |
required |
dim
|
int | None
|
Dimension to reduce. If None, averages all elements. |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Reduced tensor. |
Source code in simplegrad/functions/reduction.py
simplegrad.functions.reduction.trace(x: Tensor) -> Tensor
Compute the trace (sum of diagonal elements) of a square matrix.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
2D square tensor. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Scalar tensor containing the trace. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If x is not a 2D square tensor (checked in eager mode). |
Source code in simplegrad/functions/reduction.py
simplegrad.functions.reduction.argmax(x: Tensor, dim: int | None = None, dtype: str = 'int32') -> Tensor
Return indices of maximum values along a dimension.
Not differentiable — comp_grad is always False on the output.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor. |
required |
dim
|
int | None
|
Dimension to reduce. If None, returns the flat index. |
None
|
dtype
|
str
|
Integer dtype for the output indices. |
'int32'
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Integer tensor of argmax indices. |
Source code in simplegrad/functions/reduction.py
simplegrad.functions.reduction.argmin(x: Tensor, dim: int | None = None, dtype: str = 'int32') -> Tensor
Return indices of minimum values along a dimension.
Not differentiable — comp_grad is always False on the output.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor. |
required |
dim
|
int | None
|
Dimension to reduce. If None, returns the flat index. |
None
|
dtype
|
str
|
Integer dtype for the output indices. |
'int32'
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Integer tensor of argmin indices. |