From fc34e9412e01c503642267eed1f7d78adca9cf4e Mon Sep 17 00:00:00 2001 From: RStar2022 Date: Wed, 26 Apr 2023 00:20:56 +0900 Subject: [PATCH 1/2] add len on DocIndex Signed-off-by: RStar2022 --- docarray/index/abstract.py | 3 +++ tests/index/base_classes/test_base_doc_store.py | 8 +++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/docarray/index/abstract.py b/docarray/index/abstract.py index 17da44221d8..13f4837cd61 100644 --- a/docarray/index/abstract.py +++ b/docarray/index/abstract.py @@ -900,3 +900,6 @@ def _dict_list_to_docarray(self, dict_list: Sequence[Dict[str, Any]]) -> DocList doc_list = [self._convert_dict_to_doc(doc_dict, self._schema) for doc_dict in dict_list] # type: ignore docs_cls = DocList.__class_getitem__(cast(Type[BaseDoc], self._schema)) return docs_cls(doc_list) + + def __len__(self) -> int: + return self.num_docs() diff --git a/tests/index/base_classes/test_base_doc_store.py b/tests/index/base_classes/test_base_doc_store.py index e3c8c455bba..83b5dcd45d2 100644 --- a/tests/index/base_classes/test_base_doc_store.py +++ b/tests/index/base_classes/test_base_doc_store.py @@ -71,7 +71,7 @@ def python_type_to_db_type(self, x): return str _index = _identity - num_docs = _identity + num_docs = lambda n: 3 _del_items = _identity _get_items = _identity execute_query = _identity @@ -572,3 +572,9 @@ def test_validate_search_fields(): # 'ten' is not a valid field with pytest.raises(ValueError): index._validate_search_field('ten') + + +def test_len(): + store = DummyDocIndex[SimpleDoc]() + count = len(store) + assert count == 3 From a7b0df0e628f1bf96c6970743edbcc20d45a1ecb Mon Sep 17 00:00:00 2001 From: RStar2022 Date: Wed, 26 Apr 2023 22:40:28 +0900 Subject: [PATCH 2/2] add json-array Signed-off-by: RStar2022 --- docarray/array/doc_list/io.py | 15 ++++++++++++--- tests/units/array/test_array_save_load.py | 4 ++-- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/docarray/array/doc_list/io.py b/docarray/array/doc_list/io.py index fdad272b94c..45258b2c69d 100644 --- a/docarray/array/doc_list/io.py +++ b/docarray/array/doc_list/io.py @@ -46,8 +46,8 @@ T = TypeVar('T', bound='IOMixinArray') T_doc = TypeVar('T_doc', bound=BaseDoc) -ARRAY_PROTOCOLS = {'protobuf-array', 'pickle-array'} -SINGLE_PROTOCOLS = {'pickle', 'protobuf'} +ARRAY_PROTOCOLS = {'protobuf-array', 'pickle-array', 'json-array'} +SINGLE_PROTOCOLS = {'pickle', 'protobuf', 'json'} ALLOWED_PROTOCOLS = ARRAY_PROTOCOLS.union(SINGLE_PROTOCOLS) ALLOWED_COMPRESSIONS = {'lz4', 'bz2', 'lzma', 'zlib', 'gzip'} @@ -180,6 +180,8 @@ def _write_bytes( f.write(self.to_protobuf().SerializePartialToString()) elif protocol == 'pickle-array': f.write(pickle.dumps(self)) + elif protocol == 'json-array': + f.write(self.to_json()) elif protocol in SINGLE_PROTOCOLS: f.write( b''.join( @@ -575,7 +577,11 @@ def _load_binary_all( else: d = fp.read() - if protocol is not None and protocol in ('pickle-array', 'protobuf-array'): + if protocol is not None and protocol in ( + 'pickle-array', + 'protobuf-array', + 'json-array', + ): if _get_compress_ctx(algorithm=compress) is not None: d = _decompress_bytes(d, algorithm=compress) compress = None @@ -590,6 +596,9 @@ def _load_binary_all( elif protocol is not None and protocol == 'pickle-array': return pickle.loads(d) + elif protocol is not None and protocol == 'json-array': + return cls.from_json(d) + # Binary format for streaming case else: from rich import filesize diff --git a/tests/units/array/test_array_save_load.py b/tests/units/array/test_array_save_load.py index 1a632673d15..a56ad13064a 100644 --- a/tests/units/array/test_array_save_load.py +++ b/tests/units/array/test_array_save_load.py @@ -16,7 +16,7 @@ class MyDoc(BaseDoc): @pytest.mark.slow @pytest.mark.parametrize( - 'protocol', ['pickle-array', 'protobuf-array', 'protobuf', 'pickle'] + 'protocol', ['pickle-array', 'protobuf-array', 'protobuf', 'pickle', 'json-array'] ) @pytest.mark.parametrize('compress', ['lz4', 'bz2', 'lzma', 'zlib', 'gzip', None]) @pytest.mark.parametrize('show_progress', [False, True]) @@ -52,7 +52,7 @@ def test_array_save_load_binary(protocol, compress, tmp_path, show_progress): @pytest.mark.slow @pytest.mark.parametrize( - 'protocol', ['pickle-array', 'protobuf-array', 'protobuf', 'pickle'] + 'protocol', ['pickle-array', 'protobuf-array', 'protobuf', 'pickle', 'json-array'] ) @pytest.mark.parametrize('compress', ['lz4', 'bz2', 'lzma', 'zlib', 'gzip', None]) @pytest.mark.parametrize('show_progress', [False, True])