Source code for torchrecorder.recorder

# -*- coding: utf-8 -*-
"""
    torchrecorder.recorder
    ~~~~~~~~~~~~~~~~~

    Uses hooks to record the traversal of the execution graph

    :copyright: (c) 2019 by Gautham Venkatasubramanian.
    :license: see LICENSE for more details.
"""
from torch.nn import Module
from collections import OrderedDict
from .nodes import BaseNode, TensorNode, ParamNode, OpNode, LayerNode
from functools import partial
import time


[docs]class Recorder(object): """Record and store execution graph information Attributes: fn_set (set): a set of objects ( `~torchrecorder.nodes.BaseNode.fn`\ s) that contain recordable information nodes (dict): a mapping of `~torchrecorder.nodes.BaseNode.fn`\ s to their corresponding `~torchrecorder.nodes.BaseNode`\ s fn_types (dict): a count of `~torchrecorder.nodes.BaseNode.fn`\ s by type for naming edges (set(tuple)): a set of edges, each a pair of `~torchrecorder.nodes.BaseNode.fn`\ s """ def __init__(self): self.nodes = OrderedDict() self.fn_types = dict() self.fn_set = set() self.edges = set() self._start_time = None self._create_context()
[docs] def add_node(self, net, depth=0, parent=None, name=None): """Construct a node of recording graph. Construct a `~.nodes.BaseNode` that will store information related to ``net`` as the neural network is run. Args: net : Object whose information will be stored as the ``fn`` attribute of a `~.nodes.BaseNode` depth : The scope depth at which ``net`` is found parent : The object as part of which net will be run name : a name to recognize the object during rendering, defaults to class name Returns: `None` """ classname = type(net).__name__ if self.fn_types.get(classname): self.fn_types[classname] += 1 else: self.fn_types[classname] = 1 if name is None: objname = classname if self.fn_types[classname] > 1: objname = objname + "-" + str(self.fn_types[classname]) else: objname = name if isinstance(net, Module): if depth > 0 and name is not None: objname = objname + "\n(" + classname + ")" x = LayerNode(name=objname, fn=net, parent=parent, depth=depth) x.pre = net.register_forward_pre_hook(partial(prehook, rec=self, node=x)) x.post = net.register_forward_hook(partial(posthook, rec=self, node=x)) # x.back = net.register_backward_hook(partial(backhook, rec=self, node=x)) elif "Tensor" in classname: x = TensorNode(name=objname, fn=net, parent=parent, depth=depth) elif "Parameter" in classname: x = ParamNode(name=objname, fn=net, parent=parent, depth=depth) elif hasattr(net, "next_functions"): x = OpNode(name=objname, fn=net, parent=parent, depth=depth) else: raise RuntimeError("Cannot create node for " + str(net)) self.nodes[net] = x self.fn_set.add(net) if x.parent is not None: pnode = self.nodes[x.parent] pnode.subnets.add(net)
[docs] def add_dummy(self, dummy, fn): """Point to an existing node to assist recording. Instead of creating a separate node, the ``dummy`` object is used to point to an existing node containing ``fn``. Used for dummy ops and ``AccumulateGradient``\ s (see `leaf_dummy` ). Args: dummy: a dummy `torch.Tensor` or op that should not be recorded fn : a recorded object that will be connected to further ops """ self.fn_set.add(dummy) self.nodes[dummy] = self.nodes[fn]
[docs] def add_edge(self, _from, _to): """Construct an edge of the recording graph. Records an edge between two `~torchrecorder.nodes.BaseNode.fn` objects to be used while rendering. This will be used along with the ``nodes`` dictionary to map edges properly. Args: _from (`~torchrecorder.nodes.BaseNode.fn`\ ): _to (`~torchrecorder.nodes.BaseNode.fn`\ ): """ if _from is None or _to is None: raise AssertionError("Cannot draw edge involving" + str((_from, _to))) if self._start_time is not None: timestamp = time.time() - self._start_time else: timestamp = 0 self._start_time = time.time() edge = (_from, _to, round(timestamp, 6)) self.edges.add(edge)
[docs] def register_hooks(self, net, depth=0, parent=None, name=None): """Register the hooks of the `.Recorder` recursively on a `torch.nn.Module`\ . The hooks registered are `~functools.partial` versions of `prehook` and `posthook` corresponding to each node. Args: net (`~torch.nn.Module`\ ): depth (int): parent (`torch.nn.Module`\ ): the parent of ``net`` name (str): name of ``net`` Returns: `None` """ self.add_node(net, depth, parent, name) for n, x in net.named_children(): self.register_hooks(x, depth=depth + 1, parent=net, name=n)
[docs] def remove_hooks(self): """Remove hooks from any `~torch.nn.Module`\ s in `~torchrecorder.nodes.LayerNode`\ s. After the recording is completed, the hooks in `~torchrecorder.nodes.LayerNode`\ s are unnecessary. They are removed to prevent any possible issues. """ for node in set(self.nodes.values()): if isinstance(node, LayerNode): node.pre.remove() node.post.remove()
def _create_context(self): """Construct a dummy node as the context for the recording graph. Adds a dummy node (mapped to `None`) to the recording graph as the context. This node has the least depth (-1); is its own parent, and will have only one child: the network whose execution is to be recorded. Returns: `None` """ self.fn_set.add(None) self.nodes[None] = BaseNode(fn=None, depth=-1, parent=None, name="ContextDummy")
[docs]def op_acc(gf, rec, node): """Operator Accumulator. Creates an `~.nodes.OpNode` to record the newly-performed operation ``gf``, if not already recorded. If ``gf`` is an initialization op (``AccumulateGradient``), then points ``gf`` to its connected `torch.Tensor` instead of creating an `~.nodes.OpNode`. Otherwise recursively checks all operations that are connected to ``gf`` and adds them if necessary. Args: gf: current operation, a ``grad_fn`` object obtained from a `torch.Tensor` rec: a `~.Recorder` object whose nodes are updated node: `~.nodes.LayerNode` whose ``fn`` the current operation is a part of Returns: `None` """ if gf in rec.fn_set: pass else: if hasattr(gf, "variable"): rec.add_dummy(dummy=gf, fn=gf.variable) elif hasattr(gf, "next_functions"): rec.add_node(gf, node.depth + 1, node.fn) for x, y in gf.next_functions: if x is not None: op_acc(x, rec, node) rec.add_edge(_from=x, _to=gf)
[docs]def tensor_acc(tensor, rec, node): """Tensor Accumulator. Creates a `~.nodes.TensorNode` to record the newly-created tensor, if not already recorded. Note that the resulting `~.nodes.TensorNode` has the same parent as ``node``, because the ``tensor`` is the output of/input to ``node.fn``. Args: tensor: a `torch.Tensor` rec: a `~.Recorder` object whose nodes are updated node: a `~.nodes.LayerNode` whose ``fn`` outputs/inputs ``tensor`` Returns: `None` """ if tensor in rec.fn_set: pass else: rec.add_node(tensor, depth=node.depth, parent=node.parent)
[docs]def param_acc(param, rec, node): """Parameter Accumulator. Creates a `~.nodes.ParamNode` to record the parameter ``param`` of ``node.fn``, if not already recorded. Note that ``node.fn`` is the *parent* of ``param``\ . Args: param: a `~torch.nn.Parameter` rec: the `~.Recorder` object whose nodes are updated node: `~.nodes.LayerNode` whose ``fn`` contains ``param`` Returns: `None` """ if param in rec.fn_set: pass else: rec.add_node(param, depth=node.depth + 1, parent=node.fn)
[docs]def leaf_dummy(tensor, rec): """Performs a dummy operation (adding 0) to a leaf `~torch.Tensor`. This ensures that the (possibly in-place) operations performed on ``tensor`` hereafter can be correctly mapped. The dummy tensor (and operation) are not recorded separately, they merely point to the original tensor. Args: tensor: a newly-formed leaf `torch.Tensor` rec: the `~.Recorder` object whose nodes are updated Returns: ``tensor`` after adding 0 """ dummy = tensor + 0 rec.add_dummy(dummy=dummy, fn=tensor) rec.add_dummy(dummy=dummy.grad_fn, fn=tensor) return dummy
[docs]def prehook(module, inputs, rec, node): """hook to record BEFORE the given ``module`` is run. Records parameters contained in ``module``, then checks each tensor in ``inputs`` for any operations that may have run after the end of the previous ``module``. The ``inputs`` are then converted to leaf tensors and recorded before being passed off to the ``module``. Args: module: a `torch.nn.Module` inputs: a `torch.Tensor` or a `tuple` of `torch.Tensor`\ s rec: a `~.Recorder` object for global information node (`~torchrecorder.nodes.LayerNode`\ ): ``node.fn`` is ``module``\ . Returns: the ``leaf``-equivalent of ``inputs``. """ for name, param in module.named_parameters(recurse=False): param_acc(param, rec, node) if name is not None and name != "": rec.nodes[param].name = name is_singleton = not isinstance(inputs, tuple) a = [inputs] if is_singleton else inputs # same input appearing multiple times? new_inputs = [] for x in a: gf = x.grad_fn op_acc(gf, rec, rec.nodes[node.parent]) tensor_acc(x, rec, node) if gf is not None: rec.add_edge(_from=gf, _to=x) new_inputs.append(leaf_dummy(x, rec)) return new_inputs[0] if is_singleton else tuple(new_inputs)
[docs]def posthook(module, inputs, outputs, rec, node): """hook to record AFTER the given ``module`` has run and returned. Records any operations that may have run as part of ``module``\ , then checks if each tensor in the ``outputs`` has already been recorded by a sub\ ``module`` of the current ``module`` (the sub\ ``module``\ 's `posthook` would execute first!). If necessary, the ``outputs`` are converted to leaf tensors to record operations afresh. Args: module: a `torch.nn.Module` inputs: a `torch.Tensor` or a tuple of `torch.Tensor`\ s outputs: a `torch.Tensor` or a tuple of `torch.Tensor`\ s rec: a `~.Recorder` object for global information node (`~torchrecorder.nodes.LayerNode`\ ): ``node.fn`` is ``module``. Returns: the ``leaf``-equivalent of ``outputs``. """ is_singleton = not isinstance(outputs, tuple) b = [outputs] if is_singleton else outputs new_outputs = [] for x in b: gf = x.grad_fn if gf not in rec.fn_set: x = x.detach() x.requires_grad = True tensor_acc(x, rec, node) op_acc(gf, rec, node) rec.add_edge(gf, x) new_outputs.append(leaf_dummy(x, rec)) else: # if the op has already been recorded # it has to be a dummy op rec.nodes[gf].parent = node.parent rec.nodes[gf].depth -= 1 if rec.nodes[gf].fn in node.subnets: node.subnets.remove(rec.nodes[gf].fn) new_outputs.append(x) return new_outputs[0] if is_singleton else tuple(new_outputs)
def backhook(module, grad_inputs, grad_outputs, rec, node): pass