mfinzi/pristine-ml

View on GitHub
oil/architectures/img_classifiers/transformer.py

Summary

Maintainability
A
1 hr
Test Coverage
import torch
from torch.autograd import Variable
from torch.nn import Parameter
import torch.nn.functional as F
import torch.nn as nn
import torch.nn.init as init

import numpy as np
from torch.nn.utils import weight_norm
import math
from ..parts import conv2d
from ...utils.utils import Expression,export,Named


def RestrictedAttention(Q,K,V,P=0):
    """ Self attention mechanism, O = softmax(QK^T/sqrt(d) + P)V
        Q matrix has shape (bs,n,d)
        K matrix has shape (bs,n,r,d)  where r is the number of points in nbhd
        V matrix has shape (bs,n,r,d)
        P matrix has shape (bs,n,r)
        O output has shape (bs,n,d) """
    n,d = Q.shape[-2:] # (bs,n,1,d)@(bs,n,d,r) -> (bs,n,1,r) -> (bs,n,r)
    att_scores = (Q.unsqueeze(2)@K.permute(0,1,3,2)).squeeze(2)/np.sqrt(d) # (bs,n,r)
    weighting = torch.softmax(att_scores + P,axis=-1) # (bs,n,r)
    # (bs,n,1,r)@(bs,n,r,d) -> (bs,n,1,d) -> (bs,n,d)
    weighted_values = (weighting.unsqueeze(2)@V).squeeze(2) 
    return weighted_values

