Skip to content

Commit

Permalink
Add tests and small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
eagarvey-amd committed Oct 27, 2024
1 parent ecdbd99 commit e423afc
Show file tree
Hide file tree
Showing 7 changed files with 313 additions and 8 deletions.
17 changes: 17 additions & 0 deletions shortfin/python/shortfin_apps/sd/components/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,23 @@ def post_init(self):
len(self.prompt) if self.prompt is not None else len(self.input_ids)
)

batchable_args = [
self.prompt,
self.neg_prompt,
self.height,
self.width,
self.steps,
self.guidance_scale,
self.seed,
self.input_ids,
self.neg_input_ids,
]
for arg in batchable_args:
if isinstance(arg, list):
if len(arg) != self.num_output_images and len(arg) != 1:
raise ValueError(
f"Batchable arguments should either be singular or as many as the full batch ({self.num_output_images})."
)
if self.rid is None:
self.rid = [uuid.uuid4().hex for _ in range(self.num_output_images)]
else:
Expand Down
1 change: 1 addition & 0 deletions shortfin/python/shortfin_apps/sd/components/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
self.inference_functions: dict[str, dict[str, sf.ProgramFunction]] = {}
self.inference_programs: dict[str, sf.Program] = {}
self.procs_per_device = 1
self.streaming = True
self.workers = []
self.fibers = []
self.locks = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
"width": 1024,
"steps": 20,
"guidance_scale": [
7.5,
7.9,
7.5,
7.5
10,
10,
10,
10
],
"seed": 0,
"output_type": [
Expand Down
25 changes: 21 additions & 4 deletions shortfin/python/shortfin_apps/sd/examples/send_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,20 @@
from datetime import datetime as dt
from PIL import Image

sample_request = {
"prompt": [
" a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal",
],
"neg_prompt": ["Watermark, blurry, oversaturated, low resolution, pollution"],
"height": [1024],
"width": [1024],
"steps": [20],
"guidance_scale": [7.5],
"seed": [0],
"output_type": ["base64"],
"rid": ["string"],
}


def bytes_to_img(bytes, idx=0, width=1024, height=1024):
timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S")
Expand All @@ -19,8 +33,11 @@ def bytes_to_img(bytes, idx=0, width=1024, height=1024):
def send_json_file(file_path):
# Read the JSON file
try:
with open(file_path, "r") as json_file:
data = json.load(json_file)
if file_path == "default":
data = sample_request
else:
with open(file_path, "r") as json_file:
data = json.load(json_file)
except Exception as e:
print(f"Error reading the JSON file: {e}")
return
Expand All @@ -43,14 +60,14 @@ def send_json_file(file_path):
if isinstance(request["height"], list)
else request["height"]
)
bytes_to_img(item.encode("ascii"), idx, width, height)
bytes_to_img(item.encode("utf-8"), idx, width, height)

except requests.exceptions.RequestException as e:
print(f"Error sending the request: {e}")


if __name__ == "__main__":
p = argparse.ArgumentParser()
p.add_argument("file", type=str)
p.add_argument("--file", type=str, default="default")
args = p.parse_args()
send_json_file(args.file)
55 changes: 55 additions & 0 deletions shortfin/tests/apps/sd/components/tokenizer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import pytest


@pytest.fixture
def clip_tokenizer():
from shortfin_apps.sd.components.tokenizer import Tokenizer

return Tokenizer.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", "tokenizer"
)


def test_transformers_tokenizer(clip_tokenizer):
enc0 = clip_tokenizer.encode(["This is sequence 1", "Sequence 2"])
e0 = enc0.input_ids[0, :10]
e1 = enc0.input_ids[1, :10]
assert e0.tolist() == [
49406,
589,
533,
18833,
272,
49407,
49407,
49407,
49407,
49407,
]
assert e1.tolist() == [
49406,
18833,
273,
49407,
49407,
49407,
49407,
49407,
49407,
49407,
]


def test_tokenizer_to_array(cpu_fiber, clip_tokenizer):
batch_seq_len = 64
encs = clip_tokenizer.encode(["This is sequence 1", "Sequence 2"])
ary = clip_tokenizer.encodings_to_array(cpu_fiber.device(0), encs, batch_seq_len)
print(ary)
assert ary.view(0).items.tolist()[:5] == [49406, 589, 533, 18833, 272]
assert ary.view(1).items.tolist()[:5] == [49406, 18833, 273, 49407, 49407]
17 changes: 17 additions & 0 deletions shortfin/tests/apps/sd/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import pytest

from shortfin.support.deps import ShortfinDepNotFoundError


@pytest.fixture(autouse=True)
def require_deps():
try:
import shortfin_apps.sd
except ShortfinDepNotFoundError as e:
pytest.skip(f"Dep not available: {e}")
198 changes: 198 additions & 0 deletions shortfin/tests/apps/sd/e2e_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
import json
import requests
import time
import base64
import pytest
import subprocess
import os
import socket
import sys
from contextlib import closing

