-
Notifications
You must be signed in to change notification settings - Fork 244
feat: parametrized tensors #850
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
03daff1
feat: parametrized torch tensor type
JohannesMessner ceb1aa5
refactor: fix typo in method name
JohannesMessner 24aa932
refactor: make parametrized type implementation generic
JohannesMessner 0371a17
refactor: make abstract tensor clearer
JohannesMessner 8429a14
feat: parametrized ndarray type
JohannesMessner 93dbc69
test: add serializatin tests for param tensors
JohannesMessner 3536cee
docs: add costrings for tensor types
JohannesMessner dc9ad35
fix: allow any kind of tensor shape
JohannesMessner c26a37a
fix: use view instead of reshape
JohannesMessner File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,69 @@ | ||
| import abc | ||
| from abc import ABC | ||
| from typing import TYPE_CHECKING, Any, Generic, Tuple, Type, TypeVar | ||
|
|
||
| from docarray.typing.abstract_type import AbstractType | ||
|
|
||
| if TYPE_CHECKING: | ||
| from pydantic import BaseConfig | ||
| from pydantic.fields import ModelField | ||
|
|
||
| T = TypeVar('T', bound='AbstractTensor') | ||
| ShapeT = TypeVar('ShapeT') | ||
|
|
||
|
|
||
| class AbstractTensor(AbstractType, Generic[ShapeT], ABC): | ||
|
|
||
| __parametrized_meta__ = type | ||
|
|
||
| @classmethod | ||
| @abc.abstractmethod | ||
| def __validate_shape__(cls, t: T, shape: Tuple[int]) -> T: | ||
| """Every tensor has to implement this method in order to | ||
| enbale syntax of the form Tensor[shape]. | ||
| The intended behavoiour is as follows: | ||
| - If the shape of `t` is equal to `shape`, return `t`. | ||
| - If the shape of `t` is not equal to `shape`, | ||
| but can be reshaped to `shape`, return `t` reshaped to `shape`. | ||
| - If the shape of `t` is not equal to `shape` | ||
| and cannot be reshaped to `shape`, raise a ValueError. | ||
|
|
||
| :param t: The tensor to validate. | ||
| :param shape: The shape to validate against. | ||
| :return: The validated tensor. | ||
| """ | ||
| ... | ||
|
|
||
| @classmethod | ||
| def _create_parametrized_type(cls: Type[T], shape: Tuple[int]): | ||
| shape_str = ', '.join([str(s) for s in shape]) | ||
|
|
||
| class _ParametrizedTensor( | ||
| cls, # type: ignore | ||
| metaclass=cls.__parametrized_meta__, # type: ignore | ||
| ): | ||
| _docarray_target_shape = shape | ||
| __name__ = f'{cls.__name__}[{shape_str}]' | ||
| __qualname__ = f'{cls.__qualname__}[{shape_str}]' | ||
|
|
||
| @classmethod | ||
| def validate( | ||
| _cls, | ||
| value: Any, | ||
| field: 'ModelField', | ||
| config: 'BaseConfig', | ||
| ): | ||
| t = super().validate(value, field, config) | ||
| return _cls.__validate_shape__(t, _cls._docarray_target_shape) | ||
|
|
||
| return _ParametrizedTensor | ||
|
|
||
| def __class_getitem__(cls, item): | ||
| if isinstance(item, int): | ||
| item = (item,) | ||
| try: | ||
| item = tuple(item) | ||
| except TypeError: | ||
| raise TypeError(f'{item} is not a valid tensor shape.') | ||
|
|
||
| return cls._create_parametrized_type(item) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since AbstractType is inheried, we don't need ABC ?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually
ABCis not inherited like other classes, due to it using a custom metaclass. I believe we need to inherit fromABCagain here.This is a quite advanced topic in Python that is usually not important day to day, but if you are interested you can read more here: https://peps.python.org/pep-3119/