Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add inpaint.py #295

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ __pycache__/
# C extensions
*.so

*.ckpt

# General MacOS
.DS_Store
.AppleDouble
Expand Down
5 changes: 5 additions & 0 deletions inpaint.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
python scripts/inpaint.py --src inpaintng_Birman_9.jpg\
--mask inpaintmask_Birman_9.jpg\
--prompt "A cat is sitting on the TV"\
--steps 100\
--n_sample 10
Binary file added inpaint_0_inpaintng_Birman_9.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added inpaint_1_inpaintng_Birman_9.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added inpaint_2_inpaintng_Birman_9.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added inpaint_3_inpaintng_Birman_9.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added inpaint_4_inpaintng_Birman_9.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added inpaint_5_inpaintng_Birman_9.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added inpaint_6_inpaintng_Birman_9.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added inpaint_7_inpaintng_Birman_9.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added inpaint_8_inpaintng_Birman_9.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added inpaint_9_inpaintng_Birman_9.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added inpaintmask_Birman_9.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added inpaintng_Birman_9.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
187 changes: 187 additions & 0 deletions scripts/inpaint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import sys
import cv2
import torch
import numpy as np
import gradio as gr
from PIL import Image
from omegaconf import OmegaConf
from einops import repeat
from imwatermark import WatermarkEncoder
from pathlib import Path
import argparse

from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config


torch.set_grad_enabled(False)


def put_watermark(img, wm_encoder=None):
if wm_encoder is not None:
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
img = wm_encoder.encode(img, 'dwtDct')
img = Image.fromarray(img[:, :, ::-1])
return img


def initialize_model(config, ckpt):
config = OmegaConf.load(config)
model = instantiate_from_config(config.model)

model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)

device = torch.device(
"cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
sampler = DDIMSampler(model)

return sampler


def make_batch_sd(
image,
mask,
txt,
device,
num_samples=1):
image = np.array(image.convert("RGB"))
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0

mask = np.array(mask.convert("L"))
mask = mask.astype(np.float32) / 255.0
mask = mask[None, None]
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)

masked_image = image * (mask < 0.5)

batch = {
"image": repeat(image.to(device=device), "1 ... -> n ...", n=num_samples),
"txt": num_samples * [txt],
"mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples),
"masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples),
}
return batch


def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1, w=512, h=512):
device = torch.device(
"cuda") if torch.cuda.is_available() else torch.device("cpu")
model = sampler.model

print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
wm = "SDV2"
wm_encoder = WatermarkEncoder()
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))

prng = np.random.RandomState(seed)
start_code = prng.randn(num_samples, 4, h // 8, w // 8)
start_code = torch.from_numpy(start_code).to(
device=device, dtype=torch.float32)

with torch.no_grad(), \
torch.autocast("cuda"):
batch = make_batch_sd(image, mask, txt=prompt,
device=device, num_samples=num_samples)

c = model.cond_stage_model.encode(batch["txt"])

c_cat = list()
for ck in model.concat_keys:
cc = batch[ck].float()
if ck != model.masked_image_key:
bchw = [num_samples, 4, h // 8, w // 8]
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
else:
cc = model.get_first_stage_encoding(
model.encode_first_stage(cc))
c_cat.append(cc)
c_cat = torch.cat(c_cat, dim=1)

# cond
cond = {"c_concat": [c_cat], "c_crossattn": [c]}

# uncond cond
uc_cross = model.get_unconditional_conditioning(num_samples, "")
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}

shape = [model.channels, h // 8, w // 8]
samples_cfg, intermediates = sampler.sample(
ddim_steps,
num_samples,
shape,
cond,
verbose=False,
eta=1.0,
unconditional_guidance_scale=scale,
unconditional_conditioning=uc_full,
x_T=start_code,
)
x_samples_ddim = model.decode_first_stage(samples_cfg)

result = torch.clamp((x_samples_ddim + 1.0) / 2.0,
min=0.0, max=1.0)

result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255
return [put_watermark(Image.fromarray(img.astype(np.uint8)), wm_encoder) for img in result]

def pad_image(input_image):
pad_w, pad_h = np.max(((2, 2), np.ceil(
np.array(input_image.size) / 64).astype(int)), axis=0) * 64 - input_image.size
im_padded = Image.fromarray(
np.pad(np.array(input_image), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
return im_padded

def predict(input_image, prompt, ddim_steps, num_samples, scale, seed, sampler):
init_image = input_image["image"].convert("RGB")
init_mask = input_image["mask"].convert("RGB")
image = pad_image(init_image) # resize to integer multiple of 32
mask = pad_image(init_mask) # resize to integer multiple of 32
width, height = image.size
print("Inpainting...", width, height)

result = inpaint(
sampler=sampler,
image=image,
mask=mask,
prompt=prompt,
seed=seed,
scale=scale,
ddim_steps=ddim_steps,
num_samples=num_samples,
h=height, w=width
)

return result



def parse_args():
parser = argparse.ArgumentParser(description='Image inpainting')
parser.add_argument('--config', type=str, default="configs/stable-diffusion/v2-inpainting-inference.yaml", help='config path')
parser.add_argument('--ckpt', type=str, default="512-inpainting-ema.ckpt", help='Model checkpoint')
parser.add_argument('--src', type=str, help='Source image path')
parser.add_argument('--prompt', type=str, help='Description for source image')
parser.add_argument('--mask', type=str, help='Mask path')
parser.add_argument('--dir', type=str, default='', help='Directory where generated samples are saved')
parser.add_argument('--steps', type=int, default=45, help='Number of DDIM sample steps')
parser.add_argument('--n_sample', type=int, default=4, help='Number of samples')
args = parser.parse_args()
return args


if __name__ == "__main__":
import os.path as osp

args = parse_args()
sampler = initialize_model(args.config, args.ckpt)
input_pair = {"image":Image.open(args.src),
"mask":Image.open(args.mask)}
results = predict(input_pair, args.prompt, args.steps, args.n_sample, 12, 991108, sampler)
for i, img in enumerate(results):
img.save(osp.join(args.dir, f"inpaint_{i}_{osp.basename(args.src)}"))