TorchVista

Sample code (Not editable)
import torch
import torch.nn as nn
from torchvista import trace_model

class UNetPP(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv00 = nn.Conv2d(1, 16, 3, padding=1)
        self.conv10 = nn.Conv2d(16, 32, 3, padding=1)
        self.conv01 = nn.Conv2d(16 + 32, 16, 3, padding=1)
        self.pool = nn.MaxPool2d(2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.final = nn.Conv2d(16, 1, 1)

    def forward(self, x):
        x00 = torch.relu(self.conv00(x))
        x10 = torch.relu(self.conv10(self.pool(x00)))
        x10_up = self.up(x10)
        x01 = torch.relu(self.conv01(torch.cat([x00, x10_up], dim=1)))
        return self.final(x01)

model = UNetPP()
example_input = torch.randn(1, 1, 64, 64)

trace_model(model, example_input)

Error Output (if any)
Visualized Interactive Graph