Skip to content

Commit

Permalink
Fix batched client request processing
Browse files Browse the repository at this point in the history
  • Loading branch information
eagarvey-amd committed Oct 27, 2024
1 parent 3f17aa3 commit ecdbd99
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 65 deletions.
11 changes: 4 additions & 7 deletions shortfin/python/shortfin_apps/sd/components/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import asyncio
import io
import logging
import json

import shortfin as sf
import shortfin.array as sfnp
Expand Down Expand Up @@ -95,12 +96,8 @@ async def run(self):

# TODO: stream image outputs
logging.debug("Responding to one shot batch")
out = io.BytesIO()
result_images = [p.result_image for p in gen_processes]
for idx, result_image in enumerate(result_images):
out.write(result_image)
# TODO: save or return images
logging.debug("Wrote images as bytes to response.")
self.responder.send_response(out.getvalue())
response_data = {"images": [p.result_image for p in gen_processes]}
json_str = json.dumps(response_data)
self.responder.send_response(json_str)
finally:
self.responder.ensure_response()
25 changes: 7 additions & 18 deletions shortfin/python/shortfin_apps/sd/components/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,28 +41,17 @@ def post_init(self):
):
raise ValueError("Either text or input_ids should be provided.")

prev_input_len = None
for i in [self.prompt, self.neg_prompt, self.input_ids, self.neg_input_ids]:
if isinstance(i, str):
self.num_output_images = 1
continue
elif not i:
continue
if not isinstance(i, list):
raise ValueError("Text inputs should be strings or lists.")
if prev_input_len and not (prev_input_len == len(i)):
raise ValueError("Positive, Negative text inputs should be same length")
self.num_output_images = len(i)
prev_input_len = len(i)
if not self.num_output_images:
self.num_output_images = (
len[self.prompt] if self.prompt is not None else len(self.input_ids)
)
if isinstance(self.prompt, str):
self.prompt = [str]

self.num_output_images = (
len(self.prompt) if self.prompt is not None else len(self.input_ids)
)

if self.rid is None:
self.rid = [uuid.uuid4().hex for _ in range(self.num_output_images)]
else:
if not isinstance(self.rid, list):
raise ValueError("The rid should be a list.")
if self.output_type is None:
self.output_type = ["base64"] * self.num_output_images
self.output_type = ["PIL"] * self.num_output_images
30 changes: 26 additions & 4 deletions shortfin/python/shortfin_apps/sd/components/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@

from enum import Enum

import logging

import shortfin as sf
import shortfin.array as sfnp

from .io_struct import GenerateReqInput

logger = logging.getLogger(__name__)


class InferencePhase(Enum):
# Tokenize prompt, negative prompt and get latents, timesteps, time ids, guidance scale as device arrays
Expand Down Expand Up @@ -54,6 +58,8 @@ def __init__(
image_array: sfnp.device_array | None = None,
):
super().__init__()
self.print_debug = False

