import torch
import torch.nn as nn
from torchvista import trace_model
class DeepMLP(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(64, 128)
self.fc2 = nn.Linear(128, 128)
self.fc3 = nn.Linear(128, 64)
self.out = nn.Linear(64, 10)
def forward(self, x):
x1 = torch.relu(self.fc1(x))
x2 = torch.relu(self.fc2(x1))
x3 = torch.relu(self.fc3(x2 + x1))
return self.out(x3)
model = DeepMLP()
example_input = torch.randn(1, 64)
trace_model(model, example_input)