Skip to content

Commit

Permalink
fix: training span qualifiers and multi task models
Browse files Browse the repository at this point in the history
  • Loading branch information
percevalw committed Oct 9, 2023
1 parent 4e5f2ed commit e19f23a
Show file tree
Hide file tree
Showing 19 changed files with 513 additions and 308 deletions.
18 changes: 6 additions & 12 deletions edsnlp/connectors/brat.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,8 @@ def brat2docs(self, nlp: PipelineProtocol, run_pipe=False) -> List[Doc]:
else:
gold_docs = (nlp.make_doc(t) for t in texts)

attr_map = dict(self.attr_map or {})

for doc, doc_annotations in tqdm(
zip(gold_docs, annotations),
ascii=True,
Expand All @@ -399,13 +401,12 @@ def brat2docs(self, nlp: PipelineProtocol, run_pipe=False) -> List[Doc]:
if not Span.has_extension(dst):
Span.set_extension(dst, default=None)

encountered_attributes = set()
for ent in doc_annotations["entities"]:
if self.attr_map is None:
if self.attr_map is None: # attr_map unset by the user
for a in ent["attributes"]:
if not Span.has_extension(a["label"]):
Span.set_extension(a["label"], default=None)
encountered_attributes.add(a["label"])
attr_map[a["label"]] = a["label"]

for fragment in ent["fragments"]:
span = doc.char_span(
Expand All @@ -415,12 +416,8 @@ def brat2docs(self, nlp: PipelineProtocol, run_pipe=False) -> List[Doc]:
alignment_mode="expand",
)
for a in ent["attributes"]:
if self.attr_map is None or a["label"] in self.attr_map:
new_name = (
a["label"]
if self.attr_map is None
else self.attr_map[a["label"]]
)
if a["label"] in attr_map:
new_name = attr_map[a["label"]]
span._.set(
new_name, a["value"] if a["value"] is not None else True
)
Expand All @@ -429,9 +426,6 @@ def brat2docs(self, nlp: PipelineProtocol, run_pipe=False) -> List[Doc]:
if self.span_groups is None or ent["label"] in self.span_groups:
span_groups[ent["label"]].append(span)

if self.attr_map is None:
self.attr_map = {k: k for k in encountered_attributes}

if self.span_groups is None:
self.span_groups = sorted(span_groups.keys())

Expand Down
27 changes: 13 additions & 14 deletions edsnlp/core/torch_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def cached_preprocess(fn):
def wrapped(self: "TorchComponent", doc: Doc):
if not self.nlp or self.nlp._cache is None:
return fn(self, doc)
cache_id = hash((id(self), "preprocess", id(doc)))
if not self.nlp._cache_is_writeonly and cache_id in self.nlp._cache:
cache_id = hash((self.name, "preprocess", id(doc)))
if cache_id in self.nlp._cache:
return self.nlp._cache[cache_id]
res = fn(self, doc)
self.nlp._cache[cache_id] = res
Expand All @@ -61,10 +61,8 @@ def cached_preprocess_supervised(fn):
def wrapped(self: "TorchComponent", doc: Doc):
if not self.nlp or self.nlp._cache is None:
return fn(self, doc)
cache_id = hash((id(self), "preprocess_supervised", id(doc)))
if not self.nlp._cache_is_writeonly and cache_id in self.nlp._cache.setdefault(
self, {}
):
cache_id = hash((self.name, "preprocess_supervised", id(doc)))
if cache_id in self.nlp._cache:
return self.nlp._cache[cache_id]
res = fn(self, doc)
self.nlp._cache[cache_id] = res
Expand All @@ -76,14 +74,13 @@ def wrapped(self: "TorchComponent", doc: Doc):
def cached_collate(fn):
@wraps(fn)
def wrapped(self: "TorchComponent", batch: Dict):
cache_id = hash((id(self), "collate", hash_batch(batch)))
cache_id = (self.name, "collate", hash_batch(batch))
if not self.nlp or self.nlp._cache is None or cache_id is None:
return fn(self, batch)
if not self.nlp._cache_is_writeonly and cache_id in self.nlp._cache:
if cache_id in self.nlp._cache:
return self.nlp._cache[cache_id]
res = fn(self, batch)
self.nlp._cache[cache_id] = res
res["cache_id"] = cache_id
return res

return wrapped
Expand All @@ -95,7 +92,7 @@ def wrapped(self: "TorchComponent", batch):
# Convert args and kwargs to a dictionary matching fn signature
if not self.nlp or self.nlp._cache is None:
return fn(self, batch)
cache_id = (id(self), "forward", hash_batch(batch))
cache_id = (self.name, "forward", hash_batch(batch))
if cache_id in self.nlp._cache:
return self.nlp._cache[cache_id]
res = fn(self, batch)
Expand All @@ -105,13 +102,13 @@ def wrapped(self: "TorchComponent", batch):
return wrapped


def cached_move_to_device(fn):
def cached_batch_to_device(fn):
@wraps(fn)
def wrapped(self: "TorchComponent", batch, device):
# Convert args and kwargs to a dictionary matching fn signature
if not self.nlp or self.nlp._cache is None:
return fn(self, batch, device)
cache_id = (id(self), "move_to_device", hash_batch(batch))
cache_id = (self.name, "batch_to_device", hash_batch(batch))
if cache_id in self.nlp._cache:
return self.nlp._cache[cache_id]
res = fn(self, batch, device)
Expand All @@ -131,8 +128,10 @@ def __new__(mcs, name, bases, class_dict):
)
if "collate" in class_dict:
class_dict["collate"] = cached_collate(class_dict["collate"])
if "move_to_device" in class_dict:
class_dict["move_to_device"] = cached_move_to_device(class_dict["collate"])
if "batch_to_device" in class_dict:
class_dict["batch_to_device"] = cached_batch_to_device(
class_dict["batch_to_device"]
)
if "forward" in class_dict:
class_dict["forward"] = cached_forward(class_dict["forward"])

