Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions docarray/array/abstract_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@
from typing import Iterable, Type

from docarray.document import BaseDocument
from docarray.document.abstract_document import AbstractDocument


class AbstractDocumentArray(Iterable):

document_type: Type[BaseDocument]

@abstractmethod
def __init__(self, docs: Iterable[AbstractDocument]):
def __init__(self, docs: Iterable[BaseDocument]):
...

@abstractmethod
def __class_getitem__(
cls, item: Type[BaseDocument]
) -> Type['AbstractDocumentArray']:
...
5 changes: 2 additions & 3 deletions docarray/array/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from docarray.array.abstract_array import AbstractDocumentArray
from docarray.array.mixins import GetAttributeArrayMixin, ProtoArrayMixin
from docarray.document import AnyDocument, BaseDocument, BaseNode
from docarray.document.abstract_document import AbstractDocument


class DocumentArray(
Expand All @@ -21,7 +20,7 @@ class DocumentArray(

document_type: Type[BaseDocument] = AnyDocument

def __init__(self, docs: Iterable[AbstractDocument]):
def __init__(self, docs: Iterable[BaseDocument]):
super().__init__(doc_ for doc_ in docs)

def __class_getitem__(cls, item: Type[BaseDocument]):
Expand All @@ -31,7 +30,7 @@ def __class_getitem__(cls, item: Type[BaseDocument]):
)

class _DocumenArrayTyped(DocumentArray):
document_type = item
document_type: Type[BaseDocument] = item

for field in _DocumenArrayTyped.document_type.__fields__.keys():

Expand Down
19 changes: 16 additions & 3 deletions docarray/array/mixins/attribute.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,30 @@
from typing import List
from typing import List, Union

from docarray.array.abstract_array import AbstractDocumentArray
from docarray.document import BaseDocument


class GetAttributeArrayMixin(AbstractDocumentArray):
"""Helpers that provide attributes getter in bulk"""

def _get_documents_attribute(self, field: str) -> List:
def _get_documents_attribute(
self, field: str
) -> Union[List, AbstractDocumentArray]:
"""Return all values of the fields from all docs this array contains

:param field: name of the fields to extract
:return: Returns a list of the field value for each document
in the array like container
"""

return [getattr(doc, field) for doc in self]
field_type = self.__class__.document_type._get_nested_document_class(field)

if issubclass(field_type, BaseDocument):
# calling __class_getitem__ ourselves is a hack otherwise mypy complain
# most likely a bug in mypy though
# bug reported here https://github.com/python/mypy/issues/14111
return self.__class__.__class_getitem__(field_type)(
(getattr(doc, field) for doc in self)
)
else:
return [getattr(doc, field) for doc in self]
6 changes: 4 additions & 2 deletions docarray/document/mixins/proto.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any, Dict, Type, TypeVar

from pydantic.tools import parse_obj_as

Expand All @@ -7,10 +7,12 @@
from docarray.proto import DocumentProto, NodeProto
from docarray.typing import ID, AnyUrl, Embedding, ImageUrl, Tensor, TorchTensor

T = TypeVar('T', bound='ProtoMixin')


class ProtoMixin(AbstractDocument, BaseNode):
@classmethod
def from_protobuf(cls, pb_msg: 'DocumentProto') -> 'ProtoMixin':
def from_protobuf(cls: Type[T], pb_msg: 'DocumentProto') -> T:
"""create a Document from a protobuf message"""
from docarray import DocumentArray

Expand Down
16 changes: 16 additions & 0 deletions tests/units/array/test_mixins/test_attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,19 @@ class Mmdoc(BaseDocument):
assert len(texts) == N
for i, text in enumerate(texts):
assert text == f'hello{i}'


def test_get_bulk_attributes_document():
class InnerDoc(BaseDocument):
text: str

class Mmdoc(BaseDocument):
inner: InnerDoc

N = 10

da = DocumentArray[Mmdoc](
(Mmdoc(inner=InnerDoc(text=f'hello{i}')) for i in range(N))
)

assert isinstance(da.inner, DocumentArray)