Skip to content

Experiment Tracking

simplegrad.track.tracker.Tracker

Tracker class to manage experiments and runs using ExperimentDBManager. It allows setting experiment, starting and ending runs.

There is a main directory where all experiment databases are stored. One database corresponds to one experiment. Each experiment can have multiple runs, each with its own metrics and computational graphs.

Source code in simplegrad/track/tracker.py
class Tracker:
    """
    Tracker class to manage experiments and runs using ExperimentDBManager.
    It allows setting experiment, starting and ending runs.

    There is a main directory where all experiment databases are stored.
    One database corresponds to one experiment.
    Each experiment can have multiple runs, each with its own metrics and computational graphs.
    """

    def __init__(self, all_exp_dir: str = "./experiments"):
        self.all_exp_dir = Path(all_exp_dir)
        self.all_exp_dir.mkdir(
            parents=True, exist_ok=True
        )  # Create the experiments directory if it doesn't exist

        self.cur_exp_path = None
        self.db_manager = None
        self.current_run_id = None
        self.current_run_name = None

    def get_all_exp_paths(self) -> list[Path]:
        """Get all experiment database paths."""
        if not self.all_exp_dir.exists():
            return []
        return list(self.all_exp_dir.glob("*.db"))

    def set_all_exp_dir(self, directory: str):
        """Set the experiments directory"""
        self.all_experiments_dir = Path(directory)

    def set_experiment(self, exp_name: str):
        """Set the current experiment by name, initializing its database manager."""
        db_name = exp_name if exp_name.endswith(".db") else f"{exp_name}.db"
        self.cur_exp_path = self.all_exp_dir / db_name
        self.db_manager = ExperimentDBManager(db_path=self.cur_exp_path)
        if self.db_manager.check_connection():
            print(f"Connected to existing experiment database at {self.cur_exp_path}")
        else:
            self.db_manager.init_exp_db()

    def start_run(self, name: str | None = None, config: dict | None = None) -> int:
        """Start a new run and return the run_id"""
        self.current_run_id = self.db_manager.create_run(name=name, config=config)
        self.current_run_name = name or f"run_{self.current_run_id}"
        return self.current_run_id

    def record(self, metric_name: str, value: float, step: int):
        """Log a metric value at a given step"""
        if self.current_run_id is None:
            raise RuntimeError("No active run. Call start_run() first.")
        self.db_manager.record(self.current_run_id, metric_name, step, value)

    def end_run(self, status: str = "completed"):
        """End the current run with a given status"""
        if self.current_run_id is None:
            raise RuntimeError("No active run. Call start_run() first.")
        id = self.current_run_id
        self.db_manager.update_run_status(self.current_run_id, status)
        self.current_run_id = None
        self.current_run_name = None
        return id

    def get_all_runs(self) -> list[RunInfo]:
        """Get all runs"""
        return self.db_manager.get_all_runs()

    def get_run(self, run_id: int) -> RunInfo | None:
        """Get a specific run by id"""
        return self.db_manager.get_run(run_id)

    def delete_run(self, run_id: int):
        """Delete a run and all its data"""
        self.db_manager.delete_run(run_id)

    def get_metrics(self, run_id: int) -> list[str]:
        """Get all metric names for a given run"""
        return self.db_manager.get_metrics(run_id)

    def get_records(self, run_id: int, metric_name: str) -> list[RecordInfo]:
        """Get metric records for a given run and optional metric name"""
        return self.db_manager.get_records(run_id, metric_name)

    def get_results(self, run_id: int) -> dict[str, list[RecordInfo]]:
        """Get all metric records for a given run"""
        metrics = self.get_metrics(run_id)
        results = {metric: self.get_records(run_id, metric) for metric in metrics}
        return results

    def save_comp_graph(self, tensor: Tensor, run_id: int | None = None):
        """Save computation graph for the current run"""
        id = run_id
        if id is None:
            id = self.current_run_id
        if id is None:
            raise RuntimeError("No active run. Call start_run() first.")
        print(f"Saving computation graph for run {id}...")
        graph_data = _build_graph_data(tensor)
        self.db_manager.save_comp_graph(run_id=id, graph_data=graph_data)

    def get_comp_graph(self, graph_id: int) -> dict | None:
        """Get computation graph for a given run"""
        return self.db_manager.get_comp_graph(graph_id)

    def get_comp_graphs(self, run_id: int) -> list[dict]:
        """Get all computation graphs for a given run"""
        return self.db_manager.get_comp_graphs(run_id)

