# The data visualization module
# Defines plotting methods for analysis
from glob import glob
from plotly import graph_objs as go, io as pio, subplots
from plotly.offline import init_notebook_mode, iplot
from slm_lab.lib import logger, util
import colorlover as cl
import os
import pydash as ps
logger = logger.get_logger(__name__)
pio.templates.default = 'none' # set default white background for plots
# moving-average window size for plotting
if util.is_jupyter():
def calc_sr_ma(sr):
'''Calculate the moving-average of a series to be plotted'''
return sr.rolling(PLOT_MA_WINDOW, min_periods=1).mean()
def create_label(y_col, x_col, title=None, y_title=None, x_title=None, legend_name=None):
'''Create label dict for go.Layout with smart resolution'''
legend_name = legend_name or y_col
y_col_list, x_col_list, legend_name_list = ps.map_(
[y_col, x_col, legend_name], util.cast_list)
y_title = str(y_title or ','.join(y_col_list))
x_title = str(x_title or ','.join(x_col_list))
title = title or f'{y_title} vs {x_title}'
label = {
'y_title': y_title,
'x_title': x_title,
'title': title,
'y_col_list': y_col_list,
'x_col_list': x_col_list,
'legend_name_list': legend_name_list,
return label
def create_layout(title, y_title, x_title, x_type=None, width=500, height=500, layout_kwargs=None):
'''simplified method to generate Layout'''
layout = go.Layout(
legend=dict(x=0.0, y=-0.25, orientation='h'),
yaxis=dict(rangemode='tozero', title=y_title),
xaxis=dict(type=x_type, title=x_title),
width=width, height=height,
margin=go.layout.Margin(l=60, r=30, t=60, b=60),
return layout
def get_palette(size):
'''Get the suitable palette of a certain size'''
if size <= 8:
palette = cl.scales[str(max(3, size))]['qual']['Set2']
palette = cl.interp(cl.scales['8']['qual']['Set2'], size)
return palette
def lower_opacity(rgb, opacity):
return rgb.replace('rgb(', 'rgba(').replace('hsl(', 'hsla(').replace(')', f', {opacity})')
def plot(*args, **kwargs):
if util.is_jupyter():
return iplot(*args, **kwargs)
def plot_sr(sr, time_sr, title, y_title, x_title, color=None):
'''Plot a series'''
x = time_sr.tolist()
color = color or get_palette(1)[0]
main_trace = go.Scatter(
x=x, y=sr, mode='lines', showlegend=False,
line={'color': color, 'width': 1},
data = [main_trace]
layout = create_layout(title=title, y_title=y_title, x_title=x_title)
fig = go.Figure(data, layout)
return fig
def plot_mean_sr(sr_list, time_sr, title, y_title, x_title, color=None):
'''Plot a list of series using its mean, with error bar using std'''
mean_sr, std_sr = util.calc_srs_mean_std(sr_list)
max_sr = mean_sr + std_sr
min_sr = mean_sr - std_sr
max_y = max_sr.tolist()
min_y = min_sr.tolist()
x = time_sr.tolist()
color = color or get_palette(1)[0]
main_trace = go.Scatter(
x=x, y=mean_sr, mode='lines', showlegend=False,
line={'color': color, 'width': 1},
envelope_trace = go.Scatter(
x=x + x[::-1], y=max_y + min_y[::-1], showlegend=False,
line={'color': 'rgba(0, 0, 0, 0)'},
fill='tozerox', fillcolor=lower_opacity(color, 0.15),
data = [main_trace, envelope_trace]
layout = create_layout(title=title, y_title=y_title, x_title=x_title)
fig = go.Figure(data, layout)
return fig
def save_image(figure, filepath):
if os.environ['PY_ENV'] == 'test':
filepath = util.smart_path(filepath)
pio.write_image(figure, filepath, scale=2)
except Exception as e:
logger.warning(f'Failed to generate graph. Run retro-analysis to generate graphs later. {e}\nIf running on a headless server, prepend your Python command with `xvfb-run -a `, for example `xvfb-run -a python`')
# analysis plot methods
def plot_session(session_spec, session_metrics, session_df, df_mode='eval', ma=False):
Plot the session graphs:
- mean_returns, strengths, sample_efficiencies, training_efficiencies, stabilities (with error bar)
- additional plots from session_df: losses, exploration variable, entropy
meta_spec = session_spec['meta']
prepath = meta_spec['prepath']
graph_prepath = meta_spec['graph_prepath']
title = f'session graph: {session_spec["name"]} t{meta_spec["trial"]} s{meta_spec["session"]}'
local_metrics = session_metrics['local']
name_time_pairs = [
('mean_returns', 'frames'),
('strengths', 'frames'),
('sample_efficiencies', 'frames'),
('training_efficiencies', 'opt_steps'),
('stabilities', 'frames'),
for name, time in name_time_pairs:
sr = local_metrics[name]
if ma:
sr = calc_sr_ma(sr)
name = f'{name}_ma' # for labeling
fig = plot_sr(
sr, local_metrics[time], title, name, time)
save_image(fig, f'{graph_prepath}_session_graph_{df_mode}_{name}_vs_{time}.png')
if name in ('mean_returns', 'mean_returns_ma'): # save important graphs in prepath directly
save_image(fig, f'{prepath}_session_graph_{df_mode}_{name}_vs_{time}.png')
if ma:
# training plots from session_df
name_time_pairs = [
('loss', 'frame'),
('explore_var', 'frame'),
('entropy', 'frame'),
for name, time in name_time_pairs:
fig = plot_sr(
session_df[name], session_df[time], title, name, time)
save_image(fig, f'{graph_prepath}_session_graph_{df_mode}_{name}_vs_{time}.png')
def plot_trial(trial_spec, trial_metrics, ma=False):
Plot the trial graphs:
- mean_returns, strengths, sample_efficiencies, training_efficiencies, stabilities (with error bar)
- consistencies (no error bar)
meta_spec = trial_spec['meta']
prepath = meta_spec['prepath']
graph_prepath = meta_spec['graph_prepath']
title = f'trial graph: {trial_spec["name"]} t{meta_spec["trial"]} {meta_spec["max_session"]} sessions'
local_metrics = trial_metrics['local']
name_time_pairs = [
('mean_returns', 'frames'),
('strengths', 'frames'),
('sample_efficiencies', 'frames'),
('training_efficiencies', 'opt_steps'),
('stabilities', 'frames'),
('consistencies', 'frames'),
for name, time in name_time_pairs:
if name == 'consistencies':
sr = local_metrics[name]
if ma:
sr = calc_sr_ma(sr)
name = f'{name}_ma' # for labeling
fig = plot_sr(
sr, local_metrics[time], title, name, time)
sr_list = local_metrics[name]
if ma:
sr_list = [calc_sr_ma(sr) for sr in sr_list]
name = f'{name}_ma' # for labeling
fig = plot_mean_sr(
sr_list, local_metrics[time], title, name, time)
save_image(fig, f'{graph_prepath}_trial_graph_{name}_vs_{time}.png')
if name in ('mean_returns', 'mean_returns_ma'): # save important graphs in prepath directly
save_image(fig, f'{prepath}_trial_graph_{name}_vs_{time}.png')
def plot_experiment(experiment_spec, experiment_df, metrics_cols):
Plot the metrics vs. specs parameters of an experiment, where each point is a trial.
ref colors:
y_cols = metrics_cols
x_cols = ps.difference(experiment_df.columns.tolist(), y_cols + ['trial'])
fig = subplots.make_subplots(rows=len(y_cols), cols=len(x_cols), shared_xaxes=True, shared_yaxes=True, print_grid=False)
strength_sr = experiment_df['strength']
min_strength, max_strength = strength_sr.min(), strength_sr.max()
for row_idx, y in enumerate(y_cols):
for col_idx, x in enumerate(x_cols):
x_sr = experiment_df[x]
guard_cat_x = x_sr.astype(str) if x_sr.dtype == 'object' else x_sr
trace = go.Scatter(
y=experiment_df[y], yaxis=f'y{row_idx+1}',
x=guard_cat_x, xaxis=f'x{col_idx+1}',
showlegend=False, mode='markers',
'symbol': 'circle-open-dot', 'color': strength_sr, 'opacity': 0.5,
# dump first portion of colorscale that is too bright
'cmin': min_strength - 0.5 * (max_strength - min_strength), 'cmax': max_strength,
'colorscale': 'YlGnBu', 'reversescale': False
fig.add_trace(trace, row_idx + 1, col_idx + 1)
fig.update_xaxes(title_text='<br>'.join(ps.chunk(x, 20)), zerolinewidth=1, categoryarray=sorted(guard_cat_x.unique()), row=len(y_cols), col=col_idx+1)
fig.update_yaxes(title_text=y, rangemode='tozero', row=row_idx+1, col=1)
title=f'experiment graph: {experiment_spec["name"]}',
width=100 + 300 * len(x_cols), height=200 + 300 * len(y_cols))
graph_prepath = experiment_spec['meta']['graph_prepath']
save_image(fig, f'{graph_prepath}_experiment_graph.png')
# save important graphs in prepath directly
prepath = experiment_spec['meta']['prepath']
save_image(fig, f'{prepath}_experiment_graph.png')
return fig
def plot_multi_local_metrics(local_metrics_list, legend_list, name, time, title, palette=None, showlegend=True):
'''Method to plot list local_metrics gathered from multiple trials, with ability to specify custom legend and title. Used by plot_multi_trial'''
palette = palette or get_palette(len(local_metrics_list))
all_data = []
for idx, local_metrics in enumerate(local_metrics_list):
fig = plot_mean_sr(
local_metrics[name], local_metrics[time], '', name, time, color=palette[idx])
if legend_list is not None:
# update legend for the main trace[0].update({'showlegend': showlegend, 'name': legend_list[idx]})
all_data += list(
layout = create_layout(title, name, time)
fig = go.Figure(all_data, layout)
return fig
def plot_multi_trial(trial_metrics_path_list, legend_list, title, graph_prepath, ma=False, name_time_pairs=None, frame_scales=None, palette=None, showlegend=True):
Plot multiple trial graphs together
This method can be used in analysis and also custom plotting by specifying the arguments manually
trial_metrics_path_list = [
legend_list = [
title = f'Multi trial trial graphs'
graph_prepath = 'data/my_exp'
viz.plot_multi_trial(trial_metrics_path_list, legend_list, title, graph_prepath)
local_metrics_list = [['local'] for path in trial_metrics_path_list]
# for plotting with async runs to adjust frame scale
if frame_scales is not None:
for idx, scale in frame_scales:
local_metrics_list[idx]['frames'] = local_metrics_list[idx]['frames'] * scale
name_time_pairs = name_time_pairs or [
('mean_returns', 'frames'),
('strengths', 'frames'),
('sample_efficiencies', 'frames'),
('training_efficiencies', 'opt_steps'),
('stabilities', 'frames')
for name, time in name_time_pairs:
if ma:
for local_metrics in local_metrics_list:
sr_list = local_metrics[name]
sr_list = [calc_sr_ma(sr) for sr in sr_list]
local_metrics[f'{name}_ma'] = sr_list
name = f'{name}_ma' # for labeling
fig = plot_multi_local_metrics(local_metrics_list, legend_list, name, time, title, palette, showlegend)
save_image(fig, f'{graph_prepath}_multi_trial_graph_{name}_vs_{time}.png')
if name in ('mean_returns', 'mean_returns_ma'): # save important graphs in prepath directly
prepath = graph_prepath.replace('/graph/', '/')
save_image(fig, f'{prepath}_multi_trial_graph_{name}_vs_{time}.png')
def get_trial_legends(experiment_df, trial_idxs, metrics_cols):
'''Format trial variables in experiment_df into legend strings'''
var_df = experiment_df.drop(metrics_cols, axis=1).set_index('trial')
trial_legends = []
for trial_idx in trial_idxs:
trial_vars = var_df.loc[trial_idx].to_dict()
var_list = []
for k, v in trial_vars.items():
if hasattr(v, '__round__'):
v = round(v, 8) # prevent long float digits in formatting
var_list.append(f'{k.split(".").pop()} {v}')
var_str = ' '.join(var_list)
legend = f't{trial_idx}: {var_str}'
return trial_legends
def plot_experiment_trials(experiment_spec, experiment_df, metrics_cols):
meta_spec = experiment_spec['meta']
info_prepath = meta_spec['info_prepath']
trial_metrics_path_list = glob(f'{info_prepath}*_trial_metrics.pkl')
# sort by trial id
trial_metrics_path_list = list(sorted(trial_metrics_path_list, key=lambda k: util.prepath_to_idxs(k)[0]))
# get trial indices to build legends
trial_idxs = [util.prepath_to_idxs(prepath)[0] for prepath in trial_metrics_path_list]
legend_list = get_trial_legends(experiment_df, trial_idxs, metrics_cols)
title = f'multi trial graph: {experiment_spec["name"]}'
graph_prepath = meta_spec['graph_prepath']
plot_multi_trial(trial_metrics_path_list, legend_list, title, graph_prepath)
plot_multi_trial(trial_metrics_path_list, legend_list, title, graph_prepath, ma=True)