diff --git a/src/ralph/backends/data/async_es.py b/src/ralph/backends/data/async_es.py index 191addd41..d156aa4e3 100644 --- a/src/ralph/backends/data/async_es.py +++ b/src/ralph/backends/data/async_es.py @@ -3,7 +3,7 @@ import logging from io import IOBase from itertools import chain -from typing import Iterable, Iterator, Optional, Union +from typing import Iterable, Iterator, Optional, TypeVar, Union from elasticsearch import ApiError, AsyncElasticsearch, TransportError from elasticsearch.helpers import BulkIndexError, async_streaming_bulk @@ -21,23 +21,24 @@ from ralph.utils import parse_bytes_to_dict, read_raw logger = logging.getLogger(__name__) +Settings = TypeVar("Settings", bound=ESDataBackendSettings) -class AsyncESDataBackend(BaseAsyncDataBackend, AsyncWritable, AsyncListable): +class AsyncESDataBackend( + BaseAsyncDataBackend[Settings, ESQuery], AsyncWritable, AsyncListable +): """Asynchronous Elasticsearch data backend.""" name = "async_es" - query_class = ESQuery - settings_class = ESDataBackendSettings - def __init__(self, settings: Optional[ESDataBackendSettings] = None): + def __init__(self, settings: Optional[Settings] = None): """Instantiate the asynchronous Elasticsearch client. Args: settings (ESDataBackendSettings or None): The data backend settings. If `settings` is `None`, a default settings instance is used instead. """ - self.settings = settings if settings else self.settings_class() + super().__init__(settings) self._client = None @property diff --git a/src/ralph/backends/data/async_mongo.py b/src/ralph/backends/data/async_mongo.py index 50109cef9..02272b336 100644 --- a/src/ralph/backends/data/async_mongo.py +++ b/src/ralph/backends/data/async_mongo.py @@ -4,7 +4,7 @@ import logging from io import IOBase from itertools import chain -from typing import Any, Dict, Iterable, Iterator, Optional, Union +from typing import Any, Dict, Iterable, Iterator, Optional, TypeVar, Union from bson.errors import BSONError from motor.motor_asyncio import AsyncIOMotorClient @@ -29,22 +29,25 @@ ) logger = logging.getLogger(__name__) +Settings = TypeVar("Settings", bound=MongoDataBackendSettings) -class AsyncMongoDataBackend(BaseAsyncDataBackend, AsyncWritable, AsyncListable): +class AsyncMongoDataBackend( + BaseAsyncDataBackend[Settings, MongoQuery], + AsyncWritable, + AsyncListable, +): """Async MongoDB data backend.""" name = "async_mongo" - query_class = MongoQuery - settings_class = MongoDataBackendSettings - def __init__(self, settings: Optional[MongoDataBackendSettings] = None): + def __init__(self, settings: Optional[Settings] = None): """Instantiate the asynchronous MongoDB client. Args: settings (MongoDataBackendSettings or None): The data backend settings. """ - self.settings = settings if settings else self.settings_class() + super().__init__(settings) self.client = AsyncIOMotorClient( self.settings.CONNECTION_URI, **self.settings.CLIENT_OPTIONS.dict() ) diff --git a/src/ralph/backends/data/base.py b/src/ralph/backends/data/base.py index b505dbd45..fd60a0d6e 100644 --- a/src/ralph/backends/data/base.py +++ b/src/ralph/backends/data/base.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from enum import Enum, unique from io import IOBase -from typing import Iterable, Iterator, Optional, Union +from typing import Any, Generic, Iterable, Iterator, Optional, Type, TypeVar, Union from pydantic import BaseModel, BaseSettings, ValidationError @@ -34,7 +34,7 @@ class Config: extra = "forbid" - query_string: Union[str, None] + query_string: Union[str, None] = None @unique @@ -144,25 +144,65 @@ def list( """ -class BaseDataBackend(ABC): +def get_backend_generic_argument(backend_class: Type, position: int) -> Optional[Type]: + """Return the generic argument of `backend_class` at specified `position`.""" + if not hasattr(backend_class, "__orig_bases__"): + return None + + bases = backend_class.__orig_bases__[0] + if not hasattr(bases, "__args__") or len(bases.__args__) < abs(position) + 1: + return None + + argument = bases.__args__[position] + if argument is Any: + return None + + if isinstance(argument, TypeVar): + return argument.__bound__ + + return argument + + +def set_backend_settings_class(backend_class: Type): + """Set `settings_class` attribute with `Config.env_prefix` for `backend_class`.""" + settings_class = get_backend_generic_argument(backend_class, 0) + if settings_class: + backend_class.settings_class = settings_class + + +def set_backend_query_class(backend_class: Type): + """Set `query_class` attribute for `backend_class`.""" + query_class = get_backend_generic_argument(backend_class, 1) + if query_class: + backend_class.query_class = query_class + + +Settings = TypeVar("Settings", bound=BaseDataBackendSettings) +Query = TypeVar("Query", bound=BaseQuery) + + +class BaseDataBackend(Generic[Settings, Query], ABC): """Base data backend interface.""" name = "base" - query_class = BaseQuery - settings_class = BaseDataBackendSettings + query_class: Type[Query] + settings_class: Type[Settings] - @abstractmethod - def __init__(self, settings: Optional[BaseDataBackendSettings] = None): + def __init_subclass__(cls, **kwargs): # noqa: D105 + super().__init_subclass__(**kwargs) + set_backend_settings_class(cls) + set_backend_query_class(cls) + + def __init__(self, settings: Optional[Settings] = None): """Instantiate the data backend. Args: - settings (BaseDataBackendSettings or None): The data backend settings. + settings (Settings or None): The data backend settings. If `settings` is `None`, a default settings instance is used instead. """ + self.settings: Settings = settings if settings else self.settings_class() - def validate_query( - self, query: Union[str, dict, BaseQuery, None] = None - ) -> BaseQuery: + def validate_query(self, query: Union[str, dict, Query, None] = None) -> Query: """Validate and transform the query.""" if query is None: query = self.query_class() @@ -203,7 +243,7 @@ def status(self) -> DataBackendStatus: def read( # noqa: PLR0913 self, *, - query: Optional[Union[str, BaseQuery]] = None, + query: Optional[Union[str, Query]] = None, target: Optional[str] = None, chunk_size: Optional[int] = None, raw_output: bool = False, @@ -212,7 +252,7 @@ def read( # noqa: PLR0913 """Read records matching the `query` in the `target` container and yield them. Args: - query: (str or BaseQuery): The query to select records to read. + query: (str or Query): The query to select records to read. target (str or None): The target container name. If `target` is `None`, a default value is used instead. chunk_size (int or None): The number of records or bytes to read in one @@ -324,21 +364,26 @@ async def list( """ -class BaseAsyncDataBackend(ABC): +class BaseAsyncDataBackend(Generic[Settings, Query], ABC): """Base async data backend interface.""" name = "base" - query_class = BaseQuery - settings_class = BaseDataBackendSettings + query_class: Type[Query] + settings_class: Type[Settings] - @abstractmethod - def __init__(self, settings: Optional[BaseDataBackendSettings] = None): + def __init_subclass__(cls, **kwargs): # noqa: D105 + super().__init_subclass__(**kwargs) + set_backend_settings_class(cls) + set_backend_query_class(cls) + + def __init__(self, settings: Optional[Settings] = None): """Instantiate the data backend. Args: - settings (BaseDataBackendSettings or None): The backend settings. + settings (Settings or None): The backend settings. If `settings` is `None`, a default settings instance is used instead. """ + self.settings: Settings = settings if settings else self.settings_class() def validate_query( self, query: Union[str, dict, BaseQuery, None] = None @@ -383,7 +428,7 @@ async def status(self) -> DataBackendStatus: async def read( # noqa: PLR0913 self, *, - query: Optional[Union[str, BaseQuery]] = None, + query: Optional[Union[str, Query]] = None, target: Optional[str] = None, chunk_size: Optional[int] = None, raw_output: bool = False, @@ -392,7 +437,7 @@ async def read( # noqa: PLR0913 """Read records matching the `query` in the `target` container and yield them. Args: - query: (str or BaseQuery): The query to select records to read. + query: (str or Query): The query to select records to read. target (str or None): The target container name. If `target` is `None`, a default value is used instead. chunk_size (int or None): The number of records or bytes to read in one diff --git a/src/ralph/backends/data/clickhouse.py b/src/ralph/backends/data/clickhouse.py index ea75a33b1..5373da0db 100755 --- a/src/ralph/backends/data/clickhouse.py +++ b/src/ralph/backends/data/clickhouse.py @@ -14,6 +14,7 @@ List, NamedTuple, Optional, + TypeVar, Union, ) from uuid import UUID, uuid4 @@ -108,22 +109,27 @@ class ClickHouseQuery(BaseClickHouseQuery): query_string: Union[Json[BaseClickHouseQuery], None] -class ClickHouseDataBackend(BaseDataBackend, Writable, Listable): +Settings = TypeVar("Settings", bound=ClickHouseDataBackendSettings) + + +class ClickHouseDataBackend( + BaseDataBackend[Settings, ClickHouseQuery], + Writable, + Listable, +): """ClickHouse database backend.""" name = "clickhouse" - query_class = ClickHouseQuery default_operation_type = BaseOperationType.CREATE - settings_class = ClickHouseDataBackendSettings - def __init__(self, settings: Optional[ClickHouseDataBackendSettings] = None): + def __init__(self, settings: Optional[Settings] = None): """Instantiate the ClickHouse configuration. Args: settings (ClickHouseDataBackendSettings or None): The ClickHouse data backend settings. """ - self.settings = settings if settings else self.settings_class() + super().__init__(settings) self.database = self.settings.DATABASE self.event_table_name = self.settings.EVENT_TABLE_NAME self.default_chunk_size = self.settings.DEFAULT_CHUNK_SIZE diff --git a/src/ralph/backends/data/es.py b/src/ralph/backends/data/es.py index 03bb96b13..49dcb4253 100644 --- a/src/ralph/backends/data/es.py +++ b/src/ralph/backends/data/es.py @@ -4,7 +4,7 @@ from io import IOBase from itertools import chain from pathlib import Path -from typing import Iterable, Iterator, List, Literal, Optional, Union +from typing import Iterable, Iterator, List, Literal, Optional, TypeVar, Union from elasticsearch import ApiError, Elasticsearch, TransportError from elasticsearch.helpers import BulkIndexError, streaming_bulk @@ -111,21 +111,22 @@ class ESQuery(BaseQuery): track_total_hits: Literal[False] = False -class ESDataBackend(BaseDataBackend, Writable, Listable): +Settings = TypeVar("Settings", bound=ESDataBackendSettings) + + +class ESDataBackend(BaseDataBackend[Settings, ESQuery], Writable, Listable): """Elasticsearch data backend.""" name = "es" - query_class = ESQuery - settings_class = ESDataBackendSettings - def __init__(self, settings: Optional[ESDataBackendSettings] = None): + def __init__(self, settings: Optional[Settings] = None): """Instantiate the Elasticsearch data backend. Args: - settings (ESDataBackendSettings or None): The data backend settings. + settings (Settings or None): The data backend settings. If `settings` is `None`, a default settings instance is used instead. """ - self.settings = settings if settings else self.settings_class() + super().__init__(settings) self._client = None @property diff --git a/src/ralph/backends/data/fs.py b/src/ralph/backends/data/fs.py index 272cc4990..e3d160ee9 100644 --- a/src/ralph/backends/data/fs.py +++ b/src/ralph/backends/data/fs.py @@ -7,7 +7,7 @@ from io import IOBase from itertools import chain from pathlib import Path -from typing import IO, Iterable, Iterator, Optional, Union +from typing import IO, Iterable, Iterator, Optional, TypeVar, Union from uuid import uuid4 from ralph.backends.data.base import ( @@ -51,20 +51,28 @@ class Config(BaseSettingsConfig): LOCALE_ENCODING: str = "utf8" -class FSDataBackend(HistoryMixin, BaseDataBackend, Writable, Listable): +Settings = TypeVar("Settings", bound=FSDataBackendSettings) + + +class FSDataBackend( + BaseDataBackend[Settings, BaseQuery], + Writable, + Listable, + HistoryMixin, +): """FileSystem data backend.""" name = "fs" default_operation_type = BaseOperationType.CREATE - settings_class = FSDataBackendSettings - def __init__(self, settings: Optional[FSDataBackendSettings] = None): + def __init__(self, settings: Optional[Settings] = None): """Create the default target directory if it does not exist. Args: settings (FSDataBackendSettings or None): The data backend settings. If `settings` is `None`, a default settings instance is used instead. """ + super().__init__(settings) self.settings = settings if settings else self.settings_class() self.default_chunk_size = self.settings.DEFAULT_CHUNK_SIZE self.default_directory = self.settings.DEFAULT_DIRECTORY_PATH diff --git a/src/ralph/backends/data/ldp.py b/src/ralph/backends/data/ldp.py index 933e38fdf..adf194fe0 100644 --- a/src/ralph/backends/data/ldp.py +++ b/src/ralph/backends/data/ldp.py @@ -57,11 +57,14 @@ class Config(BaseSettingsConfig): SERVICE_NAME: Optional[str] = None -class LDPDataBackend(HistoryMixin, BaseDataBackend, Listable): +class LDPDataBackend( + BaseDataBackend[LDPDataBackendSettings, BaseQuery], + Listable, + HistoryMixin, +): """OVH LDP (Log Data Platform) data backend.""" name = "ldp" - settings_class = LDPDataBackendSettings def __init__(self, settings: Optional[LDPDataBackendSettings] = None): """Instantiate the OVH LDP client. @@ -70,7 +73,7 @@ def __init__(self, settings: Optional[LDPDataBackendSettings] = None): settings (LDPDataBackendSettings or None): The data backend settings. If `settings` is `None`, a default settings instance is used instead. """ - self.settings = settings if settings else self.settings_class() + super().__init__(settings) self.service_name = self.settings.SERVICE_NAME self.stream_id = self.settings.DEFAULT_STREAM_ID self.timeout = self.settings.REQUEST_TIMEOUT diff --git a/src/ralph/backends/data/mongo.py b/src/ralph/backends/data/mongo.py index a9f266760..8f38aa64b 100644 --- a/src/ralph/backends/data/mongo.py +++ b/src/ralph/backends/data/mongo.py @@ -7,7 +7,7 @@ import struct from io import IOBase from itertools import chain -from typing import Generator, Iterable, Iterator, List, Optional, Tuple, Union +from typing import Generator, Iterable, Iterator, List, Optional, Tuple, TypeVar, Union from uuid import uuid4 from bson.errors import BSONError @@ -90,21 +90,22 @@ class MongoQuery(BaseMongoQuery): query_string: Union[Json[BaseMongoQuery], None] -class MongoDataBackend(BaseDataBackend, Writable, Listable): +Settings = TypeVar("Settings", bound=MongoDataBackendSettings) + + +class MongoDataBackend(BaseDataBackend[Settings, MongoQuery], Writable, Listable): """MongoDB data backend.""" name = "mongo" - query_class = MongoQuery - settings_class = MongoDataBackendSettings - def __init__(self, settings: Optional[MongoDataBackendSettings] = None): + def __init__(self, settings: Optional[Settings] = None): """Instantiate the MongoDB client. Args: settings (MongoDataBackendSettings or None): The data backend settings. If `settings` is `None`, a default settings instance is used instead. """ - self.settings = settings if settings else self.settings_class() + super().__init__(settings) self.client = MongoClient( self.settings.CONNECTION_URI, **self.settings.CLIENT_OPTIONS.dict() ) diff --git a/src/ralph/backends/data/s3.py b/src/ralph/backends/data/s3.py index 723d2ac4c..c04ca5b73 100644 --- a/src/ralph/backends/data/s3.py +++ b/src/ralph/backends/data/s3.py @@ -67,17 +67,17 @@ class Config(BaseSettingsConfig): LOCALE_ENCODING: str = "utf8" -class S3DataBackend(HistoryMixin, BaseDataBackend, Writable, Listable): +class S3DataBackend( + BaseDataBackend[S3DataBackendSettings, BaseQuery], Writable, Listable, HistoryMixin +): """S3 data backend.""" name = "s3" default_operation_type = BaseOperationType.CREATE - settings_class = S3DataBackendSettings def __init__(self, settings: Optional[S3DataBackendSettings] = None): """Instantiate the AWS S3 client.""" - self.settings = settings if settings else self.settings_class() - + super().__init__(settings) self.default_bucket_name = self.settings.DEFAULT_BUCKET_NAME self.default_chunk_size = self.settings.DEFAULT_CHUNK_SIZE self.locale_encoding = self.settings.LOCALE_ENCODING diff --git a/src/ralph/backends/data/swift.py b/src/ralph/backends/data/swift.py index 50d2e465e..77cc52817 100644 --- a/src/ralph/backends/data/swift.py +++ b/src/ralph/backends/data/swift.py @@ -64,17 +64,20 @@ class Config(BaseSettingsConfig): LOCALE_ENCODING: str = "utf8" -class SwiftDataBackend(HistoryMixin, BaseDataBackend, Writable, Listable): +class SwiftDataBackend( + BaseDataBackend[SwiftDataBackendSettings, BaseQuery], + HistoryMixin, + Writable, + Listable, +): """SWIFT data backend.""" name = "swift" default_operation_type = BaseOperationType.CREATE - settings_class = SwiftDataBackendSettings def __init__(self, settings: Optional[SwiftDataBackendSettings] = None): """Prepares the options for the SwiftService.""" - self.settings = settings if settings else self.settings_class() - + super().__init__(settings) self.default_container = self.settings.DEFAULT_CONTAINER self.locale_encoding = self.settings.LOCALE_ENCODING self._connection = None diff --git a/src/ralph/backends/lrs/async_es.py b/src/ralph/backends/lrs/async_es.py index e9f9ffc9f..f782c5a69 100644 --- a/src/ralph/backends/lrs/async_es.py +++ b/src/ralph/backends/lrs/async_es.py @@ -9,17 +9,15 @@ RalphStatementsQuery, StatementQueryResult, ) -from ralph.backends.lrs.es import ESLRSBackend +from ralph.backends.lrs.es import ESLRSBackend, ESLRSBackendSettings from ralph.exceptions import BackendException, BackendParameterException logger = logging.getLogger(__name__) -class AsyncESLRSBackend(BaseAsyncLRSBackend, AsyncESDataBackend): +class AsyncESLRSBackend(BaseAsyncLRSBackend[ESLRSBackendSettings], AsyncESDataBackend): """Asynchronous Elasticsearch LRS backend implementation.""" - settings_class = ESLRSBackend.settings_class - async def query_statements( self, params: RalphStatementsQuery ) -> StatementQueryResult: diff --git a/src/ralph/backends/lrs/async_mongo.py b/src/ralph/backends/lrs/async_mongo.py index 11bdae079..4fbcb21e3 100644 --- a/src/ralph/backends/lrs/async_mongo.py +++ b/src/ralph/backends/lrs/async_mongo.py @@ -10,17 +10,17 @@ RalphStatementsQuery, StatementQueryResult, ) -from ralph.backends.lrs.mongo import MongoLRSBackend +from ralph.backends.lrs.mongo import MongoLRSBackend, MongoLRSBackendSettings from ralph.exceptions import BackendException, BackendParameterException logger = logging.getLogger(__name__) -class AsyncMongoLRSBackend(BaseAsyncLRSBackend, AsyncMongoDataBackend): +class AsyncMongoLRSBackend( + BaseAsyncLRSBackend[MongoLRSBackendSettings], AsyncMongoDataBackend +): """Async MongoDB LRS backend implementation.""" - settings_class = MongoLRSBackend.settings_class - async def query_statements( self, params: RalphStatementsQuery ) -> StatementQueryResult: diff --git a/src/ralph/backends/lrs/base.py b/src/ralph/backends/lrs/base.py index 5bdf26a89..c46e89972 100644 --- a/src/ralph/backends/lrs/base.py +++ b/src/ralph/backends/lrs/base.py @@ -3,7 +3,7 @@ from abc import abstractmethod from dataclasses import dataclass from datetime import datetime -from typing import Iterator, List, Literal, Optional, Union +from typing import Any, Iterator, List, Literal, Optional, TypeVar, Union from uuid import UUID from pydantic import BaseModel, Field, NonNegativeInt @@ -13,8 +13,8 @@ BaseDataBackend, BaseDataBackendSettings, BaseQuery, - BaseSettingsConfig, ) +from ralph.conf import BaseSettingsConfig from ralph.models.xapi.base.agents import BaseXapiAgent from ralph.models.xapi.base.common import IRI from ralph.models.xapi.base.groups import BaseXapiGroup @@ -84,10 +84,11 @@ class RalphStatementsQuery(LRSStatementsQuery): ignore_order: Optional[bool] -class BaseLRSBackend(BaseDataBackend): - """Base LRS backend interface.""" +Settings = TypeVar("Settings", bound=BaseLRSBackendSettings) + - settings_class = BaseLRSBackendSettings +class BaseLRSBackend(BaseDataBackend[Settings, Any]): + """Base LRS backend interface.""" @abstractmethod def query_statements(self, params: RalphStatementsQuery) -> StatementQueryResult: @@ -98,11 +99,9 @@ def query_statements_by_ids(self, ids: List[str]) -> Iterator[dict]: """Yield statements with matching ids from the backend.""" -class BaseAsyncLRSBackend(BaseAsyncDataBackend): +class BaseAsyncLRSBackend(BaseAsyncDataBackend[Settings, Any]): """Base async LRS backend interface.""" - settings_class = BaseLRSBackendSettings - @abstractmethod async def query_statements( self, params: RalphStatementsQuery diff --git a/src/ralph/backends/lrs/clickhouse.py b/src/ralph/backends/lrs/clickhouse.py index 6e16ae822..49aa3babd 100644 --- a/src/ralph/backends/lrs/clickhouse.py +++ b/src/ralph/backends/lrs/clickhouse.py @@ -3,7 +3,6 @@ import logging from typing import Generator, Iterator, List -from ralph.backends.data.base import BaseSettingsConfig from ralph.backends.data.clickhouse import ( ClickHouseDataBackend, ClickHouseDataBackendSettings, @@ -15,6 +14,7 @@ RalphStatementsQuery, StatementQueryResult, ) +from ralph.conf import BaseSettingsConfig from ralph.exceptions import BackendException, BackendParameterException logger = logging.getLogger(__name__) @@ -37,11 +37,11 @@ class Config(BaseSettingsConfig): IDS_CHUNK_SIZE: int = 10000 -class ClickHouseLRSBackend(BaseLRSBackend, ClickHouseDataBackend): +class ClickHouseLRSBackend( + BaseLRSBackend[ClickHouseLRSBackendSettings], ClickHouseDataBackend +): """ClickHouse LRS backend implementation.""" - settings_class = ClickHouseLRSBackendSettings - def query_statements(self, params: RalphStatementsQuery) -> StatementQueryResult: """Return the statements query payload using xAPI parameters.""" ch_params = params.dict(exclude_none=True) diff --git a/src/ralph/backends/lrs/es.py b/src/ralph/backends/lrs/es.py index 110f08fb8..5356f228d 100644 --- a/src/ralph/backends/lrs/es.py +++ b/src/ralph/backends/lrs/es.py @@ -3,7 +3,6 @@ import logging from typing import Iterator, List -from ralph.backends.data.base import BaseSettingsConfig from ralph.backends.data.es import ( ESDataBackend, ESDataBackendSettings, @@ -17,6 +16,7 @@ RalphStatementsQuery, StatementQueryResult, ) +from ralph.conf import BaseSettingsConfig from ralph.exceptions import BackendException, BackendParameterException logger = logging.getLogger(__name__) @@ -31,11 +31,9 @@ class Config(BaseSettingsConfig): env_prefix = "RALPH_BACKENDS__LRS__ES__" -class ESLRSBackend(BaseLRSBackend, ESDataBackend): +class ESLRSBackend(BaseLRSBackend[ESLRSBackendSettings], ESDataBackend): """Elasticsearch LRS backend implementation.""" - settings_class = ESLRSBackendSettings - def query_statements(self, params: RalphStatementsQuery) -> StatementQueryResult: """Return the statements query payload using xAPI parameters.""" query = self.get_query(params=params) diff --git a/src/ralph/backends/lrs/fs.py b/src/ralph/backends/lrs/fs.py index 4a7adc922..c61f909a5 100644 --- a/src/ralph/backends/lrs/fs.py +++ b/src/ralph/backends/lrs/fs.py @@ -6,7 +6,7 @@ from typing import Iterable, List, Literal, Optional, Union from uuid import UUID -from ralph.backends.data.base import BaseOperationType, BaseSettingsConfig +from ralph.backends.data.base import BaseOperationType from ralph.backends.data.fs import FSDataBackend, FSDataBackendSettings from ralph.backends.lrs.base import ( AgentParameters, @@ -15,6 +15,7 @@ RalphStatementsQuery, StatementQueryResult, ) +from ralph.conf import BaseSettingsConfig logger = logging.getLogger(__name__) @@ -34,11 +35,9 @@ class Config(BaseSettingsConfig): DEFAULT_LRS_FILE: str = "fs_lrs.jsonl" -class FSLRSBackend(BaseLRSBackend, FSDataBackend): +class FSLRSBackend(BaseLRSBackend[FSLRSBackendSettings], FSDataBackend): """FileSystem LRS Backend.""" - settings_class = FSLRSBackendSettings - def write( # noqa: PLR0913 self, data: Union[IOBase, Iterable[bytes], Iterable[dict]], diff --git a/src/ralph/backends/lrs/mongo.py b/src/ralph/backends/lrs/mongo.py index 0f9a989f6..2b31f04b1 100644 --- a/src/ralph/backends/lrs/mongo.py +++ b/src/ralph/backends/lrs/mongo.py @@ -6,7 +6,6 @@ from bson.objectid import ObjectId from pymongo import ASCENDING, DESCENDING -from ralph.backends.data.base import BaseSettingsConfig from ralph.backends.data.mongo import ( MongoDataBackend, MongoDataBackendSettings, @@ -19,6 +18,7 @@ RalphStatementsQuery, StatementQueryResult, ) +from ralph.conf import BaseSettingsConfig from ralph.exceptions import BackendException, BackendParameterException logger = logging.getLogger(__name__) @@ -33,11 +33,9 @@ class Config(BaseSettingsConfig): env_prefix = "RALPH_BACKENDS__LRS__MONGO__" -class MongoLRSBackend(BaseLRSBackend, MongoDataBackend): +class MongoLRSBackend(BaseLRSBackend[MongoLRSBackendSettings], MongoDataBackend): """MongoDB LRS backend.""" - settings_class = MongoLRSBackendSettings - def query_statements(self, params: RalphStatementsQuery) -> StatementQueryResult: """Return the results of a statements query using xAPI parameters.""" query = self.get_query(params) diff --git a/tests/backends/data/test_base.py b/tests/backends/data/test_base.py index 0e6752514..40b7a510f 100644 --- a/tests/backends/data/test_base.py +++ b/tests/backends/data/test_base.py @@ -1,9 +1,16 @@ """Tests for the base data backend""" import logging +from typing import Any import pytest -from ralph.backends.data.base import BaseDataBackend, BaseQuery, enforce_query_checks +from ralph.backends.data.base import ( + BaseDataBackend, + BaseDataBackendSettings, + BaseQuery, + enforce_query_checks, + get_backend_generic_argument, +) from ralph.exceptions import BackendParameterException @@ -21,9 +28,6 @@ def test_backends_data_base_enforce_query_checks_with_valid_input(value, expecte class MockBaseDataBackend(BaseDataBackend): """A class mocking the base data backend class.""" - def __init__(self, settings=None): - """Instantiate the Mock data backend.""" - @enforce_query_checks def read(self, query=None): """Mock the base database read method.""" @@ -59,9 +63,6 @@ def test_backends_data_base_enforce_query_checks_with_invalid_input( class MockBaseDataBackend(BaseDataBackend): """A class mocking the base database class.""" - def __init__(self, settings=None): - """Instantiate the Mock data backend.""" - @enforce_query_checks def read(self, query=None): """Mock the base database read method.""" @@ -80,3 +81,32 @@ def close(self): error = error.replace("\\", "") assert ("ralph.backends.data.base", logging.ERROR, error) in caplog.record_tuples + + +def test_backends_data_base_get_backend_generic_argument(): + """Test the get_backend_generic_argument function.""" + + assert get_backend_generic_argument(BaseDataBackendSettings, 0) is None + assert get_backend_generic_argument(BaseDataBackend, -2) is None + assert get_backend_generic_argument(BaseDataBackend, -1) is BaseQuery + assert get_backend_generic_argument(BaseDataBackend, 0) is BaseDataBackendSettings + assert get_backend_generic_argument(BaseDataBackend, 1) is BaseQuery + assert get_backend_generic_argument(BaseDataBackend, 2) is None + + class DummySettings(BaseDataBackendSettings): + """Dummy Settings.""" + + class DummyQuery(BaseQuery): + """Dummy Query.""" + + class DummyBackend(BaseDataBackend[DummySettings, DummyQuery]): + """Dummy Backend.""" + + assert get_backend_generic_argument(DummyBackend, 0) is DummySettings + assert get_backend_generic_argument(DummyBackend, 1) is DummyQuery + + class DummyAnyBackend(BaseDataBackend[DummySettings, Any]): + """Dummy Any Backend.""" + + assert get_backend_generic_argument(DummyAnyBackend, 0) is DummySettings + assert get_backend_generic_argument(DummyAnyBackend, 1) is None