from datetime import datetime as dt
from PIL import Image

BATCH_SIZES = [1]

sample_request = {
"prompt": [
" a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal",
],
"neg_prompt": ["Watermark, blurry, oversaturated, low resolution, pollution"],
"height": [1024],
"width": [1024],
"steps": [20],
"guidance_scale": [7.5],
"seed": [0],
"output_type": ["base64"],
"rid": ["string"],
}


def sd_artifacts(target: str = "gfx942"):
return {
"model_config": "sdxl_config_i8.json",
"clip_vmfb": f"stable_diffusion_xl_base_1_0_bs1_64_fp16_text_encoder_{target}.vmfb",
"scheduler_vmfb": f"stable_diffusion_xl_base_1_0_EulerDiscreteScheduler_bs1_1024x1024_fp16_20_{target}.vmfb",
"unet_vmfb": f"stable_diffusion_xl_base_1_0_bs1_64_1024x1024_i8_punet_{target}.vmfb",
"vae_vmfb": f"stable_diffusion_xl_base_1_0_bs1_1024x1024_fp16_vae_{target}.vmfb",
"clip_params": "clip_splat_fp16.irpa",
"unet_params": "punet_splat_i8.irpa",
"vae_params": "vae_splat_fp16.irpa",
}


cache = os.path.abspath("./tmp/sharktank/sd/")


@pytest.fixture(scope="module")
def sd_server():
# Create necessary directories

os.makedirs(cache, exist_ok=True)

# Download model if it doesn't exist
vmfbs_bucket = "https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/vmfbs/"
weights_bucket = (
"https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/weights/"
)
configs_bucket = (
"https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/configs/"
)
for artifact, path in sd_artifacts().items():
if "vmfb" in artifact:
bucket = vmfbs_bucket
elif "params" in artifact:
bucket = weights_bucket
else:
bucket = configs_bucket
address = bucket + path
if not os.path.exists(f"{cache}/{path}"):
subprocess.run(
f"wget {address} -O {cache}/{path}",
shell=True,
check=True,
)
# Start the server
srv_args = [
"python",
"-m",
"shortfin_apps.sd.server",
]
for arg in sd_artifacts().keys():
artifact_arg = f"--{arg}={cache}/{sd_artifacts()[arg]}"
srv_args.extend([artifact_arg])
runner = ServerRunner(srv_args)
# Wait for server to start
time.sleep(5)

yield runner

# Teardown: kill the server
del runner


@pytest.mark.system("amdgpu")
def test_sd_server(sd_server):
imgs, status_code = send_json_file(sd_server.url)
assert len(imgs) == 1
assert status_code == 200


class ServerRunner:
def __init__(self, args):
port = str(find_free_port())
self.url = "http://0.0.0.0:" + port
env = os.environ.copy()
env["PYTHONUNBUFFERED"] = "1"
env["HIP_VISIBLE_DEVICES"] = "0"
self.process = subprocess.Popen(
[
*args,
"--port=" + port,
"--device=amdgpu",
],
env=env,
# TODO: Have a more robust way of forking a subprocess.
stdout=sys.stdout,
stderr=sys.stderr,
)
print(self.process.args)
self._wait_for_ready()

def _wait_for_ready(self):
start = time.time()
while True:
time.sleep(2)
try:
if requests.get(f"{self.url}/health").status_code == 200:
return
except Exception as e:
if self.process.errors is not None:
raise RuntimeError("API server processs terminated") from e
time.sleep(1.0)
if (time.time() - start) > 30:
raise RuntimeError("Timeout waiting for server start")

def __del__(self):
try:
process = self.process
except AttributeError:
pass
else:
process.terminate()
process.wait()


def bytes_to_img(bytes, idx=0, width=1024, height=1024):
timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S")
image = Image.frombytes(
mode="RGB", size=(width, height), data=base64.b64decode(bytes)
)
img_fn = f"shortfin_sd_test_output_{timestamp}_{idx}.png"
image.save(img_fn)
return img_fn


def send_json_file(url="http://0.0.0.0:8000"):
# Read the JSON file
data = sample_request
imgs = []
# Send the data to the /generate endpoint
try:
response = requests.post(url + "/generate", json=data)
response.raise_for_status() # Raise an error for bad responses
request = json.loads(response.request.body.decode("utf-8"))

for idx, item in enumerate(response.json()["images"]):
width = (
request["width"][idx]
if isinstance(request["height"], list)
else request["height"]
)
height = (
request["height"][idx]
if isinstance(request["height"], list)
else request["height"]
)
img_fn = bytes_to_img(item.encode("utf-8"), idx, width, height)
imgs.append(img_fn)

except requests.exceptions.RequestException as e:
print(f"Error sending the request: {e}")

return imgs, response.status_code


def find_free_port():
"""This tries to find a free port to run a server on for the test.
Race conditions are possible - the port can be acquired between when this
runs and when the server starts.
https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number
"""
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("localhost", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]

0 comments on commit e423afc

Please sign in to comment.