mfinzi/LieConv

View on GitHub
lie_conv/lieGroups.py

Summary

Maintainability
C
1 day
Test Coverage
import torch
import numpy as np
from lie_conv.utils import export, Named

@export
def norm(x,dim):
    return (x**2).sum(dim=dim).sqrt()

class LieGroup(object):
    """ The abstract Lie Group requiring additional implementation of exp,log, and lifted_elems
        to use as a new group for LieConv. rep_dim,lie_dim,q_dim should additionally be specified."""
    rep_dim = NotImplemented # dimension on which G acts. (e.g. 2 for SO(2))
    lie_dim = NotImplemented # dimension of the lie algebra of G. (e.g. 1 for SO(2))
    q_dim = NotImplemented # dimension which the quotient space X/G is embedded. (e.g. 1 for SO(2) acting on R2)
    
    def __init__(self,alpha=.2):
        super().__init__()
        self.alpha=alpha

    def exp(self,a):
        """ Computes (matrix) exponential Lie algebra elements (in a given basis).
            ie out = exp(\sum_i a_i A_i) where A_i are the exponential generators of G.
            Input: [a (*,lie_dim)] where * is arbitrarily shaped
            Output: [exp(a) (*,rep_dim,rep_dim)] returns the matrix for each."""
        raise NotImplementedError
    
    def log(self,u):
        """ Computes (matrix) logarithm for collection of matrices and converts to Lie algebra basis.
            Input [u (*,rep_dim,rep_dim)]
            Output [coeffs of log(u) in basis (*,d)] """
        raise NotImplementedError
    
    def lifted_elems(self,xyz,nsamples):
        """ Takes in coordinates xyz and lifts them to Lie algebra elements a (in basis)
            and embedded orbit identifiers q. For groups where lifting is multivalued
            specify nsamples>1 as number of lifts to do for each point.
            Inputs: [xyz (*,n,rep_dim)],[mask (*,n)], [mask (int)]
            Outputs: [a (*,n*nsamples,lie_dim)],[q (*,n*nsamples,q_dim)]"""
        raise NotImplementedError
    
    def inv(self,g):
        """ We can compute the inverse of elements g (*,rep_dim,rep_dim) as exp(-log(g))"""
        return self.exp(-self.log(g))

    def distance(self,abq_pairs):
        """ Compute distance of size (*) from [abq_pairs (*,lie_dim+2*q_dim)].
            Simply computes alpha*norm(log(v^{-1}u)) +(1-alpha)*norm(q_a-q_b),
            combined distance from group element distance and orbit distance."""
        ab_dist = norm(abq_pairs[...,:self.lie_dim],dim=-1)
        qa = abq_pairs[...,self.lie_dim:self.lie_dim+self.q_dim]
        qb = abq_pairs[...,self.lie_dim+self.q_dim:self.lie_dim+2*self.q_dim]
        qa_qb_dist = norm(qa-qb,dim=-1)
        return ab_dist*self.alpha + (1-self.alpha)*qa_qb_dist
    
    def lift(self,x,nsamples,**kwargs):
        """assumes p has shape (*,n,2), vals has shape (*,n,c), mask has shape (*,n)
            returns (a,v) with shapes [(*,n*nsamples,lie_dim),(*,n*nsamples,c)"""
        p,v,m = x
        expanded_a,expanded_q = self.lifted_elems(p,nsamples,**kwargs) # (bs,n*ns,d), (bs,n*ns,qd)
        nsamples = expanded_a.shape[-2]//m.shape[-1]
        # expand v and mask like q
        expanded_v = v[...,None,:].repeat((1,)*len(v.shape[:-1])+(nsamples,1)) # (bs,n,c) -> (bs,n,1,c) -> (bs,n,ns,c)
        expanded_v = expanded_v.reshape(*expanded_a.shape[:-1],v.shape[-1]) # (bs,n,ns,c) -> (bs,n*ns,c)
        expanded_mask = m[...,None].repeat((1,)*len(v.shape[:-1])+(nsamples,)) # (bs,n) -> (bs,n,ns)
        expanded_mask = expanded_mask.reshape(*expanded_a.shape[:-1]) # (bs,n,ns) -> (bs,n*ns)
        # convert from elems to pairs
        paired_a = self.elems2pairs(expanded_a) #(bs,n*ns,d) -> (bs,n*ns,n*ns,d)
        if expanded_q is not None:
            q_in = expanded_q.unsqueeze(-2).expand(*paired_a.shape[:-1],1)
            q_out = expanded_q.unsqueeze(-3).expand(*paired_a.shape[:-1],1)
            embedded_locations = torch.cat([paired_a,q_in,q_out],dim=-1)
        else:
            embedded_locations = paired_a
        return (embedded_locations,expanded_v,expanded_mask)
    
    def expand_like(self,v,m,a):
        nsamples = a.shape[-2]//m.shape[-1]
        expanded_v = v[...,None,:].repeat((1,)*len(v.shape[:-1])+(nsamples,1)) # (bs,n,c) -> (bs,n,1,c) -> (bs,n,ns,c)
        expanded_v = expanded_v.reshape(*a.shape[:2],v.shape[-1]) # (bs,n,ns,c) -> (bs,n*ns,c)
        expanded_mask = m[...,None].repeat((1,)*len(v.shape[:-1])+(nsamples,)) # (bs,n) -> (bs,n,ns)
        expanded_mask = expanded_mask.reshape(*a.shape[:2]) # (bs,n,ns) -> (bs,n*ns)
        return expanded_v, expanded_mask
    
    def elems2pairs(self,a):
        """ computes log(e^-b e^a) for all a b pairs along n dimension of input.
            inputs: [a (bs,n,d)] outputs: [pairs_ab (bs,n,n,d)] """
        vinv = self.exp(-a.unsqueeze(-3))
        u = self.exp(a.unsqueeze(-2))
        return self.log(vinv@u)    # ((bs,1,n,d) -> (bs,1,n,r,r))@((bs,n,1,d) -> (bs,n,1,r,r))

    def BCH(self,a,b,order=2):
        """ Baker Campbell Hausdorff formula"""
        assert order <= 4, "BCH only supported up to order 4"
        B = self.bracket
        z = a+b
        if order==1: return z
        ab = B(a,b)
        z += (1/2)*ab
        if order==2: return z
        aab = B(a,ab)
        bba = B(b,-ab)
        z += (1/12)*(aab+bba)
        if order==3: return z
        baab = B(b,aab)
        z += -(1/24)*baab
        return z
    
    def bracket(self,a,b):
        """Computes the lie bracket between a and b, assumes a,b expressed as vectors"""
        A = self.components2matrix(a)
        B = self.components2matrix(b)
        return self.matrix2components(A@B-B@A)

    def __str__(self):
        return f"{self.__class__}({self.alpha})" if self.alpha!=.2 else f"{self.__class__}"
    def __repr__(self):
        return str(self)

