kengz/SLM-Lab

View on GitHub
slm_lab/agent/net/base.py

Summary

Maintainability
A
0 mins
Test Coverage
B
89%
from abc import ABC, abstractmethod
from slm_lab.agent.net import net_util
import pydash as ps
import torch
import torch.nn as nn


class Net(ABC):
    '''Abstract Net class to define the API methods'''

    def __init__(self, net_spec, in_dim, out_dim):
        '''
        @param {dict} net_spec is the spec for the net
        @param {int|list} in_dim is the input dimension(s) for the network. Usually use in_dim=body.state_dim
        @param {int|list} out_dim is the output dimension(s) for the network. Usually use out_dim=body.action_dim
        '''
        self.net_spec = net_spec
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.grad_norms = None  # for debugging
        if self.net_spec.get('gpu'):
            if torch.cuda.device_count():
                self.device = f'cuda:{net_spec.get("cuda_id", 0)}'
            else:
                self.device = 'cpu'
        else:
            self.device = 'cpu'

    @abstractmethod
    def forward(self):
        '''The forward step for a specific network architecture'''
        raise NotImplementedError

    @net_util.dev_check_train_step
    def train_step(self, loss, optim, lr_scheduler=None, clock=None, global_net=None):
        if lr_scheduler is not None:
            lr_scheduler.step(epoch=ps.get(clock, 'frame'))
        optim.zero_grad()
        loss.backward()
        if self.clip_grad_val is not None:
            nn.utils.clip_grad_norm_(self.parameters(), self.clip_grad_val)
        if global_net is not None:
            net_util.push_global_grads(self, global_net)
        optim.step()
        if global_net is not None:
            net_util.copy(global_net, self)
        if clock is not None:
            clock.tick('opt_step')
        return loss

    def store_grad_norms(self):
        '''Stores the gradient norms for debugging.'''
        norms = [param.grad.norm().item() for param in self.parameters()]
        self.grad_norms = norms