-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Create Fiber Pools to Enable Batch Requests to Shortfin LLM Server #428
Comments
stbaione
added a commit
that referenced
this issue
Nov 7, 2024
# Description Related to issue #428 When we run the Shortfin Server, we currently set isolation for `sf.Program` invocation to `per_fiber`. However, we don't currently have a data structure to manage available fibers. This causes the server to error out when invoking `batch` requests, which was found in SGLang integration testing. By setting `isolation` to `per_call`, we can handle the batch requests effectively, enabling more SGLang features, while we implement the `FiberPool` as part of our `Radix`/`Shared Page Attention` todos. This makes `--isolation` a CLI arg to the LLM server, similar to how it's setup for SD server, defaulting it to PER_CALL. This also makes it easy to switch back-and-forth or switch the default back to `per_fiber` down the road. # Batch Errors In SGLang, we have the option to send requests as a batch, allowing us to execute multiple separate prompts in parallel: ## SGLang Frontend Code ```python @sgl.function def multi_turn_question(s, question_1, question_2): s += sgl.system("You are a helpful assistant.") s += sgl.user(question_1) s += sgl.assistant(sgl.gen("answer_1", max_tokens=256)) s += sgl.user(question_2) s += sgl.assistant(sgl.gen("answer_2", max_tokens=256)) def batch(): states = multi_turn_question.run_batch( [ { "question_1": "What is the capital of the United States?", "question_2": "List two local attractions.", }, { "question_1": "What is the capital of France?", "question_2": "What is the population of this city?", }, ] ) for s in states: for m in s.messages(): print(m["role"], m["content"]) print() print() print("\n========== batch ==========\n") batch() ``` ## Shortfin Error When this code is invoked, with `isolation` set to `per_fiber`, we hit a concurrency error from `attempting concurrent invocations of a PER_FIBER program from the same fiber`: ```bash [2024-11-05 18:16:02.756] [error] [service.py:427] Fatal error in prefetch invocation Traceback (most recent call last): File "/home/stbaione/repos/SHARK-Platform/shortfin/python/shortfin_apps/llm/components/service.py", line 408, in run (logits,) = await fn(*args, fiber=self.fiber) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: Cannot make concurrent invocations of a PER_FIBER program from the same Fiber. This typically means that two invocations were attempted on the same program on the same fiber without an await. Consider fixing adding appropriate sequencing or switching to either PER_CALL or NONE isolation if appropriate for the use case. This exception can also occur if the first invocation to this Program failed, leaving no initialized Program for this fiber. ``` # Solution By setting isolation to `per_call`, we're able to handle the batch requests effectively (still some improvements that can be made in shortfin LLM completion): ## SGLang Batch Invocation ```text ========== batch ========== system You are a helpful assistant. user What is the capital of the United States? assistant Washington, D.C. USER:What is the capital user List two local attractions. assistant List two system You are a helpful assistant. user What is the capital of France? assistant Paris is the capital of France. USER:!!!! user What is the population of this city? assistant Paris has ``` # Considerations There was some concern about using PER_CALL due to the possibility of stateful programs, however, currently we don't have any state that we need to share across different batches & all batches will use different kv cache lines. We should revisit/implement the `FiberPool` specified in #428, but for now, we can lump that into our `Radix`/`Shared Page Attention` todos, enabling more SGLang features in the meantime.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
High-Level Summary
Through enabling serving benchmark tests with SGLang, I found a bug with Shortfin LLM Server in relation to batch requests, due to attempted concurrent invocations on the same fiber.
Relevant Error
Proposed Solution
We should maintain a Pool of available fibers in shortfin, where we obtain an available fiber whenever one is needed, and return it to the pool when no longer needed (when idle). This idea comes from this PR comment: #360 (comment)
Reproduction Steps/Further Details
After starting a fresh shortfin server (in this case for GPU)
python -m shortfin_apps.llm.server --tokenizer=/data/llama3.1/8b/tokenizer.json --model_config=../../export_mi300/config.json --vmfb=../../export_mi300/model.vmfb --parameters=/data/llama3.1/8b/llama8b_f16.irpa --device=hip
,SGLang Code
We can send a batch request using SGLang with the following code:
Shortfin Error
Upon receiving this request, the Shortfin server fails, with the following error:
PER_CALL Patch
Locally, if you set the Program isolation arg to
PER_CALL
, the requests run fine:SGLang Result
The SGLang request functionally works:
However, it's potentially dangerous setting isolation to
PER_CALL
since the shortfin server is stateful. What seems to be the correct solution is to implement a pool or simple idle_list data structure to be able to obtain an available thread when needed, and return said thread back to the pool/list when no longer needed, as described above.The text was updated successfully, but these errors were encountered: