diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d0a009382..b4961aae6 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -154,7 +154,7 @@ jobs: strategy: fail-fast: true matrix: - python-version: ["3.7", "3.8", "3.9"] + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v2 @@ -172,6 +172,7 @@ jobs: - name: Install library run: | - pip install . + pip install . pytest + pytest tests/pipelines/test_pipelines.py # uv venv # uv pip install . diff --git a/changelog.md b/changelog.md index b4ec18b13..c88ce1d74 100644 --- a/changelog.md +++ b/changelog.md @@ -20,6 +20,7 @@ ### Changed - Rename `eds.measurements` to `eds.quantities` +- scikit-learn (used in `eds.endlines`) is no longer installed by default when installing `edsnlp[ml]` ## v0.13.0 diff --git a/edsnlp/pipes/core/endlines/endlines.py b/edsnlp/pipes/core/endlines/endlines.py index 42003017d..6d0deb275 100644 --- a/edsnlp/pipes/core/endlines/endlines.py +++ b/edsnlp/pipes/core/endlines/endlines.py @@ -22,6 +22,10 @@ class EndLinesMatcher(GenericMatcher): Behind the scenes, it uses a `endlinesmodel` instance, which is an unsupervised algorithm based on the work of [@zweigenbaum2016]. + !!! warning "Installation" + + To use this component, you need to install the `scikit-learn` library. + Training -------- ```python @@ -93,12 +97,12 @@ class EndLinesMatcher(GenericMatcher): Extensions ---------- - The `eds.endlines` pipeline declares one extension, on both `Span` and `Token` - objects. The `end_line` attribute is a boolean, set to `True` if the pipeline + The `eds.endlines` pipe declares one extension, on both `Span` and `Token` + objects. The `end_line` attribute is a boolean, set to `True` if the pipe predicts that the new line is an end line character. Otherwise, it is set to `False` if the new line is classified as a space. - The pipeline also sets the `excluded` custom attribute on newlines that are + The pipe also sets the `excluded` custom attribute on newlines that are classified as spaces. It lets downstream matchers skip excluded tokens (see [normalisation](/pipes/core/normalisation/)) for more detail. @@ -113,7 +117,7 @@ class EndLinesMatcher(GenericMatcher): Authors and citation -------------------- - The `eds.endlines` pipeline was developed by AP-HP's Data Science team based on + The `eds.endlines` pipe was developed by AP-HP's Data Science team based on the work of [@zweigenbaum2016]. ''' diff --git a/edsnlp/pipes/misc/quantities/quantities.py b/edsnlp/pipes/misc/quantities/quantities.py index fee085220..bc57d4d0c 100644 --- a/edsnlp/pipes/misc/quantities/quantities.py +++ b/edsnlp/pipes/misc/quantities/quantities.py @@ -612,7 +612,7 @@ def __init__( as_ents: bool = False, span_setter: Optional[SpanSetterArg] = None, use_tables: bool = True, - measurements: Union[str, List[Union[str, MsrConfig]], Dict[str, MsrConfig]] = None # deprecated # noqa: E501 + measurements: Optional[Union[str, List[Union[str, MsrConfig]], Dict[str, MsrConfig]]] = None, # deprecated # noqa: E501 ): if measurements: @@ -632,7 +632,7 @@ def __init__( "Skipping that step." ) - self.all_quantities = (quantities == "all") + self.all_quantities = quantities == "all" if self.all_quantities: quantities = [] @@ -659,9 +659,7 @@ def __init__( self.extract_ranges = extract_ranges self.range_patterns = range_patterns self.span_getter = ( - validate_span_getter(span_getter) - if span_getter is not None - else None + validate_span_getter(span_getter) if span_getter is not None else None ) self.merge_mode = merge_mode self.before_snippet_limit = before_snippet_limit @@ -676,10 +674,7 @@ def __init__( "ents": as_ents, "measurements": True, "quantities": True, - **{ - name: [name] - for name in self.measure_names.values() - } + **{name: [name] for name in self.measure_names.values()}, } super().__init__(nlp=nlp, name=name, span_setter=span_setter) @@ -1033,10 +1028,17 @@ def get_matches_before(i): table_pd = table._.to_pd_table(as_spans=True) # Find out the number's row for _, row in table_pd.iterrows(): - start_line = next((item.start for item in row - if item is not None), None) - end_line = next((item.end for item in reversed(row) - if item is not None), None) + start_line = next( + (item.start for item in row if item is not None), None + ) + end_line = next( + ( + item.end + for item in reversed(row) + if item is not None + ), + None, + ) if start_line is None: continue @@ -1136,10 +1138,7 @@ def is_within_row(x): else: ent.label_ = self.measure_names[dims] - ent._.set( - ent.label_, - SimpleQuantity(value, unit_norm, self.unit_registry) - ) + ent._.set(ent.label_, SimpleQuantity(value, unit_norm, self.unit_registry)) quantities.append(ent) @@ -1224,9 +1223,7 @@ def merge_quantities_in_ranges(self, quantities: List[Span]) -> List[Span]: ] if len(matching_patterns): try: - new_value = RangeQuantity.from_quantities( - last._.value, ent._.value - ) + new_value = RangeQuantity.from_quantities(last._.value, ent._.value) merged[-1] = last = last.doc[ last.start if matching_patterns[0][0] is None @@ -1296,7 +1293,8 @@ def __call__(self, doc): existing = ( list(get_spans(doc, self.span_getter)) if self.span_getter is not None - else ()) + else () + ) snippets = ( dict.fromkeys(ent.sent for ent in existing) if self.span_getter is not None diff --git a/pyproject.toml b/pyproject.toml index 7aa3491df..9f2bacf64 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,7 +105,6 @@ ml = [ "safetensors>=0.3.0", "transformers>=4.0.0,<5.0.0", "accelerate>=0.20.3,<1.0.0", - "scikit-learn>=1.0.0", ] [project.urls] diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index fa09cc8e6..596266950 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -12,5 +12,5 @@ def test_import_all(): import edsnlp.pipes for name in dir(edsnlp.pipes): - if not name.startswith("_"): + if not name.startswith("_") and "endlines" not in name: getattr(edsnlp.pipes, name)