Skip to content

Commit

Permalink
Update dependency (#39)
Browse files Browse the repository at this point in the history
* Update pytorch-partial-tagger

* Use the Alignments class in pytorch-partial-tagger

* Fix an unnecessary alias
  • Loading branch information
yasufumy authored Jul 23, 2023
1 parent 22df87c commit d6e470f
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 46 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ dependencies = [
"torch<3.0.0,>=2.0.1",
"spacy[transformers]<4.0.0,>=3.3.1",
"spacy-alignments<1.0.0,>=0.8.5",
"pytorch-partial-tagger<1.0.0,>=0.1.12",
"pytorch-partial-tagger<1.0.0,>=0.1.14",
]
dynamic = ["version"]

Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,17 @@
from typing import Optional, Tuple

from partial_tagger.data import Alignment, Span
from partial_tagger.data.batch.text import (
BaseTokenizer,
TextBatch,
TransformerTokenizer,
)
from partial_tagger.data import Alignment, Alignments, Span
from partial_tagger.data.collators import BaseCollator, Batch, TransformerCollator
from transformers import AutoTokenizer
from transformers.models.bert_japanese import (
BertJapaneseTokenizer as _BertJapaneseTokenizer,
)
from transformers.models.bert_japanese import BertJapaneseTokenizer

from .util import get_alignments


class BertJapaneseTokenizer(BaseTokenizer):
class BertJapaneseCollator(BaseCollator):
def __init__(
self,
tokenizer: _BertJapaneseTokenizer,
tokenizer: BertJapaneseTokenizer,
tokenizer_args: Optional[dict] = None,
):
self.__tokenizer = tokenizer
Expand All @@ -29,7 +23,7 @@ def __init__(
}
self.__tokenizer_args["return_offsets_mapping"] = True

def __call__(self, texts: Tuple[str]) -> TextBatch:
def __call__(self, texts: Tuple[str]) -> Tuple[Batch, Alignments]:
batch_encoding = self.__tokenizer(texts, **self.__tokenizer_args)

pad_token_id = self.__tokenizer.pad_token_id
Expand All @@ -54,16 +48,16 @@ def __call__(self, texts: Tuple[str]) -> TextBatch:

alignments.append(Alignment(text, char_spans, tuple(token_indices)))

return TextBatch(
tagger_inputs=batch_encoding, mask=mask, alignments=tuple(alignments)
return Batch(tagger_inputs=batch_encoding, mask=mask), Alignments(
tuple(alignments)
)


def get_tokenizer(
def get_collator(
transformer_model_name: str, tokenizer_args: Optional[dict] = None
) -> BaseTokenizer:
) -> BaseCollator:
tokenizer = AutoTokenizer.from_pretrained(transformer_model_name)
if isinstance(tokenizer, _BertJapaneseTokenizer):
return BertJapaneseTokenizer(tokenizer, tokenizer_args)
if isinstance(tokenizer, BertJapaneseTokenizer):
return BertJapaneseCollator(tokenizer, tokenizer_args)
else:
return TransformerTokenizer(tokenizer, tokenizer_args)
return TransformerCollator(tokenizer, tokenizer_args)
32 changes: 16 additions & 16 deletions spacy_partial_tagger/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

import srsly
import torch
from partial_tagger.data import LabelSet
from partial_tagger.data.batch.tag import TagsBatch
from partial_tagger.data import Alignments, LabelSet
from partial_tagger.training import compute_partially_supervised_loss
from partial_tagger.utils import create_tag
from spacy import util
Expand Down Expand Up @@ -51,14 +50,16 @@ def set_annotations(
docs: List[Doc],
tag_indices: Floats2d,
) -> None:
for doc, indices in zip(docs, tag_indices.tolist()):
indices = [index for index in indices if index != self.padding_index]
alignment = doc.user_data["alignment"]
alignments = Alignments(tuple(doc.user_data["alignment"] for doc in docs))
tags_batch = alignments.create_char_based_tags(
tag_indices.tolist(),
label_set=self.label_set,
padding_index=self.padding_index,
)

for doc, tags in zip(docs, tags_batch):
ents = []
for tag in alignment.create_char_based_tags(
tag_indices=indices,
label_set=self.label_set,
):
for tag in tags:
span = doc.char_span(tag.start, tag.start + tag.length, tag.label)
if span:
ents.append(span)
Expand Down Expand Up @@ -113,7 +114,7 @@ def get_loss(
scores_pt = xp2torch(scores, requires_grad=True)

char_based_tags = []
alignments = []
temp = []
lengths = []
for example in examples:
tags = tuple(
Expand All @@ -124,14 +125,13 @@ def get_loss(

alignment = example.x.user_data["alignment"]
lengths.append(alignment.num_tokens)
alignments.append(alignment)
temp.append(alignment)

tags_batch = TagsBatch(
tags_batch=tuple(char_based_tags),
alignments=alignments,
alignments = Alignments(tuple(temp))
tag_bitmap = torch.tensor(
alignments.get_tag_bitmap(char_based_tags, self.label_set),
device=scores_pt.device,
)
tags_batch.to(scores_pt.device)
tag_bitmap = tags_batch.get_tag_bitmap(self.label_set)

max_length = max(lengths)
mask = torch.tensor(
Expand Down
18 changes: 8 additions & 10 deletions spacy_partial_tagger/tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
from typing import Any, Callable, List, Optional, Tuple, cast

from partial_tagger.data import LabelSet
from partial_tagger.data.batch.text import BaseTokenizer
from spacy.tokens import Doc
from spacy.util import registry
from thinc.api import Model, get_torch_default_device, torch2xp, xp2torch
from thinc.shims import PyTorchGradScaler, PyTorchShim
from thinc.types import ArgsKwargs, Floats4d, Ints2d
from thinc.util import convert_recursive, is_torch_array, is_xp_array

from .tokenizer import get_tokenizer
from spacy_partial_tagger.collator import get_collator

from .util import create_tagger


Expand Down Expand Up @@ -42,19 +42,17 @@ def forward(
X: List[Doc],
is_train: bool,
) -> Tuple[Tuple[Floats4d, Ints2d], Callable]:
tokenizer: BaseTokenizer = model.attrs["tokenizer"]

text_batch = tokenizer(tuple(doc.text for doc in X))
collator = model.attrs["collator"]
batch, alignments = collator(tuple(doc.text for doc in X))

for doc, alignment in zip(X, text_batch.alignments):
for doc, alignment in zip(X, alignments.alignments):
doc.user_data["alignment"] = alignment

device = get_torch_default_device()
text_batch.to(device)
batch = batch.to(device)

(log_potentials, tag_indices), backward = model.layers[0](
[text_batch.tagger_inputs, text_batch.mask],
is_train,
[batch.tagger_inputs, batch.mask], is_train
)

return (log_potentials, tag_indices), backward
Expand All @@ -74,7 +72,7 @@ def init(
mixed_precision = model.attrs["mixed_precision"]
grad_scaler = model.attrs["grad_scaler"]

model.attrs["tokenizer"] = get_tokenizer(transformer_model_name, tokenizer_args)
model.attrs["collator"] = get_collator(transformer_model_name, tokenizer_args)

tagger = create_tagger(transformer_model_name, Y, padding_index)
PyTorchWrapper = registry.get("layers", "PyTorchWrapper.v2")
Expand Down

0 comments on commit d6e470f

Please sign in to comment.