diff --git a/docarray/array/array/io.py b/docarray/array/array/io.py index b5a846102d2..6a88f55f273 100644 --- a/docarray/array/array/io.py +++ b/docarray/array/array/io.py @@ -1,4 +1,5 @@ import base64 +import csv import io import json import os @@ -6,29 +7,39 @@ import pickle from abc import abstractmethod from contextlib import nullcontext +from itertools import compress from typing import ( TYPE_CHECKING, + Any, BinaryIO, ContextManager, + Dict, Generator, Iterable, Optional, + Sequence, Tuple, Type, TypeVar, Union, ) -from docarray.base_document import BaseDocument +from docarray.base_document import AnyDocument, BaseDocument +from docarray.helper import ( + _access_path_to_dict, + _dict_to_access_paths, + _update_nested_dicts, + is_access_path_valid, +) from docarray.utils.compress import _decompress_bytes, _get_compress_ctx if TYPE_CHECKING: + from docarray import DocumentArray from docarray.proto import DocumentArrayProto T = TypeVar('T', bound='IOMixinArray') - ARRAY_PROTOCOLS = {'protobuf-array', 'pickle-array'} SINGLE_PROTOCOLS = {'pickle', 'protobuf'} ALLOWED_PROTOCOLS = ARRAY_PROTOCOLS.union(SINGLE_PROTOCOLS) @@ -291,6 +302,96 @@ def to_json(self) -> str: """ return json.dumps([doc.json() for doc in self]) + @classmethod + def from_csv( + cls, + file_path: str, + encoding: str = 'utf-8', + dialect: Union[str, csv.Dialect] = 'excel', + ) -> 'DocumentArray': + """ + Load a DocumentArray from a csv file following the schema defined in the + :attr:`~docarray.DocumentArray.document_type` attribute. + Every row of the csv file will be mapped to one document in the array. + The column names (defined in the first row) have to match the field names + of the Document type. + For nested fields use "__"-separated access paths, such as 'image__url'. + + List-like fields (including field of type DocumentArray) are not supported. + + :param file_path: path to csv file to load DocumentArray from. + :param encoding: encoding used to read the csv file. Defaults to 'utf-8'. + :param dialect: defines separator and how to handle whitespaces etc. + Can be a csv.Dialect instance or one string of: + 'excel' (for comma seperated values), + 'excel-tab' (for tab separated values), + 'unix' (for csv file generated on UNIX systems). + :return: DocumentArray + """ + from docarray import DocumentArray + + doc_type = cls.document_type + if doc_type == AnyDocument: + raise TypeError( + 'There is no document schema defined. ' + 'To load from csv, please specify the DocumentArray\'s Document type using `DocumentArray[MyDoc]`.' + ) + + da = DocumentArray.__class_getitem__(doc_type)() + with open(file_path, 'r', encoding=encoding) as fp: + rows = csv.DictReader(fp, dialect=dialect) + field_names: Optional[Sequence[Any]] = rows.fieldnames + + if field_names is None: + raise TypeError("No field names are given.") + + valid = [is_access_path_valid(doc_type, field) for field in field_names] + if not all(valid): + raise ValueError( + f'Fields provided in the csv file do not match the schema of the DocumentArray\'s ' + f'document type ({doc_type.__name__}): {list(compress(field_names, [not v for v in valid]))}' + ) + + for access_path2val in rows: + doc_dict: Dict[Any, Any] = {} + for access_path, value in access_path2val.items(): + field2val = _access_path_to_dict( + access_path=access_path, + value=value if value not in ['', 'None'] else None, + ) + _update_nested_dicts(to_update=doc_dict, update_with=field2val) + + da.append(doc_type.parse_obj(doc_dict)) + + return da + + def to_csv( + self, file_path: str, dialect: Union[str, csv.Dialect] = 'excel' + ) -> None: + """ + Save a DocumentArray to a csv file. + The field names will be stored in the first row. Each row corresponds to the + information of one Document. + Columns for nested fields will be named after the "__"-seperated access paths, + such as `"image__url"` for `image.url`. + + :param file_path: path to a csv file. + :param dialect: defines separator and how to handle whitespaces etc. + Can be a csv.Dialect instance or one string of: + 'excel' (for comma seperated values), + 'excel-tab' (for tab separated values), + 'unix' (for csv file generated on UNIX systems). + """ + fields = self.document_type._get_access_paths() + + with open(file_path, 'w') as csv_file: + writer = csv.DictWriter(csv_file, fieldnames=fields, dialect=dialect) + writer.writeheader() + + for doc in self: + doc_dict = _dict_to_access_paths(doc.dict()) + writer.writerow(doc_dict) + # Methods to load from/to files in different formats @property def _stream_header(self) -> bytes: diff --git a/docarray/base_document/mixins/io.py b/docarray/base_document/mixins/io.py index 25c21c3ffac..fdad8648f06 100644 --- a/docarray/base_document/mixins/io.py +++ b/docarray/base_document/mixins/io.py @@ -7,12 +7,15 @@ Callable, Dict, Iterable, + List, Optional, Tuple, Type, TypeVar, ) +from typing_inspect import is_union_type + from docarray.base_document.base_node import BaseNode from docarray.typing.proto_register import _PROTO_TYPE_NAME_TO_CLASS from docarray.utils.compress import _compress_bytes, _decompress_bytes @@ -291,3 +294,23 @@ def _to_node_protobuf(self) -> 'NodeProto': :return: the nested item protobuf message """ return NodeProto(document=self.to_protobuf()) + + @classmethod + def _get_access_paths(cls) -> List[str]: + """ + Get "__"-separated access paths of all fields, including nested ones. + + :return: list of all access paths + """ + from docarray import BaseDocument + + paths = [] + for field in cls.__fields__.keys(): + field_type = cls._get_field_type(field) + if not is_union_type(field_type) and issubclass(field_type, BaseDocument): + sub_paths = field_type._get_access_paths() + for path in sub_paths: + paths.append(f'{field}__{path}') + else: + paths.append(field) + return paths diff --git a/docarray/helper.py b/docarray/helper.py new file mode 100644 index 00000000000..e34aa2678ae --- /dev/null +++ b/docarray/helper.py @@ -0,0 +1,88 @@ +from typing import TYPE_CHECKING, Any, Dict, Type + +if TYPE_CHECKING: + from docarray import BaseDocument + + +def is_access_path_valid(doc: Type['BaseDocument'], access_path: str) -> bool: + """ + Check if a given access path ("__"-separated) is a valid path for a given Document class. + """ + from docarray import BaseDocument + + field, _, remaining = access_path.partition('__') + if len(remaining) == 0: + return access_path in doc.__fields__.keys() + else: + valid_field = field in doc.__fields__.keys() + if not valid_field: + return False + else: + d = doc._get_field_type(field) + if not issubclass(d, BaseDocument): + return False + else: + return is_access_path_valid(d, remaining) + + +def _access_path_to_dict(access_path: str, value) -> Dict[str, Any]: + """ + Convert an access path ("__"-separated) and its value to a (potentially) nested dict. + + EXAMPLE USAGE + .. code-block:: python + assert access_path_to_dict('image__url', 'img.png') == {'image': {'url': 'img.png'}} + """ + fields = access_path.split('__') + for field in reversed(fields): + result = {field: value} + value = result + return result + + +def _dict_to_access_paths(d: dict) -> Dict[str, Any]: + """ + Convert a (nested) dict to a Dict[access_path, value]. + Access paths are defined as a path of field(s) separated by "__". + + EXAMPLE USAGE + .. code-block:: python + assert dict_to_access_paths({'image': {'url': 'img.png'}}) == {'image__url', 'img.png'} + """ + result = {} + for k, v in d.items(): + if isinstance(v, dict): + v = _dict_to_access_paths(v) + for nested_k, nested_v in v.items(): + new_key = '__'.join([k, nested_k]) + result[new_key] = nested_v + else: + result[k] = v + return result + + +def _update_nested_dicts( + to_update: Dict[Any, Any], update_with: Dict[Any, Any] +) -> None: + """ + Update a dict with another one, while considering shared nested keys. + + EXAMPLE USAGE: + + .. code-block:: python + + d1 = {'image': {'tensor': None}, 'title': 'hello'} + d2 = {'image': {'url': 'some.png'}} + + update_nested_dicts(d1, d2) + assert d1 == {'image': {'tensor': None, 'url': 'some.png'}, 'title': 'hello'} + + :param to_update: dict that should be updated + :param update_with: dict to update with + :return: merged dict + """ + for k, v in update_with.items(): + if k not in to_update.keys(): + to_update[k] = v + else: + _update_nested_dicts(to_update[k], update_with[k]) diff --git a/tests/toydata/docs_nested.csv b/tests/toydata/docs_nested.csv new file mode 100644 index 00000000000..7b857870244 --- /dev/null +++ b/tests/toydata/docs_nested.csv @@ -0,0 +1,4 @@ +count,text,image,image2__url +000,hello 0,image_0.png,image_10.png +111,hello 1,image_1.png,None +222,hello 2,image_2.png, \ No newline at end of file diff --git a/tests/units/array/test_array_from_to_csv.py b/tests/units/array/test_array_from_to_csv.py new file mode 100644 index 00000000000..d1a6f326116 --- /dev/null +++ b/tests/units/array/test_array_from_to_csv.py @@ -0,0 +1,101 @@ +import os +from typing import Optional + +import pytest + +from docarray import BaseDocument, DocumentArray +from docarray.documents import Image +from tests import TOYDATA_DIR + + +@pytest.fixture() +def nested_doc_cls(): + class MyDoc(BaseDocument): + count: Optional[int] + text: str + + class MyDocNested(MyDoc): + image: Image + image2: Image + + return MyDocNested + + +def test_to_from_csv(tmpdir, nested_doc_cls): + da = DocumentArray[nested_doc_cls]( + [ + nested_doc_cls( + count=0, + text='hello', + image=Image(url='aux.png'), + image2=Image(url='aux.png'), + ), + nested_doc_cls(text='hello world', image=Image(), image2=Image()), + ] + ) + tmp_file = str(tmpdir / 'tmp.csv') + da.to_csv(tmp_file) + assert os.path.isfile(tmp_file) + + da_from = DocumentArray[nested_doc_cls].from_csv(tmp_file) + for doc1, doc2 in zip(da, da_from): + assert doc1 == doc2 + + +def test_from_csv_nested(nested_doc_cls): + da = DocumentArray[nested_doc_cls].from_csv( + file_path=str(TOYDATA_DIR / 'docs_nested.csv') + ) + assert len(da) == 3 + + for i, doc in enumerate(da): + assert doc.count.__class__ == int + assert doc.count == int(f'{i}{i}{i}') + + assert doc.text.__class__ == str + assert doc.text == f'hello {i}' + + assert doc.image.__class__ == Image + assert doc.image.tensor is None + assert doc.image.embedding is None + assert doc.image.bytes is None + + assert doc.image2.__class__ == Image + assert doc.image2.tensor is None + assert doc.image2.embedding is None + assert doc.image2.bytes is None + + assert da[0].image2.url == 'image_10.png' + assert da[1].image2.url is None + assert da[2].image2.url is None + + +@pytest.fixture() +def nested_doc(): + class Inner(BaseDocument): + img: Optional[Image] + + class Middle(BaseDocument): + img: Optional[Image] + inner: Optional[Inner] + + class Outer(BaseDocument): + img: Optional[Image] + middle: Optional[Middle] + + doc = Outer(img=Image(), middle=Middle(img=Image(), inner=Inner(img=Image()))) + return doc + + +def test_from_csv_without_schema_raise_exception(): + with pytest.raises(TypeError, match='no document schema defined'): + DocumentArray.from_csv(file_path=str(TOYDATA_DIR / 'docs_nested.csv')) + + +def test_from_csv_with_wrong_schema_raise_exception(nested_doc): + with pytest.raises( + ValueError, match='Fields provided in the csv file do not match the schema' + ): + DocumentArray[nested_doc.__class__].from_csv( + file_path=str(TOYDATA_DIR / 'docs.csv') + ) diff --git a/tests/units/array/test_array_from_to_json.py b/tests/units/array/test_array_from_to_json.py index 79b98c43700..b46066156e5 100644 --- a/tests/units/array/test_array_from_to_json.py +++ b/tests/units/array/test_array_from_to_json.py @@ -1,9 +1,6 @@ -import pytest - -from docarray import BaseDocument -from docarray.typing import NdArray +from docarray import BaseDocument, DocumentArray from docarray.documents import Image -from docarray import DocumentArray +from docarray.typing import NdArray class MyDoc(BaseDocument): diff --git a/tests/units/test_helper.py b/tests/units/test_helper.py new file mode 100644 index 00000000000..e6730866b7b --- /dev/null +++ b/tests/units/test_helper.py @@ -0,0 +1,89 @@ +from typing import Optional + +import pytest + +from docarray import BaseDocument +from docarray.documents import Image +from docarray.helper import ( + _access_path_to_dict, + _dict_to_access_paths, + _update_nested_dicts, + is_access_path_valid, +) + + +@pytest.fixture() +def nested_doc(): + class Inner(BaseDocument): + img: Optional[Image] + + class Middle(BaseDocument): + img: Optional[Image] + inner: Optional[Inner] + + class Outer(BaseDocument): + img: Optional[Image] + middle: Optional[Middle] + + doc = Outer(img=Image(), middle=Middle(img=Image(), inner=Inner(img=Image()))) + return doc + + +def test_is_access_path_valid(nested_doc): + assert is_access_path_valid(nested_doc.__class__, 'img') + assert is_access_path_valid(nested_doc.__class__, 'middle__img') + assert is_access_path_valid(nested_doc.__class__, 'middle__inner__img') + assert is_access_path_valid(nested_doc.__class__, 'middle') + + +def test_is_access_path_not_valid(nested_doc): + assert not is_access_path_valid(nested_doc.__class__, 'inner') + assert not is_access_path_valid(nested_doc.__class__, 'some__other__path') + assert not is_access_path_valid(nested_doc.__class__, 'middle.inner') + + +def test_get_access_paths(): + class Painting(BaseDocument): + title: str + img: Image + + access_paths = Painting._get_access_paths() + assert access_paths == [ + 'id', + 'title', + 'img__id', + 'img__url', + 'img__tensor', + 'img__embedding', + 'img__bytes', + ] + + +def test_dict_to_access_paths(): + d = { + 'a0': {'b0': {'c0': 0}, 'b1': {'c0': 1}}, + 'a1': {'b0': {'c0': 2, 'c1': 3}, 'b1': 4}, + } + casted = _dict_to_access_paths(d) + assert casted == { + 'a0__b0__c0': 0, + 'a0__b1__c0': 1, + 'a1__b0__c0': 2, + 'a1__b0__c1': 3, + 'a1__b1': 4, + } + + +def test_access_path_to_dict(): + access_path = 'a__b__c__d__e' + value = 1 + result = {'a': {'b': {'c': {'d': {'e': value}}}}} + assert _access_path_to_dict(access_path, value) == result + + +def test_update_nested_dict(): + d1 = {'text': 'hello', 'image': {'tensor': None}} + d2 = {'image': {'url': 'some.png'}} + + _update_nested_dicts(d1, d2) + assert d1 == {'text': 'hello', 'image': {'tensor': None, 'url': 'some.png'}}