Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(shortfin-sd) Usability and logging improvements. #491

Merged
merged 20 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
d853eef
Logging fixes
eagarvey-amd Nov 9, 2024
d7a5e9f
Remove a few extra lines.
eagarvey-amd Nov 11, 2024
1df6aa8
(DNM) logging for multidevice
eagarvey-amd Nov 11, 2024
32c6011
Improvements to client example.
eagarvey-amd Nov 12, 2024
959963c
Sanitize for python3.11 and correct throughput measurement.
eagarvey-amd Nov 13, 2024
6ada498
Get abspath from tuning spec arg.
eagarvey-amd Nov 13, 2024
e2b23f7
Add interactive client interface, separate config artifact downloads
eagarvey-amd Nov 13, 2024
71855bc
Merge branch 'main' into sd-logging
monorimet Nov 13, 2024
50580d1
Send default topology to config builder.
eagarvey-amd Nov 13, 2024
2c1d9b2
Run formatting
eagarvey-amd Nov 13, 2024
fd68b5e
Rename examples/send_request.py -> simple_client.py
eagarvey-amd Nov 13, 2024
b847134
Add topology flag to README instructions.
eagarvey-amd Nov 13, 2024
06ff2b1
Remove unused code in client and switch to info log level
eagarvey-amd Nov 13, 2024
b88ad15
Bump artifact version to 11132024
eagarvey-amd Nov 13, 2024
e4267f3
Merge branch 'main' into sd-logging
monorimet Nov 13, 2024
109e208
Don't update weights on vmfb/mlir version update
eagarvey-amd Nov 13, 2024
f75ffae
Restrict image size without leaving server in error state.
eagarvey-amd Nov 13, 2024
f4767d7
Fix paths for precompiled artifact updates and client prints
eagarvey-amd Nov 13, 2024
937be94
Suggest precompiled artifacts in README
eagarvey-amd Nov 13, 2024
33dd327
Make the default artifacts dir point to the user .cache
eagarvey-amd Nov 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions shortfin/python/shortfin/support/logging_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,15 @@ def __init__(self):
native_handler.setFormatter(NativeFormatter())

# TODO: Source from env vars.
logger.setLevel(logging.INFO)
logger.setLevel(logging.DEBUG)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For release should we move this back to info/warn? if you want I can just clean up the logging after you submit. I feel tempted to just add a debug/verbose option and pipe that through here.

logger.addHandler(native_handler)


def configure_main_logger(module_suffix: str = "__main__") -> logging.Logger:
"""Configures logging from a main entrypoint.

Returns a logger that can be used for the main module itself.
"""
logging.root.addHandler(native_handler)
logging.root.setLevel(logging.WARNING) # TODO: source from env vars
main_module = sys.modules["__main__"]
logging.root.setLevel(logging.INFO)
logger = logging.getLogger(f"{main_module.__package__}.{module_suffix}")
logger.setLevel(logging.INFO)
logger.addHandler(native_handler)

return logger
return logging.getLogger(f"{main_module.__package__}.{module_suffix}")
11 changes: 4 additions & 7 deletions shortfin/python/shortfin_apps/sd/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,10 @@ cd shortfin/
The server will prepare runtime artifacts for you.

```
python -m shortfin_apps.sd.server --model_config=./python/shortfin_apps/sd/examples/sdxl_config_i8.json --device=amdgpu --device_ids=0 --flagfile=./python/shortfin_apps/sd/examples/sdxl_flags_gfx942.txt --build_preference=compile
python -m shortfin_apps.sd.server --device=amdgpu --device_ids=0 --build_preference=precompiled --topology="spx_single"
```
- Run with splat(empty) weights:
```
python -m shortfin_apps.sd.server --model_config=./python/shortfin_apps/sd/examples/sdxl_config_i8.json --device=amdgpu --device_ids=0 --splat --flagfile=./python/shortfin_apps/sd/examples/sdxl_flags_gfx942.txt --build_preference=compile
```
- Run a request in a separate shell:

- Run a CLI client in a separate shell:
```
python shortfin/python/shortfin_apps/sd/examples/send_request.py --file=shortfin/python/shortfin_apps/sd/examples/sdxl_request.json
python -m shortfin_apps.sd.simple_client --interactive --save
```
10 changes: 7 additions & 3 deletions shortfin/python/shortfin_apps/sd/components/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
sfnp.bfloat16: "bf16",
}

ARTIFACT_VERSION = "11022024"
ARTIFACT_VERSION = "11132024"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can artifact version bump be automated in future?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

SDXL_BUCKET = (
f"https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/{ARTIFACT_VERSION}/"
)
Expand All @@ -51,7 +51,9 @@ def get_mlir_filenames(model_params: ModelParams, model=None):
return filter_by_model(mlir_filenames, model)


def get_vmfb_filenames(model_params: ModelParams, model=None, target: str = "gfx942"):
def get_vmfb_filenames(
model_params: ModelParams, model=None, target: str = "amdgpu-gfx942"
):
vmfb_filenames = []
file_stems = get_file_stems(model_params)
for stem in file_stems:
Expand Down Expand Up @@ -216,6 +218,8 @@ def sdxl(

mlir_bucket = SDXL_BUCKET + "mlir/"
vmfb_bucket = SDXL_BUCKET + "vmfbs/"
if "gfx" in target:
target = "amdgpu-" + target

mlir_filenames = get_mlir_filenames(model_params, model)
mlir_urls = get_url_map(mlir_filenames, mlir_bucket)
Expand Down Expand Up @@ -247,7 +251,7 @@ def sdxl(
params_urls = get_url_map(params_filenames, SDXL_WEIGHTS_BUCKET)
for f, url in params_urls.items():
out_file = os.path.join(ctx.executor.output_dir, f)
if update or needs_file(f, ctx):
if needs_file(f, ctx):
fetch_http(name=f, url=url)
filenames = [*vmfb_filenames, *params_filenames, *mlir_filenames]
return filenames
Expand Down
123 changes: 123 additions & 0 deletions shortfin/python/shortfin_apps/sd/components/config_artifacts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from iree.build import *
from iree.build.executor import FileNamespace
import itertools
import os
import shortfin.array as sfnp
import copy

from shortfin_apps.sd.components.config_struct import ModelParams

this_dir = os.path.dirname(os.path.abspath(__file__))
parent = os.path.dirname(this_dir)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you use this anywhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, copied from the other builder. Thanks.


dtype_to_filetag = {
sfnp.float16: "fp16",
sfnp.float32: "fp32",
sfnp.int8: "i8",
sfnp.bfloat16: "bf16",
}

ARTIFACT_VERSION = "11132024"
SDXL_CONFIG_BUCKET = f"https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/{ARTIFACT_VERSION}/configs/"


def get_url_map(filenames: list[str], bucket: str):
file_map = {}
for filename in filenames:
file_map[filename] = f"{bucket}{filename}"
return file_map


def needs_update(ctx):
stamp = ctx.allocate_file("version.txt")
stamp_path = stamp.get_fs_path()
if os.path.exists(stamp_path):
with open(stamp_path, "r") as s:
ver = s.read()
if ver != ARTIFACT_VERSION:
return True
else:
with open(stamp_path, "w") as s:
s.write(ARTIFACT_VERSION)
return True
return False


def needs_file(filename, ctx, namespace=FileNamespace.GEN):
out_file = ctx.allocate_file(filename, namespace=namespace).get_fs_path()
if os.path.exists(out_file):
needed = False
else:
# name_path = "bin" if namespace == FileNamespace.BIN else ""
# if name_path:
# filename = os.path.join(name_path, filename)
filekey = os.path.join(ctx.path, filename)
ctx.executor.all[filekey] = None
needed = True
return needed


@entrypoint(description="Retreives a set of SDXL configuration files.")
def sdxlconfig(
target=cl_arg(
"target",
default="gfx942",
help="IREE target architecture.",
),
model=cl_arg("model", type=str, default="sdxl", help="Model architecture"),
topology=cl_arg(
"topology",
type=str,
default="spx_single",
help="System topology configfile keyword",
),
):
ctx = executor.BuildContext.current()
update = needs_update(ctx)

model_config_filenames = [f"{model}_config_i8.json"]
model_config_urls = get_url_map(model_config_filenames, SDXL_CONFIG_BUCKET)
for f, url in model_config_urls.items():
out_file = os.path.join(ctx.executor.output_dir, f)
if update or needs_file(f, ctx):
fetch_http(name=f, url=url)

topology_config_filenames = [f"topology_config_{topology}.txt"]
topology_config_urls = get_url_map(topology_config_filenames, SDXL_CONFIG_BUCKET)
for f, url in topology_config_urls.items():
out_file = os.path.join(ctx.executor.output_dir, f)
if update or needs_file(f, ctx):
fetch_http(name=f, url=url)

flagfile_filenames = [f"{model}_flagfile_{target}.txt"]
flagfile_urls = get_url_map(flagfile_filenames, SDXL_CONFIG_BUCKET)
for f, url in flagfile_urls.items():
out_file = os.path.join(ctx.executor.output_dir, f)
if update or needs_file(f, ctx):
fetch_http(name=f, url=url)

tuning_filenames = (
[f"attention_and_matmul_spec_{target}.mlir"] if target == "gfx942" else []
)
tuning_urls = get_url_map(tuning_filenames, SDXL_CONFIG_BUCKET)
for f, url in tuning_urls.items():
out_file = os.path.join(ctx.executor.output_dir, f)
if update or needs_file(f, ctx):
fetch_http(name=f, url=url)
filenames = [
*model_config_filenames,
*topology_config_filenames,
*flagfile_filenames,
*tuning_filenames,
]
return filenames


if __name__ == "__main__":
iree_build_main()
2 changes: 1 addition & 1 deletion shortfin/python/shortfin_apps/sd/components/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .service import GenerateService
from .metrics import measure

logger = logging.getLogger(__name__)
logger = logging.getLogger("shortfin-sd.generate")


class GenerateImageProcess(sf.Process):
Expand Down
7 changes: 7 additions & 0 deletions shortfin/python/shortfin_apps/sd/components/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,10 @@ def post_init(self):
raise ValueError("The rid should be a list.")
if self.output_type is None:
self.output_type = ["base64"] * self.num_output_images
# Temporary restrictions
heights = [self.height] if not isinstance(self.height, list) else self.height
widths = [self.width] if not isinstance(self.width, list) else self.width
if any(dim != 1024 for dim in [*heights, *widths]):
raise ValueError(
"Currently, only 1024x1024 output image size is supported."
)
4 changes: 2 additions & 2 deletions shortfin/python/shortfin_apps/sd/components/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import shortfin as sf
from shortfin.interop.support.device_setup import get_selected_devices

logger = logging.getLogger(__name__)
logger = logging.getLogger("shortfin-sd.manager")


class SystemManager:
Expand All @@ -25,7 +25,7 @@ def __init__(self, device="local-task", device_ids=None, async_allocs=True):
sb.visible_devices = sb.available_devices
sb.visible_devices = get_selected_devices(sb, device_ids)
self.ls = sb.create_system()
logger.info(f"Created local system with {self.ls.device_names} devices")
logging.info(f"Created local system with {self.ls.device_names} devices")
# TODO: Come up with an easier bootstrap thing than manually
# running a thread.
self.t = threading.Thread(target=lambda: self.ls.run(self.run()))
Expand Down
2 changes: 1 addition & 1 deletion shortfin/python/shortfin_apps/sd/components/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from .io_struct import GenerateReqInput

logger = logging.getLogger(__name__)
logger = logging.getLogger("shortfin-sd.messages")


class InferencePhase(Enum):
Expand Down
2 changes: 1 addition & 1 deletion shortfin/python/shortfin_apps/sd/components/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Callable, Any
import functools

logger = logging.getLogger(__name__)
logger = logging.getLogger("shortfin-sd.metrics")


def measure(fn=None, type="exec", task=None, num_items=None, freq=1, label="items"):
Expand Down
37 changes: 26 additions & 11 deletions shortfin/python/shortfin_apps/sd/components/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
from .metrics import measure


logger = logging.getLogger(__name__)
logger = logging.getLogger("shortfin-sd.service")
logger.setLevel(logging.DEBUG)

prog_isolations = {
"none": sf.ProgramIsolation.NONE,
Expand Down Expand Up @@ -119,26 +120,29 @@ def load_inference_parameters(

def start(self):
# Initialize programs.
# This can work if we only initialize one set of programs per service, as our programs
# in SDXL are stateless and
for component in self.inference_modules:
component_modules = [
sf.ProgramModule.parameter_provider(
self.sysman.ls, *self.inference_parameters.get(component, [])
),
*self.inference_modules[component],
]

for worker_idx, worker in enumerate(self.workers):
worker_devices = self.fibers[
worker_idx * (self.fibers_per_worker)
].raw_devices

logger.info(
f"Loading inference program: {component}, worker index: {worker_idx}, device: {worker_devices}"
)
self.inference_programs[worker_idx][component] = sf.Program(
modules=component_modules,
devices=worker_devices,
isolation=self.prog_isolation,
trace_execution=self.trace_execution,
)
logger.info("Program loaded.")

for worker_idx, worker in enumerate(self.workers):
self.inference_functions[worker_idx]["encode"] = {}
for bs in self.model_params.clip_batch_sizes:
Expand Down Expand Up @@ -170,7 +174,6 @@ def start(self):
] = self.inference_programs[worker_idx]["vae"][
f"{self.model_params.vae_module_name}.decode"
]
# breakpoint()
self.batcher.launch()

def shutdown(self):
Expand Down Expand Up @@ -212,8 +215,8 @@ class BatcherProcess(sf.Process):
into batches.
"""

STROBE_SHORT_DELAY = 0.1
STROBE_LONG_DELAY = 0.25
STROBE_SHORT_DELAY = 0.5
STROBE_LONG_DELAY = 1

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this impact user experience?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are the increments by which the batcher checks for incoming requests


def __init__(self, service: GenerateService):
super().__init__(fiber=service.fibers[0])
Expand Down Expand Up @@ -356,7 +359,6 @@ async def run(self):
logger.error("Executor process recieved disjoint batch.")
phase = req.phase
phases = self.exec_requests[0].phases

req_count = len(self.exec_requests)
device0 = self.service.fibers[self.fiber_index].device(0)
if phases[InferencePhase.PREPARE]["required"]:
Expand Down Expand Up @@ -424,8 +426,12 @@ async def _prepare(self, device, requests):
async def _encode(self, device, requests):
req_bs = len(requests)
entrypoints = self.service.inference_functions[self.worker_index]["encode"]
if req_bs not in list(entrypoints.keys()):
for request in requests:
await self._encode(device, [request])
return
for bs, fn in entrypoints.items():
if bs >= req_bs:
if bs == req_bs:
break

# Prepare tokenized input ids for CLIP inference
Expand Down Expand Up @@ -462,6 +468,7 @@ async def _encode(self, device, requests):
fn,
"".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(clip_inputs)]),
)
await device
pe, te = await fn(*clip_inputs, fiber=self.fiber)

for i in range(req_bs):
Expand All @@ -477,8 +484,12 @@ async def _denoise(self, device, requests):
cfg_mult = 2 if self.service.model_params.cfg_mode else 1
# Produce denoised latents
entrypoints = self.service.inference_functions[self.worker_index]["denoise"]
if req_bs not in list(entrypoints.keys()):
for request in requests:
await self._denoise(device, [request])
return
for bs, fns in entrypoints.items():
if bs >= req_bs:
if bs == req_bs:
break

# Get shape of batched latents.
Expand Down Expand Up @@ -613,8 +624,12 @@ async def _decode(self, device, requests):
req_bs = len(requests)
# Decode latents to images
entrypoints = self.service.inference_functions[self.worker_index]["decode"]
if req_bs not in list(entrypoints.keys()):
for request in requests:
await self._decode(device, [request])
return
for bs, fn in entrypoints.items():
if bs >= req_bs:
if bs == req_bs:
break

latents_shape = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
" a cat under the snow with red eyes, covered by snow, cinematic style, medium shot, professional photo, animal",
" a cat under the snow with blue eyes, covered by ice, cinematic style, medium shot, professional photo, animal",
" a cat under the snow with blue eyes, covered by snow, cartoon style, medium shot, professional photo, animal",
" a cat under the snow with blue eyes, covered by ice, cinematic style, medium shot, professional photo, animal",
" a cat under the snow with blue eyes, covered by snow, cartoon style, medium shot, professional photo, animal",
" a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo"
],
"neg_prompt": [
Expand Down
Loading
Loading