Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion docarray/array/storage/qdrant/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import numpy as np
from qdrant_client import QdrantClient
from qdrant_client.http import models
from qdrant_client.http.models.models import (
Distance,
CreateCollection,
Expand All @@ -24,7 +25,7 @@
)

from docarray import Document
from docarray.array.storage.base.backend import BaseBackendMixin
from docarray.array.storage.base.backend import BaseBackendMixin, TypeMap
from docarray.array.storage.qdrant.helper import DISTANCES
from docarray.helper import dataclass_from_dict, random_identity
from docarray.math.helper import EPSILON
Expand Down Expand Up @@ -74,6 +75,15 @@ def distance(self) -> 'Distance':
def _tmp_collection_name(cls) -> str:
return uuid.uuid4().hex

TYPE_MAP = {
'int': TypeMap(type='integer', converter=int),
'float': TypeMap(type='float', converter=float),
'bool': TypeMap(type='int', converter=bool),
'str': TypeMap(type='keyword', converter=str),
'text': TypeMap(type='text', converter=str),
'geo': TypeMap(type='geo', converter=dict),
}

def _init_storage(
self,
docs: Optional['DocumentArraySourceType'] = None,
Expand Down Expand Up @@ -172,6 +182,23 @@ def _initialize_qdrant_schema(self):
hnsw_config=hnsw_config,
)

for col, coltype in self._config.columns.items():
if coltype == 'text':
self.client.create_payload_index(
collection_name=self.collection_name,
field_name=col,
field_schema=models.TextIndexParams(
type="text",
tokenizer=models.TokenizerType.WORD,
),
)
else:
self.client.create_payload_index(
collection_name=self.collection_name,
field_name=col,
field_schema=self._map_type(coltype),
)

def _collection_exists(self, collection_name):
resp = self.client.get_collections()
collections = [collection.name for collection in resp.collections]
Expand Down
6 changes: 3 additions & 3 deletions docarray/array/storage/qdrant/find.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from docarray import Document, DocumentArray
from docarray.math import ndarray
from docarray.score import NamedScore
from qdrant_client.http import models as rest
from qdrant_client.http import models
from qdrant_client.http.models.models import Distance

if TYPE_CHECKING: # pragma: no cover
Expand Down Expand Up @@ -59,7 +59,7 @@ def _find_similar_vectors(
query_filter=filter,
search_params=None
if not search_params
else rest.SearchParams(**search_params),
else models.SearchParams(**search_params),
limit=limit,
append_payload=['_serialized'],
)
Expand Down Expand Up @@ -117,7 +117,7 @@ def _find_with_filter(
):
list_of_points, _offset = self.client.scroll(
collection_name=self.collection_name,
scroll_filter=rest.Filter(**filter),
scroll_filter=models.Filter(**filter),
with_payload=True,
limit=limit,
)
Expand Down
6 changes: 2 additions & 4 deletions docs/advanced/document-store/qdrant.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ The following configs can be set:
| `root_id` | Boolean flag indicating whether to store `root_id` in the tags of chunk level Documents | True |



*You can read more about the HNSW parameters and their default values [here](https://qdrant.tech/documentation/indexing/#vector-index)

## Minimum example
Expand Down Expand Up @@ -150,8 +149,7 @@ print(da.find(np.random.random(D), limit=10))
(qdrant-filter)=
## Vector search with filter

Search with `.find` can be restricted by user-defined filters. Such filters can be constructed following the guidelines
in [Qdrant's Documentation](https://qdrant.tech/documentation/filtering/)
Search with `.find` can be restricted by user-defined filters. The supported tag types for filter are `'int'`, `'float'`, `'bool'`, `'str'`, `'text'` and `'geo'` as in [Qdrant](https://qdrant.tech/documentation/payload/). Such filters can be constructed following the guidelines in [Qdrant's Documentation](https://qdrant.tech/documentation/filtering/)


### Example of `.find` with a filter
Expand Down Expand Up @@ -276,4 +274,4 @@ Points with "price" at most 7:
embedding=[7. 7. 7.], price=7
embedding=[1. 1. 1.], price=1
embedding=[2. 2. 2.], price=2
```
```
113 changes: 89 additions & 24 deletions tests/unit/array/mixins/test_find.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,39 +663,104 @@ def test_filtering(


@pytest.mark.parametrize(
'storage,filter_gen,numeric_operators,operator',
'columns',
[
*[
tuple(
[
'qdrant',
lambda operator, threshold: {
'must': [{'key': 'price', 'match': {'value': threshold}}]
[
('price', 'float'),
('category', 'str'),
('info', 'text'),
('location', 'geo'),
],
{'price': 'float', 'category': 'str', 'info': 'text', 'location': 'geo'},
],
)
@pytest.mark.parametrize(
'filter,checker',
[
(
{
'must': [
{"key": "category", "match": {"value": "Shoes"}},
{"key": "price", "range": {"gte": 5.0}},
]
},
lambda r: r.tags['category'] == "Shoes" and r.tags['price'] >= 5.0,
),
(
{
'must_not': [
{"key": "info", "match": {"text": "shoes"}},
{
"key": "location",
"geo_radius": {
"center": {"lon": -98.17, "lat": 38.71},
"radius": 500.0 * 1000,
},
},
numeric_operators_qdrant,
'eq',
]
)
],
},
lambda r: r.tags['info'].find("shoes") == -1
and (
haversine_distances(
[
[-98.17, 38.71],
[r.tags['location']['lon'], r.tags['location']['lat']],
]
)
* 6371
)[0][1]
> 500.0,
),
(
{
'should': [
{"key": "info", "match": {"text": "shoes"}},
{"key": "price", "range": {"gte": 5.0}},
]
},
lambda r: r.tags['info'].find("shoes") != -1 or r.tags['price'] >= 5.0,
),
],
)
@pytest.mark.parametrize('columns', [[('price', 'int')], {'price': 'int'}])
def test_qdrant_filter_function(
storage, filter_gen, operator, numeric_operators, start_storage, columns
):
def test_qdrant_filter_query(filter, checker, columns, start_storage):
n_dim = 128
da = DocumentArray(storage='qdrant', config={'n_dim': n_dim, 'columns': columns})
da.extend([Document(id=f'r{i}', tags={'price': i}) for i in range(50)])
thresholds = [10, 20, 30]
for threshold in thresholds:
filter = filter_gen(operator, threshold)
results = da._filter(filter=filter)

assert len(results) > 0
da.extend(
[
Document(
id=f'r{i}',
embedding=np.random.rand(n_dim),
tags={
'price': i + 0.5,
'category': 'Shoes',
'info': f'shoes {i}',
'location': {"lon": -98.17 + i, "lat": 38.93 + i},
},
)
for i in range(10)
]
)

assert all(
[numeric_operators[operator](r.tags['price'], threshold) for r in results]
)
da.extend(
[
Document(
id=f'r{i+10}',
embedding=np.random.rand(n_dim),
tags={
'price': i + 0.5,
'category': 'Jeans',
'info': 'jeans {i}',
'location': {"lon": -98.17 + i, "lat": 38.93 + i},
},
)
for i in range(10)
]
)

results = da.find(np.random.rand(n_dim), filter=filter)
assert len(results) > 0
assert all([checker(r) for r in results])


@pytest.mark.parametrize('columns', [[('price', 'int')], {'price': 'int'}])
Expand Down