From 03daff1cbdb1d793017019d4d7e4a44e913d76fc Mon Sep 17 00:00:00 2001 From: Johannes Messner Date: Fri, 25 Nov 2022 12:12:31 +0100 Subject: [PATCH 1/9] feat: parametrized torch tensor type Signed-off-by: Johannes Messner --- docarray/typing/tensor/torch_tensor.py | 51 +++++++++++++++- .../units/typing/tensor/test_torch_tensor.py | 59 +++++++++++++++++++ tests/units/typing/test_torch_tensor.py | 34 ----------- 3 files changed, 108 insertions(+), 36 deletions(-) create mode 100644 tests/units/typing/tensor/test_torch_tensor.py delete mode 100644 tests/units/typing/test_torch_tensor.py diff --git a/docarray/typing/tensor/torch_tensor.py b/docarray/typing/tensor/torch_tensor.py index ad95401f161..2b53cc89761 100644 --- a/docarray/typing/tensor/torch_tensor.py +++ b/docarray/typing/tensor/torch_tensor.py @@ -1,5 +1,6 @@ +import warnings from copy import copy -from typing import TYPE_CHECKING, Any, Dict, Type, TypeVar, Union, cast +from typing import TYPE_CHECKING, Any, Dict, Generic, Type, TypeVar, Union, cast import numpy as np import torch # type: ignore @@ -24,7 +25,12 @@ class metaTorchAndNode(torch_base, node_base): pass -class TorchTensor(AbstractType, torch.Tensor, metaclass=metaTorchAndNode): +ShapeT = TypeVar('ShapeT') + + +class TorchTensor( + AbstractType, torch.Tensor, Generic[ShapeT], metaclass=metaTorchAndNode +): # Subclassing torch.Tensor following the advice from here: # https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor @classmethod @@ -34,6 +40,47 @@ def __get_validators__(cls): # the value returned from the previous validator yield cls.validate + def __class_getitem__(cls, item): + if not isinstance(item, tuple): + if isinstance(item, int): + item = (item,) + else: + raise TypeError(f'{item} is not a valid tensor shape.') + + class _ParametrizedTorchTensor(TorchTensor, metaclass=metaTorchAndNode): + _docaray_shape = item + + @classmethod + def validate( + cls: Type[T], + value: Union[T, np.ndarray, Any], + field: 'ModelField', + config: 'BaseConfig', + ) -> T: + t = super().validate(value, field, config) + if t.shape == cls._docaray_shape: + return t + else: + warnings.warn( + f'Tensor shape mismatch. Reshaping tensor ' + f'of shape {t.shape} to shape {cls._docaray_shape}' + ) + try: + return cls.from_native_torch_tensor( + torch.reshape(t, cls._docaray_shape) + ) + except RuntimeError: + raise ValueError( + f'Cannot reshape tensor of ' + f'shape {t.shape} to shape {cls._docaray_shape}' + ) + + # set class name + shape_str = ', '.join([str(s) for s in item]) + _ParametrizedTorchTensor.__name__ = f'TorchTensor[{shape_str}]' + _ParametrizedTorchTensor.__qualname__ = f'TorchTensor[{shape_str}]' + return _ParametrizedTorchTensor + @classmethod def validate( cls: Type[T], diff --git a/tests/units/typing/tensor/test_torch_tensor.py b/tests/units/typing/tensor/test_torch_tensor.py new file mode 100644 index 00000000000..ac064d5489d --- /dev/null +++ b/tests/units/typing/tensor/test_torch_tensor.py @@ -0,0 +1,59 @@ +import pytest +import torch +from pydantic.tools import parse_obj_as, schema_json_of + +from docarray.document.io.json import orjson_dumps +from docarray.typing import TorchTensor + + +def test_proto_tensor(): + + tensor = parse_obj_as(TorchTensor, torch.zeros(3, 224, 224)) + + tensor._to_node_protobuf() + + +def test_json_schema(): + schema_json_of(TorchTensor) + + +def test_dump_json(): + tensor = parse_obj_as(TorchTensor, torch.zeros(3, 224, 224)) + orjson_dumps(tensor) + + +def test_unwrap(): + tensor = parse_obj_as(TorchTensor, torch.zeros(3, 224, 224)) + ndarray = tensor.unwrap() + + assert not isinstance(ndarray, TorchTensor) + assert isinstance(tensor, TorchTensor) + assert isinstance(ndarray, torch.Tensor) + + assert tensor.data_ptr() == ndarray.data_ptr() + + assert (ndarray == torch.zeros(3, 224, 224)).all() + + +def test_parametrized(): + # correct shape, single axis + tensor = parse_obj_as(TorchTensor[128], torch.zeros(128)) + assert isinstance(tensor, TorchTensor) + assert isinstance(tensor, torch.Tensor) + assert tensor.shape == (128,) + + # correct shape, multiple axis + tensor = parse_obj_as(TorchTensor[3, 224, 224], torch.zeros(3, 224, 224)) + assert isinstance(tensor, TorchTensor) + assert isinstance(tensor, torch.Tensor) + assert tensor.shape == (3, 224, 224) + + # wrong but reshapable shape + tensor = parse_obj_as(TorchTensor[3, 224, 224], torch.zeros(224, 3, 224)) + assert isinstance(tensor, TorchTensor) + assert isinstance(tensor, torch.Tensor) + assert tensor.shape == (3, 224, 224) + + # wrong and not reshapable shape + with pytest.raises(ValueError): + parse_obj_as(TorchTensor[3, 224, 224], torch.zeros(224, 224)) diff --git a/tests/units/typing/test_torch_tensor.py b/tests/units/typing/test_torch_tensor.py deleted file mode 100644 index 7d3081f86e3..00000000000 --- a/tests/units/typing/test_torch_tensor.py +++ /dev/null @@ -1,34 +0,0 @@ -import torch -from pydantic.tools import parse_obj_as, schema_json_of - -from docarray.document.io.json import orjson_dumps -from docarray.typing import TorchTensor - - -def test_proto_tensor(): - - tensor = parse_obj_as(TorchTensor, torch.zeros(3, 224, 224)) - - tensor._to_node_protobuf() - - -def test_json_schema(): - schema_json_of(TorchTensor) - - -def test_dump_json(): - tensor = parse_obj_as(TorchTensor, torch.zeros(3, 224, 224)) - orjson_dumps(tensor) - - -def test_unwrap(): - tensor = parse_obj_as(TorchTensor, torch.zeros(3, 224, 224)) - ndarray = tensor.unwrap() - - assert not isinstance(ndarray, TorchTensor) - assert isinstance(tensor, TorchTensor) - assert isinstance(ndarray, torch.Tensor) - - assert tensor.data_ptr() == ndarray.data_ptr() - - assert (ndarray == torch.zeros(3, 224, 224)).all() From ceb1aa507d0f221a2925709a139cf17dcfcd63d9 Mon Sep 17 00:00:00 2001 From: Johannes Messner Date: Fri, 25 Nov 2022 12:13:10 +0100 Subject: [PATCH 2/9] refactor: fix typo in method name Signed-off-by: Johannes Messner --- docarray/array/array.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docarray/array/array.py b/docarray/array/array.py index 681bbe19e02..ba8f750463f 100644 --- a/docarray/array/array.py +++ b/docarray/array/array.py @@ -34,12 +34,13 @@ class _DocumenArrayTyped(DocumentArray): for field in _DocumenArrayTyped.document_type.__fields__.keys(): - def _proprety_generator(val: str): + def _property_generator(val: str): return property(lambda self: self._get_documents_attribute(val)) - setattr(_DocumenArrayTyped, field, _proprety_generator(field)) + setattr(_DocumenArrayTyped, field, _property_generator(field)) # this generates property on the fly based on the schema of the item - _DocumenArrayTyped.__name__ = f'DocumentArray{item.__name__}' + _DocumenArrayTyped.__name__ = f'DocumentArray[{item.__name__}]' + _DocumenArrayTyped.__qualname__ = f'DocumentArray[{item.__name__}]' return _DocumenArrayTyped From 24aa9323e11b35012d582fd68fcff901bf0a921a Mon Sep 17 00:00:00 2001 From: Johannes Messner Date: Fri, 25 Nov 2022 15:03:43 +0100 Subject: [PATCH 3/9] refactor: make parametrized type implementation generic Signed-off-by: Johannes Messner --- docarray/typing/tensor/abstract_tensor.py | 57 ++++++++++++++++++++ docarray/typing/tensor/torch_tensor.py | 63 +++++++---------------- 2 files changed, 77 insertions(+), 43 deletions(-) create mode 100644 docarray/typing/tensor/abstract_tensor.py diff --git a/docarray/typing/tensor/abstract_tensor.py b/docarray/typing/tensor/abstract_tensor.py new file mode 100644 index 00000000000..e6ea0318edc --- /dev/null +++ b/docarray/typing/tensor/abstract_tensor.py @@ -0,0 +1,57 @@ +from abc import ABC +from typing import TYPE_CHECKING, Any, Tuple, Type, TypeVar + +from docarray.typing.abstract_type import AbstractType + +if TYPE_CHECKING: + from pydantic import BaseConfig + from pydantic.fields import ModelField + +T = TypeVar('T', bound='AbstractTensor') + + +def _get_attr_from_superclasses(supers, attr): + for cls in supers: + if hasattr(cls, attr): + return getattr(cls, attr) + raise AttributeError(f'Cannot find attribute {attr} in mro') + + +@classmethod +def __validate_parametrized__( + cls: Type[T], + value: Any, + field: 'ModelField', + config: 'BaseConfig', +) -> T: + supers = cls.__mro__[1:] # superclasses of cls + _validate = _get_attr_from_superclasses(supers, 'validate') + t = _validate(value, field, config) + return cls.__validate_shape__(t, cls._docarray_target_shape) + + +class AbstractTensor(AbstractType, ABC): + + __parametrized_meta__ = type + + @classmethod + def __validate_shape__(cls, t: T, shape: Tuple[int]) -> T: + ... + + @classmethod + def _create_parametrized_type(cls, shape: Tuple[int]): + shape_str = ', '.join([str(s) for s in shape]) + return cls.__parametrized_meta__( + f'{cls.__name__}[{shape_str}]', + (cls,), + {'_docarray_target_shape': shape, 'validate': __validate_parametrized__}, + ) + + def __class_getitem__(cls, item): + if not isinstance(item, tuple): + if isinstance(item, int): + item = (item,) + else: + raise TypeError(f'{item} is not a valid tensor shape.') + + return cls._create_parametrized_type(item) diff --git a/docarray/typing/tensor/torch_tensor.py b/docarray/typing/tensor/torch_tensor.py index 2b53cc89761..31d8447449a 100644 --- a/docarray/typing/tensor/torch_tensor.py +++ b/docarray/typing/tensor/torch_tensor.py @@ -1,11 +1,11 @@ import warnings from copy import copy -from typing import TYPE_CHECKING, Any, Dict, Generic, Type, TypeVar, Union, cast +from typing import TYPE_CHECKING, Any, Dict, Generic, Tuple, Type, TypeVar, Union, cast import numpy as np import torch # type: ignore -from docarray.typing.abstract_type import AbstractType +from docarray.typing.tensor.abstract_tensor import AbstractTensor if TYPE_CHECKING: from pydantic.fields import ModelField @@ -29,7 +29,7 @@ class metaTorchAndNode(torch_base, node_base): class TorchTensor( - AbstractType, torch.Tensor, Generic[ShapeT], metaclass=metaTorchAndNode + AbstractTensor, torch.Tensor, Generic[ShapeT], metaclass=metaTorchAndNode ): # Subclassing torch.Tensor following the advice from here: # https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor @@ -40,46 +40,23 @@ def __get_validators__(cls): # the value returned from the previous validator yield cls.validate - def __class_getitem__(cls, item): - if not isinstance(item, tuple): - if isinstance(item, int): - item = (item,) - else: - raise TypeError(f'{item} is not a valid tensor shape.') - - class _ParametrizedTorchTensor(TorchTensor, metaclass=metaTorchAndNode): - _docaray_shape = item - - @classmethod - def validate( - cls: Type[T], - value: Union[T, np.ndarray, Any], - field: 'ModelField', - config: 'BaseConfig', - ) -> T: - t = super().validate(value, field, config) - if t.shape == cls._docaray_shape: - return t - else: - warnings.warn( - f'Tensor shape mismatch. Reshaping tensor ' - f'of shape {t.shape} to shape {cls._docaray_shape}' - ) - try: - return cls.from_native_torch_tensor( - torch.reshape(t, cls._docaray_shape) - ) - except RuntimeError: - raise ValueError( - f'Cannot reshape tensor of ' - f'shape {t.shape} to shape {cls._docaray_shape}' - ) - - # set class name - shape_str = ', '.join([str(s) for s in item]) - _ParametrizedTorchTensor.__name__ = f'TorchTensor[{shape_str}]' - _ParametrizedTorchTensor.__qualname__ = f'TorchTensor[{shape_str}]' - return _ParametrizedTorchTensor + __parametrized_meta__ = metaTorchAndNode + + @classmethod + def __validate_shape__(cls, t: T, shape: Tuple[int]) -> T: + if t.shape == shape: + return t + else: + warnings.warn( + f'Tensor shape mismatch. Reshaping tensor ' + f'of shape {t.shape} to shape {shape}' + ) + try: + return cls.from_native_torch_tensor(torch.reshape(t, shape)) + except RuntimeError: + raise ValueError( + f'Cannot reshape tensor of ' f'shape {t.shape} to shape {shape}' + ) @classmethod def validate( From 0371a172fc6ed818dc89517e29f1d45ff1bafe38 Mon Sep 17 00:00:00 2001 From: Johannes Messner Date: Fri, 25 Nov 2022 15:59:20 +0100 Subject: [PATCH 4/9] refactor: make abstract tensor clearer Signed-off-by: Johannes Messner --- docarray/typing/tensor/abstract_tensor.py | 67 +++++++++++++---------- docarray/typing/tensor/torch_tensor.py | 9 ++- 2 files changed, 43 insertions(+), 33 deletions(-) diff --git a/docarray/typing/tensor/abstract_tensor.py b/docarray/typing/tensor/abstract_tensor.py index e6ea0318edc..0565da501ed 100644 --- a/docarray/typing/tensor/abstract_tensor.py +++ b/docarray/typing/tensor/abstract_tensor.py @@ -1,5 +1,6 @@ +import abc from abc import ABC -from typing import TYPE_CHECKING, Any, Tuple, Type, TypeVar +from typing import TYPE_CHECKING, Any, Generic, Tuple, Type, TypeVar from docarray.typing.abstract_type import AbstractType @@ -8,44 +9,54 @@ from pydantic.fields import ModelField T = TypeVar('T', bound='AbstractTensor') +ShapeT = TypeVar('ShapeT') -def _get_attr_from_superclasses(supers, attr): - for cls in supers: - if hasattr(cls, attr): - return getattr(cls, attr) - raise AttributeError(f'Cannot find attribute {attr} in mro') - - -@classmethod -def __validate_parametrized__( - cls: Type[T], - value: Any, - field: 'ModelField', - config: 'BaseConfig', -) -> T: - supers = cls.__mro__[1:] # superclasses of cls - _validate = _get_attr_from_superclasses(supers, 'validate') - t = _validate(value, field, config) - return cls.__validate_shape__(t, cls._docarray_target_shape) - - -class AbstractTensor(AbstractType, ABC): +class AbstractTensor(AbstractType, Generic[ShapeT], ABC): __parametrized_meta__ = type @classmethod + @abc.abstractmethod def __validate_shape__(cls, t: T, shape: Tuple[int]) -> T: + """Every tensor has to implement this method in order to + enbale syntax of the form Tensor[shape]. + The intended behavoiour is as follows: + - If the shape of `t` is equal to `shape`, return `t`. + - If the shape of `t` is not equal to `shape`, + but can be reshaped to `shape`, return `t` reshaped to `shape`. + - If the shape of `t` is not equal to `shape` + and cannot be reshaped to `shape`, raise a ValueError. + + :param t: The tensor to validate. + :param shape: The shape to validate against. + :return: The validated tensor. + """ ... @classmethod - def _create_parametrized_type(cls, shape: Tuple[int]): + def _create_parametrized_type(cls: Type[T], shape: Tuple[int]): shape_str = ', '.join([str(s) for s in shape]) - return cls.__parametrized_meta__( - f'{cls.__name__}[{shape_str}]', - (cls,), - {'_docarray_target_shape': shape, 'validate': __validate_parametrized__}, - ) + + class _ParametrizedTensor( + cls, # type: ignore + metaclass=cls.__parametrized_meta__, # type: ignore + ): + _docarray_target_shape = shape + __name__ = f'{cls.__name__}[{shape_str}]' + __qualname__ = f'{cls.__qualname__}[{shape_str}]' + + @classmethod + def validate( + _cls, + value: Any, + field: 'ModelField', + config: 'BaseConfig', + ): + t = super().validate(value, field, config) + return _cls.__validate_shape__(t, _cls._docarray_target_shape) + + return _ParametrizedTensor def __class_getitem__(cls, item): if not isinstance(item, tuple): diff --git a/docarray/typing/tensor/torch_tensor.py b/docarray/typing/tensor/torch_tensor.py index 31d8447449a..862ca810c4c 100644 --- a/docarray/typing/tensor/torch_tensor.py +++ b/docarray/typing/tensor/torch_tensor.py @@ -16,6 +16,7 @@ from docarray.proto import NdArrayProto, NodeProto T = TypeVar('T', bound='TorchTensor') +ShapeT = TypeVar('ShapeT') torch_base = type(torch.Tensor) # type: Any node_base = type(BaseNode) # type: Any @@ -25,9 +26,6 @@ class metaTorchAndNode(torch_base, node_base): pass -ShapeT = TypeVar('ShapeT') - - class TorchTensor( AbstractTensor, torch.Tensor, Generic[ShapeT], metaclass=metaTorchAndNode ): @@ -43,7 +41,7 @@ def __get_validators__(cls): __parametrized_meta__ = metaTorchAndNode @classmethod - def __validate_shape__(cls, t: T, shape: Tuple[int]) -> T: + def __validate_shape__(cls, t: T, shape: Tuple[int]) -> T: # type: ignore if t.shape == shape: return t else: @@ -52,7 +50,8 @@ def __validate_shape__(cls, t: T, shape: Tuple[int]) -> T: f'of shape {t.shape} to shape {shape}' ) try: - return cls.from_native_torch_tensor(torch.reshape(t, shape)) + value = cls.from_native_torch_tensor(torch.reshape(t, shape)) + return cast(T, value) except RuntimeError: raise ValueError( f'Cannot reshape tensor of ' f'shape {t.shape} to shape {shape}' From 8429a14cd18b221e827a8d1dafafae9f157f70f3 Mon Sep 17 00:00:00 2001 From: Johannes Messner Date: Fri, 25 Nov 2022 16:16:33 +0100 Subject: [PATCH 5/9] feat: parametrized ndarray type Signed-off-by: Johannes Messner --- docarray/typing/tensor/ndarray.py | 36 ++++++++++++++++++++++-- tests/units/typing/tensor/test_tensor.py | 25 ++++++++++++++++ 2 files changed, 58 insertions(+), 3 deletions(-) diff --git a/docarray/typing/tensor/ndarray.py b/docarray/typing/tensor/ndarray.py index 2b19dba179e..e1f32a95378 100644 --- a/docarray/typing/tensor/ndarray.py +++ b/docarray/typing/tensor/ndarray.py @@ -1,8 +1,20 @@ -from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union, cast +import warnings +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generic, + List, + Tuple, + Type, + TypeVar, + Union, + cast, +) import numpy as np -from docarray.typing.abstract_type import AbstractType +from docarray.typing.tensor.abstract_tensor import AbstractTensor if TYPE_CHECKING: from pydantic.fields import ModelField @@ -11,9 +23,10 @@ from docarray.proto import NdArrayProto, NodeProto T = TypeVar('T', bound='NdArray') +ShapeT = TypeVar('ShapeT') -class NdArray(np.ndarray, AbstractType): +class NdArray(AbstractTensor, np.ndarray, Generic[ShapeT]): @classmethod def __get_validators__(cls): # one or more validators may be yielded which will be called in the @@ -21,6 +34,23 @@ def __get_validators__(cls): # the value returned from the previous validator yield cls.validate + @classmethod + def __validate_shape__(cls, t: T, shape: Tuple[int]) -> T: # type: ignore + if t.shape == shape: + return t + else: + warnings.warn( + f'Tensor shape mismatch. Reshaping array ' + f'of shape {t.shape} to shape {shape}' + ) + try: + value = cls.from_ndarray(np.reshape(t, shape)) + return cast(T, value) + except RuntimeError: + raise ValueError( + f'Cannot reshape array of shape {t.shape} to shape {shape}' + ) + @classmethod def validate( cls: Type[T], diff --git a/tests/units/typing/tensor/test_tensor.py b/tests/units/typing/tensor/test_tensor.py index eb6b966f24d..d0ef58bc581 100644 --- a/tests/units/typing/tensor/test_tensor.py +++ b/tests/units/typing/tensor/test_tensor.py @@ -1,5 +1,6 @@ import numpy as np import orjson +import pytest from pydantic.tools import parse_obj_as, schema_json_of from docarray.document.io.json import orjson_dumps @@ -47,3 +48,27 @@ def test_unwrap(): assert isinstance(ndarray, np.ndarray) assert isinstance(tensor, NdArray) assert (ndarray == np.zeros((3, 224, 224))).all() + + +def test_parametrized(): + # correct shape, single axis + tensor = parse_obj_as(NdArray[128], np.zeros(128)) + assert isinstance(tensor, NdArray) + assert isinstance(tensor, np.ndarray) + assert tensor.shape == (128,) + + # correct shape, multiple axis + tensor = parse_obj_as(NdArray[3, 224, 224], np.zeros((3, 224, 224))) + assert isinstance(tensor, NdArray) + assert isinstance(tensor, np.ndarray) + assert tensor.shape == (3, 224, 224) + + # wrong but reshapable shape + tensor = parse_obj_as(NdArray[3, 224, 224], np.zeros((3, 224, 224))) + assert isinstance(tensor, NdArray) + assert isinstance(tensor, np.ndarray) + assert tensor.shape == (3, 224, 224) + + # wrong and not reshapable shape + with pytest.raises(ValueError): + parse_obj_as(NdArray[3, 224, 224], np.zeros((224, 224))) From 93dbc690697309fc8c9431c3564833d23a5db6b7 Mon Sep 17 00:00:00 2001 From: Johannes Messner Date: Fri, 25 Nov 2022 16:32:46 +0100 Subject: [PATCH 6/9] test: add serializatin tests for param tensors Signed-off-by: Johannes Messner --- tests/integrations/document/test_proto.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/integrations/document/test_proto.py b/tests/integrations/document/test_proto.py index 0092aabc0b7..98f05f09e15 100644 --- a/tests/integrations/document/test_proto.py +++ b/tests/integrations/document/test_proto.py @@ -35,7 +35,9 @@ class MyDoc(Document): txt_url: TextUrl any_url: AnyUrl torch_tensor: TorchTensor + torch_tensor_param: TorchTensor[224, 224, 3] np_array: NdArray + np_array_param: NdArray[224, 224, 3] generic_nd_array: Tensor generic_torch_tensor: Tensor embedding: Embedding @@ -45,7 +47,9 @@ class MyDoc(Document): txt_url='test.txt', any_url='www.jina.ai', torch_tensor=torch.zeros((3, 224, 224)), + torch_tensor_param=torch.zeros((3, 224, 224)), np_array=np.zeros((3, 224, 224)), + np_array_param=np.zeros((3, 224, 224)), generic_nd_array=np.zeros((3, 224, 224)), generic_torch_tensor=torch.zeros((3, 224, 224)), embedding=np.zeros((3, 224, 224)), @@ -55,12 +59,23 @@ class MyDoc(Document): assert doc.img_url == 'test.png' assert doc.txt_url == 'test.txt' assert doc.any_url == 'www.jina.ai' + assert (doc.torch_tensor == torch.zeros((3, 224, 224))).all() assert isinstance(doc.torch_tensor, torch.Tensor) + + assert (doc.torch_tensor_param == torch.zeros((224, 224, 3))).all() + assert isinstance(doc.torch_tensor_param, torch.Tensor) + assert (doc.np_array == np.zeros((3, 224, 224))).all() assert isinstance(doc.np_array, np.ndarray) + + assert (doc.np_array_param == np.zeros((224, 224, 3))).all() + assert isinstance(doc.np_array_param, np.ndarray) + assert (doc.generic_nd_array == np.zeros((3, 224, 224))).all() assert isinstance(doc.generic_nd_array, np.ndarray) + assert (doc.generic_torch_tensor == torch.zeros((3, 224, 224))).all() assert isinstance(doc.generic_torch_tensor, torch.Tensor) + assert (doc.embedding == np.zeros((3, 224, 224))).all() From 3536cee215412d9c7914867637c95bbee45d0fd3 Mon Sep 17 00:00:00 2001 From: Johannes Messner Date: Fri, 25 Nov 2022 16:55:08 +0100 Subject: [PATCH 7/9] docs: add costrings for tensor types Signed-off-by: Johannes Messner --- docarray/typing/tensor/ndarray.py | 42 ++++++++++++++++++++++++++ docarray/typing/tensor/torch_tensor.py | 42 ++++++++++++++++++++++++++ 2 files changed, 84 insertions(+) diff --git a/docarray/typing/tensor/ndarray.py b/docarray/typing/tensor/ndarray.py index e1f32a95378..1e0ee9174fd 100644 --- a/docarray/typing/tensor/ndarray.py +++ b/docarray/typing/tensor/ndarray.py @@ -27,6 +27,48 @@ class NdArray(AbstractTensor, np.ndarray, Generic[ShapeT]): + """ + Subclass of np.ndarray, intended for use in a Document. + This enables (de)serialization from/to protobuf and json, data validation, + and coersion from compatible types like torch.Tensor. + + This type can also be used in a parametrized way, specifying the shape of the array. + + EXAMPLE USAGE + + .. code-block:: python + + from docarray import Document + from docarray.typing import NdArray + import numpy as np + + + class MyDoc(Document): + arr: NdArray + image_arr: NdArray[3, 224, 224] + + + # create a document with tensors + doc = MyDoc( + arr=np.zeros((128,)), + image_arr=np.zeros((3, 224, 224)), + ) + assert doc.image_arr.shape == (3, 224, 224) + + # automatic shape conversion + doc = MyDoc( + arr=np.zeros((128,)), + image_arr=np.zeros((224, 224, 3)), # will reshape to (3, 224, 224) + ) + assert doc.image_arr.shape == (3, 224, 224) + + # !! The following will raise an error due to shape mismatch !! + doc = MyDoc( + arr=np.zeros((128,)), + image_arr=np.zeros((224, 224)), # this will fail validation + ) + """ + @classmethod def __get_validators__(cls): # one or more validators may be yielded which will be called in the diff --git a/docarray/typing/tensor/torch_tensor.py b/docarray/typing/tensor/torch_tensor.py index 862ca810c4c..59716e2191c 100644 --- a/docarray/typing/tensor/torch_tensor.py +++ b/docarray/typing/tensor/torch_tensor.py @@ -31,6 +31,48 @@ class TorchTensor( ): # Subclassing torch.Tensor following the advice from here: # https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor + """ + Subclass of torch.Tensor, intended for use in a Document. + This enables (de)serialization from/to protobuf and json, data validation, + and coersion from compatible types like numpy.ndarray. + + This type can also be used in a parametrized way, + specifying the shape of the tensor. + + EXAMPLE USAGE + + .. code-block:: python + + from docarray import Document + from docarray.typing import TorchTensor + import torch + + + class MyDoc(Document): + tensor: TorchTensor + image_tensor: TorchTensor[3, 224, 224] + + + # create a document with tensors + doc = MyDoc( + tensor=torch.zeros(128), + image_tensor=torch.zeros(3, 224, 224), + ) + + # automatic shape conversion + doc = MyDoc( + tensor=torch.zeros(128), + image_tensor=torch.zeros(224, 224, 3), # will reshape to (3, 224, 224) + ) + + # !! The following will raise an error due to shape mismatch !! + doc = MyDoc( + tensor=torch.zeros(128), + image_tensor=torch.zeros(224, 224), # this will fail validation + ) + + """ + @classmethod def __get_validators__(cls): # one or more validators may be yielded which will be called in the From dc9ad35dcddb5c67482cab3b810405d0f0c325aa Mon Sep 17 00:00:00 2001 From: Johannes Messner Date: Fri, 25 Nov 2022 17:16:53 +0100 Subject: [PATCH 8/9] fix: allow any kind of tensor shape Signed-off-by: Johannes Messner --- docarray/typing/tensor/abstract_tensor.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/docarray/typing/tensor/abstract_tensor.py b/docarray/typing/tensor/abstract_tensor.py index 0565da501ed..117e8fab4cb 100644 --- a/docarray/typing/tensor/abstract_tensor.py +++ b/docarray/typing/tensor/abstract_tensor.py @@ -59,10 +59,11 @@ def validate( return _ParametrizedTensor def __class_getitem__(cls, item): - if not isinstance(item, tuple): - if isinstance(item, int): - item = (item,) - else: - raise TypeError(f'{item} is not a valid tensor shape.') + if isinstance(item, int): + item = (item,) + try: + item = tuple(item) + except TypeError: + raise TypeError(f'{item} is not a valid tensor shape.') return cls._create_parametrized_type(item) From c26a37ad4ed42a27c97adbb86a9e9f488089bab2 Mon Sep 17 00:00:00 2001 From: Johannes Messner Date: Fri, 25 Nov 2022 17:18:48 +0100 Subject: [PATCH 9/9] fix: use view instead of reshape Signed-off-by: Johannes Messner --- docarray/typing/tensor/torch_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docarray/typing/tensor/torch_tensor.py b/docarray/typing/tensor/torch_tensor.py index 59716e2191c..b564908be75 100644 --- a/docarray/typing/tensor/torch_tensor.py +++ b/docarray/typing/tensor/torch_tensor.py @@ -92,7 +92,7 @@ def __validate_shape__(cls, t: T, shape: Tuple[int]) -> T: # type: ignore f'of shape {t.shape} to shape {shape}' ) try: - value = cls.from_native_torch_tensor(torch.reshape(t, shape)) + value = cls.from_native_torch_tensor(t.view(shape)) return cast(T, value) except RuntimeError: raise ValueError(