Source code for chainercv.utils.testing.assertions.assert_is_bbox_dataset

import numpy as np
import six

from chainercv.utils.testing.assertions.assert_is_bbox import assert_is_bbox
from chainercv.utils.testing.assertions.assert_is_image import assert_is_image


[docs]def assert_is_bbox_dataset(dataset, n_fg_class, n_example=None): """Checks if a dataset satisfies the bounding box dataset API. This function checks if a given dataset satisfies the bounding box dataset API or not. If the dataset does not satifiy the API, this function raises an :class:`AssertionError`. Args: dataset: A dataset to be checked. n_fg_class (int): The number of foreground classes. n_example (int): The number of examples to be checked. If this argument is specified, this function picks examples ramdomly and checks them. Otherwise, this function checks all examples. """ assert len(dataset) > 0, 'The length of dataset must be greater than zero.' if n_example: for _ in six.moves.range(n_example): i = np.random.randint(0, len(dataset)) _check_example(dataset[i], n_fg_class) else: for i in six.moves.range(len(dataset)): _check_example(dataset[i], n_fg_class)
def _check_example(example, n_fg_class): assert len(example) >= 3, \ 'Each example must have at least three elements:' \ 'img, bbox and label.' img, bbox, label = example[:3] assert_is_image(img, color=True) assert_is_bbox(bbox, size=img.shape[1:]) assert isinstance(label, np.ndarray), \ 'label must be a numpy.ndarray.' assert label.dtype == np.int32, \ 'The type of label must be numpy.int32.' assert label.shape[1:] == (), \ 'The shape of label must be (*,).' assert len(label) == len(bbox), \ 'The length of label must be same as that of bbox.' if len(label) > 0: assert label.min() >= 0 and label.max() < n_fg_class, \ 'The value of label must be in [0, n_fg_class - 1].'