delete_run(run_id: int)

Delete a run and all its data

Source code in simplegrad/track/tracker.py
def delete_run(self, run_id: int):
    """Delete a run and all its data"""
    self.db_manager.delete_run(run_id)

end_run(status: str = 'completed')

End the current run with a given status

Source code in simplegrad/track/tracker.py
def end_run(self, status: str = "completed"):
    """End the current run with a given status"""
    if self.current_run_id is None:
        raise RuntimeError("No active run. Call start_run() first.")
    id = self.current_run_id
    self.db_manager.update_run_status(self.current_run_id, status)
    self.current_run_id = None
    self.current_run_name = None
    return id

get_all_exp_paths() -> list[Path]

Get all experiment database paths.

Source code in simplegrad/track/tracker.py
def get_all_exp_paths(self) -> list[Path]:
    """Get all experiment database paths."""
    if not self.all_exp_dir.exists():
        return []
    return list(self.all_exp_dir.glob("*.db"))

get_all_runs() -> list[RunInfo]

Get all runs

Source code in simplegrad/track/tracker.py
def get_all_runs(self) -> list[RunInfo]:
    """Get all runs"""
    return self.db_manager.get_all_runs()

get_comp_graph(graph_id: int) -> dict | None

Get computation graph for a given run

Source code in simplegrad/track/tracker.py
def get_comp_graph(self, graph_id: int) -> dict | None:
    """Get computation graph for a given run"""
    return self.db_manager.get_comp_graph(graph_id)

get_comp_graphs(run_id: int) -> list[dict]

Get all computation graphs for a given run

Source code in simplegrad/track/tracker.py
def get_comp_graphs(self, run_id: int) -> list[dict]:
    """Get all computation graphs for a given run"""
    return self.db_manager.get_comp_graphs(run_id)

get_metrics(run_id: int) -> list[str]

Get all metric names for a given run

Source code in simplegrad/track/tracker.py
def get_metrics(self, run_id: int) -> list[str]:
    """Get all metric names for a given run"""
    return self.db_manager.get_metrics(run_id)

get_records(run_id: int, metric_name: str) -> list[RecordInfo]

Get metric records for a given run and optional metric name

Source code in simplegrad/track/tracker.py
def get_records(self, run_id: int, metric_name: str) -> list[RecordInfo]:
    """Get metric records for a given run and optional metric name"""
    return self.db_manager.get_records(run_id, metric_name)

get_results(run_id: int) -> dict[str, list[RecordInfo]]

Get all metric records for a given run

Source code in simplegrad/track/tracker.py
def get_results(self, run_id: int) -> dict[str, list[RecordInfo]]:
    """Get all metric records for a given run"""
    metrics = self.get_metrics(run_id)
    results = {metric: self.get_records(run_id, metric) for metric in metrics}
    return results

get_run(run_id: int) -> RunInfo | None

Get a specific run by id

Source code in simplegrad/track/tracker.py
def get_run(self, run_id: int) -> RunInfo | None:
    """Get a specific run by id"""
    return self.db_manager.get_run(run_id)

record(metric_name: str, value: float, step: int)

Log a metric value at a given step

Source code in simplegrad/track/tracker.py
def record(self, metric_name: str, value: float, step: int):
    """Log a metric value at a given step"""
    if self.current_run_id is None:
        raise RuntimeError("No active run. Call start_run() first.")
    self.db_manager.record(self.current_run_id, metric_name, step, value)

