Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a dictionary factory backed by MARISA-tries #133

Merged
merged 2 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,61 @@ LANG_CACHE_SIZE = 5 # How many language dictionaries to keep in memory at once
For more information see the
[extended documentation](https://adbar.github.io/simplemma/).

### Reducing memory usage

For situations where low memory usage and fast initialization time are
more important than lemmatization and language detection performance,
Simplemma ships another `DictionaryFactory`, which uses a trie as
underlying data structure instead of a Python dict.

Using the `TrieDictionaryFactory` reduces memory usage on average by
20x and initialization time by 100x, but comes at the cost that
performance can be down 50% or even more compared to what Simplemma
otherwise achieves, depending on the specific usage.

To use the `TrieDictionaryFactory` you have to install Simplemma with
the `marisa-trie` extra dependency:

```
pip install simplemma[marisa-trie]
```

Then you have to create a custom strategy using the
`TrieDictionaryFactory` and use that for `Lemmatizer` and
`LanguageDetector` instances:

``` python
>>> from simplemma import LanguageDetector, Lemmatizer
>>> from simplemma.strategies import DefaultStrategy
>>> from simplemma.strategies.dictionaries import TrieDictionaryFactory

>>> lemmatization_strategy = DefaultStrategy(dictionary_factory=TrieDictionaryFactory())

>>> lemmatizer = Lemmatizer(lemmatization_strategy=lemmatization_strategy)
>>> lemmatizer.lemmatize('doughnuts', lang='en')
'doughnut'

>>> language_detector = LanguageDetector('la', lemmatization_strategy=lemmatization_strategy)
>>> language_detector.proportion_in_target_languages("opera post physica posita (τὰ μετὰ τὰ φυσικά)")
0.5
```

While memory usage and initialization time when using the
`TrieDictionaryFactory` are significantly lower compared to the
`DefaultDictionaryFactory`, that's only true if the trie dictionaries
are available on disk. That's not the case when using the
`TrieDictionaryFactory` for the first time, as Simplemma only ships
the dictionaries as Python dicts. The trie dictionaries have to be
generated once from the Python dicts. That happens on-the-fly when
using the `TrieDictionaryFactory` for the first time for a language and
will take a few seconds and use as much memory as loading the Python
dicts for the language requires. For further invocations the trie
dictionaries get cached on disk.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be helpful to explain where the cache is located on disk?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried not to document every little detail in the README to avoid blowing it up too much. IMO there should be a separate API-documentation for Simplemma to cover stuff like that. However, if it's desired, please let me know and I'll happily add it.


If the computer supposed to run Simplemma doesn't have enough memory to
generate the trie dictionaries, they can also be generated on another
computer with the same CPU architecture and copied over to the cache
directory.

## Supported languages

Expand Down
2 changes: 2 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
black==24.4.2
flake8==7.0.0
marisa_trie==1.2.0
mypy==1.10.0
platformdirs==4.2.2
pytest==8.2.1
pytest-cov==5.0.0
types-requests==2.32.0.20240523
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def get_version(package):
],
description="A simple multilingual lemmatizer for Python.",
install_requires=requirements,
extras_require={"marisa-trie": ["marisa-trie", "platformdirs"]},
license="MIT license",
long_description=readme, # + '\n\n' + history,
long_description_content_type="text/markdown",
Expand Down
6 changes: 5 additions & 1 deletion simplemma/strategies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

from .affix_decomposition import AffixDecompositionStrategy
from .default import DefaultStrategy
from .dictionaries import DefaultDictionaryFactory, DictionaryFactory
from .dictionaries import (
DefaultDictionaryFactory,
DictionaryFactory,
TrieDictionaryFactory,
)
from .dictionary_lookup import DictionaryLookupStrategy
from .fallback.lemmatization_fallback_strategy import LemmatizationFallbackStrategy
from .fallback.raise_error import RaiseErrorFallbackStrategy
Expand Down
1 change: 1 addition & 0 deletions simplemma/strategies/dictionaries/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Dictionary-based lemmatization strategy."""

from .dictionary_factory import DefaultDictionaryFactory, DictionaryFactory
from .trie_directory_factory import TrieDictionaryFactory
34 changes: 26 additions & 8 deletions simplemma/strategies/dictionaries/dictionary_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from functools import lru_cache
from os import listdir, path
from pathlib import Path
from typing import ByteString, Dict, Protocol
from typing import ByteString, Dict, Mapping, Protocol

DATA_FOLDER = str(Path(__file__).parent / "data")
SUPPORTED_LANGUAGES = [
Expand Down Expand Up @@ -62,22 +62,41 @@
def get_dictionary(
self,
lang: str,
) -> Dict[ByteString, ByteString]:
) -> Mapping[str, str]:
"""
Get the dictionary for a specific language.

Args:
lang (str): The language code.

Returns:
Dict[str, str]: The dictionary for the specified language.
Mapping[str, str]: The dictionary for the specified language.

