From bcb9360f50c8e9a850001789466c59ccf6fa26f8 Mon Sep 17 00:00:00 2001 From: Brian Walshe Date: Tue, 31 Oct 2023 11:56:30 +0000 Subject: [PATCH] This change makes it possible to specify a batch_size > 1 in order to generate multiple images. --- Stable Diffusion Deep Dive.ipynb | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/Stable Diffusion Deep Dive.ipynb b/Stable Diffusion Deep Dive.ipynb index 23e2267..9417fee 100644 --- a/Stable Diffusion Deep Dive.ipynb +++ b/Stable Diffusion Deep Dive.ipynb @@ -249,6 +249,8 @@ "text_input = tokenizer(prompt, padding=\"max_length\", max_length=tokenizer.model_max_length, truncation=True, return_tensors=\"pt\")\n", "with torch.no_grad():\n", " text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]\n", + "text_embeddings = text_embeddings.repeat(batch_size, 1, 1)\n", + "\n", "max_length = text_input.input_ids.shape[-1]\n", "uncond_input = tokenizer(\n", " [\"\"] * batch_size, padding=\"max_length\", max_length=max_length, return_tensors=\"pt\"\n", @@ -304,7 +306,9 @@ "image = image.detach().cpu().permute(0, 2, 3, 1).numpy()\n", "images = (image * 255).round().astype(\"uint8\")\n", "pil_images = [Image.fromarray(image) for image in images]\n", - "pil_images[0]" + "for img in pil_images:\n", + " plt.figure()\n", + " plt.imshow(img)\n" ] }, { @@ -867,12 +871,13 @@ "num_inference_steps = 50 # Number of denoising steps\n", "guidance_scale = 8 # Scale for classifier-free guidance\n", "generator = torch.manual_seed(32) # Seed generator to create the inital latent noise\n", - "batch_size = 1\n", + "batch_size = encoded.size()[0]\n", "\n", "# Prep text (same as before)\n", "text_input = tokenizer(prompt, padding=\"max_length\", max_length=tokenizer.model_max_length, truncation=True, return_tensors=\"pt\")\n", "with torch.no_grad():\n", " text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]\n", + "text_embeddings = text_embeddings.repeat(batch_size, 1, 1)\n", "max_length = text_input.input_ids.shape[-1]\n", "uncond_input = tokenizer(\n", " [\"\"] * batch_size, padding=\"max_length\", max_length=max_length, return_tensors=\"pt\"\n", @@ -911,7 +916,9 @@ " # compute the previous noisy sample x_t -> x_t-1\n", " latents = scheduler.step(noise_pred, t, latents).prev_sample\n", "\n", - "latents_to_pil(latents)[0]" + "for img in latents_to_pil(latents):\n", + " plt.figure()\n", + " plt.imshow(img)\n" ] }, {