TorchVista
An interactive tool to visualize the forward pass of a PyTorch model directly in the notebook—with a single line of code.
Features
Interactive Graph
Drag and zoom to explore your model architecture with full interactive control.
Collapsible Nodes
Hierarchical modules can be expanded or collapsed to manage complexity in large models.
Notebook First
Use from web-based notebooks like Jupyter, Colab, VS Code, etc.
Error Tolerant
Partial visualization when errors arise (e.g., shape mismatches) for ease of debugging.
Compressed View
Option to compress repeating nodes with identical structure to simplify large models.
Export Options
Export visualization as PNG, SVG, or HTML to showcase your models.
Quick Start
Run from your web-based notebook (Jupyter, Colab, VSCode notebook, etc):
Install
pip install torchvista
Define Model
import torch
import torch.nn as nn
from torchvista import trace_model
class SelfAttention(nn.Module):
def __init__(self, embed_dim):
super().__init__()
self.query = nn.Linear(embed_dim, embed_dim)
self.key = nn.Linear(embed_dim, embed_dim)
self.value = nn.Linear(embed_dim, embed_dim)
self.scale = embed_dim ** 0.5
def forward(self, x):
q = self.query(x)
k = self.key(x)
v = self.value(x)
attn = torch.softmax((q @ k.transpose(-2, -1)) / self.scale, dim=-1)
return attn @ v
class AttentionClassifier(nn.Module):
def __init__(self):
super().__init__()
self.embed = nn.Linear(32, 64)
self.attn = SelfAttention(64)
self.classifier = nn.Linear(64, 5)
def forward(self, x):
x = self.embed(x)
x = self.attn(x)
return self.classifier(x.mean(dim=1))
model = AttentionClassifier()
Trace Model
example_input = torch.randn(1, 10, 32)
trace_model(model, example_input)
Interactive Visualization
API: trace_model
trace_model(
model,
inputs,
show_non_gradient_nodes=True,
collapse_modules_after_depth=1,
forced_module_tracing_depth=None,
height=800,
width=None,
export_format=None,
show_module_attr_names=False,
export_path=None,
show_compressed_view=False,
)
Parameters
model
Tracing
Model instance to visualize.
inputs
Tracing
Input(s) forwarded into the model; pass a single input or a tuple.
show_non_gradient_nodes
Visual
Display nodes for constants and other values outside the gradient graph.
collapse_modules_after_depth
Visual
Depth to initially expand nested modules;
0 collapses everything (nodes can still be expanded interactively).
forced_module_tracing_depth
Tracing
Maximum depth of module internals to trace;
None traces only user-defined modules.
height
Visual
Canvas height in pixels.
width
Visual
Canvas width; accepts pixels or percentages; defaults to full available width when omitted.
export_format
Export
Optional export format:
png, svg, or html if exporting graph as a file. Otherwise, by default the graph is shown within the notebook.
show_module_attr_names
Visual
Display attribute names for modules when available instead of just class names.
export_path
Export
Custom path if exporting as a file. Only HTML format is currently supported with custom export paths.
show_compressed_view (Experimental)
Visual
Compress the graph by showing repeating nodes of the same type with identical input and output dims in single "repeat" blocks. This feature currently only recognises repeating nodes within
Sequential and ModuleList. Warning: this feature might be expensive on large models.Resources
- Quick Google Colab Tutorial – Get started with an interactive tutorial
- Check out demos – Explore examples ranging from simple linear models to complex transformers
- GitHub Repository – Source code, issues, and contributions