TorchVista

Sample code (Not editable)
import torch
import torch.nn as nn
from torchvista import trace_model

class MLP(nn.Module):
    def __init__(self, dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim),
            nn.ReLU(),
            nn.Linear(dim, dim),
            nn.ReLU(),
            nn.Linear(dim, dim),
        )

    def forward(self, x):
        return self.net(x)

class AttentionBlock(nn.Module):
    def __init__(self, dim=128, nhead=4):
        super().__init__()
        self.attn = nn.MultiheadAttention(dim, nhead, batch_first=True)
        self.norm = nn.LayerNorm(dim)
        self.mlp = MLP(dim)

    def forward(self, x):
        x = self.norm(x + self.attn(x, x, x)[0])
        x = x + self.mlp(x)
        return x

class EncoderStage(nn.Module):
    def __init__(self, dim=128, nhead=4, num_blocks=4):
        super().__init__()
        self.blocks = nn.ModuleList([
            AttentionBlock(dim, nhead) for _ in range(num_blocks)
        ])
        self.downsample = nn.Sequential(
            nn.Linear(dim, dim),
            nn.LayerNorm(dim),
        )

    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        return self.downsample(x)

class DecoderStage(nn.Module):
    def __init__(self, dim=128, nhead=4, num_blocks=4):
        super().__init__()
        self.upsample = nn.Sequential(
            nn.Linear(dim, dim),
            nn.LayerNorm(dim),
        )
        self.blocks = nn.ModuleList([
            AttentionBlock(dim, nhead) for _ in range(num_blocks)
        ])

    def forward(self, x):
        x = self.upsample(x)
        for block in self.blocks:
            x = block(x)
        return x

class CompressionTest7(nn.Module):
    def __init__(self, dim=128, nhead=4, num_stages=3, blocks_per_stage=4):
        super().__init__()
        self.encoder = nn.ModuleList([
            EncoderStage(dim, nhead, blocks_per_stage) for _ in range(num_stages)
        ])
        self.middle = nn.ModuleList([
            AttentionBlock(dim, nhead) for _ in range(6)
        ])
        self.decoder = nn.ModuleList([
            DecoderStage(dim, nhead, blocks_per_stage) for _ in range(num_stages)
        ])
        self.head = nn.Sequential(
            nn.Linear(dim, dim),
            nn.ReLU(),
            nn.Linear(dim, dim),
        )

    def forward(self, x):
        for stage in self.encoder:
            x = stage(x)
        for block in self.middle:
            x = block(x)
        for stage in self.decoder:
            x = stage(x)
        return self.head(x)

model = CompressionTest7()
example_input = torch.randn(2, 16, 128)  # (batch, seq_len, dim)

trace_model(model, example_input, collapse_modules_after_depth=0, show_compressed_view=True)

Error Output (if any)
Visualized Interactive Graph