Skip to content

Commit

Permalink
feat: gradio integration (#5008)
Browse files Browse the repository at this point in the history
* initial commit

* ci: auto fixes from pre-commit.ci

For more information, see https://pre-commit.ci

* fix for gradio 5

* ci: auto fixes from pre-commit.ci

For more information, see https://pre-commit.ci

* fix: avoid mounting gradio app twice

Signed-off-by: Frost Ming <[email protected]>
  • Loading branch information
parano authored Oct 11, 2024
1 parent 045001c commit 68fc547
Show file tree
Hide file tree
Showing 12 changed files with 207 additions and 1 deletion.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,6 @@ mlruns/
.pdm-python
.python-version
.pdm-build/

# gradio flag
flagged/
25 changes: 25 additions & 0 deletions examples/gradio/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Quickstart

This quickstart demonstrates how to add Gradio web UI to a BentoML service.

## Prerequisites

Python 3.9+ and `pip` installed. See the [Python downloads page](https://www.python.org/downloads/) to learn more.

## Get started

Perform the following steps to run this project and deploy it to BentoCloud.

1. Install the required dependencies:

```bash
pip install -r requirements.txt
```

2. Serve your model as an HTTP server. This starts a local server at [http://localhost:3000](http://localhost:3000/), making your model accessible as a web service.

```bash
bentoml serve .
```

3. Visit http://localhost:3000/ui for gradio UI. BentoML APIs can be found at http://localhost:3000
16 changes: 16 additions & 0 deletions examples/gradio/bentofile.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
service: "service:Summarization"
labels:
project: quickstart
stage: dev
include:
- "service.py"
python:
packages:
- torch
- transformers
- gradio
- pydantic>=2.0
- fastapi
lock_packages: false
docker:
python_version: "3.10"
5 changes: 5 additions & 0 deletions examples/gradio/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
bentoml
torch
transformers
gradio
fastapi
44 changes: 44 additions & 0 deletions examples/gradio/service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from __future__ import annotations # I001

import bentoml

with bentoml.importing():
import gradio as gr
import torch
from transformers import pipeline

EXAMPLE_INPUT = "Breaking News: In an astonishing turn of events, the small \
town of Willow Creek has been taken by storm as local resident Jerry Thompson's cat, \
Whiskers, performed what witnesses are calling a 'miraculous and gravity-defying leap.' \
Eyewitnesses report that Whiskers, an otherwise unremarkable tabby cat, jumped \
a record-breaking 20 feet into the air to catch a fly. The event, which took \
place in Thompson's backyard, is now being investigated by scientists for potential \
breaches in the laws of physics. Local authorities are considering a town festival \
to celebrate what is being hailed as 'The Leap of the Century."


def summarize_text(text: str) -> str:
svc_instance = bentoml.get_current_service()
return svc_instance.summarize([text])[0]


io = gr.Interface(
fn=summarize_text,
inputs=[gr.Textbox(lines=5, label="Enter Text", value=EXAMPLE_INPUT)],
outputs=[gr.Textbox(label="Summary Text")],
title="Summarization",
description="Enter text to get summarized text.",
)


@bentoml.service(resources={"cpu": "4"})
@bentoml.gradio.mount_gradio_app(io, path="/ui")
class Summarization:
def __init__(self) -> None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.pipeline = pipeline("summarization", device=device)

@bentoml.api(batchable=True)
def summarize(self, texts: list[str]) -> list[str]:
results = self.pipeline(texts)
return [item["summary_text"] for item in results]
2 changes: 1 addition & 1 deletion examples/quickstart/bentofile.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ python:
- torch
- transformers
docker:
python_version: 3.11
python_version: "3.10"
1 change: 1 addition & 0 deletions src/_bentoml_impl/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ async def create_instance(self, app: Starlette) -> None:
from ..client import RemoteProxy

self._service_instance = self.service()
self.service.gradio_app_startup_hook(max_concurrency=self.max_concurrency)
logger.info("Service %s initialized", self.service.name)
if deployment_url := os.getenv("BENTOCLOUD_DEPLOYMENT_URL"):
proxy = RemoteProxy(
Expand Down
2 changes: 2 additions & 0 deletions src/_bentoml_sdk/_internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def make_fastapi_class_views(cls: type[Any], app: FastAPI) -> None:
if isinstance(route, (APIRoute, APIWebSocketRoute))
and route.endpoint in class_methods
]
if not api_routes:
return
# Modify these routes and mount it to a new APIRouter.
# We need to to this (instead of modifying in place) because we want to use
# the app.include_router to re-run the dependency analysis for each routes.
Expand Down
89 changes: 89 additions & 0 deletions src/_bentoml_sdk/gradio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from __future__ import annotations

import typing as t
from pathlib import Path

from bentoml.exceptions import MissingDependencyException

from .decorators import mount_asgi_app

R = t.TypeVar("R")

try:
import gradio as gr
from gradio import Blocks
except ImportError: # pragma: no cover
raise MissingDependencyException(
"""'gradio' is required by 'bentoml.gradio', run 'pip install -U gradio' to install gradio.""",
)


def mount_gradio_app(blocks: Blocks, path: str, name: str = "gradio_ui"):
"""Mount a Gradio app to a BentoML service.
Args:
blocks: The Gradio blocks to be mounted.
path: The URL path to mount the Gradio app.
name: The name of the Gradio app.
Example:
Both of the following examples are allowed::
@bentoml.gradio.mount_gradio_app(blocks)
@bentoml.service()
class MyService: ...
@bentoml.service()
@bentoml.gradio.mount_gradio_app(blocks)
class MyService: ...
"""
from _bentoml_impl import server
from _bentoml_sdk.service import Service

favicon_path = (
Path(server.__file__).parent / "static_content" / "favicon-light-32x32.png"
)
assert path.startswith("/"), "Routed paths must start with '/'"
path = path.rstrip("/")

def decorator(obj: R) -> R:
blocks.dev_mode = False
blocks.show_error = True
blocks.validate_queue_settings()
blocks.root_path = path
blocks.favicon_path = favicon_path
gradio_app = gr.routes.App.create_app(blocks, app_kwargs={"root_path": path})
mount_asgi_app(gradio_app, path=path, name=name)(obj)

# @bentoml.service() decorator returns a wrapper instead of the original class
# Check if the object is an instance of Service
if isinstance(obj, Service):
# For scenario:
#
# @bentoml.gradio.mount_gradio_app(..)
# @bentoml.service(..)
# class MyService: ...
#
# If the Service instance is already created, mount the ASGI app immediately
target = obj.inner
else:
# For scenario:
#
# @bentoml.service(..)
# @bentoml.gradio.mount_gradio_app(..)
# class MyService: ...
#
# If the Service instance is not yet created, mark the Gradio app info
# for later mounting during Service instance initialization
target = obj

# Store Gradio app information for ASGI app mounting and startup event callback
gradio_apps = getattr(target, "__bentoml_gradio_apps__", [])
gradio_apps.append((gradio_app, path, name))
setattr(target, "__bentoml_gradio_apps__", gradio_apps)

return obj

return decorator
15 changes: 15 additions & 0 deletions src/_bentoml_sdk/service/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,21 @@ def add_asgi_middleware(
) -> None:
self.middlewares.append((middleware_cls, options))

def gradio_app_startup_hook(self, max_concurrency: int):
gradio_apps = getattr(self.inner, "__bentoml_gradio_apps__", [])
if gradio_apps:
for gradio_app, path, _ in gradio_apps:
logger.info(f"Initializing gradio app at: {path or '/'}")
blocks = gradio_app.get_blocks()
blocks.queue(default_concurrency_limit=max_concurrency)
if hasattr(blocks, "startup_events"):
# gradio < 5.0
blocks.startup_events()
else:
# gradio >= 5.0
blocks.run_startup_events()
delattr(self.inner, "__bentoml_gradio_apps__")

def __call__(self) -> T:
try:
instance = self.inner()
Expand Down
3 changes: 3 additions & 0 deletions src/bentoml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
pass

from . import diffusers_simple
from . import gradio
from . import ray
from . import triton
from ._internal.frameworks import detectron
Expand Down Expand Up @@ -183,6 +184,7 @@
# Integrations
triton = _LazyLoader("bentoml.triton", globals(), "bentoml.triton")
ray = _LazyLoader("bentoml.ray", globals(), "bentoml.ray")
gradio = _LazyLoader("bentoml.gradio", globals(), "bentoml.gradio")

io = _LazyLoader("bentoml.io", globals(), "bentoml.io")
batch = _LazyLoader("bentoml.batch", globals(), "bentoml.batch")
Expand Down Expand Up @@ -294,6 +296,7 @@ def __getattr__(name: str) -> Any:
"xgboost",
# integrations
"ray",
"gradio",
"cloud",
"deployment",
"triton",
Expand Down
3 changes: 3 additions & 0 deletions src/bentoml/gradio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from _bentoml_sdk.gradio import mount_gradio_app

__all__ = ["mount_gradio_app"]

0 comments on commit 68fc547

Please sign in to comment.