torchrecorder

rtd

torchrecorder is a Python package that can be used to record the execution graph of a torch.nn.Module and use it to render a visualization of the network structure via graphviz.

Licensed under MIT License.

Installation

Requirements:

Install via pip:

$ pip install torchrecorder

Simple Example

import sys
import torch
import torchrecorder


class SampleNet(torch.nn.Module):
    def __init__(self):
        torch.nn.Module.__init__(self)

        self.linear_1 = torch.nn.Linear(in_features=3, out_features=3, bias=True)
        self.linear_2 = torch.nn.Linear(in_features=3, out_features=3, bias=True)
        self.linear_3 = torch.nn.Linear(in_features=6, out_features=1, bias=True)
        self.my_special_relu = torch.nn.ReLU()

    def forward(self, inputs):
        x = self.linear_1(inputs)
        y = self.linear_2(inputs)
        z = torch.cat([x, y], dim=1)
        z = self.my_special_relu(self.linear_3(z))
        return z


def main():
    i = int(sys.argv[1])
    net = SampleNet()
    torchrecorder.render_network(
        net,
        name="Sample Net",
        input_shapes=(1, 3),
        directory="./",
        fmt="svg",
        render_depth=i,
    )

render_depth = 1

render_depth = 2

img1

img2

Contents