Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions docarray/array/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,13 @@ class _DocumenArrayTyped(DocumentArray):

for field in _DocumenArrayTyped.document_type.__fields__.keys():

def _proprety_generator(val: str):
def _property_generator(val: str):
return property(lambda self: self._get_documents_attribute(val))

setattr(_DocumenArrayTyped, field, _proprety_generator(field))
setattr(_DocumenArrayTyped, field, _property_generator(field))
# this generates property on the fly based on the schema of the item

_DocumenArrayTyped.__name__ = f'DocumentArray{item.__name__}'
_DocumenArrayTyped.__name__ = f'DocumentArray[{item.__name__}]'
_DocumenArrayTyped.__qualname__ = f'DocumentArray[{item.__name__}]'

return _DocumenArrayTyped
69 changes: 69 additions & 0 deletions docarray/typing/tensor/abstract_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import abc
from abc import ABC
from typing import TYPE_CHECKING, Any, Generic, Tuple, Type, TypeVar

from docarray.typing.abstract_type import AbstractType

if TYPE_CHECKING:
from pydantic import BaseConfig
from pydantic.fields import ModelField

T = TypeVar('T', bound='AbstractTensor')
ShapeT = TypeVar('ShapeT')


class AbstractTensor(AbstractType, Generic[ShapeT], ABC):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since AbstractType is inheried, we don't need ABC ?

@JohannesMessner JohannesMessner Nov 26, 2022

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually ABC is not inherited like other classes, due to it using a custom metaclass. I believe we need to inherit from ABC again here.

This is a quite advanced topic in Python that is usually not important day to day, but if you are interested you can read more here: https://peps.python.org/pep-3119/


__parametrized_meta__ = type

@classmethod
@abc.abstractmethod
def __validate_shape__(cls, t: T, shape: Tuple[int]) -> T:
"""Every tensor has to implement this method in order to
enbale syntax of the form Tensor[shape].
The intended behavoiour is as follows:
- If the shape of `t` is equal to `shape`, return `t`.
- If the shape of `t` is not equal to `shape`,
but can be reshaped to `shape`, return `t` reshaped to `shape`.
- If the shape of `t` is not equal to `shape`
and cannot be reshaped to `shape`, raise a ValueError.

:param t: The tensor to validate.
:param shape: The shape to validate against.
:return: The validated tensor.
"""
...

@classmethod
def _create_parametrized_type(cls: Type[T], shape: Tuple[int]):
shape_str = ', '.join([str(s) for s in shape])

class _ParametrizedTensor(
cls, # type: ignore
metaclass=cls.__parametrized_meta__, # type: ignore
):
_docarray_target_shape = shape
__name__ = f'{cls.__name__}[{shape_str}]'
__qualname__ = f'{cls.__qualname__}[{shape_str}]'

@classmethod
def validate(
_cls,
value: Any,
field: 'ModelField',
config: 'BaseConfig',
):
t = super().validate(value, field, config)
return _cls.__validate_shape__(t, _cls._docarray_target_shape)

return _ParametrizedTensor

def __class_getitem__(cls, item):
if isinstance(item, int):
item = (item,)
try:
item = tuple(item)
except TypeError:
raise TypeError(f'{item} is not a valid tensor shape.')

return cls._create_parametrized_type(item)
78 changes: 75 additions & 3 deletions docarray/typing/tensor/ndarray.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union, cast
import warnings
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generic,
List,
Tuple,
Type,
TypeVar,
Union,
cast,
)

import numpy as np

from docarray.typing.abstract_type import AbstractType
from docarray.typing.tensor.abstract_tensor import AbstractTensor

if TYPE_CHECKING:
from pydantic.fields import ModelField
Expand All @@ -11,16 +23,76 @@
from docarray.proto import NdArrayProto, NodeProto

T = TypeVar('T', bound='NdArray')
ShapeT = TypeVar('ShapeT')


