mfinzi/LieConv

View on GitHub
examples/train_springs.py

Summary

Maintainability
A
0 mins
Test Coverage
import copy, warnings
from oil.tuning.args import argupdated_config
from oil.datasetup.datasets import split_dataset
from oil.tuning.study import train_trial
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam
from oil.utils.utils import LoaderTo, islice, FixedNumpySeed, cosLr
from lie_conv.datasets import SpringDynamics
from lie_conv import datasets
from lie_conv.dynamicsTrainer import IntegratedDynamicsTrainer, FC, HLieResNet

import lie_conv.lieGroups as lieGroups
from lie_conv.lieGroups import Tx
from lie_conv import dynamicsTrainer
#from lie_conv.dynamics_trial import DynamicsTrial
try:
    import lie_conv.graphnets as graphnets
except ImportError:
    import lie_conv.lieConv as graphnets
    warnings.warn('Failed to import graphnets. Please install using \
                `pip install .[GN]` for this functionality', ImportWarning)

def makeTrainer(*,network=FC,net_cfg={},lr=1e-2,n_train=3000,regen=False,dataset=SpringDynamics,
                dtype=torch.float32,device=torch.device('cuda'),bs=200,num_epochs=2,
                trainer_config={}):
    # Create Training set and model
    splits = {'train':n_train,'val':200,'test':2000}
    dataset = dataset(n_systems=10000,regen=regen)
    with FixedNumpySeed(0):
        datasets = split_dataset(dataset,splits)
    model = network(sys_dim=dataset.sys_dim,d=dataset.space_dim,**net_cfg).to(device=device,dtype=dtype)
    # Create train and Dev(Test) dataloaders and move elems to gpu
    dataloaders = {k:LoaderTo(DataLoader(v,batch_size=min(bs,n_train),num_workers=0,shuffle=(k=='train')),
                                device=device,dtype=dtype) for k,v in datasets.items()}
    dataloaders['Train'] = islice(dataloaders['train'],len(dataloaders['val']))
    # Initialize optimizer and learning rate schedule
    opt_constr = lambda params: Adam(params, lr=lr)
    lr_sched = cosLr(num_epochs)
    return IntegratedDynamicsTrainer(model,dataloaders,opt_constr,lr_sched,
                                    log_args={'timeFrac':1/4,'minPeriod':0.0},**trainer_config)

Trial = train_trial(makeTrainer)
if __name__=='__main__':
    defaults = copy.deepcopy(makeTrainer.__kwdefaults__)
    defaults['save']=False
    defaults['trainer_config']['early_stop_metric']='val_MSE'
    print(Trial(argupdated_config(defaults,namespace=(dynamicsTrainer,lieGroups,datasets,graphnets))))