Skip to content

Visualization

Computation Graph

simplegrad.visual.inline_comp_graph.graph(tensor: Tensor, path: str | None = None) -> graphviz.Digraph

Render the computation graph of a tensor as an SVG diagram.

Functions decorated with @compound_op are enclosed in a labelled black-border rectangle. Each distinct call to a compound op gets its own rectangle, so two calls to softmax produce two separate boxes.

Node colors
  • Salmon: leaf tensors (inputs / parameters)
  • Steel blue: intermediate tensors
  • Gold: operation nodes

Parameters:

Name Type Description Default
tensor Tensor

The output tensor whose computation graph to visualize.

required
path str | None

If provided, save the SVG to this file path (without extension).

None

Returns:

Type Description
Digraph

A graphviz.Digraph object. Displays inline in Jupyter notebooks.

Source code in simplegrad/visual/inline_comp_graph.py
def graph(tensor: Tensor, path: str | None = None) -> graphviz.Digraph:
    """Render the computation graph of a tensor as an SVG diagram.

    Functions decorated with ``@compound_op`` are enclosed in a labelled
    black-border rectangle. Each distinct call to a compound op gets its own
    rectangle, so two calls to ``softmax`` produce two separate boxes.

    Node colors:
        - Salmon: leaf tensors (inputs / parameters)
        - Steel blue: intermediate tensors
        - Gold: operation nodes

    Args:
        tensor: The output tensor whose computation graph to visualize.
        path: If provided, save the SVG to this file path (without extension).

    Returns:
        A ``graphviz.Digraph`` object. Displays inline in Jupyter notebooks.
    """
    g = graphviz.Digraph(
        format="svg",
        graph_attr={
            "rankdir": "LR",
            "nodesep": "0.5",
            "ranksep": "0.7",
            "bgcolor": "white",
        },
    )
    g.strict = True

    all_tensors = _collect_nodes(tensor)

    clusters: dict[int, tuple[str, list]] = {}
    ungrouped: list = []
    for t in all_tensors:
        if t.group is not None:
            gname, gid = t.group
            if gid not in clusters:
                clusters[gid] = (gname, [])
            clusters[gid][1].append(t)
        else:
            ungrouped.append(t)

    for gid, (gname, tensors) in clusters.items():
        with g.subgraph(name=f"cluster_{gid}") as c:
            c.attr(
                label=gname,
                labelloc="t",
                labeljust="l",
                color="black",
                style="rounded",
                fontname="monospace",
                fontsize="10",
            )
            for t in tensors:
                _render_tensor_node(t, c)

    for t in ungrouped:
        _render_tensor_node(t, g)

    _add_graph_edges(tensor, g)

    if path:
        g.render(filename=path, format="svg", cleanup=True)
    return g

Training Plots

simplegrad.visual.inline_training_graphs.plot(results: dict[str, list[RecordInfo]], selected: list[str] | None = None, num_cols: int = 2, cell_w: int = 8, cell_h: int = 5, path: Path | None = None, color: str | None = None)

Plot training metrics as line charts.

Parameters:

Name Type Description Default
results dict[str, list[RecordInfo]]

Mapping of metric name to list of RecordInfo data points.

required
selected list[str] | None

Subset of metric names to plot. Plots all if None.

None
num_cols int

Number of subplot columns. Defaults to 2.

2
cell_w int

Width of each subplot cell in inches. Defaults to 8.

8
cell_h int

Height of each subplot cell in inches. Defaults to 5.

5
path Path | None

If provided, save the figure to this path.

None
color str | None

Fixed color for all lines. Random if None.

