From 2d79f9b81529edc797f310cf4b37a1937f3afc2b Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Tue, 15 Nov 2022 15:23:07 +0100 Subject: [PATCH 1/3] feat(da): add get attribute mixin to da Signed-off-by: Sami Jaghouar --- docarray/array/documentarray.py | 6 ++-- docarray/array/mixins/__init__.py | 3 +- docarray/array/mixins/attribute.py | 17 +++++++++++ docarray/array/mixins/proto.py | 3 +- .../units/array/test_mixins/test_attribute.py | 29 +++++++++++++++++++ 5 files changed, 52 insertions(+), 6 deletions(-) create mode 100644 docarray/array/mixins/attribute.py create mode 100644 tests/units/array/test_mixins/test_attribute.py diff --git a/docarray/array/documentarray.py b/docarray/array/documentarray.py index 635c3778fa7..8ead9b241e1 100644 --- a/docarray/array/documentarray.py +++ b/docarray/array/documentarray.py @@ -1,15 +1,15 @@ from typing import Iterable, Type +from docarray.array.abstract_array import AbstractDocumentArray +from docarray.array.mixins import GetAttributeArrayMixin, ProtoArrayMixin from docarray.document import AnyDocument, BaseDocument, BaseNode from docarray.document.abstract_document import AbstractDocument -from .abstract_array import AbstractDocumentArray -from .mixins import ProtoArrayMixin - class DocumentArray( list, ProtoArrayMixin, + GetAttributeArrayMixin, AbstractDocumentArray, BaseNode, ): diff --git a/docarray/array/mixins/__init__.py b/docarray/array/mixins/__init__.py index 107ea40271d..180f2b34adf 100644 --- a/docarray/array/mixins/__init__.py +++ b/docarray/array/mixins/__init__.py @@ -1,3 +1,4 @@ +from docarray.array.mixins.attribute import GetAttributeArrayMixin from docarray.array.mixins.proto import ProtoArrayMixin -__all__ = ['ProtoArrayMixin'] +__all__ = ['ProtoArrayMixin', 'GetAttributeArrayMixin'] diff --git a/docarray/array/mixins/attribute.py b/docarray/array/mixins/attribute.py new file mode 100644 index 00000000000..b099aac0f4e --- /dev/null +++ b/docarray/array/mixins/attribute.py @@ -0,0 +1,17 @@ +from typing import List + +from docarray.array.abstract_array import AbstractDocumentArray + + +class GetAttributeArrayMixin(AbstractDocumentArray): + """Helpers that provide attributes getter in bulk""" + + def _get_documents_attribute(self, field: str) -> List: + """Return all values of the fields from all docs this array contains + + :param field: name of the fields to extract + :return: Returns a list of the field value for each document + in the array like container + """ + + return [getattr(doc, field) for doc in self] diff --git a/docarray/array/mixins/proto.py b/docarray/array/mixins/proto.py index de9aee858e6..e53c0a85e34 100644 --- a/docarray/array/mixins/proto.py +++ b/docarray/array/mixins/proto.py @@ -1,9 +1,8 @@ from typing import Type +from docarray.array.abstract_array import AbstractDocumentArray from docarray.proto import DocumentArrayProto, NodeProto -from ..abstract_array import AbstractDocumentArray - class ProtoArrayMixin(AbstractDocumentArray): @classmethod diff --git a/tests/units/array/test_mixins/test_attribute.py b/tests/units/array/test_mixins/test_attribute.py new file mode 100644 index 00000000000..9745997cf0c --- /dev/null +++ b/tests/units/array/test_mixins/test_attribute.py @@ -0,0 +1,29 @@ +import numpy as np + +from docarray.array import DocumentArray +from docarray.document import BaseDocument +from docarray.typing import Tensor + + +def test_get_bukl_attributes(): + class Mmdoc(BaseDocument): + text: str + tensor: Tensor + + N = 10 + + da = DocumentArray[Mmdoc]( + (Mmdoc(text=f'hello{i}', tensor=np.zeros((3, 224, 224))) for i in range(N)) + ) + + tensors = da._get_documents_attribute('tensor') + + assert len(tensors) == N + for tensor in tensors: + assert tensor.shape == (3, 224, 224) + + texts = da._get_documents_attribute('text') + + assert len(texts) == N + for i, text in enumerate(texts): + assert text == f'hello{i}' From 6401c2e53fc5481b7bfaea0ff589be835d28d256 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Tue, 15 Nov 2022 18:10:43 +0100 Subject: [PATCH 2/3] feat(da): generate attribute on the fly at the da level Signed-off-by: Sami Jaghouar --- docarray/array/documentarray.py | 7 +++++ .../units/array/test_mixins/test_attribute.py | 26 ++++++++++++++++++- 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/docarray/array/documentarray.py b/docarray/array/documentarray.py index 8ead9b241e1..d2b5f0d2591 100644 --- a/docarray/array/documentarray.py +++ b/docarray/array/documentarray.py @@ -33,6 +33,13 @@ def __class_getitem__(cls, item: Type[BaseDocument]): class _DocumenArrayTyped(DocumentArray): document_type = item + for field in _DocumenArrayTyped.document_type.__fields__.keys(): + + def _proprety_generator(val: str): + return property(lambda self: self._get_documents_attribute(val)) + + setattr(_DocumenArrayTyped, field, _proprety_generator(field)) + _DocumenArrayTyped.__name__ = f'DocumentArray{item.__name__}' return _DocumenArrayTyped diff --git a/tests/units/array/test_mixins/test_attribute.py b/tests/units/array/test_mixins/test_attribute.py index 9745997cf0c..5eabfb8b7ef 100644 --- a/tests/units/array/test_mixins/test_attribute.py +++ b/tests/units/array/test_mixins/test_attribute.py @@ -5,7 +5,7 @@ from docarray.typing import Tensor -def test_get_bukl_attributes(): +def test_get_bulk_attributes(): class Mmdoc(BaseDocument): text: str tensor: Tensor @@ -27,3 +27,27 @@ class Mmdoc(BaseDocument): assert len(texts) == N for i, text in enumerate(texts): assert text == f'hello{i}' + + +def test_get_bulk_attributes(): + class Mmdoc(BaseDocument): + text: str + tensor: Tensor + + N = 10 + + da = DocumentArray[Mmdoc]( + (Mmdoc(text=f'hello{i}', tensor=np.zeros((3, 224, 224))) for i in range(N)) + ) + + tensors = da.tensor + + assert len(tensors) == N + for tensor in tensors: + assert tensor.shape == (3, 224, 224) + + texts = da.text + + assert len(texts) == N + for i, text in enumerate(texts): + assert text == f'hello{i}' From d708e0e55e346b0656c748cfd733888856596952 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Tue, 15 Nov 2022 18:19:24 +0100 Subject: [PATCH 3/3] feat(da): add comment Signed-off-by: Sami Jaghouar --- docarray/array/documentarray.py | 1 + 1 file changed, 1 insertion(+) diff --git a/docarray/array/documentarray.py b/docarray/array/documentarray.py index d2b5f0d2591..7963535c006 100644 --- a/docarray/array/documentarray.py +++ b/docarray/array/documentarray.py @@ -39,6 +39,7 @@ def _proprety_generator(val: str): return property(lambda self: self._get_documents_attribute(val)) setattr(_DocumenArrayTyped, field, _proprety_generator(field)) + # this generates property on the fly based on the schema of the item _DocumenArrayTyped.__name__ = f'DocumentArray{item.__name__}'