@export
def LieSubGroup(liegroup,generators):
    
    class subgroup(liegroup):
        
        def __init__(self,*args,**kwargs):
            super().__init__(*args,**kwargs)
            self.orig_dim = self.lie_dim
            self.lie_dim = len(generators)
            self.q_dim = self.orig_dim-len(generators)

        def exp(self,a_small):
            a_full = torch.zeros(*a_small.shape[:-1],self.orig_dim,
                        device=a_small.device,dtype=a_small.dtype)
            a_full[...,generators] = a_small
            return super().exp(a_full)
        
        def log(self,U):
            return super().log(U)[...,generators]
        
        def components2matrix(self,a_small):
            a_full = torch.zeros(*a_small.shape[:-1],self.orig_dim,
                         device=a_small.device,dtype=a_small.dtype)
            a_full[...,generators] = a_small
            return super().components2matrix(a_full)
        
        def matrix2components(self,A):
            return super().matrix2components(A)[...,generators]
        def lifted_elems(self,pt,nsamples=1):
            """ pt (bs,n,D) mask (bs,n), per_point specifies whether to
                use a different group element per atom in the molecule"""
            a_full,q = super().lifted_elems(pt,nsamples)
            a_sub = a_full[...,generators]
            complement_generators = list(set(range(self.orig_dim))-set(generators))
            new_qs = a_full[...,complement_generators]
            q_sub = torch.cat([q,new_qs],dim=-1) if q is not None else new_qs
            return a_sub,q_sub
        # def __str__(self):
        #     return f"Subgroup({str(liegroup)},{generators})"
    return subgroup

