Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update: fixing bug raised around stanza model not having certain words in vocabulary, along with efforts to improve latency #115

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 108 additions & 56 deletions dialogy/plugins/text/list_search_plugin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,51 @@
"""
.. _list_Search_plugin_

.. _ListSearchPlugin_

Fuzzy Search
--------------
We have often seen certain keywords that gain significance in an SLU project. These keywords are
easy to extract via patterns and are often used to create entities.These patterns shouldn't be
frequently and parser should be able to handle ASR noise and multi token issues.
The :ref:`ListSearchPlugin<ListSearchPlugin>`
helps in this task, it requires a pattern-map, we call it :code:`fuzzy_dp_config`.

.. ipython::

In [1]: from dialogy.base import Input, Output
...: from dialogy.plugins import ListSearchPlugin
...: from dialogy.workflow import Workflow

In [2]: fuzzy_dp_config={
...: "en": {
...: "location": {
...: "delhi": "Delhi"
...: }
...: }
...: }

In [3]: l = ListSearchPlugin(
...: dest="output.entities",
...: fuzzy_threshold=0.4,
...: fuzzy_dp_config=fuzzy_dp_config)

In [4]: workflow = Workflow([l])

In [5]: _, output = workflow.run(Input(utterances="I live in deli"))

In [6]: output
Out[6]:
{'intents': [],
'entities': [{'range': {'start': 7, 'end': 14},
'body': 'in deli ',
'type': 'location',
'parsers': ['ListSearchPlugin', 'ListSearchPlugin'],
'score': 1.0,
'alternative_index': 0,
'value': 'Delhi',
'entity_type': 'location',
'_meta': {}}]}
Note
---------
Module needs refactor. We are currently keeping all strategies bundled as methods as opposed to SearchStrategyClasses.

Within dialogy, we extract entities using Duckling, Pattern lists and Spacy. We can ship individual plugins but at the
Expand All @@ -9,11 +54,11 @@
all other entities. So that their :code:`from_dict(...)` methods are pristine and involve no shape hacking.
"""
import re
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Pattern, Tuple

import stanza
from loguru import logger
from thefuzz import fuzz
from thefuzz import fuzz, process

from dialogy import constants as const
from dialogy.base import Guard, Input, Output, Plugin
Expand All @@ -26,6 +71,7 @@
Value = str
Score = float
MatchType = List[Tuple[Text, Label, Value, Span, Score]] # adding score for each entity
PatternList = List[Pattern[Any]]


class ListSearchPlugin(EntityScoringMixin, Plugin):
Expand All @@ -42,15 +88,12 @@ class ListSearchPlugin(EntityScoringMixin, Plugin):

:param style: One of ["regex", "spacy"]
:type style: Optional[str]
:param candidates: Required if style is "regex", this is a :code:`dict` that shows a mapping of entity
values and their patterns.
:type candidates: Optional[Dict[str, List[str]]]
:param spacy_nlp: Required if style is "spacy", this is a
`spacy model <https://spacy.io/usage/spacy-101#annotations-ner>`_.
:type spacy_nlp: Any
:param labels: Required if style is "spacy". If there is a need to extract only a few labels from all the other
`available labels <https://github.com/explosion/spaCy/issues/441#issuecomment-311804705>`_.
:type labels: Optional[List[str]]

:param fuzzy_dp_config: shows a mapping on enity values and their corresponding matches, is required
:type fuzzy_dp_config: Dict[Any, Any]
:fuzzy_threshold : is used for confidence thresholding of entities, matches below this threshold would not be returned
:param fuzzy_threshold : Optional[float]

:param debug: A flag to set debugging on the plugin methods
:type debug: bool
"""
Expand Down Expand Up @@ -99,6 +142,8 @@ def __init__(
self.entity_types: Dict[Any, Any] = {}
self.nlp: Dict[Any, Any] = {}
self.fuzzy_threshold = fuzzy_threshold
self.entity_patterns: Dict[Any, Any] = {}
self.compiled_patterns: Dict[Any, Any] = {}

if self.style == const.FUZZY_DP:
self.fuzzy_init()
Expand All @@ -119,6 +164,16 @@ def fuzzy_init(self) -> None:
self.nlp[lang_code] = stanza.Pipeline(
lang=lang_code, tokenize_pretokenized=True
)
self.entity_patterns[lang_code] = {}
self.compiled_patterns[lang_code] = {}
for entity_type in self.entity_types[lang_code]:
self.entity_patterns[lang_code][entity_type] = list(
self.entity_dict[lang_code][entity_type].keys()
)
self.compiled_patterns[lang_code][entity_type] = [
re.compile(r"\b" + pattern + r"\b")
for pattern in self.entity_patterns[lang_code][entity_type]
]

def _search(self, transcripts: List[str], lang: str) -> List[MatchType]:
"""
Expand All @@ -141,14 +196,14 @@ def search_regex(
self,
query: str,
entity_type: str = "",
entity_patterns: List[str] = [""],
entity_patterns: PatternList = [re.compile(r"", re.UNICODE)],
match_dict: Dict[Any, Any] = {},
) -> Tuple[Text, Label, Value, Span, Score]:
max_length = 0
final_match = None

