TorchVista

Sample code (Not editable)
import torch
import torch.nn as nn
from torchvista import trace_model

class FeedForward(nn.Module):
    def __init__(self, dim, mult=4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, mult * dim),
            nn.GELU(),
            nn.Linear(mult * dim, dim),
        )

    def forward(self, x):
        return self.net(x)

class AttentionBlock(nn.Module):
    def __init__(self, dim, heads):
        super().__init__()
        self.attn = nn.MultiheadAttention(dim, heads, batch_first=True)
        # inner repetition: three identical FF nets
        self.ff_chain = nn.Sequential(*[FeedForward(dim) for _ in range(3)])
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        attn_out, _ = self.attn(x, x, x)
        x = self.norm(x + attn_out)
        x = self.ff_chain(x)
        return attn_out

class Stage(nn.Module):
    def __init__(self, dim, heads, depth):
        super().__init__()
        # mid-level repetition: N AttentionBlocks
        self.blocks = nn.Sequential(*[AttentionBlock(dim, heads) for _ in range(depth)])

    def forward(self, x):
        return self.blocks(x)

class CompressionTest12(nn.Module):
    """
    Nested repetition structure with three levels:
    - Outer: multiple Stages
    - Mid: each Stage has multiple AttentionBlocks
    - Inner: each AttentionBlock has multiple FeedForward nets

    Tests compression of deeply nested repeated modules.
    """
    def __init__(self, dim=64, heads=4, stage_depth=1, num_stages=2):
        super().__init__()
        self.stages = nn.Sequential(*[Stage(dim, heads, stage_depth) for _ in range(num_stages)])

    def forward(self, x):
        return self.stages(x)

model = CompressionTest12(dim=64, heads=4, stage_depth=1, num_stages=2)
example_input = torch.randn(2, 32, 64)

trace_model(
    model,
    example_input,
    show_compressed_view=True,
    collapse_modules_after_depth=6,
    forced_module_tracing_depth=5,
    show_non_gradient_nodes=True,
)

Error Output (if any)
Visualized Interactive Graph