Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 27 additions & 7 deletions docarray/base_doc/mixins/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ def _type_to_protobuf(value: Any) -> 'NodeProto':
data = {}

for key, content in value.items():
if not isinstance(key, str):
raise ValueError(
f'Protobuf only support string as key, but got {type(key)}'
)

data[key] = _type_to_protobuf(content)

struct = DictOfAnyProto(data=data)
Expand Down Expand Up @@ -242,7 +247,10 @@ def from_protobuf(cls: Type[T], pb_msg: 'DocProto') -> T:

@classmethod
def _get_content_from_node_proto(
cls, value: 'NodeProto', field_name: Optional[str] = None
cls,
value: 'NodeProto',
field_name: Optional[str] = None,
field_type: Optional[Type] = None,
) -> Any:
"""
load the proto data from a node proto
Expand All @@ -251,6 +259,14 @@ def _get_content_from_node_proto(
:param field_name: the name of the field
:return: the loaded field
"""

if field_name is not None and field_type is not None:
raise ValueError("field_type and field_name cannot be both passed")

field_type = field_type or (
cls._get_field_type(field_name) if field_name else None
)

content_type_dict = _PROTO_TYPE_NAME_TO_CLASS

content_key = value.WhichOneof('content')
Expand All @@ -265,17 +281,17 @@ def _get_content_from_node_proto(
getattr(value, content_key)
)
elif content_key == 'doc':
if field_name is None:
if field_type is None:
raise ValueError(
'field_name cannot be None when trying to deseriliaze a BaseDoc'
'field_type cannot be None when trying to deserialize a BaseDoc'
)
return_field = cls._get_field_type(field_name).from_protobuf(
return_field = field_type.from_protobuf(
getattr(value, content_key)
) # we get to the parent class
elif content_key == 'doc_array':
if field_name is None:
raise ValueError(
'field_name cannot be None when trying to deseriliaze a BaseDoc'
'field_name cannot be None when trying to deserialize a BaseDoc'
)
return_field = cls._get_field_type_array(field_name).from_protobuf(
getattr(value, content_key)
Expand All @@ -293,15 +309,19 @@ def _get_content_from_node_proto(
return_field = getattr(value, content_key)

elif content_key in arg_to_container.keys():
field_type = cls.__fields__[field_name].type_ if field_name else None
return_field = arg_to_container[content_key](
cls._get_content_from_node_proto(node)
cls._get_content_from_node_proto(node, field_type=field_type)
for node in getattr(value, content_key).data
)

elif content_key == 'dict':
deser_dict: Dict[str, Any] = dict()
field_type = cls.__fields__[field_name].type_ if field_name else None
for key_name, node in value.dict.data.items():
deser_dict[key_name] = cls._get_content_from_node_proto(node)
deser_dict[key_name] = cls._get_content_from_node_proto(
node, field_type=field_type
)
return_field = deser_dict
else:
raise ValueError(
Expand Down
55 changes: 55 additions & 0 deletions tests/units/document/proto/test_document_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,17 @@ class MyDoc(BaseDoc):
MyDoc.from_protobuf(doc.to_protobuf())


@pytest.mark.proto
def test_nested_dict_error():
class MyDoc(BaseDoc):
data: Dict

doc = MyDoc(data={0: (1, 2)})

with pytest.raises(ValueError, match="Protobuf only support string as key"):
doc.to_protobuf()


@pytest.mark.proto
def test_tuple_complex():
class MyDoc(BaseDoc):
Expand Down Expand Up @@ -304,3 +315,47 @@ def test_any_doc_proto():
pt = doc.to_protobuf()
doc2 = AnyDoc.from_protobuf(pt)
assert doc2.dict()['hello'] == 'world'


@pytest.mark.proto
def test_nested_list():
from typing import List

from docarray import BaseDoc, DocList
from docarray.documents import TextDoc

class TextDocWithId(TextDoc):
id: str

class ResultTestDoc(BaseDoc):
matches: List[TextDocWithId]

da = DocList[ResultTestDoc](
[
ResultTestDoc(matches=[TextDocWithId(id=f'{i}') for _ in range(10)])
for i in range(10)
]
)

DocList[ResultTestDoc].from_protobuf(da.to_protobuf())


@pytest.mark.proto
def test_nested_dict_typed():
from docarray import BaseDoc, DocList
from docarray.documents import TextDoc

class TextDocWithId(TextDoc):
id: str

class ResultTestDoc(BaseDoc):
matches: Dict[str, TextDocWithId]

da = DocList[ResultTestDoc](
[
ResultTestDoc(matches={f'{i}': TextDocWithId(id=f'{i}') for _ in range(10)})
for i in range(10)
]
)

DocList[ResultTestDoc].from_protobuf(da.to_protobuf())