From c1176b69ccc0928d6b64fbf5e67fd9b1827b8a2c Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Tue, 29 Oct 2024 19:37:47 -0500 Subject: [PATCH] (shortfin-sd) Adds program isolation optionality and fibers_per_device. (#360) --- .github/workflows/ci-sdxl.yaml | 2 +- .../shortfin_apps/sd/components/service.py | 136 ++++++++++------- shortfin/python/shortfin_apps/sd/server.py | 47 +++++- shortfin/tests/apps/sd/e2e_test.py | 140 ++++++++++++++---- 4 files changed, 232 insertions(+), 93 deletions(-) diff --git a/.github/workflows/ci-sdxl.yaml b/.github/workflows/ci-sdxl.yaml index b86c7dc1e..17dc5abec 100644 --- a/.github/workflows/ci-sdxl.yaml +++ b/.github/workflows/ci-sdxl.yaml @@ -99,4 +99,4 @@ jobs: working-directory: ${{ env.LIBSHORTFIN_DIR }} run: | ctest --timeout 30 --output-on-failure --test-dir build - pytest tests/apps/sd/e2e_test.py -v -s --system=amdgpu + HIP_VISIBLE_DEVICES=0 pytest tests/apps/sd/e2e_test.py -v -s --system=amdgpu diff --git a/shortfin/python/shortfin_apps/sd/components/service.py b/shortfin/python/shortfin_apps/sd/components/service.py index d6cf71e48..2deec49c0 100644 --- a/shortfin/python/shortfin_apps/sd/components/service.py +++ b/shortfin/python/shortfin_apps/sd/components/service.py @@ -24,6 +24,12 @@ logger = logging.getLogger(__name__) +prog_isolations = { + "none": sf.ProgramIsolation.NONE, + "per_fiber": sf.ProgramIsolation.PER_FIBER, + "per_call": sf.ProgramIsolation.PER_CALL, +} + class GenerateService: """Top level service interface for image generation.""" @@ -39,6 +45,9 @@ def __init__( sysman: SystemManager, tokenizers: list[Tokenizer], model_params: ModelParams, + fibers_per_device: int, + prog_isolation: str = "per_fiber", + show_progress: bool = False, ): self.name = name @@ -50,17 +59,20 @@ def __init__( self.inference_modules: dict[str, sf.ProgramModule] = {} self.inference_functions: dict[str, dict[str, sf.ProgramFunction]] = {} self.inference_programs: dict[str, sf.Program] = {} - self.procs_per_device = 1 + self.trace_execution = False + self.show_progress = show_progress + self.fibers_per_device = fibers_per_device + self.prog_isolation = prog_isolations[prog_isolation] self.workers = [] self.fibers = [] - self.locks = [] + self.fiber_status = [] for idx, device in enumerate(self.sysman.ls.devices): - for i in range(self.procs_per_device): + for i in range(self.fibers_per_device): worker = sysman.ls.create_worker(f"{name}-inference-{device.name}-{i}") fiber = sysman.ls.create_fiber(worker, devices=[device]) self.workers.append(worker) self.fibers.append(fiber) - self.locks.append(asyncio.Lock()) + self.fiber_status.append(0) # Scope dependent objects. self.batcher = BatcherProcess(self) @@ -99,7 +111,8 @@ def start(self): self.inference_programs[component] = sf.Program( modules=component_modules, devices=fiber.raw_devices, - trace_execution=False, + isolation=self.prog_isolation, + trace_execution=self.trace_execution, ) # TODO: export vmfbs with multiple batch size entrypoints @@ -169,6 +182,7 @@ def __init__(self, service: GenerateService): self.strobe_enabled = True self.strobes: int = 0 self.ideal_batch_size: int = max(service.model_params.max_batch_size) + self.num_fibers = len(service.fibers) def shutdown(self): self.batcher_infeed.close() @@ -199,6 +213,7 @@ async def run(self): logger.error("Illegal message received by batcher: %r", item) self.board_flights() + self.strobe_enabled = True await strober_task @@ -210,28 +225,40 @@ def board_flights(self): logger.info("Waiting a bit longer to fill flight") return self.strobes = 0 + batches = self.sort_batches() + for idx, batch in batches.items(): + for fidx, status in enumerate(self.service.fiber_status): + if ( + status == 0 + or self.service.prog_isolation == sf.ProgramIsolation.PER_CALL + ): + self.board(batch["reqs"], index=fidx) + break - batches = self.sort_pending() - for idx in batches.keys(): - self.board(batches[idx]["reqs"], index=idx) - - def sort_pending(self): - """Returns pending requests as sorted batches suitable for program invocations.""" + def sort_batches(self): + """Files pending requests into sorted batches suitable for program invocations.""" + reqs = self.pending_requests + next_key = 0 batches = {} - for req in self.pending_requests: + for req in reqs: is_sorted = False req_metas = [req.phases[phase]["metadata"] for phase in req.phases.keys()] - next_key = 0 + for idx_key, data in batches.items(): if not isinstance(data, dict): logger.error( "Expected to find a dictionary containing a list of requests and their shared metadatas." ) - if data["meta"] == req_metas: - batches[idx_key]["reqs"].append(req) + if len(batches[idx_key]["reqs"]) >= self.ideal_batch_size: + # Batch is full + next_key = idx_key + 1 + continue + elif data["meta"] == req_metas: + batches[idx_key]["reqs"].extend([req]) is_sorted = True break - next_key = idx_key + 1 + else: + next_key = idx_key + 1 if not is_sorted: batches[next_key] = { "reqs": [req], @@ -251,7 +278,8 @@ def board(self, request_bundle, index): if exec_process.exec_requests: for flighted_request in exec_process.exec_requests: self.pending_requests.remove(flighted_request) - print(f"launching exec process for {exec_process.exec_requests}") + if self.service.prog_isolation != sf.ProgramIsolation.PER_CALL: + self.service.fiber_status[index] = 1 exec_process.launch() @@ -284,22 +312,22 @@ async def run(self): phases = self.exec_requests[0].phases req_count = len(self.exec_requests) - async with self.service.locks[self.worker_index]: - device0 = self.fiber.device(0) - if phases[InferencePhase.PREPARE]["required"]: - await self._prepare(device=device0, requests=self.exec_requests) - if phases[InferencePhase.ENCODE]["required"]: - await self._encode(device=device0, requests=self.exec_requests) - if phases[InferencePhase.DENOISE]["required"]: - await self._denoise(device=device0, requests=self.exec_requests) - if phases[InferencePhase.DECODE]["required"]: - await self._decode(device=device0, requests=self.exec_requests) - if phases[InferencePhase.POSTPROCESS]["required"]: - await self._postprocess(device=device0, requests=self.exec_requests) + device0 = self.service.fibers[self.worker_index].device(0) + if phases[InferencePhase.PREPARE]["required"]: + await self._prepare(device=device0, requests=self.exec_requests) + if phases[InferencePhase.ENCODE]["required"]: + await self._encode(device=device0, requests=self.exec_requests) + if phases[InferencePhase.DENOISE]["required"]: + await self._denoise(device=device0, requests=self.exec_requests) + if phases[InferencePhase.DECODE]["required"]: + await self._decode(device=device0, requests=self.exec_requests) + if phases[InferencePhase.POSTPROCESS]["required"]: + await self._postprocess(device=device0, requests=self.exec_requests) for i in range(req_count): req = self.exec_requests[i] req.done.set_success() + self.service.fiber_status[self.worker_index] = 0 except Exception: logger.exception("Fatal error in image generation") @@ -345,7 +373,6 @@ async def _prepare(self, device, requests): sfnp.fill_randn(sample_host, generator=generator) request.sample.copy_from(sample_host) - await device return async def _encode(self, device, requests): @@ -385,15 +412,13 @@ async def _encode(self, device, requests): clip_inputs[idx].copy_from(host_arrs[idx]) # Encode tokenized inputs. - logger.info( + logger.debug( "INVOKE %r: %s", fn, "".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(clip_inputs)]), ) - await device pe, te = await fn(*clip_inputs, fiber=self.fiber) - await device for i in range(req_bs): cfg_mult = 2 requests[i].prompt_embeds = pe.view(slice(i * cfg_mult, (i + 1) * cfg_mult)) @@ -477,20 +502,23 @@ async def _denoise(self, device, requests): ns_host.items = [step_count] num_steps.copy_from(ns_host) - await device + init_inputs = [ + denoise_inputs["sample"], + num_steps, + ] + # Initialize scheduler. - logger.info( - "INVOKE %r: %s", + logger.debug( + "INVOKE %r", fns["init"], - "".join([f"\n 0: {latents_shape}"]), ) (latents, time_ids, timesteps, sigmas) = await fns["init"]( - denoise_inputs["sample"], num_steps, fiber=self.fiber + *init_inputs, fiber=self.fiber ) - - await device for i, t in tqdm( enumerate(range(step_count)), + disable=(not self.service.show_progress), + desc=f"Worker #{self.worker_index} DENOISE (bs{req_bs})", ): step = sfnp.device_array.for_device(device, [1], sfnp.sint64) s_host = step.for_transfer() @@ -498,14 +526,10 @@ async def _denoise(self, device, requests): s_host.items = [i] step.copy_from(s_host) scale_inputs = [latents, step, timesteps, sigmas] - logger.info( - "INVOKE %r: %s", + logger.debug( + "INVOKE %r", fns["scale"], - "".join( - [f"\n {i}: {ary.shape}" for i, ary in enumerate(scale_inputs)] - ), ) - await device latent_model_input, t, sigma, next_sigma = await fns["scale"]( *scale_inputs, fiber=self.fiber ) @@ -519,32 +543,25 @@ async def _denoise(self, device, requests): time_ids, denoise_inputs["guidance_scale"], ] - logger.info( - "INVOKE %r: %s", + logger.debug( + "INVOKE %r", fns["unet"], - "".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(unet_inputs)]), ) - await device (noise_pred,) = await fns["unet"](*unet_inputs, fiber=self.fiber) - await device step_inputs = [noise_pred, latents, sigma, next_sigma] - logger.info( - "INVOKE %r: %s", + logger.debug( + "INVOKE %r", fns["step"], - "".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(step_inputs)]), ) - await device (latent_model_output,) = await fns["step"](*step_inputs, fiber=self.fiber) latents.copy_from(latent_model_output) - await device for idx, req in enumerate(requests): req.denoised_latents = sfnp.device_array.for_device( device, latents_shape, self.service.model_params.vae_dtype ) req.denoised_latents.copy_from(latents.view(idx)) - await device return async def _decode(self, device, requests): @@ -569,6 +586,11 @@ async def _decode(self, device, requests): await device # Decode the denoised latents. + logger.debug( + "INVOKE %r: %s", + fn, + "".join([f"\n 0: {latents.shape}"]), + ) (image,) = await fn(latents, fiber=self.fiber) await device diff --git a/shortfin/python/shortfin_apps/sd/server.py b/shortfin/python/shortfin_apps/sd/server.py index 5e7abd1fc..0327b0a9f 100644 --- a/shortfin/python/shortfin_apps/sd/server.py +++ b/shortfin/python/shortfin_apps/sd/server.py @@ -31,7 +31,9 @@ from .components.tokenizer import Tokenizer -logger = logging.getLogger(__name__) +from shortfin.support.logging_setup import configure_main_logger + +logger = configure_main_logger("server") @asynccontextmanager @@ -87,7 +89,13 @@ def configure(args) -> SystemManager: model_params = ModelParams.load_json(args.model_config) sm = GenerateService( - name="sd", sysman=sysman, tokenizers=tokenizers, model_params=model_params + name="sd", + sysman=sysman, + tokenizers=tokenizers, + model_params=model_params, + fibers_per_device=args.fibers_per_device, + prog_isolation=args.isolation, + show_progress=args.show_progress, ) sm.load_inference_module(args.clip_vmfb, component="clip") sm.load_inference_module(args.unet_vmfb, component="unet") @@ -188,10 +196,40 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): nargs="*", help="Parameter archives to load", ) + parser.add_argument( + "--fibers_per_device", + type=int, + default=1, + help="Concurrency control -- how many fibers are created per device to run inference.", + ) + parser.add_argument( + "--isolation", + type=str, + default="per_fiber", + choices=["per_fiber", "per_call", "none"], + help="Concurrency control -- How to isolate programs.", + ) + parser.add_argument( + "--log_level", type=str, default="error", choices=["info", "debug", "error"] + ) + parser.add_argument( + "--show_progress", + action="store_true", + help="enable tqdm progress for unet iterations.", + ) + log_levels = { + "info": logging.INFO, + "debug": logging.DEBUG, + "error": logging.ERROR, + } + args = parser.parse_args(argv) + + log_level = log_levels[args.log_level] + logger.setLevel(log_level) + global sysman sysman = configure(args) - uvicorn.run( app, host=args.host, @@ -202,9 +240,6 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): if __name__ == "__main__": - from shortfin.support.logging_setup import configure_main_logger - - logger = configure_main_logger("server") main( sys.argv[1:], # Make logging defer to the default shortfin logging config. diff --git a/shortfin/tests/apps/sd/e2e_test.py b/shortfin/tests/apps/sd/e2e_test.py index b8331946b..05b9ef69b 100644 --- a/shortfin/tests/apps/sd/e2e_test.py +++ b/shortfin/tests/apps/sd/e2e_test.py @@ -1,9 +1,3 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - import json import requests import time @@ -13,6 +7,7 @@ import os import socket import sys +import copy from contextlib import closing from datetime import datetime as dt @@ -27,7 +22,7 @@ "neg_prompt": ["Watermark, blurry, oversaturated, low resolution, pollution"], "height": [1024], "width": [1024], - "steps": [20], + "steps": [5], "guidance_scale": [7.5], "seed": [0], "output_type": ["base64"], @@ -51,12 +46,7 @@ def sd_artifacts(target: str = "gfx942"): cache = os.path.abspath("./tmp/sharktank/sd/") -@pytest.fixture(scope="module") -def sd_server(): - # Create necessary directories - - os.makedirs(cache, exist_ok=True) - +def start_server(fibers_per_device=1, isolation="per_fiber"): # Download model if it doesn't exist vmfbs_bucket = "https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/vmfbs/" weights_bucket = ( @@ -88,9 +78,67 @@ def sd_server(): for arg in sd_artifacts().keys(): artifact_arg = f"--{arg}={cache}/{sd_artifacts()[arg]}" srv_args.extend([artifact_arg]) + srv_args.extend( + [ + f"--fibers_per_device={fibers_per_device}", + f"--isolation={isolation}", + ] + ) runner = ServerRunner(srv_args) # Wait for server to start - time.sleep(5) + time.sleep(3) + return runner + + +@pytest.fixture(scope="module") +def sd_server_fpd1(): + # Create necessary directories + + os.makedirs(cache, exist_ok=True) + + runner = start_server(fibers_per_device=1) + + yield runner + + # Teardown: kill the server + del runner + + +@pytest.fixture(scope="module") +def sd_server_fpd1_per_call(): + # Create necessary directories + + os.makedirs(cache, exist_ok=True) + + runner = start_server(fibers_per_device=1, isolation="per_call") + + yield runner + + # Teardown: kill the server + del runner + + +@pytest.fixture(scope="module") +def sd_server_fpd2(): + # Create necessary directories + + os.makedirs(cache, exist_ok=True) + + runner = start_server(fibers_per_device=2) + + yield runner + + # Teardown: kill the server + del runner + + +@pytest.fixture(scope="module") +def sd_server_fpd8(): + # Create necessary directories + + os.makedirs(cache, exist_ok=True) + + runner = start_server(fibers_per_device=8) yield runner @@ -99,19 +147,46 @@ def sd_server(): @pytest.mark.system("amdgpu") -def test_sd_server(sd_server): - imgs, status_code = send_json_file(sd_server.url) +def test_sd_server(sd_server_fpd1): + imgs, status_code = send_json_file(sd_server_fpd1.url) assert len(imgs) == 1 assert status_code == 200 +@pytest.mark.system("amdgpu") +def test_sd_server_bs4_dense(sd_server_fpd1): + imgs, status_code = send_json_file(sd_server_fpd1.url, num_copies=4) + assert len(imgs) == 4 + assert status_code == 200 + + +@pytest.mark.system("amdgpu") +def test_sd_server_bs8_percall(sd_server_fpd1_per_call): + imgs, status_code = send_json_file(sd_server_fpd1_per_call.url, num_copies=8) + assert len(imgs) == 8 + assert status_code == 200 + + +@pytest.mark.system("amdgpu") +def test_sd_server_bs4_dense_fpd2(sd_server_fpd2): + imgs, status_code = send_json_file(sd_server_fpd2.url, num_copies=4) + assert len(imgs) == 4 + assert status_code == 200 + + +@pytest.mark.system("amdgpu") +def test_sd_server_bs8_dense_fpd8(sd_server_fpd8): + imgs, status_code = send_json_file(sd_server_fpd8.url, num_copies=8) + assert len(imgs) == 8 + assert status_code == 200 + + class ServerRunner: def __init__(self, args): port = str(find_free_port()) self.url = "http://0.0.0.0:" + port env = os.environ.copy() env["PYTHONUNBUFFERED"] = "1" - env["HIP_VISIBLE_DEVICES"] = "0" self.process = subprocess.Popen( [ *args, @@ -158,27 +233,24 @@ def bytes_to_img(bytes, idx=0, width=1024, height=1024): return image -def send_json_file(url="http://0.0.0.0:8000"): +def send_json_file(url="http://0.0.0.0:8000", num_copies=1): # Read the JSON file - data = sample_request + data = copy.deepcopy(sample_request) imgs = [] # Send the data to the /generate endpoint + data["prompt"] = ( + [data["prompt"]] + if isinstance(data["prompt"], str) + else data["prompt"] * num_copies + ) try: response = requests.post(url + "/generate", json=data) response.raise_for_status() # Raise an error for bad responses request = json.loads(response.request.body.decode("utf-8")) for idx, item in enumerate(response.json()["images"]): - width = ( - request["width"][idx] - if isinstance(request["height"], list) - else request["height"] - ) - height = ( - request["height"][idx] - if isinstance(request["height"], list) - else request["height"] - ) + width = getbatched(request, idx, "width") + height = getbatched(request, idx, "height") img = bytes_to_img(item.encode("utf-8"), idx, width, height) imgs.append(img) @@ -188,6 +260,16 @@ def send_json_file(url="http://0.0.0.0:8000"): return imgs, response.status_code +def getbatched(req, idx, key): + if isinstance(req[key], list): + if len(req[key]) == 1: + return req[key][0] + elif len(req[key]) > idx: + return req[key][idx] + else: + return req[key] + + def find_free_port(): """This tries to find a free port to run a server on for the test.