From e3274502fbef67ecbfcb4213fda339320ed529e0 Mon Sep 17 00:00:00 2001 From: jupyterjazz Date: Mon, 24 Jul 2023 08:34:10 +0200 Subject: [PATCH] feat: update for inmemory index Signed-off-by: jupyterjazz --- docarray/index/backends/in_memory.py | 51 ++++++++++++++----- tests/index/in_memory/test_index_get_del.py | 55 +++++++++++++++++++++ 2 files changed, 95 insertions(+), 11 deletions(-) create mode 100644 tests/index/in_memory/test_index_get_del.py diff --git a/docarray/index/backends/in_memory.py b/docarray/index/backends/in_memory.py index 313f353b1c..3cd879e30e 100644 --- a/docarray/index/backends/in_memory.py +++ b/docarray/index/backends/in_memory.py @@ -94,6 +94,7 @@ def __init__( )() self._embedding_map: Dict[str, Tuple[AnyTensor, Optional[List[int]]]] = {} + self._ids_to_positions: Dict[str, int] = {} def python_type_to_db_type(self, python_type: Type) -> Any: """Map python type to database type. @@ -163,7 +164,13 @@ def index(self, docs: Union[BaseDoc, Sequence[BaseDoc]], **kwargs): """ # implementing the public option because conversion to column dict is not needed docs = self._validate_docs(docs) - self._docs.extend(docs) + ids_to_positions = self._get_ids_to_positions() + for doc in docs: + if doc.id in ids_to_positions: + self._docs[ids_to_positions[doc.id]] = doc + else: + self._docs.append(doc) + self._ids_to_positions[str(doc.id)] = len(self._ids_to_positions) # Add parent_id to all sub-index documents and store sub-index documents data_by_columns = self._get_col_value_dict(docs) @@ -216,6 +223,7 @@ def _del_items(self, doc_ids: Sequence[str]): indices.append(i) del self._docs[indices] + self._update_ids_to_positions() self._rebuild_embedding() def _ori_items(self, doc: BaseDoc) -> BaseDoc: @@ -259,15 +267,18 @@ def _get_items( """ out_docs = [] - for i, doc in enumerate(self._docs): - if doc.id in doc_ids: - if raw: - out_docs.append(doc) - else: - ori_doc = self._ori_items(doc) - schema_cls = cast(Type[BaseDoc], self.out_schema) - new_doc = schema_cls(**ori_doc.__dict__) - out_docs.append(new_doc) + ids_to_positions = self._get_ids_to_positions() + for doc_id in doc_ids: + if doc_id not in ids_to_positions: + continue + doc = self._docs[ids_to_positions[doc_id]] + if raw: + out_docs.append(doc) + else: + ori_doc = self._ori_items(doc) + schema_cls = cast(Type[BaseDoc], self.out_schema) + new_doc = schema_cls(**ori_doc.__dict__) + out_docs.append(new_doc) return out_docs @@ -461,7 +472,7 @@ def _text_search_batched( raise NotImplementedError(f'{type(self)} does not support text search.') def _doc_exists(self, doc_id: str) -> bool: - return any(doc.id == doc_id for doc in self._docs) + return doc_id in self._get_ids_to_positions() def persist(self, file: Optional[str] = None) -> None: """Persist InMemoryExactNNIndex into a binary file.""" @@ -500,3 +511,21 @@ def _get_root_doc_id(self, id: str, root: str, sub: str) -> str: id, fields[0], '__'.join(fields[1:]) ) return self._get_root_doc_id(cur_root_id, root, '') + + def _get_ids_to_positions(self) -> Dict[str, int]: + """ + Obtains a mapping between document IDs and their respective positions + within the DocList. If this mapping hasn't been initialized, it will be created. + + :return: A dictionary mapping each document ID to its corresponding position. + """ + if not self._ids_to_positions: + self._update_ids_to_positions() + return self._ids_to_positions + + def _update_ids_to_positions(self) -> None: + """ + Generates or updates the mapping between document IDs and their corresponding + positions within the DocList. + """ + self._ids_to_positions = {doc.id: pos for pos, doc in enumerate(self._docs)} diff --git a/tests/index/in_memory/test_index_get_del.py b/tests/index/in_memory/test_index_get_del.py new file mode 100644 index 0000000000..185579c154 --- /dev/null +++ b/tests/index/in_memory/test_index_get_del.py @@ -0,0 +1,55 @@ +import numpy as np + +from docarray import BaseDoc, DocList +from docarray.index import InMemoryExactNNIndex +from docarray.typing import NdArray + + +class SimpleDoc(BaseDoc): + embedding: NdArray[128] + text: str + + +def test_update_payload(): + docs = DocList[SimpleDoc]( + [SimpleDoc(embedding=np.random.rand(128), text=f'hey {i}') for i in range(100)] + ) + index = InMemoryExactNNIndex[SimpleDoc]() + index.index(docs) + + assert index.num_docs() == 100 + + for doc in docs: + doc.text += '_changed' + + index.index(docs) + assert index.num_docs() == 100 + + res = index.find(query=docs[0], search_field='embedding', limit=100) + assert len(res.documents) == 100 + for doc in res.documents: + assert '_changed' in doc.text + + +def test_update_embedding(): + docs = DocList[SimpleDoc]( + [SimpleDoc(embedding=np.random.rand(128), text=f'hey {i}') for i in range(100)] + ) + index = InMemoryExactNNIndex[SimpleDoc]() + index.index(docs) + assert index.num_docs() == 100 + + new_tensor = np.random.rand(128) + docs[0].embedding = new_tensor + + index.index(docs[0]) + assert index.num_docs() == 100 + + res = index.find(query=docs[0], search_field='embedding', limit=100) + assert len(res.documents) == 100 + found = False + for doc in res.documents: + if doc.id == docs[0].id: + found = True + assert (doc.embedding == new_tensor).all() + assert found