slm_lab/spec/spec_util.py
# The spec module
# Manages specification to run things in lab
from slm_lab import ROOT_DIR
from slm_lab.lib import logger, util
from string import Template
import itertools
import json
import os
import pydash as ps
SPEC_DIR = 'slm_lab/spec'
'''
All spec values are already param, inferred automatically.
To change from a value into param range, e.g.
- single: "explore_anneal_epi": 50
- continuous param: "explore_anneal_epi": {"min": 50, "max": 100, "dist": "uniform"}
- discrete range: "explore_anneal_epi": {"values": [50, 75, 100]}
'''
SPEC_FORMAT = {
"agent": [{
"name": str,
"algorithm": dict,
"memory": dict,
"net": dict,
}],
"env": [{
"name": str,
"max_t": (type(None), int, float),
"max_frame": (int, float),
}],
"body": {
"product": ["outer", "inner", "custom"],
"num": (int, list),
},
"meta": {
"max_session": int,
"max_trial": (type(None), int),
},
"name": str,
}
logger = logger.get_logger(__name__)
def check_comp_spec(comp_spec, comp_spec_format):
'''Base method to check component spec'''
for spec_k, spec_format_v in comp_spec_format.items():
comp_spec_v = comp_spec[spec_k]
if ps.is_list(spec_format_v):
v_set = spec_format_v
assert comp_spec_v in v_set, f'Component spec value {ps.pick(comp_spec, spec_k)} needs to be one of {util.to_json(v_set)}'
else:
v_type = spec_format_v
assert isinstance(comp_spec_v, v_type), f'Component spec {ps.pick(comp_spec, spec_k)} needs to be of type: {v_type}'
if isinstance(v_type, tuple) and int in v_type and isinstance(comp_spec_v, float):
# cast if it can be int
comp_spec[spec_k] = int(comp_spec_v)
def check_body_spec(spec):
'''Base method to check body spec for multi-agent multi-env'''
ae_product = ps.get(spec, 'body.product')
body_num = ps.get(spec, 'body.num')
if ae_product == 'outer':
pass
elif ae_product == 'inner':
agent_num = len(spec['agent'])
env_num = len(spec['env'])
assert agent_num == env_num, 'Agent and Env spec length must be equal for body `inner` product. Given {agent_num}, {env_num}'
else: # custom
assert ps.is_list(body_num)
def check_compatibility(spec):
'''Check compatibility among spec setups'''
# TODO expand to be more comprehensive
if spec['meta'].get('distributed') == 'synced':
assert ps.get(spec, 'agent.0.net.gpu') == False, f'Distributed mode "synced" works with CPU only. Set gpu: false.'
def check(spec):
'''Check a single spec for validity'''
try:
spec_name = spec.get('name')
assert set(spec.keys()) >= set(SPEC_FORMAT.keys()), f'Spec needs to follow spec.SPEC_FORMAT. Given \n {spec_name}: {util.to_json(spec)}'
for agent_spec in spec['agent']:
check_comp_spec(agent_spec, SPEC_FORMAT['agent'][0])
for env_spec in spec['env']:
check_comp_spec(env_spec, SPEC_FORMAT['env'][0])
check_comp_spec(spec['body'], SPEC_FORMAT['body'])
check_comp_spec(spec['meta'], SPEC_FORMAT['meta'])
# check_body_spec(spec)
check_compatibility(spec)
except Exception as e:
logger.exception(f'spec {spec_name} fails spec check')
raise e
return True
def check_all():
'''Check all spec files, all specs.'''
spec_files = ps.filter_(os.listdir(SPEC_DIR), lambda f: f.endswith('.json') and not f.startswith('_'))
for spec_file in spec_files:
spec_dict = util.read(f'{SPEC_DIR}/{spec_file}')
for spec_name, spec in spec_dict.items():
# fill-in info at runtime
spec['name'] = spec_name
spec = extend_meta_spec(spec)
try:
check(spec)
except Exception as e:
logger.exception(f'spec_file {spec_file} fails spec check')
raise e
logger.info(f'Checked all specs from: {ps.join(spec_files, ",")}')
return True
def extend_meta_spec(spec, experiment_ts=None):
'''
Extend meta spec with information for lab functions
@param dict:spec
@param str:experiment_ts Use this experiment_ts if given; used for resuming training
'''
extended_meta_spec = {
'rigorous_eval': ps.get(spec, 'meta.rigorous_eval', 0),
# reset lab indices to -1 so that they tick to 0
'experiment': -1,
'trial': -1,
'session': -1,
'cuda_offset': int(os.environ.get('CUDA_OFFSET', 0)),
'resume': experiment_ts is not None,
'experiment_ts': experiment_ts or util.get_ts(),
'prepath': None,
'git_sha': util.get_git_sha(),
'random_seed': None,
}
spec['meta'].update(extended_meta_spec)
return spec
def get(spec_file, spec_name, experiment_ts=None):
'''
Get an experiment spec from spec_file, spec_name.
Auto-check spec.
@param str:spec_file
@param str:spec_name
@param str:experiment_ts Use this experiment_ts if given; used for resuming training
@example
spec = spec_util.get('demo.json', 'dqn_cartpole')
'''
spec_file = spec_file.replace(SPEC_DIR, '') # guard
spec_file = f'{SPEC_DIR}/{spec_file}' # allow direct filename
spec_dict = util.read(spec_file)
assert spec_name in spec_dict, f'spec_name {spec_name} is not in spec_file {spec_file}. Choose from:\n {ps.join(spec_dict.keys(), ",")}'
spec = spec_dict[spec_name]
# fill-in info at runtime
spec['name'] = spec_name
spec = extend_meta_spec(spec, experiment_ts)
check(spec)
return spec
def get_param_specs(spec):
'''Return a list of specs with substituted spec_params'''
assert 'spec_params' in spec, 'Parametrized spec needs a spec_params key'
spec_params = spec.pop('spec_params')
spec_template = Template(json.dumps(spec))
keys = spec_params.keys()
specs = []
for idx, vals in enumerate(itertools.product(*spec_params.values())):
spec_str = spec_template.substitute(dict(zip(keys, vals)))
spec = json.loads(spec_str)
spec['name'] += f'_{"_".join(vals)}'
# offset to prevent parallel-run GPU competition, to mod in util.set_cuda_id
spec['meta']['cuda_offset'] += idx * spec['meta']['max_session']
specs.append(spec)
return specs
def _override_dev_spec(spec):
spec['meta']['max_session'] = 1
spec['meta']['max_trial'] = 2
return spec
def _override_enjoy_spec(spec):
spec['meta']['max_session'] = 1
return spec
def _override_test_spec(spec):
for agent_spec in spec['agent']:
# onpolicy freq is episodic
freq = 1 if agent_spec['memory']['name'] == 'OnPolicyReplay' else 8
agent_spec['algorithm']['training_frequency'] = freq
agent_spec['algorithm']['time_horizon'] = freq
agent_spec['algorithm']['training_start_step'] = 1
agent_spec['algorithm']['training_iter'] = 1
agent_spec['algorithm']['training_batch_iter'] = 1
for env_spec in spec['env']:
env_spec['max_frame'] = 40
if env_spec.get('num_envs', 1) > 1:
env_spec['num_envs'] = 2
env_spec['max_t'] = 12
spec['meta']['log_frequency'] = 10
spec['meta']['eval_frequency'] = 10
spec['meta']['max_session'] = 1
spec['meta']['max_trial'] = 2
return spec
def override_spec(spec, mode):
'''Override spec based on the (lab_)mode, do nothing otherwise.'''
overrider = {
'dev': _override_dev_spec,
'enjoy': _override_enjoy_spec,
'test': _override_test_spec,
}.get(mode)
if overrider is not None:
overrider(spec)
return spec
def save(spec, unit='experiment'):
'''Save spec to proper path. Called at Experiment or Trial init.'''
prepath = util.get_prepath(spec, unit)
util.write(spec, f'{prepath}_spec.json')
def tick(spec, unit):
'''
Method to tick lab unit (experiment, trial, session) in meta spec to advance their indices
Reset lower lab indices to -1 so that they tick to 0
spec_util.tick(spec, 'session')
session = Session(spec)
'''
if util.get_lab_mode() == 'enjoy': # don't tick in enjoy mode
return spec
meta_spec = spec['meta']
if unit == 'experiment':
meta_spec['experiment_ts'] = util.get_ts()
meta_spec['experiment'] += 1
meta_spec['trial'] = -1
meta_spec['session'] = -1
elif unit == 'trial':
if meta_spec['experiment'] == -1:
meta_spec['experiment'] += 1
meta_spec['trial'] += 1
meta_spec['session'] = -1
elif unit == 'session':
if meta_spec['experiment'] == -1:
meta_spec['experiment'] += 1
if meta_spec['trial'] == -1:
meta_spec['trial'] += 1
meta_spec['session'] += 1
else:
raise ValueError(f'Unrecognized lab unit to tick: {unit}')
# set prepath since it is determined at this point
meta_spec['prepath'] = prepath = util.get_prepath(spec, unit)
for folder in ('graph', 'info', 'log', 'model'):
folder_prepath = util.insert_folder(prepath, folder)
folder_predir = os.path.dirname(f'{ROOT_DIR}/{folder_prepath}')
os.makedirs(folder_predir, exist_ok=True)
assert os.path.exists(folder_predir)
meta_spec[f'{folder}_prepath'] = folder_prepath
return spec