kyle0x54/medvision

View on GitHub
medvision/visualization/image.py

Summary

Maintainability
B
5 hrs
Test Coverage
import collections
from enum import Enum, unique
import math
import cv2
import matplotlib.pyplot as plt
import numpy as np
import medvision as mv


@unique
class Color(Enum):
    """ An enum that defines common colors.
    """
    red = (0, 0, 255)
    green = (0, 255, 0)
    blue = (255, 0, 0)
    cyan = (255, 255, 0)
    yellow = (0, 255, 255)
    magenta = (255, 0, 255)
    white = (255, 255, 255)
    black = (0, 0, 0)


class Switcher:
    def __init__(self, ax, imgs):
        self.ax = ax

        self.imgs = imgs
        self.num_imgs = len(imgs)
        self.index = 0

        self.im = ax.imshow(imgs[self.index])
        self.update()

    def on_press(self, event):
        self.index = (self.index + 1) % self.num_imgs
        self.update()

    def update(self):
        self.im.set_data(self.imgs[self.index])
        self.im.axes.figure.canvas.draw_idle()
        try:
            self.im.axes.figure.canvas.flush_events()
        except NotImplementedError:
            pass


def _imshow_tight(img, title):
    cmap = 'gray' if img.ndim == 2 else None
    plt.imshow(img, cmap=cmap)
    plt.axis('off')
    plt.xlim([0, img.shape[1]])
    plt.ylim([img.shape[0], 0])
    plt.title(title)


def _imshow_switcher(imgs, title=''):
    """ imshow with channel changer.

    An enhanced version of 'imshow'. Images (stored in 'imgs') can be
    switched by single clicking.

    Args:
        imgs (list[ndarray]): images to be displayed.
        title (str): title of the plot.
    """
    fig, ax = plt.subplots(num=title)
    fig.tight_layout()
    ax.set_title(title)
    ax.axis('off')
    ref_img = imgs[0]
    ax.set_xlim([0, ref_img.shape[1]])
    ax.set_ylim([ref_img.shape[0], 0])
    switcher = Switcher(ax, imgs)
    fig.canvas.mpl_connect("button_press_event", switcher.on_press)
    plt.show()


def imshow(
    imgs,
    num_cols=None,
    fig_name=None,
    titles='',
    show=True,
    save_path=None,
):
    """ Show an image or multiple images in a single canvas.

    Args:
        imgs (ndarray or tuple/list[ndarray]): images to be shown.
        num_cols (int): image number per column for multiple images display.
            If not given, this parameter is automatically determined.
        fig_name (str): name of the plot.
        titles (str or list[str]): sub-plot titles.
        show (bool): True: show the image; False: save the image.
        save_path (str, optional): path to save the image.
    """
    if not isinstance(imgs, collections.Sequence):
        imgs = [imgs]
    num_imgs = len(imgs)

    if isinstance(titles, str):
        titles = [titles] * num_imgs
    assert len(titles) == num_imgs

    if num_cols is None:
        num_cols = int(math.ceil(math.sqrt(num_imgs)))
    assert num_cols <= len(imgs)
    num_rows = int((num_imgs + num_cols - 1) // num_cols)

    plt.figure(fig_name)
    plt.tight_layout()

    for i, img in enumerate(imgs):
        plt.subplot(num_rows, num_cols, i + 1)
        _imshow_tight(img, titles[i])

    if save_path is not None:
        plt.savefig(save_path)

    if show:
        plt.show()


def imshow_bboxes(
    img,
    bboxes,
    score_thr=0,
    colors=Color.green,
    top_k=-1,
    thickness=1,
    font_scale=0.5,
    font_thickness=1,
    font_color=Color.white,
    title='',
    show=True,
    save_path=None
):
    """ Draw bounding boxes on an image.

    To display detection result or compare detection results by different
    algorithms.

    Args:
        img (str or ndarray): image (or file path) to be displayed.
        bboxes (list or ndarray): a list of ndarray of shape (k, 4) or (n, 5).
        score_thr (float): minimum score of bboxes to be shown.
        colors (Color or list[Color]): color or list of colors.
        top_k (int): plot the first k bboxes only if set positive.
            Otherwise, plot all the bboxes.
        thickness (int): line thickness.
        font_scale (float): font scales of texts.
        font_thickness (int): font thickness.
        font_color (Color):  color of font.
        title (str): title of the plot.
        show (bool): True: show the image; False: save the image.
        save_path (str, optional): path to save the image.
    """
    if isinstance(img, str):
        img = mv.imread(img)
    else:
        img = img if img.ndim == 3 else mv.gray2rgb(img)

    if isinstance(bboxes, np.ndarray):
        bboxes = [bboxes]

    if isinstance(colors, Color):
        colors = [colors] * len(bboxes)

    assert len(bboxes) == len(colors)

    plot_prob = True if bboxes[0].shape[1] == 5 else False

    if score_thr > 0:
        for i in range(len(bboxes)):
            assert bboxes[i].shape[1] == 5
            scores = bboxes[i][:, -1]
            indices = scores > score_thr
            bboxes[i] = bboxes[i][indices, :]

    img_with_result = img.copy()
    for i, _bboxes in enumerate(bboxes):
        _bboxes_int = _bboxes.astype(np.int32)
        if top_k <= 0:
            _top_k = _bboxes.shape[0]
        else:
            _top_k = min(top_k, _bboxes.shape[0])
        for j in range(_top_k):
            left_top = (_bboxes_int[j, 0], _bboxes_int[j, 1])
            right_bottom = (_bboxes_int[j, 2], _bboxes_int[j, 3])
            cv2.rectangle(
                img_with_result, left_top, right_bottom,
                colors[i].value, thickness
            )
            if plot_prob:
                label_text = '%.2f' % _bboxes[j, -1]
                ((text_width, text_height), _) = cv2.getTextSize(
                    label_text, cv2.FONT_HERSHEY_SIMPLEX,
                    font_scale, font_thickness
                )
                cv2.rectangle(
                    img_with_result,
                    (left_top[0], left_top[1] - int(1.3 * text_height)),
                    (left_top[0] + text_width, left_top[1]),
                    colors[i].value, -1
                )
                cv2.putText(
                    img_with_result, label_text,
                    (left_top[0], left_top[1] - int(0.3 * text_height)),
                    cv2.FONT_HERSHEY_SIMPLEX, font_scale,
                    font_color.value,
                    thickness=font_thickness, lineType=cv2.LINE_AA
                )

    if show:
        _imshow_switcher([img_with_result, img], title)

    if save_path is not None:
        mv.imwrite(save_path, img_with_result)


if __name__ == '__main__':
    im = cv2.imread(
        '../../tests/data/pngs/Blue-Ogi.png',
        cv2.IMREAD_GRAYSCALE
    )
    imshow(im, fig_name='show single image', titles='name')

    h, w = im.shape[:2]
    bbox = np.array([w // 3, h // 3, w * 2 // 3, h * 2 // 3,
                     0.5]).reshape(-1, 5)
    imshow_bboxes(im, bbox, score_thr=0.2, title='draw bounding boxes')

    im = cv2.imread('../../tests/data/pngs/Blue-Ogi.png', cv2.IMREAD_COLOR)
    im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
    imshow([im] * 5, fig_name='show multiple images',
           titles=[str(i) for i in range(5)])