nnablaRL is a deep reinforcement learning library built on top of Neural Network Libraries that is intended to be used for research, development and production.
Installing nnablaRL is easy!
$ pip install nnabla-rl
nnablaRL only supports Python version >= 3.8 and nnabla version >= 1.17.
nnablaRL algorithms run on CPU by default. To run the algorithm on GPU, first install nnabla-ext-cuda as follows. (Replace [cuda-version] depending on the CUDA version installed on your machine.)
$ pip install nnabla-ext-cuda[cuda-version]
# Example installation. Supposing CUDA 11.0 is installed on your machine.
$ pip install nnabla-ext-cuda110
After installing nnabla-ext-cuda, set the gpu id to run the algorithm on through algorithm's configuration.
import nnabla_rl.algorithms as A
config = A.DQNConfig(gpu_id=0) # Use gpu 0. If negative, will run on CPU.
dqn = A.DQN(env, config=config)
...
nnablaRL has friendly Python APIs which enables to start training with only 3 lines of python code. (NOTE: Below code will run on CPU. See the above instruction to run on GPU.)
import nnabla_rl.algorithms as A
from nnabla_rl.utils.reproductions import build_classic_control_env
# Prerequisite:
# Run below to enable rendering!
# $ pip install nnabla-rl[render]
env = build_classic_control_env("Pendulum-v1", render=True) # 1
ddpg = A.DDPG(env, config=A.DDPGConfig(start_timesteps=200)) # 2
ddpg.train(env) # 3
To get more details about nnablaRL, see documentation and examples.
Most of famous/SOTA deep reinforcement learning algorithms, such as DQN, SAC, BCQ, GAIL, etc., are implemented in nnablaRL. Implemented algorithms are carefully tested and evaluated. You can easily start training your agent using these verified implementations.
For the list of implemented algorithms see here.
You can also find the reproduction and evaluation results of each algorithm here.
Note that you may not get completely the same results when running the reproduction code on your computer. The result may slightly change depending on your machine, nnabla/nnabla-rl's package version, etc.
In reinforcement learning, there are two main training procedures, online and offline, to train the agent. Online training is a training procedure that executes both data collection and network update alternately. Conversely, offline training is a training procedure that updates the network using only existing data. With nnablaRL, you can switch these two training procedures seamlessly. For example, as shown below, you can easily train a robot's controller online using simulated environment and finetune it offline with real robot dataset.
import nnabla_rl
import nnabla_rl.algorithms as A
simulator = get_simulator() # This is just an example. Assuming that simulator exists
dqn = A.DQN(simulator)
# train online for 1M iterations
dqn.train_online(simulator, total_iterations=1000000)
real_data = get_real_robot_data() # This is also an example. Assuming that you have real robot data
# fine tune the agent offline for 10k iterations using real data
dqn.train_offline(real_data, total_iterations=10000)
nnablaRL supports visualization of training graphs and training progresses with nnabla-browser!
import gym
import nnabla_rl.algorithms as A
import nnabla_rl.hooks as H
import nnabla_rl.writers as W
from nnabla_rl.utils.evaluator import EpisodicEvaluator
# save training computational graph
training_graph_hook = H.TrainingGraphHook(outdir="test")
# evaluation hook with nnabla's Monitor
eval_env = gym.make("Pendulum-v0")
evaluator = EpisodicEvaluator(run_per_evaluation=10)
evaluation_hook = H.EvaluationHook(
eval_env,
evaluator,
timing=10,
writer=W.MonitorWriter(outdir="test", file_prefix='evaluation_result'),
)
env = gym.make("Pendulum-v0")
sac = A.SAC(env)
sac.set_hooks([training_graph_hook, evaluation_hook])
sac.train_online(env, total_iterations=100)
Try below interactive demos to get started.
You can run it directly on Colab from the links in the table below.
Full documentation is here.
Any kind of contribution to nnablaRL is welcome! See the contribution guide for details.
nnablaRL is provided under the Apache License Version 2.0 license.