ramon-oliveira/aorun

View on GitHub
aorun/optimizers.py

Summary

Maintainability
A
1 hr
Test Coverage
import torch
from torch.autograd import Variable


class Optimizer(object):

    def __init__(self, params=None, decay=0.0, epsilon=1e-8):
        self.params = params
        self.decay = decay
        self.epsilon = epsilon

    def clear_gradients(self):
        if self.params is not None:
            for p in self.params:
                p.grad.data.zero_()

    def step(self):
        if self.params is None:
            raise Exception('None parameters')


class SGD(Optimizer):

    def __init__(self, lr=0.001, momentum=0.0, *args, **kwargs):
        super(SGD, self).__init__(*args, **kwargs)
        self.lr = lr
        self.momentum = momentum
        self.updates = []

    def step(self):
        super(SGD, self).step()

        if len(self.updates) == 0:
            for p in self.params:
                update = torch.zeros(p.size())
                self.updates.append(update)

        for i, (p, update) in enumerate(zip(self.params, self.updates)):
            cur_update = self.momentum * update + self.lr * p.grad.data
            p.data.sub_(cur_update)
            self.updates[i] = cur_update

        self.lr = max(1e-9, self.lr - self.decay)
        self.clear_gradients()


class Adam(Optimizer):

    def __init__(self, lr=0.001, beta1=0.9, beta2=0.999, *args, **kwargs):
        super(Adam, self).__init__(*args, **kwargs)
        self.lr = lr
        self.beta1 = beta1
        self.beta2 = beta2
        # average gradients
        self.m = {}
        # average gradients**2
        self.v = {}
        # timestep
        self.t = 0

    def step(self):
        super(Adam, self).step()
        self.t += 1

        if len(self.m) == 0:
            for p in self.params:
                self.m[p] = torch.zeros(p.size())
                self.v[p] = torch.zeros(p.size())

        for p in self.params:
            mt = self.beta1 * self.m[p] + (1 - self.beta1) * p.grad.data
            vt = self.beta2 * self.v[p] + (1 - self.beta2) * p.grad.data**2
            m = mt / (1 - self.beta1**self.t)
            v = vt / (1 - self.beta2**self.t)

            rate = self.lr / (torch.sqrt(v) + self.epsilon)
            p.data.sub_(rate * m)

            self.m[p] = mt
            self.v[p] = vt

        self.clear_gradients()

# Alias
sgd = SGD
adam = Adam


def get(obj):
    if hasattr(obj, 'step'):
        return obj
    elif type(obj) is str:
        if obj in globals():
            return globals()[obj]()
        else:
            raise Exception(f'Unknown optmizer: {obj}')
    else:
        raise Exception('Optimizer must be a callable or str')