@export
class T(LieGroup):
    def __init__(self,k):
        """ Returns the k dimensional translation group. Assumes lifting from R^k"""
        super().__init__()
        self.q_dim = 0
        self.rep_dim = k # dimension on which G acts
        self.lie_dim = k # dimension that g is embedded into

    def lifted_elems(self,xyz,nsamples,**kwargs):
        assert nsamples==1, "Abelian group, no need for nsamples"
        return xyz,None
    
    def elems2pairs(self,a):
        deltas = a.unsqueeze(-2)-a.unsqueeze(-3)
        return deltas
    # def distance(self,embedded_pairs):
    #     return norm(embedded_pairs,dim=-1)

# Helper functions for analytic exponential maps. Uses taylor expansions near x=0
# See http://ethaneade.com/lie_groups.pdf for derivations.
thresh =7e-2
def sinc(x):
    """ sin(x)/x """
    x2=x*x
    usetaylor = (x.abs()<thresh)
    return torch.where(usetaylor,1-x2/6*(1-x2/20*(1-x2/42)),x.sin()/x)
def sincc(x):
    """ (1-sinc(x))/x^2"""
    x2=x*x
    usetaylor = (x.abs()<thresh)
    return torch.where(usetaylor,1/6*(1-x2/20*(1-x2/42*(1-x2/72))),(x-x.sin())/x**3)
def cosc(x):
    """ (1-cos(x))/x^2"""
    x2 = x*x
    usetaylor = (x.abs()<thresh)
    return torch.where(usetaylor,1/2*(1-x2/12*(1-x2/30*(1-x2/56))),(1-x.cos())/x**2)
def coscc(x):
    """  """
    x2 = x*x
    #assert not torch.any(torch.isinf(x2)), f"infs in x2 log"
    usetaylor = (x.abs()<thresh)
    texpand = 1/12*(1+x2/60*(1+x2/42*(1+x2/40)))
    costerm = (2*(1-x.cos())).clamp(min=1e-6)
    full = (1-x*x.sin()/costerm)/x**2 #Nans can come up here when cos = 1
    output = torch.where(usetaylor,texpand,full)
    return output

def sinc_inv(x):
    usetaylor = (x.abs()<thresh)
    texpand = 1+(1/6)*x**2 +(7/360)*x**4
    assert not torch.any(torch.isinf(texpand)|torch.isnan(texpand)),'sincinv texpand inf'+torch.any(torch.isinf(texpand))
    return torch.where(usetaylor,texpand,x/x.sin())

## Lie Groups acting on R2

@export
class SO2(LieGroup):
    lie_dim = 1
    rep_dim = 2
    q_dim = 1
    def exp(self,a):
        R = torch.zeros(*a.shape[:-1],2,2,device=a.device,dtype=a.dtype)
        sin = a[...,0].sin()
        cos = a[...,0].cos()
        R[...,0,0] = cos
        R[...,1,1] = cos
        R[...,0,1] = -sin
        R[...,1,0] = sin
        return R
    def log(self,R):
        return torch.atan2(R[...,1,0]-R[...,0,1],R[...,0,0]+R[...,1,1])[...,None]
    def components2matrix(self,a): # a: (*,lie_dim)
        A = torch.zeros(*a.shape[:-1],2,2,device=a.device,dtype=a.dtype)
        A[...,0,1] = -a[...,0]
        A[...,1,0] = a[...,0]
        return A
    def matrix2components(self,A): # A: (*,rep_dim,rep_dim)
        a = torch.zeros(*A.shape[:-1],1,device=A.device,dtype=A.dtype)
        a[...,:1] = (A[...,1,:1]-A[...,:1,1])/2
        return a
    def lifted_elems(self,pt,nsamples=1):
        """ pt (bs,n,D) mask (bs,n), per_point specifies whether to
            use a different group element per atom in the molecule"""
        assert nsamples==1, "Abelian group, no need for nsamples"
        bs,n,D = pt.shape[:3] # origin = [1,0]
        assert D==2, "Lifting from R^2 to SO(2) supported only"
        r = norm(pt,dim=-1).unsqueeze(-1)
        theta = torch.atan2(pt[...,1],pt[...,0]).unsqueeze(-1)
        return theta,r # checked that lifted_elem(v)@[0,1] = v
    def distance(self,abq_pairs):
        angle_pairs = abq_pairs[...,0]
        ra = abq_pairs[...,1]
        rb = abq_pairs[...,2]
        return angle_pairs.abs()*self.alpha + (1-self.alpha)*(ra-rb).abs()/(ra+rb+1e-3)

