diff --git a/docarray/array/documentarray.py b/docarray/array/documentarray.py index 76ac731b4f8..635c3778fa7 100644 --- a/docarray/array/documentarray.py +++ b/docarray/array/documentarray.py @@ -1,8 +1,7 @@ from typing import Iterable, Type -from docarray.document import AnyDocument, BaseDocument +from docarray.document import AnyDocument, BaseDocument, BaseNode from docarray.document.abstract_document import AbstractDocument -from docarray.typing import BaseNode from .abstract_array import AbstractDocumentArray from .mixins import ProtoArrayMixin diff --git a/docarray/document/__init__.py b/docarray/document/__init__.py index eb73dccb05f..7b233fca4a7 100644 --- a/docarray/document/__init__.py +++ b/docarray/document/__init__.py @@ -1,4 +1,5 @@ from docarray.document.any_document import AnyDocument +from docarray.document.base_node import BaseNode from docarray.document.document import BaseDocument -__all__ = ['AnyDocument', 'BaseDocument'] +__all__ = ['AnyDocument', 'BaseDocument', 'BaseNode'] diff --git a/docarray/document/abstract_document.py b/docarray/document/abstract_document.py index 8f962d1b027..5d56016ee18 100644 --- a/docarray/document/abstract_document.py +++ b/docarray/document/abstract_document.py @@ -1,7 +1,16 @@ -from typing import Dict, Iterable +from abc import abstractmethod +from typing import TYPE_CHECKING, Dict, Iterable, Type from pydantic.fields import ModelField +if TYPE_CHECKING: + from docarray.document.mixins.proto import ProtoMixin + class AbstractDocument(Iterable): __fields__: Dict[str, ModelField] + + @classmethod + @abstractmethod + def _get_nested_document_class(cls, field: str) -> Type['ProtoMixin']: + ... diff --git a/docarray/document/document.py b/docarray/document/document.py index 342d41ec10a..f0eb363a624 100644 --- a/docarray/document/document.py +++ b/docarray/document/document.py @@ -1,11 +1,11 @@ import os -from typing import Union -from uuid import UUID +from typing import Type from pydantic import BaseModel, Field from docarray.document.abstract_document import AbstractDocument from docarray.document.base_node import BaseNode +from docarray.typing import ID from .mixins import ProtoMixin @@ -15,4 +15,14 @@ class BaseDocument(BaseModel, ProtoMixin, AbstractDocument, BaseNode): The base class for Document """ - id: Union[int, str, UUID] = Field(default_factory=lambda: os.urandom(16).hex()) + id: ID = Field(default_factory=lambda: ID.validate(os.urandom(16).hex())) + + @classmethod + def _get_nested_document_class(cls, field: str) -> Type['BaseDocument']: + """ + Accessing the nested python Class define in the schema. Could be useful for + reconstruction of Document in serialization/deserilization + :param field: name of the field + :return: + """ + return cls.__fields__[field].type_ diff --git a/docarray/document/mixins/proto.py b/docarray/document/mixins/proto.py index 0f4f5fc7a64..dc584bd7cd0 100644 --- a/docarray/document/mixins/proto.py +++ b/docarray/document/mixins/proto.py @@ -1,23 +1,14 @@ -from typing import Any, Dict, Type +from typing import Any, Dict -from docarray.proto import DocumentProto, NodeProto -from docarray.typing import Tensor +from pydantic.tools import parse_obj_as -from ..abstract_document import AbstractDocument -from ..base_node import BaseNode +from docarray.document.abstract_document import AbstractDocument +from docarray.document.base_node import BaseNode +from docarray.proto import DocumentProto, NodeProto +from docarray.typing import ID, AnyUrl, Embedding, ImageUrl, Tensor class ProtoMixin(AbstractDocument, BaseNode): - @classmethod - def _get_nested_document_class(cls, field: str) -> Type['ProtoMixin']: - """ - Accessing the nested python Class define in the schema. Could be useful for - reconstruction of Document in serialization/deserilization - :param field: name of the field - :return: - """ - return cls.__fields__[field].type_ - @classmethod def from_protobuf(cls, pb_msg: 'DocumentProto') -> 'ProtoMixin': """create a Document from a protobuf message""" @@ -30,8 +21,18 @@ def from_protobuf(cls, pb_msg: 'DocumentProto') -> 'ProtoMixin': content_type = value.WhichOneof('content') + # this if else statement need to be refactored it is too long + # the check should be delegated to the type level if content_type == 'tensor': - fields[field] = Tensor.read_ndarray(value.tensor) + fields[field] = Tensor._read_from_proto(value.tensor) + elif content_type == 'embedding': + fields[field] = Embedding._read_from_proto(value.embedding) + elif content_type == 'any_url': + fields[field] = parse_obj_as(AnyUrl, value.any_url) + elif content_type == 'image_url': + fields[field] = parse_obj_as(ImageUrl, value.image_url) + elif content_type == 'id': + fields[field] = parse_obj_as(ID, value.id) elif content_type == 'text': fields[field] = value.text elif content_type == 'nested': diff --git a/docarray/predefined_document/text.py b/docarray/predefined_document/text.py index a45ba0ab7fe..3ef2cd61d12 100644 --- a/docarray/predefined_document/text.py +++ b/docarray/predefined_document/text.py @@ -1,7 +1,7 @@ from typing import Optional from docarray.document import BaseDocument -from docarray.typing.ndarray import Embedding, Tensor +from docarray.typing.embedding import Embedding, Tensor class Text(BaseDocument): diff --git a/docarray/proto/docarray.proto b/docarray/proto/docarray.proto index aaa525236b0..54db5b576a4 100644 --- a/docarray/proto/docarray.proto +++ b/docarray/proto/docarray.proto @@ -46,6 +46,15 @@ message NodeProto { // a sub DocumentArray DocumentArrayProto chunks = 5; + + NdArrayProto embedding = 6; + + string any_url = 7; + + string image_url = 8; + + string id = 9; + } } diff --git a/docarray/proto/pb2/docarray_pb2.py b/docarray/proto/pb2/docarray_pb2.py index 1b95ddb2a55..6df55564dfb 100644 --- a/docarray/proto/pb2/docarray_pb2.py +++ b/docarray/proto/pb2/docarray_pb2.py @@ -2,10 +2,11 @@ # Generated by the protocol buffer compiler. DO NOT EDIT! # source: docarray.proto """Generated protocol buffer code.""" -from google.protobuf.internal import builder as _builder from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder + # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -13,26 +14,27 @@ from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2 - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0e\x64ocarray.proto\x12\x08\x64ocarray\x1a\x1cgoogle/protobuf/struct.proto\"A\n\x11\x44\x65nseNdArrayProto\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\r\x12\r\n\x05\x64type\x18\x03 \x01(\t\"g\n\x0cNdArrayProto\x12*\n\x05\x64\x65nse\x18\x01 \x01(\x0b\x32\x1b.docarray.DenseNdArrayProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\"\xbb\x01\n\tNodeProto\x12\x0e\n\x04\x62lob\x18\x01 \x01(\x0cH\x00\x12(\n\x06tensor\x18\x02 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x12\x0e\n\x04text\x18\x03 \x01(\tH\x00\x12)\n\x06nested\x18\x04 \x01(\x0b\x32\x17.docarray.DocumentProtoH\x00\x12.\n\x06\x63hunks\x18\x05 \x01(\x0b\x32\x1c.docarray.DocumentArrayProtoH\x00\x42\t\n\x07\x63ontent\"\x82\x01\n\rDocumentProto\x12/\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32!.docarray.DocumentProto.DataEntry\x1a@\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.docarray.NodeProto:\x02\x38\x01\";\n\x12\x44ocumentArrayProto\x12%\n\x04\x64ocs\x18\x01 \x03(\x0b\x32\x17.docarray.DocumentProtob\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x0e\x64ocarray.proto\x12\x08\x64ocarray\x1a\x1cgoogle/protobuf/struct.proto\"A\n\x11\x44\x65nseNdArrayProto\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\r\x12\r\n\x05\x64type\x18\x03 \x01(\t\"g\n\x0cNdArrayProto\x12*\n\x05\x64\x65nse\x18\x01 \x01(\x0b\x32\x1b.docarray.DenseNdArrayProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\"\x9e\x02\n\tNodeProto\x12\x0e\n\x04\x62lob\x18\x01 \x01(\x0cH\x00\x12(\n\x06tensor\x18\x02 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x12\x0e\n\x04text\x18\x03 \x01(\tH\x00\x12)\n\x06nested\x18\x04 \x01(\x0b\x32\x17.docarray.DocumentProtoH\x00\x12.\n\x06\x63hunks\x18\x05 \x01(\x0b\x32\x1c.docarray.DocumentArrayProtoH\x00\x12+\n\tembedding\x18\x06 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x12\x11\n\x07\x61ny_url\x18\x07 \x01(\tH\x00\x12\x13\n\timage_url\x18\x08 \x01(\tH\x00\x12\x0c\n\x02id\x18\t \x01(\tH\x00\x42\t\n\x07\x63ontent\"\x82\x01\n\rDocumentProto\x12/\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32!.docarray.DocumentProto.DataEntry\x1a@\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.docarray.NodeProto:\x02\x38\x01\";\n\x12\x44ocumentArrayProto\x12%\n\x04\x64ocs\x18\x01 \x03(\x0b\x32\x17.docarray.DocumentProtob\x06proto3' +) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'docarray_pb2', globals()) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _DOCUMENTPROTO_DATAENTRY._options = None - _DOCUMENTPROTO_DATAENTRY._serialized_options = b'8\001' - _DENSENDARRAYPROTO._serialized_start=58 - _DENSENDARRAYPROTO._serialized_end=123 - _NDARRAYPROTO._serialized_start=125 - _NDARRAYPROTO._serialized_end=228 - _NODEPROTO._serialized_start=231 - _NODEPROTO._serialized_end=418 - _DOCUMENTPROTO._serialized_start=421 - _DOCUMENTPROTO._serialized_end=551 - _DOCUMENTPROTO_DATAENTRY._serialized_start=487 - _DOCUMENTPROTO_DATAENTRY._serialized_end=551 - _DOCUMENTARRAYPROTO._serialized_start=553 - _DOCUMENTARRAYPROTO._serialized_end=612 + DESCRIPTOR._options = None + _DOCUMENTPROTO_DATAENTRY._options = None + _DOCUMENTPROTO_DATAENTRY._serialized_options = b'8\001' + _DENSENDARRAYPROTO._serialized_start = 58 + _DENSENDARRAYPROTO._serialized_end = 123 + _NDARRAYPROTO._serialized_start = 125 + _NDARRAYPROTO._serialized_end = 228 + _NODEPROTO._serialized_start = 231 + _NODEPROTO._serialized_end = 517 + _DOCUMENTPROTO._serialized_start = 520 + _DOCUMENTPROTO._serialized_end = 650 + _DOCUMENTPROTO_DATAENTRY._serialized_start = 586 + _DOCUMENTPROTO_DATAENTRY._serialized_end = 650 + _DOCUMENTARRAYPROTO._serialized_start = 652 + _DOCUMENTARRAYPROTO._serialized_end = 711 # @@protoc_insertion_point(module_scope) diff --git a/docarray/typing/__init__.py b/docarray/typing/__init__.py index 326e4113bfc..6d04abb0ac2 100644 --- a/docarray/typing/__init__.py +++ b/docarray/typing/__init__.py @@ -1,6 +1,6 @@ -from docarray.document.base_node import BaseNode +from docarray.typing.embedding import Embedding +from docarray.typing.id import ID +from docarray.typing.tensor import Tensor +from docarray.typing.url import AnyUrl, ImageUrl -from docarray.typing.ndarray import Embedding, Tensor -from docarray.typing.url import ImageUrl - -__all__ = ['Tensor', 'Embedding', 'BaseNode', 'ImageUrl'] +__all__ = ['Tensor', 'Embedding', 'ImageUrl', 'AnyUrl', 'ID'] diff --git a/docarray/typing/embedding.py b/docarray/typing/embedding.py new file mode 100644 index 00000000000..eb6caa3243a --- /dev/null +++ b/docarray/typing/embedding.py @@ -0,0 +1,18 @@ +from typing import TypeVar + +from docarray.proto import NodeProto +from docarray.typing.tensor import Tensor + +T = TypeVar('T', bound='Embedding') + + +class Embedding(Tensor): + def _to_node_protobuf(self: T, field: str = 'tensor') -> NodeProto: + """Convert Document into a NodeProto protobuf message. This function should + be called when the Document is nested into another Document that need to be + converted into a protobuf + :param field: field in which to store the content in the node proto + :return: the nested item protobuf message + """ + + return super()._to_node_protobuf(field='embedding') diff --git a/docarray/typing/id.py b/docarray/typing/id.py new file mode 100644 index 00000000000..a6b9d82c3fc --- /dev/null +++ b/docarray/typing/id.py @@ -0,0 +1,45 @@ +from typing import TYPE_CHECKING, Optional, Type, TypeVar, Union +from uuid import UUID + +from docarray.document.base_node import BaseNode +from docarray.proto import NodeProto + +if TYPE_CHECKING: + from pydantic import BaseConfig + from pydantic.fields import ModelField + + +T = TypeVar('T', bound='ID') + + +class ID(str, BaseNode): + """ + Represent an unique ID + """ + + @classmethod + def __get_validators__(cls): + yield cls.validate + + @classmethod + def validate( + cls: Type[T], + value: Union[str, int, UUID], + field: Optional['ModelField'] = None, + config: Optional['BaseConfig'] = None, + ) -> T: + + try: + id: str = str(value) + return cls(id) + except Exception: + raise ValueError(f'Expected a str, int or UUID, got {type(value)}') + + def _to_node_protobuf(self) -> NodeProto: + """Convert an ID into a NodeProto message. This function should + be called when the self is nested into another Document that need to be + converted into a protobuf + + :return: the nested item protobuf message + """ + return NodeProto(id=self) diff --git a/docarray/typing/ndarray.py b/docarray/typing/ndarray.py deleted file mode 100644 index 196d52955aa..00000000000 --- a/docarray/typing/ndarray.py +++ /dev/null @@ -1,3 +0,0 @@ -from .tensor import Tensor - -Embedding = Tensor diff --git a/docarray/typing/tensor/tensor.py b/docarray/typing/tensor/tensor.py index 32e14fb1362..c031207cdaa 100644 --- a/docarray/typing/tensor/tensor.py +++ b/docarray/typing/tensor/tensor.py @@ -40,20 +40,19 @@ def validate( def from_ndarray(cls: Type[T], value: np.ndarray) -> T: return value.view(cls) - def _to_node_protobuf(self: T) -> NodeProto: + def _to_node_protobuf(self: T, field: str = 'tensor') -> NodeProto: """Convert Document into a NodeProto protobuf message. This function should be called when the Document is nested into another Document that need to be converted into a protobuf - + :param field: field in which to store the content in the node proto :return: the nested item protobuf message """ nd_proto = NdArrayProto() - self.flush_ndarray(nd_proto, value=self) - NodeProto(tensor=nd_proto) - return NodeProto(tensor=nd_proto) + self._flush_tensor_to_proto(nd_proto, value=self) + return NodeProto(**{field: nd_proto}) @classmethod - def read_ndarray(cls: Type[T], pb_msg: 'NdArrayProto') -> 'T': + def _read_from_proto(cls: Type[T], pb_msg: 'NdArrayProto') -> 'T': """ read ndarray from a proto msg :param pb_msg: @@ -69,7 +68,7 @@ def read_ndarray(cls: Type[T], pb_msg: 'NdArrayProto') -> 'T': raise ValueError(f'proto message {pb_msg} cannot be cast to a Tensor') @staticmethod - def flush_ndarray(pb_msg: 'NdArrayProto', value: 'Tensor'): + def _flush_tensor_to_proto(pb_msg: 'NdArrayProto', value: 'Tensor'): pb_msg.dense.buffer = value.tobytes() pb_msg.dense.ClearField('shape') pb_msg.dense.shape.extend(list(value.shape)) diff --git a/docarray/typing/url/__init__.py b/docarray/typing/url/__init__.py index 75814c96fb5..f5a81b117fc 100644 --- a/docarray/typing/url/__init__.py +++ b/docarray/typing/url/__init__.py @@ -1,3 +1,4 @@ -from .image_url import ImageUrl +from docarray.typing.url.any_url import AnyUrl +from docarray.typing.url.image_url import ImageUrl -__all__ = ['ImageUrl'] +__all__ = ['ImageUrl', 'AnyUrl'] diff --git a/docarray/typing/url/any_url.py b/docarray/typing/url/any_url.py index c8cf3ed872b..5bec625d476 100644 --- a/docarray/typing/url/any_url.py +++ b/docarray/typing/url/any_url.py @@ -12,4 +12,4 @@ def _to_node_protobuf(self) -> NodeProto: :return: the nested item protobuf message """ - return NodeProto(text=str(self)) + return NodeProto(any_url=str(self)) diff --git a/docarray/typing/url/image_url.py b/docarray/typing/url/image_url.py index 68d956fb87d..062419fa53f 100644 --- a/docarray/typing/url/image_url.py +++ b/docarray/typing/url/image_url.py @@ -1,9 +1,19 @@ import numpy as np -from .any_url import AnyUrl +from docarray.proto import NodeProto +from docarray.typing.url.any_url import AnyUrl class ImageUrl(AnyUrl): + def _to_node_protobuf(self) -> NodeProto: + """Convert Document into a NodeProto protobuf message. This function should + be called when the Document is nested into another Document that need to + be converted into a protobuf + + :return: the nested item protobuf message + """ + return NodeProto(image_url=str(self)) + def load(self) -> np.ndarray: """ transform the url in a image Tensor diff --git a/tests/integrations/typing/test_typing_proto.py b/tests/integrations/typing/test_typing_proto.py new file mode 100644 index 00000000000..45919dbb06e --- /dev/null +++ b/tests/integrations/typing/test_typing_proto.py @@ -0,0 +1,25 @@ +import numpy as np + +from docarray import Document +from docarray.document import AnyDocument +from docarray.typing import AnyUrl, Embedding, ImageUrl, Tensor + + +def test_proto_all_types(): + class Mymmdoc(Document): + tensor: Tensor + embedding: Embedding + any_url: AnyUrl + image_url: ImageUrl + + doc = Mymmdoc( + tensor=np.zeros((3, 224, 224)), + embedding=np.zeros((100, 1)), + any_url='http://jina.ai', + image_url='http://jina.ai', + ) + + new_doc = AnyDocument.from_protobuf(doc.to_protobuf()) + + for field, value in new_doc: + assert isinstance(value, doc._get_nested_document_class(field)) diff --git a/tests/units/document/proto/test_proto_based_object.py b/tests/units/document/proto/test_proto_based_object.py index c2fd61773c8..ac6f38949b2 100644 --- a/tests/units/document/proto/test_proto_based_object.py +++ b/tests/units/document/proto/test_proto_based_object.py @@ -16,9 +16,9 @@ def test_nested_optional_item_proto(): def test_ndarray(): nd_proto = NdArrayProto() original_tensor = np.zeros((3, 224, 224)) - Tensor.flush_ndarray(nd_proto, value=original_tensor) + Tensor._flush_tensor_to_proto(nd_proto, value=original_tensor) nested_item = NodeProto(tensor=nd_proto) - tensor = Tensor.read_ndarray(nested_item.tensor) + tensor = Tensor._read_from_proto(nested_item.tensor) assert (tensor == original_tensor).all() @@ -31,7 +31,7 @@ def test_document_proto_set(): nd_proto = NdArrayProto() original_tensor = np.zeros((3, 224, 224)) - Tensor.flush_ndarray(nd_proto, value=original_tensor) + Tensor._flush_tensor_to_proto(nd_proto, value=original_tensor) nested_item2 = NodeProto(tensor=nd_proto) diff --git a/tests/units/typing/__init__.py b/tests/units/typing/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/units/typing/test_embedding.py b/tests/units/typing/test_embedding.py new file mode 100644 index 00000000000..61d31abe079 --- /dev/null +++ b/tests/units/typing/test_embedding.py @@ -0,0 +1,11 @@ +import numpy as np +from pydantic.tools import parse_obj_as + +from docarray.typing import Embedding + + +def test_proto_embedding(): + + uri = parse_obj_as(Embedding, np.zeros((3, 224, 224))) + + uri._to_node_protobuf() diff --git a/tests/units/typing/test_id.py b/tests/units/typing/test_id.py new file mode 100644 index 00000000000..5c3476bc82a --- /dev/null +++ b/tests/units/typing/test_id.py @@ -0,0 +1,16 @@ +from uuid import UUID + +import pytest +from pydantic.tools import parse_obj_as + +from docarray.typing import ID + + +@pytest.mark.parametrize( + 'id', ['1234', 1234, UUID('cf57432e-809e-4353-adbd-9d5c0d733868')] +) +def test_id_validation(id): + + parsed_id = parse_obj_as(ID, id) + + assert parsed_id == str(id) diff --git a/tests/units/typing/test_tensor.py b/tests/units/typing/test_tensor.py new file mode 100644 index 00000000000..9bace9e8841 --- /dev/null +++ b/tests/units/typing/test_tensor.py @@ -0,0 +1,11 @@ +import numpy as np +from pydantic.tools import parse_obj_as + +from docarray.typing import Tensor + + +def test_proto_tensor(): + + uri = parse_obj_as(Tensor, np.zeros((3, 224, 224))) + + uri._to_node_protobuf() diff --git a/tests/units/typing/url/__init__.py b/tests/units/typing/url/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/units/typing/url/test_any_url.py b/tests/units/typing/url/test_any_url.py new file mode 100644 index 00000000000..ad593b58519 --- /dev/null +++ b/tests/units/typing/url/test_any_url.py @@ -0,0 +1,10 @@ +from pydantic.tools import parse_obj_as + +from docarray.typing import ImageUrl + + +def test_proto_any_url(): + + uri = parse_obj_as(ImageUrl, 'http://jina.ai/img.png') + + uri._to_node_protobuf() diff --git a/tests/units/typing/test_image_url.py b/tests/units/typing/url/test_image_url.py similarity index 57% rename from tests/units/typing/test_image_url.py rename to tests/units/typing/url/test_image_url.py index 74c511dba43..37fcf525d23 100644 --- a/tests/units/typing/test_image_url.py +++ b/tests/units/typing/url/test_image_url.py @@ -1,7 +1,7 @@ import numpy as np from pydantic.tools import parse_obj_as -from docarray.typing import ImageUrl, Tensor +from docarray.typing import ImageUrl def test_image_url(): @@ -10,3 +10,10 @@ def test_image_url(): tensor = uri.load() assert isinstance(tensor, np.ndarray) + + +def test_proto_image_url(): + + uri = parse_obj_as(ImageUrl, 'http://jina.ai/img.png') + + uri._to_node_protobuf()