-
Notifications
You must be signed in to change notification settings - Fork 70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add setup_environment
to truss
#1188
base: main
Are you sure you want to change the base?
Changes from 19 commits
114c0f9
e3d97d1
b2ea022
8a6715d
3c0d4cb
97fd7a8
89ffea5
e54e739
84d5339
ebea840
b929f6e
8898caf
5ae6f34
e0b2c31
16a6639
e30b5c4
336fd95
3bae49a
23aa0e0
9e0c6ec
235bd53
c6013e8
df9d61b
fd5d479
1d0e9cb
7e283e4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
import importlib | ||
import importlib.util | ||
import inspect | ||
import json | ||
import logging | ||
import os | ||
import pathlib | ||
|
@@ -35,7 +36,7 @@ | |
from common.schema import TrussSchema | ||
from opentelemetry import trace | ||
from pydantic import BaseModel | ||
from shared import serialization | ||
from shared import dynamic_config_resolver, serialization | ||
from shared.lazy_data_resolver import LazyDataResolver | ||
from shared.secrets_resolver import SecretsResolver | ||
|
||
|
@@ -53,6 +54,7 @@ | |
EXTENSION_CLASS_NAME = "Extension" | ||
EXTENSION_FILE_NAME = "extension" | ||
TRT_LLM_EXTENSION_NAME = "trt_llm" | ||
POLL_FOR_ENVIRONMENT_UPDATES_TIMEOUT_SECS = 30 | ||
|
||
|
||
@asynccontextmanager | ||
|
@@ -191,6 +193,7 @@ class ModelDescriptor: | |
predict: MethodDescriptor | ||
postprocess: Optional[MethodDescriptor] | ||
truss_schema: Optional[TrussSchema] | ||
setup_environment: Optional[MethodDescriptor] | ||
|
||
@cached_property | ||
def skip_input_parsing(self) -> bool: | ||
|
@@ -243,11 +246,19 @@ def from_model(cls, model) -> "ModelDescriptor": | |
else: | ||
return_annotation = inspect.signature(model.predict).return_annotation | ||
|
||
if hasattr(model, "setup_environment"): | ||
setup_environment = MethodDescriptor.from_method( | ||
model.setup_environment, "setup_environment" | ||
) | ||
else: | ||
setup_environment = None | ||
|
||
return cls( | ||
preprocess=preprocess, | ||
predict=predict, | ||
postprocess=postprocess, | ||
truss_schema=TrussSchema.from_signature(parameters, return_annotation), | ||
setup_environment=setup_environment, | ||
) | ||
|
||
|
||
|
@@ -259,6 +270,8 @@ class ModelWrapper: | |
_logger: logging.Logger | ||
_status: "ModelWrapper.Status" | ||
_predict_semaphore: Semaphore | ||
_poll_for_environment_updates_task: Optional[Any] | ||
spal1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
_environment: Optional[dict] | ||
|
||
class Status(Enum): | ||
NOT_READY = 0 | ||
|
@@ -280,6 +293,8 @@ def __init__(self, config: Dict, tracer: sdk_trace.Tracer): | |
"predict_concurrency", DEFAULT_PREDICT_CONCURRENCY | ||
) | ||
) | ||
self._poll_for_environment_updates_task = None | ||
self._environment = None | ||
|
||
@property | ||
def _model(self) -> Any: | ||
|
@@ -419,6 +434,15 @@ def _load_impl(self): | |
|
||
self._maybe_model_descriptor = ModelDescriptor.from_model(self._model) | ||
|
||
if hasattr(self._model, "setup_environment"): | ||
spal1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
environment_str = dynamic_config_resolver.get_dynamic_config_value_sync( | ||
spal1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
dynamic_config_resolver.ENVIRONMENT_DYNAMIC_CONFIG_KEY | ||
) | ||
if environment_str: | ||
environment_json = json.loads(environment_str) | ||
self._model.setup_environment(environment_json) | ||
self._environment = environment_json | ||
|
||
if hasattr(self._model, "load"): | ||
retry( | ||
spal1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self._model.load, | ||
|
@@ -428,6 +452,62 @@ def _load_impl(self): | |
gap_seconds=1.0, | ||
) | ||
|
||
def setup_polling_for_environment_updates(self): | ||
self._poll_for_environment_updates_task = asyncio.create_task( | ||
self.poll_for_environment_updates() | ||
) | ||
|
||
async def setup_environment(self, environment: dict): | ||
descriptor = self.model_descriptor.setup_environment | ||
if not descriptor: | ||
return | ||
self._logger.info( | ||
f"Executing model.setup_environment with new environment: {environment}" | ||
spal1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
if descriptor.is_async: | ||
return await self._model.setup_environment(environment) | ||
else: | ||
return await to_thread.run_sync(self._model.setup_environment, environment) | ||
|
||
async def poll_for_environment_updates(self) -> None: | ||
last_modified_time = None | ||
environment_config_filename = ( | ||
await dynamic_config_resolver.get_dynamic_config_file_path_async( | ||
dynamic_config_resolver.ENVIRONMENT_DYNAMIC_CONFIG_KEY | ||
) | ||
) | ||
|
||
while True: | ||
# Wait for load to finish before checking for environment updates | ||
if not self.ready: | ||
await asyncio.sleep(POLL_FOR_ENVIRONMENT_UPDATES_TIMEOUT_SECS) | ||
spal1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
continue | ||
|
||
# Skip polling if no setup_environment implementation provided | ||
if not self.model_descriptor.setup_environment: | ||
self._logger.info("No model.setup_environment definition provided") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we need to log here, if the user didn't provide a setup_environment method, I don't think this is the right way to communicate this to them. On the other hand, if they did provide a method, it might be good to log something indicating that we will poll for changes to the env. |
||
break | ||
|
||
if environment_config_filename.exists(): | ||
try: | ||
current_mtime = os.path.getmtime(environment_config_filename) | ||
if not last_modified_time or last_modified_time != current_mtime: | ||
environment_str = await dynamic_config_resolver.get_dynamic_config_value_async( | ||
dynamic_config_resolver.ENVIRONMENT_DYNAMIC_CONFIG_KEY | ||
) | ||
if environment_str: | ||
last_modified_time = current_mtime | ||
environment_json = json.loads(environment_str) | ||
# Avoid rerunning `setup_environment` with the same environment | ||
if self._environment != environment_json: | ||
await self.setup_environment(environment_json) | ||
self._environment = environment_json | ||
except Exception as e: | ||
logging.error( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it shows up nicer if you do:
|
||
f"An error occurred while polling for environment updates: {e}" | ||
) | ||
await asyncio.sleep(POLL_FOR_ENVIRONMENT_UPDATES_TIMEOUT_SECS) | ||
|
||
async def preprocess( | ||
self, | ||
inputs: serialization.InputType, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,3 +17,4 @@ pyyaml==6.0.0 | |
requests==2.31.0 | ||
uvicorn==0.24.0 | ||
uvloop==0.19.0 | ||
aiofiles==24.1.0 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,28 @@ | ||
from pathlib import Path | ||
from typing import Optional | ||
|
||
import aiofiles | ||
|
||
DYNAMIC_CONFIG_MOUNT_DIR = "/etc/b10_dynamic_config" | ||
ENVIRONMENT_DYNAMIC_CONFIG_KEY = "environment" | ||
|
||
|
||
def get_dynamic_config_value(key: str) -> Optional[str]: | ||
def get_dynamic_config_value_sync(key: str) -> Optional[str]: | ||
dynamic_config_path = Path(DYNAMIC_CONFIG_MOUNT_DIR) / key | ||
if dynamic_config_path.exists() and dynamic_config_path.is_file(): | ||
if dynamic_config_path.exists(): | ||
with dynamic_config_path.open() as dynamic_config_file: | ||
dynamic_config_value = dynamic_config_file.read() | ||
return dynamic_config_value | ||
return dynamic_config_file.read() | ||
return None | ||
|
||
|
||
async def get_dynamic_config_file_path_async(key: str): | ||
spal1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
dynamic_config_path = Path(DYNAMIC_CONFIG_MOUNT_DIR) / key | ||
return dynamic_config_path | ||
|
||
|
||
async def get_dynamic_config_value_async(key: str) -> Optional[str]: | ||
dynamic_config_path = await get_dynamic_config_file_path_async(key) | ||
if dynamic_config_path.exists(): | ||
async with aiofiles.open(dynamic_config_path, "r") as dynamic_config_file: | ||
return await dynamic_config_file.read() | ||
return None |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
model_name: Test Loaf Failure | ||
model_name: Test Load Failure | ||
python_version: py39 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we use a less specific version here? (maybe "^24.1.0")