Skip to content

Commit

Permalink
fix no eval bug (#790)
Browse files Browse the repository at this point in the history
  • Loading branch information
firestonelib authored Sep 23, 2022
1 parent 14890de commit 13b4341
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 95 deletions.
2 changes: 2 additions & 0 deletions ppfleetx/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def build_dataset(config, mode):

def build_dataloader(config, mode):
dataset = build_dataset(config, mode)
if dataset is None:
return None

batch_sampler = None
# build sampler
Expand Down
4 changes: 2 additions & 2 deletions ppfleetx/data/dataset/multimodal_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,11 @@ def data_augmentation_for_imagen(img, resolution):
arr = deepcopy(img)
while min(*arr.size) >= 2 * resolution:
arr = arr.resize(
tuple(x // 2 for x in arr.size), resample=Image.Resampling.BOX)
tuple(x // 2 for x in arr.size), resample=Image.BOX)
scale = resolution / min(*arr.size)
arr = arr.resize(
tuple(round(x * scale) for x in arr.size),
resample=Image.Resampling.BICUBIC)
resample=Image.BICUBIC)

arr = np.array(arr.convert("RGB"))
crop_y = (arr.shape[0] - resolution) // 2
Expand Down
52 changes: 20 additions & 32 deletions ppfleetx/models/multimodal_model/imagen/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
import paddle.vision.transforms as T

from .unet import Unet
from .utils import (GaussianDiffusionContinuousTimes, default, exists,
cast_tuple, first, maybe, eval_decorator, identity,
from .utils import (GaussianDiffusionContinuousTimes, default, cast_tuple,
first, maybe, eval_decorator, identity,
pad_tuple_to_length, right_pad_dims_to, resize_image_to,
normalize_neg_one_to_one, rearrange, repeat, reduce,
unnormalize_zero_to_one, cast_uint8_images_to_float)
Expand Down Expand Up @@ -195,9 +195,8 @@ def __init__(self,
# randomly cropping for upsampler training

self.random_crop_sizes = cast_tuple(random_crop_sizes, num_unets)
assert not exists(
first(self.random_crop_sizes)
), 'you should not need to randomly crop image during training for base unet, only for upsamplers - so pass in `random_crop_sizes = (None, 128, 256)` as example'
assert first(
self.random_crop_sizes) is None, 'you should not need to randomly crop image during training for base unet, only for upsamplers - so pass in `random_crop_sizes = (None, 128, 256)` as example'
# lowres augmentation noise schedule

self.lowres_noise_schedule = GaussianDiffusionContinuousTimes(
Expand Down Expand Up @@ -284,22 +283,17 @@ def get_unet(self, unet_number):
assert 0 < unet_number <= len(self.unets)
index = unet_number - 1

if isinstance(self.unets, nn.LayerList):
unets_list = [unet for unet in self.unets]
delattr(self, 'unets')
self.unets = unets_list
self.unet_being_trained_index = index
return self.unets[index]

def reset_unets(self, ):
self.unets = nn.LayerList([*self.unets])
self.unet_being_trained_index = -1

@contextmanager
def one_unet_in_gpu(self, unet_number=None, unet=None):
assert exists(unet_number) ^ exists(unet)
assert (unet_number is not None) ^ (unet is not None)

if exists(unet_number):
if unet_number is not None:
unet = self.unets[unet_number - 1]

yield
Expand All @@ -320,7 +314,6 @@ def p_mean_variance(self,
unet,
x,
t,
*,
noise_scheduler,
text_embeds=None,
text_mask=None,
Expand Down Expand Up @@ -370,7 +363,6 @@ def p_sample(self,
unet,
x,
t,
*,
noise_scheduler,
t_next=None,
text_embeds=None,
Expand Down Expand Up @@ -412,7 +404,6 @@ def p_sample(self,
def p_sample_loop(self,
unet,
shape,
*,
noise_scheduler,
lowres_cond_img=None,
lowres_noise_times=None,
Expand All @@ -433,7 +424,7 @@ def p_sample_loop(self,

# prepare inpainting

has_inpainting = exists(inpaint_images) and exists(inpaint_masks)
has_inpainting = inpaint_images is not None and inpaint_masks is not None
resample_times = inpaint_resample_times if has_inpainting else 1

if has_inpainting:
Expand Down Expand Up @@ -532,18 +523,18 @@ def sample(
batch_size = text_embeds.shape[0]

assert not (
self.condition_on_text and not exists(text_embeds)
self.condition_on_text and text_embeds is None
), 'text or text encodings must be passed into imagen if specified'
assert not (
not self.condition_on_text and exists(text_embeds)
not self.condition_on_text and text_embeds is not None
), 'imagen specified not to be conditioned on text, yet it is presented'
assert not (
exists(text_embeds) and
text_embeds is not None and
text_embeds.shape[-1] != self.text_embed_dim
), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})'

assert not (
exists(inpaint_images) ^ exists(inpaint_masks)
(inpaint_images is not None) ^ (inpaint_masks is not None)
), 'inpaint images and masks must be both passed in to do inpainting'

outputs = []
Expand Down Expand Up @@ -609,8 +600,7 @@ def sample(

outputs.append(img)

if exists(stop_at_unet_number
) and stop_at_unet_number == unet_number:
if stop_at_unet_number is not None and stop_at_unet_number == unet_number:
break

output_index = -1 if not return_all_unet_outputs else slice(
Expand All @@ -633,7 +623,6 @@ def p_losses(self,
unet,
x_start,
times,
*,
noise_scheduler,
lowres_cond_img=None,
lowres_aug_times=None,
Expand All @@ -655,7 +644,7 @@ def p_losses(self,
# random cropping during training
# for upsamplers

if exists(random_crop_size):
if random_crop_size is not None:
aug = K.RandomCrop((random_crop_size, random_crop_size), p=1.)
# make sure low res conditioner and image both get augmented the same way
# detailed https://kornia.readthedocs.io/en/latest/augmentation.module.html?highlight=randomcrop#kornia.augmentation.RandomCrop
Expand All @@ -672,7 +661,7 @@ def p_losses(self,
# at sample time, they then fix the noise level of 0.1 - 0.3

lowres_cond_img_noisy = None
if exists(lowres_cond_img):
if lowres_cond_img is not None:
lowres_aug_times = default(lowres_aug_times, times)
lowres_cond_img_noisy, _ = self.lowres_noise_schedule.q_sample(
x_start=lowres_cond_img,
Expand Down Expand Up @@ -715,11 +704,10 @@ def forward(self,
assert images.shape[-1] == images.shape[
-2], f'the images you pass in must be a square, but received dimensions of {images.shape[2]}, {images.shape[-1]}'
assert not (
len(self.unets) > 1 and not exists(unet_number)
len(self.unets) > 1 and unet_number is None
), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
unet_number = default(unet_number, 1)
assert not exists(
self.only_train_unet_number
assert (self.only_train_unet_number is None
) or self.only_train_unet_number == unet_number, 'you can only train on unet #{self.only_train_unet_number}'

images = cast_uint8_images_to_float(images)
Expand Down Expand Up @@ -748,19 +736,19 @@ def forward(self,
text_masks, lambda: paddle.any(text_embeds != 0., axis=-1))

assert not (
self.condition_on_text and not exists(text_embeds)
self.condition_on_text and text_embeds is None
), 'text or text encodings must be passed into decoder if specified'
assert not (
not self.condition_on_text and exists(text_embeds)
not self.condition_on_text and text_embeds is not None
), 'decoder specified not to be conditioned on text, yet it is presented'

assert not (
exists(text_embeds) and
(text_embeds is not None) and
text_embeds.shape[-1] != self.text_embed_dim
), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})'

lowres_cond_img = lowres_aug_times = None
if exists(prev_image_size):
if prev_image_size is not None:
lowres_cond_img = resize_image_to(
images, prev_image_size, clamp_range=self.input_image_range)
lowres_cond_img = resize_image_to(
Expand Down
Loading

0 comments on commit 13b4341

Please sign in to comment.