Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
# Description Related to issue #428 When we run the Shortfin Server, we currently set isolation for `sf.Program` invocation to `per_fiber`. However, we don't currently have a data structure to manage available fibers. This causes the server to error out when invoking `batch` requests, which was found in SGLang integration testing. By setting `isolation` to `per_call`, we can handle the batch requests effectively, enabling more SGLang features, while we implement the `FiberPool` as part of our `Radix`/`Shared Page Attention` todos. This makes `--isolation` a CLI arg to the LLM server, similar to how it's setup for SD server, defaulting it to PER_CALL. This also makes it easy to switch back-and-forth or switch the default back to `per_fiber` down the road. # Batch Errors In SGLang, we have the option to send requests as a batch, allowing us to execute multiple separate prompts in parallel: ## SGLang Frontend Code ```python @sgl.function def multi_turn_question(s, question_1, question_2): s += sgl.system("You are a helpful assistant.") s += sgl.user(question_1) s += sgl.assistant(sgl.gen("answer_1", max_tokens=256)) s += sgl.user(question_2) s += sgl.assistant(sgl.gen("answer_2", max_tokens=256)) def batch(): states = multi_turn_question.run_batch( [ { "question_1": "What is the capital of the United States?", "question_2": "List two local attractions.", }, { "question_1": "What is the capital of France?", "question_2": "What is the population of this city?", }, ] ) for s in states: for m in s.messages(): print(m["role"], m["content"]) print() print() print("\n========== batch ==========\n") batch() ``` ## Shortfin Error When this code is invoked, with `isolation` set to `per_fiber`, we hit a concurrency error from `attempting concurrent invocations of a PER_FIBER program from the same fiber`: ```bash [2024-11-05 18:16:02.756] [error] [service.py:427] Fatal error in prefetch invocation Traceback (most recent call last): File "/home/stbaione/repos/SHARK-Platform/shortfin/python/shortfin_apps/llm/components/service.py", line 408, in run (logits,) = await fn(*args, fiber=self.fiber) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: Cannot make concurrent invocations of a PER_FIBER program from the same Fiber. This typically means that two invocations were attempted on the same program on the same fiber without an await. Consider fixing adding appropriate sequencing or switching to either PER_CALL or NONE isolation if appropriate for the use case. This exception can also occur if the first invocation to this Program failed, leaving no initialized Program for this fiber. ``` # Solution By setting isolation to `per_call`, we're able to handle the batch requests effectively (still some improvements that can be made in shortfin LLM completion): ## SGLang Batch Invocation ```text ========== batch ========== system You are a helpful assistant. user What is the capital of the United States? assistant Washington, D.C. USER:What is the capital user List two local attractions. assistant List two system You are a helpful assistant. user What is the capital of France? assistant Paris is the capital of France. USER:!!!! user What is the population of this city? assistant Paris has ``` # Considerations There was some concern about using PER_CALL due to the possibility of stateful programs, however, currently we don't have any state that we need to share across different batches & all batches will use different kv cache lines. We should revisit/implement the `FiberPool` specified in #428, but for now, we can lump that into our `Radix`/`Shared Page Attention` todos, enabling more SGLang features in the meantime.
- Loading branch information