-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
(shortfin-sd) Usability and logging improvements. #491
Changes from 19 commits
d853eef
d7a5e9f
1df6aa8
32c6011
959963c
6ada498
e2b23f7
71855bc
50580d1
2c1d9b2
fd68b5e
b847134
06ff2b1
b88ad15
e4267f3
109e208
f75ffae
f4767d7
937be94
33dd327
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,7 +24,7 @@ | |
sfnp.bfloat16: "bf16", | ||
} | ||
|
||
ARTIFACT_VERSION = "11022024" | ||
ARTIFACT_VERSION = "11132024" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can artifact version bump be automated in future? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure |
||
SDXL_BUCKET = ( | ||
f"https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/{ARTIFACT_VERSION}/" | ||
) | ||
|
@@ -51,7 +51,9 @@ def get_mlir_filenames(model_params: ModelParams, model=None): | |
return filter_by_model(mlir_filenames, model) | ||
|
||
|
||
def get_vmfb_filenames(model_params: ModelParams, model=None, target: str = "gfx942"): | ||
def get_vmfb_filenames( | ||
model_params: ModelParams, model=None, target: str = "amdgpu-gfx942" | ||
): | ||
vmfb_filenames = [] | ||
file_stems = get_file_stems(model_params) | ||
for stem in file_stems: | ||
|
@@ -216,6 +218,8 @@ def sdxl( | |
|
||
mlir_bucket = SDXL_BUCKET + "mlir/" | ||
vmfb_bucket = SDXL_BUCKET + "vmfbs/" | ||
if "gfx" in target: | ||
target = "amdgpu-" + target | ||
|
||
mlir_filenames = get_mlir_filenames(model_params, model) | ||
mlir_urls = get_url_map(mlir_filenames, mlir_bucket) | ||
|
@@ -247,7 +251,7 @@ def sdxl( | |
params_urls = get_url_map(params_filenames, SDXL_WEIGHTS_BUCKET) | ||
for f, url in params_urls.items(): | ||
out_file = os.path.join(ctx.executor.output_dir, f) | ||
if update or needs_file(f, ctx): | ||
if needs_file(f, ctx): | ||
fetch_http(name=f, url=url) | ||
filenames = [*vmfb_filenames, *params_filenames, *mlir_filenames] | ||
return filenames | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from iree.build import * | ||
from iree.build.executor import FileNamespace | ||
import itertools | ||
import os | ||
import shortfin.array as sfnp | ||
import copy | ||
|
||
from shortfin_apps.sd.components.config_struct import ModelParams | ||
|
||
this_dir = os.path.dirname(os.path.abspath(__file__)) | ||
parent = os.path.dirname(this_dir) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do you use this anywhere? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nope, copied from the other builder. Thanks. |
||
|
||
dtype_to_filetag = { | ||
sfnp.float16: "fp16", | ||
sfnp.float32: "fp32", | ||
sfnp.int8: "i8", | ||
sfnp.bfloat16: "bf16", | ||
} | ||
|
||
ARTIFACT_VERSION = "11132024" | ||
SDXL_CONFIG_BUCKET = f"https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/{ARTIFACT_VERSION}/configs/" | ||
|
||
|
||
def get_url_map(filenames: list[str], bucket: str): | ||
file_map = {} | ||
for filename in filenames: | ||
file_map[filename] = f"{bucket}{filename}" | ||
return file_map | ||
|
||
|
||
def needs_update(ctx): | ||
stamp = ctx.allocate_file("version.txt") | ||
stamp_path = stamp.get_fs_path() | ||
if os.path.exists(stamp_path): | ||
with open(stamp_path, "r") as s: | ||
ver = s.read() | ||
if ver != ARTIFACT_VERSION: | ||
return True | ||
else: | ||
with open(stamp_path, "w") as s: | ||
s.write(ARTIFACT_VERSION) | ||
return True | ||
return False | ||
|
||
|
||
def needs_file(filename, ctx, namespace=FileNamespace.GEN): | ||
out_file = ctx.allocate_file(filename, namespace=namespace).get_fs_path() | ||
if os.path.exists(out_file): | ||
needed = False | ||
else: | ||
# name_path = "bin" if namespace == FileNamespace.BIN else "" | ||
# if name_path: | ||
# filename = os.path.join(name_path, filename) | ||
filekey = os.path.join(ctx.path, filename) | ||
ctx.executor.all[filekey] = None | ||
needed = True | ||
return needed | ||
|
||
|
||
@entrypoint(description="Retreives a set of SDXL configuration files.") | ||
def sdxlconfig( | ||
target=cl_arg( | ||
"target", | ||
default="gfx942", | ||
help="IREE target architecture.", | ||
), | ||
model=cl_arg("model", type=str, default="sdxl", help="Model architecture"), | ||
topology=cl_arg( | ||
"topology", | ||
type=str, | ||
default="spx_single", | ||
help="System topology configfile keyword", | ||
), | ||
): | ||
ctx = executor.BuildContext.current() | ||
update = needs_update(ctx) | ||
|
||
model_config_filenames = [f"{model}_config_i8.json"] | ||
model_config_urls = get_url_map(model_config_filenames, SDXL_CONFIG_BUCKET) | ||
for f, url in model_config_urls.items(): | ||
out_file = os.path.join(ctx.executor.output_dir, f) | ||
if update or needs_file(f, ctx): | ||
fetch_http(name=f, url=url) | ||
|
||
topology_config_filenames = [f"topology_config_{topology}.txt"] | ||
topology_config_urls = get_url_map(topology_config_filenames, SDXL_CONFIG_BUCKET) | ||
for f, url in topology_config_urls.items(): | ||
out_file = os.path.join(ctx.executor.output_dir, f) | ||
if update or needs_file(f, ctx): | ||
fetch_http(name=f, url=url) | ||
|
||
flagfile_filenames = [f"{model}_flagfile_{target}.txt"] | ||
flagfile_urls = get_url_map(flagfile_filenames, SDXL_CONFIG_BUCKET) | ||
for f, url in flagfile_urls.items(): | ||
out_file = os.path.join(ctx.executor.output_dir, f) | ||
if update or needs_file(f, ctx): | ||
fetch_http(name=f, url=url) | ||
|
||
tuning_filenames = ( | ||
[f"attention_and_matmul_spec_{target}.mlir"] if target == "gfx942" else [] | ||
) | ||
tuning_urls = get_url_map(tuning_filenames, SDXL_CONFIG_BUCKET) | ||
for f, url in tuning_urls.items(): | ||
out_file = os.path.join(ctx.executor.output_dir, f) | ||
if update or needs_file(f, ctx): | ||
fetch_http(name=f, url=url) | ||
filenames = [ | ||
*model_config_filenames, | ||
*topology_config_filenames, | ||
*flagfile_filenames, | ||
*tuning_filenames, | ||
] | ||
return filenames | ||
|
||
|
||
if __name__ == "__main__": | ||
iree_build_main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,7 +24,8 @@ | |
from .metrics import measure | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
logger = logging.getLogger("shortfin-sd.service") | ||
logger.setLevel(logging.DEBUG) | ||
|
||
prog_isolations = { | ||
"none": sf.ProgramIsolation.NONE, | ||
|
@@ -119,26 +120,29 @@ def load_inference_parameters( | |
|
||
def start(self): | ||
# Initialize programs. | ||
# This can work if we only initialize one set of programs per service, as our programs | ||
# in SDXL are stateless and | ||
for component in self.inference_modules: | ||
component_modules = [ | ||
sf.ProgramModule.parameter_provider( | ||
self.sysman.ls, *self.inference_parameters.get(component, []) | ||
), | ||
*self.inference_modules[component], | ||
] | ||
|
||
for worker_idx, worker in enumerate(self.workers): | ||
worker_devices = self.fibers[ | ||
worker_idx * (self.fibers_per_worker) | ||
].raw_devices | ||
|
||
logger.info( | ||
f"Loading inference program: {component}, worker index: {worker_idx}, device: {worker_devices}" | ||
) | ||
self.inference_programs[worker_idx][component] = sf.Program( | ||
modules=component_modules, | ||
devices=worker_devices, | ||
isolation=self.prog_isolation, | ||
trace_execution=self.trace_execution, | ||
) | ||
logger.info("Program loaded.") | ||
|
||
for worker_idx, worker in enumerate(self.workers): | ||
self.inference_functions[worker_idx]["encode"] = {} | ||
for bs in self.model_params.clip_batch_sizes: | ||
|
@@ -170,7 +174,6 @@ def start(self): | |
] = self.inference_programs[worker_idx]["vae"][ | ||
f"{self.model_params.vae_module_name}.decode" | ||
] | ||
# breakpoint() | ||
self.batcher.launch() | ||
|
||
def shutdown(self): | ||
|
@@ -212,8 +215,8 @@ class BatcherProcess(sf.Process): | |
into batches. | ||
""" | ||
|
||
STROBE_SHORT_DELAY = 0.1 | ||
STROBE_LONG_DELAY = 0.25 | ||
STROBE_SHORT_DELAY = 0.5 | ||
STROBE_LONG_DELAY = 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How does this impact user experience? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these are the increments by which the batcher checks for incoming requests |
||
|
||
def __init__(self, service: GenerateService): | ||
super().__init__(fiber=service.fibers[0]) | ||
|
@@ -356,7 +359,6 @@ async def run(self): | |
logger.error("Executor process recieved disjoint batch.") | ||
phase = req.phase | ||
phases = self.exec_requests[0].phases | ||
|
||
req_count = len(self.exec_requests) | ||
device0 = self.service.fibers[self.fiber_index].device(0) | ||
if phases[InferencePhase.PREPARE]["required"]: | ||
|
@@ -424,8 +426,12 @@ async def _prepare(self, device, requests): | |
async def _encode(self, device, requests): | ||
req_bs = len(requests) | ||
entrypoints = self.service.inference_functions[self.worker_index]["encode"] | ||
if req_bs not in list(entrypoints.keys()): | ||
for request in requests: | ||
await self._encode(device, [request]) | ||
return | ||
for bs, fn in entrypoints.items(): | ||
if bs >= req_bs: | ||
if bs == req_bs: | ||
break | ||
|
||
# Prepare tokenized input ids for CLIP inference | ||
|
@@ -462,6 +468,7 @@ async def _encode(self, device, requests): | |
fn, | ||
"".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(clip_inputs)]), | ||
) | ||
await device | ||
pe, te = await fn(*clip_inputs, fiber=self.fiber) | ||
|
||
for i in range(req_bs): | ||
|
@@ -477,8 +484,12 @@ async def _denoise(self, device, requests): | |
cfg_mult = 2 if self.service.model_params.cfg_mode else 1 | ||
# Produce denoised latents | ||
entrypoints = self.service.inference_functions[self.worker_index]["denoise"] | ||
if req_bs not in list(entrypoints.keys()): | ||
for request in requests: | ||
await self._denoise(device, [request]) | ||
return | ||
for bs, fns in entrypoints.items(): | ||
if bs >= req_bs: | ||
if bs == req_bs: | ||
break | ||
|
||
# Get shape of batched latents. | ||
|
@@ -613,8 +624,12 @@ async def _decode(self, device, requests): | |
req_bs = len(requests) | ||
# Decode latents to images | ||
entrypoints = self.service.inference_functions[self.worker_index]["decode"] | ||
if req_bs not in list(entrypoints.keys()): | ||
for request in requests: | ||
await self._decode(device, [request]) | ||
return | ||
for bs, fn in entrypoints.items(): | ||
if bs >= req_bs: | ||
if bs == req_bs: | ||
break | ||
|
||
latents_shape = [ | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For release should we move this back to info/warn? if you want I can just clean up the logging after you submit. I feel tempted to just add a debug/verbose option and pipe that through here.