Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
bloodbare committed Jan 23, 2024
1 parent aacdfab commit dee9428
Show file tree
Hide file tree
Showing 14 changed files with 166 additions and 149 deletions.
2 changes: 1 addition & 1 deletion nucliadb_dataset/nucliadb_dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def __init__(
self.task_definition = task_definition
self.sdk = sdk
self.streamer = Streamer(
self.trainset, reader_headers=sdk.headers, base_url=sdk.base_url
self.trainset, reader_headers=sdk.headers, base_url=sdk.base_url, kbid=kbid
)

self._set_schema(self.task_definition.schema)
Expand Down
6 changes: 4 additions & 2 deletions nucliadb_dataset/nucliadb_dataset/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class NucliaDatasetsExport:
def __init__(
self,
sdk: NucliaDB,
kbid: str,
datasets_url: str,
trainset: TrainSet,
cache_path: str,
Expand All @@ -38,7 +39,7 @@ def __init__(
self.trainset = trainset
self.sdk = sdk
self.nucliadb_dataset = NucliaDBDataset(
trainset=trainset, client=sdk, base_path=cache_path
trainset=trainset, kbid=kbid, sdk=sdk, base_path=cache_path
)
self.apikey = apikey

Expand Down Expand Up @@ -71,12 +72,13 @@ class FileSystemExport:
def __init__(
self,
sdk: NucliaDB,
kbid: str,
trainset: TrainSet,
store_path: str,
):
self.sdk = sdk
self.nucliadb_dataset = NucliaDBDataset(
trainset=trainset, client=sdk, base_path=store_path
trainset=trainset, kbid=kbid, sdk=sdk, base_path=store_path
)

def export(self):
Expand Down
2 changes: 1 addition & 1 deletion nucliadb_dataset/nucliadb_dataset/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@


def bytes_to_batch(klass: Any):
def func(batch: bytes, *args) -> Any:
def func(batch: bytes, *args, **kwargs) -> Any:
pb = klass()
pb.ParseFromString(batch)
return pb
Expand Down
9 changes: 7 additions & 2 deletions nucliadb_dataset/nucliadb_dataset/streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,16 @@ class Streamer:
resp: Optional[requests.Response]

def __init__(
self, trainset: TrainSet, reader_headers: Dict[str, str], base_url: str
self,
trainset: TrainSet,
reader_headers: Dict[str, str],
base_url: str,
kbid: str,
):
self.reader_headers = reader_headers
self.base_url = base_url
self.trainset = trainset
self.kbid = kbid
self.resp = None

@property
Expand All @@ -50,7 +55,7 @@ def initialize(self, partition_id: str):
self.stream_session = requests.Session()
self.stream_session.headers.update(self.reader_headers)
self.resp = self.stream_session.post(
f"{self.base_url}/trainset/{partition_id}",
f"{self.base_url}/v1/kb/{self.kbid}/trainset/{partition_id}",
data=self.trainset.SerializeToString(),
stream=True,
)
Expand Down
17 changes: 15 additions & 2 deletions nucliadb_dataset/nucliadb_dataset/tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,11 @@

import re
import tempfile
from typing import AsyncIterator, Optional
from typing import AsyncIterator, Iterator, Optional

import docker
from nucliadb_models.common import FieldID, UserClassification
from nucliadb_models.entities import CreateEntitiesGroupPayload, Entity
from nucliadb_models.extracted import FieldMetadata
from nucliadb_models.labels import Label, LabelSet, LabelSetKind
from nucliadb_models.metadata import TokenSplit, UserFieldMetadata, UserMetadata
from nucliadb_models.resource import KnowledgeBoxObj
Expand All @@ -33,6 +32,7 @@
from nucliadb_models.writer import CreateResourcePayload
from nucliadb_sdk.v2.sdk import NucliaDB # type: ignore
import pytest
import grpc
from grpc import aio # type: ignore
from nucliadb_protos.writer_pb2_grpc import WriterStub

Expand Down Expand Up @@ -90,6 +90,8 @@ def upload_data_field_classification(sdk: NucliaDB, kb: KnowledgeBoxObj):
),
)

return kb


@pytest.fixture(scope="function")
def upload_data_paragraph_classification(sdk: NucliaDB, kb: KnowledgeBoxObj):
Expand Down Expand Up @@ -156,6 +158,7 @@ def upload_data_paragraph_classification(sdk: NucliaDB, kb: KnowledgeBoxObj):
),
),
)
return kb


@pytest.fixture(scope="function")
Expand Down Expand Up @@ -250,6 +253,8 @@ def upload_data_token_classification(sdk: NucliaDB, kb: KnowledgeBoxObj):
),
)

return kb


@pytest.fixture(scope="function")
def text_editors_kb(sdk: NucliaDB, kb: KnowledgeBoxObj):
Expand Down Expand Up @@ -321,3 +326,11 @@ async def ingest_stub(nucliadb) -> AsyncIterator[WriterStub]:
stub = WriterStub(channel) # type: ignore
yield stub
await channel.close(grace=True)


