diff --git a/Makefile b/Makefile index 2d4d86050b..5470d8957f 100644 --- a/Makefile +++ b/Makefile @@ -13,7 +13,7 @@ help: .PHONY: install-piptools install-piptools: # pip 22.1 broke pip-tools: https://github.com/jazzband/pip-tools/issues/1617 - python -m pip install -U pip-tools setuptools wheel "pip>=22.0.3,!=22.1" + python3 -m pip install -U pip-tools setuptools wheel "pip>=22.0.3,!=22.1" .PHONY: update_boilerplate update_boilerplate: @@ -57,6 +57,7 @@ unit_test_codecov: unit_test: # Skip tensorflow tests and run them with the necessary env var set so that a working (albeit slower) # library is used to serialize/deserialize protobufs is used. + # Can use pytest --lf to only rerun previously failed tests. pytest -m "not sandbox_test" tests/flytekit/unit/ --ignore=tests/flytekit/unit/extras/tensorflow ${CODECOV_OPTS} && \ PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python pytest tests/flytekit/unit/extras/tensorflow ${CODECOV_OPTS} diff --git a/flytekit/core/context_manager.py b/flytekit/core/context_manager.py index 53ff504d55..9de8e92bdd 100644 --- a/flytekit/core/context_manager.py +++ b/flytekit/core/context_manager.py @@ -349,32 +349,52 @@ def __getattr__(self, item: str) -> _GroupSecrets: """ return self._GroupSecrets(item, self) - def get(self, group: str, key: Optional[str] = None, group_version: Optional[str] = None) -> str: + def get( + self, + group: Optional[str] = None, + key: Optional[str] = None, + group_version: Optional[str] = None, + env_name: Optional[str] = None, + ) -> str: """ Retrieves a secret using the resolution order -> Env followed by file. If not found raises a ValueError """ - self.check_group_key(group) - env_var = self.get_secrets_env_var(group, key, group_version) - fpath = self.get_secrets_file(group, key, group_version) + self.check_env_name_key(env_name) + env_var = self.get_secrets_env_var(group, key, group_version, env_name) + fpath = None + if env_name is None: + fpath = self.get_secrets_file(group, key, group_version) v = os.environ.get(env_var) if v is not None: return v - if os.path.exists(fpath): + if fpath is not None and os.path.exists(fpath): with open(fpath, "r") as f: return f.read().strip() raise ValueError( - f"Unable to find secret for key {key} in group {group} " f"in Env Var:{env_var} and FilePath: {fpath}" + f"Unable to find secret for key {key} in group {group} or name {env_name} " + f"in Env Var: {env_var} or FilePath: {fpath}" ) - def get_secrets_env_var(self, group: str, key: Optional[str] = None, group_version: Optional[str] = None) -> str: + def get_secrets_env_var( + self, + group: Optional[str] = None, + key: Optional[str] = None, + group_version: Optional[str] = None, + env_name: Optional[str] = None, + ) -> str: """ Returns a string that matches the ENV Variable to look for the secrets """ + self.check_env_name_key(env_name) + if env_name is not None: + return env_name self.check_group_key(group) l = [k.upper() for k in filter(None, (group, group_version, key))] return f"{self._env_prefix}{'_'.join(l)}" - def get_secrets_file(self, group: str, key: Optional[str] = None, group_version: Optional[str] = None) -> str: + def get_secrets_file( + self, group: Optional[str] = None, key: Optional[str] = None, group_version: Optional[str] = None + ) -> str: """ Returns a path that matches the file to look for the secrets """ @@ -384,10 +404,15 @@ def get_secrets_file(self, group: str, key: Optional[str] = None, group_version: return os.path.join(self._base_dir, *l) @staticmethod - def check_group_key(group: str): + def check_group_key(group: Optional[str] = None): if group is None or group == "": raise ValueError("secrets group is a mandatory field.") + @staticmethod + def check_env_name_key(env_name: Optional[str] = None): + if env_name is not None and len(env_name) <= 0: + raise ValueError(f"Invalid env_name {env_name}") + @dataclass(frozen=True) class CompilationState(object): diff --git a/flytekit/models/security.py b/flytekit/models/security.py index 9af90a4b8a..ba19b5513e 100644 --- a/flytekit/models/security.py +++ b/flytekit/models/security.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import Enum from typing import List, Optional @@ -16,7 +16,8 @@ class Secret(_common.FlyteIdlEntity): group is the Name of the secret. For example in kubernetes secrets is the name of the secret key is optional and can be an individual secret identifier within the secret For k8s this is required version is the version of the secret. This is an optional field - mount_requirement provides a hint to the system as to how the secret should be injected + mount_requirement provides a hint to the system as to how the secret should be injected. Soon to be deprecated. + mount_target provies a target for secret injection. Can be environment variable or file path This is an optional field. """ class MountType(Enum): @@ -35,10 +36,42 @@ class MountType(Enum): Caution: May not be supported in all environments """ + @dataclass + class MountEnvVar(_common.FlyteIdlEntity): + name: str + + def to_flyte_idl(self) -> _sec.Secret.MountEnvVar: + return _sec.Secret.MountEnvVar( + name=self.name, + ) + + @classmethod + def from_flyte_idl(cls, pb2_object: _sec.Secret.MountEnvVar): + return cls( + name=pb2_object.name, + ) + + @dataclass + class MountFile(_common.FlyteIdlEntity): + path: str + + def to_flyte_idl(self) -> _sec.Secret.MountFile: + return _sec.Secret.MountFile( + path=self.path, + ) + + @classmethod + def from_flyte_idl(cls, pb2_object: _sec.Secret.MountFile): + return cls( + path=pb2_object.path, + ) + group: str key: Optional[str] = None group_version: Optional[str] = None mount_requirement: MountType = MountType.ANY + env_var: Optional[MountEnvVar] = field(default_factory=lambda: None) + file: Optional[MountFile] = field(default_factory=lambda: None) def __post_init__(self): if self.group is None: @@ -50,6 +83,8 @@ def to_flyte_idl(self) -> _sec.Secret: group_version=self.group_version, key=self.key, mount_requirement=self.mount_requirement.value, + env_var=_sec.Secret.MountEnvVar(name=self.env_var.name) if self.env_var else None, + file=_sec.Secret.MountFile(path=self.file.path) if self.file else None, ) @classmethod @@ -59,6 +94,8 @@ def from_flyte_idl(cls, pb2_object: _sec.Secret) -> "Secret": group_version=pb2_object.group_version if pb2_object.group_version else None, key=pb2_object.key if pb2_object.key else None, mount_requirement=Secret.MountType(pb2_object.mount_requirement), + env_var=Secret.MountEnvVar.from_flyte_idl(pb2_object.env_var) if pb2_object.HasField("env_var") else None, + file=Secret.MountFile.from_flyte_idl(pb2_object.file) if pb2_object.HasField("file") else None, ) diff --git a/plugins/flytekit-sqlalchemy/flytekitplugins/sqlalchemy/task.py b/plugins/flytekit-sqlalchemy/flytekitplugins/sqlalchemy/task.py index 8e8c464bd4..74d226572b 100644 --- a/plugins/flytekit-sqlalchemy/flytekitplugins/sqlalchemy/task.py +++ b/plugins/flytekit-sqlalchemy/flytekitplugins/sqlalchemy/task.py @@ -56,6 +56,7 @@ def _secret_to_dict(secret: Secret) -> typing.Dict[str, typing.Optional[str]]: "key": secret.key, "group_version": secret.group_version, "mount_requirement": secret.mount_requirement.value, + "mount_target": secret.mount_target, } def secret_connect_args_to_dicts(self) -> typing.Optional[typing.Dict[str, typing.Dict[str, typing.Optional[str]]]]: diff --git a/tests/flytekit/common/parameterizers.py b/tests/flytekit/common/parameterizers.py index d5b07fe420..c60a8afc67 100644 --- a/tests/flytekit/common/parameterizers.py +++ b/tests/flytekit/common/parameterizers.py @@ -238,6 +238,9 @@ None, security.Secret(group="x", key="g"), security.Secret(group="x", key="y", mount_requirement=security.Secret.MountType.ANY), + security.Secret(group="x", key="y", env_var=security.Secret.MountEnvVar(name="z")), + security.Secret(group="x", key="y", file=security.Secret.MountFile(path="/z")), + security.Secret(group="x", key="y", group_version="1", mount_requirement=security.Secret.MountType.ENV_VAR), security.Secret(group="x", key="y", group_version="1", mount_requirement=security.Secret.MountType.FILE), ]