Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add prefix caching #581

Merged
merged 6 commits into from
Aug 21, 2024
Merged

Add prefix caching #581

merged 6 commits into from
Aug 21, 2024

Conversation

tgaddair
Copy link
Contributor

@tgaddair tgaddair commented Aug 21, 2024

Usage:

lorax-launcher --prefix-caching ...

Note that this relies on flashinfer, which is not yet baked into the prebuilt Docker.

Adapted from huggingface/text-generation-inference#2402

@tgaddair tgaddair marked this pull request as ready for review August 21, 2024 21:27
@tgaddair tgaddair merged commit 1d2b514 into main Aug 21, 2024
2 checks passed
@tgaddair tgaddair deleted the prefix-cache-flash branch August 21, 2024 21:34
@prd-tuong-nguyen
Copy link

@tgaddair hi, do you have any document related to this feature?

@tgaddair
Copy link
Contributor Author

tgaddair commented Aug 23, 2024

@prd-tuong-nguyen not yet, but it's coming! For now, I would say this is an experimental feature that can be helpful when you want to ask multiple questions over the same large input (like a long document):

Here's a blog post going over the technical details of how it works: https://flashinfer.ai/2024/02/02/cascade-inference.html

@prd-tuong-nguyen
Copy link

@tgaddair Yeah, thank you. It's an awesome feature bro. I think it will save a significant amount of time for making inferences

@prd-tuong-nguyen
Copy link

