slm_lab/experiment/control.py
# The control module
# Creates and runs control loops at levels: Experiment, Trial, Session
from copy import deepcopy
from slm_lab.agent import Agent, Body
from slm_lab.agent.net import net_util
from slm_lab.env import make_env
from slm_lab.experiment import analysis, search
from slm_lab.lib import logger, util
from slm_lab.spec import spec_util
import pydash as ps
import torch
import torch.multiprocessing as mp
def make_agent_env(spec, global_nets=None):
'''Helper to create agent and env given spec'''
env = make_env(spec)
body = Body(env, spec)
agent = Agent(spec, body=body, global_nets=global_nets)
return agent, env
def mp_run_session(spec, global_nets, mp_dict):
'''Wrap for multiprocessing with shared variable'''
session = Session(spec, global_nets)
metrics = session.run()
mp_dict[session.index] = metrics
class Session:
'''
The base lab unit to run a RL session for a spec.
Given a spec, it creates the agent and env, runs the RL loop,
then gather data and analyze it to produce session data.
'''
def __init__(self, spec, global_nets=None):
self.spec = spec
self.index = self.spec['meta']['session']
util.set_random_seed(self.spec)
util.set_cuda_id(self.spec)
util.set_logger(self.spec, logger, 'session')
spec_util.save(spec, unit='session')
self.agent, self.env = make_agent_env(self.spec, global_nets)
if ps.get(self.spec, 'meta.rigorous_eval'):
with util.ctx_lab_mode('eval'):
self.eval_env = make_env(self.spec)
else:
self.eval_env = self.env
logger.info(util.self_desc(self))
def to_ckpt(self, env, mode='eval'):
'''Check with clock whether to run log/eval ckpt: at the start, save_freq, and the end'''
if mode == 'eval' and util.in_eval_lab_mode(): # avoid double-eval: eval-ckpt in eval mode
return False
clock = env.clock
frame = clock.get()
frequency = env.eval_frequency if mode == 'eval' else env.log_frequency
to_ckpt = util.frame_mod(frame, frequency, env.num_envs) or frame == clock.max_frame
return to_ckpt
def try_ckpt(self, agent, env):
'''Check then run checkpoint log/eval'''
body = agent.body
if self.to_ckpt(env, 'log'):
body.ckpt(self.env, 'train')
body.log_summary('train')
agent.save() # save the latest ckpt
if body.total_reward_ma >= body.best_total_reward_ma:
body.best_total_reward_ma = body.total_reward_ma
agent.save(ckpt='best')
if len(body.train_df) > 2: # need more rows to calculate metrics
metrics = analysis.analyze_session(self.spec, body.train_df, 'train', plot=False)
body.log_metrics(metrics['scalar'], 'train')
if ps.get(self.spec, 'meta.rigorous_eval') and self.to_ckpt(env, 'eval'):
logger.info('Running eval ckpt')
analysis.gen_avg_return(agent, self.eval_env)
body.ckpt(self.eval_env, 'eval')
body.log_summary('eval')
if len(body.eval_df) > 2: # need more rows to calculate metrics
metrics = analysis.analyze_session(self.spec, body.eval_df, 'eval', plot=False)
body.log_metrics(metrics['scalar'], 'eval')
def run_rl(self):
'''Run the main RL loop until clock.max_frame'''
logger.info(f'Running RL loop for trial {self.spec["meta"]["trial"]} session {self.index}')
clock = self.env.clock
state = self.env.reset()
done = False
while True:
if util.epi_done(done): # before starting another episode
self.try_ckpt(self.agent, self.env)
if clock.get() < clock.max_frame: # reset and continue
clock.tick('epi')
state = self.env.reset()
done = False
self.try_ckpt(self.agent, self.env)
if clock.get() >= clock.max_frame: # finish
break
clock.tick('t')
with torch.no_grad():
action = self.agent.act(state)
next_state, reward, done, info = self.env.step(action)
self.agent.update(state, action, reward, next_state, done)
state = next_state
def close(self):
'''Close session and clean up. Save agent, close env.'''
self.agent.close()
self.env.close()
self.eval_env.close()
torch.cuda.empty_cache()
logger.info(f'Session {self.index} done')
def run(self):
self.run_rl()
metrics = analysis.analyze_session(self.spec, self.agent.body.eval_df, 'eval')
self.agent.body.log_metrics(metrics['scalar'], 'eval')
self.close()
return metrics
class Trial:
'''
The lab unit which runs repeated sessions for a same spec, i.e. a trial
Given a spec and number s, trial creates and runs s sessions,
then gathers session data and analyze it to produce trial data.
'''
def __init__(self, spec):
self.spec = spec
self.index = self.spec['meta']['trial']
util.set_logger(self.spec, logger, 'trial')
spec_util.save(spec, unit='trial')
def parallelize_sessions(self, global_nets=None):
mp_dict = mp.Manager().dict()
workers = []
spec = deepcopy(self.spec)
for _s in range(spec['meta']['max_session']):
spec_util.tick(spec, 'session')
w = mp.Process(target=mp_run_session, args=(spec, global_nets, mp_dict))
w.start()
workers.append(w)
for w in workers:
w.join()
session_metrics_list = [mp_dict[idx] for idx in sorted(mp_dict.keys())]
return session_metrics_list
def run_sessions(self):
logger.info('Running sessions')
if self.spec['meta']['max_session'] == 1:
spec = deepcopy(self.spec)
spec_util.tick(spec, 'session')
session_metrics_list = [Session(spec).run()]
else:
session_metrics_list = self.parallelize_sessions()
return session_metrics_list
def init_global_nets(self):
session = Session(deepcopy(self.spec))
session.env.close() # safety
global_nets = net_util.init_global_nets(session.agent.algorithm)
return global_nets
def run_distributed_sessions(self):
logger.info('Running distributed sessions')
global_nets = self.init_global_nets()
session_metrics_list = self.parallelize_sessions(global_nets)
return session_metrics_list
def close(self):
logger.info(f'Trial {self.index} done')
def run(self):
if self.spec['meta'].get('distributed') == False:
session_metrics_list = self.run_sessions()
else:
session_metrics_list = self.run_distributed_sessions()
metrics = analysis.analyze_trial(self.spec, session_metrics_list)
self.close()
return metrics['scalar']
class Experiment:
'''
The lab unit to run experiments.
It generates a list of specs to search over, then run each as a trial with s repeated session,
then gathers trial data and analyze it to produce experiment data.
'''
def __init__(self, spec):
self.spec = spec
self.index = self.spec['meta']['experiment']
util.set_logger(self.spec, logger, 'trial')
spec_util.save(spec, unit='experiment')
def close(self):
logger.info('Experiment done')
def run(self):
trial_data_dict = search.run_ray_search(self.spec)
experiment_df = analysis.analyze_experiment(self.spec, trial_data_dict)
self.close()
return experiment_df