Source code for torchrecorder.helpers

# -*- coding: utf-8 -*-
"""
    torchrecorder.helpers
    ~~~~~~~~~~~~~~~~

    Helper functions for torchrecorder

    :copyright: (c) 2020 by Gautham Venkatasubramanian.
    :license: see LICENSE for more details.
"""
from torch import randn
from .recorder import Recorder
from .renderer.gv import GraphvizRenderer
from graphviz import Digraph


[docs]def render_network( net, name, input_shapes, directory, filename=None, fmt="svg", input_data=None, render_depth=1, **styler_args ): """Render the structure of a `torch.nn.Module` to an image via `graphviz`. Args: net (`torch.nn.Module`): name (str): name of the network input_shapes (None, tuple or list(tuple)): `tuple` if ``net`` has a single input, `list` ( `tuple` ), `None` if ``input_data`` is provided directory (str): directory to store the rendered image fmt (str, optional): image format input_data (`torch.Tensor` or `tuple` (`torch.Tensor` ), optional): if ``net`` requires normalized inputs, provide them here instead of setting ``input_shapes``. render_depth (int, optional): Default ``1``. **styler_args : node attributes to pass to `graphviz` """ net = net.cpu().train() rec = record(net, name, input_shapes, input_data) g = make_dot(rec, render_depth, styler_cls=None, **styler_args) g.format = fmt g.attr(label="{} at depth = {}".format(name, render_depth)) g.render("{}-{}".format(name, render_depth), directory=directory, cleanup=True)
[docs]def record(net, name, input_shapes, input_data=None): """Record the graph by running a single pass of a `torch.nn.Module`. Args: net (`torch.nn.Module`): name (str): name of the network input_shapes (None, tuple or list(tuple)): `tuple` if ``net`` has a single input, `list` ( `tuple` ), `None` if ``input_data`` is provided input_data (`torch.Tensor` or `tuple` (`torch.Tensor` ), optional): if ``net`` requires normalized inputs, provide them here instead of setting ``input_shapes``. Returns: a `~.Recorder` object containing the execution graph """ rec = Recorder() rec.register_hooks(net, depth=0, parent=None, name=name) data = [] data_given = input_data is not None single_input = False if data_given: data = input_data single_input = not isinstance(input_data, tuple) else: if isinstance(input_shapes, list): single_input = False for shape in input_shapes: d = randn(shape) data.append(d) data = tuple(data) elif isinstance(input_shapes, tuple): single_input = True data = randn(input_shapes) if single_input: data.requires_grad = True rec.add_node(data, depth=0, parent=None, name="Input") pred = net(data) else: for i, d in enumerate(data): d.requires_grad = True rec.add_node(d, depth=0, parent=None, name="Input-{i}".format(i=i + 1)) pred = net(*data) single_output = not isinstance(pred, tuple) if single_output: rec.nodes[pred].name = "Output" else: for i, p in enumerate(pred): rec.nodes[p].name = "Output-{i}".format(i=i + 1) rec.remove_hooks() return rec
[docs]def make_dot(rec, render_depth=256, styler_cls=None, **styler_args): """ Produces Graphviz representation from a `~torchrecorder.recorder.Recorder` object Args: rec (`~torchrecorder.recorder.Recorder`\ ): render_depth (int): depth until which nodes should be rendered styler_cls: styler class to instantiate when styling nodes. If `None`, defaults to `.GraphvizStyler`. Kwargs: styler_args (optional): styler properties to be set for all nodes Returns: a `graphviz.Digraph` with the rendered nodes """ graph_attr = dict(compound="true", ranksep="0.5", fontsize="24") node_attr = dict(fontsize="20") if styler_args.get("fontname", None) is not None: graph_attr["fontname"] = styler_args.get("fontname") node_attr["fontname"] = styler_args.get("fontname") g = Digraph(graph_attr=graph_attr, node_attr=node_attr) renderer = GraphvizRenderer( rec=rec, render_depth=render_depth, styler_cls=styler_cls, **styler_args ) return renderer(g)