Skip to content

Commit

Permalink
Update unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kenafoster committed Sep 7, 2023
1 parent c91ca70 commit 026fb0a
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 74 deletions.
2 changes: 1 addition & 1 deletion src/nebari_plugin_mlflow_aws/__about__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.2"
__version__ = "0.0.3"
85 changes: 43 additions & 42 deletions src/nebari_plugin_mlflow_aws/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,49 @@ class MlflowStage(NebariTerraformStage):
def template_directory(self):
return Path(inspect.getfile(self.__class__)).parent / "terraform"

def check(self, stage_outputs: Dict[str, Dict[str, Any]]) -> bool:
def _attempt_keycloak_connection(
keycloak_url,
username,
password,
master_realm_name,
client_id,
client_realm_name,
verify=False,
num_attempts=NUM_ATTEMPTS,
timeout=TIMEOUT,
):
from keycloak import KeycloakAdmin
from keycloak.exceptions import KeycloakError


for i in range(num_attempts):
try:
realm_admin = KeycloakAdmin(
keycloak_url,
username=username,
password=password,
realm_name=master_realm_name,
client_id=client_id,
verify=verify,
)
realm_admin.realm_name = client_realm_name # switch to nebari realm
c = realm_admin.get_client_id(CLIENT_NAME) # lookup client guid
existing_client = realm_admin.get_client(c) # query client info
if existing_client != None and existing_client["name"] == CLIENT_NAME:
print(
f"Attempt {i+1} succeeded connecting to keycloak and nebari client={CLIENT_NAME} exists"
)
return True
else:
print(
f"Attempt {i+1} succeeded connecting to keycloak but nebari client={CLIENT_NAME} did not exist"
)
except KeycloakError as e:
print(f"Attempt {i+1} failed connecting to keycloak {client_realm_name} realm -- {e}")
time.sleep(timeout)
return False

def check(self, stage_outputs: Dict[str, Dict[str, Any]]) -> bool:

hello = "test"
# TODO: Module requires EKS cluster is configured for IRSA. Need to confirm minimum Nebari version once this feature is part of a release.
# TODO: Also should configure this module to require Nebari version in pyproject.toml?
Expand All @@ -60,50 +99,12 @@ def check(self, stage_outputs: Dict[str, Dict[str, Any]]) -> bool:
return False

if not self.config.provider == ProviderEnum.aws:
raise KeyError("Plugin 'nebari_plugin_mlflow_aws' developed for 'aws' only. Detected provider is '{}'.".format(self.config.provider))
raise KeyError("Plugin nebari_plugin_mlflow_aws developed for aws only. Detected provider is {}.".format(self.config.provider))


keycloak_config = self.get_keycloak_config(stage_outputs)

def _attempt_keycloak_connection(
keycloak_url,
username,
password,
master_realm_name,
client_id,
client_realm_name,
verify=False,
num_attempts=NUM_ATTEMPTS,
timeout=TIMEOUT,
):
for i in range(num_attempts):
try:
realm_admin = KeycloakAdmin(
keycloak_url,
username=username,
password=password,
realm_name=master_realm_name,
client_id=client_id,
verify=verify,
)
realm_admin.realm_name = client_realm_name # switch to nebari realm
c = realm_admin.get_client_id(CLIENT_NAME) # lookup client guid
existing_client = realm_admin.get_client(c) # query client info
if existing_client != None and existing_client["name"] == CLIENT_NAME:
print(
f"Attempt {i+1} succeeded connecting to keycloak and nebari client={CLIENT_NAME} exists"
)
return True
else:
print(
f"Attempt {i+1} succeeded connecting to keycloak but nebari client={CLIENT_NAME} did not exist"
)
except KeycloakError as e:
print(f"Attempt {i+1} failed connecting to keycloak {client_realm_name} realm -- {e}")
time.sleep(timeout)
return False

