Skip to content

Commit

Permalink
change nebari config prefix and include name and overrides. remove un…
Browse files Browse the repository at this point in the history
…used OIDC var. prefix names for SA and app
  • Loading branch information
kenafoster committed Aug 30, 2023
1 parent 4e50f4b commit cbea668
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 35 deletions.
24 changes: 8 additions & 16 deletions src/nebari_plugin_mlflow_aws/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
#TODO this only works for AWS. How to check

class MlflowConfig(Base):
name: Optional[str] = "mlflow"
namespace: Optional[str] = None
values: Optional[Dict[str, Any]] = {}

class InputSchema(Base):
ml_flow: MlflowConfig = MlflowConfig()
mlflow: MlflowConfig = MlflowConfig()

class MlflowStage(NebariTerraformStage):
name = "mlflow"
Expand All @@ -35,24 +37,14 @@ def check(self, stage_outputs: Dict[str, Dict[str, Any]]) -> bool:
from keycloak import KeycloakAdmin
from keycloak.exceptions import KeycloakError

try:
_ = stage_outputs["stages/02-infrastructure"]["node_group_iam_policy_name"]

except KeyError:
print(
"\nPrerequisite stage output(s) not found: stages/02-infrastructure"
)
return False

# TODO: Module requires EKS cluster is configured for IRSA. Need to confirm minimum Nebari version once this feature is part of a release.
# TODO: Also should configure this module to require Nebari version in pyproject.toml?
try:
_ = stage_outputs["stages/02-infrastructure"]["cluster_oidc_issuer_url"]["value"]
_ = stage_outputs["stages/02-infrastructure"]["oidc_provider_arn"]["value"]

except KeyError:
print(
"\nPrerequisite stage output(s) not found in stages/02-infrastructure: cluster_oidc_issuer_url, oidc_provider_arn. Please ensure Nebari version is at least XX."
"\nPrerequisite stage output(s) not found in stages/02-infrastructure: cluster_oidc_issuer_url. Please ensure Nebari version is at least XX."
)
return False

Expand Down Expand Up @@ -135,19 +127,19 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]):
try:
domain = stage_outputs["stages/04-kubernetes-ingress"]["domain"]
cluster_oidc_issuer_url = stage_outputs["stages/02-infrastructure"]["cluster_oidc_issuer_url"]["value"]
oidc_provider_arn = stage_outputs["stages/02-infrastructure"]["oidc_provider_arn"]["value"]

except KeyError:
raise Exception("Prerequisite stage output(s) not found: stages/04-kubernetes-ingress")

chart_ns = self.config.ml_flow.namespace
chart_ns = self.config.mlflow.namespace
create_ns = True
if chart_ns == None or chart_ns == "" or chart_ns == self.config.namespace:
chart_ns = self.config.namespace
create_ns = False

return {
"name": self.config.escaped_project_name,
"chart_name": self.config.mlflow.name,
"project_name": self.config.escaped_project_name,
"realm_id": keycloak_config["realm_id"],
"client_id": CLIENT_NAME,
"base_url": f"https://{keycloak_config['domain']}/mlflow",
Expand All @@ -162,7 +154,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]):
"namespace": chart_ns,
"ingress_host": domain,
"cluster_oidc_issuer_url": cluster_oidc_issuer_url,
"oidc_provider_arn": oidc_provider_arn
"overrides": self.config.mlflow.values

}

Expand Down
8 changes: 5 additions & 3 deletions src/nebari_plugin_mlflow_aws/terraform/main.tf
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
locals {
mlflow_sa_name = "mlflow-sa"
mlflow_sa_name = "${var.chart_name}-sa"
}

# --------------------------------------------------------------------------
Expand All @@ -13,7 +13,7 @@ resource "random_id" "bucket_name_suffix" {
}

resource "aws_s3_bucket" "artifact_storage" {
bucket = "${var.name}-mlflow-artifacts-${random_id.bucket_name_suffix.hex}"
bucket = "${var.project_name}-mlflow-artifacts-${random_id.bucket_name_suffix.hex}"
acl = "private"

versioning {
Expand All @@ -30,7 +30,7 @@ module "iam_assumable_role_admin" {
version = "~> 4.0"

create_role = true
role_name = "${var.name}-mlflow-irsa"
role_name = "${var.project_name}-mlflow-irsa"
provider_url = replace(var.cluster_oidc_issuer_url, "https://", "")
role_policy_arns = [aws_iam_policy.mlflow_s3.arn]
oidc_fully_qualified_subjects = ["system:serviceaccount:${var.namespace}:${local.mlflow_sa_name}"]
Expand Down Expand Up @@ -84,11 +84,13 @@ module "keycloak" {
module "mlflow" {
source = "./modules/mlflow"

chart_name = var.chart_name
create_namespace = var.create_namespace
ingress_host = var.ingress_host
mlflow_sa_name = local.mlflow_sa_name
mlflow_sa_iam_role_arn = module.iam_assumable_role_admin.iam_role_arn
namespace = var.namespace
s3_bucket_name = aws_s3_bucket.artifact_storage.id
keycloak_config = module.keycloak.config
overrides = var.overrides
}
5 changes: 3 additions & 2 deletions src/nebari_plugin_mlflow_aws/terraform/modules/mlflow/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ resource "random_password" "mlflow_postgres" {
}

resource "helm_release" "mlflow" {
name = "mlflow"
name = var.chart_name
chart = "${path.module}/chart"
namespace = var.create_namespace ? kubernetes_namespace.this[0].metadata[0].name : var.namespace

Expand Down Expand Up @@ -79,6 +79,7 @@ resource "helm_release" "mlflow" {
value = "3600"
}
]
})
}),
yamlencode(var.overrides),
]
}
10 changes: 10 additions & 0 deletions src/nebari_plugin_mlflow_aws/terraform/modules/mlflow/variables.tf
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
variable "chart_name" {
description = "Name for mlflow chart and its namespaced resources."
type = string
}

variable "create_namespace" {
type = bool
}
Expand Down Expand Up @@ -29,4 +34,9 @@ variable "mlflow_sa_name" {
variable "mlflow_sa_iam_role_arn" {
description = "ARN of IAM role for Mlflow SA to assume"
type = string
}

variable "overrides" {
type = map(any)
default = {}
}
33 changes: 19 additions & 14 deletions src/nebari_plugin_mlflow_aws/terraform/variables.tf
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,6 @@ variable "external_url" {
type = string
}

variable "name" {
description = "Prefix name to assign to Nebari resources"
type = string
}

variable "namespace" {
type = string
}

variable "valid_redirect_uris" {
description = "A list of valid URIs a browser is permitted to redirect to after a successful login or logout"
type = list(string)
Expand All @@ -55,15 +46,29 @@ variable "ingress_host" {
type = string
}

variable "chart_name" {
description = "Name for mlflow chart and its namespaced resources."
type = string
}

variable "project_name" {
description = "Project name to assign to Nebari resources"
type = string
}

variable "namespace" {
type = string
}

variable "overrides" {
type = map(any)
default = {}
}

# IRSA SETTINGS
# -----------------

variable "cluster_oidc_issuer_url" {
description = "The URL on the EKS cluster for the OpenID Connect identity provider"
type = string
}

variable "oidc_provider_arn" {
description = "The ARN of the OIDC Provider"
type = string
}

0 comments on commit cbea668

Please sign in to comment.