Source code for chainercv.utils.testing.assertions.assert_is_semantic_segmentation_dataset

import numpy as np
import six

from chainercv.utils.testing.assertions.assert_is_image import assert_is_image


[docs]def assert_is_semantic_segmentation_dataset(dataset, n_class, n_example=None): """Checks if a dataset satisfies semantic segmentation dataset APIs. This function checks if a given dataset satisfies semantic segmentation dataset APIs or not. If the dataset does not satifiy the APIs, this function raises an :class:`AssertionError`. Args: dataset: A dataset to be checked. n_class (int): The number of classes including background. n_example (int): The number of examples to be checked. If this argument is specified, this function picks examples ramdomly and checks them. Otherwise, this function checks all examples. """ assert len(dataset) > 0, 'The length of dataset must be greater than zero.' if n_example: for _ in six.moves.range(n_example): i = np.random.randint(0, len(dataset)) _check_example(dataset[i], n_class) else: for i in six.moves.range(len(dataset)): _check_example(dataset[i], n_class)
def _check_example(example, n_class): assert len(example) >= 2, \ 'Each example must have at least two elements:' \ 'img and label.' img, label = example[:2] assert_is_image(img, color=True) assert isinstance(label, np.ndarray), \ 'label must be a numpy.ndarray.' assert label.dtype == np.int32, \ 'The type of label must be numpy.int32.' assert label.shape == img.shape[1:], \ 'The shape of label must be (H, W).' assert label.min() >= -1 and label.max() < n_class, \ 'The value of label must be in [-1, n_class - 1].'