TorchVista

Sample code (Not editable)

import torch
import torch.nn as nn
from torchvista import trace_model

class DynamicLayerSelector(nn.Module):
    def __init__(self):
        super(DynamicLayerSelector, self).__init__()
        
    def forward(self, x):        
        results = torch.zeros(x.size(0), x.size(1)).to(x.device)
        x = x + results
        
        return results

class MainModule(nn.Module):
    def __init__(self):
        super(MainModule, self).__init__()
        self.dls = DynamicLayerSelector()

    def forward(self, x):
        return self.dls(x)

model = MainModule()

example_input = torch.ones(3,  4)

trace_model(model, example_input)

Error Output (if any)
Visualized Interactive Graph