Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 6 additions & 12 deletions docarray/index/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
Iterable,
List,
Mapping,
NamedTuple,
Optional,
Sequence,
Tuple,
Expand All @@ -30,7 +29,12 @@
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.utils._internal._typing import is_tensor_union
from docarray.utils._internal.misc import import_library
from docarray.utils.find import FindResult, _FindResult
from docarray.utils.find import (
FindResult,
FindResultBatched,
_FindResult,
_FindResultBatched,
)

if TYPE_CHECKING:
import tensorflow as tf # type: ignore
Expand All @@ -47,16 +51,6 @@
TSchema = TypeVar('TSchema', bound=BaseDoc)


class FindResultBatched(NamedTuple):
documents: List[DocList]
scores: List[np.ndarray]


class _FindResultBatched(NamedTuple):
documents: Union[List[DocList], List[List[Dict[str, Any]]]]
scores: List[np.ndarray]


def _raise_not_composable(name):
def _inner(self, *args, **kwargs):
raise NotImplementedError(
Expand Down
17 changes: 5 additions & 12 deletions docarray/index/backends/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,12 @@

import docarray.typing
from docarray import BaseDoc
from docarray.index.abstract import (
BaseDocIndex,
_ColumnInfo,
_FindResultBatched,
_raise_not_composable,
)
from docarray.index.abstract import BaseDocIndex, _ColumnInfo, _raise_not_composable
from docarray.typing import AnyTensor
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.typing.tensor.ndarray import NdArray
from docarray.utils._internal.misc import is_tf_available, is_torch_available
from docarray.utils.find import _FindResult
from docarray.utils.find import _FindResult, _FindResultBatched

TSchema = TypeVar('TSchema', bound=BaseDoc)
T = TypeVar('T', bound='ElasticDocIndex')
Expand Down Expand Up @@ -387,7 +382,7 @@ def _find_batched(
das, scores = zip(
*[self._format_response(resp) for resp in responses['responses']]
)
return _FindResultBatched(documents=list(das), scores=np.array(scores))
return _FindResultBatched(documents=list(das), scores=scores)

def _filter(
self,
Expand Down Expand Up @@ -445,9 +440,7 @@ def _text_search_batched(
das, scores = zip(
*[self._format_response(resp) for resp in responses['responses']]
)
return _FindResultBatched(
documents=list(das), scores=np.array(scores, dtype=object)
)
return _FindResultBatched(documents=list(das), scores=scores)

###############################################
# Helpers #
Expand Down Expand Up @@ -544,7 +537,7 @@ def _format_response(self, response: Any) -> Tuple[List[Dict], NdArray]:
docs.append(doc_dict)
scores.append(result['_score'])

return docs, parse_obj_as(NdArray, scores)
return docs, [parse_obj_as(NdArray, np.array(s)) for s in scores]

def _refresh(self, index_name: str):
self._client.indices.refresh(index=index_name)
Expand Down
3 changes: 1 addition & 2 deletions docarray/index/backends/hnswlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from docarray.index.abstract import (
BaseDocIndex,
_ColumnInfo,
_FindResultBatched,
_raise_not_composable,
_raise_not_supported,
)
Expand All @@ -33,7 +32,7 @@
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.utils._internal.misc import import_library, is_np_int
from docarray.utils.filter import filter_docs
from docarray.utils.find import _FindResult
from docarray.utils.find import _FindResult, _FindResultBatched

if TYPE_CHECKING:
import hnswlib
Expand Down
36 changes: 20 additions & 16 deletions docarray/index/backends/qdrant.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,35 @@
import uuid
from dataclasses import dataclass, field
from typing import (
TypeVar,
Generic,
Optional,
cast,
Sequence,
Any,
Union,
List,
Dict,
Generator,
Type,
Generic,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
cast,
)

import numpy as np
import qdrant_client
from grpc import RpcError # type: ignore[import]
from qdrant_client.conversions import common_types as types
from qdrant_client.http import models as rest
from qdrant_client.http.exceptions import UnexpectedResponse

import docarray.typing.id
from docarray import BaseDoc, DocList
from docarray.index.abstract import (
BaseDocIndex,
_FindResultBatched,
_ColumnInfo,
_FindResultBatched,
_raise_not_composable,
)

import qdrant_client
from qdrant_client.conversions import common_types as types
from qdrant_client.http import models as rest

from docarray.typing import NdArray
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.utils._internal.misc import torch_imported
Expand Down Expand Up @@ -391,7 +389,10 @@ def _find_batched(
for response in responses
],
scores=[
np.array([point.score for point in response]) for response in responses
NdArray._docarray_from_native(
np.array([point.score for point in response])
)
for response in responses
],
)

Expand Down Expand Up @@ -454,7 +455,10 @@ def _text_search_batched(
# semantic search over vectors. Thus, each document is scored with a value of 1
return _FindResultBatched(
documents=documents_batched,
scores=[np.ones(len(docs)) for docs in documents_batched],
scores=[
NdArray._docarray_from_native(np.ones(len(docs)))
for docs in documents_batched
],
)

def _build_point_from_row(self, row: Dict[str, Any]) -> rest.PointStruct:
Expand Down
42 changes: 28 additions & 14 deletions docarray/utils/find.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@ class _FindResult(NamedTuple):
scores: AnyTensor


class FindResultBatched(NamedTuple):
documents: List[DocList]
scores: List[AnyTensor]


class _FindResultBatched(NamedTuple):
documents: Union[List[DocList], List[List[Dict[str, Any]]]]
scores: List[AnyTensor]


def find(
index: AnyDocArray,
query: Union[AnyTensor, BaseDoc],
Expand Down Expand Up @@ -95,15 +105,16 @@ class MyDocument(BaseDoc):
and the second element contains the corresponding scores.
"""
query = _extract_embedding_single(query, search_field)
return find_batched(
docs, scores = find_batched(
index=index,
query=query,
search_field=search_field,
metric=metric,
limit=limit,
device=device,
descending=descending,
)[0]
)
return FindResult(documents=docs[0], scores=scores[0])


def find_batched(
Expand All @@ -114,7 +125,7 @@ def find_batched(
limit: int = 10,
device: Optional[str] = None,
descending: Optional[bool] = None,
) -> List[FindResult]:
) -> FindResultBatched:
Comment thread
npitsillos marked this conversation as resolved.
"""
Find the closest Documents in the index to the queries.
Supports PyTorch and NumPy embeddings.
Expand Down Expand Up @@ -142,23 +153,23 @@ class MyDocument(BaseDoc):

# use DocList as query
query = DocList[MyDocument]([MyDocument(embedding=torch.rand(128)) for _ in range(3)])
results = find_batched(
docs, scores = find_batched(
index=index,
query=query,
search_field='embedding',
metric='cosine_sim',
)
top_matches, scores = results[0]
top_matches, scores = docs[0], scores[0]

# use tensor as query
query = torch.rand(3, 128)
results = find_batched(
docs, scores = find_batched(
index=index,
query=query,
search_field='embedding',
metric='cosine_sim',
)
top_matches, scores = results[0]
top_matches, scores = docs[0], scores[0]
```

---
Expand All @@ -176,8 +187,8 @@ class MyDocument(BaseDoc):
can be either `cpu` or a `cuda` device.
:param descending: sort the results in descending order.
Per default, this is chosen based on the `metric` argument.
:return: a list of named tuples of the form (DocList, AnyTensor),
where the first element contains the closes matches for each query,
:return: A named tuple of the form (DocList, AnyTensor),
where the first element contains the closest matches for each query,
and the second element contains the corresponding scores.
"""
if descending is None:
Expand All @@ -197,14 +208,17 @@ class MyDocument(BaseDoc):
dists, k=limit, device=device, descending=descending
)

results = []
for indices_per_query, scores_per_query in zip(top_indices, top_scores):
batched_docs: List[DocList] = []
scores = []
for _, (indices_per_query, scores_per_query) in enumerate(
zip(top_indices, top_scores)
):
docs_per_query: DocList = DocList([])
for idx in indices_per_query: # workaround until #930 is fixed
docs_per_query.append(index[idx])
docs_per_query = DocList(docs_per_query)
results.append(FindResult(scores=scores_per_query, documents=docs_per_query))
return results
batched_docs.append(DocList(docs_per_query))
scores.append(scores_per_query)
return FindResultBatched(documents=batched_docs, scores=scores)


def _extract_embedding_single(
Expand Down
Loading