@export
class RxSO2(LieGroup):
    """ Rotation scaling group. Equivalent to log polar convolution."""
    lie_dim=2
    rep_dim=2
    q_dim=0
    def exp(self,a):
        logr = a[...,0]
        R = torch.zeros(*a.shape[:-1],2,2,device=a.device,dtype=a.dtype)
        rsin = logr.exp()*a[...,1].sin()
        rcos = logr.exp()*a[...,1].cos()
        R[...,0,0] = rcos
        R[...,1,1] = rcos
        R[...,0,1] = -rsin
        R[...,1,0] = rsin
        return R
    def log(self,R):
        rsin = (R[...,1,0]-R[...,0,1])/2
        rcos = (R[...,0,0]+R[...,1,1])/2
        theta = torch.atan2(rsin,rcos)
        r = (rsin**2+rcos**2).sqrt()
        return torch.stack([r.log(),theta],dim=-1)
    def lifted_elems(self,pt,nsamples=1):
        bs,n,D = pt.shape[:3] # origin = [1,0]
        assert D==2, "Lifting from R^2 to RxSO(2) supported only"
        r = norm(pt,dim=-1).unsqueeze(-1)
        theta = torch.atan2(pt[...,1],pt[...,0]).unsqueeze(-1)
        return torch.cat([r.log(),theta],dim=-1),None
    def distance(self,abq_pairs):
        angle_dist = abq_pairs[...,1].abs()
        r_dist = abq_pairs[...,0].abs()
        return angle_dist*self.alpha + (1-self.alpha)*r_dist

@export
class RxSQ(LieGroup):
    """ Rotation Squeeze group. Equivalent to log hyperbolic coordinate convolution.
        Acts on the positive orthant R2+."""
    lie_dim=2
    rep_dim=2
    q_dim=0
    def exp(self,a):
        raise NotImplementedError
    def log(self,R):
        raise NotImplementedError
    def lifted_elems(self,pt,nsamples=1):
        bs,n,D = pt.shape[:3] # origin = [1,0]
        assert nsamples==1, "Abelian group, no need for nsamples"
        assert D==2, "Lifting from R^2 to RxSQ supported only"
        lxy = pt.log()
        logs = (lxy[...,0]-lxy[...,1])/2
        logr = (lxy[...,0]+lxy[...,1])/2
        return torch.cat([logr,logs],dim=-1),None
    def distance(self,abq_pairs):
        s_dist = abq_pairs[...,1].abs()
        r_dist = abq_pairs[...,0].abs()
        return s_dist*self.alpha + (1-self.alpha)*r_dist
        
@export
class Rx(LieSubGroup(RxSO2,(0,))): pass
@export
class SQ(LieSubGroup(RxSQ,(1,))): pass
@export
class Tx(LieSubGroup(T,(0,))): pass
@export
class Ty(LieSubGroup(T,(1,))): pass