if not _attempt_keycloak_connection(
if not self._attempt_keycloak_connection(
keycloak_config["keycloak_url"],
keycloak_config["username"],
keycloak_config["password"],
Expand Down
70 changes: 39 additions & 31 deletions tests/unit/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@ class TestConfig(InputSchema):
domain: str
escaped_project_name: str = ""
provider: str
mlflow: MlflowConfig = MlflowConfig()

@pytest.fixture(autouse=True)
def mock_keycloak_connection(monkeypatch):
monkeypatch.setattr("nebari_plugin_mlflow_aws.MlflowStage.check._attempt_keycloak_connection", lambda: True)
monkeypatch.setattr("nebari_plugin_mlflow_aws.MlflowStage._attempt_keycloak_connection", lambda *args, **kwargs: True)

def test_ctor():
sut = MlflowStage(output_directory = None, config = None)
Expand All @@ -23,7 +24,7 @@ def test_input_vars():


stage_outputs = get_stage_outputs()
#sut.check(stage_outputs)
sut.check(stage_outputs)
result = sut.input_vars(stage_outputs)
assert result["chart_name"] == "mlflow"
assert result["project_name"] == "testprojectname"
Expand All @@ -39,35 +40,42 @@ def test_input_vars():
assert result["cluster_oidc_issuer_url"] == "https://test-oidc-url.com"
assert result["overrides"] == {}

#def test_incompatible_cloud():
# TODO
#
#def test_default_namespace():
# config = TestConfig(namespace = "nebari-ns", domain = "my-test-domain.com")
# sut = MlflowStage(output_directory = None, config = config)
#
# stage_outputs = get_stage_outputs()
# result = sut.input_vars(stage_outputs)
# assert result["create_namespace"] == False
# assert result["namespace"] == "nebari-ns"
#
#def test_chart_namespace():
# config = TestConfig(namespace = "nebari-ns", domain = "my-test-domain.com", label_studio = MlflowStage(namespace = "label-studio-ns"))
# sut = MlflowStage(output_directory = None, config = config)
#
# stage_outputs = get_stage_outputs()
# result = sut.input_vars(stage_outputs)
# assert result["create_namespace"] == True
# assert result["namespace"] == "label-studio-ns"
#
#def test_chart_overrides():
# config = TestConfig(namespace = "nebari-ns", domain = "my-test-domain.com", label_studio = MlflowStage(values = { "foo": "bar" }))
# sut = MlflowStage(output_directory = None, config = config)
#
# stage_outputs = get_stage_outputs()
# result = sut.input_vars(stage_outputs)
# assert result["overrides"] == { "foo": "bar" }
#
def test_incompatible_cloud():
with pytest.raises(KeyError) as e_info:
config = TestConfig(namespace = "nebari-ns", domain = "my-test-domain.com", escaped_project_name="testprojectname", provider="gcp")
sut = MlflowStage(output_directory = None, config = config)

stage_outputs = get_stage_outputs()
sut.check(stage_outputs)

assert str(e_info.value) == "'Plugin nebari_plugin_mlflow_aws developed for aws only. Detected provider is gcp.'"

def test_default_namespace():
config = TestConfig(namespace = "nebari-ns", domain = "my-test-domain.com", provider="aws")
sut = MlflowStage(output_directory = None, config = config)

stage_outputs = get_stage_outputs()
result = sut.input_vars(stage_outputs)
assert result["create_namespace"] == False
assert result["namespace"] == "nebari-ns"

def test_chart_namespace():
config = TestConfig(namespace = "nebari-ns", domain = "my-test-domain.com", provider="aws", mlflow = MlflowConfig(namespace = "mlflow-ns"))
sut = MlflowStage(output_directory = None, config = config)

stage_outputs = get_stage_outputs()
result = sut.input_vars(stage_outputs)
assert result["create_namespace"] == True
assert result["namespace"] == "mlflow-ns"

def test_chart_overrides():
config = TestConfig(namespace = "nebari-ns", domain = "my-test-domain.com", provider="aws", mlflow = MlflowConfig(values = { "foo": "bar" }))
sut = MlflowStage(output_directory = None, config = config)

stage_outputs = get_stage_outputs()
result = sut.input_vars(stage_outputs)
assert result["overrides"] == { "foo": "bar" }

def get_stage_outputs():
return {
"stages/02-infrastructure": {
Expand Down

0 comments on commit 026fb0a

Please sign in to comment.