From 026fb0a07ae121739407ba5d6aa305763e9d9cec Mon Sep 17 00:00:00 2001 From: Ken Foster Date: Thu, 7 Sep 2023 17:41:40 +0000 Subject: [PATCH] Update unit tests --- src/nebari_plugin_mlflow_aws/__about__.py | 2 +- src/nebari_plugin_mlflow_aws/__init__.py | 85 ++++++++++++----------- tests/unit/test_plugin.py | 70 ++++++++++--------- 3 files changed, 83 insertions(+), 74 deletions(-) diff --git a/src/nebari_plugin_mlflow_aws/__about__.py b/src/nebari_plugin_mlflow_aws/__about__.py index 3b93d0b..27fdca4 100644 --- a/src/nebari_plugin_mlflow_aws/__about__.py +++ b/src/nebari_plugin_mlflow_aws/__about__.py @@ -1 +1 @@ -__version__ = "0.0.2" +__version__ = "0.0.3" diff --git a/src/nebari_plugin_mlflow_aws/__init__.py b/src/nebari_plugin_mlflow_aws/__init__.py index f6a722a..7224bc1 100644 --- a/src/nebari_plugin_mlflow_aws/__init__.py +++ b/src/nebari_plugin_mlflow_aws/__init__.py @@ -33,10 +33,49 @@ class MlflowStage(NebariTerraformStage): def template_directory(self): return Path(inspect.getfile(self.__class__)).parent / "terraform" - def check(self, stage_outputs: Dict[str, Dict[str, Any]]) -> bool: + def _attempt_keycloak_connection( + keycloak_url, + username, + password, + master_realm_name, + client_id, + client_realm_name, + verify=False, + num_attempts=NUM_ATTEMPTS, + timeout=TIMEOUT, + ): from keycloak import KeycloakAdmin from keycloak.exceptions import KeycloakError - + + for i in range(num_attempts): + try: + realm_admin = KeycloakAdmin( + keycloak_url, + username=username, + password=password, + realm_name=master_realm_name, + client_id=client_id, + verify=verify, + ) + realm_admin.realm_name = client_realm_name # switch to nebari realm + c = realm_admin.get_client_id(CLIENT_NAME) # lookup client guid + existing_client = realm_admin.get_client(c) # query client info + if existing_client != None and existing_client["name"] == CLIENT_NAME: + print( + f"Attempt {i+1} succeeded connecting to keycloak and nebari client={CLIENT_NAME} exists" + ) + return True + else: + print( + f"Attempt {i+1} succeeded connecting to keycloak but nebari client={CLIENT_NAME} did not exist" + ) + except KeycloakError as e: + print(f"Attempt {i+1} failed connecting to keycloak {client_realm_name} realm -- {e}") + time.sleep(timeout) + return False + + def check(self, stage_outputs: Dict[str, Dict[str, Any]]) -> bool: + hello = "test" # TODO: Module requires EKS cluster is configured for IRSA. Need to confirm minimum Nebari version once this feature is part of a release. # TODO: Also should configure this module to require Nebari version in pyproject.toml? @@ -60,50 +99,12 @@ def check(self, stage_outputs: Dict[str, Dict[str, Any]]) -> bool: return False if not self.config.provider == ProviderEnum.aws: - raise KeyError("Plugin 'nebari_plugin_mlflow_aws' developed for 'aws' only. Detected provider is '{}'.".format(self.config.provider)) + raise KeyError("Plugin nebari_plugin_mlflow_aws developed for aws only. Detected provider is {}.".format(self.config.provider)) keycloak_config = self.get_keycloak_config(stage_outputs) - def _attempt_keycloak_connection( - keycloak_url, - username, - password, - master_realm_name, - client_id, - client_realm_name, - verify=False, - num_attempts=NUM_ATTEMPTS, - timeout=TIMEOUT, - ): - for i in range(num_attempts): - try: - realm_admin = KeycloakAdmin( - keycloak_url, - username=username, - password=password, - realm_name=master_realm_name, - client_id=client_id, - verify=verify, - ) - realm_admin.realm_name = client_realm_name # switch to nebari realm - c = realm_admin.get_client_id(CLIENT_NAME) # lookup client guid - existing_client = realm_admin.get_client(c) # query client info - if existing_client != None and existing_client["name"] == CLIENT_NAME: - print( - f"Attempt {i+1} succeeded connecting to keycloak and nebari client={CLIENT_NAME} exists" - ) - return True - else: - print( - f"Attempt {i+1} succeeded connecting to keycloak but nebari client={CLIENT_NAME} did not exist" - ) - except KeycloakError as e: - print(f"Attempt {i+1} failed connecting to keycloak {client_realm_name} realm -- {e}") - time.sleep(timeout) - return False - - if not _attempt_keycloak_connection( + if not self._attempt_keycloak_connection( keycloak_config["keycloak_url"], keycloak_config["username"], keycloak_config["password"], diff --git a/tests/unit/test_plugin.py b/tests/unit/test_plugin.py index 1fbd45a..8854428 100644 --- a/tests/unit/test_plugin.py +++ b/tests/unit/test_plugin.py @@ -7,10 +7,11 @@ class TestConfig(InputSchema): domain: str escaped_project_name: str = "" provider: str + mlflow: MlflowConfig = MlflowConfig() @pytest.fixture(autouse=True) def mock_keycloak_connection(monkeypatch): - monkeypatch.setattr("nebari_plugin_mlflow_aws.MlflowStage.check._attempt_keycloak_connection", lambda: True) + monkeypatch.setattr("nebari_plugin_mlflow_aws.MlflowStage._attempt_keycloak_connection", lambda *args, **kwargs: True) def test_ctor(): sut = MlflowStage(output_directory = None, config = None) @@ -23,7 +24,7 @@ def test_input_vars(): stage_outputs = get_stage_outputs() - #sut.check(stage_outputs) + sut.check(stage_outputs) result = sut.input_vars(stage_outputs) assert result["chart_name"] == "mlflow" assert result["project_name"] == "testprojectname" @@ -39,35 +40,42 @@ def test_input_vars(): assert result["cluster_oidc_issuer_url"] == "https://test-oidc-url.com" assert result["overrides"] == {} -#def test_incompatible_cloud(): -# TODO -# -#def test_default_namespace(): -# config = TestConfig(namespace = "nebari-ns", domain = "my-test-domain.com") -# sut = MlflowStage(output_directory = None, config = config) -# -# stage_outputs = get_stage_outputs() -# result = sut.input_vars(stage_outputs) -# assert result["create_namespace"] == False -# assert result["namespace"] == "nebari-ns" -# -#def test_chart_namespace(): -# config = TestConfig(namespace = "nebari-ns", domain = "my-test-domain.com", label_studio = MlflowStage(namespace = "label-studio-ns")) -# sut = MlflowStage(output_directory = None, config = config) -# -# stage_outputs = get_stage_outputs() -# result = sut.input_vars(stage_outputs) -# assert result["create_namespace"] == True -# assert result["namespace"] == "label-studio-ns" -# -#def test_chart_overrides(): -# config = TestConfig(namespace = "nebari-ns", domain = "my-test-domain.com", label_studio = MlflowStage(values = { "foo": "bar" })) -# sut = MlflowStage(output_directory = None, config = config) -# -# stage_outputs = get_stage_outputs() -# result = sut.input_vars(stage_outputs) -# assert result["overrides"] == { "foo": "bar" } -# +def test_incompatible_cloud(): + with pytest.raises(KeyError) as e_info: + config = TestConfig(namespace = "nebari-ns", domain = "my-test-domain.com", escaped_project_name="testprojectname", provider="gcp") + sut = MlflowStage(output_directory = None, config = config) + + stage_outputs = get_stage_outputs() + sut.check(stage_outputs) + + assert str(e_info.value) == "'Plugin nebari_plugin_mlflow_aws developed for aws only. Detected provider is gcp.'" + +def test_default_namespace(): + config = TestConfig(namespace = "nebari-ns", domain = "my-test-domain.com", provider="aws") + sut = MlflowStage(output_directory = None, config = config) + + stage_outputs = get_stage_outputs() + result = sut.input_vars(stage_outputs) + assert result["create_namespace"] == False + assert result["namespace"] == "nebari-ns" + +def test_chart_namespace(): + config = TestConfig(namespace = "nebari-ns", domain = "my-test-domain.com", provider="aws", mlflow = MlflowConfig(namespace = "mlflow-ns")) + sut = MlflowStage(output_directory = None, config = config) + + stage_outputs = get_stage_outputs() + result = sut.input_vars(stage_outputs) + assert result["create_namespace"] == True + assert result["namespace"] == "mlflow-ns" + +def test_chart_overrides(): + config = TestConfig(namespace = "nebari-ns", domain = "my-test-domain.com", provider="aws", mlflow = MlflowConfig(values = { "foo": "bar" })) + sut = MlflowStage(output_directory = None, config = config) + + stage_outputs = get_stage_outputs() + result = sut.input_vars(stage_outputs) + assert result["overrides"] == { "foo": "bar" } + def get_stage_outputs(): return { "stages/02-infrastructure": {