diff --git a/docarray/array/documentarray.py b/docarray/array/documentarray.py index 635c3778fa7..7963535c006 100644 --- a/docarray/array/documentarray.py +++ b/docarray/array/documentarray.py @@ -1,15 +1,15 @@ from typing import Iterable, Type +from docarray.array.abstract_array import AbstractDocumentArray +from docarray.array.mixins import GetAttributeArrayMixin, ProtoArrayMixin from docarray.document import AnyDocument, BaseDocument, BaseNode from docarray.document.abstract_document import AbstractDocument -from .abstract_array import AbstractDocumentArray -from .mixins import ProtoArrayMixin - class DocumentArray( list, ProtoArrayMixin, + GetAttributeArrayMixin, AbstractDocumentArray, BaseNode, ): @@ -33,6 +33,14 @@ def __class_getitem__(cls, item: Type[BaseDocument]): class _DocumenArrayTyped(DocumentArray): document_type = item + for field in _DocumenArrayTyped.document_type.__fields__.keys(): + + def _proprety_generator(val: str): + return property(lambda self: self._get_documents_attribute(val)) + + setattr(_DocumenArrayTyped, field, _proprety_generator(field)) + # this generates property on the fly based on the schema of the item + _DocumenArrayTyped.__name__ = f'DocumentArray{item.__name__}' return _DocumenArrayTyped diff --git a/docarray/array/mixins/__init__.py b/docarray/array/mixins/__init__.py index 107ea40271d..180f2b34adf 100644 --- a/docarray/array/mixins/__init__.py +++ b/docarray/array/mixins/__init__.py @@ -1,3 +1,4 @@ +from docarray.array.mixins.attribute import GetAttributeArrayMixin from docarray.array.mixins.proto import ProtoArrayMixin -__all__ = ['ProtoArrayMixin'] +__all__ = ['ProtoArrayMixin', 'GetAttributeArrayMixin'] diff --git a/docarray/array/mixins/attribute.py b/docarray/array/mixins/attribute.py new file mode 100644 index 00000000000..b099aac0f4e --- /dev/null +++ b/docarray/array/mixins/attribute.py @@ -0,0 +1,17 @@ +from typing import List + +from docarray.array.abstract_array import AbstractDocumentArray + + +class GetAttributeArrayMixin(AbstractDocumentArray): + """Helpers that provide attributes getter in bulk""" + + def _get_documents_attribute(self, field: str) -> List: + """Return all values of the fields from all docs this array contains + + :param field: name of the fields to extract + :return: Returns a list of the field value for each document + in the array like container + """ + + return [getattr(doc, field) for doc in self] diff --git a/docarray/array/mixins/proto.py b/docarray/array/mixins/proto.py index de9aee858e6..e53c0a85e34 100644 --- a/docarray/array/mixins/proto.py +++ b/docarray/array/mixins/proto.py @@ -1,9 +1,8 @@ from typing import Type +from docarray.array.abstract_array import AbstractDocumentArray from docarray.proto import DocumentArrayProto, NodeProto -from ..abstract_array import AbstractDocumentArray - class ProtoArrayMixin(AbstractDocumentArray): @classmethod diff --git a/tests/units/array/test_mixins/test_attribute.py b/tests/units/array/test_mixins/test_attribute.py new file mode 100644 index 00000000000..5eabfb8b7ef --- /dev/null +++ b/tests/units/array/test_mixins/test_attribute.py @@ -0,0 +1,53 @@ +import numpy as np + +from docarray.array import DocumentArray +from docarray.document import BaseDocument +from docarray.typing import Tensor + + +def test_get_bulk_attributes(): + class Mmdoc(BaseDocument): + text: str + tensor: Tensor + + N = 10 + + da = DocumentArray[Mmdoc]( + (Mmdoc(text=f'hello{i}', tensor=np.zeros((3, 224, 224))) for i in range(N)) + ) + + tensors = da._get_documents_attribute('tensor') + + assert len(tensors) == N + for tensor in tensors: + assert tensor.shape == (3, 224, 224) + + texts = da._get_documents_attribute('text') + + assert len(texts) == N + for i, text in enumerate(texts): + assert text == f'hello{i}' + + +def test_get_bulk_attributes(): + class Mmdoc(BaseDocument): + text: str + tensor: Tensor + + N = 10 + + da = DocumentArray[Mmdoc]( + (Mmdoc(text=f'hello{i}', tensor=np.zeros((3, 224, 224))) for i in range(N)) + ) + + tensors = da.tensor + + assert len(tensors) == N + for tensor in tensors: + assert tensor.shape == (3, 224, 224) + + texts = da.text + + assert len(texts) == N + for i, text in enumerate(texts): + assert text == f'hello{i}'