Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduces new "env_var" and "file" fields to Secret to allow specifying name/mountPath on injection #1726

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ help:
.PHONY: install-piptools
install-piptools:
# pip 22.1 broke pip-tools: https://github.com/jazzband/pip-tools/issues/1617
python -m pip install -U pip-tools setuptools wheel "pip>=22.0.3,!=22.1"
python3 -m pip install -U pip-tools setuptools wheel "pip>=22.0.3,!=22.1"

.PHONY: update_boilerplate
update_boilerplate:
Expand Down Expand Up @@ -57,6 +57,7 @@ unit_test_codecov:
unit_test:
# Skip tensorflow tests and run them with the necessary env var set so that a working (albeit slower)
# library is used to serialize/deserialize protobufs is used.
# Can use pytest --lf to only rerun previously failed tests.
pytest -m "not sandbox_test" tests/flytekit/unit/ --ignore=tests/flytekit/unit/extras/tensorflow ${CODECOV_OPTS} && \
PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python pytest tests/flytekit/unit/extras/tensorflow ${CODECOV_OPTS}

Expand Down
43 changes: 34 additions & 9 deletions flytekit/core/context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,32 +349,52 @@ def __getattr__(self, item: str) -> _GroupSecrets:
"""
return self._GroupSecrets(item, self)

def get(self, group: str, key: Optional[str] = None, group_version: Optional[str] = None) -> str:
def get(
self,
group: Optional[str] = None,
key: Optional[str] = None,
group_version: Optional[str] = None,
env_name: Optional[str] = None,
) -> str:
"""
Retrieves a secret using the resolution order -> Env followed by file. If not found raises a ValueError
"""
self.check_group_key(group)
env_var = self.get_secrets_env_var(group, key, group_version)
fpath = self.get_secrets_file(group, key, group_version)
self.check_env_name_key(env_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.check_env_name_key(env_name)
self.assert_env_name_key(env_name)

env_var = self.get_secrets_env_var(group, key, group_version, env_name)
fpath = None
if env_name is None:
fpath = self.get_secrets_file(group, key, group_version)
v = os.environ.get(env_var)
if v is not None:
return v
if os.path.exists(fpath):
if fpath is not None and os.path.exists(fpath):
with open(fpath, "r") as f:
return f.read().strip()
raise ValueError(
f"Unable to find secret for key {key} in group {group} " f"in Env Var:{env_var} and FilePath: {fpath}"
f"Unable to find secret for key {key} in group {group} or name {env_name} "
f"in Env Var: {env_var} or FilePath: {fpath}"
)

def get_secrets_env_var(self, group: str, key: Optional[str] = None, group_version: Optional[str] = None) -> str:
def get_secrets_env_var(
self,
group: Optional[str] = None,
key: Optional[str] = None,
group_version: Optional[str] = None,
env_name: Optional[str] = None,
) -> str:
"""
Returns a string that matches the ENV Variable to look for the secrets
"""
self.check_env_name_key(env_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this redundant?

if env_name is not None:
return env_name
self.check_group_key(group)
l = [k.upper() for k in filter(None, (group, group_version, key))]
return f"{self._env_prefix}{'_'.join(l)}"

def get_secrets_file(self, group: str, key: Optional[str] = None, group_version: Optional[str] = None) -> str:
def get_secrets_file(
self, group: Optional[str] = None, key: Optional[str] = None, group_version: Optional[str] = None
) -> str:
"""
Returns a path that matches the file to look for the secrets
"""
Expand All @@ -384,10 +404,15 @@ def get_secrets_file(self, group: str, key: Optional[str] = None, group_version:
return os.path.join(self._base_dir, *l)

@staticmethod
def check_group_key(group: str):
def check_group_key(group: Optional[str]):
gpgn marked this conversation as resolved.
Show resolved Hide resolved
if group is None or group == "":
raise ValueError("secrets group is a mandatory field.")

@staticmethod
def check_env_name_key(env_name: Optional[str]):
gpgn marked this conversation as resolved.
Show resolved Hide resolved
if env_name is not None and len(env_name) <= 0:
raise ValueError(f"Invalid env_name {env_name}")


@dataclass(frozen=True)
class CompilationState(object):
Expand Down
41 changes: 39 additions & 2 deletions flytekit/models/security.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Optional

Expand All @@ -16,7 +16,8 @@ class Secret(_common.FlyteIdlEntity):
group is the Name of the secret. For example in kubernetes secrets is the name of the secret
key is optional and can be an individual secret identifier within the secret For k8s this is required
version is the version of the secret. This is an optional field
mount_requirement provides a hint to the system as to how the secret should be injected
mount_requirement provides a hint to the system as to how the secret should be injected. Soon to be deprecated.
mount_target provies a target for secret injection. Can be environment variable or file path This is an optional field.
"""