save_comp_graph(tensor: Tensor, run_id: int | None = None)

Save computation graph for the current run

Source code in simplegrad/track/tracker.py
def save_comp_graph(self, tensor: Tensor, run_id: int | None = None):
    """Save computation graph for the current run"""
    id = run_id
    if id is None:
        id = self.current_run_id
    if id is None:
        raise RuntimeError("No active run. Call start_run() first.")
    print(f"Saving computation graph for run {id}...")
    graph_data = _build_graph_data(tensor)
    self.db_manager.save_comp_graph(run_id=id, graph_data=graph_data)

set_all_exp_dir(directory: str)

Set the experiments directory

Source code in simplegrad/track/tracker.py
def set_all_exp_dir(self, directory: str):
    """Set the experiments directory"""
    self.all_experiments_dir = Path(directory)

set_experiment(exp_name: str)

Set the current experiment by name, initializing its database manager.

Source code in simplegrad/track/tracker.py
def set_experiment(self, exp_name: str):
    """Set the current experiment by name, initializing its database manager."""
    db_name = exp_name if exp_name.endswith(".db") else f"{exp_name}.db"
    self.cur_exp_path = self.all_exp_dir / db_name
    self.db_manager = ExperimentDBManager(db_path=self.cur_exp_path)
    if self.db_manager.check_connection():
        print(f"Connected to existing experiment database at {self.cur_exp_path}")
    else:
        self.db_manager.init_exp_db()

start_run(name: str | None = None, config: dict | None = None) -> int

Start a new run and return the run_id

Source code in simplegrad/track/tracker.py
def start_run(self, name: str | None = None, config: dict | None = None) -> int:
    """Start a new run and return the run_id"""
    self.current_run_id = self.db_manager.create_run(name=name, config=config)
    self.current_run_name = name or f"run_{self.current_run_id}"
    return self.current_run_id

simplegrad.track.exp_db_manager.ExperimentDBManager

SQLite-based storage for training runs and metrics.

