Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions docarray/document/mixins/proto.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import TYPE_CHECKING, Any, Dict, Type

from docarray.proto import DocumentProto, NdArrayProto, NodeProto
from docarray.proto.io import flush_ndarray, read_ndarray
from docarray.typing import Tensor

from ..abstract_document import AbstractDocument
Expand Down Expand Up @@ -32,7 +31,7 @@ def from_protobuf(cls, pb_msg: 'DocumentProto') -> 'ProtoMixin':
content_type = value.WhichOneof('content')

if content_type == 'tensor':
fields[field] = read_ndarray(value.tensor)
fields[field] = Tensor.read_ndarray(value.tensor)
elif content_type == 'text':
fields[field] = value.text
elif content_type == 'nested':
Expand Down Expand Up @@ -64,12 +63,6 @@ def to_protobuf(self) -> 'DocumentProto':
if isinstance(value, BaseNode):
nested_item = value._to_nested_item_protobuf()

elif isinstance(value, Tensor):
nd_proto = NdArrayProto()
flush_ndarray(nd_proto, value=value)
NodeProto(tensor=nd_proto)
nested_item = NodeProto(tensor=nd_proto)

elif type(value) is str:
nested_item = NodeProto(text=value)

Expand Down
1 change: 0 additions & 1 deletion docarray/proto/io/__init__.py

This file was deleted.

27 changes: 0 additions & 27 deletions docarray/proto/io/tensor.py

This file was deleted.

2 changes: 1 addition & 1 deletion docarray/typing/ndarray.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import numpy as np
from .tensor import Tensor

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use global import


Tensor = np.ndarray
Embedding = Tensor
1 change: 1 addition & 0 deletions docarray/typing/tensor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .tensor import Tensor
72 changes: 72 additions & 0 deletions docarray/typing/tensor/tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from typing import Union, TypeVar, Any, TYPE_CHECKING, Type, cast

import numpy as np
if TYPE_CHECKING:
from pydantic.fields import ModelField
from pydantic import BaseConfig, PydanticValueError

from docarray.document.base_node import BaseNode
from docarray.proto import DocumentProto, NdArrayProto, NodeProto

T = TypeVar('T', bound='Tensor')


class Tensor(np.ndarray, BaseNode):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does it inherit from np.ndarray?

@classmethod
def __get_validators__(cls):
# one or more validators may be yielded which will be called in the
# order to validate the input, each validator will receive as an input
# the value returned from the previous validator
yield cls.validate

@classmethod
def validate(cls: Type[T], value: Union[T, Any], field: 'ModelField', config: 'BaseConfig') -> T:
if isinstance(value, np.ndarray):
return cls.from_ndarray(value)
elif isinstance(value, Tensor):
return cast(T, value)
else:
try:
arr: np.ndarray = np.ndarray(value)
return cls.from_ndarray(arr)
except Exception:
pass # handled below
raise ValueError(f'Expected a numpy.ndarray, got {type(value)}')

@classmethod
def from_ndarray(cls: Type[T], value: np.ndarray) -> T:
return value.view(cls)

def _to_nested_item_protobuf(self: T) -> 'NodeProto':
"""Convert Document into a nested item protobuf message. This function should be called when the Document
is nested into another Document that need to be converted into a protobuf

:return: the nested item protobuf message
"""
nd_proto = NdArrayProto()
self.flush_ndarray(nd_proto, value=self)
NodeProto(tensor=nd_proto)
return NodeProto(tensor=nd_proto)

@classmethod
def read_ndarray(cls: Type[T], pb_msg: 'NdArrayProto') -> 'T':
"""
read ndarray from a proto msg
:param pb_msg:
:return: a numpy array
"""
source = pb_msg.dense
if source.buffer:
x = np.frombuffer(source.buffer, dtype=source.dtype)
return cls.from_ndarray(x.reshape(source.shape))
elif len(source.shape) > 0:
return cls.from_ndarray(np.zeros(source.shape))
else:
raise ValueError(f'proto message {pb_msg} cannot be cast to a Tensor')

