From dce8c6ba5df9d8d35b36ba86819e1cf6da6a528e Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Fri, 17 Feb 2023 10:58:46 +0100 Subject: [PATCH 01/19] feat: load from and to csv Signed-off-by: anna-charlotte --- docarray/array/array/io.py | 93 ++++++++++++++++++- tests/toydata/docs_nested.csv | 4 + tests/units/array/test_array_from_to_csv.py | 98 +++++++++++++++++++++ 3 files changed, 194 insertions(+), 1 deletion(-) create mode 100644 tests/toydata/docs_nested.csv create mode 100644 tests/units/array/test_array_from_to_csv.py diff --git a/docarray/array/array/io.py b/docarray/array/array/io.py index b5a846102d2..061e1702c7d 100644 --- a/docarray/array/array/io.py +++ b/docarray/array/array/io.py @@ -1,4 +1,5 @@ import base64 +import csv import io import json import os @@ -6,6 +7,7 @@ import pickle from abc import abstractmethod from contextlib import nullcontext +from itertools import compress from typing import ( TYPE_CHECKING, BinaryIO, @@ -16,11 +18,19 @@ Tuple, Type, TypeVar, - Union, + Union, Dict, Any, ) from docarray.base_document import BaseDocument + +import numpy as np +from typing_inspect import is_union_type + +from docarray.array.abstract_array import AnyDocumentArray +from docarray.base_document import AnyDocument, BaseDocument +from docarray.typing import NdArray from docarray.utils.compress import _decompress_bytes, _get_compress_ctx +from docarray.utils.misc import is_torch_available if TYPE_CHECKING: @@ -291,6 +301,62 @@ def to_json(self) -> str: """ return json.dumps([doc.json() for doc in self]) + @classmethod + def from_csv(cls, file_path: str, encoding: str = 'utf-8') -> T: + """ + Load a DocumentArray from a csv file. + + :param file_path: path to csv file to load DocumentArray from. + :param encoding: encoding used to read the csv file. Defaults to 'utf-8'. + :return: DocumentArray + """ + from docarray import DocumentArray + + doc_type = cls.document_type + if doc_type == AnyDocument: + raise TypeError( + "There is no document schema defined. " + "To load from csv, please specify the DocumentArray's document type." + ) + + da = DocumentArray[doc_type]() + with open(file_path, 'r', encoding=encoding) as fp: + lines = csv.DictReader(fp, dialect='excel') + fields = lines.fieldnames + valid = [_assert_schema(doc_type, field) for field in fields] + if not all(valid): + raise ValueError( + f'Fields provided in the csv file do not match the schema of the DocumentArray\'s ' + f'document type ({doc_type.__name__}): {list(compress(fields, [not v for v in valid]))}' + ) + + for line in lines: + doc_dict = {} + for field, value in line.items(): + print(f"field, value = {field, value}") + if value in ['', 'None']: + value = None + doc_dict.update(access_path_to_dict(access_path=field, value=value)) + da.append(doc_type.parse_obj(doc_dict)) + + return da + + def to_csv(self, file_path: str) -> None: + """ + Save a DocumentArray to a csv file. + + :param file_path: path to a csv file. + """ + fields = self.document_type.__fields__ + with open(file_path, 'w') as csv_file: + writer = csv.DictWriter(csv_file, fieldnames=fields) + writer.writeheader() + + for doc in self: + doc_dict = doc.dict() + writer.writerow(doc_dict) + + # Methods to load from/to files in different formats @property def _stream_header(self) -> bytes: @@ -530,3 +596,28 @@ def save_binary( file_ctx=file_ctx, show_progress=show_progress, ) + + + +def _assert_schema(doc: Type['BaseDocument'], field_name: str) -> bool: + field, _, remaining = field_name.partition('.') + if len(remaining) == 0: + return field_name in doc.__fields__.keys() + else: + valid_field = field in doc.__fields__.keys() + if not valid_field: + return False + else: + d = doc._get_field_type(field) + if not issubclass(d, BaseDocument): + return False + else: + return _assert_schema(d, remaining) + + +def access_path_to_dict(access_path: str, value) -> Dict[str, Any]: + fields = access_path.split('.') + for field in reversed(fields): + result = {field: value} + value = result + return result diff --git a/tests/toydata/docs_nested.csv b/tests/toydata/docs_nested.csv new file mode 100644 index 00000000000..d530ff50041 --- /dev/null +++ b/tests/toydata/docs_nested.csv @@ -0,0 +1,4 @@ +count,text,image,image2.url +000,hello 0,image_0.png,image_10.png +111,hello 1,image_1.png,image_11.png +222,hello 2,image_2.png, \ No newline at end of file diff --git a/tests/units/array/test_array_from_to_csv.py b/tests/units/array/test_array_from_to_csv.py new file mode 100644 index 00000000000..f102157181d --- /dev/null +++ b/tests/units/array/test_array_from_to_csv.py @@ -0,0 +1,98 @@ +import os +from typing import Optional + +import pytest + +from docarray import BaseDocument, DocumentArray +from docarray.array.array.io import _assert_schema +from docarray.documents import Image +from tests import TOYDATA_DIR + + +class MyDoc(BaseDocument): + count: int + text: str + + +class MyDocNested(MyDoc): + image: Image + image2: Image + + +def test_to_csv(tmpdir): + da = DocumentArray[MyDocNested]( + [ + MyDocNested( + text='hello', image=Image(url='aux.png'), image2=Image(url='aux.png') + ), + MyDocNested(text='hello world', image=Image(), image2=Image()), + ] + ) + tmp_file = str(tmpdir / 'tmp.csv') + da.to_csv(tmp_file) + assert os.path.isfile(tmp_file) + + +def test_from_csv_nested(): + da = DocumentArray[MyDocNested].from_csv( + file_path=str(TOYDATA_DIR / 'docs_nested.csv') + ) + assert len(da) == 3 + + for i, doc in enumerate(da): + assert doc.count.__class__ == int + assert doc.count == int(f'{i}{i}{i}') + + assert doc.text.__class__ == str + assert doc.text == f'hello {i}' + + assert doc.image.__class__ == Image + assert doc.image.tensor is None + assert doc.image.embedding is None + assert doc.image.bytes is None + + assert doc.image2.__class__ == Image + assert doc.image2.tensor is None + assert doc.image2.embedding is None + assert doc.image2.bytes is None + + assert da[0].image2.url == 'image_10.png' + assert da[1].image2.url == 'image_11.png' + assert da[2].image2.url is None + + +@pytest.fixture() +def nested_doc(): + class Inner(BaseDocument): + img: Optional[Image] + + class Middle(BaseDocument): + img: Optional[Image] + inner: Optional[Inner] + + class Outer(BaseDocument): + img: Optional[Image] + middle: Optional[Middle] + + doc = Outer(img=Image(), middle=Middle(img=Image(), inner=Inner(img=Image()))) + return doc + + +def test_from_csv_without_schema_raise_exception(): + with pytest.raises(TypeError, match='no document schema defined'): + DocumentArray.from_csv(file_path=str(TOYDATA_DIR / 'docs_nested.csv')) + + +def test_from_csv_with_wrong_schema_raise_exception(nested_doc): + with pytest.raises(ValueError, match=r'.*Outer.*embedding.*text.*image.*'): + DocumentArray[nested_doc.__class__].from_csv( + file_path=str(TOYDATA_DIR / 'docs_nested.csv') + ) + + +def test_assert_schema(nested_doc): + assert _assert_schema(nested_doc.__class__, 'img') + assert _assert_schema(nested_doc.__class__, 'middle.img') + assert _assert_schema(nested_doc.__class__, 'middle.inner.img') + assert _assert_schema(nested_doc.__class__, 'middle') + assert not _assert_schema(nested_doc.__class__, 'inner') From 37892a9b97b1636d9ecd6155c7bd9eaaa2dc99f0 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Mon, 20 Feb 2023 10:04:12 +0100 Subject: [PATCH 02/19] fix: from to csv Signed-off-by: anna-charlotte --- docarray/array/array/io.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/docarray/array/array/io.py b/docarray/array/array/io.py index 061e1702c7d..02061936f10 100644 --- a/docarray/array/array/io.py +++ b/docarray/array/array/io.py @@ -7,11 +7,14 @@ import pickle from abc import abstractmethod from contextlib import nullcontext +from contextlib import contextmanager, nullcontext +from functools import wraps from itertools import compress from typing import ( TYPE_CHECKING, BinaryIO, ContextManager, + Dict, Generator, Iterable, Optional, @@ -29,12 +32,14 @@ from docarray.array.abstract_array import AnyDocumentArray from docarray.base_document import AnyDocument, BaseDocument from docarray.typing import NdArray +from docarray.typing.tensor.abstract_tensor import AbstractTensor from docarray.utils.compress import _decompress_bytes, _get_compress_ctx from docarray.utils.misc import is_torch_available if TYPE_CHECKING: from docarray.proto import DocumentArrayProto + from docarray.typing import TorchTensor T = TypeVar('T', bound='IOMixinArray') @@ -356,7 +361,6 @@ def to_csv(self, file_path: str) -> None: doc_dict = doc.dict() writer.writerow(doc_dict) - # Methods to load from/to files in different formats @property def _stream_header(self) -> bytes: @@ -598,7 +602,6 @@ def save_binary( ) - def _assert_schema(doc: Type['BaseDocument'], field_name: str) -> bool: field, _, remaining = field_name.partition('.') if len(remaining) == 0: From e1e90cfc67b39edac6f6c7cec52fec455fa0d701 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Mon, 20 Feb 2023 11:14:11 +0100 Subject: [PATCH 03/19] feat: add access path to dict Signed-off-by: anna-charlotte --- docarray/array/array/io.py | 31 +++++++------ tests/units/array/test_array_from_to_csv.py | 46 ++++++++++++++++++-- tests/units/array/test_array_from_to_json.py | 7 +-- 3 files changed, 61 insertions(+), 23 deletions(-) diff --git a/docarray/array/array/io.py b/docarray/array/array/io.py index 02061936f10..97ea90045b5 100644 --- a/docarray/array/array/io.py +++ b/docarray/array/array/io.py @@ -7,11 +7,10 @@ import pickle from abc import abstractmethod from contextlib import nullcontext -from contextlib import contextmanager, nullcontext -from functools import wraps from itertools import compress from typing import ( TYPE_CHECKING, + Any, BinaryIO, ContextManager, Dict, @@ -21,25 +20,15 @@ Tuple, Type, TypeVar, - Union, Dict, Any, + Union, ) -from docarray.base_document import BaseDocument - -import numpy as np -from typing_inspect import is_union_type - -from docarray.array.abstract_array import AnyDocumentArray from docarray.base_document import AnyDocument, BaseDocument -from docarray.typing import NdArray -from docarray.typing.tensor.abstract_tensor import AbstractTensor from docarray.utils.compress import _decompress_bytes, _get_compress_ctx -from docarray.utils.misc import is_torch_available if TYPE_CHECKING: from docarray.proto import DocumentArrayProto - from docarray.typing import TorchTensor T = TypeVar('T', bound='IOMixinArray') @@ -353,12 +342,13 @@ def to_csv(self, file_path: str) -> None: :param file_path: path to a csv file. """ fields = self.document_type.__fields__ + with open(file_path, 'w') as csv_file: writer = csv.DictWriter(csv_file, fieldnames=fields) writer.writeheader() for doc in self: - doc_dict = doc.dict() + doc_dict = dict_to_access_paths(doc.dict()) writer.writerow(doc_dict) # Methods to load from/to files in different formats @@ -624,3 +614,16 @@ def access_path_to_dict(access_path: str, value) -> Dict[str, Any]: result = {field: value} value = result return result + + +def dict_to_access_paths(d: dict) -> Dict[str, Any]: + result = {} + for k, v in d.items(): + if isinstance(v, dict): + v = dict_to_access_paths(v) + for nested_k, nested_v in v.items(): + new_key = '.'.join([k, nested_k]) + result[new_key] = nested_v + else: + result[k] = v + return result diff --git a/tests/units/array/test_array_from_to_csv.py b/tests/units/array/test_array_from_to_csv.py index f102157181d..a8adb93ecc4 100644 --- a/tests/units/array/test_array_from_to_csv.py +++ b/tests/units/array/test_array_from_to_csv.py @@ -4,7 +4,7 @@ import pytest from docarray import BaseDocument, DocumentArray -from docarray.array.array.io import _assert_schema +from docarray.array.array.io import _assert_schema, dict_to_access_paths from docarray.documents import Image from tests import TOYDATA_DIR @@ -23,12 +23,27 @@ def test_to_csv(tmpdir): da = DocumentArray[MyDocNested]( [ MyDocNested( - text='hello', image=Image(url='aux.png'), image2=Image(url='aux.png') + count=0, + text='hello', + image=Image(url='aux.png'), + image2=Image(url='aux.png'), ), - MyDocNested(text='hello world', image=Image(), image2=Image()), + MyDocNested(count=2, text='hello world', image=Image(), image2=Image()), ] ) - tmp_file = str(tmpdir / 'tmp.csv') + tmp_file = '/Users/charlottegerhaher/Desktop/jina-ai/docarray_v2/docarray/tests/toydata/tmp.csv' # str(tmpdir / 'tmp.csv') + da.to_csv(tmp_file) + assert os.path.isfile(tmp_file) + + +def test_to_csv_(tmpdir): + da = DocumentArray[MyDoc]( + [ + MyDoc(count=0, text='hello'), + MyDoc(count=2, text='hello world'), + ] + ) + tmp_file = '/Users/charlottegerhaher/Desktop/jina-ai/docarray_v2/docarray/tests/toydata/tmp.csv' # str(tmpdir / 'tmp.csv') da.to_csv(tmp_file) assert os.path.isfile(tmp_file) @@ -96,3 +111,26 @@ def test_assert_schema(nested_doc): assert _assert_schema(nested_doc.__class__, 'middle.inner.img') assert _assert_schema(nested_doc.__class__, 'middle') assert not _assert_schema(nested_doc.__class__, 'inner') + + +def test_dict_to_access_paths(): + d = { + 'a0': {'b0': {'c0': 0}, 'b1': {'c0': 1}}, + 'a1': {'b0': {'c0': 2, 'c1': 3}, 'b1': 4}, + } + # d = { + # 'b0': {'c0': 2, 'c1': 3}, 'b1': 4, + # } + casted = dict_to_access_paths(d) + # assert casted == { + # 'b0.c0': 2, + # 'b0.c1': 3, + # 'b1': 4, + # } + assert casted == { + 'a0.b0.c0': 0, + 'a0.b1.c0': 1, + 'a1.b0.c0': 2, + 'a1.b0.c1': 3, + 'a1.b1': 4, + } diff --git a/tests/units/array/test_array_from_to_json.py b/tests/units/array/test_array_from_to_json.py index 79b98c43700..b46066156e5 100644 --- a/tests/units/array/test_array_from_to_json.py +++ b/tests/units/array/test_array_from_to_json.py @@ -1,9 +1,6 @@ -import pytest - -from docarray import BaseDocument -from docarray.typing import NdArray +from docarray import BaseDocument, DocumentArray from docarray.documents import Image -from docarray import DocumentArray +from docarray.typing import NdArray class MyDoc(BaseDocument): From 7bd7f1fa8bf223a233655a048840641bef56cd0d Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Mon, 20 Feb 2023 13:26:56 +0100 Subject: [PATCH 04/19] fix: from to csv Signed-off-by: anna-charlotte --- docarray/array/array/io.py | 12 ++++--- docarray/base_document/mixins/io.py | 23 +++++++++++++ tests/units/array/test_array_from_to_csv.py | 37 +++++++++++++-------- 3 files changed, 54 insertions(+), 18 deletions(-) diff --git a/docarray/array/array/io.py b/docarray/array/array/io.py index 97ea90045b5..ffe3f1a9fab 100644 --- a/docarray/array/array/io.py +++ b/docarray/array/array/io.py @@ -28,6 +28,7 @@ if TYPE_CHECKING: + from docarray import DocumentArray from docarray.proto import DocumentArrayProto T = TypeVar('T', bound='IOMixinArray') @@ -296,7 +297,7 @@ def to_json(self) -> str: return json.dumps([doc.json() for doc in self]) @classmethod - def from_csv(cls, file_path: str, encoding: str = 'utf-8') -> T: + def from_csv(cls, file_path: str, encoding: str = 'utf-8') -> 'DocumentArray': """ Load a DocumentArray from a csv file. @@ -306,7 +307,7 @@ def from_csv(cls, file_path: str, encoding: str = 'utf-8') -> T: """ from docarray import DocumentArray - doc_type = cls.document_type + doc_type: Type[BaseDocument] = cls.document_type if doc_type == AnyDocument: raise TypeError( "There is no document schema defined. " @@ -317,6 +318,9 @@ def from_csv(cls, file_path: str, encoding: str = 'utf-8') -> T: with open(file_path, 'r', encoding=encoding) as fp: lines = csv.DictReader(fp, dialect='excel') fields = lines.fieldnames + if fields is None: + raise TypeError("No field names are given.") + valid = [_assert_schema(doc_type, field) for field in fields] if not all(valid): raise ValueError( @@ -327,7 +331,6 @@ def from_csv(cls, file_path: str, encoding: str = 'utf-8') -> T: for line in lines: doc_dict = {} for field, value in line.items(): - print(f"field, value = {field, value}") if value in ['', 'None']: value = None doc_dict.update(access_path_to_dict(access_path=field, value=value)) @@ -338,10 +341,9 @@ def from_csv(cls, file_path: str, encoding: str = 'utf-8') -> T: def to_csv(self, file_path: str) -> None: """ Save a DocumentArray to a csv file. - :param file_path: path to a csv file. """ - fields = self.document_type.__fields__ + fields = self.document_type._get_access_paths() with open(file_path, 'w') as csv_file: writer = csv.DictWriter(csv_file, fieldnames=fields) diff --git a/docarray/base_document/mixins/io.py b/docarray/base_document/mixins/io.py index 25c21c3ffac..e3658eedcd4 100644 --- a/docarray/base_document/mixins/io.py +++ b/docarray/base_document/mixins/io.py @@ -7,12 +7,15 @@ Callable, Dict, Iterable, + List, Optional, Tuple, Type, TypeVar, ) +from typing_inspect import is_union_type + from docarray.base_document.base_node import BaseNode from docarray.typing.proto_register import _PROTO_TYPE_NAME_TO_CLASS from docarray.utils.compress import _compress_bytes, _decompress_bytes @@ -291,3 +294,23 @@ def _to_node_protobuf(self) -> 'NodeProto': :return: the nested item protobuf message """ return NodeProto(document=self.to_protobuf()) + + @classmethod + def _get_access_paths(cls) -> List[str]: + """ + Get dot-separated access paths of all fields, including nested ones. + + :return: list of all access paths + """ + from docarray import BaseDocument + + paths = [] + for field in cls.__fields__.keys(): + field_type = cls._get_field_type(field) + if not is_union_type(field_type) and issubclass(field_type, BaseDocument): + sub_paths = field_type._get_access_paths() + for path in sub_paths: + paths.append(f'{field}.{path}') + else: + paths.append(field) + return paths diff --git a/tests/units/array/test_array_from_to_csv.py b/tests/units/array/test_array_from_to_csv.py index a8adb93ecc4..0769282ae4b 100644 --- a/tests/units/array/test_array_from_to_csv.py +++ b/tests/units/array/test_array_from_to_csv.py @@ -10,7 +10,7 @@ class MyDoc(BaseDocument): - count: int + count: Optional[int] text: str @@ -28,7 +28,7 @@ def test_to_csv(tmpdir): image=Image(url='aux.png'), image2=Image(url='aux.png'), ), - MyDocNested(count=2, text='hello world', image=Image(), image2=Image()), + MyDocNested(text='hello world', image=Image(), image2=Image()), ] ) tmp_file = '/Users/charlottegerhaher/Desktop/jina-ai/docarray_v2/docarray/tests/toydata/tmp.csv' # str(tmpdir / 'tmp.csv') @@ -40,7 +40,7 @@ def test_to_csv_(tmpdir): da = DocumentArray[MyDoc]( [ MyDoc(count=0, text='hello'), - MyDoc(count=2, text='hello world'), + MyDoc(text='hello world'), ] ) tmp_file = '/Users/charlottegerhaher/Desktop/jina-ai/docarray_v2/docarray/tests/toydata/tmp.csv' # str(tmpdir / 'tmp.csv') @@ -93,15 +93,34 @@ class Outer(BaseDocument): return doc +def test_fields_to_access_paths(): + class Painting(BaseDocument): + title: str + img: Image + + access_paths = Painting._get_access_paths() + assert access_paths == [ + 'id', + 'title', + 'img.id', + 'img.url', + 'img.tensor', + 'img.embedding', + 'img.bytes', + ] + + def test_from_csv_without_schema_raise_exception(): with pytest.raises(TypeError, match='no document schema defined'): DocumentArray.from_csv(file_path=str(TOYDATA_DIR / 'docs_nested.csv')) def test_from_csv_with_wrong_schema_raise_exception(nested_doc): - with pytest.raises(ValueError, match=r'.*Outer.*embedding.*text.*image.*'): + with pytest.raises( + ValueError, match='Fields provided in the csv file do not match the schema' + ): DocumentArray[nested_doc.__class__].from_csv( - file_path=str(TOYDATA_DIR / 'docs_nested.csv') + file_path=str(TOYDATA_DIR / 'docs.csv') ) @@ -118,15 +137,7 @@ def test_dict_to_access_paths(): 'a0': {'b0': {'c0': 0}, 'b1': {'c0': 1}}, 'a1': {'b0': {'c0': 2, 'c1': 3}, 'b1': 4}, } - # d = { - # 'b0': {'c0': 2, 'c1': 3}, 'b1': 4, - # } casted = dict_to_access_paths(d) - # assert casted == { - # 'b0.c0': 2, - # 'b0.c1': 3, - # 'b1': 4, - # } assert casted == { 'a0.b0.c0': 0, 'a0.b1.c0': 1, From 05e19ba1f669c0607abea026ab8f32e3cc509ef6 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Mon, 20 Feb 2023 13:40:23 +0100 Subject: [PATCH 05/19] fix: clean up Signed-off-by: anna-charlotte --- docarray/array/array/io.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/docarray/array/array/io.py b/docarray/array/array/io.py index ffe3f1a9fab..998d86e88c6 100644 --- a/docarray/array/array/io.py +++ b/docarray/array/array/io.py @@ -299,7 +299,8 @@ def to_json(self) -> str: @classmethod def from_csv(cls, file_path: str, encoding: str = 'utf-8') -> 'DocumentArray': """ - Load a DocumentArray from a csv file. + Load a DocumentArray from a csv file following the schema defines in the + :attr:`~docarray.DocumentArray.document_type` attribute. :param file_path: path to csv file to load DocumentArray from. :param encoding: encoding used to read the csv file. Defaults to 'utf-8'. @@ -310,8 +311,8 @@ def from_csv(cls, file_path: str, encoding: str = 'utf-8') -> 'DocumentArray': doc_type: Type[BaseDocument] = cls.document_type if doc_type == AnyDocument: raise TypeError( - "There is no document schema defined. " - "To load from csv, please specify the DocumentArray's document type." + 'There is no document schema defined. ' + 'To load from csv, please specify the DocumentArray\'s document type.' ) da = DocumentArray[doc_type]() @@ -331,9 +332,12 @@ def from_csv(cls, file_path: str, encoding: str = 'utf-8') -> 'DocumentArray': for line in lines: doc_dict = {} for field, value in line.items(): - if value in ['', 'None']: - value = None - doc_dict.update(access_path_to_dict(access_path=field, value=value)) + doc_dict.update( + access_path_to_dict( + access_path=field, + value=value if value not in ['', 'None'] else None, + ) + ) da.append(doc_type.parse_obj(doc_dict)) return da From 5b0c7607206fa97fd7d79e13d2c3551250ceeeb4 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Mon, 20 Feb 2023 13:48:55 +0100 Subject: [PATCH 06/19] docs: add docstring and update tmpdir in test Signed-off-by: anna-charlotte --- docarray/array/array/io.py | 2 ++ tests/units/array/test_array_from_to_csv.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/docarray/array/array/io.py b/docarray/array/array/io.py index 998d86e88c6..13f1299b03d 100644 --- a/docarray/array/array/io.py +++ b/docarray/array/array/io.py @@ -301,6 +301,8 @@ def from_csv(cls, file_path: str, encoding: str = 'utf-8') -> 'DocumentArray': """ Load a DocumentArray from a csv file following the schema defines in the :attr:`~docarray.DocumentArray.document_type` attribute. + The column names have to match the fields of the document type. For nested fields + dot-separated access paths are expected, such as `'image.url'`. :param file_path: path to csv file to load DocumentArray from. :param encoding: encoding used to read the csv file. Defaults to 'utf-8'. diff --git a/tests/units/array/test_array_from_to_csv.py b/tests/units/array/test_array_from_to_csv.py index 0769282ae4b..67874d00bb7 100644 --- a/tests/units/array/test_array_from_to_csv.py +++ b/tests/units/array/test_array_from_to_csv.py @@ -31,7 +31,7 @@ def test_to_csv(tmpdir): MyDocNested(text='hello world', image=Image(), image2=Image()), ] ) - tmp_file = '/Users/charlottegerhaher/Desktop/jina-ai/docarray_v2/docarray/tests/toydata/tmp.csv' # str(tmpdir / 'tmp.csv') + tmp_file = str(tmpdir / 'tmp.csv') da.to_csv(tmp_file) assert os.path.isfile(tmp_file) @@ -43,7 +43,7 @@ def test_to_csv_(tmpdir): MyDoc(text='hello world'), ] ) - tmp_file = '/Users/charlottegerhaher/Desktop/jina-ai/docarray_v2/docarray/tests/toydata/tmp.csv' # str(tmpdir / 'tmp.csv') + tmp_file = str(tmpdir / 'tmp.csv') da.to_csv(tmp_file) assert os.path.isfile(tmp_file) From 5babdae0e4e71cc0afb641cd98291f8cc1bba8d2 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Mon, 20 Feb 2023 15:06:09 +0100 Subject: [PATCH 07/19] fix: merge nested dicts Signed-off-by: anna-charlotte --- docarray/array/array/io.py | 29 +++++++++++++++++---- tests/toydata/docs_nested.csv | 2 +- tests/units/array/test_array_from_to_csv.py | 23 +++++++++++++--- 3 files changed, 45 insertions(+), 9 deletions(-) diff --git a/docarray/array/array/io.py b/docarray/array/array/io.py index 13f1299b03d..9350db742eb 100644 --- a/docarray/array/array/io.py +++ b/docarray/array/array/io.py @@ -334,12 +334,12 @@ def from_csv(cls, file_path: str, encoding: str = 'utf-8') -> 'DocumentArray': for line in lines: doc_dict = {} for field, value in line.items(): - doc_dict.update( - access_path_to_dict( - access_path=field, - value=value if value not in ['', 'None'] else None, - ) + x = access_path_to_dict( + access_path=field, + value=value if value not in ['', 'None'] else None, ) + doc_dict = merge_nested_dicts(d1=doc_dict, d2=x) + da.append(doc_type.parse_obj(doc_dict)) return da @@ -635,3 +635,22 @@ def dict_to_access_paths(d: dict) -> Dict[str, Any]: else: result[k] = v return result + + +def merge_nested_dicts(d1: Dict[Any, Any], d2: Dict[Any, Any]) -> None: + """ + Merge two dictionaries, while considering shared nested keys. + + :param d1: first dict + :param d2: second dict + :return: merged dict + """ + import copy + + result = copy.deepcopy(d1) + for k, v in d2.items(): + if k not in result.keys(): + result[k] = v + else: + result[k] = merge_nested_dicts(d1[k], d2[k]) + return result diff --git a/tests/toydata/docs_nested.csv b/tests/toydata/docs_nested.csv index d530ff50041..f414972aeb0 100644 --- a/tests/toydata/docs_nested.csv +++ b/tests/toydata/docs_nested.csv @@ -1,4 +1,4 @@ count,text,image,image2.url 000,hello 0,image_0.png,image_10.png -111,hello 1,image_1.png,image_11.png +111,hello 1,image_1.png,None 222,hello 2,image_2.png, \ No newline at end of file diff --git a/tests/units/array/test_array_from_to_csv.py b/tests/units/array/test_array_from_to_csv.py index 67874d00bb7..1c29a298856 100644 --- a/tests/units/array/test_array_from_to_csv.py +++ b/tests/units/array/test_array_from_to_csv.py @@ -4,7 +4,11 @@ import pytest from docarray import BaseDocument, DocumentArray -from docarray.array.array.io import _assert_schema, dict_to_access_paths +from docarray.array.array.io import ( + _assert_schema, + dict_to_access_paths, + merge_nested_dicts, +) from docarray.documents import Image from tests import TOYDATA_DIR @@ -19,7 +23,7 @@ class MyDocNested(MyDoc): image2: Image -def test_to_csv(tmpdir): +def test_to_from_csv(tmpdir): da = DocumentArray[MyDocNested]( [ MyDocNested( @@ -35,6 +39,10 @@ def test_to_csv(tmpdir): da.to_csv(tmp_file) assert os.path.isfile(tmp_file) + da_from = DocumentArray[MyDocNested].from_csv(tmp_file) + for doc1, doc2 in zip(da, da_from): + assert doc1 == doc2 + def test_to_csv_(tmpdir): da = DocumentArray[MyDoc]( @@ -55,6 +63,7 @@ def test_from_csv_nested(): assert len(da) == 3 for i, doc in enumerate(da): + print(f"doc.count = {doc.count}") assert doc.count.__class__ == int assert doc.count == int(f'{i}{i}{i}') @@ -72,7 +81,7 @@ def test_from_csv_nested(): assert doc.image2.bytes is None assert da[0].image2.url == 'image_10.png' - assert da[1].image2.url == 'image_11.png' + assert da[1].image2.url is None assert da[2].image2.url is None @@ -145,3 +154,11 @@ def test_dict_to_access_paths(): 'a1.b0.c1': 3, 'a1.b1': 4, } + + +def test_update_nested_dict(): + d1 = {'text': 'hello', 'image': {'tensor': None}} + d2 = {'image': {'url': 'some.png'}} + + merged = merge_nested_dicts(d1, d2) + assert merged == {'text': 'hello', 'image': {'tensor': None, 'url': 'some.png'}} From d6b29c5e54b39f817012b410440c1668094ffd1c Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Mon, 20 Feb 2023 15:28:24 +0100 Subject: [PATCH 08/19] fix: clean up Signed-off-by: anna-charlotte --- docarray/array/array/io.py | 46 ++++++++++++--------- tests/units/array/test_array_from_to_csv.py | 18 ++++---- 2 files changed, 35 insertions(+), 29 deletions(-) diff --git a/docarray/array/array/io.py b/docarray/array/array/io.py index 9350db742eb..7a8ac0a56d5 100644 --- a/docarray/array/array/io.py +++ b/docarray/array/array/io.py @@ -299,10 +299,10 @@ def to_json(self) -> str: @classmethod def from_csv(cls, file_path: str, encoding: str = 'utf-8') -> 'DocumentArray': """ - Load a DocumentArray from a csv file following the schema defines in the + Load a DocumentArray from a csv file following the schema defined in the :attr:`~docarray.DocumentArray.document_type` attribute. - The column names have to match the fields of the document type. For nested fields - dot-separated access paths are expected, such as `'image.url'`. + The column names have to match the fields of the document type. + For nested fields use dot-separated access paths, such as 'image.url'. :param file_path: path to csv file to load DocumentArray from. :param encoding: encoding used to read the csv file. Defaults to 'utf-8'. @@ -324,7 +324,7 @@ def from_csv(cls, file_path: str, encoding: str = 'utf-8') -> 'DocumentArray': if fields is None: raise TypeError("No field names are given.") - valid = [_assert_schema(doc_type, field) for field in fields] + valid = [is_access_path_valid(doc_type, field) for field in fields] if not all(valid): raise ValueError( f'Fields provided in the csv file do not match the schema of the DocumentArray\'s ' @@ -334,11 +334,11 @@ def from_csv(cls, file_path: str, encoding: str = 'utf-8') -> 'DocumentArray': for line in lines: doc_dict = {} for field, value in line.items(): - x = access_path_to_dict( + field2val = access_path_to_dict( access_path=field, value=value if value not in ['', 'None'] else None, ) - doc_dict = merge_nested_dicts(d1=doc_dict, d2=x) + update_nested_dicts(to_update=doc_dict, update_with=field2val) da.append(doc_type.parse_obj(doc_dict)) @@ -600,7 +600,7 @@ def save_binary( ) -def _assert_schema(doc: Type['BaseDocument'], field_name: str) -> bool: +def is_access_path_valid(doc: Type['BaseDocument'], field_name: str) -> bool: field, _, remaining = field_name.partition('.') if len(remaining) == 0: return field_name in doc.__fields__.keys() @@ -613,7 +613,7 @@ def _assert_schema(doc: Type['BaseDocument'], field_name: str) -> bool: if not issubclass(d, BaseDocument): return False else: - return _assert_schema(d, remaining) + return is_access_path_valid(d, remaining) def access_path_to_dict(access_path: str, value) -> Dict[str, Any]: @@ -637,20 +637,26 @@ def dict_to_access_paths(d: dict) -> Dict[str, Any]: return result -def merge_nested_dicts(d1: Dict[Any, Any], d2: Dict[Any, Any]) -> None: +def update_nested_dicts(to_update: Dict[Any, Any], update_with: Dict[Any, Any]) -> None: """ - Merge two dictionaries, while considering shared nested keys. + Update a dict with another one, while considering shared nested keys. - :param d1: first dict - :param d2: second dict + EXAMPLE USAGE: + + .. code-block:: python + + d1 = {'image': {'tensor': None}, 'title': 'hello'} + d2 = {'image': {'url': 'some.png'}} + + update_nested_dicts(d1, d2) + assert d1 == {'image': {'tensor': None, 'url': 'some.png'}, 'title': 'hello'} + + :param to_update: dict that should be updated + :param update_with: dict to update with :return: merged dict """ - import copy - - result = copy.deepcopy(d1) - for k, v in d2.items(): - if k not in result.keys(): - result[k] = v + for k, v in update_with.items(): + if k not in to_update.keys(): + to_update[k] = v else: - result[k] = merge_nested_dicts(d1[k], d2[k]) - return result + update_nested_dicts(to_update[k], update_with[k]) diff --git a/tests/units/array/test_array_from_to_csv.py b/tests/units/array/test_array_from_to_csv.py index 1c29a298856..73f1ceb9be8 100644 --- a/tests/units/array/test_array_from_to_csv.py +++ b/tests/units/array/test_array_from_to_csv.py @@ -5,9 +5,9 @@ from docarray import BaseDocument, DocumentArray from docarray.array.array.io import ( - _assert_schema, dict_to_access_paths, - merge_nested_dicts, + is_access_path_valid, + update_nested_dicts, ) from docarray.documents import Image from tests import TOYDATA_DIR @@ -134,11 +134,11 @@ def test_from_csv_with_wrong_schema_raise_exception(nested_doc): def test_assert_schema(nested_doc): - assert _assert_schema(nested_doc.__class__, 'img') - assert _assert_schema(nested_doc.__class__, 'middle.img') - assert _assert_schema(nested_doc.__class__, 'middle.inner.img') - assert _assert_schema(nested_doc.__class__, 'middle') - assert not _assert_schema(nested_doc.__class__, 'inner') + assert is_access_path_valid(nested_doc.__class__, 'img') + assert is_access_path_valid(nested_doc.__class__, 'middle.img') + assert is_access_path_valid(nested_doc.__class__, 'middle.inner.img') + assert is_access_path_valid(nested_doc.__class__, 'middle') + assert not is_access_path_valid(nested_doc.__class__, 'inner') def test_dict_to_access_paths(): @@ -160,5 +160,5 @@ def test_update_nested_dict(): d1 = {'text': 'hello', 'image': {'tensor': None}} d2 = {'image': {'url': 'some.png'}} - merged = merge_nested_dicts(d1, d2) - assert merged == {'text': 'hello', 'image': {'tensor': None, 'url': 'some.png'}} + update_nested_dicts(d1, d2) + assert d1 == {'text': 'hello', 'image': {'tensor': None, 'url': 'some.png'}} From 9149e57cad74273212db654f3d79bd451b70442c Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Mon, 20 Feb 2023 15:44:35 +0100 Subject: [PATCH 09/19] fix: clean up Signed-off-by: anna-charlotte --- docarray/array/array/io.py | 45 +++++++++++++++------ tests/units/array/test_array_from_to_csv.py | 8 ++-- 2 files changed, 36 insertions(+), 17 deletions(-) diff --git a/docarray/array/array/io.py b/docarray/array/array/io.py index 7a8ac0a56d5..f047a7d6f7c 100644 --- a/docarray/array/array/io.py +++ b/docarray/array/array/io.py @@ -317,7 +317,7 @@ def from_csv(cls, file_path: str, encoding: str = 'utf-8') -> 'DocumentArray': 'To load from csv, please specify the DocumentArray\'s document type.' ) - da = DocumentArray[doc_type]() + da = DocumentArray[doc_type]() # type: ignore with open(file_path, 'r', encoding=encoding) as fp: lines = csv.DictReader(fp, dialect='excel') fields = lines.fieldnames @@ -332,13 +332,13 @@ def from_csv(cls, file_path: str, encoding: str = 'utf-8') -> 'DocumentArray': ) for line in lines: - doc_dict = {} + doc_dict: Dict[Any, Any] = {} for field, value in line.items(): - field2val = access_path_to_dict( + field2val = _access_path_to_dict( access_path=field, value=value if value not in ['', 'None'] else None, ) - update_nested_dicts(to_update=doc_dict, update_with=field2val) + _update_nested_dicts(to_update=doc_dict, update_with=field2val) da.append(doc_type.parse_obj(doc_dict)) @@ -356,7 +356,7 @@ def to_csv(self, file_path: str) -> None: writer.writeheader() for doc in self: - doc_dict = dict_to_access_paths(doc.dict()) + doc_dict = _dict_to_access_paths(doc.dict()) writer.writerow(doc_dict) # Methods to load from/to files in different formats @@ -600,10 +600,13 @@ def save_binary( ) -def is_access_path_valid(doc: Type['BaseDocument'], field_name: str) -> bool: - field, _, remaining = field_name.partition('.') +def is_access_path_valid(doc: Type['BaseDocument'], access_path: str) -> bool: + """ + Check if a given access path is a valid path for a given Document class. + """ + field, _, remaining = access_path.partition('.') if len(remaining) == 0: - return field_name in doc.__fields__.keys() + return access_path in doc.__fields__.keys() else: valid_field = field in doc.__fields__.keys() if not valid_field: @@ -616,7 +619,14 @@ def is_access_path_valid(doc: Type['BaseDocument'], field_name: str) -> bool: return is_access_path_valid(d, remaining) -def access_path_to_dict(access_path: str, value) -> Dict[str, Any]: +def _access_path_to_dict(access_path: str, value) -> Dict[str, Any]: + """ + Convert an access path and its value to a (potentially) nested dict. + + EXAMPLE USAGE + .. code-block:: python + assert access_path_to_dict('image.url', 'some.url') == {'image': {'url': 'some.url'}} + """ fields = access_path.split('.') for field in reversed(fields): result = {field: value} @@ -624,11 +634,18 @@ def access_path_to_dict(access_path: str, value) -> Dict[str, Any]: return result -def dict_to_access_paths(d: dict) -> Dict[str, Any]: +def _dict_to_access_paths(d: dict) -> Dict[str, Any]: + """ + Convert a (nested) dict to a Dict[access_path, value]. + + EXAMPLE USAGE + .. code-block:: python + assert dict_to_access_paths({'image': {'url': 'some.url'}}) == {'image.url', 'some.url'} + """ result = {} for k, v in d.items(): if isinstance(v, dict): - v = dict_to_access_paths(v) + v = _dict_to_access_paths(v) for nested_k, nested_v in v.items(): new_key = '.'.join([k, nested_k]) result[new_key] = nested_v @@ -637,7 +654,9 @@ def dict_to_access_paths(d: dict) -> Dict[str, Any]: return result -def update_nested_dicts(to_update: Dict[Any, Any], update_with: Dict[Any, Any]) -> None: +def _update_nested_dicts( + to_update: Dict[Any, Any], update_with: Dict[Any, Any] +) -> None: """ Update a dict with another one, while considering shared nested keys. @@ -659,4 +678,4 @@ def update_nested_dicts(to_update: Dict[Any, Any], update_with: Dict[Any, Any]) if k not in to_update.keys(): to_update[k] = v else: - update_nested_dicts(to_update[k], update_with[k]) + _update_nested_dicts(to_update[k], update_with[k]) diff --git a/tests/units/array/test_array_from_to_csv.py b/tests/units/array/test_array_from_to_csv.py index 73f1ceb9be8..f3958b6de2b 100644 --- a/tests/units/array/test_array_from_to_csv.py +++ b/tests/units/array/test_array_from_to_csv.py @@ -5,9 +5,9 @@ from docarray import BaseDocument, DocumentArray from docarray.array.array.io import ( - dict_to_access_paths, + _dict_to_access_paths, + _update_nested_dicts, is_access_path_valid, - update_nested_dicts, ) from docarray.documents import Image from tests import TOYDATA_DIR @@ -146,7 +146,7 @@ def test_dict_to_access_paths(): 'a0': {'b0': {'c0': 0}, 'b1': {'c0': 1}}, 'a1': {'b0': {'c0': 2, 'c1': 3}, 'b1': 4}, } - casted = dict_to_access_paths(d) + casted = _dict_to_access_paths(d) assert casted == { 'a0.b0.c0': 0, 'a0.b1.c0': 1, @@ -160,5 +160,5 @@ def test_update_nested_dict(): d1 = {'text': 'hello', 'image': {'tensor': None}} d2 = {'image': {'url': 'some.png'}} - update_nested_dicts(d1, d2) + _update_nested_dicts(d1, d2) assert d1 == {'text': 'hello', 'image': {'tensor': None, 'url': 'some.png'}} From 3d980b73dae784d3e35d1626b95c32fa2bfd7ab1 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Mon, 20 Feb 2023 15:49:04 +0100 Subject: [PATCH 10/19] test: update test Signed-off-by: anna-charlotte --- tests/units/array/test_array_from_to_csv.py | 45 ++++++++------------- 1 file changed, 17 insertions(+), 28 deletions(-) diff --git a/tests/units/array/test_array_from_to_csv.py b/tests/units/array/test_array_from_to_csv.py index f3958b6de2b..6f58992decd 100644 --- a/tests/units/array/test_array_from_to_csv.py +++ b/tests/units/array/test_array_from_to_csv.py @@ -44,18 +44,6 @@ def test_to_from_csv(tmpdir): assert doc1 == doc2 -def test_to_csv_(tmpdir): - da = DocumentArray[MyDoc]( - [ - MyDoc(count=0, text='hello'), - MyDoc(text='hello world'), - ] - ) - tmp_file = str(tmpdir / 'tmp.csv') - da.to_csv(tmp_file) - assert os.path.isfile(tmp_file) - - def test_from_csv_nested(): da = DocumentArray[MyDocNested].from_csv( file_path=str(TOYDATA_DIR / 'docs_nested.csv') @@ -102,7 +90,21 @@ class Outer(BaseDocument): return doc -def test_fields_to_access_paths(): +def test_from_csv_without_schema_raise_exception(): + with pytest.raises(TypeError, match='no document schema defined'): + DocumentArray.from_csv(file_path=str(TOYDATA_DIR / 'docs_nested.csv')) + + +def test_from_csv_with_wrong_schema_raise_exception(nested_doc): + with pytest.raises( + ValueError, match='Fields provided in the csv file do not match the schema' + ): + DocumentArray[nested_doc.__class__].from_csv( + file_path=str(TOYDATA_DIR / 'docs.csv') + ) + + +def test_get_access_paths(): class Painting(BaseDocument): title: str img: Image @@ -119,26 +121,13 @@ class Painting(BaseDocument): ] -def test_from_csv_without_schema_raise_exception(): - with pytest.raises(TypeError, match='no document schema defined'): - DocumentArray.from_csv(file_path=str(TOYDATA_DIR / 'docs_nested.csv')) - - -def test_from_csv_with_wrong_schema_raise_exception(nested_doc): - with pytest.raises( - ValueError, match='Fields provided in the csv file do not match the schema' - ): - DocumentArray[nested_doc.__class__].from_csv( - file_path=str(TOYDATA_DIR / 'docs.csv') - ) - - -def test_assert_schema(nested_doc): +def test_is_access_path_valid(nested_doc): assert is_access_path_valid(nested_doc.__class__, 'img') assert is_access_path_valid(nested_doc.__class__, 'middle.img') assert is_access_path_valid(nested_doc.__class__, 'middle.inner.img') assert is_access_path_valid(nested_doc.__class__, 'middle') assert not is_access_path_valid(nested_doc.__class__, 'inner') + assert not is_access_path_valid(nested_doc.__class__, 'some.other.path') def test_dict_to_access_paths(): From 516ffbbbe133d8faf3cec8b4e5601fa133dff9c6 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Mon, 20 Feb 2023 16:49:08 +0100 Subject: [PATCH 11/19] fix: apply samis suggestion from code review Signed-off-by: anna-charlotte --- tests/units/array/test_array_from_to_csv.py | 29 ++++++++++++--------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/tests/units/array/test_array_from_to_csv.py b/tests/units/array/test_array_from_to_csv.py index 6f58992decd..d22f4a45d64 100644 --- a/tests/units/array/test_array_from_to_csv.py +++ b/tests/units/array/test_array_from_to_csv.py @@ -13,39 +13,42 @@ from tests import TOYDATA_DIR -class MyDoc(BaseDocument): - count: Optional[int] - text: str +@pytest.fixture() +def nested_doc_cls(): + class MyDoc(BaseDocument): + count: Optional[int] + text: str + class MyDocNested(MyDoc): + image: Image + image2: Image -class MyDocNested(MyDoc): - image: Image - image2: Image + return MyDocNested -def test_to_from_csv(tmpdir): - da = DocumentArray[MyDocNested]( +def test_to_from_csv(tmpdir, nested_doc_cls): + da = DocumentArray[nested_doc_cls]( [ - MyDocNested( + nested_doc_cls( count=0, text='hello', image=Image(url='aux.png'), image2=Image(url='aux.png'), ), - MyDocNested(text='hello world', image=Image(), image2=Image()), + nested_doc_cls(text='hello world', image=Image(), image2=Image()), ] ) tmp_file = str(tmpdir / 'tmp.csv') da.to_csv(tmp_file) assert os.path.isfile(tmp_file) - da_from = DocumentArray[MyDocNested].from_csv(tmp_file) + da_from = DocumentArray[nested_doc_cls].from_csv(tmp_file) for doc1, doc2 in zip(da, da_from): assert doc1 == doc2 -def test_from_csv_nested(): - da = DocumentArray[MyDocNested].from_csv( +def test_from_csv_nested(nested_doc_cls): + da = DocumentArray[nested_doc_cls].from_csv( file_path=str(TOYDATA_DIR / 'docs_nested.csv') ) assert len(da) == 3 From a0d9711e8d99dec1a022826e12452b59e3f143d7 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Tue, 21 Feb 2023 09:12:15 +0100 Subject: [PATCH 12/19] fix: apply suggestions from code review wrt access paths Signed-off-by: anna-charlotte --- docarray/array/array/io.py | 13 +++++----- docarray/base_document/mixins/io.py | 4 +-- tests/toydata/docs_nested.csv | 2 +- tests/units/array/test_array_from_to_csv.py | 28 ++++++++++----------- 4 files changed, 24 insertions(+), 23 deletions(-) diff --git a/docarray/array/array/io.py b/docarray/array/array/io.py index f047a7d6f7c..197c127b859 100644 --- a/docarray/array/array/io.py +++ b/docarray/array/array/io.py @@ -302,7 +302,7 @@ def from_csv(cls, file_path: str, encoding: str = 'utf-8') -> 'DocumentArray': Load a DocumentArray from a csv file following the schema defined in the :attr:`~docarray.DocumentArray.document_type` attribute. The column names have to match the fields of the document type. - For nested fields use dot-separated access paths, such as 'image.url'. + For nested fields use dot-separated access paths, such as 'image__url'. :param file_path: path to csv file to load DocumentArray from. :param encoding: encoding used to read the csv file. Defaults to 'utf-8'. @@ -604,7 +604,7 @@ def is_access_path_valid(doc: Type['BaseDocument'], access_path: str) -> bool: """ Check if a given access path is a valid path for a given Document class. """ - field, _, remaining = access_path.partition('.') + field, _, remaining = access_path.partition('__') if len(remaining) == 0: return access_path in doc.__fields__.keys() else: @@ -625,9 +625,9 @@ def _access_path_to_dict(access_path: str, value) -> Dict[str, Any]: EXAMPLE USAGE .. code-block:: python - assert access_path_to_dict('image.url', 'some.url') == {'image': {'url': 'some.url'}} + assert access_path_to_dict('image__url', 'img.png') == {'image': {'url': 'img.png'}} """ - fields = access_path.split('.') + fields = access_path.split('__') for field in reversed(fields): result = {field: value} value = result @@ -637,17 +637,18 @@ def _access_path_to_dict(access_path: str, value) -> Dict[str, Any]: def _dict_to_access_paths(d: dict) -> Dict[str, Any]: """ Convert a (nested) dict to a Dict[access_path, value]. + Access paths are defines as a path of field(s) separated by "__". EXAMPLE USAGE .. code-block:: python - assert dict_to_access_paths({'image': {'url': 'some.url'}}) == {'image.url', 'some.url'} + assert dict_to_access_paths({'image': {'url': 'img.png'}}) == {'image__url', 'img.png'} """ result = {} for k, v in d.items(): if isinstance(v, dict): v = _dict_to_access_paths(v) for nested_k, nested_v in v.items(): - new_key = '.'.join([k, nested_k]) + new_key = '__'.join([k, nested_k]) result[new_key] = nested_v else: result[k] = v diff --git a/docarray/base_document/mixins/io.py b/docarray/base_document/mixins/io.py index e3658eedcd4..fdad8648f06 100644 --- a/docarray/base_document/mixins/io.py +++ b/docarray/base_document/mixins/io.py @@ -298,7 +298,7 @@ def _to_node_protobuf(self) -> 'NodeProto': @classmethod def _get_access_paths(cls) -> List[str]: """ - Get dot-separated access paths of all fields, including nested ones. + Get "__"-separated access paths of all fields, including nested ones. :return: list of all access paths """ @@ -310,7 +310,7 @@ def _get_access_paths(cls) -> List[str]: if not is_union_type(field_type) and issubclass(field_type, BaseDocument): sub_paths = field_type._get_access_paths() for path in sub_paths: - paths.append(f'{field}.{path}') + paths.append(f'{field}__{path}') else: paths.append(field) return paths diff --git a/tests/toydata/docs_nested.csv b/tests/toydata/docs_nested.csv index f414972aeb0..7b857870244 100644 --- a/tests/toydata/docs_nested.csv +++ b/tests/toydata/docs_nested.csv @@ -1,4 +1,4 @@ -count,text,image,image2.url +count,text,image,image2__url 000,hello 0,image_0.png,image_10.png 111,hello 1,image_1.png,None 222,hello 2,image_2.png, \ No newline at end of file diff --git a/tests/units/array/test_array_from_to_csv.py b/tests/units/array/test_array_from_to_csv.py index d22f4a45d64..cfbc6e38be2 100644 --- a/tests/units/array/test_array_from_to_csv.py +++ b/tests/units/array/test_array_from_to_csv.py @@ -54,7 +54,6 @@ def test_from_csv_nested(nested_doc_cls): assert len(da) == 3 for i, doc in enumerate(da): - print(f"doc.count = {doc.count}") assert doc.count.__class__ == int assert doc.count == int(f'{i}{i}{i}') @@ -116,21 +115,22 @@ class Painting(BaseDocument): assert access_paths == [ 'id', 'title', - 'img.id', - 'img.url', - 'img.tensor', - 'img.embedding', - 'img.bytes', + 'img__id', + 'img__url', + 'img__tensor', + 'img__embedding', + 'img__bytes', ] def test_is_access_path_valid(nested_doc): assert is_access_path_valid(nested_doc.__class__, 'img') - assert is_access_path_valid(nested_doc.__class__, 'middle.img') - assert is_access_path_valid(nested_doc.__class__, 'middle.inner.img') + assert is_access_path_valid(nested_doc.__class__, 'middle__img') + assert is_access_path_valid(nested_doc.__class__, 'middle__inner__img') assert is_access_path_valid(nested_doc.__class__, 'middle') assert not is_access_path_valid(nested_doc.__class__, 'inner') - assert not is_access_path_valid(nested_doc.__class__, 'some.other.path') + assert not is_access_path_valid(nested_doc.__class__, 'some__other__path') + assert not is_access_path_valid(nested_doc.__class__, 'middle.inner') def test_dict_to_access_paths(): @@ -140,11 +140,11 @@ def test_dict_to_access_paths(): } casted = _dict_to_access_paths(d) assert casted == { - 'a0.b0.c0': 0, - 'a0.b1.c0': 1, - 'a1.b0.c0': 2, - 'a1.b0.c1': 3, - 'a1.b1': 4, + 'a0__b0__c0': 0, + 'a0__b1__c0': 1, + 'a1__b0__c0': 2, + 'a1__b0__c1': 3, + 'a1__b1': 4, } From c9005b14b960c65145faf54520cc039ab20d0a09 Mon Sep 17 00:00:00 2001 From: Charlotte Gerhaher Date: Tue, 21 Feb 2023 10:24:20 +0100 Subject: [PATCH 13/19] fix: apply johannes suggestion Co-authored-by: Johannes Messner <44071807+JohannesMessner@users.noreply.github.com> Signed-off-by: Charlotte Gerhaher --- docarray/array/array/io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docarray/array/array/io.py b/docarray/array/array/io.py index 197c127b859..2e11b4b8e68 100644 --- a/docarray/array/array/io.py +++ b/docarray/array/array/io.py @@ -301,7 +301,7 @@ def from_csv(cls, file_path: str, encoding: str = 'utf-8') -> 'DocumentArray': """ Load a DocumentArray from a csv file following the schema defined in the :attr:`~docarray.DocumentArray.document_type` attribute. - The column names have to match the fields of the document type. + The column names have to match the field names of the Document type. For nested fields use dot-separated access paths, such as 'image__url'. :param file_path: path to csv file to load DocumentArray from. From 9fe58f5f44671646bdc62263ce2362b1e966b8df Mon Sep 17 00:00:00 2001 From: Charlotte Gerhaher Date: Tue, 21 Feb 2023 10:26:17 +0100 Subject: [PATCH 14/19] fix: apply johannes suggestion Co-authored-by: Johannes Messner <44071807+JohannesMessner@users.noreply.github.com> Signed-off-by: Charlotte Gerhaher --- docarray/array/array/io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docarray/array/array/io.py b/docarray/array/array/io.py index 2e11b4b8e68..4c115cc36d4 100644 --- a/docarray/array/array/io.py +++ b/docarray/array/array/io.py @@ -314,7 +314,7 @@ def from_csv(cls, file_path: str, encoding: str = 'utf-8') -> 'DocumentArray': if doc_type == AnyDocument: raise TypeError( 'There is no document schema defined. ' - 'To load from csv, please specify the DocumentArray\'s document type.' + 'To load from csv, please specify the DocumentArray\'s Document type using `DocumentArray[MyDoc]`.' ) da = DocumentArray[doc_type]() # type: ignore From 90867bbb9d7ce746b41daf9b35cda5c8568156d7 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Tue, 21 Feb 2023 11:13:30 +0100 Subject: [PATCH 15/19] fix: apply suggestions from code review Signed-off-by: anna-charlotte --- docarray/array/array/io.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/docarray/array/array/io.py b/docarray/array/array/io.py index 4c115cc36d4..e2d8712d314 100644 --- a/docarray/array/array/io.py +++ b/docarray/array/array/io.py @@ -16,6 +16,7 @@ Dict, Generator, Iterable, + List, Optional, Tuple, Type, @@ -33,7 +34,6 @@ T = TypeVar('T', bound='IOMixinArray') - ARRAY_PROTOCOLS = {'protobuf-array', 'pickle-array'} SINGLE_PROTOCOLS = {'pickle', 'protobuf'} ALLOWED_PROTOCOLS = ARRAY_PROTOCOLS.union(SINGLE_PROTOCOLS) @@ -297,7 +297,12 @@ def to_json(self) -> str: return json.dumps([doc.json() for doc in self]) @classmethod - def from_csv(cls, file_path: str, encoding: str = 'utf-8') -> 'DocumentArray': + def from_csv( + cls, + file_path: str, + encoding: str = 'utf-8', + dialect: Union[str, csv.Dialect] = 'excel', + ) -> 'DocumentArray': """ Load a DocumentArray from a csv file following the schema defined in the :attr:`~docarray.DocumentArray.document_type` attribute. @@ -306,21 +311,27 @@ def from_csv(cls, file_path: str, encoding: str = 'utf-8') -> 'DocumentArray': :param file_path: path to csv file to load DocumentArray from. :param encoding: encoding used to read the csv file. Defaults to 'utf-8'. + :param dialect: defines separator and how to handle whitespaces etc. + Can be a csv.Dialect instance or one string of: + 'excel' (for comma seperated values), + 'excel-tab' (for tab separated values), + 'unix' (for csv file generated on UNIX systems). :return: DocumentArray """ from docarray import DocumentArray - doc_type: Type[BaseDocument] = cls.document_type + doc_type = cls.document_type if doc_type == AnyDocument: raise TypeError( 'There is no document schema defined. ' 'To load from csv, please specify the DocumentArray\'s Document type using `DocumentArray[MyDoc]`.' ) - da = DocumentArray[doc_type]() # type: ignore + da = DocumentArray.__class_getitem__(doc_type)() with open(file_path, 'r', encoding=encoding) as fp: - lines = csv.DictReader(fp, dialect='excel') - fields = lines.fieldnames + rows: csv.DictReader = csv.DictReader(fp, dialect=dialect) + fields: Optional[List[str]] = rows.fieldnames + if fields is None: raise TypeError("No field names are given.") @@ -331,9 +342,9 @@ def from_csv(cls, file_path: str, encoding: str = 'utf-8') -> 'DocumentArray': f'document type ({doc_type.__name__}): {list(compress(fields, [not v for v in valid]))}' ) - for line in lines: + for row in rows: doc_dict: Dict[Any, Any] = {} - for field, value in line.items(): + for field, value in row.items(): field2val = _access_path_to_dict( access_path=field, value=value if value not in ['', 'None'] else None, From 1a395da5b36f3798f70b0b9da8465d1038f2e380 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Tue, 21 Feb 2023 11:43:19 +0100 Subject: [PATCH 16/19] fix: apply suggestions from code review Signed-off-by: anna-charlotte --- docarray/array/array/io.py | 42 ++++++++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/docarray/array/array/io.py b/docarray/array/array/io.py index e2d8712d314..ad4f3f49fda 100644 --- a/docarray/array/array/io.py +++ b/docarray/array/array/io.py @@ -16,8 +16,8 @@ Dict, Generator, Iterable, - List, Optional, + Sequence, Tuple, Type, TypeVar, @@ -306,8 +306,12 @@ def from_csv( """ Load a DocumentArray from a csv file following the schema defined in the :attr:`~docarray.DocumentArray.document_type` attribute. - The column names have to match the field names of the Document type. - For nested fields use dot-separated access paths, such as 'image__url'. + Every access_path2val of the csv file will be mapped to one document in the array. + The column names (defined in the first access_path2val) have to match the field names + of the Document type. + For nested field_names use "__"-separated access paths, such as 'image__url'. + + List-like field_names (including DocumentArray) are not supported :param file_path: path to csv file to load DocumentArray from. :param encoding: encoding used to read the csv file. Defaults to 'utf-8'. @@ -330,23 +334,23 @@ def from_csv( da = DocumentArray.__class_getitem__(doc_type)() with open(file_path, 'r', encoding=encoding) as fp: rows: csv.DictReader = csv.DictReader(fp, dialect=dialect) - fields: Optional[List[str]] = rows.fieldnames + field_names: Optional[Sequence[Any]] = rows.fieldnames - if fields is None: + if field_names is None: raise TypeError("No field names are given.") - valid = [is_access_path_valid(doc_type, field) for field in fields] + valid = [is_access_path_valid(doc_type, field) for field in field_names] if not all(valid): raise ValueError( f'Fields provided in the csv file do not match the schema of the DocumentArray\'s ' - f'document type ({doc_type.__name__}): {list(compress(fields, [not v for v in valid]))}' + f'document type ({doc_type.__name__}): {list(compress(field_names, [not v for v in valid]))}' ) - for row in rows: + for access_path2val in rows: doc_dict: Dict[Any, Any] = {} - for field, value in row.items(): + for access_path, value in access_path2val.items(): field2val = _access_path_to_dict( - access_path=field, + access_path=access_path, value=value if value not in ['', 'None'] else None, ) _update_nested_dicts(to_update=doc_dict, update_with=field2val) @@ -355,15 +359,23 @@ def from_csv( return da - def to_csv(self, file_path: str) -> None: + def to_csv( + self, file_path: str, dialect: Union[str, csv.Dialect] = 'excel' + ) -> None: """ Save a DocumentArray to a csv file. + The field names will be stored in the first row. The Documents information will be stored in one row each. :param file_path: path to a csv file. + :param dialect: defines separator and how to handle whitespaces etc. + Can be a csv.Dialect instance or one string of: + 'excel' (for comma seperated values), + 'excel-tab' (for tab separated values), + 'unix' (for csv file generated on UNIX systems). """ fields = self.document_type._get_access_paths() with open(file_path, 'w') as csv_file: - writer = csv.DictWriter(csv_file, fieldnames=fields) + writer = csv.DictWriter(csv_file, fieldnames=fields, dialect=dialect) writer.writeheader() for doc in self: @@ -613,7 +625,7 @@ def save_binary( def is_access_path_valid(doc: Type['BaseDocument'], access_path: str) -> bool: """ - Check if a given access path is a valid path for a given Document class. + Check if a given access path ("__"-separated) is a valid path for a given Document class. """ field, _, remaining = access_path.partition('__') if len(remaining) == 0: @@ -632,7 +644,7 @@ def is_access_path_valid(doc: Type['BaseDocument'], access_path: str) -> bool: def _access_path_to_dict(access_path: str, value) -> Dict[str, Any]: """ - Convert an access path and its value to a (potentially) nested dict. + Convert an access path ("__"-separated) and its value to a (potentially) nested dict. EXAMPLE USAGE .. code-block:: python @@ -648,7 +660,7 @@ def _access_path_to_dict(access_path: str, value) -> Dict[str, Any]: def _dict_to_access_paths(d: dict) -> Dict[str, Any]: """ Convert a (nested) dict to a Dict[access_path, value]. - Access paths are defines as a path of field(s) separated by "__". + Access paths are defined as a path of field(s) separated by "__". EXAMPLE USAGE .. code-block:: python From 00a9ea7274048b0bd9842a3032c3d18d25b1f259 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Tue, 21 Feb 2023 11:59:53 +0100 Subject: [PATCH 17/19] fix: typos Signed-off-by: anna-charlotte --- docarray/array/array/io.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/docarray/array/array/io.py b/docarray/array/array/io.py index ad4f3f49fda..9eff890477d 100644 --- a/docarray/array/array/io.py +++ b/docarray/array/array/io.py @@ -306,12 +306,12 @@ def from_csv( """ Load a DocumentArray from a csv file following the schema defined in the :attr:`~docarray.DocumentArray.document_type` attribute. - Every access_path2val of the csv file will be mapped to one document in the array. - The column names (defined in the first access_path2val) have to match the field names + Every row of the csv file will be mapped to one document in the array. + The column names (defined in the first row) have to match the field names of the Document type. - For nested field_names use "__"-separated access paths, such as 'image__url'. + For nested fields use "__"-separated access paths, such as 'image__url'. - List-like field_names (including DocumentArray) are not supported + List-like fields (including field of type DocumentArray) are not supported. :param file_path: path to csv file to load DocumentArray from. :param encoding: encoding used to read the csv file. Defaults to 'utf-8'. @@ -333,7 +333,7 @@ def from_csv( da = DocumentArray.__class_getitem__(doc_type)() with open(file_path, 'r', encoding=encoding) as fp: - rows: csv.DictReader = csv.DictReader(fp, dialect=dialect) + rows = csv.DictReader(fp, dialect=dialect) field_names: Optional[Sequence[Any]] = rows.fieldnames if field_names is None: @@ -364,7 +364,11 @@ def to_csv( ) -> None: """ Save a DocumentArray to a csv file. - The field names will be stored in the first row. The Documents information will be stored in one row each. + The field names will be stored in the first row. Each row corresponds to the + information of one Document. + Columns for nested fields will be named after the "__"-seperated access paths, + such as `"image__url"` for `image.url`. + :param file_path: path to a csv file. :param dialect: defines separator and how to handle whitespaces etc. Can be a csv.Dialect instance or one string of: From e06e5330d8e005190568e5d11ea04820a55a393b Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Tue, 21 Feb 2023 16:22:07 +0100 Subject: [PATCH 18/19] refactor: move helper functions to helper file Signed-off-by: anna-charlotte --- docarray/array/array/io.py | 88 ++------------------- docarray/helper.py | 88 +++++++++++++++++++++ tests/units/array/test_array_from_to_csv.py | 55 ------------- tests/units/test_helper.py | 85 ++++++++++++++++++++ 4 files changed, 179 insertions(+), 137 deletions(-) create mode 100644 docarray/helper.py create mode 100644 tests/units/test_helper.py diff --git a/docarray/array/array/io.py b/docarray/array/array/io.py index 9eff890477d..6a88f55f273 100644 --- a/docarray/array/array/io.py +++ b/docarray/array/array/io.py @@ -25,6 +25,12 @@ ) from docarray.base_document import AnyDocument, BaseDocument +from docarray.helper import ( + _access_path_to_dict, + _dict_to_access_paths, + _update_nested_dicts, + is_access_path_valid, +) from docarray.utils.compress import _decompress_bytes, _get_compress_ctx if TYPE_CHECKING: @@ -625,85 +631,3 @@ def save_binary( file_ctx=file_ctx, show_progress=show_progress, ) - - -def is_access_path_valid(doc: Type['BaseDocument'], access_path: str) -> bool: - """ - Check if a given access path ("__"-separated) is a valid path for a given Document class. - """ - field, _, remaining = access_path.partition('__') - if len(remaining) == 0: - return access_path in doc.__fields__.keys() - else: - valid_field = field in doc.__fields__.keys() - if not valid_field: - return False - else: - d = doc._get_field_type(field) - if not issubclass(d, BaseDocument): - return False - else: - return is_access_path_valid(d, remaining) - - -def _access_path_to_dict(access_path: str, value) -> Dict[str, Any]: - """ - Convert an access path ("__"-separated) and its value to a (potentially) nested dict. - - EXAMPLE USAGE - .. code-block:: python - assert access_path_to_dict('image__url', 'img.png') == {'image': {'url': 'img.png'}} - """ - fields = access_path.split('__') - for field in reversed(fields): - result = {field: value} - value = result - return result - - -def _dict_to_access_paths(d: dict) -> Dict[str, Any]: - """ - Convert a (nested) dict to a Dict[access_path, value]. - Access paths are defined as a path of field(s) separated by "__". - - EXAMPLE USAGE - .. code-block:: python - assert dict_to_access_paths({'image': {'url': 'img.png'}}) == {'image__url', 'img.png'} - """ - result = {} - for k, v in d.items(): - if isinstance(v, dict): - v = _dict_to_access_paths(v) - for nested_k, nested_v in v.items(): - new_key = '__'.join([k, nested_k]) - result[new_key] = nested_v - else: - result[k] = v - return result - - -def _update_nested_dicts( - to_update: Dict[Any, Any], update_with: Dict[Any, Any] -) -> None: - """ - Update a dict with another one, while considering shared nested keys. - - EXAMPLE USAGE: - - .. code-block:: python - - d1 = {'image': {'tensor': None}, 'title': 'hello'} - d2 = {'image': {'url': 'some.png'}} - - update_nested_dicts(d1, d2) - assert d1 == {'image': {'tensor': None, 'url': 'some.png'}, 'title': 'hello'} - - :param to_update: dict that should be updated - :param update_with: dict to update with - :return: merged dict - """ - for k, v in update_with.items(): - if k not in to_update.keys(): - to_update[k] = v - else: - _update_nested_dicts(to_update[k], update_with[k]) diff --git a/docarray/helper.py b/docarray/helper.py new file mode 100644 index 00000000000..e34aa2678ae --- /dev/null +++ b/docarray/helper.py @@ -0,0 +1,88 @@ +from typing import TYPE_CHECKING, Any, Dict, Type + +if TYPE_CHECKING: + from docarray import BaseDocument + + +def is_access_path_valid(doc: Type['BaseDocument'], access_path: str) -> bool: + """ + Check if a given access path ("__"-separated) is a valid path for a given Document class. + """ + from docarray import BaseDocument + + field, _, remaining = access_path.partition('__') + if len(remaining) == 0: + return access_path in doc.__fields__.keys() + else: + valid_field = field in doc.__fields__.keys() + if not valid_field: + return False + else: + d = doc._get_field_type(field) + if not issubclass(d, BaseDocument): + return False + else: + return is_access_path_valid(d, remaining) + + +def _access_path_to_dict(access_path: str, value) -> Dict[str, Any]: + """ + Convert an access path ("__"-separated) and its value to a (potentially) nested dict. + + EXAMPLE USAGE + .. code-block:: python + assert access_path_to_dict('image__url', 'img.png') == {'image': {'url': 'img.png'}} + """ + fields = access_path.split('__') + for field in reversed(fields): + result = {field: value} + value = result + return result + + +def _dict_to_access_paths(d: dict) -> Dict[str, Any]: + """ + Convert a (nested) dict to a Dict[access_path, value]. + Access paths are defined as a path of field(s) separated by "__". + + EXAMPLE USAGE + .. code-block:: python + assert dict_to_access_paths({'image': {'url': 'img.png'}}) == {'image__url', 'img.png'} + """ + result = {} + for k, v in d.items(): + if isinstance(v, dict): + v = _dict_to_access_paths(v) + for nested_k, nested_v in v.items(): + new_key = '__'.join([k, nested_k]) + result[new_key] = nested_v + else: + result[k] = v + return result + + +def _update_nested_dicts( + to_update: Dict[Any, Any], update_with: Dict[Any, Any] +) -> None: + """ + Update a dict with another one, while considering shared nested keys. + + EXAMPLE USAGE: + + .. code-block:: python + + d1 = {'image': {'tensor': None}, 'title': 'hello'} + d2 = {'image': {'url': 'some.png'}} + + update_nested_dicts(d1, d2) + assert d1 == {'image': {'tensor': None, 'url': 'some.png'}, 'title': 'hello'} + + :param to_update: dict that should be updated + :param update_with: dict to update with + :return: merged dict + """ + for k, v in update_with.items(): + if k not in to_update.keys(): + to_update[k] = v + else: + _update_nested_dicts(to_update[k], update_with[k]) diff --git a/tests/units/array/test_array_from_to_csv.py b/tests/units/array/test_array_from_to_csv.py index cfbc6e38be2..d1a6f326116 100644 --- a/tests/units/array/test_array_from_to_csv.py +++ b/tests/units/array/test_array_from_to_csv.py @@ -4,11 +4,6 @@ import pytest from docarray import BaseDocument, DocumentArray -from docarray.array.array.io import ( - _dict_to_access_paths, - _update_nested_dicts, - is_access_path_valid, -) from docarray.documents import Image from tests import TOYDATA_DIR @@ -104,53 +99,3 @@ def test_from_csv_with_wrong_schema_raise_exception(nested_doc): DocumentArray[nested_doc.__class__].from_csv( file_path=str(TOYDATA_DIR / 'docs.csv') ) - - -def test_get_access_paths(): - class Painting(BaseDocument): - title: str - img: Image - - access_paths = Painting._get_access_paths() - assert access_paths == [ - 'id', - 'title', - 'img__id', - 'img__url', - 'img__tensor', - 'img__embedding', - 'img__bytes', - ] - - -def test_is_access_path_valid(nested_doc): - assert is_access_path_valid(nested_doc.__class__, 'img') - assert is_access_path_valid(nested_doc.__class__, 'middle__img') - assert is_access_path_valid(nested_doc.__class__, 'middle__inner__img') - assert is_access_path_valid(nested_doc.__class__, 'middle') - assert not is_access_path_valid(nested_doc.__class__, 'inner') - assert not is_access_path_valid(nested_doc.__class__, 'some__other__path') - assert not is_access_path_valid(nested_doc.__class__, 'middle.inner') - - -def test_dict_to_access_paths(): - d = { - 'a0': {'b0': {'c0': 0}, 'b1': {'c0': 1}}, - 'a1': {'b0': {'c0': 2, 'c1': 3}, 'b1': 4}, - } - casted = _dict_to_access_paths(d) - assert casted == { - 'a0__b0__c0': 0, - 'a0__b1__c0': 1, - 'a1__b0__c0': 2, - 'a1__b0__c1': 3, - 'a1__b1': 4, - } - - -def test_update_nested_dict(): - d1 = {'text': 'hello', 'image': {'tensor': None}} - d2 = {'image': {'url': 'some.png'}} - - _update_nested_dicts(d1, d2) - assert d1 == {'text': 'hello', 'image': {'tensor': None, 'url': 'some.png'}} diff --git a/tests/units/test_helper.py b/tests/units/test_helper.py new file mode 100644 index 00000000000..06c1a3d5628 --- /dev/null +++ b/tests/units/test_helper.py @@ -0,0 +1,85 @@ +from typing import Optional + +import pytest + +from docarray import BaseDocument +from docarray.documents import Image +from docarray.helper import ( + _access_path_to_dict, + _dict_to_access_paths, + _update_nested_dicts, + is_access_path_valid, +) + + +@pytest.fixture() +def nested_doc_cls(): + class MyDoc(BaseDocument): + count: Optional[int] + text: str + + class MyDocNested(MyDoc): + image: Image + image2: Image + + return MyDocNested + + +def test_is_access_path_valid(nested_doc): + assert is_access_path_valid(nested_doc.__class__, 'img') + assert is_access_path_valid(nested_doc.__class__, 'middle__img') + assert is_access_path_valid(nested_doc.__class__, 'middle__inner__img') + assert is_access_path_valid(nested_doc.__class__, 'middle') + + +def test_is_access_path_not_valid(nested_doc): + assert not is_access_path_valid(nested_doc.__class__, 'inner') + assert not is_access_path_valid(nested_doc.__class__, 'some__other__path') + assert not is_access_path_valid(nested_doc.__class__, 'middle.inner') + + +def test_get_access_paths(): + class Painting(BaseDocument): + title: str + img: Image + + access_paths = Painting._get_access_paths() + assert access_paths == [ + 'id', + 'title', + 'img__id', + 'img__url', + 'img__tensor', + 'img__embedding', + 'img__bytes', + ] + + +def test_dict_to_access_paths(): + d = { + 'a0': {'b0': {'c0': 0}, 'b1': {'c0': 1}}, + 'a1': {'b0': {'c0': 2, 'c1': 3}, 'b1': 4}, + } + casted = _dict_to_access_paths(d) + assert casted == { + 'a0__b0__c0': 0, + 'a0__b1__c0': 1, + 'a1__b0__c0': 2, + 'a1__b0__c1': 3, + 'a1__b1': 4, + } + + +def test_access_path_to_dict(): + access_path = 'a__b__c__d__e' + value = 1 + result = {'a': {'b': {'c': {'d': {'e': value}}}}} + assert _access_path_to_dict(access_path, value) == result + + +def test_update_nested_dict(): + d1 = {'text': 'hello', 'image': {'tensor': None}} + d2 = {'image': {'url': 'some.png'}} + + _update_nested_dicts(d1, d2) + assert d1 == {'text': 'hello', 'image': {'tensor': None, 'url': 'some.png'}} From c8e4cf8cfa266dcb6ca2b2c87b10f3a9601bccb0 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Tue, 21 Feb 2023 16:33:45 +0100 Subject: [PATCH 19/19] test: fix fixture Signed-off-by: anna-charlotte --- tests/units/test_helper.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/tests/units/test_helper.py b/tests/units/test_helper.py index 06c1a3d5628..e6730866b7b 100644 --- a/tests/units/test_helper.py +++ b/tests/units/test_helper.py @@ -13,16 +13,20 @@ @pytest.fixture() -def nested_doc_cls(): - class MyDoc(BaseDocument): - count: Optional[int] - text: str +def nested_doc(): + class Inner(BaseDocument): + img: Optional[Image] - class MyDocNested(MyDoc): - image: Image - image2: Image + class Middle(BaseDocument): + img: Optional[Image] + inner: Optional[Inner] - return MyDocNested + class Outer(BaseDocument): + img: Optional[Image] + middle: Optional[Middle] + + doc = Outer(img=Image(), middle=Middle(img=Image(), inner=Inner(img=Image()))) + return doc def test_is_access_path_valid(nested_doc):