Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Release 0.9.44 #1183

Merged
merged 2 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "truss"
version = "0.9.43"
version = "0.9.44"
description = "A seamless bridge from model development to model delivery"
license = "MIT"
readme = "README.md"
Expand Down
3 changes: 2 additions & 1 deletion truss/remote/baseten/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import IO, List, Optional, Tuple

import truss
from truss.constants import PRODUCTION_ENVIRONMENT_NAME
from truss.remote.baseten import custom_types as b10_types
from truss.remote.baseten.api import BasetenApi
from truss.remote.baseten.error import ApiError
Expand Down Expand Up @@ -271,7 +272,7 @@ def create_truss_service(
return model_version_json["id"], model_version_json["version_id"]

if model_id is None:
if environment:
if environment and environment != PRODUCTION_ENVIRONMENT_NAME:
raise ValueError(NO_ENVIRONMENTS_EXIST_ERROR_MESSAGING)
model_version_json = api.create_model_from_truss(
model_name=model_name,
Expand Down
111 changes: 111 additions & 0 deletions truss/tests/remote/baseten/test_core.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from tempfile import NamedTemporaryFile
from unittest.mock import MagicMock

import pytest
from truss.constants import PRODUCTION_ENVIRONMENT_NAME
from truss.remote.baseten import core
from truss.remote.baseten.api import BasetenApi
from truss.remote.baseten.core import create_truss_service
from truss.remote.baseten.error import ApiError


Expand Down Expand Up @@ -84,3 +87,111 @@ def test_get_prod_version_from_versions_error():
]
prod_version = core.get_prod_version_from_versions(versions)
assert prod_version is None


@pytest.mark.parametrize(
"environment",
[
None,
PRODUCTION_ENVIRONMENT_NAME,
],
)
def test_create_truss_service_handles_eligible_environment_values(environment):
api = MagicMock()
return_value = {
"id": "id",
"version_id": "model_version_id",
}
api.create_model_from_truss.return_value = return_value
model_id, model_version_id = create_truss_service(
api,
"model_name",
"s3_key",
"config",
is_trusted=False,
preserve_previous_prod_deployment=False,
is_draft=False,
model_id=None,
deployment_name="deployment_name",
environment=environment,
)
assert model_id == return_value["id"]
assert model_version_id == return_value["version_id"]
api.create_model_from_truss.assert_called_once()


@pytest.mark.parametrize(
"model_id",
[
"some_model_id",
None,
],
)
def test_create_truss_services_handles_is_draft(model_id):
api = MagicMock()
return_value = {
"id": "id",
"version_id": "model_version_id",
}
api.create_development_model_from_truss.return_value = return_value
model_id, model_version_id = create_truss_service(
api,
"model_name",
"s3_key",
"config",
is_trusted=False,
preserve_previous_prod_deployment=False,
is_draft=True,
model_id=model_id,
deployment_name="deployment_name",
)
assert model_id == return_value["id"]
assert model_version_id == return_value["version_id"]
api.create_development_model_from_truss.assert_called_once()


@pytest.mark.parametrize(
"inputs",
[
{
"environment": None,
"deployment_name": "some deployment",
"is_trusted": True,
"preserve_previous_prod_deployment": False,
},
{
"environment": PRODUCTION_ENVIRONMENT_NAME,
"deployment_name": None,
"is_trusted": True,
"preserve_previous_prod_deployment": False,
},
{
"environment": "staging",
"deployment_name": "some_deployment_name",
"is_trusted": False,
"preserve_previous_prod_deployment": True,
},
],
)
def test_create_truss_service_handles_existing_model(inputs):
api = MagicMock()
return_value = {
"id": "model_version_id",
}
api.create_model_version_from_truss.return_value = return_value
model_id, model_version_id = create_truss_service(
api,
"model_name",
"s3_key",
"config",
is_draft=False,
model_id="model_id",
**inputs,
)

assert model_id == "model_id"
assert model_version_id == return_value["id"]
api.create_model_version_from_truss.assert_called_once()
_, kwargs = api.create_model_version_from_truss.call_args
for k, v in inputs.items():
assert kwargs[k] == v
Loading