-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a dictionary factory backed by MARISA-tries #133
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
black==24.4.2 | ||
flake8==7.0.0 | ||
marisa_trie==1.2.0 | ||
mypy==1.10.0 | ||
platformdirs==4.2.2 | ||
pytest==8.2.1 | ||
pytest-cov==5.0.0 | ||
types-requests==2.32.0.20240523 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
"""Dictionary-based lemmatization strategy.""" | ||
|
||
from .dictionary_factory import DefaultDictionaryFactory, DictionaryFactory | ||
from .trie_directory_factory import TrieDictionaryFactory |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
import logging | ||
from collections.abc import MutableMapping | ||
from functools import lru_cache | ||
from pathlib import Path | ||
from typing import List, Mapping, Optional | ||
|
||
from marisa_trie import BytesTrie, HUGE_CACHE # type: ignore[import-not-found] | ||
from platformdirs import user_cache_dir | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. While I'm not entirely sure, my guess is that mypy can't get any type information, because marisa_trie isn't Python code, but a C-extension and doesn't provide type information. |
||
|
||
from simplemma import __version__ as SIMPLEMMA_VERSION | ||
from simplemma.strategies.dictionaries.dictionary_factory import ( | ||
DefaultDictionaryFactory, | ||
DictionaryFactory, | ||
SUPPORTED_LANGUAGES, | ||
) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class TrieWrapDict(MutableMapping): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we simply modify the Why There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One of the constraints I gave myself was to implement this functionality without requiring changes to the existing code of Simplemma. That's why I didn't modify expected return types of the It's a |
||
"""Wrapper around BytesTrie to make them behave like dicts.""" | ||
|
||
def __init__(self, trie: BytesTrie): | ||
self._trie = trie | ||
|
||
def __getitem__(self, item): | ||
return self._trie[item][0].decode() | ||
|
||
def __setitem__(self, key, value): | ||
raise NotImplementedError | ||
|
||
def __delitem__(self, key): | ||
raise NotImplementedError | ||
|
||
def __iter__(self): | ||
for key in self._trie.iterkeys(): | ||
yield key | ||
|
||
def __len__(self): | ||
return len(self._trie) | ||
|
||
|
||
class TrieDictionaryFactory(DictionaryFactory): | ||
"""Memory optimized DictionaryFactory backed by MARISA-tries. | ||
|
||
This dictionary factory creates dictionaries, which are backed by a | ||
MARISA-trie instead of a dict, to make them consume very little | ||
memory compared to the DefaultDictionaryFactory. Trade-offs are that | ||
lookup performance isn't as good as with dicts. | ||
""" | ||
|
||
__slots__: List[str] = [] | ||
|
||
def __init__( | ||
self, | ||
cache_max_size: int = 8, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We are doing double caching: in-memory and in-disk. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I completely agree. I just kept that parameter from |
||
use_disk_cache: bool = True, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right now if |
||
disk_cache_dir: Optional[str] = None, | ||
) -> None: | ||
"""Initialize the TrieDictionaryFactory. | ||
|
||
Args: | ||
cache_max_size (int): The maximum number dictionaries to | ||
keep in memory. Defaults to `8`. | ||
use_disk_cache (bool): Whether to cache the tries on disk to | ||
speed up loading time. Defaults to `True`. | ||
disk_cache_dir (Optional[str]): Path where the generated | ||
tries should be stored in. Defaults to a Simplemma- | ||
specific subdirectory of the user's cache directory. | ||
""" | ||
|
||
if disk_cache_dir: | ||
self._cache_dir = Path(disk_cache_dir) | ||
else: | ||
self._cache_dir = ( | ||
Path(user_cache_dir("simplemma")) / "marisa_trie" / SIMPLEMMA_VERSION | ||
) | ||
self._use_disk_cache = use_disk_cache | ||
self._get_dictionary = lru_cache(maxsize=cache_max_size)( | ||
self._get_dictionary_uncached | ||
) | ||
|
||
def _create_trie_from_pickled_dict(self, lang: str) -> BytesTrie: | ||
"""Create a trie from a pickled dictionary.""" | ||
unpickled_dict = DefaultDictionaryFactory(cache_max_size=0).get_dictionary(lang) | ||
return BytesTrie( | ||
zip( | ||
unpickled_dict.keys(), | ||
[value.encode() for value in unpickled_dict.values()], | ||
), | ||
cache_size=HUGE_CACHE, | ||
) | ||
|
||
def _write_trie_to_disk(self, lang: str, trie: BytesTrie) -> None: | ||
"""Persist the trie to disk for later usage. | ||
|
||
The persisted trie can be loaded by subsequent runs to speed up | ||
loading times. | ||
""" | ||
logger.debug("Caching trie on disk. This might take a second.") | ||
self._cache_dir.mkdir(parents=True, exist_ok=True) | ||
|
||
trie.save(self._cache_dir / f"{lang}.dic") | ||
|
||
def _get_dictionary_uncached(self, lang: str) -> Mapping[str, str]: | ||
"""Get the dictionary for the given language.""" | ||
if lang not in SUPPORTED_LANGUAGES: | ||
raise ValueError(f"Unsupported language: {lang}") | ||
|
||
if self._use_disk_cache and (self._cache_dir / f"{lang}.dic").exists(): | ||
trie = BytesTrie().load(str(self._cache_dir / f"{lang}.dic")) | ||
else: | ||
trie = self._create_trie_from_pickled_dict(lang) | ||
if self._use_disk_cache: | ||
self._write_trie_to_disk(lang, trie) | ||
|
||
return TrieWrapDict(trie) | ||
|
||
def get_dictionary( | ||
self, | ||
lang: str, | ||
) -> Mapping[str, str]: | ||
return self._get_dictionary(lang) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be helpful to explain where the cache is located on disk?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried not to document every little detail in the README to avoid blowing it up too much. IMO there should be a separate API-documentation for Simplemma to cover stuff like that. However, if it's desired, please let me know and I'll happily add it.