docs/search_path_gif.py
def warn(*args, **kwargs):
pass
import warnings
warnings.warn = warn
import os
import gc
import glob
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib as mpl
import matplotlib.pyplot as plt
plt.rcParams["figure.facecolor"] = "w"
mpl.use("agg")
from gradient_free_optimizers.optimizers.core_optimizer.converter import Converter
def plot_search_paths(
path,
optimizer,
opt_para,
n_iter_max,
objective_function,
search_space,
initialize,
random_state,
title,
):
if opt_para == {}:
show_opt_para = False
else:
show_opt_para = True
opt = optimizer(
search_space, initialize=initialize, random_state=random_state, **opt_para
)
opt.search(
objective_function,
n_iter=n_iter_max,
# memory=False,
verbosity=False,
)
conv = Converter(search_space)
for n_iter in tqdm(range(1, n_iter_max + 1)):
def objective_function_np(args):
params = {}
for i, para_name in enumerate(search_space):
params[para_name] = args[i]
return objective_function(params)
plt.figure(figsize=(7, 7))
plt.set_cmap("jet_r")
# jet_r
x_all, y_all = search_space["x0"], search_space["x1"]
xi, yi = np.meshgrid(x_all, y_all)
zi = objective_function_np((xi, yi))
zi = np.rot90(zi, k=1)
plt.imshow(
zi,
alpha=0.15,
# interpolation="antialiased",
# vmin=z.min(),
# vmax=z.max(),
# origin="lower",
extent=[x_all.min(), x_all.max(), y_all.min(), y_all.max()],
)
for n, opt_ in enumerate(opt.optimizers):
n_optimizers = len(opt.optimizers)
n_iter_tmp = int(n_iter / n_optimizers)
n_iter_mod = n_iter % n_optimizers
if n_iter_mod > n:
n_iter_tmp += 1
if n_iter_tmp == 0:
continue
pos_list = np.array(opt_.pos_new_list)
score_list = np.array(opt_.score_new_list)
# print("\n pos_list \n", pos_list, "\n")
# print("\n score_list \n", score_list, "\n")
if len(pos_list) == 0:
continue
values_list = conv.positions2values(pos_list)
values_list = np.array(values_list)
plt.plot(
values_list[:n_iter_tmp, 0],
values_list[:n_iter_tmp, 1],
linestyle="--",
marker=",",
color="black",
alpha=0.33,
label=n,
linewidth=0.5,
)
plt.scatter(
values_list[:n_iter_tmp, 0],
values_list[:n_iter_tmp, 1],
c=score_list[:n_iter_tmp],
marker="H",
s=15,
vmin=np.amin(score_list[:n_iter_tmp]),
vmax=np.amax(score_list[:n_iter_tmp]),
label=n,
edgecolors="black",
linewidth=0.3,
)
plt.xlabel("x")
plt.ylabel("y")
nth_iteration = "\n\nnth Iteration: " + str(n_iter)
opt_para_name = ""
opt_para_value = "\n\n"
if show_opt_para:
opt_para_name += "\n Parameter:"
for para_name, para_value in opt_para.items():
opt_para_name += "\n " + " " + para_name + ": "
opt_para_value += "\n " + str(para_value) + " "
if title == True:
title_name = opt.name + "\n" + opt_para_name
plt.title(title_name, loc="left")
plt.title(opt_para_value, loc="center")
elif isinstance(title, str):
plt.title(title, loc="left")
plt.title(nth_iteration, loc="right", fontsize=8)
# plt.xlim((-101, 201))
# plt.ylim((-101, 201))
clb = plt.colorbar()
clb.set_label("score", labelpad=-50, y=1.03, rotation=0)
# plt.legend(loc="upper left", bbox_to_anchor=(-0.10, 1.2))
# plt.axis("off")
if show_opt_para:
plt.subplots_adjust(top=0.75)
plt.tight_layout()
# plt.margins(0, 0)
plt.savefig(
path + "/_plots/" + opt._name_ + "_" + "{0:0=3d}".format(n_iter) + ".jpg",
dpi=150,
pad_inches=0,
# bbox_inches="tight",
)
plt.ioff()
# Clear the current axes.
plt.cla()
# Clear the current figure.
plt.clf()
# Closes all the figure windows.
plt.close("all")
gc.collect()
def search_path_gif(
path,
optimizer,
opt_para,
name,
n_iter,
objective_function,
search_space,
initialize,
random_state=0,
title=True,
):
path = os.path.join(os.getcwd(), path)
print("\n\nname", name)
plots_dir = path + "/_plots/"
print("plots_dir", plots_dir)
os.makedirs(plots_dir, exist_ok=True)
plot_search_paths(
path=path,
optimizer=optimizer,
opt_para=opt_para,
n_iter_max=n_iter,
objective_function=objective_function,
search_space=search_space,
initialize=initialize,
random_state=random_state,
title=title,
)
### ffmpeg
framerate = str(n_iter / 10)
# framerate = str(10)
_framerate = " -framerate " + framerate + " "
_opt_ = optimizer(search_space)
_input = " -i " + path + "/_plots/" + str(_opt_._name_) + "_" + "%03d.jpg "
_scale = " -vf scale=1200:-1:flags=lanczos "
_output = os.path.join(path, name)
ffmpeg_command = (
"ffmpeg -hide_banner -loglevel error -y"
+ _framerate
+ _input
+ _scale
+ _output
)
print("\n -----> ffmpeg_command \n", ffmpeg_command, "\n")
print("create " + name)
os.system(ffmpeg_command)
### remove _plots
rm_files = glob.glob(path + "/_plots/*.jpg")
for f in rm_files:
os.remove(f)
os.rmdir(plots_dir)