From e4aa4a8f9a004ea4c41ba197397397084e50084f Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Mon, 20 Jun 2022 22:44:09 +0200 Subject: [PATCH] fix: customize metric fn expect no metric_name --- docarray/array/storage/memory/find.py | 2 +- docs/advanced/document-store/weaviate.md | 2 +- docs/fundamentals/documentarray/matching.md | 37 +++++---------------- tests/unit/array/mixins/test_find.py | 17 ++++++++++ 4 files changed, 27 insertions(+), 31 deletions(-) diff --git a/docarray/array/storage/memory/find.py b/docarray/array/storage/memory/find.py index b66917ec649..3a79902d5d3 100644 --- a/docarray/array/storage/memory/find.py +++ b/docarray/array/storage/memory/find.py @@ -67,7 +67,7 @@ def _find( batch_size = int(batch_size) if callable(metric): - cdist = metric + cdist = lambda *x: metric(*x[:2]) elif isinstance(metric, str): if use_scipy: from scipy.spatial.distance import cdist as cdist diff --git a/docs/advanced/document-store/weaviate.md b/docs/advanced/document-store/weaviate.md index fe3f2d97e92..0370a06ec7b 100644 --- a/docs/advanced/document-store/weaviate.md +++ b/docs/advanced/document-store/weaviate.md @@ -106,7 +106,7 @@ The following configs can be set: *You can read more about the HNSW parameters and their default values [here](https://weaviate.io/developers/weaviate/current/vector-index-plugins/hnsw.html#how-to-use-hnsw-and-parameters) -## Minimum Example +## Minimum example The following example shows how to use DocArray with Weaviate Document Store in order to index and search text Documents. diff --git a/docs/fundamentals/documentarray/matching.md b/docs/fundamentals/documentarray/matching.md index 03e21f66769..3a20a32f281 100644 --- a/docs/fundamentals/documentarray/matching.md +++ b/docs/fundamentals/documentarray/matching.md @@ -27,7 +27,7 @@ Though both `.find()` and `.match()` is about finding nearest neighbours of a gi In the sequel, we will use `.match()` to describe the features. But keep in mind that `.find()` should also work by simply switching the right and left-hand sides. -## Example +### Example The following example finds for each element in `da1` the three closest Documents from the elements in `da2` according to Euclidean distance. @@ -134,11 +134,11 @@ da2.find(da1, metric='euclidean', limit=3) or simply: ```python -da2.find(np.array( - [[0, 0, 0, 0, 1], - [1, 0, 0, 0, 0], - [1, 1, 1, 1, 0], - [1, 2, 2, 1, 0]]), metric='euclidean', limit=3) +da2.find( + np.array([[0, 0, 0, 0, 1], [1, 0, 0, 0, 0], [1, 1, 1, 1, 0], [1, 2, 2, 1, 0]]), + metric='euclidean', + limit=3, +) ``` The following metrics are supported: @@ -155,27 +155,6 @@ Note that framework is auto-chosen based on the type of `.embeddings`. For examp By default `A.match(B)` will copy the top-K matched Documents from B to `A.matches`. When these matches are big, copying them can be time-consuming. In this case, one can leverage `.match(..., only_id=True)` to keep only {attr}`~docarray.Document.id`. -### Pre filtering - -Both `match` and `find` support pre-filtering by passing a `filter` argument to the method. - -Pre-filtering is an advanced approximate nearest neighbors feature that allows to efficiently retrieve the nearest vectors -that respect the filtering condition. - -In contrast, post-filtering in the naive approach where you first retrieve the -nearest neighbors and then discard all the candidates that do not respect the filter condition. - -````{admonition} Pre-filtering is not available for in-memory backend -:class: caution -By default a DocumentArray will use the in-memory backend which does not support pre-filtering -``` -```` - -You can find example on how to use the pre-filtering here: - -- {ref}`ANNLite ` -- {ref}`Weaviate ` -- {ref}`Qdrant ` ### GPU support @@ -224,7 +203,7 @@ da2.embeddings = np.random.random([M, D]).astype(np.float32) ``` ```python -%timeit da1.match(da2, only_id=True) +da1.match(da2, only_id=True) ``` ```text @@ -243,7 +222,7 @@ da2.embeddings = torch.tensor(np.random.random([M, D]).astype(np.float32)) ``` ```python -%timeit da1.match(da2, device='cuda', batch_size=1_000, only_id=True) +da1.match(da2, device='cuda', batch_size=1_000, only_id=True) ``` ```text diff --git a/tests/unit/array/mixins/test_find.py b/tests/unit/array/mixins/test_find.py index d59efdd0bf5..c8ad6030cda 100644 --- a/tests/unit/array/mixins/test_find.py +++ b/tests/unit/array/mixins/test_find.py @@ -8,6 +8,23 @@ import operator +def test_customize_metric_fn(): + N, D = 4, 128 + da = DocumentArray.empty(N) + da.embeddings = np.random.random([N, D]) + + q = np.random.random([D]) + _, r1 = da.find(q)[:, ['scores__cosine__value', 'id']] + + from docarray.math.distance.numpy import cosine + + def inv_cosine(*args): + return -cosine(*args) + + _, r2 = da.find(q, metric=inv_cosine)[:, ['scores__inv_cosine__value', 'id']] + assert list(reversed(r1)) == r2 + + @pytest.mark.parametrize( 'storage, config', [