Source code for chainercv.chainer_experimental.datasets.sliceable.getter_dataset

from chainercv.chainer_experimental.datasets.sliceable.sliceable_dataset \
    import _as_key_indices
from chainercv.chainer_experimental.datasets.sliceable.sliceable_dataset \
    import _is_iterable
from chainercv.chainer_experimental.datasets.sliceable import SliceableDataset


[docs]class GetterDataset(SliceableDataset): """A sliceable dataset class that is defined with getters. This is a dataset class with getters. Please refer to the tutorial for more detailed explanation. Here is an example. >>> class SliceableLabeledImageDataset(GetterDataset): >>> def __init__(self, pairs, root='.'): >>> super(SliceableLabeledImageDataset, self).__init__() >>> with open(pairs) as f: >>> self._pairs = [l.split() for l in f] >>> self._root = root >>> >>> self.add_getter('img', self.get_image) >>> self.add_getter('label', self.get_label) >>> >>> def __len__(self): >>> return len(self._pairs) >>> >>> def get_image(self, i): >>> path, _ = self._pairs[i] >>> return read_image(os.path.join(self._root, path)) >>> >>> def get_label(self, i): >>> _, label = self._pairs[i] >>> return np.int32(label) >>> >>> dataset = SliceableLabeledImageDataset('list.txt') >>> >>> # get a subset with label = 0, 1, 2 >>> # no images are loaded >>> indices = [i for i, label in ... enumerate(dataset.slice[:, 'label']) if label in {0, 1, 2}] >>> dataset_012 = dataset.slice[indices] """ def __init__(self): self._keys = [] self._getters = [] self._return_tuple = True def __len__(self): raise NotImplementedError @property def keys(self): if self._return_tuple: return tuple(key for key, _, _ in self._keys) else: return self._keys[0][0] @keys.setter def keys(self, keys): self._keys = [self._keys[key_index] for key_index in _as_key_indices(keys, self.keys)] self._return_tuple = _is_iterable(keys)
[docs] def add_getter(self, keys, getter): """Register a getter function Args: keys (int or string or tuple of strings): The number or name(s) of data that the getter function returns. getter (callable): A getter function that takes an index and returns data of the corresponding example. """ self._getters.append(getter) if isinstance(keys, int): if keys == 1: keys = None else: keys = (None,) * keys if _is_iterable(keys): for key_index, key in enumerate(keys): self._keys.append((key, len(self._getters) - 1, key_index)) else: self._keys.append((keys, len(self._getters) - 1, None))
[docs] def get_example_by_keys(self, index, key_indices): example = [] cache = {} for key_index in key_indices: _, getter_index, key_index = self._keys[key_index] if getter_index not in cache: cache[getter_index] = self._getters[getter_index](index) if key_index is None: example.append(cache[getter_index]) else: example.append(cache[getter_index][key_index]) return tuple(example)