Skip to content

Commit

Permalink
Merge pull request #190 from tonyandrewmeyer/remove-ops-testing-dep
Browse files Browse the repository at this point in the history
refactor: remove dependency on ops.testing
  • Loading branch information
tonyandrewmeyer authored Sep 19, 2024
2 parents 9013f19 + 3da5747 commit 2b968e5
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 33 deletions.
26 changes: 20 additions & 6 deletions scenario/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Type, Union, cast

import ops
import ops.testing

from scenario.errors import AlreadyEmittedError, ContextSetupError
from scenario.logger import logger as scenario_logger
Expand All @@ -28,8 +27,20 @@
)

if TYPE_CHECKING: # pragma: no cover
try:
from ops._private.harness import ExecArgs # type: ignore
except ImportError:
from ops.testing import ExecArgs # type: ignore

from scenario.ops_main_mock import Ops
from scenario.state import AnyJson, JujuLogLine, RelationBase, State, _EntityStatus
from scenario.state import (
AnyJson,
CharmType,
JujuLogLine,
RelationBase,
State,
_EntityStatus,
)

logger = scenario_logger.getChild("runtime")

Expand Down Expand Up @@ -426,7 +437,7 @@ def test_foo():

def __init__(
self,
charm_type: Type[ops.testing.CharmType],
charm_type: Type["CharmType"],
meta: Optional[Dict[str, Any]] = None,
*,
actions: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -491,7 +502,7 @@ def __init__(
self.charm_root = charm_root
self.juju_version = juju_version
if juju_version.split(".")[0] == "2":
logger.warn(
logger.warning(
"Juju 2.x is closed and unsupported. You may encounter inconsistencies.",
)

Expand All @@ -508,7 +519,7 @@ def __init__(
self.juju_log: List["JujuLogLine"] = []
self.app_status_history: List["_EntityStatus"] = []
self.unit_status_history: List["_EntityStatus"] = []
self.exec_history: Dict[str, List[ops.testing.ExecArgs]] = {}
self.exec_history: Dict[str, List["ExecArgs"]] = {}
self.workload_version_history: List[str] = []
self.removed_secret_revisions: List[int] = []
self.emitted_events: List[ops.EventBase] = []
Expand Down Expand Up @@ -644,7 +655,10 @@ def run(self, event: "_Event", state: "State") -> "State":
assert self._output_state is not None
if event.action:
if self._action_failure_message is not None:
raise ActionFailed(self._action_failure_message, self._output_state)
raise ActionFailed(
self._action_failure_message,
state=self._output_state,
)
return self._output_state

@contextmanager
Expand Down
51 changes: 35 additions & 16 deletions scenario/mocking.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#!/usr/bin/env python3
# Copyright 2023 Canonical Ltd.
# See LICENSE file for licensing details.

import datetime
import io
import shutil
from pathlib import Path
from typing import (
Expand All @@ -20,6 +22,12 @@
)

from ops import JujuVersion, pebble

try:
from ops._private.harness import ExecArgs, _TestingPebbleClient # type: ignore
except ImportError:
from ops.testing import ExecArgs, _TestingPebbleClient # type: ignore

from ops.model import CloudSpec as CloudSpec_Ops
from ops.model import ModelError
from ops.model import Port as Port_Ops
Expand All @@ -33,7 +41,6 @@
_ModelBackend,
)
from ops.pebble import Client, ExecError
from ops.testing import ExecArgs, _TestingPebbleClient

from scenario.errors import ActionMissingFromContextError
from scenario.logger import logger as scenario_logger
Expand Down Expand Up @@ -66,9 +73,9 @@ def __init__(
change_id: int,
args: ExecArgs,
return_code: int,
stdin: Optional[TextIO],
stdout: Optional[TextIO],
stderr: Optional[TextIO],
stdin: Optional[Union[TextIO, io.BytesIO]],
stdout: Optional[Union[TextIO, io.BytesIO]],
stderr: Optional[Union[TextIO, io.BytesIO]],
):
self._change_id = change_id
self._args = args
Expand Down Expand Up @@ -99,7 +106,12 @@ def wait_output(self):
stdout = self.stdout.read() if self.stdout is not None else None
stderr = self.stderr.read() if self.stderr is not None else None
if self._return_code != 0:
raise ExecError(list(self._args.command), self._return_code, stdout, stderr)
raise ExecError(
list(self._args.command),
self._return_code,
stdout, # type: ignore
stderr, # type: ignore
)
return stdout, stderr

def send_signal(self, sig: Union[int, str]): # noqa: U100
Expand Down Expand Up @@ -167,15 +179,18 @@ def get_pebble(self, socket_path: str) -> "Client":
# container not defined in state.
mounts = {}

return _MockPebbleClient(
socket_path=socket_path,
container_root=container_root,
mounts=mounts,
state=self._state,
event=self._event,
charm_spec=self._charm_spec,
context=self._context,
container_name=container_name,
return cast(
Client,
_MockPebbleClient(
socket_path=socket_path,
container_root=container_root,
mounts=mounts,
state=self._state,
event=self._event,
charm_spec=self._charm_spec,
context=self._context,
container_name=container_name,
),
)

def _get_relation_by_id(self, rel_id) -> "RelationBase":
Expand Down Expand Up @@ -616,7 +631,7 @@ def storage_add(self, name: str, count: int = 1):
)

if "/" in name:
# this error is raised by ops.testing but not by ops at runtime
# this error is raised by Harness but not by ops at runtime
raise ModelError('storage name cannot contain "/"')

self._context.requested_storages[name] = count
Expand Down Expand Up @@ -752,6 +767,10 @@ def __init__(

self._root = container_root

self._notices: Dict[Tuple[str, str], pebble.Notice] = {}
self._last_notice_id = 0
self._changes: Dict[str, pebble.Change] = {}

# load any existing notices and check information from the state
self._notices: Dict[Tuple[str, str], pebble.Notice] = {}
self._check_infos: Dict[str, pebble.CheckInfo] = {}
Expand Down Expand Up @@ -790,7 +809,7 @@ def _layers(self) -> Dict[str, pebble.Layer]:
def _service_status(self) -> Dict[str, pebble.ServiceStatus]:
return self._container.service_statuses

# Based on a method of the same name from ops.testing.
# Based on a method of the same name from Harness.
def _find_exec_handler(self, command) -> Optional["Exec"]:
handlers = {exec.command_prefix: exec for exec in self._container.execs}
# Start with the full command and, each loop iteration, drop the last
Expand Down
4 changes: 1 addition & 3 deletions scenario/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,8 @@
)

if TYPE_CHECKING: # pragma: no cover
from ops.testing import CharmType

from scenario.context import Context
from scenario.state import State, _CharmSpec, _Event
from scenario.state import CharmType, State, _CharmSpec, _Event

logger = scenario_logger.getChild("runtime")
STORED_STATE_REGEX = re.compile(
Expand Down
2 changes: 1 addition & 1 deletion scenario/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
class ActionFailed(Exception):
"""Raised at the end of the hook if the charm has called ``event.fail()``."""

def __init__(self, message: str, state: "State"):
def __init__(self, message: str, *, state: "State"):
self.message = message
self.state = state

Expand Down
4 changes: 1 addition & 3 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@
from scenario.context import _DEFAULT_JUJU_VERSION, Context

if TYPE_CHECKING: # pragma: no cover
from ops.testing import CharmType

from scenario.state import State, _Event
from scenario.state import CharmType, State, _Event

_CT = TypeVar("_CT", bound=Type[CharmType])

Expand Down
5 changes: 1 addition & 4 deletions tests/test_charm_spec_autoload.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
import importlib
import sys
import tempfile
from contextlib import contextmanager
from pathlib import Path
from typing import Type

import pytest
import yaml
from ops import CharmBase
from ops.testing import CharmType

from scenario import Context, Relation, State
from scenario.context import ContextSetupError
from scenario.state import MetadataNotFoundError, _CharmSpec
from scenario.state import CharmType, MetadataNotFoundError, _CharmSpec

CHARM = """
from ops import CharmBase
Expand Down

0 comments on commit 2b968e5

Please sign in to comment.