From ce70e71601941057b5acdb26fdd51b7144c984dc Mon Sep 17 00:00:00 2001 From: Daniel Roschka Date: Thu, 30 May 2024 10:16:38 +0200 Subject: [PATCH 1/2] Add a dictionary factory backed by MARISA-tries This adds an additional dictionary factory backed by MARISA-tries. This dictionary factory on average offers 20x lower memory usage and 100x faster initialization time, in exchange for reduced lemmatization and language detection performance. The first time loading a dictionary with the `TrieDictionaryFactory` requires more memory and will take a few seconds, as the trie-backed dictionary has to be generated on-the-fly from the pickled dict-based dictionary first. --- README.md | 55 +++++ requirements-dev.txt | 2 + setup.py | 1 + simplemma/strategies/__init__.py | 6 +- simplemma/strategies/dictionaries/__init__.py | 1 + .../dictionaries/trie_directory_factory.py | 123 ++++++++++ .../test_trie_dictionary_factory.py | 221 ++++++++++++++++++ 7 files changed, 408 insertions(+), 1 deletion(-) create mode 100644 simplemma/strategies/dictionaries/trie_directory_factory.py create mode 100644 tests/strategies/dictionaries/test_trie_dictionary_factory.py diff --git a/README.md b/README.md index 4e05e0d..841af6f 100644 --- a/README.md +++ b/README.md @@ -259,6 +259,61 @@ LANG_CACHE_SIZE = 5 # How many language dictionaries to keep in memory at once For more information see the [extended documentation](https://adbar.github.io/simplemma/). +### Reducing memory usage + +For situations where low memory usage and fast initialization time are +more important than lemmatization and language detection performance, +Simplemma ships another `DictionaryFactory`, which uses a trie as +underlying data structure instead of a Python dict. + +Using the `TrieDictionaryFactory` reduces memory usage on average by +20x and initialization time by 100x, but comes at the cost that +performance can be down 50% or even more compared to what Simplemma +otherwise achieves, depending on the specific usage. + +To use the `TrieDictionaryFactory` you have to install Simplemma with +the `marisa-trie` extra dependency: + +``` +pip install simplemma[marisa-trie] +``` + +Then you have to create a custom strategy using the +`TrieDictionaryFactory` and use that for `Lemmatizer` and +`LanguageDetector` instances: + +``` python +>>> from simplemma import LanguageDetector, Lemmatizer +>>> from simplemma.strategies import DefaultStrategy +>>> from simplemma.strategies.dictionaries import TrieDictionaryFactory + +>>> lemmatization_strategy = DefaultStrategy(dictionary_factory=TrieDictionaryFactory()) + +>>> lemmatizer = Lemmatizer(lemmatization_strategy=lemmatization_strategy) +>>> lemmatizer.lemmatize('doughnuts', lang='en') +'doughnut' + +>>> language_detector = LanguageDetector('la', lemmatization_strategy=lemmatization_strategy) +>>> language_detector.proportion_in_target_languages("opera post physica posita (τὰ μετὰ τὰ φυσικά)") +0.5 +``` + +While memory usage and initialization time when using the +`TrieDictionaryFactory` are significantly lower compared to the +`DefaultDictionaryFactory`, that's only true if the trie dictionaries +are available on disk. That's not the case when using the +`TrieDictionaryFactory` for the first time, as Simplemma only ships +the dictionaries as Python dicts. The trie dictionaries have to be +generated once from the Python dicts. That happens on-the-fly when +using the `TrieDictionaryFactory` for the first time for a language and +will take a few seconds and use as much memory as loading the Python +dicts for the language requires. For further invocations the trie +dictionaries get cached on disk. + +If the computer supposed to run Simplemma doesn't have enough memory to +generate the trie dictionaries, they can also be generated on another +computer with the same CPU architecture and copied over to the cache +directory. ## Supported languages diff --git a/requirements-dev.txt b/requirements-dev.txt index 40ff760..7472f14 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,6 +1,8 @@ black==24.4.2 flake8==7.0.0 +marisa_trie==1.2.0 mypy==1.10.0 +platformdirs==4.2.2 pytest==8.2.1 pytest-cov==5.0.0 types-requests==2.32.0.20240523 diff --git a/setup.py b/setup.py index e2523ee..23673a4 100644 --- a/setup.py +++ b/setup.py @@ -85,6 +85,7 @@ def get_version(package): ], description="A simple multilingual lemmatizer for Python.", install_requires=requirements, + extras_require={"marisa-trie": ["marisa-trie", "platformdirs"]}, license="MIT license", long_description=readme, # + '\n\n' + history, long_description_content_type="text/markdown", diff --git a/simplemma/strategies/__init__.py b/simplemma/strategies/__init__.py index 918bc86..89fcc51 100644 --- a/simplemma/strategies/__init__.py +++ b/simplemma/strategies/__init__.py @@ -2,7 +2,11 @@ from .affix_decomposition import AffixDecompositionStrategy from .default import DefaultStrategy -from .dictionaries import DefaultDictionaryFactory, DictionaryFactory +from .dictionaries import ( + DefaultDictionaryFactory, + DictionaryFactory, + TrieDictionaryFactory, +) from .dictionary_lookup import DictionaryLookupStrategy from .fallback.lemmatization_fallback_strategy import LemmatizationFallbackStrategy from .fallback.raise_error import RaiseErrorFallbackStrategy diff --git a/simplemma/strategies/dictionaries/__init__.py b/simplemma/strategies/dictionaries/__init__.py index cb6f1b7..b791fd5 100644 --- a/simplemma/strategies/dictionaries/__init__.py +++ b/simplemma/strategies/dictionaries/__init__.py @@ -1,3 +1,4 @@ """Dictionary-based lemmatization strategy.""" from .dictionary_factory import DefaultDictionaryFactory, DictionaryFactory +from .trie_directory_factory import TrieDictionaryFactory diff --git a/simplemma/strategies/dictionaries/trie_directory_factory.py b/simplemma/strategies/dictionaries/trie_directory_factory.py new file mode 100644 index 0000000..03ac4ab --- /dev/null +++ b/simplemma/strategies/dictionaries/trie_directory_factory.py @@ -0,0 +1,123 @@ +import logging +from collections.abc import MutableMapping +from functools import lru_cache +from pathlib import Path +from typing import ByteString, Dict, List, Optional, cast + +from marisa_trie import BytesTrie, HUGE_CACHE # type: ignore[import-not-found] +from platformdirs import user_cache_dir + +from simplemma import __version__ as SIMPLEMMA_VERSION +from simplemma.strategies.dictionaries.dictionary_factory import ( + DefaultDictionaryFactory, + DictionaryFactory, + SUPPORTED_LANGUAGES, +) + +logger = logging.getLogger(__name__) + + +class TrieWrapDict(MutableMapping): + """Wrapper around BytesTrie to make them behave like dicts.""" + + def __init__(self, trie: BytesTrie): + self._trie = trie + + def __getitem__(self, item): + return self._trie[item.decode()][0] + + def __setitem__(self, key, value): + raise NotImplementedError + + def __delitem__(self, key): + raise NotImplementedError + + def __iter__(self): + for key in self._trie.iterkeys(): + yield key.encode() + + def __len__(self): + return len(self._trie) + + +class TrieDictionaryFactory(DictionaryFactory): + """Memory optimized DictionaryFactory backed by MARISA-tries. + + This dictionary factory creates dictionaries, which are backed by a + MARISA-trie instead of a dict, to make them consume very little + memory compared to the DefaultDictionaryFactory. Trade-offs are that + lookup performance isn't as good as with dicts. + """ + + __slots__: List[str] = [] + + def __init__( + self, + cache_max_size: int = 8, + use_disk_cache: bool = True, + disk_cache_dir: Optional[str] = None, + ) -> None: + """Initialize the TrieDictionaryFactory. + + Args: + cache_max_size (int): The maximum number dictionaries to + keep in memory. Defaults to `8`. + use_disk_cache (bool): Whether to cache the tries on disk to + speed up loading time. Defaults to `True`. + disk_cache_dir (Optional[str]): Path where the generated + tries should be stored in. Defaults to a Simplemma- + specific subdirectory of the user's cache directory. + """ + + if disk_cache_dir: + self._cache_dir = Path(disk_cache_dir) + else: + self._cache_dir = ( + Path(user_cache_dir("simplemma")) / "marisa_trie" / SIMPLEMMA_VERSION + ) + self._use_disk_cache = use_disk_cache + self._get_dictionary = lru_cache(maxsize=cache_max_size)( + self._get_dictionary_uncached + ) + + def _create_trie_from_pickled_dict(self, lang: str) -> BytesTrie: + """Create a trie from a pickled dictionary.""" + unpickled_dict = DefaultDictionaryFactory(cache_max_size=0).get_dictionary(lang) + return BytesTrie( + zip( + [key.decode() for key in unpickled_dict], # type: ignore[union-attr] + unpickled_dict.values(), + ), + cache_size=HUGE_CACHE, + ) + + def _write_trie_to_disk(self, lang: str, trie: BytesTrie) -> None: + """Persist the trie to disk for later usage. + + The persisted trie can be loaded by subsequent runs to speed up + loading times. + """ + logger.debug("Caching trie on disk. This might take a second.") + self._cache_dir.mkdir(parents=True, exist_ok=True) + + trie.save(self._cache_dir / f"{lang}.dic") + + def _get_dictionary_uncached(self, lang: str) -> Dict[ByteString, ByteString]: + """Get the dictionary for the given language.""" + if lang not in SUPPORTED_LANGUAGES: + raise ValueError(f"Unsupported language: {lang}") + + if self._use_disk_cache and (self._cache_dir / f"{lang}.dic").exists(): + trie = BytesTrie().load(str(self._cache_dir / f"{lang}.dic")) + else: + trie = self._create_trie_from_pickled_dict(lang) + if self._use_disk_cache: + self._write_trie_to_disk(lang, trie) + + return cast(dict, TrieWrapDict(trie)) + + def get_dictionary( + self, + lang: str, + ) -> Dict[ByteString, ByteString]: + return self._get_dictionary(lang) diff --git a/tests/strategies/dictionaries/test_trie_dictionary_factory.py b/tests/strategies/dictionaries/test_trie_dictionary_factory.py new file mode 100644 index 0000000..8cbd65a --- /dev/null +++ b/tests/strategies/dictionaries/test_trie_dictionary_factory.py @@ -0,0 +1,221 @@ +from collections.abc import ItemsView, KeysView +from pathlib import Path +from tempfile import TemporaryDirectory +from unittest.mock import call, patch + +import pytest +from marisa_trie import BytesTrie # type: ignore[import-not-found] + +from simplemma.strategies.dictionaries.trie_directory_factory import TrieWrapDict +from simplemma.strategies import TrieDictionaryFactory + + +def test_exceptions() -> None: + # missing languages or faulty language codes + dictionary_factory = TrieDictionaryFactory(use_disk_cache=False) + with pytest.raises(ValueError): + dictionary_factory.get_dictionary(("abc")) + + +def test_dictionary_lru_cache() -> None: + iterations = 10 + dictionaries = TrieDictionaryFactory(use_disk_cache=False) + for _ in range(iterations): + dictionaries.get_dictionary("en") + dictionaries.get_dictionary("de") + assert dictionaries._get_dictionary.cache_info().misses == 2 + assert dictionaries._get_dictionary.cache_info().hits == (iterations - 1) * 2 + + +def test_max_lru_cache_size() -> None: + dictionaries = TrieDictionaryFactory(cache_max_size=3, use_disk_cache=False) + + for lang in ["de", "en", "en", "es", "fr", "it", "de"]: + dictionaries.get_dictionary(lang) + + assert dictionaries._get_dictionary.cache_info().misses == 6 + assert dictionaries._get_dictionary.cache_info().hits == 1 + + +def test_disabled_disk_cache() -> None: + with TemporaryDirectory() as tmp_dir: + tmp_dir_path = Path(tmp_dir) + dictionaries = TrieDictionaryFactory( + disk_cache_dir=tmp_dir, use_disk_cache=False + ) + dictionaries.get_dictionary("en") + dictionaries.get_dictionary("fr") + assert sorted(tmp_dir_path.iterdir()) == [] + + +def test_no_disk_cache() -> None: + with TemporaryDirectory() as tmp_dir: + tmp_dir_path = Path(tmp_dir) + dictionaries = TrieDictionaryFactory( + use_disk_cache=False, disk_cache_dir=tmp_dir + ) + + with patch.object( + TrieDictionaryFactory, + "_create_trie_from_pickled_dict", + wraps=dictionaries._create_trie_from_pickled_dict, + ) as create_trie_mock, patch.object( + TrieDictionaryFactory, + "_write_trie_to_disk", + wraps=dictionaries._write_trie_to_disk, + ) as write_trie_mock: + assert sorted(tmp_dir_path.iterdir()) == [] + + dictionaries.get_dictionary("en") + dictionaries.get_dictionary("fr") + + dictionaries.get_dictionary("en") + dictionaries.get_dictionary("fr") + + create_trie_mock.assert_has_calls([call("en"), call("fr")]) + write_trie_mock.assert_not_called() + + assert sorted(tmp_dir_path.iterdir()) == [] + + +def test_disk_cache() -> None: + with TemporaryDirectory() as tmp_dir: + tmp_dir_path = Path(tmp_dir) + dictionaries = TrieDictionaryFactory(disk_cache_dir=tmp_dir) + + with patch.object( + TrieDictionaryFactory, + "_create_trie_from_pickled_dict", + wraps=dictionaries._create_trie_from_pickled_dict, + ) as create_trie_mock, patch.object( + TrieDictionaryFactory, + "_write_trie_to_disk", + wraps=dictionaries._write_trie_to_disk, + ) as write_trie_mock: + assert sorted(tmp_dir_path.iterdir()) == [] + + # Initial cached trie files should be generated. + en_dictionary = dictionaries.get_dictionary("en") + fr_dictionary = dictionaries.get_dictionary("fr") + + create_trie_mock.assert_has_calls([call("en"), call("fr")]) + create_trie_mock.reset_mock() + write_trie_mock.assert_has_calls( + [ + call("en", en_dictionary._trie), # type: ignore[attr-defined] + call("fr", fr_dictionary._trie), # type: ignore[attr-defined] + ] + ) + write_trie_mock.reset_mock() + + assert sorted(tmp_dir_path.iterdir()) == [ + tmp_dir_path / "en.dic", + tmp_dir_path / "fr.dic", + ] + + # LRU cache should result in not checking for cached tries. + dictionaries.get_dictionary("en") + dictionaries.get_dictionary("fr") + + create_trie_mock.assert_not_called() + write_trie_mock.assert_not_called() + + dictionaries._get_dictionary.cache_clear() + + # Cached trie files should be checked, but not regenerated, + # as LRU cached got emptied. + dictionaries.get_dictionary("en") + dictionaries.get_dictionary("fr") + + create_trie_mock.assert_not_called() + write_trie_mock.assert_not_called() + + assert sorted(tmp_dir_path.iterdir()) == [ + tmp_dir_path / "en.dic", + tmp_dir_path / "fr.dic", + ] + + +def test_corrupted_disk_cache() -> None: + with TemporaryDirectory() as tmp_dir: + tmp_dir_path = Path(tmp_dir) + dictionaries = TrieDictionaryFactory(disk_cache_dir=tmp_dir) + + with patch.object( + TrieDictionaryFactory, + "_create_trie_from_pickled_dict", + wraps=dictionaries._create_trie_from_pickled_dict, + ) as create_trie_mock, patch.object( + TrieDictionaryFactory, + "_write_trie_to_disk", + wraps=dictionaries._write_trie_to_disk, + ) as write_trie_mock: + assert sorted(tmp_dir_path.iterdir()) == [] + + # Initial cached trie file should be generated. + en_dictionary = dictionaries.get_dictionary("en") + + create_trie_mock.assert_has_calls([call("en")]) + create_trie_mock.reset_mock() + write_trie_mock.assert_has_calls( + [ + call("en", en_dictionary._trie), # type: ignore[attr-defined] + ] + ) + write_trie_mock.reset_mock() + + assert sorted(tmp_dir_path.iterdir()) == [ + tmp_dir_path / "en.dic", + ] + + with (tmp_dir_path / "en.dic").open("wb") as f: + f.write(b"corrupted trie dictionary") + dictionaries._get_dictionary.cache_clear() + + # Loading a corrupted hash should regenerate it. + with pytest.raises(RuntimeError): + dictionaries.get_dictionary("en") + + create_trie_mock.assert_not_called() + write_trie_mock.assert_not_called() + + assert sorted(tmp_dir_path.iterdir()) == [tmp_dir_path / "en.dic"] + + +def test_dictionary_working_as_a_dict() -> None: + dictionaries = TrieDictionaryFactory(use_disk_cache=False) + dictionary = dictionaries.get_dictionary("en") + + assert isinstance(dictionary, TrieWrapDict) + + assert (b"balconies" in dictionary) is True + assert (b"balconies123" in dictionary) is False + with pytest.raises(KeyError): + dictionary[b"balconies123"] + assert dictionary.get(b"balconies") == b"balcony" + + +def test_trie_wrap_dict(): + trie = BytesTrie( + zip(["houses", "balconies", "ponies"], [b"house", b"balcony", b"pony"]) + ) + wrapped_trie = TrieWrapDict(trie) + + assert (b"balconies" in wrapped_trie) is True + assert (b"balconies123" in wrapped_trie) is False + assert wrapped_trie[b"balconies"] == b"balcony" + with pytest.raises(KeyError): + wrapped_trie[b"balconies123"] + assert wrapped_trie.get(b"balconies") == b"balcony" + assert wrapped_trie.get(b"balconies123") is None + + assert isinstance(wrapped_trie.keys(), KeysView) + assert isinstance(wrapped_trie.items(), ItemsView) + assert len(wrapped_trie) == 3 + + with pytest.raises(NotImplementedError): + wrapped_trie["houses"] = b"teapot" + with pytest.raises(NotImplementedError): + del wrapped_trie["balconies"] + + assert [key for key in wrapped_trie] == [b"balconies", b"houses", b"ponies"] From 81f08ba0821c936a41ec2c16115e37fd227deffe Mon Sep 17 00:00:00 2001 From: Daniel Roschka Date: Fri, 14 Jun 2024 13:47:26 +0200 Subject: [PATCH 2/2] Change dictionary format to use strings again This changes the format of the dictionary returned by `DictionaryFactory().get_dictionary()` from `Dict[ByteString, ByteString]` to `Mapping[str, str] to accommodate alternative dictionary factory implementations better and to ease the dictionary handling again. This keeps the storage of pickled dictionaries with byte strings though, as they're smaller than when using strings. --- .../dictionaries/dictionary_factory.py | 34 ++++++++++++++----- .../dictionaries/trie_directory_factory.py | 16 ++++----- simplemma/strategies/dictionary_lookup.py | 13 ++----- .../strategies/greedy_dictionary_lookup.py | 4 +-- simplemma/utils.py | 8 ++--- .../test_trie_dictionary_factory.py | 22 ++++++------ tests/test_dictionary_pickler.py | 4 +-- tests/test_lemmatizer.py | 6 ++-- training/dictionary_pickler.py | 10 +++--- 9 files changed, 61 insertions(+), 56 deletions(-) diff --git a/simplemma/strategies/dictionaries/dictionary_factory.py b/simplemma/strategies/dictionaries/dictionary_factory.py index e163b97..50adad0 100644 --- a/simplemma/strategies/dictionaries/dictionary_factory.py +++ b/simplemma/strategies/dictionaries/dictionary_factory.py @@ -14,7 +14,7 @@ from functools import lru_cache from os import listdir, path from pathlib import Path -from typing import ByteString, Dict, Protocol +from typing import ByteString, Dict, Mapping, Protocol DATA_FOLDER = str(Path(__file__).parent / "data") SUPPORTED_LANGUAGES = [ @@ -62,7 +62,7 @@ class DictionaryFactory(Protocol): def get_dictionary( self, lang: str, - ) -> Dict[ByteString, ByteString]: + ) -> Mapping[str, str]: """ Get the dictionary for a specific language. @@ -70,7 +70,7 @@ def get_dictionary( lang (str): The language code. Returns: - Dict[str, str]: The dictionary for the specified language. + Mapping[str, str]: The dictionary for the specified language. Raises: ValueError: If the specified language is not supported. @@ -78,6 +78,25 @@ def get_dictionary( raise NotImplementedError +class MappingStrToByteString(Mapping[str, str]): + """Wrapper around ByString dict to make them behave like str dict.""" + + __slots__ = ["_dict"] + + def __init__(self, dictionary: Dict[bytes, bytes]): + self._dict = dictionary + + def __getitem__(self, item: str): + return self._dict[item.encode()].decode() + + def __iter__(self): + for key in self._dict: + yield key.decode() + + def __len__(self): + return len(self._dict) + + class DefaultDictionaryFactory(DictionaryFactory): """ Default Dictionary Factory. @@ -86,7 +105,7 @@ class DefaultDictionaryFactory(DictionaryFactory): It provides functionality for loading and caching dictionaries from disk that are included in Simplemma. """ - __slots__ = ["_data", "_load_dictionary_from_disk"] + __slots__ = ["_load_dictionary_from_disk"] def __init__(self, cache_max_size: int = 8): """ @@ -96,7 +115,6 @@ def __init__(self, cache_max_size: int = 8): cache_max_size (int): The maximum size of the cache for loaded dictionaries. Defaults to `8`. """ - self._data: Dict[str, Dict[ByteString, ByteString]] = {} self._load_dictionary_from_disk = lru_cache(maxsize=cache_max_size)( _load_dictionary_from_disk ) @@ -104,7 +122,7 @@ def __init__(self, cache_max_size: int = 8): def get_dictionary( self, lang: str, - ) -> Dict[ByteString, ByteString]: + ) -> Mapping[str, str]: """ Get the dictionary for a specific language. @@ -112,11 +130,11 @@ def get_dictionary( lang (str): The language code. Returns: - Dict[str, str]: The dictionary for the specified language. + Mapping[str, str]: The dictionary for the specified language. Raises: ValueError: If the specified language is not supported. """ if lang not in SUPPORTED_LANGUAGES: raise ValueError(f"Unsupported language: {lang}") - return self._load_dictionary_from_disk(lang) + return MappingStrToByteString(self._load_dictionary_from_disk(lang)) diff --git a/simplemma/strategies/dictionaries/trie_directory_factory.py b/simplemma/strategies/dictionaries/trie_directory_factory.py index 03ac4ab..6d2c2ec 100644 --- a/simplemma/strategies/dictionaries/trie_directory_factory.py +++ b/simplemma/strategies/dictionaries/trie_directory_factory.py @@ -2,7 +2,7 @@ from collections.abc import MutableMapping from functools import lru_cache from pathlib import Path -from typing import ByteString, Dict, List, Optional, cast +from typing import List, Mapping, Optional from marisa_trie import BytesTrie, HUGE_CACHE # type: ignore[import-not-found] from platformdirs import user_cache_dir @@ -24,7 +24,7 @@ def __init__(self, trie: BytesTrie): self._trie = trie def __getitem__(self, item): - return self._trie[item.decode()][0] + return self._trie[item][0].decode() def __setitem__(self, key, value): raise NotImplementedError @@ -34,7 +34,7 @@ def __delitem__(self, key): def __iter__(self): for key in self._trie.iterkeys(): - yield key.encode() + yield key def __len__(self): return len(self._trie) @@ -85,8 +85,8 @@ def _create_trie_from_pickled_dict(self, lang: str) -> BytesTrie: unpickled_dict = DefaultDictionaryFactory(cache_max_size=0).get_dictionary(lang) return BytesTrie( zip( - [key.decode() for key in unpickled_dict], # type: ignore[union-attr] - unpickled_dict.values(), + unpickled_dict.keys(), + [value.encode() for value in unpickled_dict.values()], ), cache_size=HUGE_CACHE, ) @@ -102,7 +102,7 @@ def _write_trie_to_disk(self, lang: str, trie: BytesTrie) -> None: trie.save(self._cache_dir / f"{lang}.dic") - def _get_dictionary_uncached(self, lang: str) -> Dict[ByteString, ByteString]: + def _get_dictionary_uncached(self, lang: str) -> Mapping[str, str]: """Get the dictionary for the given language.""" if lang not in SUPPORTED_LANGUAGES: raise ValueError(f"Unsupported language: {lang}") @@ -114,10 +114,10 @@ def _get_dictionary_uncached(self, lang: str) -> Dict[ByteString, ByteString]: if self._use_disk_cache: self._write_trie_to_disk(lang, trie) - return cast(dict, TrieWrapDict(trie)) + return TrieWrapDict(trie) def get_dictionary( self, lang: str, - ) -> Dict[ByteString, ByteString]: + ) -> Mapping[str, str]: return self._get_dictionary(lang) diff --git a/simplemma/strategies/dictionary_lookup.py b/simplemma/strategies/dictionary_lookup.py index a98d365..9262477 100644 --- a/simplemma/strategies/dictionary_lookup.py +++ b/simplemma/strategies/dictionary_lookup.py @@ -3,7 +3,7 @@ It provides lemmatization using dictionary lookup. """ -from typing import ByteString, Dict, Optional +from typing import Optional from .dictionaries.dictionary_factory import DefaultDictionaryFactory, DictionaryFactory from .lemmatization_strategy import LemmatizationStrategy @@ -26,13 +26,6 @@ def __init__( """ self._dictionary_factory = dictionary_factory - def _get( - self, token: str, dictionary: Dict[ByteString, ByteString] - ) -> Optional[str]: - "Convenience function to handle bytestring to string conversion." - result = dictionary.get(token.encode("utf-8")) - return result.decode("utf-8") if result else None # type: ignore[union-attr] - def get_lemma(self, token: str, lang: str) -> Optional[str]: """ Get Lemma using Dictionary Lookup @@ -50,9 +43,9 @@ def get_lemma(self, token: str, lang: str) -> Optional[str]: """ # Search the language data, reverse case to extend coverage. dictionary = self._dictionary_factory.get_dictionary(lang) - result = self._get(token, dictionary) + result = dictionary.get(token) if result: return result # Try upper or lowercase. token = token.lower() if token[0].isupper() else token.capitalize() - return self._get(token, dictionary) + return dictionary.get(token) diff --git a/simplemma/strategies/greedy_dictionary_lookup.py b/simplemma/strategies/greedy_dictionary_lookup.py index ea372de..0915402 100644 --- a/simplemma/strategies/greedy_dictionary_lookup.py +++ b/simplemma/strategies/greedy_dictionary_lookup.py @@ -58,7 +58,7 @@ def get_lemma(self, token: str, lang: str) -> str: return token dictionary = self._dictionary_factory.get_dictionary(lang) - candidate = token.encode("utf-8") + candidate = token for _ in range(self._steps): if candidate not in dictionary: break @@ -73,4 +73,4 @@ def get_lemma(self, token: str, lang: str) -> str: candidate = new_candidate - return candidate.decode("utf-8") + return candidate diff --git a/simplemma/utils.py b/simplemma/utils.py index 57d47cb..1d81fa0 100644 --- a/simplemma/utils.py +++ b/simplemma/utils.py @@ -6,7 +6,7 @@ - [validate_lang_input][simplemma.utils.validate_lang_input]: Validates the language input and ensures it is a valid tuple. """ -from typing import ByteString, Tuple, Union +from typing import Tuple, Union def validate_lang_input(lang: Union[str, Tuple[str, ...]]) -> Tuple[str]: @@ -31,9 +31,7 @@ def validate_lang_input(lang: Union[str, Tuple[str, ...]]) -> Tuple[str]: return lang # type: ignore[return-value] -def levenshtein_dist( - first: Union[ByteString, str], second: Union[ByteString, str] -) -> int: +def levenshtein_dist(str1: str, str2: str) -> int: """ Calculate the Levenshtein distance between two strings. @@ -49,8 +47,6 @@ def levenshtein_dist( int: The Levenshtein distance between the two strings. """ - str1 = first.encode("utf-8") if isinstance(first, str) else first - str2 = second.encode("utf-8") if isinstance(second, str) else second # inspired by this noticeably faster code: # https://gist.github.com/p-hash/9e0f9904ce7947c133308fbe48fe032b if str1 == str2: diff --git a/tests/strategies/dictionaries/test_trie_dictionary_factory.py b/tests/strategies/dictionaries/test_trie_dictionary_factory.py index 8cbd65a..cbbfd86 100644 --- a/tests/strategies/dictionaries/test_trie_dictionary_factory.py +++ b/tests/strategies/dictionaries/test_trie_dictionary_factory.py @@ -188,11 +188,11 @@ def test_dictionary_working_as_a_dict() -> None: assert isinstance(dictionary, TrieWrapDict) - assert (b"balconies" in dictionary) is True - assert (b"balconies123" in dictionary) is False + assert ("balconies" in dictionary) is True + assert ("balconies123" in dictionary) is False with pytest.raises(KeyError): - dictionary[b"balconies123"] - assert dictionary.get(b"balconies") == b"balcony" + dictionary["balconies123"] + assert dictionary.get("balconies") == "balcony" def test_trie_wrap_dict(): @@ -201,21 +201,21 @@ def test_trie_wrap_dict(): ) wrapped_trie = TrieWrapDict(trie) - assert (b"balconies" in wrapped_trie) is True - assert (b"balconies123" in wrapped_trie) is False - assert wrapped_trie[b"balconies"] == b"balcony" + assert ("balconies" in wrapped_trie) is True + assert ("balconies123" in wrapped_trie) is False + assert wrapped_trie["balconies"] == "balcony" with pytest.raises(KeyError): wrapped_trie[b"balconies123"] - assert wrapped_trie.get(b"balconies") == b"balcony" - assert wrapped_trie.get(b"balconies123") is None + assert wrapped_trie.get("balconies") == "balcony" + assert wrapped_trie.get("balconies123") is None assert isinstance(wrapped_trie.keys(), KeysView) assert isinstance(wrapped_trie.items(), ItemsView) assert len(wrapped_trie) == 3 with pytest.raises(NotImplementedError): - wrapped_trie["houses"] = b"teapot" + wrapped_trie["houses"] = "teapot" with pytest.raises(NotImplementedError): del wrapped_trie["balconies"] - assert [key for key in wrapped_trie] == [b"balconies", b"houses", b"ponies"] + assert [key for key in wrapped_trie] == ["balconies", "houses", "ponies"] diff --git a/tests/test_dictionary_pickler.py b/tests/test_dictionary_pickler.py index 2fc806f..37136f2 100644 --- a/tests/test_dictionary_pickler.py +++ b/tests/test_dictionary_pickler.py @@ -26,9 +26,9 @@ def test_logic() -> None: # different order mydict = dictionary_pickler._read_dict(testfile, "es", silent=True) assert len(mydict) == 5 - assert mydict[b"closeones"] == b"closeone" + assert mydict["closeones"] == "closeone" item = sorted(mydict.keys(), reverse=True)[0] - assert item == b"valid-word" + assert item == "valid-word" # file I/O assert dictionary_pickler._determine_path("lists", "de").endswith("de.txt") diff --git a/tests/test_lemmatizer.py b/tests/test_lemmatizer.py index e911cf1..17a8e93 100644 --- a/tests/test_lemmatizer.py +++ b/tests/test_lemmatizer.py @@ -1,6 +1,6 @@ """Tests for `simplemma` package.""" -from typing import ByteString, Dict +from typing import Mapping import pytest @@ -17,8 +17,8 @@ class CustomDictionaryFactory(DictionaryFactory): def get_dictionary( self, lang: str, - ) -> Dict[ByteString, ByteString]: - return {b"testing": b"the test works!!"} + ) -> Mapping[str, str]: + return {"testing": "the test works!!"} assert ( Lemmatizer( diff --git a/training/dictionary_pickler.py b/training/dictionary_pickler.py index 15345d1..69f4692 100644 --- a/training/dictionary_pickler.py +++ b/training/dictionary_pickler.py @@ -10,7 +10,7 @@ import re from operator import itemgetter from pathlib import Path -from typing import ByteString, Dict, List, Optional +from typing import Dict, List, Optional import simplemma from simplemma.strategies.defaultrules import DEFAULT_RULES @@ -49,9 +49,7 @@ def _determine_path(listpath: str, langcode: str) -> str: return str(Path(__file__).parent / filename) -def _read_dict( - filepath: str, langcode: str, silent: bool -) -> Dict[ByteString, ByteString]: +def _read_dict(filepath: str, langcode: str, silent: bool) -> Dict[str, str]: mydict: Dict[str, str] = {} myadditions: List[str] = [] i: int = 0 @@ -122,12 +120,12 @@ def _read_dict( mydict[word] = word LOGGER.debug("%s %s", langcode, i) # sort and convert to bytestrings - return {k.encode("utf-8"): v.encode("utf-8") for k, v in sorted(mydict.items())} + return dict(sorted(mydict.items())) def _load_dict( langcode: str, listpath: str = "lists", silent: bool = True -) -> Dict[ByteString, ByteString]: +) -> Dict[str, str]: filepath = _determine_path(listpath, langcode) return _read_dict(filepath, langcode, silent)