diff --git a/docarray/array/abstract_array.py b/docarray/array/abstract_array.py index b0a4b6db70c..832bc251517 100644 --- a/docarray/array/abstract_array.py +++ b/docarray/array/abstract_array.py @@ -9,6 +9,7 @@ Type, TypeVar, Union, + cast, ) from docarray.base_document import BaseDocument @@ -34,7 +35,10 @@ def __repr__(self): return f'<{self.__class__.__name__} (length={len(self)})>' @classmethod - def __class_getitem__(cls, item: Type[BaseDocument]): + def __class_getitem__(cls, item: Union[Type[BaseDocument], TypeVar, str]): + if not isinstance(item, type): + return Generic.__class_getitem__.__func__(cls, item) # type: ignore + # this do nothing that checking that item is valid type var or str if not issubclass(item, BaseDocument): raise ValueError( f'{cls.__name__}[item] item should be a Document not a {item} ' @@ -48,7 +52,7 @@ def __class_getitem__(cls, item: Type[BaseDocument]): global _DocumentArrayTyped class _DocumentArrayTyped(cls): # type: ignore - document_type: Type[BaseDocument] = item + document_type: Type[BaseDocument] = cast(Type[BaseDocument], item) for field in _DocumentArrayTyped.document_type.__fields__.keys(): diff --git a/docarray/array/array/array.py b/docarray/array/array/array.py index f39bae21efa..a42e0d2dca1 100644 --- a/docarray/array/array/array.py +++ b/docarray/array/array/array.py @@ -5,7 +5,6 @@ TYPE_CHECKING, Any, Callable, - Generic, Iterable, List, Optional, @@ -68,7 +67,7 @@ def _is_np_int(item: Any) -> bool: return False # this is unreachable, but mypy wants it -class DocumentArray(IOMixinArray, AnyDocumentArray, Generic[T_doc]): +class DocumentArray(IOMixinArray, AnyDocumentArray[T_doc]): """ DocumentArray is a container of Documents. diff --git a/docarray/array/stacked/array_stacked.py b/docarray/array/stacked/array_stacked.py index cf2bf2accab..50ef017113a 100644 --- a/docarray/array/stacked/array_stacked.py +++ b/docarray/array/stacked/array_stacked.py @@ -42,11 +42,12 @@ else: TensorFlowTensor = None # type: ignore +T_doc = TypeVar('T_doc', bound=BaseDocument) T = TypeVar('T', bound='DocumentArrayStacked') IndexIterType = Union[slice, Iterable[int], Iterable[bool], None] -class DocumentArrayStacked(AnyDocumentArray): +class DocumentArrayStacked(AnyDocumentArray[T_doc]): """ DocumentArrayStacked is a container of Documents appropriates to perform computation that require batches of data (ex: matrix multiplication, distance @@ -70,7 +71,7 @@ class DocumentArrayStacked(AnyDocumentArray): def __init__( self: T, - docs: Optional[Union[DocumentArray, Iterable[BaseDocument]]] = None, + docs: Optional[Union[DocumentArray, Iterable[T_doc]]] = None, tensor_type: Type['AbstractTensor'] = NdArray, ): self._doc_columns: Dict[str, 'DocumentArrayStacked'] = {} @@ -80,7 +81,7 @@ def __init__( self.from_iterable_document(docs) def from_iterable_document( - self: T, docs: Optional[Union[DocumentArray, Iterable[BaseDocument]]] + self: T, docs: Optional[Union[DocumentArray, Iterable[T_doc]]] ): self._docs = ( docs @@ -254,7 +255,7 @@ def _set_array_attribute( setattr(self._docs, field, values) @overload - def __getitem__(self: T, item: int) -> BaseDocument: + def __getitem__(self: T, item: int) -> T_doc: ... @overload @@ -276,9 +277,7 @@ def __getitem__(self, item): setattr(doc, field, self._tensor_columns[field][item]) return doc - def __setitem__( - self: T, key: Union[int, IndexIterType], value: Union[T, BaseDocument] - ): + def __setitem__(self: T, key: Union[int, IndexIterType], value: Union[T, T_doc]): # multiple docs case if isinstance(key, (slice, Iterable)): return self._set_data_and_columns(key, value) @@ -476,7 +475,7 @@ def unstacked_mode(self): @classmethod def validate( cls: Type[T], - value: Union[T, Iterable[BaseDocument]], + value: Union[T, Iterable[T_doc]], field: 'ModelField', config: 'BaseConfig', ) -> T: diff --git a/tests/units/array/test_array.py b/tests/units/array/test_array.py index d69514caab7..7c0e9b329f8 100644 --- a/tests/units/array/test_array.py +++ b/tests/units/array/test_array.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional, TypeVar, Union import numpy as np import pytest @@ -294,3 +294,17 @@ def test_del_item(da): 'hello 8', 'hello 9', ] + + +def test_generic_type_var(): + T = TypeVar('T', bound=BaseDocument) + + def f(a: DocumentArray[T]) -> DocumentArray[T]: + return a + + def g(a: DocumentArray['BaseDocument']) -> DocumentArray['BaseDocument']: + return a + + a = DocumentArray() + f(a) + g(a)