Source code in simplegrad/track/exp_db_manager.py
class ExperimentDBManager:
    """SQLite-based storage for training runs and metrics."""

    def __init__(self, db_path: Path):
        self.db_path = db_path

    @contextmanager
    def _get_connection(self, readonly: bool = False):
        """Get a database connection with proper cleanup."""
        conn = sqlite3.connect(self.db_path, timeout=10.0)
        conn.row_factory = sqlite3.Row
        if readonly:
            conn.isolation_level = None  # Autocommit mode
        try:
            yield conn
            if not readonly:
                conn.commit()
        except Exception:
            if not readonly:
                conn.rollback()
            raise
        finally:
            conn.close()

    def check_connection(self) -> bool:
        """Check if the database exists and is accessible."""
        if not self.db_path.exists():
            return False
        try:
            with self._get_connection() as conn:
                conn.execute("SELECT 1 FROM sqlite_master LIMIT 1")
            return True
        except (sqlite3.DatabaseError, sqlite3.OperationalError):
            return False

    def init_exp_db(self):
        """Initialize database schema."""
        with self._get_connection() as conn:
            conn.executescript("""
                CREATE TABLE IF NOT EXISTS runs (
                    id INTEGER PRIMARY KEY AUTOINCREMENT,
                    name TEXT NOT NULL,
                    created_at REAL NOT NULL,
                    status TEXT NOT NULL DEFAULT 'running' CHECK(status IN ('running', 'completed', 'failed')),
                    config TEXT DEFAULT '{}'
                );

                CREATE TABLE IF NOT EXISTS metrics (
                    id INTEGER PRIMARY KEY AUTOINCREMENT,
                    name TEXT NOT NULL UNIQUE
                );

                CREATE TABLE IF NOT EXISTS records (
                    id INTEGER PRIMARY KEY AUTOINCREMENT,
                    run_id INTEGER NOT NULL,
                    metric_id INTEGER NOT NULL,
                    step INTEGER NOT NULL,
                    value REAL NOT NULL,
                    wall_time REAL NOT NULL,
                    FOREIGN KEY (run_id) REFERENCES runs(id),
                    FOREIGN KEY (metric_id) REFERENCES metrics(id)
                );

                CREATE INDEX IF NOT EXISTS idx_records_run_step 
                ON records(run_id, step);

                CREATE TABLE IF NOT EXISTS graphs (
                    id INTEGER PRIMARY KEY AUTOINCREMENT,
                    run_id INTEGER NOT NULL UNIQUE,
                    graph_json TEXT NOT NULL,
                    created_at REAL NOT NULL,
                    FOREIGN KEY (run_id) REFERENCES runs(id)
                );
            """)

    def create_run(self, name: str | None = None, config: dict | None = None) -> int:
        """Create a new training run. Returns run_id."""
        created_at = time.time()
        config = config or {}
        name = name or f"run_{int(created_at)}"

        with self._get_connection() as conn:
            cursor = conn.execute(
                "INSERT INTO runs (name, created_at, status, config) VALUES (?, ?, ?, ?)",
                (name, created_at, "running", json.dumps(config)),
            )
            run_id = cursor.lastrowid

        return run_id

    def get_run(self, run_id: int) -> RunInfo | None:
        """Get run metadata."""
        with self._get_connection(readonly=True) as conn:
            row = conn.execute("SELECT * FROM runs WHERE id = ?", (run_id,)).fetchone()
            if row:
                metrics = None
                num_records = None

                if row["status"] == "completed":
                    # Get metrics for this run
                    metric_rows = conn.execute(
                        """SELECT DISTINCT m.name 
                           FROM records r 
                           JOIN metrics m ON r.metric_id = m.id 
                           WHERE r.run_id = ?""",
                        (run_id,),
                    ).fetchall()
                    metrics = [m["name"] for m in metric_rows]

                    # Get record counts per metric
                    record_count_rows = conn.execute(
                        """SELECT m.name, COUNT(*) as count
                           FROM records r
                           JOIN metrics m ON r.metric_id = m.id
                           WHERE r.run_id = ?
                           GROUP BY m.name""",
                        (run_id,),
                    ).fetchall()
                    num_records = [rc["count"] for rc in record_count_rows]

                return RunInfo(
                    run_id=row["id"],
                    name=row["name"],
                    created_at=_format_timestamp(row["created_at"]),
                    status=row["status"],
                    config=json.loads(row["config"]),
                    metrics=metrics,
                    num_records=num_records,
                )
        return None

    def get_all_runs(self) -> list[RunInfo]:
        """List all runs, newest first."""
        with self._get_connection(readonly=True) as conn:
            rows = conn.execute("SELECT * FROM runs ORDER BY created_at DESC").fetchall()
            runs = []
            for row in rows:
                metrics = None
                num_records = None

                if row["status"] == "completed":
                    # Get metrics for this run
                    metric_rows = conn.execute(
                        """SELECT DISTINCT m.name 
                           FROM records r 
                           JOIN metrics m ON r.metric_id = m.id 
                           WHERE r.run_id = ?""",
                        (row["id"],),
                    ).fetchall()
                    metrics = [m["name"] for m in metric_rows]

                    # Get record counts per metric
                    record_count_rows = conn.execute(
                        """SELECT m.name, COUNT(*) as count
                           FROM records r
                           JOIN metrics m ON r.metric_id = m.id
                           WHERE r.run_id = ?
                           GROUP BY m.name""",
                        (row["id"],),
                    ).fetchall()
                    num_records = [rc["count"] for rc in record_count_rows]

                runs.append(
                    RunInfo(
                        run_id=row["id"],
                        name=row["name"],
                        created_at=_format_timestamp(row["created_at"]),
                        status=row["status"],
                        config=json.loads(row["config"]),
                        metrics=metrics,
                        num_records=num_records,
                    )
                )
            return runs

    def update_run_status(self, run_id: int, status: str):
        """Update run status."""
        with self._get_connection() as conn:
            conn.execute("UPDATE runs SET status = ? WHERE id = ?", (status, run_id))

    def delete_run(self, run_id: int):
        """Delete a run and all its data."""
        with self._get_connection() as conn:
            conn.execute("DELETE FROM records WHERE run_id = ?", (run_id,))
            conn.execute("DELETE FROM graphs WHERE run_id = ?", (run_id,))
            conn.execute("DELETE FROM runs WHERE id = ?", (run_id,))

    def record(self, run_id: int, metric_name: str, step: int, value: float):
        """Log a single metric record."""
        log_time = time.time()

        with self._get_connection() as conn:
            conn.execute("INSERT OR IGNORE INTO metrics (name) VALUES (?)", (metric_name,))
            metric_id = conn.execute(
                "SELECT id FROM metrics WHERE name = ?", (metric_name,)
            ).fetchone()["id"]
            conn.execute(
                "INSERT INTO records (run_id, metric_id, step, value, wall_time) VALUES (?, ?, ?, ?, ?)",
                (run_id, metric_id, step, value, log_time),
            )

    def get_records(self, run_id: int, metric_name: str) -> list[RecordInfo]:
        """Get metric records for a run. Returns {metric_name: [MetricRecord, ...]}"""
        with self._get_connection(readonly=True) as conn:
            rows = conn.execute(
                """SELECT m.name as metric_name, r.step, r.value, r.wall_time 
                    FROM records r 
                    JOIN metrics m ON r.metric_id = m.id 
                    WHERE r.run_id = ? AND m.name = ? 
                    ORDER BY r.step""",
                (run_id, metric_name),
            ).fetchall()

            result: list[RecordInfo] = [
                RecordInfo(step=row["step"], value=row["value"], log_time=row["wall_time"])
                for row in rows
            ]
            return result

    def get_metrics(self, run_id: int) -> list[str]:
        """Get list of metric names for a run."""
        with self._get_connection(readonly=True) as conn:
            rows = conn.execute(
                """SELECT DISTINCT m.name 
                   FROM records r 
                   JOIN metrics m ON r.metric_id = m.id 
                   WHERE r.run_id = ?""",
                (run_id,),
            ).fetchall()
            return [row["name"] for row in rows]

    def save_comp_graph(self, run_id: int, graph_data: dict):
        """Save computation graph as JSON."""
        with self._get_connection() as conn:
            conn.execute(
                "INSERT OR REPLACE INTO graphs (run_id, graph_json, created_at) VALUES (?, ?, ?)",
                (run_id, json.dumps(graph_data), time.time()),
            )

    def get_comp_graph(self, graph_id: int) -> dict | None:
        """Get a single computation graph by its ID."""
        with self._get_connection(readonly=True) as conn:
            row = conn.execute("SELECT graph_json FROM graphs WHERE id = ?", (graph_id,)).fetchone()
            if row:
                return json.loads(row["graph_json"])
        return None

    def get_comp_graphs(self, run_id: int) -> list[dict]:
        """Get all computation graphs for a run."""
        with self._get_connection(readonly=True) as conn:
            rows = conn.execute(
                "SELECT id, graph_json, created_at FROM graphs WHERE run_id = ? ORDER BY created_at",
                (run_id,),
            ).fetchall()
            return [
                {
                    "id": row["id"],
                    "graph": json.loads(row["graph_json"]),
                    "created_at": row["created_at"],
                }
                for row in rows
            ]

