diff --git a/docarray/index/backends/hnswlib.py b/docarray/index/backends/hnswlib.py index 6048a5fec8..022ab8d818 100644 --- a/docarray/index/backends/hnswlib.py +++ b/docarray/index/backends/hnswlib.py @@ -619,21 +619,17 @@ def _search_and_filter( index = self._hnsw_indices[search_field] - def accept_all(id): - """Accepts all IDs.""" - return True - def accept_hashed_ids(id): """Accepts IDs that are in hashed_ids.""" return id in hashed_ids # type: ignore[operator] # Choose the appropriate filter function based on whether hashed_ids was provided - filter_function = accept_hashed_ids if hashed_ids else accept_all + extra_kwargs = {'filter': accept_hashed_ids} if hashed_ids else {} # If hashed_ids is provided, k is the minimum of limit and the length of hashed_ids; else it is limit k = min(limit, len(hashed_ids)) if hashed_ids else limit - labels, distances = index.knn_query(queries, k=k, filter=filter_function) + labels, distances = index.knn_query(queries, k=k, **extra_kwargs) result_das = [ self._get_docs_sqlite_hashed_id(