Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 13 additions & 15 deletions docarray/document/mixins/proto.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from typing import Any, Dict, Type, TypeVar

from pydantic.tools import parse_obj_as

from docarray.document.abstract_document import AbstractDocument
from docarray.document.base_node import BaseNode
from docarray.proto import DocumentProto, NodeProto
Expand All @@ -14,7 +12,6 @@ class ProtoMixin(AbstractDocument, BaseNode):
@classmethod
def from_protobuf(cls: Type[T], pb_msg: 'DocumentProto') -> T:
"""create a Document from a protobuf message"""
from docarray import DocumentArray

fields: Dict[str, Any] = {}

Expand All @@ -25,25 +22,26 @@ def from_protobuf(cls: Type[T], pb_msg: 'DocumentProto') -> T:

# this if else statement need to be refactored it is too long
# the check should be delegated to the type level
if content_type == 'tensor':
fields[field] = Tensor._read_from_proto(value.tensor)
elif content_type == 'torch_tensor':
fields[field] = TorchTensor._read_from_proto(value.torch_tensor)
elif content_type == 'embedding':
fields[field] = Embedding._read_from_proto(value.embedding)
elif content_type == 'any_url':
fields[field] = parse_obj_as(AnyUrl, value.any_url)
elif content_type == 'image_url':
fields[field] = parse_obj_as(ImageUrl, value.image_url)
elif content_type == 'id':
fields[field] = parse_obj_as(ID, value.id)
content_type_dict = dict(
tensor=Tensor,
torch_tensor=TorchTensor,
embedding=Embedding,
any_url=AnyUrl,
image_url=ImageUrl,
id=ID,
)
if content_type in content_type_dict:
fields[field] = content_type_dict[content_type].from_protobuf(
getattr(value, content_type)
)
elif content_type == 'text':
fields[field] = value.text
elif content_type == 'nested':
fields[field] = cls._get_nested_document_class(field).from_protobuf(
value.nested
) # we get to the parent class
elif content_type == 'chunks':
from docarray import DocumentArray

