diff --git a/docarray/array/storage/qdrant/backend.py b/docarray/array/storage/qdrant/backend.py index 13d27bb72ee..ef3c7abd463 100644 --- a/docarray/array/storage/qdrant/backend.py +++ b/docarray/array/storage/qdrant/backend.py @@ -14,6 +14,7 @@ import numpy as np from qdrant_client import QdrantClient +from qdrant_client.http import models from qdrant_client.http.models.models import ( Distance, CreateCollection, @@ -24,7 +25,7 @@ ) from docarray import Document -from docarray.array.storage.base.backend import BaseBackendMixin +from docarray.array.storage.base.backend import BaseBackendMixin, TypeMap from docarray.array.storage.qdrant.helper import DISTANCES from docarray.helper import dataclass_from_dict, random_identity from docarray.math.helper import EPSILON @@ -74,6 +75,15 @@ def distance(self) -> 'Distance': def _tmp_collection_name(cls) -> str: return uuid.uuid4().hex + TYPE_MAP = { + 'int': TypeMap(type='integer', converter=int), + 'float': TypeMap(type='float', converter=float), + 'bool': TypeMap(type='int', converter=bool), + 'str': TypeMap(type='keyword', converter=str), + 'text': TypeMap(type='text', converter=str), + 'geo': TypeMap(type='geo', converter=dict), + } + def _init_storage( self, docs: Optional['DocumentArraySourceType'] = None, @@ -172,6 +182,23 @@ def _initialize_qdrant_schema(self): hnsw_config=hnsw_config, ) + for col, coltype in self._config.columns.items(): + if coltype == 'text': + self.client.create_payload_index( + collection_name=self.collection_name, + field_name=col, + field_schema=models.TextIndexParams( + type="text", + tokenizer=models.TokenizerType.WORD, + ), + ) + else: + self.client.create_payload_index( + collection_name=self.collection_name, + field_name=col, + field_schema=self._map_type(coltype), + ) + def _collection_exists(self, collection_name): resp = self.client.get_collections() collections = [collection.name for collection in resp.collections] diff --git a/docarray/array/storage/qdrant/find.py b/docarray/array/storage/qdrant/find.py index fd70b2cd590..5f307c2dbdb 100644 --- a/docarray/array/storage/qdrant/find.py +++ b/docarray/array/storage/qdrant/find.py @@ -4,7 +4,7 @@ from docarray import Document, DocumentArray from docarray.math import ndarray from docarray.score import NamedScore -from qdrant_client.http import models as rest +from qdrant_client.http import models from qdrant_client.http.models.models import Distance if TYPE_CHECKING: # pragma: no cover @@ -59,7 +59,7 @@ def _find_similar_vectors( query_filter=filter, search_params=None if not search_params - else rest.SearchParams(**search_params), + else models.SearchParams(**search_params), limit=limit, append_payload=['_serialized'], ) @@ -117,7 +117,7 @@ def _find_with_filter( ): list_of_points, _offset = self.client.scroll( collection_name=self.collection_name, - scroll_filter=rest.Filter(**filter), + scroll_filter=models.Filter(**filter), with_payload=True, limit=limit, ) diff --git a/docs/advanced/document-store/qdrant.md b/docs/advanced/document-store/qdrant.md index a4da3c6fccc..10993db1ab5 100644 --- a/docs/advanced/document-store/qdrant.md +++ b/docs/advanced/document-store/qdrant.md @@ -97,7 +97,6 @@ The following configs can be set: | `root_id` | Boolean flag indicating whether to store `root_id` in the tags of chunk level Documents | True | - *You can read more about the HNSW parameters and their default values [here](https://qdrant.tech/documentation/indexing/#vector-index) ## Minimum example @@ -150,8 +149,7 @@ print(da.find(np.random.random(D), limit=10)) (qdrant-filter)= ## Vector search with filter -Search with `.find` can be restricted by user-defined filters. Such filters can be constructed following the guidelines -in [Qdrant's Documentation](https://qdrant.tech/documentation/filtering/) +Search with `.find` can be restricted by user-defined filters. The supported tag types for filter are `'int'`, `'float'`, `'bool'`, `'str'`, `'text'` and `'geo'` as in [Qdrant](https://qdrant.tech/documentation/payload/). Such filters can be constructed following the guidelines in [Qdrant's Documentation](https://qdrant.tech/documentation/filtering/) ### Example of `.find` with a filter @@ -276,4 +274,4 @@ Points with "price" at most 7: embedding=[7. 7. 7.], price=7 embedding=[1. 1. 1.], price=1 embedding=[2. 2. 2.], price=2 -``` \ No newline at end of file +``` diff --git a/tests/unit/array/mixins/test_find.py b/tests/unit/array/mixins/test_find.py index e8ea1e225ac..dae4518d0f4 100644 --- a/tests/unit/array/mixins/test_find.py +++ b/tests/unit/array/mixins/test_find.py @@ -663,39 +663,104 @@ def test_filtering( @pytest.mark.parametrize( - 'storage,filter_gen,numeric_operators,operator', + 'columns', [ - *[ - tuple( - [ - 'qdrant', - lambda operator, threshold: { - 'must': [{'key': 'price', 'match': {'value': threshold}}] + [ + ('price', 'float'), + ('category', 'str'), + ('info', 'text'), + ('location', 'geo'), + ], + {'price': 'float', 'category': 'str', 'info': 'text', 'location': 'geo'}, + ], +) +@pytest.mark.parametrize( + 'filter,checker', + [ + ( + { + 'must': [ + {"key": "category", "match": {"value": "Shoes"}}, + {"key": "price", "range": {"gte": 5.0}}, + ] + }, + lambda r: r.tags['category'] == "Shoes" and r.tags['price'] >= 5.0, + ), + ( + { + 'must_not': [ + {"key": "info", "match": {"text": "shoes"}}, + { + "key": "location", + "geo_radius": { + "center": {"lon": -98.17, "lat": 38.71}, + "radius": 500.0 * 1000, + }, }, - numeric_operators_qdrant, - 'eq', ] - ) - ], + }, + lambda r: r.tags['info'].find("shoes") == -1 + and ( + haversine_distances( + [ + [-98.17, 38.71], + [r.tags['location']['lon'], r.tags['location']['lat']], + ] + ) + * 6371 + )[0][1] + > 500.0, + ), + ( + { + 'should': [ + {"key": "info", "match": {"text": "shoes"}}, + {"key": "price", "range": {"gte": 5.0}}, + ] + }, + lambda r: r.tags['info'].find("shoes") != -1 or r.tags['price'] >= 5.0, + ), ], ) -@pytest.mark.parametrize('columns', [[('price', 'int')], {'price': 'int'}]) -def test_qdrant_filter_function( - storage, filter_gen, operator, numeric_operators, start_storage, columns -): +def test_qdrant_filter_query(filter, checker, columns, start_storage): n_dim = 128 da = DocumentArray(storage='qdrant', config={'n_dim': n_dim, 'columns': columns}) - da.extend([Document(id=f'r{i}', tags={'price': i}) for i in range(50)]) - thresholds = [10, 20, 30] - for threshold in thresholds: - filter = filter_gen(operator, threshold) - results = da._filter(filter=filter) - assert len(results) > 0 + da.extend( + [ + Document( + id=f'r{i}', + embedding=np.random.rand(n_dim), + tags={ + 'price': i + 0.5, + 'category': 'Shoes', + 'info': f'shoes {i}', + 'location': {"lon": -98.17 + i, "lat": 38.93 + i}, + }, + ) + for i in range(10) + ] + ) - assert all( - [numeric_operators[operator](r.tags['price'], threshold) for r in results] - ) + da.extend( + [ + Document( + id=f'r{i+10}', + embedding=np.random.rand(n_dim), + tags={ + 'price': i + 0.5, + 'category': 'Jeans', + 'info': 'jeans {i}', + 'location': {"lon": -98.17 + i, "lat": 38.93 + i}, + }, + ) + for i in range(10) + ] + ) + + results = da.find(np.random.rand(n_dim), filter=filter) + assert len(results) > 0 + assert all([checker(r) for r in results]) @pytest.mark.parametrize('columns', [[('price', 'int')], {'price': 'int'}])