From 127758304ad8908e3d3bc24cf0dd22240834231c Mon Sep 17 00:00:00 2001 From: samsja Date: Thu, 27 Apr 2023 16:08:17 +0200 Subject: [PATCH 1/6] fix: fix error with List[Doc] nested proto Signed-off-by: samsja --- docarray/base_doc/mixins/io.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/docarray/base_doc/mixins/io.py b/docarray/base_doc/mixins/io.py index 9d19e4337bc..5882faa03ff 100644 --- a/docarray/base_doc/mixins/io.py +++ b/docarray/base_doc/mixins/io.py @@ -242,7 +242,10 @@ def from_protobuf(cls: Type[T], pb_msg: 'DocProto') -> T: @classmethod def _get_content_from_node_proto( - cls, value: 'NodeProto', field_name: Optional[str] = None + cls, + value: 'NodeProto', + field_name: Optional[str] = None, + field_type: Optional[Type] = None, ) -> Any: """ load the proto data from a node proto @@ -251,6 +254,12 @@ def _get_content_from_node_proto( :param field_name: the name of the field :return: the loaded field """ + + if field_name is not None and field_type is not None: + raise ValueError("field_type and field_name cannot be both passed") + + field_type = field_type or cls._get_field_type(field_name) + content_type_dict = _PROTO_TYPE_NAME_TO_CLASS content_key = value.WhichOneof('content') @@ -265,11 +274,11 @@ def _get_content_from_node_proto( getattr(value, content_key) ) elif content_key == 'doc': - if field_name is None: + if field_type is None: raise ValueError( - 'field_name cannot be None when trying to deseriliaze a BaseDoc' + 'field_type cannot be None when trying to deserialize a BaseDoc' ) - return_field = cls._get_field_type(field_name).from_protobuf( + return_field = field_type.from_protobuf( getattr(value, content_key) ) # we get to the parent class elif content_key == 'doc_array': @@ -294,7 +303,9 @@ def _get_content_from_node_proto( elif content_key in arg_to_container.keys(): return_field = arg_to_container[content_key]( - cls._get_content_from_node_proto(node) + cls._get_content_from_node_proto( + node, field_type=cls.__fields__[field_name].type_ + ) for node in getattr(value, content_key).data ) From 68bc7d86b60af0db6053de1dea5e9dd2b952b761 Mon Sep 17 00:00:00 2001 From: samsja Date: Thu, 27 Apr 2023 16:13:23 +0200 Subject: [PATCH 2/6] fix: fix tests Signed-off-by: samsja --- docarray/base_doc/mixins/io.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/docarray/base_doc/mixins/io.py b/docarray/base_doc/mixins/io.py index 5882faa03ff..40abb6dca83 100644 --- a/docarray/base_doc/mixins/io.py +++ b/docarray/base_doc/mixins/io.py @@ -258,7 +258,9 @@ def _get_content_from_node_proto( if field_name is not None and field_type is not None: raise ValueError("field_type and field_name cannot be both passed") - field_type = field_type or cls._get_field_type(field_name) + field_type = ( + field_type or cls._get_field_type(field_name) if field_name else None + ) content_type_dict = _PROTO_TYPE_NAME_TO_CLASS @@ -302,10 +304,9 @@ def _get_content_from_node_proto( return_field = getattr(value, content_key) elif content_key in arg_to_container.keys(): + field_type = cls.__fields__[field_name].type_ if field_name else None return_field = arg_to_container[content_key]( - cls._get_content_from_node_proto( - node, field_type=cls.__fields__[field_name].type_ - ) + cls._get_content_from_node_proto(node, field_type=field_type) for node in getattr(value, content_key).data ) From 49987e12579020cf9a856dfc17886549196c2431 Mon Sep 17 00:00:00 2001 From: samsja Date: Thu, 27 Apr 2023 16:34:05 +0200 Subject: [PATCH 3/6] fix: add tests Signed-off-by: samsja --- docarray/base_doc/mixins/io.py | 6 ++--- .../document/proto/test_document_proto.py | 22 +++++++++++++++++++ 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/docarray/base_doc/mixins/io.py b/docarray/base_doc/mixins/io.py index 40abb6dca83..27b1e14ccb0 100644 --- a/docarray/base_doc/mixins/io.py +++ b/docarray/base_doc/mixins/io.py @@ -258,8 +258,8 @@ def _get_content_from_node_proto( if field_name is not None and field_type is not None: raise ValueError("field_type and field_name cannot be both passed") - field_type = ( - field_type or cls._get_field_type(field_name) if field_name else None + field_type = field_type or ( + cls._get_field_type(field_name) if field_name else None ) content_type_dict = _PROTO_TYPE_NAME_TO_CLASS @@ -286,7 +286,7 @@ def _get_content_from_node_proto( elif content_key == 'doc_array': if field_name is None: raise ValueError( - 'field_name cannot be None when trying to deseriliaze a BaseDoc' + 'field_name cannot be None when trying to deserialize a BaseDoc' ) return_field = cls._get_field_type_array(field_name).from_protobuf( getattr(value, content_key) diff --git a/tests/units/document/proto/test_document_proto.py b/tests/units/document/proto/test_document_proto.py index a95a0edec62..6d75baca2a7 100644 --- a/tests/units/document/proto/test_document_proto.py +++ b/tests/units/document/proto/test_document_proto.py @@ -304,3 +304,25 @@ def test_any_doc_proto(): pt = doc.to_protobuf() doc2 = AnyDoc.from_protobuf(pt) assert doc2.dict()['hello'] == 'world' + + +def test_nested_list(): + from typing import List + + from docarray import BaseDoc, DocList + from docarray.documents import TextDoc + + class TextDocWithId(TextDoc): + id: str + + class ResultTestDoc(BaseDoc): + matches: List[TextDocWithId] + + da = DocList[ResultTestDoc]( + [ + ResultTestDoc(matches=[TextDocWithId(id=f'{i}') for _ in range(10)]) + for i in range(10) + ] + ) + + DocList[ResultTestDoc].from_protobuf(da.to_protobuf()) From 425bafd69d984ca4babd3becfd75f4aae0b47095 Mon Sep 17 00:00:00 2001 From: samsja Date: Thu, 27 Apr 2023 16:55:52 +0200 Subject: [PATCH 4/6] fix: add marker test Signed-off-by: samsja --- tests/units/document/proto/test_document_proto.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/units/document/proto/test_document_proto.py b/tests/units/document/proto/test_document_proto.py index 6d75baca2a7..cb8944b9f88 100644 --- a/tests/units/document/proto/test_document_proto.py +++ b/tests/units/document/proto/test_document_proto.py @@ -306,6 +306,7 @@ def test_any_doc_proto(): assert doc2.dict()['hello'] == 'world' +@pytest.mark.proto def test_nested_list(): from typing import List From 8964012c65a3ca1891d7266b3e2cced856f79589 Mon Sep 17 00:00:00 2001 From: samsja Date: Thu, 27 Apr 2023 17:05:50 +0200 Subject: [PATCH 5/6] fix: fix dict nested Signed-off-by: samsja --- docarray/base_doc/mixins/io.py | 5 ++++- .../document/proto/test_document_proto.py | 21 +++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/docarray/base_doc/mixins/io.py b/docarray/base_doc/mixins/io.py index 27b1e14ccb0..3466763c5df 100644 --- a/docarray/base_doc/mixins/io.py +++ b/docarray/base_doc/mixins/io.py @@ -312,8 +312,11 @@ def _get_content_from_node_proto( elif content_key == 'dict': deser_dict: Dict[str, Any] = dict() + field_type = cls.__fields__[field_name].type_ if field_name else None for key_name, node in value.dict.data.items(): - deser_dict[key_name] = cls._get_content_from_node_proto(node) + deser_dict[key_name] = cls._get_content_from_node_proto( + node, field_type=field_type + ) return_field = deser_dict else: raise ValueError( diff --git a/tests/units/document/proto/test_document_proto.py b/tests/units/document/proto/test_document_proto.py index cb8944b9f88..7e14d9040e0 100644 --- a/tests/units/document/proto/test_document_proto.py +++ b/tests/units/document/proto/test_document_proto.py @@ -327,3 +327,24 @@ class ResultTestDoc(BaseDoc): ) DocList[ResultTestDoc].from_protobuf(da.to_protobuf()) + + +@pytest.mark.proto +def test_nested_dict_typed(): + from docarray import BaseDoc, DocList + from docarray.documents import TextDoc + + class TextDocWithId(TextDoc): + id: str + + class ResultTestDoc(BaseDoc): + matches: Dict[str, TextDocWithId] + + da = DocList[ResultTestDoc]( + [ + ResultTestDoc(matches={f'{i}': TextDocWithId(id=f'{i}') for _ in range(10)}) + for i in range(10) + ] + ) + + DocList[ResultTestDoc].from_protobuf(da.to_protobuf()) From 8b0b3b7f466241d8c9b60a26f63edab06a8e644f Mon Sep 17 00:00:00 2001 From: samsja Date: Thu, 27 Apr 2023 17:08:07 +0200 Subject: [PATCH 6/6] fix: add warning to proto with non str key Signed-off-by: samsja --- docarray/base_doc/mixins/io.py | 5 +++++ tests/units/document/proto/test_document_proto.py | 11 +++++++++++ 2 files changed, 16 insertions(+) diff --git a/docarray/base_doc/mixins/io.py b/docarray/base_doc/mixins/io.py index 3466763c5df..4aa27f90900 100644 --- a/docarray/base_doc/mixins/io.py +++ b/docarray/base_doc/mixins/io.py @@ -101,6 +101,11 @@ def _type_to_protobuf(value: Any) -> 'NodeProto': data = {} for key, content in value.items(): + if not isinstance(key, str): + raise ValueError( + f'Protobuf only support string as key, but got {type(key)}' + ) + data[key] = _type_to_protobuf(content) struct = DictOfAnyProto(data=data) diff --git a/tests/units/document/proto/test_document_proto.py b/tests/units/document/proto/test_document_proto.py index 7e14d9040e0..80412b7c72a 100644 --- a/tests/units/document/proto/test_document_proto.py +++ b/tests/units/document/proto/test_document_proto.py @@ -220,6 +220,17 @@ class MyDoc(BaseDoc): MyDoc.from_protobuf(doc.to_protobuf()) +@pytest.mark.proto +def test_nested_dict_error(): + class MyDoc(BaseDoc): + data: Dict + + doc = MyDoc(data={0: (1, 2)}) + + with pytest.raises(ValueError, match="Protobuf only support string as key"): + doc.to_protobuf() + + @pytest.mark.proto def test_tuple_complex(): class MyDoc(BaseDoc):