From 941c0f9b78c60fba4599ef91d1de82381ae2a729 Mon Sep 17 00:00:00 2001 From: Johannes Messner Date: Fri, 11 Nov 2022 16:10:58 +0100 Subject: [PATCH 1/2] feat: add tensor type for ndarray --- docarray/document/mixins/proto.py | 9 +-- docarray/proto/io/__init__.py | 1 - docarray/proto/io/tensor.py | 27 ------- docarray/typing/ndarray.py | 2 +- docarray/typing/tensor/__init__.py | 1 + docarray/typing/tensor/tensor.py | 74 +++++++++++++++++++ .../predefined_document/test_image.py | 4 +- tests/integrations/typing/__init__.py | 0 tests/integrations/typing/tensor.py | 16 ++++ .../document/proto/test_document_proto.py | 7 +- .../document/proto/test_proto_based_object.py | 8 +- tests/units/typing/test_image_url.py | 3 +- 12 files changed, 106 insertions(+), 46 deletions(-) delete mode 100644 docarray/proto/io/__init__.py delete mode 100644 docarray/proto/io/tensor.py create mode 100644 docarray/typing/tensor/__init__.py create mode 100644 docarray/typing/tensor/tensor.py create mode 100644 tests/integrations/typing/__init__.py create mode 100644 tests/integrations/typing/tensor.py diff --git a/docarray/document/mixins/proto.py b/docarray/document/mixins/proto.py index ddeb44d1d0d..6fc38fbe080 100644 --- a/docarray/document/mixins/proto.py +++ b/docarray/document/mixins/proto.py @@ -1,7 +1,6 @@ from typing import TYPE_CHECKING, Any, Dict, Type from docarray.proto import DocumentProto, NdArrayProto, NodeProto -from docarray.proto.io import flush_ndarray, read_ndarray from docarray.typing import Tensor from ..abstract_document import AbstractDocument @@ -32,7 +31,7 @@ def from_protobuf(cls, pb_msg: 'DocumentProto') -> 'ProtoMixin': content_type = value.WhichOneof('content') if content_type == 'tensor': - fields[field] = read_ndarray(value.tensor) + fields[field] = Tensor.read_ndarray(value.tensor) elif content_type == 'text': fields[field] = value.text elif content_type == 'nested': @@ -64,12 +63,6 @@ def to_protobuf(self) -> 'DocumentProto': if isinstance(value, BaseNode): nested_item = value._to_nested_item_protobuf() - elif isinstance(value, Tensor): - nd_proto = NdArrayProto() - flush_ndarray(nd_proto, value=value) - NodeProto(tensor=nd_proto) - nested_item = NodeProto(tensor=nd_proto) - elif type(value) is str: nested_item = NodeProto(text=value) diff --git a/docarray/proto/io/__init__.py b/docarray/proto/io/__init__.py deleted file mode 100644 index 0606654178d..00000000000 --- a/docarray/proto/io/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .tensor import flush_ndarray, read_ndarray diff --git a/docarray/proto/io/tensor.py b/docarray/proto/io/tensor.py deleted file mode 100644 index 921632b9532..00000000000 --- a/docarray/proto/io/tensor.py +++ /dev/null @@ -1,27 +0,0 @@ -import numpy as np - -from docarray.proto import NdArrayProto -from docarray.typing.ndarray import Tensor - - -def read_ndarray(pb_msg: 'NdArrayProto') -> 'Tensor': - """ - read ndarray from a proto msg - :param pb_msg: - :return: a numpy array - """ - source = pb_msg.dense - if source.buffer: - x = np.frombuffer(source.buffer, dtype=source.dtype) - return x.reshape(source.shape) - elif len(source.shape) > 0: - return np.zeros(source.shape) - else: - raise ValueError(f'proto message {pb_msg} cannot be cast to a Tensor') - - -def flush_ndarray(pb_msg: 'NdArrayProto', value: 'Tensor'): - pb_msg.dense.buffer = value.tobytes() - pb_msg.dense.ClearField('shape') - pb_msg.dense.shape.extend(list(value.shape)) - pb_msg.dense.dtype = value.dtype.str diff --git a/docarray/typing/ndarray.py b/docarray/typing/ndarray.py index 60f224b0dc3..841bb541830 100644 --- a/docarray/typing/ndarray.py +++ b/docarray/typing/ndarray.py @@ -1,4 +1,4 @@ import numpy as np +from .tensor import Tensor -Tensor = np.ndarray Embedding = Tensor diff --git a/docarray/typing/tensor/__init__.py b/docarray/typing/tensor/__init__.py new file mode 100644 index 00000000000..aa666c11ebc --- /dev/null +++ b/docarray/typing/tensor/__init__.py @@ -0,0 +1 @@ +from .tensor import Tensor \ No newline at end of file diff --git a/docarray/typing/tensor/tensor.py b/docarray/typing/tensor/tensor.py new file mode 100644 index 00000000000..8a1c52ebac4 --- /dev/null +++ b/docarray/typing/tensor/tensor.py @@ -0,0 +1,74 @@ +from typing import Union, TypeVar, Any, TYPE_CHECKING + +import numpy as np +if TYPE_CHECKING: + from pydantic.fields import ModelField + from pydantic import BaseConfig + +from docarray.document.base_node import BaseNode +from docarray.proto import DocumentProto, NdArrayProto, NodeProto +from pydantic import ValidationError + +T = TypeVar('T', bound=np.ndarray) + + +class Tensor(np.ndarray, BaseNode): + @classmethod + def __get_validators__(cls): + # one or more validators may be yielded which will be called in the + # order to validate the input, each validator will receive as an input + # the value returned from the previous validator + yield cls.validate + + @classmethod + def validate(cls: T, value: Union[T, Any], field: 'ModelField', config: 'BaseConfig') -> T: + if isinstance(value, np.ndarray): + return cls.from_ndarray(value) + elif isinstance(value, Tensor): + return value + else: + try: + arr = np.ndarray(value) + return cls.from_ndarray(arr) + except Exception: + pass # handled below + raise ValidationError(f'Expected a numpy.ndarray, got {type(value)}') + + @classmethod + def from_ndarray(cls, value: np.ndarray) -> T: + return value.view(cls) + + + def _to_nested_item_protobuf(self) -> 'NodeProto': + """Convert Document into a nested item protobuf message. This function should be called when the Document + is nested into another Document that need to be converted into a protobuf + + :return: the nested item protobuf message + """ + nd_proto = NdArrayProto() + self.flush_ndarray(nd_proto, value=self) + NodeProto(tensor=nd_proto) + return NodeProto(tensor=nd_proto) + + @classmethod + def read_ndarray(cls, pb_msg: 'NdArrayProto') -> 'Tensor': + """ + read ndarray from a proto msg + :param pb_msg: + :return: a numpy array + """ + source = pb_msg.dense + if source.buffer: + x = np.frombuffer(source.buffer, dtype=source.dtype) + return cls.from_ndarray(x.reshape(source.shape)) + elif len(source.shape) > 0: + return cls.from_ndarray(np.zeros(source.shape)) + else: + raise ValueError(f'proto message {pb_msg} cannot be cast to a Tensor') + + @staticmethod + def flush_ndarray(pb_msg: 'NdArrayProto', value: 'Tensor'): + pb_msg.dense.buffer = value.tobytes() + pb_msg.dense.ClearField('shape') + pb_msg.dense.shape.extend(list(value.shape)) + pb_msg.dense.dtype = value.dtype.str \ No newline at end of file diff --git a/tests/integrations/predefined_document/test_image.py b/tests/integrations/predefined_document/test_image.py index 964935d2c12..e1de2b020a0 100644 --- a/tests/integrations/predefined_document/test_image.py +++ b/tests/integrations/predefined_document/test_image.py @@ -1,3 +1,5 @@ +import numpy as np + from docarray import Image from docarray.typing import Tensor @@ -8,4 +10,4 @@ def test_image(): image.tensor = image.uri.load() - assert isinstance(image.tensor, Tensor) + assert isinstance(image.tensor, np.ndarray) diff --git a/tests/integrations/typing/__init__.py b/tests/integrations/typing/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/integrations/typing/tensor.py b/tests/integrations/typing/tensor.py new file mode 100644 index 00000000000..0d17207e667 --- /dev/null +++ b/tests/integrations/typing/tensor.py @@ -0,0 +1,16 @@ +import numpy as np + +from docarray.typing import Tensor +from docarray import Document + + +def test_set_tensor(): + + class MyDocument(Document): + tensor: Tensor + + d = MyDocument(tensor=np.zeros((3, 224, 224))) + + assert isinstance(d.tensor, Tensor) + assert isinstance(d.tensor, np.ndarray) + assert (d.tensor == np.zeros((3, 224, 224))).all() diff --git a/tests/units/document/proto/test_document_proto.py b/tests/units/document/proto/test_document_proto.py index 4ae299d3347..f0918b5b721 100644 --- a/tests/units/document/proto/test_document_proto.py +++ b/tests/units/document/proto/test_document_proto.py @@ -4,6 +4,7 @@ from docarray import DocumentArray from docarray.document import BaseDocument +from docarray.typing import Tensor def test_proto_simple(): @@ -17,7 +18,7 @@ class CustomDoc(BaseDocument): def test_proto_ndarray(): class CustomDoc(BaseDocument): - tensor: np.ndarray + tensor: Tensor tensor = np.zeros((3, 224, 224)) doc = CustomDoc(tensor=tensor) @@ -29,7 +30,7 @@ class CustomDoc(BaseDocument): def test_proto_with_nested_doc(): class CustomInnerDoc(BaseDocument): - tensor: np.ndarray + tensor: Tensor class CustomDoc(BaseDocument): text: str @@ -42,7 +43,7 @@ class CustomDoc(BaseDocument): def test_proto_with_chunks_doc(): class CustomInnerDoc(BaseDocument): - tensor: np.ndarray + tensor: Tensor class CustomDoc(BaseDocument): text: str diff --git a/tests/units/document/proto/test_proto_based_object.py b/tests/units/document/proto/test_proto_based_object.py index 6f9acf77245..c2fd61773c8 100644 --- a/tests/units/document/proto/test_proto_based_object.py +++ b/tests/units/document/proto/test_proto_based_object.py @@ -1,7 +1,7 @@ import numpy as np from docarray.proto import DocumentProto, NdArrayProto, NodeProto -from docarray.proto.io import flush_ndarray, read_ndarray +from docarray.typing import Tensor def test_nested_item_proto(): @@ -16,9 +16,9 @@ def test_nested_optional_item_proto(): def test_ndarray(): nd_proto = NdArrayProto() original_tensor = np.zeros((3, 224, 224)) - flush_ndarray(nd_proto, value=original_tensor) + Tensor.flush_ndarray(nd_proto, value=original_tensor) nested_item = NodeProto(tensor=nd_proto) - tensor = read_ndarray(nested_item.tensor) + tensor = Tensor.read_ndarray(nested_item.tensor) assert (tensor == original_tensor).all() @@ -31,7 +31,7 @@ def test_document_proto_set(): nd_proto = NdArrayProto() original_tensor = np.zeros((3, 224, 224)) - flush_ndarray(nd_proto, value=original_tensor) + Tensor.flush_ndarray(nd_proto, value=original_tensor) nested_item2 = NodeProto(tensor=nd_proto) diff --git a/tests/units/typing/test_image_url.py b/tests/units/typing/test_image_url.py index ef21c8de4ee..74c511dba43 100644 --- a/tests/units/typing/test_image_url.py +++ b/tests/units/typing/test_image_url.py @@ -1,3 +1,4 @@ +import numpy as np from pydantic.tools import parse_obj_as from docarray.typing import ImageUrl, Tensor @@ -8,4 +9,4 @@ def test_image_url(): tensor = uri.load() - assert isinstance(tensor, Tensor) + assert isinstance(tensor, np.ndarray) From 4a0f7bfdd0c07cf26df2e1fd589fa83792b59ea8 Mon Sep 17 00:00:00 2001 From: Johannes Messner Date: Fri, 11 Nov 2022 16:35:09 +0100 Subject: [PATCH 2/2] fix: fix mypy typing --- docarray/typing/tensor/tensor.py | 22 ++++++++++------------ docarray/typing/url/image_url.py | 4 +--- 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/docarray/typing/tensor/tensor.py b/docarray/typing/tensor/tensor.py index 8a1c52ebac4..179eaca3596 100644 --- a/docarray/typing/tensor/tensor.py +++ b/docarray/typing/tensor/tensor.py @@ -1,15 +1,14 @@ -from typing import Union, TypeVar, Any, TYPE_CHECKING +from typing import Union, TypeVar, Any, TYPE_CHECKING, Type, cast import numpy as np if TYPE_CHECKING: from pydantic.fields import ModelField - from pydantic import BaseConfig + from pydantic import BaseConfig, PydanticValueError from docarray.document.base_node import BaseNode from docarray.proto import DocumentProto, NdArrayProto, NodeProto -from pydantic import ValidationError -T = TypeVar('T', bound=np.ndarray) +T = TypeVar('T', bound='Tensor') class Tensor(np.ndarray, BaseNode): @@ -21,25 +20,24 @@ def __get_validators__(cls): yield cls.validate @classmethod - def validate(cls: T, value: Union[T, Any], field: 'ModelField', config: 'BaseConfig') -> T: + def validate(cls: Type[T], value: Union[T, Any], field: 'ModelField', config: 'BaseConfig') -> T: if isinstance(value, np.ndarray): return cls.from_ndarray(value) elif isinstance(value, Tensor): - return value + return cast(T, value) else: try: - arr = np.ndarray(value) + arr: np.ndarray = np.ndarray(value) return cls.from_ndarray(arr) except Exception: pass # handled below - raise ValidationError(f'Expected a numpy.ndarray, got {type(value)}') + raise ValueError(f'Expected a numpy.ndarray, got {type(value)}') @classmethod - def from_ndarray(cls, value: np.ndarray) -> T: + def from_ndarray(cls: Type[T], value: np.ndarray) -> T: return value.view(cls) - - def _to_nested_item_protobuf(self) -> 'NodeProto': + def _to_nested_item_protobuf(self: T) -> 'NodeProto': """Convert Document into a nested item protobuf message. This function should be called when the Document is nested into another Document that need to be converted into a protobuf @@ -51,7 +49,7 @@ def _to_nested_item_protobuf(self) -> 'NodeProto': return NodeProto(tensor=nd_proto) @classmethod - def read_ndarray(cls, pb_msg: 'NdArrayProto') -> 'Tensor': + def read_ndarray(cls: Type[T], pb_msg: 'NdArrayProto') -> 'T': """ read ndarray from a proto msg :param pb_msg: diff --git a/docarray/typing/url/image_url.py b/docarray/typing/url/image_url.py index 4e9338e13d5..68d956fb87d 100644 --- a/docarray/typing/url/image_url.py +++ b/docarray/typing/url/image_url.py @@ -1,12 +1,10 @@ import numpy as np -from docarray.typing import Tensor - from .any_url import AnyUrl class ImageUrl(AnyUrl): - def load(self) -> Tensor: + def load(self) -> np.ndarray: """ transform the url in a image Tensor