import torch
import torch.nn as nn
from torchvista import trace_model
class LeNet5(nn.Module):
def __init__(self, num_classes=10):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 6, 5), nn.Tanh(), nn.AvgPool2d(2),
nn.Conv2d(6, 16, 5), nn.Tanh(), nn.AvgPool2d(2)
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(16*4*4, 120), nn.Tanh(),
nn.Linear(120, 84), nn.Tanh(),
nn.Linear(84, num_classes)
)
def forward(self, x):
return self.classifier(self.features(x))
model = LeNet5()
example_input = torch.randn(2, 1, 28, 28)
trace_model(model, example_input)