diff --git a/ppfleetx/data/__init__.py b/ppfleetx/data/__init__.py index fc6607b5a..44a9ae2f9 100644 --- a/ppfleetx/data/__init__.py +++ b/ppfleetx/data/__init__.py @@ -58,6 +58,8 @@ def build_dataset(config, mode): def build_dataloader(config, mode): dataset = build_dataset(config, mode) + if dataset is None: + return None batch_sampler = None # build sampler diff --git a/ppfleetx/data/dataset/multimodal_dataset.py b/ppfleetx/data/dataset/multimodal_dataset.py index 1b3a7c8b3..9e42ff261 100644 --- a/ppfleetx/data/dataset/multimodal_dataset.py +++ b/ppfleetx/data/dataset/multimodal_dataset.py @@ -78,11 +78,11 @@ def data_augmentation_for_imagen(img, resolution): arr = deepcopy(img) while min(*arr.size) >= 2 * resolution: arr = arr.resize( - tuple(x // 2 for x in arr.size), resample=Image.Resampling.BOX) + tuple(x // 2 for x in arr.size), resample=Image.BOX) scale = resolution / min(*arr.size) arr = arr.resize( tuple(round(x * scale) for x in arr.size), - resample=Image.Resampling.BICUBIC) + resample=Image.BICUBIC) arr = np.array(arr.convert("RGB")) crop_y = (arr.shape[0] - resolution) // 2 diff --git a/ppfleetx/models/multimodal_model/imagen/modeling.py b/ppfleetx/models/multimodal_model/imagen/modeling.py index 4a8823225..ced95f633 100644 --- a/ppfleetx/models/multimodal_model/imagen/modeling.py +++ b/ppfleetx/models/multimodal_model/imagen/modeling.py @@ -21,8 +21,8 @@ import paddle.vision.transforms as T from .unet import Unet -from .utils import (GaussianDiffusionContinuousTimes, default, exists, - cast_tuple, first, maybe, eval_decorator, identity, +from .utils import (GaussianDiffusionContinuousTimes, default, cast_tuple, + first, maybe, eval_decorator, identity, pad_tuple_to_length, right_pad_dims_to, resize_image_to, normalize_neg_one_to_one, rearrange, repeat, reduce, unnormalize_zero_to_one, cast_uint8_images_to_float) @@ -195,9 +195,8 @@ def __init__(self, # randomly cropping for upsampler training self.random_crop_sizes = cast_tuple(random_crop_sizes, num_unets) - assert not exists( - first(self.random_crop_sizes) - ), 'you should not need to randomly crop image during training for base unet, only for upsamplers - so pass in `random_crop_sizes = (None, 128, 256)` as example' + assert first( + self.random_crop_sizes) is None, 'you should not need to randomly crop image during training for base unet, only for upsamplers - so pass in `random_crop_sizes = (None, 128, 256)` as example' # lowres augmentation noise schedule self.lowres_noise_schedule = GaussianDiffusionContinuousTimes( @@ -284,22 +283,17 @@ def get_unet(self, unet_number): assert 0 < unet_number <= len(self.unets) index = unet_number - 1 - if isinstance(self.unets, nn.LayerList): - unets_list = [unet for unet in self.unets] - delattr(self, 'unets') - self.unets = unets_list self.unet_being_trained_index = index return self.unets[index] def reset_unets(self, ): - self.unets = nn.LayerList([*self.unets]) self.unet_being_trained_index = -1 @contextmanager def one_unet_in_gpu(self, unet_number=None, unet=None): - assert exists(unet_number) ^ exists(unet) + assert (unet_number is not None) ^ (unet is not None) - if exists(unet_number): + if unet_number is not None: unet = self.unets[unet_number - 1] yield @@ -320,7 +314,6 @@ def p_mean_variance(self, unet, x, t, - *, noise_scheduler, text_embeds=None, text_mask=None, @@ -370,7 +363,6 @@ def p_sample(self, unet, x, t, - *, noise_scheduler, t_next=None, text_embeds=None, @@ -412,7 +404,6 @@ def p_sample(self, def p_sample_loop(self, unet, shape, - *, noise_scheduler, lowres_cond_img=None, lowres_noise_times=None, @@ -433,7 +424,7 @@ def p_sample_loop(self, # prepare inpainting - has_inpainting = exists(inpaint_images) and exists(inpaint_masks) + has_inpainting = inpaint_images is not None and inpaint_masks is not None resample_times = inpaint_resample_times if has_inpainting else 1 if has_inpainting: @@ -532,18 +523,18 @@ def sample( batch_size = text_embeds.shape[0] assert not ( - self.condition_on_text and not exists(text_embeds) + self.condition_on_text and text_embeds is None ), 'text or text encodings must be passed into imagen if specified' assert not ( - not self.condition_on_text and exists(text_embeds) + not self.condition_on_text and text_embeds is not None ), 'imagen specified not to be conditioned on text, yet it is presented' assert not ( - exists(text_embeds) and + text_embeds is not None and text_embeds.shape[-1] != self.text_embed_dim ), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})' assert not ( - exists(inpaint_images) ^ exists(inpaint_masks) + (inpaint_images is not None) ^ (inpaint_masks is not None) ), 'inpaint images and masks must be both passed in to do inpainting' outputs = [] @@ -609,8 +600,7 @@ def sample( outputs.append(img) - if exists(stop_at_unet_number - ) and stop_at_unet_number == unet_number: + if stop_at_unet_number is not None and stop_at_unet_number == unet_number: break output_index = -1 if not return_all_unet_outputs else slice( @@ -633,7 +623,6 @@ def p_losses(self, unet, x_start, times, - *, noise_scheduler, lowres_cond_img=None, lowres_aug_times=None, @@ -655,7 +644,7 @@ def p_losses(self, # random cropping during training # for upsamplers - if exists(random_crop_size): + if random_crop_size is not None: aug = K.RandomCrop((random_crop_size, random_crop_size), p=1.) # make sure low res conditioner and image both get augmented the same way # detailed https://kornia.readthedocs.io/en/latest/augmentation.module.html?highlight=randomcrop#kornia.augmentation.RandomCrop @@ -672,7 +661,7 @@ def p_losses(self, # at sample time, they then fix the noise level of 0.1 - 0.3 lowres_cond_img_noisy = None - if exists(lowres_cond_img): + if lowres_cond_img is not None: lowres_aug_times = default(lowres_aug_times, times) lowres_cond_img_noisy, _ = self.lowres_noise_schedule.q_sample( x_start=lowres_cond_img, @@ -715,11 +704,10 @@ def forward(self, assert images.shape[-1] == images.shape[ -2], f'the images you pass in must be a square, but received dimensions of {images.shape[2]}, {images.shape[-1]}' assert not ( - len(self.unets) > 1 and not exists(unet_number) + len(self.unets) > 1 and unet_number is None ), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)' unet_number = default(unet_number, 1) - assert not exists( - self.only_train_unet_number + assert (self.only_train_unet_number is None ) or self.only_train_unet_number == unet_number, 'you can only train on unet #{self.only_train_unet_number}' images = cast_uint8_images_to_float(images) @@ -748,19 +736,19 @@ def forward(self, text_masks, lambda: paddle.any(text_embeds != 0., axis=-1)) assert not ( - self.condition_on_text and not exists(text_embeds) + self.condition_on_text and text_embeds is None ), 'text or text encodings must be passed into decoder if specified' assert not ( - not self.condition_on_text and exists(text_embeds) + not self.condition_on_text and text_embeds is not None ), 'decoder specified not to be conditioned on text, yet it is presented' assert not ( - exists(text_embeds) and + (text_embeds is not None) and text_embeds.shape[-1] != self.text_embed_dim ), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})' lowres_cond_img = lowres_aug_times = None - if exists(prev_image_size): + if prev_image_size is not None: lowres_cond_img = resize_image_to( images, prev_image_size, clamp_range=self.input_image_range) lowres_cond_img = resize_image_to( diff --git a/ppfleetx/models/multimodal_model/imagen/unet.py b/ppfleetx/models/multimodal_model/imagen/unet.py index 923f887d6..6087eb85d 100644 --- a/ppfleetx/models/multimodal_model/imagen/unet.py +++ b/ppfleetx/models/multimodal_model/imagen/unet.py @@ -21,7 +21,7 @@ from paddle import nn, einsum import paddle.nn.functional as F -from .utils import (zeros_, zero_init_, default, exists, cast_tuple, +from .utils import (zeros_, zero_init_, default, cast_tuple, resize_image_to, prob_mask_like, masked_mean, Identity, repeat, repeat_many, Rearrange, rearrange, rearrange_many, EinopsToAndFrom, Parallel, Always) @@ -66,7 +66,7 @@ def forward(self, x): class GlobalContext(nn.Layer): """ basically a superior form of squeeze-excitation that is attention-esque """ - def __init__(self, *, dim_in, dim_out): + def __init__(self, dim_in, dim_out): super().__init__() self.to_k = nn.Conv2D(dim_in, 1, 1) hidden_dim = max(3, dim_out // 2) @@ -84,7 +84,7 @@ def forward(self, x): class PerceiverAttention(nn.Layer): - def __init__(self, *, dim, dim_head=64, heads=8, cosine_sim_attn=False): + def __init__(self, dim, dim_head=64, heads=8, cosine_sim_attn=False): super().__init__() self.scale = dim_head**-0.5 if not cosine_sim_attn else 1 self.cosine_sim_attn = cosine_sim_attn @@ -129,7 +129,7 @@ def forward(self, x, latents, mask=None): sim = einsum('... i d, ... j d -> ... i j', q, k) * self.cosine_sim_scale - if exists(mask): + if mask is not None: mask = F.pad(mask, (0, latents.shape[-2]), value=True) mask = rearrange(mask, 'b j -> b 1 1 j') sim = paddle.where(mask == 0, @@ -146,7 +146,6 @@ def forward(self, x, latents, mask=None): class PerceiverResampler(nn.Layer): def __init__( self, - *, dim, depth, dim_head=64, @@ -191,7 +190,7 @@ def forward(self, x, mask=None): latents = repeat(self.latents, 'n d -> b n d', b=x.shape[0]) - if exists(self.to_latents_from_mean_pooled_seq): + if self.to_latents_from_mean_pooled_seq is not None: meanpooled_seq = masked_mean( x, axis=1, mask=paddle.ones( x.shape[:2], dtype=paddle.bool)) @@ -209,7 +208,6 @@ def forward(self, x, mask=None): class CrossAttention(nn.Layer): def __init__(self, dim, - *, context_dim=None, dim_head=64, heads=8, @@ -272,7 +270,7 @@ def forward(self, x, context, mask=None): # masking - if exists(mask): + if mask is not None: mask = F.pad(mask, (1, 0), value=True) mask = rearrange(mask, 'b j -> b 1 1 j') sim = paddle.where(mask == 0, @@ -307,7 +305,7 @@ def forward(self, x, context, mask=None): # masking - if exists(mask): + if mask is not None: mask = F.pad(mask, (1, 0), value=True) mask = rearrange(mask, 'b n -> b n 1') k = paddle.where(mask == 0, @@ -336,7 +334,7 @@ def __init__(self, dim, dim_out, groups=8, norm=True): def forward(self, x, scale_shift=None): x = self.groupnorm(x) - if exists(scale_shift): + if scale_shift is not None: scale, shift = scale_shift x = x * (scale + 1) + shift @@ -348,7 +346,6 @@ class ResnetBlock(nn.Layer): def __init__(self, dim, dim_out, - *, cond_dim=None, time_cond_dim=None, groups=8, @@ -360,13 +357,13 @@ def __init__(self, self.time_mlp = None - if exists(time_cond_dim): + if time_cond_dim is not None: self.time_mlp = nn.Sequential( nn.Silu(), nn.Linear(time_cond_dim, dim_out * 2)) self.cross_attn = None - if exists(cond_dim): + if cond_dim is not None: attn_klass = CrossAttention if not linear_attn else LinearCrossAttention self.cross_attn = EinopsToAndFrom( @@ -387,15 +384,15 @@ def __init__(self, def forward(self, x, time_emb=None, cond=None): scale_shift = None - if exists(self.time_mlp) and exists(time_emb): + if self.time_mlp is not None and time_emb is not None: time_emb = self.time_mlp(time_emb) time_emb = time_emb[:, :, None, None] scale_shift = time_emb.chunk(2, axis=1) h = self.block1(x) - if exists(self.cross_attn): - assert exists(cond) + if self.cross_attn is not None: + assert cond is not None h = self.cross_attn(h, context=cond) + h h = self.block2(h, scale_shift=scale_shift) @@ -434,7 +431,6 @@ def ChanFeedForward( class Attention(nn.Layer): def __init__(self, dim, - *, dim_head=64, heads=8, context_dim=None, @@ -456,7 +452,7 @@ def __init__(self, self.to_context = nn.Sequential( nn.LayerNorm(context_dim), nn.Linear( - context_dim, dim_head * 2)) if exists(context_dim) else None + context_dim, dim_head * 2)) if context_dim is not None else None self.to_out = nn.Sequential( nn.Linear( @@ -480,8 +476,8 @@ def forward(self, x, context=None, mask=None, attn_bias=None): # add text conditioning, if present - if exists(context): - assert exists(self.to_context) + if context is not None: + assert self.to_context is not None ck, cv = self.to_context(context).chunk(2, axis=-1) k = paddle.concat((ck, k), axis=-2) v = paddle.concat((cv, v), axis=-2) @@ -497,12 +493,12 @@ def forward(self, x, context=None, mask=None, attn_bias=None): # relative positional encoding (T5 style) - if exists(attn_bias): + if attn_bias is not None: sim = sim + attn_bias # masking - if exists(mask): + if mask is not None: mask = F.pad(mask, (1, 0), value=True) mask = rearrange(mask, 'b j -> b 1 1 j') sim = paddle.where(mask == 0, @@ -532,7 +528,6 @@ def forward(self, x, **kwargs): class TransformerBlock(nn.Layer): def __init__(self, dim, - *, depth=1, heads=8, dim_head=32, @@ -639,7 +634,7 @@ def __init__(self, nn.LayerNorm(context_dim), nn.Linear( context_dim, inner_dim * 2, - bias_attr=False)) if exists(context_dim) else None + bias_attr=False)) if context_dim is not None else None self.to_out = nn.Sequential( nn.Conv2D( @@ -653,8 +648,8 @@ def forward(self, fmap, context=None): q, k, v = rearrange_many( (q, k, v), 'b (h c) x y -> (b h) (x y) c', h=h) - if exists(context): - assert exists(self.to_context) + if context is not None: + assert self.to_context is not None ck, cv = self.to_context(context).chunk(2, axis=-1) ck, cv = rearrange_many((ck, cv), 'b n (h d) -> (b h) n d', h=h) k = paddle.concat((k, ck), axis=-2) @@ -676,7 +671,6 @@ def forward(self, fmap, context=None): class LinearAttentionTransformerBlock(nn.Layer): def __init__(self, dim, - *, depth=1, heads=8, dim_head=32, @@ -778,7 +772,6 @@ def forward(self, x): class UpsampleCombiner(nn.Layer): def __init__(self, dim, - *, enabled=False, dim_ins=tuple(), dim_outs=tuple()): @@ -813,7 +806,6 @@ def forward(self, x, fmaps=None): class Unet(nn.Layer): def __init__(self, - *, dim, image_embed_dim=1024, text_embed_dim=1024, @@ -960,9 +952,7 @@ def __init__(self, self.text_to_cond = None if cond_on_text: - assert exists( - text_embed_dim - ), 'text_embed_dim must be given to the unet if cond_on_text is True' + assert text_embed_dim is not None, 'text_embed_dim must be given to the unet if cond_on_text is True' self.text_to_cond = nn.Linear(text_embed_dim, cond_dim) # finer control over whether to condition on text encodings @@ -1224,7 +1214,7 @@ def __init__(self, # if the current settings for the unet are not correct # for cascading DDPM, then reinit the unet with the right settings - def cast_model_parameters(self, *, lowres_cond, text_embed_dim, channels, + def cast_model_parameters(self, lowres_cond, text_embed_dim, channels, channels_out, cond_on_text): if lowres_cond == self.lowres_cond and \ channels == self.channels and \ @@ -1292,7 +1282,6 @@ def forward_with_cond_scale(self, *args, cond_scale=1., **kwargs): def forward(self, x, time, - *, lowres_cond_img=None, lowres_noise_times=None, text_embeds=None, @@ -1303,21 +1292,21 @@ def forward(self, # add low resolution conditioning, if present - assert not (self.lowres_cond and not exists(lowres_cond_img) + assert not (self.lowres_cond and lowres_cond_img is None ), 'low resolution conditioning image must be present' - assert not (self.lowres_cond and not exists(lowres_noise_times) + assert not (self.lowres_cond and lowres_noise_times is None ), 'low resolution conditioning noise time must be present' - if exists(lowres_cond_img): + if lowres_cond_img is not None: x = paddle.concat((x, lowres_cond_img), axis=1) # condition on input image assert not ( - self.has_cond_image ^ exists(cond_images) + self.has_cond_image ^ (cond_images is not None) ), 'you either requested to condition on an image on the unet, but the conditioning image is not supplied, or vice versa' - if exists(cond_images): + if cond_images is not None: assert cond_images.shape[ 1] == self.cond_images_channels, 'the number of channels on the conditioning image you are passing in does not match what you specified on initialiation of the unet' cond_images = resize_image_to(cond_images, x.shape[-1]) @@ -1359,7 +1348,7 @@ def forward(self, text_tokens = None - if exists(text_embeds) and self.cond_on_text: + if text_embeds is not None and self.cond_on_text: # conditional dropout @@ -1374,7 +1363,7 @@ def forward(self, text_tokens = text_tokens[:, :self.max_text_len] - if exists(text_mask): + if text_mask is not None: text_mask = text_mask[:, :self.max_text_len] text_tokens_len = text_tokens.shape[1] @@ -1384,7 +1373,7 @@ def forward(self, text_tokens = F.pad(text_tokens, (0, remainder), data_format='NLC') - if exists(text_mask): + if text_mask is not None: text_mask = text_mask[:, :, None] if remainder > 0: text_mask = F.pad(text_mask, (0, remainder), @@ -1401,7 +1390,7 @@ def forward(self, text_tokens, null_text_embed) - if exists(self.attn_pool): + if self.attn_pool is not None: text_tokens = self.attn_pool(text_tokens) # extra non-attention conditioning by projecting and then summing text embeddings to time @@ -1420,20 +1409,20 @@ def forward(self, # main conditioning tokens (c) - c = time_tokens if not exists(text_tokens) else paddle.concat( + c = time_tokens if text_tokens is None else paddle.concat( (time_tokens, text_tokens), axis=-2) # normalize conditioning tokens c = self.norm_cond(c) - if exists(self.init_resnet_block): + if self.init_resnet_block is not None: x = self.init_resnet_block(x, t) hiddens = [] for pre_downsample, init_block, resnet_blocks, attn_block, post_downsample in self.downs: - if exists(pre_downsample): + if pre_downsample is not None: x = pre_downsample(x) x = init_block(x, t, c) @@ -1445,12 +1434,12 @@ def forward(self, x = attn_block(x, c) hiddens.append(x) - if exists(post_downsample): + if post_downsample is not None: x = post_downsample(x) x = self.mid_block1(x, t, c) - if exists(self.mid_attn): + if self.mid_attn is not None: x = self.mid_attn(x) x = self.mid_block2(x, t, c) @@ -1476,10 +1465,10 @@ def forward(self, if self.init_conv_to_final_conv_residual: x = paddle.concat((x, init_conv_residual), axis=1) - if exists(self.final_res_block): + if self.final_res_block is not None: x = self.final_res_block(x, t) - if exists(lowres_cond_img): + if lowres_cond_img is not None: x = paddle.concat((x, lowres_cond_img), axis=1) return self.final_conv(x) diff --git a/ppfleetx/models/multimodal_model/imagen/utils.py b/ppfleetx/models/multimodal_model/imagen/utils.py index 9a36f6a41..d4aaf3379 100644 --- a/ppfleetx/models/multimodal_model/imagen/utils.py +++ b/ppfleetx/models/multimodal_model/imagen/utils.py @@ -23,10 +23,6 @@ # helper functions -def exists(val): - return val is not None - - def identity(t, *args, **kwargs): return t @@ -40,7 +36,7 @@ def first(arr, d=None): def maybe(fn): @wraps(fn) def inner(x): - if not exists(x): + if x is None: return x return fn(x) @@ -48,7 +44,7 @@ def inner(x): def default(val, d): - if exists(val): + if val is not None: return val return d() if callable(d) else d @@ -59,7 +55,7 @@ def cast_tuple(val, length=None): output = val if isinstance(val, tuple) else ((val, ) * default(length, 1)) - if exists(length): + if length is not None: assert len(output) == length return output @@ -95,7 +91,7 @@ def pad_tuple_to_length(t, length, fillvalue=None): def zero_init_(m): zeros_(m.weight) - if exists(m.bias): + if m.bias is not None: zeros_(m.bias) @@ -135,7 +131,7 @@ def resize_image_to(image, target_image_size, clamp_range=None): out = F.interpolate(image, target_image_size, mode='nearest') - if exists(clamp_range): + if clamp_range is not None: out = out.clip(*clamp_range) return out @@ -292,7 +288,7 @@ def log(t, eps: float=1e-12): def masked_mean(t, *, axis, mask=None): - if not exists(mask): + if mask is None: return t.mean(axis=axis) denom = mask.sum(axis=axis, keepdim=True) diff --git a/ppfleetx/models/multimodal_model/multimodal_module.py b/ppfleetx/models/multimodal_model/multimodal_module.py index bc18fc18c..ea6c0e39d 100644 --- a/ppfleetx/models/multimodal_model/multimodal_module.py +++ b/ppfleetx/models/multimodal_model/multimodal_module.py @@ -48,7 +48,7 @@ def training_step(self, batch): return loss def training_step_end(self, log_dict): - speed = self.configs.Engine.logging_freq / log_dict['train_cost'] + speed = 1. / log_dict['train_cost'] logger.info( "[train] epoch: %d, batch: %d, loss: %.9f, avg_batch_cost: %.5f sec, speed: %.2f step/s, learning rate: %.5e"