Source code for chainercv.utils.testing.assertions.assert_is_label_dataset

import numpy as np
import six

from chainercv.utils.testing.assertions.assert_is_image import assert_is_image


[docs]def assert_is_label_dataset(dataset, n_class, n_example=None, color=True): """Checks if a dataset satisfies the label dataset API. This function checks if a given dataset satisfies the label dataset API or not. If the dataset does not satifiy the API, this function raises an :class:`AssertionError`. Args: dataset: A dataset to be checked. n_class (int): The number of classes. n_example (int): The number of examples to be checked. If this argument is specified, this function picks examples ramdomly and checks them. Otherwise, this function checks all examples. color (bool): A boolean that determines the expected channel size. If it is :obj:`True`, the number of channels should be :obj:`3`. Otherwise, it should be :obj:`1`. The default value is :obj:`True`. """ assert len(dataset) > 0, 'The length of dataset must be greater than zero.' if n_example: for _ in six.moves.range(n_example): i = np.random.randint(0, len(dataset)) _check_example(dataset[i], n_class, color) else: for i in six.moves.range(len(dataset)): _check_example(dataset[i], n_class, color)
def _check_example(example, n_class, color): assert len(example) >= 2, \ 'Each example must have at least two elements:' \ 'img and label.' img, label = example[:2] assert_is_image(img, color=color) assert isinstance(label, np.int32), \ 'label must be a numpy.int32.' assert label.ndim == 0, 'The ndim of label must be 0' assert label >= 0 and label < n_class, \ 'The value of label must be in [0, n_class - 1].'