From 651ad8cbdef0eee3aef18c7ce38ac06742b95984 Mon Sep 17 00:00:00 2001 From: David Buchaca Date: Mon, 9 May 2022 14:54:18 +0200 Subject: [PATCH 01/37] feat: add annlite filter --- docarray/array/storage/annlite/backend.py | 3 +- docarray/array/storage/annlite/find.py | 5 +-- tests/unit/array/mixins/test_find.py | 37 +++++++++++++++++++++++ 3 files changed, 42 insertions(+), 3 deletions(-) diff --git a/docarray/array/storage/annlite/backend.py b/docarray/array/storage/annlite/backend.py index 82bee7e03cc..fc450558bf4 100644 --- a/docarray/array/storage/annlite/backend.py +++ b/docarray/array/storage/annlite/backend.py @@ -4,7 +4,7 @@ Dict, Optional, TYPE_CHECKING, - Iterable, + Iterable, List, Tuple, ) import numpy as np @@ -25,6 +25,7 @@ class AnnliteConfig: ef_construction: Optional[int] = None ef_search: Optional[int] = None max_connection: Optional[int] = None + columns: Optional[List[Tuple[str, str]]] = None class BackendMixin(BaseBackendMixin): diff --git a/docarray/array/storage/annlite/find.py b/docarray/array/storage/annlite/find.py index cf8824e7a0e..55f072d72d1 100644 --- a/docarray/array/storage/annlite/find.py +++ b/docarray/array/storage/annlite/find.py @@ -2,7 +2,7 @@ Union, Optional, TYPE_CHECKING, - List, + List, Dict, ) if TYPE_CHECKING: @@ -16,6 +16,7 @@ def _find( query: 'np.ndarray', limit: Optional[Union[int, float]] = 20, only_id: bool = False, + filter: Dict = {}, **kwargs, ) -> List['DocumentArray']: """Returns approximate nearest neighbors given an input query. @@ -34,7 +35,7 @@ def _find( query = query.reshape(1, -1) _, match_docs = self._annlite._search_documents( - query, limit=limit, include_metadata=not only_id + query, limit=limit, filter=filter, include_metadata=not only_id ) return match_docs diff --git a/tests/unit/array/mixins/test_find.py b/tests/unit/array/mixins/test_find.py index f68dafb63b7..29ff44b04dc 100644 --- a/tests/unit/array/mixins/test_find.py +++ b/tests/unit/array/mixins/test_find.py @@ -3,6 +3,7 @@ from docarray import DocumentArray, Document from docarray.math import ndarray +import operator @pytest.mark.parametrize( @@ -192,3 +193,39 @@ def test_find_by_tag(storage, config, start_storage): assert isinstance(results, list) assert len(results) == 2 assert all([isinstance(result, DocumentArray) for result in results]) == True + + + +numeric_operators = { + '$gte': operator.ge, + '$gt': operator.gt, + '$lte': operator.le, + '$lt': operator.lt, + '$eq': operator.eq, + '$neq': operator.ne, +} + +@pytest.mark.parametrize('operator', list(numeric_operators.keys())) +def test_search_annlite_filter(tmpdir, operator): + + Nq = 5 + D = 128 + columns = [('price', float), ('category', str)] + da = DocumentArray(storage='annlite', config={'n_dim': D, 'columns': columns, 'data_path': str(tmpdir)}) + + X = np.random.random((Nq, D)).astype(np.float32) + query_da = DocumentArray([Document(embedding=X[i]) for i in range(Nq)]) + + thresholds = [20, 50, 100, 400] + + for threshold in thresholds: + da.find( + query_da, filter={'price': {operator: threshold}}, include_metadata=True + ) + for query in query_da: + assert all( + [ + numeric_operators[operator](m.tags['price'], threshold) + for m in query.matches + ] + ) \ No newline at end of file From 06be300faf2ab30b0116f38aefee708139a9779b Mon Sep 17 00:00:00 2001 From: David Buchaca Date: Mon, 9 May 2022 14:56:42 +0200 Subject: [PATCH 02/37] refactor: black style --- docarray/array/storage/annlite/backend.py | 4 +++- docarray/array/storage/annlite/find.py | 3 ++- tests/unit/array/mixins/test_find.py | 21 ++++++++++++--------- 3 files changed, 17 insertions(+), 11 deletions(-) diff --git a/docarray/array/storage/annlite/backend.py b/docarray/array/storage/annlite/backend.py index fc450558bf4..0d5b7ebb98a 100644 --- a/docarray/array/storage/annlite/backend.py +++ b/docarray/array/storage/annlite/backend.py @@ -4,7 +4,9 @@ Dict, Optional, TYPE_CHECKING, - Iterable, List, Tuple, + Iterable, + List, + Tuple, ) import numpy as np diff --git a/docarray/array/storage/annlite/find.py b/docarray/array/storage/annlite/find.py index 55f072d72d1..49412203f13 100644 --- a/docarray/array/storage/annlite/find.py +++ b/docarray/array/storage/annlite/find.py @@ -2,7 +2,8 @@ Union, Optional, TYPE_CHECKING, - List, Dict, + List, + Dict, ) if TYPE_CHECKING: diff --git a/tests/unit/array/mixins/test_find.py b/tests/unit/array/mixins/test_find.py index 29ff44b04dc..8e6d34ef2be 100644 --- a/tests/unit/array/mixins/test_find.py +++ b/tests/unit/array/mixins/test_find.py @@ -195,23 +195,26 @@ def test_find_by_tag(storage, config, start_storage): assert all([isinstance(result, DocumentArray) for result in results]) == True - numeric_operators = { - '$gte': operator.ge, - '$gt': operator.gt, - '$lte': operator.le, - '$lt': operator.lt, - '$eq': operator.eq, - '$neq': operator.ne, + '$gte': operator.ge, + '$gt': operator.gt, + '$lte': operator.le, + '$lt': operator.lt, + '$eq': operator.eq, + '$neq': operator.ne, } + @pytest.mark.parametrize('operator', list(numeric_operators.keys())) def test_search_annlite_filter(tmpdir, operator): Nq = 5 D = 128 columns = [('price', float), ('category', str)] - da = DocumentArray(storage='annlite', config={'n_dim': D, 'columns': columns, 'data_path': str(tmpdir)}) + da = DocumentArray( + storage='annlite', + config={'n_dim': D, 'columns': columns, 'data_path': str(tmpdir)}, + ) X = np.random.random((Nq, D)).astype(np.float32) query_da = DocumentArray([Document(embedding=X[i]) for i in range(Nq)]) @@ -228,4 +231,4 @@ def test_search_annlite_filter(tmpdir, operator): numeric_operators[operator](m.tags['price'], threshold) for m in query.matches ] - ) \ No newline at end of file + ) From 2df4bc0d9772f490e17dc3c03ff9403c2942f5ff Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Tue, 10 May 2022 11:52:42 +0200 Subject: [PATCH 03/37] feat: add filter parameter to interface --- docarray/array/mixins/find.py | 24 +++++++++++++++--------- docarray/array/storage/annlite/find.py | 4 ++-- docarray/array/storage/elastic/find.py | 7 +++++++ docarray/array/storage/memory/find.py | 5 +++++ docarray/array/storage/qdrant/find.py | 8 +++++++- docarray/array/storage/weaviate/find.py | 8 +++++++- 6 files changed, 43 insertions(+), 13 deletions(-) diff --git a/docarray/array/mixins/find.py b/docarray/array/mixins/find.py index 969fd57cba4..c6286ae8f66 100644 --- a/docarray/array/mixins/find.py +++ b/docarray/array/mixins/find.py @@ -87,13 +87,16 @@ def find(self: 'T', query: Dict, **kwargs) -> 'DocumentArray': def find( self: 'T', - query: Union['DocumentArray', 'Document', 'ArrayType', Dict, str, List[str]], + query: Union[ + 'DocumentArray', 'Document', 'ArrayType', Dict, str, List[str], None + ] = None, metric: Union[ str, Callable[['ArrayType', 'ArrayType'], 'np.ndarray'] ] = 'cosine', limit: Optional[Union[int, float]] = 20, metric_name: Optional[str] = None, exclude_self: bool = False, + filter: Optional[Dict] = None, only_id: bool = False, index: str = 'text', **kwargs, @@ -124,14 +127,10 @@ def find( from ... import Document, DocumentArray - if isinstance(query, dict): + if isinstance(query, dict) and filter is None: return self._filter(query) - elif isinstance(query, (DocumentArray, Document)): - - if isinstance(query, Document): - query = DocumentArray(query) - - _query = query.embeddings + elif query is None and isinstance(filter, dict): + return self._filter(filter) elif isinstance(query, str) or ( isinstance(query, list) and isinstance(query[0], str) ): @@ -140,6 +139,12 @@ def find( return result[0] else: return result + elif isinstance(query, (DocumentArray, Document)): + + if isinstance(query, Document): + query = DocumentArray(query) + + _query = query.embeddings else: _query = query @@ -169,6 +174,7 @@ def find( _result = self._find( _query, + filter=filter, **kwargs, ) @@ -227,7 +233,7 @@ def find( @abc.abstractmethod def _find( - self, query: 'ArrayType', limit: int, **kwargs + self, query: 'ArrayType', limit: int, filter: Optional[Dict] = None, **kwargs ) -> Tuple['np.ndarray', 'np.ndarray']: raise NotImplementedError diff --git a/docarray/array/storage/annlite/find.py b/docarray/array/storage/annlite/find.py index 49412203f13..26576ab15c7 100644 --- a/docarray/array/storage/annlite/find.py +++ b/docarray/array/storage/annlite/find.py @@ -17,7 +17,7 @@ def _find( query: 'np.ndarray', limit: Optional[Union[int, float]] = 20, only_id: bool = False, - filter: Dict = {}, + filter: Optional[Dict] = None, **kwargs, ) -> List['DocumentArray']: """Returns approximate nearest neighbors given an input query. @@ -36,7 +36,7 @@ def _find( query = query.reshape(1, -1) _, match_docs = self._annlite._search_documents( - query, limit=limit, filter=filter, include_metadata=not only_id + query, limit=limit, filter=filter or {}, include_metadata=not only_id ) return match_docs diff --git a/docarray/array/storage/elastic/find.py b/docarray/array/storage/elastic/find.py index 8149135b448..0c920e5540f 100644 --- a/docarray/array/storage/elastic/find.py +++ b/docarray/array/storage/elastic/find.py @@ -4,6 +4,8 @@ Sequence, List, Union, + Optional, + Dict, ) import numpy as np @@ -99,6 +101,7 @@ def _find( self, query: 'ElasticArrayType', limit: int = 10, + filter: Optional[Dict] = None, **kwargs, ) -> List['DocumentArray']: """Returns approximate nearest neighbors given a batch of input queries. @@ -109,6 +112,10 @@ def _find( :return: DocumentArray containing the closest documents to the query if it is a single query, otherwise a list of DocumentArrays containing the closest Document objects for each of the queries in `query`. """ + if filter is not None: + raise ValueError( + 'Filtered vector search is not supported for ElasticSearch backend' + ) query = np.array(query) num_rows, n_dim = ndarray.get_array_rows(query) if n_dim != 2: diff --git a/docarray/array/storage/memory/find.py b/docarray/array/storage/memory/find.py index a3f5c667da2..5ac1cc2a706 100644 --- a/docarray/array/storage/memory/find.py +++ b/docarray/array/storage/memory/find.py @@ -27,6 +27,7 @@ def _find( use_scipy: bool = False, device: str = 'cpu', num_worker: Optional[int] = 1, + filter: Optional[Dict] = None, **kwargs, ) -> Tuple['np.ndarray', 'np.ndarray']: """Returns approximate nearest neighbors given a batch of input queries. @@ -52,6 +53,10 @@ def _find( :return: a list of DocumentArrays containing the closest Document objects for each of the queries in `query`. """ + if filter is not None: + raise ValueError( + 'Filtered vector search is not supported for In-Memory backend' + ) if batch_size is not None: if batch_size <= 0: diff --git a/docarray/array/storage/qdrant/find.py b/docarray/array/storage/qdrant/find.py index e1c7c7b71b8..db35e55a9c5 100644 --- a/docarray/array/storage/qdrant/find.py +++ b/docarray/array/storage/qdrant/find.py @@ -4,6 +4,8 @@ TypeVar, Sequence, List, + Dict, + Optional, ) from qdrant_client.http.models.models import Distance @@ -74,7 +76,11 @@ def _find_similar_vectors(self, q: 'QdrantArrayType', limit=10): return DocumentArray(docs) def _find( - self, query: 'QdrantArrayType', limit: int = 10, **kwargs + self, + query: 'QdrantArrayType', + limit: int = 10, + filter: Optional[Dict] = None, + **kwargs, ) -> List['DocumentArray']: """Returns approximate nearest neighbors given a batch of input queries. :param query: input supported to be used in Qdrant. diff --git a/docarray/array/storage/weaviate/find.py b/docarray/array/storage/weaviate/find.py index 3cdaa435bfe..610ae981f67 100644 --- a/docarray/array/storage/weaviate/find.py +++ b/docarray/array/storage/weaviate/find.py @@ -3,6 +3,8 @@ TypeVar, Sequence, List, + Dict, + Optional, ) import numpy as np @@ -64,7 +66,11 @@ def _find_similar_vectors(self, query: 'WeaviateArrayType', limit=10): return DocumentArray(docs) def _find( - self, query: 'WeaviateArrayType', limit: int = 10, **kwargs + self, + query: 'WeaviateArrayType', + limit: int = 10, + filter: Optional[Dict] = None, + **kwargs ) -> List['DocumentArray']: """Returns approximate nearest neighbors given a batch of input queries. :param query: input supported to be stored in Weaviate. This includes any from the list '[np.ndarray, tensorflow.Tensor, torch.Tensor, Sequence[float]]' From 4768f3eaeb4a4d8b5baecc433f14040bb90a1b37 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Tue, 10 May 2022 15:01:43 +0200 Subject: [PATCH 04/37] fix: linting --- docarray/array/storage/memory/find.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docarray/array/storage/memory/find.py b/docarray/array/storage/memory/find.py index 5ac1cc2a706..45ad076aab1 100644 --- a/docarray/array/storage/memory/find.py +++ b/docarray/array/storage/memory/find.py @@ -1,4 +1,4 @@ -from typing import Optional, Union, Tuple, Callable, TYPE_CHECKING +from typing import Optional, Union, Tuple, Callable, TYPE_CHECKING, Dict import numpy as np From 035c8e9712ef82e9c0657c9b242c48771708a338 Mon Sep 17 00:00:00 2001 From: David Buchaca Prats Date: Wed, 11 May 2022 16:21:27 +0200 Subject: [PATCH 05/37] feat: add columns to backend configs --- docarray/array/storage/elastic/backend.py | 1 + docarray/array/storage/qdrant/backend.py | 2 ++ docarray/array/storage/sqlite/backend.py | 3 +++ docarray/array/storage/weaviate/backend.py | 1 + tests/unit/array/mixins/test_find.py | 1 + 5 files changed, 8 insertions(+) diff --git a/docarray/array/storage/elastic/backend.py b/docarray/array/storage/elastic/backend.py index 6911e11f455..9e2800a6b64 100644 --- a/docarray/array/storage/elastic/backend.py +++ b/docarray/array/storage/elastic/backend.py @@ -38,6 +38,7 @@ class ElasticConfig: batch_size: int = 64 ef_construction: Optional[int] = None m: Optional[int] = None + columns: Optional[List[Tuple[str, str]]] = None class BackendMixin(BaseBackendMixin): diff --git a/docarray/array/storage/qdrant/backend.py b/docarray/array/storage/qdrant/backend.py index 6cf62357d5f..2bc3520b9b9 100644 --- a/docarray/array/storage/qdrant/backend.py +++ b/docarray/array/storage/qdrant/backend.py @@ -7,6 +7,7 @@ Dict, Iterable, List, + Tuple, ) import numpy as np @@ -41,6 +42,7 @@ class QdrantConfig: ef_construct: Optional[int] = None full_scan_threshold: Optional[int] = None m: Optional[int] = None + columns: Optional[List[Tuple[str, str]]] = None class BackendMixin(BaseBackendMixin): diff --git a/docarray/array/storage/sqlite/backend.py b/docarray/array/storage/sqlite/backend.py index 8574e2053d2..8131af28054 100644 --- a/docarray/array/storage/sqlite/backend.py +++ b/docarray/array/storage/sqlite/backend.py @@ -8,6 +8,8 @@ Optional, TYPE_CHECKING, Union, + List, + Tuple, ) from .helper import initialize_table @@ -33,6 +35,7 @@ class SqliteConfig: conn_config: Dict = field(default_factory=dict) journal_mode: str = 'DELETE' synchronous: str = 'OFF' + columns: Optional[List[Tuple[str, str]]] = None class BackendMixin(BaseBackendMixin): diff --git a/docarray/array/storage/weaviate/backend.py b/docarray/array/storage/weaviate/backend.py index 59756f647f0..73ff79c2ffe 100644 --- a/docarray/array/storage/weaviate/backend.py +++ b/docarray/array/storage/weaviate/backend.py @@ -37,6 +37,7 @@ class WeaviateConfig: ef_construction: Optional[int] = None timeout_config: Optional[Tuple[int, int]] = None max_connections: Optional[int] = None + columns: Optional[List[Tuple[str, str]]] = None class BackendMixin(BaseBackendMixin): diff --git a/tests/unit/array/mixins/test_find.py b/tests/unit/array/mixins/test_find.py index 8e6d34ef2be..065354a4c28 100644 --- a/tests/unit/array/mixins/test_find.py +++ b/tests/unit/array/mixins/test_find.py @@ -226,6 +226,7 @@ def test_search_annlite_filter(tmpdir, operator): query_da, filter={'price': {operator: threshold}}, include_metadata=True ) for query in query_da: + assert len(query.matches) > 0 assert all( [ numeric_operators[operator](m.tags['price'], threshold) From 30741b8422d57039a7f7b00d2e8ed2f3a880a1a0 Mon Sep 17 00:00:00 2001 From: David Buchaca Prats Date: Wed, 11 May 2022 16:41:08 +0200 Subject: [PATCH 06/37] fix: missing type hint --- docarray/array/storage/elastic/backend.py | 1 + 1 file changed, 1 insertion(+) diff --git a/docarray/array/storage/elastic/backend.py b/docarray/array/storage/elastic/backend.py index 9e2800a6b64..9271ceaec8b 100644 --- a/docarray/array/storage/elastic/backend.py +++ b/docarray/array/storage/elastic/backend.py @@ -9,6 +9,7 @@ List, Iterable, Any, + Tuple, ) import numpy as np From 1f637038a6b2a4260e11f010f36a02a27328b2a5 Mon Sep 17 00:00:00 2001 From: David Buchaca Prats Date: Mon, 16 May 2022 19:05:13 +0200 Subject: [PATCH 07/37] feat: add column to weaviate schema --- docarray/array/storage/weaviate/backend.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/docarray/array/storage/weaviate/backend.py b/docarray/array/storage/weaviate/backend.py index 73ff79c2ffe..99e9d9961f7 100644 --- a/docarray/array/storage/weaviate/backend.py +++ b/docarray/array/storage/weaviate/backend.py @@ -123,7 +123,7 @@ def _get_schema_by_name(self, cls_name: str) -> Dict: 'maxConnections': self._config.max_connections, } - return { + base_classes = { 'classes': [ { 'class': cls_name, @@ -151,6 +151,22 @@ def _get_schema_by_name(self, cls_name: str) -> Dict: }, ] } + for col, coltype in self._config.columns: + new_class = { + 'class': cls_name + col[0].upper() + col[1:], + "vectorizer": "none", + 'vectorIndexConfig': {'skip': False}, + 'properties': [ + { + 'dataType': [coltype], + 'name': cls_name + col, + 'indexInverted': False, + }, + ], + } + base_classes['classes'].append(new_class) + + return base_classes def _load_or_create_weaviate_schema(self): """Create a new weaviate schema for this :class:`DocumentArrayWeaviate` object From 232a4198cf7c072164ac6cb8858b8593b3aacab9 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Tue, 17 May 2022 16:50:23 +0200 Subject: [PATCH 08/37] fix: add col type mapping for weaviate --- docarray/array/storage/base/backend.py | 5 +++++ docarray/array/storage/weaviate/backend.py | 6 ++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/docarray/array/storage/base/backend.py b/docarray/array/storage/base/backend.py index 1cde4ccf09d..14f32434aa1 100644 --- a/docarray/array/storage/base/backend.py +++ b/docarray/array/storage/base/backend.py @@ -7,6 +7,8 @@ class BaseBackendMixin(ABC): + TYPE_MAP: Dict + def _init_storage( self, _docs: Optional['DocumentArraySourceType'] = None, @@ -27,3 +29,6 @@ def _map_embedding(self, embedding: 'ArrayType') -> 'ArrayType': from ....math.ndarray import to_numpy_array return to_numpy_array(embedding) + + def _map_type(self, col_type: str) -> str: + return self.TYPE_MAP[col_type] diff --git a/docarray/array/storage/weaviate/backend.py b/docarray/array/storage/weaviate/backend.py index 99e9d9961f7..49d084b7ae6 100644 --- a/docarray/array/storage/weaviate/backend.py +++ b/docarray/array/storage/weaviate/backend.py @@ -43,6 +43,8 @@ class WeaviateConfig: class BackendMixin(BaseBackendMixin): """Provide necessary functions to enable this storage backend.""" + TYPE_MAP = {'str': 'string', 'float': 'number', 'int': 'int'} + def _init_storage( self, _docs: Optional['DocumentArraySourceType'] = None, @@ -155,10 +157,10 @@ def _get_schema_by_name(self, cls_name: str) -> Dict: new_class = { 'class': cls_name + col[0].upper() + col[1:], "vectorizer": "none", - 'vectorIndexConfig': {'skip': False}, + 'vectorIndexConfig': {'skip': True}, 'properties': [ { - 'dataType': [coltype], + 'dataType': [self._map_type(coltype)], 'name': cls_name + col, 'indexInverted': False, }, From d33ddf237bcb271a7e02622de7311ef6c11cc2a6 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Wed, 18 May 2022 10:37:06 +0200 Subject: [PATCH 09/37] fix: default to empty list if None --- docarray/array/storage/weaviate/backend.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docarray/array/storage/weaviate/backend.py b/docarray/array/storage/weaviate/backend.py index 49d084b7ae6..296f142be65 100644 --- a/docarray/array/storage/weaviate/backend.py +++ b/docarray/array/storage/weaviate/backend.py @@ -81,6 +81,9 @@ def _init_storage( ) self._config = config + if self._config.columns is None: + self._config.columns = [] + self._schemas = self._load_or_create_weaviate_schema() _REGISTRY[self.__class__.__name__][self._class_name].append(self) From c1a3fdb88dedb6357bfa883a19f5115e03deefa3 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Wed, 18 May 2022 11:00:50 +0200 Subject: [PATCH 10/37] feat: set attributes in _set_doc_by_id for weaviate --- docarray/array/storage/weaviate/backend.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/docarray/array/storage/weaviate/backend.py b/docarray/array/storage/weaviate/backend.py index 38012ff57db..81028b7bfa4 100644 --- a/docarray/array/storage/weaviate/backend.py +++ b/docarray/array/storage/weaviate/backend.py @@ -315,8 +315,13 @@ def _doc2weaviate_create_payload(self, value: 'Document'): :param value: document to create a payload for :return: the payload dictionary """ + extra_columns = {col: value.tags.get(col) for col, _ in self._config.columns} + return dict( - data_object={'_serialized': value.to_base64(**self._serialize_config)}, + data_object={ + '_serialized': value.to_base64(**self._serialize_config), + **extra_columns, + }, class_name=self._class_name, uuid=self._map_id(value.id), vector=self._map_embedding(value.embedding), From e7df1a0165911bf98efcf78c5e5c4cde57565d26 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Wed, 18 May 2022 11:19:50 +0200 Subject: [PATCH 11/37] fix: columns are added as properties not classes --- docarray/array/storage/weaviate/backend.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/docarray/array/storage/weaviate/backend.py b/docarray/array/storage/weaviate/backend.py index 81028b7bfa4..045bedf7f55 100644 --- a/docarray/array/storage/weaviate/backend.py +++ b/docarray/array/storage/weaviate/backend.py @@ -172,19 +172,12 @@ def _get_schema_by_name(self, cls_name: str) -> Dict: ] } for col, coltype in self._config.columns: - new_class = { - 'class': cls_name + col[0].upper() + col[1:], - "vectorizer": "none", - 'vectorIndexConfig': {'skip': True}, - 'properties': [ - { - 'dataType': [self._map_type(coltype)], - 'name': cls_name + col, - 'indexInverted': False, - }, - ], + new_property = { + 'dataType': [self._map_type(coltype)], + 'name': col, + 'indexInverted': False, } - base_classes['classes'].append(new_class) + base_classes['classes'][0]['properties'].append(new_property) return base_classes From d249573f865eabe35c9efa16c589bdb899d41687 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Wed, 18 May 2022 11:20:17 +0200 Subject: [PATCH 12/37] feat: use filter in _find for weaviate --- docarray/array/storage/weaviate/find.py | 36 +++++++++++++++++-------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/docarray/array/storage/weaviate/find.py b/docarray/array/storage/weaviate/find.py index 610ae981f67..f52a14c60e9 100644 --- a/docarray/array/storage/weaviate/find.py +++ b/docarray/array/storage/weaviate/find.py @@ -29,22 +29,36 @@ class FindMixin: - def _find_similar_vectors(self, query: 'WeaviateArrayType', limit=10): + def _find_similar_vectors( + self, query: 'WeaviateArrayType', limit=10, filter: Optional[Dict] = None + ): query = to_numpy_array(query) is_all_zero = np.all(query == 0) if is_all_zero: query = query + EPSILON query_dict = {'vector': query} - results = ( - self._client.query.get( - self._class_name, - ['_serialized', '_additional {certainty}', '_additional {id}'], + if filter: + results = ( + self._client.query.get( + self._class_name, + ['_serialized', '_additional {certainty}', '_additional {id}'], + ) + .with_where(filter) + .with_limit(limit) + .with_near_vector(query_dict) + .do() + ) + else: + results = ( + self._client.query.get( + self._class_name, + ['_serialized', '_additional {certainty}', '_additional {id}'], + ) + .with_limit(limit) + .with_near_vector(query_dict) + .do() ) - .with_limit(limit) - .with_near_vector(query_dict) - .do() - ) docs = [] # The serialized document is stored in results['data']['Get'][self._class_name] @@ -86,10 +100,10 @@ def _find( num_rows, _ = ndarray.get_array_rows(query) if num_rows == 1: - return [self._find_similar_vectors(query, limit=limit)] + return [self._find_similar_vectors(query, limit=limit, filter=filter)] else: closest_docs = [] for q in query: - da = self._find_similar_vectors(q, limit=limit) + da = self._find_similar_vectors(q, limit=limit, filter=filter) closest_docs.append(da) return closest_docs From 2b3d91a6e5000fdd9f3d57816e84a1d959030d09 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Wed, 18 May 2022 13:36:13 +0200 Subject: [PATCH 13/37] fix: set indexInverted to True to enable filtering for weaviate --- docarray/array/storage/weaviate/backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docarray/array/storage/weaviate/backend.py b/docarray/array/storage/weaviate/backend.py index 045bedf7f55..e7e53baa7fe 100644 --- a/docarray/array/storage/weaviate/backend.py +++ b/docarray/array/storage/weaviate/backend.py @@ -175,7 +175,7 @@ def _get_schema_by_name(self, cls_name: str) -> Dict: new_property = { 'dataType': [self._map_type(coltype)], 'name': col, - 'indexInverted': False, + 'indexInverted': True, } base_classes['classes'][0]['properties'].append(new_property) From a16fe2a25651fd405864cd5eaa30961326ff6234 Mon Sep 17 00:00:00 2001 From: David Buchaca Prats Date: Thu, 19 May 2022 12:04:51 +0200 Subject: [PATCH 14/37] fix: weaviate error handling message --- docarray/array/storage/weaviate/find.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/docarray/array/storage/weaviate/find.py b/docarray/array/storage/weaviate/find.py index f52a14c60e9..7babb941184 100644 --- a/docarray/array/storage/weaviate/find.py +++ b/docarray/array/storage/weaviate/find.py @@ -61,8 +61,17 @@ def _find_similar_vectors( ) docs = [] + found_results = ( + results.get('data', {}).get('Get', {}).get(self._class_name, []) or [] + ) + if 'errors' in results: + errors = '\n'.join(map(lambda error: error['message'], results['errors'])) + raise ValueError( + f'find failed, please check your filter query. Errors: \n{errors}' + ) + # The serialized document is stored in results['data']['Get'][self._class_name] - for result in results.get('data', {}).get('Get', {}).get(self._class_name, []): + for result in found_results: doc = Document.from_base64(result['_serialized'], **self._serialize_config) certainty = result['_additional']['certainty'] @@ -84,7 +93,7 @@ def _find( query: 'WeaviateArrayType', limit: int = 10, filter: Optional[Dict] = None, - **kwargs + **kwargs, ) -> List['DocumentArray']: """Returns approximate nearest neighbors given a batch of input queries. :param query: input supported to be stored in Weaviate. This includes any from the list '[np.ndarray, tensorflow.Tensor, torch.Tensor, Sequence[float]]' From 64b5532231398e7d21ec85c1809708299a92bbeb Mon Sep 17 00:00:00 2001 From: David Buchaca Prats Date: Thu, 19 May 2022 12:05:52 +0200 Subject: [PATCH 15/37] feat: pre filtering for qdrant --- docarray/array/storage/qdrant/backend.py | 3 +++ docarray/array/storage/qdrant/find.py | 10 ++++++---- docarray/array/storage/qdrant/getsetdel.py | 6 +++++- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/docarray/array/storage/qdrant/backend.py b/docarray/array/storage/qdrant/backend.py index 2bc3520b9b9..4c78915a707 100644 --- a/docarray/array/storage/qdrant/backend.py +++ b/docarray/array/storage/qdrant/backend.py @@ -88,6 +88,9 @@ def _init_storage( self._config = config self._persist = bool(self._config.collection_name) + if self._config.columns is None: + self._config.columns = [] + self._config.collection_name = ( self.__class__.__name__ + random_identity() if self._config.collection_name is None diff --git a/docarray/array/storage/qdrant/find.py b/docarray/array/storage/qdrant/find.py index db35e55a9c5..2119632d441 100644 --- a/docarray/array/storage/qdrant/find.py +++ b/docarray/array/storage/qdrant/find.py @@ -50,13 +50,15 @@ def serialize_config(self) -> dict: def distance(self) -> 'Distance': raise NotImplementedError() - def _find_similar_vectors(self, q: 'QdrantArrayType', limit=10): + def _find_similar_vectors( + self, q: 'QdrantArrayType', limit: int = 10, filter: Optional[Dict] = None + ): query_vector = self._map_embedding(q) search_result = self.client.search( self.collection_name, query_vector=query_vector, - query_filter=None, + query_filter=filter, search_params=None, top=limit, append_payload=['_serialized'], @@ -93,10 +95,10 @@ def _find( num_rows, _ = ndarray.get_array_rows(query) if num_rows == 1: - return [self._find_similar_vectors(query, limit=limit)] + return [self._find_similar_vectors(query, limit=limit, filter=filter)] else: closest_docs = [] for q in query: - da = self._find_similar_vectors(q, limit=limit) + da = self._find_similar_vectors(q, limit=limit, filter=filter) closest_docs.append(da) return closest_docs diff --git a/docarray/array/storage/qdrant/getsetdel.py b/docarray/array/storage/qdrant/getsetdel.py index 62f6bf7b293..fdc2f1b1069 100644 --- a/docarray/array/storage/qdrant/getsetdel.py +++ b/docarray/array/storage/qdrant/getsetdel.py @@ -65,9 +65,13 @@ def _qdrant_to_document(self, qdrant_record: dict) -> 'Document': ) def _document_to_qdrant(self, doc: 'Document') -> 'PointStruct': + extra_columns = {col: doc.tags.get(col) for col, _ in self._config.columns} + return PointStruct( id=self._map_id(doc.id), - payload=dict(_serialized=doc.to_base64(**self.serialization_config)), + payload=dict( + _serialized=doc.to_base64(**self.serialization_config), **extra_columns + ), vector=self._map_embedding(doc.embedding), ) From 79b7db75a5338dde5467b3605f1220ed4f5c5da4 Mon Sep 17 00:00:00 2001 From: David Buchaca Prats Date: Thu, 19 May 2022 13:14:27 +0200 Subject: [PATCH 16/37] test: pre filtering in weaviate and qdrant --- tests/unit/array/mixins/test_find.py | 91 ++++++++++++++++++++++++++-- 1 file changed, 86 insertions(+), 5 deletions(-) diff --git a/tests/unit/array/mixins/test_find.py b/tests/unit/array/mixins/test_find.py index 065354a4c28..3c49ce5d424 100644 --- a/tests/unit/array/mixins/test_find.py +++ b/tests/unit/array/mixins/test_find.py @@ -195,7 +195,7 @@ def test_find_by_tag(storage, config, start_storage): assert all([isinstance(result, DocumentArray) for result in results]) == True -numeric_operators = { +numeric_operators_annlite = { '$gte': operator.ge, '$gt': operator.gt, '$lte': operator.le, @@ -205,8 +205,8 @@ def test_find_by_tag(storage, config, start_storage): } -@pytest.mark.parametrize('operator', list(numeric_operators.keys())) -def test_search_annlite_filter(tmpdir, operator): +@pytest.mark.parametrize('operator', list(numeric_operators_annlite.keys())) +def test_search_annlite_pre_filtering(tmpdir, operator): Nq = 5 D = 128 @@ -226,10 +226,91 @@ def test_search_annlite_filter(tmpdir, operator): query_da, filter={'price': {operator: threshold}}, include_metadata=True ) for query in query_da: - assert len(query.matches) > 0 assert all( [ - numeric_operators[operator](m.tags['price'], threshold) + numeric_operators_annlite[operator](m.tags['price'], threshold) for m in query.matches ] ) + + +numeric_operators_weaviate = { + 'GreaterThanEqual': operator.ge, + 'GreaterThan': operator.gt, + 'LessThanEqual': operator.le, + 'LessThan': operator.lt, + 'Equal': operator.eq, + 'NotEqual': operator.ne, +} + + +@pytest.mark.parametrize('operator', list(numeric_operators_weaviate.keys())) +def test_search_weaviate_pre_filtering(operator): + + n_dim = 128 + da = DocumentArray( + storage='weaviate', config={'n_dim': n_dim, 'columns': [('price', 'int')]} + ) + + da.extend( + [ + Document(id=f'r{i}', embedding=np.random.rand(n_dim), tags={'price': i}) + for i in range(50) + ] + ) + thresholds = [10, 20, 30] + + for threshold in thresholds: + results = da.find( + np.random.rand(n_dim), + filter={"path": ["price"], "operator": operator, "valueInt": threshold}, + ) + assert all( + [ + numeric_operators_weaviate[operator](r.tags['price'], threshold) + for r in results + ] + ) + + +numeric_operators_qrant = { + 'gte': operator.ge, + 'gt': operator.gt, + 'lte': operator.le, + 'lt': operator.lt, + 'eq': operator.eq, + 'neq': operator.ne, +} + + +@pytest.mark.parametrize('operator', list(numeric_operators_qrant.keys())) +def test_search_qdrant_pre_filtering(operator): + + n_dim = 128 + da = DocumentArray( + storage='qdrant', config={'n_dim': n_dim, 'columns': [('price', 'int')]} + ) + + da.extend( + [ + Document(id=f'r{i}', embedding=np.random.rand(n_dim), tags={'price': i}) + for i in range(50) + ] + ) + thresholds = [10, 20, 30] + + for threshold in thresholds: + + if operator in ('eq', 'neq'): + filter = {'key': 'price', 'value': {operator: threshold}} + else: + filter = {'key': 'price', 'range': {operator: threshold}} + + results = da.find(np.random.rand(n_dim), filter={'must': [filter]}) + + assert all( + [ + numeric_operators_qrant[operator](r.tags['price'], threshold) + for r in results + ] + ) From c3b5aea0df0b9c9abd5fcf588655c3f93c91f136 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Thu, 19 May 2022 15:31:26 +0200 Subject: [PATCH 17/37] refactor: unify pre-filtering tests of all backends --- tests/unit/array/mixins/test_find.py | 140 +++++++++++++-------------- 1 file changed, 67 insertions(+), 73 deletions(-) diff --git a/tests/unit/array/mixins/test_find.py b/tests/unit/array/mixins/test_find.py index 3c49ce5d424..afcc3502042 100644 --- a/tests/unit/array/mixins/test_find.py +++ b/tests/unit/array/mixins/test_find.py @@ -1,3 +1,5 @@ +from itertools import product + import numpy as np import pytest @@ -204,36 +206,6 @@ def test_find_by_tag(storage, config, start_storage): '$neq': operator.ne, } - -@pytest.mark.parametrize('operator', list(numeric_operators_annlite.keys())) -def test_search_annlite_pre_filtering(tmpdir, operator): - - Nq = 5 - D = 128 - columns = [('price', float), ('category', str)] - da = DocumentArray( - storage='annlite', - config={'n_dim': D, 'columns': columns, 'data_path': str(tmpdir)}, - ) - - X = np.random.random((Nq, D)).astype(np.float32) - query_da = DocumentArray([Document(embedding=X[i]) for i in range(Nq)]) - - thresholds = [20, 50, 100, 400] - - for threshold in thresholds: - da.find( - query_da, filter={'price': {operator: threshold}}, include_metadata=True - ) - for query in query_da: - assert all( - [ - numeric_operators_annlite[operator](m.tags['price'], threshold) - for m in query.matches - ] - ) - - numeric_operators_weaviate = { 'GreaterThanEqual': operator.ge, 'GreaterThan': operator.gt, @@ -244,36 +216,7 @@ def test_search_annlite_pre_filtering(tmpdir, operator): } -@pytest.mark.parametrize('operator', list(numeric_operators_weaviate.keys())) -def test_search_weaviate_pre_filtering(operator): - - n_dim = 128 - da = DocumentArray( - storage='weaviate', config={'n_dim': n_dim, 'columns': [('price', 'int')]} - ) - - da.extend( - [ - Document(id=f'r{i}', embedding=np.random.rand(n_dim), tags={'price': i}) - for i in range(50) - ] - ) - thresholds = [10, 20, 30] - - for threshold in thresholds: - results = da.find( - np.random.rand(n_dim), - filter={"path": ["price"], "operator": operator, "valueInt": threshold}, - ) - assert all( - [ - numeric_operators_weaviate[operator](r.tags['price'], threshold) - for r in results - ] - ) - - -numeric_operators_qrant = { +numeric_operators_qdrant = { 'gte': operator.ge, 'gt': operator.gt, 'lte': operator.le, @@ -283,12 +226,69 @@ def test_search_weaviate_pre_filtering(operator): } -@pytest.mark.parametrize('operator', list(numeric_operators_qrant.keys())) -def test_search_qdrant_pre_filtering(operator): - +@pytest.mark.parametrize( + 'storage,filter_gen,numeric_operators,operator', + [ + *[ + tuple( + [ + 'weaviate', + lambda operator, threshold: { + 'path': ['price'], + 'operator': operator, + 'valueInt': threshold, + }, + numeric_operators_weaviate, + operator, + ] + ) + for operator in numeric_operators_weaviate.keys() + ], + *[ + tuple( + [ + 'qdrant', + lambda operator, threshold: { + 'must': [{'key': 'price', 'range': {operator: threshold}}] + }, + numeric_operators_qdrant, + operator, + ] + ) + for operator in ['gte', 'gt', 'lte', 'lt'] + ], + *[ + tuple( + [ + 'qdrant', + lambda operator, threshold: { + 'must': [{'key': 'price', 'value': {operator: threshold}}] + }, + numeric_operators_qdrant, + operator, + ] + ) + for operator in ['eq', 'neq'] + ], + *[ + tuple( + [ + 'annlite', + lambda operator, threshold: {'price': {operator: threshold}}, + numeric_operators_annlite, + operator, + ] + ) + for operator in numeric_operators_annlite.keys() + ], + ], +) +def test_search_pre_filtering( + storage, filter_gen, operator, numeric_operators, start_storage +): n_dim = 128 da = DocumentArray( - storage='qdrant', config={'n_dim': n_dim, 'columns': [('price', 'int')]} + storage=storage, config={'n_dim': n_dim, 'columns': [('price', 'int')]} ) da.extend( @@ -301,16 +301,10 @@ def test_search_qdrant_pre_filtering(operator): for threshold in thresholds: - if operator in ('eq', 'neq'): - filter = {'key': 'price', 'value': {operator: threshold}} - else: - filter = {'key': 'price', 'range': {operator: threshold}} + filter = filter_gen(operator, threshold) - results = da.find(np.random.rand(n_dim), filter={'must': [filter]}) + results = da.find(np.random.rand(n_dim), filter=filter) assert all( - [ - numeric_operators_qrant[operator](r.tags['price'], threshold) - for r in results - ] + [numeric_operators[operator](r.tags['price'], threshold) for r in results] ) From a6d5d30ad79ba28db06fe1c343e23d19088a9c1f Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Thu, 19 May 2022 16:51:52 +0200 Subject: [PATCH 18/37] fix: map col types in annlite --- docarray/array/storage/annlite/backend.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/docarray/array/storage/annlite/backend.py b/docarray/array/storage/annlite/backend.py index 0d5b7ebb98a..2dd6560a757 100644 --- a/docarray/array/storage/annlite/backend.py +++ b/docarray/array/storage/annlite/backend.py @@ -33,6 +33,8 @@ class AnnliteConfig: class BackendMixin(BaseBackendMixin): """Provide necessary functions to enable this storage backend.""" + TYPE_MAP = {'str': 'TEXT', 'float': 'float', 'int': 'integer'} + def _map_embedding(self, embedding: 'ArrayType') -> 'ArrayType': if embedding is None: embedding = np.zeros(self.n_dim, dtype=np.float32) @@ -63,6 +65,13 @@ def _init_storage( config.data_path = TemporaryDirectory().name + if config.columns: + for i in range(len(config.columns)): + config.columns[i] = ( + config.columns[i][0], + self._map_type(config.columns[i][1]), + ) + self._config = config config = asdict(config) From a548c9feddbe99081e4090e7d40aa921b91565ce Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Fri, 20 May 2022 11:00:34 +0200 Subject: [PATCH 19/37] fix: start_storage at module level --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 09989e06156..22cbf372e20 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,7 +16,7 @@ def tmpfile(tmpdir): return tmpdir / tmpfile -@pytest.fixture(scope='session') +@pytest.fixture(scope='module') def start_storage(): os.system( f"docker-compose -f {compose_yml} --project-directory . up --build -d " From aacf3fe6637582c3fc4a0749b60a934bb2256b07 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Fri, 20 May 2022 11:05:17 +0200 Subject: [PATCH 20/37] test: cover both API usages --- tests/unit/array/mixins/test_filter.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/tests/unit/array/mixins/test_filter.py b/tests/unit/array/mixins/test_filter.py index f95cedf8461..165a650a5cc 100644 --- a/tests/unit/array/mixins/test_filter.py +++ b/tests/unit/array/mixins/test_filter.py @@ -22,29 +22,35 @@ def test_empty_filter(docs): assert len(result) == 5 -def test_simple_filter(docs): - result = docs.find({'text': {'$eq': 'hello'}}) +@pytest.mark.parametrize('filter_api', [True, False]) +def test_simple_filter(docs, filter_api): + if filter_api: + method = lambda query: docs.find(filter=query) + else: + method = lambda query: docs.find(query) + + result = method({'text': {'$eq': 'hello'}}) assert len(result) == 1 assert result[0].text == 'hello' - result = docs.find({'tags__x': {'$gte': 0.5}}) + result = method({'tags__x': {'$gte': 0.5}}) assert len(result) == 1 assert result[0].tags['x'] == 0.8 - result = docs.find({'tags__name': {'$regex': '^h'}}) + result = method({'tags__name': {'$regex': '^h'}}) assert len(result) == 2 assert result[1].id == docs[1].id - result = docs.find({'text': {'$regex': '^h'}}) + result = method({'text': {'$regex': '^h'}}) assert len(result) == 1 assert result[0].id == docs[0].id - result = docs.find({'tags': {'$size': 2}}) + result = method({'tags': {'$size': 2}}) assert result[0].id == docs[2].id - result = docs.find({'text': {'$exists': True}}) + result = method({'text': {'$exists': True}}) assert len(result) == 2 - result = docs.find({'tensor': {'$exists': True}}) + result = method({'tensor': {'$exists': True}}) assert len(result) == 0 From 1421ee7c9a149ac9c5ad6060145432a494a96118 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Fri, 20 May 2022 11:07:46 +0200 Subject: [PATCH 21/37] chore: address review --- docarray/array/storage/annlite/backend.py | 16 +++++++++------- docarray/array/storage/sqlite/backend.py | 1 - 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/docarray/array/storage/annlite/backend.py b/docarray/array/storage/annlite/backend.py index 2dd6560a757..4523c057db0 100644 --- a/docarray/array/storage/annlite/backend.py +++ b/docarray/array/storage/annlite/backend.py @@ -65,15 +65,17 @@ def _init_storage( config.data_path = TemporaryDirectory().name - if config.columns: - for i in range(len(config.columns)): - config.columns[i] = ( - config.columns[i][0], - self._map_type(config.columns[i][1]), - ) - self._config = config + if self._config.columns is None: + self._config.columns = [] + + for i in range(len(self._config.columns)): + self._config.columns[i] = ( + self._config.columns[i][0], + self._map_type(self._config.columns[i][1]), + ) + config = asdict(config) self.n_dim = config.pop('n_dim') diff --git a/docarray/array/storage/sqlite/backend.py b/docarray/array/storage/sqlite/backend.py index 8131af28054..555ce5fc892 100644 --- a/docarray/array/storage/sqlite/backend.py +++ b/docarray/array/storage/sqlite/backend.py @@ -35,7 +35,6 @@ class SqliteConfig: conn_config: Dict = field(default_factory=dict) journal_mode: str = 'DELETE' synchronous: str = 'OFF' - columns: Optional[List[Tuple[str, str]]] = None class BackendMixin(BaseBackendMixin): From 1986f82480e826618629a10716c3445c465b84af Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Fri, 20 May 2022 11:08:51 +0200 Subject: [PATCH 22/37] chore: address review --- docarray/array/storage/weaviate/find.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docarray/array/storage/weaviate/find.py b/docarray/array/storage/weaviate/find.py index 7babb941184..75e426926b5 100644 --- a/docarray/array/storage/weaviate/find.py +++ b/docarray/array/storage/weaviate/find.py @@ -60,16 +60,16 @@ def _find_similar_vectors( .do() ) docs = [] - - found_results = ( - results.get('data', {}).get('Get', {}).get(self._class_name, []) or [] - ) if 'errors' in results: errors = '\n'.join(map(lambda error: error['message'], results['errors'])) raise ValueError( f'find failed, please check your filter query. Errors: \n{errors}' ) + found_results = ( + results.get('data', {}).get('Get', {}).get(self._class_name, []) or [] + ) + # The serialized document is stored in results['data']['Get'][self._class_name] for result in found_results: doc = Document.from_base64(result['_serialized'], **self._serialize_config) From a497a85f4f33d16bdb48b884af20f655314658c8 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Fri, 20 May 2022 11:29:53 +0200 Subject: [PATCH 23/37] fix: cryptographic random generator for weaviate classnames --- docarray/array/storage/weaviate/backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docarray/array/storage/weaviate/backend.py b/docarray/array/storage/weaviate/backend.py index e7e53baa7fe..6ed777edbe1 100644 --- a/docarray/array/storage/weaviate/backend.py +++ b/docarray/array/storage/weaviate/backend.py @@ -118,7 +118,7 @@ def _get_weaviate_class_name(self) -> str: :return: string representing the name of weaviate class/schema name of this :class:`DocumentArrayWeaviate` object """ - return ''.join([i for i in uuid.uuid1().hex if not i.isdigit()]).capitalize() + return f'Class{uuid.uuid4().hex}' def _get_schema_by_name(self, cls_name: str) -> Dict: """Return the schema dictionary object with the class name From 200b83c05c55622538101fb8c4a9e5946cce2b15 Mon Sep 17 00:00:00 2001 From: David Buchaca Prats Date: Fri, 20 May 2022 14:25:58 +0200 Subject: [PATCH 24/37] refactor: make filter weaviate more readable --- docarray/array/storage/weaviate/find.py | 32 ++++++++++--------------- 1 file changed, 12 insertions(+), 20 deletions(-) diff --git a/docarray/array/storage/weaviate/find.py b/docarray/array/storage/weaviate/find.py index 75e426926b5..b282d0038b8 100644 --- a/docarray/array/storage/weaviate/find.py +++ b/docarray/array/storage/weaviate/find.py @@ -38,27 +38,19 @@ def _find_similar_vectors( query = query + EPSILON query_dict = {'vector': query} + + query_builder = ( + self._client.query.get(self._class_name, '_serialized') + .with_additional(['id', 'certainty']) + .with_limit(limit) + .with_near_vector(query_dict) + ) + if filter: - results = ( - self._client.query.get( - self._class_name, - ['_serialized', '_additional {certainty}', '_additional {id}'], - ) - .with_where(filter) - .with_limit(limit) - .with_near_vector(query_dict) - .do() - ) - else: - results = ( - self._client.query.get( - self._class_name, - ['_serialized', '_additional {certainty}', '_additional {id}'], - ) - .with_limit(limit) - .with_near_vector(query_dict) - .do() - ) + query_builder = query_builder.with_where(filter) + + results = query_builder.do() + docs = [] if 'errors' in results: errors = '\n'.join(map(lambda error: error['message'], results['errors'])) From b4991df8637aac6810ab2e4eed0d9c7654d66f64 Mon Sep 17 00:00:00 2001 From: David Buchaca Prats Date: Mon, 23 May 2022 16:33:34 +0200 Subject: [PATCH 25/37] docs: showcase pre-filtering in annlite --- docs/advanced/document-store/annlite.md | 76 +++++++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/docs/advanced/document-store/annlite.md b/docs/advanced/document-store/annlite.md index f953dac1ab6..c6c5f3c982f 100644 --- a/docs/advanced/document-store/annlite.md +++ b/docs/advanced/document-store/annlite.md @@ -48,3 +48,79 @@ The following configs can be set: | `max_connection` | The number of bi-directional links created for every new element during construction. | `None`, defaults to the default value in the AnnLite package* | *You can check the default values in [the AnnLite source code](https://github.com/jina-ai/annlite/blob/main/annlite/core/index/hnsw/index.py) + + +## Search with condition + +Search with `.find` can be restricted by user-defined filters. Such filters that can be constructed using the following operators: + + +| Name | Description | Equivalent Python operator | +|-------------------|------------------------|----------------------------| +| `$gte` | Greater or equal to | `>=` | +| `$gt` | Greater than | `>` | +| `$lte` | Less or equal to | `<=` | +| `$lt` | Less than ) | `<` | +| `$eq` | Equal to | `==` | +| `$neq` | Not equal to | `!=` | + +Such filters can be constructing following the guidelines provided in [the AnnLite source repository](https://github.com/jina-ai/annlite). + +### Example of `.find` with a filter + +Consider Documents with embeddings `[0,0,0]` up to ` [9,9,9]` where the document with embedding `[i,i,i]` +has as tag `price` with value `i`. We can create such example with the following code: + + +```python +from docarray import Document, DocumentArray +import numpy as np + +n_dim = 3 +metric = 'Euclidean' + +da = DocumentArray( + storage='annlite', + config={'n_dim': n_dim, 'columns': [('price', 'float')], 'metric': metric}, +) + +da.extend( + [ + Document(id=f'r{i}', embedding=i * np.ones(n_dim), tags={'price': i}) + for i in range(10) + ] +) +``` + +Consider we want the nearest vectors to the embedding `[8. 8. 8.]`, with the restriction that +prices must follow a filter. As an example, let us consider that retrieved documents must have `price` value lower +or equal than `max_price`. We can encode this information in annlite using `filter = {'price': {'$lte': max_price}}`. + +Then the search with the proposed filter can implemented and used with the following code: + +```python +max_price = 7 +n_limit = 4 + +np_query = np.ones(n_dim) * 8 +print(f'\nQuery vector: \t{np_query}') + +filter = {'price': {'$lte': max_price}} +results = da.find(np_query, filter=filter, limit=n_limit) +print('\nEmbeddings Nearest Neighbours with "price" at most 7:\n') +for embedding, price in zip(results.embeddings, results[:, 'tags__price']): + print(f'\tembedding={embedding},\t price={price}') +``` + +This would print + +```bash +Query vector: [8. 8. 8.] + +Embeddings Nearest Neighbours with "price" at most 7: + + embedding=[7. 7. 7.], price=7 + embedding=[6. 6. 6.], price=6 + embedding=[5. 5. 5.], price=5 + embedding=[4. 4. 4.], price=4 + ``` From d0e18e860fcc693e27275e856b939e11da25d5cd Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Tue, 24 May 2022 10:34:37 +0200 Subject: [PATCH 26/37] docs: document filter parameter --- docarray/array/mixins/find.py | 11 ++++++++--- docarray/array/storage/annlite/find.py | 1 + docarray/array/storage/elastic/find.py | 1 + docarray/array/storage/memory/find.py | 2 +- docarray/array/storage/qdrant/find.py | 1 + docarray/array/storage/weaviate/find.py | 1 + 6 files changed, 13 insertions(+), 4 deletions(-) diff --git a/docarray/array/mixins/find.py b/docarray/array/mixins/find.py index c6286ae8f66..8b6a0e42c92 100644 --- a/docarray/array/mixins/find.py +++ b/docarray/array/mixins/find.py @@ -103,9 +103,13 @@ def find( ) -> Union['DocumentArray', List['DocumentArray']]: """Returns matching Documents given an input query. If the query is a `DocumentArray`, `Document` or `ArrayType`, exhaustive or approximate nearest neighbor search - will be performed depending on whether the storage backend supports ANN. - If the query is a `dict` object, Documents will be filtered according to DocArray's query language and all - matching Documents that match the filter will be returned. + will be performed depending on whether the storage backend supports ANN. Furthermore, if filter is not None, + pre-filtering will be applied along with vector search. + If the query is a `dict` object or, query is None and filter is not None, Documents will be filtered and all + matching Documents that match the filter will be returned. In this case, query (if it's dict) or filter will be + used for filtering. The object must follow the backend-specific filter format if the backend supports filtering + or DocArray's query language format. In the latter case, filtering will be applied in the client side not the + backend side. If the query is a string or list of strings, a search by text will be performed if the backend supports indexing and searching text fields. If not, a `NotImplementedError` will be raised. @@ -115,6 +119,7 @@ def find( :param metric: the distance metric. :param exclude_self: if set, Documents in results with same ``id`` as the query values will not be considered as matches. This is only applied when the input query is Document or DocumentArray. + :param filter: filter query used for pre-filtering or filtering :param only_id: if set, then returning matches will only contain ``id`` :param index: if the query is a string, text search will be performed on the `index` field, otherwise, this parameter is ignored. By default, the Document `text` attribute will be used for search, diff --git a/docarray/array/storage/annlite/find.py b/docarray/array/storage/annlite/find.py index 26576ab15c7..9084eaf35ff 100644 --- a/docarray/array/storage/annlite/find.py +++ b/docarray/array/storage/annlite/find.py @@ -25,6 +25,7 @@ def _find( :param query: the query documents to search. :param limit: the number of results to get for each query document in search. :param only_id: if set, then returning matches will only contain ``id`` + :param filter: filter query used for pre-filtering :param kwargs: other kwargs. :return: a list of DocumentArrays containing the closest Document objects for each of the queries in `query`. diff --git a/docarray/array/storage/elastic/find.py b/docarray/array/storage/elastic/find.py index 0c920e5540f..95e04b7113b 100644 --- a/docarray/array/storage/elastic/find.py +++ b/docarray/array/storage/elastic/find.py @@ -108,6 +108,7 @@ def _find( :param query: input supported to be stored in Elastic. This includes any from the list '[np.ndarray, tensorflow.Tensor, torch.Tensor, Sequence[float]]' :param limit: number of retrieved items + :param filter: filter query used for pre-filtering :return: DocumentArray containing the closest documents to the query if it is a single query, otherwise a list of DocumentArrays containing the closest Document objects for each of the queries in `query`. diff --git a/docarray/array/storage/memory/find.py b/docarray/array/storage/memory/find.py index 45ad076aab1..b66917ec649 100644 --- a/docarray/array/storage/memory/find.py +++ b/docarray/array/storage/memory/find.py @@ -48,7 +48,7 @@ def _find( .. note:: This argument is only effective when ``batch_size`` is set. - + :param filter: filter query used for pre-filtering :param kwargs: other kwargs. :return: a list of DocumentArrays containing the closest Document objects for each of the queries in `query`. diff --git a/docarray/array/storage/qdrant/find.py b/docarray/array/storage/qdrant/find.py index 2119632d441..52f5b0cc11e 100644 --- a/docarray/array/storage/qdrant/find.py +++ b/docarray/array/storage/qdrant/find.py @@ -87,6 +87,7 @@ def _find( """Returns approximate nearest neighbors given a batch of input queries. :param query: input supported to be used in Qdrant. :param limit: number of retrieved items + :param filter: filter query used for pre-filtering :return: a list of DocumentArrays containing the closest Document objects for each of the queries in `query`. diff --git a/docarray/array/storage/weaviate/find.py b/docarray/array/storage/weaviate/find.py index b282d0038b8..3cf73477266 100644 --- a/docarray/array/storage/weaviate/find.py +++ b/docarray/array/storage/weaviate/find.py @@ -90,6 +90,7 @@ def _find( """Returns approximate nearest neighbors given a batch of input queries. :param query: input supported to be stored in Weaviate. This includes any from the list '[np.ndarray, tensorflow.Tensor, torch.Tensor, Sequence[float]]' :param limit: number of retrieved items + :param filter: filter query used for pre-filtering :return: DocumentArray containing the closest documents to the query if it is a single query, otherwise a list of DocumentArrays containing the closest Document objects for each of the queries in `query`. From 4d65d70c9c21778b9901b6bbcdc1f0be4d7a168a Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Tue, 24 May 2022 11:13:42 +0200 Subject: [PATCH 27/37] refactor: refactor find type checking --- docarray/array/mixins/find.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/docarray/array/mixins/find.py b/docarray/array/mixins/find.py index 8b6a0e42c92..a79c056fe6d 100644 --- a/docarray/array/mixins/find.py +++ b/docarray/array/mixins/find.py @@ -132,18 +132,30 @@ def find( from ... import Document, DocumentArray - if isinstance(query, dict) and filter is None: - return self._filter(query) - elif query is None and isinstance(filter, dict): - return self._filter(filter) + if isinstance(query, dict): + if filter is None: + return self._filter(query) + else: + raise ValueError( + 'filter and query cannot be both dict type, set only one for filtering' + ) + elif query is None: + if isinstance(filter, dict): + return self._filter(filter) + else: + raise ValueError('filter must be dict when query is None') elif isinstance(query, str) or ( isinstance(query, list) and isinstance(query[0], str) ): + if filter is not None: + raise ValueError('cannot use filter with text search') result = self._find_by_text(query, index=index, limit=limit, **kwargs) if isinstance(query, str): return result[0] else: return result + + # for all the rest, vector search will be performed elif isinstance(query, (DocumentArray, Document)): if isinstance(query, Document): From d6bbb5a787a0693e086a5d2525ee5bc6ba8e3699 Mon Sep 17 00:00:00 2001 From: David Buchaca Prats Date: Tue, 24 May 2022 11:14:16 +0200 Subject: [PATCH 28/37] docs: add qdrant filtering documentation --- docs/advanced/document-store/annlite.md | 4 +- docs/advanced/document-store/qdrant.md | 83 +++++++++++++++++++++++++ 2 files changed, 85 insertions(+), 2 deletions(-) diff --git a/docs/advanced/document-store/annlite.md b/docs/advanced/document-store/annlite.md index c6c5f3c982f..7c6be0c1315 100644 --- a/docs/advanced/document-store/annlite.md +++ b/docs/advanced/document-store/annlite.md @@ -50,7 +50,7 @@ The following configs can be set: *You can check the default values in [the AnnLite source code](https://github.com/jina-ai/annlite/blob/main/annlite/core/index/hnsw/index.py) -## Search with condition +## Search with filter Search with `.find` can be restricted by user-defined filters. Such filters that can be constructed using the following operators: @@ -64,7 +64,7 @@ Search with `.find` can be restricted by user-defined filters. Such filters that | `$eq` | Equal to | `==` | | `$neq` | Not equal to | `!=` | -Such filters can be constructing following the guidelines provided in [the AnnLite source repository](https://github.com/jina-ai/annlite). +filters can be constructing following the guidelines provided in [the AnnLite source repository](https://github.com/jina-ai/annlite). ### Example of `.find` with a filter diff --git a/docs/advanced/document-store/qdrant.md b/docs/advanced/document-store/qdrant.md index 5dc3987c912..1315ebb8f1e 100644 --- a/docs/advanced/document-store/qdrant.md +++ b/docs/advanced/document-store/qdrant.md @@ -132,3 +132,86 @@ print(da.find(np.random.random(D), limit=10)) ```bash ``` + + +## Search with filter + +Search with `.find` can be restricted by user-defined filters. Such filters that can be constructed using the following operators: + + +| Name | Description | Equivalent Python operator | +|-------------|------------------------|----------------------------| +| `gte` | Greater or equal to | `>=` | +| `gt` | Greater than | `>` | +| `lte` | Less or equal to | `<=` | +| `lt` | Less than ) | `<` | +| `eq` | Equal to | `==` | +| `neq` | Not equal to | `!=` | + + +### Example of `.find` with a filter + + +Consider Documents with embeddings `[0,0,0]` up to ` [9,9,9]` where the document with embedding `[i,i,i]` +has as tag `price` with value `i`. We can create such example with the following code: + +```python +from docarray import Document, DocumentArray +import numpy as np + +n_dim = 3 +distance = 'euclidean' + +da = DocumentArray( + storage='qdrant', + config={'n_dim': n_dim, 'columns': [('price', 'float')], 'distance': distance}, +) + +print(f'\nDocumentArray distance: {distance}') + +da.extend( + [ + Document(id=f'r{i}', embedding=i * np.ones(n_dim), tags={'price': i}) + for i in range(10) + ] +) + +print('\nIndexed Embeddings:\n') +for embedding, price in zip(da.embeddings, da[:, 'tags__price']): + print(f'\tembedding={embedding},\t price={price}') +``` + +Consider we want the nearest vectors to the embedding `[8. 8. 8.]`, with the restriction that +prices must follow a filter. As an example, let us consider that retrieved documents must have `price` value lower +or equal than `max_price`. We can encode this information in annlite using `filter = {'price': {'$lte': max_price}}`. + +Then the search with the proposed filter can implemented and used with the following code: + +```python +max_price = 7 +n_limit = 4 + +np_query = np.ones(n_dim) * 8 +print(f'\nQuery vector: \t{np_query}') + +filter = {'must': [{'key': 'price', 'range': {'lte': max_price}}]} +results = da.find(np_query, filter=filter, limit=n_limit) + +print('\nEmbeddings Nearest Neighbours with "price" at most 7:\n') +for embedding, price in zip(results.embeddings, results[:, 'tags__price']): + print(f'\tembedding={embedding},\t price={price}') +``` + +This would print + + +``` +Query vector: [8. 8. 8.] + +Embeddings Nearest Neighbours with "price" at most 7: + + embedding=[7. 7. 7.], price=7 + embedding=[6. 6. 6.], price=6 + embedding=[5. 5. 5.], price=5 + embedding=[4. 4. 4.], price=4 +``` From cbfaafea60394d391626ff17c03ddf7f391a48b5 Mon Sep 17 00:00:00 2001 From: David Buchaca Prats Date: Tue, 24 May 2022 11:26:10 +0200 Subject: [PATCH 29/37] docs: add filter weaviate --- docs/advanced/document-store/qdrant.md | 1 - docs/advanced/document-store/weaviate.md | 89 ++++++++++++++++++++++++ 2 files changed, 89 insertions(+), 1 deletion(-) diff --git a/docs/advanced/document-store/qdrant.md b/docs/advanced/document-store/qdrant.md index 1315ebb8f1e..69fb1dc90d2 100644 --- a/docs/advanced/document-store/qdrant.md +++ b/docs/advanced/document-store/qdrant.md @@ -204,7 +204,6 @@ for embedding, price in zip(results.embeddings, results[:, 'tags__price']): This would print - ``` Query vector: [8. 8. 8.] diff --git a/docs/advanced/document-store/weaviate.md b/docs/advanced/document-store/weaviate.md index a1ea3afd426..3b2a60c4a44 100644 --- a/docs/advanced/document-store/weaviate.md +++ b/docs/advanced/document-store/weaviate.md @@ -167,3 +167,92 @@ print(results[0].text) ```text Persist Documents with Weaviate. ``` + + +## Search with filter + +Search with `.find` can be restricted by user-defined filters. Such filters that can be constructed using the following operators: + +| Name | Description | Equivalent Python operator | +|-------------------|------------------------|----------------------------| +| `GreaterThanEqual`| Greater or equal to | `>=` | +| `GreaterThan` | Greater than | `>` | +| `LessThanEqual` | Less or equal to | `<=` | +| `LessThan` | Less than ) | `<` | +| `Equal` | Equal to | `==` | +| `NotEqual` | Not equal to | `!=` | + + + +### Example of `.find` with a filter + +Consider Documents with embeddings `[0,0,0]` up to ` [9,9,9]` where the document with embedding `[i,i,i]` +has as tag `price` with value `i`. We can create such example with the following code: + + +```python +from docarray import Document, DocumentArray +import numpy as np + +n_dim = 3 + +da = DocumentArray( + storage='qdrant', + config={ + 'n_dim': n_dim, + 'columns': [('price', 'float')], + #'distance':distance + }, +) + +da.extend( + [ + Document(id=f'r{i}', embedding=i * np.ones(n_dim), tags={'price': i}) + for i in range(10) + ] +) + +print('\nIndexed Embeddings:\n') +for embedding, price in zip(da.embeddings, da[:, 'tags__price']): + print(f'\tembedding={embedding},\t price={price}') +``` + +Consider we want the nearest vectors to the embedding `[8. 8. 8.]`, with the restriction that +prices must follow a filter. As an example, let us consider that retrieved documents must have `price` value lower +or equal than `max_price`. We can encode this information in annlite using `filter = {'price': {'$lte': max_price}}`. + +Then the search with the proposed filter can implemented and used with the following code: + +```python +max_price = 7 +n_limit = 4 + +np_query = np.ones(n_dim) * 8 +print(f'\nQuery vector: \t{np_query}') + +filter = {'path': ['price'], 'operator': 'LowerThanEqual', 'valueInt': max_price} +results = da.find(np_query, filter=filter, limit=n_limit) + +print('\nEmbeddings Nearest Neighbours with "price" at most 7:\n') +for embedding, price in zip(results.embeddings, results[:, 'tags__price']): + print(f'\tembedding={embedding},\t price={price}') +``` + +This would print + +```bash +Embeddings Nearest Neighbours with "price" at most 7: + + embedding=[3. 3. 3.], price=3 + embedding=[6. 6. 6.], price=6 + embedding=[9. 9. 9.], price=9 + embedding=[1. 1. 1.], price=1 + +Embeddings Nearest Neighbours without restriction: + [[3. 3. 3.] + [6. 6. 6.] + [1. 1. 1.] + [2. 2. 2.]] + ``` + +Note that currently Weaviate only supports the cosine distance. \ No newline at end of file From d8d6664bf8189ad69b8a33c1d6545256b07e0672 Mon Sep 17 00:00:00 2001 From: David Buchaca Prats Date: Tue, 24 May 2022 11:28:33 +0200 Subject: [PATCH 30/37] docs: use context manager in examples --- docs/advanced/document-store/annlite.md | 13 +++++++------ docs/advanced/document-store/qdrant.md | 13 +++++++------ docs/advanced/document-store/weaviate.md | 13 +++++++------ 3 files changed, 21 insertions(+), 18 deletions(-) diff --git a/docs/advanced/document-store/annlite.md b/docs/advanced/document-store/annlite.md index 7c6be0c1315..cab8f34568b 100644 --- a/docs/advanced/document-store/annlite.md +++ b/docs/advanced/document-store/annlite.md @@ -84,12 +84,13 @@ da = DocumentArray( config={'n_dim': n_dim, 'columns': [('price', 'float')], 'metric': metric}, ) -da.extend( - [ - Document(id=f'r{i}', embedding=i * np.ones(n_dim), tags={'price': i}) - for i in range(10) - ] -) +with da: + da.extend( + [ + Document(id=f'r{i}', embedding=i * np.ones(n_dim), tags={'price': i}) + for i in range(10) + ] + ) ``` Consider we want the nearest vectors to the embedding `[8. 8. 8.]`, with the restriction that diff --git a/docs/advanced/document-store/qdrant.md b/docs/advanced/document-store/qdrant.md index 69fb1dc90d2..6332afc5617 100644 --- a/docs/advanced/document-store/qdrant.md +++ b/docs/advanced/document-store/qdrant.md @@ -169,12 +169,13 @@ da = DocumentArray( print(f'\nDocumentArray distance: {distance}') -da.extend( - [ - Document(id=f'r{i}', embedding=i * np.ones(n_dim), tags={'price': i}) - for i in range(10) - ] -) +with da: + da.extend( + [ + Document(id=f'r{i}', embedding=i * np.ones(n_dim), tags={'price': i}) + for i in range(10) + ] + ) print('\nIndexed Embeddings:\n') for embedding, price in zip(da.embeddings, da[:, 'tags__price']): diff --git a/docs/advanced/document-store/weaviate.md b/docs/advanced/document-store/weaviate.md index 3b2a60c4a44..82a358a06c5 100644 --- a/docs/advanced/document-store/weaviate.md +++ b/docs/advanced/document-store/weaviate.md @@ -205,12 +205,13 @@ da = DocumentArray( }, ) -da.extend( - [ - Document(id=f'r{i}', embedding=i * np.ones(n_dim), tags={'price': i}) - for i in range(10) - ] -) +with da: + da.extend( + [ + Document(id=f'r{i}', embedding=i * np.ones(n_dim), tags={'price': i}) + for i in range(10) + ] + ) print('\nIndexed Embeddings:\n') for embedding, price in zip(da.embeddings, da[:, 'tags__price']): From b2936bdf2920099fe681d100a667f1d777eb4b89 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Tue, 24 May 2022 11:31:44 +0200 Subject: [PATCH 31/37] test: cover wrong filter format for weaviate --- tests/unit/array/mixins/test_find.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/unit/array/mixins/test_find.py b/tests/unit/array/mixins/test_find.py index afcc3502042..cf157584ef0 100644 --- a/tests/unit/array/mixins/test_find.py +++ b/tests/unit/array/mixins/test_find.py @@ -308,3 +308,20 @@ def test_search_pre_filtering( assert all( [numeric_operators[operator](r.tags['price'], threshold) for r in results] ) + + +def test_weaviate_filter_query(start_storage): + n_dim = 128 + da = DocumentArray( + storage='weaviate', config={'n_dim': n_dim, 'columns': [('price', 'int')]} + ) + + da.extend( + [ + Document(id=f'r{i}', embedding=np.random.rand(n_dim), tags={'price': i}) + for i in range(50) + ] + ) + + with pytest.raises(ValueError): + da.find(np.random.rand(n_dim), filter={'wrong': 'filter'}) From 92c637a869e7ed654d013a8c64a53007732f6f92 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Tue, 24 May 2022 11:55:05 +0200 Subject: [PATCH 32/37] test: cover unsupported pre-filtering for memory and elastic --- tests/unit/array/mixins/test_find.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/unit/array/mixins/test_find.py b/tests/unit/array/mixins/test_find.py index cf157584ef0..b77c2122b7c 100644 --- a/tests/unit/array/mixins/test_find.py +++ b/tests/unit/array/mixins/test_find.py @@ -325,3 +325,22 @@ def test_weaviate_filter_query(start_storage): with pytest.raises(ValueError): da.find(np.random.rand(n_dim), filter={'wrong': 'filter'}) + + +@pytest.mark.parametrize('storage', ['memory', 'elasticsearch']) +def test_unsupported_pre_filtering(storage, start_storage): + + n_dim = 128 + da = DocumentArray( + storage=storage, config={'n_dim': n_dim, 'columns': [('price', 'int')]} + ) + + da.extend( + [ + Document(id=f'r{i}', embedding=np.random.rand(n_dim), tags={'price': i}) + for i in range(50) + ] + ) + + with pytest.raises(ValueError): + da.find(np.random.rand(n_dim), filter={'price': {'$gte': 2}}) From d368a588372201fd600cdb5d71aa510bbe8479ef Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Tue, 24 May 2022 15:38:25 +0200 Subject: [PATCH 33/37] test: update test_embedding_ops_error --- tests/unit/array/mixins/test_exception.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/array/mixins/test_exception.py b/tests/unit/array/mixins/test_exception.py index a7ad7fa1733..e611a563bd4 100644 --- a/tests/unit/array/mixins/test_exception.py +++ b/tests/unit/array/mixins/test_exception.py @@ -34,5 +34,5 @@ def test_embedding_ops_error(): with pytest.raises(ValueError, match='Did you forget to set'): db.find(da) da.embeddings = np.random.random([100, 256]) - with pytest.raises(ValueError, match='Did you forget to set'): + with pytest.raises(ValueError, match='filter must be dict when query is None'): da.find(None) From 7cde83b60a6add36d31fbd5bc71bb53ee3dca4c1 Mon Sep 17 00:00:00 2001 From: David Buchaca Prats Date: Tue, 24 May 2022 15:48:23 +0200 Subject: [PATCH 34/37] docs: add weaviate filter reference --- docs/advanced/document-store/weaviate.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/advanced/document-store/weaviate.md b/docs/advanced/document-store/weaviate.md index 82a358a06c5..12e507e8e01 100644 --- a/docs/advanced/document-store/weaviate.md +++ b/docs/advanced/document-store/weaviate.md @@ -182,7 +182,8 @@ Search with `.find` can be restricted by user-defined filters. Such filters that | `Equal` | Equal to | `==` | | `NotEqual` | Not equal to | `!=` | - + filters can be constructing following the guidelines provided in [the weaviate documentation](https://weaviate.io/developers/weaviate/current/graphql-references/filters.html). + ### Example of `.find` with a filter From 3d4d13f98af432e4499fb6089f2802e35841d7c8 Mon Sep 17 00:00:00 2001 From: David Buchaca Prats Date: Tue, 24 May 2022 16:03:35 +0200 Subject: [PATCH 35/37] docs: fix typos --- docs/advanced/document-store/annlite.md | 6 +++--- docs/advanced/document-store/qdrant.md | 4 ++-- docs/advanced/document-store/weaviate.md | 8 ++++---- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/advanced/document-store/annlite.md b/docs/advanced/document-store/annlite.md index cab8f34568b..5c743fbd0bf 100644 --- a/docs/advanced/document-store/annlite.md +++ b/docs/advanced/document-store/annlite.md @@ -50,7 +50,7 @@ The following configs can be set: *You can check the default values in [the AnnLite source code](https://github.com/jina-ai/annlite/blob/main/annlite/core/index/hnsw/index.py) -## Search with filter +## Vector search with filter Search with `.find` can be restricted by user-defined filters. Such filters that can be constructed using the following operators: @@ -64,7 +64,7 @@ Search with `.find` can be restricted by user-defined filters. Such filters that | `$eq` | Equal to | `==` | | `$neq` | Not equal to | `!=` | -filters can be constructing following the guidelines provided in [the AnnLite source repository](https://github.com/jina-ai/annlite). +Filters can be constructed following the guidelines provided in [the AnnLite source repository](https://github.com/jina-ai/annlite). ### Example of `.find` with a filter @@ -94,7 +94,7 @@ with da: ``` Consider we want the nearest vectors to the embedding `[8. 8. 8.]`, with the restriction that -prices must follow a filter. As an example, let us consider that retrieved documents must have `price` value lower +prices must follow a filter. As an example, let's consider that retrieved documents must have `price` value lower or equal than `max_price`. We can encode this information in annlite using `filter = {'price': {'$lte': max_price}}`. Then the search with the proposed filter can implemented and used with the following code: diff --git a/docs/advanced/document-store/qdrant.md b/docs/advanced/document-store/qdrant.md index 6332afc5617..696f65d9dcf 100644 --- a/docs/advanced/document-store/qdrant.md +++ b/docs/advanced/document-store/qdrant.md @@ -134,7 +134,7 @@ print(da.find(np.random.random(D), limit=10)) ``` -## Search with filter +## Vector search with filter Search with `.find` can be restricted by user-defined filters. Such filters that can be constructed using the following operators: @@ -183,7 +183,7 @@ for embedding, price in zip(da.embeddings, da[:, 'tags__price']): ``` Consider we want the nearest vectors to the embedding `[8. 8. 8.]`, with the restriction that -prices must follow a filter. As an example, let us consider that retrieved documents must have `price` value lower +prices must follow a filter. As an example, let's consider that retrieved documents must have `price` value lower or equal than `max_price`. We can encode this information in annlite using `filter = {'price': {'$lte': max_price}}`. Then the search with the proposed filter can implemented and used with the following code: diff --git a/docs/advanced/document-store/weaviate.md b/docs/advanced/document-store/weaviate.md index ad1fdc12ec1..d58685de891 100644 --- a/docs/advanced/document-store/weaviate.md +++ b/docs/advanced/document-store/weaviate.md @@ -170,7 +170,7 @@ Persist Documents with Weaviate. ``` -## Search with filter +## Vector search with filter Search with `.find` can be restricted by user-defined filters. Such filters that can be constructed using the following operators: @@ -183,7 +183,7 @@ Search with `.find` can be restricted by user-defined filters. Such filters that | `Equal` | Equal to | `==` | | `NotEqual` | Not equal to | `!=` | - filters can be constructing following the guidelines provided in [the weaviate documentation](https://weaviate.io/developers/weaviate/current/graphql-references/filters.html). +Filters can be constructed following the guidelines provided in [the weaviate documentation](https://weaviate.io/developers/weaviate/current/graphql-references/filters.html). ### Example of `.find` with a filter @@ -221,8 +221,8 @@ for embedding, price in zip(da.embeddings, da[:, 'tags__price']): ``` Consider we want the nearest vectors to the embedding `[8. 8. 8.]`, with the restriction that -prices must follow a filter. As an example, let us consider that retrieved documents must have `price` value lower -or equal than `max_price`. We can encode this information in annlite using `filter = {'price': {'$lte': max_price}}`. +prices must follow a filter. As an example, let's consider that retrieved documents must have `price` value lower +or equal than `max_price`. We can encode this information in weaviate using `filter = {'path': ['price'], 'operator': 'LowerThanEqual', 'valueInt': max_price}`. Then the search with the proposed filter can implemented and used with the following code: From d1dd27c95e8b2945bb98aabf683b104547fd06ef Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Tue, 24 May 2022 16:56:36 +0200 Subject: [PATCH 36/37] docs: reference backend filter query language definition --- docs/advanced/document-store/annlite.md | 15 ++------------- docs/advanced/document-store/qdrant.md | 15 +++------------ docs/advanced/document-store/weaviate.md | 19 +++---------------- 3 files changed, 8 insertions(+), 41 deletions(-) diff --git a/docs/advanced/document-store/annlite.md b/docs/advanced/document-store/annlite.md index 5c743fbd0bf..e98e83b20f2 100644 --- a/docs/advanced/document-store/annlite.md +++ b/docs/advanced/document-store/annlite.md @@ -52,18 +52,7 @@ The following configs can be set: ## Vector search with filter -Search with `.find` can be restricted by user-defined filters. Such filters that can be constructed using the following operators: - - -| Name | Description | Equivalent Python operator | -|-------------------|------------------------|----------------------------| -| `$gte` | Greater or equal to | `>=` | -| `$gt` | Greater than | `>` | -| `$lte` | Less or equal to | `<=` | -| `$lt` | Less than ) | `<` | -| `$eq` | Equal to | `==` | -| `$neq` | Not equal to | `!=` | - +Search with `.find` can be restricted by user-defined filters. Filters can be constructed following the guidelines provided in [the AnnLite source repository](https://github.com/jina-ai/annlite). ### Example of `.find` with a filter @@ -113,7 +102,7 @@ for embedding, price in zip(results.embeddings, results[:, 'tags__price']): print(f'\tembedding={embedding},\t price={price}') ``` -This would print +This would print: ```bash Query vector: [8. 8. 8.] diff --git a/docs/advanced/document-store/qdrant.md b/docs/advanced/document-store/qdrant.md index 696f65d9dcf..02cfedb3542 100644 --- a/docs/advanced/document-store/qdrant.md +++ b/docs/advanced/document-store/qdrant.md @@ -136,17 +136,8 @@ print(da.find(np.random.random(D), limit=10)) ## Vector search with filter -Search with `.find` can be restricted by user-defined filters. Such filters that can be constructed using the following operators: - - -| Name | Description | Equivalent Python operator | -|-------------|------------------------|----------------------------| -| `gte` | Greater or equal to | `>=` | -| `gt` | Greater than | `>` | -| `lte` | Less or equal to | `<=` | -| `lt` | Less than ) | `<` | -| `eq` | Equal to | `==` | -| `neq` | Not equal to | `!=` | +Search with `.find` can be restricted by user-defined filters. Such filters can be constructed following the guidelines +in [Qdrant's Documentation](https://qdrant.tech/documentation/filtering/) ### Example of `.find` with a filter @@ -203,7 +194,7 @@ for embedding, price in zip(results.embeddings, results[:, 'tags__price']): print(f'\tembedding={embedding},\t price={price}') ``` -This would print +This would print: ``` Query vector: [8. 8. 8.] diff --git a/docs/advanced/document-store/weaviate.md b/docs/advanced/document-store/weaviate.md index d58685de891..94b911fd460 100644 --- a/docs/advanced/document-store/weaviate.md +++ b/docs/advanced/document-store/weaviate.md @@ -172,19 +172,8 @@ Persist Documents with Weaviate. ## Vector search with filter -Search with `.find` can be restricted by user-defined filters. Such filters that can be constructed using the following operators: - -| Name | Description | Equivalent Python operator | -|-------------------|------------------------|----------------------------| -| `GreaterThanEqual`| Greater or equal to | `>=` | -| `GreaterThan` | Greater than | `>` | -| `LessThanEqual` | Less or equal to | `<=` | -| `LessThan` | Less than ) | `<` | -| `Equal` | Equal to | `==` | -| `NotEqual` | Not equal to | `!=` | - -Filters can be constructed following the guidelines provided in [the weaviate documentation](https://weaviate.io/developers/weaviate/current/graphql-references/filters.html). - +Search with `.find` can be restricted by user-defined filters. Such filters can be constructed following the guidelines +in [Weaviate's Documentation](https://weaviate.io/developers/weaviate/current/graphql-references/filters.html). ### Example of `.find` with a filter @@ -241,7 +230,7 @@ for embedding, price in zip(results.embeddings, results[:, 'tags__price']): print(f'\tembedding={embedding},\t price={price}') ``` -This would print +This would print: ```bash Embeddings Nearest Neighbours with "price" at most 7: @@ -257,5 +246,3 @@ Embeddings Nearest Neighbours without restriction: [1. 1. 1.] [2. 2. 2.]] ``` - -Note that currently Weaviate only supports the cosine distance. \ No newline at end of file From bbc1473e8422d3696f102d5864f7cf49c1f8b4bc Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Tue, 24 May 2022 17:28:53 +0200 Subject: [PATCH 37/37] docs: fix weaviate example --- docs/advanced/document-store/weaviate.md | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/docs/advanced/document-store/weaviate.md b/docs/advanced/document-store/weaviate.md index 94b911fd460..a66169da677 100644 --- a/docs/advanced/document-store/weaviate.md +++ b/docs/advanced/document-store/weaviate.md @@ -188,12 +188,8 @@ import numpy as np n_dim = 3 da = DocumentArray( - storage='qdrant', - config={ - 'n_dim': n_dim, - 'columns': [('price', 'float')], - #'distance':distance - }, + storage='weaviate', + config={'n_dim': n_dim, 'columns': [('price', 'int')], 'distance': 'l2-squared'}, ) with da: @@ -222,7 +218,7 @@ n_limit = 4 np_query = np.ones(n_dim) * 8 print(f'\nQuery vector: \t{np_query}') -filter = {'path': ['price'], 'operator': 'LowerThanEqual', 'valueInt': max_price} +filter = {'path': ['price'], 'operator': 'LessThanEqual', 'valueInt': max_price} results = da.find(np_query, filter=filter, limit=n_limit) print('\nEmbeddings Nearest Neighbours with "price" at most 7:\n') @@ -235,14 +231,8 @@ This would print: ```bash Embeddings Nearest Neighbours with "price" at most 7: - embedding=[3. 3. 3.], price=3 + embedding=[7. 7. 7.], price=7 embedding=[6. 6. 6.], price=6 - embedding=[9. 9. 9.], price=9 - embedding=[1. 1. 1.], price=1 - -Embeddings Nearest Neighbours without restriction: - [[3. 3. 3.] - [6. 6. 6.] - [1. 1. 1.] - [2. 2. 2.]] + embedding=[5. 5. 5.], price=5 + embedding=[4. 4. 4.], price=4 ```