From 47f0837d402d1eb4e4c574dd6289f84ff7ccced6 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Mon, 19 Dec 2022 16:43:50 +0100 Subject: [PATCH 01/26] test: add test for traverse flat Signed-off-by: anna-charlotte --- .../units/array/test_mixins/test_traverse.py | 80 +++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 tests/units/array/test_mixins/test_traverse.py diff --git a/tests/units/array/test_mixins/test_traverse.py b/tests/units/array/test_mixins/test_traverse.py new file mode 100644 index 00000000000..f8e917456d9 --- /dev/null +++ b/tests/units/array/test_mixins/test_traverse.py @@ -0,0 +1,80 @@ +from typing import Optional + +import pytest +import torch + +from docarray import Document, DocumentArray, Text +from docarray.typing import TorchTensor + +num_docs = 5 +num_sub_docs = 2 +num_sub_sub_docs = 3 + + +@pytest.fixture +def multi_model_docs(): + class SubSubDoc(Document): + sub_sub_text: Text + sub_sub_tensor: TorchTensor[2] + + class SubDoc(Document): + sub_text: Text + sub_da: DocumentArray[SubSubDoc] + + class MultiModalDoc(Document): + mm_text: Text + mm_tensor: Optional[TorchTensor[3, 2, 2]] + mm_da: DocumentArray[SubDoc] + + docs = DocumentArray[MultiModalDoc]( + [ + MultiModalDoc( + mm_text=Text(text=f'hello{i}'), + mm_da=[ + SubDoc( + sub_text=Text(text=f'sub_{i}_1'), + sub_da=DocumentArray[SubSubDoc]( + [ + SubSubDoc( + sub_sub_text=Text(text='subsub'), + sub_sub_tensor=torch.zeros(2), + ) + for _ in range(num_sub_sub_docs) + ] + ), + ) + for _ in range(num_sub_docs) + ], + ) + for i in range(num_docs) + ] + ) + + return docs + + +@pytest.mark.parametrize( + 'filter_fn', + [(lambda d: True), None], +) +@pytest.mark.parametrize( + 'traversal_path,len_result', + [ + ('mm_text', num_docs), # List of 5 Text + ('mm_text.text', num_docs), # List of 5 strings + ('mm_da', num_docs * num_sub_docs), # List of 5 * 2 SubDocs + ('mm_da.sub_text', num_docs * num_sub_docs), # List of 5 * 2 Text + ( + 'mm_da.sub_da', + num_docs * num_sub_docs * num_sub_sub_docs, + ), # List of 5 * 2 * 3 SubSubDoc + ( + 'mm_da.sub_da.sub_sub_text', + num_docs * num_sub_docs * num_sub_sub_docs, + ), # List of 5 * 2 * 3 Text + ], +) +def test_traverse_flat(multi_model_docs, traversal_path, len_result, filter_fn): + doc_req = multi_model_docs + ds = list(doc_req.traverse_flat(traversal_path)) + assert len(ds) == len_result From 434ab278390b774685c717646be2fc98760e5601 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Tue, 20 Dec 2022 14:03:03 +0100 Subject: [PATCH 02/26] test: update test for traverse flat Signed-off-by: anna-charlotte --- .../units/array/test_mixins/test_traverse.py | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/tests/units/array/test_mixins/test_traverse.py b/tests/units/array/test_mixins/test_traverse.py index f8e917456d9..27a4205bc6c 100644 --- a/tests/units/array/test_mixins/test_traverse.py +++ b/tests/units/array/test_mixins/test_traverse.py @@ -55,26 +55,25 @@ class MultiModalDoc(Document): @pytest.mark.parametrize( 'filter_fn', - [(lambda d: True), None], + [(lambda d: True)], ) @pytest.mark.parametrize( - 'traversal_path,len_result', + 'access_path,len_result', [ - ('mm_text', num_docs), # List of 5 Text + ('mm_text', num_docs), # List of 5 Text objs ('mm_text.text', num_docs), # List of 5 strings - ('mm_da', num_docs * num_sub_docs), # List of 5 * 2 SubDocs - ('mm_da.sub_text', num_docs * num_sub_docs), # List of 5 * 2 Text + ('mm_da', num_docs * num_sub_docs), # List of 5 * 2 SubDoc objs + ('mm_da.sub_text', num_docs * num_sub_docs), # List of 5 * 2 Text objs ( 'mm_da.sub_da', num_docs * num_sub_docs * num_sub_sub_docs, - ), # List of 5 * 2 * 3 SubSubDoc + ), # List of 5 * 2 * 3 SubSubDoc objs ( 'mm_da.sub_da.sub_sub_text', num_docs * num_sub_docs * num_sub_sub_docs, - ), # List of 5 * 2 * 3 Text + ), # List of 5 * 2 * 3 Text objs ], ) -def test_traverse_flat(multi_model_docs, traversal_path, len_result, filter_fn): - doc_req = multi_model_docs - ds = list(doc_req.traverse_flat(traversal_path)) - assert len(ds) == len_result +def test_traverse_flat(multi_model_docs, access_path, len_result, filter_fn): + traversed = multi_model_docs.traverse_flat(access_path) + assert len(traversed) == len_result From e4f8486ffb5f8763175965d4213ddd83563e271b Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Tue, 20 Dec 2022 14:11:55 +0100 Subject: [PATCH 03/26] chore: change ruff line length from 88 to 100 Signed-off-by: anna-charlotte --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 948ecd324f4..b3a911849a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,7 @@ skip_glob= ['docarray/proto/pb2/*'] [tool.ruff] exclude = ['docarray/proto/pb2/*', 'docs/*'] +line-length = 100 [tool.pytest.ini_options] markers = [ From 6cbacb4fa2dda2f1aec85ae950eb6a3100311a68 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Tue, 20 Dec 2022 14:12:50 +0100 Subject: [PATCH 04/26] feat: add traverse mixin with traverse flat Signed-off-by: anna-charlotte --- docarray/array/mixins/traverse.py | 109 ++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 docarray/array/mixins/traverse.py diff --git a/docarray/array/mixins/traverse.py b/docarray/array/mixins/traverse.py new file mode 100644 index 00000000000..52e7aaf13e5 --- /dev/null +++ b/docarray/array/mixins/traverse.py @@ -0,0 +1,109 @@ +from typing import TYPE_CHECKING, Any, List, Union + +from docarray.array.abstract_array import AbstractDocumentArray + +if TYPE_CHECKING: + from docarray import Document, DocumentArray + + +class TraverseMixin(AbstractDocumentArray): + """ + A mixin used for traversing :class:`DocumentArray`. + """ + + def traverse_flat( + self: AbstractDocumentArray, + access_path: str, + ) -> List[Any]: + """ + Return a List of the accessed objects when applying the access_path. If this res + Return a List of the accessed objects when applying the access_path. If this + results in a nested list or list of DocumentArrays, the list will be flattened + on the first level. The access path is a string that consists of attribute + names, concatenated and dot-seperated. It describes the path from the first + level to an arbitrary one, e.g. 'doc_attr_x.sub_doc_attr_x.sub_sub_doc_attr_z'. + + :param access_path: a string that represents the access path. + :return: list of the accessed objects, flattened if nested. + + EXAMPLE USAGE + .. code-block:: python + from docarray import Document, DocumentArray, Text + + + class Author(Document): + name: str + + + class Book(Document): + author: Author + content: Text + + + da = DocumentArray[Book]( + Book(author=Author(name='Ben'), content=Text(text=f'book_{i}')) for i in range(10) + ) + + books = da.traverse_flat(access_path='content') # list of 10 Text objs + authors = da.traverse_flat(access_path='author.name') # list of 10 strings + + If the resulting list is a nested list, it will be flattened: + + EXAMPLE USAGE + .. code-block:: python + from docarray import Document, DocumentArray + + + class Chapter(Document): + content: str + + + class Book(Document): + chapters: DocumentArray[Chapter] + + + da = DocumentArray[Book]( + Book( + chapters=DocumentArray[Chapter]( + [Chapter(content='some_content') for _ in range(3)] + ) + ) + for _ in range(10) + ) + + chapters = da.traverse_flat(access_path='chapters') # list of 30 strings + + """ + leaves = list(self._traverse(docs=self, access_path=access_path)) + return self._flatten(leaves) + + @staticmethod + def _traverse(docs: Union['Document', 'DocumentArray'], access_path: str): + if access_path: + path_attrs = access_path.split('.') + curr_attr = path_attrs[0] + path_attrs.pop(0) + + from docarray import Document + + if isinstance(docs, Document): + docs = [docs] + + for d in docs: + x = getattr(d, curr_attr) + yield from TraverseMixin._traverse(x, '.'.join(path_attrs)) + else: + yield docs + + @staticmethod + def _flatten(sequence) -> 'DocumentArray': + from docarray import DocumentArray + + res = [] + for seq in sequence: + if isinstance(seq, (list, DocumentArray)): + res += seq + else: + res.append(seq) + + return res From 1f00d9f113c0286cf4e980962a7e6783c3c7664e Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Tue, 20 Dec 2022 14:22:35 +0100 Subject: [PATCH 05/26] chore: revert pyproject toml ruff Signed-off-by: anna-charlotte --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b3a911849a5..948ecd324f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,6 @@ skip_glob= ['docarray/proto/pb2/*'] [tool.ruff] exclude = ['docarray/proto/pb2/*', 'docs/*'] -line-length = 100 [tool.pytest.ini_options] markers = [ From 9ed50c98bba6f0f118f9a6f9fd2042dc5f3fa4f4 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Tue, 20 Dec 2022 14:25:42 +0100 Subject: [PATCH 06/26] docs: fix docs in traverse mixin Signed-off-by: anna-charlotte --- docarray/array/mixins/traverse.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docarray/array/mixins/traverse.py b/docarray/array/mixins/traverse.py index 52e7aaf13e5..939563f3704 100644 --- a/docarray/array/mixins/traverse.py +++ b/docarray/array/mixins/traverse.py @@ -16,7 +16,6 @@ def traverse_flat( access_path: str, ) -> List[Any]: """ - Return a List of the accessed objects when applying the access_path. If this res Return a List of the accessed objects when applying the access_path. If this results in a nested list or list of DocumentArrays, the list will be flattened on the first level. The access path is a string that consists of attribute @@ -41,7 +40,8 @@ class Book(Document): da = DocumentArray[Book]( - Book(author=Author(name='Ben'), content=Text(text=f'book_{i}')) for i in range(10) + Book(author=Author(name='Ben'), content=Text(text=f'book_{i}')) + for i in range(10) # noqa: E501 ) books = da.traverse_flat(access_path='content') # list of 10 Text objs From c651fc8e60e8f326fc4f45d6aadddbf7d34a8a42 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Tue, 20 Dec 2022 14:36:00 +0100 Subject: [PATCH 07/26] fix: mypy type hints Signed-off-by: anna-charlotte --- docarray/array/mixins/traverse.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docarray/array/mixins/traverse.py b/docarray/array/mixins/traverse.py index 939563f3704..8c6dc814cc5 100644 --- a/docarray/array/mixins/traverse.py +++ b/docarray/array/mixins/traverse.py @@ -12,7 +12,7 @@ class TraverseMixin(AbstractDocumentArray): """ def traverse_flat( - self: AbstractDocumentArray, + self: DocumentArray, access_path: str, ) -> List[Any]: """ @@ -78,7 +78,7 @@ class Book(Document): return self._flatten(leaves) @staticmethod - def _traverse(docs: Union['Document', 'DocumentArray'], access_path: str): + def _traverse(docs: Union[Document, DocumentArray], access_path: str): if access_path: path_attrs = access_path.split('.') curr_attr = path_attrs[0] @@ -96,10 +96,10 @@ def _traverse(docs: Union['Document', 'DocumentArray'], access_path: str): yield docs @staticmethod - def _flatten(sequence) -> 'DocumentArray': + def _flatten(sequence) -> List[Any]: from docarray import DocumentArray - res = [] + res: List[Any] = [] for seq in sequence: if isinstance(seq, (list, DocumentArray)): res += seq From a2a18b63ef8a36e8af1a661efcc9f12885985ec4 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Tue, 20 Dec 2022 14:40:13 +0100 Subject: [PATCH 08/26] fix: mypy type hints Signed-off-by: anna-charlotte --- docarray/array/mixins/traverse.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docarray/array/mixins/traverse.py b/docarray/array/mixins/traverse.py index 8c6dc814cc5..e4db3bbb2ac 100644 --- a/docarray/array/mixins/traverse.py +++ b/docarray/array/mixins/traverse.py @@ -45,6 +45,7 @@ class Book(Document): ) books = da.traverse_flat(access_path='content') # list of 10 Text objs + authors = da.traverse_flat(access_path='author.name') # list of 10 strings If the resulting list is a nested list, it will be flattened: @@ -74,8 +75,8 @@ class Book(Document): chapters = da.traverse_flat(access_path='chapters') # list of 30 strings """ - leaves = list(self._traverse(docs=self, access_path=access_path)) - return self._flatten(leaves) + leaves = list(TraverseMixin._traverse(docs=self, access_path=access_path)) + return TraverseMixin._flatten(leaves) @staticmethod def _traverse(docs: Union[Document, DocumentArray], access_path: str): From fd0f954ebe9a34fa2ef8ed97f29d8a2df01dbca3 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Tue, 20 Dec 2022 14:44:43 +0100 Subject: [PATCH 09/26] fix: mypy type hints Signed-off-by: anna-charlotte --- docarray/array/mixins/traverse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docarray/array/mixins/traverse.py b/docarray/array/mixins/traverse.py index e4db3bbb2ac..1352bb237fe 100644 --- a/docarray/array/mixins/traverse.py +++ b/docarray/array/mixins/traverse.py @@ -88,7 +88,7 @@ def _traverse(docs: Union[Document, DocumentArray], access_path: str): from docarray import Document if isinstance(docs, Document): - docs = [docs] + docs: List[Document] = [docs] for d in docs: x = getattr(d, curr_attr) From 6a529696b114a61aad70d47485f0cdbfc46ff258 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Tue, 20 Dec 2022 15:14:06 +0100 Subject: [PATCH 10/26] fix: add allmixins Signed-off-by: anna-charlotte --- docarray/array/abstract_array.py | 3 ++- docarray/array/mixins/__init__.py | 7 +++++++ docarray/array/mixins/traverse.py | 13 ++++++------- 3 files changed, 15 insertions(+), 8 deletions(-) create mode 100644 docarray/array/mixins/__init__.py diff --git a/docarray/array/abstract_array.py b/docarray/array/abstract_array.py index e0f0e1f510e..34949b0416c 100644 --- a/docarray/array/abstract_array.py +++ b/docarray/array/abstract_array.py @@ -1,6 +1,7 @@ from abc import abstractmethod from typing import TYPE_CHECKING, Generic, List, Sequence, Type, TypeVar, Union +from docarray.array.mixins import AllMixins from docarray.document import BaseDocument from docarray.typing.abstract_type import AbstractType @@ -13,7 +14,7 @@ T_doc = TypeVar('T_doc', bound=BaseDocument) -class AnyDocumentArray(Sequence[BaseDocument], Generic[T_doc], AbstractType): +class AnyDocumentArray(AllMixins, Sequence[BaseDocument], Generic[T_doc], AbstractType): document_type: Type[BaseDocument] def __class_getitem__(cls, item: Type[BaseDocument]): diff --git a/docarray/array/mixins/__init__.py b/docarray/array/mixins/__init__.py new file mode 100644 index 00000000000..acff7ee41ce --- /dev/null +++ b/docarray/array/mixins/__init__.py @@ -0,0 +1,7 @@ +from docarray.array.mixins.traverse import TraverseMixin + + +class AllMixins(TraverseMixin): + """All plugins that can be used in :class:`DocumentArray`.""" + + ... diff --git a/docarray/array/mixins/traverse.py b/docarray/array/mixins/traverse.py index 1352bb237fe..af85bd0b240 100644 --- a/docarray/array/mixins/traverse.py +++ b/docarray/array/mixins/traverse.py @@ -1,12 +1,10 @@ from typing import TYPE_CHECKING, Any, List, Union -from docarray.array.abstract_array import AbstractDocumentArray - if TYPE_CHECKING: from docarray import Document, DocumentArray -class TraverseMixin(AbstractDocumentArray): +class TraverseMixin: """ A mixin used for traversing :class:`DocumentArray`. """ @@ -88,11 +86,12 @@ def _traverse(docs: Union[Document, DocumentArray], access_path: str): from docarray import Document if isinstance(docs, Document): - docs: List[Document] = [docs] - - for d in docs: - x = getattr(d, curr_attr) + x = getattr(docs, curr_attr) yield from TraverseMixin._traverse(x, '.'.join(path_attrs)) + else: + for d in docs: + x = getattr(d, curr_attr) + yield from TraverseMixin._traverse(x, '.'.join(path_attrs)) else: yield docs From 8c7e84a6b04c03e5916dad83804b6fc2af202a5a Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Tue, 20 Dec 2022 15:19:47 +0100 Subject: [PATCH 11/26] fix: mypy type hints Signed-off-by: anna-charlotte --- docarray/array/mixins/traverse.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docarray/array/mixins/traverse.py b/docarray/array/mixins/traverse.py index af85bd0b240..a6b1dc617ef 100644 --- a/docarray/array/mixins/traverse.py +++ b/docarray/array/mixins/traverse.py @@ -1,10 +1,12 @@ from typing import TYPE_CHECKING, Any, List, Union +from docarray.array.abstract_array import AnyDocumentArray + if TYPE_CHECKING: from docarray import Document, DocumentArray -class TraverseMixin: +class TraverseMixin(AnyDocumentArray): """ A mixin used for traversing :class:`DocumentArray`. """ From bf0e6174ac7d11366ef0835b0e31ba82c5ad980d Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Tue, 20 Dec 2022 15:25:16 +0100 Subject: [PATCH 12/26] fix: mypy type hints Signed-off-by: anna-charlotte --- docarray/array/mixins/traverse.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/docarray/array/mixins/traverse.py b/docarray/array/mixins/traverse.py index a6b1dc617ef..af85bd0b240 100644 --- a/docarray/array/mixins/traverse.py +++ b/docarray/array/mixins/traverse.py @@ -1,12 +1,10 @@ from typing import TYPE_CHECKING, Any, List, Union -from docarray.array.abstract_array import AnyDocumentArray - if TYPE_CHECKING: from docarray import Document, DocumentArray -class TraverseMixin(AnyDocumentArray): +class TraverseMixin: """ A mixin used for traversing :class:`DocumentArray`. """ From fe3fd8facf2fdb0c3ac9b7a9b07b647b41e50fc0 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Tue, 20 Dec 2022 15:29:45 +0100 Subject: [PATCH 13/26] fix: mypy type hints Signed-off-by: anna-charlotte --- docarray/array/mixins/traverse.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docarray/array/mixins/traverse.py b/docarray/array/mixins/traverse.py index af85bd0b240..6d26e661a22 100644 --- a/docarray/array/mixins/traverse.py +++ b/docarray/array/mixins/traverse.py @@ -10,7 +10,7 @@ class TraverseMixin: """ def traverse_flat( - self: DocumentArray, + self: 'DocumentArray', access_path: str, ) -> List[Any]: """ @@ -77,7 +77,7 @@ class Book(Document): return TraverseMixin._flatten(leaves) @staticmethod - def _traverse(docs: Union[Document, DocumentArray], access_path: str): + def _traverse(docs: Union['Document', 'DocumentArray'], access_path: str): if access_path: path_attrs = access_path.split('.') curr_attr = path_attrs[0] From 85c69bf873148454ff00c4cff4ed5978db8a1b00 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Wed, 21 Dec 2022 09:50:08 +0100 Subject: [PATCH 14/26] refactor: move traverse flat to abstract array Signed-off-by: anna-charlotte --- docarray/array/abstract_array.py | 103 +++++++++++++++++++++++++++- docarray/array/mixins/__init__.py | 7 -- docarray/array/mixins/traverse.py | 109 ------------------------------ 3 files changed, 100 insertions(+), 119 deletions(-) delete mode 100644 docarray/array/mixins/__init__.py delete mode 100644 docarray/array/mixins/traverse.py diff --git a/docarray/array/abstract_array.py b/docarray/array/abstract_array.py index 34949b0416c..c96f873427c 100644 --- a/docarray/array/abstract_array.py +++ b/docarray/array/abstract_array.py @@ -1,11 +1,11 @@ from abc import abstractmethod -from typing import TYPE_CHECKING, Generic, List, Sequence, Type, TypeVar, Union +from typing import TYPE_CHECKING, Any, Generic, List, Sequence, Type, TypeVar, Union -from docarray.array.mixins import AllMixins from docarray.document import BaseDocument from docarray.typing.abstract_type import AbstractType if TYPE_CHECKING: + from docarray import Document from docarray.proto import DocumentArrayProto, NodeProto from docarray.typing import NdArray, TorchTensor @@ -14,7 +14,7 @@ T_doc = TypeVar('T_doc', bound=BaseDocument) -class AnyDocumentArray(AllMixins, Sequence[BaseDocument], Generic[T_doc], AbstractType): +class AnyDocumentArray(Sequence[BaseDocument], Generic[T_doc], AbstractType): document_type: Type[BaseDocument] def __class_getitem__(cls, item: Type[BaseDocument]): @@ -93,3 +93,100 @@ def _to_node_protobuf(self) -> 'NodeProto': from docarray.proto import NodeProto return NodeProto(chunks=self.to_protobuf()) + + def traverse_flat( + self: 'AnyDocumentArray', + access_path: str, + ) -> List[Any]: + """ + Return a List of the accessed objects when applying the access_path. If this + results in a nested list or list of DocumentArrays, the list will be flattened + on the first level. The access path is a string that consists of attribute + names, concatenated and dot-seperated. It describes the path from the first + level to an arbitrary one, e.g. 'doc_attr_x.sub_doc_attr_x.sub_sub_doc_attr_z'. + + :param access_path: a string that represents the access path. + :return: list of the accessed objects, flattened if nested. + + EXAMPLE USAGE + .. code-block:: python + from docarray import Document, DocumentArray, Text + + + class Author(Document): + name: str + + + class Book(Document): + author: Author + content: Text + + + da = DocumentArray[Book]( + Book(author=Author(name='Jenny'), content=Text(text=f'book_{i}')) + for i in range(10) # noqa: E501 + ) + + books = da.traverse_flat(access_path='content') # list of 10 Text objs + + authors = da.traverse_flat(access_path='author.name') # list of 10 strings + + If the resulting list is a nested list, it will be flattened: + + EXAMPLE USAGE + .. code-block:: python + from docarray import Document, DocumentArray + + + class Chapter(Document): + content: str + + + class Book(Document): + chapters: DocumentArray[Chapter] + + + da = DocumentArray[Book]( + Book( + chapters=DocumentArray[Chapter]( + [Chapter(content='some_content') for _ in range(3)] + ) + ) + for _ in range(10) + ) + + chapters = da.traverse_flat(access_path='chapters') # list of 30 strings + + """ + leaves = list(AnyDocumentArray._traverse(node=self, access_path=access_path)) + return AnyDocumentArray._flatten(leaves) + + @staticmethod + def _traverse(node: Union['Document', 'AnyDocumentArray'], access_path: str): + if access_path: + path_attrs = access_path.split('.') + curr_attr = path_attrs[0] + path_attrs.pop(0) + + if isinstance(node, (AnyDocumentArray, list)): + for n in node: + x = getattr(n, curr_attr) + yield from AnyDocumentArray._traverse(x, '.'.join(path_attrs)) + else: + x = getattr(node, curr_attr) + yield from AnyDocumentArray._traverse(x, '.'.join(path_attrs)) + else: + yield node + + @staticmethod + def _flatten(sequence: List[Any]) -> List[Any]: + from docarray import DocumentArray + + res: List[Any] = [] + for seq in sequence: + if isinstance(seq, (list, DocumentArray)): + res += seq + else: + res.append(seq) + + return res diff --git a/docarray/array/mixins/__init__.py b/docarray/array/mixins/__init__.py deleted file mode 100644 index acff7ee41ce..00000000000 --- a/docarray/array/mixins/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from docarray.array.mixins.traverse import TraverseMixin - - -class AllMixins(TraverseMixin): - """All plugins that can be used in :class:`DocumentArray`.""" - - ... diff --git a/docarray/array/mixins/traverse.py b/docarray/array/mixins/traverse.py deleted file mode 100644 index 6d26e661a22..00000000000 --- a/docarray/array/mixins/traverse.py +++ /dev/null @@ -1,109 +0,0 @@ -from typing import TYPE_CHECKING, Any, List, Union - -if TYPE_CHECKING: - from docarray import Document, DocumentArray - - -class TraverseMixin: - """ - A mixin used for traversing :class:`DocumentArray`. - """ - - def traverse_flat( - self: 'DocumentArray', - access_path: str, - ) -> List[Any]: - """ - Return a List of the accessed objects when applying the access_path. If this - results in a nested list or list of DocumentArrays, the list will be flattened - on the first level. The access path is a string that consists of attribute - names, concatenated and dot-seperated. It describes the path from the first - level to an arbitrary one, e.g. 'doc_attr_x.sub_doc_attr_x.sub_sub_doc_attr_z'. - - :param access_path: a string that represents the access path. - :return: list of the accessed objects, flattened if nested. - - EXAMPLE USAGE - .. code-block:: python - from docarray import Document, DocumentArray, Text - - - class Author(Document): - name: str - - - class Book(Document): - author: Author - content: Text - - - da = DocumentArray[Book]( - Book(author=Author(name='Ben'), content=Text(text=f'book_{i}')) - for i in range(10) # noqa: E501 - ) - - books = da.traverse_flat(access_path='content') # list of 10 Text objs - - authors = da.traverse_flat(access_path='author.name') # list of 10 strings - - If the resulting list is a nested list, it will be flattened: - - EXAMPLE USAGE - .. code-block:: python - from docarray import Document, DocumentArray - - - class Chapter(Document): - content: str - - - class Book(Document): - chapters: DocumentArray[Chapter] - - - da = DocumentArray[Book]( - Book( - chapters=DocumentArray[Chapter]( - [Chapter(content='some_content') for _ in range(3)] - ) - ) - for _ in range(10) - ) - - chapters = da.traverse_flat(access_path='chapters') # list of 30 strings - - """ - leaves = list(TraverseMixin._traverse(docs=self, access_path=access_path)) - return TraverseMixin._flatten(leaves) - - @staticmethod - def _traverse(docs: Union['Document', 'DocumentArray'], access_path: str): - if access_path: - path_attrs = access_path.split('.') - curr_attr = path_attrs[0] - path_attrs.pop(0) - - from docarray import Document - - if isinstance(docs, Document): - x = getattr(docs, curr_attr) - yield from TraverseMixin._traverse(x, '.'.join(path_attrs)) - else: - for d in docs: - x = getattr(d, curr_attr) - yield from TraverseMixin._traverse(x, '.'.join(path_attrs)) - else: - yield docs - - @staticmethod - def _flatten(sequence) -> List[Any]: - from docarray import DocumentArray - - res: List[Any] = [] - for seq in sequence: - if isinstance(seq, (list, DocumentArray)): - res += seq - else: - res.append(seq) - - return res From e0d9f1ad22ab2046f7e5a987e506700ff3a60418 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Wed, 21 Dec 2022 12:20:53 +0100 Subject: [PATCH 15/26] fix: traverse flat for stack mode Signed-off-by: anna-charlotte --- docarray/array/abstract_array.py | 22 +++++++++++----- .../array/{test_mixins => }/test_traverse.py | 26 +++++++++++++++---- 2 files changed, 37 insertions(+), 11 deletions(-) rename tests/units/array/{test_mixins => }/test_traverse.py (82%) diff --git a/docarray/array/abstract_array.py b/docarray/array/abstract_array.py index c96f873427c..2594fc6667f 100644 --- a/docarray/array/abstract_array.py +++ b/docarray/array/abstract_array.py @@ -7,8 +7,7 @@ if TYPE_CHECKING: from docarray import Document from docarray.proto import DocumentArrayProto, NodeProto - from docarray.typing import NdArray, TorchTensor - + from docarray.typing import NdArray, Tensor, TorchTensor T = TypeVar('T', bound='AnyDocumentArray') T_doc = TypeVar('T_doc', bound=BaseDocument) @@ -97,7 +96,7 @@ def _to_node_protobuf(self) -> 'NodeProto': def traverse_flat( self: 'AnyDocumentArray', access_path: str, - ) -> List[Any]: + ) -> Union[List[Any], 'Tensor']: """ Return a List of the accessed objects when applying the access_path. If this results in a nested list or list of DocumentArrays, the list will be flattened @@ -158,8 +157,15 @@ class Book(Document): chapters = da.traverse_flat(access_path='chapters') # list of 30 strings """ - leaves = list(AnyDocumentArray._traverse(node=self, access_path=access_path)) - return AnyDocumentArray._flatten(leaves) + nodes = list(AnyDocumentArray._traverse(node=self, access_path=access_path)) + flattened = AnyDocumentArray._flatten(nodes) + + from docarray.typing import Tensor + + if len(flattened) == 1 and isinstance(flattened[0], Tensor): + return flattened[0] + else: + return flattened @staticmethod def _traverse(node: Union['Document', 'AnyDocumentArray'], access_path: str): @@ -168,7 +174,11 @@ def _traverse(node: Union['Document', 'AnyDocumentArray'], access_path: str): curr_attr = path_attrs[0] path_attrs.pop(0) - if isinstance(node, (AnyDocumentArray, list)): + from docarray.array import DocumentArrayStacked + + if isinstance(node, (AnyDocumentArray, list)) and not isinstance( + node, DocumentArrayStacked + ): for n in node: x = getattr(n, curr_attr) yield from AnyDocumentArray._traverse(x, '.'.join(path_attrs)) diff --git a/tests/units/array/test_mixins/test_traverse.py b/tests/units/array/test_traverse.py similarity index 82% rename from tests/units/array/test_mixins/test_traverse.py rename to tests/units/array/test_traverse.py index 27a4205bc6c..4c74c427229 100644 --- a/tests/units/array/test_mixins/test_traverse.py +++ b/tests/units/array/test_traverse.py @@ -53,10 +53,6 @@ class MultiModalDoc(Document): return docs -@pytest.mark.parametrize( - 'filter_fn', - [(lambda d: True)], -) @pytest.mark.parametrize( 'access_path,len_result', [ @@ -74,6 +70,26 @@ class MultiModalDoc(Document): ), # List of 5 * 2 * 3 Text objs ], ) -def test_traverse_flat(multi_model_docs, access_path, len_result, filter_fn): +def test_traverse_flat(multi_model_docs, access_path, len_result): traversed = multi_model_docs.traverse_flat(access_path) assert len(traversed) == len_result + + +def test_traverse_stacked_da(): + class Image(Document): + tensor: TorchTensor[3, 224, 224] + + batch = DocumentArray[Image]( + [ + Image( + tensor=torch.zeros(3, 224, 224), + ) + for _ in range(2) + ] + ) + + batch_stacked = batch.stack() + tensors = batch_stacked.traverse_flat(access_path='tensor') + + assert tensors.shape == (2, 3, 224, 224) + assert isinstance(tensors, torch.Tensor) From 130dea2d699e88f2b4d8d65a5ebbcddec470993b Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Wed, 21 Dec 2022 12:26:53 +0100 Subject: [PATCH 16/26] docs: add docs for stacked da Signed-off-by: anna-charlotte --- docarray/array/abstract_array.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/docarray/array/abstract_array.py b/docarray/array/abstract_array.py index 2594fc6667f..972365cd787 100644 --- a/docarray/array/abstract_array.py +++ b/docarray/array/abstract_array.py @@ -156,6 +156,29 @@ class Book(Document): chapters = da.traverse_flat(access_path='chapters') # list of 30 strings + If your DocumentArray is in stacked mode and you want to access a field of + type Tensor, the stacked tensor will be returned instead of a list: + + EXAMPLE USAGE + .. code-block:: python + class Image(Document): + tensor: TorchTensor[3, 224, 224] + + + batch = DocumentArray[Image]( + [ + Image( + tensor=torch.zeros(3, 224, 224), + ) + for _ in range(2) + ] + ) + + batch_stacked = batch.stack() + tensors = batch_stacked.traverse_flat( + access_path='tensor' + ) # tensor of shape (2, 3, 224, 224) + """ nodes = list(AnyDocumentArray._traverse(node=self, access_path=access_path)) flattened = AnyDocumentArray._flatten(nodes) From 201a42d2a0692506dd74b7eafdd79a6aaa6522dd Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Wed, 21 Dec 2022 15:13:15 +0100 Subject: [PATCH 17/26] feat: traversal flat for stacked and unstacked in abstract array Signed-off-by: anna-charlotte --- docarray/array/abstract_array.py | 34 +++++++++++++++----------------- 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/docarray/array/abstract_array.py b/docarray/array/abstract_array.py index 972365cd787..ed719130b8c 100644 --- a/docarray/array/abstract_array.py +++ b/docarray/array/abstract_array.py @@ -5,7 +5,6 @@ from docarray.typing.abstract_type import AbstractType if TYPE_CHECKING: - from docarray import Document from docarray.proto import DocumentArrayProto, NodeProto from docarray.typing import NdArray, Tensor, TorchTensor @@ -156,7 +155,7 @@ class Book(Document): chapters = da.traverse_flat(access_path='chapters') # list of 30 strings - If your DocumentArray is in stacked mode and you want to access a field of + If your DocumentArray is in stacked mode, and you want to access a field of type Tensor, the stacked tensor will be returned instead of a list: EXAMPLE USAGE @@ -183,19 +182,22 @@ class Image(Document): nodes = list(AnyDocumentArray._traverse(node=self, access_path=access_path)) flattened = AnyDocumentArray._flatten(nodes) + from docarray.array import DocumentArrayStacked from docarray.typing import Tensor - if len(flattened) == 1 and isinstance(flattened[0], Tensor): + if ( + len(flattened) == 1 + and isinstance(flattened[0], Tensor) + and isinstance(self, DocumentArrayStacked) + ): return flattened[0] else: return flattened @staticmethod - def _traverse(node: Union['Document', 'AnyDocumentArray'], access_path: str): + def _traverse(node: Any, access_path: str): if access_path: - path_attrs = access_path.split('.') - curr_attr = path_attrs[0] - path_attrs.pop(0) + curr_attr, _, path_attrs = access_path.partition('.') from docarray.array import DocumentArrayStacked @@ -204,22 +206,18 @@ def _traverse(node: Union['Document', 'AnyDocumentArray'], access_path: str): ): for n in node: x = getattr(n, curr_attr) - yield from AnyDocumentArray._traverse(x, '.'.join(path_attrs)) + yield from AnyDocumentArray._traverse(x, path_attrs) else: x = getattr(node, curr_attr) - yield from AnyDocumentArray._traverse(x, '.'.join(path_attrs)) + yield from AnyDocumentArray._traverse(x, path_attrs) else: yield node @staticmethod - def _flatten(sequence: List[Any]) -> List[Any]: + def _flatten(sequence) -> List[Any]: from docarray import DocumentArray - res: List[Any] = [] - for seq in sequence: - if isinstance(seq, (list, DocumentArray)): - res += seq - else: - res.append(seq) - - return res + if isinstance(sequence[0], (list, DocumentArray)): + return [item for sublist in sequence for item in sublist] + else: + return sequence From af74fb74363d00ecb5f8a56cbf7d8091f40aa0b4 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Wed, 21 Dec 2022 16:08:14 +0100 Subject: [PATCH 18/26] feat: move traverse_flat to array and stacked array instead of abstract array Signed-off-by: anna-charlotte --- docarray/array/abstract_array.py | 22 ++++------------------ docarray/array/array.py | 13 +++++++++++-- docarray/array/array_stacked.py | 17 ++++++++++++++++- 3 files changed, 31 insertions(+), 21 deletions(-) diff --git a/docarray/array/abstract_array.py b/docarray/array/abstract_array.py index ed719130b8c..53d66a20f8c 100644 --- a/docarray/array/abstract_array.py +++ b/docarray/array/abstract_array.py @@ -92,6 +92,7 @@ def _to_node_protobuf(self) -> 'NodeProto': return NodeProto(chunks=self.to_protobuf()) + @abstractmethod def traverse_flat( self: 'AnyDocumentArray', access_path: str, @@ -179,31 +180,16 @@ class Image(Document): ) # tensor of shape (2, 3, 224, 224) """ - nodes = list(AnyDocumentArray._traverse(node=self, access_path=access_path)) - flattened = AnyDocumentArray._flatten(nodes) - - from docarray.array import DocumentArrayStacked - from docarray.typing import Tensor - - if ( - len(flattened) == 1 - and isinstance(flattened[0], Tensor) - and isinstance(self, DocumentArrayStacked) - ): - return flattened[0] - else: - return flattened + ... @staticmethod def _traverse(node: Any, access_path: str): if access_path: curr_attr, _, path_attrs = access_path.partition('.') - from docarray.array import DocumentArrayStacked + from docarray.array import DocumentArray - if isinstance(node, (AnyDocumentArray, list)) and not isinstance( - node, DocumentArrayStacked - ): + if isinstance(node, (DocumentArray, list)): for n in node: x = getattr(n, curr_attr) yield from AnyDocumentArray._traverse(x, path_attrs) diff --git a/docarray/array/array.py b/docarray/array/array.py index a6132c06089..98999eed533 100644 --- a/docarray/array/array.py +++ b/docarray/array/array.py @@ -1,6 +1,6 @@ from contextlib import contextmanager from functools import wraps -from typing import TYPE_CHECKING, Callable, Iterable, List, Type, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Type, TypeVar, Union from docarray.array.abstract_array import AnyDocumentArray from docarray.document import AnyDocument, BaseDocument @@ -11,7 +11,7 @@ from docarray.array.array_stacked import DocumentArrayStacked from docarray.proto import DocumentArrayProto - from docarray.typing import NdArray, TorchTensor + from docarray.typing import NdArray, Tensor, TorchTensor T = TypeVar('T', bound='DocumentArray') @@ -189,3 +189,12 @@ def validate( return cls(value) else: raise TypeError(f'Expecting an Iterable of {cls.document_type}') + + def traverse_flat( + self: 'DocumentArray', + access_path: str, + ) -> Union[List[Any], 'Tensor']: + nodes = list(AnyDocumentArray._traverse(node=self, access_path=access_path)) + flattened = AnyDocumentArray._flatten(nodes) + + return flattened diff --git a/docarray/array/array_stacked.py b/docarray/array/array_stacked.py index 9aaba357e41..6cf22b03052 100644 --- a/docarray/array/array_stacked.py +++ b/docarray/array/array_stacked.py @@ -2,6 +2,7 @@ from contextlib import contextmanager from typing import ( TYPE_CHECKING, + Any, DefaultDict, Dict, Iterable, @@ -22,7 +23,7 @@ from pydantic.fields import ModelField from docarray.proto import DocumentArrayStackedProto - from docarray.typing import TorchTensor + from docarray.typing import Tensor, TorchTensor try: @@ -256,3 +257,17 @@ def validate( return cls(DocumentArray(value)) else: raise TypeError(f'Expecting an Iterable of {cls.document_type}') + + def traverse_flat( + self: 'AnyDocumentArray', + access_path: str, + ) -> Union[List[Any], 'Tensor']: + nodes = list(AnyDocumentArray._traverse(node=self, access_path=access_path)) + flattened = AnyDocumentArray._flatten(nodes) + + from docarray.typing import Tensor + + if len(flattened) == 1 and isinstance(flattened[0], Tensor): + return flattened[0] + else: + return flattened From cdb49f57b439a8035f55a1ba0d8746eb81abbb91 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Thu, 22 Dec 2022 18:00:12 +0100 Subject: [PATCH 19/26] fix: remove size for instance check Signed-off-by: anna-charlotte --- tests/units/array/test_traverse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/units/array/test_traverse.py b/tests/units/array/test_traverse.py index 4c74c427229..28db5eaf7b8 100644 --- a/tests/units/array/test_traverse.py +++ b/tests/units/array/test_traverse.py @@ -77,7 +77,7 @@ def test_traverse_flat(multi_model_docs, access_path, len_result): def test_traverse_stacked_da(): class Image(Document): - tensor: TorchTensor[3, 224, 224] + tensor: TorchTensor batch = DocumentArray[Image]( [ From c4038cb68b9232d57a71bce397ec0907c2b02a5c Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Thu, 22 Dec 2022 18:13:44 +0100 Subject: [PATCH 20/26] fix: change instance check Signed-off-by: anna-charlotte --- docarray/array/array_stacked.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/docarray/array/array_stacked.py b/docarray/array/array_stacked.py index 6cf22b03052..0153d908934 100644 --- a/docarray/array/array_stacked.py +++ b/docarray/array/array_stacked.py @@ -265,9 +265,7 @@ def traverse_flat( nodes = list(AnyDocumentArray._traverse(node=self, access_path=access_path)) flattened = AnyDocumentArray._flatten(nodes) - from docarray.typing import Tensor - - if len(flattened) == 1 and isinstance(flattened[0], Tensor): + if len(flattened) == 1 and isinstance(flattened[0], (NdArray, TorchTensor)): return flattened[0] else: return flattened From ee2a0e3b5b3be51aa1d7b9fdcf594396d7a2e751 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Fri, 23 Dec 2022 09:39:30 +0100 Subject: [PATCH 21/26] fix: mypy type hints Signed-off-by: anna-charlotte --- docarray/array/array.py | 4 ++-- docarray/array/array_stacked.py | 4 ++-- tests/units/array/test_traverse.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docarray/array/array.py b/docarray/array/array.py index 98999eed533..d65c9861023 100644 --- a/docarray/array/array.py +++ b/docarray/array/array.py @@ -11,7 +11,7 @@ from docarray.array.array_stacked import DocumentArrayStacked from docarray.proto import DocumentArrayProto - from docarray.typing import NdArray, Tensor, TorchTensor + from docarray.typing import NdArray, TorchTensor T = TypeVar('T', bound='DocumentArray') @@ -193,7 +193,7 @@ def validate( def traverse_flat( self: 'DocumentArray', access_path: str, - ) -> Union[List[Any], 'Tensor']: + ) -> Union[List[Any]]: nodes = list(AnyDocumentArray._traverse(node=self, access_path=access_path)) flattened = AnyDocumentArray._flatten(nodes) diff --git a/docarray/array/array_stacked.py b/docarray/array/array_stacked.py index 0153d908934..76d135d9534 100644 --- a/docarray/array/array_stacked.py +++ b/docarray/array/array_stacked.py @@ -15,7 +15,7 @@ from docarray.array.abstract_array import AnyDocumentArray from docarray.array.array import DocumentArray from docarray.document import AnyDocument, BaseDocument -from docarray.typing import NdArray +from docarray.typing import NdArray, Tensor from docarray.typing.tensor.abstract_tensor import AbstractTensor if TYPE_CHECKING: @@ -265,7 +265,7 @@ def traverse_flat( nodes = list(AnyDocumentArray._traverse(node=self, access_path=access_path)) flattened = AnyDocumentArray._flatten(nodes) - if len(flattened) == 1 and isinstance(flattened[0], (NdArray, TorchTensor)): + if len(flattened) == 1 and isinstance(flattened[0], Tensor): return flattened[0] else: return flattened diff --git a/tests/units/array/test_traverse.py b/tests/units/array/test_traverse.py index 28db5eaf7b8..4c74c427229 100644 --- a/tests/units/array/test_traverse.py +++ b/tests/units/array/test_traverse.py @@ -77,7 +77,7 @@ def test_traverse_flat(multi_model_docs, access_path, len_result): def test_traverse_stacked_da(): class Image(Document): - tensor: TorchTensor + tensor: TorchTensor[3, 224, 224] batch = DocumentArray[Image]( [ From aab4c6a858c2ede97bbbe0b5f8c4c0de4ce4262c Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Fri, 23 Dec 2022 09:49:09 +0100 Subject: [PATCH 22/26] fix: torchtensor and ndarray instead of only tensor Signed-off-by: anna-charlotte --- docarray/array/abstract_array.py | 4 ++-- docarray/array/array_stacked.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docarray/array/abstract_array.py b/docarray/array/abstract_array.py index 53d66a20f8c..ddc50f5b782 100644 --- a/docarray/array/abstract_array.py +++ b/docarray/array/abstract_array.py @@ -6,7 +6,7 @@ if TYPE_CHECKING: from docarray.proto import DocumentArrayProto, NodeProto - from docarray.typing import NdArray, Tensor, TorchTensor + from docarray.typing import NdArray, TorchTensor T = TypeVar('T', bound='AnyDocumentArray') T_doc = TypeVar('T_doc', bound=BaseDocument) @@ -96,7 +96,7 @@ def _to_node_protobuf(self) -> 'NodeProto': def traverse_flat( self: 'AnyDocumentArray', access_path: str, - ) -> Union[List[Any], 'Tensor']: + ) -> Union[List[Any], 'NdArray', 'TorchTensor']: """ Return a List of the accessed objects when applying the access_path. If this results in a nested list or list of DocumentArrays, the list will be flattened diff --git a/docarray/array/array_stacked.py b/docarray/array/array_stacked.py index 76d135d9534..f85be7cb785 100644 --- a/docarray/array/array_stacked.py +++ b/docarray/array/array_stacked.py @@ -15,7 +15,7 @@ from docarray.array.abstract_array import AnyDocumentArray from docarray.array.array import DocumentArray from docarray.document import AnyDocument, BaseDocument -from docarray.typing import NdArray, Tensor +from docarray.typing import NdArray from docarray.typing.tensor.abstract_tensor import AbstractTensor if TYPE_CHECKING: @@ -23,7 +23,7 @@ from pydantic.fields import ModelField from docarray.proto import DocumentArrayStackedProto - from docarray.typing import Tensor, TorchTensor + from docarray.typing import TorchTensor try: @@ -261,11 +261,11 @@ def validate( def traverse_flat( self: 'AnyDocumentArray', access_path: str, - ) -> Union[List[Any], 'Tensor']: + ) -> Union[List[Any], 'TorchTensor', 'NdArray']: nodes = list(AnyDocumentArray._traverse(node=self, access_path=access_path)) flattened = AnyDocumentArray._flatten(nodes) - if len(flattened) == 1 and isinstance(flattened[0], Tensor): + if len(flattened) == 1 and isinstance(flattened[0], (NdArray, TorchTensor)): return flattened[0] else: return flattened From 710b927898ff8426a57d9eda567d7267eeaee39e Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Fri, 23 Dec 2022 11:23:50 +0100 Subject: [PATCH 23/26] fix: add type hint and len check in flatten Signed-off-by: anna-charlotte --- docarray/array/abstract_array.py | 8 ++++---- docarray/array/array.py | 7 ++++++- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/docarray/array/abstract_array.py b/docarray/array/abstract_array.py index ddc50f5b782..19166f639f9 100644 --- a/docarray/array/abstract_array.py +++ b/docarray/array/abstract_array.py @@ -200,10 +200,10 @@ def _traverse(node: Any, access_path: str): yield node @staticmethod - def _flatten(sequence) -> List[Any]: + def _flatten(sequence: List[Any]) -> List[Any]: from docarray import DocumentArray - if isinstance(sequence[0], (list, DocumentArray)): - return [item for sublist in sequence for item in sublist] - else: + if len(sequence) == 0 or not isinstance(sequence[0], (list, DocumentArray)): return sequence + else: + return [item for sublist in sequence for item in sublist] diff --git a/docarray/array/array.py b/docarray/array/array.py index d65c9861023..e7dbe2a1237 100644 --- a/docarray/array/array.py +++ b/docarray/array/array.py @@ -2,6 +2,8 @@ from functools import wraps from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Type, TypeVar, Union +import numpy as np + from docarray.array.abstract_array import AnyDocumentArray from docarray.document import AnyDocument, BaseDocument @@ -77,7 +79,10 @@ def __len__(self): return len(self._data) def __getitem__(self, item): - return self._data[item] + if isinstance(item, (int, np.generic)): + return self._data[item] + elif isinstance(item, str): + pass def __iter__(self): return iter(self._data) From fe8bf3ccd51eecb342b0797bac51c4641d25b49c Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Fri, 23 Dec 2022 11:42:55 +0100 Subject: [PATCH 24/26] refactor: rename flatten to flatten one level Signed-off-by: anna-charlotte --- docarray/array/abstract_array.py | 2 +- docarray/array/array.py | 2 +- docarray/array/array_stacked.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docarray/array/abstract_array.py b/docarray/array/abstract_array.py index 19166f639f9..0ea5bbefea9 100644 --- a/docarray/array/abstract_array.py +++ b/docarray/array/abstract_array.py @@ -200,7 +200,7 @@ def _traverse(node: Any, access_path: str): yield node @staticmethod - def _flatten(sequence: List[Any]) -> List[Any]: + def _flatten_one_level(sequence: List[Any]) -> List[Any]: from docarray import DocumentArray if len(sequence) == 0 or not isinstance(sequence[0], (list, DocumentArray)): diff --git a/docarray/array/array.py b/docarray/array/array.py index e7dbe2a1237..4e955b62f87 100644 --- a/docarray/array/array.py +++ b/docarray/array/array.py @@ -200,6 +200,6 @@ def traverse_flat( access_path: str, ) -> Union[List[Any]]: nodes = list(AnyDocumentArray._traverse(node=self, access_path=access_path)) - flattened = AnyDocumentArray._flatten(nodes) + flattened = AnyDocumentArray._flatten_one_level(nodes) return flattened diff --git a/docarray/array/array_stacked.py b/docarray/array/array_stacked.py index f85be7cb785..2690c337bd4 100644 --- a/docarray/array/array_stacked.py +++ b/docarray/array/array_stacked.py @@ -263,7 +263,7 @@ def traverse_flat( access_path: str, ) -> Union[List[Any], 'TorchTensor', 'NdArray']: nodes = list(AnyDocumentArray._traverse(node=self, access_path=access_path)) - flattened = AnyDocumentArray._flatten(nodes) + flattened = AnyDocumentArray._flatten_one_level(nodes) if len(flattened) == 1 and isinstance(flattened[0], (NdArray, TorchTensor)): return flattened[0] From baa1c318f886f9c85ad3881ed127b34bee48d308 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Fri, 23 Dec 2022 11:44:37 +0100 Subject: [PATCH 25/26] test: add test for flatten one level Signed-off-by: anna-charlotte --- tests/units/array/test_traverse.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/units/array/test_traverse.py b/tests/units/array/test_traverse.py index 4c74c427229..3b9286f06ea 100644 --- a/tests/units/array/test_traverse.py +++ b/tests/units/array/test_traverse.py @@ -4,6 +4,7 @@ import torch from docarray import Document, DocumentArray, Text +from docarray.array.abstract_array import AnyDocumentArray from docarray.typing import TorchTensor num_docs = 5 @@ -93,3 +94,24 @@ class Image(Document): assert tensors.shape == (2, 3, 224, 224) assert isinstance(tensors, torch.Tensor) + + +@pytest.mark.parametrize( + 'input_list,output_list', + [ + ([1, 2, 3], [1, 2, 3]), + ([[1], [2], [3]], [1, 2, 3]), + ([[[1]], [[2]], [[3]]], [[1], [2], [3]]), + ], +) +def test_flatten_one_level(input_list, output_list): + flattened = AnyDocumentArray._flatten_one_level(sequence=input_list) + assert flattened == output_list + + +def test_flatten_one_level_list_of_da(): + doc = Document() + input_list = [DocumentArray([doc, doc, doc])] + + flattened = AnyDocumentArray._flatten_one_level(sequence=input_list) + assert flattened == [doc, doc, doc] From 459c4dd6c4d7ed5c0d38e66e5384a2a7841893f1 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Fri, 23 Dec 2022 11:46:33 +0100 Subject: [PATCH 26/26] fix: revert change in array get item Signed-off-by: anna-charlotte --- docarray/array/array.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/docarray/array/array.py b/docarray/array/array.py index 4e955b62f87..c9c4fc15f58 100644 --- a/docarray/array/array.py +++ b/docarray/array/array.py @@ -2,8 +2,6 @@ from functools import wraps from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Type, TypeVar, Union -import numpy as np - from docarray.array.abstract_array import AnyDocumentArray from docarray.document import AnyDocument, BaseDocument @@ -79,10 +77,7 @@ def __len__(self): return len(self._data) def __getitem__(self, item): - if isinstance(item, (int, np.generic)): - return self._data[item] - elif isinstance(item, str): - pass + return self._data[item] def __iter__(self): return iter(self._data)