diff --git a/README.md b/README.md index dd2999672d0..b76423aab2a 100644 --- a/README.md +++ b/README.md @@ -36,15 +36,14 @@ This follow [Pydantic Model](https://pydantic-docs.helpmanual.io/usage/models/) It is similar to the dataclass from the (old) docarray - ```python -from docarray.typing import Tensor +from docarray.typing import NdArray import numpy as np class Banner(Document): text: str - image: Tensor + image: NdArray banner = Banner(text='DocArray is amazing', image=np.zeros((3, 224, 224))) diff --git a/docarray/document/mixins/proto.py b/docarray/document/mixins/proto.py index 78edd99c4f5..5b6f3205460 100644 --- a/docarray/document/mixins/proto.py +++ b/docarray/document/mixins/proto.py @@ -8,7 +8,7 @@ AnyUrl, Embedding, ImageUrl, - Tensor, + NdArray, TextUrl, TorchTensor, ) @@ -31,7 +31,7 @@ def from_protobuf(cls: Type[T], pb_msg: 'DocumentProto') -> T: # this if else statement need to be refactored it is too long # the check should be delegated to the type level content_type_dict = dict( - tensor=Tensor, + ndarray=NdArray, torch_tensor=TorchTensor, embedding=Embedding, any_url=AnyUrl, diff --git a/docarray/proto/docarray.proto b/docarray/proto/docarray.proto index 747cd7bf203..e3843c3d43f 100644 --- a/docarray/proto/docarray.proto +++ b/docarray/proto/docarray.proto @@ -36,7 +36,7 @@ message NodeProto { bytes blob = 1; // the ndarray of the image/audio/video document - NdArrayProto tensor = 2; + NdArrayProto ndarray = 2; // a text string text = 3; diff --git a/docarray/proto/pb2/docarray_pb2.py b/docarray/proto/pb2/docarray_pb2.py index f7813a5c4c3..1784b9cbe5e 100644 --- a/docarray/proto/pb2/docarray_pb2.py +++ b/docarray/proto/pb2/docarray_pb2.py @@ -15,7 +15,7 @@ from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x0e\x64ocarray.proto\x12\x08\x64ocarray\x1a\x1cgoogle/protobuf/struct.proto\"A\n\x11\x44\x65nseNdArrayProto\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\r\x12\r\n\x05\x64type\x18\x03 \x01(\t\"g\n\x0cNdArrayProto\x12*\n\x05\x64\x65nse\x18\x01 \x01(\x0b\x32\x1b.docarray.DenseNdArrayProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\"\xe2\x02\n\tNodeProto\x12\x0e\n\x04\x62lob\x18\x01 \x01(\x0cH\x00\x12(\n\x06tensor\x18\x02 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x12\x0e\n\x04text\x18\x03 \x01(\tH\x00\x12)\n\x06nested\x18\x04 \x01(\x0b\x32\x17.docarray.DocumentProtoH\x00\x12.\n\x06\x63hunks\x18\x05 \x01(\x0b\x32\x1c.docarray.DocumentArrayProtoH\x00\x12+\n\tembedding\x18\x06 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x12\x11\n\x07\x61ny_url\x18\x07 \x01(\tH\x00\x12\x13\n\timage_url\x18\x08 \x01(\tH\x00\x12\x12\n\x08text_url\x18\t \x01(\tH\x00\x12\x0c\n\x02id\x18\n \x01(\tH\x00\x12.\n\x0ctorch_tensor\x18\x0b \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x42\t\n\x07\x63ontent\"\x82\x01\n\rDocumentProto\x12/\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32!.docarray.DocumentProto.DataEntry\x1a@\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.docarray.NodeProto:\x02\x38\x01\";\n\x12\x44ocumentArrayProto\x12%\n\x04\x64ocs\x18\x01 \x03(\x0b\x32\x17.docarray.DocumentProtob\x06proto3' + b'\n\x0e\x64ocarray.proto\x12\x08\x64ocarray\x1a\x1cgoogle/protobuf/struct.proto\"A\n\x11\x44\x65nseNdArrayProto\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\r\x12\r\n\x05\x64type\x18\x03 \x01(\t\"g\n\x0cNdArrayProto\x12*\n\x05\x64\x65nse\x18\x01 \x01(\x0b\x32\x1b.docarray.DenseNdArrayProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\"\xe3\x02\n\tNodeProto\x12\x0e\n\x04\x62lob\x18\x01 \x01(\x0cH\x00\x12)\n\x07ndarray\x18\x02 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x12\x0e\n\x04text\x18\x03 \x01(\tH\x00\x12)\n\x06nested\x18\x04 \x01(\x0b\x32\x17.docarray.DocumentProtoH\x00\x12.\n\x06\x63hunks\x18\x05 \x01(\x0b\x32\x1c.docarray.DocumentArrayProtoH\x00\x12+\n\tembedding\x18\x06 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x12\x11\n\x07\x61ny_url\x18\x07 \x01(\tH\x00\x12\x13\n\timage_url\x18\x08 \x01(\tH\x00\x12\x12\n\x08text_url\x18\t \x01(\tH\x00\x12\x0c\n\x02id\x18\n \x01(\tH\x00\x12.\n\x0ctorch_tensor\x18\x0b \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x42\t\n\x07\x63ontent\"\x82\x01\n\rDocumentProto\x12/\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32!.docarray.DocumentProto.DataEntry\x1a@\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.docarray.NodeProto:\x02\x38\x01\";\n\x12\x44ocumentArrayProto\x12%\n\x04\x64ocs\x18\x01 \x03(\x0b\x32\x17.docarray.DocumentProtob\x06proto3' ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -30,11 +30,11 @@ _NDARRAYPROTO._serialized_start = 125 _NDARRAYPROTO._serialized_end = 228 _NODEPROTO._serialized_start = 231 - _NODEPROTO._serialized_end = 585 - _DOCUMENTPROTO._serialized_start = 588 - _DOCUMENTPROTO._serialized_end = 718 - _DOCUMENTPROTO_DATAENTRY._serialized_start = 654 - _DOCUMENTPROTO_DATAENTRY._serialized_end = 718 - _DOCUMENTARRAYPROTO._serialized_start = 720 - _DOCUMENTARRAYPROTO._serialized_end = 779 + _NODEPROTO._serialized_end = 586 + _DOCUMENTPROTO._serialized_start = 589 + _DOCUMENTPROTO._serialized_end = 719 + _DOCUMENTPROTO_DATAENTRY._serialized_start = 655 + _DOCUMENTPROTO_DATAENTRY._serialized_end = 719 + _DOCUMENTARRAYPROTO._serialized_start = 721 + _DOCUMENTARRAYPROTO._serialized_end = 780 # @@protoc_insertion_point(module_scope) diff --git a/docarray/typing/__init__.py b/docarray/typing/__init__.py index 4a56a7bd3d4..581c7ff4bd0 100644 --- a/docarray/typing/__init__.py +++ b/docarray/typing/__init__.py @@ -1,14 +1,15 @@ from docarray.typing.id import ID -from docarray.typing.tensor import Tensor, TorchTensor +from docarray.typing.tensor import NdArray, Tensor, TorchTensor from docarray.typing.tensor.embedding import Embedding from docarray.typing.url import AnyUrl, ImageUrl, TextUrl __all__ = [ 'TorchTensor', - 'Tensor', + 'NdArray', 'Embedding', 'ImageUrl', 'TextUrl', 'AnyUrl', 'ID', + 'Tensor', ] diff --git a/docarray/typing/tensor/__init__.py b/docarray/typing/tensor/__init__.py index d195a58053d..5bb6553a9b8 100644 --- a/docarray/typing/tensor/__init__.py +++ b/docarray/typing/tensor/__init__.py @@ -1,4 +1,5 @@ +from docarray.typing.tensor.ndarray import NdArray from docarray.typing.tensor.tensor import Tensor from docarray.typing.tensor.torch_tensor import TorchTensor -__all__ = ['Tensor', 'TorchTensor'] +__all__ = ['NdArray', 'TorchTensor', 'Tensor'] diff --git a/docarray/typing/tensor/embedding.py b/docarray/typing/tensor/embedding.py index eb6caa3243a..a77db717098 100644 --- a/docarray/typing/tensor/embedding.py +++ b/docarray/typing/tensor/embedding.py @@ -1,12 +1,12 @@ from typing import TypeVar from docarray.proto import NodeProto -from docarray.typing.tensor import Tensor +from docarray.typing.tensor import NdArray T = TypeVar('T', bound='Embedding') -class Embedding(Tensor): +class Embedding(NdArray): def _to_node_protobuf(self: T, field: str = 'tensor') -> NodeProto: """Convert Document into a NodeProto protobuf message. This function should be called when the Document is nested into another Document that need to be diff --git a/docarray/typing/tensor/ndarray.py b/docarray/typing/tensor/ndarray.py new file mode 100644 index 00000000000..2b19dba179e --- /dev/null +++ b/docarray/typing/tensor/ndarray.py @@ -0,0 +1,121 @@ +from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union, cast + +import numpy as np + +from docarray.typing.abstract_type import AbstractType + +if TYPE_CHECKING: + from pydantic.fields import ModelField + from pydantic import BaseConfig + +from docarray.proto import NdArrayProto, NodeProto + +T = TypeVar('T', bound='NdArray') + + +class NdArray(np.ndarray, AbstractType): + @classmethod + def __get_validators__(cls): + # one or more validators may be yielded which will be called in the + # order to validate the input, each validator will receive as an input + # the value returned from the previous validator + yield cls.validate + + @classmethod + def validate( + cls: Type[T], + value: Union[T, np.ndarray, List[Any], Tuple[Any], Any], + field: 'ModelField', + config: 'BaseConfig', + ) -> T: + if isinstance(value, np.ndarray): + return cls.from_ndarray(value) + elif isinstance(value, NdArray): + return cast(T, value) + elif isinstance(value, list) or isinstance(value, tuple): + try: + arr_from_list: np.ndarray = np.asarray(value) + return cls.from_ndarray(arr_from_list) + except Exception: + pass # handled below + else: + try: + arr: np.ndarray = np.ndarray(value) + return cls.from_ndarray(arr) + except Exception: + pass # handled below + raise ValueError(f'Expected a numpy.ndarray compatible type, got {type(value)}') + + @classmethod + def from_ndarray(cls: Type[T], value: np.ndarray) -> T: + return value.view(cls) + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + # this is needed to dump to json + field_schema.update(type='string', format='tensor') + + def _to_json_compatible(self) -> np.ndarray: + """ + Convert tensor into a json compatible object + :return: a list representation of the tensor + """ + return self.unwrap() + + def unwrap(self) -> np.ndarray: + """ + Return the original ndarray without any memory copy. + + The original view rest intact and is still a Document NdArray + but the return object is a pure np.ndarray but both object share + the same memory layout. + + EXAMPLE USAGE + .. code-block:: python + from docarray.typing import NdArray + import numpy as np + + t1 = NdArray.validate(np.zeros((3, 224, 224)), None, None) + # here t is a docarray TenNdArray + t2 = t.unwrap() + # here t2 is a pure np.ndarray but t1 is still a Docarray NdArray + # But both share the same underlying memory + + + :return: a numpy ndarray + """ + return self.view(np.ndarray) + + def _to_node_protobuf(self: T, field: str = 'ndarray') -> NodeProto: + """Convert itself into a NodeProto protobuf message. This function should + be called when the Document is nested into another Document that need to be + converted into a protobuf + :param field: field in which to store the content in the node proto + :return: the nested item protobuf message + """ + nd_proto = NdArrayProto() + self._flush_tensor_to_proto(nd_proto, value=self) + return NodeProto(**{field: nd_proto}) + + @classmethod + def from_protobuf(cls: Type[T], pb_msg: 'NdArrayProto') -> 'T': + """ + read ndarray from a proto msg + :param pb_msg: + :return: a numpy array + """ + source = pb_msg.dense + if source.buffer: + x = np.frombuffer(source.buffer, dtype=source.dtype) + return cls.from_ndarray(x.reshape(source.shape)) + elif len(source.shape) > 0: + return cls.from_ndarray(np.zeros(source.shape)) + else: + raise ValueError(f'proto message {pb_msg} cannot be cast to a NdArray') + + @staticmethod + def _flush_tensor_to_proto(pb_msg: 'NdArrayProto', value: 'NdArray'): + pb_msg.dense.buffer = value.tobytes() + pb_msg.dense.ClearField('shape') + pb_msg.dense.shape.extend(list(value.shape)) + pb_msg.dense.dtype = value.dtype.str diff --git a/docarray/typing/tensor/tensor.py b/docarray/typing/tensor/tensor.py index 93a040a6302..5a9df3b23be 100644 --- a/docarray/typing/tensor/tensor.py +++ b/docarray/typing/tensor/tensor.py @@ -1,121 +1,6 @@ -from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union, cast +from typing import Union -import numpy as np +from docarray.typing.tensor.ndarray import NdArray +from docarray.typing.tensor.torch_tensor import TorchTensor -from docarray.typing.abstract_type import AbstractType - -if TYPE_CHECKING: - from pydantic.fields import ModelField - from pydantic import BaseConfig - -from docarray.proto import NdArrayProto, NodeProto - -T = TypeVar('T', bound='Tensor') - - -class Tensor(np.ndarray, AbstractType): - @classmethod - def __get_validators__(cls): - # one or more validators may be yielded which will be called in the - # order to validate the input, each validator will receive as an input - # the value returned from the previous validator - yield cls.validate - - @classmethod - def validate( - cls: Type[T], - value: Union[T, np.ndarray, List[Any], Tuple[Any], Any], - field: 'ModelField', - config: 'BaseConfig', - ) -> T: - if isinstance(value, np.ndarray): - return cls.from_ndarray(value) - elif isinstance(value, Tensor): - return cast(T, value) - elif isinstance(value, list) or isinstance(value, tuple): - try: - arr_from_list: np.ndarray = np.asarray(value) - return cls.from_ndarray(arr_from_list) - except Exception: - pass # handled below - else: - try: - arr: np.ndarray = np.ndarray(value) - return cls.from_ndarray(arr) - except Exception: - pass # handled below - raise ValueError(f'Expected a numpy.ndarray compatible type, got {type(value)}') - - @classmethod - def from_ndarray(cls: Type[T], value: np.ndarray) -> T: - return value.view(cls) - - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - # this is needed to dump to json - field_schema.update(type='string', format='tensor') - - def _to_json_compatible(self) -> np.ndarray: - """ - Convert tensor into a json compatible object - :return: a list representation of the tensor - """ - return self.unwrap() - - def unwrap(self) -> np.ndarray: - """ - Return the original ndarray without any memory copy. - - The original view rest intact and is still a Document Tensor - but the return object is a pure np.ndarray but both object share - the same memory layout. - - EXAMPLE USAGE - .. code-block:: python - from docarray.typing import Tensor - import numpy as np - - t1 = Tensor.validate(np.zeros((3, 224, 224)), None, None) - # here t is a docarray Tensor - t2 = t.unwrap() - # here t2 is a pure np.ndarray but t1 is still a Docarray Tensor - # But both share the same underlying memory - - - :return: a numpy ndarray - """ - return self.view(np.ndarray) - - def _to_node_protobuf(self: T, field: str = 'tensor') -> NodeProto: - """Convert itself into a NodeProto protobuf message. This function should - be called when the Document is nested into another Document that need to be - converted into a protobuf - :param field: field in which to store the content in the node proto - :return: the nested item protobuf message - """ - nd_proto = NdArrayProto() - self._flush_tensor_to_proto(nd_proto, value=self) - return NodeProto(**{field: nd_proto}) - - @classmethod - def from_protobuf(cls: Type[T], pb_msg: 'NdArrayProto') -> 'T': - """ - read ndarray from a proto msg - :param pb_msg: - :return: a numpy array - """ - source = pb_msg.dense - if source.buffer: - x = np.frombuffer(source.buffer, dtype=source.dtype) - return cls.from_ndarray(x.reshape(source.shape)) - elif len(source.shape) > 0: - return cls.from_ndarray(np.zeros(source.shape)) - else: - raise ValueError(f'proto message {pb_msg} cannot be cast to a Tensor') - - @staticmethod - def _flush_tensor_to_proto(pb_msg: 'NdArrayProto', value: 'Tensor'): - pb_msg.dense.buffer = value.tobytes() - pb_msg.dense.ClearField('shape') - pb_msg.dense.shape.extend(list(value.shape)) - pb_msg.dense.dtype = value.dtype.str +Tensor = Union[NdArray, TorchTensor] diff --git a/tests/integrations/array/test_array_proto.py b/tests/integrations/array/test_array_proto.py index 2ffef02755d..54eb4f1fafd 100644 --- a/tests/integrations/array/test_array_proto.py +++ b/tests/integrations/array/test_array_proto.py @@ -1,13 +1,13 @@ import numpy as np -from docarray import DocumentArray, Document, Image, Text -from docarray.typing import Tensor +from docarray import Document, DocumentArray, Image, Text +from docarray.typing import NdArray def test_simple_proto(): class CustomDoc(Document): text: str - tensor: Tensor + tensor: NdArray da = DocumentArray( [CustomDoc(text='hello', tensor=np.zeros((3, 224, 224))) for _ in range(10)] diff --git a/tests/integrations/document/test_proto.py b/tests/integrations/document/test_proto.py index c60cf9b1a7f..0092aabc0b7 100644 --- a/tests/integrations/document/test_proto.py +++ b/tests/integrations/document/test_proto.py @@ -1,6 +1,16 @@ import numpy as np +import torch -from docarray import DocumentArray, Document, Image, Text +from docarray import Document, Image, Text +from docarray.typing import ( + AnyUrl, + Embedding, + ImageUrl, + NdArray, + Tensor, + TextUrl, + TorchTensor, +) def test_multi_modal_doc_proto(): @@ -17,3 +27,40 @@ class MySUperDoc(Document): ) MyMultiModalDoc.from_protobuf(doc.to_protobuf()) + + +def test_all_types(): + class MyDoc(Document): + img_url: ImageUrl + txt_url: TextUrl + any_url: AnyUrl + torch_tensor: TorchTensor + np_array: NdArray + generic_nd_array: Tensor + generic_torch_tensor: Tensor + embedding: Embedding + + doc = MyDoc( + img_url='test.png', + txt_url='test.txt', + any_url='www.jina.ai', + torch_tensor=torch.zeros((3, 224, 224)), + np_array=np.zeros((3, 224, 224)), + generic_nd_array=np.zeros((3, 224, 224)), + generic_torch_tensor=torch.zeros((3, 224, 224)), + embedding=np.zeros((3, 224, 224)), + ) + doc = MyDoc.from_protobuf(doc.to_protobuf()) + + assert doc.img_url == 'test.png' + assert doc.txt_url == 'test.txt' + assert doc.any_url == 'www.jina.ai' + assert (doc.torch_tensor == torch.zeros((3, 224, 224))).all() + assert isinstance(doc.torch_tensor, torch.Tensor) + assert (doc.np_array == np.zeros((3, 224, 224))).all() + assert isinstance(doc.np_array, np.ndarray) + assert (doc.generic_nd_array == np.zeros((3, 224, 224))).all() + assert isinstance(doc.generic_nd_array, np.ndarray) + assert (doc.generic_torch_tensor == torch.zeros((3, 224, 224))).all() + assert isinstance(doc.generic_torch_tensor, torch.Tensor) + assert (doc.embedding == np.zeros((3, 224, 224))).all() diff --git a/tests/integrations/document/test_to_json.py b/tests/integrations/document/test_to_json.py index 2aa642d0589..81bfeb2db32 100644 --- a/tests/integrations/document/test_to_json.py +++ b/tests/integrations/document/test_to_json.py @@ -2,12 +2,12 @@ import torch from docarray.document import BaseDocument -from docarray.typing import AnyUrl, Tensor, TorchTensor +from docarray.typing import AnyUrl, NdArray, TorchTensor def test_to_json(): class Mmdoc(BaseDocument): - img: Tensor + img: NdArray url: AnyUrl txt: str torch_tensor: TorchTensor @@ -23,7 +23,7 @@ class Mmdoc(BaseDocument): def test_from_json(): class Mmdoc(BaseDocument): - img: Tensor + img: NdArray url: AnyUrl txt: str torch_tensor: TorchTensor diff --git a/tests/integrations/externals/test_fastapi.py b/tests/integrations/externals/test_fastapi.py index 0ebb6682518..0459a80b5a3 100644 --- a/tests/integrations/externals/test_fastapi.py +++ b/tests/integrations/externals/test_fastapi.py @@ -4,7 +4,7 @@ from httpx import AsyncClient from docarray import Document, Image, Text -from docarray.typing import Tensor +from docarray.typing import NdArray @pytest.mark.asyncio @@ -40,8 +40,8 @@ class InputDoc(Document): img: Image class OutputDoc(Document): - embedding_clip: Tensor - embedding_bert: Tensor + embedding_clip: NdArray + embedding_bert: NdArray input_doc = InputDoc(img=Image(tensor=np.zeros((3, 224, 224)))) @@ -70,8 +70,8 @@ class InputDoc(Document): text: str class OutputDoc(Document): - embedding_clip: Tensor - embedding_bert: Tensor + embedding_clip: NdArray + embedding_bert: NdArray input_doc = InputDoc(text='hello') diff --git a/tests/integrations/typing/test_ndarray.py b/tests/integrations/typing/test_ndarray.py new file mode 100644 index 00000000000..75d294def09 --- /dev/null +++ b/tests/integrations/typing/test_ndarray.py @@ -0,0 +1,15 @@ +import numpy as np + +from docarray import Document +from docarray.typing import NdArray + + +def test_set_tensor(): + class MyDocument(Document): + tensor: NdArray + + d = MyDocument(tensor=np.zeros((3, 224, 224))) + + assert isinstance(d.tensor, NdArray) + assert isinstance(d.tensor, np.ndarray) + assert (d.tensor == np.zeros((3, 224, 224))).all() diff --git a/tests/integrations/typing/test_tensor.py b/tests/integrations/typing/test_tensor.py index 880e1df65a8..23c2a86ef67 100644 --- a/tests/integrations/typing/test_tensor.py +++ b/tests/integrations/typing/test_tensor.py @@ -1,7 +1,8 @@ import numpy as np +import torch from docarray import Document -from docarray.typing import Tensor +from docarray.typing import NdArray, Tensor, TorchTensor def test_set_tensor(): @@ -10,6 +11,12 @@ class MyDocument(Document): d = MyDocument(tensor=np.zeros((3, 224, 224))) - assert isinstance(d.tensor, Tensor) + assert isinstance(d.tensor, NdArray) assert isinstance(d.tensor, np.ndarray) assert (d.tensor == np.zeros((3, 224, 224))).all() + + d = MyDocument(tensor=torch.zeros((3, 224, 224))) + + assert isinstance(d.tensor, TorchTensor) + assert isinstance(d.tensor, torch.Tensor) + assert (d.tensor == torch.zeros((3, 224, 224))).all() diff --git a/tests/integrations/typing/test_typing_proto.py b/tests/integrations/typing/test_typing_proto.py index e9290b8b081..e1c839b5d48 100644 --- a/tests/integrations/typing/test_typing_proto.py +++ b/tests/integrations/typing/test_typing_proto.py @@ -3,12 +3,12 @@ from docarray import Document from docarray.document import AnyDocument -from docarray.typing import AnyUrl, Embedding, ImageUrl, Tensor, TextUrl, TorchTensor +from docarray.typing import AnyUrl, Embedding, ImageUrl, NdArray, TextUrl, TorchTensor def test_proto_all_types(): class Mymmdoc(Document): - tensor: Tensor + tensor: NdArray torch_tensor: TorchTensor embedding: Embedding any_url: AnyUrl diff --git a/tests/units/array/test_mixins/test_attribute.py b/tests/units/array/test_mixins/test_attribute.py index 2f83110143e..1a168bcd0f0 100644 --- a/tests/units/array/test_mixins/test_attribute.py +++ b/tests/units/array/test_mixins/test_attribute.py @@ -2,13 +2,13 @@ from docarray.array import DocumentArray from docarray.document import BaseDocument -from docarray.typing import Tensor +from docarray.typing import NdArray def test_get_bulk_attributes_function(): class Mmdoc(BaseDocument): text: str - tensor: Tensor + tensor: NdArray N = 10 @@ -32,7 +32,7 @@ class Mmdoc(BaseDocument): def test_get_bulk_attributes(): class Mmdoc(BaseDocument): text: str - tensor: Tensor + tensor: NdArray N = 10 diff --git a/tests/units/document/proto/test_document_proto.py b/tests/units/document/proto/test_document_proto.py index a842b53eda6..37a3a02414d 100644 --- a/tests/units/document/proto/test_document_proto.py +++ b/tests/units/document/proto/test_document_proto.py @@ -5,7 +5,7 @@ from docarray import DocumentArray from docarray.document import BaseDocument -from docarray.typing import Tensor, TorchTensor +from docarray.typing import NdArray, TorchTensor def test_proto_simple(): @@ -19,7 +19,7 @@ class CustomDoc(BaseDocument): def test_proto_ndarray(): class CustomDoc(BaseDocument): - tensor: Tensor + tensor: NdArray tensor = np.zeros((3, 224, 224)) doc = CustomDoc(tensor=tensor) @@ -31,7 +31,7 @@ class CustomDoc(BaseDocument): def test_proto_with_nested_doc(): class CustomInnerDoc(BaseDocument): - tensor: Tensor + tensor: NdArray class CustomDoc(BaseDocument): text: str @@ -44,7 +44,7 @@ class CustomDoc(BaseDocument): def test_proto_with_chunks_doc(): class CustomInnerDoc(BaseDocument): - tensor: Tensor + tensor: NdArray class CustomDoc(BaseDocument): text: str diff --git a/tests/units/document/proto/test_proto_based_object.py b/tests/units/document/proto/test_proto_based_object.py index d4e3d47f504..d986e86086f 100644 --- a/tests/units/document/proto/test_proto_based_object.py +++ b/tests/units/document/proto/test_proto_based_object.py @@ -1,7 +1,7 @@ import numpy as np from docarray.proto import DocumentProto, NdArrayProto, NodeProto -from docarray.typing import Tensor +from docarray.typing import NdArray def test_nested_item_proto(): @@ -15,12 +15,12 @@ def test_nested_optional_item_proto(): def test_ndarray(): nd_proto = NdArrayProto() - original_tensor = np.zeros((3, 224, 224)) - Tensor._flush_tensor_to_proto(nd_proto, value=original_tensor) - nested_item = NodeProto(tensor=nd_proto) - tensor = Tensor.from_protobuf(nested_item.tensor) + original_ndarray = np.zeros((3, 224, 224)) + NdArray._flush_tensor_to_proto(nd_proto, value=original_ndarray) + nested_item = NodeProto(ndarray=nd_proto) + tensor = NdArray.from_protobuf(nested_item.ndarray) - assert (tensor == original_tensor).all() + assert (tensor == original_ndarray).all() def test_document_proto_set(): @@ -30,10 +30,10 @@ def test_document_proto_set(): nested_item1 = NodeProto(text='hello') nd_proto = NdArrayProto() - original_tensor = np.zeros((3, 224, 224)) - Tensor._flush_tensor_to_proto(nd_proto, value=original_tensor) + original_ndarray = np.zeros((3, 224, 224)) + NdArray._flush_tensor_to_proto(nd_proto, value=original_ndarray) - nested_item2 = NodeProto(tensor=nd_proto) + nested_item2 = NodeProto(ndarray=nd_proto) data['a'] = nested_item1 data['b'] = nested_item2 diff --git a/tests/units/document/test_any_document.py b/tests/units/document/test_any_document.py index bfa8dfb7e91..0f6a82e4e10 100644 --- a/tests/units/document/test_any_document.py +++ b/tests/units/document/test_any_document.py @@ -1,13 +1,13 @@ import numpy as np from docarray.document import AnyDocument, BaseDocument -from docarray.typing import Tensor +from docarray.typing import NdArray def test_any_doc(): class InnerDocument(BaseDocument): text: str - tensor: Tensor + tensor: NdArray class CustomDoc(BaseDocument): inner: InnerDocument diff --git a/tests/units/typing/tensor/test_tensor.py b/tests/units/typing/tensor/test_tensor.py index fbe5d72a50e..eb6b966f24d 100644 --- a/tests/units/typing/tensor/test_tensor.py +++ b/tests/units/typing/tensor/test_tensor.py @@ -3,33 +3,33 @@ from pydantic.tools import parse_obj_as, schema_json_of from docarray.document.io.json import orjson_dumps -from docarray.typing import Tensor +from docarray.typing import NdArray def test_proto_tensor(): - tensor = parse_obj_as(Tensor, np.zeros((3, 224, 224))) + tensor = parse_obj_as(NdArray, np.zeros((3, 224, 224))) tensor._to_node_protobuf() def test_from_list(): - tensor = parse_obj_as(Tensor, [[0.0, 0.0], [0.0, 0.0]]) + tensor = parse_obj_as(NdArray, [[0.0, 0.0], [0.0, 0.0]]) assert (tensor == np.zeros((2, 2))).all() def test_json_schema(): - schema_json_of(Tensor) + schema_json_of(NdArray) def test_dump_json(): - tensor = parse_obj_as(Tensor, np.zeros((3, 224, 224))) + tensor = parse_obj_as(NdArray, np.zeros((3, 224, 224))) orjson_dumps(tensor) def test_load_json(): - tensor = parse_obj_as(Tensor, np.zeros((2, 2))) + tensor = parse_obj_as(NdArray, np.zeros((2, 2))) json = orjson_dumps(tensor) print(json) @@ -40,10 +40,10 @@ def test_load_json(): def test_unwrap(): - tensor = parse_obj_as(Tensor, np.zeros((3, 224, 224))) + tensor = parse_obj_as(NdArray, np.zeros((3, 224, 224))) ndarray = tensor.unwrap() - assert not isinstance(ndarray, Tensor) + assert not isinstance(ndarray, NdArray) assert isinstance(ndarray, np.ndarray) - assert isinstance(tensor, Tensor) + assert isinstance(tensor, NdArray) assert (ndarray == np.zeros((3, 224, 224))).all()