User Guide

torchrecorder is pure Python3 code, it does not contain any C modules.

Installation

Requirements:

Install via pip and PyPI:

$ pip install torchrecorder

Install via pip and the Github repo:

$ pip install git+https://github.com/ahgamut/torchrecorder/

Examples

The default usage is via the render_network wrapper function.

import sys
import torch
import torchrecorder


class SampleNet(torch.nn.Module):
    def __init__(self):
        torch.nn.Module.__init__(self)

        self.linear_1 = torch.nn.Linear(in_features=3, out_features=3, bias=True)
        self.linear_2 = torch.nn.Linear(in_features=3, out_features=3, bias=True)
        self.linear_3 = torch.nn.Linear(in_features=6, out_features=1, bias=True)
        self.my_special_relu = torch.nn.ReLU()

    def forward(self, inputs):
        x = self.linear_1(inputs)
        y = self.linear_2(inputs)
        z = torch.cat([x, y], dim=1)
        z = self.my_special_relu(self.linear_3(z))
        return z


def main():
    i = int(sys.argv[1])
    net = SampleNet()
    torchrecorder.render_network(
        net,
        name="Sample Net",
        input_shapes=(1, 3),
        directory="./",
        fmt="svg",
        render_depth=i,
    )

sample

The render_network function calls record and make_dot so the call to render_network in the above example could be written as below, to allow for any modifications to the Digraph after rendering.

def main2():
    i = int(sys.argv[1])
    net = SampleNet()
    # equivalent to calling render_network
    rec = torchrecorder.record(net, name="Sample Net", input_shapes=(1, 3))
    g = torchrecorder.make_dot(rec, render_depth=i)
    # g is graphviz.Digraph object
    g.format = "svg"
    g.attr(label="{} at depth={}".format("Sample Net", i))
    g.render("{}-{}".format("Sample Net", i), directory="./", cleanup=True)

Styling graphviz attributes

To change the default styling attributes of every node, you can pass any number of graphviz-related attributes 1 as keyword arguments to render_network (or make_dot). The below example sets Lato as the default font.

def main():
    net = SampleNet()
    rec = torchrecorder.record(net, name="Sample Net", input_shapes=(1, 3))
    g = torchrecorder.make_dot(rec, render_depth=1, fontname="Lato")
    g.format = "svg"
    g.attr(label="Font Change via styler_args")
    g.render("{}-{}".format("StyleArgs", 1), directory="./", cleanup=True)

styler1

Custom Styler Objects

If the default styling of node shapes/colors is not sufficient, you can create a subclass of GraphvizStyler and pass it to make_dot via the styler_cls argument. The subclass needs to accepts graphviz attributes as keyword arguments, and override the style_node and style_edge methods.

In the below example, I construct a styler subclass that shows some parameters of Conv2d objects, draws orange edges out of Conv2d objects, and blue edges into ReLU objects:


class ConvSample(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(
            in_channels=1, out_channels=5, kernel_size=5, stride=2, padding=2
        )
        self.conv2 = torch.nn.Conv2d(
            in_channels=5, out_channels=5, kernel_size=3, stride=1, padding=1
        )
        self.relu = torch.nn.ReLU(inplace=False)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(x1)
        x3 = x1 + x2
        return self.relu(x3)

class MyStyler(GraphvizStyler):
    def style_node(self, node):
        default = super().style_node(node)
        if isinstance(node.fn, torch.nn.Conv2d):
            params = {}
            params["kernel_size"] = node.fn.kernel_size
            params["padding"] = node.fn.padding
            params["stride"] = node.fn.stride
            default["label"] = (
                node.name
                + "\n("
                + ",\n".join("{}={}".format(k, v) for k, v in params.items())
                + ")"
            )
            default["penwidth"] = "2.4"
        return default

    def style_edge(self, fnode, tnode):
        if isinstance(fnode.fn, torch.nn.Conv2d) and isinstance(tnode.fn, torch.Tensor):
            return {"penwidth": "4.8", "color": "#ee8800"}
        elif isinstance(tnode.fn, torch.nn.ReLU) and isinstance(fnode.fn, torch.Tensor):
            return {"penwidth": "4.8", "color": "#00228f"}
        else:
            return super().style_edge(fnode, tnode)


def main():
    net = ConvSample()
    rec = torchrecorder.record(net, name="ConvSample", input_shapes=(1, 1, 10, 10))
    g = torchrecorder.make_dot(rec, render_depth=1, styler_cls=MyStyler, fontname="Lato")
    g.format = "svg"
    g.attr(label="Custom Styler Class")
    g.render("{}-{}".format("CustomStyler", 1), directory="./", cleanup=True)


styler2

Rendering into different formats

Currently, the torchrecorder package only provides rendering into graphviz objects, but the rendering functionality can be extended by subclassing the BaseRenderer class in a manner similar to the GraphvizRenderer. You can read the source code to see how the subclassing can be done.

If you create a subclass of BaseRenderer for a new rendering format, submit a pull request! I’ve been trying to render in a SigmaJS-compatible format, but haven’t been able to.

1

the list of graphviz node attributes can be seen at https://graphviz.gitlab.io/_pages/doc/info/attrs.html