Hi @tgaddair
I encountered this error while using the prefix-caching feature. Can you take a look?

  • With Open-Orca/Mistral-7B-OpenOrca
{"timestamp":"2024-09-09T04:36:06.485791Z","level":"ERROR","fields":{"message":"interceptor.py:41 Method Warmup encountered an error.\nTraceback (most recent call last):\n  File \"/opt/conda/bin/lorax-server\", line 8, in <module>\n    sys.exit(app())\n  File \"/opt/conda/lib/python3.10/site-packages/typer/main.py\", line 311, in __call__\n    return get_command(self)(*args, **kwargs)\n  File \"/opt/conda/lib/python3.10/site-packages/click/core.py\", line 1157, in __call__\n    return self.main(*args, **kwargs)\n  File \"/opt/conda/lib/python3.10/site-packages/typer/core.py\", line 778, in main\n    return _main(\n  File \"/opt/conda/lib/python3.10/site-packages/typer/core.py\", line 216, in _main\n    rv = self.invoke(ctx)\n  File \"/opt/conda/lib/python3.10/site-packages/click/core.py\", line 1688, in invoke\n    return _process_result(sub_ctx.command.invoke(sub_ctx))\n  File \"/opt/conda/lib/python3.10/site-packages/click/core.py\", line 1434, in invoke\n    return ctx.invoke(self.callback, **ctx.params)\n  File \"/opt/conda/lib/python3.10/site-packages/click/core.py\", line 783, in invoke\n    return __callback(*args, **kwargs)\n  File \"/opt/conda/lib/python3.10/site-packages/typer/main.py\", line 683, in wrapper\n    return callback(**use_params)  # type: ignore\n  File \"/opt/conda/lib/python3.10/site-packages/lorax_server/cli.py\", line 87, in serve\n    server.serve(\n  File \"/opt/conda/lib/python3.10/site-packages/lorax_server/server.py\", line 408, in serve\n    asyncio.run(\n  File \"/opt/conda/lib/python3.10/asyncio/runners.py\", line 44, in run\n    return loop.run_until_complete(main)\n  File \"/opt/conda/lib/python3.10/asyncio/base_events.py\", line 636, in run_until_complete\n    self.run_forever()\n  File \"/opt/conda/lib/python3.10/asyncio/base_events.py\", line 603, in run_forever\n    self._run_once()\n  File \"/opt/conda/lib/python3.10/asyncio/base_events.py\", line 1909, in _run_once\n    handle._run()\n  File \"/opt/conda/lib/python3.10/asyncio/events.py\", line 80, in _run\n    self._context.run(self._callback, *self._args)\n  File \"/opt/conda/lib/python3.10/site-packages/grpc_interceptor/server.py\", line 165, in invoke_intercept_method\n    return await self.intercept(\n> File \"/opt/conda/lib/python3.10/site-packages/lorax_server/interceptor.py\", line 38, in intercept\n    return await response\n  File \"/opt/conda/lib/python3.10/site-packages/opentelemetry/instrumentation/grpc/_aio_server.py\", line 82, in _unary_interceptor\n    raise error\n  File \"/opt/conda/lib/python3.10/site-packages/opentelemetry/instrumentation/grpc/_aio_server.py\", line 73, in _unary_interceptor\n    return await behavior(request_or_iterator, context)\n  File \"/opt/conda/lib/python3.10/site-packages/lorax_server/server.py\", line 86, in Warmup\n    max_supported_total_tokens = self.model.warmup(batch, request.max_new_tokens)\n  File \"/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_causal_lm.py\", line 895, in warmup\n    _, batch = self.generate_token(batch, is_warmup=True)\n  File \"/opt/conda/lib/python3.10/contextlib.py\", line 79, in inner\n    return func(*args, **kwds)\n  File \"/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_causal_lm.py\", line 1156, in generate_token\n    out, speculative_logits = self.forward(batch, adapter_data)\n  File \"/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_causal_lm.py\", line 1092, in forward\n    out = model.forward(\n  File \"/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_mistral_modeling.py\", line 600, in forward\n    hidden_states = self.model(\n  File \"/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1553, in _wrapped_call_impl\n    return self._call_impl(*args, **kwargs)\n  File \"/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1562, in _call_impl\n    return forward_call(*args, **kwargs)\n  File \"/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_mistral_modeling.py\", line 525, in forward\n    hidden_states, residual = layer(\n  File \"/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1553, in _wrapped_call_impl\n    return self._call_impl(*args, **kwargs)\n  File \"/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1562, in _call_impl\n    return forward_call(*args, **kwargs)\n  File \"/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_mistral_modeling.py\", line 456, in forward\n    attn_output = self.self_attn(\n  File \"/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1553, in _wrapped_call_impl\n    return self._call_impl(*args, **kwargs)\n  File \"/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1562, in _call_impl\n    return forward_call(*args, **kwargs)\n  File \"/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_mistral_modeling.py\", line 340, in forward\n    attn_output = flash_attn.attention(\n  File \"/opt/conda/lib/python3.10/site-packages/lorax_server/utils/flash_attn.py\", line 131, in attention\n    assert window_size_left == -1, \"Windowing is not supported with flash infer when using kv cache\"\nAssertionError: Windowing is not supported with flash infer when using kv cache\n"},"target":"lorax_launcher"}
  • With Phi3:
{"timestamp":"2024-09-09T04:32:36.679869Z","level":"ERROR","fields":{"message":"interceptor.py:41 Method Warmup encountered an error.\nTraceback (most recent call last):\n  File \"/opt/conda/bin/lorax-server\", line 8, in <module>\n    sys.exit(app())\n  File \"/opt/conda/lib/python3.10/site-packages/typer/main.py\", line 311, in __call__\n    return get_command(self)(*args, **kwargs)\n  File \"/opt/conda/lib/python3.10/site-packages/click/core.py\", line 1157, in __call__\n    return self.main(*args, **kwargs)\n  File \"/opt/conda/lib/python3.10/site-packages/typer/core.py\", line 778, in main\n    return _main(\n  File \"/opt/conda/lib/python3.10/site-packages/typer/core.py\", line 216, in _main\n    rv = self.invoke(ctx)\n  File \"/opt/conda/lib/python3.10/site-packages/click/core.py\", line 1688, in invoke\n    return _process_result(sub_ctx.command.invoke(sub_ctx))\n  File \"/opt/conda/lib/python3.10/site-packages/click/core.py\", line 1434, in invoke\n    return ctx.invoke(self.callback, **ctx.params)\n  File \"/opt/conda/lib/python3.10/site-packages/click/core.py\", line 783, in invoke\n    return __callback(*args, **kwargs)\n  File \"/opt/conda/lib/python3.10/site-packages/typer/main.py\", line 683, in wrapper\n    return callback(**use_params)  # type: ignore\n  File \"/opt/conda/lib/python3.10/site-packages/lorax_server/cli.py\", line 87, in serve\n    server.serve(\n  File \"/opt/conda/lib/python3.10/site-packages/lorax_server/server.py\", line 408, in serve\n    asyncio.run(\n  File \"/opt/conda/lib/python3.10/asyncio/runners.py\", line 44, in run\n    return loop.run_until_complete(main)\n  File \"/opt/conda/lib/python3.10/asyncio/base_events.py\", line 636, in run_until_complete\n    self.run_forever()\n  File \"/opt/conda/lib/python3.10/asyncio/base_events.py\", line 603, in run_forever\n    self._run_once()\n  File \"/opt/conda/lib/python3.10/asyncio/base_events.py\", line 1909, in _run_once\n    handle._run()\n  File \"/opt/conda/lib/python3.10/asyncio/events.py\", line 80, in _run\n    self._context.run(self._callback, *self._args)\n  File \"/opt/conda/lib/python3.10/site-packages/grpc_interceptor/server.py\", line 165, in invoke_intercept_method\n    return await self.intercept(\n> File \"/opt/conda/lib/python3.10/site-packages/lorax_server/interceptor.py\", line 38, in intercept\n    return await response\n  File \"/opt/conda/lib/python3.10/site-packages/opentelemetry/instrumentation/grpc/_aio_server.py\", line 82, in _unary_interceptor\n    raise error\n  File \"/opt/conda/lib/python3.10/site-packages/opentelemetry/instrumentation/grpc/_aio_server.py\", line 73, in _unary_interceptor\n    return await behavior(request_or_iterator, context)\n  File \"/opt/conda/lib/python3.10/site-packages/lorax_server/server.py\", line 86, in Warmup\n    max_supported_total_tokens = self.model.warmup(batch, request.max_new_tokens)\n  File \"/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_causal_lm.py\", line 895, in warmup\n    _, batch = self.generate_token(batch, is_warmup=True)\n  File \"/opt/conda/lib/python3.10/contextlib.py\", line 79, in inner\n    return func(*args, **kwds)\n  File \"/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_causal_lm.py\", line 1156, in generate_token\n    out, speculative_logits = self.forward(batch, adapter_data)\n  File \"/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_causal_lm.py\", line 1092, in forward\n    out = model.forward(\n  File \"/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_phi3_modeling.py\", line 507, in forward\n    hidden_states = self.model(\n  File \"/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1553, in _wrapped_call_impl\n    return self._call_impl(*args, **kwargs)\n  File \"/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1562, in _call_impl\n    return forward_call(*args, **kwargs)\n  File \"/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_phi3_modeling.py\", line 457, in forward\n    hidden_states, residual = layer(\n  File \"/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1553, in _wrapped_call_impl\n    return self._call_impl(*args, **kwargs)\n  File \"/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1562, in _call_impl\n    return forward_call(*args, **kwargs)\n  File \"/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_phi3_modeling.py\", line 390, in forward\n    attn_output = self.self_attn(\n  File \"/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1553, in _wrapped_call_impl\n    return self._call_impl(*args, **kwargs)\n  File \"/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1562, in _call_impl\n    return forward_call(*args, **kwargs)\n  File \"/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_phi3_modeling.py\", line 279, in forward\n    attn_output = flash_attn.attention(\n  File \"/opt/conda/lib/python3.10/site-packages/lorax_server/utils/flash_attn.py\", line 146, in attention\n    return prefill_with_paged_kv_state.get().forward(\n  File \"/opt/conda/lib/python3.10/site-packages/flashinfer/prefill.py\", line 914, in forward\n    out = self._wrapper.forward(\nRuntimeError: BatchPrefillWithPagedKVCachePyTorchWrapper::Forward(at::Tensor, at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>, std::optional<at::Tensor>, at::Tensor, at::Tensor, at::Tensor, bool, unsigned int, bool, int, float, float, float, float, bool)::<lambda()>::<lambda()>::<lambda()>::<lambda()>::<lambda()> failed to dispatch head_dim 96\n"},"target":"lorax_launcher"}

@tgaddair
Copy link
Contributor Author

Thanks for reporting @prd-tuong-nguyen!

The second issue is known and being tracked by the FlashInfer team: flashinfer-ai/flashinfer#455

It looks like there's a workaround I can look to add, however.

The first issue is new. Can you file an issue and I'll take a look?

@prd-tuong-nguyen
Copy link

@tgaddair Thanks! Here is the issue:#599

@OlivierDehaene
Copy link
Contributor

Hey predibase team, nice PR, cool to see that you're using some of TGI's latest features!

We appreciate the acknowledgements in the README, but we were wondering if you could add attribution somewhere/link to the original PR when you adapt these features (or mention in the README that it's based on the latest TGI, not 0.9.4!).

Thanks a lot for your work on lorax!

@tgaddair
Copy link
Contributor Author

Hey @OlivierDehaene , apologies for that. More than happy to reference specific PRs when pulling upstream changes going forward!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants