Skip to content
8 changes: 6 additions & 2 deletions docarray/array/abstract_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Type,
TypeVar,
Union,
cast,
)

from docarray.base_document import BaseDocument
Expand All @@ -34,7 +35,10 @@ def __repr__(self):
return f'<{self.__class__.__name__} (length={len(self)})>'

@classmethod
def __class_getitem__(cls, item: Type[BaseDocument]):
def __class_getitem__(cls, item: Union[Type[BaseDocument], TypeVar, str]):
if not isinstance(item, type):
return Generic.__class_getitem__.__func__(cls, item) # type: ignore

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How did you know this is what is required?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this even doing anything?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this even doing anything?

this is not doing anything else than checking the the item is valid Generic. I think it does not harm to keep it

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this even doing anything?

this is not doing anything else than checking the the item is valid Generic. I think it does not harm to keep it

Crashed in Python 3.12.0

    return Generic.__class_getitem__.__func__(cls, item)  # type: ignore
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'builtin_function_or_method' object has no attribute '__func__'. Did you mean: '__doc__'?

# this do nothing that checking that item is valid type var or str
if not issubclass(item, BaseDocument):
raise ValueError(
f'{cls.__name__}[item] item should be a Document not a {item} '
Expand All @@ -48,7 +52,7 @@ def __class_getitem__(cls, item: Type[BaseDocument]):
global _DocumentArrayTyped

class _DocumentArrayTyped(cls): # type: ignore
document_type: Type[BaseDocument] = item
document_type: Type[BaseDocument] = cast(Type[BaseDocument], item)

for field in _DocumentArrayTyped.document_type.__fields__.keys():

Expand Down
3 changes: 1 addition & 2 deletions docarray/array/array/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
TYPE_CHECKING,
Any,
Callable,
Generic,
Iterable,
List,
Optional,
Expand Down Expand Up @@ -68,7 +67,7 @@ def _is_np_int(item: Any) -> bool:
return False # this is unreachable, but mypy wants it


class DocumentArray(IOMixinArray, AnyDocumentArray, Generic[T_doc]):
class DocumentArray(IOMixinArray, AnyDocumentArray[T_doc]):
"""
DocumentArray is a container of Documents.

Expand Down
15 changes: 7 additions & 8 deletions docarray/array/stacked/array_stacked.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,12 @@
else:
TensorFlowTensor = None # type: ignore

T_doc = TypeVar('T_doc', bound=BaseDocument)
T = TypeVar('T', bound='DocumentArrayStacked')
IndexIterType = Union[slice, Iterable[int], Iterable[bool], None]


class DocumentArrayStacked(AnyDocumentArray):
class DocumentArrayStacked(AnyDocumentArray[T_doc]):
"""
DocumentArrayStacked is a container of Documents appropriates to perform
computation that require batches of data (ex: matrix multiplication, distance
Expand All @@ -70,7 +71,7 @@ class DocumentArrayStacked(AnyDocumentArray):

def __init__(
self: T,
docs: Optional[Union[DocumentArray, Iterable[BaseDocument]]] = None,
docs: Optional[Union[DocumentArray, Iterable[T_doc]]] = None,
tensor_type: Type['AbstractTensor'] = NdArray,
):
self._doc_columns: Dict[str, 'DocumentArrayStacked'] = {}
Expand All @@ -80,7 +81,7 @@ def __init__(
self.from_iterable_document(docs)

def from_iterable_document(
self: T, docs: Optional[Union[DocumentArray, Iterable[BaseDocument]]]
self: T, docs: Optional[Union[DocumentArray, Iterable[T_doc]]]
):
self._docs = (
docs
Expand Down Expand Up @@ -254,7 +255,7 @@ def _set_array_attribute(
setattr(self._docs, field, values)

@overload
def __getitem__(self: T, item: int) -> BaseDocument:
def __getitem__(self: T, item: int) -> T_doc:
...

@overload
Expand All @@ -276,9 +277,7 @@ def __getitem__(self, item):
setattr(doc, field, self._tensor_columns[field][item])
return doc

def __setitem__(
self: T, key: Union[int, IndexIterType], value: Union[T, BaseDocument]
):
def __setitem__(self: T, key: Union[int, IndexIterType], value: Union[T, T_doc]):
# multiple docs case
if isinstance(key, (slice, Iterable)):
return self._set_data_and_columns(key, value)
Expand Down Expand Up @@ -476,7 +475,7 @@ def unstacked_mode(self):
@classmethod
def validate(
cls: Type[T],
value: Union[T, Iterable[BaseDocument]],
value: Union[T, Iterable[T_doc]],
field: 'ModelField',
config: 'BaseConfig',
) -> T:
Expand Down
16 changes: 15 additions & 1 deletion tests/units/array/test_array.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Optional, TypeVar, Union

import numpy as np
import pytest
Expand Down Expand Up @@ -294,3 +294,17 @@ def test_del_item(da):
'hello 8',
'hello 9',
]


def test_generic_type_var():
T = TypeVar('T', bound=BaseDocument)

def f(a: DocumentArray[T]) -> DocumentArray[T]:
return a

def g(a: DocumentArray['BaseDocument']) -> DocumentArray['BaseDocument']:
return a

a = DocumentArray()
f(a)
g(a)