TorchVista

Sample code (Not editable)

import torch
import torch.nn as nn
from torchvista import trace_model

class DeepMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(64, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 64)
        self.out = nn.Linear(64, 10)

    def forward(self, x):
        x1 = torch.relu(self.fc1(x))
        x2 = torch.relu(self.fc2(x1))
        x3 = torch.relu(self.fc3(x2 + x1))
        return self.out(x3)

model = DeepMLP()
example_input = torch.randn(1, 64)

trace_model(model, example_input)

Error Output (if any)
Visualized Interactive Graph