diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 0c1219d760..d3e23b922c 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -122,3 +122,6 @@ required-features = ["onnx"] [[example]] name = "colpali" required-features = ["pdf2image"] + +[[example]] +name = "stable-diffusion-3" \ No newline at end of file diff --git a/candle-examples/examples/stable-diffusion-3/README.md b/candle-examples/examples/stable-diffusion-3/README.md new file mode 100644 index 0000000000..746a31fa1b --- /dev/null +++ b/candle-examples/examples/stable-diffusion-3/README.md @@ -0,0 +1,54 @@ +# candle-stable-diffusion-3: Candle Implementation of Stable Diffusion 3 Medium + +![](assets/stable-diffusion-3.jpg) + +*A cute rusty robot holding a candle torch in its hand, with glowing neon text \"LETS GO RUSTY\" displayed on its chest, bright background, high quality, 4k* + +Stable Diffusion 3 Medium is a text-to-image model based on Multimodal Diffusion Transformer (MMDiT) architecture. + +- [huggingface repo](https://huggingface.co/stabilityai/stable-diffusion-3-medium) +- [research paper](https://arxiv.org/pdf/2403.03206) +- [announcement blog post](https://stability.ai/news/stable-diffusion-3-medium) + +## Getting access to the weights + +The weights of Stable Diffusion 3 Medium is released by Stability AI under the Stability Community License. You will need to accept the conditions and acquire a license by visiting the [repo on HuggingFace Hub](https://huggingface.co/stabilityai/stable-diffusion-3-medium) to gain access to the weights for your HuggingFace account. + +On the first run, the weights will be automatically downloaded from the Huggingface Hub. You might be prompted to configure a [Huggingface User Access Tokens](https://huggingface.co/docs/hub/en/security-tokens) (recommended) on your computer if you haven't done that before. After the download, the weights will be [cached](https://huggingface.co/docs/datasets/en/cache) and remain accessible locally. + +## Running the model + +```shell +cargo run --example stable-diffusion-3 --release --features=cuda -- \ + --height 1024 --width 1024 \ + --prompt 'A cute rusty robot holding a candle torch in its hand, with glowing neon text \"LETS GO RUSTY\" displayed on its chest, bright background, high quality, 4k' +``` + +To display other options available, + +```shell +cargo run --example stable-diffusion-3 --release --features=cuda -- --help +``` + +If GPU supports, Flash-Attention is a strongly recommended feature as it can greatly improve the speed of inference, as MMDiT is a transformer model heavily depends on attentions. To utilize [candle-flash-attn](https://github.com/huggingface/candle/tree/main/candle-flash-attn) in the demo, you will need both `--features flash-attn` and `--use-flash-attn`. + +```shell +cargo run --example stable-diffusion-3 --release --features=cuda,flash-attn -- --use-flash-attn ... +``` + +## Performance Benchmark + +Below benchmark is done by generating 1024-by-1024 image from 28 steps of Euler sampling and measure the average speed (iteration per seconds). + +[candle](https://github.com/huggingface/candle) and [candle-flash-attn](https://github.com/huggingface/candle/tree/main/candle-flash-attn) is based on the commit of [0d96ec3](https://github.com/huggingface/candle/commit/0d96ec31e8be03f844ed0aed636d6217dee9c7bc). + +System specs (Desktop PCIE 5 x8/x8 dual-GPU setup): + +- Operating System: Ubuntu 23.10 +- CPU: i9 12900K w/o overclocking. +- RAM: 64G dual-channel DDR5 @ 4800 MT/s + +| Speed (iter/s) | w/o flash-attn | w/ flash-attn | +| -------------- | -------------- | ------------- | +| RTX 3090 Ti | 0.83 | 2.15 | +| RTX 4090 | 1.72 | 4.06 | diff --git a/candle-examples/examples/stable-diffusion-3/assets/stable-diffusion-3.jpg b/candle-examples/examples/stable-diffusion-3/assets/stable-diffusion-3.jpg new file mode 100644 index 0000000000..58ca16c3bf Binary files /dev/null and b/candle-examples/examples/stable-diffusion-3/assets/stable-diffusion-3.jpg differ diff --git a/candle-examples/examples/stable-diffusion-3/clip.rs b/candle-examples/examples/stable-diffusion-3/clip.rs new file mode 100644 index 0000000000..77263d968c --- /dev/null +++ b/candle-examples/examples/stable-diffusion-3/clip.rs @@ -0,0 +1,201 @@ +use anyhow::{Error as E, Ok, Result}; +use candle::{DType, IndexOp, Module, Tensor, D}; +use candle_transformers::models::{stable_diffusion, t5}; +use tokenizers::tokenizer::Tokenizer; + +struct ClipWithTokenizer { + clip: stable_diffusion::clip::ClipTextTransformer, + config: stable_diffusion::clip::Config, + tokenizer: Tokenizer, + max_position_embeddings: usize, +} + +impl ClipWithTokenizer { + fn new( + vb: candle_nn::VarBuilder, + config: stable_diffusion::clip::Config, + tokenizer_path: &str, + max_position_embeddings: usize, + ) -> Result { + let clip = stable_diffusion::clip::ClipTextTransformer::new(vb, &config)?; + let path_buf = hf_hub::api::sync::Api::new()? + .model(tokenizer_path.to_string()) + .get("tokenizer.json")?; + let tokenizer = Tokenizer::from_file(path_buf.to_str().ok_or(E::msg( + "Failed to serialize huggingface PathBuf of CLIP tokenizer", + ))?) + .map_err(E::msg)?; + Ok(Self { + clip, + config, + tokenizer, + max_position_embeddings, + }) + } + + fn encode_text_to_embedding( + &self, + prompt: &str, + device: &candle::Device, + ) -> Result<(Tensor, Tensor)> { + let pad_id = match &self.config.pad_with { + Some(padding) => *self + .tokenizer + .get_vocab(true) + .get(padding.as_str()) + .ok_or(E::msg("Failed to tokenize CLIP padding."))?, + None => *self + .tokenizer + .get_vocab(true) + .get("<|endoftext|>") + .ok_or(E::msg("Failed to tokenize CLIP end-of-text."))?, + }; + + let mut tokens = self + .tokenizer + .encode(prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + + let eos_position = tokens.len() - 1; + + while tokens.len() < self.max_position_embeddings { + tokens.push(pad_id) + } + let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?; + let (text_embeddings, text_embeddings_penultimate) = self + .clip + .forward_until_encoder_layer(&tokens, usize::MAX, -2)?; + let text_embeddings_pooled = text_embeddings.i((0, eos_position, ..))?; + + Ok((text_embeddings_penultimate, text_embeddings_pooled)) + } +} + +struct T5WithTokenizer { + t5: t5::T5EncoderModel, + tokenizer: Tokenizer, + max_position_embeddings: usize, +} + +impl T5WithTokenizer { + fn new(vb: candle_nn::VarBuilder, max_position_embeddings: usize) -> Result { + let api = hf_hub::api::sync::Api::new()?; + let repo = api.repo(hf_hub::Repo::with_revision( + "google/t5-v1_1-xxl".to_string(), + hf_hub::RepoType::Model, + "refs/pr/2".to_string(), + )); + let config_filename = repo.get("config.json")?; + let config = std::fs::read_to_string(config_filename)?; + let config: t5::Config = serde_json::from_str(&config)?; + let model = t5::T5EncoderModel::load(vb, &config)?; + + let tokenizer_filename = api + .model("lmz/mt5-tokenizers".to_string()) + .get("t5-v1_1-xxl.tokenizer.json")?; + + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + Ok(Self { + t5: model, + tokenizer, + max_position_embeddings, + }) + } + + fn encode_text_to_embedding( + &mut self, + prompt: &str, + device: &candle::Device, + ) -> Result { + let mut tokens = self + .tokenizer + .encode(prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + tokens.resize(self.max_position_embeddings, 0); + let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?; + let embeddings = self.t5.forward(&input_token_ids)?; + Ok(embeddings) + } +} + +pub struct StableDiffusion3TripleClipWithTokenizer { + clip_l: ClipWithTokenizer, + clip_g: ClipWithTokenizer, + clip_g_text_projection: candle_nn::Linear, + t5: T5WithTokenizer, +} + +impl StableDiffusion3TripleClipWithTokenizer { + pub fn new(vb_fp16: candle_nn::VarBuilder, vb_fp32: candle_nn::VarBuilder) -> Result { + let max_position_embeddings = 77usize; + let clip_l = ClipWithTokenizer::new( + vb_fp16.pp("clip_l.transformer"), + stable_diffusion::clip::Config::sdxl(), + "openai/clip-vit-large-patch14", + max_position_embeddings, + )?; + + let clip_g = ClipWithTokenizer::new( + vb_fp16.pp("clip_g.transformer"), + stable_diffusion::clip::Config::sdxl2(), + "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", + max_position_embeddings, + )?; + + let text_projection = candle_nn::linear_no_bias( + 1280, + 1280, + vb_fp16.pp("clip_g.transformer.text_projection"), + )?; + + // Current T5 implementation does not support fp16, so we use fp32 VarBuilder for T5. + // This is a temporary workaround until the T5 implementation is updated to support fp16. + // Also see: + // https://github.com/huggingface/candle/issues/2480 + // https://github.com/huggingface/candle/pull/2481 + let t5 = T5WithTokenizer::new(vb_fp32.pp("t5xxl.transformer"), max_position_embeddings)?; + + Ok(Self { + clip_l, + clip_g, + clip_g_text_projection: text_projection, + t5, + }) + } + + pub fn encode_text_to_embedding( + &mut self, + prompt: &str, + device: &candle::Device, + ) -> Result<(Tensor, Tensor)> { + let (clip_l_embeddings, clip_l_embeddings_pooled) = + self.clip_l.encode_text_to_embedding(prompt, device)?; + let (clip_g_embeddings, clip_g_embeddings_pooled) = + self.clip_g.encode_text_to_embedding(prompt, device)?; + + let clip_g_embeddings_pooled = self + .clip_g_text_projection + .forward(&clip_g_embeddings_pooled.unsqueeze(0)?)? + .squeeze(0)?; + + let y = Tensor::cat(&[&clip_l_embeddings_pooled, &clip_g_embeddings_pooled], 0)? + .unsqueeze(0)?; + let clip_embeddings_concat = Tensor::cat( + &[&clip_l_embeddings, &clip_g_embeddings], + D::Minus1, + )? + .pad_with_zeros(D::Minus1, 0, 2048)?; + + let t5_embeddings = self + .t5 + .encode_text_to_embedding(prompt, device)? + .to_dtype(DType::F16)?; + let context = Tensor::cat(&[&clip_embeddings_concat, &t5_embeddings], D::Minus2)?; + + Ok((context, y)) + } +} diff --git a/candle-examples/examples/stable-diffusion-3/main.rs b/candle-examples/examples/stable-diffusion-3/main.rs new file mode 100644 index 0000000000..164ae4205b --- /dev/null +++ b/candle-examples/examples/stable-diffusion-3/main.rs @@ -0,0 +1,185 @@ +mod clip; +mod sampling; +mod vae; + +use candle::{DType, IndexOp, Tensor}; +use candle_transformers::models::mmdit::model::{Config as MMDiTConfig, MMDiT}; + +use crate::clip::StableDiffusion3TripleClipWithTokenizer; +use crate::vae::{build_sd3_vae_autoencoder, sd3_vae_vb_rename}; + +use anyhow::{Ok, Result}; +use clap::Parser; + +#[derive(Parser)] +#[command(author, version, about, long_about = None)] +struct Args { + /// The prompt to be used for image generation. + #[arg( + long, + default_value = "A cute rusty robot holding a candle torch in its hand, \ + with glowing neon text \"LETS GO RUSTY\" displayed on its chest, \ + bright background, high quality, 4k" + )] + prompt: String, + + #[arg(long, default_value = "")] + uncond_prompt: String, + + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// The CUDA device ID to use. + #[arg(long, default_value = "0")] + cuda_device_id: usize, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + /// Use flash_attn to accelerate attention operation in the MMDiT. + #[arg(long)] + use_flash_attn: bool, + + /// The height in pixels of the generated image. + #[arg(long, default_value_t = 1024)] + height: usize, + + /// The width in pixels of the generated image. + #[arg(long, default_value_t = 1024)] + width: usize, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 28)] + num_inference_steps: usize, + + // CFG scale. + #[arg(long, default_value_t = 4.0)] + cfg_scale: f64, + + // Time shift factor (alpha). + #[arg(long, default_value_t = 3.0)] + time_shift: f64, + + /// The seed to use when generating random samples. + #[arg(long)] + seed: Option, +} + +fn main() -> Result<()> { + let args = Args::parse(); + // Your main code here + run(args) +} + +fn run(args: Args) -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let Args { + prompt, + uncond_prompt, + cpu, + cuda_device_id, + tracing, + use_flash_attn, + height, + width, + num_inference_steps, + cfg_scale, + time_shift, + seed, + } = args; + + let _guard = if tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + + // TODO: Support and test on Metal. + let device = if cpu { + candle::Device::Cpu + } else { + candle::Device::cuda_if_available(cuda_device_id)? + }; + + let api = hf_hub::api::sync::Api::new()?; + let sai_repo = { + let name = "stabilityai/stable-diffusion-3-medium"; + api.repo(hf_hub::Repo::model(name.to_string())) + }; + let model_file = sai_repo.get("sd3_medium_incl_clips_t5xxlfp16.safetensors")?; + let vb_fp16 = unsafe { + candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F16, &device)? + }; + + let (context, y) = { + let vb_fp32 = unsafe { + candle_nn::VarBuilder::from_mmaped_safetensors( + &[model_file.clone()], + DType::F32, + &device, + )? + }; + let mut triple = StableDiffusion3TripleClipWithTokenizer::new( + vb_fp16.pp("text_encoders"), + vb_fp32.pp("text_encoders"), + )?; + let (context, y) = triple.encode_text_to_embedding(prompt.as_str(), &device)?; + let (context_uncond, y_uncond) = + triple.encode_text_to_embedding(uncond_prompt.as_str(), &device)?; + ( + Tensor::cat(&[context, context_uncond], 0)?, + Tensor::cat(&[y, y_uncond], 0)?, + ) + }; + + let x = { + let mmdit = MMDiT::new( + &MMDiTConfig::sd3_medium(), + use_flash_attn, + vb_fp16.pp("model.diffusion_model"), + )?; + + if let Some(seed) = seed { + device.set_seed(seed)?; + } + let start_time = std::time::Instant::now(); + let x = sampling::euler_sample( + &mmdit, + &y, + &context, + num_inference_steps, + cfg_scale, + time_shift, + height, + width, + )?; + let dt = start_time.elapsed().as_secs_f32(); + println!( + "Sampling done. {num_inference_steps} steps. {:.2}s. Average rate: {:.2} iter/s", + dt, + num_inference_steps as f32 / dt + ); + x + }; + + let img = { + let vb_vae = vb_fp16 + .clone() + .rename_f(sd3_vae_vb_rename) + .pp("first_stage_model"); + let autoencoder = build_sd3_vae_autoencoder(vb_vae)?; + + // Apply TAESD3 scale factor. Seems to be significantly improving the quality of the image. + // https://github.com/comfyanonymous/ComfyUI/blob/3c60ecd7a83da43d694e26a77ca6b93106891251/nodes.py#L721-L723 + autoencoder.decode(&((x.clone() / 1.5305)? + 0.0609)?)? + }; + let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(candle::DType::U8)?; + candle_examples::save_image(&img.i(0)?, "out.jpg")?; + Ok(()) +} diff --git a/candle-examples/examples/stable-diffusion-3/sampling.rs b/candle-examples/examples/stable-diffusion-3/sampling.rs new file mode 100644 index 0000000000..147d8e7380 --- /dev/null +++ b/candle-examples/examples/stable-diffusion-3/sampling.rs @@ -0,0 +1,55 @@ +use anyhow::{Ok, Result}; +use candle::{DType, Tensor}; + +use candle_transformers::models::flux; +use candle_transformers::models::mmdit::model::MMDiT; // for the get_noise function + +#[allow(clippy::too_many_arguments)] +pub fn euler_sample( + mmdit: &MMDiT, + y: &Tensor, + context: &Tensor, + num_inference_steps: usize, + cfg_scale: f64, + time_shift: f64, + height: usize, + width: usize, +) -> Result { + let mut x = flux::sampling::get_noise(1, height, width, y.device())?.to_dtype(DType::F16)?; + let sigmas = (0..=num_inference_steps) + .map(|x| x as f64 / num_inference_steps as f64) + .rev() + .map(|x| time_snr_shift(time_shift, x)) + .collect::>(); + + for window in sigmas.windows(2) { + let (s_curr, s_prev) = match window { + [a, b] => (a, b), + _ => continue, + }; + + let timestep = (*s_curr) * 1000.0; + let noise_pred = mmdit.forward( + &Tensor::cat(&[x.clone(), x.clone()], 0)?, + &Tensor::full(timestep, (2,), x.device())?.contiguous()?, + y, + context, + )?; + x = (x + (apply_cfg(cfg_scale, &noise_pred)? * (*s_prev - *s_curr))?)?; + } + Ok(x) +} + +// The "Resolution-dependent shifting of timestep schedules" recommended in the SD3 tech report paper +// https://arxiv.org/pdf/2403.03206 +// Following the implementation in ComfyUI: +// https://github.com/comfyanonymous/ComfyUI/blob/3c60ecd7a83da43d694e26a77ca6b93106891251/ +// comfy/model_sampling.py#L181 +fn time_snr_shift(alpha: f64, t: f64) -> f64 { + alpha * t / (1.0 + (alpha - 1.0) * t) +} + +fn apply_cfg(cfg_scale: f64, noise_pred: &Tensor) -> Result { + Ok(((cfg_scale * noise_pred.narrow(0, 0, 1)?)? + - ((cfg_scale - 1.0) * noise_pred.narrow(0, 1, 1)?)?)?) +} diff --git a/candle-examples/examples/stable-diffusion-3/vae.rs b/candle-examples/examples/stable-diffusion-3/vae.rs new file mode 100644 index 0000000000..708e472eff --- /dev/null +++ b/candle-examples/examples/stable-diffusion-3/vae.rs @@ -0,0 +1,93 @@ +use anyhow::{Ok, Result}; +use candle_transformers::models::stable_diffusion::vae; + +pub fn build_sd3_vae_autoencoder(vb: candle_nn::VarBuilder) -> Result { + let config = vae::AutoEncoderKLConfig { + block_out_channels: vec![128, 256, 512, 512], + layers_per_block: 2, + latent_channels: 16, + norm_num_groups: 32, + use_quant_conv: false, + use_post_quant_conv: false, + }; + Ok(vae::AutoEncoderKL::new(vb, 3, 3, config)?) +} + +pub fn sd3_vae_vb_rename(name: &str) -> String { + let parts: Vec<&str> = name.split('.').collect(); + let mut result = Vec::new(); + let mut i = 0; + + while i < parts.len() { + match parts[i] { + "down_blocks" => { + result.push("down"); + } + "mid_block" => { + result.push("mid"); + } + "up_blocks" => { + result.push("up"); + match parts[i + 1] { + // Reverse the order of up_blocks. + "0" => result.push("3"), + "1" => result.push("2"), + "2" => result.push("1"), + "3" => result.push("0"), + _ => {} + } + i += 1; // Skip the number after up_blocks. + } + "resnets" => { + if i > 0 && parts[i - 1] == "mid_block" { + match parts[i + 1] { + "0" => result.push("block_1"), + "1" => result.push("block_2"), + _ => {} + } + i += 1; // Skip the number after resnets. + } else { + result.push("block"); + } + } + "downsamplers" => { + result.push("downsample"); + i += 1; // Skip the 0 after downsamplers. + } + "conv_shortcut" => { + result.push("nin_shortcut"); + } + "attentions" => { + if parts[i + 1] == "0" { + result.push("attn_1") + } + i += 1; // Skip the number after attentions. + } + "group_norm" => { + result.push("norm"); + } + "query" => { + result.push("q"); + } + "key" => { + result.push("k"); + } + "value" => { + result.push("v"); + } + "proj_attn" => { + result.push("proj_out"); + } + "conv_norm_out" => { + result.push("norm_out"); + } + "upsamplers" => { + result.push("upsample"); + i += 1; // Skip the 0 after upsamplers. + } + part => result.push(part), + } + i += 1; + } + result.join(".") +} diff --git a/candle-transformers/src/models/mmdit/blocks.rs b/candle-transformers/src/models/mmdit/blocks.rs index e2b924a013..a1777f915b 100644 --- a/candle-transformers/src/models/mmdit/blocks.rs +++ b/candle-transformers/src/models/mmdit/blocks.rs @@ -194,10 +194,16 @@ pub struct JointBlock { x_block: DiTBlock, context_block: DiTBlock, num_heads: usize, + use_flash_attn: bool, } impl JointBlock { - pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result { + pub fn new( + hidden_size: usize, + num_heads: usize, + use_flash_attn: bool, + vb: nn::VarBuilder, + ) -> Result { let x_block = DiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?; let context_block = DiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?; @@ -205,13 +211,15 @@ impl JointBlock { x_block, context_block, num_heads, + use_flash_attn, }) } pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)> { let (context_qkv, context_interm) = self.context_block.pre_attention(context, c)?; let (x_qkv, x_interm) = self.x_block.pre_attention(x, c)?; - let (context_attn, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads)?; + let (context_attn, x_attn) = + joint_attn(&context_qkv, &x_qkv, self.num_heads, self.use_flash_attn)?; let context_out = self.context_block .post_attention(&context_attn, context, &context_interm)?; @@ -224,16 +232,23 @@ pub struct ContextQkvOnlyJointBlock { x_block: DiTBlock, context_block: QkvOnlyDiTBlock, num_heads: usize, + use_flash_attn: bool, } impl ContextQkvOnlyJointBlock { - pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result { + pub fn new( + hidden_size: usize, + num_heads: usize, + use_flash_attn: bool, + vb: nn::VarBuilder, + ) -> Result { let x_block = DiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?; let context_block = QkvOnlyDiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?; Ok(Self { x_block, context_block, num_heads, + use_flash_attn, }) } @@ -241,7 +256,7 @@ impl ContextQkvOnlyJointBlock { let context_qkv = self.context_block.pre_attention(context, c)?; let (x_qkv, x_interm) = self.x_block.pre_attention(x, c)?; - let (_, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads)?; + let (_, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads, self.use_flash_attn)?; let x_out = self.x_block.post_attention(&x_attn, x, &x_interm)?; Ok(x_out) @@ -266,7 +281,28 @@ fn flash_compatible_attention( attn_scores.reshape(q_dims_for_matmul)?.transpose(1, 2) } -fn joint_attn(context_qkv: &Qkv, x_qkv: &Qkv, num_heads: usize) -> Result<(Tensor, Tensor)> { +#[cfg(feature = "flash-attn")] +fn flash_attn( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + causal: bool, +) -> Result { + candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal) +} + +#[cfg(not(feature = "flash-attn"))] +fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result { + unimplemented!("compile with '--features flash-attn'") +} + +fn joint_attn( + context_qkv: &Qkv, + x_qkv: &Qkv, + num_heads: usize, + use_flash_attn: bool, +) -> Result<(Tensor, Tensor)> { let qkv = Qkv { q: Tensor::cat(&[&context_qkv.q, &x_qkv.q], 1)?, k: Tensor::cat(&[&context_qkv.k, &x_qkv.k], 1)?, @@ -282,8 +318,12 @@ fn joint_attn(context_qkv: &Qkv, x_qkv: &Qkv, num_heads: usize) -> Result<(Tenso let headdim = qkv.q.dim(D::Minus1)?; let softmax_scale = 1.0 / (headdim as f64).sqrt(); - // let attn: Tensor = candle_flash_attn::flash_attn(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32, false)?; - let attn = flash_compatible_attention(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32)?; + + let attn = if use_flash_attn { + flash_attn(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32, false)? + } else { + flash_compatible_attention(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32)? + }; let attn = attn.reshape((batch_size, seqlen, ()))?; let context_qkv_seqlen = context_qkv.q.dim(1)?; diff --git a/candle-transformers/src/models/mmdit/model.rs b/candle-transformers/src/models/mmdit/model.rs index 1523836c7f..864b662377 100644 --- a/candle-transformers/src/models/mmdit/model.rs +++ b/candle-transformers/src/models/mmdit/model.rs @@ -23,7 +23,7 @@ pub struct Config { } impl Config { - pub fn sd3() -> Self { + pub fn sd3_medium() -> Self { Self { patch_size: 2, in_channels: 16, @@ -49,7 +49,7 @@ pub struct MMDiT { } impl MMDiT { - pub fn new(cfg: &Config, vb: nn::VarBuilder) -> Result { + pub fn new(cfg: &Config, use_flash_attn: bool, vb: nn::VarBuilder) -> Result { let hidden_size = cfg.head_size * cfg.depth; let core = MMDiTCore::new( cfg.depth, @@ -57,6 +57,7 @@ impl MMDiT { cfg.depth, cfg.patch_size, cfg.out_channels, + use_flash_attn, vb.clone(), )?; let patch_embedder = PatchEmbedder::new( @@ -135,6 +136,7 @@ impl MMDiTCore { num_heads: usize, patch_size: usize, out_channels: usize, + use_flash_attn: bool, vb: nn::VarBuilder, ) -> Result { let mut joint_blocks = Vec::with_capacity(depth - 1); @@ -142,6 +144,7 @@ impl MMDiTCore { joint_blocks.push(JointBlock::new( hidden_size, num_heads, + use_flash_attn, vb.pp(format!("joint_blocks.{}", i)), )?); } @@ -151,6 +154,7 @@ impl MMDiTCore { context_qkv_only_joint_block: ContextQkvOnlyJointBlock::new( hidden_size, num_heads, + use_flash_attn, vb.pp(format!("joint_blocks.{}", depth - 1)), )?, final_layer: FinalLayer::new( diff --git a/candle-transformers/src/models/mmdit/projections.rs b/candle-transformers/src/models/mmdit/projections.rs index 1077398f5c..dc1e8ec941 100644 --- a/candle-transformers/src/models/mmdit/projections.rs +++ b/candle-transformers/src/models/mmdit/projections.rs @@ -42,7 +42,6 @@ pub struct QkvOnlyAttnProjections { impl QkvOnlyAttnProjections { pub fn new(dim: usize, num_heads: usize, vb: nn::VarBuilder) -> Result { - // {'dim': 1536, 'num_heads': 24} let head_dim = dim / num_heads; let qkv = nn::linear(dim, dim * 3, vb.pp("qkv"))?; Ok(Self { qkv, head_dim }) diff --git a/candle-transformers/src/models/stable_diffusion/attention.rs b/candle-transformers/src/models/stable_diffusion/attention.rs index 5cc59e8203..c04e6aa1ff 100644 --- a/candle-transformers/src/models/stable_diffusion/attention.rs +++ b/candle-transformers/src/models/stable_diffusion/attention.rs @@ -467,6 +467,24 @@ pub struct AttentionBlock { config: AttentionBlockConfig, } +// In the .safetensor weights of official Stable Diffusion 3 Medium Huggingface repo +// https://huggingface.co/stabilityai/stable-diffusion-3-medium +// Linear layer may use a different dimension for the weight in the linear, which is +// incompatible with the current implementation of the nn::linear constructor. +// This is a workaround to handle the different dimensions. +fn get_qkv_linear(channels: usize, vs: nn::VarBuilder) -> Result { + match vs.get((channels, channels), "weight") { + Ok(_) => nn::linear(channels, channels, vs), + Err(_) => { + let weight = vs + .get((channels, channels, 1, 1), "weight")? + .reshape((channels, channels))?; + let bias = vs.get((channels,), "bias")?; + Ok(nn::Linear::new(weight, Some(bias))) + } + } +} + impl AttentionBlock { pub fn new(vs: nn::VarBuilder, channels: usize, config: AttentionBlockConfig) -> Result { let num_head_channels = config.num_head_channels.unwrap_or(channels); @@ -478,10 +496,10 @@ impl AttentionBlock { } else { ("query", "key", "value", "proj_attn") }; - let query = nn::linear(channels, channels, vs.pp(q_path))?; - let key = nn::linear(channels, channels, vs.pp(k_path))?; - let value = nn::linear(channels, channels, vs.pp(v_path))?; - let proj_attn = nn::linear(channels, channels, vs.pp(out_path))?; + let query = get_qkv_linear(channels, vs.pp(q_path))?; + let key = get_qkv_linear(channels, vs.pp(k_path))?; + let value = get_qkv_linear(channels, vs.pp(v_path))?; + let proj_attn = get_qkv_linear(channels, vs.pp(out_path))?; let span = tracing::span!(tracing::Level::TRACE, "attn-block"); Ok(Self { group_norm, diff --git a/candle-transformers/src/models/stable_diffusion/clip.rs b/candle-transformers/src/models/stable_diffusion/clip.rs index 5254818e60..2f631248bc 100644 --- a/candle-transformers/src/models/stable_diffusion/clip.rs +++ b/candle-transformers/src/models/stable_diffusion/clip.rs @@ -388,6 +388,37 @@ impl ClipTextTransformer { let xs = self.encoder.forward(&xs, &causal_attention_mask)?; self.final_layer_norm.forward(&xs) } + + pub fn forward_until_encoder_layer( + &self, + xs: &Tensor, + mask_after: usize, + until_layer: isize, + ) -> Result<(Tensor, Tensor)> { + let (bsz, seq_len) = xs.dims2()?; + let xs = self.embeddings.forward(xs)?; + let causal_attention_mask = + Self::build_causal_attention_mask(bsz, seq_len, mask_after, xs.device())?; + + let mut xs = xs.clone(); + let mut intermediate = xs.clone(); + + // Modified encoder.forward that returns the intermediate tensor along with final output. + let until_layer = if until_layer < 0 { + self.encoder.layers.len() as isize + until_layer + } else { + until_layer + } as usize; + + for (layer_id, layer) in self.encoder.layers.iter().enumerate() { + xs = layer.forward(&xs, &causal_attention_mask)?; + if layer_id == until_layer { + intermediate = xs.clone(); + } + } + + Ok((self.final_layer_norm.forward(&xs)?, intermediate)) + } } impl Module for ClipTextTransformer { diff --git a/candle-transformers/src/models/stable_diffusion/mod.rs b/candle-transformers/src/models/stable_diffusion/mod.rs index 30f239756c..37f4cdbf59 100644 --- a/candle-transformers/src/models/stable_diffusion/mod.rs +++ b/candle-transformers/src/models/stable_diffusion/mod.rs @@ -65,6 +65,8 @@ impl StableDiffusionConfig { layers_per_block: 2, latent_channels: 4, norm_num_groups: 32, + use_quant_conv: true, + use_post_quant_conv: true, }; let height = if let Some(height) = height { assert_eq!(height % 8, 0, "height has to be divisible by 8"); @@ -133,6 +135,8 @@ impl StableDiffusionConfig { layers_per_block: 2, latent_channels: 4, norm_num_groups: 32, + use_quant_conv: true, + use_post_quant_conv: true, }; let scheduler = Arc::new(ddim::DDIMSchedulerConfig { prediction_type, @@ -214,6 +218,8 @@ impl StableDiffusionConfig { layers_per_block: 2, latent_channels: 4, norm_num_groups: 32, + use_quant_conv: true, + use_post_quant_conv: true, }; let scheduler = Arc::new(ddim::DDIMSchedulerConfig { prediction_type, @@ -281,6 +287,8 @@ impl StableDiffusionConfig { layers_per_block: 2, latent_channels: 4, norm_num_groups: 32, + use_quant_conv: true, + use_post_quant_conv: true, }; let scheduler = Arc::new( euler_ancestral_discrete::EulerAncestralDiscreteSchedulerConfig { @@ -378,6 +386,8 @@ impl StableDiffusionConfig { layers_per_block: 2, latent_channels: 4, norm_num_groups: 32, + use_quant_conv: true, + use_post_quant_conv: true, }; let scheduler = Arc::new(ddim::DDIMSchedulerConfig { ..Default::default() diff --git a/candle-transformers/src/models/stable_diffusion/vae.rs b/candle-transformers/src/models/stable_diffusion/vae.rs index 670b3f5638..b3aba80277 100644 --- a/candle-transformers/src/models/stable_diffusion/vae.rs +++ b/candle-transformers/src/models/stable_diffusion/vae.rs @@ -275,6 +275,8 @@ pub struct AutoEncoderKLConfig { pub layers_per_block: usize, pub latent_channels: usize, pub norm_num_groups: usize, + pub use_quant_conv: bool, + pub use_post_quant_conv: bool, } impl Default for AutoEncoderKLConfig { @@ -284,6 +286,8 @@ impl Default for AutoEncoderKLConfig { layers_per_block: 1, latent_channels: 4, norm_num_groups: 32, + use_quant_conv: true, + use_post_quant_conv: true, } } } @@ -315,8 +319,8 @@ impl DiagonalGaussianDistribution { pub struct AutoEncoderKL { encoder: Encoder, decoder: Decoder, - quant_conv: nn::Conv2d, - post_quant_conv: nn::Conv2d, + quant_conv: Option, + post_quant_conv: Option, pub config: AutoEncoderKLConfig, } @@ -342,20 +346,33 @@ impl AutoEncoderKL { }; let decoder = Decoder::new(vs.pp("decoder"), latent_channels, out_channels, decoder_cfg)?; let conv_cfg = Default::default(); - let quant_conv = nn::conv2d( - 2 * latent_channels, - 2 * latent_channels, - 1, - conv_cfg, - vs.pp("quant_conv"), - )?; - let post_quant_conv = nn::conv2d( - latent_channels, - latent_channels, - 1, - conv_cfg, - vs.pp("post_quant_conv"), - )?; + + let quant_conv = { + if config.use_quant_conv { + Some(nn::conv2d( + 2 * latent_channels, + 2 * latent_channels, + 1, + conv_cfg, + vs.pp("quant_conv"), + )?) + } else { + None + } + }; + let post_quant_conv = { + if config.use_post_quant_conv { + Some(nn::conv2d( + latent_channels, + latent_channels, + 1, + conv_cfg, + vs.pp("post_quant_conv"), + )?) + } else { + None + } + }; Ok(Self { encoder, decoder, @@ -368,13 +385,19 @@ impl AutoEncoderKL { /// Returns the distribution in the latent space. pub fn encode(&self, xs: &Tensor) -> Result { let xs = self.encoder.forward(xs)?; - let parameters = self.quant_conv.forward(&xs)?; + let parameters = match &self.quant_conv { + None => xs, + Some(quant_conv) => quant_conv.forward(&xs)?, + }; DiagonalGaussianDistribution::new(¶meters) } /// Takes as input some sampled values. pub fn decode(&self, xs: &Tensor) -> Result { - let xs = self.post_quant_conv.forward(xs)?; - self.decoder.forward(&xs) + let xs = match &self.post_quant_conv { + None => xs, + Some(post_quant_conv) => &post_quant_conv.forward(xs)?, + }; + self.decoder.forward(xs) } } diff --git a/candle-wasm-examples/yolo/Cargo.toml b/candle-wasm-examples/yolo/Cargo.toml index e03319a043..c492521005 100644 --- a/candle-wasm-examples/yolo/Cargo.toml +++ b/candle-wasm-examples/yolo/Cargo.toml @@ -35,7 +35,7 @@ yew-agent = "0.2.0" yew = { version = "0.20.0", features = ["csr"] } [dependencies.web-sys] -version = "0.3.70" +version = "=0.3.70" features = [ 'Blob', 'CanvasRenderingContext2d', diff --git a/candle-wasm-tests/tests/quantized_tests.rs b/candle-wasm-tests/tests/quantized_tests.rs index 8705df4219..ae448078f0 100644 --- a/candle-wasm-tests/tests/quantized_tests.rs +++ b/candle-wasm-tests/tests/quantized_tests.rs @@ -1,3 +1,4 @@ +#![allow(unused)] use candle::{ quantized::{self, k_quants, GgmlDType, GgmlType}, test_utils::to_vec2_round,