Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions docarray/array/documentarray.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from typing import Iterable, Type

from docarray.array.abstract_array import AbstractDocumentArray
from docarray.array.mixins import GetAttributeArrayMixin, ProtoArrayMixin
from docarray.document import AnyDocument, BaseDocument, BaseNode
from docarray.document.abstract_document import AbstractDocument

from .abstract_array import AbstractDocumentArray
from .mixins import ProtoArrayMixin


class DocumentArray(
list,
ProtoArrayMixin,
GetAttributeArrayMixin,
AbstractDocumentArray,
BaseNode,
):
Expand All @@ -33,6 +33,14 @@ def __class_getitem__(cls, item: Type[BaseDocument]):
class _DocumenArrayTyped(DocumentArray):
document_type = item

for field in _DocumenArrayTyped.document_type.__fields__.keys():

def _proprety_generator(val: str):
return property(lambda self: self._get_documents_attribute(val))

setattr(_DocumenArrayTyped, field, _proprety_generator(field))
# this generates property on the fly based on the schema of the item

_DocumenArrayTyped.__name__ = f'DocumentArray{item.__name__}'

return _DocumenArrayTyped
3 changes: 2 additions & 1 deletion docarray/array/mixins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from docarray.array.mixins.attribute import GetAttributeArrayMixin
from docarray.array.mixins.proto import ProtoArrayMixin

__all__ = ['ProtoArrayMixin']
__all__ = ['ProtoArrayMixin', 'GetAttributeArrayMixin']
17 changes: 17 additions & 0 deletions docarray/array/mixins/attribute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import List

from docarray.array.abstract_array import AbstractDocumentArray


class GetAttributeArrayMixin(AbstractDocumentArray):
"""Helpers that provide attributes getter in bulk"""

def _get_documents_attribute(self, field: str) -> List:
"""Return all values of the fields from all docs this array contains

:param field: name of the fields to extract
:return: Returns a list of the field value for each document
in the array like container
"""

return [getattr(doc, field) for doc in self]
3 changes: 1 addition & 2 deletions docarray/array/mixins/proto.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from typing import Type

from docarray.array.abstract_array import AbstractDocumentArray
from docarray.proto import DocumentArrayProto, NodeProto

from ..abstract_array import AbstractDocumentArray


class ProtoArrayMixin(AbstractDocumentArray):
@classmethod
Expand Down
53 changes: 53 additions & 0 deletions tests/units/array/test_mixins/test_attribute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import numpy as np

from docarray.array import DocumentArray
from docarray.document import BaseDocument
from docarray.typing import Tensor


def test_get_bulk_attributes():
class Mmdoc(BaseDocument):
text: str
tensor: Tensor

N = 10

da = DocumentArray[Mmdoc](
(Mmdoc(text=f'hello{i}', tensor=np.zeros((3, 224, 224))) for i in range(N))
)

tensors = da._get_documents_attribute('tensor')

assert len(tensors) == N
for tensor in tensors:
assert tensor.shape == (3, 224, 224)

texts = da._get_documents_attribute('text')

assert len(texts) == N
for i, text in enumerate(texts):
assert text == f'hello{i}'


def test_get_bulk_attributes():
class Mmdoc(BaseDocument):
text: str
tensor: Tensor

N = 10

da = DocumentArray[Mmdoc](
(Mmdoc(text=f'hello{i}', tensor=np.zeros((3, 224, 224))) for i in range(N))
)

tensors = da.tensor

assert len(tensors) == N
for tensor in tensors:
assert tensor.shape == (3, 224, 224)

texts = da.text

assert len(texts) == N
for i, text in enumerate(texts):
assert text == f'hello{i}'