class NdArray(np.ndarray, AbstractType):
class NdArray(AbstractTensor, np.ndarray, Generic[ShapeT]):
"""
Subclass of np.ndarray, intended for use in a Document.
This enables (de)serialization from/to protobuf and json, data validation,
and coersion from compatible types like torch.Tensor.

This type can also be used in a parametrized way, specifying the shape of the array.

EXAMPLE USAGE

.. code-block:: python

from docarray import Document
from docarray.typing import NdArray
import numpy as np


class MyDoc(Document):
arr: NdArray
image_arr: NdArray[3, 224, 224]


# create a document with tensors
doc = MyDoc(
arr=np.zeros((128,)),
image_arr=np.zeros((3, 224, 224)),
)
assert doc.image_arr.shape == (3, 224, 224)

# automatic shape conversion
doc = MyDoc(
arr=np.zeros((128,)),
image_arr=np.zeros((224, 224, 3)), # will reshape to (3, 224, 224)
)
assert doc.image_arr.shape == (3, 224, 224)

# !! The following will raise an error due to shape mismatch !!
doc = MyDoc(
arr=np.zeros((128,)),
image_arr=np.zeros((224, 224)), # this will fail validation
)
"""

@classmethod
def __get_validators__(cls):
# one or more validators may be yielded which will be called in the
# order to validate the input, each validator will receive as an input
# the value returned from the previous validator
yield cls.validate

@classmethod
def __validate_shape__(cls, t: T, shape: Tuple[int]) -> T: # type: ignore
if t.shape == shape:
return t
else:
warnings.warn(
f'Tensor shape mismatch. Reshaping array '
f'of shape {t.shape} to shape {shape}'
)
try:
value = cls.from_ndarray(np.reshape(t, shape))
return cast(T, value)
except RuntimeError:
raise ValueError(
f'Cannot reshape array of shape {t.shape} to shape {shape}'
)

