Skip to content

Commit

Permalink
Add annotations for returned data (#972)
Browse files Browse the repository at this point in the history
Start to add type annotations within the tests
  • Loading branch information
ogenstad authored Sep 20, 2024
1 parent 1c2a88f commit 8d8001e
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 11 deletions.
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,6 @@ max-returns = 11
"ANN002", # Missing type annotation for `*args`
"ANN003", # Missing type annotation for `**kwargs`
"ANN201", # Missing return type annotation for public function
"ANN202", # Missing return type annotation for private function
"ANN206", # Missing return type annotation for classmethod
"ARG001", # Unused function argument
"B007", # Loop control variable `host` not used within loop body
"C414", # Unnecessary `list` call within `sorted()`
Expand Down
12 changes: 8 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import List
from typing import Any, Dict, List, Type, TypeVar, Union

import pytest
import ruamel.yaml
Expand All @@ -18,14 +18,16 @@
from nornir.core.state import GlobalState
from nornir.core.task import AggregatedResult, Task

ElementType = TypeVar("ElementType", bound=Union[Group, Host])

global_data = GlobalState(dry_run=True)


def inventory_from_yaml():
dir_path = os.path.dirname(os.path.realpath(__file__))
yml = ruamel.yaml.YAML(typ="safe")

def get_connection_options(data):
def get_connection_options(data) -> Dict[str, ConnectionOptions]:
cp = {}
for cn, c in data.items():
cp[cn] = ConnectionOptions(
Expand All @@ -38,7 +40,7 @@ def get_connection_options(data):
)
return cp

def get_defaults():
def get_defaults() -> Defaults:
defaults_file = f"{dir_path}/inventory_data/defaults.yaml"
with open(defaults_file, "r") as f:
defaults_dict = yml.load(f)
Expand All @@ -55,7 +57,9 @@ def get_defaults():
),
)

def get_inventory_element(typ, data, name, defaults):
def get_inventory_element(
typ: Type[ElementType], data: Dict[str, Any], name: str, defaults: Union[Defaults, None]
) -> ElementType:
return typ(
name=name,
hostname=data.get("hostname"),
Expand Down
4 changes: 2 additions & 2 deletions tests/core/test_InitNornir.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def test_InitNornir_different_transform_function_by_string_with_bad_options(self

class TestLogging:
@classmethod
def cleanup(cls):
def cleanup(cls) -> None:
# this does not work as setup_method, because pytest injects
# _pytest.logging.LogCaptureHandler handler to the root logger
# and StreamHandler to _pytest.capture.EncodedFile to other loggers
Expand All @@ -183,7 +183,7 @@ def cleanup(cls):
logger_.setLevel(logging.NOTSET)

@classmethod
def teardown_class(cls):
def teardown_class(cls) -> None:
cls.cleanup()

def test_InitNornir_logging_defaults(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def validate_params(task, conn, params, nornir_config):

class Test:
@classmethod
def setup_class(cls):
def setup_class(cls) -> None:
ConnectionPluginRegister.deregister_all()
ConnectionPluginRegister.register("dummy", DummyConnectionPlugin)
ConnectionPluginRegister.register("dummy2", DummyConnectionPlugin)
Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_inventory.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def test_filtering_func(self, inv):
)
assert long_names == ["dev1.group_1", "dev4.group_2", "dev6.group_3"]

def longer_than(dev, length):
def longer_than(dev, length) -> bool:
return len(dev["my_var"]) > length

long_names = sorted(list(inv.filter(filter_func=longer_than, length=20).hosts.keys()))
Expand Down
2 changes: 1 addition & 1 deletion tests/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def wrap_cli_test(output, save_output=False):
"""

@decorator
def run_test(func, *args, **kwargs):
def run_test(func, *args, **kwargs) -> None:
stdout = StringIO()
backup_stdout = sys.stdout
sys.stdout = stdout
Expand Down

0 comments on commit 8d8001e

Please sign in to comment.