Skip to content

Commit

Permalink
Merge pull request #11 from jina-ai/feat-when-condition
Browse files Browse the repository at this point in the history
feat: add filters
  • Loading branch information
numb3r3 authored May 6, 2022
2 parents d910e0e + 74a8a92 commit f1d61d4
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 4 deletions.
20 changes: 16 additions & 4 deletions executor/encoder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
__copyright__ = 'Copyright (c) 2022 Jina AI Limited. All rights reserved.'
__license__ = 'Apache-2.0'

from typing import Optional
from typing import Dict, Optional

import numpy as np
import torch
Expand Down Expand Up @@ -61,6 +61,7 @@ def __init__(
input_shape: str = 'bnc',
device: str = 'cpu',
batch_size: int = 64,
filters: Optional[dict] = None,
**kwargs,
) -> None:
"""
Expand All @@ -72,6 +73,8 @@ def __init__(
:param input_shape: The shape of Input Point Cloud (b: batch, n: no of points, c: channels)
:param device: The device to use.
:param batch_size: The batch size to use.
:param filters: The filter condition that the documents need to fulfill before reaching the Executor.
The condition can be defined in the form of a `DocArray query condition <https://docarray.jina.ai/fundamentals/documentarray/find/#query-by-conditions>`
"""
super().__init__(**kwargs)

Expand Down Expand Up @@ -115,12 +118,21 @@ def __init__(

self._device = device
self._batch_size = batch_size
self._filters = filters

@requests
def encode(self, docs: DocumentArray, **_):
def encode(self, docs: 'DocumentArray', **_):
"""Encode docs."""
docs.apply(normalize)
docs.embed(
if docs is None:
return

if self._filters:
filtered_docs = docs.find(self._filters)
else:
filtered_docs = docs

filtered_docs.apply(normalize)
filtered_docs.embed(
self._model,
device=self._device,
batch_size=self._batch_size,
Expand Down
20 changes: 20 additions & 0 deletions tests/unit/test_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,23 @@ def test_encoder(model_name):

assert docs[0].embedding is not None
assert docs[0].embedding.shape == (1024,)


def test_filter():
encoder = MeshDataEncoder(
pretrained_model=None,
default_model_name='pointconv',
filters={'embedding': {'$exists': False}},
)

docs = DocumentArray(Document(tensor=np.random.random((1024, 3))))

embedding = np.random.random((512,))
docs.append(Document(tensor=np.random.random((1024, 3)), embedding=embedding))

encoder.encode(docs)

assert docs[0].embedding.shape == (1024,)

assert docs[1].embedding is not None
assert docs[1].embedding.shape == (512,)

0 comments on commit f1d61d4

Please sign in to comment.