From 784fc40cc7e7443635e05d661060184afb50e03c Mon Sep 17 00:00:00 2001 From: Ben Shaver Date: Mon, 23 Oct 2023 15:36:23 -0400 Subject: [PATCH 1/6] feat(types): added ProtocolType to give more constrained type to `protocol` parameters Signed-off-by: Ben Shaver --- docarray/utils/_internal/misc.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docarray/utils/_internal/misc.py b/docarray/utils/_internal/misc.py index 5665f922fe..aada4d22ba 100644 --- a/docarray/utils/_internal/misc.py +++ b/docarray/utils/_internal/misc.py @@ -2,7 +2,7 @@ import os import re import types -from typing import Any, Optional +from typing import Any, Optional, Literal import numpy as np @@ -52,6 +52,7 @@ 'pymilvus': '"docarray[milvus]"', } +ProtocolType = Literal['protobuf', 'pickle'] def import_library( package: str, raise_error: bool = True From 9dc0efff5eb35ddcd403bc204f2aee05c1e68727 Mon Sep 17 00:00:00 2001 From: Ben Shaver Date: Mon, 23 Oct 2023 16:01:43 -0400 Subject: [PATCH 2/6] feat(types): expanded set of allowed protocol types Followed the ALLOWED_PROTOCOLS constant at https://github.com/docarray/docarray/blob/main/docarray/array/doc_list/io.py#L54C1-L54C1 Signed-off-by: Ben Shaver --- docarray/utils/_internal/misc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docarray/utils/_internal/misc.py b/docarray/utils/_internal/misc.py index aada4d22ba..864100bbd5 100644 --- a/docarray/utils/_internal/misc.py +++ b/docarray/utils/_internal/misc.py @@ -52,7 +52,7 @@ 'pymilvus': '"docarray[milvus]"', } -ProtocolType = Literal['protobuf', 'pickle'] +ProtocolType = Literal['protobuf', 'pickle', 'json', 'json-array', 'protobuf-array', 'pickle-array'] def import_library( package: str, raise_error: bool = True From 1977708a01a5843929b15a39f2efe00efd8822a8 Mon Sep 17 00:00:00 2001 From: Ben Shaver Date: Mon, 23 Oct 2023 16:02:17 -0400 Subject: [PATCH 3/6] refactor(types): changed instances of protocol parameter to use new ProtocolType Signed-off-by: Ben Shaver --- docarray/array/doc_list/io.py | 38 +++++++++++++++++----------------- docarray/array/doc_vec/io.py | 7 ++++--- docarray/base_doc/mixins/io.py | 8 +++---- docarray/store/helpers.py | 10 +++++---- 4 files changed, 33 insertions(+), 30 deletions(-) diff --git a/docarray/array/doc_list/io.py b/docarray/array/doc_list/io.py index 90b645cdad..aa7b3fc2d9 100644 --- a/docarray/array/doc_list/io.py +++ b/docarray/array/doc_list/io.py @@ -36,7 +36,7 @@ _dict_to_access_paths, ) from docarray.utils._internal.compress import _decompress_bytes, _get_compress_ctx -from docarray.utils._internal.misc import import_library +from docarray.utils._internal.misc import import_library, ProtocolType if TYPE_CHECKING: import pandas as pd @@ -57,9 +57,9 @@ def _protocol_and_compress_from_file_path( file_path: Union[pathlib.Path, str], - default_protocol: Optional[str] = None, + default_protocol: Optional[ProtocolType] = None, default_compress: Optional[str] = None, -) -> Tuple[Optional[str], Optional[str]]: +) -> Tuple[Optional[ProtocolType], Optional[str]]: """Extract protocol and compression algorithm from a string, use defaults if not found. :param file_path: path of a file. :param default_protocol: default serialization protocol used in case not found. @@ -79,7 +79,7 @@ def _protocol_and_compress_from_file_path( file_extensions = [e.replace('.', '') for e in pathlib.Path(file_path).suffixes] for extension in file_extensions: if extension in ALLOWED_PROTOCOLS: - protocol = extension + protocol = cast(ProtocolType, extension) elif extension in ALLOWED_COMPRESSIONS: compress = extension @@ -135,7 +135,7 @@ def to_protobuf(self) -> 'DocListProto': def from_bytes( cls: Type[T], data: bytes, - protocol: str = 'protobuf-array', + protocol: ProtocolType = 'protobuf-array', compress: Optional[str] = None, show_progress: bool = False, ) -> T: @@ -157,7 +157,7 @@ def from_bytes( def _write_bytes( self, bf: BinaryIO, - protocol: str = 'protobuf-array', + protocol: ProtocolType = 'protobuf-array', compress: Optional[str] = None, show_progress: bool = False, ) -> None: @@ -201,7 +201,7 @@ def _write_bytes( def _to_binary_stream( self, - protocol: str = 'protobuf', + protocol: ProtocolType = 'protobuf', compress: Optional[str] = None, show_progress: bool = False, ) -> Iterator[bytes]: @@ -241,7 +241,7 @@ def _to_binary_stream( def to_bytes( self, - protocol: str = 'protobuf-array', + protocol: ProtocolType = 'protobuf-array', compress: Optional[str] = None, file_ctx: Optional[BinaryIO] = None, show_progress: bool = False, @@ -273,7 +273,7 @@ def to_bytes( def from_base64( cls: Type[T], data: str, - protocol: str = 'protobuf-array', + protocol: ProtocolType = 'protobuf-array', compress: Optional[str] = None, show_progress: bool = False, ) -> T: @@ -294,7 +294,7 @@ def from_base64( def to_base64( self, - protocol: str = 'protobuf-array', + protocol: ProtocolType = 'protobuf-array', compress: Optional[str] = None, show_progress: bool = False, ) -> str: @@ -576,7 +576,7 @@ def _get_proto_class(cls: Type[T]): def _load_binary_all( cls: Type[T], file_ctx: Union[ContextManager[io.BufferedReader], ContextManager[bytes]], - protocol: Optional[str], + protocol: Optional[ProtocolType], compress: Optional[str], show_progress: bool, tensor_type: Optional[Type['AbstractTensor']] = None, @@ -659,7 +659,7 @@ def _load_binary_all( start_pos = end_doc_pos # variable length bytes doc - load_protocol: str = protocol or 'protobuf' + load_protocol: ProtocolType = protocol or cast(ProtocolType, 'protobuf') doc = cls.doc_type.from_bytes( d[start_doc_pos:end_doc_pos], protocol=load_protocol, @@ -680,7 +680,7 @@ def _load_binary_all( def _load_binary_stream( cls: Type[T], file_ctx: ContextManager[io.BufferedReader], - protocol: str = 'protobuf', + protocol: ProtocolType = 'protobuf', compress: Optional[str] = None, show_progress: bool = False, ) -> Generator['T_doc', None, None]: @@ -728,7 +728,7 @@ def _load_binary_stream( len_current_doc_in_bytes = int.from_bytes( f.read(4), 'big', signed=False ) - load_protocol: str = protocol + load_protocol: ProtocolType = protocol yield cls.doc_type.from_bytes( f.read(len_current_doc_in_bytes), protocol=load_protocol, @@ -743,10 +743,10 @@ def _load_binary_stream( @staticmethod def _get_file_context( file: Union[str, bytes, pathlib.Path, io.BufferedReader, _LazyRequestReader], - protocol: str, + protocol: ProtocolType, compress: Optional[str] = None, - ) -> Tuple[Union[nullcontext, io.BufferedReader], Optional[str], Optional[str]]: - load_protocol: Optional[str] = protocol + ) -> Tuple[Union[nullcontext, io.BufferedReader], Optional[ProtocolType], Optional[str]]: + load_protocol: Optional[ProtocolType] = protocol load_compress: Optional[str] = compress file_ctx: Union[nullcontext, io.BufferedReader] if isinstance(file, (io.BufferedReader, _LazyRequestReader, bytes)): @@ -765,7 +765,7 @@ def _get_file_context( def load_binary( cls: Type[T], file: Union[str, bytes, pathlib.Path, io.BufferedReader, _LazyRequestReader], - protocol: str = 'protobuf-array', + protocol: ProtocolType = 'protobuf-array', compress: Optional[str] = None, show_progress: bool = False, streaming: bool = False, @@ -814,7 +814,7 @@ def load_binary( def save_binary( self, file: Union[str, pathlib.Path], - protocol: str = 'protobuf-array', + protocol: ProtocolType = 'protobuf-array', compress: Optional[str] = None, show_progress: bool = False, ) -> None: diff --git a/docarray/array/doc_vec/io.py b/docarray/array/doc_vec/io.py index 3cf7630586..ad41d76308 100644 --- a/docarray/array/doc_vec/io.py +++ b/docarray/array/doc_vec/io.py @@ -31,6 +31,7 @@ from docarray.typing import NdArray from docarray.typing.tensor.abstract_tensor import AbstractTensor from docarray.utils._internal.pydantic import is_pydantic_v2 +from docarray.utils._internal.misc import ProtocolType if TYPE_CHECKING: import csv @@ -351,7 +352,7 @@ def from_csv( def from_base64( cls: Type[T], data: str, - protocol: str = 'protobuf-array', + protocol: ProtocolType = 'protobuf-array', compress: Optional[str] = None, show_progress: bool = False, tensor_type: Type['AbstractTensor'] = NdArray, @@ -377,7 +378,7 @@ def from_base64( def from_bytes( cls: Type[T], data: bytes, - protocol: str = 'protobuf-array', + protocol: ProtocolType = 'protobuf-array', compress: Optional[str] = None, show_progress: bool = False, tensor_type: Type['AbstractTensor'] = NdArray, @@ -454,7 +455,7 @@ class Person(BaseDoc): def load_binary( cls: Type[T], file: Union[str, bytes, pathlib.Path, io.BufferedReader, _LazyRequestReader], - protocol: str = 'protobuf-array', + protocol: ProtocolType = 'protobuf-array', compress: Optional[str] = None, show_progress: bool = False, streaming: bool = False, diff --git a/docarray/base_doc/mixins/io.py b/docarray/base_doc/mixins/io.py index cc4a3470d7..1a703cce58 100644 --- a/docarray/base_doc/mixins/io.py +++ b/docarray/base_doc/mixins/io.py @@ -36,7 +36,7 @@ from docarray.proto import DocProto, NodeProto from docarray.typing import TensorFlowTensor, TorchTensor - + from docarray.utils._internal.misc import ProtocolType else: tf = import_library('tensorflow', raise_error=False) @@ -150,7 +150,7 @@ def __bytes__(self) -> bytes: return self.to_bytes() def to_bytes( - self, protocol: str = 'protobuf', compress: Optional[str] = None + self, protocol: ProtocolType = 'protobuf', compress: Optional[str] = None ) -> bytes: """Serialize itself into bytes. @@ -177,7 +177,7 @@ def to_bytes( def from_bytes( cls: Type[T], data: bytes, - protocol: str = 'protobuf', + protocol: ProtocolType = 'protobuf', compress: Optional[str] = None, ) -> T: """Build Document object from binary bytes @@ -203,7 +203,7 @@ def from_bytes( ) def to_base64( - self, protocol: str = 'protobuf', compress: Optional[str] = None + self, protocol: ProtocolType = 'protobuf', compress: Optional[str] = None ) -> str: """Serialize a Document object into as base64 string diff --git a/docarray/store/helpers.py b/docarray/store/helpers.py index e2c4cf99a5..2d81a3f22c 100644 --- a/docarray/store/helpers.py +++ b/docarray/store/helpers.py @@ -13,6 +13,8 @@ import requests + from docarray.utils._internal.misc import ProtocolType + CACHING_REQUEST_READER_CHUNK_SIZE = 2**20 @@ -112,12 +114,12 @@ def raise_req_error(resp: 'requests.Response') -> NoReturn: class Streamable(Protocol): """A protocol for streamable objects.""" - def to_bytes(self, protocol: str, compress: Optional[str]) -> bytes: + def to_bytes(self, protocol: ProtocolType, compress: Optional[str]) -> bytes: ... @classmethod def from_bytes( - cls: Type[T_Elem], bytes: bytes, protocol: str, compress: Optional[str] + cls: Type[T_Elem], bytes: bytes, protocol: ProtocolType, compress: Optional[str] ) -> 'T_Elem': ... @@ -133,7 +135,7 @@ def close(self): def _to_binary_stream( iterator: Iterator['Streamable'], total: Optional[int] = None, - protocol: str = 'protobuf', + protocol: ProtocolType = 'protobuf', compress: Optional[str] = None, show_progress: bool = False, ) -> Iterator[bytes]: @@ -170,7 +172,7 @@ def _from_binary_stream( cls: Type[T], stream: ReadableBytes, total: Optional[int] = None, - protocol: str = 'protobuf', + protocol: ProtocolType = 'protobuf', compress: Optional[str] = None, show_progress: bool = False, ) -> Iterator['T']: From 61eb05f41dbaac96b9d64f7d3bdd6f8c762d0157 Mon Sep 17 00:00:00 2001 From: Ben Shaver Date: Mon, 23 Oct 2023 16:09:51 -0400 Subject: [PATCH 4/6] chore(fmt): auto-formatting Signed-off-by: Ben Shaver --- docarray/array/doc_list/io.py | 9 ++++++--- docarray/array/doc_vec/io.py | 1 - docarray/base_doc/mixins/io.py | 2 -- docarray/utils/_internal/misc.py | 5 ++++- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/docarray/array/doc_list/io.py b/docarray/array/doc_list/io.py index aa7b3fc2d9..82d00197e2 100644 --- a/docarray/array/doc_list/io.py +++ b/docarray/array/doc_list/io.py @@ -383,7 +383,6 @@ def _from_csv_file( file: Union[StringIO, TextIOWrapper], dialect: Union[str, csv.Dialect], ) -> 'T': - rows = csv.DictReader(file, dialect=dialect) doc_type = cls.doc_type @@ -659,7 +658,9 @@ def _load_binary_all( start_pos = end_doc_pos # variable length bytes doc - load_protocol: ProtocolType = protocol or cast(ProtocolType, 'protobuf') + load_protocol: ProtocolType = protocol or cast( + ProtocolType, 'protobuf' + ) doc = cls.doc_type.from_bytes( d[start_doc_pos:end_doc_pos], protocol=load_protocol, @@ -745,7 +746,9 @@ def _get_file_context( file: Union[str, bytes, pathlib.Path, io.BufferedReader, _LazyRequestReader], protocol: ProtocolType, compress: Optional[str] = None, - ) -> Tuple[Union[nullcontext, io.BufferedReader], Optional[ProtocolType], Optional[str]]: + ) -> Tuple[ + Union[nullcontext, io.BufferedReader], Optional[ProtocolType], Optional[str] + ]: load_protocol: Optional[ProtocolType] = protocol load_compress: Optional[str] = compress file_ctx: Union[nullcontext, io.BufferedReader] diff --git a/docarray/array/doc_vec/io.py b/docarray/array/doc_vec/io.py index ad41d76308..dd7213252f 100644 --- a/docarray/array/doc_vec/io.py +++ b/docarray/array/doc_vec/io.py @@ -135,7 +135,6 @@ def _from_json_col_dict( json_columns: Dict[str, Any], tensor_type: Type[AbstractTensor] = NdArray, ) -> T: - tensor_cols = json_columns['tensor_columns'] doc_cols = json_columns['doc_columns'] docs_vec_cols = json_columns['docs_vec_columns'] diff --git a/docarray/base_doc/mixins/io.py b/docarray/base_doc/mixins/io.py index 1a703cce58..11ea08aa10 100644 --- a/docarray/base_doc/mixins/io.py +++ b/docarray/base_doc/mixins/io.py @@ -329,7 +329,6 @@ def _get_content_from_node_proto( return_field = getattr(value, content_key) elif content_key in arg_to_container.keys(): - if field_name and field_name in cls._docarray_fields(): field_type = cls._get_field_inner_type(field_name) else: @@ -347,7 +346,6 @@ def _get_content_from_node_proto( deser_dict: Dict[str, Any] = dict() if field_name and field_name in cls._docarray_fields(): - if is_pydantic_v2: dict_args = get_args( cls._docarray_fields()[field_name].annotation diff --git a/docarray/utils/_internal/misc.py b/docarray/utils/_internal/misc.py index 864100bbd5..bb1e4ffe1d 100644 --- a/docarray/utils/_internal/misc.py +++ b/docarray/utils/_internal/misc.py @@ -52,7 +52,10 @@ 'pymilvus': '"docarray[milvus]"', } -ProtocolType = Literal['protobuf', 'pickle', 'json', 'json-array', 'protobuf-array', 'pickle-array'] +ProtocolType = Literal[ + 'protobuf', 'pickle', 'json', 'json-array', 'protobuf-array', 'pickle-array' +] + def import_library( package: str, raise_error: bool = True From d60ebe62d57d6c32f83e9fc2ea488a544f6cf8e5 Mon Sep 17 00:00:00 2001 From: Ben Shaver Date: Mon, 23 Oct 2023 16:47:49 -0400 Subject: [PATCH 5/6] fix(types): moving import ProtocolType from under TYPE_CHECKING box Signed-off-by: Ben Shaver --- docarray/base_doc/mixins/io.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docarray/base_doc/mixins/io.py b/docarray/base_doc/mixins/io.py index 11ea08aa10..0f371d21ab 100644 --- a/docarray/base_doc/mixins/io.py +++ b/docarray/base_doc/mixins/io.py @@ -26,7 +26,7 @@ from docarray.typing.proto_register import _PROTO_TYPE_NAME_TO_CLASS from docarray.utils._internal._typing import safe_issubclass from docarray.utils._internal.compress import _compress_bytes, _decompress_bytes -from docarray.utils._internal.misc import import_library +from docarray.utils._internal.misc import ProtocolType, import_library from docarray.utils._internal.pydantic import is_pydantic_v2 if TYPE_CHECKING: @@ -36,7 +36,6 @@ from docarray.proto import DocProto, NodeProto from docarray.typing import TensorFlowTensor, TorchTensor - from docarray.utils._internal.misc import ProtocolType else: tf = import_library('tensorflow', raise_error=False) From bbb89394ef2cd0679e5d59c7557b2cf8c8c0b8ac Mon Sep 17 00:00:00 2001 From: Ben Shaver Date: Mon, 23 Oct 2023 16:55:20 -0400 Subject: [PATCH 6/6] fix(types): moving import of ProtocolType from under TYPE_CHECKING block Signed-off-by: Ben Shaver --- docarray/store/helpers.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docarray/store/helpers.py b/docarray/store/helpers.py index 2d81a3f22c..24f28ac8ff 100644 --- a/docarray/store/helpers.py +++ b/docarray/store/helpers.py @@ -6,6 +6,7 @@ from rich import filesize from typing_extensions import TYPE_CHECKING, Protocol +from docarray.utils._internal.misc import ProtocolType from docarray.utils._internal.progress_bar import _get_progressbar if TYPE_CHECKING: @@ -13,8 +14,6 @@ import requests - from docarray.utils._internal.misc import ProtocolType - CACHING_REQUEST_READER_CHUNK_SIZE = 2**20