Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add setup_environment to truss #1188

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
495 changes: 275 additions & 220 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "truss"
version = "0.9.44"
version = "0.9.45rc008"
description = "A seamless bridge from model development to model delivery"
license = "MIT"
readme = "README.md"
Expand All @@ -27,6 +27,7 @@ packages = [
"Baseten" = "https://baseten.co"

[tool.poetry.dependencies]
aiofiles = "^24.1.0"
blake3 = "^0.3.3"
boto3 = "^1.34.85"
fastapi = ">=0.109.1"
Expand Down Expand Up @@ -96,6 +97,7 @@ pytest = "7.2.0"
pytest-cov = "^3.0.0"
types-PyYAML = "^6.0.12.12"
types-setuptools = "^69.0.0.0"
types-aiofiles = "^24.1.0.20240626"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use a less specific version here? (maybe "^24.1.0")


[tool.poetry.scripts]
truss = 'truss.cli:truss_cli'
Expand Down
2 changes: 1 addition & 1 deletion truss-chains/truss_chains/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def override_chainlet_to_service_metadata(
chainlet_to_service: Mapping[str, definitions.ServiceDescriptor],
):
# Override predict_urls in chainlet_to_service ServiceDescriptors if dynamic_chainlet_config exists
dynamic_chainlet_config_str = dynamic_config_resolver.get_dynamic_config_value(
dynamic_chainlet_config_str = dynamic_config_resolver.get_dynamic_config_value_sync(
definitions.DYNAMIC_CHAINLET_CONFIG_KEY
)
if dynamic_chainlet_config_str:
Expand Down
12 changes: 12 additions & 0 deletions truss/local/local_config_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,18 @@ def bptr_data_resolution_dir_path():
bptr_data_dir.mkdir(exist_ok=True, parents=True)
return bptr_data_dir

@staticmethod
def dynamic_config_path():
dynamic_config_dir = LocalConfigHandler.TRUSS_CONFIG_DIR / "b10_dynamic_config"
dynamic_config_dir.mkdir(exist_ok=True, parents=True)
return dynamic_config_dir

@staticmethod
def set_dynamic_config(key: str, value: str):
key_path = LocalConfigHandler.dynamic_config_path() / key
with key_path.open("w") as key_file:
key_file.write(value)

@staticmethod
def _signatures_dir_path():
return LocalConfigHandler.TRUSS_CONFIG_DIR / "signatures"
Expand Down
85 changes: 84 additions & 1 deletion truss/templates/server/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import importlib
import importlib.util
import inspect
import json
import logging
import os
import pathlib
Expand Down Expand Up @@ -35,7 +36,7 @@
from common.schema import TrussSchema
from opentelemetry import trace
from pydantic import BaseModel
from shared import serialization
from shared import dynamic_config_resolver, serialization
from shared.lazy_data_resolver import LazyDataResolver
from shared.secrets_resolver import SecretsResolver

Expand All @@ -53,6 +54,7 @@
EXTENSION_CLASS_NAME = "Extension"
EXTENSION_FILE_NAME = "extension"
TRT_LLM_EXTENSION_NAME = "trt_llm"
POLL_FOR_ENVIRONMENT_UPDATES_TIMEOUT_SECS = 30


@asynccontextmanager
Expand Down Expand Up @@ -191,6 +193,7 @@ class ModelDescriptor:
predict: MethodDescriptor
postprocess: Optional[MethodDescriptor]
truss_schema: Optional[TrussSchema]
setup_environment: Optional[MethodDescriptor]

@cached_property
def skip_input_parsing(self) -> bool:
Expand Down Expand Up @@ -243,11 +246,19 @@ def from_model(cls, model) -> "ModelDescriptor":
else:
return_annotation = inspect.signature(model.predict).return_annotation

if hasattr(model, "setup_environment"):
setup_environment = MethodDescriptor.from_method(
model.setup_environment, "setup_environment"
)
else:
setup_environment = None

return cls(
preprocess=preprocess,
predict=predict,
postprocess=postprocess,
truss_schema=TrussSchema.from_signature(parameters, return_annotation),
setup_environment=setup_environment,
)


Expand All @@ -259,6 +270,8 @@ class ModelWrapper:
_logger: logging.Logger
_status: "ModelWrapper.Status"
_predict_semaphore: Semaphore
_poll_for_environment_updates_task: Optional[asyncio.Task]
_environment: Optional[dict]

class Status(Enum):
NOT_READY = 0
Expand All @@ -280,6 +293,8 @@ def __init__(self, config: Dict, tracer: sdk_trace.Tracer):
"predict_concurrency", DEFAULT_PREDICT_CONCURRENCY
)
)
self._poll_for_environment_updates_task = None
self._environment = None

