oil/architectures/parts/deconv.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules import conv
from torch.nn.modules.utils import _pair
from ...utils.utils import Expression,export,Named
#import cv2
#This is a reference implementation using im2col, and is not used anywhere else
class Conv2d(conv._ConvNd):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,bias=False):
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
super(Conv2d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _pair(0), 1, False)
self.kernel_size=kernel_size
self.dilation=dilation
self.padding=padding
self.stride=stride
def forward(self, x):
N,C,H,W=x.shape
out_h=(H+2*self.padding[0]-self.kernel_size[0]+1)//self.stride[0]
out_w=(W+2*self.padding[0]-self.kernel_size[0]+1)//self.stride[1]
w=self.weight
#im2col
inp_unf = torch.nn.functional.unfold(x, self.kernel_size,self.dilation,self.padding,self.stride)
#matrix multiplication, reshape
out_unf = inp_unf.transpose(1, 2).matmul(w.view(w.size(0), -1).t()).transpose(1, 2).view(N,-1,out_h,out_w)
return out_unf
#iteratively solve for inverse sqrt of a matrix
def isqrt_newton_schulz_autograd(A, numIters):
dim = A.shape[0]
normA=A.norm()
Y = A.div(normA)
I = torch.eye(dim,dtype=A.dtype,device=A.device)
Z = torch.eye(dim,dtype=A.dtype,device=A.device)
for i in range(numIters):
T = 0.5*(3.0*I - Z@Y)
Y = Y@T
Z = T@Z
#A_sqrt = Y*torch.sqrt(normA)
A_isqrt = Z / torch.sqrt(normA)
return A_isqrt
#deconvolve channels
class ChannelDeconv(nn.Module):
def __init__(self, num_groups, eps=1e-2,n_iter=5,momentum=0.1,sampling_stride=3,debug=False):
super(ChannelDeconv, self).__init__()
self.eps = eps
self.n_iter=n_iter
self.momentum=momentum
self.num_groups = num_groups
self.debug=debug
self.register_buffer('running_mean1', torch.zeros(num_groups, 1))
#self.register_buffer('running_cov', torch.eye(num_groups))
self.register_buffer('running_deconv', torch.eye(num_groups))
self.register_buffer('running_mean2', torch.zeros(1, 1))
self.register_buffer('running_var', torch.ones(1, 1))
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
self.sampling_stride=sampling_stride
def forward(self, x):
x_shape = x.shape
if len(x.shape)==2:
x=x.view(x.shape[0],x.shape[1],1,1)
if len(x.shape)==3:
print('Error! Unsupprted tensor shape.')
N, C, H, W = x.size()
G = self.num_groups
#take the first c channels out for deconv
c=int(C/G)*G
if c==0:
print('Error! num_groups should be set smaller.')
#step 1. remove mean
if c!=C:
x1=x[:,:c].permute(1,0,2,3).contiguous().view(G,-1)
else:
x1=x.permute(1,0,2,3).contiguous().view(G,-1)
if self.sampling_stride > 1 and H >= self.sampling_stride and W >= self.sampling_stride:
x1_s = x1[:,::self.sampling_stride**2]
else:
x1_s=x1
mean1 = x1_s.mean(-1, keepdim=True)
if self.num_batches_tracked==0:
self.running_mean1.copy_(mean1.detach())
if self.training:
self.running_mean1.mul_(1-self.momentum)
self.running_mean1.add_(mean1.detach()*self.momentum)
else:
mean1 = self.running_mean1
x1=x1-mean1
#step 2. calculate deconv@x1 = cov^(-0.5)@x1
if self.training:
cov = x1_s @ x1_s.t() / x1_s.shape[1] + self.eps * torch.eye(G, dtype=x.dtype, device=x.device)
deconv = isqrt_newton_schulz_autograd(cov, self.n_iter)
if self.num_batches_tracked==0:
#self.running_cov.copy_(cov.detach())
self.running_deconv.copy_(deconv.detach())
if self.training:
#self.running_cov.mul_(1-self.momentum)
#self.running_cov.add_(cov.detach()*self.momentum)
self.running_deconv.mul_(1 - self.momentum)
self.running_deconv.add_(deconv.detach() * self.momentum)
else:
# cov = self.running_cov
deconv = self.running_deconv
x1 =deconv@x1
#reshape to N,c,J,W
x1 = x1.view(c, N, H, W).contiguous().permute(1,0,2,3)
# normalize the remaining channels
if c!=C:
x_tmp=x[:, c:].view(N,-1)
if self.sampling_stride > 1 and H>=self.sampling_stride and W>=self.sampling_stride:
x_s = x_tmp[:, ::self.sampling_stride ** 2]
else:
x_s = x_tmp
mean2=x_s.mean()
var=x_s.var()
if self.num_batches_tracked == 0:
self.running_mean2.copy_(mean2.detach())
self.running_var.copy_(var.detach())
if self.training:
self.running_mean2.mul_(1 - self.momentum)
self.running_mean2.add_(mean2.detach() * self.momentum)
self.running_var.mul_(1 - self.momentum)
self.running_var.add_(var.detach() * self.momentum)
else:
mean2 = self.running_mean2
var = self.running_var
x_tmp = (x[:, c:] - mean2) / (var + self.eps).sqrt()
x1 = torch.cat([x1, x_tmp], dim=1)
if self.training:
self.num_batches_tracked.add_(1)
if len(x_shape)==2:
x1=x1.view(x_shape)
return x1
@export
class DeConv2d(conv._ConvNd):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,bias=True, eps=1e-2, n_iter=5, momentum=0.1, mode=4, num_groups=16,debug=False):
# mode 1: remove channel correlation then pixel correlation
# mode 2: only remove pixel correlation
# mode 3: only channel correlation
# mode 4: remove channel correlation and pixel correlation together
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
self.kernel_size=kernel_size
self.dilation=dilation
self.padding=padding
self.stride=stride
super(DeConv2d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _pair(0), 1, bias, padding_mode='zeros')
#add padding_mode='zeros' for pytorch 1.1
self.momentum = momentum
self.mode=mode
self.n_iter = n_iter
self.eps = eps
num_features = self.weight.shape[2] * self.weight.shape[3]#k*k
if self.mode!=2:
if num_groups>self.weight.shape[1]:
num_groups=self.weight.shape[1]
self.num_groups=num_groups
if self.mode!=4:
self.channel_deconv=ChannelDeconv(num_groups,eps=eps,n_iter=n_iter,momentum=momentum,debug=False)
else:
num_features*=num_groups
self.num_features = num_features
if self.mode!=3:
self.register_buffer('running_mean', torch.zeros(num_features,1))
#self.register_buffer('running_cov', torch.eye(num_features))
self.register_buffer('running_deconv', torch.eye(num_features))
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
def forward(self, x):
N,C,H,W=x.shape
out_h=(H+2*self.padding[0]-self.kernel_size[0]+1)//self.stride[0]
out_w=(W+2*self.padding[0]-self.kernel_size[0]+1)//self.stride[1]
if self.mode == 1:
x = self.channel_deconv(x)
if self.mode==3:
x=self.channel_deconv(x)
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, 1)
if self.mode!=3:
#1. im2col, reshape
# N * cols * pixels
inp_unf = torch.nn.functional.unfold(x, self.kernel_size,self.dilation,self.padding,self.stride)
#(k*k, C*N*H*W) for pixel deconv
#(k*k*G, C//G*N*H*W) for grouped pixel deconv
X=inp_unf.permute(1,0,2).contiguous().view(self.num_features,-1)
#2.subtract mean
X_mean = X.mean(-1, keepdim=True)
#track stats for evaluation
if self.num_batches_tracked==0:
self.running_mean.copy_(X_mean.detach())
if self.training:
self.running_mean.mul_(1-self.momentum)
self.running_mean.add_(X_mean.detach()*self.momentum)
else:
X_mean = self.running_mean
X = X - X_mean
#3. calculate COV, COV^(-0.5), then deconv
if self.training:
Cov = X / X.shape[1] @ X.t() + self.eps * torch.eye(X.shape[0], dtype=X.dtype, device=X.device)
deconv = isqrt_newton_schulz_autograd(Cov, self.n_iter)
#track stats for evaluation
if self.num_batches_tracked==0:
#self.running_cov.copy_(Cov.detach())
self.running_deconv.copy_(deconv.detach())
if self.training:
#self.running_cov.mul_(1-self.momentum)
#self.running_cov.add_(Cov.detach()*self.momentum)
self.running_deconv.mul_(1 - self.momentum)
self.running_deconv.add_(deconv.detach() * self.momentum)
else:
#Cov = self.running_cov
deconv = self.running_deconv
#deconv
X_deconv =deconv@X
#reshape
X_deconv=X_deconv.view(-1,N,out_h*out_w).contiguous().permute(1,2,0)
#4. convolve
w = self.weight
out_unf = X_deconv.matmul(w.view(w.size(0), -1).t()).transpose(1, 2).view(N,-1,out_h,out_w)
if self.bias is not None:
out_unf=out_unf+self.bias.view(1,-1,1,1)
if self.training:
self.num_batches_tracked.add_(1)
return out_unf#.contiguous()
#this version is faster but slightly weaker. We approximately remove the mean.
@export
class FastDeconv(conv._ConvNd):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,bias=True, eps=1e-2, n_iter=5, momentum=0.1, num_groups=16,sampling_stride=3):
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
self.kernel_size=kernel_size
self.dilation=dilation
self.padding=padding
self.stride=stride
self.momentum = momentum
self.n_iter = n_iter
self.eps = eps
super(FastDeconv, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _pair(0), 1, bias, padding_mode='zeros')
if num_groups>in_channels:
num_groups=in_channels
self.num_groups=num_groups
self.num_features = self.kernel_size[0] * self.kernel_size[1]*num_groups
self.register_buffer('running_mean', torch.zeros(1,self.num_groups, 1, 1))
self.register_buffer('running_deconv', torch.eye(self.num_features))
self.sampling_stride=[sampling_stride*s for s in stride]
def forward(self, x):
N,C,H,W=x.shape
N1,C1=N*C//self.num_groups,self.num_groups
# 1.subtract mean (this is a fast approximation)
x=x.view(N1,C1,H,W)
# track stats for evaluation
if self.training:
x_mean = x.mean((0,2,3), keepdim=True)
self.running_mean.mul_(1 - self.momentum)
self.running_mean.add_(x_mean.detach() * self.momentum)
else:
x_mean = self.running_mean
x = x - x_mean
x=x.view(N,C,H,W)
#2. im2col: N x cols x pixels
inp_unf = torch.nn.functional.unfold(x, self.kernel_size,self.dilation,self.padding,self.sampling_stride)
#(k*k*G, C//G*N*H*W) for grouped pixel deconv
X = inp_unf.transpose(0,1).contiguous().view(self.num_features, -1)
#3. calculate COV, COV^(-0.5), then deconv
if self.training:
Cov = X / X.shape[1] @ X.t() + self.eps * torch.eye(X.shape[0], dtype=X.dtype, device=X.device)
deconv = isqrt_newton_schulz_autograd(Cov, self.n_iter)
#track stats for evaluation
if self.training:
self.running_deconv.mul_(1 - self.momentum)
self.running_deconv.add_(deconv.detach() * self.momentum)
else:
deconv = self.running_deconv
#deconv + conv
w=self.weight.view(self.weight.shape[0],-1).t().contiguous().view(self.num_features,-1)
w=deconv@w
w=w.view(-1,self.weight.shape[0]).t().view(self.weight.shape)
return F.conv2d(x, w.view(self.weight.shape), self.bias, self.stride, self.padding, self.dilation, 1)