Skip to content

Commit

Permalink
Lint
Browse files Browse the repository at this point in the history
  • Loading branch information
javitonino committed May 9, 2024
1 parent 7c7b72b commit 2e83ee7
Show file tree
Hide file tree
Showing 29 changed files with 203 additions and 149 deletions.
2 changes: 1 addition & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,4 @@ ignore_missing_imports = True
disable_error_code = arg-type, call-arg

[mypy-aioresponses]
ignore_missing_imports = True
ignore_missing_imports = True
9 changes: 7 additions & 2 deletions nucliadb/nucliadb/common/cluster/rebalance.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,13 @@ async def move_set_of_kb_resources(
for result in search_response.document.results:
resource_id = result.uuid
try:
async with datamanagers.with_transaction() as txn, locking.distributed_lock(
locking.RESOURCE_INDEX_LOCK.format(kbid=kbid, resource_id=resource_id)
async with (
datamanagers.with_transaction() as txn,
locking.distributed_lock(
locking.RESOURCE_INDEX_LOCK.format(
kbid=kbid, resource_id=resource_id
)
),
):
found_shard_id = await datamanagers.resources.get_resource_shard_id(
txn, kbid=kbid, rid=resource_id
Expand Down
2 changes: 1 addition & 1 deletion nucliadb/nucliadb/common/datamanagers/rollover.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ async def add_indexed(
kbid: str,
resource_id: str,
shard_id: str,
modification_time: int
modification_time: int,
) -> None:
to_index = KB_ROLLOVER_RESOURCES_TO_INDEX.format(kbid=kbid, resource=resource_id)
indexed = KB_ROLLOVER_RESOURCES_INDEXED.format(kbid=kbid, resource=resource_id)
Expand Down
7 changes: 4 additions & 3 deletions nucliadb/nucliadb/ingest/consumer/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,10 @@ async def subscription_worker(self, msg: Msg):
message_source = "<msg source not set>"
start = time.monotonic()

async with MessageProgressUpdater(
msg, nats_consumer_settings.nats_ack_wait * 0.66
), self.lock:
async with (
MessageProgressUpdater(msg, nats_consumer_settings.nats_ack_wait * 0.66),
self.lock,
):
logger.info(
f"Message processing: subject:{subject}, seqid: {seqid}, reply: {reply}"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -672,8 +672,9 @@ async def test_ingest_autocommit_deadletter_marks_resource(
rid = str(uuid.uuid4())
message = make_message(kbid, rid)

with patch.object(processor, "notify_commit") as mock_notify, pytest.raises(
DeadletteredError
with (
patch.object(processor, "notify_commit") as mock_notify,
pytest.raises(DeadletteredError),
):
# cause an error to force deadletter handling
mock_notify.side_effect = Exception("test")
Expand Down
11 changes: 6 additions & 5 deletions nucliadb/nucliadb/ingest/tests/unit/consumer/test_auditing.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,12 @@ def shard_manager(reader):
nm = MagicMock()
node = MagicMock(reader=reader)
nm.get_shards_by_kbid = AsyncMock(return_value=[ShardObject()])
with patch(
"nucliadb.ingest.consumer.auditing.get_shard_manager", return_value=nm
), patch(
"nucliadb.ingest.consumer.auditing.choose_node",
return_value=(node, "shard_id"),
with (
patch("nucliadb.ingest.consumer.auditing.get_shard_manager", return_value=nm),
patch(
"nucliadb.ingest.consumer.auditing.choose_node",
return_value=(node, "shard_id"),
),
):
yield nm

Expand Down
20 changes: 12 additions & 8 deletions nucliadb/nucliadb/ingest/tests/unit/consumer/test_shard_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,18 @@ def shard_manager(reader):
shards = Shards(shards=[ShardObject(read_only=False)], actual=0)
sm.get_current_active_shard = AsyncMock(return_value=shards.shards[0])
sm.maybe_create_new_shard = AsyncMock()
with patch(
"nucliadb.ingest.consumer.shard_creator.get_shard_manager", return_value=sm
), patch(
"nucliadb.ingest.consumer.shard_creator.choose_node",
return_value=(node, "shard_id"),
), patch(
"nucliadb.ingest.consumer.shard_creator.locking.distributed_lock",
return_value=AsyncMock(),
with (
patch(
"nucliadb.ingest.consumer.shard_creator.get_shard_manager", return_value=sm
),
patch(
"nucliadb.ingest.consumer.shard_creator.choose_node",
return_value=(node, "shard_id"),
),
patch(
"nucliadb.ingest.consumer.shard_creator.locking.distributed_lock",
return_value=AsyncMock(),
),
):
yield sm

Expand Down
13 changes: 8 additions & 5 deletions nucliadb/nucliadb/ingest/tests/unit/service/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,11 +660,14 @@ async def test_Index(self, writer: WriterServicer):
request = writer_pb2.IndexResource(kbid="kbid", rid="rid")

txn = AsyncMock()
with patch(
"nucliadb.ingest.service.writer.get_partitioning"
) as get_partitioning, patch(
"nucliadb.ingest.service.writer.get_transaction_utility",
MagicMock(return_value=txn),
with (
patch(
"nucliadb.ingest.service.writer.get_partitioning"
) as get_partitioning,
patch(
"nucliadb.ingest.service.writer.get_transaction_utility",
MagicMock(return_value=txn),
),
):
resp = await writer.Index(request)

Expand Down
2 changes: 1 addition & 1 deletion nucliadb/nucliadb/search/api/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,5 @@ def fastapi_query(param: ParamDefault, default: Optional[Any] = _NOT_SET, **kw)
le=param.le,
gt=param.gt,
max_length=param.max_items,
**kw
**kw,
)
6 changes: 4 additions & 2 deletions nucliadb/nucliadb/search/requesters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import asyncio
import json
from enum import Enum
from typing import Any, Optional, TypeVar, Union, overload, Sequence
from typing import Any, Optional, Sequence, TypeVar, Union, overload

from fastapi import HTTPException
from google.protobuf.json_format import MessageToDict
Expand Down Expand Up @@ -130,7 +130,9 @@ async def node_query(
pb_query: REQUEST_TYPE,
target_shard_replicas: Optional[list[str]] = None,
use_read_replica_nodes: bool = True,
) -> tuple[Sequence[Union[T, BaseException]], bool, list[tuple[AbstractIndexNode, str]]]:
) -> tuple[
Sequence[Union[T, BaseException]], bool, list[tuple[AbstractIndexNode, str]]
]:
use_read_replica_nodes = use_read_replica_nodes and has_feature(
const.Features.READ_REPLICA_SEARCHES, context={"kbid": kbid}
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,11 @@ def _create_find_result(
@pytest.mark.asyncio
async def test_default_prompt_context(kb):
result_text = " ".join(["text"] * 10)
with patch("nucliadb.search.search.chat.prompt.get_read_only_transaction"), patch(
"nucliadb.search.search.chat.prompt.get_storage"
), patch("nucliadb.search.search.chat.prompt.KnowledgeBoxORM", return_value=kb):
with (
patch("nucliadb.search.search.chat.prompt.get_read_only_transaction"),
patch("nucliadb.search.search.chat.prompt.get_storage"),
patch("nucliadb.search.search.chat.prompt.KnowledgeBoxORM", return_value=kb),
):
context = chat_prompt.CappedPromptContext(max_size=int(1e6))
find_results = KnowledgeboxFindResults(
facets={},
Expand Down
5 changes: 3 additions & 2 deletions nucliadb/nucliadb/search/tests/unit/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@

@pytest.fixture(scope="function")
def run_fastapi_with_metrics():
with patch("nucliadb.search.run.run_fastapi_with_metrics") as mocked, patch(
"nucliadb.search.run.instrument_app"
with (
patch("nucliadb.search.run.run_fastapi_with_metrics") as mocked,
patch("nucliadb.search.run.instrument_app"),
):
yield mocked

Expand Down
16 changes: 8 additions & 8 deletions nucliadb/nucliadb/standalone/tests/unit/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@

@pytest.fixture(scope="function", autouse=True)
def mocked_deps():
with mock.patch("uvicorn.Server.run"), mock.patch(
"pydantic_argparse.ArgumentParser.parse_typed_args", return_value=Settings()
), mock.patch(
f"{STANDALONE_RUN}.get_latest_nucliadb", return_value="1.0.0"
), mock.patch(
"uvicorn.Server.startup"
), mock.patch(
f"{STANDALONE_RUN}.run_migrations"
with (
mock.patch("uvicorn.Server.run"),
mock.patch(
"pydantic_argparse.ArgumentParser.parse_typed_args", return_value=Settings()
),
mock.patch(f"{STANDALONE_RUN}.get_latest_nucliadb", return_value="1.0.0"),
mock.patch("uvicorn.Server.startup"),
mock.patch(f"{STANDALONE_RUN}.run_migrations"),
):
yield

Expand Down
9 changes: 6 additions & 3 deletions nucliadb/nucliadb/tests/migrations/test_migration_0018.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,12 @@ async def test_migration_0018_global(maindb_driver: Driver):
execution_context = Mock()
execution_context.kv_driver = maindb_driver

with patch("nucliadb.ingest.orm.knowledgebox.get_storage", new=AsyncMock()), patch(
"nucliadb.ingest.orm.knowledgebox.get_shard_manager",
new=Mock(return_value=AsyncMock()),
with (
patch("nucliadb.ingest.orm.knowledgebox.get_storage", new=AsyncMock()),
patch(
"nucliadb.ingest.orm.knowledgebox.get_shard_manager",
new=Mock(return_value=AsyncMock()),
),
):
# setup some orphan /kbslugs keys and some real ones
async with maindb_driver.transaction() as txn:
Expand Down
18 changes: 10 additions & 8 deletions nucliadb/nucliadb/tests/unit/common/cluster/discovery/test_k8s.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,16 @@ def writer_stub():
node_id="node_id", shard_count=1, available_disk=10, total_disk=10
)
)
with patch(
"nucliadb.common.cluster.discovery.base.nodewriter_pb2_grpc.NodeWriterStub",
return_value=writer_stub,
), patch(
"nucliadb.common.cluster.discovery.base.replication_pb2_grpc.ReplicationServiceStub",
return_value=writer_stub,
), patch(
"nucliadb.common.cluster.discovery.base.get_traced_grpc_channel"
with (
patch(
"nucliadb.common.cluster.discovery.base.nodewriter_pb2_grpc.NodeWriterStub",
return_value=writer_stub,
),
patch(
"nucliadb.common.cluster.discovery.base.replication_pb2_grpc.ReplicationServiceStub",
return_value=writer_stub,
),
patch("nucliadb.common.cluster.discovery.base.get_traced_grpc_channel"),
):
yield writer_stub

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@
@pytest.fixture
def cluster_settings():
settings = Settings()
with patch(
"nucliadb.common.cluster.standalone.service.cluster_settings", settings
), tempfile.TemporaryDirectory() as tmpdir:
with (
patch("nucliadb.common.cluster.standalone.service.cluster_settings", settings),
tempfile.TemporaryDirectory() as tmpdir,
):
settings.data_path = tmpdir
os.makedirs(os.path.join(tmpdir, "shards"))
yield settings
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@
@pytest.fixture
def cluster_settings():
settings = Settings()
with patch(
"nucliadb.common.cluster.standalone.utils.cluster_settings", settings
), tempfile.TemporaryDirectory() as tmpdir:
with (
patch("nucliadb.common.cluster.standalone.utils.cluster_settings", settings),
tempfile.TemporaryDirectory() as tmpdir,
):
settings.data_path = tmpdir
yield settings

Expand All @@ -52,8 +53,9 @@ def test_get_self_k8s_host(cluster_settings: Settings, monkeypatch):
monkeypatch.setenv("NUCLIADB_SERVICE_HOST", "host")
monkeypatch.setenv("HOSTNAME", "nucliadb-0")

with patch(
"nucliadb.common.cluster.standalone.grpc_node_binding.NodeWriter"
), patch("nucliadb.common.cluster.standalone.grpc_node_binding.NodeReader"):
with (
patch("nucliadb.common.cluster.standalone.grpc_node_binding.NodeWriter"),
patch("nucliadb.common.cluster.standalone.grpc_node_binding.NodeReader"),
):
# patch because loading settings validates address now
assert utils.get_self().address == "nucliadb-0.nucliadb"
16 changes: 10 additions & 6 deletions nucliadb/nucliadb/tests/unit/common/cluster/test_rollover.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,16 @@ async def _mock_indexed_keys(kbid):

mock.iter_indexed_keys = _mock_indexed_keys

with patch("nucliadb.common.cluster.rollover.datamanagers.rollover", mock), patch(
"nucliadb.common.cluster.rollover.datamanagers.with_transaction",
return_value=AsyncMock(),
), patch(
"nucliadb.ingest.consumer.shard_creator.locking.distributed_lock",
return_value=AsyncMock(),
with (
patch("nucliadb.common.cluster.rollover.datamanagers.rollover", mock),
patch(
"nucliadb.common.cluster.rollover.datamanagers.with_transaction",
return_value=AsyncMock(),
),
patch(
"nucliadb.ingest.consumer.shard_creator.locking.distributed_lock",
return_value=AsyncMock(),
),
):
yield mock

Expand Down
40 changes: 25 additions & 15 deletions nucliadb/nucliadb/tests/unit/common/maindb/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,46 +36,56 @@ def reset_driver_utils():
@pytest.mark.asyncio
async def test_setup_driver_redis():
mock = AsyncMock(initialized=False)
with patch.object(settings, "driver", "redis"), patch.object(
settings, "driver_redis_url", "driver_redis_url"
), patch("nucliadb.common.maindb.utils.RedisDriver", return_value=mock):
with (
patch.object(settings, "driver", "redis"),
patch.object(settings, "driver_redis_url", "driver_redis_url"),
patch("nucliadb.common.maindb.utils.RedisDriver", return_value=mock),
):
assert await setup_driver() == mock
mock.initialize.assert_awaited_once()


@pytest.mark.asyncio
async def test_setup_driver_tikv():
mock = AsyncMock(initialized=False)
with patch.object(settings, "driver", "tikv"), patch.object(
settings, "driver_tikv_url", "driver_tikv_url"
), patch("nucliadb.common.maindb.utils.TiKVDriver", return_value=mock):
with (
patch.object(settings, "driver", "tikv"),
patch.object(settings, "driver_tikv_url", "driver_tikv_url"),
patch("nucliadb.common.maindb.utils.TiKVDriver", return_value=mock),
):
assert await setup_driver() == mock
mock.initialize.assert_awaited_once()


@pytest.mark.asyncio
async def test_setup_driver_pg():
mock = AsyncMock(initialized=False)
with patch.object(settings, "driver", "pg"), patch.object(
settings, "driver_pg_url", "driver_pg_url"
), patch("nucliadb.common.maindb.utils.PGDriver", return_value=mock):
with (
patch.object(settings, "driver", "pg"),
patch.object(settings, "driver_pg_url", "driver_pg_url"),
patch("nucliadb.common.maindb.utils.PGDriver", return_value=mock),
):
assert await setup_driver() == mock
mock.initialize.assert_awaited_once()


@pytest.mark.asyncio
async def test_setup_driver_local():
mock = AsyncMock(initialized=False)
with patch.object(settings, "driver", "local"), patch.object(
settings, "driver_local_url", "driver_local_url"
), patch("nucliadb.common.maindb.utils.LocalDriver", return_value=mock):
with (
patch.object(settings, "driver", "local"),
patch.object(settings, "driver_local_url", "driver_local_url"),
patch("nucliadb.common.maindb.utils.LocalDriver", return_value=mock),
):
assert await setup_driver() == mock
mock.initialize.assert_awaited_once()


@pytest.mark.asyncio
async def test_setup_driver_error():
with patch.object(settings, "driver", "pg"), patch.object(
settings, "driver_pg_url", None
), pytest.raises(ConfigurationError):
with (
patch.object(settings, "driver", "pg"),
patch.object(settings, "driver_pg_url", None),
pytest.raises(ConfigurationError),
):
await setup_driver()
12 changes: 6 additions & 6 deletions nucliadb/nucliadb/tests/unit/http_clients/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ def test_check_status():


def test_get_processing_api_url():
with mock.patch.object(
nuclia_settings, "nuclia_service_account", "sa"
), mock.patch.object(
nuclia_settings, "nuclia_zone", "nuclia_zone"
), mock.patch.object(
nuclia_settings, "nuclia_public_url", "https://{zone}.nuclia_public_url"
with (
mock.patch.object(nuclia_settings, "nuclia_service_account", "sa"),
mock.patch.object(nuclia_settings, "nuclia_zone", "nuclia_zone"),
mock.patch.object(
nuclia_settings, "nuclia_public_url", "https://{zone}.nuclia_public_url"
),
):
assert (
processing.get_processing_api_url()
Expand Down
Loading

0 comments on commit 2e83ee7

Please sign in to comment.