Source code for torchrecorder.renderer.base

# -*- coding: utf-8 -*-
"""
    torchrecorder.renderer.base
    ~~~~~~~~~~~~~~~~~~~~~~

    Abstract base class for Renderer objects

    :param copyright: (c) 2020 by Gautham Venkatasubramanian.
    :param license: see LICENSE for more details.
"""
from collections import OrderedDict
from ..nodes import LayerNode


[docs]class BaseRenderer(object): """Base Class for rendering information from a `~torchrecorder.recorder.Recorder`. Attributes: rec (`~torchrecorder.recorder.Recorder`): render_depth (int): nodes having a greater depth than this value will not be rendered processed (`collections.OrderedDict`): An ``OrderedDict`` whose keys contain ``nodes`` and values contain the corresponding (directed) edge lists """ def __init__(self, rec, render_depth=256): self.rec = rec self.render_depth = render_depth self.processed = OrderedDict() def render_node(self, dest, node): raise NotImplementedError("Base Class") def render_recursive_node(self, dest, node): raise NotImplementedError("Base Class") def render_edge(self, dest, fnode, tnode): raise NotImplementedError("Base Class") def __call__(self, dest): """Render nodes and edges. Uses `.render_depth` to select nodes and edges from `.rec` required for rendering. calls `.render_node` on each node, followed by `render_edge` on each edge from that node. Args: dest: destination for the rendered information (`graphviz.Digraph`, `dict` etc.) Returns: ``dest`` after updating with necessary information """ self.processed.clear() self._process_nodes() self._process_edges() while len(self.processed) != 0: node = next(iter(self.processed)) targets = self.processed[node] self.render_node(dest, node) for t in targets: self.render_edge(dest, node, t) self.processed.pop(node) return dest def _process_nodes(self): """Filter out nodes that have a greater depth than required. """ for k, v in self.rec.nodes.items(): if k is not None and v.depth <= self.render_depth: self.processed[v] = [] def _process_edges(self): """Construct necessary edges between filtered nodes. After all nodes deeper than `.render_depth` have been removed, all the edges that between these nodes and those that remain must be transformed accordingly: such edges are "lifted up" for rendering, and ignored if they are internal to a node (i.e. both source and destination have been removed). """ def lifted_node(x): xnode = self.rec.nodes[x] while xnode.depth > self.render_depth: xnode = self.rec.nodes[xnode.parent] return xnode lifted_edges = set( (lifted_node(x), lifted_node(y), z) for x, y, z in self.rec.edges if lifted_node(x) != lifted_node(y) ) for fnode, tnode, _ in lifted_edges: self.processed[fnode].append(tnode)