From b8cc3e3b9a2a96d559a4252e0c0aef03964c5c22 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Perceval=20Wajsb=C3=BCrt?= Date: Mon, 27 May 2024 14:38:05 +0200 Subject: [PATCH] feat: improve eds.table matcher Co-Authored-By: Jacques Ung --- changelog.md | 4 + edsnlp/pipes/misc/tables/__init__.py | 1 - edsnlp/pipes/misc/tables/patterns.py | 6 +- edsnlp/pipes/misc/tables/tables.py | 180 ++++++++++++++++++++------- tests/pipelines/misc/test_tables.py | 44 ++++++- 5 files changed, 183 insertions(+), 52 deletions(-) diff --git a/changelog.md b/changelog.md index 2e01d9d59..678d92f35 100644 --- a/changelog.md +++ b/changelog.md @@ -24,6 +24,10 @@ - Added a new `eds.ner_overlap_scorer` to evaluate matches between two lists of entities, counting true when the dice overlap is above a given threshold - `edsnlp.load` now accepts EDS-NLP models from the huggingface hub 🤗 ! - New `python -m edsnlp.package` command to package a model for the huggingface hub or pypi-like registries +- Improve table detection in `eds.tables` and support new options in `table._.to_pd_table(...)`: + - `header=True` to use first row as header + - `index=True` to use first column as index + - `as_spans=True` to fill cells as document spans instead of strings ### Changed diff --git a/edsnlp/pipes/misc/tables/__init__.py b/edsnlp/pipes/misc/tables/__init__.py index 861ef7628..5437a3e5f 100644 --- a/edsnlp/pipes/misc/tables/__init__.py +++ b/edsnlp/pipes/misc/tables/__init__.py @@ -1,2 +1 @@ -from .patterns import regex, sep from .tables import TablesMatcher diff --git a/edsnlp/pipes/misc/tables/patterns.py b/edsnlp/pipes/misc/tables/patterns.py index b200aa2d7..919143f60 100644 --- a/edsnlp/pipes/misc/tables/patterns.py +++ b/edsnlp/pipes/misc/tables/patterns.py @@ -1,4 +1,2 @@ -sep = r"¦|\|" -regex = dict( - tables=rf"(\b.*{sep}.*\n)+", -) +sep = ["¦", "|"] +regex_template = [r"(?:{sep}?(?:[^{sep}\n]*{sep})+[^{sep}\n]*{sep}?\n)+"] diff --git a/edsnlp/pipes/misc/tables/tables.py b/edsnlp/pipes/misc/tables/tables.py index 0da6f65c3..ae57300d4 100644 --- a/edsnlp/pipes/misc/tables/tables.py +++ b/edsnlp/pipes/misc/tables/tables.py @@ -1,16 +1,18 @@ -from io import StringIO +import re from typing import Dict, Optional, Union import pandas as pd from spacy.tokens import Doc, Span from edsnlp.core import PipelineProtocol -from edsnlp.pipes.core.matcher.matcher import GenericMatcher +from edsnlp.matchers.phrase import EDSPhraseMatcher +from edsnlp.matchers.regex import RegexMatcher +from edsnlp.pipes.base import BaseComponent from edsnlp.pipes.misc.tables import patterns -from edsnlp.utils.filter import get_spans +from edsnlp.utils.typing import AsList -class TablesMatcher(GenericMatcher): +class TablesMatcher(BaseComponent): ''' The `eds.tables` matcher detects tables in a documents. @@ -70,7 +72,11 @@ class TablesMatcher(GenericMatcher): # VMP ¦fL ¦11.5 + ¦7.4-10.8 # Convert span to Pandas table - df = table._.to_pd_table() + df = table._.to_pd_table( + as_spans=False, # set True to set the table cells as spans instead of strings + header=False, # set True to use the first row as header + index=False, # set True to use the first column as index + ) type(df) # Out: pandas.core.frame.DataFrame ``` @@ -96,7 +102,7 @@ class TablesMatcher(GenericMatcher): Parameters ---------- nlp : PipelineProtocol - spaCy nlp pipeline to use for matching. + Pipeline object name: str Name of the component. tables_pattern : Optional[Dict[str, str]] @@ -120,41 +126,106 @@ class TablesMatcher(GenericMatcher): def __init__( self, nlp: PipelineProtocol, - name: str = "tables", + name: Optional[str] = "tables", *, - tables_pattern: Optional[Dict[str, str]] = None, - sep_pattern: Optional[str] = None, + tables_pattern: Optional[AsList[str]] = None, + sep_pattern: Optional[AsList[str]] = None, attr: Union[Dict[str, str], str] = "TEXT", ignore_excluded: bool = True, ): - if tables_pattern is None and sep_pattern is None: - self.tables_pattern = patterns.regex - self.sep = patterns.sep - elif tables_pattern is None or sep_pattern is None: - raise ValueError( - "Both tables_pattern and sep_pattern must be provided " - "for custom eds.table pipeline." - ) - else: - self.tables_pattern = tables_pattern - self.sep = sep_pattern - - super().__init__( - nlp=nlp, - name=name, - terms=None, - regex=self.tables_pattern, - attr=attr, - ignore_excluded=ignore_excluded, + super().__init__(nlp, name) + if tables_pattern is None: + tables_pattern = patterns.regex_template + + if sep_pattern is None: + sep_pattern = patterns.sep + + self.regex_matcher = RegexMatcher(attr=attr, ignore_excluded=ignore_excluded) + self.regex_matcher.add( + "table", + list( + dict.fromkeys( + template.format(sep=re.escape(sep)) + for sep in sep_pattern + for template in tables_pattern + ) + ), + ) + + self.term_matcher = EDSPhraseMatcher( + nlp.vocab, attr=attr, ignore_excluded=ignore_excluded + ) + self.term_matcher.build_patterns( + nlp, + { + "eol_pattern": "\n", + "sep_pattern": sep_pattern, + }, ) if not Span.has_extension("to_pd_table"): Span.set_extension("to_pd_table", method=self.to_pd_table) - self.set_extensions() + @classmethod + def set_extensions(cls) -> None: + """ + Set extensions for the tables pipeline. + """ + + if not Span.has_extension("table"): + Span.set_extension("table", default=None) + + def get_table(self, table): + """ + Convert spans of tables to dictionaries + Parameters + ---------- + table : Span + + Returns + ------- + List[Span] + """ + + # We store each row in a list and store each of hese lists + # in processed_table for post processing + # considering the self.col_names and self.row_names var + processed_table = [] + delimiters = [ + delimiter + for delimiter in self.term_matcher(table, as_spans=True) + if delimiter.start >= table.start and delimiter.end <= table.end + ] + + last = table.start + row = [] + # Parse the table to match each cell thanks to delimiters + for delimiter in delimiters: + row.append(table[last - table.start : delimiter.start - table.start]) + last = delimiter.end + + # End the actual row if there is an end of line + if delimiter.label_ == "eol_pattern": + processed_table.append(row) + row = [] + + # Remove first or last column in case the separator pattern is + # also used in the raw table to draw the outlines + max_len = max(len(row) for row in processed_table) + if all(row[0].start == row[0].end for row in processed_table): + processed_table = [row[1:] for row in processed_table] + if all( + row[-1].start == row[-1].end + for row in processed_table + if len(row) == max_len + ): + processed_table = [row[:-1] for row in processed_table] + + return processed_table def __call__(self, doc: Doc) -> Doc: - """Find spans that contain tables + """ + Find spans that contain tables Parameters ---------- @@ -164,21 +235,40 @@ def __call__(self, doc: Doc) -> Doc: ------- Doc """ - matches = self.process(doc) - tables = get_spans(matches, "tables") - # parsed = self.parse(tables=tables) + matches = list(self.regex_matcher(doc, as_spans=True)) + doc.spans["tables"] = matches + return doc - doc.spans["tables"] = tables + def to_pd_table( + self, + span, + as_spans=False, + header: bool = False, + index: bool = False, + ) -> pd.DataFrame: + """ + Return pandas DataFrame - return doc + Parameters + ---------- + span : Span + The span containing the table + as_spans : bool + Whether to return the table cells as spans + header : bool + Whether the table has a header + index : bool + Whether the table has an index + """ + table = self.get_table(span) + if not as_spans: + table = [[str(cell) for cell in data] for data in table] - def to_pd_table(self, span) -> pd.DataFrame: - table_str_io = StringIO(span.text) - parsed = pd.read_csv( - table_str_io, - sep=self.sep, - engine="python", - header=None, - on_bad_lines="skip", - ) - return parsed + table = pd.DataFrame.from_records(table) + if header: + table.columns = [str(k) for k in table.iloc[0]] + table = table[1:] + if index: + table.index = [str(k) for k in table.iloc[:, 0]] + table = table.iloc[:, 1:] + return table diff --git a/tests/pipelines/misc/test_tables.py b/tests/pipelines/misc/test_tables.py index 69ff76f3d..d147c7981 100644 --- a/tests/pipelines/misc/test_tables.py +++ b/tests/pipelines/misc/test_tables.py @@ -1,3 +1,6 @@ +import pytest +from spacy.tokens.span import Span + TEXT = """ Le patientqsfqfdf bla bla bla Leucocytes ¦x10*9/L ¦4.97 ¦4.09-11 @@ -14,18 +17,55 @@ 2/2Pat : | | |Intitulé RCP + |Libellé | Unité | Valeur | Intervalle | + |Leucocytes |x10*9/L |4.97 | 4.09-11 | + |Hématies |x10*12/L|4.68 | 4.53-5.79 | + |Hémoglobine |g/dL |14.8 | 13.4-16.7 | + |Hématocrite ||44.2 | 39.2-48.6 | + |VGM |fL | 94.4 + | 79.6-94 | + |TCMH |pg |31.6 | + |CCMH |g/dL + |Plaquettes |x10*9/L |191 | 172-398 | + |VMP |fL |11.5 + | 7.4-10.8 | """ def test_tables(blank_nlp): + if blank_nlp.lang != "eds": + pytest.skip("Test only for eds language") blank_nlp.add_pipe("eds.normalizer") blank_nlp.add_pipe("eds.tables") doc = blank_nlp(TEXT) - assert len(doc.spans["tables"]) == 1 + assert len(doc.spans["tables"]) == 2 span = doc.spans["tables"][0] df = span._.to_pd_table() - assert df.iloc[5, 0] == "TCMH " + assert len(df.columns) == 4 + assert len(df) == 9 + assert str(df.iloc[5, 0]) == "TCMH" + + span = doc.spans["tables"][1] + df = span._.to_pd_table(header=True, index=True, as_spans=True) + print(df) + assert df.columns.tolist() == [ + "Unité", + "Valeur", + "Intervalle", + ] + assert df.index.tolist() == [ + "Leucocytes", + "Hématies", + "Hémoglobine", + "Hématocrite", + "VGM", + "TCMH", + "CCMH", + "Plaquettes", + "VMP", + ] + cell = df.loc["TCMH", "Valeur"] + assert isinstance(cell, Span) + assert cell.text == "31.6"