Skip to content

Commit

Permalink
Produce apis to add a k8s cloud and tear down its models first
Browse files Browse the repository at this point in the history
  • Loading branch information
addyess committed Feb 29, 2024
1 parent 40be682 commit d987b2b
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 14 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ build
dist/
*.orig
report
.coverage
.coverage
juju-crashdump*
133 changes: 121 additions & 12 deletions pytest_operator/plugin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import base64
import contextlib
import dataclasses
import enum
Expand Down Expand Up @@ -43,6 +44,8 @@
import yaml
from _pytest.config import Config
from _pytest.config.argparsing import Parser
from kubernetes.client import Configuration as K8sConfiguration
from juju.client import client
from juju.client.jujudata import FileJujuData
from juju.exceptions import DeadEntityException
from juju.errors import JujuError
Expand Down Expand Up @@ -399,6 +402,10 @@ class ModelInUseError(Exception):
"""Raise when trying to add a model alias which already exists."""


BundleOpt = TypeVar("BundleOpt", str, Path, "OpsTest.Bundle")
Timeout = TypeVar("Timeout", float, int)


@dataclasses.dataclass
class ModelState:
model: Model
Expand All @@ -408,14 +415,18 @@ class ModelState:
model_name: str
config: Optional[dict] = None
tmp_path: Optional[Path] = None
timeout: Optional[Timeout] = None

@property
def full_name(self) -> str:
return f"{self.controller_name}:{self.model_name}"


BundleOpt = TypeVar("BundleOpt", str, Path, "OpsTest.Bundle")
Timeout = TypeVar("Timeout", float, int)
@dataclasses.dataclass
class CloudState:
cloud_name: str
models: List[str] = dataclasses.field(default_factory=list)
timeout: Optional[Timeout] = None


class OpsTest:
Expand Down Expand Up @@ -510,6 +521,7 @@ def __init__(self, request, tmp_path_factory):
# use an OrderedDict so that the first model made is destroyed last.
self._current_alias = None
self._models: MutableMapping[str, ModelState] = OrderedDict()
self._clouds: MutableMapping[str, CloudState] = OrderedDict()

@contextlib.contextmanager
def model_context(self, alias: str) -> Generator[Model, None, None]:
Expand Down Expand Up @@ -597,14 +609,16 @@ def keep_model(self) -> bool:
current_state = self.current_alias and self._models.get(self.current_alias)
return current_state.keep if current_state else self._init_keep_model

def _generate_model_name(self) -> str:
def _generate_name(self, kind: str) -> str:
module_name = self.request.module.__name__.rpartition(".")[-1]
suffix = "".join(choices(ascii_lowercase + digits, k=4))
if kind != "model":
suffix = "-".join((kind, suffix))
return f"{module_name.replace('_', '-')}-{suffix}"

@cached_property
def default_model_name(self) -> str:
return self._generate_model_name()
return self._generate_name(kind="model")