check_connection() -> bool

Check if the database exists and is accessible.

Source code in simplegrad/track/exp_db_manager.py
def check_connection(self) -> bool:
    """Check if the database exists and is accessible."""
    if not self.db_path.exists():
        return False
    try:
        with self._get_connection() as conn:
            conn.execute("SELECT 1 FROM sqlite_master LIMIT 1")
        return True
    except (sqlite3.DatabaseError, sqlite3.OperationalError):
        return False

create_run(name: str | None = None, config: dict | None = None) -> int

Create a new training run. Returns run_id.

Source code in simplegrad/track/exp_db_manager.py
def create_run(self, name: str | None = None, config: dict | None = None) -> int:
    """Create a new training run. Returns run_id."""
    created_at = time.time()
    config = config or {}
    name = name or f"run_{int(created_at)}"

    with self._get_connection() as conn:
        cursor = conn.execute(
            "INSERT INTO runs (name, created_at, status, config) VALUES (?, ?, ?, ?)",
            (name, created_at, "running", json.dumps(config)),
        )
        run_id = cursor.lastrowid

    return run_id

delete_run(run_id: int)

Delete a run and all its data.

Source code in simplegrad/track/exp_db_manager.py
def delete_run(self, run_id: int):
    """Delete a run and all its data."""
    with self._get_connection() as conn:
        conn.execute("DELETE FROM records WHERE run_id = ?", (run_id,))
        conn.execute("DELETE FROM graphs WHERE run_id = ?", (run_id,))
        conn.execute("DELETE FROM runs WHERE id = ?", (run_id,))