@classmethod
def validate(
cls: Type[T],
Expand Down
71 changes: 68 additions & 3 deletions docarray/typing/tensor/torch_tensor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import warnings
from copy import copy
from typing import TYPE_CHECKING, Any, Dict, Type, TypeVar, Union, cast
from typing import TYPE_CHECKING, Any, Dict, Generic, Tuple, Type, TypeVar, Union, cast

import numpy as np
import torch # type: ignore

from docarray.typing.abstract_type import AbstractType
from docarray.typing.tensor.abstract_tensor import AbstractTensor

if TYPE_CHECKING:
from pydantic.fields import ModelField
Expand All @@ -15,6 +16,7 @@
from docarray.proto import NdArrayProto, NodeProto

T = TypeVar('T', bound='TorchTensor')
ShapeT = TypeVar('ShapeT')

torch_base = type(torch.Tensor) # type: Any
node_base = type(BaseNode) # type: Any
Expand All @@ -24,16 +26,79 @@ class metaTorchAndNode(torch_base, node_base):
pass


class TorchTensor(AbstractType, torch.Tensor, metaclass=metaTorchAndNode):
class TorchTensor(
AbstractTensor, torch.Tensor, Generic[ShapeT], metaclass=metaTorchAndNode
):
# Subclassing torch.Tensor following the advice from here:
# https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor
"""
Subclass of torch.Tensor, intended for use in a Document.
This enables (de)serialization from/to protobuf and json, data validation,
and coersion from compatible types like numpy.ndarray.

This type can also be used in a parametrized way,
specifying the shape of the tensor.

EXAMPLE USAGE

.. code-block:: python

from docarray import Document
from docarray.typing import TorchTensor
import torch


class MyDoc(Document):
tensor: TorchTensor
image_tensor: TorchTensor[3, 224, 224]


# create a document with tensors
doc = MyDoc(
tensor=torch.zeros(128),
image_tensor=torch.zeros(3, 224, 224),
)

# automatic shape conversion
doc = MyDoc(
tensor=torch.zeros(128),
image_tensor=torch.zeros(224, 224, 3), # will reshape to (3, 224, 224)
)

# !! The following will raise an error due to shape mismatch !!
doc = MyDoc(
tensor=torch.zeros(128),
image_tensor=torch.zeros(224, 224), # this will fail validation
)

"""

@classmethod
def __get_validators__(cls):
# one or more validators may be yielded which will be called in the
# order to validate the input, each validator will receive as an input
# the value returned from the previous validator
yield cls.validate

__parametrized_meta__ = metaTorchAndNode

@classmethod
def __validate_shape__(cls, t: T, shape: Tuple[int]) -> T: # type: ignore
if t.shape == shape:
return t
else:
warnings.warn(
f'Tensor shape mismatch. Reshaping tensor '
f'of shape {t.shape} to shape {shape}'
)
try:
value = cls.from_native_torch_tensor(t.view(shape))
return cast(T, value)
except RuntimeError:
raise ValueError(
f'Cannot reshape tensor of ' f'shape {t.shape} to shape {shape}'
)

@classmethod
def validate(
cls: Type[T],
Expand Down
15 changes: 15 additions & 0 deletions tests/integrations/document/test_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ class MyDoc(Document):
txt_url: TextUrl
any_url: AnyUrl
torch_tensor: TorchTensor
torch_tensor_param: TorchTensor[224, 224, 3]
np_array: NdArray
np_array_param: NdArray[224, 224, 3]
generic_nd_array: Tensor
generic_torch_tensor: Tensor
embedding: Embedding
Expand All @@ -45,7 +47,9 @@ class MyDoc(Document):
txt_url='test.txt',
any_url='www.jina.ai',
torch_tensor=torch.zeros((3, 224, 224)),
torch_tensor_param=torch.zeros((3, 224, 224)),
np_array=np.zeros((3, 224, 224)),
np_array_param=np.zeros((3, 224, 224)),
generic_nd_array=np.zeros((3, 224, 224)),
generic_torch_tensor=torch.zeros((3, 224, 224)),
embedding=np.zeros((3, 224, 224)),
Expand All @@ -55,12 +59,23 @@ class MyDoc(Document):
assert doc.img_url == 'test.png'
assert doc.txt_url == 'test.txt'
assert doc.any_url == 'www.jina.ai'

assert (doc.torch_tensor == torch.zeros((3, 224, 224))).all()
assert isinstance(doc.torch_tensor, torch.Tensor)

assert (doc.torch_tensor_param == torch.zeros((224, 224, 3))).all()
assert isinstance(doc.torch_tensor_param, torch.Tensor)

assert (doc.np_array == np.zeros((3, 224, 224))).all()
assert isinstance(doc.np_array, np.ndarray)

assert (doc.np_array_param == np.zeros((224, 224, 3))).all()
assert isinstance(doc.np_array_param, np.ndarray)

assert (doc.generic_nd_array == np.zeros((3, 224, 224))).all()
assert isinstance(doc.generic_nd_array, np.ndarray)

assert (doc.generic_torch_tensor == torch.zeros((3, 224, 224))).all()
assert isinstance(doc.generic_torch_tensor, torch.Tensor)

assert (doc.embedding == np.zeros((3, 224, 224))).all()
25 changes: 25 additions & 0 deletions tests/units/typing/tensor/test_tensor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import orjson
import pytest
from pydantic.tools import parse_obj_as, schema_json_of

from docarray.document.io.json import orjson_dumps
Expand Down Expand Up @@ -47,3 +48,27 @@ def test_unwrap():
assert isinstance(ndarray, np.ndarray)
assert isinstance(tensor, NdArray)
assert (ndarray == np.zeros((3, 224, 224))).all()


def test_parametrized():
# correct shape, single axis
tensor = parse_obj_as(NdArray[128], np.zeros(128))
assert isinstance(tensor, NdArray)
assert isinstance(tensor, np.ndarray)
assert tensor.shape == (128,)

# correct shape, multiple axis
tensor = parse_obj_as(NdArray[3, 224, 224], np.zeros((3, 224, 224)))
assert isinstance(tensor, NdArray)
assert isinstance(tensor, np.ndarray)
assert tensor.shape == (3, 224, 224)

# wrong but reshapable shape
tensor = parse_obj_as(NdArray[3, 224, 224], np.zeros((3, 224, 224)))
assert isinstance(tensor, NdArray)
assert isinstance(tensor, np.ndarray)
assert tensor.shape == (3, 224, 224)

# wrong and not reshapable shape
with pytest.raises(ValueError):
parse_obj_as(NdArray[3, 224, 224], np.zeros((224, 224)))
Loading