Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion docarray/array/mixins/match.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union, Callable, Tuple, TYPE_CHECKING
from typing import Optional, Union, Callable, Tuple, TYPE_CHECKING, Dict

if TYPE_CHECKING:
import numpy as np
Expand All @@ -20,6 +20,7 @@ def match(
metric_name: Optional[str] = None,
batch_size: Optional[int] = None,
exclude_self: bool = False,
filter: Optional[Dict] = None,
only_id: bool = False,
use_scipy: bool = False,
device: str = 'cpu',
Expand Down Expand Up @@ -50,6 +51,7 @@ def match(
elements. When `darray` is big, this can significantly speedup the computation.
:param exclude_self: if set, Documents in ``darray`` with same ``id`` as the left-hand values will not be
considered as matches.
:param filter: filter query used for pre-filtering
:param only_id: if set, then returning matches will only contain ``id``
:param use_scipy: if set, use ``scipy`` as the computation backend. Note, ``scipy`` does not support distance
on sparse matrix.
Expand All @@ -76,6 +78,7 @@ def match(
metric_name=metric_name,
batch_size=batch_size,
exclude_self=exclude_self,
filter=filter,
only_id=only_id,
use_scipy=use_scipy,
device=device,
Expand Down
2 changes: 1 addition & 1 deletion docs/advanced/document-store/annlite.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ The following configs can be set:

*You can check the default values in [the AnnLite source code](https://github.com/jina-ai/annlite/blob/main/annlite/core/index/hnsw/index.py)


(annlite-filter)=
## Vector search with filter

Search with `.find` can be restricted by user-defined filters.
Expand Down
2 changes: 1 addition & 1 deletion docs/advanced/document-store/qdrant.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ print(da.find(np.random.random(D), limit=10))
<DocumentArray (length=10) at 4917906896>
```


(qdrant-filter)=
## Vector search with filter

Search with `.find` can be restricted by user-defined filters. Such filters can be constructed following the guidelines
Expand Down
2 changes: 1 addition & 1 deletion docs/advanced/document-store/weaviate.md
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ print(results[0].text)
Persist Documents with Weaviate.
```


(weaviate-filter)=
## Vector search with filter

Search with `.find` can be restricted by user-defined filters. Such filters can be constructed following the guidelines
Expand Down
22 changes: 22 additions & 0 deletions docs/fundamentals/documentarray/matching.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,28 @@ Note that framework is auto-chosen based on the type of `.embeddings`. For examp

By default `A.match(B)` will copy the top-K matched Documents from B to `A.matches`. When these matches are big, copying them can be time-consuming. In this case, one can leverage `.match(..., only_id=True)` to keep only {attr}`~docarray.Document.id`.

### Pre filtering

Both `match` and `find` support pre-filtering by passing a `filter` argument to the method.

Pre-filtering is an advanced approximate nearest neighbors feature that allows to efficiently retrieve the nearest vectors
that respect the filtering condition.

In contrast, post-filtering in the naive approach where you first retrieve the
nearest neighbors and then discard all the candidates that do not respect the filter condition.

````{admonition} Pre-filtering is not available for in-memory backend
:class: caution
By default a DocumentArray will use the in-memory backend which does not support pre-filtering
```
````

You can find example on how to use the pre-filtering here:

- {ref}`ANNLite <annlite-filter>`
- {ref}`Weaviate <weaviate-filter>`
- {ref}`Qdrant <qdrant-filter>`


### GPU support

Expand Down
119 changes: 118 additions & 1 deletion tests/unit/array/mixins/test_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from scipy.spatial.distance import cdist as scipy_cdist

from docarray import Document, DocumentArray
from docarray.array.storage.weaviate import WeaviateConfig
import operator


@pytest.fixture()
Expand Down Expand Up @@ -577,3 +577,120 @@ def test_match_ensure_scores_unique():
for m in query.matches:
assert m.scores['euclidean'].value >= previous_score
previous_score = m.scores['euclidean'].value


numeric_operators_annlite = {
'$gte': operator.ge,
'$gt': operator.gt,
'$lte': operator.le,
'$lt': operator.lt,
'$eq': operator.eq,
'$neq': operator.ne,
}

numeric_operators_weaviate = {
'GreaterThanEqual': operator.ge,
'GreaterThan': operator.gt,
'LessThanEqual': operator.le,
'LessThan': operator.lt,
'Equal': operator.eq,
'NotEqual': operator.ne,
}


numeric_operators_qdrant = {
'gte': operator.ge,
'gt': operator.gt,
'lte': operator.le,
'lt': operator.lt,
'eq': operator.eq,
'neq': operator.ne,
}


@pytest.mark.parametrize(
'storage,filter_gen,numeric_operators,operator',
[
*[
tuple(
[
'weaviate',
lambda operator, threshold: {
'path': ['price'],
'operator': operator,
'valueInt': threshold,
},
numeric_operators_weaviate,
operator,
]
)
for operator in numeric_operators_weaviate.keys()
],
*[
tuple(
[
'qdrant',
lambda operator, threshold: {
'must': [{'key': 'price', 'range': {operator: threshold}}]
},
numeric_operators_qdrant,
operator,
]
)
for operator in ['gte', 'gt', 'lte', 'lt']
],
*[
tuple(
[
'qdrant',
lambda operator, threshold: {
'must': [{'key': 'price', 'value': {operator: threshold}}]
},
numeric_operators_qdrant,
operator,
]
)
for operator in ['eq', 'neq']
],
*[
tuple(
[
'annlite',
lambda operator, threshold: {'price': {operator: threshold}},
numeric_operators_annlite,
operator,
]
)
for operator in numeric_operators_annlite.keys()
],
],
)
def test_match_pre_filtering(
storage, filter_gen, operator, numeric_operators, start_storage
):
n_dim = 128
da = DocumentArray(
storage=storage, config={'n_dim': n_dim, 'columns': [('price', 'int')]}
)

da.extend(
[
Document(id=f'r{i}', embedding=np.random.rand(n_dim), tags={'price': i})
for i in range(50)
]
)
thresholds = [10, 20, 30]

for threshold in thresholds:

filter = filter_gen(operator, threshold)

doc = Document(embedding=np.random.rand(n_dim))
doc.match(da, filter=filter)

assert all(
[
numeric_operators[operator](r.tags['price'], threshold)
for r in doc.matches
]
)