get_all_runs() -> list[RunInfo]

List all runs, newest first.

Source code in simplegrad/track/exp_db_manager.py
def get_all_runs(self) -> list[RunInfo]:
    """List all runs, newest first."""
    with self._get_connection(readonly=True) as conn:
        rows = conn.execute("SELECT * FROM runs ORDER BY created_at DESC").fetchall()
        runs = []
        for row in rows:
            metrics = None
            num_records = None

            if row["status"] == "completed":
                # Get metrics for this run
                metric_rows = conn.execute(
                    """SELECT DISTINCT m.name 
                       FROM records r 
                       JOIN metrics m ON r.metric_id = m.id 
                       WHERE r.run_id = ?""",
                    (row["id"],),
                ).fetchall()
                metrics = [m["name"] for m in metric_rows]

                # Get record counts per metric
                record_count_rows = conn.execute(
                    """SELECT m.name, COUNT(*) as count
                       FROM records r
                       JOIN metrics m ON r.metric_id = m.id
                       WHERE r.run_id = ?
                       GROUP BY m.name""",
                    (row["id"],),
                ).fetchall()
                num_records = [rc["count"] for rc in record_count_rows]

            runs.append(
                RunInfo(
                    run_id=row["id"],
                    name=row["name"],
                    created_at=_format_timestamp(row["created_at"]),
                    status=row["status"],
                    config=json.loads(row["config"]),
                    metrics=metrics,
                    num_records=num_records,
                )
            )
        return runs

get_comp_graph(graph_id: int) -> dict | None

Get a single computation graph by its ID.

Source code in simplegrad/track/exp_db_manager.py
def get_comp_graph(self, graph_id: int) -> dict | None:
    """Get a single computation graph by its ID."""
    with self._get_connection(readonly=True) as conn:
        row = conn.execute("SELECT graph_json FROM graphs WHERE id = ?", (graph_id,)).fetchone()
        if row:
            return json.loads(row["graph_json"])
    return None

get_comp_graphs(run_id: int) -> list[dict]

Get all computation graphs for a run.

Source code in simplegrad/track/exp_db_manager.py
def get_comp_graphs(self, run_id: int) -> list[dict]:
    """Get all computation graphs for a run."""
    with self._get_connection(readonly=True) as conn:
        rows = conn.execute(
            "SELECT id, graph_json, created_at FROM graphs WHERE run_id = ? ORDER BY created_at",
            (run_id,),
        ).fetchall()
        return [
            {
                "id": row["id"],
                "graph": json.loads(row["graph_json"]),
                "created_at": row["created_at"],
            }
            for row in rows
        ]

get_metrics(run_id: int) -> list[str]

Get list of metric names for a run.