def fold_heads_into_batchdim(x,num_heads):
    """ Converts x of shape (bs,*,d) -> (num_heads*bs,*,d//num_heads)"""
    d = x.shape[-1]
    bs = x.shape[0]
    M = len(x.shape)
    heads_at_front = x.view(*x.shape[:-1],d//num_heads,num_heads).permute(M,*range(M))
    return heads_at_front.reshape(bs*num_heads,*x.shape[1:-1],d//num_heads)

def fold_heads_outof_batchdim(x,num_heads):
    """ Converts x of shape (num_heads*bs,*,d//num_heads) -> (bs,*,d)"""
    bs = x.shape[0]//num_heads
    d = x.shape[-1]*num_heads
    M = len(x.shape)
    return x.view(num_heads,bs,*x.shape[1:]).permute(*range(1,M+1),0).reshape(bs,*x.shape[1:-1],d)

class FFNPositionalNetwork(nn.Module):
    def __init__(self,ch,nbhd_extractor,num_heads):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(2,ch),nn.ReLU(),nn.Linear(ch,num_heads))
        self.nbhd_extractor = nbhd_extractor # (bs,n,d) -> (bs,n,r,d)
    def forward(self,x):
         # assumes x is an image with shape (bs,h,w,c)
        bs,c,h,w = x.shape                                              # (1,h*w,2)
        coords = torch.stack(torch.meshgrid([torch.linspace(-3,3,h),torch.linspace(-3,3,w)]),dim=-1).view(h*w,2).unsqueeze(0)
        relative_positional_enc = self.nbhd_extractor(coords) - coords.unsqueeze(2) #(p'-p), (1,h*w,r,2)
        positional_scores = self.net(relative_positional_enc.cuda()) # (1,h*w,r,2) -> =(1,h*w,r,nh)
        return positional_scores.repeat(bs,1,1,1) #(bs,h*w,r,nh)

class RestrictedSelfAttention(nn.Module):
    def __init__(self,ch_in,nbhd_extractor,num_heads=8):
        super().__init__()
        self.WQ = nn.Linear(ch_in,ch_in)
        self.WK = nn.Linear(ch_in,ch_in)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            
        self.WV = nn.Linear(ch_in,ch_in)
        self.WO = nn.Linear(ch_in,ch_in)            #TODO: unfix 2
        self.nbhd_extractor = nbhd_extractor
        self.num_heads = num_heads

    def forward(self,X,P=0):
        #print(X.shape)
        #print(P.shape)
        """Expects X shape (bs,n,d), P of shape (bs,n,r,nh)"""
        Queries = fold_heads_into_batchdim(self.WQ(X),self.num_heads) #(nh*bs,n,d//nh)
        nbhd_Keys = fold_heads_into_batchdim(self.nbhd_extractor(self.WK(X)),self.num_heads) #(nh*bs,n,r,d//nh)
        nbhd_Vals = fold_heads_into_batchdim(self.nbhd_extractor(self.WV(X)),self.num_heads) #(nh*bs,n,r,d//nh)
        nbhd_Pos = fold_heads_into_batchdim(P,self.num_heads).squeeze(-1)                  #(nh*bs,n,r)
        folded_attended_vals = RestrictedAttention(Queries,nbhd_Keys,nbhd_Vals,nbhd_Pos)
        attended_vals = fold_heads_outof_batchdim(folded_attended_vals,self.num_heads)  #(nh*bs,n,d//nh) ->(bs,n,d)
        return self.WO(attended_vals)

# Plan:
# 1) implement standalone self-attention conv replacement layer, verify that it works on cifar10
#       - (investigate replacing bottleneck block with a transformer block (w/ FFN)) #2 relus vs 1
#       - replace positional encoding style with that used in https://arxiv.org/pdf/1904.11491.pdf
#       - There are two options: multi-head self attention or just convolution
# 2) square 7 x 7 block -> nearest 50 neighbors precomputed on the images (could use kd-tree because 2d)
#       - optimize implementation efficiency
#       - tune number of neighbors
# 3) investigate non grid pooling mechanisms. Candidates:
#       a) random subsampling
#       b) randomly placed w/positional (or other) attention
#       c) bottom up superpixel segmentation/aggregation
#       d) neural net parametrizes a density p(x,y) proportional to e^{-f(x,y)}
#               - use HMC or NUTS to sample and use attention to aggregate
#               - restrict to subset of original points, subsample or use attention to aggregate
# 4) investigate predicting subsample factor sigmoid (0,1) and penalizing by the computation time
#       - So that this factor has pos signal, linearly iterpolate perf at more/less points
#       - subsample factors to be shared across elements in the minibatch to ease batching
#       - design heuristic so that no scheduling of the cost factor is necessary, eg: 
# 5) Alternatively, control downsampling factor to maximize the derivative of train/val loss wrt time = dl/di di/dt
#       - can use 



def extract_image_patches(x, kernel, stride=1, dilation=1):
    """Assumes input has shape (bs,c,h,w) output has shape"""
    # Do TF 'SAME' Padding
    b,c,h,w = x.shape
    h2 = math.ceil(h / stride)
    w2 = math.ceil(w / stride)
    pad_row = (h2 - 1) * stride + (kernel - 1) * dilation + 1 - h
    pad_col = (w2 - 1) * stride + (kernel - 1) * dilation + 1 - w
    x = F.pad(x, (pad_row//2, pad_row - pad_row//2, pad_col//2, pad_col - pad_col//2))
    
    # Extract patches
    patches = x.unfold(2, kernel, stride).unfold(3, kernel, stride)
    patches = patches.permute(0,4,5,1,2,3).contiguous()
    # has shape [bs,h,w,c,k,k]
    return patches


def square_nbhd_extractor(diameter):
    # diam is the diameter of the square nbhd, ie nbhd is (diam,diam)
    def diam_square_nbhd_extractor(x):
        # (bs,n,d) -> (bs,n,r,d)
        r = diameter**2
        bs,n,d = x.shape # unpack x into an image
        h = w = int(np.sqrt(n)) # assume img is square atm
        patches = extract_image_patches(x.permute(0,2,1).reshape(bs,h,w,d),diameter) #(bs,h,w,d,k,k)
        return patches.view(bs,n,d,r).permute(0,1,3,2)
    return diam_square_nbhd_extractor



class AttConvReplacement(nn.Module):
    def __init__(self, channels, ksize, num_heads=8):
        super().__init__()
        self.nbhd_extractor = square_nbhd_extractor(ksize)
        self.position_network = FFNPositionalNetwork(channels*8,self.nbhd_extractor,num_heads)
        self.mha = RestrictedSelfAttention(channels,self.nbhd_extractor,num_heads)
    def forward(self,X):
        P = self.position_network(X)
        bs,c,h,w = X.shape
        X_as_points = X.permute(0,2,3,1).view(bs,h*w,c)
        # (bs,n,c) - > (bs,c,h,w)
        return self.mha(X_as_points,P).permute(0,2,1).view(X.shape)

class PositionOnlyAtt(nn.Module):
    def __init__(self, ch, ksize):
        super().__init__()
        self.nbhd_extractor = square_nbhd_extractor(ksize)
        self.position_network = nn.Sequential(nn.Linear(2,ch),nn.ReLU(),nn.Linear(ch,ch))
    def forward(self,x):
         # assumes x is an image with shape (bs,h,w,c)
        bs,c,h,w = x.shape                                              # (1,h*w,2)
        coords = torch.stack(torch.meshgrid([torch.linspace(-3,3,h),torch.linspace(-3,3,w)]),dim=-1).view(h*w,2).unsqueeze(0)
        relative_positional_enc = self.nbhd_extractor(coords) - coords.unsqueeze(2) #(p'-p), (1,h*w,r,2)
        P = self.position_network(relative_positional_enc.cuda()) # (1,h*w,r,2) -> (1,h*w,r,nh)
        weighting = fold_heads_into_batchdim(torch.softmax(P,axis=2).repeat(bs,1,1,1),c).squeeze(-1)#(bs,n,r,c)->(bs*c,n,r)
        X_as_points = x.permute(0,2,3,1).view(bs,h*w,c) #(bs,c,h,w) -> (bs,h*w,c)
        V = fold_heads_into_batchdim(self.nbhd_extractor(X_as_points)).squeeze(-1)  # (bs,n,c) -> (bs,n,r,c) -> (bs*c,n,r)
        # (bs*c,n,r)*(bs*c,n,r) -> (bs,n,1,d) -> (bs,n,d)
        return fold_heads_outof_batchdim((weighting*V).sum(-1).unsqueeze(-1),c).permute(0,2,1).view(x.shape)

class AttentionConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, groups=1, bias=False):
        super(AttentionConv, self).__init__()
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.groups = groups

        assert self.out_channels % self.groups == 0, "out_channels should be divided by groups. (example: out_channels: 40, groups: 4)"

        self.rel_h = nn.Parameter(torch.randn(out_channels // 2, 1, 1, kernel_size, 1), requires_grad=True)
        self.rel_w = nn.Parameter(torch.randn(out_channels // 2, 1, 1, 1, kernel_size), requires_grad=True)

        self.key_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias)
        self.query_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias)
        self.value_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias)

        self.reset_parameters()

    def forward(self, x):
        batch, channels, height, width = x.size()

        padded_x = F.pad(x, [self.padding, self.padding, self.padding, self.padding])
        q_out = self.query_conv(x)
        k_out = self.key_conv(padded_x)
        v_out = self.value_conv(padded_x)

        k_out = k_out.unfold(2, self.kernel_size, self.stride).unfold(3, self.kernel_size, self.stride)
        v_out = v_out.unfold(2, self.kernel_size, self.stride).unfold(3, self.kernel_size, self.stride)

        v_out_h, v_out_w = v_out.split(self.out_channels // 2, dim=1)
        v_out = torch.cat((v_out_h + self.rel_h, v_out_w + self.rel_w), dim=1)

        k_out = k_out.contiguous().view(batch, self.groups, self.out_channels // self.groups, height, width, -1)
        v_out = v_out.contiguous().view(batch, self.groups, self.out_channels // self.groups, height, width, -1)

        q_out = q_out.view(batch, self.groups, self.out_channels // self.groups, height, width, 1)

        out = q_out * k_out
        out = F.softmax(out, dim=-1)
        out = torch.einsum('bnchwk,bnchwk -> bnchw', out, v_out).view(batch, -1, height, width)

        return out

    def reset_parameters(self):
        init.kaiming_normal_(self.key_conv.weight, mode='fan_out', nonlinearity='relu')
        init.kaiming_normal_(self.value_conv.weight, mode='fan_out', nonlinearity='relu')
        init.kaiming_normal_(self.query_conv.weight, mode='fan_out', nonlinearity='relu')

        init.normal_(self.rel_h, 0, 1)
        init.normal_(self.rel_w, 0, 1)

@export
class AttResBlock(nn.Module):
    def __init__(self,k=64,ksize=7,drop_rate=0,stride=1,gn=False,num_heads=8):
        super().__init__()
        norm_layer = (lambda c: nn.GroupNorm(c//16,c)) if gn else nn.BatchNorm2d
        self.net = nn.Sequential(
            norm_layer(k),
            nn.ReLU(),
            conv2d(k,k,1),
            norm_layer(k),
            nn.ReLU(),
            AttentionConv(k,k,kernel_size=ksize, padding=ksize//2, groups=8),#conv2d(k,k,3),#AttConvReplacement(k,ksize,num_heads),
            nn.Dropout(p=drop_rate),
        )

    def forward(self,x):
        return x + self.net(x)

@export
class layer13a(nn.Module,metaclass=Named):
    """
    Very small CNN
    """
    def __init__(self, num_targets=10,k=64,ksize=7,num_heads=8):
        super().__init__()
        self.num_targets = num_targets
        self.net = nn.Sequential(
            conv2d(3,k,1),#AttConvReplacement(3,k,ksize),
            *[AttResBlock(k,ksize,num_heads=num_heads) for i in range(3)],
            nn.AvgPool2d(2),
            conv2d(k,2*k,1),
            *[AttResBlock(2*k,ksize,num_heads=num_heads) for i in range(3)],
            nn.AvgPool2d(2),
            *[AttResBlock(2*k,ksize,num_heads=num_heads) for i in range(3)],
            Expression(lambda u:u.mean(-1).mean(-1)),
            nn.Linear(2*k,num_targets)
        )
    def forward(self,x):
        return self.net(x)

# Version 2, more transformer style oriented

def FFN(k):
    ## assumes bs, *, k
    return nn.Sequential(nn.Conv2d(k,4*k,1),nn.ReLU(),nn.Conv2d(4*k,k,1))

class AddAndNorm(nn.Module):
    def __init__(self,block,ch_in,dropout=0):
        super().__init__()
        self.block = block
        self.layerNorm = nn.BatchNorm2d(ch_in)#nn.LayerNorm(normalized_shape=(ch_in,))
        self.dropout = nn.Dropout(dropout)
    def forward(self,X):
        """Expects X shape (bs,n,d)"""
        # Non Standard (pre-residual) placement of layernorm, see: https://openreview.net/pdf?id=B1x8anVFPr
        return X+self.dropout(self.block(self.layerNorm(X)))


class TransformerBlock(nn.Module):
    def __init__(self,hidden_dim,ksize=5,num_heads=8,dropout=0):
        super().__init__()
        MHA = AttConvReplacement(hidden_dim, ksize, num_heads)
        FF = FFN(hidden_dim)
        self.net = nn.Sequential(AddAndNorm(MHA,hidden_dim,dropout),AddAndNorm(FF,hidden_dim,dropout))
    def forward(self,X):
        """Expects X shape (bs,n,hidden_dim)"""
        return self.net(X)

@export
class layer13at(nn.Module,metaclass=Named):
    """
    Very small CNN
    """
    def __init__(self, num_targets=10,k=64,ksize=5,num_heads=8):
        super().__init__()
        self.num_targets = num_targets
        self.net = nn.Sequential(
            conv2d(3,k,1),#AttConvReplacement(3,k,ksize),
            *[TransformerBlock(k,ksize,num_heads) for _ in range(2)],
            nn.AvgPool2d(2),
            conv2d(k,2*k,1),
             *[TransformerBlock(2*k,ksize,num_heads) for _ in range(2)],
            nn.AvgPool2d(2),
             *[TransformerBlock(2*k,ksize,num_heads) for _ in range(2)],
            Expression(lambda u:u.mean(-1).mean(-1)),
            nn.Linear(2*k,num_targets)
        )
    def forward(self,x):
        return self.net(x)