class MountType(Enum):
Expand All @@ -35,10 +36,42 @@ class MountType(Enum):
Caution: May not be supported in all environments
"""

@dataclass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually i would love to move away from these model classes. The Secret object will be a mixed bag, but I think that's okay. Could we just use the raw pb generated classes for the new fields?

(background: we wrote the model files before there were .pyi files, so nothing had type hints, field hints and it made things easier to use. but now we do have them with pyi files)

class MountEnvVar(_common.FlyteIdlEntity):
name: str

def to_flyte_idl(self) -> _sec.Secret.MountEnvVar:
return _sec.Secret.MountEnvVar(
name=self.name,
)

@classmethod
def from_flyte_idl(cls, pb2_object: _sec.Secret.MountEnvVar):
return cls(
name=pb2_object.name,
)

@dataclass
class MountFile(_common.FlyteIdlEntity):
path: str

def to_flyte_idl(self) -> _sec.Secret.MountFile:
return _sec.Secret.MountFile(
path=self.path,
)

@classmethod
def from_flyte_idl(cls, pb2_object: _sec.Secret.MountFile):
return cls(
path=pb2_object.path,
)

group: str
key: Optional[str] = None
group_version: Optional[str] = None
mount_requirement: MountType = MountType.ANY
env_var: Optional[MountEnvVar] = field(default_factory=lambda: None)
file: Optional[MountFile] = field(default_factory=lambda: None)

def __post_init__(self):
if self.group is None:
Expand All @@ -50,6 +83,8 @@ def to_flyte_idl(self) -> _sec.Secret:
group_version=self.group_version,
key=self.key,
mount_requirement=self.mount_requirement.value,
env_var=_sec.Secret.MountEnvVar(name=self.env_var.name) if self.env_var else None,
file=_sec.Secret.MountFile(path=self.file.path) if self.file else None,
)

@classmethod
Expand All @@ -59,6 +94,8 @@ def from_flyte_idl(cls, pb2_object: _sec.Secret) -> "Secret":
group_version=pb2_object.group_version if pb2_object.group_version else None,
key=pb2_object.key if pb2_object.key else None,
mount_requirement=Secret.MountType(pb2_object.mount_requirement),
env_var=Secret.MountEnvVar.from_flyte_idl(pb2_object.env_var) if pb2_object.HasField("env_var") else None,
file=Secret.MountFile.from_flyte_idl(pb2_object.file) if pb2_object.HasField("file") else None,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def _secret_to_dict(secret: Secret) -> typing.Dict[str, typing.Optional[str]]:
"key": secret.key,
"group_version": secret.group_version,
"mount_requirement": secret.mount_requirement.value,
"mount_target": secret.mount_target,
}

def secret_connect_args_to_dicts(self) -> typing.Optional[typing.Dict[str, typing.Dict[str, typing.Optional[str]]]]:
Expand Down
3 changes: 3 additions & 0 deletions tests/flytekit/common/parameterizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,9 @@
None,
security.Secret(group="x", key="g"),
security.Secret(group="x", key="y", mount_requirement=security.Secret.MountType.ANY),
security.Secret(group="x", key="y", env_var=security.Secret.MountEnvVar(name="z")),
security.Secret(group="x", key="y", file=security.Secret.MountFile(path="/z")),
security.Secret(group="x", key="y", group_version="1", mount_requirement=security.Secret.MountType.ENV_VAR),
security.Secret(group="x", key="y", group_version="1", mount_requirement=security.Secret.MountType.FILE),
]

Expand Down