Skip to content

Commit

Permalink
Small tweaks, doing test run after refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
lukemelas committed Apr 12, 2021
1 parent b5e26b3 commit d72e371
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 9 deletions.
12 changes: 8 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,16 @@ The pretrained weights for most GANs are downloaded automatically. For those tha

There are also some standard dependencies:
- PyTorch (tested on version 1.7.1, but should work on any version)
- [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning)
- [Hydra](https://github.com/facebookresearch/hydra) 1.1
- [Albumentations](https://github.com/albumentations-team/albumentations), [Kornia](https://github.com/kornia/kornia), [Retry](https://github.com/invl/retry)
- [Albumentations](https://github.com/albumentations-team/albumentations)
- [Kornia](https://github.com/kornia/kornia)
- [Retry](https://github.com/invl/retry)
- [Optional] [Weights and Biases](https://wandb.ai/)

Install them with:
```bash
pip install hydra-core==1.1.0dev5 albumentations tqdm retry kornia
pip install hydra-core==1.1.0dev5 pytorch_lightning albumentations tqdm retry kornia
```


Expand Down Expand Up @@ -154,7 +157,7 @@ In the example commands below, we use BigBiGAN. You can easily switch out BigBiG
```bash
PYTHONPATH=. python optimization/main.py data_gen/generator=bigbigan name=NAME
```
The output will be saved in `outputs/optimization/fixed-BigBiGAN-NAME/DATE/`, with the final checkpoint in `latest.pth`.
This should take less than 5 minutes to run. The output will be saved in `outputs/optimization/fixed-BigBiGAN-NAME/DATE/`, with the final checkpoint in `latest.pth`.

**Segmentation with precomputed generations**

Expand All @@ -170,7 +173,7 @@ data_gen.save_size=1000000 \
data_gen.kwargs.batch_size=1 \
data_gen.kwargs.generation_batch_size=128
```
This will generate 1 million image-label pairs and save them to `YOUR_OUTPUT_DIR/images`. Note that `YOUR_OUTPUT_DIR` should be an _absolute path_, not a relative one, because Hydra changes the working directory. You may also want to tune the `generation_batch_size` to maximize GPU utilization on your machine.
This will generate 1 million image-label pairs and save them to `YOUR_OUTPUT_DIR/images`. Note that `YOUR_OUTPUT_DIR` should be an _absolute path_, not a relative one, because Hydra changes the working directory. You may also want to tune the `generation_batch_size` to maximize GPU utilization on your machine. It takes around 3-4 hours to generate 1 million images on a single V100 GPU.

Once you have generated data, you can train a segmentation model:
```bash
Expand All @@ -179,6 +182,7 @@ name=NAME \
data_gen=saved \
data_gen.data.root="YOUR_OUTPUT_DIR_FROM_ABOVE"
```
It takes around 3 hours on 1 GPU to complete 18000 iterations, by which point the model has converged (in fact you can probably get away with fewer steps, I would guess around ~5000).

**Segmentation with on-the-fly generations**

Expand Down
2 changes: 1 addition & 1 deletion src/config/segment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ dataloader:
trainer:
# See https://pytorch-lightning.readthedocs.io/en/stable/trainer.html#trainer-flags
gpus: 1
max_steps: 12000
max_steps: 18000
accelerator: null # "ddp_spawn"
num_sanity_val_steps: 5
fast_dev_run: False
Expand Down
1 change: 1 addition & 0 deletions src/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .unet_model import UNet
from .latent_shift_model import MODELS
2 changes: 1 addition & 1 deletion src/segmentation/generate_and_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def run(cfg: DictConfig):
for output_dict in output_batch:
img = tensor_to_image(output_dict['img'])
mask = tensor_to_mask(output_dict['mask'])
y = int(output_dict['y'])
y = int(output_dict['y']) if 'y' in output_dict else 0
stem = f'{i:08d}-seed_{cfg.seed}-class_{y:03d}'
img.save(save_dir / f'{stem}.jpg')
mask.save(save_dir / f'{stem}.png')
Expand Down
6 changes: 3 additions & 3 deletions src/segmentation/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import logging
import hydra

from . import utils
from . import metrics
from segmentation import utils
from segmentation import metrics
from models import UNet
from datasets import SegmentationDataset, create_gan_dataset, create_train_and_val_dataloaders

Expand Down Expand Up @@ -189,7 +189,7 @@ def main(cfg: DictConfig):
]

# Logging
logger = pl.loggers.WandbLogger(name=cfg.name) if cfg.wanbd else True
logger = pl.loggers.WandbLogger(name=cfg.name) if cfg.wandb else True

# Trainer
trainer = pl.Trainer(logger=logger, callbacks=callbacks, **cfg.trainer)
Expand Down

0 comments on commit d72e371

Please sign in to comment.