@export
class SE2(SO2):
    lie_dim = 3
    rep_dim = 3
    q_dim = 0
    def log(self,g):
        theta = super().log(g[...,:2,:2])
        I = torch.eye(2,device=g.device,dtype=g.dtype)
        K = super().components2matrix(torch.ones_like(theta))
        theta = theta.unsqueeze(-1)
        Vinv = (sinc(theta)/(2*cosc(theta)))*I - theta*K/2
        a = torch.zeros(g.shape[:-1],device=g.device,dtype=g.dtype)
        a[...,0] = theta[...,0,0]
        a[...,1:] = (Vinv@g[...,:2,2].unsqueeze(-1)).squeeze(-1)
        return a
    
    def exp(self,a):
        """ assumes that a is expanded in the basis [tx,ty,theta] of the lie algebra
            a should have shape (*,3)"""
        theta = a[...,0].unsqueeze(-1)
        I = torch.eye(2,device=a.device,dtype=a.dtype)
        K = super().components2matrix(torch.ones_like(a))
        theta = theta.unsqueeze(-1)
        V = sinc(theta)*I + theta*cosc(theta)*K
        g = torch.zeros(*a.shape[:-1],3,3,device=a.device,dtype=a.dtype)
        g[...,:2,:2] = theta.cos()*I+theta.sin()*K
        g[...,:2,2] = (V@a[...,1:].unsqueeze(-1)).squeeze(-1)
        g[...,2,2] = 1
        return g
    
    def components2matrix(self,a):
        """takes an element in the lie algebra expressed in the standard basis and
            expands to the corresponding matrix. a: (*,3)"""
        A = torch.zeros(*a.shape,3,device=a.device,dtype=a.dtype)
        A[...,2,:2] = a[...,1:]
        A[...,0,1] = a[...,0]
        A[...,1,0] = -a[...,0]
        return A
    
    def matrix2components(self,A):
        """takes an element in the lie algebra expressed as a matrix (*,3,3) and
            expresses it in the standard basis"""
        a = torch.zeros(*A.shape[:-1],device=A.device,dtype=A.dtype)
        a[...,1:] = A[...,:2,2]
        a[...,0] = (A[...,1,0]-A[...,0,1])/2
        return a
    
    def lifted_elems(self,pt,nsamples=1):
        #TODO: correctly handle masking, unnecessary for image data
        d=self.rep_dim
        # Sample stabilizer of the origin
        #thetas = (torch.rand(*p.shape[:-1],1).to(p.device)*2-1)*np.pi
        #thetas = torch.randn(nsamples)*2*np.pi - np.pi
        thetas = torch.linspace(-np.pi,np.pi,nsamples+1)[1:].to(pt.device)
        for _ in pt.shape[:-1]: # uniform on circle, but -pi and pi ar the same
            thetas=thetas.unsqueeze(0)
        thetas = thetas + torch.rand(*pt.shape[:-1],1).to(pt.device)*2*np.pi
        R = torch.zeros(*pt.shape[:-1],nsamples,d,d).to(pt.device)
        sin,cos = thetas.sin(),thetas.cos()
        R[...,0,0] = cos
        R[...,1,1] = cos
        R[...,0,1] = -sin
        R[...,1,0] = sin
        R[...,2,2] = 1
        # Get T(p)
        T = torch.zeros_like(R)
        T[...,0,0]=1
        T[...,1,1]=1
        T[...,2,2]=1
        T[...,:2,2] = pt.unsqueeze(-2)
        flat_a = self.log(T@R).reshape(*pt.shape[:-2],pt.shape[-2]*nsamples,d)
        return flat_a, None
    def distance(self,abq_pairs):
        d_theta = abq_pairs[...,0].abs()
        d_r = norm(abq_pairs[...,1:],dim=-1)
        return d_theta*self.alpha + (1-self.alpha)*d_r

## Lie Groups acting on R3

# Hodge star on R3
def cross_matrix(k):
    """Application of hodge star on R3, mapping Λ^1 R3 -> Λ^2 R3"""
    K = torch.zeros(*k.shape[:-1],3,3,device=k.device,dtype=k.dtype)
    K[...,0,1] = -k[...,2]
    K[...,0,2] = k[...,1]
    K[...,1,0] = k[...,2]
    K[...,1,2] = -k[...,0]
    K[...,2,0] = -k[...,1]
    K[...,2,1] = k[...,0]
    return K

def uncross_matrix(K):
    """Application of hodge star on R3, mapping Λ^2 R3 -> Λ^1 R3"""
    k = torch.zeros(*K.shape[:-1],device=K.device,dtype=K.dtype)
    k[...,0] = (K[...,2,1] - K[...,1,2])/2
    k[...,1] = (K[...,0,2] - K[...,2,0])/2
    k[...,2] = (K[...,1,0] - K[...,0,1])/2
    return k

