oil/architectures/img_classifiers/densenet.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from ...utils.utils import Named
"""
Densenet model definition
ported from https://github.com/bearpaw/pytorch-classification/blob/master/models/cifar/densenet.py
"""
__all__ = ['DenseNet','DenseNetBC12','DenseNetBC40']
from torch.autograd import Variable
class Bottleneck(nn.Module):
def __init__(self, inplanes, expansion=4, growthRate=12, dropRate=0):
super(Bottleneck, self).__init__()
planes = expansion * growthRate
self.bn1 = nn.BatchNorm2d(inplanes)
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, growthRate, kernel_size=3,
padding=1, bias=False)
self.relu = nn.ReLU(inplace=True)
self.dropRate = dropRate
def forward(self, x):
out = self.bn1(x)
out = self.relu(out)
out = self.conv1(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv2(out)
if self.dropRate > 0:
out = F.dropout(out, p=self.dropRate, training=self.training)
out = torch.cat((x, out), 1)
return out
class BasicBlock(nn.Module):
def __init__(self, inplanes, expansion=1, growthRate=12, dropRate=0):
super(BasicBlock, self).__init__()
planes = expansion * growthRate
self.bn1 = nn.BatchNorm2d(inplanes)
self.conv1 = nn.Conv2d(inplanes, growthRate, kernel_size=3,
padding=1, bias=False)
self.relu = nn.ReLU(inplace=True)
self.dropRate = dropRate
def forward(self, x):
out = self.bn1(x)
out = self.relu(out)
out = self.conv1(out)
if self.dropRate > 0:
out = F.dropout(out, p=self.dropRate, training=self.training)
out = torch.cat((x, out), 1)
return out
class Transition(nn.Module):
def __init__(self, inplanes, outplanes):
super(Transition, self).__init__()
self.bn1 = nn.BatchNorm2d(inplanes)
self.conv1 = nn.Conv2d(inplanes, outplanes, kernel_size=1,
bias=False)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
out = self.bn1(x)
out = self.relu(out)
out = self.conv1(out)
out = F.avg_pool2d(out, 2)
return out
class DenseNet(nn.Module,metaclass=Named):
def __init__(self, depth=22, block=Bottleneck,
drop_rate=0, num_targets=10, k=12, compressionRate=2):
super(DenseNet, self).__init__()
assert (depth - 4) % 3 == 0, 'depth should be 3n+4'
n = (depth - 4) / 3 if block == BasicBlock else (depth - 4) // 6
self.growthRate = k
self.dropRate = drop_rate
# self.inplanes is a global variable used across multiple
# helper functions
self.inplanes = k * 2
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, padding=1,
bias=False)
self.dense1 = self._make_denseblock(block, n)
self.trans1 = self._make_transition(compressionRate)
self.dense2 = self._make_denseblock(block, n)
self.trans2 = self._make_transition(compressionRate)
self.dense3 = self._make_denseblock(block, n)
self.bn = nn.BatchNorm2d(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(8)
self.fc = nn.Linear(self.inplanes, num_targets)
# Weight initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_denseblock(self, block, blocks):
layers = []
for i in range(blocks):
# Currently we fix the expansion ratio as the default value
layers.append(block(self.inplanes, growthRate=self.growthRate, dropRate=self.dropRate))
self.inplanes += self.growthRate
return nn.Sequential(*layers)
def _make_transition(self, compressionRate):
inplanes = self.inplanes
outplanes = int(math.floor(self.inplanes // compressionRate))
self.inplanes = outplanes
return Transition(inplanes, outplanes)
def forward(self, x):
x = self.conv1(x)
x = self.trans1(self.dense1(x))
x = self.trans2(self.dense2(x))
x = self.dense3(x)
x = self.bn(x)
x = self.relu(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
class DenseNetBC12(DenseNet):
def __init__(self,num_targets=10,drop_rate=0):
super().__init__(depth=100,k=12,drop_rate=drop_rate,num_targets=num_targets)
class DenseNetBC40(DenseNet):
def __init__(self,num_targets=10,drop_rate=0):
super().__init__(depth=190,k=40,drop_rate=drop_rate,num_targets=num_targets)