PiePline/piepline

View on GitHub
examples/notebooks/img_segmentation.ipynb

Summary

Maintainability
Test Coverage
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The images human portrait segmentation example.\n",
    "\n",
    "Images dataset was taken from [PicsArt](https://picsart.com/) AI Hackathon.\n",
    "\n",
    "Dataset may be downloaded [there](https://s3.eu-central-1.amazonaws.com/datasouls/public/picsart_hack_online_data.zip)\n",
    "\n",
    "For this example need to install this dependencies:\n",
    "``pip install sklearn, albumentations, opencv-python``"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "import cv2\n",
    "import os\n",
    "import numpy as np\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "from albumentations import Compose, HorizontalFlip, VerticalFlip, RandomRotate90, RandomGamma, RandomBrightnessContrast, RGBShift, \\\n",
    "    Resize, RandomCrop, OneOf\n",
    "\n",
    "from neural_pipeline import Trainer\n",
    "from neural_pipeline.builtin.models.albunet import resnet18\n",
    "from neural_pipeline.data_producer import AbstractDataset, DataProducer\n",
    "from neural_pipeline.train_config import AbstractMetric, MetricsProcessor, MetricsGroup, TrainStage, ValidationStage, TrainConfig\n",
    "from neural_pipeline.utils import FileStructManager\n",
    "from neural_pipeline.builtin.monitors.tensorboard import TensorboardMonitor\n",
    "from neural_pipeline.monitoring import LogMonitor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets_dir = 'data/dataset'\n",
    "base_dir = os.path.join(datasets_dir, 'picsart_hack_online_data')\n",
    "\n",
    "preprocess = OneOf([RandomCrop(height=224, width=224), Resize(width=224, height=224)], p=1)\n",
    "transforms = Compose([HorizontalFlip(), VerticalFlip(), RandomRotate90()], p=0.5)\n",
    "aug = Compose([RandomBrightnessContrast(brightness_limit=0.4, contrast_limit=0.4), RandomGamma(),\n",
    "               RGBShift(), transforms])\n",
    "\n",
    "\n",
    "def augmentate(item: {}):\n",
    "    res = preprocess(image=item['data'], mask=item['target'])\n",
    "    res = aug(image=res['image'], mask=res['mask'])\n",
    "    return {'data': res['image'], 'target': res['mask']}\n",
    "\n",
    "\n",
    "def augmentate_and_to_pytorch(item: {}):\n",
    "    res = augmentate(item)\n",
    "    return {'data': torch.from_numpy(np.moveaxis(res['data'].astype(np.float32) / 255., -1, 0)),\n",
    "            'target': torch.from_numpy(np.expand_dims(res['target'].astype(np.float32) / 255, axis=0))}\n",
    "\n",
    "\n",
    "class PicsartDataset(AbstractDataset):\n",
    "    def __init__(self, images_pathes: [], aug: callable):\n",
    "        images_dir = os.path.join(base_dir, 'train')\n",
    "        masks_dir = os.path.join(base_dir, 'train_mask')\n",
    "        images_pathes = sorted(images_pathes, key=lambda p: int(os.path.splitext(p)[0]))\n",
    "        self.__image_pathes = []\n",
    "        self.__aug = aug\n",
    "        for p in images_pathes:\n",
    "            name = os.path.splitext(p)[0]\n",
    "            mask_img = os.path.join(masks_dir, name + '.png')\n",
    "            if os.path.exists(mask_img):\n",
    "                path = {'data': os.path.join(images_dir, p), 'target': mask_img}\n",
    "                self.__image_pathes.append(path)\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.__image_pathes)\n",
    "\n",
    "    def __getitem__(self, item):\n",
    "        img = cv2.imread(self.__image_pathes[item]['data'])\n",
    "        return self.__aug({'data': img,\n",
    "                           'target': cv2.imread(self.__image_pathes[item]['target'], cv2.IMREAD_UNCHANGED)})\n",
    "\n",
    "\n",
    "images_dir = os.path.join(base_dir, 'train')\n",
    "images_pathes = [f for f in os.listdir(images_dir) if os.path.splitext(f)[1] == \".jpg\"]\n",
    "train_pathes, val_pathes = train_test_split(images_pathes, shuffle=True, test_size=0.2)\n",
    "\n",
    "train_dataset = PicsartDataset(train_pathes, augmentate_and_to_pytorch)\n",
    "val_dataset = PicsartDataset(val_pathes, augmentate_and_to_pytorch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "eps = 1e-6\n",
    "\n",
    "\n",
    "def dice(preds: torch.Tensor, trues: torch.Tensor) -> np.ndarray:\n",
    "    preds_inner = preds.data.cpu().numpy().copy()\n",
    "    trues_inner = trues.data.cpu().numpy().copy()\n",
    "\n",
    "    preds_inner = np.reshape(preds_inner, (preds_inner.shape[0], preds_inner.size // preds_inner.shape[0]))\n",
    "    trues_inner = np.reshape(trues_inner, (trues_inner.shape[0], trues_inner.size // trues_inner.shape[0]))\n",
    "\n",
    "    intersection = (preds_inner * trues_inner).sum(1)\n",
    "    scores = (2. * intersection + eps) / (preds_inner.sum(1) + trues_inner.sum(1) + eps)\n",
    "\n",
    "    return scores\n",
    "\n",
    "\n",
    "def jaccard(preds: torch.Tensor, trues: torch.Tensor):\n",
    "    preds_inner = preds.cpu().data.numpy().copy()\n",
    "    trues_inner = trues.cpu().data.numpy().copy()\n",
    "\n",
    "    preds_inner = np.reshape(preds_inner, (preds_inner.shape[0], preds_inner.size // preds_inner.shape[0]))\n",
    "    trues_inner = np.reshape(trues_inner, (trues_inner.shape[0], trues_inner.size // trues_inner.shape[0]))\n",
    "    intersection = (preds_inner * trues_inner).sum(1)\n",
    "    scores = (intersection + eps) / ((preds_inner + trues_inner).sum(1) - intersection + eps)\n",
    "\n",
    "    return scores\n",
    "\n",
    "\n",
    "class DiceMetric(AbstractMetric):\n",
    "    def __init__(self):\n",
    "        super().__init__('dice')\n",
    "\n",
    "    def calc(self, output: torch.Tensor, target: torch.Tensor) -> np.ndarray or float:\n",
    "        return dice(output, target)\n",
    "\n",
    "\n",
    "class JaccardMetric(AbstractMetric):\n",
    "    def __init__(self):\n",
    "        super().__init__('jaccard')\n",
    "\n",
    "    def calc(self, output: torch.Tensor, target: torch.Tensor) -> np.ndarray or float:\n",
    "        return jaccard(output, target)\n",
    "\n",
    "\n",
    "class SegmentationMetricsProcessor(MetricsProcessor):\n",
    "    def __init__(self, stage_name: str):\n",
    "        super().__init__()\n",
    "        self.jaccard_metric = JaccardMetric()\n",
    "        self.add_metrics_group(MetricsGroup(stage_name)\n",
    "                               .add(self.jaccard_metric)\n",
    "                               .add(DiceMetric()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from neural_pipeline.builtin.monitors.mpl import MPLMonitor\n",
    "\n",
    "train_data_producer = DataProducer([train_dataset], batch_size=2, num_workers=3)\n",
    "val_data_producer = DataProducer([val_dataset], batch_size=2, num_workers=3)\n",
    "\n",
    "model = resnet18(classes_num=1, in_channels=3, pretrained=True)\n",
    "\n",
    "train_stage = TrainStage(train_data_producer, SegmentationMetricsProcessor('train'))\n",
    "val_metrics_processor = SegmentationMetricsProcessor('validation')\n",
    "val_stage = ValidationStage(val_data_producer, val_metrics_processor)\n",
    "\n",
    "train_config = TrainConfig([train_stage, val_stage], torch.nn.BCEWithLogitsLoss(),\n",
    "                           torch.optim.Adam(model.parameters(), lr=1e-4))\n",
    "\n",
    "file_struct_manager = FileStructManager(checkpoint_dir_path=r\"data/checkpoints\", logdir_path=r\"data/logs\")\n",
    "\n",
    "trainer = Trainer(model, train_config, file_struct_manager).set_epoch_num(10)\n",
    "\n",
    "tensorboard = TensorboardMonitor(file_struct_manager, is_continue=False, network_name='PortraitSegmentation')\n",
    "mpl_monitor = MPLMonitor()\n",
    "log = LogMonitor(file_struct_manager, 'logs.json')\n",
    "trainer.monitor_hub.add_monitor(tensorboard).add_monitor(log).add_monitor(mpl_monitor)\n",
    "\n",
    "trainer.enable_lr_decaying(coeff=0.5, patience=10, target_val_clbk=lambda: np.mean(train_stage.get_losses()))\n",
    "trainer.add_on_epoch_end_callback(lambda: monitor.update_scalar('params\\lr', trainer.data_processor().get_lr()))\n",
    "trainer.train()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}