neat/main.py
"""Implements the NEAT algorithm."""
import hashlib
import json
import os
from time import time
import gym
import numpy as np
import requests
from gym import wrappers
from neat.creature import Creature
from neat.genome import Genome
from neat.population import Population
from neat.pso import PSO
from neat.species import Species
class NeatAlgorithm:
"""An implementation of the NEAT algorithm based off the original paper."""
api_url = 'http://localhost:5000/api'
def __init__(self, env, n_pops=150, offline=False):
self.env = env
self.n_trials = env.spec.trials
self.reward_threshold = env.spec.reward_threshold
genesis = Creature(env.observation_space.shape[0], env.action_space.n)
self.population = Population(genesis, n_pops)
self.fitness_history = []
self.snapshot_i = 0
run_id_hash = hashlib.sha1()
run_id_hash.update(str(time()).encode('utf-8'))
self.run_id = run_id_hash.hexdigest()[:16]
self.offline = offline
if not self.offline:
r = requests.request('POST', NeatAlgorithm.api_url + '/runs',
json=dict(id=self.run_id))
if r.status_code != 201:
print('WARNING: Could not add run data to database.')
print(r.json())
def train(self, n_episodes=100, n_steps=200, n_pso_episodes=5,
debug_mode=False):
"""Train species of individuals.
Arguments:
n_episodes: The number of episodes to trian for.
n_steps: The maximum number of steps per individual per episode.
n_pso_episodes: The number of episodes
debug_mode: If set to True, some features that aren't intended for
testing environments and such are disabled.
"""
sim_start = time()
episode_complete_msg_format = "\r{:03d}/{:03d} - " \
"mean fitness: {:.2f} - " \
"median fitness: {:.2f} - " \
"mean time per creature: {:02.4f}s - " \
"total time: {:.4f}s"
for episode in range(n_episodes):
print('Episode {:02d}/{:02d}'.format(episode + 1, n_episodes))
episode_start = time()
self.fitness_history.append([])
if n_pso_episodes > 0:
print('Acquiring Collective Intelligence...')
for species in self.population.species:
pso = PSO(self.env, species.members)
pso.train(n_episodes=n_pso_episodes, n_steps=n_steps)
pso.apply()
print('\nCollective Intelligence acquired in {:.4f}s.'
.format(time() - episode_start))
print('Evaluating Population Goodness...')
for pop_i, creature in enumerate(self.population.creatures):
observation = self.env.reset()
for step in range(n_steps):
action = creature.get_action(observation)
observation, reward, done, _ = self.env.step(action)
if done:
creature.fitness = step + 1
break
else:
creature.fitness = n_steps
self.fitness_history[episode].append(creature.fitness)
mean_fitness = np.mean(self.fitness_history[episode])
median_fitness = np.median(self.fitness_history[episode])
episode_time = time() - episode_start
mean_time_per_creature = episode_time / (pop_i + 1)
print(episode_complete_msg_format
.format(pop_i + 1, self.population.n_pops, mean_fitness,
median_fitness, mean_time_per_creature,
episode_time),
end='')
if episode >= self.n_trials and \
np.mean(self.fitness_history[episode - self.n_trials:episode]) \
>= self.reward_threshold:
print('\nSolved in %d episodes :)' % (episode + 1))
break
print()
if not debug_mode:
self.make_snapshot()
self.population.next_generation()
print('Total episode time: %.4fs.\n' % (time() - episode_start))
else:
print('Could not solve in %d episodes :(' % n_episodes)
print('Total run time: {:.2f}s'.format(time() - sim_start))
print()
if not self.offline:
r = requests.request('PATCH', '%s/runs/%s/finished' %
(NeatAlgorithm.api_url, self.run_id))
if r.status_code != 204:
print("WARNING: Was not able to update run finished status.")
print(r.json())
self.post_training_stuff(n_steps, debug_mode)
def post_training_stuff(self, n_steps, debug_mode=False):
"""Do post training stuff."""
print('Here are the species that made it to the end and the number of '
'creatures in each of them:')
self.population.list_species()
best_species = self.population.best_species
print()
oldest_creature = self.population.oldest_creature
print('The oldest creature was %s, who lived for %d generations.' %
(oldest_creature, oldest_creature.age))
print('Out of these species, the best species was %s.' % best_species)
print('The overall champion was %s who had %d nodes and %d '
'connections in its neural network.' %
(best_species.champion,
len(best_species.champion.phenotype.nodes),
len(best_species.champion.phenotype.connections)))
print()
print('Checking if %s makes the grade...' % best_species.champion,
end='')
makes_the_grade = self.makes_the_grade(best_species.champion, n_steps)
print(('\r%s makes the grade :)' if makes_the_grade else
'\r%s doesn\'t make the grade :(') % best_species.champion)
print()
if not debug_mode:
print("Recording %s" % best_species.champion)
self.record_video(best_species.champion, n_steps=n_steps)
self.dump()
self.env.close()
def makes_the_grade(self, creature, n_steps):
"""Check if the creature 'passes' the environment.
Returns: True if the creature passes, False otherwise.
"""
avg_reward = 0
env = self.env
for episode in range(self.n_trials):
observation = env.reset()
episode_reward = 0
for step in range(n_steps):
action = creature.get_action(observation)
observation, reward, done, _ = env.step(action)
episode_reward += reward
if done:
break
avg_reward += episode_reward
return (avg_reward / self.n_trials) >= self.reward_threshold
def record_video(self, creature, n_episodes=20, n_steps=200):
"""Record a video of the creature trying to solve the problem.
Arguments:
creature: the creature to record.
n_episodes: how many episodes to record.
n_steps: how many steps to run each episode for.
"""
env = wrappers.Monitor(self.env, './data/videos/%s' % time())
for i_episode in range(n_episodes):
observation = env.reset()
for step in range(n_steps):
env.render()
action = creature.get_action(observation)
observation, _, done, _ = env.step(action)
if done:
print("Episode finished after {} timesteps"
.format(step + 1))
break
def dump(self, path='data/training/', filename=None):
"""Save training data to file."""
if path[-1] != '/':
path += '/'
path += self.run_id + '/'
fullpath = path + (filename if filename else 'dump.json')
os.makedirs(path, exist_ok=True)
with open(fullpath, 'w') as f:
json.dump(self.to_json(), f)
print('Saved training data to: %s.' % fullpath)
def make_snapshot(self, path='data/training/', filename=None):
"""Save training data to file."""
if path[-1] != '/':
path += '/'
path += self.run_id + '/'
self.snapshot_i += 1
fullpath = path + (filename if filename else
'snapshot-%02d.json' % self.snapshot_i)
os.makedirs(path, exist_ok=True)
with open(fullpath, 'w') as f:
json.dump(self.population.to_json(), f)
def to_json(self):
"""Encode the current state of the algorithm as JSON.
This saves pretty much everything from parameters to individual
creatures.
Returns: the generated JSON.
"""
return dict(
run_id=self.run_id,
env=self.env.unwrapped.spec.id,
population=self.population.to_json(),
settings=dict(
survival_threshold=Population.survival_threshold,
compatibility_threshold=Species.compatibility_threshold,
p_interspecies_mating=Species.p_interspecies_mating,
disjointedness_importance=Creature.disjointedness_importance,
excessivity_importance=Creature.excessivity_importance,
weight_unsameness_importance=
Creature.weight_unsameness_importance,
p_mate_only=Creature.p_mate_only,
p_mutate_only=Creature.p_mutate_only,
p_mate_average=Genome.p_mate_average,
p_mate_choose=Genome.p_mate_choose,
p_add_node=Genome.p_add_node,
p_add_connection=Genome.p_add_connection,
p_re_enable_connection=Genome.p_re_enable_connection,
p_perturb=Genome.p_perturb,
perturb_range=Genome.perturb_range
)
)
@staticmethod
def from_json(config, offline=False):
"""Load an instance of the NEAT algorithm from JSON.
Arguments:
config: the JSON dictionary loaded from file.
offline: Whether the NEAT instance should be started 'offline'.
Returns: an instance of the NEAT algorithm.
"""
env = gym.make(config['env'])
algo = NeatAlgorithm(env, offline=offline)
algo.run_id = config['run_id']
algo.n_trials = env.spec.trials
algo.reward_threshold = env.spec.reward_threshold
algo.population = Population.from_json(config['population'])
NeatAlgorithm.set_config(config['settings'])
return algo
@staticmethod
def set_config(config):
"""Set the parameters of the NEAT algorithm.
Arguments:
config: A dictionary containing the key-value pairs for the
algorithm parameters.
"""
Population.survival_threshold = config['survival_threshold']
Species.compatibility_threshold = config['compatibility_threshold']
Species.p_interspecies_mating = config['p_interspecies_mating']
Creature.disjointedness_importance = config['disjointedness_importance']
Creature.excessivity_importance = config['excessivity_importance']
Creature.p_mate_only = config['p_mate_only']
Creature.p_mutate_only = config['p_mutate_only']
Genome.p_mate_average = config['p_mate_average']
Genome.p_mate_choose = config['p_mate_choose']
Genome.p_add_node = config['p_add_node']
Genome.p_add_connection = config['p_add_connection']
Genome.p_re_enable_connection = config['p_re_enable_connection']
Genome.p_perturb = config['p_perturb']
Genome.perturb_range = config['perturb_range']