Base class for all optimizers.
Subclasses must implement step() to define the parameter update rule.
Source code in simplegrad/core/optimizer.py
| class Optimizer:
"""Base class for all optimizers.
Subclasses must implement `step()` to define the parameter update rule.
"""
def __init__(self, lr: float, model):
self.step_count = 0
if lr is None:
raise ValueError("Learning rate (lr) must be provided.")
if model is None:
raise ValueError("Model must be provided.")
self.lr = lr
self.model = model
def zero_grad(self):
"""Zero gradients for all model parameters."""
for _, param in self.model.parameters().items():
param.grad = np.zeros_like(param.values)
def step(self):
"""Perform a single optimization step. Must be implemented by subclasses."""
raise NotImplementedError("step() method is not implemented.")
def reset_step_count(self):
"""Reset the internal step counter to zero."""
self.step_count = 0
def set_lr(self, new_lr: float):
"""Set a new learning rate."""
self.lr = new_lr
|
reset_step_count()
Reset the internal step counter to zero.
Source code in simplegrad/core/optimizer.py
| def reset_step_count(self):
"""Reset the internal step counter to zero."""
self.step_count = 0
|
set_lr(new_lr: float)
Set a new learning rate.
Source code in simplegrad/core/optimizer.py
| def set_lr(self, new_lr: float):
"""Set a new learning rate."""
self.lr = new_lr
|
step()
Perform a single optimization step. Must be implemented by subclasses.
Source code in simplegrad/core/optimizer.py
| def step(self):
"""Perform a single optimization step. Must be implemented by subclasses."""
raise NotImplementedError("step() method is not implemented.")
|
zero_grad()
Zero gradients for all model parameters.
Source code in simplegrad/core/optimizer.py
| def zero_grad(self):
"""Zero gradients for all model parameters."""
for _, param in self.model.parameters().items():
param.grad = np.zeros_like(param.values)
|