Skip to content

Commit

Permalink
Data Augmentation Training (#2413)
Browse files Browse the repository at this point in the history
* Add field streaming

* Fix test

* Fix dataset

* Fix datasets

* Fix lint

* Fix import

---------

Co-authored-by: Ferran Llamas <[email protected]>
Co-authored-by: Ferran Llamas <[email protected]>
  • Loading branch information
3 people authored Aug 27, 2024
1 parent ae6ad6b commit 5d37539
Show file tree
Hide file tree
Showing 15 changed files with 594 additions and 270 deletions.
56 changes: 0 additions & 56 deletions nucliadb/src/nucliadb/train/api/v1/check.py

This file was deleted.

3 changes: 3 additions & 0 deletions nucliadb/src/nucliadb/train/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from nucliadb.train.generators.field_classifier import (
field_classification_batch_generator,
)
from nucliadb.train.generators.field_streaming import field_streaming_batch_generator
from nucliadb.train.generators.image_classifier import (
image_classification_batch_generator,
)
Expand Down Expand Up @@ -75,6 +76,8 @@ async def generate_train_data(kbid: str, shard: str, trainset: TrainSet):

elif trainset.type == TaskType.QUESTION_ANSWER_STREAMING:
batch_generator = question_answer_batch_generator(kbid, trainset, node, shard_replica_id)
elif trainset.type == TaskType.FIELD_STREAMING:
batch_generator = field_streaming_batch_generator(kbid, trainset, node, shard_replica_id)

if batch_generator is None:
raise HTTPException(
Expand Down
156 changes: 156 additions & 0 deletions nucliadb/src/nucliadb/train/generators/field_streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Copyright (C) 2021 Bosutech XXI S.L.
#
# nucliadb is offered under the AGPL v3.0 and as commercial software.
# For commercial licensing, contact us at [email protected].
#
# AGPL:
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#

from typing import AsyncGenerator, Optional

from nucliadb.common.cluster.base import AbstractIndexNode
from nucliadb.common.ids import FIELD_TYPE_STR_TO_PB
from nucliadb.train import logger
from nucliadb.train.generators.utils import batchify, get_resource_from_cache_or_db
from nucliadb_protos.dataset_pb2 import (
FieldSplitData,
FieldStreamingBatch,
TrainSet,
)
from nucliadb_protos.nodereader_pb2 import StreamRequest
from nucliadb_protos.resources_pb2 import Basic, FieldComputedMetadata
from nucliadb_protos.utils_pb2 import ExtractedText


def field_streaming_batch_generator(
kbid: str,
trainset: TrainSet,
node: AbstractIndexNode,
shard_replica_id: str,
) -> AsyncGenerator[FieldStreamingBatch, None]:
generator = generate_field_streaming_payloads(kbid, trainset, node, shard_replica_id)
batch_generator = batchify(generator, trainset.batch_size, FieldStreamingBatch)
return batch_generator


async def generate_field_streaming_payloads(
kbid: str,
trainset: TrainSet,
node: AbstractIndexNode,
shard_replica_id: str,
) -> AsyncGenerator[FieldSplitData, None]:
# Query how many resources has each label
request = StreamRequest()
request.shard_id.id = shard_replica_id

for label in trainset.filter.labels:
request.filter.labels.append(f"/l/{label}")

for path in trainset.filter.paths:
request.filter.labels.append(f"/p/{path}")

for metadata in trainset.filter.metadata:
request.filter.labels.append(f"/m/{metadata}")

for entity in trainset.filter.entities:
request.filter.labels.append(f"/e/{entity}")

for field in trainset.filter.fields:
request.filter.labels.append(f"/f/{field}")

for status in trainset.filter.status:
request.filter.labels.append(f"/n/s/{status}")
total = 0

async for document_item in node.stream_get_fields(request):
text_labels = []
for label in document_item.labels:
text_labels.append(label)

field_id = f"{document_item.uuid}{document_item.field}"
total += 1

field_parts = document_item.field.split("/")
if len(field_parts) == 3:
_, field_type, field = field_parts
split = "0"
elif len(field_parts) == 4:
_, field_type, field, split = field_parts
else:
raise Exception(f"Invalid field definition {document_item.field}")

tl = FieldSplitData()
rid, field_type, field = field_id.split("/")
tl.rid = document_item.uuid
tl.field = field
tl.field_type = field_type
tl.split = split
extracted = await get_field_text(kbid, rid, field, field_type)
if extracted is not None:
tl.text.CopyFrom(extracted)

metadata_obj = await get_field_metadata(kbid, rid, field, field_type)
if metadata_obj is not None:
tl.metadata.CopyFrom(metadata_obj)

basic = await get_field_basic(kbid, rid, field, field_type)
if basic is not None:
tl.basic.CopyFrom(basic)

tl.labels.extend(text_labels)

yield tl


async def get_field_text(kbid: str, rid: str, field: str, field_type: str) -> Optional[ExtractedText]:
orm_resource = await get_resource_from_cache_or_db(kbid, rid)

if orm_resource is None:
logger.error(f"{rid} does not exist on DB")
return None

field_type_int = FIELD_TYPE_STR_TO_PB[field_type]
field_obj = await orm_resource.get_field(field, field_type_int, load=False)
extracted_text = await field_obj.get_extracted_text()

return extracted_text


async def get_field_metadata(
kbid: str, rid: str, field: str, field_type: str
) -> Optional[FieldComputedMetadata]:
orm_resource = await get_resource_from_cache_or_db(kbid, rid)

if orm_resource is None:
logger.error(f"{rid} does not exist on DB")
return None

field_type_int = FIELD_TYPE_STR_TO_PB[field_type]
field_obj = await orm_resource.get_field(field, field_type_int, load=False)
field_metadata = await field_obj.get_field_metadata()

return field_metadata


async def get_field_basic(kbid: str, rid: str, field: str, field_type: str) -> Optional[Basic]:
orm_resource = await get_resource_from_cache_or_db(kbid, rid)

if orm_resource is None:
logger.error(f"{rid} does not exist on DB")
return None

basic = await orm_resource.get_basic()

return basic
8 changes: 5 additions & 3 deletions nucliadb/src/nucliadb/train/generators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
#

from contextvars import ContextVar
from typing import Any, AsyncIterator, Optional
from typing import Any, AsyncGenerator, AsyncIterator, Optional, Type

from nucliadb.common.maindb.utils import get_driver
from nucliadb.ingest.orm.knowledgebox import KnowledgeBox as KnowledgeBoxORM
from nucliadb.ingest.orm.resource import FIELD_TYPE_STR_TO_PB
from nucliadb.ingest.orm.resource import Resource as ResourceORM
from nucliadb.train import SERVICE_NAME, logger
from nucliadb.train.types import TrainBatchType
from nucliadb.train.types import T
from nucliadb_utils.utilities import get_storage

rcache: ContextVar[Optional[dict[str, ResourceORM]]] = ContextVar("rcache", default=None)
Expand Down Expand Up @@ -89,7 +89,9 @@ async def get_paragraph(kbid: str, paragraph_id: str) -> str:
return splitted_text


async def batchify(producer: AsyncIterator[Any], size: int, batch_klass: TrainBatchType):
async def batchify(
producer: AsyncIterator[Any], size: int, batch_klass: Type[T]
) -> AsyncGenerator[T, None]:
# NOTE: we are supposing all protobuffers have a data field
batch = []
async for item in producer:
Expand Down
23 changes: 13 additions & 10 deletions nucliadb/src/nucliadb/train/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
from typing import Union
from typing import TypeVar, Union

from nucliadb_protos import dataset_pb2 as dpb

Expand All @@ -29,14 +29,17 @@
dpb.QuestionAnswerStreamingBatch,
dpb.SentenceClassificationBatch,
dpb.TokenClassificationBatch,
dpb.FieldStreamingBatch,
]

TrainBatchType = Union[
type[dpb.FieldClassificationBatch],
type[dpb.ImageClassificationBatch],
type[dpb.ParagraphClassificationBatch],
type[dpb.ParagraphStreamingBatch],
type[dpb.QuestionAnswerStreamingBatch],
type[dpb.SentenceClassificationBatch],
type[dpb.TokenClassificationBatch],
]
T = TypeVar(
"T",
dpb.FieldClassificationBatch,
dpb.ImageClassificationBatch,
dpb.ParagraphClassificationBatch,
dpb.ParagraphStreamingBatch,
dpb.QuestionAnswerStreamingBatch,
dpb.SentenceClassificationBatch,
dpb.TokenClassificationBatch,
dpb.FieldStreamingBatch,
)
2 changes: 1 addition & 1 deletion nucliadb/src/nucliadb/writer/api/v1/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
)

if TYPE_CHECKING: # pragma: no cover
FIELD_TYPE_NAME_TO_FIELD_TYPE_MAP: dict[models.FieldTypeName, resources_pb2.FieldType.V]
FIELD_TYPE_NAME_TO_FIELD_TYPE_MAP: dict[models.FieldTypeName, resources_pb2.FieldType.ValueType]
else:
FIELD_TYPE_NAME_TO_FIELD_TYPE_MAP: dict[models.FieldTypeName, int]

Expand Down
Loading

0 comments on commit 5d37539

Please sign in to comment.