@pytest.fixture
def ingest_stub_sync(nucliadb) -> Iterator[WriterStub]:
channel = grpc.insecure_channel(f"{nucliadb.host}:{nucliadb.grpc}")
stub = WriterStub(channel) # type: ignore
yield stub
channel.close()

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,18 @@
import re
import tempfile
from uuid import uuid4
from nucliadb_models.resource import KnowledgeBoxObj
from nucliadb_sdk.v2.sdk import NucliaDB

import pyarrow as pa # type: ignore
import pytest
from nucliadb_protos.dataset_pb2 import TaskType, TrainSet

from nucliadb_dataset.export import FileSystemExport, NucliaDatasetsExport
from nucliadb_sdk.knowledgebox import KnowledgeBox


def test_filesystem_export(
knowledgebox: KnowledgeBox, upload_data_field_classification
sdk: NucliaDB, upload_data_field_classification: KnowledgeBoxObj
):
trainset = TrainSet()
trainset.type = TaskType.FIELD_CLASSIFICATION
Expand All @@ -40,7 +41,8 @@ def test_filesystem_export(

with tempfile.TemporaryDirectory() as tmpdirname:
exporter = FileSystemExport(
knowledgebox.client,
sdk=sdk,
kbid=upload_data_field_classification.uuid,
trainset=trainset,
store_path=tmpdirname,
)
Expand All @@ -55,8 +57,8 @@ def test_filesystem_export(

def test_datasets_export(
mocked_datasets_url: str,
knowledgebox: KnowledgeBox,
upload_data_field_classification,
sdk: NucliaDB,
upload_data_field_classification: KnowledgeBoxObj,
):
trainset = TrainSet()
trainset.type = TaskType.FIELD_CLASSIFICATION
Expand All @@ -65,7 +67,8 @@ def test_datasets_export(

with tempfile.TemporaryDirectory() as tmpdirname:
exporter = NucliaDatasetsExport(
client=knowledgebox.client,
sdk=sdk,
kbid=upload_data_field_classification.uuid,
datasets_url=mocked_datasets_url,
trainset=trainset,
cache_path=tmpdirname,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

from nucliadb_models.common import UserClassification
from nucliadb_models.metadata import UserMetadata
from nucliadb_models.resource import KnowledgeBoxObj
from nucliadb_models.text import TextField
from nucliadb_models.utils import FieldIdString, SlugString
from nucliadb_models.writer import CreateResourcePayload
from nucliadb_sdk.v2.sdk import NucliaDB
import pyarrow as pa # type: ignore
from nucliadb_protos.dataset_pb2 import TaskType, TrainSet

Expand All @@ -26,7 +33,7 @@


def test_field_classification_with_labels(
knowledgebox: KnowledgeBox, upload_data_field_classification
sdk: NucliaDB, upload_data_field_classification: KnowledgeBoxObj
):
trainset = TrainSet()
trainset.type = TaskType.FIELD_CLASSIFICATION
Expand All @@ -41,32 +48,56 @@ def test_field_classification_with_labels(
trainset.filter.ClearField("labels")
trainset.filter.labels.extend(labels) # type: ignore

partitions = export_dataset(knowledgebox, trainset)
partitions = export_dataset(
sdk=sdk, trainset=trainset, kb=upload_data_field_classification
)
assert len(partitions) == 1

loaded_array = partitions[0]
assert len(loaded_array) == expected


def test_datascientist(knowledgebox: KnowledgeBox, temp_folder):
knowledgebox.upload(
text="I'm Ramon",
labels=["labelset/positive"],
def test_datascientist(sdk: NucliaDB, temp_folder, kb: KnowledgeBoxObj):
sdk.create_resource(
kbid=kb.uuid,
content=CreateResourcePayload(
texts={FieldIdString("text"): TextField(body="I'm Ramon")},
usermetadata=UserMetadata(
classifications=[
UserClassification(labelset="labelset", label="positive"),
]
),
),
)

knowledgebox.upload(
text="I'm not Ramon",
labels=["labelset/negative"],
sdk.create_resource(
kbid=kb.uuid,
content=CreateResourcePayload(
texts={FieldIdString("text"): TextField(body="I'm not Ramon")},
usermetadata=UserMetadata(
classifications=[
UserClassification(labelset="labelset", label="negative"),
]
),
),
)

knowledgebox.upload(
text="I'm Aleix",
labels=["labelset/smart"],
sdk.create_resource(
kbid=kb.uuid,
content=CreateResourcePayload(
texts={FieldIdString("text"): TextField(body="I'm Aleix")},
usermetadata=UserMetadata(
classifications=[
UserClassification(labelset="labelset", label="smart"),
]
),
),
)

arrow_filenames = download_all_partitions(
task="FIELD_CLASSIFICATION",
knowledgebox=knowledgebox,
sdk=sdk,
slug=kb.slug,
path=temp_folder,
labels=["labelset"],
)
Expand Down
Loading

1 comment on commit dee9428

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark

Benchmark suite Current: dee9428 Previous: c67870a Ratio
nucliadb/search/tests/unit/search/test_fetch.py::test_highligh_error 12818.497675129342 iter/sec (stddev: 1.7101399130881047e-7) 12887.24555746259 iter/sec (stddev: 2.385970996903907e-7) 1.01

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.