From 6e82f5e0a7abe2513accfce92458d38923d63a68 Mon Sep 17 00:00:00 2001 From: rcano-baseten Date: Thu, 10 Oct 2024 11:49:13 -0400 Subject: [PATCH 1/2] fix truss push --promote (#1181) * fix truss * pre commit hook * adds unit tests * ensure None environment is accepted --- truss/remote/baseten/core.py | 3 +- truss/tests/remote/baseten/test_core.py | 111 ++++++++++++++++++++++++ 2 files changed, 113 insertions(+), 1 deletion(-) diff --git a/truss/remote/baseten/core.py b/truss/remote/baseten/core.py index a2f766586..e6df30b87 100644 --- a/truss/remote/baseten/core.py +++ b/truss/remote/baseten/core.py @@ -4,6 +4,7 @@ from typing import IO, List, Optional, Tuple import truss +from truss.constants import PRODUCTION_ENVIRONMENT_NAME from truss.remote.baseten import custom_types as b10_types from truss.remote.baseten.api import BasetenApi from truss.remote.baseten.error import ApiError @@ -271,7 +272,7 @@ def create_truss_service( return model_version_json["id"], model_version_json["version_id"] if model_id is None: - if environment: + if environment and environment != PRODUCTION_ENVIRONMENT_NAME: raise ValueError(NO_ENVIRONMENTS_EXIST_ERROR_MESSAGING) model_version_json = api.create_model_from_truss( model_name=model_name, diff --git a/truss/tests/remote/baseten/test_core.py b/truss/tests/remote/baseten/test_core.py index 21612f5de..5f3727ac4 100644 --- a/truss/tests/remote/baseten/test_core.py +++ b/truss/tests/remote/baseten/test_core.py @@ -1,8 +1,11 @@ from tempfile import NamedTemporaryFile from unittest.mock import MagicMock +import pytest +from truss.constants import PRODUCTION_ENVIRONMENT_NAME from truss.remote.baseten import core from truss.remote.baseten.api import BasetenApi +from truss.remote.baseten.core import create_truss_service from truss.remote.baseten.error import ApiError @@ -84,3 +87,111 @@ def test_get_prod_version_from_versions_error(): ] prod_version = core.get_prod_version_from_versions(versions) assert prod_version is None + + +@pytest.mark.parametrize( + "environment", + [ + None, + PRODUCTION_ENVIRONMENT_NAME, + ], +) +def test_create_truss_service_handles_eligible_environment_values(environment): + api = MagicMock() + return_value = { + "id": "id", + "version_id": "model_version_id", + } + api.create_model_from_truss.return_value = return_value + model_id, model_version_id = create_truss_service( + api, + "model_name", + "s3_key", + "config", + is_trusted=False, + preserve_previous_prod_deployment=False, + is_draft=False, + model_id=None, + deployment_name="deployment_name", + environment=environment, + ) + assert model_id == return_value["id"] + assert model_version_id == return_value["version_id"] + api.create_model_from_truss.assert_called_once() + + +@pytest.mark.parametrize( + "model_id", + [ + "some_model_id", + None, + ], +) +def test_create_truss_services_handles_is_draft(model_id): + api = MagicMock() + return_value = { + "id": "id", + "version_id": "model_version_id", + } + api.create_development_model_from_truss.return_value = return_value + model_id, model_version_id = create_truss_service( + api, + "model_name", + "s3_key", + "config", + is_trusted=False, + preserve_previous_prod_deployment=False, + is_draft=True, + model_id=model_id, + deployment_name="deployment_name", + ) + assert model_id == return_value["id"] + assert model_version_id == return_value["version_id"] + api.create_development_model_from_truss.assert_called_once() + + +@pytest.mark.parametrize( + "inputs", + [ + { + "environment": None, + "deployment_name": "some deployment", + "is_trusted": True, + "preserve_previous_prod_deployment": False, + }, + { + "environment": PRODUCTION_ENVIRONMENT_NAME, + "deployment_name": None, + "is_trusted": True, + "preserve_previous_prod_deployment": False, + }, + { + "environment": "staging", + "deployment_name": "some_deployment_name", + "is_trusted": False, + "preserve_previous_prod_deployment": True, + }, + ], +) +def test_create_truss_service_handles_existing_model(inputs): + api = MagicMock() + return_value = { + "id": "model_version_id", + } + api.create_model_version_from_truss.return_value = return_value + model_id, model_version_id = create_truss_service( + api, + "model_name", + "s3_key", + "config", + is_draft=False, + model_id="model_id", + **inputs, + ) + + assert model_id == "model_id" + assert model_version_id == return_value["id"] + api.create_model_version_from_truss.assert_called_once() + _, kwargs = api.create_model_version_from_truss.call_args + for k, v in inputs.items(): + assert kwargs[k] == v From c57f21e990c8b47316ff94cca68b29d74796afdb Mon Sep 17 00:00:00 2001 From: basetenbot <96544894+basetenbot@users.noreply.github.com> Date: Thu, 10 Oct 2024 16:04:48 +0000 Subject: [PATCH 2/2] Bump version to 0.9.44 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index bbd740e12..e6b31bd3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "truss" -version = "0.9.43" +version = "0.9.44" description = "A seamless bridge from model development to model delivery" license = "MIT" readme = "README.md"