Expand Down
2 changes: 1 addition & 1 deletion edsnlp/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(
max_value=None,
start_value=0.0,
path="lr",
warmup_rate=0.1,
warmup_rate=0.0,
):
self.path = path
self.start_value = start_value
Expand Down
36 changes: 21 additions & 15 deletions edsnlp/pipelines/trainable/embeddings/span_pooler/span_pooler.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,35 +105,41 @@ def preprocess(self, doc: Doc) -> Dict[str, Any]:
begins = []
ends = []

for i, (embedded_span, target_ents) in enumerate(
zip(
embedded_spans,
align_spans(
source=spans,
target=embedded_spans,
),
)
embedded_spans_to_idx = {span: i for i, span in enumerate(embedded_spans)}
for i, (span, embedding_spans) in enumerate(
zip(spans, align_spans(embedded_spans, spans))
):
start = embedded_span.start
sequence_idx.extend([i] * len(spans))
begins.extend([span.start - start for span in spans])
ends.extend([span.end - start for span in spans])
if len(embedding_spans) != 1:
raise Exception(
f"Span {span} is not aligned to exactly one embedding span: "
f"{embedding_spans}"
)
start = embedding_spans[0].start
sequence_idx.append(embedded_spans_to_idx[embedding_spans[0]])
begins.append(span.start - start)
ends.append(span.end - start)
return {
"embedding": self.embedding.preprocess(doc),
"begins": begins,
"ends": ends,
"sequence_idx": sequence_idx,
"num_sequences": len(embedded_spans),
"$spans": spans,
"$embedded_spans": embedded_spans,
}

def collate(self, batch: Dict[str, Sequence[Any]]) -> SpanPoolerBatchInput:
sequence_idx = []
offset = 0
for indices, seq_length in zip(batch["sequence_idx"], batch["num_sequences"]):
sequence_idx.extend([offset + idx for idx in indices])
offset += seq_length

collated: SpanPoolerBatchInput = {
"embedding": self.embedding.collate(batch["embedding"]),
"begins": torch.as_tensor([b for x in batch["begins"] for b in x]),
"ends": torch.as_tensor([e for x in batch["ends"] for e in x]),
"sequence_idx": torch.as_tensor(
[e for x in batch["sequence_idx"] for e in x]
),
"sequence_idx": torch.as_tensor(sequence_idx),
}
return collated

Expand Down
Loading

0 comments on commit e19f23a

Please sign in to comment.