Skip to content

Commit

Permalink
Updated optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
lukemelas committed Apr 11, 2021
1 parent 046793b commit b5e26b3
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 12 deletions.
11 changes: 5 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,13 @@ The pretrained weights for most GANs are downloaded automatically. For those tha

There are also some standard dependencies:
- PyTorch (tested on version 1.7.1, but should work on any version)
- Hydra 1.1
- Albumentations, Retry
- [Optional] Weights and Biases
- [Hydra](https://github.com/facebookresearch/hydra) 1.1
- [Albumentations](https://github.com/albumentations-team/albumentations), [Kornia](https://github.com/kornia/kornia), [Retry](https://github.com/invl/retry)
- [Optional] [Weights and Biases](https://wandb.ai/)

Install them with:
```bash
pip install hydra-core --pre
pip install albumentations tqdm retry tensorboard
pip install hydra-core==1.1.0dev5 albumentations tqdm retry kornia
```


Expand Down Expand Up @@ -151,7 +150,7 @@ _Note:_ All commands are called from within the `src` directory.

In the example commands below, we use BigBiGAN. You can easily switch out BigBiGAN for another model if you would like to.

**Optimizaion**
**Optimization**
```bash
PYTHONPATH=. python optimization/main.py data_gen/generator=bigbigan name=NAME
```
Expand Down
1 change: 0 additions & 1 deletion src/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
from .unet_model import UNet
from .ensemble import Ensemble
9 changes: 4 additions & 5 deletions src/optimization/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from models.latent_shift_model import MODELS
from optimization import utils

# from pytorch_pretrained_gans import make_gan
from pytorch_pretrained_gans import make_gan


class UnsupervisedSegmentationLoss(torch.nn.Module):
Expand Down Expand Up @@ -65,7 +65,7 @@ def run(cfg: DictConfig):
device = torch.device('cuda')

# Load GAN
G = make_gan(gan_type=cfg.generator.gan_type, **cfg.generator.kwargs)
G = make_gan(gan_type=cfg.data_gen.generator.gan_type, **cfg.data_gen.generator.kwargs)
G.eval().to(device)
utils.set_requires_grad(G, False)

Expand All @@ -86,7 +86,7 @@ def run(cfg: DictConfig):
scheduler = Scheduler(optimizer, **cfg.scheduler.kwargs)

# Loss function
criterion = UnsupervisedSegmentationLoss(cfg.losses, image_size=cfg.generator.image_size)
criterion = UnsupervisedSegmentationLoss(cfg.losses, image_size=cfg.data_gen.generator.image_size)
criterion.to(device)

# Fixed vectors for visualization
Expand Down Expand Up @@ -138,8 +138,7 @@ def run(cfg: DictConfig):
# Visualize with Tensorboard and save to file
if i % cfg.vis_every == 0:
img_grid = utils.create_grid(G=G, model=model, zs=z_vis_fixed, ys=y_vis_fixed, n_imgs=8)
img_file = "images" / f"{i}.png"
img_file.parent.mkdir(parents=True, exist_ok=True)
img_file = f"visualization-{i:05d}.png"
img_grid = (img_grid * 255).astype(np.uint8)
Image.fromarray(img_grid).save(str(img_file))

Expand Down
2 changes: 2 additions & 0 deletions src/optimization/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pathlib import Path
import random
import numpy as np
import torch
Expand Down Expand Up @@ -30,6 +31,7 @@ def create_grid(G, model, zs, ys, n_imgs=8, rs=[0, 2, 4, 8]):


def save_checkpoint(model, checkpoint_dir='.', name='latest.pth', **kwargs):
checkpoint_dir = Path(checkpoint_dir)
checkpoint_dir.mkdir(exist_ok=True)
torch.save(dict(
state_dict=model.state_dict(),
Expand Down

0 comments on commit b5e26b3

Please sign in to comment.