takuseno/d3rlpy

View on GitHub
scripts/create_cartpole_dataset

Summary

Maintainability
Test Coverage
#!/usr/bin/env python

import gym
import d3rlpy

d3rlpy.seed(100)

# prepare environment
env = gym.make('CartPole-v0')
eval_env = gym.make('CartPole-v0')

# prepare algorithms
dqn = d3rlpy.algos.DQN(learning_rate=1e-3, target_update_interval=100)

# prepare utilities
buffer = d3rlpy.online.buffers.ReplayBuffer(maxlen=1000000, env=env)
explorer = d3rlpy.online.explorers.ConstantEpsilonGreedy(epsilon=0.3)

# start training
dqn.fit_online(
    env, buffer=buffer, explorer=explorer, eval_env=eval_env, n_steps=100000
)

# export replay buffer as MDPDataset
dataset = buffer.to_mdp_dataset()

# save MDPDataset
dataset.dump('cartpole.h5')