import torch
import torch.nn as nn
from torchvista import trace_model
class CompressionTest3(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList([
nn.Sequential(
nn.Linear(64, 64),
nn.Linear(64, 64),
nn.Linear(64, 64),
nn.Linear(64, 64),
),
nn.Sequential(
nn.Linear(64, 64),
nn.Linear(64, 64),
nn.Linear(64, 64),
nn.Linear(64, 64),
),
nn.Sequential(
nn.Linear(64, 64),
nn.Linear(64, 64),
nn.Linear(64, 64),
nn.Linear(64, 64),
),
nn.Sequential(
nn.Linear(64, 64),
nn.Linear(64, 64),
nn.Linear(64, 64),
nn.Linear(64, 64),
),
])
def forward(self, x):
for seq in self.layers:
x = seq(x)
return x
model = CompressionTest3()
example_input = torch.randn(2, 64)
trace_model(model, example_input, show_compressed_view=True)