TorchVista

An interactive tool to visualize the forward pass of a PyTorch model directly in the notebook—with a single line of code.

Features

Interactive Graph

Drag and zoom to explore your model architecture with full interactive control.

Collapsible Nodes

Hierarchical modules can be expanded or collapsed to manage complexity in large models.

Notebook First

Use from web-based notebooks like Jupyter, Colab, VS Code, etc.

Error Tolerant

Partial visualization when errors arise (e.g., shape mismatches) for ease of debugging.

Compressed View

Option to compress repeating nodes with identical structure to simplify large models.

Export Options

Export visualization as PNG, SVG, or HTML to showcase your models.

Quick Start

Run from your web-based notebook (Jupyter, Colab, VSCode notebook, etc):

Install

pip install torchvista

Define Model

import torch
import torch.nn as nn
from torchvista import trace_model

class SelfAttention(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)
        self.scale = embed_dim ** 0.5

    def forward(self, x):
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)
        attn = torch.softmax((q @ k.transpose(-2, -1)) / self.scale, dim=-1)
        return attn @ v

class AttentionClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.embed = nn.Linear(32, 64)
        self.attn = SelfAttention(64)
        self.classifier = nn.Linear(64, 5)

    def forward(self, x):
        x = self.embed(x)
        x = self.attn(x)
        return self.classifier(x.mean(dim=1))

model = AttentionClassifier()

Trace Model

example_input = torch.randn(1, 10, 32)
trace_model(model, example_input)

Interactive Visualization

API: trace_model

trace_model(
    model,
    inputs,
    show_non_gradient_nodes=True,
    collapse_modules_after_depth=1,
    forced_module_tracing_depth=None,
    height=800,
    width=None,
    export_format=None,
    show_module_attr_names=False,
    export_path=None,
    show_compressed_view=False,
)

Parameters

model (torch.nn.Module) Tracing
Model instance to visualize.
inputs (Any) Tracing
Input(s) forwarded into the model; pass a single input or a tuple.
show_non_gradient_nodes (bool, default: True) Visual
Display nodes for constants and other values outside the gradient graph.
collapse_modules_after_depth (int, default: 1) Visual
Depth to initially expand nested modules; 0 collapses everything (nodes can still be expanded interactively).
forced_module_tracing_depth (int, default: None) Tracing
Maximum depth of module internals to trace; None traces only user-defined modules.
height (int, default: 800) Visual
Canvas height in pixels.
width (int | str, default: None) Visual
Canvas width; accepts pixels or percentages; defaults to full available width when omitted.
export_format (str, default: None) Export
Optional export format: png, svg, or html if exporting graph as a file. Otherwise, by default the graph is shown within the notebook.
show_module_attr_names (bool, default: False) Visual
Display attribute names for modules when available instead of just class names.
export_path (str, default: None) Export
Custom path if exporting as a file. Only HTML format is currently supported with custom export paths.
show_compressed_view (Experimental) (bool, default: False) Visual
Compress the graph by showing repeating nodes of the same type with identical input and output dims in single "repeat" blocks. This feature currently only recognises repeating nodes within Sequential and ModuleList. Warning: this feature might be expensive on large models.

Resources