mfinzi/pristine-ml

View on GitHub
oil/tuning/slurmExecutor.py

Summary

Maintainability
A
0 mins
Test Coverage

import dill
from ast import literal_eval
import sys,os,stat
import time
import tempfile
import atexit
import subprocess
from concurrent import futures
from functools import partial
import itertools
import torch
from oil.tuning.localGpuExecutor import LocalGpuExecutor

def kwargs_to_list(kwargs):
    return ["%s%s"%(('--'+k+'=',v) if len(k)>1
               else ('-'+k+' ',v)) for k,v in kwargs.items()]
               
def tmp_file_name(suffix=".sh"):
    t = tempfile.mktemp(dir='.',suffix=suffix)
    atexit.register(os.unlink, t)
    return t

class SlurmExecutor(futures.ThreadPoolExecutor):
    def __init__(self,*args,slurm_cfg={},clone_session=True,**kwargs):
        self.slurm_cfg = slurm_cfg
        # Dump the python session
        if clone_session:
            self.session_file_name = tmp_file_name(".pkl")
            dill.dump_session(self.session_file_name)
        else:
            self.session_file_name = 'no_session'
        super().__init__(*args,**kwargs)

    def submit(self,fn,*args,**kwargs):
        def todo():
            with open(tmp_file_name(), 'wb+') as funcfile:
                dill.dump((fn,args,kwargs),funcfile)
            with open(tmp_file_name(), "wb+") as sh_script:
                sh_script.write(os.fsencode('#!/bin/sh\n{} {} {} {}'\
                        .format(sys.executable,os.path.realpath(__file__),
                        funcfile.name,self.session_file_name)))
                os.fchmod(sh_script.fileno(),stat.S_IRWXU|stat.S_IRWXG|stat.S_IROTH|stat.S_IXOTH)
            cfg_args = kwargs_to_list(self.slurm_cfg)
            subprocess.call(['srun',*cfg_args,sh_script.name])
            with open(funcfile.name, 'rb') as funcfile:
                function_output = dill.load(funcfile)
            return function_output
        return super().submit(todo)

    def map(self, fn, *iterables, timeout=None, chunksize=1):
        """ Identical to the chunky ProcessPoolExecutor implementation,
            but underlying parts aren't exposed """
        if chunksize < 1:
            raise ValueError("chunksize must be >= 1.")
        results = super().map(partial(_process_chunk, fn),
                              _get_chunks(*iterables, chunksize=chunksize),
                              timeout=timeout)
        return _chain_from_iterable_of_lists(results)

def _process_chunk(fn, chunk):
    return [fn(*args) for args in chunk]
def _get_chunks(*iterables, chunksize):
    it = zip(*iterables)
    while True:
        chunk = tuple(itertools.islice(it, chunksize))
        if not chunk:
            return
        yield chunk
def _chain_from_iterable_of_lists(iterable):
    for element in iterable:
        element.reverse()
        while element:
            yield element.pop()

def LocalExecutor(max_workers=None):
    if max_workers==1 or torch.cuda.device_count()<=1 or os.environ.copy().get("WORLD_SIZE",0)!=0:
        print("local")
        return futures.ThreadPoolExecutor(max_workers=1)
    else:
        return LocalGpuExecutor(max_workers)

# #LocalExecutor = LocalGpuExecutor
# class LocalExecutor(futures.ThreadPoolExecutor):
#     """Wraps ProcessPoolExecutor but distributes local gpus to the
#         processes #TODO: restrict gpu allocation. At the moment restricts
#         to sequential (single core and gpu) execution."""
#     def __init__(self,max_workers,*args,**kwargs):
#         super().__init__(max_workers=1,*args,**kwargs)
#         #os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
#         #os.environ["CUDA_VISIBLE_DEVICES"]="0"
        

if __name__=='__main__':
    if sys.argv[2]!='no_session':
        dill.load_session(sys.argv[2])
    with open(sys.argv[1], 'rb') as funcfile:
        (fn,args,kwargs) = dill.load(funcfile)
        out = fn(*args,**kwargs)
    with open(sys.argv[1], 'wb+') as funcfile:
        dill.dump(out,funcfile)