@property
def _model(self) -> Any:
Expand Down Expand Up @@ -419,6 +434,9 @@ def _load_impl(self):

self._maybe_model_descriptor = ModelDescriptor.from_model(self._model)

if self._maybe_model_descriptor.setup_environment:
self._initialize_environment_before_load()

if hasattr(self._model, "load"):
retry(
spal1 marked this conversation as resolved.
Show resolved Hide resolved
self._model.load,
Expand All @@ -428,6 +446,71 @@ def _load_impl(self):
gap_seconds=1.0,
)

def setup_polling_for_environment_updates(self):
self._poll_for_environment_updates_task = asyncio.create_task(
self.poll_for_environment_updates()
)

def _initialize_environment_before_load(self):
environment_str = dynamic_config_resolver.get_dynamic_config_value_sync(
dynamic_config_resolver.ENVIRONMENT_DYNAMIC_CONFIG_KEY
)
if environment_str:
environment_json = json.loads(environment_str)
self._model.setup_environment(environment_json)
self._environment = environment_json

async def setup_environment(self, environment: Optional[dict]):
descriptor = self.model_descriptor.setup_environment
if not descriptor:
return
self._logger.info(
f"Executing model.setup_environment with new environment: {environment}"
spal1 marked this conversation as resolved.
Show resolved Hide resolved
)
if descriptor.is_async:
return await self._model.setup_environment(environment)
else:
return await to_thread.run_sync(self._model.setup_environment, environment)

async def poll_for_environment_updates(self) -> None:
last_modified_time = None
environment_config_filename = (
dynamic_config_resolver.get_dynamic_config_file_path(
dynamic_config_resolver.ENVIRONMENT_DYNAMIC_CONFIG_KEY
)
)

while True:
# Give control back to the event loop while waiting for environment updates
await asyncio.sleep(POLL_FOR_ENVIRONMENT_UPDATES_TIMEOUT_SECS)

# Wait for load to finish before checking for environment updates
if not self.ready:
continue

# Skip polling if no setup_environment implementation provided
if not self.model_descriptor.setup_environment:
break

if environment_config_filename.exists():
try:
current_mtime = os.path.getmtime(environment_config_filename)
if not last_modified_time or last_modified_time != current_mtime:
environment_str = await dynamic_config_resolver.get_dynamic_config_value_async(
dynamic_config_resolver.ENVIRONMENT_DYNAMIC_CONFIG_KEY
)
if environment_str:
last_modified_time = current_mtime
environment_json = json.loads(environment_str)
# Avoid rerunning `setup_environment` with the same environment
if self._environment != environment_json:
await self.setup_environment(environment_json)
self._environment = environment_json
except Exception as e:
logging.error(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it shows up nicer if you do:

logging.error("blah blah", exc_info=e)

f"An error occurred while polling for environment updates: {e}"
)

