-
Notifications
You must be signed in to change notification settings - Fork 242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor MMDiT
, add ImageToImage
and Inpaint
for SD3
#1909
Merged
divyashreepathihalli
merged 9 commits into
keras-team:master
from
james77777778:refactor-mmdit
Oct 8, 2024
Merged
Changes from 7 commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
e20ab7d
Refactor `MMDiT` and add `ImageToImage`
james77777778 a7cc7f2
Update model version
james77777778 da16e67
Fix minor bugs.
james77777778 8aa7388
Add `Inpaint` for SD3.
james77777778 c7749fb
Fix warnings of MMDiT.
james77777778 37c519b
Addcomment to Inpaint
james77777778 5ff2fa1
Simplify `MMDiT` implementation and info of `summary()`.
james77777778 eda16fc
Refactor `generate()` API of `TextToImage`, `ImageToImage` and `Inpai…
james77777778 023789e
Merge remote-tracking branch 'upstream/master' into refactor-mmdit
james77777778 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,351 @@ | ||
import itertools | ||
from functools import partial | ||
|
||
import keras | ||
from keras import ops | ||
from keras import random | ||
|
||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.models.task import Task | ||
from keras_hub.src.utils.keras_utils import standardize_data_format | ||
|
||
try: | ||
import tensorflow as tf | ||
except ImportError: | ||
tf = None | ||
|
||
|
||
@keras_hub_export("keras_hub.models.ImageToImage") | ||
class ImageToImage(Task): | ||
"""Base class for image-to-image tasks. | ||
|
||
`ImageToImage` tasks wrap a `keras_hub.models.Backbone` and | ||
a `keras_hub.models.Preprocessor` to create a model that can be used for | ||
generation and generative fine-tuning. | ||
|
||
`ImageToImage` tasks provide an additional, high-level `generate()` function | ||
which can be used to generate image by token with a (image, string) in, | ||
image out signature. | ||
|
||
All `ImageToImage` tasks include a `from_preset()` constructor which can be | ||
used to load a pre-trained config and weights. | ||
|
||
Example: | ||
|
||
```python | ||
# Load a Stable Diffusion 3 backbone with pre-trained weights. | ||
reference_image = np.ones((1024, 1024, 3), dtype="float32") | ||
image_to_image = keras_hub.models.ImageToImage.from_preset( | ||
"stable_diffusion_3_medium", | ||
) | ||
image_to_image.generate( | ||
reference_image, | ||
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", | ||
) | ||
|
||
# Load a Stable Diffusion 3 backbone at bfloat16 precision. | ||
image_to_image = keras_hub.models.ImageToImage.from_preset( | ||
"stable_diffusion_3_medium", | ||
dtype="bfloat16", | ||
) | ||
image_to_image.generate( | ||
reference_image, | ||
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", | ||
) | ||
``` | ||
""" | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
# Default compilation. | ||
self.compile() | ||
|
||
@property | ||
def image_shape(self): | ||
return tuple(self.backbone.image_shape) | ||
|
||
@property | ||
def latent_shape(self): | ||
return tuple(self.backbone.latent_shape) | ||
|
||
def compile( | ||
self, | ||
optimizer="auto", | ||
loss="auto", | ||
*, | ||
metrics="auto", | ||
**kwargs, | ||
): | ||
"""Configures the `ImageToImage` task for training. | ||
|
||
The `ImageToImage` task extends the default compilation signature of | ||
`keras.Model.compile` with defaults for `optimizer`, `loss`, and | ||
`metrics`. To override these defaults, pass any value | ||
to these arguments during compilation. | ||
|
||
Args: | ||
optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer` | ||
instance. Defaults to `"auto"`, which uses the default optimizer | ||
for the given model and task. See `keras.Model.compile` and | ||
`keras.optimizers` for more info on possible `optimizer` values. | ||
loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance. | ||
Defaults to `"auto"`, where a | ||
`keras.losses.MeanSquaredError` loss will be applied. See | ||
`keras.Model.compile` and `keras.losses` for more info on | ||
possible `loss` values. | ||
metrics: `"auto"`, or a list of metrics to be evaluated by | ||
the model during training and testing. Defaults to `"auto"`, | ||
where a `keras.metrics.MeanSquaredError` will be applied to | ||
track the loss of the model during training. See | ||
`keras.Model.compile` and `keras.metrics` for more info on | ||
possible `metrics` values. | ||
**kwargs: See `keras.Model.compile` for a full list of arguments | ||
supported by the compile method. | ||
""" | ||
# Ref: https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py#L410-L414 | ||
if optimizer == "auto": | ||
optimizer = keras.optimizers.AdamW( | ||
1e-4, weight_decay=1e-2, epsilon=1e-8, clipnorm=1.0 | ||
) | ||
if loss == "auto": | ||
loss = keras.losses.MeanSquaredError() | ||
if metrics == "auto": | ||
metrics = [keras.metrics.MeanSquaredError()] | ||
super().compile( | ||
optimizer=optimizer, | ||
loss=loss, | ||
metrics=metrics, | ||
**kwargs, | ||
) | ||
self.generate_function = None | ||
|
||
def generate_step(self, *args, **kwargs): | ||
"""Run generation on batches of input.""" | ||
raise NotImplementedError | ||
|
||
def make_generate_function(self): | ||
"""Create or return the compiled generation function.""" | ||
if self.generate_function is not None: | ||
return self.generate_function | ||
|
||
self.generate_function = self.generate_step | ||
if keras.config.backend() == "torch": | ||
import torch | ||
|
||
def wrapped_function(*args, **kwargs): | ||
with torch.no_grad(): | ||
return self.generate_step(*args, **kwargs) | ||
|
||
self.generate_function = wrapped_function | ||
elif keras.config.backend() == "tensorflow" and not self.run_eagerly: | ||
self.generate_function = tf.function( | ||
self.generate_step, jit_compile=self.jit_compile | ||
) | ||
elif keras.config.backend() == "jax" and not self.run_eagerly: | ||
import jax | ||
|
||
@partial(jax.jit) | ||
def compiled_function(state, *args, **kwargs): | ||
( | ||
trainable_variables, | ||
non_trainable_variables, | ||
) = state | ||
mapping = itertools.chain( | ||
zip(self.trainable_variables, trainable_variables), | ||
zip(self.non_trainable_variables, non_trainable_variables), | ||
) | ||
|
||
with keras.StatelessScope(state_mapping=mapping): | ||
outputs = self.generate_step(*args, **kwargs) | ||
return outputs | ||
|
||
def wrapped_function(*args, **kwargs): | ||
# Create an explicit tuple of all variable state. | ||
state = ( | ||
# Use the explicit variable.value to preserve the | ||
# sharding spec of distribution. | ||
[v.value for v in self.trainable_variables], | ||
[v.value for v in self.non_trainable_variables], | ||
) | ||
outputs = compiled_function(state, *args, **kwargs) | ||
return outputs | ||
|
||
self.generate_function = wrapped_function | ||
return self.generate_function | ||
|
||
def _normalize_generate_images(self, inputs): | ||
"""Normalize user image to the generate function. | ||
|
||
This function converts all inputs to tensors, adds a batch dimension if | ||
necessary, and returns a iterable "dataset like" object (either an | ||
actual `tf.data.Dataset` or a list with a single batch element). | ||
""" | ||
if tf and isinstance(inputs, tf.data.Dataset): | ||
return inputs.as_numpy_iterator(), False | ||
|
||
def normalize(x): | ||
data_format = getattr( | ||
self.backbone, "data_format", standardize_data_format(None) | ||
) | ||
input_is_scalar = False | ||
x = ops.convert_to_tensor(x) | ||
if len(ops.shape(x)) < 4: | ||
x = ops.expand_dims(x, axis=0) | ||
input_is_scalar = True | ||
x = ops.image.resize( | ||
x, | ||
(self.backbone.height, self.backbone.width), | ||
interpolation="nearest", | ||
data_format=data_format, | ||
) | ||
return x, input_is_scalar | ||
|
||
if isinstance(inputs, dict): | ||
for key in inputs: | ||
inputs[key], input_is_scalar = normalize(inputs[key]) | ||
else: | ||
inputs, input_is_scalar = normalize(inputs) | ||
|
||
return inputs, input_is_scalar | ||
|
||
def _normalize_generate_inputs(self, inputs): | ||
"""Normalize user input to the generate function. | ||
|
||
This function converts all inputs to tensors, adds a batch dimension if | ||
necessary, and returns a iterable "dataset like" object (either an | ||
actual `tf.data.Dataset` or a list with a single batch element). | ||
""" | ||
if tf and isinstance(inputs, tf.data.Dataset): | ||
return inputs.as_numpy_iterator(), False | ||
|
||
def normalize(x): | ||
if isinstance(x, str): | ||
return [x], True | ||
if tf and isinstance(x, tf.Tensor) and x.shape.rank == 0: | ||
return x[tf.newaxis], True | ||
return x, False | ||
|
||
if isinstance(inputs, dict): | ||
for key in inputs: | ||
inputs[key], input_is_scalar = normalize(inputs[key]) | ||
else: | ||
inputs, input_is_scalar = normalize(inputs) | ||
|
||
return inputs, input_is_scalar | ||
|
||
def _normalize_generate_outputs(self, outputs, input_is_scalar): | ||
"""Normalize user output from the generate function. | ||
|
||
This function converts all output to numpy with a value range of | ||
`[0, 255]`. If a batch dimension was added to the input, it is removed | ||
from the output. | ||
""" | ||
|
||
def normalize(x): | ||
outputs = ops.clip(ops.divide(ops.add(x, 1.0), 2.0), 0.0, 1.0) | ||
outputs = ops.cast(ops.round(ops.multiply(outputs, 255.0)), "uint8") | ||
outputs = ops.convert_to_numpy(outputs) | ||
if input_is_scalar: | ||
outputs = outputs[0] | ||
return outputs | ||
|
||
if isinstance(outputs[0], dict): | ||
normalized = {} | ||
for key in outputs[0]: | ||
normalized[key] = normalize([x[key] for x in outputs]) | ||
return normalized | ||
return normalize([x for x in outputs]) | ||
|
||
def generate( | ||
self, | ||
images, | ||
inputs, | ||
negative_inputs, | ||
num_steps, | ||
guidance_scale, | ||
strength, | ||
seed=None, | ||
): | ||
"""Generate image based on the provided `images` and `inputs`. | ||
|
||
The `images` are reference images within a value range of `[-1.0, 1.0]`, | ||
which will be resized to `self.backbone.height` and | ||
`self.backbone.width`, then encoded into latent space by the VAE | ||
encoder. The `inputs` are strings that will be tokenized and encoded by | ||
the text encoder. | ||
|
||
If `images` and `inputs` are a `tf.data.Dataset`, outputs will be | ||
generated "batch-by-batch" and concatenated. Otherwise, all inputs will | ||
be processed as batches. | ||
|
||
Args: | ||
images: python data, tensor data, or a `tf.data.Dataset`. | ||
inputs: python data, tensor data, or a `tf.data.Dataset`. | ||
negative_inputs: python data, tensor data, or a `tf.data.Dataset`. | ||
Unlike `inputs`, these are used as negative inputs to guide the | ||
generation. If not provided, it defaults to `""` for each input | ||
in `inputs`. | ||
num_steps: int. The number of diffusion steps to take. | ||
guidance_scale: float. The classifier free guidance scale defined in | ||
[Classifier-Free Diffusion Guidance]( | ||
https://arxiv.org/abs/2207.12598). A higher scale encourages | ||
generating images more closely related to the prompts, typically | ||
at the cost of lower image quality. | ||
strength: float. Indicates the extent to which the reference | ||
`images` are transformed. Must be between `0.0` and `1.0`. When | ||
`strength=1.0`, `images` is essentially ignore and added noise | ||
is maximum and the denoising process runs for the full number of | ||
iterations specified in `num_steps`. | ||
seed: optional int. Used as a random seed. | ||
""" | ||
num_steps = int(num_steps) | ||
guidance_scale = float(guidance_scale) | ||
strength = float(strength) | ||
if strength < 0.0 or strength > 1.0: | ||
raise ValueError( | ||
"`strength` must be between `0.0` and `1.0`. " | ||
f"Received strength={strength}." | ||
) | ||
|
||
# Setup our three main passes. | ||
# 1. Preprocessing strings to dense integer tensors. | ||
# 2. Generate outputs via a compiled function on dense tensors. | ||
# 3. Postprocess dense tensors to a value range of `[0, 255]`. | ||
generate_function = self.make_generate_function() | ||
|
||
def preprocess(x): | ||
return self.preprocessor.generate_preprocess(x) | ||
|
||
# Normalize and preprocess inputs. | ||
images, image_is_scalar = self._normalize_generate_images(images) | ||
inputs, _ = self._normalize_generate_inputs(inputs) | ||
if negative_inputs is None: | ||
negative_inputs = [""] * len(inputs) | ||
negative_inputs, _ = self._normalize_generate_inputs(negative_inputs) | ||
|
||
if self.preprocessor is not None: | ||
inputs = preprocess(inputs) | ||
negative_inputs = preprocess(negative_inputs) | ||
if isinstance(inputs, dict): | ||
batch_size = len(inputs[list(inputs.keys())[0]]) | ||
else: | ||
batch_size = len(inputs) | ||
|
||
# Get the starting step for denoising. | ||
starting_step = int(num_steps * (1.0 - strength)) | ||
|
||
# Initialize random noises. | ||
noise_shape = (batch_size,) + self.latent_shape[1:] | ||
noises = random.normal(noise_shape, dtype="float32", seed=seed) | ||
|
||
# Image-to-image. | ||
outputs = generate_function( | ||
ops.convert_to_tensor(images), | ||
noises, | ||
inputs, | ||
negative_inputs, | ||
ops.convert_to_tensor(starting_step, "int32"), | ||
ops.convert_to_tensor(num_steps, "int32"), | ||
ops.convert_to_tensor(guidance_scale), | ||
) | ||
return self._normalize_generate_outputs(outputs, image_is_scalar) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add more details about this arg
inputs
, seems too generic.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review. I have some ideas to refine the args for
TextToImage
,ImageToImage
andInpaint
.I believe it would be better to align the API design with that of
CausalLM
.Working on it!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@divyashreepathihalli
I have refactored the arguments for all
TextToImage
,ImageToImage
andInpaint
tasks to use a singleinputs
, which aligns withCausalLM
. Additionally, detailed docstrings have been added.For
inputs
:str
,[str]
,{"prompts": str|[str]}
,{"prompts": str|[str], "negative_prompts": str|[str]}
, andtf.data.Dataset
.{"images": array, "prompts": str|[str]}
,{"images": array, "prompts": str|[str], "negative_prompts": str|[str]}
, andtf.data.Dataset
.{"images": array, "masks": array, "prompts": str|[str]}
,{"images": array, "masks": array, "prompts": str|[str], "negative_prompts": str|[str]}
andtf.data.Dataset
.I’ve also noticed that
Flux
doesn't support"negative_prompts"
. To address this, I’ve introduced a new attr,support_negative_prompts
, that toggles this feature based the attr.Please let me know if this looks good.