@staticmethod
def flush_ndarray(pb_msg: 'NdArrayProto', value: 'Tensor'):
pb_msg.dense.buffer = value.tobytes()
pb_msg.dense.ClearField('shape')
pb_msg.dense.shape.extend(list(value.shape))
pb_msg.dense.dtype = value.dtype.str
4 changes: 1 addition & 3 deletions docarray/typing/url/image_url.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import numpy as np

from docarray.typing import Tensor

from .any_url import AnyUrl


class ImageUrl(AnyUrl):
def load(self) -> Tensor:
def load(self) -> np.ndarray:
"""
transform the url in a image Tensor

Expand Down
4 changes: 3 additions & 1 deletion tests/integrations/predefined_document/test_image.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import numpy as np

from docarray import Image
from docarray.typing import Tensor

Expand All @@ -8,4 +10,4 @@ def test_image():

image.tensor = image.uri.load()

assert isinstance(image.tensor, Tensor)
assert isinstance(image.tensor, np.ndarray)
Empty file.
16 changes: 16 additions & 0 deletions tests/integrations/typing/tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import numpy as np

from docarray.typing import Tensor
from docarray import Document


def test_set_tensor():

class MyDocument(Document):
tensor: Tensor

d = MyDocument(tensor=np.zeros((3, 224, 224)))

assert isinstance(d.tensor, Tensor)
assert isinstance(d.tensor, np.ndarray)
assert (d.tensor == np.zeros((3, 224, 224))).all()
7 changes: 4 additions & 3 deletions tests/units/document/proto/test_document_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from docarray import DocumentArray
from docarray.document import BaseDocument
from docarray.typing import Tensor


def test_proto_simple():
Expand All @@ -17,7 +18,7 @@ class CustomDoc(BaseDocument):

def test_proto_ndarray():
class CustomDoc(BaseDocument):
tensor: np.ndarray
tensor: Tensor

tensor = np.zeros((3, 224, 224))
doc = CustomDoc(tensor=tensor)
Expand All @@ -29,7 +30,7 @@ class CustomDoc(BaseDocument):

def test_proto_with_nested_doc():
class CustomInnerDoc(BaseDocument):
tensor: np.ndarray
tensor: Tensor

class CustomDoc(BaseDocument):
text: str
Expand All @@ -42,7 +43,7 @@ class CustomDoc(BaseDocument):

def test_proto_with_chunks_doc():
class CustomInnerDoc(BaseDocument):
tensor: np.ndarray
tensor: Tensor

class CustomDoc(BaseDocument):
text: str
Expand Down
8 changes: 4 additions & 4 deletions tests/units/document/proto/test_proto_based_object.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np

from docarray.proto import DocumentProto, NdArrayProto, NodeProto
from docarray.proto.io import flush_ndarray, read_ndarray
from docarray.typing import Tensor


def test_nested_item_proto():
Expand All @@ -16,9 +16,9 @@ def test_nested_optional_item_proto():
def test_ndarray():
nd_proto = NdArrayProto()
original_tensor = np.zeros((3, 224, 224))
flush_ndarray(nd_proto, value=original_tensor)
Tensor.flush_ndarray(nd_proto, value=original_tensor)
nested_item = NodeProto(tensor=nd_proto)
tensor = read_ndarray(nested_item.tensor)
tensor = Tensor.read_ndarray(nested_item.tensor)

assert (tensor == original_tensor).all()

Expand All @@ -31,7 +31,7 @@ def test_document_proto_set():

nd_proto = NdArrayProto()
original_tensor = np.zeros((3, 224, 224))
flush_ndarray(nd_proto, value=original_tensor)
Tensor.flush_ndarray(nd_proto, value=original_tensor)

nested_item2 = NodeProto(tensor=nd_proto)

Expand Down
3 changes: 2 additions & 1 deletion tests/units/typing/test_image_url.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
from pydantic.tools import parse_obj_as

from docarray.typing import ImageUrl, Tensor
Expand All @@ -8,4 +9,4 @@ def test_image_url():

tensor = uri.load()

assert isinstance(tensor, Tensor)
assert isinstance(tensor, np.ndarray)