Skip to content

Commit

Permalink
Metadata optional, id_validator reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
shouhanzen committed Mar 6, 2024
1 parent 9bb95c2 commit 468188b
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 23 deletions.
42 changes: 28 additions & 14 deletions src/dsmlp/app/id_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
from dsmlp.plugin.logger import Logger
from dsmlp.app.types import *


class IDValidator(ComponentValidator):

def __init__(self, awsed: AwsedClient, logger: Logger) -> None:
self.awsed = awsed
self.logger = logger

def validate_pod(self, request: Request):
"""
Validate pods for namespaces with the 'k8s-sync' label
Expand All @@ -28,7 +29,8 @@ def validate_pod(self, request: Request):
# if 'k8s-sync' in namespace.labels:
user = self.awsed.describe_user(username)
if not user:
raise ValidationFailure(f"namespace: no AWSEd user found with username {username}")
raise ValidationFailure(
f"namespace: no AWSEd user found with username {username}")
allowed_uid = user.uid
allowed_courses = user.enrollments

Expand All @@ -39,15 +41,20 @@ def validate_pod(self, request: Request):

metadata = request.object.metadata
spec = request.object.spec
self.validate_course_enrollment(allowed_courses, metadata.labels)
self.validate_pod_security_context(allowed_uid, allowed_gids, spec.securityContext)

if metadata is not None and metadata.labels is not None:
self.validate_course_enrollment(allowed_courses, metadata.labels)

self.validate_pod_security_context(
allowed_uid, allowed_gids, spec.securityContext)
self.validate_containers(allowed_uid, allowed_gids, spec)

def validate_course_enrollment(self, allowed_courses: List[str], labels: Dict[str, str]):
if not 'dsmlp/course' in labels:
return
if not labels['dsmlp/course'] in allowed_courses:
raise ValidationFailure(f"metadata.labels: dsmlp/course must be in range {allowed_courses}")
raise ValidationFailure(
f"metadata.labels: dsmlp/course must be in range {allowed_courses}")

def validate_pod_security_context(
self,
Expand All @@ -59,18 +66,22 @@ def validate_pod_security_context(
return

if securityContext.runAsUser is not None and authorized_uid != securityContext.runAsUser:
raise ValidationFailure(f"spec.securityContext: uid must be in range [{authorized_uid}]")
raise ValidationFailure(
f"spec.securityContext: uid must be in range [{authorized_uid}]")

if securityContext.runAsGroup is not None and securityContext.runAsGroup not in allowed_teams:
raise ValidationFailure(f"spec.securityContext: gid must be in range {allowed_teams}")
raise ValidationFailure(
f"spec.securityContext: gid must be in range {allowed_teams}")

if securityContext.fsGroup is not None and securityContext.fsGroup not in allowed_teams:
raise ValidationFailure(f"spec.securityContext: gid must be in range {allowed_teams}")
raise ValidationFailure(
f"spec.securityContext: gid must be in range {allowed_teams}")

if securityContext.supplementalGroups is not None:
for sgroup in securityContext.supplementalGroups:
if not sgroup in allowed_teams:
raise ValidationFailure(f"spec.securityContext: gid must be in range {allowed_teams}")
raise ValidationFailure(
f"spec.securityContext: gid must be in range {allowed_teams}")

def validate_containers(
self,
Expand All @@ -81,8 +92,10 @@ def validate_containers(
"""
Validate the security context of containers and initContainers
"""
self.validate_security_contexts(authorized_uid, allowed_teams, spec.containers, "containers")
self.validate_security_contexts(authorized_uid, allowed_teams, spec.initContainers, "initContainers")
self.validate_security_contexts(
authorized_uid, allowed_teams, spec.containers, "containers")
self.validate_security_contexts(
authorized_uid, allowed_teams, spec.initContainers, "initContainers")

def validate_security_contexts(
self, authorized_uid: int, allowed_teams: List[int],
Expand All @@ -100,7 +113,8 @@ def validate_security_contexts(
if securityContext is None:
continue

self.validate_security_context(authorized_uid, allowed_teams, securityContext, f"{context}[{i}]")
self.validate_security_context(
authorized_uid, allowed_teams, securityContext, f"{context}[{i}]")

def validate_security_context(
self,
Expand Down Expand Up @@ -128,4 +142,4 @@ def admission_response(self, uid, allowed, message):
"message": message
}
}
}
}
16 changes: 8 additions & 8 deletions src/dsmlp/ext/kube.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,21 @@ def get_namespace(self, name: str) -> Namespace:
api = self.get_policy_api()
v1namespace: V1Namespace = api.read_namespace(name=name)
metadata: V1ObjectMeta = v1namespace.metadata

gpu_quota = 1
if metadata.annotations is not None and GPU_LIMIT_ANNOTATION in metadata.annotations:
if metadata is not None and metadata.annotations is not None and GPU_LIMIT_ANNOTATION in metadata.annotations:
gpu_quota = int(metadata.annotations[GPU_LIMIT_ANNOTATION])

return Namespace(
name=metadata.name,
labels=metadata.labels,
gpu_quota=gpu_quota)

def get_gpus_in_namespace(self, name: str) -> int:
api = self.get_policy_api()
V1Namespace: V1Namespace = api.read_namespace(name=name)
pods = api.list_namespaced_pod(namespace=name)

gpu_count = 0
for pod in pods.items:
for container in pod.spec.containers:
Expand All @@ -45,13 +45,13 @@ def get_gpus_in_namespace(self, name: str) -> int:
limit = int(container.resources.limits[GPU_LABEL])
except (KeyError, AttributeError, TypeError):
pass

gpu_count += max(requested, limit)

return gpu_count


# noinspection PyMethodMayBeStatic

def get_policy_api(self) -> CoreV1Api:
try:
config.load_incluster_config()
Expand Down
6 changes: 5 additions & 1 deletion tests/app/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ def gen_request(gpu_req: int = 0, gpu_lim: int = 0, low_priority: bool = False,
if course is not None:
labels["dsmlp/course"] = course

metadata = None
if labels != {}:
metadata = ObjectMeta(labels=labels)

sec_context = None
if run_as_user is not None or run_as_group is not None or fs_group is not None or supplemental_groups is not None:
sec_context = PodSecurityContext(
Expand All @@ -52,7 +56,7 @@ def gen_request(gpu_req: int = 0, gpu_lim: int = 0, low_priority: bool = False,
uid=uid,
namespace=username,
object=Object(
metadata=ObjectMeta(labels=labels),
metadata=metadata,
spec=PodSpec(
containers=containers,
priorityClassName=p_class,
Expand Down

0 comments on commit 468188b

Please sign in to comment.