Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install poetry
poetry install -E common --without dev
poetry install --all-extras --without dev
- name: Test basic import
run: poetry run python -c 'from docarray import DocumentArray, Document'

Expand Down Expand Up @@ -110,7 +110,7 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install poetry
poetry install -E common
poetry install --all-extras

- name: Test
id: test
Expand Down
4 changes: 3 additions & 1 deletion docarray/document/mixins/proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from docarray.document.abstract_document import AbstractDocument
from docarray.document.base_node import BaseNode
from docarray.proto import DocumentProto, NodeProto
from docarray.typing import ID, AnyUrl, Embedding, ImageUrl, Tensor
from docarray.typing import ID, AnyUrl, Embedding, ImageUrl, Tensor, TorchTensor


class ProtoMixin(AbstractDocument, BaseNode):
Expand All @@ -25,6 +25,8 @@ def from_protobuf(cls, pb_msg: 'DocumentProto') -> 'ProtoMixin':
# the check should be delegated to the type level
if content_type == 'tensor':
fields[field] = Tensor._read_from_proto(value.tensor)
elif content_type == 'torch_tensor':
fields[field] = TorchTensor._read_from_proto(value.torch_tensor)
elif content_type == 'embedding':
fields[field] = Embedding._read_from_proto(value.embedding)
elif content_type == 'any_url':
Expand Down
2 changes: 2 additions & 0 deletions docarray/proto/docarray.proto
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ message NodeProto {

string id = 9;

NdArrayProto torch_tensor = 10;

}

}
Expand Down
38 changes: 18 additions & 20 deletions docarray/proto/pb2/docarray_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions docarray/typing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from docarray.typing.embedding import Embedding
from docarray.typing.id import ID
from docarray.typing.tensor import Tensor
from docarray.typing.tensor import Tensor, TorchTensor
from docarray.typing.url import AnyUrl, ImageUrl

__all__ = ['Tensor', 'Embedding', 'ImageUrl', 'AnyUrl', 'ID']
__all__ = ['Tensor', 'Embedding', 'ImageUrl', 'AnyUrl', 'ID', 'TorchTensor']
3 changes: 2 additions & 1 deletion docarray/typing/tensor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from docarray.typing.tensor.tensor import Tensor
from docarray.typing.tensor.torch_tensor import TorchTensor

__all__ = ['Tensor']
__all__ = ['Tensor', 'TorchTensor']
106 changes: 106 additions & 0 deletions docarray/typing/tensor/torch_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from typing import TYPE_CHECKING, Any, Type, TypeVar, Union, cast

import numpy as np
import torch # type: ignore

if TYPE_CHECKING:
from pydantic.fields import ModelField
from pydantic import BaseConfig
import numpy as np

from docarray.document.base_node import BaseNode
from docarray.proto import NdArrayProto, NodeProto

T = TypeVar('T', bound='TorchTensor')

torch_base = type(torch.Tensor) # type: Any
node_base = type(BaseNode) # type: Any


class metaTorchAndNode(torch_base, node_base):
pass


class TorchTensor(torch.Tensor, BaseNode, metaclass=metaTorchAndNode):
# Subclassing torch.Tensor following the advice from here:
# https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor
@classmethod
def __get_validators__(cls):
# one or more validators may be yielded which will be called in the
# order to validate the input, each validator will receive as an input
# the value returned from the previous validator
yield cls.validate

@classmethod
def validate(
cls: Type[T],
value: Union[T, np.ndarray, Any],
field: 'ModelField',
config: 'BaseConfig',
) -> T:
if isinstance(value, TorchTensor):
return cast(T, value)
elif isinstance(value, torch.Tensor):
return cls.from_native_torch_tensor(value)

else:
try:
arr: torch.Tensor = torch.tensor(value)
return cls.from_native_torch_tensor(arr)
except Exception:
pass # handled below
raise ValueError(f'Expected a torch.Tensor, got {type(value)}')

@classmethod
def from_native_torch_tensor(cls: Type[T], value: torch.Tensor) -> T:
"""Create a TorchTensor from a native torch.Tensor

:param value: the native torch.Tensor
:return: a TorchTensor
"""
value.__class__ = cls
return cast(T, value)

@classmethod
def from_ndarray(cls: Type[T], value: np.ndarray) -> T:
"""Create a TorchTensor from a numpy array

:param value: the numpy array
:return: a TorchTensor
"""
return cls.from_native_torch_tensor(torch.from_numpy(value))

def _to_node_protobuf(self: T, field: str = 'torch_tensor') -> NodeProto:
"""Convert Document into a NodeProto protobuf message. This function should
be called when the Document is nested into another Document that need to be
converted into a protobuf
:param field: field in which to store the content in the node proto
:return: the nested item protobuf message
"""
nd_proto = NdArrayProto()
self._flush_tensor_to_proto(nd_proto, value=self)
return NodeProto(**{field: nd_proto})

@classmethod
def _read_from_proto(cls: Type[T], pb_msg: 'NdArrayProto') -> 'T':
"""
read ndarray from a proto msg
:param pb_msg:
:return: a numpy array
"""
source = pb_msg.dense
if source.buffer:
x = np.frombuffer(source.buffer, dtype=source.dtype)
return cls.from_ndarray(x.reshape(source.shape))
elif len(source.shape) > 0:
return cls.from_ndarray(np.zeros(source.shape))
else:
raise ValueError(f'proto message {pb_msg} cannot be cast to a TorchTensor')

@staticmethod
def _flush_tensor_to_proto(pb_msg: 'NdArrayProto', value: 'TorchTensor'):
value_np = value.detach().cpu().numpy()
pb_msg.dense.buffer = value_np.tobytes()
pb_msg.dense.ClearField('shape')
pb_msg.dense.shape.extend(list(value_np.shape))
pb_msg.dense.dtype = value_np.dtype.str
Loading