Source code in simplegrad/track/exp_db_manager.py
def get_metrics(self, run_id: int) -> list[str]:
    """Get list of metric names for a run."""
    with self._get_connection(readonly=True) as conn:
        rows = conn.execute(
            """SELECT DISTINCT m.name 
               FROM records r 
               JOIN metrics m ON r.metric_id = m.id 
               WHERE r.run_id = ?""",
            (run_id,),
        ).fetchall()
        return [row["name"] for row in rows]

get_records(run_id: int, metric_name: str) -> list[RecordInfo]

Get metric records for a run. Returns {metric_name: [MetricRecord, ...]}

Source code in simplegrad/track/exp_db_manager.py
def get_records(self, run_id: int, metric_name: str) -> list[RecordInfo]:
    """Get metric records for a run. Returns {metric_name: [MetricRecord, ...]}"""
    with self._get_connection(readonly=True) as conn:
        rows = conn.execute(
            """SELECT m.name as metric_name, r.step, r.value, r.wall_time 
                FROM records r 
                JOIN metrics m ON r.metric_id = m.id 
                WHERE r.run_id = ? AND m.name = ? 
                ORDER BY r.step""",
            (run_id, metric_name),
        ).fetchall()

        result: list[RecordInfo] = [
            RecordInfo(step=row["step"], value=row["value"], log_time=row["wall_time"])
            for row in rows
        ]
        return result

get_run(run_id: int) -> RunInfo | None

Get run metadata.

Source code in simplegrad/track/exp_db_manager.py
def get_run(self, run_id: int) -> RunInfo | None:
    """Get run metadata."""
    with self._get_connection(readonly=True) as conn:
        row = conn.execute("SELECT * FROM runs WHERE id = ?", (run_id,)).fetchone()
        if row:
            metrics = None
            num_records = None

            if row["status"] == "completed":
                # Get metrics for this run
                metric_rows = conn.execute(
                    """SELECT DISTINCT m.name 
                       FROM records r 
                       JOIN metrics m ON r.metric_id = m.id 
                       WHERE r.run_id = ?""",
                    (run_id,),
                ).fetchall()
                metrics = [m["name"] for m in metric_rows]

                # Get record counts per metric
                record_count_rows = conn.execute(
                    """SELECT m.name, COUNT(*) as count
                       FROM records r
                       JOIN metrics m ON r.metric_id = m.id
                       WHERE r.run_id = ?
                       GROUP BY m.name""",
                    (run_id,),
                ).fetchall()
                num_records = [rc["count"] for rc in record_count_rows]

            return RunInfo(
                run_id=row["id"],
                name=row["name"],
                created_at=_format_timestamp(row["created_at"]),
                status=row["status"],
                config=json.loads(row["config"]),
                metrics=metrics,
                num_records=num_records,
            )
    return None

init_exp_db()

Initialize database schema.

Source code in simplegrad/track/exp_db_manager.py
def init_exp_db(self):
    """Initialize database schema."""
    with self._get_connection() as conn:
        conn.executescript("""
            CREATE TABLE IF NOT EXISTS runs (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                name TEXT NOT NULL,
                created_at REAL NOT NULL,
                status TEXT NOT NULL DEFAULT 'running' CHECK(status IN ('running', 'completed', 'failed')),
                config TEXT DEFAULT '{}'
            );

            CREATE TABLE IF NOT EXISTS metrics (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                name TEXT NOT NULL UNIQUE
            );

            CREATE TABLE IF NOT EXISTS records (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                run_id INTEGER NOT NULL,
                metric_id INTEGER NOT NULL,
                step INTEGER NOT NULL,
                value REAL NOT NULL,
                wall_time REAL NOT NULL,
                FOREIGN KEY (run_id) REFERENCES runs(id),
                FOREIGN KEY (metric_id) REFERENCES metrics(id)
            );

            CREATE INDEX IF NOT EXISTS idx_records_run_step 
            ON records(run_id, step);

            CREATE TABLE IF NOT EXISTS graphs (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                run_id INTEGER NOT NULL UNIQUE,
                graph_json TEXT NOT NULL,
                created_at REAL NOT NULL,
                FOREIGN KEY (run_id) REFERENCES runs(id)
            );
        """)

