mfinzi/pristine-ml

View on GitHub
oil/datasetup/camvid.py

Summary

Maintainability
A
45 mins
Test Coverage

# Adapted from https://github.com/felixgwu/vision/blob/cf491d301f62ae9c77ff7250fb7def5cd55ec963/torchvision/datasets/camvid.py
import os
import torch
import torch.utils.data as data
import numpy as np
from PIL import Image
from torchvision.datasets.folder import default_loader
import torchvision.transforms as transforms
from .joint_transforms import JointRandomCrop, JointRandomHorizontalFlip

def make_dataset(dir):
    images = []
    for root, _, fnames in sorted(os.walk(dir)):
        for fname in fnames:
            #if is_image_file(fname):
            path = os.path.join(root, fname)
            item = path
            images.append(item)
    return images

def LabelToLongTensor(pic):
    label = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
    label = label.view(pic.size[1], pic.size[0], 1)
    label = label.transpose(0, 1).transpose(0, 2).squeeze().contiguous().long()
    return label

class CamVid(data.Dataset):
    # weights when using median frequency balancing used in SegNet paper
    # https://arxiv.org/pdf/1511.00561.pdf
    # The numbers were generated by https://github.com/yandex/segnet-torch/blob/master/datasets/camvid-gen.lua
    classes = ['Sky', 'Building', 'Column-Pole', 'Road',
           'Sidewalk', 'Tree', 'Sign-Symbol', 'Fence', 'Car', 'Pedestrain',
           'Bicyclist', 'Void']
    class_weights = [0.58872014284134, 0.51052379608154, 2.6966278553009, 0.45021694898605, 1.1785038709641,
                0.77028578519821, 2.4782588481903, 2.5273461341858, 1.0122526884079, 3.2375309467316,
                4.1312313079834, 0]
    class_color = [(128, 128, 128),(128, 0, 0),(192, 192, 128),(128, 64, 128),(0, 0, 192),(128, 128, 0),
                    (192, 128, 128),(64, 64, 128),(64, 0, 128),(64, 64, 0),(0, 128, 192),(0, 0, 0),]
    num_classes=11
    means = [0.41189489566336, 0.4251328133025, 0.4326707089857]
    stds = [0.27413549931506, 0.28506257482912, 0.28284674400252]
    

    def __init__(self, root, split='train', joint_transform=None,
                 transform=None, download=False,
                 loader=default_loader):
        self.root = root
        assert split in ('train', 'val', 'test')
        self.split = split
        self.transform = transform
        self.joint_transform = joint_transform

        if download:
            self.download()

        self.imgs = make_dataset(os.path.join(self.root, self.split))


    def __getitem__(self, index):
        path = self.imgs[index]
        img = self.loader(path)
        target = Image.open(path.replace(self.split, self.split + 'annot'))

        if self.joint_transform is not None:
            img, target = self.joint_transform([img, target])

        if self.transform is not None:
            img = self.transform(img)

        target = LabelToLongTensor(target)
        return img, target

    def __len__(self):
        return len(self.imgs)

    def download(self):
        # TODO: please download the dataset from
        # https://github.com/alexgkendall/SegNet-Tutorial/tree/master/CamVid
        raise NotImplementedError

    @classmethod
    def LabelToPILImage(self, label):
        label = label.unsqueeze(0)
        colored_label = torch.zeros(3, label.size(1), label.size(2)).byte()
        for i, color in enumerate(self.class_color):
            mask = label.eq(i)
            for j in range(3):
                colored_label[j].masked_fill_(mask, color[j])
        npimg = colored_label.numpy()
        npimg = np.transpose(npimg, (1, 2, 0))
        mode = None
        if npimg.shape[2] == 1:
            npimg = npimg[:, :, 0]
            mode = "L"
        return Image.fromarray(npimg, mode=mode)