@export
class SO3(LieGroup):
    lie_dim = 3
    rep_dim = 3
    q_dim = 1
    def __init__(self,alpha=.2):
        super().__init__()
        self.alpha = alpha
    
    def exp(self,w):
        """ Rodriguez's formula, assuming shape (*,3)
            where components 1,2,3 are the generators for xrot,yrot,zrot"""
        theta = norm(w,dim=-1)[...,None,None]
        K = cross_matrix(w)
        I = torch.eye(3,device=K.device,dtype=K.dtype)
        Rs = I + K*sinc(theta) + (K@K)*cosc(theta)
        return Rs
    
    def log(self,R):
        """ Computes components in terms of generators rx,ry,rz. Shape (*,3,3)"""
        trR = R[...,0,0]+R[...,1,1]+R[...,2,2]
        costheta = ((trR-1)/2).clamp(max=1,min=-1).unsqueeze(-1)
        theta = torch.acos(costheta)
        logR = uncross_matrix(R)*sinc_inv(theta)
        return logR
    
    def components2matrix(self,a): # a: (*,3)
        return cross_matrix(a)
    
    def matrix2components(self,A): # A: (*,rep_dim,rep_dim)
        return uncross_matrix(A)
    
    def sample(self,*shape,device=torch.device('cuda'),dtype=torch.float32):
        q = torch.randn(*shape,4,device=device,dtype=dtype)
        q /= norm(q,dim=-1).unsqueeze(-1)
        theta_2 = torch.atan2(norm(q[...,1:],dim=-1),q[...,0]).unsqueeze(-1)
        so3_elem = 2*sinc_inv(theta_2)*q[...,1:] # # (sin(x/2)u -> xu) for x angle and u direction
        R = self.exp(so3_elem)
        return R
    
    def lifted_elems(self,pt,nsamples,**kwargs):
        """ Lifting from R^3 -> SO(3) , R^3/SO(3). pt shape (*,3)
            First get a random rotation Rz about [1,0,0] by the appropriate angle
            and then rotate from [1,0,0] to p/\|p\| with Rp  to get RpRz and then
            convert to logarithmic coordinates log(RpRz), \|p\|"""
        d=self.rep_dim
        device,dtype = pt.device,pt.dtype
        # Sample stabilizer of the origin
        q = torch.randn(*pt.shape[:-1],nsamples,4,device=device,dtype=dtype)
        q /= norm(q,dim=-1).unsqueeze(-1)
        theta = 2*torch.atan2(norm(q[...,1:],dim=-1),q[...,0]).unsqueeze(-1)
        zhat = torch.zeros(*pt.shape[:-1],nsamples,3,device=device,dtype=dtype) # (*,3)
        zhat[...,0] = 1#theta
        Rz = self.exp(zhat*theta)

        # Compute the rotation between zhat and p
        r = norm(pt,dim=-1).unsqueeze(-1) # (*,1)
        assert not torch.any(torch.isinf(pt)|torch.isnan(pt))
        p_on_sphere = pt/r.clamp(min=1e-5)
        w = torch.cross(zhat,p_on_sphere[...,None,:].expand(*zhat.shape))
        sin = norm(w,dim=-1)
        cos = p_on_sphere[...,None,0]
        
        angle = torch.atan2(sin,cos).unsqueeze(-1) #cos angle
        Rp = self.exp(w*sinc_inv(angle))
        
        # Combine the rotations into one
        A = self.log(Rp@Rz)  # Convert to lie algebra element
        assert not torch.any(torch.isnan(A)|torch.isinf(A))
        q = r[...,None,:].expand(*r.shape[:-1],nsamples,1) # The orbit identifier is \|x\|
        flat_q = q.reshape(*r.shape[:-2],r.shape[-2]*nsamples,1)
        flat_a = A.reshape(*pt.shape[:-2],pt.shape[-2]*nsamples,d)
        return flat_a, flat_q