record(run_id: int, metric_name: str, step: int, value: float)

Log a single metric record.

Source code in simplegrad/track/exp_db_manager.py
def record(self, run_id: int, metric_name: str, step: int, value: float):
    """Log a single metric record."""
    log_time = time.time()

    with self._get_connection() as conn:
        conn.execute("INSERT OR IGNORE INTO metrics (name) VALUES (?)", (metric_name,))
        metric_id = conn.execute(
            "SELECT id FROM metrics WHERE name = ?", (metric_name,)
        ).fetchone()["id"]
        conn.execute(
            "INSERT INTO records (run_id, metric_id, step, value, wall_time) VALUES (?, ?, ?, ?, ?)",
            (run_id, metric_id, step, value, log_time),
        )

save_comp_graph(run_id: int, graph_data: dict)

Save computation graph as JSON.

Source code in simplegrad/track/exp_db_manager.py
def save_comp_graph(self, run_id: int, graph_data: dict):
    """Save computation graph as JSON."""
    with self._get_connection() as conn:
        conn.execute(
            "INSERT OR REPLACE INTO graphs (run_id, graph_json, created_at) VALUES (?, ?, ?)",
            (run_id, json.dumps(graph_data), time.time()),
        )

update_run_status(run_id: int, status: str)

Update run status.

Source code in simplegrad/track/exp_db_manager.py
def update_run_status(self, run_id: int, status: str):
    """Update run status."""
    with self._get_connection() as conn:
        conn.execute("UPDATE runs SET status = ? WHERE id = ?", (status, run_id))

simplegrad.track.exp_db_manager.RunInfo

Bases: BaseModel

Metadata for a training run.

Source code in simplegrad/track/exp_db_manager.py
class RunInfo(BaseModel):
    """Metadata for a training run."""

    run_id: int
    name: str
    created_at: str  # Formatted datetime string for display
    status: str  # 'running', 'completed', 'failed'
    config: dict
    num_records: list[int] | None = None
    metrics: list[str] | None = None

simplegrad.track.exp_db_manager.RecordInfo

Bases: BaseModel

A single metric record (data point).

Source code in simplegrad/track/exp_db_manager.py
class RecordInfo(BaseModel):
    """A single metric record (data point)."""

    step: int
    value: float
    log_time: float

simplegrad.track.comp_graph._build_graph_data(tensor: Tensor) -> dict

Build a JSON-serializable graph structure for D3.js visualization.

Source code in simplegrad/track/comp_graph.py
def _build_graph_data(tensor: Tensor) -> dict:
    """Build a JSON-serializable graph structure for D3.js visualization."""
    nodes = []
    edges = []
    visited = set()

    def traverse(t: Tensor):
        if t._str_id in visited:
            return
        visited.add(t._str_id)

        # Add tensor node
        nodes.append(
            {
                "id": t._str_id,
                "type": "tensor",
                "label": t.label or "",
                "shape": list(t.shape),
                "comp_grad": t.comp_grad,
                "is_leaf": t.is_leaf,
            }
        )

        # Add operation node if exists
        if t.oper is not None:
            oper_id = t._str_id + "_" + t.oper
            nodes.append(
                {
                    "id": oper_id,
                    "type": "operation",
                    "label": t.oper,
                }
            )
            edges.append(
                {
                    "source": oper_id,
                    "target": t._str_id,
                }
            )

            # Connect parent tensors to operation
            for parent in t.prev:
                edges.append(
                    {
                        "source": parent._str_id,
                        "target": oper_id,
                    }
                )
                traverse(parent)

    traverse(tensor)

    return {
        "nodes": nodes,
        "edges": edges,
    }