self.phases = {}
self.phase = None
self.height = height
Expand Down Expand Up @@ -87,6 +93,7 @@ def __init__(
self.image_array = image_array

self.result_image = None
self.img_metadata = None

self.done = sf.VoidFuture()

Expand All @@ -96,8 +103,18 @@ def __init__(
self.return_host_array: bool = True

self.post_init()
print(self.phases)
print(self.phase)

def __setattr__(self, name, value):
self.__dict__[name] = value
if getattr(self, "print_debug"):
self.on_change(name, value)

def on_change(self, name, value):
if isinstance(value, sfnp.device_array):
val_host = value.for_transfer()
val_host.copy_from(value)
logger.info("NAME: ", name)
logger.info("VALUE: \n", val_host)

@staticmethod
def from_batch(gen_req: GenerateReqInput, index: int) -> "InferenceExecRequest":
Expand All @@ -114,8 +131,13 @@ def from_batch(gen_req: GenerateReqInput, index: int) -> "InferenceExecRequest":
for item in gen_inputs:
received = getattr(gen_req, item, None)
if isinstance(received, list):
if index >= len(received):
rec_input = None
if index > len(received):
if len(received) == 1:
rec_input = received[0]
else:
logging.error(
"Inputs in request must be singular or as many as the list of prompts."
)
else:
rec_input = received[index]
else:
Expand Down
44 changes: 27 additions & 17 deletions shortfin/python/shortfin_apps/sd/components/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from tqdm.auto import tqdm
from pathlib import Path
from PIL import Image
import io
import base64

import shortfin as sf
import shortfin.array as sfnp
Expand Down Expand Up @@ -51,12 +53,14 @@ def __init__(
self.procs_per_device = 1
self.workers = []
self.fibers = []
self.locks = []
for idx, device in enumerate(self.sysman.ls.devices):
for i in range(self.procs_per_device):
worker = sysman.ls.create_worker(f"{name}-inference-{device.name}-{i}")
fiber = sysman.ls.create_fiber(worker, devices=[device])
self.workers.append(worker)
self.fibers.append(fiber)
self.locks.append(asyncio.Lock())

# Scope dependent objects.
self.batcher = BatcherProcess(self)
Expand Down Expand Up @@ -192,6 +196,7 @@ async def run(self):
self.strobes += 1
else:
logger.error("Illegal message received by batcher: %r", item)

self.board_flights()
self.strobe_enabled = True
await strober_task
Expand All @@ -207,7 +212,7 @@ def board_flights(self):

batches = self.sort_pending()
for idx in batches.keys():
self.board(batches[idx]["reqs"])
self.board(batches[idx]["reqs"], index=idx)

def sort_pending(self):
"""Returns pending requests as sorted batches suitable for program invocations."""
Expand All @@ -233,11 +238,11 @@ def sort_pending(self):
}
return batches

def board(self, request_bundle):
def board(self, request_bundle, index):
pending = request_bundle
if len(pending) == 0:
return
exec_process = InferenceExecutorProcess(self.service, 0)
exec_process = InferenceExecutorProcess(self.service, index)
for req in pending:
if len(exec_process.exec_requests) >= self.ideal_batch_size:
break
Expand All @@ -264,6 +269,7 @@ def __init__(
):
super().__init__(fiber=service.fibers[index])
self.service = service
self.worker_index = index
self.exec_requests: list[InferenceExecRequest] = []

async def run(self):
Expand All @@ -273,22 +279,22 @@ async def run(self):
if phase:
if phase != req.phase:
logger.error("Executor process recieved disjoint batch.")
phase = req.phase
phases = self.exec_requests[0].phases

req_count = len(self.exec_requests)
device0 = self.fiber.device(0)
await device0

if phases[InferencePhase.PREPARE]["required"]:
await self._prepare(device=device0, requests=self.exec_requests)
if phases[InferencePhase.ENCODE]["required"]:
await self._encode(device=device0, requests=self.exec_requests)
if phases[InferencePhase.DENOISE]["required"]:
await self._denoise(device=device0, requests=self.exec_requests)
if phases[InferencePhase.DECODE]["required"]:
await self._decode(device=device0, requests=self.exec_requests)
if phases[InferencePhase.POSTPROCESS]["required"]:
await self._postprocess(device=device0, requests=self.exec_requests)
async with self.service.locks[self.worker_index]:
device0 = self.fiber.device(0)
if phases[InferencePhase.PREPARE]["required"]:
await self._prepare(device=device0, requests=self.exec_requests)
if phases[InferencePhase.ENCODE]["required"]:
await self._encode(device=device0, requests=self.exec_requests)
if phases[InferencePhase.DENOISE]["required"]:
await self._denoise(device=device0, requests=self.exec_requests)
if phases[InferencePhase.DECODE]["required"]:
await self._decode(device=device0, requests=self.exec_requests)
if phases[InferencePhase.POSTPROCESS]["required"]:
await self._postprocess(device=device0, requests=self.exec_requests)

for i in range(req_count):
req = self.exec_requests[i]
Expand Down Expand Up @@ -575,6 +581,7 @@ async def _decode(self, device, requests):
]
images_host = sfnp.device_array.for_host(device, images_shape, sfnp.float16)
images_host.copy_from(image)
await device
for idx, req in enumerate(requests):
image_array = images_host.view(idx).items
dtype = image_array.typecode
Expand All @@ -591,5 +598,8 @@ async def _postprocess(self, device, requests):
# TODO: reimpl with sfnp
permuted = np.transpose(req.image_array, (0, 2, 3, 1))[0]
cast_image = (permuted * 255).round().astype("uint8")
req.result_image = Image.fromarray(cast_image).tobytes()
image_bytes = Image.fromarray(cast_image).tobytes()

image = base64.b64encode(image_bytes).decode("utf-8")
req.result_image = image
return
18 changes: 18 additions & 0 deletions shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
"prompt": [
" a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal",
" a cat under the snow with green eyes, covered by snow, cinematic style, medium shot, professional photo, animal"
],
"neg_prompt": "Watermark, blurry, oversaturated, low resolution, pollution",
"height": 1024,
"width": 1024,
"steps": 20,
"guidance_scale": [
7.5,
7.9
],
"seed": 0,
"output_type": [
"base64"
]
}
22 changes: 22 additions & 0 deletions shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs4.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"prompt": [
" a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal",
" a cat under the snow with red eyes, covered by snow, cinematic style, medium shot, professional photo, animal",
" a dog under the snow with brown eyes, covered by snow, cinematic style, medium shot, professional photo, animal",
" a cat under the snow with green eyes, covered by snow, cinematic style, medium shot, professional photo, animal"
],
"neg_prompt": "Watermark, blurry, oversaturated, low resolution, pollution",
"height": 1024,
"width": 1024,
"steps": 20,
"guidance_scale": [
7.5,
7.9,
7.5,
7.5
],
"seed": 0,
"output_type": [
"base64"
]
}
35 changes: 16 additions & 19 deletions shortfin/python/shortfin_apps/sd/examples/send_request.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import json
import requests
import argparse
import base64

