From ea97d59f46b46a9d3dc6ff3c166509311b246406 Mon Sep 17 00:00:00 2001 From: samsja Date: Fri, 17 Feb 2023 11:34:58 +0100 Subject: [PATCH 01/11] refactor: create io mixin Signed-off-by: samsja --- docarray/base_document/document.py | 96 +------------------ docarray/base_document/mixins/__init__.py | 4 +- .../base_document/mixins/{proto.py => io.py} | 96 ++++++++++++++++++- 3 files changed, 99 insertions(+), 97 deletions(-) rename docarray/base_document/mixins/{proto.py => io.py} (67%) diff --git a/docarray/base_document/document.py b/docarray/base_document/document.py index f04e0093254..7bdce5a8cd2 100644 --- a/docarray/base_document/document.py +++ b/docarray/base_document/document.py @@ -1,26 +1,21 @@ import os -from typing import Type, Optional, TypeVar +from typing import Type import orjson from pydantic import BaseModel, Field, parse_obj_as from rich.console import Console -import pickle -import base64 from docarray.base_document.base_node import BaseNode from docarray.base_document.io.json import orjson_dumps, orjson_dumps_and_decode -from docarray.utils.compress import _compress_bytes, _decompress_bytes -from docarray.base_document.mixins import ProtoMixin, UpdateMixin +from docarray.base_document.mixins import IOMixin, UpdateMixin from docarray.typing import ID _console: Console = Console() -T = TypeVar('T', bound='BaseDocument') - -class BaseDocument(BaseModel, ProtoMixin, UpdateMixin, BaseNode): +class BaseDocument(BaseModel, IOMixin, UpdateMixin, BaseNode): """ - The base class for Document + The base class for Documents """ id: ID = Field(default_factory=lambda: parse_obj_as(ID, os.urandom(16).hex())) @@ -33,7 +28,7 @@ class Config: validate_assignment = True @classmethod - def _get_field_type(cls, field: str) -> Type['BaseDocument']: + def _get_field_type(cls, field: str) -> Type: """ Accessing the nested python Class define in the schema. Could be useful for reconstruction of Document in serialization/deserilization @@ -61,87 +56,6 @@ def schema_summary(cls) -> None: DocumentSummary.schema_summary(cls) - def __bytes__(self) -> bytes: - return self.to_bytes() - - def to_bytes( - self, protocol: str = 'protobuf', compress: Optional[str] = None - ) -> bytes: - """Serialize itself into bytes. - - For more Pythonic code, please use ``bytes(...)``. - - :param protocol: protocol to use. It can be 'pickle' or 'protobuf' - :param compress: compress algorithm to use - :return: the binary serialization in bytes - """ - import pickle - - if protocol == 'pickle': - bstr = pickle.dumps(self) - elif protocol == 'protobuf': - bstr = self.to_protobuf().SerializePartialToString() - else: - raise ValueError( - f'protocol={protocol} is not supported. Can be only `protobuf` or pickle protocols 0-5.' - ) - return _compress_bytes(bstr, algorithm=compress) - - @classmethod - def from_bytes( - cls: Type[T], - data: bytes, - protocol: str = 'protobuf', - compress: Optional[str] = None, - ) -> T: - """Build Document object from binary bytes - - :param data: binary bytes - :param protocol: protocol to use. It can be 'pickle' or 'protobuf' - :param compress: compress method to use - :return: a Document object - """ - bstr = _decompress_bytes(data, algorithm=compress) - if protocol == 'pickle': - return pickle.loads(bstr) - elif protocol == 'protobuf': - from docarray.proto import DocumentProto - - pb_msg = DocumentProto() - pb_msg.ParseFromString(bstr) - return cls.from_protobuf(pb_msg) - else: - raise ValueError( - f'protocol={protocol} is not supported. Can be only `protobuf` or pickle protocols 0-5.' - ) - - def to_base64( - self, protocol: str = 'protobuf', compress: Optional[str] = None - ) -> str: - """Serialize a Document object into as base64 string - - :param protocol: protocol to use. It can be 'pickle' or 'protobuf' - :param compress: compress method to use - :return: a base64 encoded string - """ - return base64.b64encode(self.to_bytes(protocol, compress)).decode('utf-8') - - @classmethod - def from_base64( - cls: Type[T], - data: str, - protocol: str = 'pickle', - compress: Optional[str] = None, - ) -> T: - """Build Document object from binary bytes - - :param data: a base64 encoded string - :param protocol: protocol to use. It can be 'pickle' or 'protobuf' - :param compress: compress method to use - :return: a Document object - """ - return cls.from_bytes(base64.b64decode(data), protocol, compress) - def _ipython_display_(self): """Displays the object in IPython as a summary""" self.summary() diff --git a/docarray/base_document/mixins/__init__.py b/docarray/base_document/mixins/__init__.py index e4fdf7a6e7e..53b3242874a 100644 --- a/docarray/base_document/mixins/__init__.py +++ b/docarray/base_document/mixins/__init__.py @@ -1,4 +1,4 @@ -from docarray.base_document.mixins.proto import ProtoMixin +from docarray.base_document.mixins.io import IOMixin from docarray.base_document.mixins.update import UpdateMixin -__all__ = ['ProtoMixin', 'UpdateMixin'] +__all__ = ['IOMixin', 'UpdateMixin'] diff --git a/docarray/base_document/mixins/proto.py b/docarray/base_document/mixins/io.py similarity index 67% rename from docarray/base_document/mixins/proto.py rename to docarray/base_document/mixins/io.py index 10bb9fe991d..05806b35680 100644 --- a/docarray/base_document/mixins/proto.py +++ b/docarray/base_document/mixins/io.py @@ -1,8 +1,11 @@ +import base64 +import pickle from abc import abstractmethod -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Type, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, TypeVar from docarray.base_document.base_node import BaseNode from docarray.typing.proto_register import _PROTO_TYPE_NAME_TO_CLASS +from docarray.utils.compress import _compress_bytes, _decompress_bytes if TYPE_CHECKING: from pydantic.fields import ModelField @@ -10,17 +13,102 @@ from docarray.proto import DocumentProto, NodeProto -T = TypeVar('T', bound='ProtoMixin') +T = TypeVar('T', bound='IOMixin') -class ProtoMixin(Iterable): +class IOMixin: + """ + IOMixin to define all the bytes/protobuf/json related part of BaseDocument + """ + __fields__: Dict[str, 'ModelField'] @classmethod @abstractmethod - def _get_field_type(cls, field: str) -> Type['ProtoMixin']: + def _get_field_type(cls, field: str) -> Type: ... + def __bytes__(self) -> bytes: + return self.to_bytes() + + def to_bytes( + self, protocol: str = 'protobuf', compress: Optional[str] = None + ) -> bytes: + """Serialize itself into bytes. + + For more Pythonic code, please use ``bytes(...)``. + + :param protocol: protocol to use. It can be 'pickle' or 'protobuf' + :param compress: compress algorithm to use + :return: the binary serialization in bytes + """ + import pickle + + if protocol == 'pickle': + bstr = pickle.dumps(self) + elif protocol == 'protobuf': + bstr = self.to_protobuf().SerializePartialToString() + else: + raise ValueError( + f'protocol={protocol} is not supported. Can be only `protobuf` or pickle protocols 0-5.' + ) + return _compress_bytes(bstr, algorithm=compress) + + @classmethod + def from_bytes( + cls: Type[T], + data: bytes, + protocol: str = 'protobuf', + compress: Optional[str] = None, + ) -> T: + """Build Document object from binary bytes + + :param data: binary bytes + :param protocol: protocol to use. It can be 'pickle' or 'protobuf' + :param compress: compress method to use + :return: a Document object + """ + bstr = _decompress_bytes(data, algorithm=compress) + if protocol == 'pickle': + return pickle.loads(bstr) + elif protocol == 'protobuf': + from docarray.proto import DocumentProto + + pb_msg = DocumentProto() + pb_msg.ParseFromString(bstr) + return cls.from_protobuf(pb_msg) + else: + raise ValueError( + f'protocol={protocol} is not supported. Can be only `protobuf` or pickle protocols 0-5.' + ) + + def to_base64( + self, protocol: str = 'protobuf', compress: Optional[str] = None + ) -> str: + """Serialize a Document object into as base64 string + + :param protocol: protocol to use. It can be 'pickle' or 'protobuf' + :param compress: compress method to use + :return: a base64 encoded string + """ + return base64.b64encode(self.to_bytes(protocol, compress)).decode('utf-8') + + @classmethod + def from_base64( + cls: Type[T], + data: str, + protocol: str = 'pickle', + compress: Optional[str] = None, + ) -> T: + """Build Document object from binary bytes + + :param data: a base64 encoded string + :param protocol: protocol to use. It can be 'pickle' or 'protobuf' + :param compress: compress method to use + :return: a Document object + """ + return cls.from_bytes(base64.b64decode(data), protocol, compress) + @classmethod def from_protobuf(cls: Type[T], pb_msg: 'DocumentProto') -> T: """create a Document from a protobuf message From 3d2e6563f6b842408ac39725d9aa4b8cbba74c05 Mon Sep 17 00:00:00 2001 From: samsja Date: Fri, 17 Feb 2023 11:47:35 +0100 Subject: [PATCH 02/11] fix: fix mypy Signed-off-by: samsja --- docarray/array/array.py | 34 ++++++++++++++++------------- docarray/base_document/mixins/io.py | 14 ++++++++++-- docarray/utils/find.py | 4 ++-- 3 files changed, 33 insertions(+), 19 deletions(-) diff --git a/docarray/array/array.py b/docarray/array/array.py index b8bb580777c..87cecab589c 100644 --- a/docarray/array/array.py +++ b/docarray/array/array.py @@ -1,40 +1,39 @@ +import base64 +import io +import json +import os +import pathlib +import pickle from contextlib import contextmanager, nullcontext from functools import wraps from typing import ( TYPE_CHECKING, Any, + BinaryIO, Callable, + ContextManager, + Generator, Generic, Iterable, List, Optional, Sequence, + Tuple, Type, TypeVar, Union, cast, overload, - BinaryIO, - ContextManager, - Tuple, - Generator, ) import numpy as np -import json -import io -import os -import pickle -import pathlib -import base64 - from typing_inspect import is_union_type from docarray.array.abstract_array import AnyDocumentArray from docarray.base_document import AnyDocument, BaseDocument from docarray.typing import NdArray -from docarray.utils.misc import is_torch_available from docarray.utils.compress import _decompress_bytes, _get_compress_ctx +from docarray.utils.misc import is_torch_available if TYPE_CHECKING: from pydantic import BaseConfig @@ -268,7 +267,10 @@ def __setitem__(self: T, key: IndexIterType, value: Union[T, BaseDocument]): return self._set_by_mask(key_norm_, value_) elif isinstance(head, int): key_norm__ = cast(Iterable[int], key_norm) - return self._set_by_indices(key_norm__, value) + value_ = cast(Sequence[BaseDocument], value) # this is no strictly true + # set_by_mask requires value_ to have getitem which + # _normalize_index_item() ensures + return self._set_by_indices(key_norm__, value_) else: raise TypeError(f'Invalid type {type(head)} for indexing') @@ -566,6 +568,7 @@ def _write_bytes( f.write(pickle.dumps(self)) elif protocol in SINGLE_PROTOCOLS: from rich import filesize + from docarray.utils.progress_bar import _get_progressbar pbar, t = _get_progressbar( @@ -741,6 +744,7 @@ def _load_binary_all( # Binary format for streaming case else: from rich import filesize + from docarray.utils.progress_bar import _get_progressbar # 1 byte (uint8) @@ -797,10 +801,10 @@ def _load_binary_stream( :return: a generator of `Document` objects """ - from docarray import BaseDocument + from rich import filesize + from docarray import BaseDocument from docarray.utils.progress_bar import _get_progressbar - from rich import filesize with file_ctx as f: version_numdocs_lendoc0 = f.read(9) diff --git a/docarray/base_document/mixins/io.py b/docarray/base_document/mixins/io.py index 05806b35680..25c21c3ffac 100644 --- a/docarray/base_document/mixins/io.py +++ b/docarray/base_document/mixins/io.py @@ -1,7 +1,17 @@ import base64 import pickle from abc import abstractmethod -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + Optional, + Tuple, + Type, + TypeVar, +) from docarray.base_document.base_node import BaseNode from docarray.typing.proto_register import _PROTO_TYPE_NAME_TO_CLASS @@ -16,7 +26,7 @@ T = TypeVar('T', bound='IOMixin') -class IOMixin: +class IOMixin(Iterable[Tuple[str, Any]]): """ IOMixin to define all the bytes/protobuf/json related part of BaseDocument """ diff --git a/docarray/utils/find.py b/docarray/utils/find.py index 28898752df4..c20f183abb8 100644 --- a/docarray/utils/find.py +++ b/docarray/utils/find.py @@ -1,4 +1,4 @@ -from typing import List, NamedTuple, Optional, Type, Union +from typing import List, NamedTuple, Optional, Type, Union, cast from typing_inspect import is_union_type @@ -284,4 +284,4 @@ def _da_attr_type(da: AnyDocumentArray, attr: str) -> Type[AnyTensor]: f'but {field_type.__class__.__name__}' ) - return field_type + return cast(Type[AnyTensor], field_type) From cb30f82bd9b57982cba1169a97a96925b3d05188 Mon Sep 17 00:00:00 2001 From: samsja Date: Fri, 17 Feb 2023 11:50:59 +0100 Subject: [PATCH 03/11] refactor: move array and array stacked to subfolder Signed-off-by: samsja --- docarray/__init__.py | 2 +- docarray/array/__init__.py | 4 +-- docarray/array/array/__init__.py | 0 docarray/array/{ => array}/array.py | 8 +++--- docarray/array/stacked/__init__.py | 0 docarray/array/{ => stacked}/array_stacked.py | 2 +- docarray/utils/filter.py | 27 ++++++++++++------- docarray/utils/find.py | 4 +-- .../torch/data/test_torch_dataset.py | 6 ++--- tests/units/array/test_array_proto.py | 2 +- 10 files changed, 31 insertions(+), 24 deletions(-) create mode 100644 docarray/array/array/__init__.py rename docarray/array/{ => array}/array.py (99%) create mode 100644 docarray/array/stacked/__init__.py rename docarray/array/{ => stacked}/array_stacked.py (99%) diff --git a/docarray/__init__.py b/docarray/__init__.py index bfc0842a846..9482eae3ebf 100644 --- a/docarray/__init__.py +++ b/docarray/__init__.py @@ -1,6 +1,6 @@ __version__ = '2023.01.18.alpha' -from docarray.array.array import DocumentArray +from docarray.array.array.array import DocumentArray from docarray.base_document.document import BaseDocument __all__ = [ diff --git a/docarray/array/__init__.py b/docarray/array/__init__.py index 7099e10f238..1b88646ebf1 100644 --- a/docarray/array/__init__.py +++ b/docarray/array/__init__.py @@ -1,4 +1,4 @@ -from docarray.array.array import DocumentArray -from docarray.array.array_stacked import DocumentArrayStacked +from docarray.array.array.array import DocumentArray +from docarray.array.stacked.array_stacked import DocumentArrayStacked __all__ = ['DocumentArray', 'DocumentArrayStacked'] diff --git a/docarray/array/array/__init__.py b/docarray/array/array/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/docarray/array/array.py b/docarray/array/array/array.py similarity index 99% rename from docarray/array/array.py rename to docarray/array/array/array.py index 87cecab589c..5ce98624fcd 100644 --- a/docarray/array/array.py +++ b/docarray/array/array/array.py @@ -39,7 +39,7 @@ from pydantic import BaseConfig from pydantic.fields import ModelField - from docarray.array.array_stacked import DocumentArrayStacked + from docarray.array.stacked.array_stacked import DocumentArrayStacked from docarray.proto import DocumentArrayProto from docarray.typing import TorchTensor from docarray.typing.tensor.abstract_tensor import AbstractTensor @@ -448,7 +448,7 @@ def stacked_mode(self): ... """ - from docarray.array.array_stacked import DocumentArrayStacked + from docarray.array.stacked.array_stacked import DocumentArrayStacked try: da_stacked = DocumentArrayStacked.__class_getitem__(self.document_type)( @@ -465,7 +465,7 @@ def stack(self) -> 'DocumentArrayStacked': Convert the DocumentArray into a DocumentArrayStacked. `Self` cannot be used afterwards """ - from docarray.array.array_stacked import DocumentArrayStacked + from docarray.array.stacked.array_stacked import DocumentArrayStacked return DocumentArrayStacked.__class_getitem__(self.document_type)(self) @@ -476,7 +476,7 @@ def validate( field: 'ModelField', config: 'BaseConfig', ): - from docarray.array.array_stacked import DocumentArrayStacked + from docarray.array.stacked.array_stacked import DocumentArrayStacked if isinstance(value, (cls, DocumentArrayStacked)): return value diff --git a/docarray/array/stacked/__init__.py b/docarray/array/stacked/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/docarray/array/array_stacked.py b/docarray/array/stacked/array_stacked.py similarity index 99% rename from docarray/array/array_stacked.py rename to docarray/array/stacked/array_stacked.py index 95795b6e6ff..cf2bf2accab 100644 --- a/docarray/array/array_stacked.py +++ b/docarray/array/stacked/array_stacked.py @@ -15,7 +15,7 @@ ) from docarray.array.abstract_array import AnyDocumentArray -from docarray.array.array import DocumentArray +from docarray.array.array.array import DocumentArray from docarray.base_document import AnyDocument, BaseDocument from docarray.typing import NdArray from docarray.typing.tensor.abstract_tensor import AbstractTensor diff --git a/docarray/utils/filter.py b/docarray/utils/filter.py index 6d666a4a96f..97233016200 100644 --- a/docarray/utils/filter.py +++ b/docarray/utils/filter.py @@ -1,10 +1,8 @@ import json - -from typing import Union, Dict, List - +from typing import Dict, List, Union from docarray.array.abstract_array import AnyDocumentArray -from docarray.array.array import DocumentArray +from docarray.array.array.array import DocumentArray def filter( @@ -31,12 +29,21 @@ class MyDocument(BaseDocument): docs = DocumentArray[MyDocument]( - [MyDocument(caption='A tiger in the jungle', - image=Image(url='tigerphoto.png'), price=100), - MyDocument(caption='A swimming turtle', - image=Image(url='turtlepic.png'), price=50), - MyDocument(caption='A couple birdwatching with binoculars', - image=Image(url='binocularsphoto.png'), price=30)] + [ + MyDocument( + caption='A tiger in the jungle', + image=Image(url='tigerphoto.png'), + price=100, + ), + MyDocument( + caption='A swimming turtle', image=Image(url='turtlepic.png'), price=50 + ), + MyDocument( + caption='A couple birdwatching with binoculars', + image=Image(url='binocularsphoto.png'), + price=30, + ), + ] ) query = { '$and': { diff --git a/docarray/utils/find.py b/docarray/utils/find.py index c20f183abb8..60025ca9c19 100644 --- a/docarray/utils/find.py +++ b/docarray/utils/find.py @@ -3,8 +3,8 @@ from typing_inspect import is_union_type from docarray.array.abstract_array import AnyDocumentArray -from docarray.array.array import DocumentArray -from docarray.array.array_stacked import DocumentArrayStacked +from docarray.array.array.array import DocumentArray +from docarray.array.stacked.array_stacked import DocumentArrayStacked from docarray.base_document import BaseDocument from docarray.typing import AnyTensor from docarray.typing.tensor.abstract_tensor import AbstractTensor diff --git a/tests/integrations/torch/data/test_torch_dataset.py b/tests/integrations/torch/data/test_torch_dataset.py index d79e57f716a..c9f6c54a8fe 100644 --- a/tests/integrations/torch/data/test_torch_dataset.py +++ b/tests/integrations/torch/data/test_torch_dataset.py @@ -56,7 +56,7 @@ def test_torch_dataset(captions_da: DocumentArray[PairTextImage]): dataset, batch_size=BATCH_SIZE, collate_fn=dataset.collate_fn, shuffle=True ) - from docarray.array.array_stacked import DocumentArrayStacked + from docarray.array.stacked.array_stacked import DocumentArrayStacked batch_lens = [] for batch in loader: @@ -135,7 +135,7 @@ def test_torch_dl_multiprocessing(captions_da: DocumentArray[PairTextImage]): multiprocessing_context='fork', ) - from docarray.array.array_stacked import DocumentArrayStacked + from docarray.array.stacked.array_stacked import DocumentArrayStacked batch_lens = [] for batch in loader: @@ -163,7 +163,7 @@ def test_torch_dl_pin_memory(captions_da: DocumentArray[PairTextImage]): multiprocessing_context='fork', ) - from docarray.array.array_stacked import DocumentArrayStacked + from docarray.array.stacked.array_stacked import DocumentArrayStacked batch_lens = [] for batch in loader: diff --git a/tests/units/array/test_array_proto.py b/tests/units/array/test_array_proto.py index 062062d6f2d..dd8de4014e2 100644 --- a/tests/units/array/test_array_proto.py +++ b/tests/units/array/test_array_proto.py @@ -2,7 +2,7 @@ import pytest from docarray import BaseDocument, DocumentArray -from docarray.array.array_stacked import DocumentArrayStacked +from docarray.array.stacked.array_stacked import DocumentArrayStacked from docarray.documents import Image, Text from docarray.typing import NdArray From 27bbbd1b189e9978448813393e048e8df6ed0805 Mon Sep 17 00:00:00 2001 From: samsja Date: Fri, 17 Feb 2023 11:56:40 +0100 Subject: [PATCH 04/11] refactor: move document array io to mixin Signed-off-by: samsja --- docarray/array/array/array.py | 504 +-------------------------------- docarray/array/array/io.py | 516 ++++++++++++++++++++++++++++++++++ 2 files changed, 519 insertions(+), 501 deletions(-) create mode 100644 docarray/array/array/io.py diff --git a/docarray/array/array/array.py b/docarray/array/array/array.py index 5ce98624fcd..2b3c16b8fb9 100644 --- a/docarray/array/array/array.py +++ b/docarray/array/array/array.py @@ -1,24 +1,15 @@ -import base64 import io -import json -import os -import pathlib -import pickle -from contextlib import contextmanager, nullcontext +from contextlib import contextmanager from functools import wraps from typing import ( TYPE_CHECKING, Any, - BinaryIO, Callable, - ContextManager, - Generator, Generic, Iterable, List, Optional, Sequence, - Tuple, Type, TypeVar, Union, @@ -30,9 +21,9 @@ from typing_inspect import is_union_type from docarray.array.abstract_array import AnyDocumentArray +from docarray.array.array.io import IOMixinArray from docarray.base_document import AnyDocument, BaseDocument from docarray.typing import NdArray -from docarray.utils.compress import _decompress_bytes, _get_compress_ctx from docarray.utils.misc import is_torch_available if TYPE_CHECKING: @@ -40,7 +31,6 @@ from pydantic.fields import ModelField from docarray.array.stacked.array_stacked import DocumentArrayStacked - from docarray.proto import DocumentArrayProto from docarray.typing import TorchTensor from docarray.typing.tensor.abstract_tensor import AbstractTensor @@ -48,42 +38,6 @@ T_doc = TypeVar('T_doc', bound=BaseDocument) IndexIterType = Union[slice, Iterable[int], Iterable[bool], None] -ARRAY_PROTOCOLS = {'protobuf-array', 'pickle-array'} -SINGLE_PROTOCOLS = {'pickle', 'protobuf'} -ALLOWED_PROTOCOLS = ARRAY_PROTOCOLS.union(SINGLE_PROTOCOLS) -ALLOWED_COMPRESSIONS = {'lz4', 'bz2', 'lzma', 'zlib', 'gzip'} - - -def _protocol_and_compress_from_file_path( - file_path: Union[pathlib.Path, str], - default_protocol: Optional[str] = None, - default_compress: Optional[str] = None, -) -> Tuple[Optional[str], Optional[str]]: - """Extract protocol and compression algorithm from a string, use defaults if not found. - :param file_path: path of a file. - :param default_protocol: default serialization protocol used in case not found. - :param default_compress: default compression method used in case not found. - Examples: - >>> _protocol_and_compress_from_file_path('./docarray_fashion_mnist.protobuf.gzip') - ('protobuf', 'gzip') - >>> _protocol_and_compress_from_file_path('/Documents/docarray_fashion_mnist.protobuf') - ('protobuf', None) - >>> _protocol_and_compress_from_file_path('/Documents/docarray_fashion_mnist.gzip') - (None, gzip) - """ - - protocol = default_protocol - compress = default_compress - - file_extensions = [e.replace('.', '') for e in pathlib.Path(file_path).suffixes] - for extension in file_extensions: - if extension in ALLOWED_PROTOCOLS: - protocol = extension - elif extension in ALLOWED_COMPRESSIONS: - compress = extension - - return protocol, compress - def _delegate_meth_to_data(meth_name: str) -> Callable: """ @@ -113,21 +67,7 @@ def _is_np_int(item: Any) -> bool: return False # this is unreachable, but mypy wants it -class _LazyRequestReader: - def __init__(self, r): - self._data = r.iter_content(chunk_size=1024 * 1024) - self.content = b'' - - def __getitem__(self, item: slice): - while len(self.content) < item.stop: - try: - self.content += next(self._data) - except StopIteration: - return self.content[item.start : -1 : item.step] - return self.content[item] - - -class DocumentArray(AnyDocumentArray, Generic[T_doc]): +class DocumentArray(IOMixinArray, AnyDocumentArray, Generic[T_doc]): """ DocumentArray is a container of Documents. @@ -493,441 +433,3 @@ def traverse_flat( flattened = AnyDocumentArray._flatten_one_level(nodes) return flattened - - # Methods to load from/to different formats - - @classmethod - def from_protobuf(cls: Type[T], pb_msg: 'DocumentArrayProto') -> T: - """create a Document from a protobuf message - :param pb_msg: The protobuf message from where to construct the DocumentArray - """ - return cls( - cls.document_type.from_protobuf(doc_proto) for doc_proto in pb_msg.docs - ) - - def to_protobuf(self) -> 'DocumentArrayProto': - """Convert DocumentArray into a Protobuf message""" - from docarray.proto import DocumentArrayProto - - da_proto = DocumentArrayProto() - for doc in self: - da_proto.docs.append(doc.to_protobuf()) - - return da_proto - - @classmethod - def from_bytes( - cls: Type[T], - data: bytes, - protocol: str = 'protobuf-array', - compress: Optional[str] = None, - show_progress: bool = False, - ) -> T: - """Deserialize bytes into a DocumentArray. - - :param data: Bytes from which to deserialize - :param protocol: protocol that was used to serialize - :param compress: compress algorithm that was used to serialize - :param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` - :return: the deserialized DocumentArray - """ - return cls._load_binary_all( - file_ctx=nullcontext(data), - protocol=protocol, - compress=compress, - show_progress=show_progress, - ) - - def _write_bytes( - self, - bf: BinaryIO, - protocol: str = 'protobuf-array', - compress: Optional[str] = None, - show_progress: bool = False, - ) -> None: - if protocol in ARRAY_PROTOCOLS: - compress_ctx = _get_compress_ctx(compress) - else: - # delegate the compression to per-doc compression - compress_ctx = None - - fc: ContextManager - if compress_ctx is None: - # if compress do not support streaming then postpone the compress - # into the for-loop - f, fc = bf, nullcontext() - else: - f = compress_ctx(bf) - fc = f - compress = None - - with fc: - if protocol == 'protobuf-array': - f.write(self.to_protobuf().SerializePartialToString()) - elif protocol == 'pickle-array': - f.write(pickle.dumps(self)) - elif protocol in SINGLE_PROTOCOLS: - from rich import filesize - - from docarray.utils.progress_bar import _get_progressbar - - pbar, t = _get_progressbar( - 'Serializing', disable=not show_progress, total=len(self) - ) - - f.write(self._stream_header) - - with pbar: - _total_size = 0 - pbar.start_task(t) - for doc in self: - doc_bytes = doc.to_bytes(protocol=protocol, compress=compress) - len_doc_as_bytes = len(doc_bytes).to_bytes( - 4, 'big', signed=False - ) - all_bytes = len_doc_as_bytes + doc_bytes - f.write(all_bytes) - _total_size += len(all_bytes) - pbar.update( - t, - advance=1, - total_size=str(filesize.decimal(_total_size)), - ) - else: - raise ValueError( - f'protocol={protocol} is not supported. Can be only {ALLOWED_PROTOCOLS}.' - ) - - def to_bytes( - self, - protocol: str = 'protobuf-array', - compress: Optional[str] = None, - file_ctx: Optional[BinaryIO] = None, - show_progress: bool = False, - ) -> Optional[bytes]: - """Serialize itself into bytes. - - For more Pythonic code, please use ``bytes(...)``. - - :param protocol: protocol to use. It can be 'pickle-array', 'protobuf-array', 'pickle' or 'protobuf' - :param compress: compress algorithm to use - :param file_ctx: File or filename or serialized bytes where the data is stored. - :param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` - :return: the binary serialization in bytes or None if file_ctx is passed where to store - """ - - with (file_ctx or io.BytesIO()) as bf: - self._write_bytes( - bf=bf, - protocol=protocol, - compress=compress, - show_progress=show_progress, - ) - if isinstance(bf, io.BytesIO): - return bf.getvalue() - - return None - - @classmethod - def from_base64( - cls: Type[T], - data: str, - protocol: str = 'protobuf-array', - compress: Optional[str] = None, - show_progress: bool = False, - ) -> T: - """Deserialize base64 strings into a DocumentArray. - - :param data: Base64 string to deserialize - :param protocol: protocol that was used to serialize - :param compress: compress algorithm that was used to serialize - :param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` - :return: the deserialized DocumentArray - """ - return cls._load_binary_all( - file_ctx=nullcontext(base64.b64decode(data)), - protocol=protocol, - compress=compress, - show_progress=show_progress, - ) - - def to_base64( - self, - protocol: str = 'protobuf-array', - compress: Optional[str] = None, - show_progress: bool = False, - ) -> str: - """Serialize itself into base64 encoded string. - - :param protocol: protocol to use. It can be 'pickle-array', 'protobuf-array', 'pickle' or 'protobuf' - :param compress: compress algorithm to use - :param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` - :return: the binary serialization in bytes or None if file_ctx is passed where to store - """ - with io.BytesIO() as bf: - self._write_bytes( - bf=bf, - compress=compress, - protocol=protocol, - show_progress=show_progress, - ) - return base64.b64encode(bf.getvalue()).decode('utf-8') - - @classmethod - def from_json( - cls: Type[T], - file: Union[str, bytes, bytearray], - ) -> T: - """Deserialize JSON strings or bytes into a DocumentArray. - - :param file: JSON object from where to deserialize a DocumentArray - :return: the deserialized DocumentArray - """ - json_docs = json.loads(file) - return cls([cls.document_type.parse_raw(v) for v in json_docs]) - - def to_json(self) -> str: - """Convert the object into a JSON string. Can be loaded via :meth:`.from_json`. - :return: JSON serialization of DocumentArray - """ - return json.dumps([doc.json() for doc in self]) - - # Methods to load from/to files in different formats - @property - def _stream_header(self) -> bytes: - # Binary format for streaming case - - # V1 DocArray streaming serialization format - # | 1 byte | 8 bytes | 4 bytes | variable | 4 bytes | variable ... - - # 1 byte (uint8) - version_byte = b'\x01' - # 8 bytes (uint64) - num_docs_as_bytes = len(self).to_bytes(8, 'big', signed=False) - return version_byte + num_docs_as_bytes - - @classmethod - def _load_binary_all( - cls: Type[T], - file_ctx: Union[ContextManager[io.BufferedReader], ContextManager[bytes]], - protocol: Optional[str], - compress: Optional[str], - show_progress: bool, - ): - """Read a `DocumentArray` object from a binary file - :param protocol: protocol to use. It can be 'pickle-array', 'protobuf-array', 'pickle' or 'protobuf' - :param compress: compress algorithm to use - :param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` - :return: a `DocumentArray` - """ - with file_ctx as fp: - if isinstance(fp, bytes): - d = fp - else: - d = fp.read() - - if protocol is not None and protocol in ('pickle-array', 'protobuf-array'): - if _get_compress_ctx(algorithm=compress) is not None: - d = _decompress_bytes(d, algorithm=compress) - compress = None - - if protocol is not None and protocol == 'protobuf-array': - from docarray.proto import DocumentArrayProto - - dap = DocumentArrayProto() - dap.ParseFromString(d) - - return cls.from_protobuf(dap) - elif protocol is not None and protocol == 'pickle-array': - return pickle.loads(d) - - # Binary format for streaming case - else: - from rich import filesize - - from docarray.utils.progress_bar import _get_progressbar - - # 1 byte (uint8) - # 8 bytes (uint64) - num_docs = int.from_bytes(d[1:9], 'big', signed=False) - - pbar, t = _get_progressbar( - 'Deserializing', disable=not show_progress, total=num_docs - ) - - # this 9 is version + num_docs bytes used - start_pos = 9 - docs = [] - with pbar: - _total_size = 0 - pbar.start_task(t) - - for _ in range(num_docs): - # 4 bytes (uint32) - len_current_doc_in_bytes = int.from_bytes( - d[start_pos : start_pos + 4], 'big', signed=False - ) - start_doc_pos = start_pos + 4 - end_doc_pos = start_doc_pos + len_current_doc_in_bytes - start_pos = end_doc_pos - - # variable length bytes doc - load_protocol: str = protocol or 'protobuf' - doc = cls.document_type.from_bytes( - d[start_doc_pos:end_doc_pos], - protocol=load_protocol, - compress=compress, - ) - docs.append(doc) - _total_size += len_current_doc_in_bytes - pbar.update( - t, advance=1, total_size=str(filesize.decimal(_total_size)) - ) - return cls(docs) - - @classmethod - def _load_binary_stream( - cls: Type[T], - file_ctx: ContextManager[io.BufferedReader], - protocol: Optional[str] = None, - compress: Optional[str] = None, - show_progress: bool = False, - ) -> Generator['BaseDocument', None, None]: - """Yield `Document` objects from a binary file - - :param protocol: protocol to use. It can be 'pickle' or 'protobuf' - :param compress: compress algorithm to use - :param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` - :return: a generator of `Document` objects - """ - - from rich import filesize - - from docarray import BaseDocument - from docarray.utils.progress_bar import _get_progressbar - - with file_ctx as f: - version_numdocs_lendoc0 = f.read(9) - # 1 byte (uint8) - # 8 bytes (uint64) - num_docs = int.from_bytes(version_numdocs_lendoc0[1:9], 'big', signed=False) - - pbar, t = _get_progressbar( - 'Deserializing', disable=not show_progress, total=num_docs - ) - - with pbar: - _total_size = 0 - pbar.start_task(t) - for _ in range(num_docs): - # 4 bytes (uint32) - len_current_doc_in_bytes = int.from_bytes( - f.read(4), 'big', signed=False - ) - _total_size += len_current_doc_in_bytes - load_protocol: str = protocol or 'protobuf' - yield BaseDocument.from_bytes( - f.read(len_current_doc_in_bytes), - protocol=load_protocol, - compress=compress, - ) - pbar.update( - t, advance=1, total_size=str(filesize.decimal(_total_size)) - ) - - @classmethod - def load_binary( - cls: Type[T], - file: Union[str, bytes, pathlib.Path, io.BufferedReader, _LazyRequestReader], - protocol: str = 'protobuf-array', - compress: Optional[str] = None, - show_progress: bool = False, - streaming: bool = False, - ) -> Union[T, Generator['BaseDocument', None, None]]: - """Load array elements from a compressed binary file. - - :param file: File or filename or serialized bytes where the data is stored. - :param protocol: protocol to use. It can be 'pickle-array', 'protobuf-array', 'pickle' or 'protobuf' - :param compress: compress algorithm to use - :param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` - :param streaming: if `True` returns a generator over `Document` objects. - In case protocol is pickle the `Documents` are streamed from disk to save memory usage - :return: a DocumentArray object - - .. note:: - If `file` is `str` it can specify `protocol` and `compress` as file extensions. - This functionality assumes `file=file_name.$protocol.$compress` where `$protocol` and `$compress` refer to a - string interpolation of the respective `protocol` and `compress` methods. - For example if `file=my_docarray.protobuf.lz4` then the binary data will be loaded assuming `protocol=protobuf` - and `compress=lz4`. - """ - load_protocol: Optional[str] = protocol - load_compress: Optional[str] = compress - file_ctx: Union[nullcontext, io.BufferedReader] - if isinstance(file, (io.BufferedReader, _LazyRequestReader, bytes)): - file_ctx = nullcontext(file) - # by checking path existence we allow file to be of type Path, LocalPath, PurePath and str - elif isinstance(file, (str, pathlib.Path)) and os.path.exists(file): - load_protocol, load_compress = _protocol_and_compress_from_file_path( - file, protocol, compress - ) - file_ctx = open(file, 'rb') - else: - raise FileNotFoundError(f'cannot find file {file}') - if streaming: - return cls._load_binary_stream( - file_ctx, - protocol=load_protocol, - compress=load_compress, - show_progress=show_progress, - ) - else: - return cls._load_binary_all( - file_ctx, load_protocol, load_compress, show_progress - ) - - def save_binary( - self, - file: Union[str, pathlib.Path], - protocol: str = 'protobuf-array', - compress: Optional[str] = None, - show_progress: bool = False, - ) -> None: - """Save DocumentArray into a binary file. - - It will use the protocol to pick how to save the DocumentArray. - If used 'picke-array` and `protobuf-array` the DocumentArray will be stored - and compressed at complete level using `pickle` or `protobuf`. - When using `protobuf` or `pickle` as protocol each Document in DocumentArray - will be stored individually and this would make it available for streaming. - - :param file: File or filename to which the data is saved. - :param protocol: protocol to use. It can be 'pickle-array', 'protobuf-array', 'pickle' or 'protobuf' - :param compress: compress algorithm to use - :param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` - - .. note:: - If `file` is `str` it can specify `protocol` and `compress` as file extensions. - This functionality assumes `file=file_name.$protocol.$compress` where `$protocol` and `$compress` refer to a - string interpolation of the respective `protocol` and `compress` methods. - For example if `file=my_docarray.protobuf.lz4` then the binary data will be created using `protocol=protobuf` - and `compress=lz4`. - """ - if isinstance(file, io.BufferedWriter): - file_ctx = nullcontext(file) - else: - _protocol, _compress = _protocol_and_compress_from_file_path(file) - - if _protocol is not None: - protocol = _protocol - if _compress is not None: - compress = _compress - - file_ctx = open(file, 'wb') - - self.to_bytes( - protocol=protocol, - compress=compress, - file_ctx=file_ctx, - show_progress=show_progress, - ) diff --git a/docarray/array/array/io.py b/docarray/array/array/io.py new file mode 100644 index 00000000000..bb3e5aff991 --- /dev/null +++ b/docarray/array/array/io.py @@ -0,0 +1,516 @@ +import base64 +import io +import json +import os +import pathlib +import pickle +from contextlib import nullcontext +from typing import ( + TYPE_CHECKING, + BinaryIO, + ContextManager, + Generator, + Optional, + Tuple, + Type, + TypeVar, + Union, +) + +from docarray.base_document import BaseDocument +from docarray.utils.compress import _decompress_bytes, _get_compress_ctx + +if TYPE_CHECKING: + + from docarray.proto import DocumentArrayProto + +T = TypeVar('T', bound='IOMixinArray') + + +ARRAY_PROTOCOLS = {'protobuf-array', 'pickle-array'} +SINGLE_PROTOCOLS = {'pickle', 'protobuf'} +ALLOWED_PROTOCOLS = ARRAY_PROTOCOLS.union(SINGLE_PROTOCOLS) +ALLOWED_COMPRESSIONS = {'lz4', 'bz2', 'lzma', 'zlib', 'gzip'} + + +def _protocol_and_compress_from_file_path( + file_path: Union[pathlib.Path, str], + default_protocol: Optional[str] = None, + default_compress: Optional[str] = None, +) -> Tuple[Optional[str], Optional[str]]: + """Extract protocol and compression algorithm from a string, use defaults if not found. + :param file_path: path of a file. + :param default_protocol: default serialization protocol used in case not found. + :param default_compress: default compression method used in case not found. + Examples: + >>> _protocol_and_compress_from_file_path('./docarray_fashion_mnist.protobuf.gzip') + ('protobuf', 'gzip') + >>> _protocol_and_compress_from_file_path('/Documents/docarray_fashion_mnist.protobuf') + ('protobuf', None) + >>> _protocol_and_compress_from_file_path('/Documents/docarray_fashion_mnist.gzip') + (None, gzip) + """ + + protocol = default_protocol + compress = default_compress + + file_extensions = [e.replace('.', '') for e in pathlib.Path(file_path).suffixes] + for extension in file_extensions: + if extension in ALLOWED_PROTOCOLS: + protocol = extension + elif extension in ALLOWED_COMPRESSIONS: + compress = extension + + return protocol, compress + + +class _LazyRequestReader: + def __init__(self, r): + self._data = r.iter_content(chunk_size=1024 * 1024) + self.content = b'' + + def __getitem__(self, item: slice): + while len(self.content) < item.stop: + try: + self.content += next(self._data) + except StopIteration: + return self.content[item.start : -1 : item.step] + return self.content[item] + + +class IOMixinArray: + @classmethod + def from_protobuf(cls: Type[T], pb_msg: 'DocumentArrayProto') -> T: + """create a Document from a protobuf message + :param pb_msg: The protobuf message from where to construct the DocumentArray + """ + return cls( + cls.document_type.from_protobuf(doc_proto) for doc_proto in pb_msg.docs + ) + + def to_protobuf(self) -> 'DocumentArrayProto': + """Convert DocumentArray into a Protobuf message""" + from docarray.proto import DocumentArrayProto + + da_proto = DocumentArrayProto() + for doc in self: + da_proto.docs.append(doc.to_protobuf()) + + return da_proto + + @classmethod + def from_bytes( + cls: Type[T], + data: bytes, + protocol: str = 'protobuf-array', + compress: Optional[str] = None, + show_progress: bool = False, + ) -> T: + """Deserialize bytes into a DocumentArray. + + :param data: Bytes from which to deserialize + :param protocol: protocol that was used to serialize + :param compress: compress algorithm that was used to serialize + :param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` + :return: the deserialized DocumentArray + """ + return cls._load_binary_all( + file_ctx=nullcontext(data), + protocol=protocol, + compress=compress, + show_progress=show_progress, + ) + + def _write_bytes( + self, + bf: BinaryIO, + protocol: str = 'protobuf-array', + compress: Optional[str] = None, + show_progress: bool = False, + ) -> None: + if protocol in ARRAY_PROTOCOLS: + compress_ctx = _get_compress_ctx(compress) + else: + # delegate the compression to per-doc compression + compress_ctx = None + + fc: ContextManager + if compress_ctx is None: + # if compress do not support streaming then postpone the compress + # into the for-loop + f, fc = bf, nullcontext() + else: + f = compress_ctx(bf) + fc = f + compress = None + + with fc: + if protocol == 'protobuf-array': + f.write(self.to_protobuf().SerializePartialToString()) + elif protocol == 'pickle-array': + f.write(pickle.dumps(self)) + elif protocol in SINGLE_PROTOCOLS: + from rich import filesize + + from docarray.utils.progress_bar import _get_progressbar + + pbar, t = _get_progressbar( + 'Serializing', disable=not show_progress, total=len(self) + ) + + f.write(self._stream_header) + + with pbar: + _total_size = 0 + pbar.start_task(t) + for doc in self: + doc_bytes = doc.to_bytes(protocol=protocol, compress=compress) + len_doc_as_bytes = len(doc_bytes).to_bytes( + 4, 'big', signed=False + ) + all_bytes = len_doc_as_bytes + doc_bytes + f.write(all_bytes) + _total_size += len(all_bytes) + pbar.update( + t, + advance=1, + total_size=str(filesize.decimal(_total_size)), + ) + else: + raise ValueError( + f'protocol={protocol} is not supported. Can be only {ALLOWED_PROTOCOLS}.' + ) + + def to_bytes( + self, + protocol: str = 'protobuf-array', + compress: Optional[str] = None, + file_ctx: Optional[BinaryIO] = None, + show_progress: bool = False, + ) -> Optional[bytes]: + """Serialize itself into bytes. + + For more Pythonic code, please use ``bytes(...)``. + + :param protocol: protocol to use. It can be 'pickle-array', 'protobuf-array', 'pickle' or 'protobuf' + :param compress: compress algorithm to use + :param file_ctx: File or filename or serialized bytes where the data is stored. + :param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` + :return: the binary serialization in bytes or None if file_ctx is passed where to store + """ + + with (file_ctx or io.BytesIO()) as bf: + self._write_bytes( + bf=bf, + protocol=protocol, + compress=compress, + show_progress=show_progress, + ) + if isinstance(bf, io.BytesIO): + return bf.getvalue() + + return None + + @classmethod + def from_base64( + cls: Type[T], + data: str, + protocol: str = 'protobuf-array', + compress: Optional[str] = None, + show_progress: bool = False, + ) -> T: + """Deserialize base64 strings into a DocumentArray. + + :param data: Base64 string to deserialize + :param protocol: protocol that was used to serialize + :param compress: compress algorithm that was used to serialize + :param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` + :return: the deserialized DocumentArray + """ + return cls._load_binary_all( + file_ctx=nullcontext(base64.b64decode(data)), + protocol=protocol, + compress=compress, + show_progress=show_progress, + ) + + def to_base64( + self, + protocol: str = 'protobuf-array', + compress: Optional[str] = None, + show_progress: bool = False, + ) -> str: + """Serialize itself into base64 encoded string. + + :param protocol: protocol to use. It can be 'pickle-array', 'protobuf-array', 'pickle' or 'protobuf' + :param compress: compress algorithm to use + :param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` + :return: the binary serialization in bytes or None if file_ctx is passed where to store + """ + with io.BytesIO() as bf: + self._write_bytes( + bf=bf, + compress=compress, + protocol=protocol, + show_progress=show_progress, + ) + return base64.b64encode(bf.getvalue()).decode('utf-8') + + @classmethod + def from_json( + cls: Type[T], + file: Union[str, bytes, bytearray], + ) -> T: + """Deserialize JSON strings or bytes into a DocumentArray. + + :param file: JSON object from where to deserialize a DocumentArray + :return: the deserialized DocumentArray + """ + json_docs = json.loads(file) + return cls([cls.document_type.parse_raw(v) for v in json_docs]) + + def to_json(self) -> str: + """Convert the object into a JSON string. Can be loaded via :meth:`.from_json`. + :return: JSON serialization of DocumentArray + """ + return json.dumps([doc.json() for doc in self]) + + # Methods to load from/to files in different formats + @property + def _stream_header(self) -> bytes: + # Binary format for streaming case + + # V1 DocArray streaming serialization format + # | 1 byte | 8 bytes | 4 bytes | variable | 4 bytes | variable ... + + # 1 byte (uint8) + version_byte = b'\x01' + # 8 bytes (uint64) + num_docs_as_bytes = len(self).to_bytes(8, 'big', signed=False) + return version_byte + num_docs_as_bytes + + @classmethod + def _load_binary_all( + cls: Type[T], + file_ctx: Union[ContextManager[io.BufferedReader], ContextManager[bytes]], + protocol: Optional[str], + compress: Optional[str], + show_progress: bool, + ): + """Read a `DocumentArray` object from a binary file + :param protocol: protocol to use. It can be 'pickle-array', 'protobuf-array', 'pickle' or 'protobuf' + :param compress: compress algorithm to use + :param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` + :return: a `DocumentArray` + """ + with file_ctx as fp: + if isinstance(fp, bytes): + d = fp + else: + d = fp.read() + + if protocol is not None and protocol in ('pickle-array', 'protobuf-array'): + if _get_compress_ctx(algorithm=compress) is not None: + d = _decompress_bytes(d, algorithm=compress) + compress = None + + if protocol is not None and protocol == 'protobuf-array': + from docarray.proto import DocumentArrayProto + + dap = DocumentArrayProto() + dap.ParseFromString(d) + + return cls.from_protobuf(dap) + elif protocol is not None and protocol == 'pickle-array': + return pickle.loads(d) + + # Binary format for streaming case + else: + from rich import filesize + + from docarray.utils.progress_bar import _get_progressbar + + # 1 byte (uint8) + # 8 bytes (uint64) + num_docs = int.from_bytes(d[1:9], 'big', signed=False) + + pbar, t = _get_progressbar( + 'Deserializing', disable=not show_progress, total=num_docs + ) + + # this 9 is version + num_docs bytes used + start_pos = 9 + docs = [] + with pbar: + _total_size = 0 + pbar.start_task(t) + + for _ in range(num_docs): + # 4 bytes (uint32) + len_current_doc_in_bytes = int.from_bytes( + d[start_pos : start_pos + 4], 'big', signed=False + ) + start_doc_pos = start_pos + 4 + end_doc_pos = start_doc_pos + len_current_doc_in_bytes + start_pos = end_doc_pos + + # variable length bytes doc + load_protocol: str = protocol or 'protobuf' + doc = cls.document_type.from_bytes( + d[start_doc_pos:end_doc_pos], + protocol=load_protocol, + compress=compress, + ) + docs.append(doc) + _total_size += len_current_doc_in_bytes + pbar.update( + t, advance=1, total_size=str(filesize.decimal(_total_size)) + ) + return cls(docs) + + @classmethod + def _load_binary_stream( + cls: Type[T], + file_ctx: ContextManager[io.BufferedReader], + protocol: Optional[str] = None, + compress: Optional[str] = None, + show_progress: bool = False, + ) -> Generator['BaseDocument', None, None]: + """Yield `Document` objects from a binary file + + :param protocol: protocol to use. It can be 'pickle' or 'protobuf' + :param compress: compress algorithm to use + :param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` + :return: a generator of `Document` objects + """ + + from rich import filesize + + from docarray import BaseDocument + from docarray.utils.progress_bar import _get_progressbar + + with file_ctx as f: + version_numdocs_lendoc0 = f.read(9) + # 1 byte (uint8) + # 8 bytes (uint64) + num_docs = int.from_bytes(version_numdocs_lendoc0[1:9], 'big', signed=False) + + pbar, t = _get_progressbar( + 'Deserializing', disable=not show_progress, total=num_docs + ) + + with pbar: + _total_size = 0 + pbar.start_task(t) + for _ in range(num_docs): + # 4 bytes (uint32) + len_current_doc_in_bytes = int.from_bytes( + f.read(4), 'big', signed=False + ) + _total_size += len_current_doc_in_bytes + load_protocol: str = protocol or 'protobuf' + yield BaseDocument.from_bytes( + f.read(len_current_doc_in_bytes), + protocol=load_protocol, + compress=compress, + ) + pbar.update( + t, advance=1, total_size=str(filesize.decimal(_total_size)) + ) + + @classmethod + def load_binary( + cls: Type[T], + file: Union[str, bytes, pathlib.Path, io.BufferedReader, _LazyRequestReader], + protocol: str = 'protobuf-array', + compress: Optional[str] = None, + show_progress: bool = False, + streaming: bool = False, + ) -> Union[T, Generator['BaseDocument', None, None]]: + """Load array elements from a compressed binary file. + + :param file: File or filename or serialized bytes where the data is stored. + :param protocol: protocol to use. It can be 'pickle-array', 'protobuf-array', 'pickle' or 'protobuf' + :param compress: compress algorithm to use + :param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` + :param streaming: if `True` returns a generator over `Document` objects. + In case protocol is pickle the `Documents` are streamed from disk to save memory usage + :return: a DocumentArray object + + .. note:: + If `file` is `str` it can specify `protocol` and `compress` as file extensions. + This functionality assumes `file=file_name.$protocol.$compress` where `$protocol` and `$compress` refer to a + string interpolation of the respective `protocol` and `compress` methods. + For example if `file=my_docarray.protobuf.lz4` then the binary data will be loaded assuming `protocol=protobuf` + and `compress=lz4`. + """ + load_protocol: Optional[str] = protocol + load_compress: Optional[str] = compress + file_ctx: Union[nullcontext, io.BufferedReader] + if isinstance(file, (io.BufferedReader, _LazyRequestReader, bytes)): + file_ctx = nullcontext(file) + # by checking path existence we allow file to be of type Path, LocalPath, PurePath and str + elif isinstance(file, (str, pathlib.Path)) and os.path.exists(file): + load_protocol, load_compress = _protocol_and_compress_from_file_path( + file, protocol, compress + ) + file_ctx = open(file, 'rb') + else: + raise FileNotFoundError(f'cannot find file {file}') + if streaming: + return cls._load_binary_stream( + file_ctx, + protocol=load_protocol, + compress=load_compress, + show_progress=show_progress, + ) + else: + return cls._load_binary_all( + file_ctx, load_protocol, load_compress, show_progress + ) + + def save_binary( + self, + file: Union[str, pathlib.Path], + protocol: str = 'protobuf-array', + compress: Optional[str] = None, + show_progress: bool = False, + ) -> None: + """Save DocumentArray into a binary file. + + It will use the protocol to pick how to save the DocumentArray. + If used 'picke-array` and `protobuf-array` the DocumentArray will be stored + and compressed at complete level using `pickle` or `protobuf`. + When using `protobuf` or `pickle` as protocol each Document in DocumentArray + will be stored individually and this would make it available for streaming. + + :param file: File or filename to which the data is saved. + :param protocol: protocol to use. It can be 'pickle-array', 'protobuf-array', 'pickle' or 'protobuf' + :param compress: compress algorithm to use + :param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` + + .. note:: + If `file` is `str` it can specify `protocol` and `compress` as file extensions. + This functionality assumes `file=file_name.$protocol.$compress` where `$protocol` and `$compress` refer to a + string interpolation of the respective `protocol` and `compress` methods. + For example if `file=my_docarray.protobuf.lz4` then the binary data will be created using `protocol=protobuf` + and `compress=lz4`. + """ + if isinstance(file, io.BufferedWriter): + file_ctx = nullcontext(file) + else: + _protocol, _compress = _protocol_and_compress_from_file_path(file) + + if _protocol is not None: + protocol = _protocol + if _compress is not None: + compress = _compress + + file_ctx = open(file, 'wb') + + self.to_bytes( + protocol=protocol, + compress=compress, + file_ctx=file_ctx, + show_progress=show_progress, + ) From 5e324cf5156b28f616f9d8234823d0b0eacaa417 Mon Sep 17 00:00:00 2001 From: samsja Date: Fri, 17 Feb 2023 13:26:36 +0100 Subject: [PATCH 05/11] refactor: create io mixin for docuemnt array Signed-off-by: samsja --- docarray/array/array/array.py | 8 ++++++++ docarray/array/array/io.py | 15 ++++++++++++++- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/docarray/array/array/array.py b/docarray/array/array/array.py index 2b3c16b8fb9..f39bae21efa 100644 --- a/docarray/array/array/array.py +++ b/docarray/array/array/array.py @@ -31,6 +31,7 @@ from pydantic.fields import ModelField from docarray.array.stacked.array_stacked import DocumentArrayStacked + from docarray.proto import DocumentArrayProto from docarray.typing import TorchTensor from docarray.typing.tensor.abstract_tensor import AbstractTensor @@ -433,3 +434,10 @@ def traverse_flat( flattened = AnyDocumentArray._flatten_one_level(nodes) return flattened + + @classmethod + def from_protobuf(cls: Type[T], pb_msg: 'DocumentArrayProto') -> T: + """create a Document from a protobuf message + :param pb_msg: The protobuf message from where to construct the DocumentArray + """ + return super().from_protobuf(pb_msg) diff --git a/docarray/array/array/io.py b/docarray/array/array/io.py index bb3e5aff991..490d30c584d 100644 --- a/docarray/array/array/io.py +++ b/docarray/array/array/io.py @@ -4,13 +4,16 @@ import os import pathlib import pickle +from abc import abstractmethod from contextlib import nullcontext from typing import ( TYPE_CHECKING, BinaryIO, ContextManager, Generator, + Iterable, Optional, + Sized, Tuple, Type, TypeVar, @@ -78,7 +81,17 @@ def __getitem__(self, item: slice): return self.content[item] -class IOMixinArray: +class IOMixinArray(Iterable[BaseDocument], Sized): + + document_type: Type[BaseDocument] + + @abstractmethod + def __init__( + self, + docs: Optional[Iterable[BaseDocument]] = None, + ): + ... + @classmethod def from_protobuf(cls: Type[T], pb_msg: 'DocumentArrayProto') -> T: """create a Document from a protobuf message From c3482b8538eada9c8b9c152067d792e5bba1a37e Mon Sep 17 00:00:00 2001 From: samsja Date: Fri, 17 Feb 2023 14:30:08 +0100 Subject: [PATCH 06/11] fix: fix sized problem Signed-off-by: samsja --- docarray/array/array/io.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/docarray/array/array/io.py b/docarray/array/array/io.py index 490d30c584d..b5a846102d2 100644 --- a/docarray/array/array/io.py +++ b/docarray/array/array/io.py @@ -13,7 +13,6 @@ Generator, Iterable, Optional, - Sized, Tuple, Type, TypeVar, @@ -81,10 +80,14 @@ def __getitem__(self, item: slice): return self.content[item] -class IOMixinArray(Iterable[BaseDocument], Sized): +class IOMixinArray(Iterable[BaseDocument]): document_type: Type[BaseDocument] + @abstractmethod + def __len__(self): + ... + @abstractmethod def __init__( self, From 990f4ff340cece4c8fe3ec4c9dd49f10953976f6 Mon Sep 17 00:00:00 2001 From: samsja Date: Fri, 17 Feb 2023 15:00:55 +0100 Subject: [PATCH 07/11] fix: make any array generic for type var --- docarray/array/abstract_array.py | 2 ++ tests/units/array/test_array.py | 12 +++++++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/docarray/array/abstract_array.py b/docarray/array/abstract_array.py index b0a4b6db70c..828c5e1085b 100644 --- a/docarray/array/abstract_array.py +++ b/docarray/array/abstract_array.py @@ -35,6 +35,8 @@ def __repr__(self): @classmethod def __class_getitem__(cls, item: Type[BaseDocument]): + if isinstance(item, TypeVar): + return super().__class_getitem__(item) if not issubclass(item, BaseDocument): raise ValueError( f'{cls.__name__}[item] item should be a Document not a {item} ' diff --git a/tests/units/array/test_array.py b/tests/units/array/test_array.py index d69514caab7..ec15f21afbc 100644 --- a/tests/units/array/test_array.py +++ b/tests/units/array/test_array.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional, TypeVar, Union import numpy as np import pytest @@ -294,3 +294,13 @@ def test_del_item(da): 'hello 8', 'hello 9', ] + + +def test_generic_type_var(): + T = TypeVar('T', bound=BaseDocument) + + def f(a: DocumentArray[T]) -> DocumentArray[T]: + return a + + a = DocumentArray() + f(a) From 57b2101de4b1f2e4c1c87b64a243d2c442a75fce Mon Sep 17 00:00:00 2001 From: samsja Date: Fri, 17 Feb 2023 15:10:49 +0100 Subject: [PATCH 08/11] fix : allow string --- docarray/array/abstract_array.py | 7 ++++--- tests/units/array/test_array.py | 4 ++++ 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/docarray/array/abstract_array.py b/docarray/array/abstract_array.py index 828c5e1085b..e9ebfb43314 100644 --- a/docarray/array/abstract_array.py +++ b/docarray/array/abstract_array.py @@ -9,6 +9,7 @@ Type, TypeVar, Union, + cast, ) from docarray.base_document import BaseDocument @@ -34,8 +35,8 @@ def __repr__(self): return f'<{self.__class__.__name__} (length={len(self)})>' @classmethod - def __class_getitem__(cls, item: Type[BaseDocument]): - if isinstance(item, TypeVar): + def __class_getitem__(cls, item: Union[Type[BaseDocument], TypeVar, str]): + if not isinstance(item, type): return super().__class_getitem__(item) if not issubclass(item, BaseDocument): raise ValueError( @@ -50,7 +51,7 @@ def __class_getitem__(cls, item: Type[BaseDocument]): global _DocumentArrayTyped class _DocumentArrayTyped(cls): # type: ignore - document_type: Type[BaseDocument] = item + document_type: Type[BaseDocument] = cast(Type[BaseDocument], item) for field in _DocumentArrayTyped.document_type.__fields__.keys(): diff --git a/tests/units/array/test_array.py b/tests/units/array/test_array.py index ec15f21afbc..7c0e9b329f8 100644 --- a/tests/units/array/test_array.py +++ b/tests/units/array/test_array.py @@ -302,5 +302,9 @@ def test_generic_type_var(): def f(a: DocumentArray[T]) -> DocumentArray[T]: return a + def g(a: DocumentArray['BaseDocument']) -> DocumentArray['BaseDocument']: + return a + a = DocumentArray() f(a) + g(a) From 0bded51419a96a9960c4a74d8d3411c812dc6979 Mon Sep 17 00:00:00 2001 From: samsja Date: Fri, 17 Feb 2023 16:12:22 +0100 Subject: [PATCH 09/11] feat: add generic to document array stacked --- docarray/array/array/array.py | 3 +-- docarray/array/stacked/array_stacked.py | 15 +++++++-------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/docarray/array/array/array.py b/docarray/array/array/array.py index f39bae21efa..a42e0d2dca1 100644 --- a/docarray/array/array/array.py +++ b/docarray/array/array/array.py @@ -5,7 +5,6 @@ TYPE_CHECKING, Any, Callable, - Generic, Iterable, List, Optional, @@ -68,7 +67,7 @@ def _is_np_int(item: Any) -> bool: return False # this is unreachable, but mypy wants it -class DocumentArray(IOMixinArray, AnyDocumentArray, Generic[T_doc]): +class DocumentArray(IOMixinArray, AnyDocumentArray[T_doc]): """ DocumentArray is a container of Documents. diff --git a/docarray/array/stacked/array_stacked.py b/docarray/array/stacked/array_stacked.py index cf2bf2accab..50ef017113a 100644 --- a/docarray/array/stacked/array_stacked.py +++ b/docarray/array/stacked/array_stacked.py @@ -42,11 +42,12 @@ else: TensorFlowTensor = None # type: ignore +T_doc = TypeVar('T_doc', bound=BaseDocument) T = TypeVar('T', bound='DocumentArrayStacked') IndexIterType = Union[slice, Iterable[int], Iterable[bool], None] -class DocumentArrayStacked(AnyDocumentArray): +class DocumentArrayStacked(AnyDocumentArray[T_doc]): """ DocumentArrayStacked is a container of Documents appropriates to perform computation that require batches of data (ex: matrix multiplication, distance @@ -70,7 +71,7 @@ class DocumentArrayStacked(AnyDocumentArray): def __init__( self: T, - docs: Optional[Union[DocumentArray, Iterable[BaseDocument]]] = None, + docs: Optional[Union[DocumentArray, Iterable[T_doc]]] = None, tensor_type: Type['AbstractTensor'] = NdArray, ): self._doc_columns: Dict[str, 'DocumentArrayStacked'] = {} @@ -80,7 +81,7 @@ def __init__( self.from_iterable_document(docs) def from_iterable_document( - self: T, docs: Optional[Union[DocumentArray, Iterable[BaseDocument]]] + self: T, docs: Optional[Union[DocumentArray, Iterable[T_doc]]] ): self._docs = ( docs @@ -254,7 +255,7 @@ def _set_array_attribute( setattr(self._docs, field, values) @overload - def __getitem__(self: T, item: int) -> BaseDocument: + def __getitem__(self: T, item: int) -> T_doc: ... @overload @@ -276,9 +277,7 @@ def __getitem__(self, item): setattr(doc, field, self._tensor_columns[field][item]) return doc - def __setitem__( - self: T, key: Union[int, IndexIterType], value: Union[T, BaseDocument] - ): + def __setitem__(self: T, key: Union[int, IndexIterType], value: Union[T, T_doc]): # multiple docs case if isinstance(key, (slice, Iterable)): return self._set_data_and_columns(key, value) @@ -476,7 +475,7 @@ def unstacked_mode(self): @classmethod def validate( cls: Type[T], - value: Union[T, Iterable[BaseDocument]], + value: Union[T, Iterable[T_doc]], field: 'ModelField', config: 'BaseConfig', ) -> T: From 4a477f4256e38a19ec1a20f35f115e11db3194ee Mon Sep 17 00:00:00 2001 From: samsja Date: Fri, 17 Feb 2023 16:55:40 +0100 Subject: [PATCH 10/11] fix: fix generic class getitem --- docarray/array/abstract_array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docarray/array/abstract_array.py b/docarray/array/abstract_array.py index e9ebfb43314..c12bb8e7ee4 100644 --- a/docarray/array/abstract_array.py +++ b/docarray/array/abstract_array.py @@ -37,7 +37,7 @@ def __repr__(self): @classmethod def __class_getitem__(cls, item: Union[Type[BaseDocument], TypeVar, str]): if not isinstance(item, type): - return super().__class_getitem__(item) + return Generic.__class_getitem__.__func__(cls, item) # type: ignore if not issubclass(item, BaseDocument): raise ValueError( f'{cls.__name__}[item] item should be a Document not a {item} ' From d07ccd074cf93dec37f665dd1be1112103ce5488 Mon Sep 17 00:00:00 2001 From: samsja Date: Fri, 17 Feb 2023 17:00:17 +0100 Subject: [PATCH 11/11] fix: ad comment --- docarray/array/abstract_array.py | 1 + 1 file changed, 1 insertion(+) diff --git a/docarray/array/abstract_array.py b/docarray/array/abstract_array.py index c12bb8e7ee4..832bc251517 100644 --- a/docarray/array/abstract_array.py +++ b/docarray/array/abstract_array.py @@ -38,6 +38,7 @@ def __repr__(self): def __class_getitem__(cls, item: Union[Type[BaseDocument], TypeVar, str]): if not isinstance(item, type): return Generic.__class_getitem__.__func__(cls, item) # type: ignore + # this do nothing that checking that item is valid type var or str if not issubclass(item, BaseDocument): raise ValueError( f'{cls.__name__}[item] item should be a Document not a {item} '