diff --git a/docarray/array/mixins/match.py b/docarray/array/mixins/match.py index 8a6efaf878f..03a9fd3fe08 100644 --- a/docarray/array/mixins/match.py +++ b/docarray/array/mixins/match.py @@ -1,4 +1,4 @@ -from typing import Optional, Union, Callable, Tuple, TYPE_CHECKING +from typing import Optional, Union, Callable, Tuple, TYPE_CHECKING, Dict if TYPE_CHECKING: import numpy as np @@ -20,6 +20,7 @@ def match( metric_name: Optional[str] = None, batch_size: Optional[int] = None, exclude_self: bool = False, + filter: Optional[Dict] = None, only_id: bool = False, use_scipy: bool = False, device: str = 'cpu', @@ -50,6 +51,7 @@ def match( elements. When `darray` is big, this can significantly speedup the computation. :param exclude_self: if set, Documents in ``darray`` with same ``id`` as the left-hand values will not be considered as matches. + :param filter: filter query used for pre-filtering :param only_id: if set, then returning matches will only contain ``id`` :param use_scipy: if set, use ``scipy`` as the computation backend. Note, ``scipy`` does not support distance on sparse matrix. @@ -76,6 +78,7 @@ def match( metric_name=metric_name, batch_size=batch_size, exclude_self=exclude_self, + filter=filter, only_id=only_id, use_scipy=use_scipy, device=device, diff --git a/docs/advanced/document-store/annlite.md b/docs/advanced/document-store/annlite.md index e98e83b20f2..e6ebdb72a40 100644 --- a/docs/advanced/document-store/annlite.md +++ b/docs/advanced/document-store/annlite.md @@ -49,7 +49,7 @@ The following configs can be set: *You can check the default values in [the AnnLite source code](https://github.com/jina-ai/annlite/blob/main/annlite/core/index/hnsw/index.py) - +(annlite-filter)= ## Vector search with filter Search with `.find` can be restricted by user-defined filters. diff --git a/docs/advanced/document-store/qdrant.md b/docs/advanced/document-store/qdrant.md index 02cfedb3542..575a79a3491 100644 --- a/docs/advanced/document-store/qdrant.md +++ b/docs/advanced/document-store/qdrant.md @@ -133,7 +133,7 @@ print(da.find(np.random.random(D), limit=10)) ``` - +(qdrant-filter)= ## Vector search with filter Search with `.find` can be restricted by user-defined filters. Such filters can be constructed following the guidelines diff --git a/docs/advanced/document-store/weaviate.md b/docs/advanced/document-store/weaviate.md index a66169da677..26841d3620d 100644 --- a/docs/advanced/document-store/weaviate.md +++ b/docs/advanced/document-store/weaviate.md @@ -169,7 +169,7 @@ print(results[0].text) Persist Documents with Weaviate. ``` - +(weaviate-filter)= ## Vector search with filter Search with `.find` can be restricted by user-defined filters. Such filters can be constructed following the guidelines diff --git a/docs/fundamentals/documentarray/matching.md b/docs/fundamentals/documentarray/matching.md index b75d4969b7d..03e21f66769 100644 --- a/docs/fundamentals/documentarray/matching.md +++ b/docs/fundamentals/documentarray/matching.md @@ -155,6 +155,28 @@ Note that framework is auto-chosen based on the type of `.embeddings`. For examp By default `A.match(B)` will copy the top-K matched Documents from B to `A.matches`. When these matches are big, copying them can be time-consuming. In this case, one can leverage `.match(..., only_id=True)` to keep only {attr}`~docarray.Document.id`. +### Pre filtering + +Both `match` and `find` support pre-filtering by passing a `filter` argument to the method. + +Pre-filtering is an advanced approximate nearest neighbors feature that allows to efficiently retrieve the nearest vectors +that respect the filtering condition. + +In contrast, post-filtering in the naive approach where you first retrieve the +nearest neighbors and then discard all the candidates that do not respect the filter condition. + +````{admonition} Pre-filtering is not available for in-memory backend +:class: caution +By default a DocumentArray will use the in-memory backend which does not support pre-filtering +``` +```` + +You can find example on how to use the pre-filtering here: + +- {ref}`ANNLite ` +- {ref}`Weaviate ` +- {ref}`Qdrant ` + ### GPU support diff --git a/tests/unit/array/mixins/test_match.py b/tests/unit/array/mixins/test_match.py index 19d5faa909d..5408b8ab35e 100644 --- a/tests/unit/array/mixins/test_match.py +++ b/tests/unit/array/mixins/test_match.py @@ -10,7 +10,7 @@ from scipy.spatial.distance import cdist as scipy_cdist from docarray import Document, DocumentArray -from docarray.array.storage.weaviate import WeaviateConfig +import operator @pytest.fixture() @@ -577,3 +577,120 @@ def test_match_ensure_scores_unique(): for m in query.matches: assert m.scores['euclidean'].value >= previous_score previous_score = m.scores['euclidean'].value + + +numeric_operators_annlite = { + '$gte': operator.ge, + '$gt': operator.gt, + '$lte': operator.le, + '$lt': operator.lt, + '$eq': operator.eq, + '$neq': operator.ne, +} + +numeric_operators_weaviate = { + 'GreaterThanEqual': operator.ge, + 'GreaterThan': operator.gt, + 'LessThanEqual': operator.le, + 'LessThan': operator.lt, + 'Equal': operator.eq, + 'NotEqual': operator.ne, +} + + +numeric_operators_qdrant = { + 'gte': operator.ge, + 'gt': operator.gt, + 'lte': operator.le, + 'lt': operator.lt, + 'eq': operator.eq, + 'neq': operator.ne, +} + + +@pytest.mark.parametrize( + 'storage,filter_gen,numeric_operators,operator', + [ + *[ + tuple( + [ + 'weaviate', + lambda operator, threshold: { + 'path': ['price'], + 'operator': operator, + 'valueInt': threshold, + }, + numeric_operators_weaviate, + operator, + ] + ) + for operator in numeric_operators_weaviate.keys() + ], + *[ + tuple( + [ + 'qdrant', + lambda operator, threshold: { + 'must': [{'key': 'price', 'range': {operator: threshold}}] + }, + numeric_operators_qdrant, + operator, + ] + ) + for operator in ['gte', 'gt', 'lte', 'lt'] + ], + *[ + tuple( + [ + 'qdrant', + lambda operator, threshold: { + 'must': [{'key': 'price', 'value': {operator: threshold}}] + }, + numeric_operators_qdrant, + operator, + ] + ) + for operator in ['eq', 'neq'] + ], + *[ + tuple( + [ + 'annlite', + lambda operator, threshold: {'price': {operator: threshold}}, + numeric_operators_annlite, + operator, + ] + ) + for operator in numeric_operators_annlite.keys() + ], + ], +) +def test_match_pre_filtering( + storage, filter_gen, operator, numeric_operators, start_storage +): + n_dim = 128 + da = DocumentArray( + storage=storage, config={'n_dim': n_dim, 'columns': [('price', 'int')]} + ) + + da.extend( + [ + Document(id=f'r{i}', embedding=np.random.rand(n_dim), tags={'price': i}) + for i in range(50) + ] + ) + thresholds = [10, 20, 30] + + for threshold in thresholds: + + filter = filter_gen(operator, threshold) + + doc = Document(embedding=np.random.rand(n_dim)) + doc.match(da, filter=filter) + + assert all( + [ + numeric_operators[operator](r.tags['price'], threshold) + for r in doc.matches + ] + )