async def preprocess(
self,
inputs: serialization.InputType,
Expand Down
1 change: 1 addition & 0 deletions truss/templates/server/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ pyyaml==6.0.0
requests==2.31.0
uvicorn==0.24.0
uvloop==0.19.0
aiofiles==24.1.0
1 change: 1 addition & 0 deletions truss/templates/server/truss_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ def on_startup(self):
if self._setup_json_logger:
setup_logging()
self._model.start_load_thread()
self._model.setup_polling_for_environment_updates()

def create_application(self):
app = FastAPI(
Expand Down
23 changes: 19 additions & 4 deletions truss/templates/shared/dynamic_config_resolver.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,28 @@
from pathlib import Path
from typing import Optional

import aiofiles

DYNAMIC_CONFIG_MOUNT_DIR = "/etc/b10_dynamic_config"
ENVIRONMENT_DYNAMIC_CONFIG_KEY = "environment"


def get_dynamic_config_value(key: str) -> Optional[str]:
def get_dynamic_config_value_sync(key: str) -> Optional[str]:
dynamic_config_path = Path(DYNAMIC_CONFIG_MOUNT_DIR) / key
if dynamic_config_path.exists() and dynamic_config_path.is_file():
if dynamic_config_path.exists():
with dynamic_config_path.open() as dynamic_config_file:
dynamic_config_value = dynamic_config_file.read()
return dynamic_config_value
return dynamic_config_file.read()
return None


def get_dynamic_config_file_path(key: str):
dynamic_config_path = Path(DYNAMIC_CONFIG_MOUNT_DIR) / key
return dynamic_config_path


async def get_dynamic_config_value_async(key: str) -> Optional[str]:
dynamic_config_path = get_dynamic_config_file_path(key)
if dynamic_config_path.exists():
async with aiofiles.open(dynamic_config_path, "r") as dynamic_config_file:
return await dynamic_config_file.read()
return None
2 changes: 1 addition & 1 deletion truss/test_data/model_load_failure_test/config.yaml
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
model_name: Test Loaf Failure
model_name: Test Load Failure
python_version: py39
102 changes: 97 additions & 5 deletions truss/tests/templates/core/server/test_dynamic_config_resolver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import json

import aiofiles
import pytest
from truss.templates.shared.dynamic_config_resolver import get_dynamic_config_value
from truss.templates.shared import dynamic_config_resolver

from truss_chains import definitions

Expand All @@ -18,17 +19,108 @@
"",
],
)
def test_get_dynamic_config_value(config, tmp_path, dynamic_config_mount_dir):
def test_get_dynamic_chainlet_config_value_sync(
config, tmp_path, dynamic_config_mount_dir
):
with (tmp_path / definitions.DYNAMIC_CHAINLET_CONFIG_KEY).open("w") as f:
f.write(json.dumps(config))
chainlet_service_config = get_dynamic_config_value(
chainlet_service_config = dynamic_config_resolver.get_dynamic_config_value_sync(
definitions.DYNAMIC_CHAINLET_CONFIG_KEY
)
assert json.loads(chainlet_service_config) == config


def test_get_missing_config_value(dynamic_config_mount_dir):
chainlet_service_config = get_dynamic_config_value(
@pytest.mark.parametrize(
"config",
[
{
"environment_name": "production",
"foo": "bar",
},
{},
"",
None,
],
)
def test_get_dynamic_config_environment_value_sync(
config, tmp_path, dynamic_config_mount_dir
):
with (tmp_path / dynamic_config_resolver.ENVIRONMENT_DYNAMIC_CONFIG_KEY).open(
"w"
) as f:
f.write(json.dumps(config))
environment_str = dynamic_config_resolver.get_dynamic_config_value_sync(
dynamic_config_resolver.ENVIRONMENT_DYNAMIC_CONFIG_KEY
)
assert json.loads(environment_str) == config


def test_get_missing_config_value_sync(dynamic_config_mount_dir):
chainlet_service_config = dynamic_config_resolver.get_dynamic_config_value_sync(
definitions.DYNAMIC_CHAINLET_CONFIG_KEY
)
assert not chainlet_service_config


@pytest.mark.asyncio
@pytest.mark.parametrize(
"config",
[
{
"RandInt": {
"predict_url": "https://model-id.api.baseten.co/deployment/deployment-id/predict"
}
},
{},
"",
],
)
async def test_get_dynamic_chainlet_config_value_async(
config, tmp_path, dynamic_config_mount_dir
):
async with aiofiles.open(
tmp_path / definitions.DYNAMIC_CHAINLET_CONFIG_KEY, "w"
) as f:
await f.write(json.dumps(config))
chainlet_service_config = (
await dynamic_config_resolver.get_dynamic_config_value_async(
definitions.DYNAMIC_CHAINLET_CONFIG_KEY
)
)
assert json.loads(chainlet_service_config) == config


@pytest.mark.asyncio
@pytest.mark.parametrize(
"config",
[
{
"environment_name": "production",
"foo": "bar",
},
{},
"",
None,
],
)
async def test_get_dynamic_config_environment_value_async(
config, tmp_path, dynamic_config_mount_dir
):
async with aiofiles.open(
tmp_path / dynamic_config_resolver.ENVIRONMENT_DYNAMIC_CONFIG_KEY, "w"
) as f:
await f.write(json.dumps(config))
environment_str = await dynamic_config_resolver.get_dynamic_config_value_async(
dynamic_config_resolver.ENVIRONMENT_DYNAMIC_CONFIG_KEY
)
assert json.loads(environment_str) == config


@pytest.mark.asyncio
async def test_get_missing_config_value_async(dynamic_config_mount_dir):
chainlet_service_config = (
await dynamic_config_resolver.get_dynamic_config_value_async(
definitions.DYNAMIC_CHAINLET_CONFIG_KEY
)
)
assert not chainlet_service_config
Loading
Loading