User Guide¶
torchrecorder is pure Python3 code, it does not contain any C modules.
Installation¶
Requirements:
PyTorch v1.3 or greater (the
cpuversion is only required)The Graphviz library and its Python interface
Install via pip and PyPI:
$ pip install torchrecorder
Install via pip and the Github repo:
$ pip install git+https://github.com/ahgamut/torchrecorder/
Examples¶
The default usage is via the render_network wrapper function.
import sys
import torch
import torchrecorder
class SampleNet(torch.nn.Module):
def __init__(self):
torch.nn.Module.__init__(self)
self.linear_1 = torch.nn.Linear(in_features=3, out_features=3, bias=True)
self.linear_2 = torch.nn.Linear(in_features=3, out_features=3, bias=True)
self.linear_3 = torch.nn.Linear(in_features=6, out_features=1, bias=True)
self.my_special_relu = torch.nn.ReLU()
def forward(self, inputs):
x = self.linear_1(inputs)
y = self.linear_2(inputs)
z = torch.cat([x, y], dim=1)
z = self.my_special_relu(self.linear_3(z))
return z
def main():
i = int(sys.argv[1])
net = SampleNet()
torchrecorder.render_network(
net,
name="Sample Net",
input_shapes=(1, 3),
directory="./",
fmt="svg",
render_depth=i,
)
The render_network function calls record and make_dot so the call to
render_network in the above example could be written as below, to allow for
any modifications to the Digraph after rendering.
def main2():
i = int(sys.argv[1])
net = SampleNet()
# equivalent to calling render_network
rec = torchrecorder.record(net, name="Sample Net", input_shapes=(1, 3))
g = torchrecorder.make_dot(rec, render_depth=i)
# g is graphviz.Digraph object
g.format = "svg"
g.attr(label="{} at depth={}".format("Sample Net", i))
g.render("{}-{}".format("Sample Net", i), directory="./", cleanup=True)
Styling graphviz attributes¶
To change the default styling attributes of every node, you can pass any number of graphviz-related attributes 1
as keyword arguments to render_network (or make_dot).
The below example sets Lato as the default font.
def main():
net = SampleNet()
rec = torchrecorder.record(net, name="Sample Net", input_shapes=(1, 3))
g = torchrecorder.make_dot(rec, render_depth=1, fontname="Lato")
g.format = "svg"
g.attr(label="Font Change via styler_args")
g.render("{}-{}".format("StyleArgs", 1), directory="./", cleanup=True)
Custom Styler Objects¶
If the default styling of node shapes/colors is not sufficient, you can create a subclass of GraphvizStyler and
pass it to make_dot via the styler_cls argument. The subclass needs to accepts graphviz attributes as keyword arguments,
and override the style_node and style_edge methods.
In the below example, I construct a styler subclass that shows some parameters of Conv2d objects, draws orange edges
out of Conv2d objects, and blue edges into ReLU objects:
class ConvSample(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(
in_channels=1, out_channels=5, kernel_size=5, stride=2, padding=2
)
self.conv2 = torch.nn.Conv2d(
in_channels=5, out_channels=5, kernel_size=3, stride=1, padding=1
)
self.relu = torch.nn.ReLU(inplace=False)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x1)
x3 = x1 + x2
return self.relu(x3)
class MyStyler(GraphvizStyler):
def style_node(self, node):
default = super().style_node(node)
if isinstance(node.fn, torch.nn.Conv2d):
params = {}
params["kernel_size"] = node.fn.kernel_size
params["padding"] = node.fn.padding
params["stride"] = node.fn.stride
default["label"] = (
node.name
+ "\n("
+ ",\n".join("{}={}".format(k, v) for k, v in params.items())
+ ")"
)
default["penwidth"] = "2.4"
return default
def style_edge(self, fnode, tnode):
if isinstance(fnode.fn, torch.nn.Conv2d) and isinstance(tnode.fn, torch.Tensor):
return {"penwidth": "4.8", "color": "#ee8800"}
elif isinstance(tnode.fn, torch.nn.ReLU) and isinstance(fnode.fn, torch.Tensor):
return {"penwidth": "4.8", "color": "#00228f"}
else:
return super().style_edge(fnode, tnode)
def main():
net = ConvSample()
rec = torchrecorder.record(net, name="ConvSample", input_shapes=(1, 1, 10, 10))
g = torchrecorder.make_dot(rec, render_depth=1, styler_cls=MyStyler, fontname="Lato")
g.format = "svg"
g.attr(label="Custom Styler Class")
g.render("{}-{}".format("CustomStyler", 1), directory="./", cleanup=True)
Rendering into different formats¶
Currently, the torchrecorder package only provides rendering into graphviz objects, but the rendering functionality
can be extended by subclassing the BaseRenderer class in a manner similar to the GraphvizRenderer.
You can read the source code to see how the subclassing can be done.
If you create a subclass of BaseRenderer for a new rendering format, submit a pull request!
I’ve been trying to render in a SigmaJS-compatible format, but haven’t been able to.
- 1
the list of graphviz node attributes can be seen at https://graphviz.gitlab.io/_pages/doc/info/attrs.html