torchrecorder
¶
torchrecorder
is a Python package that can be used to record the execution graph of a torch.nn.Module
and use
it to render a visualization of the network structure via graphviz
.
Licensed under MIT License.
Installation¶
Requirements:
Python3.6+
PyTorch v1.3 or greater (the
cpu
version)The Graphviz library and
graphviz
python package
Install via pip
:
$ pip install torchrecorder
Simple Example¶
import sys
import torch
import torchrecorder
class SampleNet(torch.nn.Module):
def __init__(self):
torch.nn.Module.__init__(self)
self.linear_1 = torch.nn.Linear(in_features=3, out_features=3, bias=True)
self.linear_2 = torch.nn.Linear(in_features=3, out_features=3, bias=True)
self.linear_3 = torch.nn.Linear(in_features=6, out_features=1, bias=True)
self.my_special_relu = torch.nn.ReLU()
def forward(self, inputs):
x = self.linear_1(inputs)
y = self.linear_2(inputs)
z = torch.cat([x, y], dim=1)
z = self.my_special_relu(self.linear_3(z))
return z
def main():
i = int(sys.argv[1])
net = SampleNet()
torchrecorder.render_network(
net,
name="Sample Net",
input_shapes=(1, 3),
directory="./",
fmt="svg",
render_depth=i,
)
|
|
---|---|