Source code for chainercv.extensions.vis_report.detection_vis_report

import copy
import os
import warnings

import chainer

from chainercv.visualizations.vis_bbox import vis_bbox

try:
    import matplotlib  # NOQA
    _available = True

except (ImportError, TypeError):
    _available = False


def _check_available():
    if not _available:
        warnings.warn('matplotlib is not installed on your environment, '
                      'so nothing will be plotted at this time. '
                      'Please install matplotlib to plot figures.\n\n'
                      '  $ pip install matplotlib\n')


[docs]class DetectionVisReport(chainer.training.extension.Extension): """An extension that visualizes output of a detection model. This extension visualizes the predicted bounding boxes together with the ground truth bounding boxes. Internally, this extension takes examples from an iterator, predict bounding boxes from the images in the examples, and visualizes them using :meth:`chainercv.visualizations.vis_bbox`. The process can be illustrated in the following code. .. code:: python batch = next(iterator) # Convert batch -> imgs, gt_bboxes, gt_labels pred_bboxes, pred_labels, pred_scores = target.predict(imgs) # Visualization code for img, gt_bbox, gt_label, pred_bbox, pred_label, pred_score \\ in zip(imgs, gt_boxes, gt_labels, pred_bboxes, pred_labels, pred_scores): # the ground truth vis_bbox(img, gt_bbox, gt_label) # the prediction vis_bbox(img, pred_bbox, pred_label, pred_score) .. note:: :obj:`gt_bbox` and :obj:`pred_bbox` are float arrays of shape :math:`(R, 4)`, where :math:`R` is the number of bounding boxes in the image. Each bounding box is organized by :math:`(y_{min}, x_{min}, y_{max}, x_{max})` in the second axis. :obj:`gt_label` and :obj:`pred_label` are intenger arrays of shape :math:`(R,)`. Each label indicates the class of the bounding box. :obj:`pred_score` is a float array of shape :math:`(R,)`. Each score indicates how confident the prediction is. Args: iterator: Iterator object that produces images and ground truth. target: Link object used for detection. label_names (iterable of strings): Name of labels ordered according to label ids. If this is :obj:`None`, labels will be skipped. filename (str): Basename for the saved image. It can contain two keywords, :obj:`'{iteration}'` and :obj:`'{index}'`. They are replaced with the iteration of the trainer and the index of the sample when this extension save an image. The default value is :obj:`'detection_iter={iteration}_idx={index}.jpg'`. """ def __init__( self, iterator, target, label_names=None, filename='detection_iter={iteration}_idx={index}.jpg'): _check_available() self.iterator = iterator self.target = target self.label_names = label_names self.filename = filename @staticmethod def available(): _check_available() return _available def __call__(self, trainer): if _available: # Dynamically import pyplot so that the backend of matplotlib # can be configured after importing chainercv. import matplotlib.pyplot as plt else: return if hasattr(self.iterator, 'reset'): self.iterator.reset() it = self.iterator else: it = copy.copy(self.iterator) idx = 0 while True: try: batch = next(it) except StopIteration: break imgs = [img for img, _, _ in batch] pred_bboxes, pred_labels, pred_scores = self.target.predict(imgs) for (img, gt_bbox, gt_label), pred_bbox, pred_label, pred_score \ in zip(batch, pred_bboxes, pred_labels, pred_scores): pred_bbox = chainer.backends.cuda.to_cpu(pred_bbox) pred_label = chainer.backends.cuda.to_cpu(pred_label) pred_score = chainer.backends.cuda.to_cpu(pred_score) out_file = self.filename.format( index=idx, iteration=trainer.updater.iteration) out_file = os.path.join(trainer.out, out_file) fig = plt.figure() ax_gt = fig.add_subplot(2, 1, 1) ax_gt.set_title('ground truth') vis_bbox( img, gt_bbox, gt_label, label_names=self.label_names, ax=ax_gt) ax_pred = fig.add_subplot(2, 1, 2) ax_pred.set_title('prediction') vis_bbox( img, pred_bbox, pred_label, pred_score, label_names=self.label_names, ax=ax_pred) plt.savefig(out_file) plt.close() idx += 1