diff --git a/shortfin/python/shortfin_apps/sd/components/generate.py b/shortfin/python/shortfin_apps/sd/components/generate.py index 347e4cb3b..ca4f9799d 100644 --- a/shortfin/python/shortfin_apps/sd/components/generate.py +++ b/shortfin/python/shortfin_apps/sd/components/generate.py @@ -7,6 +7,7 @@ import asyncio import io import logging +import json import shortfin as sf import shortfin.array as sfnp @@ -95,12 +96,8 @@ async def run(self): # TODO: stream image outputs logging.debug("Responding to one shot batch") - out = io.BytesIO() - result_images = [p.result_image for p in gen_processes] - for idx, result_image in enumerate(result_images): - out.write(result_image) - # TODO: save or return images - logging.debug("Wrote images as bytes to response.") - self.responder.send_response(out.getvalue()) + response_data = {"images": [p.result_image for p in gen_processes]} + json_str = json.dumps(response_data) + self.responder.send_response(json_str) finally: self.responder.ensure_response() diff --git a/shortfin/python/shortfin_apps/sd/components/io_struct.py b/shortfin/python/shortfin_apps/sd/components/io_struct.py index e69bc8d82..c1e417f1d 100644 --- a/shortfin/python/shortfin_apps/sd/components/io_struct.py +++ b/shortfin/python/shortfin_apps/sd/components/io_struct.py @@ -41,23 +41,12 @@ def post_init(self): ): raise ValueError("Either text or input_ids should be provided.") - prev_input_len = None - for i in [self.prompt, self.neg_prompt, self.input_ids, self.neg_input_ids]: - if isinstance(i, str): - self.num_output_images = 1 - continue - elif not i: - continue - if not isinstance(i, list): - raise ValueError("Text inputs should be strings or lists.") - if prev_input_len and not (prev_input_len == len(i)): - raise ValueError("Positive, Negative text inputs should be same length") - self.num_output_images = len(i) - prev_input_len = len(i) - if not self.num_output_images: - self.num_output_images = ( - len[self.prompt] if self.prompt is not None else len(self.input_ids) - ) + if isinstance(self.prompt, str): + self.prompt = [str] + + self.num_output_images = ( + len(self.prompt) if self.prompt is not None else len(self.input_ids) + ) if self.rid is None: self.rid = [uuid.uuid4().hex for _ in range(self.num_output_images)] @@ -65,4 +54,4 @@ def post_init(self): if not isinstance(self.rid, list): raise ValueError("The rid should be a list.") if self.output_type is None: - self.output_type = ["base64"] * self.num_output_images + self.output_type = ["PIL"] * self.num_output_images diff --git a/shortfin/python/shortfin_apps/sd/components/messages.py b/shortfin/python/shortfin_apps/sd/components/messages.py index 9966b7fa8..dea9217d0 100644 --- a/shortfin/python/shortfin_apps/sd/components/messages.py +++ b/shortfin/python/shortfin_apps/sd/components/messages.py @@ -6,11 +6,15 @@ from enum import Enum +import logging + import shortfin as sf import shortfin.array as sfnp from .io_struct import GenerateReqInput +logger = logging.getLogger(__name__) + class InferencePhase(Enum): # Tokenize prompt, negative prompt and get latents, timesteps, time ids, guidance scale as device arrays @@ -54,6 +58,8 @@ def __init__( image_array: sfnp.device_array | None = None, ): super().__init__() + self.print_debug = False + self.phases = {} self.phase = None self.height = height @@ -87,6 +93,7 @@ def __init__( self.image_array = image_array self.result_image = None + self.img_metadata = None self.done = sf.VoidFuture() @@ -96,8 +103,18 @@ def __init__( self.return_host_array: bool = True self.post_init() - print(self.phases) - print(self.phase) + + def __setattr__(self, name, value): + self.__dict__[name] = value + if getattr(self, "print_debug"): + self.on_change(name, value) + + def on_change(self, name, value): + if isinstance(value, sfnp.device_array): + val_host = value.for_transfer() + val_host.copy_from(value) + logger.info("NAME: ", name) + logger.info("VALUE: \n", val_host) @staticmethod def from_batch(gen_req: GenerateReqInput, index: int) -> "InferenceExecRequest": @@ -114,8 +131,13 @@ def from_batch(gen_req: GenerateReqInput, index: int) -> "InferenceExecRequest": for item in gen_inputs: received = getattr(gen_req, item, None) if isinstance(received, list): - if index >= len(received): - rec_input = None + if index > len(received): + if len(received) == 1: + rec_input = received[0] + else: + logging.error( + "Inputs in request must be singular or as many as the list of prompts." + ) else: rec_input = received[index] else: diff --git a/shortfin/python/shortfin_apps/sd/components/service.py b/shortfin/python/shortfin_apps/sd/components/service.py index 16c0e8c4b..e1435fbd5 100644 --- a/shortfin/python/shortfin_apps/sd/components/service.py +++ b/shortfin/python/shortfin_apps/sd/components/service.py @@ -11,6 +11,8 @@ from tqdm.auto import tqdm from pathlib import Path from PIL import Image +import io +import base64 import shortfin as sf import shortfin.array as sfnp @@ -51,12 +53,14 @@ def __init__( self.procs_per_device = 1 self.workers = [] self.fibers = [] + self.locks = [] for idx, device in enumerate(self.sysman.ls.devices): for i in range(self.procs_per_device): worker = sysman.ls.create_worker(f"{name}-inference-{device.name}-{i}") fiber = sysman.ls.create_fiber(worker, devices=[device]) self.workers.append(worker) self.fibers.append(fiber) + self.locks.append(asyncio.Lock()) # Scope dependent objects. self.batcher = BatcherProcess(self) @@ -192,6 +196,7 @@ async def run(self): self.strobes += 1 else: logger.error("Illegal message received by batcher: %r", item) + self.board_flights() self.strobe_enabled = True await strober_task @@ -207,7 +212,7 @@ def board_flights(self): batches = self.sort_pending() for idx in batches.keys(): - self.board(batches[idx]["reqs"]) + self.board(batches[idx]["reqs"], index=idx) def sort_pending(self): """Returns pending requests as sorted batches suitable for program invocations.""" @@ -233,11 +238,11 @@ def sort_pending(self): } return batches - def board(self, request_bundle): + def board(self, request_bundle, index): pending = request_bundle if len(pending) == 0: return - exec_process = InferenceExecutorProcess(self.service, 0) + exec_process = InferenceExecutorProcess(self.service, index) for req in pending: if len(exec_process.exec_requests) >= self.ideal_batch_size: break @@ -264,6 +269,7 @@ def __init__( ): super().__init__(fiber=service.fibers[index]) self.service = service + self.worker_index = index self.exec_requests: list[InferenceExecRequest] = [] async def run(self): @@ -273,22 +279,22 @@ async def run(self): if phase: if phase != req.phase: logger.error("Executor process recieved disjoint batch.") + phase = req.phase phases = self.exec_requests[0].phases req_count = len(self.exec_requests) - device0 = self.fiber.device(0) - await device0 - - if phases[InferencePhase.PREPARE]["required"]: - await self._prepare(device=device0, requests=self.exec_requests) - if phases[InferencePhase.ENCODE]["required"]: - await self._encode(device=device0, requests=self.exec_requests) - if phases[InferencePhase.DENOISE]["required"]: - await self._denoise(device=device0, requests=self.exec_requests) - if phases[InferencePhase.DECODE]["required"]: - await self._decode(device=device0, requests=self.exec_requests) - if phases[InferencePhase.POSTPROCESS]["required"]: - await self._postprocess(device=device0, requests=self.exec_requests) + async with self.service.locks[self.worker_index]: + device0 = self.fiber.device(0) + if phases[InferencePhase.PREPARE]["required"]: + await self._prepare(device=device0, requests=self.exec_requests) + if phases[InferencePhase.ENCODE]["required"]: + await self._encode(device=device0, requests=self.exec_requests) + if phases[InferencePhase.DENOISE]["required"]: + await self._denoise(device=device0, requests=self.exec_requests) + if phases[InferencePhase.DECODE]["required"]: + await self._decode(device=device0, requests=self.exec_requests) + if phases[InferencePhase.POSTPROCESS]["required"]: + await self._postprocess(device=device0, requests=self.exec_requests) for i in range(req_count): req = self.exec_requests[i] @@ -575,6 +581,7 @@ async def _decode(self, device, requests): ] images_host = sfnp.device_array.for_host(device, images_shape, sfnp.float16) images_host.copy_from(image) + await device for idx, req in enumerate(requests): image_array = images_host.view(idx).items dtype = image_array.typecode @@ -591,5 +598,8 @@ async def _postprocess(self, device, requests): # TODO: reimpl with sfnp permuted = np.transpose(req.image_array, (0, 2, 3, 1))[0] cast_image = (permuted * 255).round().astype("uint8") - req.result_image = Image.fromarray(cast_image).tobytes() + image_bytes = Image.fromarray(cast_image).tobytes() + + image = base64.b64encode(image_bytes).decode("utf-8") + req.result_image = image return diff --git a/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs2.json b/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs2.json new file mode 100644 index 000000000..0ded22888 --- /dev/null +++ b/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs2.json @@ -0,0 +1,18 @@ +{ + "prompt": [ + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with green eyes, covered by snow, cinematic style, medium shot, professional photo, animal" + ], + "neg_prompt": "Watermark, blurry, oversaturated, low resolution, pollution", + "height": 1024, + "width": 1024, + "steps": 20, + "guidance_scale": [ + 7.5, + 7.9 + ], + "seed": 0, + "output_type": [ + "base64" + ] +} diff --git a/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs4.json b/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs4.json new file mode 100644 index 000000000..ede3d7f02 --- /dev/null +++ b/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs4.json @@ -0,0 +1,22 @@ +{ + "prompt": [ + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with red eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + " a dog under the snow with brown eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with green eyes, covered by snow, cinematic style, medium shot, professional photo, animal" + ], + "neg_prompt": "Watermark, blurry, oversaturated, low resolution, pollution", + "height": 1024, + "width": 1024, + "steps": 20, + "guidance_scale": [ + 7.5, + 7.9, + 7.5, + 7.5 + ], + "seed": 0, + "output_type": [ + "base64" + ] +} diff --git a/shortfin/python/shortfin_apps/sd/examples/send_request.py b/shortfin/python/shortfin_apps/sd/examples/send_request.py index a273ebe60..032846ca9 100644 --- a/shortfin/python/shortfin_apps/sd/examples/send_request.py +++ b/shortfin/python/shortfin_apps/sd/examples/send_request.py @@ -1,6 +1,7 @@ import json import requests import argparse +import base64 from datetime import datetime as dt from PIL import Image @@ -8,7 +9,9 @@ def bytes_to_img(bytes, idx=0, width=1024, height=1024): timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S") - image = Image.frombytes(mode="RGB", size=(width, height), data=bytes) + image = Image.frombytes( + mode="RGB", size=(width, height), data=base64.b64decode(bytes) + ) image.save(f"shortfin_sd_output_{timestamp}_{idx}.png") print(f"Saved to shortfin_sd_output_{timestamp}_{idx}.png") @@ -28,25 +31,19 @@ def send_json_file(file_path): response.raise_for_status() # Raise an error for bad responses print("Saving response as image...") timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S") - breakpoint() request = json.loads(response.request.body.decode("utf-8")) - if isinstance(response.content, list): - for idx, item in enumerate(response.content): - width = ( - request["width"][idx] - if isinstance(request["height"], list) - else request["height"] - ) - height = ( - request["height"][idx] - if isinstance(request["height"], list) - else request["height"] - ) - bytes_to_img(item, idx, width, height) - elif isinstance(response.content, bytes): - width = request["width"] - height = request["height"] - bytes_to_img(response.content, width=width, height=height) + for idx, item in enumerate(response.json()["images"]): + width = ( + request["width"][idx] + if isinstance(request["height"], list) + else request["height"] + ) + height = ( + request["height"][idx] + if isinstance(request["height"], list) + else request["height"] + ) + bytes_to_img(item.encode("ascii"), idx, width, height) except requests.exceptions.RequestException as e: print(f"Error sending the request: {e}")