-
Notifications
You must be signed in to change notification settings - Fork 957
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Stable Diffusion 3 Example (#2558)
* Add stable diffusion 3 example Add get_qkv_linear to handle different dimensionality in linears Add stable diffusion 3 example Add use_quant_conv and use_post_quant_conv for vae in stable diffusion adapt existing AutoEncoderKLConfig to the change add forward_until_encoder_layer to ClipTextTransformer rename sd3 config to sd3_medium in mmdit; minor clean-up Enable flash-attn for mmdit impl when the feature is enabled. Add sd3 example codebase add document crediting references pass the cargo fmt test pass the clippy test * fix typos * expose cfg_scale and time_shift as options * Replace the sample image with JPG version. Change image output format accordingly. * make meaningful error messages * remove the tail-end assignment in sd3_vae_vb_rename * remove the CUDA requirement * use default_value in clap args * add use_flash_attn to turn on/off flash-attn for MMDiT at runtime * resolve clippy errors and warnings * use default_value_t * Pin the web-sys dependency. * Clippy fix. --------- Co-authored-by: Laurent <[email protected]>
- Loading branch information
1 parent
0d96ec3
commit ca7cf5c
Showing
16 changed files
with
751 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# candle-stable-diffusion-3: Candle Implementation of Stable Diffusion 3 Medium | ||
|
||
![](assets/stable-diffusion-3.jpg) | ||
|
||
*A cute rusty robot holding a candle torch in its hand, with glowing neon text \"LETS GO RUSTY\" displayed on its chest, bright background, high quality, 4k* | ||
|
||
Stable Diffusion 3 Medium is a text-to-image model based on Multimodal Diffusion Transformer (MMDiT) architecture. | ||
|
||
- [huggingface repo](https://huggingface.co/stabilityai/stable-diffusion-3-medium) | ||
- [research paper](https://arxiv.org/pdf/2403.03206) | ||
- [announcement blog post](https://stability.ai/news/stable-diffusion-3-medium) | ||
|
||
## Getting access to the weights | ||
|
||
The weights of Stable Diffusion 3 Medium is released by Stability AI under the Stability Community License. You will need to accept the conditions and acquire a license by visiting the [repo on HuggingFace Hub](https://huggingface.co/stabilityai/stable-diffusion-3-medium) to gain access to the weights for your HuggingFace account. | ||
|
||
On the first run, the weights will be automatically downloaded from the Huggingface Hub. You might be prompted to configure a [Huggingface User Access Tokens](https://huggingface.co/docs/hub/en/security-tokens) (recommended) on your computer if you haven't done that before. After the download, the weights will be [cached](https://huggingface.co/docs/datasets/en/cache) and remain accessible locally. | ||
|
||
## Running the model | ||
|
||
```shell | ||
cargo run --example stable-diffusion-3 --release --features=cuda -- \ | ||
--height 1024 --width 1024 \ | ||
--prompt 'A cute rusty robot holding a candle torch in its hand, with glowing neon text \"LETS GO RUSTY\" displayed on its chest, bright background, high quality, 4k' | ||
``` | ||
|
||
To display other options available, | ||
|
||
```shell | ||
cargo run --example stable-diffusion-3 --release --features=cuda -- --help | ||
``` | ||
|
||
If GPU supports, Flash-Attention is a strongly recommended feature as it can greatly improve the speed of inference, as MMDiT is a transformer model heavily depends on attentions. To utilize [candle-flash-attn](https://github.com/huggingface/candle/tree/main/candle-flash-attn) in the demo, you will need both `--features flash-attn` and `--use-flash-attn`. | ||
|
||
```shell | ||
cargo run --example stable-diffusion-3 --release --features=cuda,flash-attn -- --use-flash-attn ... | ||
``` | ||
|
||
## Performance Benchmark | ||
|
||
Below benchmark is done by generating 1024-by-1024 image from 28 steps of Euler sampling and measure the average speed (iteration per seconds). | ||
|
||
[candle](https://github.com/huggingface/candle) and [candle-flash-attn](https://github.com/huggingface/candle/tree/main/candle-flash-attn) is based on the commit of [0d96ec3](https://github.com/huggingface/candle/commit/0d96ec31e8be03f844ed0aed636d6217dee9c7bc). | ||
|
||
System specs (Desktop PCIE 5 x8/x8 dual-GPU setup): | ||
|
||
- Operating System: Ubuntu 23.10 | ||
- CPU: i9 12900K w/o overclocking. | ||
- RAM: 64G dual-channel DDR5 @ 4800 MT/s | ||
|
||
| Speed (iter/s) | w/o flash-attn | w/ flash-attn | | ||
| -------------- | -------------- | ------------- | | ||
| RTX 3090 Ti | 0.83 | 2.15 | | ||
| RTX 4090 | 1.72 | 4.06 | |
Binary file added
BIN
+81.4 KB
candle-examples/examples/stable-diffusion-3/assets/stable-diffusion-3.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
use anyhow::{Error as E, Ok, Result}; | ||
use candle::{DType, IndexOp, Module, Tensor, D}; | ||
use candle_transformers::models::{stable_diffusion, t5}; | ||
use tokenizers::tokenizer::Tokenizer; | ||
|
||
struct ClipWithTokenizer { | ||
clip: stable_diffusion::clip::ClipTextTransformer, | ||
config: stable_diffusion::clip::Config, | ||
tokenizer: Tokenizer, | ||
max_position_embeddings: usize, | ||
} | ||
|
||
impl ClipWithTokenizer { | ||
fn new( | ||
vb: candle_nn::VarBuilder, | ||
config: stable_diffusion::clip::Config, | ||
tokenizer_path: &str, | ||
max_position_embeddings: usize, | ||
) -> Result<Self> { | ||
let clip = stable_diffusion::clip::ClipTextTransformer::new(vb, &config)?; | ||
let path_buf = hf_hub::api::sync::Api::new()? | ||
.model(tokenizer_path.to_string()) | ||
.get("tokenizer.json")?; | ||
let tokenizer = Tokenizer::from_file(path_buf.to_str().ok_or(E::msg( | ||
"Failed to serialize huggingface PathBuf of CLIP tokenizer", | ||
))?) | ||
.map_err(E::msg)?; | ||
Ok(Self { | ||
clip, | ||
config, | ||
tokenizer, | ||
max_position_embeddings, | ||
}) | ||
} | ||
|
||
fn encode_text_to_embedding( | ||
&self, | ||
prompt: &str, | ||
device: &candle::Device, | ||
) -> Result<(Tensor, Tensor)> { | ||
let pad_id = match &self.config.pad_with { | ||
Some(padding) => *self | ||
.tokenizer | ||
.get_vocab(true) | ||
.get(padding.as_str()) | ||
.ok_or(E::msg("Failed to tokenize CLIP padding."))?, | ||
None => *self | ||
.tokenizer | ||
.get_vocab(true) | ||
.get("<|endoftext|>") | ||
.ok_or(E::msg("Failed to tokenize CLIP end-of-text."))?, | ||
}; | ||
|
||
let mut tokens = self | ||
.tokenizer | ||
.encode(prompt, true) | ||
.map_err(E::msg)? | ||
.get_ids() | ||
.to_vec(); | ||
|
||
let eos_position = tokens.len() - 1; | ||
|
||
while tokens.len() < self.max_position_embeddings { | ||
tokens.push(pad_id) | ||
} | ||
let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?; | ||
let (text_embeddings, text_embeddings_penultimate) = self | ||
.clip | ||
.forward_until_encoder_layer(&tokens, usize::MAX, -2)?; | ||
let text_embeddings_pooled = text_embeddings.i((0, eos_position, ..))?; | ||
|
||
Ok((text_embeddings_penultimate, text_embeddings_pooled)) | ||
} | ||
} | ||
|
||
struct T5WithTokenizer { | ||
t5: t5::T5EncoderModel, | ||
tokenizer: Tokenizer, | ||
max_position_embeddings: usize, | ||
} | ||
|
||
impl T5WithTokenizer { | ||
fn new(vb: candle_nn::VarBuilder, max_position_embeddings: usize) -> Result<Self> { | ||
let api = hf_hub::api::sync::Api::new()?; | ||
let repo = api.repo(hf_hub::Repo::with_revision( | ||
"google/t5-v1_1-xxl".to_string(), | ||
hf_hub::RepoType::Model, | ||
"refs/pr/2".to_string(), | ||
)); | ||
let config_filename = repo.get("config.json")?; | ||
let config = std::fs::read_to_string(config_filename)?; | ||
let config: t5::Config = serde_json::from_str(&config)?; | ||
let model = t5::T5EncoderModel::load(vb, &config)?; | ||
|
||
let tokenizer_filename = api | ||
.model("lmz/mt5-tokenizers".to_string()) | ||
.get("t5-v1_1-xxl.tokenizer.json")?; | ||
|
||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; | ||
Ok(Self { | ||
t5: model, | ||
tokenizer, | ||
max_position_embeddings, | ||
}) | ||
} | ||
|
||
fn encode_text_to_embedding( | ||
&mut self, | ||
prompt: &str, | ||
device: &candle::Device, | ||
) -> Result<Tensor> { | ||
let mut tokens = self | ||
.tokenizer | ||
.encode(prompt, true) | ||
.map_err(E::msg)? | ||
.get_ids() | ||
.to_vec(); | ||
tokens.resize(self.max_position_embeddings, 0); | ||
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?; | ||
let embeddings = self.t5.forward(&input_token_ids)?; | ||
Ok(embeddings) | ||
} | ||
} | ||
|
||
pub struct StableDiffusion3TripleClipWithTokenizer { | ||
clip_l: ClipWithTokenizer, | ||
clip_g: ClipWithTokenizer, | ||
clip_g_text_projection: candle_nn::Linear, | ||
t5: T5WithTokenizer, | ||
} | ||
|
||
impl StableDiffusion3TripleClipWithTokenizer { | ||
pub fn new(vb_fp16: candle_nn::VarBuilder, vb_fp32: candle_nn::VarBuilder) -> Result<Self> { | ||
let max_position_embeddings = 77usize; | ||
let clip_l = ClipWithTokenizer::new( | ||
vb_fp16.pp("clip_l.transformer"), | ||
stable_diffusion::clip::Config::sdxl(), | ||
"openai/clip-vit-large-patch14", | ||
max_position_embeddings, | ||
)?; | ||
|
||
let clip_g = ClipWithTokenizer::new( | ||
vb_fp16.pp("clip_g.transformer"), | ||
stable_diffusion::clip::Config::sdxl2(), | ||
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", | ||
max_position_embeddings, | ||
)?; | ||
|
||
let text_projection = candle_nn::linear_no_bias( | ||
1280, | ||
1280, | ||
vb_fp16.pp("clip_g.transformer.text_projection"), | ||
)?; | ||
|
||
// Current T5 implementation does not support fp16, so we use fp32 VarBuilder for T5. | ||
// This is a temporary workaround until the T5 implementation is updated to support fp16. | ||
// Also see: | ||
// https://github.com/huggingface/candle/issues/2480 | ||
// https://github.com/huggingface/candle/pull/2481 | ||
let t5 = T5WithTokenizer::new(vb_fp32.pp("t5xxl.transformer"), max_position_embeddings)?; | ||
|
||
Ok(Self { | ||
clip_l, | ||
clip_g, | ||
clip_g_text_projection: text_projection, | ||
t5, | ||
}) | ||
} | ||
|
||
pub fn encode_text_to_embedding( | ||
&mut self, | ||
prompt: &str, | ||
device: &candle::Device, | ||
) -> Result<(Tensor, Tensor)> { | ||
let (clip_l_embeddings, clip_l_embeddings_pooled) = | ||
self.clip_l.encode_text_to_embedding(prompt, device)?; | ||
let (clip_g_embeddings, clip_g_embeddings_pooled) = | ||
self.clip_g.encode_text_to_embedding(prompt, device)?; | ||
|
||
let clip_g_embeddings_pooled = self | ||
.clip_g_text_projection | ||
.forward(&clip_g_embeddings_pooled.unsqueeze(0)?)? | ||
.squeeze(0)?; | ||
|
||
let y = Tensor::cat(&[&clip_l_embeddings_pooled, &clip_g_embeddings_pooled], 0)? | ||
.unsqueeze(0)?; | ||
let clip_embeddings_concat = Tensor::cat( | ||
&[&clip_l_embeddings, &clip_g_embeddings], | ||
D::Minus1, | ||
)? | ||
.pad_with_zeros(D::Minus1, 0, 2048)?; | ||
|
||
let t5_embeddings = self | ||
.t5 | ||
.encode_text_to_embedding(prompt, device)? | ||
.to_dtype(DType::F16)?; | ||
let context = Tensor::cat(&[&clip_embeddings_concat, &t5_embeddings], D::Minus2)?; | ||
|
||
Ok((context, y)) | ||
} | ||
} |
Oops, something went wrong.