Skip to content

Commit

Permalink
Use new type annotation syntax (#1677)
Browse files Browse the repository at this point in the history
* Use new type annotation syntax

* Update setup.py for packages with support for 3.8

* Run pre-checks with all supported python versions
  • Loading branch information
jotare authored Dec 19, 2023
1 parent a2a9376 commit db85e1c
Show file tree
Hide file tree
Showing 70 changed files with 389 additions and 395 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/nucliadb_dataset.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
pre-checks:
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11"]
python-version: ["3.8", "3.9", "3.10", "3.11"]
uses: ./.github/workflows/_component_prechecks.yml
with:
python_version: "${{ matrix.python-version }}"
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/nucliadb_models.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ jobs:
name: NucliaDBModelsTests
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11"]
python-version: ["3.8", "3.9", "3.10", "3.11"]
uses: ./.github/workflows/_component_prechecks.yml
with:
python_version: "${{ matrix.python-version }}"
component: "nucliadb_models"
component: "nucliadb_models"
2 changes: 1 addition & 1 deletion .github/workflows/nucliadb_sdk.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
pre-checks:
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11"]
python-version: ["3.8", "3.9", "3.10", "3.11"]
uses: ./.github/workflows/_component_prechecks.yml
with:
python_version: "${{ matrix.python-version }}"
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/nucliadb_utils.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
pre-checks:
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11"]
python-version: ["3.8", "3.9", "3.10", "3.11"]
uses: ./.github/workflows/_component_prechecks.yml
with:
python_version: "${{ matrix.python-version }}"
Expand Down
6 changes: 3 additions & 3 deletions nucliadb/nucliadb/common/cluster/grpc_node_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
from typing import Any, Dict, List
from typing import Any

from nucliadb_protos.nodereader_pb2 import (
EdgeList,
Expand All @@ -39,7 +39,7 @@


class DummyWriterStub: # pragma: no cover
calls: Dict[str, List[Any]] = {}
calls: dict[str, list[Any]] = {}

async def NewShard(self, data): # pragma: no cover
self.calls.setdefault("NewShard", []).append(data)
Expand Down Expand Up @@ -90,7 +90,7 @@ async def GC(self, request: ShardId) -> EmptyResponse: # pragma: no cover


class DummyReaderStub: # pragma: no cover
calls: Dict[str, List[Any]] = {}
calls: dict[str, list[Any]] = {}

async def GetShard(self, data): # pragma: no cover
self.calls.setdefault("GetShard", []).append(data)
Expand Down
4 changes: 2 additions & 2 deletions nucliadb/nucliadb/common/maindb/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import asyncio
from contextlib import asynccontextmanager
from typing import AsyncGenerator, List, Optional
from typing import AsyncGenerator, Optional

DEFAULT_SCAN_LIMIT = 10
DEFAULT_BATCH_SCAN_LIMIT = 500
Expand Down Expand Up @@ -60,7 +60,7 @@ async def count(self, match: str) -> int:

class Driver:
initialized = False
_abort_tasks: List[asyncio.Task] = []
_abort_tasks: list[asyncio.Task] = []

async def initialize(self):
raise NotImplementedError()
Expand Down
8 changes: 4 additions & 4 deletions nucliadb/nucliadb/common/maindb/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#
import glob
import os
from typing import Dict, List, Optional
from typing import Optional

from nucliadb.common.maindb.driver import (
DEFAULT_BATCH_SCAN_LIMIT,
Expand All @@ -37,9 +37,9 @@


class LocalTransaction(Transaction):
modified_keys: Dict[str, bytes]
visited_keys: Dict[str, bytes]
deleted_keys: List[str]
modified_keys: dict[str, bytes]
visited_keys: dict[str, bytes]
deleted_keys: list[str]

def __init__(self, url: str, driver: Driver):
self.url = url
Expand Down
8 changes: 4 additions & 4 deletions nucliadb/nucliadb/common/maindb/pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from __future__ import annotations

import asyncio
from typing import Any, AsyncGenerator, List, Optional, Union
from typing import Any, AsyncGenerator, Optional, Union

import asyncpg
import backoff
Expand Down Expand Up @@ -69,7 +69,7 @@ async def delete(self, key: str) -> None:
async with self.lock:
await self.connection.execute("DELETE FROM resources WHERE key = $1", key)

async def batch_get(self, keys: List[str]) -> List[Optional[bytes]]:
async def batch_get(self, keys: list[str]) -> list[Optional[bytes]]:
async with self.lock:
records = {
record["key"]: record["value"]
Expand Down Expand Up @@ -146,7 +146,7 @@ async def commit(self):
self.open = False
await self.connection.close()

async def batch_get(self, keys: List[str]):
async def batch_get(self, keys: list[str]):
return await self.data_layer.batch_get(keys)

async def get(self, key: str) -> Optional[bytes]:
Expand Down Expand Up @@ -189,7 +189,7 @@ async def abort(self):
async def commit(self):
...

async def batch_get(self, keys: List[str]):
async def batch_get(self, keys: list[str]):
return await DataLayer(self.pool).batch_get(keys)

async def get(self, key: str) -> Optional[bytes]:
Expand Down
10 changes: 5 additions & 5 deletions nucliadb/nucliadb/common/maindb/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
from typing import Any, Dict, List, Optional
from typing import Any, Optional

from nucliadb.common.maindb.driver import (
DEFAULT_BATCH_SCAN_LIMIT,
Expand All @@ -35,9 +35,9 @@


class RedisTransaction(Transaction):
modified_keys: Dict[str, bytes]
visited_keys: Dict[str, bytes]
deleted_keys: List[str]
modified_keys: dict[str, bytes]
visited_keys: dict[str, bytes]
deleted_keys: list[str]

def __init__(self, redis: Any, driver: Driver):
self.redis = redis
Expand Down Expand Up @@ -84,7 +84,7 @@ async def batch_get(self, keys: list[str]) -> list[Optional[bytes]]:
if len(keys) == 0:
return []

bytes_keys: List[bytes] = [x.encode() for x in keys]
bytes_keys: list[bytes] = [x.encode() for x in keys]
results = await self.redis.mget(bytes_keys)

for idx, key in enumerate(keys):
Expand Down
4 changes: 2 additions & 2 deletions nucliadb/nucliadb/common/maindb/tikv.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import asyncio
import logging
from typing import Any, List, Optional
from typing import Any, Optional

import backoff

Expand Down Expand Up @@ -221,7 +221,7 @@ async def count(self, match: str) -> int:
class TiKVDriver(Driver):
tikv = None

def __init__(self, url: List[str]):
def __init__(self, url: list[str]):
if TiKV is False:
raise ImportError("TiKV is not installed")
self.url = url
Expand Down
4 changes: 2 additions & 2 deletions nucliadb/nucliadb/ingest/consumer/pull.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#
import asyncio
import base64
from typing import List, Optional
from typing import Optional

import nats
import nats.errors
Expand Down Expand Up @@ -48,7 +48,7 @@ class PullWorker:
The processing pull endpoint is also described as the "processing proxy" at times.
"""

subscriptions: List[Subscription]
subscriptions: list[Subscription]

def __init__(
self,
Expand Down
10 changes: 5 additions & 5 deletions nucliadb/nucliadb/ingest/fields/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import enum
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import Any, Optional, Type

from nucliadb_protos.resources_pb2 import (
CloudFile,
Expand Down Expand Up @@ -285,7 +285,7 @@ async def get_extracted_text_cf(self) -> Optional[CloudFile]:

async def set_vectors(
self, payload: ExtractedVectorsWrapper
) -> Tuple[Optional[VectorObject], bool, List[str]]:
) -> tuple[Optional[VectorObject], bool, list[str]]:
if self.type in SUBFIELDFIELDS:
try:
actual_payload: Optional[VectorObject] = await self.get_vectors(
Expand Down Expand Up @@ -341,7 +341,7 @@ async def get_vectors(self, force=False) -> Optional[VectorObject]:

async def set_user_vectors(
self, user_vectors: UserVectorsWrapper
) -> Tuple[UserVectorSet, Dict[str, UserVectorsList]]:
) -> tuple[UserVectorSet, dict[str, UserVectorsList]]:
try:
actual_payload: Optional[UserVectorSet] = await self.get_user_vectors(
force=True
Expand All @@ -351,7 +351,7 @@ async def set_user_vectors(

sf = self.get_storage_field(FieldTypes.USER_FIELD_VECTORS)

vectors_to_delete: Dict[str, UserVectorsList] = {}
vectors_to_delete: dict[str, UserVectorsList] = {}
if actual_payload is not None:
for vectorset, user_vector in user_vectors.vectors.vectors.items():
for key, vector in user_vector.vectors.items():
Expand Down Expand Up @@ -392,7 +392,7 @@ async def get_vectors_cf(self) -> Optional[CloudFile]:

async def set_field_metadata(
self, payload: FieldComputedMetadataWrapper
) -> Tuple[FieldComputedMetadata, List[str], Dict[str, List[str]]]:
) -> tuple[FieldComputedMetadata, list[str], dict[str, list[str]]]:
if self.type in SUBFIELDFIELDS:
try:
actual_payload: Optional[
Expand Down
6 changes: 3 additions & 3 deletions nucliadb/nucliadb/ingest/fields/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
import uuid
from typing import Any, Dict, Optional
from typing import Any, Optional

from nucliadb_protos.resources_pb2 import CloudFile
from nucliadb_protos.resources_pb2 import Conversation as PBConversation
Expand All @@ -39,7 +39,7 @@ class PageNotFound(Exception):
class Conversation(Field):
pbklass = PBConversation
type: str = "c"
value: Dict[int, PBConversation]
value: dict[int, PBConversation]
metadata: Optional[FieldConversation]

_created: bool = False
Expand All @@ -49,7 +49,7 @@ def __init__(
id: str,
resource: Any,
pb: Optional[Any] = None,
value: Optional[Dict[int, PBConversation]] = None,
value: Optional[dict[int, PBConversation]] = None,
):
super(Conversation, self).__init__(id, resource, pb, value)
self.value = {}
Expand Down
20 changes: 10 additions & 10 deletions nucliadb/nucliadb/ingest/orm/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#

from typing import AsyncGenerator, Dict, List, Optional, Set, Tuple
from typing import AsyncGenerator, Optional

from nucliadb_protos.knowledgebox_pb2 import (
DeletedEntitiesGroups,
Expand Down Expand Up @@ -87,13 +87,13 @@ async def get_entities_group(self, group: str) -> Optional[EntitiesGroup]:
return None
return await self.get_entities_group_inner(group)

async def get_entities_groups(self) -> Dict[str, EntitiesGroup]:
async def get_entities_groups(self) -> dict[str, EntitiesGroup]:
groups = {}
async for group, eg in self.iterate_entities_groups(exclude_deleted=True):
groups[group] = eg
return groups

async def list_entities_groups(self) -> Dict[str, EntitiesGroupSummary]:
async def list_entities_groups(self) -> dict[str, EntitiesGroupSummary]:
groups = {}
async for group in self.iterate_entities_groups_names(exclude_deleted=True):
stored = await self.get_stored_entities_group(group)
Expand All @@ -107,7 +107,7 @@ async def list_entities_groups(self) -> Dict[str, EntitiesGroupSummary]:
groups[group] = EntitiesGroupSummary()
return groups

async def update_entities(self, group: str, entities: Dict[str, Entity]):
async def update_entities(self, group: str, entities: dict[str, Entity]):
"""Update entities on an entity group. New entities are appended and existing
are overwriten. Existing entities not appearing in `entities` are left
intact. Use `delete_entities` to delete them instead.
Expand Down Expand Up @@ -157,7 +157,7 @@ async def set_entities_group_metadata(

await self.store_entities_group(group, entities_group)

async def delete_entities(self, group: str, delete: List[str]):
async def delete_entities(self, group: str, delete: list[str]):
stored = await self.get_stored_entities_group(group)

stored = stored or EntitiesGroup()
Expand Down Expand Up @@ -229,8 +229,8 @@ async def do_entities_search(
eg = EntitiesGroup(entities=entities)
return eg

async def get_deleted_entities_groups(self) -> Set[str]:
deleted: Set[str] = set()
async def get_deleted_entities_groups(self) -> set[str]:
deleted: set[str] = set()
key = KB_DELETED_ENTITIES_GROUPS.format(kbid=self.kbid)
payload = await self.txn.get(key)
if payload:
Expand All @@ -252,7 +252,7 @@ async def entities_group_exists(self, group: str) -> bool:

async def iterate_entities_groups(
self, exclude_deleted: bool
) -> AsyncGenerator[Tuple[str, EntitiesGroup], None]:
) -> AsyncGenerator[tuple[str, EntitiesGroup], None]:
async for group in self.iterate_entities_groups_names(exclude_deleted):
eg = await self.get_entities_group_inner(group)
if eg is None:
Expand Down Expand Up @@ -284,7 +284,7 @@ async def iterate_entities_groups_names(
yield group
visited_groups.add(group)

async def get_indexed_entities_groups_names(self) -> Set[str]:
async def get_indexed_entities_groups_names(self) -> set[str]:
shard_manager = get_shard_manager()

async def query_indexed_entities_group_names(
Expand Down Expand Up @@ -367,7 +367,7 @@ def merge_entities_groups(indexed: EntitiesGroup, stored: EntitiesGroup):
`indexed` share entities. That's also true for common fields.
"""
merged_entities: Dict[str, Entity] = {}
merged_entities: dict[str, Entity] = {}
merged_entities.update(indexed.entities)
merged_entities.update(stored.entities)

Expand Down
8 changes: 4 additions & 4 deletions nucliadb/nucliadb/ingest/orm/knowledgebox.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
from datetime import datetime
from typing import AsyncGenerator, AsyncIterator, Optional, Sequence, Tuple
from typing import AsyncGenerator, AsyncIterator, Optional, Sequence
from uuid import uuid4

from grpc import StatusCode
Expand Down Expand Up @@ -161,7 +161,7 @@ async def get_kb_uuid(cls, txn: Transaction, slug: str) -> Optional[str]:
@classmethod
async def get_kbs(
cls, txn: Transaction, slug: str, count: int = -1
) -> AsyncIterator[Tuple[str, str]]:
) -> AsyncIterator[tuple[str, str]]:
async for key in txn.keys(KB_SLUGS.format(slug=slug), count=count):
slug = key.replace(KB_SLUGS_BASE, "")
uuid = await cls.get_kb_uuid(txn, slug)
Expand All @@ -179,7 +179,7 @@ async def create(
uuid: Optional[str] = None,
config: Optional[KnowledgeBoxConfig] = None,
release_channel: ReleaseChannel.ValueType = ReleaseChannel.STABLE,
) -> Tuple[str, bool]:
) -> tuple[str, bool]:
failed = False
exist = await cls.get_kb_uuid(txn, slug)
if exist:
Expand Down Expand Up @@ -272,7 +272,7 @@ async def update(

return uuid

async def iterate_kb_nodes(self) -> AsyncIterator[Tuple[AbstractIndexNode, str]]:
async def iterate_kb_nodes(self) -> AsyncIterator[tuple[AbstractIndexNode, str]]:
shards_obj = await self.data_manager.get_shards_object(self.kbid)

for shard in shards_obj.shards:
Expand Down
Loading

0 comments on commit db85e1c

Please sign in to comment.