diff --git a/shortfin/python/shortfin/support/logging_setup.py b/shortfin/python/shortfin/support/logging_setup.py index 5585e6a82..3cb373f1e 100644 --- a/shortfin/python/shortfin/support/logging_setup.py +++ b/shortfin/python/shortfin/support/logging_setup.py @@ -38,19 +38,15 @@ def __init__(self): native_handler.setFormatter(NativeFormatter()) # TODO: Source from env vars. -logger.setLevel(logging.INFO) +logger.setLevel(logging.DEBUG) logger.addHandler(native_handler) def configure_main_logger(module_suffix: str = "__main__") -> logging.Logger: """Configures logging from a main entrypoint. - Returns a logger that can be used for the main module itself. """ + logging.root.addHandler(native_handler) + logging.root.setLevel(logging.WARNING) # TODO: source from env vars main_module = sys.modules["__main__"] - logging.root.setLevel(logging.INFO) - logger = logging.getLogger(f"{main_module.__package__}.{module_suffix}") - logger.setLevel(logging.INFO) - logger.addHandler(native_handler) - - return logger + return logging.getLogger(f"{main_module.__package__}.{module_suffix}") diff --git a/shortfin/python/shortfin_apps/sd/README.md b/shortfin/python/shortfin_apps/sd/README.md index 4808cad08..6dd701c62 100644 --- a/shortfin/python/shortfin_apps/sd/README.md +++ b/shortfin/python/shortfin_apps/sd/README.md @@ -38,13 +38,10 @@ cd shortfin/ The server will prepare runtime artifacts for you. ``` -python -m shortfin_apps.sd.server --model_config=./python/shortfin_apps/sd/examples/sdxl_config_i8.json --device=amdgpu --device_ids=0 --flagfile=./python/shortfin_apps/sd/examples/sdxl_flags_gfx942.txt --build_preference=compile +python -m shortfin_apps.sd.server --device=amdgpu --device_ids=0 --build_preference=precompiled --topology="spx_single" ``` - - Run with splat(empty) weights: -``` -python -m shortfin_apps.sd.server --model_config=./python/shortfin_apps/sd/examples/sdxl_config_i8.json --device=amdgpu --device_ids=0 --splat --flagfile=./python/shortfin_apps/sd/examples/sdxl_flags_gfx942.txt --build_preference=compile -``` - - Run a request in a separate shell: + + - Run a CLI client in a separate shell: ``` -python shortfin/python/shortfin_apps/sd/examples/send_request.py --file=shortfin/python/shortfin_apps/sd/examples/sdxl_request.json +python -m shortfin_apps.sd.simple_client --interactive --save ``` diff --git a/shortfin/python/shortfin_apps/sd/components/builders.py b/shortfin/python/shortfin_apps/sd/components/builders.py index ed948bee9..f23922dd6 100644 --- a/shortfin/python/shortfin_apps/sd/components/builders.py +++ b/shortfin/python/shortfin_apps/sd/components/builders.py @@ -24,7 +24,7 @@ sfnp.bfloat16: "bf16", } -ARTIFACT_VERSION = "11022024" +ARTIFACT_VERSION = "11132024" SDXL_BUCKET = ( f"https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/{ARTIFACT_VERSION}/" ) @@ -51,7 +51,9 @@ def get_mlir_filenames(model_params: ModelParams, model=None): return filter_by_model(mlir_filenames, model) -def get_vmfb_filenames(model_params: ModelParams, model=None, target: str = "gfx942"): +def get_vmfb_filenames( + model_params: ModelParams, model=None, target: str = "amdgpu-gfx942" +): vmfb_filenames = [] file_stems = get_file_stems(model_params) for stem in file_stems: @@ -216,6 +218,8 @@ def sdxl( mlir_bucket = SDXL_BUCKET + "mlir/" vmfb_bucket = SDXL_BUCKET + "vmfbs/" + if "gfx" in target: + target = "amdgpu-" + target mlir_filenames = get_mlir_filenames(model_params, model) mlir_urls = get_url_map(mlir_filenames, mlir_bucket) @@ -247,7 +251,7 @@ def sdxl( params_urls = get_url_map(params_filenames, SDXL_WEIGHTS_BUCKET) for f, url in params_urls.items(): out_file = os.path.join(ctx.executor.output_dir, f) - if update or needs_file(f, ctx): + if needs_file(f, ctx): fetch_http(name=f, url=url) filenames = [*vmfb_filenames, *params_filenames, *mlir_filenames] return filenames diff --git a/shortfin/python/shortfin_apps/sd/components/config_artifacts.py b/shortfin/python/shortfin_apps/sd/components/config_artifacts.py new file mode 100644 index 000000000..b5a1d682b --- /dev/null +++ b/shortfin/python/shortfin_apps/sd/components/config_artifacts.py @@ -0,0 +1,123 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from iree.build import * +from iree.build.executor import FileNamespace +import itertools +import os +import shortfin.array as sfnp +import copy + +from shortfin_apps.sd.components.config_struct import ModelParams + +this_dir = os.path.dirname(os.path.abspath(__file__)) +parent = os.path.dirname(this_dir) + +dtype_to_filetag = { + sfnp.float16: "fp16", + sfnp.float32: "fp32", + sfnp.int8: "i8", + sfnp.bfloat16: "bf16", +} + +ARTIFACT_VERSION = "11132024" +SDXL_CONFIG_BUCKET = f"https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/{ARTIFACT_VERSION}/configs/" + + +def get_url_map(filenames: list[str], bucket: str): + file_map = {} + for filename in filenames: + file_map[filename] = f"{bucket}{filename}" + return file_map + + +def needs_update(ctx): + stamp = ctx.allocate_file("version.txt") + stamp_path = stamp.get_fs_path() + if os.path.exists(stamp_path): + with open(stamp_path, "r") as s: + ver = s.read() + if ver != ARTIFACT_VERSION: + return True + else: + with open(stamp_path, "w") as s: + s.write(ARTIFACT_VERSION) + return True + return False + + +def needs_file(filename, ctx, namespace=FileNamespace.GEN): + out_file = ctx.allocate_file(filename, namespace=namespace).get_fs_path() + if os.path.exists(out_file): + needed = False + else: + # name_path = "bin" if namespace == FileNamespace.BIN else "" + # if name_path: + # filename = os.path.join(name_path, filename) + filekey = os.path.join(ctx.path, filename) + ctx.executor.all[filekey] = None + needed = True + return needed + + +@entrypoint(description="Retreives a set of SDXL configuration files.") +def sdxlconfig( + target=cl_arg( + "target", + default="gfx942", + help="IREE target architecture.", + ), + model=cl_arg("model", type=str, default="sdxl", help="Model architecture"), + topology=cl_arg( + "topology", + type=str, + default="spx_single", + help="System topology configfile keyword", + ), +): + ctx = executor.BuildContext.current() + update = needs_update(ctx) + + model_config_filenames = [f"{model}_config_i8.json"] + model_config_urls = get_url_map(model_config_filenames, SDXL_CONFIG_BUCKET) + for f, url in model_config_urls.items(): + out_file = os.path.join(ctx.executor.output_dir, f) + if update or needs_file(f, ctx): + fetch_http(name=f, url=url) + + topology_config_filenames = [f"topology_config_{topology}.txt"] + topology_config_urls = get_url_map(topology_config_filenames, SDXL_CONFIG_BUCKET) + for f, url in topology_config_urls.items(): + out_file = os.path.join(ctx.executor.output_dir, f) + if update or needs_file(f, ctx): + fetch_http(name=f, url=url) + + flagfile_filenames = [f"{model}_flagfile_{target}.txt"] + flagfile_urls = get_url_map(flagfile_filenames, SDXL_CONFIG_BUCKET) + for f, url in flagfile_urls.items(): + out_file = os.path.join(ctx.executor.output_dir, f) + if update or needs_file(f, ctx): + fetch_http(name=f, url=url) + + tuning_filenames = ( + [f"attention_and_matmul_spec_{target}.mlir"] if target == "gfx942" else [] + ) + tuning_urls = get_url_map(tuning_filenames, SDXL_CONFIG_BUCKET) + for f, url in tuning_urls.items(): + out_file = os.path.join(ctx.executor.output_dir, f) + if update or needs_file(f, ctx): + fetch_http(name=f, url=url) + filenames = [ + *model_config_filenames, + *topology_config_filenames, + *flagfile_filenames, + *tuning_filenames, + ] + return filenames + + +if __name__ == "__main__": + iree_build_main() diff --git a/shortfin/python/shortfin_apps/sd/components/generate.py b/shortfin/python/shortfin_apps/sd/components/generate.py index ebb5ea08a..1afa73d5e 100644 --- a/shortfin/python/shortfin_apps/sd/components/generate.py +++ b/shortfin/python/shortfin_apps/sd/components/generate.py @@ -20,7 +20,7 @@ from .service import GenerateService from .metrics import measure -logger = logging.getLogger(__name__) +logger = logging.getLogger("shortfin-sd.generate") class GenerateImageProcess(sf.Process): diff --git a/shortfin/python/shortfin_apps/sd/components/io_struct.py b/shortfin/python/shortfin_apps/sd/components/io_struct.py index d2952a818..d1d9cf41a 100644 --- a/shortfin/python/shortfin_apps/sd/components/io_struct.py +++ b/shortfin/python/shortfin_apps/sd/components/io_struct.py @@ -72,3 +72,10 @@ def post_init(self): raise ValueError("The rid should be a list.") if self.output_type is None: self.output_type = ["base64"] * self.num_output_images + # Temporary restrictions + heights = [self.height] if not isinstance(self.height, list) else self.height + widths = [self.width] if not isinstance(self.width, list) else self.width + if any(dim != 1024 for dim in [*heights, *widths]): + raise ValueError( + "Currently, only 1024x1024 output image size is supported." + ) diff --git a/shortfin/python/shortfin_apps/sd/components/manager.py b/shortfin/python/shortfin_apps/sd/components/manager.py index b44116b39..ea29b69a4 100644 --- a/shortfin/python/shortfin_apps/sd/components/manager.py +++ b/shortfin/python/shortfin_apps/sd/components/manager.py @@ -10,7 +10,7 @@ import shortfin as sf from shortfin.interop.support.device_setup import get_selected_devices -logger = logging.getLogger(__name__) +logger = logging.getLogger("shortfin-sd.manager") class SystemManager: @@ -25,7 +25,7 @@ def __init__(self, device="local-task", device_ids=None, async_allocs=True): sb.visible_devices = sb.available_devices sb.visible_devices = get_selected_devices(sb, device_ids) self.ls = sb.create_system() - logger.info(f"Created local system with {self.ls.device_names} devices") + logging.info(f"Created local system with {self.ls.device_names} devices") # TODO: Come up with an easier bootstrap thing than manually # running a thread. self.t = threading.Thread(target=lambda: self.ls.run(self.run())) diff --git a/shortfin/python/shortfin_apps/sd/components/messages.py b/shortfin/python/shortfin_apps/sd/components/messages.py index 88eb28ff4..6ae716bad 100644 --- a/shortfin/python/shortfin_apps/sd/components/messages.py +++ b/shortfin/python/shortfin_apps/sd/components/messages.py @@ -13,7 +13,7 @@ from .io_struct import GenerateReqInput -logger = logging.getLogger(__name__) +logger = logging.getLogger("shortfin-sd.messages") class InferencePhase(Enum): diff --git a/shortfin/python/shortfin_apps/sd/components/metrics.py b/shortfin/python/shortfin_apps/sd/components/metrics.py index f8fd30876..a1811beea 100644 --- a/shortfin/python/shortfin_apps/sd/components/metrics.py +++ b/shortfin/python/shortfin_apps/sd/components/metrics.py @@ -10,7 +10,7 @@ from typing import Callable, Any import functools -logger = logging.getLogger(__name__) +logger = logging.getLogger("shortfin-sd.metrics") def measure(fn=None, type="exec", task=None, num_items=None, freq=1, label="items"): diff --git a/shortfin/python/shortfin_apps/sd/components/service.py b/shortfin/python/shortfin_apps/sd/components/service.py index 1ee11569a..ad3fd9404 100644 --- a/shortfin/python/shortfin_apps/sd/components/service.py +++ b/shortfin/python/shortfin_apps/sd/components/service.py @@ -24,7 +24,8 @@ from .metrics import measure -logger = logging.getLogger(__name__) +logger = logging.getLogger("shortfin-sd.service") +logger.setLevel(logging.DEBUG) prog_isolations = { "none": sf.ProgramIsolation.NONE, @@ -119,8 +120,6 @@ def load_inference_parameters( def start(self): # Initialize programs. - # This can work if we only initialize one set of programs per service, as our programs - # in SDXL are stateless and for component in self.inference_modules: component_modules = [ sf.ProgramModule.parameter_provider( @@ -128,17 +127,22 @@ def start(self): ), *self.inference_modules[component], ] + for worker_idx, worker in enumerate(self.workers): worker_devices = self.fibers[ worker_idx * (self.fibers_per_worker) ].raw_devices - + logger.info( + f"Loading inference program: {component}, worker index: {worker_idx}, device: {worker_devices}" + ) self.inference_programs[worker_idx][component] = sf.Program( modules=component_modules, devices=worker_devices, isolation=self.prog_isolation, trace_execution=self.trace_execution, ) + logger.info("Program loaded.") + for worker_idx, worker in enumerate(self.workers): self.inference_functions[worker_idx]["encode"] = {} for bs in self.model_params.clip_batch_sizes: @@ -170,7 +174,6 @@ def start(self): ] = self.inference_programs[worker_idx]["vae"][ f"{self.model_params.vae_module_name}.decode" ] - # breakpoint() self.batcher.launch() def shutdown(self): @@ -212,8 +215,8 @@ class BatcherProcess(sf.Process): into batches. """ - STROBE_SHORT_DELAY = 0.1 - STROBE_LONG_DELAY = 0.25 + STROBE_SHORT_DELAY = 0.5 + STROBE_LONG_DELAY = 1 def __init__(self, service: GenerateService): super().__init__(fiber=service.fibers[0]) @@ -356,7 +359,6 @@ async def run(self): logger.error("Executor process recieved disjoint batch.") phase = req.phase phases = self.exec_requests[0].phases - req_count = len(self.exec_requests) device0 = self.service.fibers[self.fiber_index].device(0) if phases[InferencePhase.PREPARE]["required"]: @@ -424,8 +426,12 @@ async def _prepare(self, device, requests): async def _encode(self, device, requests): req_bs = len(requests) entrypoints = self.service.inference_functions[self.worker_index]["encode"] + if req_bs not in list(entrypoints.keys()): + for request in requests: + await self._encode(device, [request]) + return for bs, fn in entrypoints.items(): - if bs >= req_bs: + if bs == req_bs: break # Prepare tokenized input ids for CLIP inference @@ -462,6 +468,7 @@ async def _encode(self, device, requests): fn, "".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(clip_inputs)]), ) + await device pe, te = await fn(*clip_inputs, fiber=self.fiber) for i in range(req_bs): @@ -477,8 +484,12 @@ async def _denoise(self, device, requests): cfg_mult = 2 if self.service.model_params.cfg_mode else 1 # Produce denoised latents entrypoints = self.service.inference_functions[self.worker_index]["denoise"] + if req_bs not in list(entrypoints.keys()): + for request in requests: + await self._denoise(device, [request]) + return for bs, fns in entrypoints.items(): - if bs >= req_bs: + if bs == req_bs: break # Get shape of batched latents. @@ -613,8 +624,12 @@ async def _decode(self, device, requests): req_bs = len(requests) # Decode latents to images entrypoints = self.service.inference_functions[self.worker_index]["decode"] + if req_bs not in list(entrypoints.keys()): + for request in requests: + await self._decode(device, [request]) + return for bs, fn in entrypoints.items(): - if bs >= req_bs: + if bs == req_bs: break latents_shape = [ diff --git a/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs32.json b/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs32.json index 192a2be61..002f43f0e 100644 --- a/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs32.json +++ b/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs32.json @@ -29,6 +29,8 @@ " a cat under the snow with red eyes, covered by snow, cinematic style, medium shot, professional photo, animal", " a cat under the snow with blue eyes, covered by ice, cinematic style, medium shot, professional photo, animal", " a cat under the snow with blue eyes, covered by snow, cartoon style, medium shot, professional photo, animal", + " a cat under the snow with blue eyes, covered by ice, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with blue eyes, covered by snow, cartoon style, medium shot, professional photo, animal", " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo" ], "neg_prompt": [ diff --git a/shortfin/python/shortfin_apps/sd/examples/send_request.py b/shortfin/python/shortfin_apps/sd/examples/send_request.py deleted file mode 100644 index 9fce890d6..000000000 --- a/shortfin/python/shortfin_apps/sd/examples/send_request.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import json -import requests -import argparse -import base64 - -from datetime import datetime as dt -from PIL import Image - -sample_request = { - "prompt": [ - " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal", - ], - "neg_prompt": ["Watermark, blurry, oversaturated, low resolution, pollution"], - "height": [1024], - "width": [1024], - "steps": [20], - "guidance_scale": [7.5], - "seed": [0], - "output_type": ["base64"], - "rid": ["string"], -} - - -def bytes_to_img(bytes, idx=0, width=1024, height=1024): - timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S") - image = Image.frombytes( - mode="RGB", size=(width, height), data=base64.b64decode(bytes) - ) - image.save(f"shortfin_sd_output_{timestamp}_{idx}.png") - print(f"Saved to shortfin_sd_output_{timestamp}_{idx}.png") - - -def send_json_file(args): - # Read the JSON file - try: - if args.file == "default": - data = sample_request - else: - with open(args.file, "r") as json_file: - data = json.load(json_file) - except Exception as e: - print(f"Error reading the JSON file: {e}") - return - data["prompt"] = ( - [data["prompt"]] - if isinstance(data["prompt"], str) - else data["prompt"] * args.reps - ) - # Send the data to the /generate endpoint - try: - response = requests.post("http://0.0.0.0:8000/generate", json=data) - response.raise_for_status() # Raise an error for bad responses - timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S") - request = json.loads(response.request.body.decode("utf-8")) - for idx, item in enumerate(response.json()["images"]): - width = get_batched(request, "width", idx) - height = get_batched(request, "height", idx) - if args.save: - print("Saving response as image...") - bytes_to_img(item.encode("utf-8"), idx, width, height) - print("Responses processed.") - - except requests.exceptions.RequestException as e: - print(f"Error sending the request: {e}") - - -def get_batched(request, arg, idx): - if isinstance(request[arg], list): - if len(request[arg]) == 1: - indexed = request[arg][0] - else: - indexed = request[arg][idx] - else: - indexed = request[arg] - return indexed - - -if __name__ == "__main__": - p = argparse.ArgumentParser() - p.add_argument("--file", type=str, default="default") - p.add_argument("--reps", type=int, default=1) - p.add_argument("--save", type=argparse.BooleanOptionalAction, help="save images") - args = p.parse_args() - send_json_file(args) diff --git a/shortfin/python/shortfin_apps/sd/server.py b/shortfin/python/shortfin_apps/sd/server.py index 7ace4d407..9ee81d1c4 100644 --- a/shortfin/python/shortfin_apps/sd/server.py +++ b/shortfin/python/shortfin_apps/sd/server.py @@ -15,8 +15,6 @@ import copy import subprocess -from iree.build import * - # Import first as it does dep checking and reporting. from shortfin.interop.fastapi import FastAPIResponder @@ -33,9 +31,12 @@ from .components.tokenizer import Tokenizer from .components.builders import sdxl -from shortfin.support.logging_setup import configure_main_logger +from shortfin.support.logging_setup import native_handler, configure_main_logger -logger = configure_main_logger("server") +logger = logging.getLogger("shortfin-sd") +logger.addHandler(native_handler) +logger.setLevel(logging.INFO) +logger.propagate = False THIS_DIR = Path(__file__).resolve().parent @@ -84,6 +85,7 @@ async def generate_request(gen_req: GenerateReqInput, request: Request): def configure(args) -> SystemManager: # Setup system (configure devices, etc). + model_config, topology_config, flagfile, tuning_spec, args = get_configs(args) sysman = SystemManager(args.device, args.device_ids, args.amdgpu_async_allocations) # Setup each service we are hosting. @@ -92,7 +94,9 @@ def configure(args) -> SystemManager: subfolder = f"tokenizer_{idx + 1}" if idx > 0 else "tokenizer" tokenizers.append(Tokenizer.from_pretrained(tok_name, subfolder)) - model_params = ModelParams.load_json(args.model_config) + model_params = ModelParams.load_json(model_config) + vmfbs, params = get_modules(args, model_config, flagfile, tuning_spec) + sm = GenerateService( name="sd", sysman=sysman, @@ -104,7 +108,6 @@ def configure(args) -> SystemManager: show_progress=args.show_progress, trace_execution=args.trace_execution, ) - vmfbs, params = get_modules(args) for key, vmfblist in vmfbs.items(): for vmfb in vmfblist: sm.load_inference_module(vmfb, component=key) @@ -114,15 +117,80 @@ def configure(args) -> SystemManager: return sysman -def get_modules(args): +def get_configs(args): + # Returns one set of config artifacts. + modelname = "sdxl" + model_config = args.model_config if args.model_config else None + topology_config = None + tuning_spec = None + flagfile = args.flagfile if args.flagfile else None + topology_inp = args.topology if args.topology else "spx_single" + cfg_builder_args = [ + sys.executable, + "-m", + "iree.build", + os.path.join(THIS_DIR, "components", "config_artifacts.py"), + f"--target={args.target}", + f"--output-dir={args.artifacts_dir}", + f"--model={modelname}", + f"--topology={topology_inp}", + ] + outs = subprocess.check_output(cfg_builder_args).decode() + outs_paths = outs.splitlines() + for i in outs_paths: + if "sdxl_config" in i and not args.model_config: + model_config = i + elif "topology" in i and args.topology: + topology_config = i + elif "flagfile" in i and not args.flagfile: + flagfile = i + elif "attention_and_matmul_spec" in i and args.use_tuned: + tuning_spec = i + + if args.use_tuned and args.tuning_spec: + tuning_spec = os.path.abspath(args.tuning_spec) + + if topology_config: + with open(topology_config, "r") as f: + contents = [line.rstrip() for line in f] + for spec in contents: + if "--" in spec: + arglist = spec.strip("--").split("=") + arg = arglist[0] + if len(arglist) > 2: + value = arglist[1:] + for val in value: + try: + val = int(val) + except ValueError: + continue + elif len(arglist) == 2: + value = arglist[-1] + try: + value = int(value) + except ValueError: + continue + else: + # It's a boolean arg. + value = True + setattr(args, arg, value) + else: + # It's an env var. + arglist = spec.split("=") + os.environ[arglist[0]] = arglist[1] + + return model_config, topology_config, flagfile, tuning_spec, args + + +def get_modules(args, model_config, flagfile, td_spec): # TODO: Move this out of server entrypoint vmfbs = {"clip": [], "unet": [], "vae": [], "scheduler": []} params = {"clip": [], "unet": [], "vae": []} model_flags = copy.deepcopy(vmfbs) model_flags["all"] = args.compile_flags - if args.flagfile: - with open(args.flagfile, "r") as f: + if flagfile: + with open(flagfile, "r") as f: contents = [line.rstrip() for line in f] flagged_model = "all" for elem in contents: @@ -131,6 +199,10 @@ def get_modules(args): flagged_model = elem else: model_flags[flagged_model].extend([elem]) + if td_spec: + model_flags["unet"].extend( + [f"--iree-codegen-transform-dialect-library={td_spec}"] + ) filenames = [] for modelname in vmfbs.keys(): @@ -140,7 +212,7 @@ def get_modules(args): "-m", "iree.build", os.path.join(THIS_DIR, "components", "builders.py"), - f"--model-json={args.model_config}", + f"--model-json={model_config}", f"--target={args.target}", f"--splat={args.splat}", f"--build-preference={args.build_preference}", @@ -165,6 +237,8 @@ def get_modules(args): def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): + from pathlib import Path + parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default=None) parser.add_argument("--port", type=int, default=8000) @@ -212,8 +286,7 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): parser.add_argument( "--model_config", type=Path, - required=True, - help="Path to the model config file", + help="Path to the model config file. If None, defaults to i8 punet, batch size 1", ) parser.add_argument( "--workers_per_device", @@ -275,17 +348,36 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): ) parser.add_argument( "--artifacts_dir", + type=Path, + default=None, + help="Path to local artifacts cache.", + ) + parser.add_argument( + "--tuning_spec", type=str, default="", - help="Path to local artifacts cache.", + help="Path to transform dialect spec if compiling an executable with tunings.", + ) + parser.add_argument( + "--topology", + type=str, + default=None, + choices=["spx_single", "cpx_single", "spx_multi", "cpx_multi"], + help="Use one of four known performant preconfigured device/fiber topologies.", + ) + parser.add_argument( + "--use_tuned", + type=int, + default=1, + help="Use tunings for attention and matmul ops. 0 to disable.", ) args = parser.parse_args(argv) + if not args.artifacts_dir: + home = Path.home() + artdir = home / ".cache" / "shark" + args.artifacts_dir = str(artdir) - log_level = logging.INFO - - logging.root.setLevel(log_level) - logger.addHandler(logging.FileHandler("shortfin_sd.log")) global sysman sysman = configure(args) uvicorn.run( @@ -298,14 +390,31 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): if __name__ == "__main__": + logging.root.setLevel(logging.INFO) main( sys.argv[1:], # Make logging defer to the default shortfin logging config. log_config={ "version": 1, "disable_existing_loggers": False, - "formatters": {}, - "handlers": {}, - "loggers": {}, + "formatters": { + "default": { + "format": "%(asctime)s - %(levelname)s - %(message)s", + "datefmt": "%Y-%m-%d %H:%M:%S", + }, + }, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "formatter": "default", + }, + }, + "loggers": { + "uvicorn": { + "handlers": ["console"], + "level": "INFO", + "propagate": False, + }, + }, }, ) diff --git a/shortfin/python/shortfin_apps/sd/simple_client.py b/shortfin/python/shortfin_apps/sd/simple_client.py new file mode 100644 index 000000000..f8aabd8e7 --- /dev/null +++ b/shortfin/python/shortfin_apps/sd/simple_client.py @@ -0,0 +1,229 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import json +import requests +import argparse +import base64 +import time +import asyncio +import aiohttp +import sys +import os + +from datetime import datetime as dt +from PIL import Image + +sample_request = { + "prompt": [ + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + ], + "neg_prompt": ["Watermark, blurry, oversaturated, low resolution, pollution"], + "height": [1024], + "width": [1024], + "steps": [20], + "guidance_scale": [7.5], + "seed": [0], + "output_type": ["base64"], + "rid": ["string"], +} + + +def bytes_to_img(bytes, idx=0, width=1024, height=1024, outputdir="./gen_imgs"): + timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S") + image = Image.frombytes( + mode="RGB", size=(width, height), data=base64.b64decode(bytes) + ) + if not os.path.isdir(outputdir): + os.mkdir(outputdir) + im_path = os.path.join(outputdir, f"shortfin_sd_output_{timestamp}_{idx}.png") + image.save(im_path) + print(f"Saved to {im_path}") + + +def get_batched(request, arg, idx): + if isinstance(request[arg], list): + if len(request[arg]) == 1: + indexed = request[arg][0] + else: + indexed = request[arg][idx] + else: + indexed = request[arg] + return indexed + + +async def send_request(session, rep, args, data): + try: + print("Sending request batch #", rep) + url = f"http://0.0.0.0:{args.port}/generate" + start = time.time() + async with session.post(url, json=data) as response: + end = time.time() + # Check if the response was successful + if response.status == 200: + response.raise_for_status() # Raise an error for bad responses + timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S") + res_json = await response.json(content_type=None) + if args.save: + for idx, item in enumerate(res_json["images"]): + width = get_batched(data, "width", idx) + height = get_batched(data, "height", idx) + print("Saving response as image...") + bytes_to_img( + item.encode("utf-8"), idx, width, height, args.outputdir + ) + latency = end - start + print("Responses processed.") + return latency, len(data["prompt"]) + else: + print(f"Error: Received {response.status} from server") + raise Exception + except Exception as e: + print(f"Request failed: {e}") + raise Exception + + +async def static(args): + # Create an aiohttp session for sending requests + async with aiohttp.ClientSession() as session: + pending = [] + latencies = [] + sample_counts = [] + # Read the JSON file if supplied. Otherwise, get user input. + try: + if args.file == "default": + data = sample_request + else: + with open(args.file, "r") as json_file: + data = json.load(json_file) + except Exception as e: + print(f"Error reading the JSON file: {e}") + return + data["prompt"] = ( + [data["prompt"]] if isinstance(data["prompt"], str) else data["prompt"] + ) + start = time.time() + + async for i in async_range(args.reps): + pending.append(asyncio.create_task(send_request(session, i, args, data))) + await asyncio.sleep(1) # Wait for 1 second before sending the next request + while pending: + done, pending = await asyncio.wait( + pending, return_when=asyncio.ALL_COMPLETED + ) + for task in done: + latency, num_samples = await task + latencies.append(latency) + sample_counts.append(num_samples) + end = time.time() + if not any([i is None for i in [latencies, sample_counts]]): + total_num_samples = sum(sample_counts) + sps = str(total_num_samples / (end - start)) + print(f"Average throughput: {sps} samples per second") + else: + raise ValueError("Received error response from server.") + + +async def interactive(args): + # Create an aiohttp session for sending requests + async with aiohttp.ClientSession() as session: + pending = [] + latencies = [] + sample_counts = [] + # Read the JSON file if supplied. Otherwise, get user input. + try: + if args.file == "default": + data = sample_request + else: + with open(args.file, "r") as json_file: + data = json.load(json_file) + except Exception as e: + print(f"Error reading the JSON file: {e}") + return + data["prompt"] = ( + [data["prompt"]] if isinstance(data["prompt"], str) else data["prompt"] + ) + while True: + prompt = await ainput("Enter a prompt: ") + data["prompt"] = [prompt] + data["steps"] = [args.steps] + print("Sending request with prompt: ", data["prompt"]) + + async for i in async_range(args.reps): + pending.append( + asyncio.create_task(send_request(session, i, args, data)) + ) + await asyncio.sleep( + 1 + ) # Wait for 1 second before sending the next request + while pending: + done, pending = await asyncio.wait( + pending, return_when=asyncio.ALL_COMPLETED + ) + for task in done: + latency, num_samples = await task + pending = [] + if any([i is None for i in [latencies, sample_counts]]): + raise ValueError("Received error response from server.") + + +async def ainput(prompt: str) -> str: + return await asyncio.to_thread(input, f"{prompt} ") + + +async def async_range(count): + for i in range(count): + yield (i) + await asyncio.sleep(0.0) + + +def main(argv): + p = argparse.ArgumentParser() + p.add_argument( + "--file", + type=str, + default="default", + help="A non-default request to send to the server.", + ) + p.add_argument( + "--reps", + type=int, + default=1, + help="Number of times to duplicate each request in one second intervals.", + ) + p.add_argument( + "--save", + action=argparse.BooleanOptionalAction, + default=True, + help="Save images. To disable, use --no-save", + ) + p.add_argument( + "--outputdir", + type=str, + default="gen_imgs", + help="Directory to which images get saved.", + ) + p.add_argument("--port", type=str, default="8000", help="Server port") + p.add_argument( + "--steps", + type=int, + default="20", + help="Number of inference steps. More steps usually means a better image. Interactive only.", + ) + p.add_argument( + "--interactive", + action="store_true", + help="Start as an example CLI client instead of sending static requests.", + ) + args = p.parse_args() + if args.interactive: + asyncio.run(interactive(args)) + else: + asyncio.run(static(args)) + + +if __name__ == "__main__": + main(sys.argv)