Tutorial 2

Exploring Nested Modules

Here we look at a slightly more complex model with nested modules. The code to trace remains the same, but this time, try clicking the "+" button on the Sequential modules to see what lies inside.

Code

import torch
import torch.nn as nn
from torchvista import trace_model

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        # This Sequential can be expanded in the visualization
        self.conv = nn.Sequential(
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.BatchNorm2d(channels),
            nn.ReLU(),
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.BatchNorm2d(channels),
        )
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.relu(x + self.conv(x))

class SimpleResNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.stem = nn.Conv2d(3, 16, 3, padding=1)
        # These nested modules can be expanded too
        self.block1 = ResidualBlock(16)
        self.block2 = ResidualBlock(16)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(16, 10)

    def forward(self, x):
        x = self.stem(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.pool(x).squeeze(-1).squeeze(-1)
        return self.fc(x)

model = SimpleResNet()
example_input = torch.randn(1, 3, 32, 32)

trace_model(model, example_input)

Interactive Visualization

Now you've seen how to expand and collapse nested modules in the visualization.

But you might have noticed that inbuilt modules that you didn't define yourself, such as Conv2d, BatchNorm2d cannot be expanded. This is because TorchVista only traces modules defined in your code by default, to avoid cluttering the visualization with low-level details.

Next we will learn how to overcome this by forcing the tracing depth.