This repo contains implementation of the EDA algorithm used in
Aligning Diffusion Behaviors with Q-functions for Efficient Continuous Control
Huayu Chen, Kaiwen Zheng, Hang Su, Jun Zhu
Tsinghua
EDA allows you to directly finetunes a diffusion behavior policy learned through imitation learning into an RL optimized policy (just like LLM alignment):
EDA leverages a specially-designed diffsuion model architecture so that it can calculate likelihood in one forward step:
Installations of PyTorch, MuJoCo, and D4RL are needed.
Before performing alignment, we need to first pretrain the critic Q model and the bottleneck behavior model.
Respectively run
TASK="halfcheetah-medium-v2"; seed=0; python3 -u train_behavior.py --expid ${TASK}-baseline-seed${seed} --env $TASK --seed ${seed}
TASK="halfcheetah-medium-v2"; seed=0; python3 -u train_critic.py --expid ${TASK}-baseline-seed${seed} --env $TASK --seed ${seed}
Finally, run
TASK="halfcheetah-medium-v2"; seed=0; python3 -u finetune_policy.py --expid ${TASK}-baseline-seed${seed} --env $TASK --seed ${seed} --actor_load_path ./EDA_model_factory/${TASK}-baseline-seed${seed}/behavior_ckpt200.pth --critic_load_path ./EDA_model_factory/${TASK}-baseline-seed${seed}/critic_ckpt150.pth --beta=0.1
For default choices of beta
, please checkout Appendix E in the paper.
If you find our project helpful, please consider citing
@article{chen2024aligning,
title={Aligning Diffusion Behaviors with Q-functions for Efficient Continuous Control},
author={Chen, Huayu and Zheng, Kaiwen and Su, Hang and Zhu, Jun},
journal={arXiv preprint arXiv:2407.09024},
year={2024}
}
MIT