diff --git a/docarray/base_doc/doc.py b/docarray/base_doc/doc.py index e5c630622c..4154f3248a 100644 --- a/docarray/base_doc/doc.py +++ b/docarray/base_doc/doc.py @@ -51,6 +51,8 @@ 'set[int] | set[str] | dict[int, Any] | dict[str, Any] | None' ) + from pydantic import ConfigDict + _console: Console = Console() @@ -71,10 +73,14 @@ class BaseDocWithoutId(BaseModel, IOMixin, UpdateMixin, BaseNode): if is_pydantic_v2: - class Config: - validate_assignment = True - _load_extra_fields_from_protobuf = False - json_encoders = {AbstractTensor: lambda x: x} + class ConfigDocArray(ConfigDict): + _load_extra_fields_from_protobuf: bool + + model_config = ConfigDocArray( + validate_assignment=True, + _load_extra_fields_from_protobuf=False, + json_encoders={AbstractTensor: lambda x: x}, + ) else: diff --git a/docarray/base_doc/mixins/io.py b/docarray/base_doc/mixins/io.py index f9e1f37c63..958897555c 100644 --- a/docarray/base_doc/mixins/io.py +++ b/docarray/base_doc/mixins/io.py @@ -238,10 +238,14 @@ def from_protobuf(cls: Type[T], pb_msg: 'DocProto') -> T: """ fields: Dict[str, Any] = {} - + load_extra_field = ( + cls.model_config['_load_extra_fields_from_protobuf'] + if is_pydantic_v2 + else cls.Config._load_extra_fields_from_protobuf + ) for field_name in pb_msg.data: if ( - not (cls.Config._load_extra_fields_from_protobuf) + not (load_extra_field) and field_name not in cls._docarray_fields().keys() ): continue # optimization we don't even load the data if the key does not diff --git a/docs/user_guide/representing/first_step.md b/docs/user_guide/representing/first_step.md index f2a6bcae2d..114e03cd54 100644 --- a/docs/user_guide/representing/first_step.md +++ b/docs/user_guide/representing/first_step.md @@ -119,18 +119,32 @@ This representation can be used to [send](../sending/first_step.md) or [store](. ## Setting a Pydantic `Config` class -Documents support setting a `Config` [like any other Pydantic `BaseModel`](https://docs.pydantic.dev/latest/usage/model_config/). +Documents support setting a custom `configuration` [like any other Pydantic `BaseModel`](https://docs.pydantic.dev/latest/api/config/). -However, if you set a config, you should inherit from the `BaseDoc` config class: +Here is an example to extend the Config of a Document dependong on which version of Pydantic you are using. -```python -from docarray import BaseDoc -class MyDoc(BaseDoc): - class Config(BaseDoc.Config): - arbitrary_types_allowed = True # just an example setting -``` +=== "Pydantic v1" + ```python + from docarray import BaseDoc + + + class MyDoc(BaseDoc): + class Config(BaseDoc.Config): + arbitrary_types_allowed = True # just an example setting + ``` + +=== "Pydantic v2" + ```python + from docarray import BaseDoc + + + class MyDoc(BaseDoc): + model_config = BaseDoc.ConfigDocArray.ConfigDict( + arbitrary_types_allowed=True + ) # just an example setting + ``` See also: @@ -138,3 +152,4 @@ See also: * API reference for the [BaseDoc][docarray.base_doc.doc.BaseDoc] class * The [Storing](../storing/first_step.md) section on how to store your data * The [Sending](../sending/first_step.md) section on how to send your data + diff --git a/tests/units/document/test_any_document.py b/tests/units/document/test_any_document.py index c55be1ff58..7a235b45fe 100644 --- a/tests/units/document/test_any_document.py +++ b/tests/units/document/test_any_document.py @@ -2,14 +2,10 @@ import numpy as np import pytest -from orjson import orjson from docarray import DocList from docarray.base_doc import AnyDoc, BaseDoc -from docarray.base_doc.io.json import orjson_dumps_and_decode from docarray.typing import NdArray -from docarray.typing.tensor.abstract_tensor import AbstractTensor -from docarray.utils._internal.pydantic import is_pydantic_v2 def test_any_doc(): @@ -94,21 +90,3 @@ class DocTest(BaseDoc): assert isinstance(d.ld[0], dict) assert d.ld[0]['text'] == 'I am inner' assert d.ld[0]['t'] == {'a': 'b'} - - -@pytest.mark.skipif(is_pydantic_v2, reason="Not working with pydantic v2 for now") -def test_subclass_config(): - class MyDoc(BaseDoc): - x: str - - class Config(BaseDoc.Config): - arbitrary_types_allowed = True # just an example setting - - assert MyDoc.Config.json_loads == orjson.loads - assert MyDoc.Config.json_dumps == orjson_dumps_and_decode - assert ( - MyDoc.Config.json_encoders[AbstractTensor](3) == 3 - ) # dirty check that it is identity - assert MyDoc.Config.validate_assignment - assert not MyDoc.Config._load_extra_fields_from_protobuf - assert MyDoc.Config.arbitrary_types_allowed diff --git a/tests/units/document/test_base_document.py b/tests/units/document/test_base_document.py index 09c9a0660a..2bd80af376 100644 --- a/tests/units/document/test_base_document.py +++ b/tests/units/document/test_base_document.py @@ -1,11 +1,15 @@ from typing import Any, List, Optional, Tuple import numpy as np +import orjson import pytest from docarray import DocList, DocVec from docarray.base_doc.doc import BaseDoc +from docarray.base_doc.io.json import orjson_dumps_and_decode from docarray.typing import NdArray +from docarray.typing.tensor.abstract_tensor import AbstractTensor +from docarray.utils._internal.pydantic import is_pydantic_v2 def test_base_document_init(): @@ -146,3 +150,40 @@ class MyDoc(BaseDoc): field_type = MyDoc._get_field_inner_type("tuple_") assert field_type == Any + + +@pytest.mark.skipif( + is_pydantic_v2, reason="syntax only working with pydantic v1 for now" +) +def test_subclass_config(): + class MyDoc(BaseDoc): + x: str + + class Config(BaseDoc.Config): + arbitrary_types_allowed = True # just an example setting + + assert MyDoc.Config.json_loads == orjson.loads + assert MyDoc.Config.json_dumps == orjson_dumps_and_decode + assert ( + MyDoc.Config.json_encoders[AbstractTensor](3) == 3 + ) # dirty check that it is identity + assert MyDoc.Config.validate_assignment + assert not MyDoc.Config._load_extra_fields_from_protobuf + assert MyDoc.Config.arbitrary_types_allowed + + +@pytest.mark.skipif(not (is_pydantic_v2), reason="syntax only working with pydantic v2") +def test_subclass_config_v2(): + class MyDoc(BaseDoc): + x: str + + model_config = BaseDoc.ConfigDocArray( + arbitrary_types_allowed=True + ) # just an example setting + + assert ( + MyDoc.model_config['json_encoders'][AbstractTensor](3) == 3 + ) # dirty check that it is identity + assert MyDoc.model_config['validate_assignment'] + assert not MyDoc.model_config['_load_extra_fields_from_protobuf'] + assert MyDoc.model_config['arbitrary_types_allowed']