Skip to content

Commit

Permalink
tests: improve coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
percevalw committed Jul 28, 2023
1 parent 1466ee0 commit e04bc4b
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 65 deletions.
49 changes: 9 additions & 40 deletions edsnlp/core/registry.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from functools import partial

import catalogue
import inspect
import spacy
Expand Down Expand Up @@ -155,12 +153,17 @@ def curried(**kwargs):
func = catalogue._get(namespace)
return curried

available = self.get_available()
current_namespace = " -> ".join(self.namespace)
available_str = ", ".join(available) or "none"
available_str = ", ".join(self.get_available()) or "none"
raise catalogue.RegistryError(
f"Can't find '{name}' in registry {current_namespace}. "
f"Available names: {available_str}"
(
"Can't find '{name}' in registry {current_namespace}. "
"Available names: {available_str}"
).format(
name=name,
current_namespace=current_namespace,
available_str=available_str,
)
)

def register(
Expand Down Expand Up @@ -196,7 +199,6 @@ def register(
save_params = {"@factory": name}

def register(fn: catalogue.InFunc) -> catalogue.InFunc:

if len(accepted_arguments(fn, ["nlp", "name"])) < 2:
raise ValueError(
"Factory functions must accept nlp and name as arguments."
Expand Down Expand Up @@ -244,39 +246,6 @@ def invoke(validated_fn, kwargs):
return register(func) if func is not None else register


class TokenizerRegistry(Registry):
def get(self, name: str) -> Any:
"""
Get a tokenizer
"""

def curried(**kwargs):
return partial(func, **kwargs)

namespace = list(self.namespace) + [name]
spacy_namespace = ["spacy", "tokenizers", name]
if catalogue.check_exists(*namespace):
func = catalogue._get(namespace)
return curried
elif catalogue.check_exists(*spacy_namespace):
func = catalogue._get(spacy_namespace)
return curried

if self.entry_points:
self.get_entry_point(name)
if catalogue.check_exists(*namespace):
func = catalogue._get(namespace)
return curried

available = self.get_available()
current_namespace = " -> ".join(self.namespace)
available_str = ", ".join(available) or "none"
raise catalogue.RegistryError(
f"Can't find '{name}' in registry {current_namespace}. "
f"Available names: {available_str}"
)


class registry(RegistryCollection):
factory = factories = FactoryRegistry(("spacy", "factories"), entry_points=True)
misc = Registry(("spacy", "misc"), entry_points=True)
Expand Down
17 changes: 5 additions & 12 deletions edsnlp/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,17 @@
from edsnlp.utils.collections import get_deep_attr, set_deep_attr


def split_name(names):
_names = []
for part in names.split("."):
try:
_names.append(int(part))
except ValueError:
_names.append(part)
return _names


class ScheduledOptimizer(torch.optim.Optimizer):
def __init__(self, optim):
self.optim = optim
schedule_to_groups = defaultdict(lambda: [])
for group in self.optim.param_groups:
if "schedules" in group:
if not isinstance(group["schedules"], list):
group["schedules"] = [group["schedules"]]
group["schedules"] = (
group["schedules"]
if isinstance(group["schedules"], list)
else [group["schedules"]]
)
group["schedules"] = list(group["schedules"])
for schedule in group["schedules"]:
schedule_to_groups[schedule].append(group)
Expand Down
16 changes: 7 additions & 9 deletions edsnlp/utils/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,12 @@ def __init__(self, seq: Iterable[Dict[str, Any]]):
def __iter__(self):
return batch_compress_dict(iter(self.seq))

def __getstate__(self):
return {"seq": self.seq}
# def __getstate__(self):
# return {"seq": self.seq}

def __setstate__(self, state):
self.seq = state["seq"]
self.flatten = None
# def __setstate__(self, state):
# self.seq = state["seq"]
# self.flatten = None

def __next__(self) -> Dict[str, List]:
exec_result = {}
Expand Down Expand Up @@ -128,19 +128,17 @@ def decompress_dict(seq: Union[Iterable[Dict[str, Any]], Dict[str, Any]]):
return res


def dedup(sequence, key=None):
def dedup(sequence, key):
"""
Deduplicate a sequence, keeping the last occurrence of each item.
Parameters
----------
sequence : Sequence
Sequence to deduplicate
key : Callable, optional
key : Callable
Key function to use for deduplication, by default None
"""
if key is None:
return list(dict.fromkeys(sequence))
return list({key(item): item for item in sequence}.values())


Expand Down
14 changes: 10 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,16 @@ omit-covered-files = false
# badge-format = "svg"

[tool.coverage]
exclude_lines = [
"raise NotImplementedError",
"def __repr__",
]
exclude_also = [
"def __repr__",
"raise AssertionError",
"raise NotImplementedError",
"raise .*Error",
"if __name__ == .__main__.:",
"if TYPE_CHECKING:",
"class .*\\bProtocol\\):",
"@(abc\\.)?abstractmethod",
]

[tool.ruff]
exclude = [
Expand Down

0 comments on commit e04bc4b

Please sign in to comment.