torchrecorder
¶
torchrecorder
is a Python package that can be used to record the execution graph of a torch.nn.Module
and use
it to render a visualization of the network structure via graphviz
.
Licensed under MIT License.
Installation¶
Requirements:
Python3.6+
PyTorch v1.3 or greater (the
cpu
version)The Graphviz library and
graphviz
python package
Install via pip
:
$ pip install torchrecorder
Simple Example¶
import sys
import torch
import torchrecorder
class SampleNet(torch.nn.Module):
def __init__(self):
torch.nn.Module.__init__(self)
self.linear_1 = torch.nn.Linear(in_features=3, out_features=3, bias=True)
self.linear_2 = torch.nn.Linear(in_features=3, out_features=3, bias=True)
self.linear_3 = torch.nn.Linear(in_features=6, out_features=1, bias=True)
self.my_special_relu = torch.nn.ReLU()
def forward(self, inputs):
x = self.linear_1(inputs)
y = self.linear_2(inputs)
z = torch.cat([x, y], dim=1)
z = self.my_special_relu(self.linear_3(z))
return z
def main():
i = int(sys.argv[1])
net = SampleNet()
torchrecorder.render_network(
net,
name="Sample Net",
input_shapes=(1, 3),
directory="./",
fmt="svg",
render_depth=i,
)
|
|
---|---|
Contents¶
User Guide¶
torchrecorder
is pure Python3 code, it does not contain any C modules.
Installation¶
Requirements:
PyTorch v1.3 or greater (the
cpu
version is only required)The Graphviz library and its Python interface
Install via pip
and PyPI:
$ pip install torchrecorder
Install via pip
and the Github repo:
$ pip install git+https://github.com/ahgamut/torchrecorder/
Examples¶
The default usage is via the render_network
wrapper function.
import sys
import torch
import torchrecorder
class SampleNet(torch.nn.Module):
def __init__(self):
torch.nn.Module.__init__(self)
self.linear_1 = torch.nn.Linear(in_features=3, out_features=3, bias=True)
self.linear_2 = torch.nn.Linear(in_features=3, out_features=3, bias=True)
self.linear_3 = torch.nn.Linear(in_features=6, out_features=1, bias=True)
self.my_special_relu = torch.nn.ReLU()
def forward(self, inputs):
x = self.linear_1(inputs)
y = self.linear_2(inputs)
z = torch.cat([x, y], dim=1)
z = self.my_special_relu(self.linear_3(z))
return z
def main():
i = int(sys.argv[1])
net = SampleNet()
torchrecorder.render_network(
net,
name="Sample Net",
input_shapes=(1, 3),
directory="./",
fmt="svg",
render_depth=i,
)
The render_network function calls record
and make_dot
so the call to
render_network
in the above example could be written as below, to allow for
any modifications to the Digraph
after rendering.
def main2():
i = int(sys.argv[1])
net = SampleNet()
# equivalent to calling render_network
rec = torchrecorder.record(net, name="Sample Net", input_shapes=(1, 3))
g = torchrecorder.make_dot(rec, render_depth=i)
# g is graphviz.Digraph object
g.format = "svg"
g.attr(label="{} at depth={}".format("Sample Net", i))
g.render("{}-{}".format("Sample Net", i), directory="./", cleanup=True)
Styling graphviz
attributes¶
To change the default styling attributes of every node, you can pass any number of graphviz
-related attributes 1
as keyword arguments to render_network
(or make_dot
).
The below example sets Lato
as the default font.
def main():
net = SampleNet()
rec = torchrecorder.record(net, name="Sample Net", input_shapes=(1, 3))
g = torchrecorder.make_dot(rec, render_depth=1, fontname="Lato")
g.format = "svg"
g.attr(label="Font Change via styler_args")
g.render("{}-{}".format("StyleArgs", 1), directory="./", cleanup=True)
Custom Styler Objects¶
If the default styling of node shapes/colors is not sufficient, you can create a subclass of GraphvizStyler
and
pass it to make_dot
via the styler_cls
argument. The subclass needs to accepts graphviz
attributes as keyword arguments,
and override the style_node
and style_edge
methods.
In the below example, I construct a styler subclass that shows some parameters of Conv2d
objects, draws orange edges
out of Conv2d
objects, and blue edges into ReLU
objects:
class ConvSample(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(
in_channels=1, out_channels=5, kernel_size=5, stride=2, padding=2
)
self.conv2 = torch.nn.Conv2d(
in_channels=5, out_channels=5, kernel_size=3, stride=1, padding=1
)
self.relu = torch.nn.ReLU(inplace=False)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x1)
x3 = x1 + x2
return self.relu(x3)
class MyStyler(GraphvizStyler):
def style_node(self, node):
default = super().style_node(node)
if isinstance(node.fn, torch.nn.Conv2d):
params = {}
params["kernel_size"] = node.fn.kernel_size
params["padding"] = node.fn.padding
params["stride"] = node.fn.stride
default["label"] = (
node.name
+ "\n("
+ ",\n".join("{}={}".format(k, v) for k, v in params.items())
+ ")"
)
default["penwidth"] = "2.4"
return default
def style_edge(self, fnode, tnode):
if isinstance(fnode.fn, torch.nn.Conv2d) and isinstance(tnode.fn, torch.Tensor):
return {"penwidth": "4.8", "color": "#ee8800"}
elif isinstance(tnode.fn, torch.nn.ReLU) and isinstance(fnode.fn, torch.Tensor):
return {"penwidth": "4.8", "color": "#00228f"}
else:
return super().style_edge(fnode, tnode)
def main():
net = ConvSample()
rec = torchrecorder.record(net, name="ConvSample", input_shapes=(1, 1, 10, 10))
g = torchrecorder.make_dot(rec, render_depth=1, styler_cls=MyStyler, fontname="Lato")
g.format = "svg"
g.attr(label="Custom Styler Class")
g.render("{}-{}".format("CustomStyler", 1), directory="./", cleanup=True)
Rendering into different formats¶
Currently, the torchrecorder
package only provides rendering into graphviz
objects, but the rendering functionality
can be extended by subclassing the BaseRenderer
class in a manner similar to the GraphvizRenderer
.
You can read the source code to see how the subclassing can be done.
If you create a subclass of BaseRenderer
for a new rendering format, submit a pull request!
I’ve been trying to render in a SigmaJS-compatible format, but haven’t been able to.
- 1
the list of graphviz node attributes can be seen at https://graphviz.gitlab.io/_pages/doc/info/attrs.html
API Reference¶
Convenience Functions¶
-
torchrecorder.
render_network
(net, name, input_shapes, directory, filename=None, fmt='svg', input_data=None, render_depth=1, **styler_args)[source]¶ Render the structure of a
torch.nn.Module
to an image viagraphviz
.- Parameters
net (
torch.nn.Module
) –name (str) – name of the network
input_shapes (None, tuple or list(tuple)) –
tuple
ifnet
has a single input,list
(tuple
),None
ifinput_data
is provideddirectory (str) – directory to store the rendered image
fmt (str, optional) – image format
input_data (
torch.Tensor
ortuple
(torch.Tensor
), optional) – ifnet
requires normalized inputs, provide them here instead of settinginput_shapes
.render_depth (int, optional) – Default
1
.**styler_args – node attributes to pass to
graphviz
-
torchrecorder.
record
(net, name, input_shapes, input_data=None)[source]¶ Record the graph by running a single pass of a
torch.nn.Module
.- Parameters
net (
torch.nn.Module
) –name (str) – name of the network
input_shapes (None, tuple or list(tuple)) –
tuple
ifnet
has a single input,list
(tuple
),None
ifinput_data
is providedinput_data (
torch.Tensor
ortuple
(torch.Tensor
), optional) – ifnet
requires normalized inputs, provide them here instead of settinginput_shapes
.
- Returns
a
Recorder
object containing the execution graph
-
torchrecorder.
make_dot
(rec, render_depth=256, styler_cls=None, **styler_args)[source]¶ Produces Graphviz representation from a
Recorder
object- Parameters
rec (
Recorder
) –render_depth (int) – depth until which nodes should be rendered
styler_cls – styler class to instantiate when styling nodes. If
None
, defaults toGraphvizStyler
.
- Kwargs:
styler_args (optional): styler properties to be set for all nodes
- Returns
a
graphviz.Digraph
with the rendered nodes
Custom graphviz
styling¶
-
class
torchrecorder.renderer.
GraphvizStyler
(**styler_args)[source]¶ Bases:
object
Provide styling options before rendering to graphviz.
The style_node
and style_edge
methods read the properties
BaseNode
objects, so any subclass of GraphvizStyler
would need the same.
-
class
torchrecorder.nodes.
TensorNode
(name='', fn=None, depth=-1, parent=None)[source]¶ Bases:
torchrecorder.nodes.BaseNode
Node to encapsulate a
torch.Tensor
.-
fn
¶ - Type
-
-
class
torchrecorder.nodes.
OpNode
(name='', fn=None, depth=-1, parent=None)[source]¶ Bases:
torchrecorder.nodes.BaseNode
Node to encapsulate an Op, a
grad_fn
attribute of atorch.Tensor
.-
fn
¶ - Type
-
-
class
torchrecorder.nodes.
LayerNode
(name='', fn=None, depth=-1, parent=None)[source]¶ Bases:
torchrecorder.nodes.BaseNode
Node to encapsulate a
torch.nn.Module
.-
fn
¶ - Type
-
-
class
torchrecorder.nodes.
ParamNode
(name='', fn=None, depth=-1, parent=None)[source]¶ Bases:
torchrecorder.nodes.TensorNode
Node to encapsulate a
torch.nn.Parameter
.-
fn
¶ - Type
torch.nn.Parameter
-
parent
¶ a
Module
whoseparameters
containsfn
- Type
-
Custom Rendering¶
If you are creating a new format to render information from a Recorder
,
you would need to subclass the following methods in BaseRenderer
,
as done in GraphvizRenderer
:
-
class
torchrecorder.renderer.
GraphvizRenderer
(rec, render_depth=256, styler_cls=None, **styler_args)[source]¶ Bases:
torchrecorder.renderer.base.BaseRenderer
Render information from a
Recorder
into agraphviz.Digraph
.-
styler
¶ GraphvizStyler
or a subclass- Type
class
-
render_node
(g, node)[source]¶ Render a node in
graphviz
Renders
node
into theDigraph
g
, after applying appropriate styling. Ifnode
is aLayerNode
, checksrender_depth
to see if itssubnets
have to rendered.- Parameters
g (
graphviz.Digraph
) –node (
BaseNode
) –
-
render_recursive_node
(g, node)[source]¶ Render a
LayerNode
and its subnets.- Parameters
g (
graphviz.Digraph
) –node (
LayerNode
) – has adepth
greater thanrender_depth
The
node
is rendered as a separateDigraph
and then is added as agraphviz.Digraph.subgraph
tog
.
-
render_edge
(g, fnode, tnode)[source]¶ Render an edge in
graphviz
- Parameters
g (
graphviz.Digraph
) –fnode (
BaseNode
) –tnode (
BaseNode
) –
-
Custom Recording¶
Subclassing Recorder
should be unnecessary in most cases.
-
class
torchrecorder.recorder.
Recorder
[source]¶ Bases:
object
Record and store execution graph information
-
add_node
(net, depth=0, parent=None, name=None)[source]¶ Construct a node of recording graph.
Construct a
BaseNode
that will store information related tonet
as the neural network is run.
-
add_dummy
(dummy, fn)[source]¶ Point to an existing node to assist recording.
Instead of creating a separate node, the
dummy
object is used to point to an existing node containingfn
. Used for dummy ops andAccumulateGradient
s (seeleaf_dummy
).- Parameters
dummy – a dummy
torch.Tensor
or op that should not be recordedfn – a recorded object that will be connected to further ops
-
add_edge
(_from, _to)[source]¶ Construct an edge of the recording graph.
Records an edge between two
fn
objects to be used while rendering. This will be used along with thenodes
dictionary to map edges properly.
-
register_hooks
(net, depth=0, parent=None, name=None)[source]¶ Register the hooks of the
Recorder
recursively on atorch.nn.Module
.The hooks registered are
partial
versions ofprehook
andposthook
corresponding to each node.- Parameters
net (
Module
) –depth (int) –
parent (
torch.nn.Module
) – the parent ofnet
name (str) – name of
net
- Returns
-
-
torchrecorder.recorder.
op_acc
(gf, rec, node)[source]¶ Operator Accumulator.
Creates an
OpNode
to record the newly-performed operationgf
, if not already recorded. Ifgf
is an initialization op (AccumulateGradient
), then pointsgf
to its connectedtorch.Tensor
instead of creating anOpNode
. Otherwise recursively checks all operations that are connected togf
and adds them if necessary.- Parameters
gf – current operation, a
grad_fn
object obtained from atorch.Tensor
rec – a
Recorder
object whose nodes are updatednode –
LayerNode
whosefn
the current operation is a part of
- Returns
-
torchrecorder.recorder.
tensor_acc
(tensor, rec, node)[source]¶ Tensor Accumulator.
Creates a
TensorNode
to record the newly-created tensor, if not already recorded. Note that the resultingTensorNode
has the same parent asnode
, because thetensor
is the output of/input tonode.fn
.- Parameters
tensor – a
torch.Tensor
rec – a
Recorder
object whose nodes are updatednode – a
LayerNode
whosefn
outputs/inputstensor
- Returns
-
torchrecorder.recorder.
param_acc
(param, rec, node)[source]¶ Parameter Accumulator.
Creates a
ParamNode
to record the parameterparam
ofnode.fn
, if not already recorded. Note thatnode.fn
is the parent ofparam
.
-
torchrecorder.recorder.
leaf_dummy
(tensor, rec)[source]¶ Performs a dummy operation (adding 0) to a leaf
Tensor
.This ensures that the (possibly in-place) operations performed on
tensor
hereafter can be correctly mapped. The dummy tensor (and operation) are not recorded separately, they merely point to the original tensor.- Parameters
tensor – a newly-formed leaf
torch.Tensor
rec – the
Recorder
object whose nodes are updated
- Returns
tensor
after adding 0
-
torchrecorder.recorder.
prehook
(module, inputs, rec, node)[source]¶ hook to record BEFORE the given
module
is run.Records parameters contained in
module
, then checks each tensor ininputs
for any operations that may have run after the end of the previousmodule
. Theinputs
are then converted to leaf tensors and recorded before being passed off to themodule
.- Parameters
module – a
torch.nn.Module
inputs – a
torch.Tensor
or atuple
oftorch.Tensor
srec – a
Recorder
object for global informationnode (
LayerNode
) –node.fn
ismodule
.
- Returns
the
leaf
-equivalent ofinputs
.
-
torchrecorder.recorder.
posthook
(module, inputs, outputs, rec, node)[source]¶ hook to record AFTER the given
module
has run and returned.Records any operations that may have run as part of
module
, then checks if each tensor in theoutputs
has already been recorded by a submodule
of the currentmodule
(the submodule
’sposthook
would execute first!). If necessary, theoutputs
are converted to leaf tensors to record operations afresh.- Parameters
module – a
torch.nn.Module
inputs – a
torch.Tensor
or a tuple oftorch.Tensor
soutputs – a
torch.Tensor
or a tuple oftorch.Tensor
srec – a
Recorder
object for global informationnode (
LayerNode
) –node.fn
ismodule
.
- Returns
the
leaf
-equivalent ofoutputs
.
License¶
MIT License
Copyright (c) 2019-2020 Gautham Venkatasubramanian
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.