mfinzi/LieConv

View on GitHub
lie_conv/hamiltonian.py

Summary

Maintainability
A
0 mins
Test Coverage
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import matplotlib.animation as animation


class HamiltonianDynamics(nn.Module):
    """ Defines the dynamics given a hamiltonian. If wgrad=True, the dynamics can be backproped."""
    def __init__(self,H,wgrad=False):
        super().__init__()
        self.H = H
        self.wgrad=wgrad
        self.nfe=0
    def forward(self,t,z):
        self.nfe+=1
        with torch.enable_grad():
            z = torch.zeros_like(z, requires_grad=True) + z
            D = z.shape[-1]
            h = self.H(t,z).sum() # elements in mb are independent, gives mb gradients
            rg = torch.autograd.grad(h,z,create_graph=self.wgrad)[0] # riemannian gradient
        sg = torch.cat([rg[:,D//2:],-rg[:,:D//2]],dim=-1) # symplectic gradient = SdH
        return sg


def EuclideanK(momentums, masses):
    """ Shape (bs,n,d), and (bs,n),
        standard \sum_n p_n^2/{2m_n} kinetic energy"""
    p_sq_norms = momentums.pow(2).sum(-1)
    kinetic_energy = (p_sq_norms / masses).sum(-1) / 2
    return kinetic_energy
    # return (p*(p/m[:,:,None])).sum(-1).sum(-1)/2

def KeplerV(positions, masses):
    """ Shape (bs,n,d), and (bs,n),
        Gravitational PE: -\sum_{jk} m_jm_k/{\|q_j-q_k\|}"""
    grav_const = 1
    n = masses.shape[-1]
    row_ind, col_ind = torch.tril_indices(n, n, offset=-1)
    moments = (masses.unsqueeze(1) * masses.unsqueeze(2))[:, row_ind, col_ind]
    pair_diff = (positions.unsqueeze(1) - positions.unsqueeze(2))[:, row_ind, col_ind]
    pair_dist = pair_diff.norm(dim=-1) + 1e-8
    potential_energy = -grav_const * (moments / pair_dist).sum(-1)
    return potential_energy

def KeplerH(z,m):
    """ with shapes (bs,2nd)"""
    bs, D = z.shape # of ODE dims, 2*num_particles*space_dim
    q = z[:,:D//2].reshape(*m.shape,-1)
    p = z[:,D//2:].reshape(*m.shape,-1)
    potential_energy = KeplerV(q, m)
    kinetic_energy = EuclideanK(p, m)
    assert potential_energy.shape[0] == bs
    assert kinetic_energy.shape[0] == bs
    return potential_energy + kinetic_energy

def SpringV(q,k):
    """ Potential for a bunch particles connected by springs with kij
        Shape (bs,n,d), and (bs,n,n)"""
    K = k[:,:,None]*k[:,None,:] #(bs,n,n)
    n = K.shape[-1]
    radial = (q[:,:,None,:] - q[:,None,:,:]).norm(dim=-1)**2 # (bs, n, n)
    potential = .5*(K*radial).sum(-1).sum(-1)
    return potential #(bs,n,n) -> (bs)

def SpringH(z,m,k):
    """ with shapes (bs,2nd)"""
    D = z.shape[-1] # of ODE dims, 2*num_particles*space_dim
    q = z[:,:D//2].reshape(*m.shape,-1)
    p = z[:,D//2:].reshape(*m.shape,-1)
    return EuclideanK(p,m) + SpringV(q,k)

def BallV(q,r):
    """ Potential for a bunch of (almost) rigid balls and walls, each ball has radius r"""
    n = r.shape[-1]
    thresh = 0.1
    barrier = lambda dist: .5*(torch.exp(1/(dist-thresh*1.05) - 50*(dist-thresh))).sum(-1)#50*((dist-thresh)**2).sum(-1)#.5*(torch.exp(1/(dist-thresh*1.05))/dist).sum(-1)#1/(dist-thresh*1.05)
    separation = (q[:,:,None,:] - q[:,None,:,:]).norm(dim=-1)
    sum_r = r[:,:,None]+r[:,None,:]
    touching = (separation-sum_r < thresh)
    energy = barrier(separation[touching]-sum_r[touching])
    for i in range(q.shape[-1]):
        ld = q[:,:,i]+1-r
        rd = 1-q[:,:,i]-r
        energy += barrier(ld[ld<thresh])+barrier(rd[rd<thresh])
    return energy

def BallH(z,m,r):
    D = z.shape[-1] # of ODE dims, 2*num_particles*space_dim
    q = z[:,:D//2].reshape(*m.shape,-1)
    p = z[:,D//2:].reshape(*m.shape,-1)
    return EuclideanK(p,m) + BallV(q,r)


# TODO:
# Make animation plots look nicer. Why are there leftover points on the trails?
class Animation2d(object):
    def __init__(self, qt, ms=None, box_lim=(-1, 1)):
        if ms is None: ms = len(qt)*[6]
        self.qt = qt
        self.fig = plt.figure()
        self.ax = self.fig.add_axes([0, 0, 1, 1])#axes(projection='3d')
        self.ax.set_xlim(box_lim)
        self.ax.set_ylim(box_lim)
        self.lines = sum([self.ax.plot([],[],'-') for particle in self.qt],[])
        self.pts = sum([self.ax.plot([],[],'o',ms=ms[i]) for i in range(len(self.qt))],[])
    def init(self):
        for line,pt in zip(self.lines,self.pts):
            line.set_data([], [])
            pt.set_data([], [])
        return self.lines + self.pts
    def update(self,i=0):
        for line, pt, trajectory in zip(self.lines,self.pts,self.qt):
            x,y = trajectory[:,:i]
            line.set_data(x,y)
            pt.set_data(x[-1:], y[-1:])
        #self.fig.clear()
        self.fig.canvas.draw()
        return self.lines+self.pts
    def animate(self):
        return animation.FuncAnimation(self.fig,self.update,frames=self.qt.shape[-1],
                                       interval=33,init_func=self.init,blit=True)
    
class Animation3d(object):
    def __init__(self,qt,ms=None, box_lim=(-1, 1)):
        if ms is None: ms = len(qt)*[6]
        self.qt = qt
        self.fig = plt.figure()
        self.ax = self.fig.add_axes([0, 0, 1, 1],projection='3d')#axes(projection='3d')
        self.ax.set_xlim3d(box_lim)
        self.ax.set_ylim3d(box_lim)
        self.ax.set_zlim3d(box_lim)
        self.lines = sum([self.ax.plot([],[],[],'-') for _ in self.qt],[])
        self.pts = sum([self.ax.plot([],[],[],'o',ms=ms[i]) for i in range(len(self.qt))],[])
    def init(self):
        for line,pt in zip(self.lines,self.pts):
            line.set_data([], [])
            line.set_3d_properties([])
            pt.set_data([], [])
            pt.set_3d_properties([])
        return self.lines + self.pts
    def update(self,i=0):
        for line, pt, trajectory in zip(self.lines,self.pts,self.qt):
            x,y,z = trajectory[:,:i]
            line.set_data(x,y)
            line.set_3d_properties(z)
            pt.set_data(x[-1:], y[-1:])
            pt.set_3d_properties(z[-1:])
        #self.fig.clear()
        self.fig.canvas.draw()
        return self.lines+self.pts
    def animate(self):
        return animation.FuncAnimation(self.fig,self.update,frames=self.qt.shape[-1],
                                       interval=33,init_func=self.init,blit=True)
def AnimationNd(n):
    if n==2: return Animation2d
    elif n==3: return Animation3d
    else: assert False, "No animation for d={}".format(n)