for pattern in entity_patterns:
result = re.search(pattern, query)
result = pattern.search(query)
if result:
match_value = match_dict[result.group()]
match_len = len(match_value)
Expand All @@ -169,41 +224,42 @@ def dp_search(
entity_patterns: List[str] = [""],
match_dict: Dict[Any, Any] = {},
) -> Tuple[Text, Label, Value, Span, Score]:

sentence = nlp(query).sentences[0]
value = ""
pos_tags = ["PROPN", "NOUN", "ADP"]
result_dict = {}
for word in sentence.words:
if word.upos in pos_tags:
if value == "":
span_start = word.start_char
span_end = word.end_char

"""
joining individual tokens that together are the real entity,
Since we are dealing with Multi-Word entities here

"""
value = value + str(word.text) + " "
if value != "":
for pattern in entity_patterns:
val = fuzz.ratio(pattern, value) / 100
if val > self.fuzzy_threshold:
match_value = match_dict[pattern]
result_dict[match_value] = val
if result_dict:
match_output = max(result_dict, key=lambda x: result_dict[x])
match_score = result_dict[match_output]

return (
value,
entity_type,
match_output,
(span_start, span_end),
match_score,
)
return (value, entity_type, "", (0, 0), 0.0)
try:
sentence = nlp(query).sentences[0]
value = ""
pos_tags = ["PROPN", "NOUN", "ADP"]

for word in sentence.words:
if word.upos in pos_tags:
if value == "":
span_start = word.start_char
span_end = word.end_char

"""
joining individual tokens that together are the real entity,
Since we are dealing with Multi-Word entities here

"""
value = value + str(word.text) + " "
if value != "":
matches = process.extractOne(value, entity_patterns)
match_output = match_dict[
matches[0]
] # extracting highest confidence match from tuple
match_score = (
matches[1] / 100
) # extracting the associated confidence score and scaling it down to (0,1) range.
if match_score > self.fuzzy_threshold:
return (
value,
entity_type,
match_output,
(span_start, span_end),
match_score,
)
return (value, entity_type, "", (0, 0), 0.0)
except KeyError:
return ("", entity_type, "", (0, 0), 0.0)

# new method based on experiments done during development of channel parser
def get_fuzzy_dp_search(self, transcript: str, lang: str = "") -> MatchType:
Expand All @@ -218,17 +274,13 @@ def get_fuzzy_dp_search(self, transcript: str, lang: str = "") -> MatchType:
match = []
query = transcript

entity_patterns = {}
entity_match_dict = {}
for entity_type in self.entity_types[lang]:
entity_patterns[entity_type] = list(
self.entity_dict[lang][entity_type].keys()
)
entity_match_dict[entity_type] = self.entity_dict[lang][entity_type]
match_entity = self.search_regex(
query,
entity_type,
entity_patterns[entity_type],
self.compiled_patterns[lang][entity_type],
entity_match_dict[entity_type],
)

Expand All @@ -237,7 +289,7 @@ def get_fuzzy_dp_search(self, transcript: str, lang: str = "") -> MatchType:
query,
self.nlp[lang],
entity_type,
entity_patterns[entity_type],
self.entity_patterns[lang][entity_type],
entity_match_dict[entity_type],
)
match.append(match_entity)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,16 @@ def test_not_supported_lang():
l.get_entities(["........."], "te")


def test_entity_not_found():
def test_entity_not_found_and_keyerror():
l = ListSearchPlugin(
dest="output.entities",
fuzzy_threshold=0.4,
fuzzy_dp_config={"en": {"location": {"delhi": "Delhi"}}},
)
# testing entity not found
assert l.get_entities(["I live in punjab"], "en") == []
# testing keyrror
assert l.get_entities(["ramchandra k hathiyar"], "en") == []


@pytest.mark.parametrize("payload", load_tests("cases", __file__))
Expand Down