Skip to content

Function & Context

Function is the base class for every differentiable operation in simplegrad. You subclass it, implement forward (runs NumPy) and backward (returns local gradients), and call cls.apply(...) to execute. The apply method wires the result into the computation graph automatically. Context is a simple namespace used to shuttle intermediate values computed in forward through to backward.

import simplegrad as sg
from simplegrad.core import Function, Context
import numpy as np

class _Square(Function):
    oper = "Square"

    @staticmethod
    def forward(ctx: Context, x) -> np.ndarray:
        ctx.x_values = x.values
        return x.values ** 2

    @staticmethod
    def backward(ctx: Context, grad):
        return 2 * ctx.x_values * grad

x = sg.Tensor([2.0, 3.0], requires_grad=True)
y = _Square.apply(x)
y.sum().backward()
print(x.grad)  # [4. 6.]

Function

Function

Base class for differentiable operations.

Subclass this and implement forward and backward as static methods. Call cls.apply(*inputs) to run the op — it handles creating the output tensor, wiring the computation graph, and setting up gradient accumulation.

forward computes and returns the numpy result (and saves anything needed for backward into ctx). backward receives the upstream gradient and returns one gradient array per Tensor input — pure computation, no accumulation. The apply method handles accumulating those gradients into .grad via +=, including broadcast dimension reduction.

Class attributes

oper: Short label shown on graph nodes. Defaults to the class name. differentiable: Set to False for ops like argmax that have no gradient.

Class attributes

Attribute Type Description
.oper str Short label shown on computation graph nodes. Defaults to the class name.
.differentiable bool Set to False for ops with no gradient (e.g. argmax).

Methods

Method Description
.apply() Run the op, build the graph node, and wire up the backward step.
.forward() Compute the forward pass. Save anything needed for backward into ctx.
.backward() Compute gradients. Return one array per Tensor input.
.output_shape() Infer the output shape from inputs without executing the op.

Context

Context

Stores intermediate values computed during a forward pass for reuse in backward.

Every op that needs to carry state from forward to backward should create a Context, write to it inside the forward lambda, and read from it inside the backward function. This pattern works in both eager and lazy mode: in eager mode the forward lambda runs immediately; in lazy mode it runs at .realize() time — either way, the backward always runs after the forward, so ctx attributes are always populated by the time they are read.

Two attributes are set automatically by :meth:Function.apply before forward is called:

device: Device string of the operation's tensors (e.g. ``"cpu"``).
backend: The compute module for this device — either :mod:`numpy`
    or :mod:`cupy`. Forward and backward methods should alias this
    as ``xp = ctx.backend`` and use ``xp.*`` instead of ``np.*``.

Additional attributes are set freely with dot notation — use whatever names are meaningful for the op.

Attributes

Attribute Type Description
.device str Device string of the operation's tensors (e.g. "cpu"). Set automatically by apply.
.backend module Compute module — numpy or cupy. Alias as xp = ctx.backend in ops.