diff --git a/docarray/array/mixins/find.py b/docarray/array/mixins/find.py index 969fd57cba4..a79c056fe6d 100644 --- a/docarray/array/mixins/find.py +++ b/docarray/array/mixins/find.py @@ -87,22 +87,29 @@ def find(self: 'T', query: Dict, **kwargs) -> 'DocumentArray': def find( self: 'T', - query: Union['DocumentArray', 'Document', 'ArrayType', Dict, str, List[str]], + query: Union[ + 'DocumentArray', 'Document', 'ArrayType', Dict, str, List[str], None + ] = None, metric: Union[ str, Callable[['ArrayType', 'ArrayType'], 'np.ndarray'] ] = 'cosine', limit: Optional[Union[int, float]] = 20, metric_name: Optional[str] = None, exclude_self: bool = False, + filter: Optional[Dict] = None, only_id: bool = False, index: str = 'text', **kwargs, ) -> Union['DocumentArray', List['DocumentArray']]: """Returns matching Documents given an input query. If the query is a `DocumentArray`, `Document` or `ArrayType`, exhaustive or approximate nearest neighbor search - will be performed depending on whether the storage backend supports ANN. - If the query is a `dict` object, Documents will be filtered according to DocArray's query language and all - matching Documents that match the filter will be returned. + will be performed depending on whether the storage backend supports ANN. Furthermore, if filter is not None, + pre-filtering will be applied along with vector search. + If the query is a `dict` object or, query is None and filter is not None, Documents will be filtered and all + matching Documents that match the filter will be returned. In this case, query (if it's dict) or filter will be + used for filtering. The object must follow the backend-specific filter format if the backend supports filtering + or DocArray's query language format. In the latter case, filtering will be applied in the client side not the + backend side. If the query is a string or list of strings, a search by text will be performed if the backend supports indexing and searching text fields. If not, a `NotImplementedError` will be raised. @@ -112,6 +119,7 @@ def find( :param metric: the distance metric. :param exclude_self: if set, Documents in results with same ``id`` as the query values will not be considered as matches. This is only applied when the input query is Document or DocumentArray. + :param filter: filter query used for pre-filtering or filtering :param only_id: if set, then returning matches will only contain ``id`` :param index: if the query is a string, text search will be performed on the `index` field, otherwise, this parameter is ignored. By default, the Document `text` attribute will be used for search, @@ -125,21 +133,35 @@ def find( from ... import Document, DocumentArray if isinstance(query, dict): - return self._filter(query) - elif isinstance(query, (DocumentArray, Document)): - - if isinstance(query, Document): - query = DocumentArray(query) - - _query = query.embeddings + if filter is None: + return self._filter(query) + else: + raise ValueError( + 'filter and query cannot be both dict type, set only one for filtering' + ) + elif query is None: + if isinstance(filter, dict): + return self._filter(filter) + else: + raise ValueError('filter must be dict when query is None') elif isinstance(query, str) or ( isinstance(query, list) and isinstance(query[0], str) ): + if filter is not None: + raise ValueError('cannot use filter with text search') result = self._find_by_text(query, index=index, limit=limit, **kwargs) if isinstance(query, str): return result[0] else: return result + + # for all the rest, vector search will be performed + elif isinstance(query, (DocumentArray, Document)): + + if isinstance(query, Document): + query = DocumentArray(query) + + _query = query.embeddings else: _query = query @@ -169,6 +191,7 @@ def find( _result = self._find( _query, + filter=filter, **kwargs, ) @@ -227,7 +250,7 @@ def find( @abc.abstractmethod def _find( - self, query: 'ArrayType', limit: int, **kwargs + self, query: 'ArrayType', limit: int, filter: Optional[Dict] = None, **kwargs ) -> Tuple['np.ndarray', 'np.ndarray']: raise NotImplementedError diff --git a/docarray/array/storage/annlite/backend.py b/docarray/array/storage/annlite/backend.py index 82bee7e03cc..4523c057db0 100644 --- a/docarray/array/storage/annlite/backend.py +++ b/docarray/array/storage/annlite/backend.py @@ -5,6 +5,8 @@ Optional, TYPE_CHECKING, Iterable, + List, + Tuple, ) import numpy as np @@ -25,11 +27,14 @@ class AnnliteConfig: ef_construction: Optional[int] = None ef_search: Optional[int] = None max_connection: Optional[int] = None + columns: Optional[List[Tuple[str, str]]] = None class BackendMixin(BaseBackendMixin): """Provide necessary functions to enable this storage backend.""" + TYPE_MAP = {'str': 'TEXT', 'float': 'float', 'int': 'integer'} + def _map_embedding(self, embedding: 'ArrayType') -> 'ArrayType': if embedding is None: embedding = np.zeros(self.n_dim, dtype=np.float32) @@ -62,6 +67,15 @@ def _init_storage( self._config = config + if self._config.columns is None: + self._config.columns = [] + + for i in range(len(self._config.columns)): + self._config.columns[i] = ( + self._config.columns[i][0], + self._map_type(self._config.columns[i][1]), + ) + config = asdict(config) self.n_dim = config.pop('n_dim') diff --git a/docarray/array/storage/annlite/find.py b/docarray/array/storage/annlite/find.py index cf8824e7a0e..9084eaf35ff 100644 --- a/docarray/array/storage/annlite/find.py +++ b/docarray/array/storage/annlite/find.py @@ -3,6 +3,7 @@ Optional, TYPE_CHECKING, List, + Dict, ) if TYPE_CHECKING: @@ -16,6 +17,7 @@ def _find( query: 'np.ndarray', limit: Optional[Union[int, float]] = 20, only_id: bool = False, + filter: Optional[Dict] = None, **kwargs, ) -> List['DocumentArray']: """Returns approximate nearest neighbors given an input query. @@ -23,6 +25,7 @@ def _find( :param query: the query documents to search. :param limit: the number of results to get for each query document in search. :param only_id: if set, then returning matches will only contain ``id`` + :param filter: filter query used for pre-filtering :param kwargs: other kwargs. :return: a list of DocumentArrays containing the closest Document objects for each of the queries in `query`. @@ -34,7 +37,7 @@ def _find( query = query.reshape(1, -1) _, match_docs = self._annlite._search_documents( - query, limit=limit, include_metadata=not only_id + query, limit=limit, filter=filter or {}, include_metadata=not only_id ) return match_docs diff --git a/docarray/array/storage/base/backend.py b/docarray/array/storage/base/backend.py index 1cde4ccf09d..14f32434aa1 100644 --- a/docarray/array/storage/base/backend.py +++ b/docarray/array/storage/base/backend.py @@ -7,6 +7,8 @@ class BaseBackendMixin(ABC): + TYPE_MAP: Dict + def _init_storage( self, _docs: Optional['DocumentArraySourceType'] = None, @@ -27,3 +29,6 @@ def _map_embedding(self, embedding: 'ArrayType') -> 'ArrayType': from ....math.ndarray import to_numpy_array return to_numpy_array(embedding) + + def _map_type(self, col_type: str) -> str: + return self.TYPE_MAP[col_type] diff --git a/docarray/array/storage/elastic/backend.py b/docarray/array/storage/elastic/backend.py index 819339c0b9d..dfc66ae010b 100644 --- a/docarray/array/storage/elastic/backend.py +++ b/docarray/array/storage/elastic/backend.py @@ -9,6 +9,7 @@ List, Iterable, Any, + Tuple, Mapping, ) @@ -41,6 +42,7 @@ class ElasticConfig: batch_size: int = 64 ef_construction: Optional[int] = None m: Optional[int] = None + columns: Optional[List[Tuple[str, str]]] = None class BackendMixin(BaseBackendMixin): diff --git a/docarray/array/storage/elastic/find.py b/docarray/array/storage/elastic/find.py index 8149135b448..95e04b7113b 100644 --- a/docarray/array/storage/elastic/find.py +++ b/docarray/array/storage/elastic/find.py @@ -4,6 +4,8 @@ Sequence, List, Union, + Optional, + Dict, ) import numpy as np @@ -99,16 +101,22 @@ def _find( self, query: 'ElasticArrayType', limit: int = 10, + filter: Optional[Dict] = None, **kwargs, ) -> List['DocumentArray']: """Returns approximate nearest neighbors given a batch of input queries. :param query: input supported to be stored in Elastic. This includes any from the list '[np.ndarray, tensorflow.Tensor, torch.Tensor, Sequence[float]]' :param limit: number of retrieved items + :param filter: filter query used for pre-filtering :return: DocumentArray containing the closest documents to the query if it is a single query, otherwise a list of DocumentArrays containing the closest Document objects for each of the queries in `query`. """ + if filter is not None: + raise ValueError( + 'Filtered vector search is not supported for ElasticSearch backend' + ) query = np.array(query) num_rows, n_dim = ndarray.get_array_rows(query) if n_dim != 2: diff --git a/docarray/array/storage/memory/find.py b/docarray/array/storage/memory/find.py index a3f5c667da2..b66917ec649 100644 --- a/docarray/array/storage/memory/find.py +++ b/docarray/array/storage/memory/find.py @@ -1,4 +1,4 @@ -from typing import Optional, Union, Tuple, Callable, TYPE_CHECKING +from typing import Optional, Union, Tuple, Callable, TYPE_CHECKING, Dict import numpy as np @@ -27,6 +27,7 @@ def _find( use_scipy: bool = False, device: str = 'cpu', num_worker: Optional[int] = 1, + filter: Optional[Dict] = None, **kwargs, ) -> Tuple['np.ndarray', 'np.ndarray']: """Returns approximate nearest neighbors given a batch of input queries. @@ -47,11 +48,15 @@ def _find( .. note:: This argument is only effective when ``batch_size`` is set. - + :param filter: filter query used for pre-filtering :param kwargs: other kwargs. :return: a list of DocumentArrays containing the closest Document objects for each of the queries in `query`. """ + if filter is not None: + raise ValueError( + 'Filtered vector search is not supported for In-Memory backend' + ) if batch_size is not None: if batch_size <= 0: diff --git a/docarray/array/storage/qdrant/backend.py b/docarray/array/storage/qdrant/backend.py index 6cf62357d5f..4c78915a707 100644 --- a/docarray/array/storage/qdrant/backend.py +++ b/docarray/array/storage/qdrant/backend.py @@ -7,6 +7,7 @@ Dict, Iterable, List, + Tuple, ) import numpy as np @@ -41,6 +42,7 @@ class QdrantConfig: ef_construct: Optional[int] = None full_scan_threshold: Optional[int] = None m: Optional[int] = None + columns: Optional[List[Tuple[str, str]]] = None class BackendMixin(BaseBackendMixin): @@ -86,6 +88,9 @@ def _init_storage( self._config = config self._persist = bool(self._config.collection_name) + if self._config.columns is None: + self._config.columns = [] + self._config.collection_name = ( self.__class__.__name__ + random_identity() if self._config.collection_name is None diff --git a/docarray/array/storage/qdrant/find.py b/docarray/array/storage/qdrant/find.py index e1c7c7b71b8..52f5b0cc11e 100644 --- a/docarray/array/storage/qdrant/find.py +++ b/docarray/array/storage/qdrant/find.py @@ -4,6 +4,8 @@ TypeVar, Sequence, List, + Dict, + Optional, ) from qdrant_client.http.models.models import Distance @@ -48,13 +50,15 @@ def serialize_config(self) -> dict: def distance(self) -> 'Distance': raise NotImplementedError() - def _find_similar_vectors(self, q: 'QdrantArrayType', limit=10): + def _find_similar_vectors( + self, q: 'QdrantArrayType', limit: int = 10, filter: Optional[Dict] = None + ): query_vector = self._map_embedding(q) search_result = self.client.search( self.collection_name, query_vector=query_vector, - query_filter=None, + query_filter=filter, search_params=None, top=limit, append_payload=['_serialized'], @@ -74,11 +78,16 @@ def _find_similar_vectors(self, q: 'QdrantArrayType', limit=10): return DocumentArray(docs) def _find( - self, query: 'QdrantArrayType', limit: int = 10, **kwargs + self, + query: 'QdrantArrayType', + limit: int = 10, + filter: Optional[Dict] = None, + **kwargs, ) -> List['DocumentArray']: """Returns approximate nearest neighbors given a batch of input queries. :param query: input supported to be used in Qdrant. :param limit: number of retrieved items + :param filter: filter query used for pre-filtering :return: a list of DocumentArrays containing the closest Document objects for each of the queries in `query`. @@ -87,10 +96,10 @@ def _find( num_rows, _ = ndarray.get_array_rows(query) if num_rows == 1: - return [self._find_similar_vectors(query, limit=limit)] + return [self._find_similar_vectors(query, limit=limit, filter=filter)] else: closest_docs = [] for q in query: - da = self._find_similar_vectors(q, limit=limit) + da = self._find_similar_vectors(q, limit=limit, filter=filter) closest_docs.append(da) return closest_docs diff --git a/docarray/array/storage/qdrant/getsetdel.py b/docarray/array/storage/qdrant/getsetdel.py index 62f6bf7b293..fdc2f1b1069 100644 --- a/docarray/array/storage/qdrant/getsetdel.py +++ b/docarray/array/storage/qdrant/getsetdel.py @@ -65,9 +65,13 @@ def _qdrant_to_document(self, qdrant_record: dict) -> 'Document': ) def _document_to_qdrant(self, doc: 'Document') -> 'PointStruct': + extra_columns = {col: doc.tags.get(col) for col, _ in self._config.columns} + return PointStruct( id=self._map_id(doc.id), - payload=dict(_serialized=doc.to_base64(**self.serialization_config)), + payload=dict( + _serialized=doc.to_base64(**self.serialization_config), **extra_columns + ), vector=self._map_embedding(doc.embedding), ) diff --git a/docarray/array/storage/sqlite/backend.py b/docarray/array/storage/sqlite/backend.py index 8574e2053d2..555ce5fc892 100644 --- a/docarray/array/storage/sqlite/backend.py +++ b/docarray/array/storage/sqlite/backend.py @@ -8,6 +8,8 @@ Optional, TYPE_CHECKING, Union, + List, + Tuple, ) from .helper import initialize_table diff --git a/docarray/array/storage/weaviate/backend.py b/docarray/array/storage/weaviate/backend.py index afbaf2795cf..c693c1cecd9 100644 --- a/docarray/array/storage/weaviate/backend.py +++ b/docarray/array/storage/weaviate/backend.py @@ -45,12 +45,15 @@ class WeaviateConfig: flat_search_cutoff: Optional[int] = None cleanup_interval_seconds: Optional[int] = None skip: Optional[bool] = None + columns: Optional[List[Tuple[str, str]]] = None distance: Optional[str] = None class BackendMixin(BaseBackendMixin): """Provide necessary functions to enable this storage backend.""" + TYPE_MAP = {'str': 'string', 'float': 'number', 'int': 'int'} + def _init_storage( self, _docs: Optional['DocumentArraySourceType'] = None, @@ -87,6 +90,9 @@ def _init_storage( ) self._config = config + if self._config.columns is None: + self._config.columns = [] + self._schemas = self._load_or_create_weaviate_schema() _REGISTRY[self.__class__.__name__][self._class_name].append(self) @@ -139,7 +145,7 @@ def _get_schema_by_name(self, cls_name: str) -> Dict: 'distance': self._config.distance, } - return { + base_classes = { 'classes': [ { 'class': cls_name, @@ -167,6 +173,15 @@ def _get_schema_by_name(self, cls_name: str) -> Dict: }, ] } + for col, coltype in self._config.columns: + new_property = { + 'dataType': [self._map_type(coltype)], + 'name': col, + 'indexInverted': True, + } + base_classes['classes'][0]['properties'].append(new_property) + + return base_classes def _load_or_create_weaviate_schema(self): """Create a new weaviate schema for this :class:`DocumentArrayWeaviate` object @@ -295,8 +310,13 @@ def _doc2weaviate_create_payload(self, value: 'Document'): :param value: document to create a payload for :return: the payload dictionary """ + extra_columns = {col: value.tags.get(col) for col, _ in self._config.columns} + return dict( - data_object={'_serialized': value.to_base64(**self._serialize_config)}, + data_object={ + '_serialized': value.to_base64(**self._serialize_config), + **extra_columns, + }, class_name=self._class_name, uuid=self._map_id(value.id), vector=self._map_embedding(value.embedding), diff --git a/docarray/array/storage/weaviate/find.py b/docarray/array/storage/weaviate/find.py index 3cdaa435bfe..3cf73477266 100644 --- a/docarray/array/storage/weaviate/find.py +++ b/docarray/array/storage/weaviate/find.py @@ -3,6 +3,8 @@ TypeVar, Sequence, List, + Dict, + Optional, ) import numpy as np @@ -27,26 +29,41 @@ class FindMixin: - def _find_similar_vectors(self, query: 'WeaviateArrayType', limit=10): + def _find_similar_vectors( + self, query: 'WeaviateArrayType', limit=10, filter: Optional[Dict] = None + ): query = to_numpy_array(query) is_all_zero = np.all(query == 0) if is_all_zero: query = query + EPSILON query_dict = {'vector': query} - results = ( - self._client.query.get( - self._class_name, - ['_serialized', '_additional {certainty}', '_additional {id}'], - ) + + query_builder = ( + self._client.query.get(self._class_name, '_serialized') + .with_additional(['id', 'certainty']) .with_limit(limit) .with_near_vector(query_dict) - .do() ) + + if filter: + query_builder = query_builder.with_where(filter) + + results = query_builder.do() + docs = [] + if 'errors' in results: + errors = '\n'.join(map(lambda error: error['message'], results['errors'])) + raise ValueError( + f'find failed, please check your filter query. Errors: \n{errors}' + ) + + found_results = ( + results.get('data', {}).get('Get', {}).get(self._class_name, []) or [] + ) # The serialized document is stored in results['data']['Get'][self._class_name] - for result in results.get('data', {}).get('Get', {}).get(self._class_name, []): + for result in found_results: doc = Document.from_base64(result['_serialized'], **self._serialize_config) certainty = result['_additional']['certainty'] @@ -64,11 +81,16 @@ def _find_similar_vectors(self, query: 'WeaviateArrayType', limit=10): return DocumentArray(docs) def _find( - self, query: 'WeaviateArrayType', limit: int = 10, **kwargs + self, + query: 'WeaviateArrayType', + limit: int = 10, + filter: Optional[Dict] = None, + **kwargs, ) -> List['DocumentArray']: """Returns approximate nearest neighbors given a batch of input queries. :param query: input supported to be stored in Weaviate. This includes any from the list '[np.ndarray, tensorflow.Tensor, torch.Tensor, Sequence[float]]' :param limit: number of retrieved items + :param filter: filter query used for pre-filtering :return: DocumentArray containing the closest documents to the query if it is a single query, otherwise a list of DocumentArrays containing the closest Document objects for each of the queries in `query`. @@ -80,10 +102,10 @@ def _find( num_rows, _ = ndarray.get_array_rows(query) if num_rows == 1: - return [self._find_similar_vectors(query, limit=limit)] + return [self._find_similar_vectors(query, limit=limit, filter=filter)] else: closest_docs = [] for q in query: - da = self._find_similar_vectors(q, limit=limit) + da = self._find_similar_vectors(q, limit=limit, filter=filter) closest_docs.append(da) return closest_docs diff --git a/docs/advanced/document-store/annlite.md b/docs/advanced/document-store/annlite.md index f953dac1ab6..e98e83b20f2 100644 --- a/docs/advanced/document-store/annlite.md +++ b/docs/advanced/document-store/annlite.md @@ -48,3 +48,69 @@ The following configs can be set: | `max_connection` | The number of bi-directional links created for every new element during construction. | `None`, defaults to the default value in the AnnLite package* | *You can check the default values in [the AnnLite source code](https://github.com/jina-ai/annlite/blob/main/annlite/core/index/hnsw/index.py) + + +## Vector search with filter + +Search with `.find` can be restricted by user-defined filters. +Filters can be constructed following the guidelines provided in [the AnnLite source repository](https://github.com/jina-ai/annlite). + +### Example of `.find` with a filter + +Consider Documents with embeddings `[0,0,0]` up to ` [9,9,9]` where the document with embedding `[i,i,i]` +has as tag `price` with value `i`. We can create such example with the following code: + + +```python +from docarray import Document, DocumentArray +import numpy as np + +n_dim = 3 +metric = 'Euclidean' + +da = DocumentArray( + storage='annlite', + config={'n_dim': n_dim, 'columns': [('price', 'float')], 'metric': metric}, +) + +with da: + da.extend( + [ + Document(id=f'r{i}', embedding=i * np.ones(n_dim), tags={'price': i}) + for i in range(10) + ] + ) +``` + +Consider we want the nearest vectors to the embedding `[8. 8. 8.]`, with the restriction that +prices must follow a filter. As an example, let's consider that retrieved documents must have `price` value lower +or equal than `max_price`. We can encode this information in annlite using `filter = {'price': {'$lte': max_price}}`. + +Then the search with the proposed filter can implemented and used with the following code: + +```python +max_price = 7 +n_limit = 4 + +np_query = np.ones(n_dim) * 8 +print(f'\nQuery vector: \t{np_query}') + +filter = {'price': {'$lte': max_price}} +results = da.find(np_query, filter=filter, limit=n_limit) +print('\nEmbeddings Nearest Neighbours with "price" at most 7:\n') +for embedding, price in zip(results.embeddings, results[:, 'tags__price']): + print(f'\tembedding={embedding},\t price={price}') +``` + +This would print: + +```bash +Query vector: [8. 8. 8.] + +Embeddings Nearest Neighbours with "price" at most 7: + + embedding=[7. 7. 7.], price=7 + embedding=[6. 6. 6.], price=6 + embedding=[5. 5. 5.], price=5 + embedding=[4. 4. 4.], price=4 + ``` diff --git a/docs/advanced/document-store/qdrant.md b/docs/advanced/document-store/qdrant.md index 5dc3987c912..02cfedb3542 100644 --- a/docs/advanced/document-store/qdrant.md +++ b/docs/advanced/document-store/qdrant.md @@ -132,3 +132,77 @@ print(da.find(np.random.random(D), limit=10)) ```bash ``` + + +## Vector search with filter + +Search with `.find` can be restricted by user-defined filters. Such filters can be constructed following the guidelines +in [Qdrant's Documentation](https://qdrant.tech/documentation/filtering/) + + +### Example of `.find` with a filter + + +Consider Documents with embeddings `[0,0,0]` up to ` [9,9,9]` where the document with embedding `[i,i,i]` +has as tag `price` with value `i`. We can create such example with the following code: + +```python +from docarray import Document, DocumentArray +import numpy as np + +n_dim = 3 +distance = 'euclidean' + +da = DocumentArray( + storage='qdrant', + config={'n_dim': n_dim, 'columns': [('price', 'float')], 'distance': distance}, +) + +print(f'\nDocumentArray distance: {distance}') + +with da: + da.extend( + [ + Document(id=f'r{i}', embedding=i * np.ones(n_dim), tags={'price': i}) + for i in range(10) + ] + ) + +print('\nIndexed Embeddings:\n') +for embedding, price in zip(da.embeddings, da[:, 'tags__price']): + print(f'\tembedding={embedding},\t price={price}') +``` + +Consider we want the nearest vectors to the embedding `[8. 8. 8.]`, with the restriction that +prices must follow a filter. As an example, let's consider that retrieved documents must have `price` value lower +or equal than `max_price`. We can encode this information in annlite using `filter = {'price': {'$lte': max_price}}`. + +Then the search with the proposed filter can implemented and used with the following code: + +```python +max_price = 7 +n_limit = 4 + +np_query = np.ones(n_dim) * 8 +print(f'\nQuery vector: \t{np_query}') + +filter = {'must': [{'key': 'price', 'range': {'lte': max_price}}]} +results = da.find(np_query, filter=filter, limit=n_limit) + +print('\nEmbeddings Nearest Neighbours with "price" at most 7:\n') +for embedding, price in zip(results.embeddings, results[:, 'tags__price']): + print(f'\tembedding={embedding},\t price={price}') +``` + +This would print: + +``` +Query vector: [8. 8. 8.] + +Embeddings Nearest Neighbours with "price" at most 7: + + embedding=[7. 7. 7.], price=7 + embedding=[6. 6. 6.], price=6 + embedding=[5. 5. 5.], price=5 + embedding=[4. 4. 4.], price=4 +``` diff --git a/docs/advanced/document-store/weaviate.md b/docs/advanced/document-store/weaviate.md index 01b35ec5797..a66169da677 100644 --- a/docs/advanced/document-store/weaviate.md +++ b/docs/advanced/document-store/weaviate.md @@ -168,3 +168,71 @@ print(results[0].text) ```text Persist Documents with Weaviate. ``` + + +## Vector search with filter + +Search with `.find` can be restricted by user-defined filters. Such filters can be constructed following the guidelines +in [Weaviate's Documentation](https://weaviate.io/developers/weaviate/current/graphql-references/filters.html). + +### Example of `.find` with a filter + +Consider Documents with embeddings `[0,0,0]` up to ` [9,9,9]` where the document with embedding `[i,i,i]` +has as tag `price` with value `i`. We can create such example with the following code: + + +```python +from docarray import Document, DocumentArray +import numpy as np + +n_dim = 3 + +da = DocumentArray( + storage='weaviate', + config={'n_dim': n_dim, 'columns': [('price', 'int')], 'distance': 'l2-squared'}, +) + +with da: + da.extend( + [ + Document(id=f'r{i}', embedding=i * np.ones(n_dim), tags={'price': i}) + for i in range(10) + ] + ) + +print('\nIndexed Embeddings:\n') +for embedding, price in zip(da.embeddings, da[:, 'tags__price']): + print(f'\tembedding={embedding},\t price={price}') +``` + +Consider we want the nearest vectors to the embedding `[8. 8. 8.]`, with the restriction that +prices must follow a filter. As an example, let's consider that retrieved documents must have `price` value lower +or equal than `max_price`. We can encode this information in weaviate using `filter = {'path': ['price'], 'operator': 'LowerThanEqual', 'valueInt': max_price}`. + +Then the search with the proposed filter can implemented and used with the following code: + +```python +max_price = 7 +n_limit = 4 + +np_query = np.ones(n_dim) * 8 +print(f'\nQuery vector: \t{np_query}') + +filter = {'path': ['price'], 'operator': 'LessThanEqual', 'valueInt': max_price} +results = da.find(np_query, filter=filter, limit=n_limit) + +print('\nEmbeddings Nearest Neighbours with "price" at most 7:\n') +for embedding, price in zip(results.embeddings, results[:, 'tags__price']): + print(f'\tembedding={embedding},\t price={price}') +``` + +This would print: + +```bash +Embeddings Nearest Neighbours with "price" at most 7: + + embedding=[7. 7. 7.], price=7 + embedding=[6. 6. 6.], price=6 + embedding=[5. 5. 5.], price=5 + embedding=[4. 4. 4.], price=4 + ``` diff --git a/tests/conftest.py b/tests/conftest.py index 09989e06156..22cbf372e20 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,7 +16,7 @@ def tmpfile(tmpdir): return tmpdir / tmpfile -@pytest.fixture(scope='session') +@pytest.fixture(scope='module') def start_storage(): os.system( f"docker-compose -f {compose_yml} --project-directory . up --build -d " diff --git a/tests/unit/array/mixins/test_exception.py b/tests/unit/array/mixins/test_exception.py index a7ad7fa1733..e611a563bd4 100644 --- a/tests/unit/array/mixins/test_exception.py +++ b/tests/unit/array/mixins/test_exception.py @@ -34,5 +34,5 @@ def test_embedding_ops_error(): with pytest.raises(ValueError, match='Did you forget to set'): db.find(da) da.embeddings = np.random.random([100, 256]) - with pytest.raises(ValueError, match='Did you forget to set'): + with pytest.raises(ValueError, match='filter must be dict when query is None'): da.find(None) diff --git a/tests/unit/array/mixins/test_filter.py b/tests/unit/array/mixins/test_filter.py index f95cedf8461..165a650a5cc 100644 --- a/tests/unit/array/mixins/test_filter.py +++ b/tests/unit/array/mixins/test_filter.py @@ -22,29 +22,35 @@ def test_empty_filter(docs): assert len(result) == 5 -def test_simple_filter(docs): - result = docs.find({'text': {'$eq': 'hello'}}) +@pytest.mark.parametrize('filter_api', [True, False]) +def test_simple_filter(docs, filter_api): + if filter_api: + method = lambda query: docs.find(filter=query) + else: + method = lambda query: docs.find(query) + + result = method({'text': {'$eq': 'hello'}}) assert len(result) == 1 assert result[0].text == 'hello' - result = docs.find({'tags__x': {'$gte': 0.5}}) + result = method({'tags__x': {'$gte': 0.5}}) assert len(result) == 1 assert result[0].tags['x'] == 0.8 - result = docs.find({'tags__name': {'$regex': '^h'}}) + result = method({'tags__name': {'$regex': '^h'}}) assert len(result) == 2 assert result[1].id == docs[1].id - result = docs.find({'text': {'$regex': '^h'}}) + result = method({'text': {'$regex': '^h'}}) assert len(result) == 1 assert result[0].id == docs[0].id - result = docs.find({'tags': {'$size': 2}}) + result = method({'tags': {'$size': 2}}) assert result[0].id == docs[2].id - result = docs.find({'text': {'$exists': True}}) + result = method({'text': {'$exists': True}}) assert len(result) == 2 - result = docs.find({'tensor': {'$exists': True}}) + result = method({'tensor': {'$exists': True}}) assert len(result) == 0 diff --git a/tests/unit/array/mixins/test_find.py b/tests/unit/array/mixins/test_find.py index f68dafb63b7..b77c2122b7c 100644 --- a/tests/unit/array/mixins/test_find.py +++ b/tests/unit/array/mixins/test_find.py @@ -1,8 +1,11 @@ +from itertools import product + import numpy as np import pytest from docarray import DocumentArray, Document from docarray.math import ndarray +import operator @pytest.mark.parametrize( @@ -192,3 +195,152 @@ def test_find_by_tag(storage, config, start_storage): assert isinstance(results, list) assert len(results) == 2 assert all([isinstance(result, DocumentArray) for result in results]) == True + + +numeric_operators_annlite = { + '$gte': operator.ge, + '$gt': operator.gt, + '$lte': operator.le, + '$lt': operator.lt, + '$eq': operator.eq, + '$neq': operator.ne, +} + +numeric_operators_weaviate = { + 'GreaterThanEqual': operator.ge, + 'GreaterThan': operator.gt, + 'LessThanEqual': operator.le, + 'LessThan': operator.lt, + 'Equal': operator.eq, + 'NotEqual': operator.ne, +} + + +numeric_operators_qdrant = { + 'gte': operator.ge, + 'gt': operator.gt, + 'lte': operator.le, + 'lt': operator.lt, + 'eq': operator.eq, + 'neq': operator.ne, +} + + +@pytest.mark.parametrize( + 'storage,filter_gen,numeric_operators,operator', + [ + *[ + tuple( + [ + 'weaviate', + lambda operator, threshold: { + 'path': ['price'], + 'operator': operator, + 'valueInt': threshold, + }, + numeric_operators_weaviate, + operator, + ] + ) + for operator in numeric_operators_weaviate.keys() + ], + *[ + tuple( + [ + 'qdrant', + lambda operator, threshold: { + 'must': [{'key': 'price', 'range': {operator: threshold}}] + }, + numeric_operators_qdrant, + operator, + ] + ) + for operator in ['gte', 'gt', 'lte', 'lt'] + ], + *[ + tuple( + [ + 'qdrant', + lambda operator, threshold: { + 'must': [{'key': 'price', 'value': {operator: threshold}}] + }, + numeric_operators_qdrant, + operator, + ] + ) + for operator in ['eq', 'neq'] + ], + *[ + tuple( + [ + 'annlite', + lambda operator, threshold: {'price': {operator: threshold}}, + numeric_operators_annlite, + operator, + ] + ) + for operator in numeric_operators_annlite.keys() + ], + ], +) +def test_search_pre_filtering( + storage, filter_gen, operator, numeric_operators, start_storage +): + n_dim = 128 + da = DocumentArray( + storage=storage, config={'n_dim': n_dim, 'columns': [('price', 'int')]} + ) + + da.extend( + [ + Document(id=f'r{i}', embedding=np.random.rand(n_dim), tags={'price': i}) + for i in range(50) + ] + ) + thresholds = [10, 20, 30] + + for threshold in thresholds: + + filter = filter_gen(operator, threshold) + + results = da.find(np.random.rand(n_dim), filter=filter) + + assert all( + [numeric_operators[operator](r.tags['price'], threshold) for r in results] + ) + + +def test_weaviate_filter_query(start_storage): + n_dim = 128 + da = DocumentArray( + storage='weaviate', config={'n_dim': n_dim, 'columns': [('price', 'int')]} + ) + + da.extend( + [ + Document(id=f'r{i}', embedding=np.random.rand(n_dim), tags={'price': i}) + for i in range(50) + ] + ) + + with pytest.raises(ValueError): + da.find(np.random.rand(n_dim), filter={'wrong': 'filter'}) + + +@pytest.mark.parametrize('storage', ['memory', 'elasticsearch']) +def test_unsupported_pre_filtering(storage, start_storage): + + n_dim = 128 + da = DocumentArray( + storage=storage, config={'n_dim': n_dim, 'columns': [('price', 'int')]} + ) + + da.extend( + [ + Document(id=f'r{i}', embedding=np.random.rand(n_dim), tags={'price': i}) + for i in range(50) + ] + ) + + with pytest.raises(ValueError): + da.find(np.random.rand(n_dim), filter={'price': {'$gte': 2}})