Skip to content

feat: parametrized tensors#850

Merged
samsja merged 9 commits into
feat-rewrite-v2from
parametrized-tensors
Nov 28, 2022
Merged

feat: parametrized tensors#850
samsja merged 9 commits into
feat-rewrite-v2from
parametrized-tensors

Conversation

@JohannesMessner

@JohannesMessner JohannesMessner commented Nov 25, 2022

Copy link
Copy Markdown
Member

Goals:

Have parametrized types for our tensor types, where the tensor shape can be specified directly in the type:

class MyDoc(Document):
    tensor: TorchTensor[3, 224, 224]

doc = MyDoc(tensor=torch.zeros(3, 224, 224))  # works
doc = MyDoc(tensor=torch.zeros(224, 224, 3))  # works by reshaping
doc = MyDoc(tensor=torch.zeros(224))  # fails validation

How it works:

The __class_getitem__() that is reponsible for the [...] syntax is defined in a general manner in AbstractTensor.
This means that concrete tensor types like TorchTensor and NdArray only have to implement __validate_shape__(), the rest works automatically.

Todo:

  • TorchTensor
  • NdArray
  • General solution that applies to all tensor types at once
  • Test with (de)serialization
  • Test with generic Tensor type -> this does not work, not supported
  • Documetn in docstring

Signed-off-by: Johannes Messner <messnerjo@gmail.com>
Signed-off-by: Johannes Messner <messnerjo@gmail.com>
@JohannesMessner JohannesMessner changed the title Parametrized tensors feat: arametrized tensors Nov 25, 2022
@JohannesMessner JohannesMessner changed the title feat: arametrized tensors feat: parametrized tensors Nov 25, 2022
Signed-off-by: Johannes Messner <messnerjo@gmail.com>
Signed-off-by: Johannes Messner <messnerjo@gmail.com>
Signed-off-by: Johannes Messner <messnerjo@gmail.com>
Signed-off-by: Johannes Messner <messnerjo@gmail.com>
Signed-off-by: Johannes Messner <messnerjo@gmail.com>
@JohannesMessner JohannesMessner marked this pull request as ready for review November 25, 2022 15:55
@JohannesMessner JohannesMessner requested review from dongxiang123 and samsja and removed request for samsja November 25, 2022 15:55
Comment thread docarray/typing/tensor/abstract_tensor.py Outdated
Comment thread docarray/typing/tensor/torch_tensor.py Outdated
Signed-off-by: Johannes Messner <messnerjo@gmail.com>
Signed-off-by: Johannes Messner <messnerjo@gmail.com>
ShapeT = TypeVar('ShapeT')


class AbstractTensor(AbstractType, Generic[ShapeT], ABC):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since AbstractType is inheried, we don't need ABC ?

@JohannesMessner JohannesMessner Nov 26, 2022

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually ABC is not inherited like other classes, due to it using a custom metaclass. I believe we need to inherit from ABC again here.

This is a quite advanced topic in Python that is usually not important day to day, but if you are interested you can read more here: https://peps.python.org/pep-3119/

tensor = parse_obj_as(TorchTensor[128], torch.zeros(128))
assert isinstance(tensor, TorchTensor)
assert isinstance(tensor, torch.Tensor)
assert tensor.shape == (128,)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove comma?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without the comma it would just be an int, but we want it to be a tuple

@samsja samsja merged commit 829e3f3 into feat-rewrite-v2 Nov 28, 2022
@samsja samsja deleted the parametrized-tensors branch November 28, 2022 08:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants