Skip to content

Commit

Permalink
Add --isolation Arg for LLM Server, Default to per_call (#445)
Browse files Browse the repository at this point in the history
# Description
Related to issue #428 

When we run the Shortfin Server, we currently set isolation for
`sf.Program` invocation to `per_fiber`. However, we don't currently have
a data structure to manage available fibers. This causes the server to
error out when invoking `batch` requests, which was found in SGLang
integration testing. By setting `isolation` to `per_call`, we can handle
the batch requests effectively, enabling more SGLang features, while we
implement the `FiberPool` as part of our `Radix`/`Shared Page Attention`
todos.

This makes `--isolation` a CLI arg to the LLM server, similar to how
it's setup for SD server, defaulting it to PER_CALL. This also makes it
easy to switch back-and-forth or switch the default back to `per_fiber`
down the road.

# Batch Errors

In SGLang, we have the option to send requests as a batch, allowing us
to execute multiple separate prompts in parallel:

## SGLang Frontend Code
```python
@sgl.function
def multi_turn_question(s, question_1, question_2):
    s += sgl.system("You are a helpful assistant.")
    s += sgl.user(question_1)
    s += sgl.assistant(sgl.gen("answer_1", max_tokens=256))
    s += sgl.user(question_2)
    s += sgl.assistant(sgl.gen("answer_2", max_tokens=256))
    
def batch():
    states = multi_turn_question.run_batch(
        [
            {
                "question_1": "What is the capital of the United States?",
                "question_2": "List two local attractions.",
            },
            {
                "question_1": "What is the capital of France?",
                "question_2": "What is the population of this city?",
            },
        ]
    )

    for s in states:
        for m in s.messages():
            print(m["role"], m["content"])

        print()
    print()

print("\n========== batch ==========\n")
batch()
```

## Shortfin Error
When this code is invoked, with `isolation` set to `per_fiber`, we hit a
concurrency error from `attempting concurrent invocations of a PER_FIBER
program from the same fiber`:
```bash
[2024-11-05 18:16:02.756] [error] [service.py:427] Fatal error in prefetch invocation
Traceback (most recent call last):
  File "/home/stbaione/repos/SHARK-Platform/shortfin/python/shortfin_apps/llm/components/service.py", line 408, in run
    (logits,) = await fn(*args, fiber=self.fiber)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Cannot make concurrent invocations of a PER_FIBER program from the same Fiber. This typically means that two 
invocations were attempted on the same program on the same fiber without an await. Consider fixing adding appropriate 
sequencing or switching to either PER_CALL or NONE isolation if appropriate for the use case. This exception can also occur if 
the first invocation to this Program failed, leaving no initialized Program for this fiber.
```

# Solution

By setting isolation to `per_call`, we're able to handle the batch
requests effectively (still some improvements that can be made in
shortfin LLM completion):

## SGLang Batch Invocation
```text
========== batch ==========

system You are a helpful assistant.
user What is the capital of the United States?
assistant Washington, D.C.
USER:What is the capital
user List two local attractions.
assistant List two

system You are a helpful assistant.
user What is the capital of France?
assistant Paris is the capital of France.
USER:!!!!
user What is the population of this city?
assistant Paris has
```

# Considerations

There was some concern about using PER_CALL due to the possibility of
stateful programs, however, currently we don't have any state that we
need to share across different batches & all batches will use different
kv cache lines.

We should revisit/implement the `FiberPool` specified in #428, but for
now, we can lump that into our `Radix`/`Shared Page Attention` todos,
enabling more SGLang features in the meantime.
  • Loading branch information
stbaione authored Nov 7, 2024
1 parent 79fe7e2 commit d297718
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
8 changes: 8 additions & 0 deletions shortfin/python/shortfin_apps/llm/components/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@

logger = logging.getLogger(__name__)

PROG_ISOLATIONS = {
isolation.name.lower(): isolation for isolation in sf.ProgramIsolation
}


class GenerateService:
"""Top level service interface for generating text against a model."""
Expand All @@ -34,6 +38,7 @@ def __init__(
sysman: SystemManager,
tokenizer: Tokenizer,
model_params: ModelParams,
program_isolation: str = "per_call",
):
self.name = name

Expand All @@ -53,6 +58,8 @@ def __init__(
devices=self.main_fiber.devices_dict.values(), model_params=model_params
)

self.program_isolation = PROG_ISOLATIONS[program_isolation]

def load_inference_module(self, vmfb_path: Path):
self.inference_modules.append(sf.ProgramModule.load(self.sysman.ls, vmfb_path))

Expand All @@ -75,6 +82,7 @@ def start(self):
+ self.inference_modules,
devices=self.sysman.ls.devices,
trace_execution=False,
isolation=self.program_isolation,
)
# Resolve prefill entrypoints.
self.prefill_functions = {}
Expand Down
14 changes: 13 additions & 1 deletion shortfin/python/shortfin_apps/llm/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import uvicorn.logging

# Import first as it does dep checking and reporting.
from shortfin import ProgramIsolation
from shortfin.interop.fastapi import FastAPIResponder

from contextlib import asynccontextmanager
Expand Down Expand Up @@ -82,7 +83,11 @@ def configure(args) -> SystemManager:
tokenizer = Tokenizer.from_tokenizer_json_file(args.tokenizer_json)
model_params = ModelParams.load_json(args.model_config)
sm = GenerateService(
name="default", sysman=sysman, tokenizer=tokenizer, model_params=model_params
name="default",
sysman=sysman,
tokenizer=tokenizer,
model_params=model_params,
program_isolation=args.isolation,
)
sm.load_inference_module(args.vmfb)
sm.load_inference_parameters(*args.parameters, parameter_scope="model")
Expand Down Expand Up @@ -135,6 +140,13 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG):
default="local-task",
help="Device to serve on; e.g. local-task, hip. Same options as `iree-run-module --device` ",
)
parser.add_argument(
"--isolation",
type=str,
default="per_call",
choices=[isolation.name.lower() for isolation in ProgramIsolation],
help="Concurrency control -- How to isolate programs.",
)
args = parser.parse_args(argv)
global sysman
sysman = configure(args)
Expand Down

0 comments on commit d297718

Please sign in to comment.