from datetime import datetime as dt
from PIL import Image


def bytes_to_img(bytes, idx=0, width=1024, height=1024):
timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S")
image = Image.frombytes(mode="RGB", size=(width, height), data=bytes)
image = Image.frombytes(
mode="RGB", size=(width, height), data=base64.b64decode(bytes)
)
image.save(f"shortfin_sd_output_{timestamp}_{idx}.png")
print(f"Saved to shortfin_sd_output_{timestamp}_{idx}.png")

Expand All @@ -28,25 +31,19 @@ def send_json_file(file_path):
response.raise_for_status() # Raise an error for bad responses
print("Saving response as image...")
timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S")
breakpoint()
request = json.loads(response.request.body.decode("utf-8"))
if isinstance(response.content, list):
for idx, item in enumerate(response.content):
width = (
request["width"][idx]
if isinstance(request["height"], list)
else request["height"]
)
height = (
request["height"][idx]
if isinstance(request["height"], list)
else request["height"]
)
bytes_to_img(item, idx, width, height)
elif isinstance(response.content, bytes):
width = request["width"]
height = request["height"]
bytes_to_img(response.content, width=width, height=height)
for idx, item in enumerate(response.json()["images"]):
width = (
request["width"][idx]
if isinstance(request["height"], list)
else request["height"]
)
height = (
request["height"][idx]
if isinstance(request["height"], list)
else request["height"]
)
bytes_to_img(item.encode("ascii"), idx, width, height)

except requests.exceptions.RequestException as e:
print(f"Error sending the request: {e}")
Expand Down

0 comments on commit ecdbd99

Please sign in to comment.