Raises:
ValueError: If the specified language is not supported.
"""
raise NotImplementedError


class MappingStrToByteString(Mapping[str, str]):
"""Wrapper around ByString dict to make them behave like str dict."""

__slots__ = ["_dict"]

def __init__(self, dictionary: Dict[bytes, bytes]):
self._dict = dictionary

def __getitem__(self, item: str):
return self._dict[item.encode()].decode()

def __iter__(self):
for key in self._dict:
yield key.decode()

def __len__(self):
return len(self._dict)

Check warning on line 97 in simplemma/strategies/dictionaries/dictionary_factory.py

View check run for this annotation

Codecov / codecov/patch

simplemma/strategies/dictionaries/dictionary_factory.py#L97

Added line #L97 was not covered by tests


class DefaultDictionaryFactory(DictionaryFactory):
"""
Default Dictionary Factory.
Expand All @@ -86,7 +105,7 @@
It provides functionality for loading and caching dictionaries from disk that are included in Simplemma.
"""

__slots__ = ["_data", "_load_dictionary_from_disk"]
__slots__ = ["_load_dictionary_from_disk"]

def __init__(self, cache_max_size: int = 8):
"""
Expand All @@ -96,27 +115,26 @@
cache_max_size (int): The maximum size of the cache for loaded dictionaries.
Defaults to `8`.
"""
self._data: Dict[str, Dict[ByteString, ByteString]] = {}
self._load_dictionary_from_disk = lru_cache(maxsize=cache_max_size)(
_load_dictionary_from_disk
)

def get_dictionary(
self,
lang: str,
) -> Dict[ByteString, ByteString]:
) -> Mapping[str, str]:
"""
Get the dictionary for a specific language.

Args:
lang (str): The language code.

Returns:
Dict[str, str]: The dictionary for the specified language.
Mapping[str, str]: The dictionary for the specified language.

Raises:
ValueError: If the specified language is not supported.
"""
if lang not in SUPPORTED_LANGUAGES:
raise ValueError(f"Unsupported language: {lang}")
return self._load_dictionary_from_disk(lang)
return MappingStrToByteString(self._load_dictionary_from_disk(lang))
123 changes: 123 additions & 0 deletions simplemma/strategies/dictionaries/trie_directory_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import logging
from collections.abc import MutableMapping
from functools import lru_cache
from pathlib import Path
from typing import List, Mapping, Optional

from marisa_trie import BytesTrie, HUGE_CACHE # type: ignore[import-not-found]
from platformdirs import user_cache_dir
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

platformdirs is also a external dependency. So, why it doesn't need the type ignore like marisa_trie?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I'm not entirely sure, my guess is that mypy can't get any type information, because marisa_trie isn't Python code, but a C-extension and doesn't provide type information.


from simplemma import __version__ as SIMPLEMMA_VERSION
from simplemma.strategies.dictionaries.dictionary_factory import (
DefaultDictionaryFactory,
DictionaryFactory,
SUPPORTED_LANGUAGES,
)

logger = logging.getLogger(__name__)


class TrieWrapDict(MutableMapping):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we simply modify the DictionaryFactory protocol to return a Mapping instead of Dict instead of having this wrapper?

Why MutableMapping instead of Mapping?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of the constraints I gave myself was to implement this functionality without requiring changes to the existing code of Simplemma. That's why I didn't modify expected return types of the DictionaryFactory for example. Doing so would certainly simplify things, however that's something I didn't want to decide on my own.

It's a MutableMapping right now, because a dict, what's used when using the DefaultDictionaryFactory, is as well. No further reason beyond that.

"""Wrapper around BytesTrie to make them behave like dicts."""

def __init__(self, trie: BytesTrie):
self._trie = trie

def __getitem__(self, item):
return self._trie[item][0].decode()

def __setitem__(self, key, value):
raise NotImplementedError

def __delitem__(self, key):
raise NotImplementedError

def __iter__(self):
for key in self._trie.iterkeys():
yield key

def __len__(self):
return len(self._trie)


class TrieDictionaryFactory(DictionaryFactory):
"""Memory optimized DictionaryFactory backed by MARISA-tries.

This dictionary factory creates dictionaries, which are backed by a
MARISA-trie instead of a dict, to make them consume very little
memory compared to the DefaultDictionaryFactory. Trade-offs are that
lookup performance isn't as good as with dicts.
"""

__slots__: List[str] = []

def __init__(
self,
cache_max_size: int = 8,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are doing double caching: in-memory and in-disk.
These params make unclear what it is about.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I completely agree. I just kept that parameter from DefaultDictionaryFactory to be able to use TrieDictionaryFactory as a drop-in replacement for it.

use_disk_cache: bool = True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use_disk_cache is not needed.
If the disk_cache_dir is None, then you don't use disk caching.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now if disk_cache_dir is None, a subdirectory in the users platform-specific cache directory is used to store the cache. So use_disk_cache is used right now to distinguish between disabling the disk cache and using the default cache location.

disk_cache_dir: Optional[str] = None,
) -> None:
"""Initialize the TrieDictionaryFactory.