fields[field] = DocumentArray.from_protobuf(
value.chunks
Expand Down
2 changes: 1 addition & 1 deletion docarray/predefined_document/text.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional

from docarray.document import BaseDocument
from docarray.typing.embedding import Embedding, Tensor
from docarray.typing.tensor.embedding import Embedding, Tensor


class Text(BaseDocument):
Expand Down
11 changes: 9 additions & 2 deletions docarray/typing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from docarray.typing.embedding import Embedding
from docarray.typing.id import ID
from docarray.typing.tensor import Tensor, TorchTensor
from docarray.typing.tensor.embedding import Embedding
from docarray.typing.url import AnyUrl, ImageUrl

__all__ = ['Tensor', 'Embedding', 'ImageUrl', 'AnyUrl', 'ID', 'TorchTensor']
__all__ = [
'TorchTensor',
'Tensor',
'Embedding',
'ImageUrl',
'AnyUrl',
'ID',
]
19 changes: 13 additions & 6 deletions docarray/typing/id.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from typing import TYPE_CHECKING, Optional, Type, TypeVar, Union
from typing import Optional, Type, TypeVar, Union
from uuid import UUID

from pydantic import BaseConfig, parse_obj_as
from pydantic.fields import ModelField

from docarray.document.base_node import BaseNode
from docarray.proto import NodeProto

if TYPE_CHECKING:
from pydantic import BaseConfig
from pydantic.fields import ModelField


T = TypeVar('T', bound='ID')


Expand Down Expand Up @@ -43,3 +41,12 @@ def _to_node_protobuf(self) -> NodeProto:
:return: the nested item protobuf message
"""
return NodeProto(id=self)

@classmethod
def from_protobuf(cls: Type[T], pb_msg: 'str') -> T:
"""
read ndarray from a proto msg
Comment thread
dongxiang123 marked this conversation as resolved.
:param pb_msg:
:return: a string
"""
return parse_obj_as(cls, pb_msg)
File renamed without changes.
2 changes: 1 addition & 1 deletion docarray/typing/tensor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _to_node_protobuf(self: T, field: str = 'tensor') -> NodeProto:
return NodeProto(**{field: nd_proto})

@classmethod
def _read_from_proto(cls: Type[T], pb_msg: 'NdArrayProto') -> 'T':
def from_protobuf(cls: Type[T], pb_msg: 'NdArrayProto') -> 'T':
"""
read ndarray from a proto msg
:param pb_msg:
Expand Down
2 changes: 1 addition & 1 deletion docarray/typing/tensor/torch_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def _to_node_protobuf(self: T, field: str = 'torch_tensor') -> NodeProto:
return NodeProto(**{field: nd_proto})

@classmethod
def _read_from_proto(cls: Type[T], pb_msg: 'NdArrayProto') -> 'T':
def from_protobuf(cls: Type[T], pb_msg: 'NdArrayProto') -> 'T':
"""
read ndarray from a proto msg
Comment thread
dongxiang123 marked this conversation as resolved.
:param pb_msg:
Expand Down
14 changes: 14 additions & 0 deletions docarray/typing/url/any_url.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from typing import Type, TypeVar

from pydantic import AnyUrl as BaseAnyUrl
from pydantic import parse_obj_as

from docarray.document.base_node import BaseNode
from docarray.proto import NodeProto

T = TypeVar('T', bound='AnyUrl')


class AnyUrl(BaseAnyUrl, BaseNode):
def _to_node_protobuf(self) -> NodeProto:
Expand All @@ -13,3 +18,12 @@ def _to_node_protobuf(self) -> NodeProto:
:return: the nested item protobuf message
"""
return NodeProto(any_url=str(self))

@classmethod
def from_protobuf(cls: Type[T], pb_msg: 'str') -> T:
"""
read url from a proto msg
:param pb_msg:
:return: url
"""
return parse_obj_as(cls, pb_msg)
12 changes: 12 additions & 0 deletions tests/integrations/typing/test_anyurl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from docarray import Document
from docarray.typing import AnyUrl


def test_set_any_url():
class MyDocument(Document):
any_url: AnyUrl

d = MyDocument(any_url="https://jina.ai")

assert isinstance(d.any_url, AnyUrl)
assert d.any_url == "https://jina.ai"
15 changes: 15 additions & 0 deletions tests/integrations/typing/test_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import numpy as np

from docarray import Document
from docarray.typing import Embedding


def test_set_embedding():
class MyDocument(Document):
embedding: Embedding

d = MyDocument(embedding=np.zeros((3, 224, 224)))

assert isinstance(d.embedding, Embedding)
assert isinstance(d.embedding, np.ndarray)
assert (d.embedding == np.zeros((3, 224, 224))).all()
12 changes: 12 additions & 0 deletions tests/integrations/typing/test_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from docarray import Document
from docarray.typing import ID


def test_set_id():
class MyDocument(Document):
id: ID

d = MyDocument(id="123")

assert isinstance(d.id, ID)
assert d.id == "123"
12 changes: 12 additions & 0 deletions tests/integrations/typing/test_image_url.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from docarray import Document
from docarray.typing import ImageUrl


def test_set_image_url():
class MyDocument(Document):
image_url: ImageUrl

d = MyDocument(image_url="https://jina.ai/img.png")

assert isinstance(d.image_url, ImageUrl)
assert d.image_url == "https://jina.ai/img.png"
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from docarray.typing import TorchTensor


def test_set_tensor():
def test_set_torch_tensor():
class MyDocument(Document):
tensor: TorchTensor

Expand Down
2 changes: 1 addition & 1 deletion tests/units/document/proto/test_proto_based_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_ndarray():
original_tensor = np.zeros((3, 224, 224))
Tensor._flush_tensor_to_proto(nd_proto, value=original_tensor)
nested_item = NodeProto(tensor=nd_proto)
tensor = Tensor._read_from_proto(nested_item.tensor)
tensor = Tensor.from_protobuf(nested_item.tensor)

assert (tensor == original_tensor).all()

Expand Down
Empty file.