From d6e470ffa026c0fdebf72af0c98f6ccc1af56e28 Mon Sep 17 00:00:00 2001 From: Taniguchi Yasufumi Date: Sun, 23 Jul 2023 22:56:03 +0900 Subject: [PATCH] Update dependency (#39) * Update pytorch-partial-tagger * Use the Alignments class in pytorch-partial-tagger * Fix an unnecessary alias --- pyproject.toml | 2 +- .../{tokenizer.py => collator.py} | 32 ++++++++----------- spacy_partial_tagger/pipeline.py | 32 +++++++++---------- spacy_partial_tagger/tagger.py | 18 +++++------ 4 files changed, 38 insertions(+), 46 deletions(-) rename spacy_partial_tagger/{tokenizer.py => collator.py} (69%) diff --git a/pyproject.toml b/pyproject.toml index d1cc1c7..783ad82 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ "torch<3.0.0,>=2.0.1", "spacy[transformers]<4.0.0,>=3.3.1", "spacy-alignments<1.0.0,>=0.8.5", - "pytorch-partial-tagger<1.0.0,>=0.1.12", + "pytorch-partial-tagger<1.0.0,>=0.1.14", ] dynamic = ["version"] diff --git a/spacy_partial_tagger/tokenizer.py b/spacy_partial_tagger/collator.py similarity index 69% rename from spacy_partial_tagger/tokenizer.py rename to spacy_partial_tagger/collator.py index 0414c4f..ed8f919 100644 --- a/spacy_partial_tagger/tokenizer.py +++ b/spacy_partial_tagger/collator.py @@ -1,23 +1,17 @@ from typing import Optional, Tuple -from partial_tagger.data import Alignment, Span -from partial_tagger.data.batch.text import ( - BaseTokenizer, - TextBatch, - TransformerTokenizer, -) +from partial_tagger.data import Alignment, Alignments, Span +from partial_tagger.data.collators import BaseCollator, Batch, TransformerCollator from transformers import AutoTokenizer -from transformers.models.bert_japanese import ( - BertJapaneseTokenizer as _BertJapaneseTokenizer, -) +from transformers.models.bert_japanese import BertJapaneseTokenizer from .util import get_alignments -class BertJapaneseTokenizer(BaseTokenizer): +class BertJapaneseCollator(BaseCollator): def __init__( self, - tokenizer: _BertJapaneseTokenizer, + tokenizer: BertJapaneseTokenizer, tokenizer_args: Optional[dict] = None, ): self.__tokenizer = tokenizer @@ -29,7 +23,7 @@ def __init__( } self.__tokenizer_args["return_offsets_mapping"] = True - def __call__(self, texts: Tuple[str]) -> TextBatch: + def __call__(self, texts: Tuple[str]) -> Tuple[Batch, Alignments]: batch_encoding = self.__tokenizer(texts, **self.__tokenizer_args) pad_token_id = self.__tokenizer.pad_token_id @@ -54,16 +48,16 @@ def __call__(self, texts: Tuple[str]) -> TextBatch: alignments.append(Alignment(text, char_spans, tuple(token_indices))) - return TextBatch( - tagger_inputs=batch_encoding, mask=mask, alignments=tuple(alignments) + return Batch(tagger_inputs=batch_encoding, mask=mask), Alignments( + tuple(alignments) ) -def get_tokenizer( +def get_collator( transformer_model_name: str, tokenizer_args: Optional[dict] = None -) -> BaseTokenizer: +) -> BaseCollator: tokenizer = AutoTokenizer.from_pretrained(transformer_model_name) - if isinstance(tokenizer, _BertJapaneseTokenizer): - return BertJapaneseTokenizer(tokenizer, tokenizer_args) + if isinstance(tokenizer, BertJapaneseTokenizer): + return BertJapaneseCollator(tokenizer, tokenizer_args) else: - return TransformerTokenizer(tokenizer, tokenizer_args) + return TransformerCollator(tokenizer, tokenizer_args) diff --git a/spacy_partial_tagger/pipeline.py b/spacy_partial_tagger/pipeline.py index 0184a49..6245580 100644 --- a/spacy_partial_tagger/pipeline.py +++ b/spacy_partial_tagger/pipeline.py @@ -2,8 +2,7 @@ import srsly import torch -from partial_tagger.data import LabelSet -from partial_tagger.data.batch.tag import TagsBatch +from partial_tagger.data import Alignments, LabelSet from partial_tagger.training import compute_partially_supervised_loss from partial_tagger.utils import create_tag from spacy import util @@ -51,14 +50,16 @@ def set_annotations( docs: List[Doc], tag_indices: Floats2d, ) -> None: - for doc, indices in zip(docs, tag_indices.tolist()): - indices = [index for index in indices if index != self.padding_index] - alignment = doc.user_data["alignment"] + alignments = Alignments(tuple(doc.user_data["alignment"] for doc in docs)) + tags_batch = alignments.create_char_based_tags( + tag_indices.tolist(), + label_set=self.label_set, + padding_index=self.padding_index, + ) + + for doc, tags in zip(docs, tags_batch): ents = [] - for tag in alignment.create_char_based_tags( - tag_indices=indices, - label_set=self.label_set, - ): + for tag in tags: span = doc.char_span(tag.start, tag.start + tag.length, tag.label) if span: ents.append(span) @@ -113,7 +114,7 @@ def get_loss( scores_pt = xp2torch(scores, requires_grad=True) char_based_tags = [] - alignments = [] + temp = [] lengths = [] for example in examples: tags = tuple( @@ -124,14 +125,13 @@ def get_loss( alignment = example.x.user_data["alignment"] lengths.append(alignment.num_tokens) - alignments.append(alignment) + temp.append(alignment) - tags_batch = TagsBatch( - tags_batch=tuple(char_based_tags), - alignments=alignments, + alignments = Alignments(tuple(temp)) + tag_bitmap = torch.tensor( + alignments.get_tag_bitmap(char_based_tags, self.label_set), + device=scores_pt.device, ) - tags_batch.to(scores_pt.device) - tag_bitmap = tags_batch.get_tag_bitmap(self.label_set) max_length = max(lengths) mask = torch.tensor( diff --git a/spacy_partial_tagger/tagger.py b/spacy_partial_tagger/tagger.py index fcf34cb..332034e 100644 --- a/spacy_partial_tagger/tagger.py +++ b/spacy_partial_tagger/tagger.py @@ -2,7 +2,6 @@ from typing import Any, Callable, List, Optional, Tuple, cast from partial_tagger.data import LabelSet -from partial_tagger.data.batch.text import BaseTokenizer from spacy.tokens import Doc from spacy.util import registry from thinc.api import Model, get_torch_default_device, torch2xp, xp2torch @@ -10,7 +9,8 @@ from thinc.types import ArgsKwargs, Floats4d, Ints2d from thinc.util import convert_recursive, is_torch_array, is_xp_array -from .tokenizer import get_tokenizer +from spacy_partial_tagger.collator import get_collator + from .util import create_tagger @@ -42,19 +42,17 @@ def forward( X: List[Doc], is_train: bool, ) -> Tuple[Tuple[Floats4d, Ints2d], Callable]: - tokenizer: BaseTokenizer = model.attrs["tokenizer"] - - text_batch = tokenizer(tuple(doc.text for doc in X)) + collator = model.attrs["collator"] + batch, alignments = collator(tuple(doc.text for doc in X)) - for doc, alignment in zip(X, text_batch.alignments): + for doc, alignment in zip(X, alignments.alignments): doc.user_data["alignment"] = alignment device = get_torch_default_device() - text_batch.to(device) + batch = batch.to(device) (log_potentials, tag_indices), backward = model.layers[0]( - [text_batch.tagger_inputs, text_batch.mask], - is_train, + [batch.tagger_inputs, batch.mask], is_train ) return (log_potentials, tag_indices), backward @@ -74,7 +72,7 @@ def init( mixed_precision = model.attrs["mixed_precision"] grad_scaler = model.attrs["grad_scaler"] - model.attrs["tokenizer"] = get_tokenizer(transformer_model_name, tokenizer_args) + model.attrs["collator"] = get_collator(transformer_model_name, tokenizer_args) tagger = create_tagger(transformer_model_name, Y, padding_index) PyTorchWrapper = registry.get("layers", "PyTorchWrapper.v2")