TorchVista

Sample code (Not editable)

import torch
import torch.nn as nn
from torchvista import trace_model

class SimpleMultiPathModule(nn.Module):
    def __init__(self, in_features=10, out_features=5):
        super(SimpleMultiPathModule, self).__init__()
        
        self.linear1 = nn.Linear(in_features, out_features)
        
        self.path_weights = nn.ParameterDict({
            'path1_weight': nn.Parameter(torch.tensor(0.5)),
        })
        
        self.global_biases = nn.ParameterList([
            nn.Parameter(torch.randn(1)),
            nn.Parameter(torch.randn(1))
        ])
    
    def forward(self, x, x2, x3):
        
        output1 = self.linear1(x)
        x += x2

        x *= F.relu(x3)
        
        results = torch.zeros(x.size(0), self.linear1.out_features, device=x.device)
                
        bias = 5 + self.global_biases[0] + self.global_biases[1] + output1[:,:]
        results = results + bias + output1[-5]
        
        return results

model = SimpleMultiPathModule()

x = torch.randn(3, 10)
y = torch.randn(3, 10)
z = torch.randn(3, 10)
example_input = (x, y, z)


trace_model(model, example_input)

Error Output (if any)
index -5 is out of bounds for dimension 0 with size 3
Visualized Interactive Graph