oil/utils/utils.py
import numpy as np
import torch
import numbers
import time
import torch.nn as nn
import inspect
import copy
import os
import dill
import itertools
import sys
import torch.utils.data
import collections
import random
class Named(type):
def __str__(self):
return self.__name__
def __repr__(self):
return self.__name__
def export(fn):
mod = sys.modules[fn.__module__]
if hasattr(mod, '__all__'):
mod.__all__.append(fn.__name__)
else:
mod.__all__ = [fn.__name__]
return fn
def log_uniform(low,high,size=None):
logX = np.random.uniform(np.log(low),np.log(high),size)
return np.exp(logX)
class ReadOnlyDict(dict):
def __readonly__(self, *args, **kwargs):
raise RuntimeError("Cannot modify ReadOnlyDict")
__setitem__ = __readonly__
__delitem__ = __readonly__
pop = __readonly__
popitem = __readonly__
clear = __readonly__
update = __readonly__
setdefault = __readonly__
del __readonly__
# class map_with_len(object):
# def __init__(self, func, iter_with_len):
# self._func = func
# self._iter = iter_with_len
# def __iter__(self):
# return map(self._func, self._iter)
# def __len__(self):
# return len(self._iter)
# class imap(torch.utils.data.DataLoader):
# def __init__(self,func,loader):
# try: loader_func = loader._func
# except AttributeError: loader_func=lambda x:x
# self.__dict__ = loader.__dict__
# self._dl = loader
# self._func = lambda x: func(loader_func(x))
# def __iter__(self):
# return map(self._func,self._dl.__iter__())
# def imap(func,loader):
# class _imap(loader.__class__):
# def __init__(self,loader):
# self.__dict__ = loader.__dict__
# def __iter__(self):
# return map(func,super().__iter__())
# return _imap(loader)
class Wrapper(object):
# Special methods are dispatched by what is defined in the class rather
# than the instance, so it bypasses __getattr__, as a result for a
# wrapper that makes use of any of these methods, we must dynamically dispatch
# the special methods at the instance level (using getattr)
def __init__(self, obj):
self._wrapped_obj = obj
def __getattr__(self, attr):
if attr =='_wrapped_obj': raise AttributeError
if attr == '__dict__': assert False
#if attr not in self.__dict__: raise AttributeError
return getattr(self._wrapped_obj, attr)
smethods = '''__bool__ __int__ __float__ __complex__ __index__
__len__ __getitem__ __setitem__ __delitem__ __contains__
__iter__ __next__ __reversed__
__call__ __enter__ __exit__
__str__ __repr__ __bytes__ __format__
__eq__ __ne__ __lt__ __le__ __gt__ __ge__ __hash__
__add__ __mul__ __sub__ __truediv__ __floordiv__ __mod__
__and__ __or__ __xor__ __invert__ __lshift__ __rshift__
__pos__ __neg__ __abs__ __pow__ __divmod__
__round__ __ceil__ __floor__ __trunc__
__radd__ __rmul__ __rsub__ __rtruediv__ __rfloordiv__ __rmod__
__rand__ __ror__ __rxor__ __rlshift__ __rrshift__
__rpow__ __rdivmod__ __getitem__
__get__ __set__ __delete__
__dir__ __sizeof__'''.split()
for sm in smethods:
setattr(Wrapper, sm, lambda self, *args, sm=sm: Wrapper.__getattr__(self,sm)(*args))
class dmap(Wrapper):
def __init__(self,func,dataset):
super().__init__(dataset)
self._func = func
def __getitem__(self,i):
return self._func(super().__getitem__(i))
class imap(Wrapper):
def __init__(self,func,loader):
super().__init__(loader)
self._func = func
def __iter__(self):
return map(self._func,super().__iter__())
class islice(Wrapper):
def __init__(self,loader,*args,**kwargs):
super().__init__(loader)
self._args = args
self._kwargs = kwargs
def __iter__(self):
return iter(itertools.islice(super().__iter__(),*self._args,**self._kwargs))
## Wraps a dataloader and cycles repeatedly
class icycle(Wrapper):
def __init__(self,dataloader):
super().__init__(dataloader)
def __iter__(self):
while True:
for data in super().__iter__():
yield data
def __len__(self):
return 10**10
# ## Wraps a dataloader and cycles repeatedly
# class icycle(object):
# def __init__(self,dataloader):
# self.dataloader = dataloader
# def __iter__(self):
# while True:
# for data in self.dataloader:
# yield data
# def __len__(self):
# return 10**10
# class imap(object):
# def __init__(self,func,loader):
# self.func = func
# self.loader = loader
# def __iter__(self):
# return map(self.func,self.loader.__iter__())
# def __getattr__(self,name):
# if name==
# return self.loader.__getattribute__(name)
# def __setattr__(self,name,value):
# if name not in ['func','loader']:
# self.loader.__setattr__(name,value)
# else: super().__setattr__(name,value)
def minibatch_to(mb,device=None,dtype=None):
try: return mb.to(device=device,dtype=dtype)
except AttributeError:
if isinstance(mb,dict):
return type(mb)(((k,minibatch_to(v,device,dtype)) for k,v in mb.items()))
else:
return type(mb)(minibatch_to(elem,device,dtype) for elem in mb)
import functools
def LoaderTo(loader,device=None,dtype=None):
return imap(functools.partial(minibatch_to,device=device,dtype=dtype),loader)
# class LoaderTo(torch.utils.data.DataLoader):
# def __init__(self,loader, device):
# self.__dict__ = loader.__dict__
# self._device = device
# def __iter__(self):
# def minibatch_map(mb):
# try: return mb.to(self._device)
# except AttributeError:
# return type(mb)(minibatch_map(elem) for elem in mb)#map(lambda x:x.to(self._device),mb))
# return map(minibatch_map,super().__iter__())
# class islice(object):
# def __init__(self,dataloader,k):
# """ Wraps a dataloader, but only takes the first k elements with iter,
# if shuffling is enabled, this may be different from different
# calls to iter """
# self._k = k
# self.loader = dataloader
# def __iter__(self):
# return iter(itertools.islice(self.loader),self._k)
# def __getattr__(self,name):
# return self.loader.__getattribute__(name)
# def __setattr__(self,name,value):
# if name not in ['_k','loader']:
# self.loader.__setattr__(name,value)
# else: super().__setattr__(name,value)
# class islice(torch.utils.data.DataLoader):
# def __init__(self,dataloader,k):
# """ Wraps a dataloader, but only takes the first k elements with iter,
# if shuffling is enabled, this may be different from different
# calls to iter """
# self.__dict__= dataloader.__dict__
# self.dl = dataloader
# self._k = k
# def __iter__(self):
# return iter(itertools.islice(self.dl,self._k))
def to_device_layer(device):
def minibatch_map(mb):
try: return mb.to(device)
except AttributeError:
return type(mb)(map(lambda x:x.to(device),mb))
return Expression(minibatch_map)
# def loader_to(device):
# """Returns a function that sends dataloader output
# to the specified device"""
# def minibatch_map(mb):
# try: return mb.to(device)
# except AttributeError:
# return type(mb)(map(lambda x:x.to(device),mb))
# return lambda loader: map_with_len(minibatch_map, loader)
# # Wraps a generator so that calling __iter__ multiple
# # times produces distinct non-empty generators
class reusable(object):
def __init__(self, generator_constructor):
self._gen = generator_constructor
def __iter__(self):
return self._gen()
# def __len__(self):
# return len(self._gen())
# class islice(object):
# def __init__(self,dataloader,k):
# self.dataloader = dataloader
# self.k = k
# def __iter__(self):
# return iter(itertools.islice(self.dataloader,self.k))
# def __len__(self):
# return self.k
class izip(object):
def __init__(self,*iters):
self.iters = iters
def __iter__(self):
return iter(zip(*self.iters))
def __len__(self):
return min(len(it) for it in self.iters)
class Eval(object):
def __init__(self, model, on=True):
self.model = model
self.on = on
def __enter__(self):
self.training_state = self.model.training
self.model.train(not self.on)
def __exit__(self, *args):
self.model.train(self.training_state)
class FixedNumpySeed(object):
def __init__(self, seed):
self.seed = seed
def __enter__(self):
self.np_rng_state = np.random.get_state()
np.random.seed(self.seed)
self.rand_rng_state = random.getstate()
random.seed(self.seed)
def __exit__(self, *args):
np.random.set_state(self.np_rng_state)
random.setstate(self.rand_rng_state)
class FixedPytorchSeed(object):
def __init__(self, seed):
self.seed = seed
def __enter__(self):
self.pt_rng_state = torch.random.get_rng_state()
torch.manual_seed(self.seed)
def __exit__(self, *args):
torch.random.set_rng_state(self.pt_rng_state)
class Expression(nn.Module):
def __init__(self, func):
super(Expression, self).__init__()
self.func = func
def forward(self, *args,**kwargs):
return self.func(*args,**kwargs)
def cosLr(num_epochs,cycle_mult=1):
if isinstance(num_epochs, collections.abc.Iterable):
num_epochs = sum(num_epochs)
def lrSched(epoch):
r = cycle_mult + 1e-8
L = num_epochs#cycle_length #base
current_cycle = np.floor(np.log(1+(r-1)*epoch/L)/np.log(r))
current_cycle_length = L*r**current_cycle
cycle_iter = epoch - L*(r**current_cycle - 1)/(r-1) #(cap lr from going too low)
cos_scale = max(.5*(1 + np.cos(np.pi*cycle_iter/current_cycle_length)),1e-3)
return cos_scale
return lrSched
def recursively_update(d, u):
for k, v in u.items():
if isinstance(v, dict):
d[k] = recursively_update(d.get(k, type(v)()), v)
else:
d[k] = v
return d
def to_lambda(x):
""" Turns constants into constant functions """
if isinstance(x, numbers.Number):
return lambda e: x
else:
return x
def prettyPrintLog(logDict, *epochIts):
formatStr = "[%3d/%d][%6d/%d] "
valuesTuple = epochIts
for key, val in logDict.items():
formatStr += (key+": %.3f ")
valuesTuple += (val,)
print(formatStr % valuesTuple)
def logOneMinusSoftmax(x):
""" numerically more stable version of log(1-softmax(x)) """
max_vals, _ = torch.max(x, 1)
shifted_x = x - max_vals.unsqueeze(1).expand_as(x)
exp_x = torch.exp(shifted_x)
sum_exp_x = exp_x.sum(1).unsqueeze(1).expand_as(exp_x)
k = x.size()[1]
batch_size = x.size()[0]
sum_except_matrix = (torch.ones(k,k) - torch.eye(k)).cuda()
resized_sum_except_m = sum_except_matrix.squeeze(0).expand(batch_size,k,k)
sum_except_exp_x = torch.bmm(resized_sum_except_m, exp_x.unsqueeze(2)).squeeze()
return torch.log(sum_except_exp_x) - torch.log(sum_exp_x)
def init_args():
frame = inspect.currentframe()
outer_frames = inspect.getouterframes(frame)
caller_frame = outer_frames[1][0]
args =inspect.getargvalues(caller_frame)[-1]
args.pop("self")
return args
# Coded by Massimiliano Tomassoli, 2012.
def genCur(func, unique = True, minArgs = None):
""" Generates a 'curried' version of a function. """
def g(*myArgs, **myKwArgs):
def f(*args, **kwArgs):
if args or kwArgs: # some more args!
# Allocates data to assign to the next 'f'.
newArgs = myArgs + args
newKwArgs = dict.copy(myKwArgs)
# If unique is True, we don't want repeated keyword arguments.
if unique and not kwArgs.keys().isdisjoint(newKwArgs):
raise ValueError("Repeated kw arg while unique = True")
# Adds/updates keyword arguments.
newKwArgs.update(kwArgs)
# Checks whether it's time to evaluate func.
numArgsIn = len(newArgs) + len(newKwArgs)
totalArgs = len(inspect.getfullargspec(func).args)
namedArgs = 0 if func.__defaults__ is None else len(func.__defaults__)
numArgsRequired = totalArgs - namedArgs
if (minArgs is not None and minArgs <= numArgsIn) \
or (minArgs is None and numArgsRequired <= len(newArgs)):
#print(newArgs)
#print(newKwArgs)
return func(*newArgs, **newKwArgs) # time to evaluate func
else:
return g(*newArgs, **newKwArgs) # returns a new 'f'
else: # the evaluation was forced
return func(*myArgs, **myKwArgs)
return f
return g
def curry(f,minArgs = None):
return genCur(f, True, minArgs)
def cur(f, minArgs = None):
return genCur(f, True, minArgs)
def curr(f, minArgs = None):
return genCur(f, False, minArgs)
def dillcopy(obj):
return dill.loads(dill.dumps(obj))
## Super hacky method to that returns a method that constructs the same object
## as the object constructed when this is called in an __init__ method of a base
## class
def reconstructor():
frame = inspect.currentframe()
outer_frames = inspect.getouterframes(frame)
subclass_depth=-1
while inspect.getframeinfo(outer_frames[subclass_depth+2][0])[2]=='__init__':
subclass_caller_frame = outer_frames[subclass_depth+2][0]
subclass_depth +=1
assert subclass_depth >=0, "Not called in an __init__ method"
#print("subclass depth = {}".format(subclass_depth))
argnames,varargname,keywordname,localss = inspect.getargvalues(subclass_caller_frame)
args_in = {k:v for k,v in localss.items() if k in argnames}
cls = args_in.pop("self").__class__
args_in_copy = dillcopy(args_in)
args = dillcopy(localss[varargname]) if varargname is not None else None
kwargs = dillcopy(localss[keywordname]) if keywordname is not None else {}
kwargs.update(args_in_copy)
if args is not None: return lambda **newKwArgs: cls(*args,**dict(kwargs,**newKwArgs))
else: return lambda **newKwArgs: cls(**dict(kwargs,**newKwArgs))
def make_like(reconstructible):
if isinstance(reconstructible, str):
load_path = reconstructible
if os.path.isfile(load_path):
state = torch.load(load_path, pickle_module=dill)
return state['reconstructor'](rebuildable=False)
else:
print("=> no checkpoint found at '{}'".format(load_path))
else: #Then it is a live object
return reconstructible.reconstructor(rebuildable=False)
def full_load(reconstructible):
if isinstance(reconstructible, str):
load_path = reconstructible
if os.path.isfile(load_path):
state = torch.load(load_path, pickle_module=dill)
model = state['reconstructor'](rebuildable=False)
model.load_state(state)
return
else:
print("=> no checkpoint found at '{}'".format(load_path))
else:
model = reconstructible.reconstructor(rebuildable=False)
model.load_state(model.get_state())
import sys
import select
def maybe_get_input():
""" Returns None if no enter has been pressed, otherwise the line"""
i,o,e = select.select([sys.stdin],[],[],0.0001)
for s in i:
if s == sys.stdin:
return sys.stdin.readline()
return None