mfinzi/OMGchess

View on GitHub
chess/transformer.py

Summary

Maintainability
A
1 hr
Test Coverage
import torch
import torch.nn as nn
import numpy as np

from Seq2SlateModel import FeatureEmbedding

# architecture
# positional encoding is added to each of the feature vectors

# One transformer encoder block:
# Attention(x) = LayerNorm(x+Dropout(MultiHeadAtt(x)))

# FFN(x) = LayerNorm(x+Dropout(Dense(ReLU(Dense(x))))) # elementwise


class FFN(nn.Module):
    def __init__(self,hidden_dim,inner_dim):
        super().__init__()
        with self.name_scope():
            self.linear1 = nn.Dense(inner_dim,flatten=False)
            self.linear2 = nn.Dense(hidden_dim,flatten=False)
            self.relu = gluon.nn.Activation('relu')
    def forward(self,X):
        return self.linear2(self.relu(self.linear1(X)))

def Attention(Q,K,V,P=None):
    """Self attention mechanism, softmax(QK^T/sqrt(d))V
       assumes Q,K,V have shape (bs,n,d). Optionally includes relative
       positional encoding matrix P with shape (n,n,d)."""
    bs,n,d = Q.shape
    Kt = K.transpose((0,2,1)) # (bs,n,d)->(bs,d,n)
    scores = nd.linalg.gemm2(Q,Kt) # (bs,n,n)                          
                                #(n,bs,d) x (n,d,n) -> (n,bs,n) -> (bs,n,n)
    pos_scores = nd.linalg.gemm2(Q.transpose((1,0,2)),P).transpose((1,0,2)) if P is not None else 0
    weighting = nd.softmax((scores+pos_scores)/np.sqrt(d),axis=-1)
    weighted_values = nd.linalg.gemm2(weighting,V)
    return weighted_values

class SelfAttentionHead(gluon.Block):
    def __init__(self,head_dim,P=None):
        super().__init__()
        with self.name_scope():
            self.WQ = nn.Dense(head_dim,flatten=False)
            self.WK = nn.Dense(head_dim,flatten=False)
            self.WV = nn.Dense(head_dim,flatten=False)

    def forward(self,X,P=None):
        """Expects X shape (bs,n,d)"""
        return Attention(self.WQ(X),self.WK(X),self.WV(X),P)

class MultiHeadAttention(gluon.Block):
    k = 2 # 2 is enough for relative position see https://www.aclweb.org/anthology/N18-2074
    def __init__(self,hidden_dim,num_heads=4,rel_pos=False):
        """ The query,key, and value dimensions are hidden_dim/num_heads"""
        super().__init__()
        with self.name_scope():
            self.heads = nn.Sequential()
            for _ in range(num_heads):
                self.heads.add(SelfAttentionHead(hidden_dim//num_heads))
            self.WO = nn.Dense(hidden_dim,flatten=False) 
            self.positionalEmbedding = nn.Embedding(2*self.k+1,hidden_dim//num_heads) if rel_pos else None
    def forward(self,X):
        """Expects X shape (bs,n,hidden_dim)"""
        bs,n,hd = X.shape
        if self.positionalEmbedding is not None:
            rel_positions = ((nd.arange(n).expand_dims(-1)-nd.arange(n)).clip(-self.k,self.k)+self.k).astype(int)
            P = self.positionalEmbedding(rel_positions).transpose((0,2,1)) # (n,n,d) -> (n,d,n)
        else: P = None
        return self.WO(nd.Concat(*[head(X,P) for head in self.heads],dim=-1))


class AddAndNorm(gluon.Block):
    def __init__(self,block,dropout=0):
        super().__init__()
        with self.name_scope():
            self.block = block
            self.layerNorm = nn.LayerNorm()
            self.dropout = nn.Dropout(dropout)
    def forward(self,X):
        """Expects X shape (bs,n,d)"""
        return self.layerNorm(X+self.dropout(self.block(X)))

class TransformerBlock(gluon.Block):
    def __init__(self,hidden_dim,inner_dim,num_heads=4,dropout=0,rel_pos=True):
        super().__init__()
        with self.name_scope():
            MHA = MultiHeadAttention(hidden_dim,num_heads,rel_pos)
            self.attention_block = AddAndNorm(MHA,dropout)
            FF = FFN(hidden_dim,inner_dim)
            self.feed_forward_block = AddAndNorm(FF,dropout)
    def forward(self,X):
        """Expects X shape (bs,n,hidden_dim)"""
        return self.feed_forward_block(self.attention_block(X))


class SlateTransformer(gluon.Block):
    def __init__(self,hidden_dim,vocab_size,num_blocks,inner_dim,num_heads=4,embed_dim=16,dropout=0):
        """Hidden dim specifies the dimension of x, inner dim the middle size of the FFN """
        super().__init__()
        with self.name_scope():
            self.embedding = FeatureEmbedding(vocab_size,embed_dim,dropout=dropout)
            self.net = nn.Sequential()
            self.net.add(nn.Dense(hidden_dim,flatten=False))
            with self.net.name_scope():
                for _ in range(num_blocks):
                    self.net.add(TransformerBlock(hidden_dim,inner_dim,num_heads,dropout))
                self.net.add(nn.Dense(1,flatten=False))
    def forward(self, x):
        """assumes that x is of shape (n x bs x .)"""
        embedded_inputs = self.embedding(x).transpose((1,0,2)) # (n,bs,d) -> (bs,n,d)
        return self.net(embedded_inputs)[:,:,0].T # (bs,n,1) -> (n,bs,1)