Skip to content

Optimizer

simplegrad.core.optimizer.Optimizer

Base class for all optimizers.

Subclasses must implement step() to define the parameter update rule.

Source code in simplegrad/core/optimizer.py
class Optimizer:
    """Base class for all optimizers.

    Subclasses must implement `step()` to define the parameter update rule.
    """

    def __init__(self, lr: float, model):
        self.step_count = 0

        if lr is None:
            raise ValueError("Learning rate (lr) must be provided.")
        if model is None:
            raise ValueError("Model must be provided.")

        self.lr = lr
        self.model = model

    def zero_grad(self):
        """Zero gradients for all model parameters."""
        for _, param in self.model.parameters().items():
            param.grad = np.zeros_like(param.values)

    def step(self):
        """Perform a single optimization step. Must be implemented by subclasses."""
        raise NotImplementedError("step() method is not implemented.")

    def reset_step_count(self):
        """Reset the internal step counter to zero."""
        self.step_count = 0

    def set_lr(self, new_lr: float):
        """Set a new learning rate."""
        self.lr = new_lr

reset_step_count()

Reset the internal step counter to zero.

Source code in simplegrad/core/optimizer.py
def reset_step_count(self):
    """Reset the internal step counter to zero."""
    self.step_count = 0

set_lr(new_lr: float)

Set a new learning rate.

Source code in simplegrad/core/optimizer.py
def set_lr(self, new_lr: float):
    """Set a new learning rate."""
    self.lr = new_lr

step()

Perform a single optimization step. Must be implemented by subclasses.

Source code in simplegrad/core/optimizer.py
def step(self):
    """Perform a single optimization step. Must be implemented by subclasses."""
    raise NotImplementedError("step() method is not implemented.")

zero_grad()

Zero gradients for all model parameters.

Source code in simplegrad/core/optimizer.py
def zero_grad(self):
    """Zero gradients for all model parameters."""
    for _, param in self.model.parameters().items():
        param.grad = np.zeros_like(param.values)