diff --git a/docarray/array/array.py b/docarray/array/array.py index 681bbe19e02..ba8f750463f 100644 --- a/docarray/array/array.py +++ b/docarray/array/array.py @@ -34,12 +34,13 @@ class _DocumenArrayTyped(DocumentArray): for field in _DocumenArrayTyped.document_type.__fields__.keys(): - def _proprety_generator(val: str): + def _property_generator(val: str): return property(lambda self: self._get_documents_attribute(val)) - setattr(_DocumenArrayTyped, field, _proprety_generator(field)) + setattr(_DocumenArrayTyped, field, _property_generator(field)) # this generates property on the fly based on the schema of the item - _DocumenArrayTyped.__name__ = f'DocumentArray{item.__name__}' + _DocumenArrayTyped.__name__ = f'DocumentArray[{item.__name__}]' + _DocumenArrayTyped.__qualname__ = f'DocumentArray[{item.__name__}]' return _DocumenArrayTyped diff --git a/docarray/typing/tensor/abstract_tensor.py b/docarray/typing/tensor/abstract_tensor.py new file mode 100644 index 00000000000..117e8fab4cb --- /dev/null +++ b/docarray/typing/tensor/abstract_tensor.py @@ -0,0 +1,69 @@ +import abc +from abc import ABC +from typing import TYPE_CHECKING, Any, Generic, Tuple, Type, TypeVar + +from docarray.typing.abstract_type import AbstractType + +if TYPE_CHECKING: + from pydantic import BaseConfig + from pydantic.fields import ModelField + +T = TypeVar('T', bound='AbstractTensor') +ShapeT = TypeVar('ShapeT') + + +class AbstractTensor(AbstractType, Generic[ShapeT], ABC): + + __parametrized_meta__ = type + + @classmethod + @abc.abstractmethod + def __validate_shape__(cls, t: T, shape: Tuple[int]) -> T: + """Every tensor has to implement this method in order to + enbale syntax of the form Tensor[shape]. + The intended behavoiour is as follows: + - If the shape of `t` is equal to `shape`, return `t`. + - If the shape of `t` is not equal to `shape`, + but can be reshaped to `shape`, return `t` reshaped to `shape`. + - If the shape of `t` is not equal to `shape` + and cannot be reshaped to `shape`, raise a ValueError. + + :param t: The tensor to validate. + :param shape: The shape to validate against. + :return: The validated tensor. + """ + ... + + @classmethod + def _create_parametrized_type(cls: Type[T], shape: Tuple[int]): + shape_str = ', '.join([str(s) for s in shape]) + + class _ParametrizedTensor( + cls, # type: ignore + metaclass=cls.__parametrized_meta__, # type: ignore + ): + _docarray_target_shape = shape + __name__ = f'{cls.__name__}[{shape_str}]' + __qualname__ = f'{cls.__qualname__}[{shape_str}]' + + @classmethod + def validate( + _cls, + value: Any, + field: 'ModelField', + config: 'BaseConfig', + ): + t = super().validate(value, field, config) + return _cls.__validate_shape__(t, _cls._docarray_target_shape) + + return _ParametrizedTensor + + def __class_getitem__(cls, item): + if isinstance(item, int): + item = (item,) + try: + item = tuple(item) + except TypeError: + raise TypeError(f'{item} is not a valid tensor shape.') + + return cls._create_parametrized_type(item) diff --git a/docarray/typing/tensor/ndarray.py b/docarray/typing/tensor/ndarray.py index 2b19dba179e..1e0ee9174fd 100644 --- a/docarray/typing/tensor/ndarray.py +++ b/docarray/typing/tensor/ndarray.py @@ -1,8 +1,20 @@ -from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union, cast +import warnings +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generic, + List, + Tuple, + Type, + TypeVar, + Union, + cast, +) import numpy as np -from docarray.typing.abstract_type import AbstractType +from docarray.typing.tensor.abstract_tensor import AbstractTensor if TYPE_CHECKING: from pydantic.fields import ModelField @@ -11,9 +23,52 @@ from docarray.proto import NdArrayProto, NodeProto T = TypeVar('T', bound='NdArray') +ShapeT = TypeVar('ShapeT') -class NdArray(np.ndarray, AbstractType): +class NdArray(AbstractTensor, np.ndarray, Generic[ShapeT]): + """ + Subclass of np.ndarray, intended for use in a Document. + This enables (de)serialization from/to protobuf and json, data validation, + and coersion from compatible types like torch.Tensor. + + This type can also be used in a parametrized way, specifying the shape of the array. + + EXAMPLE USAGE + + .. code-block:: python + + from docarray import Document + from docarray.typing import NdArray + import numpy as np + + + class MyDoc(Document): + arr: NdArray + image_arr: NdArray[3, 224, 224] + + + # create a document with tensors + doc = MyDoc( + arr=np.zeros((128,)), + image_arr=np.zeros((3, 224, 224)), + ) + assert doc.image_arr.shape == (3, 224, 224) + + # automatic shape conversion + doc = MyDoc( + arr=np.zeros((128,)), + image_arr=np.zeros((224, 224, 3)), # will reshape to (3, 224, 224) + ) + assert doc.image_arr.shape == (3, 224, 224) + + # !! The following will raise an error due to shape mismatch !! + doc = MyDoc( + arr=np.zeros((128,)), + image_arr=np.zeros((224, 224)), # this will fail validation + ) + """ + @classmethod def __get_validators__(cls): # one or more validators may be yielded which will be called in the @@ -21,6 +76,23 @@ def __get_validators__(cls): # the value returned from the previous validator yield cls.validate + @classmethod + def __validate_shape__(cls, t: T, shape: Tuple[int]) -> T: # type: ignore + if t.shape == shape: + return t + else: + warnings.warn( + f'Tensor shape mismatch. Reshaping array ' + f'of shape {t.shape} to shape {shape}' + ) + try: + value = cls.from_ndarray(np.reshape(t, shape)) + return cast(T, value) + except RuntimeError: + raise ValueError( + f'Cannot reshape array of shape {t.shape} to shape {shape}' + ) + @classmethod def validate( cls: Type[T], diff --git a/docarray/typing/tensor/torch_tensor.py b/docarray/typing/tensor/torch_tensor.py index ad95401f161..b564908be75 100644 --- a/docarray/typing/tensor/torch_tensor.py +++ b/docarray/typing/tensor/torch_tensor.py @@ -1,10 +1,11 @@ +import warnings from copy import copy -from typing import TYPE_CHECKING, Any, Dict, Type, TypeVar, Union, cast +from typing import TYPE_CHECKING, Any, Dict, Generic, Tuple, Type, TypeVar, Union, cast import numpy as np import torch # type: ignore -from docarray.typing.abstract_type import AbstractType +from docarray.typing.tensor.abstract_tensor import AbstractTensor if TYPE_CHECKING: from pydantic.fields import ModelField @@ -15,6 +16,7 @@ from docarray.proto import NdArrayProto, NodeProto T = TypeVar('T', bound='TorchTensor') +ShapeT = TypeVar('ShapeT') torch_base = type(torch.Tensor) # type: Any node_base = type(BaseNode) # type: Any @@ -24,9 +26,53 @@ class metaTorchAndNode(torch_base, node_base): pass -class TorchTensor(AbstractType, torch.Tensor, metaclass=metaTorchAndNode): +class TorchTensor( + AbstractTensor, torch.Tensor, Generic[ShapeT], metaclass=metaTorchAndNode +): # Subclassing torch.Tensor following the advice from here: # https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor + """ + Subclass of torch.Tensor, intended for use in a Document. + This enables (de)serialization from/to protobuf and json, data validation, + and coersion from compatible types like numpy.ndarray. + + This type can also be used in a parametrized way, + specifying the shape of the tensor. + + EXAMPLE USAGE + + .. code-block:: python + + from docarray import Document + from docarray.typing import TorchTensor + import torch + + + class MyDoc(Document): + tensor: TorchTensor + image_tensor: TorchTensor[3, 224, 224] + + + # create a document with tensors + doc = MyDoc( + tensor=torch.zeros(128), + image_tensor=torch.zeros(3, 224, 224), + ) + + # automatic shape conversion + doc = MyDoc( + tensor=torch.zeros(128), + image_tensor=torch.zeros(224, 224, 3), # will reshape to (3, 224, 224) + ) + + # !! The following will raise an error due to shape mismatch !! + doc = MyDoc( + tensor=torch.zeros(128), + image_tensor=torch.zeros(224, 224), # this will fail validation + ) + + """ + @classmethod def __get_validators__(cls): # one or more validators may be yielded which will be called in the @@ -34,6 +80,25 @@ def __get_validators__(cls): # the value returned from the previous validator yield cls.validate + __parametrized_meta__ = metaTorchAndNode + + @classmethod + def __validate_shape__(cls, t: T, shape: Tuple[int]) -> T: # type: ignore + if t.shape == shape: + return t + else: + warnings.warn( + f'Tensor shape mismatch. Reshaping tensor ' + f'of shape {t.shape} to shape {shape}' + ) + try: + value = cls.from_native_torch_tensor(t.view(shape)) + return cast(T, value) + except RuntimeError: + raise ValueError( + f'Cannot reshape tensor of ' f'shape {t.shape} to shape {shape}' + ) + @classmethod def validate( cls: Type[T], diff --git a/tests/integrations/document/test_proto.py b/tests/integrations/document/test_proto.py index 0092aabc0b7..98f05f09e15 100644 --- a/tests/integrations/document/test_proto.py +++ b/tests/integrations/document/test_proto.py @@ -35,7 +35,9 @@ class MyDoc(Document): txt_url: TextUrl any_url: AnyUrl torch_tensor: TorchTensor + torch_tensor_param: TorchTensor[224, 224, 3] np_array: NdArray + np_array_param: NdArray[224, 224, 3] generic_nd_array: Tensor generic_torch_tensor: Tensor embedding: Embedding @@ -45,7 +47,9 @@ class MyDoc(Document): txt_url='test.txt', any_url='www.jina.ai', torch_tensor=torch.zeros((3, 224, 224)), + torch_tensor_param=torch.zeros((3, 224, 224)), np_array=np.zeros((3, 224, 224)), + np_array_param=np.zeros((3, 224, 224)), generic_nd_array=np.zeros((3, 224, 224)), generic_torch_tensor=torch.zeros((3, 224, 224)), embedding=np.zeros((3, 224, 224)), @@ -55,12 +59,23 @@ class MyDoc(Document): assert doc.img_url == 'test.png' assert doc.txt_url == 'test.txt' assert doc.any_url == 'www.jina.ai' + assert (doc.torch_tensor == torch.zeros((3, 224, 224))).all() assert isinstance(doc.torch_tensor, torch.Tensor) + + assert (doc.torch_tensor_param == torch.zeros((224, 224, 3))).all() + assert isinstance(doc.torch_tensor_param, torch.Tensor) + assert (doc.np_array == np.zeros((3, 224, 224))).all() assert isinstance(doc.np_array, np.ndarray) + + assert (doc.np_array_param == np.zeros((224, 224, 3))).all() + assert isinstance(doc.np_array_param, np.ndarray) + assert (doc.generic_nd_array == np.zeros((3, 224, 224))).all() assert isinstance(doc.generic_nd_array, np.ndarray) + assert (doc.generic_torch_tensor == torch.zeros((3, 224, 224))).all() assert isinstance(doc.generic_torch_tensor, torch.Tensor) + assert (doc.embedding == np.zeros((3, 224, 224))).all() diff --git a/tests/units/typing/tensor/test_tensor.py b/tests/units/typing/tensor/test_tensor.py index eb6b966f24d..d0ef58bc581 100644 --- a/tests/units/typing/tensor/test_tensor.py +++ b/tests/units/typing/tensor/test_tensor.py @@ -1,5 +1,6 @@ import numpy as np import orjson +import pytest from pydantic.tools import parse_obj_as, schema_json_of from docarray.document.io.json import orjson_dumps @@ -47,3 +48,27 @@ def test_unwrap(): assert isinstance(ndarray, np.ndarray) assert isinstance(tensor, NdArray) assert (ndarray == np.zeros((3, 224, 224))).all() + + +def test_parametrized(): + # correct shape, single axis + tensor = parse_obj_as(NdArray[128], np.zeros(128)) + assert isinstance(tensor, NdArray) + assert isinstance(tensor, np.ndarray) + assert tensor.shape == (128,) + + # correct shape, multiple axis + tensor = parse_obj_as(NdArray[3, 224, 224], np.zeros((3, 224, 224))) + assert isinstance(tensor, NdArray) + assert isinstance(tensor, np.ndarray) + assert tensor.shape == (3, 224, 224) + + # wrong but reshapable shape + tensor = parse_obj_as(NdArray[3, 224, 224], np.zeros((3, 224, 224))) + assert isinstance(tensor, NdArray) + assert isinstance(tensor, np.ndarray) + assert tensor.shape == (3, 224, 224) + + # wrong and not reshapable shape + with pytest.raises(ValueError): + parse_obj_as(NdArray[3, 224, 224], np.zeros((224, 224))) diff --git a/tests/units/typing/tensor/test_torch_tensor.py b/tests/units/typing/tensor/test_torch_tensor.py new file mode 100644 index 00000000000..ac064d5489d --- /dev/null +++ b/tests/units/typing/tensor/test_torch_tensor.py @@ -0,0 +1,59 @@ +import pytest +import torch +from pydantic.tools import parse_obj_as, schema_json_of + +from docarray.document.io.json import orjson_dumps +from docarray.typing import TorchTensor + + +def test_proto_tensor(): + + tensor = parse_obj_as(TorchTensor, torch.zeros(3, 224, 224)) + + tensor._to_node_protobuf() + + +def test_json_schema(): + schema_json_of(TorchTensor) + + +def test_dump_json(): + tensor = parse_obj_as(TorchTensor, torch.zeros(3, 224, 224)) + orjson_dumps(tensor) + + +def test_unwrap(): + tensor = parse_obj_as(TorchTensor, torch.zeros(3, 224, 224)) + ndarray = tensor.unwrap() + + assert not isinstance(ndarray, TorchTensor) + assert isinstance(tensor, TorchTensor) + assert isinstance(ndarray, torch.Tensor) + + assert tensor.data_ptr() == ndarray.data_ptr() + + assert (ndarray == torch.zeros(3, 224, 224)).all() + + +def test_parametrized(): + # correct shape, single axis + tensor = parse_obj_as(TorchTensor[128], torch.zeros(128)) + assert isinstance(tensor, TorchTensor) + assert isinstance(tensor, torch.Tensor) + assert tensor.shape == (128,) + + # correct shape, multiple axis + tensor = parse_obj_as(TorchTensor[3, 224, 224], torch.zeros(3, 224, 224)) + assert isinstance(tensor, TorchTensor) + assert isinstance(tensor, torch.Tensor) + assert tensor.shape == (3, 224, 224) + + # wrong but reshapable shape + tensor = parse_obj_as(TorchTensor[3, 224, 224], torch.zeros(224, 3, 224)) + assert isinstance(tensor, TorchTensor) + assert isinstance(tensor, torch.Tensor) + assert tensor.shape == (3, 224, 224) + + # wrong and not reshapable shape + with pytest.raises(ValueError): + parse_obj_as(TorchTensor[3, 224, 224], torch.zeros(224, 224)) diff --git a/tests/units/typing/test_torch_tensor.py b/tests/units/typing/test_torch_tensor.py deleted file mode 100644 index 7d3081f86e3..00000000000 --- a/tests/units/typing/test_torch_tensor.py +++ /dev/null @@ -1,34 +0,0 @@ -import torch -from pydantic.tools import parse_obj_as, schema_json_of - -from docarray.document.io.json import orjson_dumps -from docarray.typing import TorchTensor - - -def test_proto_tensor(): - - tensor = parse_obj_as(TorchTensor, torch.zeros(3, 224, 224)) - - tensor._to_node_protobuf() - - -def test_json_schema(): - schema_json_of(TorchTensor) - - -def test_dump_json(): - tensor = parse_obj_as(TorchTensor, torch.zeros(3, 224, 224)) - orjson_dumps(tensor) - - -def test_unwrap(): - tensor = parse_obj_as(TorchTensor, torch.zeros(3, 224, 224)) - ndarray = tensor.unwrap() - - assert not isinstance(ndarray, TorchTensor) - assert isinstance(tensor, TorchTensor) - assert isinstance(ndarray, torch.Tensor) - - assert tensor.data_ptr() == ndarray.data_ptr() - - assert (ndarray == torch.zeros(3, 224, 224)).all()