Skip to content

Commit

Permalink
Update SD3 init parameters (replacing height, width with `image_s…
Browse files Browse the repository at this point in the history
…hape`) (#1951)

* Replace SD3 `height` and `width` with `image_shape`

* Update URI

* Revert comment

* Update SD3 handle

* Replace `height` and `width` with `image_shape`

* Update docstrings

* Fix CI
  • Loading branch information
james77777778 authored Oct 24, 2024
1 parent 1283e70 commit e24a516
Show file tree
Hide file tree
Showing 13 changed files with 34 additions and 46 deletions.
2 changes: 1 addition & 1 deletion keras_hub/src/models/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def normalize_images(x):
input_is_scalar = True
x = ops.image.resize(
x,
(self.backbone.height, self.backbone.width),
(self.backbone.image_shape[0], self.backbone.image_shape[1]),
interpolation="nearest",
data_format=data_format,
)
Expand Down
12 changes: 6 additions & 6 deletions keras_hub/src/models/inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def normalize(x):
input_is_scalar = True
x = ops.image.resize(
x,
(self.backbone.height, self.backbone.width),
(self.backbone.image_shape[0], self.backbone.image_shape[1]),
interpolation="nearest",
data_format=data_format,
)
Expand Down Expand Up @@ -240,7 +240,7 @@ def normalize(x):
x = ops.cast(x, "float32")
x = ops.image.resize(
x,
(self.backbone.height, self.backbone.width),
(self.backbone.image_shape[0], self.backbone.image_shape[1]),
interpolation="nearest",
data_format=data_format,
)
Expand Down Expand Up @@ -303,7 +303,7 @@ def normalize_images(x):
input_is_scalar = True
x = ops.image.resize(
x,
(self.backbone.height, self.backbone.width),
(self.backbone.image_shape[0], self.backbone.image_shape[1]),
interpolation="nearest",
data_format=data_format,
)
Expand All @@ -323,7 +323,7 @@ def normalize_masks(x):
x = ops.cast(x, "float32")
x = ops.image.resize(
x,
(self.backbone.height, self.backbone.width),
(self.backbone.image_shape[0], self.backbone.image_shape[1]),
interpolation="nearest",
data_format=data_format,
)
Expand Down Expand Up @@ -384,8 +384,8 @@ def generate(
Typically, `inputs` is a dict with `"images"` `"masks"` and `"prompts"`
keys. `"images"` are reference images within a value range of
`[-1.0, 1.0]`, which will be resized to `self.backbone.height` and
`self.backbone.width`, then encoded into latent space by the VAE
`[-1.0, 1.0]`, which will be resized to height and width from
`self.backbone.image_shape`, then encoded into latent space by the VAE
encoder. `"masks"` are mask images with a boolean dtype, where white
pixels are repainted while black pixels are preserved. `"prompts"` are
strings that will be tokenized and encoded by the text encoder.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,8 @@ class StableDiffusion3Backbone(Backbone):
model. Defaults to `1000`.
shift: float. The shift value for the timestep schedule. Defaults to
`3.0`.
height: optional int. The output height of the image.
width: optional int. The output width of the image.
image_shape: tuple. The input shape without the batch size. Defaults to
`(1024, 1024, 3)`.
data_format: `None` or str. If specified, either `"channels_last"` or
`"channels_first"`. The ordering of the dimensions in the
inputs. `"channels_last"` corresponds to inputs with shape
Expand Down Expand Up @@ -270,23 +270,21 @@ def __init__(
output_channels=3,
num_train_timesteps=1000,
shift=3.0,
height=None,
width=None,
image_shape=(1024, 1024, 3),
data_format=None,
dtype=None,
**kwargs,
):
height = int(height or 1024)
width = int(width or 1024)
if height % 8 != 0 or width % 8 != 0:
raise ValueError(
"`height` and `width` must be divisible by 8. "
f"Received: height={height}, width={width}"
)
data_format = standardize_data_format(data_format)
if data_format != "channels_last":
raise NotImplementedError
image_shape = (height, width, int(vae.input_channels))
height = image_shape[0]
width = image_shape[1]
if height % 8 != 0 or width % 8 != 0:
raise ValueError(
"height and width in `image_shape` must be divisible by 8. "
f"Received: image_shape={image_shape}"
)
latent_shape = (height // 8, width // 8, int(latent_channels))
context_shape = (None, 4096 if t5 is None else t5.hidden_dim)
pooled_projection_shape = (clip_l.hidden_dim + clip_g.hidden_dim,)
Expand Down Expand Up @@ -452,8 +450,7 @@ def __init__(
self.output_channels = output_channels
self.num_train_timesteps = num_train_timesteps
self.shift = shift
self.height = height
self.width = width
self.image_shape = image_shape

@property
def latent_shape(self):
Expand Down Expand Up @@ -585,8 +582,7 @@ def get_config(self):
"output_channels": self.output_channels,
"num_train_timesteps": self.num_train_timesteps,
"shift": self.shift,
"height": self.height,
"width": self.width,
"image_shape": self.image_shape,
}
)
return config
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

class StableDiffusion3BackboneTest(TestCase):
def setUp(self):
height, width = 64, 64
image_shape = (64, 64, 3)
height, width = image_shape[0], image_shape[1]
vae = VAEBackbone(
[32, 32, 32, 32],
[1, 1, 1, 1],
Expand All @@ -36,8 +37,7 @@ def setUp(self):
"vae": vae,
"clip_l": clip_l,
"clip_g": clip_g,
"height": height,
"width": width,
"image_shape": image_shape,
}
self.input_data = {
"images": ops.ones((2, height, width, 3)),
Expand Down Expand Up @@ -82,7 +82,6 @@ def test_all_presets(self):
preset=preset,
input_data=self.input_data,
init_kwargs={
"height": self.init_kwargs["height"],
"width": self.init_kwargs["width"],
"image_shape": self.init_kwargs["image_shape"],
},
)
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class StableDiffusion3ImageToImage(ImageToImage):
Use `generate()` to do image generation.
```python
image_to_image = keras_hub.models.StableDiffusion3ImageToImage.from_preset(
"stable_diffusion_3_medium", height=512, width=512
"stable_diffusion_3_medium", image_shape=(512, 512, 3)
)
image_to_image.generate(
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ def setUp(self):
clip_g=CLIPTextEncoder(
20, 128, 128, 2, 2, 256, "gelu", -2, name="clip_g"
),
height=64,
width=64,
image_shape=(64, 64, 3),
)
self.init_kwargs = {
"preprocessor": self.preprocessor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class StableDiffusion3Inpaint(Inpaint):
reference_image = np.ones((1024, 1024, 3), dtype="float32")
reference_mask = np.ones((1024, 1024), dtype="float32")
inpaint = keras_hub.models.StableDiffusion3Inpaint.from_preset(
"stable_diffusion_3_medium", height=512, width=512
"stable_diffusion_3_medium", image_shape=(512, 512, 3)
)
inpaint.generate(
reference_image,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ def setUp(self):
clip_g=CLIPTextEncoder(
20, 128, 128, 2, 2, 256, "gelu", -2, name="clip_g"
),
height=64,
width=64,
image_shape=(64, 64, 3),
)
self.init_kwargs = {
"preprocessor": self.preprocessor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@
"path": "stable_diffusion_3",
"model_card": "https://arxiv.org/abs/2110.00476",
},
"kaggle_handle": "kaggle://keras/stablediffusion3/keras/stable_diffusion_3_medium/2",
"kaggle_handle": "kaggle://keras/stablediffusion3/keras/stable_diffusion_3_medium/3",
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class StableDiffusion3TextToImage(TextToImage):
Use `generate()` to do image generation.
```python
text_to_image = keras_hub.models.StableDiffusion3TextToImage.from_preset(
"stable_diffusion_3_medium", height=512, width=512
"stable_diffusion_3_medium", image_shape=(512, 512, 3)
)
text_to_image.generate(
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ def setUp(self):
clip_g=CLIPTextEncoder(
20, 128, 128, 2, 2, 256, "gelu", -2, name="clip_g"
),
height=64,
width=64,
image_shape=(64, 64, 3),
)
self.init_kwargs = {
"preprocessor": self.preprocessor,
Expand Down
6 changes: 2 additions & 4 deletions keras_hub/src/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,10 +563,8 @@ def get_backbone_kwargs(self, **kwargs):
backbone_kwargs["dtype"] = kwargs.pop("dtype", None)

# Forward `height` and `width` to backbone when using `TextToImage`.
if "height" in kwargs:
backbone_kwargs["height"] = kwargs.pop("height", None)
if "width" in kwargs:
backbone_kwargs["width"] = kwargs.pop("width", None)
if "image_shape" in kwargs:
backbone_kwargs["image_shape"] = kwargs.pop("image_shape", None)

return backbone_kwargs, kwargs

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,7 @@ def convert_model(preset, height, width):
vae,
clip_l,
clip_g,
height=height,
width=width,
image_shape=(height, width, 3),
name="stable_diffusion_3_backbone",
)
return backbone
Expand Down Expand Up @@ -532,8 +531,7 @@ def main(_):

keras_preprocessor.save_to_preset(preset)
# Set the image size to 1024, the same as in huggingface/diffusers.
keras_model.height = 1024
keras_model.width = 1024
keras_model.image_shape = (1024, 1024, 3)
keras_model.save_to_preset(preset)
print(f"🏁 Preset saved to ./{preset}.")

Expand Down

0 comments on commit e24a516

Please sign in to comment.