In this work, we propose a new “editing-based” method, i.e., Attribute Group Editing (AGE), for few-shot image generation. The basic assumption is that any image is a collection of attributes and the editing direction for a specific attribute is shared across all categories. AGE examines the internal representation learned in GANs and identifies semantically meaningful directions. Specifically, the class embedding, i.e., the mean vector of the latent codes from a specific category, is used to represent the category-relevant attributes, and the category-irrelevant attributes are learned globally by Sparse Dictionary Learning on the difference between the sample embedding and the class embedding. Given a GAN well trained on seen categories, diverse images of unseen categories can be synthesized through editing category-irrelevant attributes while keeping category-relevant attributes unchanged. Without re-training the GAN, AGE is capable of not only producing more realistic and diverse images for downstream visual applications with limited data but achieving controllable image editing with interpretable category-irrelevant directions.
Comparison between images generated by MatchingGAN, LoFGAN, and AGE on Flowers, Animal Faces, and VGGFaces.
Official implementation of AGE for few-shot image generation. Our code is modified from pSp.
- Linux
- NVIDIA GPU + CUDA CuDNN (CPU may be possible with some modifications, but is not inherently supported)
- Python 3
- Clone this repo:
git clone https://github.com/UniBester/AGE.git
cd AGE
- Dependencies:
We recommend running this repository using Anaconda. All dependencies for defining the environment are provided inenvironment/environment.yaml
.
Here, we use pSp to find the latent code of real images in the latent domain of a pretrained StyleGAN generator. Follow the instructions to train a pSp model firsly. Or you can also directly download the pSp pre-trained models we provide.
-
You should first download the Animal Faces / Flowers / VggFaces and organize the file structure as follows:
└── data_root ├── train | ├── cate-id_sample-id.jpg # train-img | └── ... # ... └── valid ├── cate-id_sample-id.jpg # valid-img └── ... # ...
Here, we provide organized Animal Faces dataset as an example:
└── data_root ├── train | ├── n02085620_25.JPEG_238_24_392_167.jpg | └── ... └── valid ├── n02093754_14.JPEG_80_18_239_163.jpg └── ...
-
Currently, we provide support for numerous datasets.
- Refer to
configs/paths_config.py
to define the necessary data paths and model paths for training and evaluation. - Refer to
configs/transforms_config.py
for the transforms defined for each dataset. - Finally, refer to
configs/data_configs.py
for the data paths for the train and valid sets as well as the transforms.
- Refer to
-
If you wish to experiment with your own dataset, you can simply make the necessary adjustments in
data_configs.py
to define your data paths.transforms_configs.py
to define your own data transforms.
To train AGE, the class embedding of each category in both train and test split should be get first by using tools/get_class_embedding.py
.
python tools/get_class_embedding.py \
--class_embedding_path=/path/to/save/classs/embeddings \
--psp_checkpoint_path=/path/to/pretrained/pSp/checkpoint \
--train_data_path=/path/to/training/data \
--test_batch_size=4 \
--test_workers=4
The main training script can be found in tools/train.py
.
Intermediate training results are saved to opts.exp_dir
. This includes checkpoints, train outputs, and test outputs.
Additionally, if you have tensorboard installed, you can visualize tensorboard logs in opts.exp_dir/logs
.
#set GPUs to use.
export CUDA_VISIBLE_DEVICES=0,1,2,3
#begin training.
python -m torch.distributed.launch \
--nproc_per_node=4 \
tools/train.py \
--dataset_type=af_encode \
--exp_dir=/path/to/experiment/output \
--workers=8 \
--batch_size=8 \
--valid_batch_size=8 \
--valid_workers=8 \
--val_interval=2500 \
--save_interval=5000 \
--start_from_latent_avg \
--l2_lambda=1 \
--sparse_lambda=0.005 \
--orthogonal_lambda=0.0005 \
--A_length=100 \
--psp_checkpoint_path=/path/to/pretrained/pSp/checkpoint \
--class_embedding_path=/path/to/class/embeddings
Having trained your model or using pre-trained models we provide, you can use tools/inference.py
to apply the model on a set of images.
For example,
python tools/inference.py \
--output_path=/path/to/output \
--checkpoint_path=/path/to/checkpoint \
--test_data_path=/path/to/test/input \
--train/data_path=/path/to/training/data \
--class_embedding_path=/path/to/classs/embeddings \
--n_distribution_path=/path/to/save/n/distribution \
--test_batch_size=4 \
--test_workers=4 \
--n_images=5 \
--alpha=1 \
--beta=0.005
Path | Description |
---|---|
AGE | Repository root folder |
├ configs | Folder containing configs defining model/data paths and data transforms |
├ criteria | Folder containing various loss criterias for training |
├ datasets | Folder with various dataset objects and augmentations |
├ environment | Folder containing Anaconda environment used in our experiments |
├ models | Folder containting all the models and training objects |
│ ├ encoders | Folder containing our pSp encoder architecture implementation and ArcFace encoder implementation from TreB1eN |
│ ├ stylegan2 | StyleGAN2 model from rosinality |
│ └ age.py | Implementation of our AGE |
├ options | Folder with training and test command-line options |
├ tools | Folder with running scripts for training and inference |
├ optimizer | Folder with Ranger implementation from lessw2020 |
└ utils | Folder with various utility functions |
If you use this code for your research, please cite our paper Attribute Group Editing for Reliable Few-shot Image Generation:
@inproceedings{ding2022attribute,
title={Attribute Group Editing for Reliable Few-shot Image Generation},
author={Ding, Guanqi and Han, Xinzhe and Wang, Shuhui and Wu, Shuzhe and Jin, Xin and Tu, Dandan and Huang, Qingming},
booktitle=CVPR,
year={2022},
}