diff --git a/docarray/index/backends/helper.py b/docarray/index/backends/helper.py index 268f623ab1..5582dbba86 100644 --- a/docarray/index/backends/helper.py +++ b/docarray/index/backends/helper.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Tuple, Type, cast +from typing import Any, Dict, List, Tuple, Type, cast, Set from docarray import BaseDoc, DocList from docarray.index.abstract import BaseDocIndex @@ -20,6 +20,43 @@ def inner(self, *args, **kwargs): return inner +def _collect_query_required_args(method_name: str, required_args: Set[str] = None): + """ + Returns a function that ensures required keyword arguments are provided. + + :param method_name: The name of the method for which the required arguments are being checked. + :type method_name: str + :param required_args: A set containing the names of required keyword arguments. Defaults to None. + :type required_args: Optional[Set[str]] + :return: A function that checks for required keyword arguments before executing the specified method. + Raises ValueError if positional arguments are provided. + Raises TypeError if any required keyword argument is missing. + :rtype: Callable + """ + + if required_args is None: + required_args = set() + + def inner(self, *args, **kwargs): + if args: + raise ValueError( + f"Positional arguments are not supported for " + f"`{type(self)}.{method_name}`. " + f"Use keyword arguments instead." + ) + + missing_args = required_args - set(kwargs.keys()) + if missing_args: + raise ValueError( + f"`{type(self)}.{method_name}` is missing required argument(s): {', '.join(missing_args)}" + ) + + updated_query = self._queries + [(method_name, kwargs)] + return type(self)(updated_query) + + return inner + + def _execute_find_and_filter_query( doc_index: BaseDocIndex, query: List[Tuple[str, Dict]], reverse_order: bool = False ) -> FindResult: diff --git a/docarray/index/backends/mongodb_atlas.py b/docarray/index/backends/mongodb_atlas.py index caaa82742f..f2bbc04983 100644 --- a/docarray/index/backends/mongodb_atlas.py +++ b/docarray/index/backends/mongodb_atlas.py @@ -1,62 +1,96 @@ import collections import logging -from collections import defaultdict from dataclasses import dataclass, field from functools import cached_property - from typing import ( Any, Dict, Generator, Generic, List, + NamedTuple, Optional, Sequence, + Tuple, Type, TypeVar, Union, - Tuple, ) import bson import numpy as np from pymongo import MongoClient -from docarray import BaseDoc, DocList +from docarray import BaseDoc, DocList, handler from docarray.index.abstract import BaseDocIndex, _raise_not_composable +from docarray.index.backends.helper import _collect_query_required_args +from docarray.typing import AnyTensor from docarray.typing.tensor.abstract_tensor import AbstractTensor from docarray.utils._internal._typing import safe_issubclass from docarray.utils.find import _FindResult, _FindResultBatched +logger = logging.getLogger(__name__) +logger.addHandler(handler) + + MAX_CANDIDATES = 10_000 OVERSAMPLING_FACTOR = 10 TSchema = TypeVar('TSchema', bound=BaseDoc) +class HybridResult(NamedTuple): + """Adds breakdown of scores into vector and text components.""" + + documents: Union[DocList, List[Dict[str, Any]]] + scores: AnyTensor + score_breakdown: Dict[str, List[Any]] + + class MongoDBAtlasDocumentIndex(BaseDocIndex, Generic[TSchema]): + """DocumentIndex backed by MongoDB Atlas Vector Store. + + MongoDB Atlas provides full Text, Vector, and Hybrid Search + and can store structured data, text and vector indexes + in the same Collection (Index). + + Atlas provides efficient index and search on vector embeddings + using the Hierarchical Navigable Small Worlds (HNSW) algorithm. + + For documentation, see the following. + * Text Search: https://www.mongodb.com/docs/atlas/atlas-search/atlas-search-overview/ + * Vector Search: https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-overview/ + * Hybrid Search: https://www.mongodb.com/docs/atlas/atlas-vector-search/tutorials/reciprocal-rank-fusion/ + """ + def __init__(self, db_config=None, **kwargs): super().__init__(db_config=db_config, **kwargs) - self._logger = logging.getLogger(__name__) - self._create_indexes() - self._logger.info(f'{self.__class__.__name__} has been initialized') + logger.info(f'{self.__class__.__name__} has been initialized') @property - def _collection(self): - if self._is_subindex: - return self._db_config.index_name + def index_name(self): + """The name of the index/collection in the database. - if not self._schema: - raise ValueError( - 'A MongoDBAtlasDocumentIndex must be typed with a Document type.' - 'To do so, use the syntax: MongoDBAtlasDocumentIndex[DocumentType]' - ) + Note that in MongoDB Atlas, one has Collections (analogous to Tables), + which can have Search Indexes. They are distinct. + DocArray tends to consider them together. - return self._schema.__name__.lower() + The index_name can be set when initializing MongoDBAtlasDocumentIndex. + The easiest way is to pass index_name= as a kwarg. + Otherwise, a rational default uses the name of the DocumentTypes that it contains. + """ - @property - def index_name(self): - """Return the name of the index in the database.""" - return self._collection + if self._db_config.index_name is not None: + return self._db_config.index_name + else: + # Create a reasonable default + if not self._schema: + raise ValueError( + 'A MongoDBAtlasDocumentIndex must be typed with a Document type.' + 'To do so, use the syntax: MongoDBAtlasDocumentIndex[DocumentType]' + ) + schema_name = self._schema.__name__.lower() + logger.debug(f"db_config.index_name was not set. Using {schema_name}") + return schema_name @property def _database_name(self): @@ -69,8 +103,9 @@ def _client(self): ) @property - def _doc_collection(self): - return self._client[self._database_name][self._collection] + def _collection(self): + """MongoDB Collection""" + return self._client[self._database_name][self.index_name] @staticmethod def _connect_to_mongodb_atlas(atlas_connection_uri: str): @@ -86,43 +121,182 @@ def _connect_to_mongodb_atlas(atlas_connection_uri: str): def _create_indexes(self): """Create a new index in the MongoDB database if it doesn't already exist.""" - self._logger.warning( - "Search Indexes in MongoDB Atlas must be created manually. " - "Currently, client-side creation of vector indexes is not allowed on free clusters." - "Please follow instructions in docs/API_reference/doc_index/backends/mongodb.md" - ) + + def _check_index_exists(self, index_name: str) -> bool: + """ + Check if an index exists in the MongoDB Atlas database. + + :param index_name: The name of the index. + :return: True if the index exists, False otherwise. + """ + + @dataclass + class Query: + """Dataclass describing a query.""" + + vector_fields: Optional[Dict[str, np.ndarray]] + filters: Optional[List[Any]] + text_searches: Optional[List[Any]] + limit: int class QueryBuilder(BaseDocIndex.QueryBuilder): - ... + """Compose complex queries containing vector search (find), text_search, and filters. + + Arguments to `find` are vectors of embeddings, text_search expects strings, + and filters expect dicts of MongoDB Query Language (MDB). + + + NOTE: When doing Hybrid Search, pay close attention to the interpretation and use of inputs, + particularly when multiple calls are made of the same method (find, text_search, filter). + * find (Vector Search): Embedding vectors will be averaged. The penalty/weight defined in DBConfig will not change. + * text_search: Individual searches are performed, each with the same penalty/weight. + * filter: Within Vector Search, performs efficient k-NN filtering with the Lucene engine + """ + + def __init__(self, query: Optional[List[Tuple[str, Dict]]] = None): + super().__init__() + # list of tuples (method name, kwargs) + self._queries: List[Tuple[str, Dict]] = query or [] + + def build(self, limit: int = 1, *args, **kwargs) -> Any: + """Build a `Query` that can be passed to `execute_query`.""" + search_fields: Dict[str, np.ndarray] = collections.defaultdict(list) + filters: List[Any] = [] + text_searches: List[Any] = [] + for method, kwargs in self._queries: + if method == 'find': + search_field = kwargs['search_field'] + search_fields[search_field].append(kwargs["query"]) + + elif method == 'filter': + filters.append(kwargs) + else: + text_searches.append(kwargs) + + vector_fields = { + field: np.average(vectors, axis=0) + for field, vectors in search_fields.items() + } + return MongoDBAtlasDocumentIndex.Query( + vector_fields=vector_fields, + filters=filters, + text_searches=text_searches, + limit=limit, + ) + + find = _collect_query_required_args('find', {'search_field', 'query'}) + filter = _collect_query_required_args('filter', {'query'}) + text_search = _collect_query_required_args( + 'text_search', {'search_field', 'query'} + ) - find = _raise_not_composable('find') - filter = _raise_not_composable('filter') - text_search = _raise_not_composable('text_search') find_batched = _raise_not_composable('find_batched') filter_batched = _raise_not_composable('filter_batched') text_search_batched = _raise_not_composable('text_search_batched') - def execute_query(self, query: Any, *args, **kwargs) -> _FindResult: - """ - Execute a query on the database. - Can take two kinds of inputs: - 1. A native query of the underlying database. This is meant as a passthrough so that you - can enjoy any functionality that is not available through the Document index API. - 2. The output of this Document index' `QueryBuilder.build()` method. - :param query: the query to execute + def execute_query( + self, query: Any, *args, score_breakdown=True, **kwargs + ) -> Any: # _FindResult: + """Execute a Query on the database. + + :param query: the query to execute. The output of this Document index's `QueryBuilder.build()` method. :param args: positional arguments to pass to the query + :param score_breakdown: Will provide breakdown of scores into text and vector components for Hybrid Searches. :param kwargs: keyword arguments to pass to the query :return: the result of the query """ - ... + if not isinstance(query, MongoDBAtlasDocumentIndex.Query): + raise ValueError( + "Expected MongoDBAtlasDocumentIndex.Query. Found {type(query)=}." + "For native calls to MongoDBAtlasDocumentIndex, simply call filter()" + ) + + if len(query.vector_fields) > 1: + self._logger.warning( + f"{len(query.vector_fields)} embedding vectors have been provided to the query. They will be averaged." + ) + if len(query.text_searches) > 1: + self._logger.warning( + f"{len(query.text_searches)} text searches will be performed, and each receive a ranked score." + ) + + # collect filters + filters: List[Dict[str, Any]] = [] + for filter_ in query.filters: + filters.append(filter_['query']) + + # check if hybrid search is needed. + hybrid = len(query.vector_fields) + len(query.text_searches) > 1 + if hybrid: + if len(query.vector_fields) > 1: + raise NotImplementedError( + "Hybrid Search on multiple Vector Indexes has yet to be done." + ) + pipeline = self._hybrid_search( + query.vector_fields, query.text_searches, filters, query.limit + ) + else: + if query.text_searches: + # it is a simple text search, perhaps with filters. + text_stage = self._text_search_stage(**query.text_searches[0]) + pipeline = [ + text_stage, + {"$match": {"$and": filters} if filters else {}}, + { + '$project': self._project_fields( + extra_fields={"score": {'$meta': 'searchScore'}} + ) + }, + {"$limit": query.limit}, + ] + elif query.vector_fields: + # it is a simple vector search, perhaps with filters. + assert ( + len(query.vector_fields) == 1 + ), "Query contains more than one vector_field." + field, vector_query = list(query.vector_fields.items())[0] + pipeline = [ + self._vector_search_stage( + query=vector_query, + search_field=field, + limit=query.limit, + filters=filters, + ), + { + '$project': self._project_fields( + extra_fields={"score": {'$meta': 'vectorSearchScore'}} + ) + }, + ] + # it is only a filter search. + else: + pipeline = [{"$match": {"$and": filters}}] + + with self._collection.aggregate(pipeline) as cursor: + results, scores = self._mongo_to_docs(cursor) + docs = self._dict_list_to_docarray(results) + + if hybrid and score_breakdown and results: + score_breakdown = collections.defaultdict(list) + score_fields = [key for key in results[0] if "score" in key] + for res in results: + score_breakdown["id"].append(res["id"]) + for sf in score_fields: + score_breakdown[sf].append(res[sf]) + logger.debug(score_breakdown) + return HybridResult( + documents=docs, scores=scores, score_breakdown=score_breakdown + ) + + return _FindResult(documents=docs, scores=scores) @dataclass class DBConfig(BaseDocIndex.DBConfig): mongo_connection_uri: str = 'localhost' index_name: Optional[str] = None - database_name: Optional[str] = "db" + database_name: Optional[str] = "default" default_column_config: Dict[Type, Dict[str, Any]] = field( - default_factory=lambda: defaultdict( + default_factory=lambda: collections.defaultdict( dict, { bson.BSONARR: { @@ -131,13 +305,13 @@ class DBConfig(BaseDocIndex.DBConfig): 'max_candidates': MAX_CANDIDATES, 'indexed': False, 'index_name': None, - 'penalty': 1, + 'penalty': 5, }, bson.BSONSTR: { 'indexed': False, 'index_name': None, 'operator': 'phrase', - 'penalty': 10, + 'penalty': 1, }, }, ) @@ -145,7 +319,7 @@ class DBConfig(BaseDocIndex.DBConfig): @dataclass class RuntimeConfig(BaseDocIndex.RuntimeConfig): - pass + ... def python_type_to_db_type(self, python_type: Type) -> Any: """Map python type to database type. @@ -186,16 +360,14 @@ def _docs_to_mongo(self, docs): return [self._doc_to_mongo(doc) for doc in docs] @staticmethod - def _mongo_to_doc(mongo_doc: dict) -> Tuple[dict, float]: + def _mongo_to_doc(mongo_doc: dict) -> dict: result = mongo_doc.copy() result["id"] = result.pop("_id") - score = result.pop("score", None) + score = result.get("score", None) return result, score @staticmethod - def _mongo_to_docs( - mongo_docs: Generator[Dict, None, None] - ) -> Tuple[List[dict], List[float]]: + def _mongo_to_docs(mongo_docs: Generator[Dict, None, None]) -> List[dict]: docs = [] scores = [] for mongo_doc in mongo_docs: @@ -212,11 +384,15 @@ def _get_max_candidates(self, search_field: str) -> int: return self._column_infos[search_field].config["max_candidates"] def _index(self, column_to_data: Dict[str, Generator[Any, None, None]]): - """index a document into the store""" - # `column_to_data` is a dictionary from column name to a generator - # that yields the data for that column. - # If you want to work directly on documents, you can implement index() instead - # If you implement index(), _index() only needs a dummy implementation. + """Add and Index Documents to the datastore + + The input format is aimed towards column vectors, which is not + the natural fit for MongoDB Collections, but we have chosen + not to override BaseDocIndex.index as it provides valuable validation. + This may change in the future. + + :param column_to_data: is a dictionary from column name to a generator + """ self._index_subindex(column_to_data) docs: List[Dict[str, Any]] = [] while True: @@ -226,11 +402,11 @@ def _index(self, column_to_data: Dict[str, Generator[Any, None, None]]): docs.append(mongo_doc) except StopIteration: break - self._doc_collection.insert_many(docs) + self._collection.insert_many(docs) def num_docs(self) -> int: """Return the number of indexed documents""" - return self._doc_collection.count_documents({}) + return self._collection.count_documents({}) @property def _is_index_empty(self) -> bool: @@ -246,7 +422,7 @@ def _del_items(self, doc_ids: Sequence[str]) -> None: :param doc_ids: ids to delete from the Document Store """ mg_filter = {"_id": {"$in": doc_ids}} - self._doc_collection.delete_many(mg_filter) + self._collection.delete_many(mg_filter) def _get_items( self, doc_ids: Sequence[str] @@ -258,29 +434,138 @@ def _get_items( :return: Sequence of Documents, sorted corresponding to the order of `doc_ids`. Duplicate `doc_ids` can be omitted in the output. """ mg_filter = {"_id": {"$in": doc_ids}} - docs = self._doc_collection.find(mg_filter) + docs = self._collection.find(mg_filter) docs, _ = self._mongo_to_docs(docs) if not docs: raise KeyError(f'No document with id {doc_ids} found') return docs - def _vector_stage_search( + def _reciprocal_rank_stage(self, search_field: str, score_field: str): + penalty = self._column_infos[search_field].config["penalty"] + projection_fields = { + key: f"$docs.{key}" for key in self._column_infos.keys() if key != "id" + } + projection_fields["_id"] = "$docs._id" + projection_fields[score_field] = 1 + + return [ + {"$group": {"_id": None, "docs": {"$push": "$$ROOT"}}}, + {"$unwind": {"path": "$docs", "includeArrayIndex": "rank"}}, + { + "$addFields": { + score_field: {"$divide": [1.0, {"$add": ["$rank", penalty, 1]}]} + } + }, + {'$project': projection_fields}, + ] + + def _add_stage_to_pipeline(self, pipeline: List[Any], stage: Dict[str, Any]): + if pipeline: + pipeline.append( + {"$unionWith": {"coll": self.index_name, "pipeline": stage}} + ) + else: + pipeline.extend(stage) + return pipeline + + def _final_stage(self, scores_fields, limit): + """Sum individual scores, sort, and apply limit.""" + doc_fields = self._column_infos.keys() + grouped_fields = { + key: {"$first": f"${key}"} for key in doc_fields if key != "_id" + } + best_score = {score: {'$max': f'${score}'} for score in scores_fields} + final_pipeline = [ + {"$group": {"_id": "$_id", **grouped_fields, **best_score}}, + { + "$project": { + **{doc_field: 1 for doc_field in doc_fields}, + **{score: {"$ifNull": [f"${score}", 0]} for score in scores_fields}, + } + }, + { + "$addFields": { + "score": {"$add": [f"${score}" for score in scores_fields]}, + } + }, + {"$sort": {"score": -1}}, + {"$limit": limit}, + ] + return final_pipeline + + @staticmethod + def _score_field(search_field: str, search_field_counts: Dict[str, int]): + score_field = f"{search_field}_score" + count = search_field_counts[search_field] + if count > 1: + score_field += str(count) + return score_field + + def _hybrid_search( + self, + vector_queries: Dict[str, Any], + text_queries: List[Dict[str, Any]], + filters: Dict[str, Any], + limit: int, + ): + hybrid_pipeline = [] # combined aggregate pipeline + search_field_counts = collections.defaultdict( + int + ) # stores count of calls on same search field + score_fields = [] # names given to scores of each search stage + for search_field, query in vector_queries.items(): + search_field_counts[search_field] += 1 + vector_stage = self._vector_search_stage( + query=query, + search_field=search_field, + limit=limit, + filters=filters, + ) + score_field = self._score_field(search_field, search_field_counts) + score_fields.append(score_field) + vector_pipeline = [ + vector_stage, + *self._reciprocal_rank_stage(search_field, score_field), + ] + self._add_stage_to_pipeline(hybrid_pipeline, vector_pipeline) + + for kwargs in text_queries: + search_field_counts[kwargs["search_field"]] += 1 + text_stage = self._text_search_stage(**kwargs) + search_field = kwargs["search_field"] + score_field = self._score_field(search_field, search_field_counts) + score_fields.append(score_field) + reciprocal_rank_stage = self._reciprocal_rank_stage( + search_field, score_field + ) + text_pipeline = [ + text_stage, + {"$match": {"$and": filters} if filters else {}}, + {"$limit": limit}, + *reciprocal_rank_stage, + ] + self._add_stage_to_pipeline(hybrid_pipeline, text_pipeline) + + hybrid_pipeline += self._final_stage(score_fields, limit) + return hybrid_pipeline + + def _vector_search_stage( self, query: np.ndarray, search_field: str, limit: int, - filters: List[Dict[str, Any]] = [], + filters: List[Dict[str, Any]] = None, ) -> Dict[str, Any]: - index_name = self._get_column_db_index(search_field) + search_index_name = self._get_column_db_index(search_field) oversampling_factor = self._get_oversampling_factor(search_field) max_candidates = self._get_max_candidates(search_field) query = query.astype(np.float64).tolist() return { '$vectorSearch': { - 'index': index_name, + 'index': search_index_name, 'path': search_field, 'queryVector': query, 'numCandidates': min(limit * oversampling_factor, max_candidates), @@ -289,13 +574,7 @@ def _vector_stage_search( } } - def _filter_query( - self, - query: Any, - ) -> Dict[str, Any]: - return query - - def _text_stage_step( + def _text_search_stage( self, query: str, search_field: str, @@ -316,7 +595,7 @@ def _doc_exists(self, doc_id: str) -> bool: :param doc_id: The id of a document to check. :return: True if the document exists in the index, False otherwise. """ - doc = self._doc_collection.find_one({"_id": doc_id}) + doc = self._collection.find_one({"_id": doc_id}) return bool(doc) def _find( @@ -330,12 +609,12 @@ def _find( :param query: query vector for KNN/ANN search. Has single axis. :param limit: maximum number of documents to return per query :param search_field: name of the field to search on - :return: a named NamedTuple containing `documents` and `scores` + :return: a named tuple containing `documents` and `scores` """ # NOTE: in standard implementations, # `search_field` is equal to the column name to search on - vector_search_stage = self._vector_stage_search(query, search_field, limit) + vector_search_stage = self._vector_search_stage(query, search_field, limit) pipeline = [ vector_search_stage, @@ -346,7 +625,7 @@ def _find( }, ] - with self._doc_collection.aggregate(pipeline) as cursor: + with self._collection.aggregate(pipeline) as cursor: documents, scores = self._mongo_to_docs(cursor) return _FindResult(documents=documents, scores=scores) @@ -360,7 +639,7 @@ def _find_batched( Has shape (batch_size, vector_dim) :param limit: maximum number of documents to return :param search_field: name of the field to search on - :return: a named NamedTuple containing `documents` and `scores` + :return: a named tuple containing `documents` and `scores` """ docs, scores = [], [] for query in queries: @@ -433,7 +712,7 @@ def _filter( :param limit: maximum number of documents to return :return: a DocList containing the documents that match the filter query """ - with self._doc_collection.find(filter_query, limit=limit) as cursor: + with self._collection.find(filter_query, limit=limit) as cursor: return self._mongo_to_docs(cursor)[0] def _filter_batched( @@ -462,9 +741,9 @@ def _text_search( :param query: The text to search for :param limit: maximum number of documents to return :param search_field: name of the field to search on - :return: a named Tuple containing `documents` and `scores` + :return: a named tuple containing `documents` and `scores` """ - text_stage = self._text_stage_step(query=query, search_field=search_field) + text_stage = self._text_search_stage(query=query, search_field=search_field) pipeline = [ text_stage, @@ -476,7 +755,7 @@ def _text_search( {"$limit": limit}, ] - with self._doc_collection.aggregate(pipeline) as cursor: + with self._collection.aggregate(pipeline) as cursor: documents, scores = self._mongo_to_docs(cursor) return _FindResult(documents=documents, scores=scores) @@ -492,7 +771,7 @@ def _text_search_batched( :param queries: The texts to search for :param limit: maximum number of documents to return per query :param search_field: name of the field to search on - :return: a named Tuple containing `documents` and `scores` + :return: a named tuple containing `documents` and `scores` """ # NOTE: in standard implementations, # `search_field` is equal to the column name to search on @@ -511,7 +790,5 @@ def _filter_by_parent_id(self, id: str) -> Optional[List[str]]: :param id: the root document id to filter by :return: a list of ids of the subindex documents """ - with self._doc_collection.find( - {"parent_id": id}, projection={"_id": 1} - ) as cursor: + with self._collection.find({"parent_id": id}, projection={"_id": 1}) as cursor: return [doc["_id"] for doc in cursor] diff --git a/pyproject.toml b/pyproject.toml index 26d1a04766..c908917161 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -165,5 +165,6 @@ markers = [ "index: marks test using a document index", "benchmark: marks slow benchmarking tests", "elasticv8: marks test that run with ElasticSearch v8", - "jac: need to have access to jac cloud" + "jac: need to have access to jac cloud", + "atlas: mark tests using MongoDB Atlas", ] diff --git a/tests/index/mongo_atlas/__init__.py b/tests/index/mongo_atlas/__init__.py index 352060a305..360ba6ee1c 100644 --- a/tests/index/mongo_atlas/__init__.py +++ b/tests/index/mongo_atlas/__init__.py @@ -26,8 +26,7 @@ class NestedDoc(BaseDoc): class FlatSchema(BaseDoc): embedding1: NdArray = Field(dim=N_DIM, index_name="vector_index_1") - # the dim and N_DIM are setted different on propouse. to check the correct handling of n_dim - embedding2: NdArray[50] = Field(dim=N_DIM, index_name="vector_index_2") + embedding2: NdArray = Field(dim=N_DIM, index_name="vector_index_2") def assert_when_ready(callable: Callable, tries: int = 5, interval: float = 2): @@ -37,10 +36,10 @@ def assert_when_ready(callable: Callable, tries: int = 5, interval: float = 2): while True: try: callable() - except AssertionError: + except AssertionError as e: tries -= 1 if tries == 0: - raise + raise RuntimeError("Retries exhausted.") from e time.sleep(interval) else: return diff --git a/tests/index/mongo_atlas/conftest.py b/tests/index/mongo_atlas/conftest.py index 727fabb1f5..beb1276eed 100644 --- a/tests/index/mongo_atlas/conftest.py +++ b/tests/index/mongo_atlas/conftest.py @@ -1,3 +1,4 @@ +import logging import os import numpy as np @@ -19,7 +20,9 @@ def mongodb_index_config(): @pytest.fixture def simple_index(mongodb_index_config): - index = MongoDBAtlasDocumentIndex[SimpleSchema](**mongodb_index_config) + index = MongoDBAtlasDocumentIndex[SimpleSchema]( + index_name="bespoke_name", **mongodb_index_config + ) return index @@ -30,8 +33,20 @@ def nested_index(mongodb_index_config): @pytest.fixture(scope='module') -def random_simple_documents(): - N_DIM = 10 +def n_dim(): + return 10 + + +@pytest.fixture(scope='module') +def embeddings(n_dim): + """A consistent, reasonable, mock of vector embeddings, in [-1, 1].""" + x = np.linspace(-np.pi, np.pi, n_dim) + y = np.arange(n_dim) + return np.sin(x[np.newaxis, :] + y[:, np.newaxis]) + + +@pytest.fixture(scope='module') +def random_simple_documents(n_dim, embeddings): docs_text = [ "Text processing with Python is a valuable skill for data analysis.", "Gardening tips for a beautiful backyard oasis.", @@ -45,37 +60,36 @@ def random_simple_documents(): "eleifend eros non, accumsan lectus. Curabitur porta auctor tellus at pharetra. Phasellus ut condimentum", ] return [ - SimpleSchema(embedding=np.random.rand(N_DIM), number=i, text=docs_text[i]) - for i in range(10) + SimpleSchema(embedding=embeddings[i], number=i, text=docs_text[i]) + for i in range(len(docs_text)) ] @pytest.fixture -def nested_documents(): - N_DIM = 10 +def nested_documents(n_dim): docs = [ NestedDoc( - d=SimpleDoc(embedding=np.random.rand(N_DIM)), - embedding=np.random.rand(N_DIM), + d=SimpleDoc(embedding=np.random.rand(n_dim)), + embedding=np.random.rand(n_dim), ) for _ in range(10) ] docs.append( NestedDoc( - d=SimpleDoc(embedding=np.zeros(N_DIM)), - embedding=np.ones(N_DIM), + d=SimpleDoc(embedding=np.zeros(n_dim)), + embedding=np.ones(n_dim), ) ) docs.append( NestedDoc( - d=SimpleDoc(embedding=np.ones(N_DIM)), - embedding=np.zeros(N_DIM), + d=SimpleDoc(embedding=np.ones(n_dim)), + embedding=np.zeros(n_dim), ) ) docs.append( NestedDoc( - d=SimpleDoc(embedding=np.zeros(N_DIM)), - embedding=np.ones(N_DIM), + d=SimpleDoc(embedding=np.zeros(n_dim)), + embedding=np.ones(n_dim), ) ) return docs @@ -86,10 +100,11 @@ def simple_index_with_docs(simple_index, random_simple_documents): """ Setup and teardown of simple_index. Accesses the underlying MongoDB collection directly. """ - simple_index._doc_collection.delete_many({}) + simple_index._collection.delete_many({}) + simple_index._logger.setLevel(logging.DEBUG) simple_index.index(random_simple_documents) yield simple_index, random_simple_documents - simple_index._doc_collection.delete_many({}) + simple_index._collection.delete_many({}) @pytest.fixture @@ -97,7 +112,7 @@ def nested_index_with_docs(nested_index, nested_documents): """ Setup and teardown of simple_index. Accesses the underlying MongoDB collection directly. """ - nested_index._doc_collection.delete_many({}) + nested_index._collection.delete_many({}) nested_index.index(nested_documents) yield nested_index, nested_documents - nested_index._doc_collection.delete_many({}) + nested_index._collection.delete_many({}) diff --git a/tests/index/mongo_atlas/test_find.py b/tests/index/mongo_atlas/test_find.py index aadfacb454..e9968b05dd 100644 --- a/tests/index/mongo_atlas/test_find.py +++ b/tests/index/mongo_atlas/test_find.py @@ -8,13 +8,11 @@ from . import NestedDoc, SimpleDoc, SimpleSchema, assert_when_ready -N_DIM = 10 - -def test_find_simple_schema(simple_index_with_docs): # noqa: F811 +def test_find_simple_schema(simple_index_with_docs, n_dim): # noqa: F811 simple_index, random_simple_documents = simple_index_with_docs # noqa: F811 - query = np.ones(N_DIM) + query = np.ones(n_dim) # Insert one doc that identically matches query's embedding expected_matching_document = SimpleSchema(embedding=query, text="other", number=10) @@ -29,8 +27,8 @@ def pred(): assert_when_ready(pred) -def test_find_empty_index(simple_index): # noqa: F811 - query = np.random.rand(N_DIM) +def test_find_empty_index(simple_index, n_dim): # noqa: F811 + query = np.random.rand(n_dim) def pred(): docs, scores = simple_index.find(query, search_field='embedding', limit=5) @@ -40,10 +38,10 @@ def pred(): assert_when_ready(pred) -def test_find_limit_larger_than_index(simple_index_with_docs): # noqa: F811 +def test_find_limit_larger_than_index(simple_index_with_docs, n_dim): # noqa: F811 simple_index, random_simple_documents = simple_index_with_docs # noqa: F811 - query = np.ones(N_DIM) + query = np.ones(n_dim) new_doc = SimpleSchema(embedding=query, text="other", number=10) simple_index.index(new_doc) @@ -56,29 +54,29 @@ def pred(): assert_when_ready(pred) -def test_find_flat_schema(mongodb_index_config): # noqa: F811 +def test_find_flat_schema(mongodb_index_config, n_dim): # noqa: F811 class FlatSchema(BaseDoc): - embedding1: NdArray = Field(dim=N_DIM, index_name="vector_index_1") - # the dim and N_DIM are setted different on propouse. to check the correct handling of n_dim - embedding2: NdArray[50] = Field(dim=N_DIM, index_name="vector_index_2") + embedding1: NdArray = Field(dim=n_dim, index_name="vector_index_1") + # the dim and n_dim are setted different on propouse. to check the correct handling of n_dim + embedding2: NdArray[50] = Field(dim=n_dim, index_name="vector_index_2") index = MongoDBAtlasDocumentIndex[FlatSchema](**mongodb_index_config) - index._doc_collection.delete_many({}) + index._collection.delete_many({}) index_docs = [ - FlatSchema(embedding1=np.random.rand(N_DIM), embedding2=np.random.rand(50)) + FlatSchema(embedding1=np.random.rand(n_dim), embedding2=np.random.rand(50)) for _ in range(10) ] - index_docs.append(FlatSchema(embedding1=np.zeros(N_DIM), embedding2=np.ones(50))) - index_docs.append(FlatSchema(embedding1=np.ones(N_DIM), embedding2=np.zeros(50))) + index_docs.append(FlatSchema(embedding1=np.zeros(n_dim), embedding2=np.ones(50))) + index_docs.append(FlatSchema(embedding1=np.ones(n_dim), embedding2=np.zeros(50))) index.index(index_docs) def pred1(): # find on embedding1 - query = np.ones(N_DIM) + query = np.ones(n_dim) docs, scores = index.find(query, search_field='embedding1', limit=5) assert len(docs) == 5 assert len(scores) == 5 @@ -116,10 +114,10 @@ def pred(): assert_when_ready(pred) -def test_find_nested_schema(nested_index_with_docs): # noqa: F811 +def test_find_nested_schema(nested_index_with_docs, n_dim): # noqa: F811 db, base_docs = nested_index_with_docs - query = NestedDoc(d=SimpleDoc(embedding=np.ones(N_DIM)), embedding=np.ones(N_DIM)) + query = NestedDoc(d=SimpleDoc(embedding=np.ones(n_dim)), embedding=np.ones(n_dim)) # find on root level def pred(): @@ -137,11 +135,11 @@ def pred(): assert_when_ready(pred) -def test_find_schema_without_index(mongodb_index_config): # noqa: F811 +def test_find_schema_without_index(mongodb_index_config, n_dim): # noqa: F811 class Schema(BaseDoc): - vec: NdArray = Field(dim=N_DIM) + vec: NdArray = Field(dim=n_dim) index = MongoDBAtlasDocumentIndex[Schema](**mongodb_index_config) - query = np.ones(N_DIM) + query = np.ones(n_dim) with pytest.raises(ValueError): index.find(query, search_field='vec', limit=2) diff --git a/tests/index/mongo_atlas/test_persist_data.py b/tests/index/mongo_atlas/test_persist_data.py index 62ff02348d..d170bfc22a 100644 --- a/tests/index/mongo_atlas/test_persist_data.py +++ b/tests/index/mongo_atlas/test_persist_data.py @@ -5,7 +5,7 @@ def test_persist(mongodb_index_config, random_simple_documents): # noqa: F811 index = MongoDBAtlasDocumentIndex[SimpleSchema](**mongodb_index_config) - index._doc_collection.delete_many({}) + index._collection.delete_many({}) def cleaned_database(): assert index.num_docs() == 0 diff --git a/tests/index/mongo_atlas/test_query_builder.py b/tests/index/mongo_atlas/test_query_builder.py new file mode 100644 index 0000000000..3b103cec3d --- /dev/null +++ b/tests/index/mongo_atlas/test_query_builder.py @@ -0,0 +1,352 @@ +import numpy as np +import pytest + +from . import assert_when_ready + + +def test_missing_required_var_exceptions(simple_index): # noqa: F811 + """Ensure that exceptions are raised when required arguments are not provided.""" + + with pytest.raises(ValueError): + simple_index.build_query().find().build() + + with pytest.raises(ValueError): + simple_index.build_query().text_search().build() + + with pytest.raises(ValueError): + simple_index.build_query().filter().build() + + +def test_find_uses_provided_vector(simple_index): # noqa: F811 + query = ( + simple_index.build_query() + .find(query=np.ones(10), search_field='embedding') + .build(7) + ) + + query_vector = query.vector_fields.pop('embedding') + assert query.vector_fields == {} + assert np.allclose(query_vector, np.ones(10)) + assert query.filters == [] + assert query.limit == 7 + + +def test_multiple_find_returns_averaged_vector(simple_index, n_dim): # noqa: F811 + query = ( + simple_index.build_query() # type: ignore[attr-defined] + .find(query=np.ones(n_dim), search_field='embedding') + .find(query=np.zeros(n_dim), search_field='embedding') + .build(5) + ) + + assert len(query.vector_fields) == 1 + query_vector = query.vector_fields.pop('embedding') + assert query.vector_fields == {} + assert np.allclose(query_vector, np.array([0.5] * n_dim)) + assert query.filters == [] + assert query.limit == 5 + + +def test_filter_passes_filter(simple_index): # noqa: F811 + index = simple_index + + filter = {"number": {"$lt": 1}} + query = index.build_query().filter(query=filter).build(limit=11) # type: ignore[attr-defined] + + assert query.vector_fields == {} + assert query.filters == [{"query": filter}] + assert query.limit == 11 + + +def test_execute_query_find_filter(simple_index_with_docs, n_dim): # noqa: F811 + """Tests filters passed to vector search behave as expected""" + index, _ = simple_index_with_docs + + find_query = np.ones(n_dim) + filter_query1 = {"number": {"$lt": 8}} + filter_query2 = {"number": {"$gt": 5}} + + query = ( + index.build_query() # type: ignore[attr-defined] + .find(query=find_query, search_field='embedding') + .filter(query=filter_query1) + .filter(query=filter_query2) + .build(limit=5) + ) + + def trial(): + res = index.execute_query(query) + assert len(res.documents) == 2 + assert set(res.documents.number) == {6, 7} + + assert_when_ready(trial) + + +def test_execute_only_filter( + simple_index_with_docs, # noqa: F811 +): + index, _ = simple_index_with_docs + + filter_query1 = {"number": {"$lt": 8}} + filter_query2 = {"number": {"$gt": 5}} + + query = ( + index.build_query() # type: ignore[attr-defined] + .filter(query=filter_query1) + .filter(query=filter_query2) + .build(limit=5) + ) + + def trial(): + res = index.execute_query(query) + + assert len(res.documents) == 2 + assert set(res.documents.number) == {6, 7} + + assert_when_ready(trial) + + +def test_execute_text_search_with_filter( + simple_index_with_docs, # noqa: F811 +): + """Note: Text search returns only matching _, not limit.""" + index, _ = simple_index_with_docs + + filter_query1 = {"number": {"$eq": 0}} + + query = ( + index.build_query() # type: ignore[attr-defined] + .text_search(query="Python is a valuable skill", search_field='text') + .filter(query=filter_query1) + .build(limit=5) + ) + + def trial(): + res = index.execute_query(query) + + assert len(res.documents) == 1 + assert set(res.documents.number) == {0} + + assert_when_ready(trial) + + +def test_find( + simple_index_with_docs, + n_dim, # noqa: F811 +): + index, _ = simple_index_with_docs + limit = 3 + # Base Case: No filters, single text search, single vector search + query = ( + index.build_query() # type: ignore[attr-defined] + .find(query=np.ones(n_dim), search_field='embedding') + .build(limit=limit) + ) + + def trial(): + res = index.execute_query(query) + assert len(res.documents) == limit + assert res.documents.number == [5, 4, 6] + + assert_when_ready(trial) + + +def test_hybrid_search(simple_index_with_docs, n_dim): # noqa: F811 + find_query = np.ones(n_dim) + index, docs = simple_index_with_docs + n_docs = len(docs) + limit = n_docs + + # Base Case: No filters, single text search, single vector search + query = ( + index.build_query() # type: ignore[attr-defined] + .find(query=find_query, search_field='embedding') + .text_search(query="Python is a valuable skill", search_field='text') + .build(limit=limit) + ) + + def trial(): + res = index.execute_query(query) + assert len(res.documents) == limit + assert set(res.documents.number) == set(range(n_docs)) + + assert_when_ready(trial) + + # Now that we've successfully executed a query, we know that the search indexes have been built + # We no longer need to sleep and retry. Re-run to keep results + res_base = index.execute_query(query) + + # Case 2: Base plus a filter + filter_query1 = {"number": {"$gt": 0}} + + query = ( + index.build_query() # type: ignore[attr-defined] + .find(query=find_query, search_field='embedding') + .text_search(query="Python is a valuable skill", search_field='text') + .filter(query=filter_query1) + .build(limit=n_docs) + ) + + res = index.execute_query(query) + assert len(res.documents) == 9 + assert set(res.documents.number) == set(range(1, n_docs)) + + # Case 3: Base with, but matching, additional vector search component + # As we are using averaging to combine embedding vectors, this is a no-op + query = ( + index.build_query() # type: ignore[attr-defined] + .find(query=find_query, search_field='embedding') + .find(query=find_query, search_field='embedding') + .text_search(query="Python is a valuable skill", search_field='text') + .build(limit=n_docs) + ) + res3 = index.execute_query(query) + assert res3.documents.number == res_base.documents.number + + # Case 4: Base with, but perpendicular, additional vector search component + query = ( + index.build_query() # type: ignore[attr-defined] + # .find(query=find_query, search_field='embedding') + .find( + query=np.random.standard_normal(find_query.shape), search_field='embedding' + ) + .text_search(query="Python is a valuable skill", search_field='text') + .build(limit=n_docs) + ) + res4 = index.execute_query(query) + assert res4.documents.number != res_base.documents.number + + # Case 5: Multiple text searches + query = ( + index.build_query() # type: ignore[attr-defined] + .find(query=find_query, search_field='embedding') + .text_search(query="Python is a valuable skill", search_field='text') + .text_search(query="classical music compositions", search_field='text') + .build(limit=n_docs) + ) + res5 = index.execute_query(query) + assert res5.documents.number[:2] == [0, 3] + + # Case 6: Multiple text search with filters + query = ( + index.build_query() # type: ignore[attr-defined] + .find(query=find_query, search_field='embedding') + .filter(query={"number": {"$gt": 0}}) + .text_search(query="classical music compositions", search_field='text') + .text_search(query="Python is a valuable skill", search_field='text') + .build(limit=n_docs) + ) + res6 = index.execute_query(query) + assert res6.documents.number[0] == 3 + + +def test_hybrid_search_multiple_text(simple_index_with_docs, n_dim): # noqa: F811 + """Tests disambiguation of scores on multiple text searches on same field.""" + + index, _ = simple_index_with_docs + limit = 10 + query = ( + index.build_query() # type: ignore[attr-defined] + .text_search(query="classical music compositions", search_field='text') + .text_search(query="Python is a valuable skill", search_field='text') + .find(query=np.ones(n_dim), search_field='embedding') + .build(limit=limit) + ) + + def trial(): + res = index.execute_query(query, score_breakdown=True) + assert len(res.documents) == limit + assert res.documents.number == [0, 3, 5, 4, 6, 9, 7, 1, 2, 8] + + assert_when_ready(trial) + + +def test_hybrid_search_only_text(simple_index_with_docs): # noqa: F811 + """Query built with two text searches will be a Hybrid Search. + + It will return only two results. + In our case, each text matches just one document, hence we will receive two results, each top ranked + """ + index, _ = simple_index_with_docs + limit = 10 + query = ( + index.build_query() # type: ignore[attr-defined] + .text_search(query="classical music compositions", search_field='text') + .text_search(query="Python is a valuable skill", search_field='text') + .build(limit=limit) + ) + + def trial(): + res = index.execute_query(query) + assert len(res.documents) != limit + # Instead, we find the number of documents containing one of these phrases + assert len(res.documents) == len(query.text_searches) + assert set(res.documents.number) == {0, 3} + assert set(res.scores) == {0.5, 0.5} + + assert_when_ready(trial) + + +def test_hybrid_search_only_vector(simple_index_with_docs, n_dim): # noqa: F811 + + limit = 3 + index, _ = simple_index_with_docs + query = ( + index.build_query() # type: ignore[attr-defined] + .find(query=np.ones(n_dim), search_field='embedding') + .find(query=np.zeros(n_dim), search_field='embedding') + .build(limit=limit) + ) + + def trial(): + res = index.execute_query(query) + assert len(res.documents) == limit + assert res.documents.number == [5, 4, 6] + + assert_when_ready(trial) + + +@pytest.mark.skip +def test_hybrid_search_vectors_with_different_fields( + mongodb_index_config, +): # noqa: F811 + """Hybrid Search involving queries to two different vector indexes. + + # TODO - To be added in an upcoming release. + """ + + from docarray.index.backends.mongodb_atlas import MongoDBAtlasDocumentIndex + from tests.index.mongo_atlas import FlatSchema + + multi_index = MongoDBAtlasDocumentIndex[FlatSchema](**mongodb_index_config) + multi_index._collection.delete_many({}) + + n_dim = 25 + n_docs = 5 + data = [ + FlatSchema( + embedding1=np.random.standard_normal(n_dim), + embedding2=np.random.standard_normal(n_dim), + ) + for _ in range(n_docs) + ] + multi_index.index(data) + yield multi_index + multi_index._collection.delete_many({}) + + limit = 3 + query = ( + multi_index.build_query() # type: ignore[attr-defined] + .find(query=np.ones(n_dim), search_field='embedding1') + .find(query=np.zeros(n_dim), search_field='embedding2') + .build(limit=limit) + ) + + with pytest.raises(NotImplementedError): + + def trial(): + res = multi_index.execute_query(query) + assert len(res.documents) == limit + assert res.documents.number == [5, 4, 6] + + assert_when_ready(trial) diff --git a/tests/index/mongo_atlas/test_subindex.py b/tests/index/mongo_atlas/test_subindex.py index 82f8744221..71e99beca3 100644 --- a/tests/index/mongo_atlas/test_subindex.py +++ b/tests/index/mongo_atlas/test_subindex.py @@ -53,7 +53,7 @@ class MyDoc(BaseDoc): def clean_subindex(index): for subindex in index._subindices.values(): clean_subindex(subindex) - index._doc_collection.delete_many({}) + index._collection.delete_many({}) @pytest.fixture(scope='session') @@ -262,6 +262,4 @@ def test_subindex_del(index): def test_subindex_collections(mongodb_index_config): # noqa: F811 doc_index = MongoDBAtlasDocumentIndex[MetaCategoryDoc](**mongodb_index_config) - assert doc_index._subindices["paths"].index_name == 'metacategorydoc__paths' - assert doc_index._subindices["paths"]._collection == 'metacategorydoc__paths' diff --git a/tests/index/mongo_atlas/test_text_search.py b/tests/index/mongo_atlas/test_text_search.py index cbc6db8058..c480c218c7 100644 --- a/tests/index/mongo_atlas/test_text_search.py +++ b/tests/index/mongo_atlas/test_text_search.py @@ -9,7 +9,7 @@ def test_text_search(simple_index_with_docs): # noqa: F811 def pred(): docs, scores = simple_index.text_search( - query=query_string, search_field='text', limit=1 + query=query_string, search_field='text', limit=10 ) assert len(docs) == 1 assert docs[0].text == expected_text