async def run(
self,
Expand Down Expand Up @@ -670,23 +684,33 @@ async def _add_model(self, cloud_name, model_name, keep=False, **kwargs):
"""
controller = self._controller
controller_name = controller.controller_name
credential_name = None
timeout = None
if not cloud_name:
# if not provided, try the default cloud name
cloud_name = self._init_cloud_name
if not cloud_name:
# if not provided, use the controller's default cloud
cloud_name = await controller.get_cloud()
if ops_cloud := self._clouds.get(cloud_name):
credential_name = cloud_name
timeout = ops_cloud.timeout

model_full_name = f"{controller_name}:{model_name}"
log.info(f"Adding model {model_full_name} on cloud {cloud_name}")

model = await controller.add_model(model_name, cloud_name, **kwargs)
model = await controller.add_model(
model_name, cloud_name, credential_name=credential_name, **kwargs
)
# NB: This call to `juju models` is needed because libjuju's
# `add_model` doesn't update the models.yaml cache that the Juju
# CLI depends on with the model's UUID, which the CLI requires to
# connect. Calling `juju models` beforehand forces the CLI to
# update the cache from the controller.
await self.juju("models")
state = ModelState(model, keep, controller_name, cloud_name, model_name)
state = ModelState(
model, keep, controller_name, cloud_name, model_name, timeout=timeout
)
state.config = await model.get_config()
return state

Expand Down Expand Up @@ -820,11 +844,13 @@ async def track_model(
)
else:
cloud_name = cloud_name or self.cloud_name
model_name = model_name or self._generate_model_name()
model_name = model_name or self._generate_name(kind="model")
model_state = await self._add_model(
cloud_name, model_name, keep_val, **kwargs
)
self._models[alias] = model_state
if ops_cloud := self._clouds.get(cloud_name):
ops_cloud.models.append(alias)
return model_state.model

async def log_model(self):
Expand Down Expand Up @@ -886,6 +912,10 @@ async def forget_model(
if alias not in self.models:
raise ModelNotFoundError(f"{alias} not found")

model_state: ModelState = self._models[alias]
if timeout is None and model_state.timeout:
timeout = model_state.timeout

with self.model_context(alias) as model:
await self.log_model()
model_name = model.info.name
Expand All @@ -896,13 +926,26 @@ async def forget_model(
if not self.keep_model:
await self._reset(model, allow_failure, timeout=timeout)
await self._controller.destroy_model(
model_name, force=True, destroy_storage=destroy_storage
model_name,
force=True,
destroy_storage=destroy_storage,
max_wait=timeout,
)
await model.disconnect()

async def model_alive():
return model_name in await self._controller.list_models()

if timeout and await model_alive():
log.warning("Waiting for model %s to leave...", model_name)
while await model_alive():
asyncio.sleep(5)

# stop managing this model now
log.info(f"Forgetting {alias}...")
log.info(f"Forgetting model {alias}...")
self._models.pop(alias)
if ops_cloud := self._clouds.get(model_state.cloud_name):
ops_cloud.models.remove(alias)
if alias is self.current_alias:
self._current_alias = None

Expand Down Expand Up @@ -933,7 +976,9 @@ async def _destroy(entity_name: str, **kwargs):

try:
await model.block_until(
lambda: len(model.machines) == 0 and len(model.applications) == 0,
lambda: len(model.units) == 0
and len(model.machines) == 0
and len(model.applications) == 0,
timeout=timeout,
)
except asyncio.TimeoutError:
Expand All @@ -948,10 +993,15 @@ async def _destroy(entity_name: str, **kwargs):
log.info(f"Reset {model.info.name} completed successfully.")

async def _cleanup_models(self):
# remove clouds from most recently made, to first made
# each model in the cloud will be forgotten
for cloud in reversed(self._clouds):
await self.forget_cloud(cloud)

# remove models from most recently made, to first made
aliases = list(reversed(self._models.keys()))
for models in aliases:
await self.forget_model(models)
for model in aliases:
await self.forget_model(model)

await self._controller.disconnect()

Expand Down Expand Up @@ -1491,3 +1541,62 @@ def is_crash_dump_enabled(self) -> bool:
return True
else:
return False

### Add K8S
async def add_k8s(self, config: K8sConfiguration, **kwargs) -> str:
controller = self._controller
cloud_name = self._generate_name("k8s-cloud")
log.info(f"Adding k8s cloud {cloud_name}")

cloud_def = client.Cloud(
auth_types=[
"certificate",
"clientcertificate",
"oauth2",
"oauth2withcert",
"userpass",
],
ca_certificates=[Path(config.ssl_ca_cert).read_text()],
endpoint=config.host,
host_cloud_region="kubernetes/ops-test",
regions=[client.CloudRegion(endpoint=config.host, name="k8s")],
skip_tls_verify=not config.verify_ssl,
type_="kubernetes",
)

if config.cert_file and config.key_file:
auth_type = "clientcertificate"
attrs = dict(
ClientCertificateData=Path(config.cert_file).read_text(),
ClientKeyData=Path(config.key_file).read_text(),
)
elif token := config.api_key["authorization"]:
if token.startswith("Bearer "):
auth_type = "oauth2"
attrs = {"Token": token.split(" ")[1]}
elif token.startswith("Basic "):
auth_type, userpass = "userpass", token.split(" ")[1]
user, passwd = base64.b64decode(userpass).decode().split(":")
attrs = {"username": user, "password": passwd}
else:
raise ValueError("Failed to find credentials in authorization token")
else:
raise ValueError("Failed to find credentials in kubernetes.Configuration")

await controller.add_cloud(cloud_name, cloud_def)
await controller.add_credential(
cloud_name,
credential=client.CloudCredential(attrs, auth_type),
cloud=cloud_name,
)
self._clouds[cloud_name] = CloudState(cloud_name, timeout=5 * 60)
return cloud_name

async def forget_cloud(self, cloud_name: str):
if cloud_name not in self._clouds:
raise KeyError(f"{cloud_name} not in clouds")
for model in reversed(self._clouds[cloud_name].models):
await self.forget_model(model)
log.info(f"Forgetting cloud: {cloud_name}...")
await self._controller.remove_cloud(cloud_name)
del self._clouds[cloud_name]
15 changes: 15 additions & 0 deletions tests/integration/test_opstest_add_k8s.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# test that pytest operator supports adding a k8s to an existing controller
# This is a new k8s cloud created/managed by pytest-operator

from pytest_operator.plugin import OpsTest
from kubernetes import config as k8s_config
from kubernetes.client import Configuration

async def test_add_k8s_cloud(ops_test: OpsTest):
config = type.__call__(Configuration)
k8s_config.load_config(client_configuration=config)
k8s_cloud = await ops_test.add_k8s(config, skip_storage=True, storage_class=None)
k8s_model = await ops_test.track_model("secondary", cloud_name=k8s_cloud, keep=ops_test.ModelKeep.NEVER)
with ops_test.model_context("secondary"):
await k8s_model.deploy("coredns", trust=True)
await k8s_model.wait_for_idle(apps=["coredns"], status="active")
2 changes: 1 addition & 1 deletion tests/unit/test_pytest_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ async def test_fixture_set_up_automatic_model(

await ops_test._setup_model()
mock_juju.controller.add_model.assert_called_with(
model_name, "this-cloud", config=None
model_name, "this-cloud", credential_name=None, config=None
)
juju_cmd.assert_called_with(ops_test, "models")
assert ops_test.model == mock_juju.model
Expand Down

0 comments on commit d987b2b

Please sign in to comment.