Args:
cache_max_size (int): The maximum number dictionaries to
keep in memory. Defaults to `8`.
use_disk_cache (bool): Whether to cache the tries on disk to
speed up loading time. Defaults to `True`.
disk_cache_dir (Optional[str]): Path where the generated
tries should be stored in. Defaults to a Simplemma-
specific subdirectory of the user's cache directory.
"""

if disk_cache_dir:
self._cache_dir = Path(disk_cache_dir)
else:
self._cache_dir = (
Path(user_cache_dir("simplemma")) / "marisa_trie" / SIMPLEMMA_VERSION
)
self._use_disk_cache = use_disk_cache
self._get_dictionary = lru_cache(maxsize=cache_max_size)(
self._get_dictionary_uncached
)

def _create_trie_from_pickled_dict(self, lang: str) -> BytesTrie:
"""Create a trie from a pickled dictionary."""
unpickled_dict = DefaultDictionaryFactory(cache_max_size=0).get_dictionary(lang)
return BytesTrie(
zip(
unpickled_dict.keys(),
[value.encode() for value in unpickled_dict.values()],
),
cache_size=HUGE_CACHE,
)

def _write_trie_to_disk(self, lang: str, trie: BytesTrie) -> None:
"""Persist the trie to disk for later usage.

The persisted trie can be loaded by subsequent runs to speed up
loading times.
"""
logger.debug("Caching trie on disk. This might take a second.")
self._cache_dir.mkdir(parents=True, exist_ok=True)

trie.save(self._cache_dir / f"{lang}.dic")

def _get_dictionary_uncached(self, lang: str) -> Mapping[str, str]:
"""Get the dictionary for the given language."""
if lang not in SUPPORTED_LANGUAGES:
raise ValueError(f"Unsupported language: {lang}")

if self._use_disk_cache and (self._cache_dir / f"{lang}.dic").exists():
trie = BytesTrie().load(str(self._cache_dir / f"{lang}.dic"))
else:
trie = self._create_trie_from_pickled_dict(lang)
if self._use_disk_cache:
self._write_trie_to_disk(lang, trie)

return TrieWrapDict(trie)

def get_dictionary(
self,
lang: str,
) -> Mapping[str, str]:
return self._get_dictionary(lang)
13 changes: 3 additions & 10 deletions simplemma/strategies/dictionary_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
It provides lemmatization using dictionary lookup.
"""

from typing import ByteString, Dict, Optional
from typing import Optional

from .dictionaries.dictionary_factory import DefaultDictionaryFactory, DictionaryFactory
from .lemmatization_strategy import LemmatizationStrategy
Expand All @@ -26,13 +26,6 @@ def __init__(
"""
self._dictionary_factory = dictionary_factory

def _get(
self, token: str, dictionary: Dict[ByteString, ByteString]
) -> Optional[str]:
"Convenience function to handle bytestring to string conversion."
result = dictionary.get(token.encode("utf-8"))
return result.decode("utf-8") if result else None # type: ignore[union-attr]

def get_lemma(self, token: str, lang: str) -> Optional[str]:
"""
Get Lemma using Dictionary Lookup
Expand All @@ -50,9 +43,9 @@ def get_lemma(self, token: str, lang: str) -> Optional[str]:
"""
# Search the language data, reverse case to extend coverage.
dictionary = self._dictionary_factory.get_dictionary(lang)
result = self._get(token, dictionary)
result = dictionary.get(token)
if result:
return result
# Try upper or lowercase.
token = token.lower() if token[0].isupper() else token.capitalize()
return self._get(token, dictionary)
return dictionary.get(token)
4 changes: 2 additions & 2 deletions simplemma/strategies/greedy_dictionary_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def get_lemma(self, token: str, lang: str) -> str:
return token

dictionary = self._dictionary_factory.get_dictionary(lang)
candidate = token.encode("utf-8")
candidate = token
for _ in range(self._steps):
if candidate not in dictionary:
break
Expand All @@ -73,4 +73,4 @@ def get_lemma(self, token: str, lang: str) -> str:

candidate = new_candidate

return candidate.decode("utf-8")
return candidate
8 changes: 2 additions & 6 deletions simplemma/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
- [validate_lang_input][simplemma.utils.validate_lang_input]: Validates the language input and ensures it is a valid tuple.
"""

from typing import ByteString, Tuple, Union
from typing import Tuple, Union


def validate_lang_input(lang: Union[str, Tuple[str, ...]]) -> Tuple[str]:
Expand All @@ -31,9 +31,7 @@ def validate_lang_input(lang: Union[str, Tuple[str, ...]]) -> Tuple[str]:
return lang # type: ignore[return-value]


def levenshtein_dist(
first: Union[ByteString, str], second: Union[ByteString, str]
) -> int:
def levenshtein_dist(str1: str, str2: str) -> int:
"""
Calculate the Levenshtein distance between two strings.

Expand All @@ -49,8 +47,6 @@ def levenshtein_dist(
int: The Levenshtein distance between the two strings.

"""
str1 = first.encode("utf-8") if isinstance(first, str) else first
str2 = second.encode("utf-8") if isinstance(second, str) else second
# inspired by this noticeably faster code:
# https://gist.github.com/p-hash/9e0f9904ce7947c133308fbe48fe032b
if str1 == str2:
Expand Down
Loading
Loading