Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Parallel Q-Networks algorithm (PQN) #472

Open
wants to merge 6 commits into
base: master
Choose a base branch
from

Conversation

roger-creus
Copy link

@roger-creus roger-creus commented Jul 17, 2024

Description

Adding PQN from Simplifying Deep Temporal Difference Learning

I have implemented both pqn.py and pqn_atari_envpool.py. The results are promising for the Cartpole version. Check them out here. I am now running some debugging experiments for the Atari version.

Some details about the implementations:

  • Both use envpool
  • Hyperaprameters try to match the configs from the official implementations but some are changed (epsilon-decay schedule matches the DQN implementation from cleanRL. I haven't checked the importance of the hyperparameter the defaults in CleanRL made more sense to me)
  • For comparing pqn.py and dqn.py in cartpole I multiplied the rewards from the environment by 0.1 as done in the official implementation of PQN. performance increases for both algos.
  • Using LayerNorm in the networks instead of allowing the user to select between Layer or Batch norm. Layer norm should work better.
  • Not giving the user the option to add BatchNorm to the inputs to the network (i.e. states) as in the official implementaiton.

Overall the implementation is similar to ppo with envpool (so very fast!) but with the sample-efficiency of Q-learning! Nice algorithm! :)

Let me know how to proceed!

Types of changes

  • Bug fix
  • New feature
  • New algorithm
  • Documentation

Checklist:

  • I've read the CONTRIBUTION guide (required).
  • I have ensured pre-commit run --all-files passes (required).
  • I have updated the tests accordingly (if applicable).
  • I have updated the documentation and previewed the changes via mkdocs serve.
    • I have explained note-worthy implementation details.
    • I have explained the logged metrics.
    • I have added links to the original paper and related papers.

If you need to run benchmark experiments for a performance-impacting changes:

  • I have contacted @vwxyzjn to obtain access to the openrlbenchmark W&B team.
  • I have used the benchmark utility to submit the tracked experiments to the openrlbenchmark/cleanrl W&B project, optionally with --capture_video.
  • I have performed RLops with python -m openrlbenchmark.rlops.
    • For new feature or bug fix:
      • I have used the RLops utility to understand the performance impact of the changes and confirmed there is no regression.
    • For new algorithm:
      • I have created a table comparing my results against those from reputable sources (i.e., the original paper or other reference implementation).
    • I have added the learning curves generated by the python -m openrlbenchmark.rlops utility to the documentation.
    • I have added links to the tracked experiments in W&B, generated by python -m openrlbenchmark.rlops ....your_args... --report, to the documentation.

Copy link

vercel bot commented Jul 17, 2024

The latest updates on your projects. Learn more about Vercel for Git ↗︎

Name Status Preview Comments Updated (UTC)
cleanrl ✅ Ready (Inspect) Visit Preview 💬 Add feedback Jul 26, 2024 1:25pm

@sdpkjc
Copy link
Collaborator

sdpkjc commented Jul 17, 2024

Hey Roger, it's really cool to see you adding PQN to CleanRL! I've read the paper before, and I think your implementation is great. When it comes time to run benchmarks or add documentation, let's collaborate to see how we can best do it. Looking forward to seeing the completed PR! 🚀👍

@roger-creus
Copy link
Author

roger-creus commented Jul 17, 2024

I think the code might be ready to be benchmarked. These are some results in Breakout. It seems to converge to 400 score in 10M which would match DQN. The official imlpementation reports 515 score after 400M steps. Should I be added to the openrlbenchmark W&B team?

image

@sdpkjc
Copy link
Collaborator

sdpkjc commented Jul 18, 2024

I think the code might be ready to be benchmarked. These are some results in Breakout. It seems to converge to 400 score in 10M which would match DQN. The official imlpementation reports 515 score after 400M steps. Should I be added to the openrlbenchmark W&B team?

image

I noticed that the epsilon greedy implementation in our current setup differs from the official one, where each environment independently performs epsilon greedy exploration, whereas in our implementation, all environments share a single random number. This might have an impact when running many environments in parallel. Of course, there could be other reasons for the performance differences too. Let's start by running some benchmark tests to see if the performance also falls short in other environments. Looking forward to working through this together!

https://github.com/mttga/purejaxql/blob/9878a74439593c5d0acc8e506fefc44daa230c51/purejaxql/pqn_atari.py#L312-L325

… some envs can explore and some exploit, like in the official implementation
@roger-creus
Copy link
Author

Very nice catch! Let me try to set up the benchmark experiments :)

@roger-creus
Copy link
Author

Here are some first results!
I think they look pretty good but maybe in BeamRider-v5 is falling a bit short. Let me double check the implementation and run some more experiments

@vwxyzjn
Copy link
Owner

vwxyzjn commented Jul 19, 2024

Been watching this from far, very cool work!!

@pseudo-rnd-thoughts
Copy link
Collaborator

pseudo-rnd-thoughts commented Jul 19, 2024

Nice job, your results show it takes 25 minutes for 10 million frames while the paper reports 200 million in an hour.
Do you know why there are such significant differences in performance?

No equivalent to jax.jit or jax.lax.scan?

@roger-creus
Copy link
Author

Updated results here. I wonder how should I generate the comparison between DQN/PQN with the rlops function since I am using envpool and I am not being able to compare pqn_atari_envpool.py in Breakout-v5 vs dqn_atari.py in BreakoutNoFrameskip-v4 for instance. Should I make a version of PQN that doesn't use envpool?

@pseudo-rnd-thoughts It is probably because jax.lax.scan. I am not used to coding in jax, but after searching online, it seems pytorch does not have a function like scan...

@sdpkjc
Copy link
Collaborator

sdpkjc commented Jul 20, 2024

Maybe try torch.compile?

@roger-creus
Copy link
Author

Hey! How do you think we should proceed?

I believe that it will be hard to match the speed of the JAX-based original implementation in this torch implementation, but at least it provides a Q-learning alternative + envpool that matches CleanRL envpool PPO, which can already be very useful! :)

@roger-creus
Copy link
Author

I realized I was re-computing the values for each state in the rollouts when computing Q(lambda) returns. I have now used a values buffer (as used in the PPO implementations actually) and replaced the computations in the Q(lambda) process. Performance remains the same and the code is now approx 150% faster.

Also, I added pqn_atari_lstm_envpool! First results in the atari environments show the implementation is correct. Please double check! :)

Please let me know how we should continue!

@pseudo-rnd-thoughts
Copy link
Collaborator

@roger-creus There is a larger issue of EnvPool with rollouts and computing the loss function, see #475

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants