Source code for chainercv.links.model.pickable_sequential_chain

import chainer


[docs]class PickableSequentialChain(chainer.Chain): """A sequential chain that can pick intermediate layers. Callable objects, such as :class:`chainer.Link` and :class:`chainer.Function`, can be registered to this chain with :meth:`init_scope`. This chain keeps the order of registrations and :meth:`__call__` executes callables in that order. A :class:`chainer.Link` object in the sequence will be added as a child link of this link. :meth:`__call__` returns single or multiple layers that are picked up through a stream of computation. These layers can be specified by :obj:`pick`, which contains the names of the layers that are collected. When :obj:`pick` is a string, single layer is returned. When :obj:`pick` is an iterable of strings, a tuple of layers is returned. The order of the layers is the same as the order of the strings in :obj:`pick`. When :obj:`pick` is :obj:`None`, the last layer is returned. Examples: >>> import chainer.functions as F >>> import chainer.links as L >>> model = PickableSequentialChain() >>> with model.init_scope(): >>> model.l1 = L.Linear(None, 1000) >>> model.l1_relu = F.relu >>> model.l2 = L.Linear(None, 1000) >>> model.l2_relu = F.relu >>> model.l3 = L.Linear(None, 10) >>> # This is layer l3 >>> layer3 = model(x) >>> # The layers to be collected can be changed. >>> model.pick = ('l2_relu', 'l1_relu') >>> # These are layers l2_relu and l1_relu. >>> layer2, layer1 = model(x) Parameters: pick (string or iterable of strings): Names of layers that are collected during the forward pass. layer_names (iterable of strings): Names of layers that can be collected from this chain. The names are ordered in the order of computation. """ def __init__(self): super(PickableSequentialChain, self).__init__() self.layer_names = [] # Two attributes are initialized by the setter of pick. # self._pick -> None # self._return_tuple -> False self.pick = None def __setattr__(self, name, value): super(PickableSequentialChain, self).__setattr__(name, value) if self.within_init_scope and callable(value): self.layer_names.append(name) def __delattr__(self, name): if self._pick and name in self._pick: raise AttributeError( 'layer {:s} is registered to pick.'.format(name)) super(PickableSequentialChain, self).__delattr__(name) try: self.layer_names.remove(name) except ValueError: pass @property def pick(self): if self._pick is None: return None if self._return_tuple: return self._pick else: return self._pick[0] @pick.setter def pick(self, pick): if pick is None: self._return_tuple = False self._pick = None return if (not isinstance(pick, str) and all(isinstance(name, str) for name in pick)): return_tuple = True else: return_tuple = False pick = (pick,) if any(name not in self.layer_names for name in pick): raise ValueError('Invalid layer name') self._return_tuple = return_tuple self._pick = tuple(pick)
[docs] def remove_unused(self): """Delete all layers that are not needed for the forward pass. """ if self._pick is None: return # The biggest index among indices of the layers that are included # in pick. last_index = max(self.layer_names.index(name) for name in self._pick) for name in self.layer_names[last_index + 1:]: delattr(self, name)
def __call__(self, x): """Forward this model. Args: x (chainer.Variable or array): Input to the model. Returns: chainer.Variable or tuple of chainer.Variable: The returned layers are determined by :obj:`pick`. """ if self._pick is None: pick = (self.layer_names[-1],) else: pick = self._pick # The biggest index among indices of the layers that are included # in pick. last_index = max(self.layer_names.index(name) for name in pick) layers = {} h = x for name in self.layer_names[:last_index + 1]: h = self[name](h) if name in pick: layers[name] = h if self._return_tuple: layers = tuple(layers[name] for name in pick) else: layers = list(layers.values())[0] return layers