diff --git a/README.md b/README.md index 04d41cf..748cf0f 100644 --- a/README.md +++ b/README.md @@ -24,14 +24,13 @@ The pretrained weights for most GANs are downloaded automatically. For those tha There are also some standard dependencies: - PyTorch (tested on version 1.7.1, but should work on any version) - - Hydra 1.1 - - Albumentations, Retry - - [Optional] Weights and Biases + - [Hydra](https://github.com/facebookresearch/hydra) 1.1 + - [Albumentations](https://github.com/albumentations-team/albumentations), [Kornia](https://github.com/kornia/kornia), [Retry](https://github.com/invl/retry) + - [Optional] [Weights and Biases](https://wandb.ai/) Install them with: ```bash -pip install hydra-core --pre -pip install albumentations tqdm retry tensorboard +pip install hydra-core==1.1.0dev5 albumentations tqdm retry kornia ``` @@ -151,7 +150,7 @@ _Note:_ All commands are called from within the `src` directory. In the example commands below, we use BigBiGAN. You can easily switch out BigBiGAN for another model if you would like to. -**Optimizaion** +**Optimization** ```bash PYTHONPATH=. python optimization/main.py data_gen/generator=bigbigan name=NAME ``` diff --git a/src/models/__init__.py b/src/models/__init__.py index 41f3fb0..2e9b63b 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -1,2 +1 @@ from .unet_model import UNet -from .ensemble import Ensemble diff --git a/src/optimization/main.py b/src/optimization/main.py index 8b6f87c..a8c98e4 100644 --- a/src/optimization/main.py +++ b/src/optimization/main.py @@ -11,7 +11,7 @@ from models.latent_shift_model import MODELS from optimization import utils -# from pytorch_pretrained_gans import make_gan +from pytorch_pretrained_gans import make_gan class UnsupervisedSegmentationLoss(torch.nn.Module): @@ -65,7 +65,7 @@ def run(cfg: DictConfig): device = torch.device('cuda') # Load GAN - G = make_gan(gan_type=cfg.generator.gan_type, **cfg.generator.kwargs) + G = make_gan(gan_type=cfg.data_gen.generator.gan_type, **cfg.data_gen.generator.kwargs) G.eval().to(device) utils.set_requires_grad(G, False) @@ -86,7 +86,7 @@ def run(cfg: DictConfig): scheduler = Scheduler(optimizer, **cfg.scheduler.kwargs) # Loss function - criterion = UnsupervisedSegmentationLoss(cfg.losses, image_size=cfg.generator.image_size) + criterion = UnsupervisedSegmentationLoss(cfg.losses, image_size=cfg.data_gen.generator.image_size) criterion.to(device) # Fixed vectors for visualization @@ -138,8 +138,7 @@ def run(cfg: DictConfig): # Visualize with Tensorboard and save to file if i % cfg.vis_every == 0: img_grid = utils.create_grid(G=G, model=model, zs=z_vis_fixed, ys=y_vis_fixed, n_imgs=8) - img_file = "images" / f"{i}.png" - img_file.parent.mkdir(parents=True, exist_ok=True) + img_file = f"visualization-{i:05d}.png" img_grid = (img_grid * 255).astype(np.uint8) Image.fromarray(img_grid).save(str(img_file)) diff --git a/src/optimization/utils.py b/src/optimization/utils.py index 8c2ca33..79a1659 100644 --- a/src/optimization/utils.py +++ b/src/optimization/utils.py @@ -1,3 +1,4 @@ +from pathlib import Path import random import numpy as np import torch @@ -30,6 +31,7 @@ def create_grid(G, model, zs, ys, n_imgs=8, rs=[0, 2, 4, 8]): def save_checkpoint(model, checkpoint_dir='.', name='latest.pth', **kwargs): + checkpoint_dir = Path(checkpoint_dir) checkpoint_dir.mkdir(exist_ok=True) torch.save(dict( state_dict=model.state_dict(),