Skip to content

Latest commit

 

History

History
76 lines (57 loc) · 3.29 KB

README.md

File metadata and controls

76 lines (57 loc) · 3.29 KB

Continual Distillation Learning

PyTorch code for the paper:
Continual Distillation Learning
Qifan Zhang, Yunhui Guo, Yu Xiang

arXiv, Project

Abstract

We study the problem of Continual Distillation Learning (CDL) that considers Knowledge Distillation (KD) in the Continual Learning (CL) setup. A teacher model and a student model need to learn a sequence of tasks, and the knowledge of the teacher model will be distillated to the student in order to improve the student model. We introduce a novel method named CDL-Prompt that leverages prompt- based continual learning models to build the teacher-student model. We investigate how to utilize the prompts of the teacher model in the student model for knowledge distillation, and propose an attention-based prompt mapping scheme to use the teacher prompts for the student. We demonstrate that our method can be applied to different prompt-based continual learning models such as L2P, DualPrompt and CODA-Prompt to improve their performance using powerful teacher models. While recent CL methods focus on prompt learning, we show that our method can be utilized to build efficient CL models using prompt-based knowledge distillation.

Setup

  • Install anaconda: https://www.anaconda.com/distribution/
  • set up conda environment w/ python 3.8, ex: conda create --name coda python=3.8
  • conda activate CDL
  • sh install_requirements.sh
  • NOTE: this framework was tested using torch == 2.0.0 but should work for previous versions

Datasets

Training

All commands should be run under the project root directory. The scripts are set up for 2 GPUs but can be modified for your hardware.

sh experiments/cifar100.sh
sh experiments/imagenet-r.sh

Or you can directly run the run.py and test on ImageNet-R dataset:

python -u run.py --config configs/imnet-r_prompt.yaml --overwrite 0 \
    --learner_type prompt --learner_name CODAPrompt \
    --prompt_param 100 8 0.0 \
    --log_dir Demo_test/coda-p \
    --t_model 'vit_base_patch16_224' \
    --s_model 'vit_small_patch16_224' \
    --random_s 1
  • You can change the learner_name for DualPrompt or L2P.(And change the prompt_param for different learner. Check the imagenet-r.sh)
  • You can adjust the teacher and student's model with --t_model and --s_model.
  • Change the random_s(random seed) for different results.

Results

The results will be saved in the created --log_dir folder, including the models for the teacher and student as well as the final average accuracy for both the teacher and student.

Citation

If you find the method useful in your research, please consider citing:

@misc{lu2024adapting,
    title={Continual Distillation Learning},
    author={Qifan Zhang and Yunhui Guo and Yu Xiang},
    year={2024},
    eprint={2407.13911},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}

Acknowledgments

This project is based on the following repositories: