Source code for chainercv.chainer_experimental.datasets.sliceable.tuple_dataset

from collections import defaultdict
import six

from chainercv.chainer_experimental.datasets.sliceable.sliceable_dataset \
    import _as_tuple
from chainercv.chainer_experimental.datasets.sliceable import SliceableDataset


[docs]class TupleDataset(SliceableDataset): """A sliceable version of :class:`chainer.datasets.TupleDataset`. Here is an example. >>> # omit keys >>> dataset = TupleDataset([0, 1, 2], [0, 1, 4]) >>> dataset.keys) # (None, None) >>> dataset.slice[:, 0][:] # [0, 1, 2] >>> >>> dataset_more = TupleDataset(dataset, [0, 1, 8]) >>> dataset_more.keys # (None, None, None) >>> dataset_more.slice[:, [1, 2]][:]) # [(0, 0), (1, 1), (4, 8)] >>> >>> # specify the name of a key >>> named_dataset = TupleDataset(('feat0', [0, 1, 2]), [0, 1, 4]) >>> named_dataset.keys) # ('feat0', None) >>> # slice takes both key and index (or their mixture) >>> named_dataset.slice[:, ['feat0', 1]][:]) # [(0, 0), (1, 1), (2, 4)] Args: datasets: The underlying datasets. The following datasets are acceptable. * An inheritance of \ :class:~chainer.datasets.sliceable.SliceableDataset`. * A tuple of a name and a data array. \ The data array should be list or :class:`numpy.ndarray`. * A data array. In this case, the name of key is :obj:`None`. """ def __init__(self, *datasets): if len(datasets) == 0: raise ValueError('At least one dataset is required') self._len = None self._keys = [] self._datasets = [] for dataset in datasets: if isinstance(dataset, SliceableDataset): self._datasets.append(dataset) for key_index, key in enumerate(_as_tuple(dataset.keys)): self._keys.append( (key, len(self._datasets) - 1, key_index)) else: if isinstance(dataset, tuple): key, dataset = dataset else: key = None self._datasets.append(dataset) self._keys.append((key, len(self._datasets) - 1, None)) if self._len is None: self._len = len(dataset) if not len(dataset) == self._len: raise ValueError( 'All datasets should have the same length') def __len__(self): return self._len @property def keys(self): return tuple(key for key, _, _ in self._keys)
[docs] def get_example_by_keys(self, index, key_indices): datasets_key_indices = defaultdict(set) for key_index in key_indices: _, dataset_index, key_index = self._keys[key_index] if key_index is None: datasets_key_indices[dataset_index] = None else: datasets_key_indices[dataset_index].add(key_index) values = {} for dataset_index, dataset_key_indices in \ six.iteritems(datasets_key_indices): dataset = self._datasets[dataset_index] if dataset_key_indices is None: values[(dataset_index, None)] = dataset[index] else: dataset_key_indices = tuple(dataset_key_indices) values.update(six.moves.zip( ((dataset_index, key_index) for key_index in dataset_key_indices), dataset.get_example_by_keys(index, dataset_key_indices))) return tuple( values[self._keys[key_index][1:]] for key_index in key_indices)