mfinzi/LieConv

View on GitHub
lie_conv/lieConv.py

Summary

Maintainability
D
1 day
Test Coverage
""" 
Note to the Reader:

In order to handle the generic spatial data we process features as a tuple: (coordinates,values,mask)
with shapes (bs,n,d), (bs,n,c), (bs,n) where bs is batchsize, n is the maximum number of points in the minibatch,
d is the dimension of the coordinates, and c is the number of channels of the feature map.
The mask specifies which elements are valid and ~mask specifies the elements that have been added through padding.
For the PointConv operation and networks elements are passed through the network as this tuple (coordinates,values,mask).

Naively for LieConv we would process (lie_algebra_elems,values,mask) with the same shapes but with d as the dimension of the group.
However, as an optimization to avoid repeated calculation of the pairs log(v^{-1}u)=log(e^{-b}e^{a}), we instead compute this for all
pairs once at the lifting stage which has the name 'ab_pairs' in the code and shape (bs,n,n,d). Subsampling operates on this matrix
by subsampling both n axes. abq_pairs also includes the q pairs (embedded orbit identifiers) so abq_pairs =  [log(e^{-b}e^{a}),qa,qb].
So the tuple  (abq_pairs,values,mask) with shapes (bs,n,n,d) (bs,n,c) (bs,n) is passed through the network. 
The 'Pass' module is used extensively to operate on only one of these and return the tuple with the rest unchanged, 
such as for computing a swish nonlinearity on values.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from lie_conv.utils import Expression,export,Named, Pass
from lie_conv.utils import FarthestSubsample, knn_point, index_points
from lie_conv.lieGroups import T,SO2,SO3,SE2,SE3, norm
from lie_conv.masked_batchnorm import MaskBatchNormNd


@export
def Swish():
    return Expression(lambda x: x*torch.sigmoid(x))

def LinearBNact(chin,chout,act='swish',bn=True):
    """assumes that the inputs to the net are shape (bs,n,mc_samples,c)"""
    assert act in ('relu','swish'), f"unknown activation type {act}"
    normlayer = MaskBatchNormNd(chout)
    return nn.Sequential(
        Pass(nn.Linear(chin,chout),dim=1),
        normlayer if bn else nn.Sequential(),
        Pass(Swish() if act=='swish' else nn.ReLU(),dim=1))

def WeightNet(in_dim,out_dim,act,bn,k=32):
    return nn.Sequential(
        *LinearBNact(in_dim, k, act, bn),
        *LinearBNact(k, k, act, bn),
        *LinearBNact(k, out_dim, act, bn))

class PointConv(nn.Module):
    def __init__(self,chin,chout,mc_samples=32,xyz_dim=3,ds_frac=1,knn_channels=None,act='swish',bn=False,mean=False):
        super().__init__()
        self.chin = chin # input channels
        self.cmco_ci = 16 # a hyperparameter controlling size and bottleneck compute cost of weightnet
        self.xyz_dim = xyz_dim # dimension of the space on which convolution operates
        self.knn_channels = knn_channels # number of xyz dims on which to compute knn
        self.mc_samples = mc_samples # number of samples to use to estimate convolution
        self.weightnet = WeightNet(xyz_dim, self.cmco_ci, act, bn) # MLP - final layer to compute kernel vals (see A1)
        self.linear = nn.Linear(self.cmco_ci * chin, chout)        # final linear layer to compute kernel vals (see A1)
        self.mean=mean  # Whether or not to divide by the number of mc_samples
        assert ds_frac==1, "test branch no ds, will need to carefully check that subsample respects padding"
        self.subsample = FarthestSubsample(ds_frac,knn_channels=knn_channels) # Module for subsampling if ds_frac<1

    def extract_neighborhood(self,inp,query_xyz):
        """ inputs shape ([inp_xyz (bs,n,d)], [inp_vals (bs,n,c)], [query_xyz (bs,m,d)])"""
        inp_xyz,inp_vals,mask = inp
        neighbor_idx = knn_point(min(self.mc_samples, inp_xyz.shape[1]),
                    inp_xyz[:,:,:self.knn_channels], query_xyz[:,:,:self.knn_channels],mask)
        neighbor_xyz = index_points(inp_xyz, neighbor_idx) # (bs,n,mc_samples,d)
        neighbor_values = index_points(inp_vals, neighbor_idx) #(bs,n,mc_samples,c)
        neighbor_mask = index_points(mask,neighbor_idx) # (bs,n,mc_samples)
        return neighbor_xyz, neighbor_values, neighbor_mask

    def point_convolve(self,embedded_group_elems,nbhd_vals, nbhd_mask):
        """ embedded_group_elems: (bs,m,nbhd,d)
            nbhd_vals: (bs,m,mc_samples,ci)
            nbhd_mask: (bs,m,mc_samples)"""
        bs, m, nbhd, ci = nbhd_vals.shape  # (bs,m,mc_samples,d) -> (bs,m,mc_samples,cm*co/ci)
        _, penult_kernel_weights, _ = self.weightnet((None,embedded_group_elems,nbhd_mask))
        penult_kernel_weights_m = torch.where(nbhd_mask.unsqueeze(-1),penult_kernel_weights,torch.zeros_like(penult_kernel_weights))
        nbhd_vals_m = torch.where(nbhd_mask.unsqueeze(-1),nbhd_vals,torch.zeros_like(nbhd_vals))
        #      (bs,m,mc_samples,ci) -> (bs,m,ci,mc_samples) @ (bs, m, mc_samples, cmco/ci) -> (bs,m,ci,cmco/ci) -> (bs,m, cmco) 
        partial_convolved_vals = (nbhd_vals_m.transpose(-1,-2)@penult_kernel_weights_m).view(bs, m, -1)
        convolved_vals = self.linear(partial_convolved_vals) #  (bs,m,cmco) -> (bs,m,co)
        if self.mean: convolved_vals /= nbhd_mask.sum(-1,keepdim=True).clamp(min=1)
        return convolved_vals

    def get_embedded_group_elems(self,output_xyz,nbhd_xyz):
        return output_xyz - nbhd_xyz
    
    def forward(self, inp):
        """inputs, and outputs have shape ([xyz (bs,n,d)], [vals (bs,n,c)])
            query_xyz has shape (bs,n,d)"""
        query_xyz, sub_vals, sub_mask = self.subsample(inp)
        nbhd_xyz, nbhd_vals, nbhd_mask = self.extract_neighborhood(inp, query_xyz)
        deltas = self.get_embedded_group_elems(query_xyz.unsqueeze(2), nbhd_xyz)
        convolved_vals = self.point_convolve(deltas, nbhd_vals, nbhd_mask)
        convolved_wzeros = torch.where(sub_mask.unsqueeze(-1),convolved_vals,torch.zeros_like(convolved_vals))
        return query_xyz, convolved_wzeros, sub_mask

def FPSindices(dists,frac,mask):
    """ inputs: pairwise distances DISTS (bs,n,n), downsample_frac (float), valid atom mask (bs,n) 
        outputs: chosen_indices (bs,m) """
    m = int(np.round(frac*dists.shape[1]))
    device = dists.device
    bs,n = dists.shape[:2]
    chosen_indices = torch.zeros(bs, m, dtype=torch.long,device=device)
    distances = torch.ones(bs, n,device=device) * 1e8
    a = torch.randint(0, n, (bs,), dtype=torch.long,device=device) #choose random start
    idx = a%mask.sum(-1) + torch.cat([torch.zeros(1,device=device).long(),torch.cumsum(mask.sum(-1),dim=0)[:-1]],dim=0)
    farthest = torch.where(mask)[1][idx]
    B = torch.arange(bs, dtype=torch.long,device=device)
    for i in range(m):
        chosen_indices[:, i] = farthest # add point that is farthest to chosen
        dist = torch.where(mask,dists[B,farthest],-100*torch.ones_like(distances)) # (bs,n) compute distance from new point to all others
        closer = dist < distances      # if dist from new point is smaller than chosen points so far
        distances[closer] = dist[closer] # update the chosen set's distance to all other points
        farthest = torch.max(distances, -1)[1] # select the point that is farthest from the set
    return chosen_indices


class FPSsubsample(nn.Module):
    def __init__(self,ds_frac,cache=False,group=None):
        super().__init__()
        self.ds_frac = ds_frac
        self.cache=cache
        self.cached_indices = None
        self.group = group
    def forward(self,inp,withquery=False):
        abq_pairs,vals,mask = inp
        dist = self.group.distance if self.group else lambda ab: norm(ab,dim=-1)
        if self.ds_frac!=1:
            if self.cache and self.cached_indices is None:
                query_idx = self.cached_indices = FPSindices(dist(abq_pairs),self.ds_frac,mask).detach()
            elif self.cache:
                query_idx = self.cached_indices
            else:
                query_idx = FPSindices(dist(abq_pairs),self.ds_frac,mask)
            B = torch.arange(query_idx.shape[0],device=query_idx.device).long()[:,None]
            subsampled_abq_pairs = abq_pairs[B,query_idx][B,:,query_idx]
            subsampled_values = vals[B,query_idx]
            subsampled_mask = mask[B,query_idx]
        else:
            subsampled_abq_pairs = abq_pairs
            subsampled_values = vals
            subsampled_mask = mask
            query_idx = None
        if withquery: return (subsampled_abq_pairs,subsampled_values,subsampled_mask, query_idx)
        return (subsampled_abq_pairs,subsampled_values,subsampled_mask)

class LieConv(PointConv):
    def __init__(self,*args,group=T(3),ds_frac=1,fill=1/3,cache=False,knn=False,**kwargs):
        kwargs.pop('xyz_dim',None)
        super().__init__(*args,xyz_dim=group.lie_dim+2*group.q_dim,**kwargs)
        self.group = group # Equivariance group for LieConv
        self.register_buffer('r',torch.tensor(2.)) # Internal variable for local_neighborhood radius, set by fill
        self.fill_frac = min(fill,1.) # Average Fraction of the input which enters into local_neighborhood, determines r
        self.knn=knn            # Whether or not to use the k nearest points instead of random samples for conv estimator
        self.subsample = FPSsubsample(ds_frac,cache=cache,group=self.group)
        self.coeff = .5  # Internal coefficient used for updating r
        self.fill_frac_ema = fill # Keeps track of average fill frac, used for logging only
        
    def extract_neighborhood(self,inp,query_indices):
        """ inputs: [pairs_abq (bs,n,n,d), inp_vals (bs,n,c), mask (bs,n), query_indices (bs,m)]
            outputs: [neighbor_abq (bs,m,mc_samples,d), neighbor_vals (bs,m,mc_samples,c)]"""

        # Subsample pairs_ab, inp_vals, mask to the query_indices
        pairs_abq, inp_vals, mask = inp
        if query_indices is not None:
            B = torch.arange(inp_vals.shape[0],device=inp_vals.device).long()[:,None]
            abq_at_query = pairs_abq[B,query_indices]
            mask_at_query = mask[B,query_indices]
        else:
            abq_at_query = pairs_abq
            mask_at_query = mask
        vals_at_query = inp_vals
        dists = self.group.distance(abq_at_query) #(bs,m,n,d) -> (bs,m,n)
        dists = torch.where(mask[:,None,:].expand(*dists.shape),dists,1e8*torch.ones_like(dists))
        k = min(self.mc_samples,inp_vals.shape[1])
        
        # Determine ids (and mask) for points sampled within neighborhood (A4)
        if self.knn: # NBHD: KNN
            nbhd_idx = torch.topk(dists,k,dim=-1,largest=False,sorted=False)[1] #(bs,m,nbhd)
            valid_within_ball = (nbhd_idx>-1)&mask[:,None,:]&mask_at_query[:,:,None]
            assert not torch.any(nbhd_idx>dists.shape[-1]), f"error with topk,\
                        nbhd{k} nans|inf{torch.any(torch.isnan(dists)|torch.isinf(dists))}"
        else: # NBHD: Sampled Distance Ball
            bs,m,n = dists.shape
            within_ball = (dists < self.r)&mask[:,None,:]&mask_at_query[:,:,None] # (bs,m,n)
            B = torch.arange(bs)[:,None,None]
            M = torch.arange(m)[None,:,None]
            noise = torch.zeros(bs,m,n,device=within_ball.device)
            noise.uniform_(0,1)
            valid_within_ball, nbhd_idx =torch.topk(within_ball+noise,k,dim=-1,largest=True,sorted=False)
            valid_within_ball = (valid_within_ball>1)
        
        # Retrieve abq_pairs, values, and mask at the nbhd locations
        B = torch.arange(inp_vals.shape[0],device=inp_vals.device).long()[:,None,None].expand(*nbhd_idx.shape)
        M = torch.arange(abq_at_query.shape[1],device=inp_vals.device).long()[None,:,None].expand(*nbhd_idx.shape)
        nbhd_abq = abq_at_query[B,M,nbhd_idx]     #(bs,m,n,d) -> (bs,m,mc_samples,d)
        nbhd_vals = vals_at_query[B,nbhd_idx]   #(bs,n,c) -> (bs,m,mc_samples,c)
        nbhd_mask = mask[B,nbhd_idx]            #(bs,n) -> (bs,m,mc_samples)
        
        if self.training and not self.knn: # update ball radius to match fraction fill_frac inside
            navg = (within_ball.float()).sum(-1).sum()/mask_at_query[:,:,None].sum()
            avg_fill = (navg/mask.sum(-1).float().mean()).cpu().item()
            self.r +=  self.coeff*(self.fill_frac - avg_fill)
            self.fill_frac_ema += .1*(avg_fill-self.fill_frac_ema)
        return nbhd_abq, nbhd_vals, (nbhd_mask&valid_within_ball.bool())

    # def log_data(self,logger,step,name):
    #     logger.add_scalars('info', {f'{name}_fill':self.fill_frac_ema}, step=step)
    #     logger.add_scalars('info', {f'{name}_R':self.r}, step=step)

    def point_convolve(self,embedded_group_elems,nbhd_vals,nbhd_mask):
        """ Uses generalized PointConv trick (A1) to compute convolution using pairwise elems (aij) and nbhd vals (vi).
            inputs [embedded_group_elems (bs,m,mc_samples,d), nbhd_vals (bs,m,mc_samples,ci), nbhd_mask (bs,m,mc_samples)]
            outputs [convolved_vals (bs,m,co)]"""
        bs, m, nbhd, ci = nbhd_vals.shape  # (bs,m,mc_samples,d) -> (bs,m,mc_samples,cm*co/ci)
        _, penult_kernel_weights, _ = self.weightnet((None,embedded_group_elems,nbhd_mask))
        penult_kernel_weights_m = torch.where(nbhd_mask.unsqueeze(-1),penult_kernel_weights,torch.zeros_like(penult_kernel_weights))
        nbhd_vals_m = torch.where(nbhd_mask.unsqueeze(-1),nbhd_vals,torch.zeros_like(nbhd_vals))
        #      (bs,m,mc_samples,ci) -> (bs,m,ci,mc_samples) @ (bs, m, mc_samples, cmco/ci) -> (bs,m,ci,cmco/ci) -> (bs,m, cmco) 
        partial_convolved_vals = (nbhd_vals_m.transpose(-1,-2)@penult_kernel_weights_m).view(bs, m, -1)
        convolved_vals = self.linear(partial_convolved_vals) #  (bs,m,cmco) -> (bs,m,co)
        if self.mean: convolved_vals /= nbhd_mask.sum(-1,keepdim=True).clamp(min=1) # Divide by num points
        return convolved_vals

    def forward(self, inp):
        """inputs: [pairs_abq (bs,n,n,d)], [inp_vals (bs,n,ci)]), [query_indices (bs,m)]
           outputs [subsampled_abq (bs,m,m,d)], [convolved_vals (bs,m,co)]"""
        sub_abq, sub_vals, sub_mask, query_indices = self.subsample(inp,withquery=True)
        nbhd_abq, nbhd_vals, nbhd_mask = self.extract_neighborhood(inp, query_indices)
        convolved_vals = self.point_convolve(nbhd_abq, nbhd_vals, nbhd_mask)
        convolved_wzeros = torch.where(sub_mask.unsqueeze(-1),convolved_vals,torch.zeros_like(convolved_vals))
        return sub_abq, convolved_wzeros, sub_mask



@export
def pConvBNrelu(in_channels,out_channels,bn=True,act='swish',**kwargs):
    return nn.Sequential(
        PointConv(in_channels,out_channels,bn=bn,**kwargs),
        MaskBatchNormNd(out_channels) if bn else nn.Sequential(),
        Pass(Swish() if act=='swish' else nn.ReLU(),dim=1)
    )

@export
def LieConvBNrelu(in_channels,out_channels,bn=True,act='swish',**kwargs):
    return nn.Sequential(
        LieConv(in_channels,out_channels,bn=bn,**kwargs),
        MaskBatchNormNd(out_channels) if bn else nn.Sequential(),
        Pass(Swish() if act=='swish' else nn.ReLU(),dim=1)
    )


class BottleBlock(nn.Module):
    """ A bottleneck residual block as described in figure 5"""
    def __init__(self,chin,chout,conv,bn=False,act='swish',fill=None):
        super().__init__()
        assert chin<= chout, f"unsupported channels chin{chin}, chout{chout}. No upsampling atm."
        nonlinearity = Swish if act=='swish' else nn.ReLU
        self.conv = conv(chin//4,chout//4,fill=fill) if fill is not None else conv(chin//4,chout//4)
        self.net = nn.Sequential(
            MaskBatchNormNd(chin) if bn else nn.Sequential(),
            Pass(nonlinearity(),dim=1),
            Pass(nn.Linear(chin,chin//4),dim=1),
            MaskBatchNormNd(chin//4) if bn else nn.Sequential(),
            Pass(nonlinearity(),dim=1),
            self.conv,
            MaskBatchNormNd(chout//4) if bn else nn.Sequential(),
            Pass(nonlinearity(),dim=1),
            Pass(nn.Linear(chout//4,chout),dim=1),
        )
        self.chin = chin
    def forward(self,inp):
        sub_coords, sub_values, mask = self.conv.subsample(inp)
        new_coords, new_values, mask = self.net(inp)
        new_values[...,:self.chin] += sub_values
        return new_coords, new_values, mask

class GlobalPool(nn.Module):
    """computes values reduced over all spatial locations (& group elements) in the mask"""
    def __init__(self,mean=False):
        super().__init__()
        self.mean = mean
        
    def forward(self,x):
        """x [xyz (bs,n,d), vals (bs,n,c), mask (bs,n)]"""
        if len(x)==2: return x[1].mean(1)
        coords, vals,mask = x
        summed = torch.where(mask.unsqueeze(-1),vals,torch.zeros_like(vals)).sum(1)
        if self.mean:
            summed /= mask.sum(-1).unsqueeze(-1)
        return summed

@export
class LieResNet(nn.Module):
    """ Generic LieConv architecture from Fig 5. Relevant Arguments:
        [Fill] specifies the fraction of the input which is included in local neighborhood. 
                (can be array to specify a different value for each layer)
        [nbhd] number of samples to use for Monte Carlo estimation (p)
        [chin] number of input channels: 1 for MNIST, 3 for RGB images, other for non images
        [ds_frac] total downsampling to perform throughout the layers of the net. In (0,1)
        [num_layers] number of BottleNeck Block layers in the network
        [k] channel width for the network. Can be int (same for all) or array to specify individually.
        [liftsamples] number of samples to use in lifting. 1 for all groups with trivial stabilizer. Otherwise 2+
        [Group] Chosen group to be equivariant to.
        [bn] whether or not to use batch normalization. Recommended in all cases except dynamical systems.
        """
    def __init__(self, chin, ds_frac=1, num_outputs=1, k=1536, nbhd=np.inf,
                act="swish", bn=True, num_layers=6, mean=True, per_point=True,pool=True,
                liftsamples=1, fill=1/4, group=SE3,knn=False,cache=False, **kwargs):
        super().__init__()
        if isinstance(fill,(float,int)):
            fill = [fill]*num_layers
        if isinstance(k,int):
            k = [k]*(num_layers+1)
        conv = lambda ki,ko,fill: LieConv(ki, ko, mc_samples=nbhd, ds_frac=ds_frac, bn=bn,act=act, mean=mean,
                                group=group,fill=fill,cache=cache,knn=knn,**kwargs)
        self.net = nn.Sequential(
            Pass(nn.Linear(chin,k[0]),dim=1), #embedding layer
            *[BottleBlock(k[i],k[i+1],conv,bn=bn,act=act,fill=fill[i]) for i in range(num_layers)],
            MaskBatchNormNd(k[-1]) if bn else nn.Sequential(),
            Pass(Swish() if act=='swish' else nn.ReLU(),dim=1),
            Pass(nn.Linear(k[-1],num_outputs),dim=1),
            GlobalPool(mean=mean) if pool else Expression(lambda x: x[1]),
            )
        self.liftsamples = liftsamples
        self.per_point=per_point
        self.group = group

    def forward(self, x):
        lifted_x = self.group.lift(x,self.liftsamples)
        return self.net(lifted_x)

@export
class ImgLieResnet(LieResNet):
    """ Lie Conv architecture specialized to images. Uses scaling rule to determine channel
         and downsampling scaling. Same arguments as LieResNet"""
    def __init__(self,chin=1,total_ds=1/64,num_layers=6,group=T(2),fill=1/32,k=256,
        knn=False,nbhd=12,num_targets=10,increase_channels=True,**kwargs):
        ds_frac = (total_ds)**(1/num_layers)
        fill = [fill/ds_frac**i for i in range(num_layers)]
        if increase_channels: # whether or not to scale the channels as image is downsampled
            k = [int(k/ds_frac**(i/2)) for i in range(num_layers+1)]
        super().__init__(chin=chin,ds_frac=ds_frac,num_layers=num_layers,nbhd=nbhd,mean=True,
                        group=group,fill=fill,k=k,num_outputs=num_targets,cache=True,knn=knn,**kwargs)
        self.lifted_coords = None

    def forward(self,x,coord_transform=None):
        """ assumes x is a regular image: (bs,c,h,w)"""
        bs,c,h,w = x.shape
        # Construct coordinate grid
        i = torch.linspace(-h/2.,h/2.,h)
        j = torch.linspace(-w/2.,w/2.,w)
        coords = torch.stack(torch.meshgrid([i,j]),dim=-1).float()
        # Perform center crop
        center_mask = coords.norm(dim=-1)<15. # crop out corners (filled only with zeros)
        coords = coords[center_mask].view(-1,2).unsqueeze(0).repeat(bs,1,1).to(x.device)
        if coord_transform is not None: coords = coord_transform(coords)
        values = x.permute(0,2,3,1)[:,center_mask,:].reshape(bs,-1,c)
        mask = torch.ones(bs,values.shape[1],device=x.device)>0 # all true
        z = (coords,values,mask)
        # Perform lifting of the coordinates and cache results
        with torch.no_grad():
            if self.lifted_coords is None:
                self.lifted_coords,lifted_vals,lifted_mask = self.group.lift(z,self.liftsamples)
            else:
                lifted_vals,lifted_mask = self.group.expand_like(values,mask,self.lifted_coords)
        return self.net((self.lifted_coords,lifted_vals,lifted_mask))