examples/notebooks/img_segmentation.ipynb
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The images human portrait segmentation example.\n",
"\n",
"Images dataset was taken from [PicsArt](https://picsart.com/) AI Hackathon.\n",
"\n",
"Dataset may be downloaded [there](https://s3.eu-central-1.amazonaws.com/datasouls/public/picsart_hack_online_data.zip)\n",
"\n",
"For this example need to install this dependencies:\n",
"``pip install sklearn, albumentations, opencv-python``"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"import cv2\n",
"import os\n",
"import numpy as np\n",
"\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"from albumentations import Compose, HorizontalFlip, VerticalFlip, RandomRotate90, RandomGamma, RandomBrightnessContrast, RGBShift, \\\n",
" Resize, RandomCrop, OneOf\n",
"\n",
"from neural_pipeline import Trainer\n",
"from neural_pipeline.builtin.models.albunet import resnet18\n",
"from neural_pipeline.data_producer import AbstractDataset, DataProducer\n",
"from neural_pipeline.train_config import AbstractMetric, MetricsProcessor, MetricsGroup, TrainStage, ValidationStage, TrainConfig\n",
"from neural_pipeline.utils import FileStructManager\n",
"from neural_pipeline.builtin.monitors.tensorboard import TensorboardMonitor\n",
"from neural_pipeline.monitoring import LogMonitor"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"datasets_dir = 'data/dataset'\n",
"base_dir = os.path.join(datasets_dir, 'picsart_hack_online_data')\n",
"\n",
"preprocess = OneOf([RandomCrop(height=224, width=224), Resize(width=224, height=224)], p=1)\n",
"transforms = Compose([HorizontalFlip(), VerticalFlip(), RandomRotate90()], p=0.5)\n",
"aug = Compose([RandomBrightnessContrast(brightness_limit=0.4, contrast_limit=0.4), RandomGamma(),\n",
" RGBShift(), transforms])\n",
"\n",
"\n",
"def augmentate(item: {}):\n",
" res = preprocess(image=item['data'], mask=item['target'])\n",
" res = aug(image=res['image'], mask=res['mask'])\n",
" return {'data': res['image'], 'target': res['mask']}\n",
"\n",
"\n",
"def augmentate_and_to_pytorch(item: {}):\n",
" res = augmentate(item)\n",
" return {'data': torch.from_numpy(np.moveaxis(res['data'].astype(np.float32) / 255., -1, 0)),\n",
" 'target': torch.from_numpy(np.expand_dims(res['target'].astype(np.float32) / 255, axis=0))}\n",
"\n",
"\n",
"class PicsartDataset(AbstractDataset):\n",
" def __init__(self, images_pathes: [], aug: callable):\n",
" images_dir = os.path.join(base_dir, 'train')\n",
" masks_dir = os.path.join(base_dir, 'train_mask')\n",
" images_pathes = sorted(images_pathes, key=lambda p: int(os.path.splitext(p)[0]))\n",
" self.__image_pathes = []\n",
" self.__aug = aug\n",
" for p in images_pathes:\n",
" name = os.path.splitext(p)[0]\n",
" mask_img = os.path.join(masks_dir, name + '.png')\n",
" if os.path.exists(mask_img):\n",
" path = {'data': os.path.join(images_dir, p), 'target': mask_img}\n",
" self.__image_pathes.append(path)\n",
"\n",
" def __len__(self):\n",
" return len(self.__image_pathes)\n",
"\n",
" def __getitem__(self, item):\n",
" img = cv2.imread(self.__image_pathes[item]['data'])\n",
" return self.__aug({'data': img,\n",
" 'target': cv2.imread(self.__image_pathes[item]['target'], cv2.IMREAD_UNCHANGED)})\n",
"\n",
"\n",
"images_dir = os.path.join(base_dir, 'train')\n",
"images_pathes = [f for f in os.listdir(images_dir) if os.path.splitext(f)[1] == \".jpg\"]\n",
"train_pathes, val_pathes = train_test_split(images_pathes, shuffle=True, test_size=0.2)\n",
"\n",
"train_dataset = PicsartDataset(train_pathes, augmentate_and_to_pytorch)\n",
"val_dataset = PicsartDataset(val_pathes, augmentate_and_to_pytorch)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"eps = 1e-6\n",
"\n",
"\n",
"def dice(preds: torch.Tensor, trues: torch.Tensor) -> np.ndarray:\n",
" preds_inner = preds.data.cpu().numpy().copy()\n",
" trues_inner = trues.data.cpu().numpy().copy()\n",
"\n",
" preds_inner = np.reshape(preds_inner, (preds_inner.shape[0], preds_inner.size // preds_inner.shape[0]))\n",
" trues_inner = np.reshape(trues_inner, (trues_inner.shape[0], trues_inner.size // trues_inner.shape[0]))\n",
"\n",
" intersection = (preds_inner * trues_inner).sum(1)\n",
" scores = (2. * intersection + eps) / (preds_inner.sum(1) + trues_inner.sum(1) + eps)\n",
"\n",
" return scores\n",
"\n",
"\n",
"def jaccard(preds: torch.Tensor, trues: torch.Tensor):\n",
" preds_inner = preds.cpu().data.numpy().copy()\n",
" trues_inner = trues.cpu().data.numpy().copy()\n",
"\n",
" preds_inner = np.reshape(preds_inner, (preds_inner.shape[0], preds_inner.size // preds_inner.shape[0]))\n",
" trues_inner = np.reshape(trues_inner, (trues_inner.shape[0], trues_inner.size // trues_inner.shape[0]))\n",
" intersection = (preds_inner * trues_inner).sum(1)\n",
" scores = (intersection + eps) / ((preds_inner + trues_inner).sum(1) - intersection + eps)\n",
"\n",
" return scores\n",
"\n",
"\n",
"class DiceMetric(AbstractMetric):\n",
" def __init__(self):\n",
" super().__init__('dice')\n",
"\n",
" def calc(self, output: torch.Tensor, target: torch.Tensor) -> np.ndarray or float:\n",
" return dice(output, target)\n",
"\n",
"\n",
"class JaccardMetric(AbstractMetric):\n",
" def __init__(self):\n",
" super().__init__('jaccard')\n",
"\n",
" def calc(self, output: torch.Tensor, target: torch.Tensor) -> np.ndarray or float:\n",
" return jaccard(output, target)\n",
"\n",
"\n",
"class SegmentationMetricsProcessor(MetricsProcessor):\n",
" def __init__(self, stage_name: str):\n",
" super().__init__()\n",
" self.jaccard_metric = JaccardMetric()\n",
" self.add_metrics_group(MetricsGroup(stage_name)\n",
" .add(self.jaccard_metric)\n",
" .add(DiceMetric()))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from neural_pipeline.builtin.monitors.mpl import MPLMonitor\n",
"\n",
"train_data_producer = DataProducer([train_dataset], batch_size=2, num_workers=3)\n",
"val_data_producer = DataProducer([val_dataset], batch_size=2, num_workers=3)\n",
"\n",
"model = resnet18(classes_num=1, in_channels=3, pretrained=True)\n",
"\n",
"train_stage = TrainStage(train_data_producer, SegmentationMetricsProcessor('train'))\n",
"val_metrics_processor = SegmentationMetricsProcessor('validation')\n",
"val_stage = ValidationStage(val_data_producer, val_metrics_processor)\n",
"\n",
"train_config = TrainConfig([train_stage, val_stage], torch.nn.BCEWithLogitsLoss(),\n",
" torch.optim.Adam(model.parameters(), lr=1e-4))\n",
"\n",
"file_struct_manager = FileStructManager(checkpoint_dir_path=r\"data/checkpoints\", logdir_path=r\"data/logs\")\n",
"\n",
"trainer = Trainer(model, train_config, file_struct_manager).set_epoch_num(10)\n",
"\n",
"tensorboard = TensorboardMonitor(file_struct_manager, is_continue=False, network_name='PortraitSegmentation')\n",
"mpl_monitor = MPLMonitor()\n",
"log = LogMonitor(file_struct_manager, 'logs.json')\n",
"trainer.monitor_hub.add_monitor(tensorboard).add_monitor(log).add_monitor(mpl_monitor)\n",
"\n",
"trainer.enable_lr_decaying(coeff=0.5, patience=10, target_val_clbk=lambda: np.mean(train_stage.get_losses()))\n",
"trainer.add_on_epoch_end_callback(lambda: monitor.update_scalar('params\\lr', trainer.data_processor().get_lr()))\n",
"trainer.train()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}