From e20ab7d8a0b225304814a61775f7c503eb494640 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Sat, 5 Oct 2024 19:01:37 +0800 Subject: [PATCH 1/8] Refactor `MMDiT` and add `ImageToImage` --- keras_hub/api/models/__init__.py | 4 + keras_hub/src/models/image_to_image.py | 345 ++++++++++++++ keras_hub/src/models/preprocessor.py | 8 +- .../src/models/stable_diffusion_3/mmdit.py | 422 +++++++++++------- .../stable_diffusion_3_image_to_image.py | 163 +++++++ .../stable_diffusion_3_image_to_image_test.py | 161 +++++++ keras_hub/src/tokenizers/tokenizer.py | 14 +- keras_hub/src/utils/preset_utils.py | 14 +- .../convert_stable_diffusion_3_checkpoints.py | 28 +- 9 files changed, 976 insertions(+), 183 deletions(-) create mode 100644 keras_hub/src/models/image_to_image.py create mode 100644 keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py create mode 100644 keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image_test.py diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 1450ddceb3..9983e1a8ea 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -180,6 +180,7 @@ from keras_hub.src.models.image_segmenter_preprocessor import ( ImageSegmenterPreprocessor, ) +from keras_hub.src.models.image_to_image import ImageToImage from keras_hub.src.models.llama3.llama3_backbone import Llama3Backbone from keras_hub.src.models.llama3.llama3_causal_lm import Llama3CausalLM from keras_hub.src.models.llama3.llama3_causal_lm_preprocessor import ( @@ -270,6 +271,9 @@ from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( StableDiffusion3Backbone, ) +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_image_to_image import ( + StableDiffusion3ImageToImage, +) from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image import ( StableDiffusion3TextToImage, ) diff --git a/keras_hub/src/models/image_to_image.py b/keras_hub/src/models/image_to_image.py new file mode 100644 index 0000000000..2139b1af5d --- /dev/null +++ b/keras_hub/src/models/image_to_image.py @@ -0,0 +1,345 @@ +import itertools +from functools import partial + +import keras +from keras import ops +from keras import random + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.task import Task +from keras_hub.src.utils.keras_utils import standardize_data_format + +try: + import tensorflow as tf +except ImportError: + tf = None + + +@keras_hub_export("keras_hub.models.ImageToImage") +class ImageToImage(Task): + """Base class for image-to-image tasks. + + `ImageToImage` tasks wrap a `keras_hub.models.Backbone` and + a `keras_hub.models.Preprocessor` to create a model that can be used for + generation and generative fine-tuning. + + `ImageToImage` tasks provide an additional, high-level `generate()` function + which can be used to generate image by token with a (image, string) in, + image out signature. + + All `ImageToImage` tasks include a `from_preset()` constructor which can be + used to load a pre-trained config and weights. + + Example: + + ```python + # Load a Stable Diffusion 3 backbone with pre-trained weights. + reference_image = np.ones((1024, 1024, 3), dtype="float32") + image_to_image = keras_hub.models.ImageToImage.from_preset( + "stable_diffusion_3_medium", + ) + image_to_image.generate( + reference_image, + "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + ) + + # Load a Stable Diffusion 3 backbone at bfloat16 precision. + image_to_image = keras_hub.models.ImageToImage.from_preset( + "stable_diffusion_3_medium", + dtype="bfloat16", + ) + image_to_image.generate( + reference_image, + "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + ) + ``` + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Default compilation. + self.compile() + + @property + def image_shape(self): + return tuple(self.backbone.image_shape) + + @property + def latent_shape(self): + return tuple(self.backbone.latent_shape) + + def compile( + self, + optimizer="auto", + loss="auto", + *, + metrics="auto", + **kwargs, + ): + """Configures the `ImageToImage` task for training. + + The `ImageToImage` task extends the default compilation signature of + `keras.Model.compile` with defaults for `optimizer`, `loss`, and + `metrics`. To override these defaults, pass any value + to these arguments during compilation. + + Args: + optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer` + instance. Defaults to `"auto"`, which uses the default optimizer + for the given model and task. See `keras.Model.compile` and + `keras.optimizers` for more info on possible `optimizer` values. + loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance. + Defaults to `"auto"`, where a + `keras.losses.MeanSquaredError` loss will be applied. See + `keras.Model.compile` and `keras.losses` for more info on + possible `loss` values. + metrics: `"auto"`, or a list of metrics to be evaluated by + the model during training and testing. Defaults to `"auto"`, + where a `keras.metrics.MeanSquaredError` will be applied to + track the loss of the model during training. See + `keras.Model.compile` and `keras.metrics` for more info on + possible `metrics` values. + **kwargs: See `keras.Model.compile` for a full list of arguments + supported by the compile method. + """ + # Ref: https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py#L410-L414 + if optimizer == "auto": + optimizer = keras.optimizers.AdamW( + 1e-4, weight_decay=1e-2, epsilon=1e-8, clipnorm=1.0 + ) + if loss == "auto": + loss = keras.losses.MeanSquaredError() + if metrics == "auto": + metrics = [keras.metrics.MeanSquaredError()] + super().compile( + optimizer=optimizer, + loss=loss, + metrics=metrics, + **kwargs, + ) + self.generate_function = None + + def generate_step(self, *args, **kwargs): + """Run generation on batches of input.""" + raise NotImplementedError + + def make_generate_function(self): + """Create or return the compiled generation function.""" + if self.generate_function is not None: + return self.generate_function + + self.generate_function = self.generate_step + if keras.config.backend() == "torch": + import torch + + def wrapped_function(*args, **kwargs): + with torch.no_grad(): + return self.generate_step(*args, **kwargs) + + self.generate_function = wrapped_function + elif keras.config.backend() == "tensorflow" and not self.run_eagerly: + self.generate_function = tf.function( + self.generate_step, jit_compile=self.jit_compile + ) + elif keras.config.backend() == "jax" and not self.run_eagerly: + import jax + + @partial(jax.jit) + def compiled_function(state, *args, **kwargs): + ( + trainable_variables, + non_trainable_variables, + ) = state + mapping = itertools.chain( + zip(self.trainable_variables, trainable_variables), + zip(self.non_trainable_variables, non_trainable_variables), + ) + + with keras.StatelessScope(state_mapping=mapping): + outputs = self.generate_step(*args, **kwargs) + return outputs + + def wrapped_function(*args, **kwargs): + # Create an explicit tuple of all variable state. + state = ( + # Use the explicit variable.value to preserve the + # sharding spec of distribution. + [v.value for v in self.trainable_variables], + [v.value for v in self.non_trainable_variables], + ) + outputs = compiled_function(state, *args, **kwargs) + return outputs + + self.generate_function = wrapped_function + return self.generate_function + + def _normalize_generate_images(self, inputs): + """Normalize user image to the generate function. + + This function converts all inputs to tensors, adds a batch dimension if + necessary, and returns a iterable "dataset like" object (either an + actual `tf.data.Dataset` or a list with a single batch element). + """ + if tf and isinstance(inputs, tf.data.Dataset): + return inputs.as_numpy_iterator(), False + + def normalize(x): + data_format = getattr( + self.backbone, "data_format", standardize_data_format(None) + ) + input_is_scalar = False + x = ops.convert_to_tensor(x) + if len(ops.shape(x)) < 4: + x = ops.expand_dims(x, axis=0) + input_is_scalar = True + x = ops.image.resize( + x, + (self.backbone.height, self.backbone.width), + interpolation="nearest", + data_format=data_format, + ) + return x, input_is_scalar + + if isinstance(inputs, dict): + for key in inputs: + inputs[key], input_is_scalar = normalize(inputs[key]) + else: + inputs, input_is_scalar = normalize(inputs) + + return inputs, input_is_scalar + + def _normalize_generate_inputs(self, inputs): + """Normalize user input to the generate function. + + This function converts all inputs to tensors, adds a batch dimension if + necessary, and returns a iterable "dataset like" object (either an + actual `tf.data.Dataset` or a list with a single batch element). + """ + if tf and isinstance(inputs, tf.data.Dataset): + return inputs.as_numpy_iterator(), False + + def normalize(x): + if isinstance(x, str): + return [x], True + if tf and isinstance(x, tf.Tensor) and x.shape.rank == 0: + return x[tf.newaxis], True + return x, False + + if isinstance(inputs, dict): + for key in inputs: + inputs[key], input_is_scalar = normalize(inputs[key]) + else: + inputs, input_is_scalar = normalize(inputs) + + return inputs, input_is_scalar + + def _normalize_generate_outputs(self, outputs, input_is_scalar): + """Normalize user output from the generate function. + + This function converts all output to numpy with a value range of + `[0, 255]`. If a batch dimension was added to the input, it is removed + from the output. + """ + + def normalize(x): + outputs = ops.clip(ops.divide(ops.add(x, 1.0), 2.0), 0.0, 1.0) + outputs = ops.cast(ops.round(ops.multiply(outputs, 255.0)), "uint8") + outputs = ops.convert_to_numpy(outputs) + if input_is_scalar: + outputs = outputs[0] + return outputs + + if isinstance(outputs[0], dict): + normalized = {} + for key in outputs[0]: + normalized[key] = normalize([x[key] for x in outputs]) + return normalized + return normalize([x for x in outputs]) + + def generate( + self, + images, + inputs, + negative_inputs, + num_steps, + guidance_scale, + strength, + seed=None, + ): + """Generate image based on the provided `images` and `inputs`. + + If `images` and `inputs` are a `tf.data.Dataset`, outputs will be + generated "batch-by-batch" and concatenated. Otherwise, all inputs will + be processed as batches. + + Args: + images: python data, tensor data, or a `tf.data.Dataset`. + inputs: python data, tensor data, or a `tf.data.Dataset`. + negative_inputs: python data, tensor data, or a `tf.data.Dataset`. + Unlike `inputs`, these are used as negative inputs to guide the + generation. If not provided, it defaults to `""` for each input + in `inputs`. + num_steps: int. The number of diffusion steps to take. + guidance_scale: float. The classifier free guidance scale defined in + [Classifier-Free Diffusion Guidance]( + https://arxiv.org/abs/2207.12598). A higher scale encourages + generating images more closely related to the prompts, typically + at the cost of lower image quality. + strength: float. Indicates the extent to which the reference + `images` are transformed. Must be between `0.0` and `1.0`. When + `strength=1.0`, `images` is essentially ignore and added noise + is maximum and the denoising process runs for the full number of + iterations specified in `num_steps`. + seed: optional int. Used as a random seed. + """ + num_steps = int(num_steps) + guidance_scale = float(guidance_scale) + strength = float(strength) + if strength < 0.0 or strength > 1.0: + raise ValueError( + "`strength` must be between `0.0` and `1.0`. " + f"Received strength={strength}." + ) + + # Setup our three main passes. + # 1. Preprocessing strings to dense integer tensors. + # 2. Generate outputs via a compiled function on dense tensors. + # 3. Postprocess dense tensors to a value range of `[0, 255]`. + generate_function = self.make_generate_function() + + def preprocess(x): + return self.preprocessor.generate_preprocess(x) + + # Normalize and preprocess inputs. + images, image_is_scalar = self._normalize_generate_images(images) + inputs, _ = self._normalize_generate_inputs(inputs) + if negative_inputs is None: + negative_inputs = [""] * len(inputs) + negative_inputs, _ = self._normalize_generate_inputs(negative_inputs) + + if self.preprocessor is not None: + inputs = preprocess(inputs) + negative_inputs = preprocess(negative_inputs) + if isinstance(inputs, dict): + batch_size = len(inputs[list(inputs.keys())[0]]) + else: + batch_size = len(inputs) + + # Get the starting step for denoising. + starting_step = int(num_steps * (1.0 - strength)) + + # Initialize random noises. + noise_shape = (batch_size,) + self.latent_shape[1:] + noises = random.normal(noise_shape, dtype="float32", seed=seed) + + # Image-to-image. + outputs = generate_function( + ops.convert_to_tensor(images), + noises, + inputs, + negative_inputs, + ops.convert_to_tensor(starting_step, "int32"), + ops.convert_to_tensor(num_steps, "int32"), + ops.convert_to_tensor(guidance_scale), + ) + return self._normalize_generate_outputs(outputs, image_is_scalar) diff --git a/keras_hub/src/models/preprocessor.py b/keras_hub/src/models/preprocessor.py index f0569a36f8..a53955640d 100644 --- a/keras_hub/src/models/preprocessor.py +++ b/keras_hub/src/models/preprocessor.py @@ -32,7 +32,7 @@ class Preprocessor(PreprocessingLayer): image_converter_cls = None def __init__(self, *args, **kwargs): - self.config_name = kwargs.pop("config_name", PREPROCESSOR_CONFIG_FILE) + self.config_file = kwargs.pop("config_file", PREPROCESSOR_CONFIG_FILE) super().__init__(*args, **kwargs) self._tokenizer = None self._image_converter = None @@ -85,7 +85,7 @@ def get_config(self): ) config.update( { - "config_name": self.config_name, + "config_file": self.config_file, } ) return config @@ -117,7 +117,7 @@ def presets(cls): def from_preset( cls, preset, - config_name=PREPROCESSOR_CONFIG_FILE, + config_file=PREPROCESSOR_CONFIG_FILE, **kwargs, ): """Instantiate a `keras_hub.models.Preprocessor` from a model preset. @@ -167,7 +167,7 @@ def from_preset( # Detect the correct subclass if we need to. if cls.backbone_cls != backbone_cls: cls = find_subclass(preset, cls, backbone_cls) - return loader.load_preprocessor(cls, config_name, **kwargs) + return loader.load_preprocessor(cls, config_file, **kwargs) @classmethod def _add_missing_kwargs(cls, loader, kwargs): diff --git a/keras_hub/src/models/stable_diffusion_3/mmdit.py b/keras_hub/src/models/stable_diffusion_3/mmdit.py index 0fe78e571b..0a618a427c 100644 --- a/keras_hub/src/models/stable_diffusion_3/mmdit.py +++ b/keras_hub/src/models/stable_diffusion_3/mmdit.py @@ -2,7 +2,6 @@ import keras from keras import layers -from keras import models from keras import ops from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding @@ -11,7 +10,167 @@ from keras_hub.src.utils.keras_utils import standardize_data_format +class AdaptiveLayerNormalization(layers.Layer): + """Adaptive layer normalization. + + Args: + embedding_dim: int. The size of each embedding vector. + residual_modulation: bool. Whether to output the modulation parameters + of the residual connection within the block of the diffusion + transformers. Defaults to `False`. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `dtype` etc. + + References: + - [FiLM: Visual Reasoning with a General Conditioning Layer]( + https://arxiv.org/abs/1709.07871). + - [Scalable Diffusion Models with Transformers]( + https://arxiv.org/abs/2212.09748). + """ + + def __init__(self, hidden_dim, residual_modulation=False, **kwargs): + super().__init__(**kwargs) + self.hidden_dim = int(hidden_dim) + self.residual_modulation = bool(residual_modulation) + num_modulations = 6 if self.residual_modulation else 2 + + self.silu = layers.Activation("silu", dtype=self.dtype_policy) + self.dense = layers.Dense( + num_modulations * hidden_dim, dtype=self.dtype_policy, name="dense" + ) + self.norm = layers.LayerNormalization( + epsilon=1e-6, + center=False, + scale=False, + dtype="float32", + name="norm", + ) + + def build(self, inputs_shape, embeddings_shape): + self.silu.build(embeddings_shape) + self.dense.build(embeddings_shape) + self.norm.build(inputs_shape) + + def call(self, inputs, embeddings, training=None): + x = inputs + emb = self.dense(self.silu(embeddings), training=training) + if self.residual_modulation: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + ops.split(emb, 6, axis=1) + ) + else: + shift_msa, scale_msa = ops.split(emb, 2, axis=1) + scale_msa = ops.expand_dims(scale_msa, axis=1) + shift_msa = ops.expand_dims(shift_msa, axis=1) + x = ops.add( + ops.multiply( + self.norm(x, training=training), + ops.add(1.0, scale_msa), + ), + shift_msa, + ) + if self.residual_modulation: + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp + else: + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_dim": self.hidden_dim, + "residual_modulation": self.residual_modulation, + } + ) + return config + + def compute_output_shape(self, inputs_shape, embeddings_shape): + if self.residual_modulation: + return ( + inputs_shape, + embeddings_shape, + embeddings_shape, + embeddings_shape, + embeddings_shape, + ) + else: + return inputs_shape + + +class MLP(layers.Layer): + """A MLP block with architecture. + + Args: + hidden_dim: int. The number of units in the hidden layers. + output_dim: int. The number of units in the output layer. + activation: str of callable. Activation to use in the hidden layers. + Default to `None`. + """ + + def __init__(self, hidden_dim, output_dim, activation=None, **kwargs): + super().__init__(**kwargs) + self.hidden_dim = int(hidden_dim) + self.output_dim = int(output_dim) + self.activation = keras.activations.get(activation) + + self.dense1 = layers.Dense( + hidden_dim, + activation=self.activation, + dtype=self.dtype_policy, + name="dense1", + ) + self.dense2 = layers.Dense( + output_dim, + activation=None, + dtype=self.dtype_policy, + name="dense2", + ) + + def build(self, inputs_shape): + self.dense1.build(inputs_shape) + inputs_shape = self.dense1.compute_output_shape(inputs_shape) + self.dense2.build(inputs_shape) + + def call(self, inputs, training=None): + x = self.dense1(inputs, training=training) + return self.dense2(x, training=training) + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_dim": self.hidden_dim, + "output_dim": self.output_dim, + "activation": keras.activations.serialize(self.activation), + } + ) + return config + + def compute_output_shape(self, inputs_shape): + outputs_shape = list(inputs_shape) + outputs_shape[-1] = self.output_dim + return outputs_shape + + class PatchEmbedding(layers.Layer): + """A layer that converts images into patches. + + Args: + patch_size: int. The size of one side of each patch. + hidden_dim: int. The number of units in the hidden layers. + data_format: `None` or str. If specified, either `"channels_last"` or + `"channels_first"`. The ordering of the dimensions in the + inputs. `"channels_last"` corresponds to inputs with shape + `(batch_size, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch_size, channels, height, width)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + `"channels_last"`. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `dtype` etc. + """ + def __init__(self, patch_size, hidden_dim, data_format=None, **kwargs): super().__init__(**kwargs) self.patch_size = int(patch_size) @@ -48,6 +207,15 @@ def get_config(self): class AdjustablePositionEmbedding(PositionEmbedding): + """A position embedding layer with adjustable height and width. + + The embedding will be cropped to match the input dimensions. + + Args: + height: int. The maximum height of the embedding. + width: int. The maximum width of the embedding. + """ + def __init__( self, height, @@ -89,6 +257,20 @@ def compute_output_shape(self, input_shape): class TimestepEmbedding(layers.Layer): + """A layer which learns embedding for input timesteps. + + Args: + embedding_dim: int. The size of the embedding. + frequency_dim: int. The size of the frequency. + max_period: int. Controls the maximum frequency of the embeddings. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `dtype` etc. + + Reference: + - [Denoising Diffusion Probabilistic Models]( + https://arxiv.org/abs/2006.11239). + """ + def __init__( self, embedding_dim, frequency_dim=256, max_period=10000, **kwargs ): @@ -98,15 +280,11 @@ def __init__( self.max_period = float(max_period) self.half_frequency_dim = self.frequency_dim // 2 - self.mlp = models.Sequential( - [ - layers.Dense( - embedding_dim, activation="silu", dtype=self.dtype_policy - ), - layers.Dense( - embedding_dim, activation=None, dtype=self.dtype_policy - ), - ], + self.mlp = MLP( + embedding_dim, + embedding_dim, + "silu", + dtype=self.dtype_policy, name="mlp", ) @@ -155,6 +333,18 @@ def compute_output_shape(self, inputs_shape): class DismantledBlock(layers.Layer): + """A dismantled block used to compute pre- and post-attention. + + Args: + num_heads: int. Number of attention heads. + hidden_dim: int. The number of units in the hidden layers. + mlp_ratio: float. The expansion ratio of `MLP`. + use_projection: bool. Whether to use an attention projection layer at + the end of the block. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `dtype` etc. + """ + def __init__( self, num_heads, @@ -173,25 +363,18 @@ def __init__( self.head_dim = head_dim mlp_hidden_dim = int(hidden_dim * mlp_ratio) self.mlp_hidden_dim = mlp_hidden_dim - num_modulations = 6 if use_projection else 2 - self.num_modulations = num_modulations - - self.adaptive_norm_modulation = models.Sequential( - [ - layers.Activation("silu", dtype=self.dtype_policy), - layers.Dense( - num_modulations * hidden_dim, dtype=self.dtype_policy - ), - ], - name="adaptive_norm_modulation", - ) - self.norm1 = layers.LayerNormalization( - epsilon=1e-6, - center=False, - scale=False, - dtype="float32", - name="norm1", - ) + + if use_projection: + self.ada_layer_norm = AdaptiveLayerNormalization( + hidden_dim, + residual_modulation=True, + dtype=self.dtype_policy, + name="ada_layer_norm", + ) + else: + self.ada_layer_norm = AdaptiveLayerNormalization( + hidden_dim, dtype=self.dtype_policy, name="ada_layer_norm" + ) self.attention_qkv = layers.Dense( hidden_dim * 3, dtype=self.dtype_policy, name="attention_qkv" ) @@ -206,73 +389,45 @@ def __init__( dtype="float32", name="norm2", ) - self.mlp = models.Sequential( - [ - layers.Dense( - mlp_hidden_dim, - activation=gelu_approximate, - dtype=self.dtype_policy, - ), - layers.Dense( - hidden_dim, - dtype=self.dtype_policy, - ), - ], + self.mlp = MLP( + mlp_hidden_dim, + hidden_dim, + gelu_approximate, + dtype=self.dtype_policy, name="mlp", ) def build(self, inputs_shape, timestep_embedding): - self.adaptive_norm_modulation.build(timestep_embedding) + self.ada_layer_norm.build(inputs_shape, timestep_embedding) self.attention_qkv.build(inputs_shape) - self.norm1.build(inputs_shape) if self.use_projection: self.attention_proj.build(inputs_shape) self.norm2.build(inputs_shape) self.mlp.build(inputs_shape) def _modulate(self, inputs, shift, scale): - shift = ops.expand_dims(shift, axis=1) - scale = ops.expand_dims(scale, axis=1) + inputs = ops.cast(inputs, self.compute_dtype) + shift = ops.cast(shift, self.compute_dtype) + scale = ops.cast(scale, self.compute_dtype) return ops.add(ops.multiply(inputs, ops.add(scale, 1.0)), shift) def _compute_pre_attention(self, inputs, timestep_embedding, training=None): batch_size = ops.shape(inputs)[0] if self.use_projection: - modulation = self.adaptive_norm_modulation( - timestep_embedding, training=training - ) - modulation = ops.reshape( - modulation, (batch_size, 6, self.hidden_dim) - ) - ( - shift_msa, - scale_msa, - gate_msa, - shift_mlp, - scale_mlp, - gate_mlp, - ) = ops.unstack(modulation, 6, axis=1) - qkv = self.attention_qkv( - self._modulate(self.norm1(inputs), shift_msa, scale_msa), - training=training, + x, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.ada_layer_norm( + inputs, timestep_embedding, training=training ) + qkv = self.attention_qkv(x, training=training) qkv = ops.reshape( qkv, (batch_size, -1, 3, self.num_heads, self.head_dim) ) q, k, v = ops.unstack(qkv, 3, axis=2) return (q, k, v), (inputs, gate_msa, shift_mlp, scale_mlp, gate_mlp) else: - modulation = self.adaptive_norm_modulation( - timestep_embedding, training=training - ) - modulation = ops.reshape( - modulation, (batch_size, 2, self.hidden_dim) - ) - shift_msa, scale_msa = ops.unstack(modulation, 2, axis=1) - qkv = self.attention_qkv( - self._modulate(self.norm1(inputs), shift_msa, scale_msa), - training=training, + x = self.ada_layer_norm( + inputs, timestep_embedding, training=training ) + qkv = self.attention_qkv(x, training=training) qkv = ops.reshape( qkv, (batch_size, -1, 3, self.num_heads, self.head_dim) ) @@ -283,12 +438,16 @@ def _compute_post_attention( self, inputs, inputs_intermediates, training=None ): x, gate_msa, shift_mlp, scale_mlp, gate_mlp = inputs_intermediates + gate_msa = ops.expand_dims(gate_msa, axis=1) + shift_mlp = ops.expand_dims(shift_mlp, axis=1) + scale_mlp = ops.expand_dims(scale_mlp, axis=1) + gate_mlp = ops.expand_dims(gate_mlp, axis=1) attn = self.attention_proj(inputs, training=training) - x = ops.add(x, ops.multiply(ops.expand_dims(gate_msa, axis=1), attn)) + x = ops.add(x, ops.multiply(gate_msa, attn)) x = ops.add( x, ops.multiply( - ops.expand_dims(gate_mlp, axis=1), + gate_mlp, self.mlp( self._modulate(self.norm2(x), shift_mlp, scale_mlp), training=training, @@ -328,6 +487,27 @@ def get_config(self): class MMDiTBlock(layers.Layer): + """A MMDiT block consisting of two `DismantledBlock` layers. + + One `DismantledBlock` processes the input latents, and the other processes + the context embedding. This block integrates two modalities within the + attention operation, allowing each representation to operate in its own + space while considering the other. + + Args: + num_heads: int. Number of attention heads. + hidden_dim: int. The number of units in the hidden layers. + mlp_ratio: float. The expansion ratio of `MLP`. + use_context_projection: bool. Whether to use an attention projection + layer at the end of the context block. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `dtype` etc. + + Reference: + - [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis]( + https://arxiv.org/abs/2403.03206) + """ + def __init__( self, num_heads, @@ -453,74 +633,16 @@ def compute_output_shape( return inputs_shape -class OutputLayer(layers.Layer): - def __init__(self, hidden_dim, output_dim, **kwargs): - super().__init__(**kwargs) - self.hidden_dim = hidden_dim - self.output_dim = output_dim - num_modulation = 2 - - self.adaptive_norm_modulation = models.Sequential( - [ - layers.Activation("silu", dtype=self.dtype_policy), - layers.Dense( - num_modulation * hidden_dim, dtype=self.dtype_policy - ), - ], - name="adaptive_norm_modulation", - ) - self.norm = layers.LayerNormalization( - epsilon=1e-6, - center=False, - scale=False, - dtype="float32", - name="norm", - ) - self.output_dense = layers.Dense( - output_dim, - use_bias=True, - dtype=self.dtype_policy, - name="output_dense", - ) - - def build(self, inputs_shape, timestep_embedding_shape): - self.adaptive_norm_modulation.build(timestep_embedding_shape) - self.norm.build(inputs_shape) - self.output_dense.build(inputs_shape) - - def _modulate(self, inputs, shift, scale): - shift = ops.expand_dims(shift, axis=1) - scale = ops.expand_dims(scale, axis=1) - return ops.add(ops.multiply(inputs, ops.add(scale, 1.0)), shift) - - def call(self, inputs, timestep_embedding, training=None): - x = inputs - modulation = self.adaptive_norm_modulation( - timestep_embedding, training=training - ) - modulation = ops.reshape(modulation, (-1, 2, self.hidden_dim)) - shift, scale = ops.unstack(modulation, 2, axis=1) - x = self._modulate(self.norm(x), shift, scale) - x = self.output_dense(x, training=training) - return x - - def get_config(self): - config = super().get_config() - config.update( - { - "hidden_dim": self.hidden_dim, - "output_dim": self.output_dim, - } - ) - return config - - def compute_output_shape(self, inputs_shape): - outputs_shape = list(inputs_shape) - outputs_shape[-1] = self.output_dim - return outputs_shape +class Unpatch(layers.Layer): + """A layer that reconstructs the image from hidden patches. + Args: + patch_size: int. The size of each square patch in the input image. + output_dim: int. The number of units in the output layer. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `dtype` etc. + """ -class Unpatch(layers.Layer): def __init__(self, patch_size, output_dim, **kwargs): super().__init__(**kwargs) self.patch_size = int(patch_size) @@ -556,7 +678,7 @@ def compute_output_shape(self, inputs_shape): class MMDiT(Backbone): - """Multimodal Diffusion Transformer (MMDiT) model for Stable Diffusion 3. + """A Multimodal Diffusion Transformer (MMDiT) model. MMDiT is introduced in [ Scaling Rectified Flow Transformers for High-Resolution Image Synthesis]( @@ -636,12 +758,8 @@ def __init__( dtype=dtype, name="context_embedding", ) - self.vector_embedding = models.Sequential( - [ - layers.Dense(hidden_dim, activation="silu", dtype=dtype), - layers.Dense(hidden_dim, activation=None, dtype=dtype), - ], - name="vector_embedding", + self.vector_embedding = MLP( + hidden_dim, hidden_dim, "silu", dtype=dtype, name="vector_embedding" ) self.vector_embedding_add = layers.Add( dtype=dtype, name="vector_embedding_add" @@ -660,8 +778,11 @@ def __init__( ) for i in range(num_layers) ] - self.output_layer = OutputLayer( - hidden_dim, output_dim_in_final, dtype=dtype, name="output_layer" + self.output_ada_layer_norm = AdaptiveLayerNormalization( + hidden_dim, dtype=dtype, name="output_ada_layer_norm" + ) + self.output_dense = layers.Dense( + output_dim_in_final, dtype=dtype, name="output_dense" ) self.unpatch = Unpatch( patch_size, output_dim, dtype=dtype, name="unpatch" @@ -696,7 +817,8 @@ def __init__( x = block(x, context, timestep_embedding) # Output layer. - x = self.output_layer(x, timestep_embedding) + x = self.output_ada_layer_norm(x, timestep_embedding) + x = self.output_dense(x) outputs = self.unpatch(x, height=image_height, width=image_width) super().__init__( diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py new file mode 100644 index 0000000000..7a6714c65b --- /dev/null +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py @@ -0,0 +1,163 @@ +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.image_to_image import ImageToImage +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( + StableDiffusion3Backbone, +) +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import ( + StableDiffusion3TextToImagePreprocessor, +) + + +@keras_hub_export("keras_hub.models.StableDiffusion3ImageToImage") +class StableDiffusion3ImageToImage(ImageToImage): + """An end-to-end Stable Diffusion 3 model for image-to-image generation. + + This model has a `generate()` method, which generates image based on a pair + of image and prompt. + + Args: + backbone: A `keras_hub.models.StableDiffusion3Backbone` instance. + preprocessor: A + `keras_hub.models.StableDiffusion3TextToImagePreprocessor` instance. + + Examples: + + Use `generate()` to do image generation. + ```python + reference_image = np.ones((512, 512, 3), dtype="float32") + image_to_image = keras_hub.models.StableDiffusion3ImageToImage.from_preset( + "stable_diffusion_3_medium", height=512, width=512 + ) + image_to_image.generate( + reference_image, + "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + ) + + # Generate with batched prompts. + reference_images = np.ones((2, 512, 512, 3), dtype="float32") + image_to_image.generate( + reference_images, + ["cute wallpaper art of a cat", "cute wallpaper art of a dog"] + ) + + # Generate with different `num_steps`, `guidance_scale` and `strength`. + image_to_image.generate( + reference_image, + "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + num_steps=50, + guidance_scale=5.0, + strength=0.6, + ) + ``` + """ + + backbone_cls = StableDiffusion3Backbone + preprocessor_cls = StableDiffusion3TextToImagePreprocessor + + def __init__( + self, + backbone, + preprocessor, + **kwargs, + ): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + + # === Functional Model === + inputs = backbone.input + outputs = backbone.output + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + def fit(self, *args, **kwargs): + raise NotImplementedError( + "Currently, `fit` is not supported for " + "`StableDiffusion3ImageToImage`." + ) + + def generate_step( + self, + images, + noises, + token_ids, + negative_token_ids, + starting_step, + num_steps, + guidance_scale, + ): + """A compilable generation function for batched of inputs. + + This function represents the inner, XLA-compilable, generation function + for batched inputs. + + Args: + images: A (batch_size, image_height, image_width, 3) tensor + containing the reference images. + noises: A (batch_size, latent_height, latent_width, channels) tensor + containing the noises to be added to the latents. Typically, + this tensor is sampled from the Gaussian distribution. + token_ids: A (batch_size, num_tokens) tensor containing the + tokens based on the input prompts. + negative_token_ids: A (batch_size, num_tokens) tensor + containing the negative tokens based on the input prompts. + starting_step: int. The number of the starting diffusion step. + num_steps: int. The number of diffusion steps to take. + guidance_scale: float. The classifier free guidance scale defined in + [Classifier-Free Diffusion Guidance]( + https://arxiv.org/abs/2207.12598). Higher scale encourages to + generate images that are closely linked to prompts, usually at + the expense of lower image quality. + """ + # Encode images. + latents = self.backbone.encode_image_step(images) + + # Add noises to latents. + latents = self.backbone.add_noise_step( + latents, noises, starting_step, num_steps + ) + + # Encode inputs. + embeddings = self.backbone.encode_text_step( + token_ids, negative_token_ids + ) + + # Denoise. + def body_fun(step, latents): + return self.backbone.denoise_step( + latents, + embeddings, + step, + num_steps, + guidance_scale, + ) + + latents = ops.fori_loop(starting_step, num_steps, body_fun, latents) + + # Decode. + return self.backbone.decode_step(latents) + + def generate( + self, + images, + inputs, + negative_inputs=None, + num_steps=28, + guidance_scale=7.0, + strength=0.8, + seed=None, + ): + return super().generate( + images, + inputs, + negative_inputs=negative_inputs, + num_steps=num_steps, + guidance_scale=guidance_scale, + strength=strength, + seed=seed, + ) diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image_test.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image_test.py new file mode 100644 index 0000000000..1f9d4c19d6 --- /dev/null +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image_test.py @@ -0,0 +1,161 @@ +import keras +import pytest +from keras import ops + +from keras_hub.src.models.clip.clip_preprocessor import CLIPPreprocessor +from keras_hub.src.models.clip.clip_text_encoder import CLIPTextEncoder +from keras_hub.src.models.clip.clip_tokenizer import CLIPTokenizer +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( + StableDiffusion3Backbone, +) +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_image_to_image import ( + StableDiffusion3ImageToImage, +) +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import ( + StableDiffusion3TextToImagePreprocessor, +) +from keras_hub.src.models.vae.vae_backbone import VAEBackbone +from keras_hub.src.tests.test_case import TestCase + + +class StableDiffusion3ImageToImageTest(TestCase): + def setUp(self): + # Instantiate the preprocessor. + vocab = ["air", "plane", "port"] + vocab += ["<|endoftext|>", "<|startoftext|>"] + vocab = dict([(token, i) for i, token in enumerate(vocab)]) + merges = ["a i", "p l", "n e", "p o", "r t", "ai r", "pl a"] + merges += ["po rt", "pla ne"] + clip_l_tokenizer = CLIPTokenizer(vocab, merges, pad_with_end_token=True) + clip_g_tokenizer = CLIPTokenizer(vocab, merges) + clip_l_preprocessor = CLIPPreprocessor(clip_l_tokenizer) + clip_g_preprocessor = CLIPPreprocessor(clip_g_tokenizer) + self.preprocessor = StableDiffusion3TextToImagePreprocessor( + clip_l_preprocessor, clip_g_preprocessor + ) + + self.backbone = StableDiffusion3Backbone( + mmdit_patch_size=2, + mmdit_hidden_dim=16 * 2, + mmdit_num_layers=2, + mmdit_num_heads=2, + mmdit_position_size=192, + vae=VAEBackbone( + [32, 32, 32, 32], + [1, 1, 1, 1], + [32, 32, 32, 32], + [1, 1, 1, 1], + # Use `mode` generate a deterministic output. + sampler_method="mode", + name="vae", + ), + clip_l=CLIPTextEncoder( + 20, 64, 64, 2, 2, 128, "quick_gelu", -2, name="clip_l" + ), + clip_g=CLIPTextEncoder( + 20, 128, 128, 2, 2, 256, "gelu", -2, name="clip_g" + ), + height=64, + width=64, + ) + self.init_kwargs = { + "preprocessor": self.preprocessor, + "backbone": self.backbone, + } + self.input_data = { + "images": ops.ones((2, 64, 64, 3)), + "latents": ops.ones((2, 8, 8, 16)), + "clip_l_token_ids": ops.ones((2, 5), dtype="int32"), + "clip_l_negative_token_ids": ops.ones((2, 5), dtype="int32"), + "clip_g_token_ids": ops.ones((2, 5), dtype="int32"), + "clip_g_negative_token_ids": ops.ones((2, 5), dtype="int32"), + "num_steps": ops.ones((2,), dtype="int32"), + "guidance_scale": ops.ones((2,)), + } + + def test_text_to_image_basics(self): + pytest.skip( + reason="TODO: enable after preprocessor flow is figured out" + ) + self.run_task_test( + cls=StableDiffusion3ImageToImage, + init_kwargs=self.init_kwargs, + train_data=None, + expected_output_shape={ + "images": (2, 64, 64, 3), + "latents": (2, 8, 8, 16), + }, + ) + + def test_generate(self): + image_to_image = StableDiffusion3ImageToImage(**self.init_kwargs) + seed = 42 + image = self.input_data["images"][0] + # String input. + prompt = ["airplane"] + negative_prompt = [""] + output = image_to_image.generate( + image, prompt, negative_prompt, seed=seed + ) + # Int tensor input. + prompt_ids = self.preprocessor.generate_preprocess(prompt) + negative_prompt_ids = self.preprocessor.generate_preprocess( + negative_prompt + ) + image_to_image.preprocessor = None + output2 = image_to_image.generate( + image, prompt_ids, negative_prompt_ids, seed=seed + ) + self.assertAllClose(output, output2) + + def test_generate_with_lower_precision(self): + original_floatx = keras.config.floatx() + try: + for dtype in ["float16", "bfloat16"]: + keras.config.set_floatx(dtype) + image_to_image = StableDiffusion3ImageToImage( + **self.init_kwargs + ) + seed = 42 + image = self.input_data["images"][0] + # String input. + prompt = ["airplane"] + negative_prompt = [""] + output = image_to_image.generate( + image, prompt, negative_prompt, seed=seed + ) + # Int tensor input. + prompt_ids = self.preprocessor.generate_preprocess(prompt) + negative_prompt_ids = self.preprocessor.generate_preprocess( + negative_prompt + ) + image_to_image.preprocessor = None + output2 = image_to_image.generate( + image, prompt_ids, negative_prompt_ids, seed=seed + ) + self.assertAllClose(output, output2) + finally: + # Restore floatx to the original value to prevent impact on other + # tests even if there is an exception. + keras.config.set_floatx(original_floatx) + + def test_generate_compilation(self): + image_to_image = StableDiffusion3ImageToImage(**self.init_kwargs) + image = self.input_data["images"][0] + # Assert we do not recompile with successive calls. + image_to_image.generate(image, "airplane") + first_fn = image_to_image.generate_function + image_to_image.generate(image, "airplane") + second_fn = image_to_image.generate_function + self.assertEqual(first_fn, second_fn) + # Assert we do recompile after compile is called. + image_to_image.compile() + self.assertIsNone(image_to_image.generate_function) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=StableDiffusion3ImageToImage, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_hub/src/tokenizers/tokenizer.py b/keras_hub/src/tokenizers/tokenizer.py index b97efae444..5e8986a89e 100644 --- a/keras_hub/src/tokenizers/tokenizer.py +++ b/keras_hub/src/tokenizers/tokenizer.py @@ -66,7 +66,7 @@ def detokenize(self, inputs): backbone_cls = None def __init__(self, *args, **kwargs): - self.config_name = kwargs.pop("config_name", TOKENIZER_CONFIG_FILE) + self.config_file = kwargs.pop("config_file", TOKENIZER_CONFIG_FILE) super().__init__(*args, **kwargs) self.file_assets = None @@ -178,7 +178,7 @@ def get_config(self): config = super().get_config() config.update( { - "config_name": self.config_name, + "config_file": self.config_file, } ) return config @@ -199,11 +199,11 @@ def call(self, inputs, *args, training=None, **kwargs): def load_preset_assets(self, preset): asset_path = None for asset in self.file_assets: - subdir = self.config_name.split(".")[0] + subdir = self.config_file.split(".")[0] preset_path = os.path.join(ASSET_DIR, subdir, asset) asset_path = get_file(preset, preset_path) - tokenizer_config_name = os.path.dirname(asset_path) - self.load_assets(tokenizer_config_name) + tokenizer_config_file = os.path.dirname(asset_path) + self.load_assets(tokenizer_config_file) @classproperty def presets(cls): @@ -214,7 +214,7 @@ def presets(cls): def from_preset( cls, preset, - config_name=TOKENIZER_CONFIG_FILE, + config_file=TOKENIZER_CONFIG_FILE, **kwargs, ): """Instantiate a `keras_hub.models.Tokenizer` from a model preset. @@ -260,4 +260,4 @@ class like `keras_hub.models.Tokenizer.from_preset()`, or from backbone_cls = loader.check_backbone_class() if cls.backbone_cls != backbone_cls: cls = find_subclass(preset, cls, backbone_cls) - return loader.load_tokenizer(cls, config_name, **kwargs) + return loader.load_tokenizer(cls, config_file, **kwargs) diff --git a/keras_hub/src/utils/preset_utils.py b/keras_hub/src/utils/preset_utils.py index 65af19df7f..261d1eda50 100644 --- a/keras_hub/src/utils/preset_utils.py +++ b/keras_hub/src/utils/preset_utils.py @@ -578,7 +578,7 @@ def load_backbone(self, cls, load_weights, **kwargs): """Load the backbone model from the preset.""" raise NotImplementedError - def load_tokenizer(self, cls, config_name=TOKENIZER_CONFIG_FILE, **kwargs): + def load_tokenizer(self, cls, config_file=TOKENIZER_CONFIG_FILE, **kwargs): """Load a tokenizer layer from the preset.""" raise NotImplementedError @@ -609,7 +609,7 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs): return cls(**kwargs) def load_preprocessor( - self, cls, config_name=PREPROCESSOR_CONFIG_FILE, **kwargs + self, cls, config_file=PREPROCESSOR_CONFIG_FILE, **kwargs ): """Load a prepocessor layer from the preset. @@ -632,8 +632,8 @@ def load_backbone(self, cls, load_weights, **kwargs): backbone.load_weights(get_file(self.preset, MODEL_WEIGHTS_FILE)) return backbone - def load_tokenizer(self, cls, config_name=TOKENIZER_CONFIG_FILE, **kwargs): - tokenizer_config = load_json(self.preset, config_name) + def load_tokenizer(self, cls, config_file=TOKENIZER_CONFIG_FILE, **kwargs): + tokenizer_config = load_json(self.preset, config_file) tokenizer = load_serialized_object(tokenizer_config, **kwargs) if hasattr(tokenizer, "load_preset_assets"): tokenizer.load_preset_assets(self.preset) @@ -678,13 +678,13 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs): return task def load_preprocessor( - self, cls, config_name=PREPROCESSOR_CONFIG_FILE, **kwargs + self, cls, config_file=PREPROCESSOR_CONFIG_FILE, **kwargs ): # If there is no `preprocessing.json` or it's for the wrong class, # delegate to the super class loader. - if not check_file_exists(self.preset, config_name): + if not check_file_exists(self.preset, config_file): return super().load_preprocessor(cls, **kwargs) - preprocessor_json = load_json(self.preset, config_name) + preprocessor_json = load_json(self.preset, config_file) if not issubclass(check_config_class(preprocessor_json), cls): return super().load_preprocessor(cls, **kwargs) # We found a `preprocessing.json` with a complete config for our class. diff --git a/tools/checkpoint_conversion/convert_stable_diffusion_3_checkpoints.py b/tools/checkpoint_conversion/convert_stable_diffusion_3_checkpoints.py index 15b9691532..51b9082ccf 100644 --- a/tools/checkpoint_conversion/convert_stable_diffusion_3_checkpoints.py +++ b/tools/checkpoint_conversion/convert_stable_diffusion_3_checkpoints.py @@ -130,23 +130,23 @@ def convert_preprocessor(): vocabulary, merges, pad_with_end_token=True, - config_name="clip_l_tokenizer.json", + config_file="clip_l_tokenizer.json", name="clip_l_tokenizer", ) clip_g_tokenizer = CLIPTokenizer( vocabulary, merges, - config_name="clip_g_tokenizer.json", + config_file="clip_g_tokenizer.json", name="clip_g_tokenizer", ) clip_l_preprocessor = CLIPPreprocessor( clip_l_tokenizer, - config_name="clip_l_preprocessor.json", + config_file="clip_l_preprocessor.json", name="clip_l_preprocessor", ) clip_g_preprocessor = CLIPPreprocessor( clip_g_tokenizer, - config_name="clip_g_preprocessor.json", + config_file="clip_g_preprocessor.json", name="clip_g_preprocessor", ) preprocessor = StableDiffusion3TextToImagePreprocessor( @@ -310,19 +310,19 @@ def port_diffuser(preset, filename, model): ) port_dense(loader, model.context_embedding, "context_embedder") port_dense( - loader, model.vector_embedding.layers[0], "y_embedder.mlp.0" + loader, model.vector_embedding.dense1, "y_embedder.mlp.0" ) port_dense( - loader, model.vector_embedding.layers[1], "y_embedder.mlp.2" + loader, model.vector_embedding.dense2, "y_embedder.mlp.2" ) port_dense( loader, - model.timestep_embedding.mlp.layers[0], + model.timestep_embedding.mlp.dense1, "t_embedder.mlp.0", ) port_dense( loader, - model.timestep_embedding.mlp.layers[1], + model.timestep_embedding.mlp.dense2, "t_embedder.mlp.2", ) @@ -338,7 +338,7 @@ def port_diffuser(preset, filename, model): prefix = f"joint_blocks.{i}.{block_name}" port_dense( loader, - block.adaptive_norm_modulation.layers[1], + block.ada_layer_norm.dense, f"{prefix}.adaLN_modulation.1", ) port_dense( @@ -351,18 +351,16 @@ def port_diffuser(preset, filename, model): port_dense( loader, block.attention_proj, f"{prefix}.attn.proj" ) - port_dense(loader, block.mlp.layers[0], f"{prefix}.mlp.fc1") - port_dense(loader, block.mlp.layers[1], f"{prefix}.mlp.fc2") + port_dense(loader, block.mlp.dense1, f"{prefix}.mlp.fc1") + port_dense(loader, block.mlp.dense2, f"{prefix}.mlp.fc2") # Output layer port_dense( loader, - model.output_layer.adaptive_norm_modulation.layers[1], + model.output_ada_layer_norm.dense, "final_layer.adaLN_modulation.1", ) - port_dense( - loader, model.output_layer.output_dense, "final_layer.linear" - ) + port_dense(loader, model.output_dense, "final_layer.linear") return model def port_vae(preset, filename, model): From a7cc7f2197ebd9cfc2f3fe25f2dd7b1ba5da2691 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Sat, 5 Oct 2024 19:25:41 +0800 Subject: [PATCH 2/8] Update model version --- .../src/models/stable_diffusion_3/stable_diffusion_3_presets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py index 2067fdb8dc..af55ac276a 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py @@ -13,6 +13,6 @@ "path": "stablediffusion3", "model_card": "https://arxiv.org/abs/2110.00476", }, - "kaggle_handle": "kaggle://kerashub/stablediffusion3/keras/stable_diffusion_3_medium/3", + "kaggle_handle": "kaggle://kerashub/stablediffusion3/keras/stable_diffusion_3_medium/4", } } From da16e670e85a72527427df4c6f30b2108f241327 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Sat, 5 Oct 2024 20:24:38 +0800 Subject: [PATCH 3/8] Fix minor bugs. --- keras_hub/src/models/image_to_image.py | 5 +++++ keras_hub/src/models/stable_diffusion_3/mmdit.py | 12 ++++++++++++ .../stable_diffusion_3_image_to_image.py | 4 ++-- .../stable_diffusion_3_image_to_image_test.py | 2 +- 4 files changed, 20 insertions(+), 3 deletions(-) diff --git a/keras_hub/src/models/image_to_image.py b/keras_hub/src/models/image_to_image.py index 2139b1af5d..99dda6993f 100644 --- a/keras_hub/src/models/image_to_image.py +++ b/keras_hub/src/models/image_to_image.py @@ -268,6 +268,11 @@ def generate( ): """Generate image based on the provided `images` and `inputs`. + The `images` are reference images that will be resized to + `self.backbone.height` and `self.backbone.width`, then encoded into + latent space by the VAE encoder. The `inputs` are strings that will be + tokenized and encoded by the text encoder. + If `images` and `inputs` are a `tf.data.Dataset`, outputs will be generated "batch-by-batch" and concatenated. Otherwise, all inputs will be processed as batches. diff --git a/keras_hub/src/models/stable_diffusion_3/mmdit.py b/keras_hub/src/models/stable_diffusion_3/mmdit.py index 0a618a427c..722bfdf273 100644 --- a/keras_hub/src/models/stable_diffusion_3/mmdit.py +++ b/keras_hub/src/models/stable_diffusion_3/mmdit.py @@ -252,6 +252,17 @@ def call(self, inputs, height=None, width=None): position_embedding = ops.expand_dims(position_embedding, axis=0) return position_embedding + def get_config(self): + config = super().get_config() + del config["sequence_length"] + config.update( + { + "height": self.height, + "width": self.width, + } + ) + return config + def compute_output_shape(self, input_shape): return input_shape @@ -321,6 +332,7 @@ def get_config(self): config.update( { "embedding_dim": self.embedding_dim, + "frequency_dim": self.frequency_dim, "max_period": self.max_period, } ) diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py index 7a6714c65b..3d551eb5ef 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py @@ -14,8 +14,8 @@ class StableDiffusion3ImageToImage(ImageToImage): """An end-to-end Stable Diffusion 3 model for image-to-image generation. - This model has a `generate()` method, which generates image based on a pair - of image and prompt. + This model has a `generate()` method, which generates images based + on a combination of a reference image and a text prompt. Args: backbone: A `keras_hub.models.StableDiffusion3Backbone` instance. diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image_test.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image_test.py index 1f9d4c19d6..7374ea8e8c 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image_test.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image_test.py @@ -73,7 +73,7 @@ def setUp(self): "guidance_scale": ops.ones((2,)), } - def test_text_to_image_basics(self): + def test_image_to_image_basics(self): pytest.skip( reason="TODO: enable after preprocessor flow is figured out" ) From 8aa738851c0265ba3f7247aa8cc19a0a77888a27 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Sun, 6 Oct 2024 23:13:22 +0800 Subject: [PATCH 4/8] Add `Inpaint` for SD3. --- keras_hub/api/models/__init__.py | 4 + keras_hub/src/models/image_to_image.py | 9 +- keras_hub/src/models/inpaint.py | 398 ++++++++++++++++++ .../src/models/stable_diffusion_3/mmdit.py | 51 ++- .../stable_diffusion_3_image_to_image.py | 2 +- .../stable_diffusion_3_inpaint.py | 200 +++++++++ .../stable_diffusion_3_inpaint_test.py | 162 +++++++ 7 files changed, 794 insertions(+), 32 deletions(-) create mode 100644 keras_hub/src/models/inpaint.py create mode 100644 keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py create mode 100644 keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint_test.py diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 9983e1a8ea..5a6359da2a 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -181,6 +181,7 @@ ImageSegmenterPreprocessor, ) from keras_hub.src.models.image_to_image import ImageToImage +from keras_hub.src.models.inpaint import Inpaint from keras_hub.src.models.llama3.llama3_backbone import Llama3Backbone from keras_hub.src.models.llama3.llama3_causal_lm import Llama3CausalLM from keras_hub.src.models.llama3.llama3_causal_lm_preprocessor import ( @@ -274,6 +275,9 @@ from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_image_to_image import ( StableDiffusion3ImageToImage, ) +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_inpaint import ( + StableDiffusion3Inpaint, +) from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image import ( StableDiffusion3TextToImage, ) diff --git a/keras_hub/src/models/image_to_image.py b/keras_hub/src/models/image_to_image.py index 99dda6993f..8f92b66031 100644 --- a/keras_hub/src/models/image_to_image.py +++ b/keras_hub/src/models/image_to_image.py @@ -268,10 +268,11 @@ def generate( ): """Generate image based on the provided `images` and `inputs`. - The `images` are reference images that will be resized to - `self.backbone.height` and `self.backbone.width`, then encoded into - latent space by the VAE encoder. The `inputs` are strings that will be - tokenized and encoded by the text encoder. + The `images` are reference images within a value range of `[-1.0, 1.0]`, + which will be resized to `self.backbone.height` and + `self.backbone.width`, then encoded into latent space by the VAE + encoder. The `inputs` are strings that will be tokenized and encoded by + the text encoder. If `images` and `inputs` are a `tf.data.Dataset`, outputs will be generated "batch-by-batch" and concatenated. Otherwise, all inputs will diff --git a/keras_hub/src/models/inpaint.py b/keras_hub/src/models/inpaint.py new file mode 100644 index 0000000000..013dba162b --- /dev/null +++ b/keras_hub/src/models/inpaint.py @@ -0,0 +1,398 @@ +import itertools +from functools import partial + +import keras +from keras import ops +from keras import random + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.task import Task +from keras_hub.src.utils.keras_utils import standardize_data_format + +try: + import tensorflow as tf +except ImportError: + tf = None + + +@keras_hub_export("keras_hub.models.Inpaint") +class Inpaint(Task): + """Base class for image-to-image tasks. + + `Inpaint` tasks wrap a `keras_hub.models.Backbone` and + a `keras_hub.models.Preprocessor` to create a model that can be used for + generation and generative fine-tuning. + + `Inpaint` tasks provide an additional, high-level `generate()` function + which can be used to generate image by token with a (image, mask, string) + in, image out signature. + + All `Inpaint` tasks include a `from_preset()` constructor which can be + used to load a pre-trained config and weights. + + Example: + + ```python + # Load a Stable Diffusion 3 backbone with pre-trained weights. + reference_image = np.ones((1024, 1024, 3), dtype="float32") + reference_mask = np.ones((1024, 1024), dtype="float32") + inpaint = keras_hub.models.Inpaint.from_preset( + "stable_diffusion_3_medium", + ) + inpaint.generate( + reference_image, + reference_mask, + "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + ) + + # Load a Stable Diffusion 3 backbone at bfloat16 precision. + inpaint = keras_hub.models.Inpaint.from_preset( + "stable_diffusion_3_medium", + dtype="bfloat16", + ) + inpaint.generate( + reference_image, + reference_mask, + "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + ) + ``` + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Default compilation. + self.compile() + + @property + def image_shape(self): + return tuple(self.backbone.image_shape) + + @property + def latent_shape(self): + return tuple(self.backbone.latent_shape) + + def compile( + self, + optimizer="auto", + loss="auto", + *, + metrics="auto", + **kwargs, + ): + """Configures the `Inpaint` task for training. + + The `Inpaint` task extends the default compilation signature of + `keras.Model.compile` with defaults for `optimizer`, `loss`, and + `metrics`. To override these defaults, pass any value + to these arguments during compilation. + + Args: + optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer` + instance. Defaults to `"auto"`, which uses the default optimizer + for the given model and task. See `keras.Model.compile` and + `keras.optimizers` for more info on possible `optimizer` values. + loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance. + Defaults to `"auto"`, where a + `keras.losses.MeanSquaredError` loss will be applied. See + `keras.Model.compile` and `keras.losses` for more info on + possible `loss` values. + metrics: `"auto"`, or a list of metrics to be evaluated by + the model during training and testing. Defaults to `"auto"`, + where a `keras.metrics.MeanSquaredError` will be applied to + track the loss of the model during training. See + `keras.Model.compile` and `keras.metrics` for more info on + possible `metrics` values. + **kwargs: See `keras.Model.compile` for a full list of arguments + supported by the compile method. + """ + # Ref: https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py#L410-L414 + if optimizer == "auto": + optimizer = keras.optimizers.AdamW( + 1e-4, weight_decay=1e-2, epsilon=1e-8, clipnorm=1.0 + ) + if loss == "auto": + loss = keras.losses.MeanSquaredError() + if metrics == "auto": + metrics = [keras.metrics.MeanSquaredError()] + super().compile( + optimizer=optimizer, + loss=loss, + metrics=metrics, + **kwargs, + ) + self.generate_function = None + + def generate_step(self, *args, **kwargs): + """Run generation on batches of input.""" + raise NotImplementedError + + def make_generate_function(self): + """Create or return the compiled generation function.""" + if self.generate_function is not None: + return self.generate_function + + self.generate_function = self.generate_step + if keras.config.backend() == "torch": + import torch + + def wrapped_function(*args, **kwargs): + with torch.no_grad(): + return self.generate_step(*args, **kwargs) + + self.generate_function = wrapped_function + elif keras.config.backend() == "tensorflow" and not self.run_eagerly: + self.generate_function = tf.function( + self.generate_step, jit_compile=self.jit_compile + ) + elif keras.config.backend() == "jax" and not self.run_eagerly: + import jax + + @partial(jax.jit) + def compiled_function(state, *args, **kwargs): + ( + trainable_variables, + non_trainable_variables, + ) = state + mapping = itertools.chain( + zip(self.trainable_variables, trainable_variables), + zip(self.non_trainable_variables, non_trainable_variables), + ) + + with keras.StatelessScope(state_mapping=mapping): + outputs = self.generate_step(*args, **kwargs) + return outputs + + def wrapped_function(*args, **kwargs): + # Create an explicit tuple of all variable state. + state = ( + # Use the explicit variable.value to preserve the + # sharding spec of distribution. + [v.value for v in self.trainable_variables], + [v.value for v in self.non_trainable_variables], + ) + outputs = compiled_function(state, *args, **kwargs) + return outputs + + self.generate_function = wrapped_function + return self.generate_function + + def _normalize_generate_images(self, inputs): + """Normalize user image to the generate function. + + This function converts all inputs to tensors, adds a batch dimension if + necessary, and returns a iterable "dataset like" object (either an + actual `tf.data.Dataset` or a list with a single batch element). + """ + if tf and isinstance(inputs, tf.data.Dataset): + return inputs.as_numpy_iterator(), False + + def normalize(x): + data_format = getattr( + self.backbone, "data_format", standardize_data_format(None) + ) + input_is_scalar = False + x = ops.convert_to_tensor(x) + if len(ops.shape(x)) < 4: + x = ops.expand_dims(x, axis=0) + input_is_scalar = True + x = ops.image.resize( + x, + (self.backbone.height, self.backbone.width), + interpolation="nearest", + data_format=data_format, + ) + return x, input_is_scalar + + if isinstance(inputs, dict): + for key in inputs: + inputs[key], input_is_scalar = normalize(inputs[key]) + else: + inputs, input_is_scalar = normalize(inputs) + + return inputs, input_is_scalar + + def _normalize_generate_masks(self, inputs): + """Normalize user masks to the generate function. + + This function converts all inputs to tensors, adds a batch dimension if + necessary, and returns a iterable "dataset like" object (either an + actual `tf.data.Dataset` or a list with a single batch element). + """ + if tf and isinstance(inputs, tf.data.Dataset): + return inputs.as_numpy_iterator(), False + + def normalize(x): + data_format = getattr( + self.backbone, "data_format", standardize_data_format(None) + ) + input_is_scalar = False + x = ops.convert_to_tensor(x) + if len(ops.shape(x)) < 3: + x = ops.expand_dims(x, axis=0) + input_is_scalar = True + x = ops.expand_dims(x, axis=-1) + if keras.backend.standardize_dtype(x.dtype) == "bool": + x = ops.cast(x, "float32") + x = ops.image.resize( + x, + (self.backbone.height, self.backbone.width), + interpolation="nearest", + data_format=data_format, + ) + x = ops.squeeze(x, axis=-1) + return x, input_is_scalar + + if isinstance(inputs, dict): + for key in inputs: + inputs[key], input_is_scalar = normalize(inputs[key]) + else: + inputs, input_is_scalar = normalize(inputs) + + return inputs, input_is_scalar + + def _normalize_generate_inputs(self, inputs): + """Normalize user input to the generate function. + + This function converts all inputs to tensors, adds a batch dimension if + necessary, and returns a iterable "dataset like" object (either an + actual `tf.data.Dataset` or a list with a single batch element). + """ + if tf and isinstance(inputs, tf.data.Dataset): + return inputs.as_numpy_iterator(), False + + def normalize(x): + if isinstance(x, str): + return [x], True + if tf and isinstance(x, tf.Tensor) and x.shape.rank == 0: + return x[tf.newaxis], True + return x, False + + if isinstance(inputs, dict): + for key in inputs: + inputs[key], input_is_scalar = normalize(inputs[key]) + else: + inputs, input_is_scalar = normalize(inputs) + + return inputs, input_is_scalar + + def _normalize_generate_outputs(self, outputs, input_is_scalar): + """Normalize user output from the generate function. + + This function converts all output to numpy with a value range of + `[0, 255]`. If a batch dimension was added to the input, it is removed + from the output. + """ + + def normalize(x): + outputs = ops.clip(ops.divide(ops.add(x, 1.0), 2.0), 0.0, 1.0) + outputs = ops.cast(ops.round(ops.multiply(outputs, 255.0)), "uint8") + outputs = ops.convert_to_numpy(outputs) + if input_is_scalar: + outputs = outputs[0] + return outputs + + if isinstance(outputs[0], dict): + normalized = {} + for key in outputs[0]: + normalized[key] = normalize([x[key] for x in outputs]) + return normalized + return normalize([x for x in outputs]) + + def generate( + self, + images, + masks, + inputs, + negative_inputs, + num_steps, + guidance_scale, + strength, + seed=None, + ): + """Generate image based on the provided `images`, `masks` and `inputs`. + + The `images` are reference images within a value range of `[-1.0, 1.0]`, + which will be resized to `self.backbone.height` and + `self.backbone.width`, then encoded into latent space by the VAE + encoder. The `masks` are mask images with a boolean dtype, where white + pixels are repainted while black pixels are preserved. The `inputs` are + strings that will be tokenized and encoded by the text encoder. + + If `images`, `masks` and `inputs` are a `tf.data.Dataset`, outputs will + be generated "batch-by-batch" and concatenated. Otherwise, all inputs + will be processed as batches. + + Args: + images: python data, tensor data, or a `tf.data.Dataset`. + masks: python data, tensor data, or a `tf.data.Dataset`. + inputs: python data, tensor data, or a `tf.data.Dataset`. + negative_inputs: python data, tensor data, or a `tf.data.Dataset`. + Unlike `inputs`, these are used as negative inputs to guide the + generation. If not provided, it defaults to `""` for each input + in `inputs`. + num_steps: int. The number of diffusion steps to take. + guidance_scale: float. The classifier free guidance scale defined in + [Classifier-Free Diffusion Guidance]( + https://arxiv.org/abs/2207.12598). A higher scale encourages + generating images more closely related to the prompts, typically + at the cost of lower image quality. + strength: float. Indicates the extent to which the reference + `images` are transformed. Must be between `0.0` and `1.0`. When + `strength=1.0`, `images` is essentially ignore and added noise + is maximum and the denoising process runs for the full number of + iterations specified in `num_steps`. + seed: optional int. Used as a random seed. + """ + num_steps = int(num_steps) + guidance_scale = float(guidance_scale) + strength = float(strength) + if strength < 0.0 or strength > 1.0: + raise ValueError( + "`strength` must be between `0.0` and `1.0`. " + f"Received strength={strength}." + ) + + # Setup our three main passes. + # 1. Preprocessing strings to dense integer tensors. + # 2. Generate outputs via a compiled function on dense tensors. + # 3. Postprocess dense tensors to a value range of `[0, 255]`. + generate_function = self.make_generate_function() + + def preprocess(x): + return self.preprocessor.generate_preprocess(x) + + # Normalize and preprocess inputs. + images, image_is_scalar = self._normalize_generate_images(images) + masks, _ = self._normalize_generate_masks(masks) + inputs, _ = self._normalize_generate_inputs(inputs) + if negative_inputs is None: + negative_inputs = [""] * len(inputs) + negative_inputs, _ = self._normalize_generate_inputs(negative_inputs) + + if self.preprocessor is not None: + inputs = preprocess(inputs) + negative_inputs = preprocess(negative_inputs) + if isinstance(inputs, dict): + batch_size = len(inputs[list(inputs.keys())[0]]) + else: + batch_size = len(inputs) + + # Get the starting step for denoising. + starting_step = int(num_steps * (1.0 - strength)) + + # Initialize random noises. + noise_shape = (batch_size,) + self.latent_shape[1:] + noises = random.normal(noise_shape, dtype="float32", seed=seed) + + # Inpaint. + outputs = generate_function( + ops.convert_to_tensor(images), + ops.convert_to_tensor(masks), + noises, + inputs, + negative_inputs, + ops.convert_to_tensor(starting_step, "int32"), + ops.convert_to_tensor(num_steps, "int32"), + ops.convert_to_tensor(guidance_scale), + ) + return self._normalize_generate_outputs(outputs, image_is_scalar) diff --git a/keras_hub/src/models/stable_diffusion_3/mmdit.py b/keras_hub/src/models/stable_diffusion_3/mmdit.py index 722bfdf273..546d56f13a 100644 --- a/keras_hub/src/models/stable_diffusion_3/mmdit.py +++ b/keras_hub/src/models/stable_diffusion_3/mmdit.py @@ -289,7 +289,17 @@ def __init__( self.embedding_dim = int(embedding_dim) self.frequency_dim = int(frequency_dim) self.max_period = float(max_period) - self.half_frequency_dim = self.frequency_dim // 2 + # Precomputed `freq`. + half_frequency_dim = frequency_dim // 2 + self.freq = ops.exp( + ops.divide( + ops.multiply( + -math.log(max_period), + ops.arange(0, half_frequency_dim, dtype="float32"), + ), + half_frequency_dim, + ) + ) self.mlp = MLP( embedding_dim, @@ -307,16 +317,7 @@ def build(self, inputs_shape): def _create_timestep_embedding(self, inputs): compute_dtype = keras.backend.result_type(self.compute_dtype, "float32") x = ops.cast(inputs, compute_dtype) - freqs = ops.exp( - ops.divide( - ops.multiply( - -math.log(self.max_period), - ops.arange(0, self.half_frequency_dim, dtype="float32"), - ), - self.half_frequency_dim, - ) - ) - freqs = ops.cast(freqs, compute_dtype) + freqs = ops.cast(self.freq, compute_dtype) x = ops.multiply(x, ops.expand_dims(freqs, axis=0)) embedding = ops.concatenate([ops.cos(x), ops.sin(x)], axis=-1) if self.frequency_dim % 2 != 0: @@ -537,8 +538,6 @@ def __init__( head_dim = hidden_dim // num_heads self.head_dim = head_dim self._inverse_sqrt_key_dim = 1.0 / math.sqrt(head_dim) - self._dot_product_equation = "aecd,abcd->acbe" - self._combine_equation = "acbe,aecd->abcd" self.x_block = DismantledBlock( num_heads=num_heads, @@ -563,20 +562,18 @@ def build(self, inputs_shape, context_shape, timestep_embedding_shape): self.context_block.build(context_shape, timestep_embedding_shape) def _compute_attention(self, query, key, value): - query = ops.multiply( - query, ops.cast(self._inverse_sqrt_key_dim, query.dtype) - ) - attention_scores = ops.einsum(self._dot_product_equation, key, query) - attention_scores = self.softmax(attention_scores) - attention_scores = ops.cast(attention_scores, self.compute_dtype) - attention_output = ops.einsum( - self._combine_equation, attention_scores, value - ) - batch_size = ops.shape(attention_output)[0] - attention_output = ops.reshape( - attention_output, (batch_size, -1, self.num_heads * self.head_dim) - ) - return attention_output + # Ref: jax.nn.dot_product_attention + # https://github.com/jax-ml/jax/blob/db89c245ac66911c98f265a05956fdfa4bc79d83/jax/_src/nn/functions.py#L846 + batch_size = ops.shape(query)[0] + logits = ops.einsum("BTNH,BSNH->BNTS", query, key) + logits = ops.multiply(logits, self._inverse_sqrt_key_dim) + probs = self.softmax(logits) + probs = ops.cast(probs, self.compute_dtype) + encoded = ops.einsum("BNTS,BSNH->BTNH", probs, value) + encoded = ops.reshape( + encoded, (batch_size, -1, self.num_heads * self.head_dim) + ) + return encoded def call(self, inputs, context, timestep_embedding, training=None): # Compute pre-attention. diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py index 3d551eb5ef..9a8372be52 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py @@ -147,7 +147,7 @@ def generate( images, inputs, negative_inputs=None, - num_steps=28, + num_steps=50, guidance_scale=7.0, strength=0.8, seed=None, diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py new file mode 100644 index 0000000000..2b9e21636f --- /dev/null +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py @@ -0,0 +1,200 @@ +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.inpaint import Inpaint +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( + StableDiffusion3Backbone, +) +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import ( + StableDiffusion3TextToImagePreprocessor, +) + + +@keras_hub_export("keras_hub.models.StableDiffusion3Inpaint") +class StableDiffusion3Inpaint(Inpaint): + """An end-to-end Stable Diffusion 3 model for inpaint generation. + + This model has a `generate()` method, which generates images based + on a combination of a reference image, mask and a text prompt. + + Args: + backbone: A `keras_hub.models.StableDiffusion3Backbone` instance. + preprocessor: A + `keras_hub.models.StableDiffusion3TextToImagePreprocessor` instance. + + Examples: + + Use `generate()` to do image generation. + ```python + reference_image = np.ones((1024, 1024, 3), dtype="float32") + reference_mask = np.ones((1024, 1024), dtype="float32") + inpaint = keras_hub.models.StableDiffusion3Inpaint.from_preset( + "stable_diffusion_3_medium", height=512, width=512 + ) + inpaint.generate( + reference_image, + reference_mask, + "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + ) + + # Generate with batched prompts. + reference_images = np.ones((2, 512, 512, 3), dtype="float32") + reference_mask = np.ones((2, 1024, 1024), dtype="float32") + inpaint.generate( + reference_images, + reference_mask, + ["cute wallpaper art of a cat", "cute wallpaper art of a dog"] + ) + + # Generate with different `num_steps`, `guidance_scale` and `strength`. + inpaint.generate( + reference_image, + reference_mask, + "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + num_steps=50, + guidance_scale=5.0, + strength=0.6, + ) + ``` + """ + + backbone_cls = StableDiffusion3Backbone + preprocessor_cls = StableDiffusion3TextToImagePreprocessor + + def __init__( + self, + backbone, + preprocessor, + **kwargs, + ): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + + # === Functional Model === + inputs = backbone.input + outputs = backbone.output + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + def fit(self, *args, **kwargs): + raise NotImplementedError( + "Currently, `fit` is not supported for " + "`StableDiffusion3Inpaint`." + ) + + def generate_step( + self, + images, + masks, + noises, + token_ids, + negative_token_ids, + starting_step, + num_steps, + guidance_scale, + ): + """A compilable generation function for batched of inputs. + + This function represents the inner, XLA-compilable, generation function + for batched inputs. + + Args: + images: A (batch_size, image_height, image_width, 3) tensor + containing the reference images. + masks: A (batch_size, image_height, image_width, 1 or 3) tensor + containing the reference masks. + noises: A (batch_size, latent_height, latent_width, channels) tensor + containing the noises to be added to the latents. Typically, + this tensor is sampled from the Gaussian distribution. + token_ids: A (batch_size, num_tokens) tensor containing the + tokens based on the input prompts. + negative_token_ids: A (batch_size, num_tokens) tensor + containing the negative tokens based on the input prompts. + starting_step: int. The number of the starting diffusion step. + num_steps: int. The number of diffusion steps to take. + guidance_scale: float. The classifier free guidance scale defined in + [Classifier-Free Diffusion Guidance]( + https://arxiv.org/abs/2207.12598). Higher scale encourages to + generate images that are closely linked to prompts, usually at + the expense of lower image quality. + """ + # Get masked images. + masks = ops.cast(ops.expand_dims(masks, axis=-1) > 0.5, images.dtype) + masks_latent_size = ops.image.resize( + masks, + (self.backbone.latent_shape[1], self.backbone.latent_shape[2]), + interpolation="nearest", + ) + + # Encode images. + image_latents = self.backbone.encode_image_step(images) + + # Add noises to latents. + latents = self.backbone.add_noise_step( + image_latents, noises, starting_step, num_steps + ) + + # Encode inputs. + embeddings = self.backbone.encode_text_step( + token_ids, negative_token_ids + ) + + # Denoise. + def body_fun(step, latents): + latents = self.backbone.denoise_step( + latents, + embeddings, + step, + num_steps, + guidance_scale, + ) + + def true_fn(): + next_step = ops.add(step, 1) + return self.backbone.add_noise_step( + image_latents, noises, next_step, num_steps + ) + + init_latents = ops.cond( + step < ops.subtract(num_steps, 1), + true_fn, + lambda: ops.cast(image_latents, noises.dtype), + ) + latents = ops.add( + ops.multiply( + ops.subtract(1.0, masks_latent_size), init_latents + ), + ops.multiply(masks_latent_size, latents), + ) + return latents + + latents = ops.fori_loop(starting_step, num_steps, body_fun, latents) + + # Decode. + return self.backbone.decode_step(latents) + + def generate( + self, + images, + masks, + inputs, + negative_inputs=None, + num_steps=50, + guidance_scale=7.0, + strength=0.6, + seed=None, + ): + return super().generate( + images, + masks, + inputs, + negative_inputs=negative_inputs, + num_steps=num_steps, + guidance_scale=guidance_scale, + strength=strength, + seed=seed, + ) diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint_test.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint_test.py new file mode 100644 index 0000000000..2d5e37c7c5 --- /dev/null +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint_test.py @@ -0,0 +1,162 @@ +import keras +import pytest +from keras import ops + +from keras_hub.src.models.clip.clip_preprocessor import CLIPPreprocessor +from keras_hub.src.models.clip.clip_text_encoder import CLIPTextEncoder +from keras_hub.src.models.clip.clip_tokenizer import CLIPTokenizer +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( + StableDiffusion3Backbone, +) +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_inpaint import ( + StableDiffusion3Inpaint, +) +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import ( + StableDiffusion3TextToImagePreprocessor, +) +from keras_hub.src.models.vae.vae_backbone import VAEBackbone +from keras_hub.src.tests.test_case import TestCase + + +class StableDiffusion3InpaintTest(TestCase): + def setUp(self): + # Instantiate the preprocessor. + vocab = ["air", "plane", "port"] + vocab += ["<|endoftext|>", "<|startoftext|>"] + vocab = dict([(token, i) for i, token in enumerate(vocab)]) + merges = ["a i", "p l", "n e", "p o", "r t", "ai r", "pl a"] + merges += ["po rt", "pla ne"] + clip_l_tokenizer = CLIPTokenizer(vocab, merges, pad_with_end_token=True) + clip_g_tokenizer = CLIPTokenizer(vocab, merges) + clip_l_preprocessor = CLIPPreprocessor(clip_l_tokenizer) + clip_g_preprocessor = CLIPPreprocessor(clip_g_tokenizer) + self.preprocessor = StableDiffusion3TextToImagePreprocessor( + clip_l_preprocessor, clip_g_preprocessor + ) + + self.backbone = StableDiffusion3Backbone( + mmdit_patch_size=2, + mmdit_hidden_dim=16 * 2, + mmdit_num_layers=2, + mmdit_num_heads=2, + mmdit_position_size=192, + vae=VAEBackbone( + [32, 32, 32, 32], + [1, 1, 1, 1], + [32, 32, 32, 32], + [1, 1, 1, 1], + # Use `mode` generate a deterministic output. + sampler_method="mode", + name="vae", + ), + clip_l=CLIPTextEncoder( + 20, 64, 64, 2, 2, 128, "quick_gelu", -2, name="clip_l" + ), + clip_g=CLIPTextEncoder( + 20, 128, 128, 2, 2, 256, "gelu", -2, name="clip_g" + ), + height=64, + width=64, + ) + self.init_kwargs = { + "preprocessor": self.preprocessor, + "backbone": self.backbone, + } + self.input_data = { + "images": ops.ones((2, 64, 64, 3)), + "latents": ops.ones((2, 8, 8, 16)), + "clip_l_token_ids": ops.ones((2, 5), dtype="int32"), + "clip_l_negative_token_ids": ops.ones((2, 5), dtype="int32"), + "clip_g_token_ids": ops.ones((2, 5), dtype="int32"), + "clip_g_negative_token_ids": ops.ones((2, 5), dtype="int32"), + "num_steps": ops.ones((2,), dtype="int32"), + "guidance_scale": ops.ones((2,)), + } + + def test_inpaint_basics(self): + pytest.skip( + reason="TODO: enable after preprocessor flow is figured out" + ) + self.run_task_test( + cls=StableDiffusion3Inpaint, + init_kwargs=self.init_kwargs, + train_data=None, + expected_output_shape={ + "images": (2, 64, 64, 3), + "latents": (2, 8, 8, 16), + }, + ) + + def test_generate(self): + inpaint = StableDiffusion3Inpaint(**self.init_kwargs) + seed = 42 + image = self.input_data["images"][0] + mask = self.input_data["images"][0][..., 0] # (B, H, W) + # String input. + prompt = ["airplane"] + negative_prompt = [""] + output = inpaint.generate( + image, mask, prompt, negative_prompt, seed=seed + ) + # Int tensor input. + prompt_ids = self.preprocessor.generate_preprocess(prompt) + negative_prompt_ids = self.preprocessor.generate_preprocess( + negative_prompt + ) + inpaint.preprocessor = None + output2 = inpaint.generate( + image, mask, prompt_ids, negative_prompt_ids, seed=seed + ) + self.assertAllClose(output, output2) + + def test_generate_with_lower_precision(self): + original_floatx = keras.config.floatx() + try: + for dtype in ["float16", "bfloat16"]: + keras.config.set_floatx(dtype) + inpaint = StableDiffusion3Inpaint(**self.init_kwargs) + seed = 42 + image = self.input_data["images"][0] + mask = self.input_data["images"][0][..., 0] # (B, H, W) + # String input. + prompt = ["airplane"] + negative_prompt = [""] + output = inpaint.generate( + image, mask, prompt, negative_prompt, seed=seed + ) + # Int tensor input. + prompt_ids = self.preprocessor.generate_preprocess(prompt) + negative_prompt_ids = self.preprocessor.generate_preprocess( + negative_prompt + ) + inpaint.preprocessor = None + output2 = inpaint.generate( + image, mask, prompt_ids, negative_prompt_ids, seed=seed + ) + self.assertAllClose(output, output2) + finally: + # Restore floatx to the original value to prevent impact on other + # tests even if there is an exception. + keras.config.set_floatx(original_floatx) + + def test_generate_compilation(self): + inpaint = StableDiffusion3Inpaint(**self.init_kwargs) + image = self.input_data["images"][0] + mask = self.input_data["images"][0][..., 0] # (B, H, W) + # Assert we do not recompile with successive calls. + inpaint.generate(image, mask, "airplane") + first_fn = inpaint.generate_function + inpaint.generate(image, mask, "airplane") + second_fn = inpaint.generate_function + self.assertEqual(first_fn, second_fn) + # Assert we do recompile after compile is called. + inpaint.compile() + self.assertIsNone(inpaint.generate_function) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=StableDiffusion3Inpaint, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) From c7749fb3d7a3dfe45a4ad3378f88f889abec224a Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Mon, 7 Oct 2024 11:25:52 +0800 Subject: [PATCH 5/8] Fix warnings of MMDiT. --- .../stable_diffusion_3_backbone.py | 30 ++++++++++++------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py index c5930a3460..0e8287a59e 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py @@ -293,6 +293,12 @@ def __init__( name="diffuser", ) self.vae = vae + self.cfg_concat = ClassifierFreeGuidanceConcatenate( + dtype=dtype, name="classifier_free_guidance_concat" + ) + self.cfg = ClassifierFreeGuidance( + dtype=dtype, name="classifier_free_guidance" + ) # Set `dtype="float32"` to ensure the high precision for the noise # residual. self.scheduler = FlowMatchEulerDiscreteScheduler( @@ -301,17 +307,11 @@ def __init__( dtype="float32", name="scheduler", ) - self.cfg_concat = ClassifierFreeGuidanceConcatenate( - dtype="float32", name="classifier_free_guidance_concat" - ) - self.cfg = ClassifierFreeGuidance( - dtype="float32", name="classifier_free_guidance" - ) self.euler_step = EulerStep(dtype="float32", name="euler_step") self.latent_rescaling = layers.Rescaling( scale=1.0 / self.vae.scale, offset=self.vae.shift, - dtype="float32", + dtype=dtype, name="latent_rescaling", ) @@ -440,8 +440,12 @@ def encode_text_step(self, token_ids, negative_token_ids): t5_hidden_dim = self.t5_hidden_dim def encode(token_ids): - clip_l_outputs = self.clip_l(token_ids["clip_l"], training=False) - clip_g_outputs = self.clip_g(token_ids["clip_g"], training=False) + clip_l_outputs = self.clip_l( + {"token_ids": token_ids["clip_l"]}, training=False + ) + clip_g_outputs = self.clip_g( + {"token_ids": token_ids["clip_g"]}, training=False + ) clip_l_projection = self.clip_l_projection( clip_l_outputs["sequence_output"], token_ids["clip_l"], @@ -468,7 +472,13 @@ def encode(token_ids): [[0, 0], [0, 0], [0, t5_hidden_dim - clip_hidden_dim]], ) if self.t5 is not None: - t5_outputs = self.t5(token_ids["t5"], training=False) + t5_outputs = self.t5( + { + "token_ids": token_ids["t5"], + "padding_mask": ops.ones_like(token_ids["t5"]), + }, + training=False, + ) embeddings = ops.concatenate([embeddings, t5_outputs], axis=-2) else: padded_size = self.clip_l.max_sequence_length From 37c519b3a98e03b28454d51522c6315ab194b5e4 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Mon, 7 Oct 2024 11:32:34 +0800 Subject: [PATCH 6/8] Addcomment to Inpaint --- .../src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py | 1 + 1 file changed, 1 insertion(+) diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py index 2b9e21636f..1202831b21 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py @@ -153,6 +153,7 @@ def body_fun(step, latents): guidance_scale, ) + # Compute the previous latents x_t -> x_t-1. def true_fn(): next_step = ops.add(step, 1) return self.backbone.add_noise_step( From 5ff2fa19f3ae3b4086ae81285724f889f3eb06de Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Mon, 7 Oct 2024 23:17:34 +0800 Subject: [PATCH 7/8] Simplify `MMDiT` implementation and info of `summary()`. --- .../stable_diffusion_3_backbone.py | 112 +++++++++++------- 1 file changed, 67 insertions(+), 45 deletions(-) diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py index 0e8287a59e..485340fbd0 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py @@ -51,11 +51,52 @@ def compute_output_shape(self, inputs_shape): return (inputs_shape[0], self.hidden_dim) -class ClassifierFreeGuidanceConcatenate(layers.Layer): - def __init__(self, axis=0, **kwargs): - super().__init__(**kwargs) - self.axis = axis +class CLIPConcatenate(layers.Layer): + def call( + self, + clip_l_projection, + clip_g_projection, + clip_l_intermediate_output, + clip_g_intermediate_output, + padding, + ): + pooled_embeddings = ops.concatenate( + [clip_l_projection, clip_g_projection], axis=-1 + ) + embeddings = ops.concatenate( + [clip_l_intermediate_output, clip_g_intermediate_output], axis=-1 + ) + embeddings = ops.pad(embeddings, [[0, 0], [0, 0], [0, padding]]) + return pooled_embeddings, embeddings + + +class ImageRescaling(layers.Rescaling): + """Rescales inputs from image space to latent space. + + The rescaling is performed using the formula: `(inputs - offset) * scale`. + """ + + def call(self, inputs): + dtype = self.compute_dtype + scale = self.backend.cast(self.scale, dtype) + offset = self.backend.cast(self.offset, dtype) + return (self.backend.cast(inputs, dtype) - offset) * scale + + +class LatentRescaling(layers.Rescaling): + """Rescales inputs from latent space to image space. + + The rescaling is performed using the formula: `inputs / scale + offset`. + """ + + def call(self, inputs): + dtype = self.compute_dtype + scale = self.backend.cast(self.scale, dtype) + offset = self.backend.cast(self.offset, dtype) + return (self.backend.cast(inputs, dtype) / scale) + offset + +class ClassifierFreeGuidanceConcatenate(layers.Layer): def call( self, latents, @@ -66,20 +107,16 @@ def call( timestep, ): timestep = ops.broadcast_to(timestep, ops.shape(latents)[:1]) - latents = ops.concatenate([latents, latents], axis=self.axis) + latents = ops.concatenate([latents, latents], axis=0) contexts = ops.concatenate( - [positive_contexts, negative_contexts], axis=self.axis + [positive_contexts, negative_contexts], axis=0 ) pooled_projections = ops.concatenate( - [positive_pooled_projections, negative_pooled_projections], - axis=self.axis, + [positive_pooled_projections, negative_pooled_projections], axis=0 ) - timesteps = ops.concatenate([timestep, timestep], axis=self.axis) + timesteps = ops.concatenate([timestep, timestep], axis=0) return latents, contexts, pooled_projections, timesteps - def get_config(self): - return super().get_config() - class ClassifierFreeGuidance(layers.Layer): """Perform classifier free guidance. @@ -100,9 +137,6 @@ class ClassifierFreeGuidance(layers.Layer): - [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). """ - def __init__(self, **kwargs): - super().__init__(**kwargs) - def call(self, inputs, guidance_scale): positive_noise, negative_noise = ops.split(inputs, 2, axis=0) return ops.add( @@ -112,9 +146,6 @@ def call(self, inputs, guidance_scale): ), ) - def get_config(self): - return super().get_config() - def compute_output_shape(self, inputs_shape): outputs_shape = list(inputs_shape) if outputs_shape[0] is not None: @@ -142,16 +173,10 @@ class EulerStep(layers.Layer): https://arxiv.org/abs/2206.00364). """ - def __init__(self, **kwargs): - super().__init__(**kwargs) - def call(self, latents, noise_residual, sigma, sigma_next): sigma_diff = ops.subtract(sigma_next, sigma) return ops.add(latents, ops.multiply(sigma_diff, noise_residual)) - def get_config(self): - return super().get_config() - def compute_output_shape(self, latents_shape): return latents_shape @@ -272,12 +297,13 @@ def __init__( self.clip_l_projection = CLIPProjection( clip_l.hidden_dim, dtype=dtype, name="clip_l_projection" ) - self.clip_l_projection.build([None, clip_l.hidden_dim], None) self.clip_g = clip_g self.clip_g_projection = CLIPProjection( clip_g.hidden_dim, dtype=dtype, name="clip_g_projection" ) - self.clip_g_projection.build([None, clip_g.hidden_dim], None) + self.clip_concatenate = CLIPConcatenate( + dtype=dtype, name="clip_concatenate" + ) self.t5 = t5 self.diffuser = MMDiT( mmdit_patch_size, @@ -308,8 +334,14 @@ def __init__( name="scheduler", ) self.euler_step = EulerStep(dtype="float32", name="euler_step") - self.latent_rescaling = layers.Rescaling( - scale=1.0 / self.vae.scale, + self.image_rescaling = ImageRescaling( + scale=self.vae.scale, + offset=self.vae.shift, + dtype=dtype, + name="image_rescaling", + ) + self.latent_rescaling = LatentRescaling( + scale=self.vae.scale, offset=self.vae.shift, dtype=dtype, name="latent_rescaling", @@ -456,20 +488,12 @@ def encode(token_ids): token_ids["clip_g"], training=False, ) - pooled_embeddings = ops.concatenate( - [clip_l_projection, clip_g_projection], - axis=-1, - ) - embeddings = ops.concatenate( - [ - clip_l_outputs["intermediate_output"], - clip_g_outputs["intermediate_output"], - ], - axis=-1, - ) - embeddings = ops.pad( - embeddings, - [[0, 0], [0, 0], [0, t5_hidden_dim - clip_hidden_dim]], + pooled_embeddings, embeddings = self.clip_concatenate( + clip_l_projection, + clip_g_projection, + clip_l_outputs["intermediate_output"], + clip_g_outputs["intermediate_output"], + padding=t5_hidden_dim - clip_hidden_dim, ) if self.t5 is not None: t5_outputs = self.t5( @@ -500,9 +524,7 @@ def encode(token_ids): def encode_image_step(self, images): latents = self.vae.encode(images) - return ops.multiply( - ops.subtract(latents, self.vae.shift), self.vae.scale - ) + return self.image_rescaling(latents) def add_noise_step(self, latents, noises, step, num_steps): return self.scheduler.add_noise(latents, noises, step, num_steps) From eda16fcebdc1219be49e358106a2a8601dc266f2 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Tue, 8 Oct 2024 17:18:46 +0800 Subject: [PATCH 8/8] Refactor `generate()` API of `TextToImage`, `ImageToImage` and `Inpaint`. --- keras_hub/src/models/image_to_image.py | 224 +++++++++++------ keras_hub/src/models/inpaint.py | 237 +++++++++++++----- .../stable_diffusion_3_image_to_image.py | 42 ++-- .../stable_diffusion_3_image_to_image_test.py | 32 ++- .../stable_diffusion_3_inpaint.py | 17 +- .../stable_diffusion_3_inpaint_test.py | 48 +++- .../stable_diffusion_3_text_to_image.py | 21 +- .../stable_diffusion_3_text_to_image_test.py | 26 +- keras_hub/src/models/text_to_image.py | 125 ++++++--- 9 files changed, 540 insertions(+), 232 deletions(-) diff --git a/keras_hub/src/models/image_to_image.py b/keras_hub/src/models/image_to_image.py index 8f92b66031..239f49b55d 100644 --- a/keras_hub/src/models/image_to_image.py +++ b/keras_hub/src/models/image_to_image.py @@ -60,6 +60,11 @@ def __init__(self, *args, **kwargs): # Default compilation. self.compile() + @property + def support_negative_prompts(self): + """Whether the model supports `negative_prompts` key in `generate()`.""" + return bool(True) + @property def image_shape(self): return tuple(self.backbone.image_shape) @@ -173,17 +178,52 @@ def wrapped_function(*args, **kwargs): self.generate_function = wrapped_function return self.generate_function - def _normalize_generate_images(self, inputs): - """Normalize user image to the generate function. + def _normalize_generate_inputs(self, inputs): + """Normalize user input to the generate function. This function converts all inputs to tensors, adds a batch dimension if necessary, and returns a iterable "dataset like" object (either an actual `tf.data.Dataset` or a list with a single batch element). + + The input format must be one of the following: + - A dict with "images", "prompts" and/or "negative_prompts" keys + - A tf.data.Dataset with "images", "prompts" and/or "negative_prompts" + keys + + The output will be a dict with "images", "prompts" and/or + "negative_prompts" keys. """ if tf and isinstance(inputs, tf.data.Dataset): - return inputs.as_numpy_iterator(), False + _inputs = { + "images": inputs.map(lambda x: x["images"]).as_numpy_iterator(), + "prompts": inputs.map( + lambda x: x["prompts"] + ).as_numpy_iterator(), + } + if self.support_negative_prompts: + _inputs["negative_prompts"] = inputs.map( + lambda x: x["negative_prompts"] + ).as_numpy_iterator() + return _inputs, False + + if ( + not isinstance(inputs, dict) + or "images" not in inputs + or "prompts" not in inputs + ): + raise ValueError( + '`inputs` must be a dict with "images" and "prompts" keys or a' + f"tf.data.Dataset. Received: inputs={inputs}" + ) def normalize(x): + if isinstance(x, str): + return [x], True + if tf and isinstance(x, tf.Tensor) and x.shape.rank == 0: + return x[tf.newaxis], True + return x, False + + def normalize_images(x): data_format = getattr( self.backbone, "data_format", standardize_data_format(None) ) @@ -200,38 +240,23 @@ def normalize(x): ) return x, input_is_scalar - if isinstance(inputs, dict): - for key in inputs: + def get_dummy_prompts(x): + dummy_prompts = [""] * len(x) + if tf and isinstance(x, tf.Tensor): + return tf.convert_to_tensor(dummy_prompts) + else: + return dummy_prompts + + for key in inputs: + if key == "images": + inputs[key], input_is_scalar = normalize_images(inputs[key]) + else: inputs[key], input_is_scalar = normalize(inputs[key]) - else: - inputs, input_is_scalar = normalize(inputs) - return inputs, input_is_scalar - - def _normalize_generate_inputs(self, inputs): - """Normalize user input to the generate function. + if self.support_negative_prompts and "negative_prompts" not in inputs: + inputs["negative_prompts"] = get_dummy_prompts(inputs["prompts"]) - This function converts all inputs to tensors, adds a batch dimension if - necessary, and returns a iterable "dataset like" object (either an - actual `tf.data.Dataset` or a list with a single batch element). - """ - if tf and isinstance(inputs, tf.data.Dataset): - return inputs.as_numpy_iterator(), False - - def normalize(x): - if isinstance(x, str): - return [x], True - if tf and isinstance(x, tf.Tensor) and x.shape.rank == 0: - return x[tf.newaxis], True - return x, False - - if isinstance(inputs, dict): - for key in inputs: - inputs[key], input_is_scalar = normalize(inputs[key]) - else: - inputs, input_is_scalar = normalize(inputs) - - return inputs, input_is_scalar + return [inputs], input_is_scalar def _normalize_generate_outputs(self, outputs, input_is_scalar): """Normalize user output from the generate function. @@ -242,12 +267,11 @@ def _normalize_generate_outputs(self, outputs, input_is_scalar): """ def normalize(x): - outputs = ops.clip(ops.divide(ops.add(x, 1.0), 2.0), 0.0, 1.0) + outputs = ops.concatenate(x, axis=0) + outputs = ops.clip(ops.divide(ops.add(outputs, 1.0), 2.0), 0.0, 1.0) outputs = ops.cast(ops.round(ops.multiply(outputs, 255.0)), "uint8") - outputs = ops.convert_to_numpy(outputs) - if input_is_scalar: - outputs = outputs[0] - return outputs + outputs = ops.squeeze(outputs, 0) if input_is_scalar else outputs + return ops.convert_to_numpy(outputs) if isinstance(outputs[0], dict): normalized = {} @@ -258,33 +282,36 @@ def normalize(x): def generate( self, - images, inputs, - negative_inputs, num_steps, guidance_scale, strength, seed=None, ): - """Generate image based on the provided `images` and `inputs`. + """Generate image based on the provided `inputs`. - The `images` are reference images within a value range of `[-1.0, 1.0]`, - which will be resized to `self.backbone.height` and + Typically, `inputs` is a dict with `"images"` and `"prompts"` keys. + `"images"` are reference images within a value range of + `[-1.0, 1.0]`, which will be resized to `self.backbone.height` and `self.backbone.width`, then encoded into latent space by the VAE - encoder. The `inputs` are strings that will be tokenized and encoded by + encoder. `"prompts"` are strings that will be tokenized and encoded by the text encoder. - If `images` and `inputs` are a `tf.data.Dataset`, outputs will be - generated "batch-by-batch" and concatenated. Otherwise, all inputs will - be processed as batches. + Some models support a `"negative_prompts"` key, which helps steer the + model away from generating certain styles and elements. To enable this, + add `"negative_prompts"` to the input dict. + + If `inputs` are a `tf.data.Dataset`, outputs will be generated + "batch-by-batch" and concatenated. Otherwise, all inputs will be + processed as batches. Args: - images: python data, tensor data, or a `tf.data.Dataset`. - inputs: python data, tensor data, or a `tf.data.Dataset`. - negative_inputs: python data, tensor data, or a `tf.data.Dataset`. - Unlike `inputs`, these are used as negative inputs to guide the - generation. If not provided, it defaults to `""` for each input - in `inputs`. + inputs: python data, tensor data, or a `tf.data.Dataset`. The format + must be one of the following: + - A dict with `"images"`, `"prompts"` and/or + `"negative_prompts"` keys. + - A `tf.data.Dataset` with `"images"`, `"prompts"` and/or + `"negative_prompts"` keys. num_steps: int. The number of diffusion steps to take. guidance_scale: float. The classifier free guidance scale defined in [Classifier-Free Diffusion Guidance]( @@ -306,6 +333,32 @@ def generate( "`strength` must be between `0.0` and `1.0`. " f"Received strength={strength}." ) + starting_step = int(num_steps * (1.0 - strength)) + starting_step = ops.convert_to_tensor(starting_step, "int32") + num_steps = ops.convert_to_tensor(num_steps, "int32") + guidance_scale = ops.convert_to_tensor(guidance_scale) + + # Check `inputs` format. + required_keys = ["images", "prompts"] + if tf and isinstance(inputs, tf.data.Dataset): + spec = inputs.element_spec + if not all(key in spec for key in required_keys): + raise ValueError( + "Expected a `tf.data.Dataset` with the following keys:" + f"{required_keys}. Received: inputs.element_spec={spec}" + ) + else: + if not isinstance(inputs, dict): + raise ValueError( + "Expected a `dict` or `tf.data.Dataset`. " + f"Received: inputs={inputs} of type {type(inputs)}." + ) + if not all(key in inputs for key in required_keys): + raise ValueError( + "Expected a `dict` with the following keys:" + f"{required_keys}. " + f"Received: inputs.keys={list(inputs.keys())}" + ) # Setup our three main passes. # 1. Preprocessing strings to dense integer tensors. @@ -314,38 +367,45 @@ def generate( generate_function = self.make_generate_function() def preprocess(x): - return self.preprocessor.generate_preprocess(x) + if self.preprocessor is not None: + return self.preprocessor.generate_preprocess(x) + else: + return x + + def generate(images, x): + token_ids = x[0] if self.support_negative_prompts else x + + # Initialize noises. + if isinstance(token_ids, dict): + arbitrary_key = list(token_ids.keys())[0] + batch_size = ops.shape(token_ids[arbitrary_key])[0] + else: + batch_size = ops.shape(token_ids)[0] + noise_shape = (batch_size,) + self.latent_shape[1:] + noises = random.normal(noise_shape, dtype="float32", seed=seed) + + return generate_function( + images, noises, x, starting_step, num_steps, guidance_scale + ) # Normalize and preprocess inputs. - images, image_is_scalar = self._normalize_generate_images(images) - inputs, _ = self._normalize_generate_inputs(inputs) - if negative_inputs is None: - negative_inputs = [""] * len(inputs) - negative_inputs, _ = self._normalize_generate_inputs(negative_inputs) - - if self.preprocessor is not None: - inputs = preprocess(inputs) - negative_inputs = preprocess(negative_inputs) - if isinstance(inputs, dict): - batch_size = len(inputs[list(inputs.keys())[0]]) + inputs, input_is_scalar = self._normalize_generate_inputs(inputs) + if self.support_negative_prompts: + images = [x["images"] for x in inputs] + token_ids = [preprocess(x["prompts"]) for x in inputs] + negative_token_ids = [ + preprocess(x["negative_prompts"]) for x in inputs + ] + # Tuple format: (images, (token_ids, negative_token_ids)). + inputs = [ + x for x in zip(images, zip(token_ids, negative_token_ids)) + ] else: - batch_size = len(inputs) - - # Get the starting step for denoising. - starting_step = int(num_steps * (1.0 - strength)) - - # Initialize random noises. - noise_shape = (batch_size,) + self.latent_shape[1:] - noises = random.normal(noise_shape, dtype="float32", seed=seed) + images = [x["images"] for x in inputs] + token_ids = [preprocess(x["prompts"]) for x in inputs] + # Tuple format: (images, token_ids). + inputs = [x for x in zip(images, token_ids)] # Image-to-image. - outputs = generate_function( - ops.convert_to_tensor(images), - noises, - inputs, - negative_inputs, - ops.convert_to_tensor(starting_step, "int32"), - ops.convert_to_tensor(num_steps, "int32"), - ops.convert_to_tensor(guidance_scale), - ) - return self._normalize_generate_outputs(outputs, image_is_scalar) + outputs = [generate(*x) for x in inputs] + return self._normalize_generate_outputs(outputs, input_is_scalar) diff --git a/keras_hub/src/models/inpaint.py b/keras_hub/src/models/inpaint.py index 013dba162b..1c475f5b83 100644 --- a/keras_hub/src/models/inpaint.py +++ b/keras_hub/src/models/inpaint.py @@ -63,6 +63,11 @@ def __init__(self, *args, **kwargs): # Default compilation. self.compile() + @property + def support_negative_prompts(self): + """Whether the model supports `negative_prompts` key in `generate()`.""" + return bool(True) + @property def image_shape(self): return tuple(self.backbone.image_shape) @@ -256,9 +261,29 @@ def _normalize_generate_inputs(self, inputs): This function converts all inputs to tensors, adds a batch dimension if necessary, and returns a iterable "dataset like" object (either an actual `tf.data.Dataset` or a list with a single batch element). + + The input format must be one of the following: + - A dict with "images", "masks", "prompts" and/or "negative_prompts" + keys + - A tf.data.Dataset with "images", "masks", "prompts" and/or + "negative_prompts" keys + + The output will be a dict with "images", "masks", "prompts" and/or + "negative_prompts" keys. """ if tf and isinstance(inputs, tf.data.Dataset): - return inputs.as_numpy_iterator(), False + _inputs = { + "images": inputs.map(lambda x: x["images"]).as_numpy_iterator(), + "masks": inputs.map(lambda x: x["masks"]).as_numpy_iterator(), + "prompts": inputs.map( + lambda x: x["prompts"] + ).as_numpy_iterator(), + } + if self.support_negative_prompts: + _inputs["negative_prompts"] = inputs.map( + lambda x: x["negative_prompts"] + ).as_numpy_iterator() + return _inputs, False def normalize(x): if isinstance(x, str): @@ -267,13 +292,63 @@ def normalize(x): return x[tf.newaxis], True return x, False - if isinstance(inputs, dict): - for key in inputs: + def normalize_images(x): + data_format = getattr( + self.backbone, "data_format", standardize_data_format(None) + ) + input_is_scalar = False + x = ops.convert_to_tensor(x) + if len(ops.shape(x)) < 4: + x = ops.expand_dims(x, axis=0) + input_is_scalar = True + x = ops.image.resize( + x, + (self.backbone.height, self.backbone.width), + interpolation="nearest", + data_format=data_format, + ) + return x, input_is_scalar + + def normalize_masks(x): + data_format = getattr( + self.backbone, "data_format", standardize_data_format(None) + ) + input_is_scalar = False + x = ops.convert_to_tensor(x) + if len(ops.shape(x)) < 3: + x = ops.expand_dims(x, axis=0) + input_is_scalar = True + x = ops.expand_dims(x, axis=-1) + if keras.backend.standardize_dtype(x.dtype) == "bool": + x = ops.cast(x, "float32") + x = ops.image.resize( + x, + (self.backbone.height, self.backbone.width), + interpolation="nearest", + data_format=data_format, + ) + x = ops.squeeze(x, axis=-1) + return x, input_is_scalar + + def get_dummy_prompts(x): + dummy_prompts = [""] * len(x) + if tf and isinstance(x, tf.Tensor): + return tf.convert_to_tensor(dummy_prompts) + else: + return dummy_prompts + + for key in inputs: + if key == "images": + inputs[key], input_is_scalar = normalize_images(inputs[key]) + elif key == "masks": + inputs[key], input_is_scalar = normalize_masks(inputs[key]) + else: inputs[key], input_is_scalar = normalize(inputs[key]) - else: - inputs, input_is_scalar = normalize(inputs) - return inputs, input_is_scalar + if self.support_negative_prompts and "negative_prompts" not in inputs: + inputs["negative_prompts"] = get_dummy_prompts(inputs["prompts"]) + + return [inputs], input_is_scalar def _normalize_generate_outputs(self, outputs, input_is_scalar): """Normalize user output from the generate function. @@ -284,12 +359,11 @@ def _normalize_generate_outputs(self, outputs, input_is_scalar): """ def normalize(x): - outputs = ops.clip(ops.divide(ops.add(x, 1.0), 2.0), 0.0, 1.0) + outputs = ops.concatenate(x, axis=0) + outputs = ops.clip(ops.divide(ops.add(outputs, 1.0), 2.0), 0.0, 1.0) outputs = ops.cast(ops.round(ops.multiply(outputs, 255.0)), "uint8") - outputs = ops.convert_to_numpy(outputs) - if input_is_scalar: - outputs = outputs[0] - return outputs + outputs = ops.squeeze(outputs, 0) if input_is_scalar else outputs + return ops.convert_to_numpy(outputs) if isinstance(outputs[0], dict): normalized = {} @@ -300,36 +374,37 @@ def normalize(x): def generate( self, - images, - masks, inputs, - negative_inputs, num_steps, guidance_scale, strength, seed=None, ): - """Generate image based on the provided `images`, `masks` and `inputs`. + """Generate image based on the provided `inputs`. - The `images` are reference images within a value range of `[-1.0, 1.0]`, - which will be resized to `self.backbone.height` and + Typically, `inputs` is a dict with `"images"` `"masks"` and `"prompts"` + keys. `"images"` are reference images within a value range of + `[-1.0, 1.0]`, which will be resized to `self.backbone.height` and `self.backbone.width`, then encoded into latent space by the VAE - encoder. The `masks` are mask images with a boolean dtype, where white - pixels are repainted while black pixels are preserved. The `inputs` are + encoder. `"masks"` are mask images with a boolean dtype, where white + pixels are repainted while black pixels are preserved. `"prompts"` are strings that will be tokenized and encoded by the text encoder. - If `images`, `masks` and `inputs` are a `tf.data.Dataset`, outputs will - be generated "batch-by-batch" and concatenated. Otherwise, all inputs - will be processed as batches. + Some models support a `"negative_prompts"` key, which helps steer the + model away from generating certain styles and elements. To enable this, + add `"negative_prompts"` to the input dict. + + If `inputs` are a `tf.data.Dataset`, outputs will be generated + "batch-by-batch" and concatenated. Otherwise, all inputs will be + processed as batches. Args: - images: python data, tensor data, or a `tf.data.Dataset`. - masks: python data, tensor data, or a `tf.data.Dataset`. - inputs: python data, tensor data, or a `tf.data.Dataset`. - negative_inputs: python data, tensor data, or a `tf.data.Dataset`. - Unlike `inputs`, these are used as negative inputs to guide the - generation. If not provided, it defaults to `""` for each input - in `inputs`. + inputs: python data, tensor data, or a `tf.data.Dataset`. The format + must be one of the following: + - A dict with `"images"`, `"masks"`, `"prompts"` and/or + `"negative_prompts"` keys. + - A `tf.data.Dataset` with `"images"`, `"masks"`, `"prompts"` + and/or `"negative_prompts"` keys. num_steps: int. The number of diffusion steps to take. guidance_scale: float. The classifier free guidance scale defined in [Classifier-Free Diffusion Guidance]( @@ -351,6 +426,32 @@ def generate( "`strength` must be between `0.0` and `1.0`. " f"Received strength={strength}." ) + starting_step = int(num_steps * (1.0 - strength)) + starting_step = ops.convert_to_tensor(starting_step, "int32") + num_steps = ops.convert_to_tensor(num_steps, "int32") + guidance_scale = ops.convert_to_tensor(guidance_scale) + + # Check `inputs` format. + required_keys = ["images", "masks", "prompts"] + if tf and isinstance(inputs, tf.data.Dataset): + spec = inputs.element_spec + if not all(key in spec for key in required_keys): + raise ValueError( + "Expected a `tf.data.Dataset` with the following keys:" + f"{required_keys}. Received: inputs.element_spec={spec}" + ) + else: + if not isinstance(inputs, dict): + raise ValueError( + "Expected a `dict` or `tf.data.Dataset`. " + f"Received: inputs={inputs} of type {type(inputs)}." + ) + if not all(key in inputs for key in required_keys): + raise ValueError( + "Expected a `dict` with the following keys:" + f"{required_keys}. " + f"Received: inputs.keys={list(inputs.keys())}" + ) # Setup our three main passes. # 1. Preprocessing strings to dense integer tensors. @@ -359,40 +460,54 @@ def generate( generate_function = self.make_generate_function() def preprocess(x): - return self.preprocessor.generate_preprocess(x) + if self.preprocessor is not None: + return self.preprocessor.generate_preprocess(x) + else: + return x + + def generate(images, masks, x): + token_ids = x[0] if self.support_negative_prompts else x + + # Initialize noises. + if isinstance(token_ids, dict): + arbitrary_key = list(token_ids.keys())[0] + batch_size = ops.shape(token_ids[arbitrary_key])[0] + else: + batch_size = ops.shape(token_ids)[0] + noise_shape = (batch_size,) + self.latent_shape[1:] + noises = random.normal(noise_shape, dtype="float32", seed=seed) + + return generate_function( + images, + masks, + noises, + x, + starting_step, + num_steps, + guidance_scale, + ) # Normalize and preprocess inputs. - images, image_is_scalar = self._normalize_generate_images(images) - masks, _ = self._normalize_generate_masks(masks) - inputs, _ = self._normalize_generate_inputs(inputs) - if negative_inputs is None: - negative_inputs = [""] * len(inputs) - negative_inputs, _ = self._normalize_generate_inputs(negative_inputs) - - if self.preprocessor is not None: - inputs = preprocess(inputs) - negative_inputs = preprocess(negative_inputs) - if isinstance(inputs, dict): - batch_size = len(inputs[list(inputs.keys())[0]]) + inputs, input_is_scalar = self._normalize_generate_inputs(inputs) + if self.support_negative_prompts: + images = [x["images"] for x in inputs] + masks = [x["masks"] for x in inputs] + token_ids = [preprocess(x["prompts"]) for x in inputs] + negative_token_ids = [ + preprocess(x["negative_prompts"]) for x in inputs + ] + # Tuple format: (images, masks, (token_ids, negative_token_ids)). + inputs = [ + x + for x in zip(images, masks, zip(token_ids, negative_token_ids)) + ] else: - batch_size = len(inputs) - - # Get the starting step for denoising. - starting_step = int(num_steps * (1.0 - strength)) - - # Initialize random noises. - noise_shape = (batch_size,) + self.latent_shape[1:] - noises = random.normal(noise_shape, dtype="float32", seed=seed) + images = [x["images"] for x in inputs] + masks = [x["masks"] for x in inputs] + token_ids = [preprocess(x["prompts"]) for x in inputs] + # Tuple format: (images, masks, token_ids). + inputs = [x for x in zip(images, masks, token_ids)] # Inpaint. - outputs = generate_function( - ops.convert_to_tensor(images), - ops.convert_to_tensor(masks), - noises, - inputs, - negative_inputs, - ops.convert_to_tensor(starting_step, "int32"), - ops.convert_to_tensor(num_steps, "int32"), - ops.convert_to_tensor(guidance_scale), - ) - return self._normalize_generate_outputs(outputs, image_is_scalar) + outputs = [generate(*x) for x in inputs] + return self._normalize_generate_outputs(outputs, input_is_scalar) diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py index 9a8372be52..29c939e759 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py @@ -26,30 +26,43 @@ class StableDiffusion3ImageToImage(ImageToImage): Use `generate()` to do image generation. ```python - reference_image = np.ones((512, 512, 3), dtype="float32") image_to_image = keras_hub.models.StableDiffusion3ImageToImage.from_preset( "stable_diffusion_3_medium", height=512, width=512 ) image_to_image.generate( - reference_image, - "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + { + "images": np.ones((512, 512, 3), dtype="float32"), + "prompts": "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + } ) # Generate with batched prompts. - reference_images = np.ones((2, 512, 512, 3), dtype="float32") image_to_image.generate( - reference_images, - ["cute wallpaper art of a cat", "cute wallpaper art of a dog"] + { + "images": np.ones((2, 512, 512, 3), dtype="float32"), + "prompts": ["cute wallpaper art of a cat", "cute wallpaper art of a dog"], + } ) # Generate with different `num_steps`, `guidance_scale` and `strength`. image_to_image.generate( - reference_image, - "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + { + "images": np.ones((512, 512, 3), dtype="float32"), + "prompts": "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + } num_steps=50, guidance_scale=5.0, strength=0.6, ) + + # Generate with `negative_prompts`. + text_to_image.generate( + { + "images": np.ones((512, 512, 3), dtype="float32"), + "prompts": "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + "negative_prompts": "green color", + } + ) ``` """ @@ -86,7 +99,6 @@ def generate_step( images, noises, token_ids, - negative_token_ids, starting_step, num_steps, guidance_scale, @@ -102,10 +114,8 @@ def generate_step( noises: A (batch_size, latent_height, latent_width, channels) tensor containing the noises to be added to the latents. Typically, this tensor is sampled from the Gaussian distribution. - token_ids: A (batch_size, num_tokens) tensor containing the - tokens based on the input prompts. - negative_token_ids: A (batch_size, num_tokens) tensor - containing the negative tokens based on the input prompts. + token_ids: A pair of (batch_size, num_tokens) tensor containing the + tokens based on the input prompts and negative prompts. starting_step: int. The number of the starting diffusion step. num_steps: int. The number of diffusion steps to take. guidance_scale: float. The classifier free guidance scale defined in @@ -114,6 +124,8 @@ def generate_step( generate images that are closely linked to prompts, usually at the expense of lower image quality. """ + token_ids, negative_token_ids = token_ids + # Encode images. latents = self.backbone.encode_image_step(images) @@ -144,18 +156,14 @@ def body_fun(step, latents): def generate( self, - images, inputs, - negative_inputs=None, num_steps=50, guidance_scale=7.0, strength=0.8, seed=None, ): return super().generate( - images, inputs, - negative_inputs=negative_inputs, num_steps=num_steps, guidance_scale=guidance_scale, strength=strength, diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image_test.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image_test.py index 7374ea8e8c..7debb69630 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image_test.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image_test.py @@ -95,7 +95,12 @@ def test_generate(self): prompt = ["airplane"] negative_prompt = [""] output = image_to_image.generate( - image, prompt, negative_prompt, seed=seed + { + "images": image, + "prompts": prompt, + "negative_prompts": negative_prompt, + }, + seed=seed, ) # Int tensor input. prompt_ids = self.preprocessor.generate_preprocess(prompt) @@ -104,7 +109,12 @@ def test_generate(self): ) image_to_image.preprocessor = None output2 = image_to_image.generate( - image, prompt_ids, negative_prompt_ids, seed=seed + { + "images": image, + "prompts": prompt_ids, + "negative_prompts": negative_prompt_ids, + }, + seed=seed, ) self.assertAllClose(output, output2) @@ -122,7 +132,12 @@ def test_generate_with_lower_precision(self): prompt = ["airplane"] negative_prompt = [""] output = image_to_image.generate( - image, prompt, negative_prompt, seed=seed + { + "images": image, + "prompts": prompt, + "negative_prompts": negative_prompt, + }, + seed=seed, ) # Int tensor input. prompt_ids = self.preprocessor.generate_preprocess(prompt) @@ -131,7 +146,12 @@ def test_generate_with_lower_precision(self): ) image_to_image.preprocessor = None output2 = image_to_image.generate( - image, prompt_ids, negative_prompt_ids, seed=seed + { + "images": image, + "prompts": prompt_ids, + "negative_prompts": negative_prompt_ids, + }, + seed=seed, ) self.assertAllClose(output, output2) finally: @@ -143,9 +163,9 @@ def test_generate_compilation(self): image_to_image = StableDiffusion3ImageToImage(**self.init_kwargs) image = self.input_data["images"][0] # Assert we do not recompile with successive calls. - image_to_image.generate(image, "airplane") + image_to_image.generate({"images": image, "prompts": "airplane"}) first_fn = image_to_image.generate_function - image_to_image.generate(image, "airplane") + image_to_image.generate({"images": image, "prompts": "airplane"}) second_fn = image_to_image.generate_function self.assertEqual(first_fn, second_fn) # Assert we do recompile after compile is called. diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py index 1202831b21..90c11a7238 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py @@ -92,7 +92,6 @@ def generate_step( masks, noises, token_ids, - negative_token_ids, starting_step, num_steps, guidance_scale, @@ -105,15 +104,13 @@ def generate_step( Args: images: A (batch_size, image_height, image_width, 3) tensor containing the reference images. - masks: A (batch_size, image_height, image_width, 1 or 3) tensor + masks: A (batch_size, image_height, image_width) tensor containing the reference masks. noises: A (batch_size, latent_height, latent_width, channels) tensor containing the noises to be added to the latents. Typically, this tensor is sampled from the Gaussian distribution. - token_ids: A (batch_size, num_tokens) tensor containing the - tokens based on the input prompts. - negative_token_ids: A (batch_size, num_tokens) tensor - containing the negative tokens based on the input prompts. + token_ids: A pair of (batch_size, num_tokens) tensor containing the + tokens based on the input prompts and negative prompts. starting_step: int. The number of the starting diffusion step. num_steps: int. The number of diffusion steps to take. guidance_scale: float. The classifier free guidance scale defined in @@ -122,6 +119,8 @@ def generate_step( generate images that are closely linked to prompts, usually at the expense of lower image quality. """ + token_ids, negative_token_ids = token_ids + # Get masked images. masks = ops.cast(ops.expand_dims(masks, axis=-1) > 0.5, images.dtype) masks_latent_size = ops.image.resize( @@ -180,20 +179,14 @@ def true_fn(): def generate( self, - images, - masks, inputs, - negative_inputs=None, num_steps=50, guidance_scale=7.0, strength=0.6, seed=None, ): return super().generate( - images, - masks, inputs, - negative_inputs=negative_inputs, num_steps=num_steps, guidance_scale=guidance_scale, strength=strength, diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint_test.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint_test.py index 2d5e37c7c5..faade4b1e4 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint_test.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint_test.py @@ -96,7 +96,13 @@ def test_generate(self): prompt = ["airplane"] negative_prompt = [""] output = inpaint.generate( - image, mask, prompt, negative_prompt, seed=seed + { + "images": image, + "masks": mask, + "prompts": prompt, + "negative_prompts": negative_prompt, + }, + seed=seed, ) # Int tensor input. prompt_ids = self.preprocessor.generate_preprocess(prompt) @@ -105,7 +111,13 @@ def test_generate(self): ) inpaint.preprocessor = None output2 = inpaint.generate( - image, mask, prompt_ids, negative_prompt_ids, seed=seed + { + "images": image, + "masks": mask, + "prompts": prompt_ids, + "negative_prompts": negative_prompt_ids, + }, + seed=seed, ) self.assertAllClose(output, output2) @@ -122,7 +134,13 @@ def test_generate_with_lower_precision(self): prompt = ["airplane"] negative_prompt = [""] output = inpaint.generate( - image, mask, prompt, negative_prompt, seed=seed + { + "images": image, + "masks": mask, + "prompts": prompt, + "negative_prompts": negative_prompt, + }, + seed=seed, ) # Int tensor input. prompt_ids = self.preprocessor.generate_preprocess(prompt) @@ -131,7 +149,13 @@ def test_generate_with_lower_precision(self): ) inpaint.preprocessor = None output2 = inpaint.generate( - image, mask, prompt_ids, negative_prompt_ids, seed=seed + { + "images": image, + "masks": mask, + "prompts": prompt_ids, + "negative_prompts": negative_prompt_ids, + }, + seed=seed, ) self.assertAllClose(output, output2) finally: @@ -144,9 +168,21 @@ def test_generate_compilation(self): image = self.input_data["images"][0] mask = self.input_data["images"][0][..., 0] # (B, H, W) # Assert we do not recompile with successive calls. - inpaint.generate(image, mask, "airplane") + inpaint.generate( + { + "images": image, + "masks": mask, + "prompts": "airplane", + } + ) first_fn = inpaint.generate_function - inpaint.generate(image, mask, "airplane") + inpaint.generate( + { + "images": image, + "masks": mask, + "prompts": "airplane", + } + ) second_fn = inpaint.generate_function self.assertEqual(first_fn, second_fn) # Assert we do recompile after compile is called. diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py index 63f0ba6c28..623088f231 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py @@ -44,6 +44,14 @@ class StableDiffusion3TextToImage(TextToImage): num_steps=50, guidance_scale=5.0, ) + + # Generate with `negative_prompts`. + text_to_image.generate( + { + "prompts": "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + "negative_prompts": "green color", + } + ) ``` """ @@ -79,7 +87,6 @@ def generate_step( self, latents, token_ids, - negative_token_ids, num_steps, guidance_scale, ): @@ -92,10 +99,8 @@ def generate_step( latents: A (batch_size, height, width, channels) tensor containing the latents to start generation from. Typically, this tensor is sampled from the Gaussian distribution. - token_ids: A (batch_size, num_tokens) tensor containing the - tokens based on the input prompts. - negative_token_ids: A (batch_size, num_tokens) tensor - containing the negative tokens based on the input prompts. + token_ids: A pair of (batch_size, num_tokens) tensor containing the + tokens based on the input prompts and negative prompts. num_steps: int. The number of diffusion steps to take. guidance_scale: float. The classifier free guidance scale defined in [Classifier-Free Diffusion Guidance]( @@ -103,7 +108,9 @@ def generate_step( generate images that are closely linked to prompts, usually at the expense of lower image quality. """ - # Encode inputs. + token_ids, negative_token_ids = token_ids + + # Encode prompts. embeddings = self.backbone.encode_text_step( token_ids, negative_token_ids ) @@ -126,14 +133,12 @@ def body_fun(step, latents): def generate( self, inputs, - negative_inputs=None, num_steps=28, guidance_scale=7.0, seed=None, ): return super().generate( inputs, - negative_inputs=negative_inputs, num_steps=num_steps, guidance_scale=guidance_scale, seed=seed, diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_test.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_test.py index 837c95fa37..bbbb55b27d 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_test.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_test.py @@ -93,7 +93,13 @@ def test_generate(self): # String input. prompt = ["airplane"] negative_prompt = [""] - output = text_to_image.generate(prompt, negative_prompt, seed=seed) + output = text_to_image.generate( + { + "prompts": prompt, + "negative_prompts": negative_prompt, + }, + seed=seed, + ) # Int tensor input. prompt_ids = self.preprocessor.generate_preprocess(prompt) negative_prompt_ids = self.preprocessor.generate_preprocess( @@ -101,7 +107,11 @@ def test_generate(self): ) text_to_image.preprocessor = None output2 = text_to_image.generate( - prompt_ids, negative_prompt_ids, seed=seed + { + "prompts": prompt_ids, + "negative_prompts": negative_prompt_ids, + }, + seed=seed, ) self.assertAllClose(output, output2) @@ -116,7 +126,11 @@ def test_generate_with_lower_precision(self): prompt = ["airplane"] negative_prompt = [""] output = text_to_image.generate( - prompt, negative_prompt, seed=seed + { + "prompts": prompt, + "negative_prompts": negative_prompt, + }, + seed=seed, ) # Int tensor input. prompt_ids = self.preprocessor.generate_preprocess(prompt) @@ -125,7 +139,11 @@ def test_generate_with_lower_precision(self): ) text_to_image.preprocessor = None output2 = text_to_image.generate( - prompt_ids, negative_prompt_ids, seed=seed + { + "prompts": prompt_ids, + "negative_prompts": negative_prompt_ids, + }, + seed=seed, ) self.assertAllClose(output, output2) finally: diff --git a/keras_hub/src/models/text_to_image.py b/keras_hub/src/models/text_to_image.py index 291a4b023e..54b8dcdae2 100644 --- a/keras_hub/src/models/text_to_image.py +++ b/keras_hub/src/models/text_to_image.py @@ -56,6 +56,11 @@ def __init__(self, *args, **kwargs): # Default compilation. self.compile() + @property + def support_negative_prompts(self): + """Whether the model supports `negative_prompts` key in `generate()`.""" + return bool(True) + @property def latent_shape(self): return tuple(self.backbone.latent_shape) @@ -171,9 +176,26 @@ def _normalize_generate_inputs(self, inputs): This function converts all inputs to tensors, adds a batch dimension if necessary, and returns a iterable "dataset like" object (either an actual `tf.data.Dataset` or a list with a single batch element). + + The input format must be one of the following: + - A single string + - A list of strings + - A dict with "prompts" and/or "negative_prompts" keys + - A tf.data.Dataset with "prompts" and/or "negative_prompts" keys + + The output will be a dict with "prompts" and/or "negative_prompts" keys. """ if tf and isinstance(inputs, tf.data.Dataset): - return inputs.as_numpy_iterator(), False + _inputs = { + "prompts": inputs.map( + lambda x: x["prompts"] + ).as_numpy_iterator() + } + if self.support_negative_prompts: + _inputs["negative_prompts"] = inputs.map( + lambda x: x["negative_prompts"] + ).as_numpy_iterator() + return _inputs, False def normalize(x): if isinstance(x, str): @@ -182,13 +204,24 @@ def normalize(x): return x[tf.newaxis], True return x, False + def get_dummy_prompts(x): + dummy_prompts = [""] * len(x) + if tf and isinstance(x, tf.Tensor): + return tf.convert_to_tensor(dummy_prompts) + else: + return dummy_prompts + if isinstance(inputs, dict): for key in inputs: inputs[key], input_is_scalar = normalize(inputs[key]) else: inputs, input_is_scalar = normalize(inputs) + inputs = {"prompts": inputs} - return inputs, input_is_scalar + if self.support_negative_prompts and "negative_prompts" not in inputs: + inputs["negative_prompts"] = get_dummy_prompts(inputs["prompts"]) + + return [inputs], input_is_scalar def _normalize_generate_outputs(self, outputs, input_is_scalar): """Normalize user output from the generate function. @@ -199,12 +232,11 @@ def _normalize_generate_outputs(self, outputs, input_is_scalar): """ def normalize(x): - outputs = ops.clip(ops.divide(ops.add(x, 1.0), 2.0), 0.0, 1.0) + outputs = ops.concatenate(x, axis=0) + outputs = ops.clip(ops.divide(ops.add(outputs, 1.0), 2.0), 0.0, 1.0) outputs = ops.cast(ops.round(ops.multiply(outputs, 255.0)), "uint8") - outputs = ops.convert_to_numpy(outputs) - if input_is_scalar: - outputs = outputs[0] - return outputs + outputs = ops.squeeze(outputs, 0) if input_is_scalar else outputs + return ops.convert_to_numpy(outputs) if isinstance(outputs[0], dict): normalized = {} @@ -216,23 +248,40 @@ def normalize(x): def generate( self, inputs, - negative_inputs, num_steps, guidance_scale, seed=None, ): - """Generate image based on the provided `inputs` and `negative_inputs`. + """Generate image based on the provided `inputs`. + + Typically, `inputs` contains a text description (known as a prompt) used + to guide the image generation. + + Some models support a `negative_prompts` key, which helps steer the + model away from generating certain styles and elements. To enable this, + pass `prompts` and `negative_prompts` as a dict: + + ```python + text_to_image.generate( + { + "prompts": "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + "negative_prompts": "green color", + } + ) + ``` If `inputs` are a `tf.data.Dataset`, outputs will be generated "batch-by-batch" and concatenated. Otherwise, all inputs will be processed as batches. Args: - inputs: python data, tensor data, or a `tf.data.Dataset`. - negative_inputs: python data, tensor data, or a `tf.data.Dataset`. - Unlike `inputs`, these are used as negative inputs to guide the - generation. If not provided, it defaults to `""` for each input - in `inputs`. + inputs: python data, tensor data, or a `tf.data.Dataset`. The format + must be one of the following: + - A single string + - A list of strings + - A dict with "prompts" and/or "negative_prompts" keys + - A `tf.data.Dataset` with "prompts" and/or "negative_prompts" + keys num_steps: int. The number of diffusion steps to take. guidance_scale: float. The classifier free guidance scale defined in [Classifier-Free Diffusion Guidance]( @@ -251,32 +300,36 @@ def generate( generate_function = self.make_generate_function() def preprocess(x): - return self.preprocessor.generate_preprocess(x) + if self.preprocessor is not None: + return self.preprocessor.generate_preprocess(x) + else: + return x + + def generate(x): + token_ids = x[0] if self.support_negative_prompts else x + + # Initialize latents. + if isinstance(token_ids, dict): + arbitrary_key = list(token_ids.keys())[0] + batch_size = ops.shape(token_ids[arbitrary_key])[0] + else: + batch_size = ops.shape(token_ids)[0] + latent_shape = (batch_size,) + self.latent_shape[1:] + latents = random.normal(latent_shape, dtype="float32", seed=seed) + + return generate_function(latents, x, num_steps, guidance_scale) # Normalize and preprocess inputs. inputs, input_is_scalar = self._normalize_generate_inputs(inputs) - if negative_inputs is None: - negative_inputs = [""] * len(inputs) - negative_inputs, _ = self._normalize_generate_inputs(negative_inputs) - - if self.preprocessor is not None: - inputs = preprocess(inputs) - negative_inputs = preprocess(negative_inputs) - if isinstance(inputs, dict): - batch_size = len(inputs[list(inputs.keys())[0]]) + if self.support_negative_prompts: + token_ids = [preprocess(x["prompts"]) for x in inputs] + negative_token_ids = [ + preprocess(x["negative_prompts"]) for x in inputs + ] + inputs = [x for x in zip(token_ids, negative_token_ids)] else: - batch_size = len(inputs) - - # Initialize random latents. - latent_shape = (batch_size,) + self.latent_shape[1:] - latents = random.normal(latent_shape, dtype="float32", seed=seed) + inputs = [preprocess(x["prompts"]) for x in inputs] # Text-to-image. - outputs = generate_function( - latents, - inputs, - negative_inputs, - num_steps, - guidance_scale, - ) + outputs = [generate(x) for x in inputs] return self._normalize_generate_outputs(outputs, input_is_scalar)