@export
class SE3(SO3):
    lie_dim = 6
    rep_dim = 4
    q_dim = 0
    def __init__(self,alpha=.2,per_point=True):
        super().__init__()
        self.alpha = alpha
        self.per_point = per_point

    def exp(self,w):
        theta = norm(w[...,:3],dim=-1)[...,None,None]
        K = cross_matrix(w[...,:3])
        R = super().exp(w[...,:3])
        I = torch.eye(3,device=w.device,dtype=w.dtype)
        V = I + cosc(theta)*K + sincc(theta)*(K@K)
        U = torch.zeros(*w.shape[:-1],4,4,device=w.device,dtype=w.dtype)
        U[...,:3,:3] = R
        U[...,:3,3] = (V@w[...,3:].unsqueeze(-1)).squeeze(-1)
        U[...,3,3] = 1
        return U
    
    def log(self,U):
        w = super().log(U[...,:3,:3])
        I = torch.eye(3,device=w.device,dtype=w.dtype)
        K = cross_matrix(w[...,:3])
        theta = norm(w,dim=-1)[...,None,None]#%(2*np.pi)
        #theta[theta>np.pi] -= 2*np.pi
        cosccc = coscc(theta)
        Vinv = I - K/2 + cosccc*(K@K)
        u = (Vinv@U[...,:3,3].unsqueeze(-1)).squeeze(-1)
        #assert not torch.any(torch.isnan(u)), f"nans in u log {torch.isnan(u).sum()}, {torch.where(torch.isnan(u))}"
        return torch.cat([w,u],dim=-1)

    
    def components2matrix(self,a): # a: (*,3)
        A = torch.zeros(*a.shape[:-1],4,4,device=a.device,dtype=a.dtype)
        A[...,:3,:3] = cross_matrix(a[...,:3])
        A[...,:3,3] = a[...,3:]
        return A
    
    def matrix2components(self,A): # A: (*,4,4)
        return torch.cat([uncross_matrix(A[...,:3,:3]),A[...,:3,3]],dim=-1)

    def lifted_elems(self,pt,nsamples):
        """ pt (bs,n,D) mask (bs,n), per_point specifies whether to
            use a different group element per atom in the molecule"""
        #return farthest_lift(self,pt,mask,nsamples,alpha)
        # same lifts for each point right now
        bs,n = pt.shape[:2]
        if self.per_point:
            q = torch.randn(bs,n,nsamples,4,device=pt.device,dtype=pt.dtype)
        else:
            q = torch.randn(bs,1,nsamples,4,device=pt.device,dtype=pt.dtype)
        q /= norm(q,dim=-1).unsqueeze(-1)
        theta_2 = torch.atan2(norm(q[...,1:],dim=-1),q[...,0]).unsqueeze(-1)
        so3_elem = 2*sinc_inv(theta_2)*q[...,1:] # (sin(x/2)u -> xu) for x angle and u direction
        se3_elem = torch.cat([so3_elem,torch.zeros_like(so3_elem)],dim=-1)
        R = self.exp(se3_elem)
        T = torch.zeros(bs,n,nsamples,4,4,device=pt.device,dtype=pt.dtype) # (bs,n,nsamples,4,4)
        T[...,:,:] = torch.eye(4,device=pt.device,dtype=pt.dtype)
        T[...,:3,3] = pt[:,:,None,:] # (bs,n,1,3)
        a = self.log(T@R)#@R) # bs, n, nsamples, 6
        return a.reshape(bs,n*nsamples,6), None
    def distance(self,abq_pairs):
        dist_rot = norm(abq_pairs[...,:3],dim=-1)
        dist_trans = norm(abq_pairs[...,3:],dim=-1)
        return dist_rot*self.alpha + (1-self.alpha)*dist_trans

@export
class Trivial(LieGroup):
    lie_dim=0
    def __init__(self,dim=2):
        super().__init__()
        self.q_dim = dim
        self.rep_dim = dim

    def lift(self,x,nsamples,**kwargs):
        assert nsamples==1, "Abelian group, no need for nsamples"
        p,v,m = x
        bs,n,d = p.shape
        qa = p[...,:,None,:].expand(bs,n,n,d)
        qb = p[...,None,:,:].expand(bs,n,n,d)
        q = torch.cat([qa,qb],dim=-1)
        return q,v,m
    # def distance(self,abq_pairs):
    #     qa = abq_pairs[...,:self.q_dim]
    #     qb = abq_pairs[...,self.q_dim:]
    #     return norm(qa-qb,dim=-1)

@export
class FakeSchGroup(object):
    lie_dim=0
    rep_dim=3
    q_dim=1
    
    def lift(self,x,nsamples,**kwargs):
        """assumes p has shape (*,n,2), vals has shape (*,n,c), mask has shape (*,n)
            returns (a,v) with shapes [(*,n*nsamples,lie_dim),(*,n*nsamples,c)"""
        p,v,m = x
        q = (p[...,:,None,:] - p[...,None,:,:]).norm(dim=-1).unsqueeze(-1)
        return (q,v,m)
    def distance(self,abq_pairs):
        return abq_pairs