From 904f1a7c35cdf8ddc79b49880081385b577c3af2 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Wed, 16 Nov 2022 14:21:11 +0100 Subject: [PATCH 1/4] feat: allow da bulk access to return da for document Signed-off-by: Sami Jaghouar --- docarray/array/abstract_array.py | 6 +++++- docarray/array/array.py | 6 +++--- docarray/array/mixins/attribute.py | 14 +++++++++++--- docarray/document/mixins/proto.py | 6 ++++-- tests/units/array/test_mixins/test_attribute.py | 16 ++++++++++++++++ 5 files changed, 39 insertions(+), 9 deletions(-) diff --git a/docarray/array/abstract_array.py b/docarray/array/abstract_array.py index 15a31527e43..383dc329302 100644 --- a/docarray/array/abstract_array.py +++ b/docarray/array/abstract_array.py @@ -10,5 +10,9 @@ class AbstractDocumentArray(Iterable): document_type: Type[BaseDocument] @abstractmethod - def __init__(self, docs: Iterable[AbstractDocument]): + def __init__(self, docs: Iterable[BaseDocument]): + ... + + @abstractmethod + def __class_getitem__(cls, item: Type[BaseDocument]) -> Type['AbstractDocument']: ... diff --git a/docarray/array/array.py b/docarray/array/array.py index d15f04a7c11..661c1b64a4f 100644 --- a/docarray/array/array.py +++ b/docarray/array/array.py @@ -7,7 +7,7 @@ class DocumentArray( - list, + list[AbstractDocument], ProtoArrayMixin, GetAttributeArrayMixin, AbstractDocumentArray, @@ -21,7 +21,7 @@ class DocumentArray( document_type: Type[BaseDocument] = AnyDocument - def __init__(self, docs: Iterable[AbstractDocument]): + def __init__(self, docs: Iterable[BaseDocument]): super().__init__(doc_ for doc_ in docs) def __class_getitem__(cls, item: Type[BaseDocument]): @@ -31,7 +31,7 @@ def __class_getitem__(cls, item: Type[BaseDocument]): ) class _DocumenArrayTyped(DocumentArray): - document_type = item + document_type: Type[BaseDocument] = item for field in _DocumenArrayTyped.document_type.__fields__.keys(): diff --git a/docarray/array/mixins/attribute.py b/docarray/array/mixins/attribute.py index b099aac0f4e..a22e172e6c9 100644 --- a/docarray/array/mixins/attribute.py +++ b/docarray/array/mixins/attribute.py @@ -1,12 +1,15 @@ -from typing import List +from typing import List, Union from docarray.array.abstract_array import AbstractDocumentArray +from docarray.document import BaseDocument class GetAttributeArrayMixin(AbstractDocumentArray): """Helpers that provide attributes getter in bulk""" - def _get_documents_attribute(self, field: str) -> List: + def _get_documents_attribute( + self, field: str + ) -> Union[List, AbstractDocumentArray]: """Return all values of the fields from all docs this array contains :param field: name of the fields to extract @@ -14,4 +17,9 @@ def _get_documents_attribute(self, field: str) -> List: in the array like container """ - return [getattr(doc, field) for doc in self] + field_type = self.__class__.document_type._get_nested_document_class(field) + + if issubclass(field_type, BaseDocument): + return self.__class__[field_type]((getattr(doc, field) for doc in self)) + else: + return [getattr(doc, field) for doc in self] diff --git a/docarray/document/mixins/proto.py b/docarray/document/mixins/proto.py index b601eda4a2f..aaa3e6157dc 100644 --- a/docarray/document/mixins/proto.py +++ b/docarray/document/mixins/proto.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Type, TypeVar from pydantic.tools import parse_obj_as @@ -7,10 +7,12 @@ from docarray.proto import DocumentProto, NodeProto from docarray.typing import ID, AnyUrl, Embedding, ImageUrl, Tensor, TorchTensor +T = TypeVar('T', bound='ProtoMixin') + class ProtoMixin(AbstractDocument, BaseNode): @classmethod - def from_protobuf(cls, pb_msg: 'DocumentProto') -> 'ProtoMixin': + def from_protobuf(cls: Type[T], pb_msg: 'DocumentProto') -> T: """create a Document from a protobuf message""" from docarray import DocumentArray diff --git a/tests/units/array/test_mixins/test_attribute.py b/tests/units/array/test_mixins/test_attribute.py index cf955702a09..2f83110143e 100644 --- a/tests/units/array/test_mixins/test_attribute.py +++ b/tests/units/array/test_mixins/test_attribute.py @@ -51,3 +51,19 @@ class Mmdoc(BaseDocument): assert len(texts) == N for i, text in enumerate(texts): assert text == f'hello{i}' + + +def test_get_bulk_attributes_document(): + class InnerDoc(BaseDocument): + text: str + + class Mmdoc(BaseDocument): + inner: InnerDoc + + N = 10 + + da = DocumentArray[Mmdoc]( + (Mmdoc(inner=InnerDoc(text=f'hello{i}')) for i in range(N)) + ) + + assert isinstance(da.inner, DocumentArray) From a9efdd94029b6439372ca935a91dbc5ee1cf0ab1 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Wed, 16 Nov 2022 15:27:27 +0100 Subject: [PATCH 2/4] fix: fix mypy type pb Signed-off-by: Sami Jaghouar --- docarray/array/abstract_array.py | 5 +++-- docarray/array/mixins/attribute.py | 6 +++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/docarray/array/abstract_array.py b/docarray/array/abstract_array.py index 383dc329302..4812d69d4ca 100644 --- a/docarray/array/abstract_array.py +++ b/docarray/array/abstract_array.py @@ -2,7 +2,6 @@ from typing import Iterable, Type from docarray.document import BaseDocument -from docarray.document.abstract_document import AbstractDocument class AbstractDocumentArray(Iterable): @@ -14,5 +13,7 @@ def __init__(self, docs: Iterable[BaseDocument]): ... @abstractmethod - def __class_getitem__(cls, item: Type[BaseDocument]) -> Type['AbstractDocument']: + def __class_getitem__( + cls, item: Type[BaseDocument] + ) -> Type['AbstractDocumentArray']: ... diff --git a/docarray/array/mixins/attribute.py b/docarray/array/mixins/attribute.py index a22e172e6c9..438f08916e7 100644 --- a/docarray/array/mixins/attribute.py +++ b/docarray/array/mixins/attribute.py @@ -20,6 +20,10 @@ def _get_documents_attribute( field_type = self.__class__.document_type._get_nested_document_class(field) if issubclass(field_type, BaseDocument): - return self.__class__[field_type]((getattr(doc, field) for doc in self)) + # calling __class_getitem__ ourselves is a hack otherwise mypy complain + # most likely a bug in mypy though + return self.__class__.__class_getitem__(field_type)( + (getattr(doc, field) for doc in self) + ) else: return [getattr(doc, field) for doc in self] From 6b5c1ddb39a5c2ffd88e1503d7935b3c5fa66ebc Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Wed, 16 Nov 2022 16:26:31 +0100 Subject: [PATCH 3/4] fix: add link to the mypy issue Signed-off-by: Sami Jaghouar --- docarray/array/mixins/attribute.py | 1 + 1 file changed, 1 insertion(+) diff --git a/docarray/array/mixins/attribute.py b/docarray/array/mixins/attribute.py index 438f08916e7..712b5b6524e 100644 --- a/docarray/array/mixins/attribute.py +++ b/docarray/array/mixins/attribute.py @@ -22,6 +22,7 @@ def _get_documents_attribute( if issubclass(field_type, BaseDocument): # calling __class_getitem__ ourselves is a hack otherwise mypy complain # most likely a bug in mypy though + # bug reported here https://github.com/python/mypy/issues/14111 return self.__class__.__class_getitem__(field_type)( (getattr(doc, field) for doc in self) ) From efa4513e3251d351aca1cd3754fbdd98cc1a160e Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Wed, 16 Nov 2022 16:30:04 +0100 Subject: [PATCH 4/4] fix: remove useless list type hint Signed-off-by: Sami Jaghouar --- docarray/array/array.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docarray/array/array.py b/docarray/array/array.py index 661c1b64a4f..681bbe19e02 100644 --- a/docarray/array/array.py +++ b/docarray/array/array.py @@ -3,11 +3,10 @@ from docarray.array.abstract_array import AbstractDocumentArray from docarray.array.mixins import GetAttributeArrayMixin, ProtoArrayMixin from docarray.document import AnyDocument, BaseDocument, BaseNode -from docarray.document.abstract_document import AbstractDocument class DocumentArray( - list[AbstractDocument], + list, ProtoArrayMixin, GetAttributeArrayMixin, AbstractDocumentArray,