None
Source code in simplegrad/visual/inline_training_graphs.py
def plot(
    results: dict[str, list[RecordInfo]],
    selected: list[str] | None = None,
    num_cols: int = 2,
    cell_w: int = 8,
    cell_h: int = 5,
    path: Path | None = None,
    color: str | None = None,
):
    """Plot training metrics as line charts.

    Args:
        results: Mapping of metric name to list of ``RecordInfo`` data points.
        selected: Subset of metric names to plot. Plots all if None.
        num_cols: Number of subplot columns. Defaults to 2.
        cell_w: Width of each subplot cell in inches. Defaults to 8.
        cell_h: Height of each subplot cell in inches. Defaults to 5.
        path: If provided, save the figure to this path.
        color: Fixed color for all lines. Random if None.
    """
    colors = [
        "#1f77b4",
        "#ff7f0e",
        "#2ca02c",
        "#d62728",
        "#9467bd",
        "#8c564b",
        "#e377c2",
        "#7f7f7f",
        "#bcbd22",
        "#17becf",
    ]

    # Determine which metrics to plot
    metrics_to_plot = selected if selected else list(results.keys())

    num_metrics = len(metrics_to_plot)
    num_rows = (num_metrics + num_cols - 1) // num_cols
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(cell_w * num_cols, cell_h * num_rows))
    axes = axes.flatten() if num_metrics > 1 else [axes]

    for ax in axes[num_metrics:]:
        ax.axis("off")

    for i, metric_name in enumerate(metrics_to_plot):
        if metric_name not in results:
            continue

        records = results[metric_name]
        steps = [record.step for record in records]
        values = [record.value for record in records]
        plot_color = color if color else random.choice(colors)
        marker = "o" if len(steps) < cell_w * 8 else None

        axes[i].plot(steps, values, marker=marker, color=plot_color)
        axes[i].set_title(metric_name)
        axes[i].set_xlabel("Step")
        axes[i].set_ylabel("Value")
        axes[i].grid(True)

    plt.tight_layout()
    if path:
        plt.savefig(path)
    plt.show()

simplegrad.visual.inline_training_graphs.scatter(results: dict[str, list[RecordInfo]], selected: list[str] | None = None, num_cols: int = 2, cell_w: int = 8, cell_h: int = 5, path: Path | None = None, color: str | None = None)

Plot training metrics as scatter charts.

Parameters:

Name Type Description Default
results dict[str, list[RecordInfo]]

Mapping of metric name to list of RecordInfo data points.

required
selected list[str] | None

Subset of metric names to plot. Plots all if None.

None
num_cols int

Number of subplot columns. Defaults to 2.

2
cell_w int

Width of each subplot cell in inches. Defaults to 8.

8
cell_h int

Height of each subplot cell in inches. Defaults to 5.

5
path Path | None

If provided, save the figure to this path.

None
color str | None

Fixed color for all points. Random if None.

None
Source code in simplegrad/visual/inline_training_graphs.py
def scatter(
    results: dict[str, list[RecordInfo]],
    selected: list[str] | None = None,
    num_cols: int = 2,
    cell_w: int = 8,
    cell_h: int = 5,
    path: Path | None = None,
    color: str | None = None,
):
    """Plot training metrics as scatter charts.

    Args:
        results: Mapping of metric name to list of ``RecordInfo`` data points.
        selected: Subset of metric names to plot. Plots all if None.
        num_cols: Number of subplot columns. Defaults to 2.
        cell_w: Width of each subplot cell in inches. Defaults to 8.
        cell_h: Height of each subplot cell in inches. Defaults to 5.
        path: If provided, save the figure to this path.
        color: Fixed color for all points. Random if None.
    """
    colors = [
        "#1f77b4",
        "#ff7f0e",
        "#2ca02c",
        "#d62728",
        "#9467bd",
        "#8c564b",
        "#e377c2",
        "#7f7f7f",
        "#bcbd22",
        "#17becf",
    ]

    # Determine which metrics to plot
    metrics_to_plot = selected if selected else list(results.keys())

    num_metrics = len(metrics_to_plot)
    num_rows = (num_metrics + num_cols - 1) // num_cols
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(cell_w * num_cols, cell_h * num_rows))
    axes = axes.flatten() if num_metrics > 1 else [axes]

    for ax in axes[num_metrics:]:
        ax.axis("off")

    for i, metric_name in enumerate(metrics_to_plot):
        if metric_name not in results:
            continue

        records = results[metric_name]
        steps = [record.step for record in records]
        values = [record.value for record in records]
        plot_color = color if color else random.choice(colors)

        axes[i].scatter(steps, values, color=plot_color)
        axes[i].set_title(metric_name)
        axes[i].set_xlabel("Step")
        axes[i].set_ylabel("Value")
        axes[i].grid(True)

    plt.tight_layout()
    if path:
        plt.savefig(path)
    plt.show()