Source code for chainercv.datasets.voc.voc_instance_segmentation_dataset

import numpy as np
import os

from chainercv.chainer_experimental.datasets.sliceable import GetterDataset
from chainercv.datasets.voc import voc_utils
from chainercv.utils import read_image


[docs]class VOCInstanceSegmentationDataset(GetterDataset): """Instance segmentation dataset for PASCAL `VOC2012`_. .. _`VOC2012`: http://host.robots.ox.ac.uk/pascal/VOC/voc2012/ Args: data_dir (string): Path to the root of the training data. If this is :obj:`auto`, this class will automatically download data for you under :obj:`$CHAINER_DATASET_ROOT/pfnet/chainercv/voc`. split ({'train', 'val', 'trainval'}): Select a split of the dataset. This dataset returns the following data. .. csv-table:: :header: name, shape, dtype, format :obj:`img`, ":math:`(3, H, W)`", :obj:`float32`, \ "RGB, :math:`[0, 255]`" :obj:`mask`, ":math:`(R, H, W)`", :obj:`bool`, -- :obj:`label`, ":math:`(R,)`", :obj:`int32`, \ ":math:`[0, \#fg\_class - 1]`" """ def __init__(self, data_dir='auto', split='train'): super(VOCInstanceSegmentationDataset, self).__init__() if split not in ['train', 'trainval', 'val']: raise ValueError( 'please pick split from \'train\', \'trainval\', \'val\'') if data_dir == 'auto': data_dir = voc_utils.get_voc('2012', split) id_list_file = os.path.join( data_dir, 'ImageSets/Segmentation/{0}.txt'.format(split)) self.ids = [id_.strip() for id_ in open(id_list_file)] self.data_dir = data_dir self.add_getter('img', self._get_image) self.add_getter(('mask', 'label'), self._get_annotations) def __len__(self): return len(self.ids) def _get_image(self, i): data_id = self.ids[i] img_file = os.path.join( self.data_dir, 'JPEGImages', data_id + '.jpg') return read_image(img_file, color=True) def _get_annotations(self, i): data_id = self.ids[i] label_img, inst_img = self._load_label_inst(data_id) mask, label = voc_utils.image_wise_to_instance_wise( label_img, inst_img) return mask, label def _load_label_inst(self, data_id): label_file = os.path.join( self.data_dir, 'SegmentationClass', data_id + '.png') inst_file = os.path.join( self.data_dir, 'SegmentationObject', data_id + '.png') label_img = read_image(label_file, dtype=np.int32, color=False) label_img = label_img[0] label_img[label_img == 255] = -1 inst_img = read_image(inst_file, dtype=np.int32, color=False) inst_img = inst_img[0] inst_img[inst_img == 0] = -1 inst_img[inst_img == 255] = -1 return label_img, inst_img