diff --git a/.github/workflows/add_license.yml b/.github/workflows/add_license.yml
index 6c497e19d2b..9c63c711a46 100644
--- a/.github/workflows/add_license.yml
+++ b/.github/workflows/add_license.yml
@@ -15,7 +15,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v3
with:
- python-version: 3.10
+ python-version: "3.10"
- name: Run add_license.sh and check for changes
id: add_license
diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml
index a1aae08ec9b..e0a14b5252c 100644
--- a/.github/workflows/cd.yml
+++ b/.github/workflows/cd.yml
@@ -21,7 +21,7 @@ jobs:
- name: Pre-release (.devN)
run: |
git fetch --depth=1 origin +refs/tags/*:refs/tags/*
- pip install poetry
+ pip install poetry==1.7.1
./scripts/release.sh
env:
PYPI_USERNAME: ${{ secrets.TWINE_USERNAME }}
@@ -35,20 +35,16 @@ jobs:
steps:
- uses: actions/checkout@v3
with:
- fetch-depth: 0
-
- - name: Get changed files
- id: changed-files-specific
- uses: tj-actions/changed-files@v41
- with:
- files: |
- README.md
+ fetch-depth: 2
- name: Check if README is modified
id: step_output
- if: steps.changed-files-specific.outputs.any_changed == 'true'
run: |
- echo "readme_changed=true" >> $GITHUB_OUTPUT
+ if git diff --name-only HEAD^ HEAD | grep -q "README.md"; then
+ echo "readme_changed=true" >> $GITHUB_OUTPUT
+ else
+ echo "readme_changed=false" >> $GITHUB_OUTPUT
+ fi
publish-docarray-org:
needs: check-readme-modification
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index b8c4added62..07c32d0b873 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -25,7 +25,7 @@ jobs:
- name: Lint with ruff
run: |
python -m pip install --upgrade pip
- python -m pip install poetry
+ python -m pip install poetry==1.7.1
poetry install
# stop the build if there are Python syntax errors or undefined names
@@ -44,7 +44,7 @@ jobs:
- name: check black
run: |
python -m pip install --upgrade pip
- python -m pip install poetry
+ python -m pip install poetry==1.7.1
poetry install --only dev
poetry run black --check .
@@ -62,7 +62,7 @@ jobs:
- name: Prepare environment
run: |
python -m pip install --upgrade pip
- python -m pip install poetry
+ python -m pip install poetry==1.7.1
poetry install --without dev
poetry run pip install tensorflow==2.12.0
poetry run pip install jax
@@ -106,7 +106,7 @@ jobs:
- name: Prepare environment
run: |
python -m pip install --upgrade pip
- python -m pip install poetry
+ python -m pip install poetry==1.7.1
poetry install --all-extras
poetry run pip install elasticsearch==8.6.2
./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }}
@@ -119,7 +119,7 @@ jobs:
- name: Test
id: test
run: |
- poetry run pytest -m "not (tensorflow or benchmark or index or jax)" --cov=docarray --cov-report=xml ${{ matrix.test-path }} --ignore=tests/integrations/store/test_jac.py
+ poetry run pytest -m "not (tensorflow or benchmark or index or jax)" --cov=docarray --cov-report=xml -v -s ${{ matrix.test-path }} --ignore=tests/integrations/store/test_jac.py
echo "flag it as docarray for codeoverage"
echo "codecov_flag=docarray" >> $GITHUB_OUTPUT
timeout-minutes: 30
@@ -156,7 +156,7 @@ jobs:
- name: Prepare environment
run: |
python -m pip install --upgrade pip
- python -m pip install poetry
+ python -m pip install poetry==1.7.1
poetry install --all-extras
./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }}
poetry run pip install protobuf==3.20.0 # we check that we support 3.19
@@ -167,7 +167,7 @@ jobs:
- name: Test
id: test
run: |
- poetry run pytest -m 'proto' --cov=docarray --cov-report=xml tests
+ poetry run pytest -m 'proto' --cov=docarray --cov-report=xml -v -s tests
echo "flag it as docarray for codeoverage"
echo "codecov_flag=docarray" >> $GITHUB_OUTPUT
timeout-minutes: 30
@@ -204,7 +204,7 @@ jobs:
- name: Prepare environment
run: |
python -m pip install --upgrade pip
- python -m pip install poetry
+ python -m pip install poetry==1.7.1
poetry install --all-extras
./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }}
poetry run pip install protobuf==3.20.0
@@ -217,7 +217,7 @@ jobs:
- name: Test
id: test
run: |
- poetry run pytest -m 'index and not elasticv8' --cov=docarray --cov-report=xml tests/index/${{ matrix.db_test_folder }}
+ poetry run pytest -m 'index and not elasticv8' --cov=docarray --cov-report=xml -v -s tests/index/${{ matrix.db_test_folder }}
echo "flag it as docarray for codeoverage"
echo "codecov_flag=docarray" >> $GITHUB_OUTPUT
timeout-minutes: 30
@@ -253,7 +253,7 @@ jobs:
- name: Prepare environment
run: |
python -m pip install --upgrade pip
- python -m pip install poetry
+ python -m pip install poetry==1.7.1
poetry install --all-extras
./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }}
poetry run pip install protobuf==3.20.0
@@ -267,7 +267,7 @@ jobs:
- name: Test
id: test
run: |
- poetry run pytest -m 'index and elasticv8' --cov=docarray --cov-report=xml tests
+ poetry run pytest -m 'index and elasticv8' --cov=docarray --cov-report=xml -v -s tests
echo "flag it as docarray for codeoverage"
echo "codecov_flag=docarray" >> $GITHUB_OUTPUT
timeout-minutes: 30
@@ -302,7 +302,7 @@ jobs:
- name: Prepare environment
run: |
python -m pip install --upgrade pip
- python -m pip install poetry
+ python -m pip install poetry==1.7.1
poetry install --all-extras
./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }}
poetry run pip install protobuf==3.20.0
@@ -316,7 +316,7 @@ jobs:
- name: Test
id: test
run: |
- poetry run pytest -m 'tensorflow' --cov=docarray --cov-report=xml tests
+ poetry run pytest -m 'tensorflow' --cov=docarray --cov-report=xml -v -s tests
echo "flag it as docarray for codeoverage"
echo "codecov_flag=docarray" >> $GITHUB_OUTPUT
timeout-minutes: 30
@@ -351,7 +351,7 @@ jobs:
- name: Prepare environment
run: |
python -m pip install --upgrade pip
- python -m pip install poetry
+ python -m pip install poetry==1.7.1
poetry install --all-extras
./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }}
poetry run pip uninstall -y torch
@@ -362,7 +362,7 @@ jobs:
- name: Test
id: test
run: |
- poetry run pytest -m 'jax' --cov=docarray --cov-report=xml tests
+ poetry run pytest -m 'jax' --cov=docarray --cov-report=xml -v -s tests
echo "flag it as docarray for codeoverage"
echo "codecov_flag=docarray" >> $GITHUB_OUTPUT
timeout-minutes: 30
@@ -398,7 +398,7 @@ jobs:
- name: Prepare environment
run: |
python -m pip install --upgrade pip
- python -m pip install poetry
+ python -m pip install poetry==1.7.1
poetry install --all-extras
poetry run pip uninstall -y torch
poetry run pip install torch
@@ -406,7 +406,7 @@ jobs:
- name: Test
id: test
run: |
- poetry run pytest -m 'benchmark' --cov=docarray --cov-report=xml tests
+ poetry run pytest -m 'benchmark' --cov=docarray --cov-report=xml -v -s tests
echo "flag it as docarray for codeoverage"
echo "codecov_flag=docarray" >> $GITHUB_OUTPUT
timeout-minutes: 30
diff --git a/.github/workflows/ci_only_pr.yml b/.github/workflows/ci_only_pr.yml
index 1e8d3f9694f..9d040e72b62 100644
--- a/.github/workflows/ci_only_pr.yml
+++ b/.github/workflows/ci_only_pr.yml
@@ -43,7 +43,7 @@ jobs:
run: |
npm i -g netlify-cli
python -m pip install --upgrade pip
- python -m pip install poetry
+ python -m pip install poetry==1.7.1
python -m poetry config virtualenvs.create false && python -m poetry install --no-interaction --no-ansi --all-extras
cd docs
diff --git a/.github/workflows/force-release.yml b/.github/workflows/force-release.yml
index 3037e791081..3ad1af18ced 100644
--- a/.github/workflows/force-release.yml
+++ b/.github/workflows/force-release.yml
@@ -40,7 +40,7 @@ jobs:
- run: |
git fetch --depth=1 origin +refs/tags/*:refs/tags/*
npm install git-release-notes
- pip install poetry
+ python -m pip install poetry==1.7.1
./scripts/release.sh final "${{ github.event.inputs.release_reason }}" "${{github.actor}}"
env:
TWINE_USERNAME: __token__
diff --git a/.github/workflows/uncaped.yml b/.github/workflows/uncaped.yml
index e1cbafb6d44..ccb56bc2497 100644
--- a/.github/workflows/uncaped.yml
+++ b/.github/workflows/uncaped.yml
@@ -21,7 +21,7 @@ jobs:
- name: Prepare environment
run: |
python -m pip install --upgrade pip
- python -m pip install poetry
+ python -m pip install poetry==1.7.1
rm poetry.lock
poetry install --all-extras
poetry run pip install elasticsearch==8.6.2
diff --git a/CHANGELOG.md b/CHANGELOG.md
index f0620722888..48f2dedcd93 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -15,6 +15,7 @@
+
## Release Note (`0.30.0`)
@@ -746,3 +747,54 @@
- [[```8de3e175```](https://github.com/jina-ai/docarray/commit/8de3e1757bdb23b509ad2630219c3c26605308f0)] __-__ refactor test of the torchtensor (#1837) (*Naymul Islam*)
- [[```d5d928b8```](https://github.com/jina-ai/docarray/commit/d5d928b82f36a3279277c07bed44fd22bb0bba34)] __-__ __version__: the next version will be 0.39.2 (*Jina Dev Bot*)
+
+## Release Note (`0.40.1`)
+
+> Release time: 2025-03-21 08:34:40
+
+
+
+🙇 We'd like to thank all contributors for this new release! In particular,
+ Joan Fontanals, Emmanuel Ferdman, Casey Clements, YuXuan Tay, dependabot[bot], James Brown, Jina Dev Bot, 🙇
+
+
+### 🐞 Bug fixes
+
+ - [[```d98acb71```](https://github.com/jina-ai/docarray/commit/d98acb716e0c336a817f65b62d428ab13cf8ac42)] __-__ fix DocList schema when using Pydantic V2 (#1876) (*Joan Fontanals*)
+ - [[```83ebef60```](https://github.com/jina-ai/docarray/commit/83ebef6087e868517681e59877008f80f1e7f113)] __-__ update license location (#1911) (*Emmanuel Ferdman*)
+ - [[```8f4ba7cd```](https://github.com/jina-ai/docarray/commit/8f4ba7cdf177f3e4ecc838eef659496d6038af03)] __-__ use docker compose (#1905) (*YuXuan Tay*)
+ - [[```febbdc42```](https://github.com/jina-ai/docarray/commit/febbdc4291c4af7ad2058d7feebf6a3169de93e9)] __-__ fix float in dynamic Document creation (#1877) (*Joan Fontanals*)
+ - [[```7c1e18ef```](https://github.com/jina-ai/docarray/commit/7c1e18ef01b09ef3d864b200248c875d0d9ced29)] __-__ fix create pure python class iteratively (#1867) (*Joan Fontanals*)
+
+### 📗 Documentation
+
+ - [[```e4665e91```](https://github.com/jina-ai/docarray/commit/e4665e91b37f97a4a18a80399431d624db8ca453)] __-__ move hint about schemas to common docindex section (#1868) (*Joan Fontanals*)
+ - [[```8da50c92```](https://github.com/jina-ai/docarray/commit/8da50c927c24b981867650399f64d4930bd7c574)] __-__ add code review to contributing.md (#1853) (*Joan Fontanals*)
+
+### 🏁 Unit Test and CICD
+
+ - [[```a162a4b0```](https://github.com/jina-ai/docarray/commit/a162a4b09f4ad8e8c5c117c0c0101541af4c00a1)] __-__ fix release procedure (#1922) (*Joan Fontanals*)
+ - [[```82d7cee7```](https://github.com/jina-ai/docarray/commit/82d7cee71ccdd4d5874985aef0567631424b5bfd)] __-__ fix some ci (#1893) (*Joan Fontanals*)
+ - [[```791e4a04```](https://github.com/jina-ai/docarray/commit/791e4a0473afe9d9bde87733074eef0ce217d198)] __-__ update release procedure (#1869) (*Joan Fontanals*)
+ - [[```aa15b9ef```](https://github.com/jina-ai/docarray/commit/aa15b9eff2f5293849e83291d79bf519994c3503)] __-__ add license (#1861) (*Joan Fontanals*)
+
+### 🍹 Other Improvements
+
+ - [[```b5696b22```](https://github.com/jina-ai/docarray/commit/b5696b227161f087fa32834dcd6c2d212cf82c0e)] __-__ fix poetry in ci (#1921) (*Joan Fontanals*)
+ - [[```d3358105```](https://github.com/jina-ai/docarray/commit/d3358105db645418c3cebfc6acb0f353127364aa)] __-__ update pyproject version (#1919) (*Joan Fontanals*)
+ - [[```40cf2962```](https://github.com/jina-ai/docarray/commit/40cf29622b29be1f32595e26876593bb5f1e03be)] __-__ MongoDB Atlas: Two line change to make our CI builds green (#1910) (*Casey Clements*)
+ - [[```75e0033a```](https://github.com/jina-ai/docarray/commit/75e0033a361a31280709899e94d6f5e14ff4b8ae)] __-__ __deps__: bump setuptools from 65.5.1 to 70.0.0 (#1899) (*dependabot[bot]*)
+ - [[```75a743c9```](https://github.com/jina-ai/docarray/commit/75a743c99dc549eaf4c3ffe01086d09a8f3f3e44)] __-__ __deps-dev__: bump tornado from 6.2 to 6.4.1 (#1894) (*dependabot[bot]*)
+ - [[```f3fa7c23```](https://github.com/jina-ai/docarray/commit/f3fa7c2376da2449e98aff159167bf41467d610c)] __-__ __deps__: bump pydantic from 1.10.8 to 1.10.13 (#1884) (*dependabot[bot]*)
+ - [[```46d50828```](https://github.com/jina-ai/docarray/commit/46d5082844602689de97c904af7c8139980711ed)] __-__ __deps__: bump urllib3 from 1.26.14 to 1.26.19 (#1896) (*dependabot[bot]*)
+ - [[```f0f4236e```](https://github.com/jina-ai/docarray/commit/f0f4236ebf75528e6c5344dc75328ce9cf56cae9)] __-__ __deps__: bump zipp from 3.10.0 to 3.19.1 (#1898) (*dependabot[bot]*)
+ - [[```d65d27ce```](https://github.com/jina-ai/docarray/commit/d65d27ce37f5e7c930b7792fd665ac4da9c6398d)] __-__ __deps__: bump certifi from 2022.9.24 to 2024.7.4 (#1897) (*dependabot[bot]*)
+ - [[```b8b62173```](https://github.com/jina-ai/docarray/commit/b8b621735dbe16c188bf8c1c03cb3f1a22076ae8)] __-__ __deps__: bump authlib from 1.2.0 to 1.3.1 (#1895) (*dependabot[bot]*)
+ - [[```6a972d1c```](https://github.com/jina-ai/docarray/commit/6a972d1c0dcf6d0c2816dea14df37e0039945542)] __-__ __deps__: bump qdrant-client from 1.4.0 to 1.9.0 (#1892) (*dependabot[bot]*)
+ - [[```f71a5e6a```](https://github.com/jina-ai/docarray/commit/f71a5e6af58b77fdeb15ba27abd0b7d40b84fd09)] __-__ __deps__: bump cryptography from 40.0.1 to 42.0.4 (#1872) (*dependabot[bot]*)
+ - [[```065aab44```](https://github.com/jina-ai/docarray/commit/065aab441cd71635ee3711ad862240e967ca3da6)] __-__ __deps__: bump orjson from 3.8.2 to 3.9.15 (#1873) (*dependabot[bot]*)
+ - [[```caf97135```](https://github.com/jina-ai/docarray/commit/caf9713502791a8fbbf0aa53b3ca2db126f18df7)] __-__ add license notice to every file (#1860) (*Joan Fontanals*)
+ - [[```50376358```](https://github.com/jina-ai/docarray/commit/50376358163005e66a76cd0cb40217fd7a4f1252)] __-__ __deps-dev__: bump jupyterlab from 3.5.0 to 3.6.7 (#1848) (*dependabot[bot]*)
+ - [[```104b403b```](https://github.com/jina-ai/docarray/commit/104b403b2b61a485e2cc032a357f46e7dc8044fe)] __-__ __deps__: bump tj-actions/changed-files from 34 to 41 in /.github/workflows (#1844) (*dependabot[bot]*)
+ - [[```f9426a29```](https://github.com/jina-ai/docarray/commit/f9426a29b29580beae8805d2556b4a94ff493edc)] __-__ __version__: the next version will be 0.40.1 (*Jina Dev Bot*)
+
diff --git a/README.md b/README.md
index 06acc4f516a..1c4e27f989d 100644
--- a/README.md
+++ b/README.md
@@ -16,7 +16,7 @@
> The README you're currently viewing is for DocArray>0.30, which introduces some significant changes from DocArray 0.21. If you wish to continue using the older DocArray <=0.21, ensure you install it via `pip install docarray==0.21`. Refer to its [codebase](https://github.com/docarray/docarray/tree/v0.21.0), [documentation](https://docarray.jina.ai), and [its hot-fixes branch](https://github.com/docarray/docarray/tree/docarray-v1-fixes) for more information.
-DocArray is a Python library expertly crafted for the [representation](#represent), [transmission](#send), [storage](#store), and [retrieval](#retrieve) of multimodal data. Tailored for the development of multimodal AI applications, its design guarantees seamless integration with the extensive Python and machine learning ecosystems. As of January 2022, DocArray is openly distributed under the [Apache License 2.0](https://github.com/docarray/docarray/blob/main/LICENSE) and currently enjoys the status of a sandbox project within the [LF AI & Data Foundation](https://lfaidata.foundation/).
+DocArray is a Python library expertly crafted for the [representation](#represent), [transmission](#send), [storage](#store), and [retrieval](#retrieve) of multimodal data. Tailored for the development of multimodal AI applications, its design guarantees seamless integration with the extensive Python and machine learning ecosystems. As of January 2022, DocArray is openly distributed under the [Apache License 2.0](https://github.com/docarray/docarray/blob/main/LICENSE.md) and currently enjoys the status of a sandbox project within the [LF AI & Data Foundation](https://lfaidata.foundation/).
diff --git a/docarray/__init__.py b/docarray/__init__.py
index 6ce3f9eb90f..20b08ba1735 100644
--- a/docarray/__init__.py
+++ b/docarray/__init__.py
@@ -13,13 +13,67 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-__version__ = '0.40.1'
+__version__ = '0.40.2'
import logging
from docarray.array import DocList, DocVec
from docarray.base_doc.doc import BaseDoc
from docarray.utils._internal.misc import _get_path_from_docarray_root_level
+from docarray.utils._internal.pydantic import is_pydantic_v2
+
+
+def unpickle_doclist(doc_type, b):
+ return DocList[doc_type].from_bytes(b, protocol="protobuf")
+
+
+def unpickle_docvec(doc_type, tensor_type, b):
+ return DocVec[doc_type].from_bytes(b, protocol="protobuf", tensor_type=tensor_type)
+
+
+if is_pydantic_v2:
+ # Register the pickle functions
+ def register_serializers():
+ import copyreg
+ from functools import partial
+
+ unpickle_doc_fn = partial(BaseDoc.from_bytes, protocol="protobuf")
+
+ def pickle_doc(doc):
+ b = doc.to_bytes(protocol='protobuf')
+ return unpickle_doc_fn, (doc.__class__, b)
+
+ # Register BaseDoc serialization
+ copyreg.pickle(BaseDoc, pickle_doc)
+
+ # For DocList, we need to hook into __reduce__ since it's a generic
+
+ def pickle_doclist(doc_list):
+ b = doc_list.to_bytes(protocol='protobuf')
+ doc_type = doc_list.doc_type
+ return unpickle_doclist, (doc_type, b)
+
+ # Replace DocList.__reduce__ with a method that returns the correct format
+ def doclist_reduce(self):
+ return pickle_doclist(self)
+
+ DocList.__reduce__ = doclist_reduce
+
+ # For DocVec, we need to hook into __reduce__ since it's a generic
+
+ def pickle_docvec(doc_vec):
+ b = doc_vec.to_bytes(protocol='protobuf')
+ doc_type = doc_vec.doc_type
+ tensor_type = doc_vec.tensor_type
+ return unpickle_docvec, (doc_type, tensor_type, b)
+
+ # Replace DocList.__reduce__ with a method that returns the correct format
+ def docvec_reduce(self):
+ return pickle_docvec(self)
+
+ DocVec.__reduce__ = docvec_reduce
+
+ register_serializers()
__all__ = ['BaseDoc', 'DocList', 'DocVec']
diff --git a/docarray/array/any_array.py b/docarray/array/any_array.py
index 50c47cf4ec4..0c29e54ae82 100644
--- a/docarray/array/any_array.py
+++ b/docarray/array/any_array.py
@@ -25,6 +25,7 @@
from docarray.exceptions.exceptions import UnusableObjectError
from docarray.typing.abstract_type import AbstractType
from docarray.utils._internal._typing import change_cls_name, safe_issubclass
+from docarray.utils._internal.pydantic import is_pydantic_v2
if TYPE_CHECKING:
from docarray.proto import DocListProto, NodeProto
@@ -73,8 +74,19 @@ def __class_getitem__(cls, item: Union[Type[BaseDocWithoutId], TypeVar, str]):
# Promote to global scope so multiprocessing can pickle it
global _DocArrayTyped
- class _DocArrayTyped(cls): # type: ignore
- doc_type: Type[BaseDocWithoutId] = cast(Type[BaseDocWithoutId], item)
+ if not is_pydantic_v2:
+
+ class _DocArrayTyped(cls): # type: ignore
+ doc_type: Type[BaseDocWithoutId] = cast(
+ Type[BaseDocWithoutId], item
+ )
+
+ else:
+
+ class _DocArrayTyped(cls, Generic[T_doc]): # type: ignore
+ doc_type: Type[BaseDocWithoutId] = cast(
+ Type[BaseDocWithoutId], item
+ )
for field in _DocArrayTyped.doc_type._docarray_fields().keys():
@@ -99,14 +111,24 @@ def _setter(self, value):
setattr(_DocArrayTyped, field, _property_generator(field))
# this generates property on the fly based on the schema of the item
- # The global scope and qualname need to refer to this class a unique name.
- # Otherwise, creating another _DocArrayTyped will overwrite this one.
- change_cls_name(
- _DocArrayTyped, f'{cls.__name__}[{item.__name__}]', globals()
- )
-
- cls.__typed_da__[cls][item] = _DocArrayTyped
+ # # The global scope and qualname need to refer to this class a unique name.
+ # # Otherwise, creating another _DocArrayTyped will overwrite this one.
+ if not is_pydantic_v2:
+ change_cls_name(
+ _DocArrayTyped, f'{cls.__name__}[{item.__name__}]', globals()
+ )
+ cls.__typed_da__[cls][item] = _DocArrayTyped
+ else:
+ change_cls_name(_DocArrayTyped, f'{cls.__name__}', globals())
+ if sys.version_info < (3, 12):
+ cls.__typed_da__[cls][item] = Generic.__class_getitem__.__func__(
+ _DocArrayTyped, item
+ ) # type: ignore
+ # this do nothing that checking that item is valid type var or str
+ # Keep the approach in #1147 to be compatible with lower versions of Python.
+ else:
+ cls.__typed_da__[cls][item] = GenericAlias(_DocArrayTyped, item) # type: ignore
return cls.__typed_da__[cls][item]
@overload
diff --git a/docarray/array/doc_list/doc_list.py b/docarray/array/doc_list/doc_list.py
index c21cf934132..49236199153 100644
--- a/docarray/array/doc_list/doc_list.py
+++ b/docarray/array/doc_list/doc_list.py
@@ -12,6 +12,7 @@
Union,
cast,
overload,
+ Callable,
)
from pydantic import parse_obj_as
@@ -28,7 +29,6 @@
from docarray.utils._internal.pydantic import is_pydantic_v2
if is_pydantic_v2:
- from pydantic import GetCoreSchemaHandler
from pydantic_core import core_schema
from docarray.utils._internal._typing import safe_issubclass
@@ -45,10 +45,7 @@
class DocList(
- ListAdvancedIndexing[T_doc],
- PushPullMixin,
- IOMixinDocList,
- AnyDocArray[T_doc],
+ ListAdvancedIndexing[T_doc], PushPullMixin, IOMixinDocList, AnyDocArray[T_doc]
):
"""
DocList is a container of Documents.
@@ -357,8 +354,20 @@ def __repr__(self):
@classmethod
def __get_pydantic_core_schema__(
- cls, _source_type: Any, _handler: GetCoreSchemaHandler
+ cls, source: Any, handler: Callable[[Any], core_schema.CoreSchema]
) -> core_schema.CoreSchema:
- return core_schema.general_plain_validator_function(
- cls.validate,
+ instance_schema = core_schema.is_instance_schema(cls)
+ args = getattr(source, '__args__', None)
+ if args:
+ sequence_t_schema = handler(Sequence[args[0]])
+ else:
+ sequence_t_schema = handler(Sequence)
+
+ def validate_fn(v, info):
+ # input has already been validated
+ return cls(v, validate_input_docs=False)
+
+ non_instance_schema = core_schema.with_info_after_validator_function(
+ validate_fn, sequence_t_schema
)
+ return core_schema.union_schema([instance_schema, non_instance_schema])
diff --git a/docarray/array/doc_list/io.py b/docarray/array/doc_list/io.py
index 82d00197e26..3acb66bf6e8 100644
--- a/docarray/array/doc_list/io.py
+++ b/docarray/array/doc_list/io.py
@@ -256,7 +256,6 @@ def to_bytes(
:param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf`
:return: the binary serialization in bytes or None if file_ctx is passed where to store
"""
-
with file_ctx or io.BytesIO() as bf:
self._write_bytes(
bf=bf,
diff --git a/docarray/array/doc_vec/doc_vec.py b/docarray/array/doc_vec/doc_vec.py
index 9d515cfd96f..0cc462f173d 100644
--- a/docarray/array/doc_vec/doc_vec.py
+++ b/docarray/array/doc_vec/doc_vec.py
@@ -198,7 +198,7 @@ def _check_doc_field_not_none(field_name, doc):
if safe_issubclass(tensor.__class__, tensor_type):
field_type = tensor_type
- if isinstance(field_type, type):
+ if isinstance(field_type, type) or safe_issubclass(field_type, AnyDocArray):
if tf_available and safe_issubclass(field_type, TensorFlowTensor):
# tf.Tensor does not allow item assignment, therefore the
# optimized way
@@ -335,7 +335,9 @@ def _docarray_validate(
return cast(T, value.to_doc_vec())
else:
raise ValueError(f'DocVec[value.doc_type] is not compatible with {cls}')
- elif isinstance(value, DocList.__class_getitem__(cls.doc_type)):
+ elif not is_pydantic_v2 and isinstance(
+ value, DocList.__class_getitem__(cls.doc_type)
+ ):
return cast(T, value.to_doc_vec())
elif isinstance(value, Sequence):
return cls(value)
diff --git a/docarray/base_doc/doc.py b/docarray/base_doc/doc.py
index 4d45f1369a8..e880504bc05 100644
--- a/docarray/base_doc/doc.py
+++ b/docarray/base_doc/doc.py
@@ -326,8 +326,13 @@ def _exclude_doclist(
from docarray.array.any_array import AnyDocArray
type_ = self._get_field_annotation(field)
- if isinstance(type_, type) and issubclass(type_, AnyDocArray):
- doclist_exclude_fields.append(field)
+ if is_pydantic_v2:
+ # Conservative when touching pydantic v1 logic
+ if safe_issubclass(type_, AnyDocArray):
+ doclist_exclude_fields.append(field)
+ else:
+ if isinstance(type_, type) and safe_issubclass(type_, AnyDocArray):
+ doclist_exclude_fields.append(field)
original_exclude = exclude
if exclude is None:
@@ -480,7 +485,6 @@ def model_dump( # type: ignore
warnings: bool = True,
) -> Dict[str, Any]:
def _model_dump(doc):
-
(
exclude_,
original_exclude,
diff --git a/docarray/base_doc/mixins/update.py b/docarray/base_doc/mixins/update.py
index 721f8225ebb..7ce596ce1aa 100644
--- a/docarray/base_doc/mixins/update.py
+++ b/docarray/base_doc/mixins/update.py
@@ -110,9 +110,7 @@ def _group_fields(doc: 'UpdateMixin') -> _FieldGroups:
if field_name not in FORBIDDEN_FIELDS_TO_UPDATE:
field_type = doc._get_field_annotation(field_name)
- if isinstance(field_type, type) and safe_issubclass(
- field_type, DocList
- ):
+ if safe_issubclass(field_type, DocList):
nested_docarray_fields.append(field_name)
else:
origin = get_origin(field_type)
diff --git a/docarray/index/backends/elastic.py b/docarray/index/backends/elastic.py
index c008fa29de0..a335f85e32a 100644
--- a/docarray/index/backends/elastic.py
+++ b/docarray/index/backends/elastic.py
@@ -352,12 +352,12 @@ def python_type_to_db_type(self, python_type: Type) -> Any:
dict: 'object',
}
- for type in elastic_py_types.keys():
- if safe_issubclass(python_type, type):
+ for t in elastic_py_types.keys():
+ if safe_issubclass(python_type, t):
self._logger.info(
- f'Mapped Python type {python_type} to database type "{elastic_py_types[type]}"'
+ f'Mapped Python type {python_type} to database type "{elastic_py_types[t]}"'
)
- return elastic_py_types[type]
+ return elastic_py_types[t]
err_msg = f'Unsupported column type for {type(self)}: {python_type}'
self._logger.error(err_msg)
diff --git a/docarray/index/backends/epsilla.py b/docarray/index/backends/epsilla.py
index 83c171daed0..0392e9d010e 100644
--- a/docarray/index/backends/epsilla.py
+++ b/docarray/index/backends/epsilla.py
@@ -100,8 +100,8 @@ def __init__(self, db_config=None, **kwargs):
def _validate_column_info(self):
vector_columns = []
for info in self._column_infos.values():
- for type in [list, np.ndarray, AbstractTensor]:
- if safe_issubclass(info.docarray_type, type) and info.config.get(
+ for t in [list, np.ndarray, AbstractTensor]:
+ if safe_issubclass(info.docarray_type, t) and info.config.get(
'is_embedding', False
):
# check that dimension is present
diff --git a/docarray/index/backends/helper.py b/docarray/index/backends/helper.py
index 268f623ab18..5582dbba866 100644
--- a/docarray/index/backends/helper.py
+++ b/docarray/index/backends/helper.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, List, Tuple, Type, cast
+from typing import Any, Dict, List, Tuple, Type, cast, Set
from docarray import BaseDoc, DocList
from docarray.index.abstract import BaseDocIndex
@@ -20,6 +20,43 @@ def inner(self, *args, **kwargs):
return inner
+def _collect_query_required_args(method_name: str, required_args: Set[str] = None):
+ """
+ Returns a function that ensures required keyword arguments are provided.
+
+ :param method_name: The name of the method for which the required arguments are being checked.
+ :type method_name: str
+ :param required_args: A set containing the names of required keyword arguments. Defaults to None.
+ :type required_args: Optional[Set[str]]
+ :return: A function that checks for required keyword arguments before executing the specified method.
+ Raises ValueError if positional arguments are provided.
+ Raises TypeError if any required keyword argument is missing.
+ :rtype: Callable
+ """
+
+ if required_args is None:
+ required_args = set()
+
+ def inner(self, *args, **kwargs):
+ if args:
+ raise ValueError(
+ f"Positional arguments are not supported for "
+ f"`{type(self)}.{method_name}`. "
+ f"Use keyword arguments instead."
+ )
+
+ missing_args = required_args - set(kwargs.keys())
+ if missing_args:
+ raise ValueError(
+ f"`{type(self)}.{method_name}` is missing required argument(s): {', '.join(missing_args)}"
+ )
+
+ updated_query = self._queries + [(method_name, kwargs)]
+ return type(self)(updated_query)
+
+ return inner
+
+
def _execute_find_and_filter_query(
doc_index: BaseDocIndex, query: List[Tuple[str, Dict]], reverse_order: bool = False
) -> FindResult:
diff --git a/docarray/index/backends/milvus.py b/docarray/index/backends/milvus.py
index 609eee1ec8b..e84baac7210 100644
--- a/docarray/index/backends/milvus.py
+++ b/docarray/index/backends/milvus.py
@@ -192,7 +192,7 @@ def python_type_to_db_type(self, python_type: Type) -> Any:
AbstractTensor: DataType.FLOAT_VECTOR,
}
- if issubclass(python_type, ID):
+ if safe_issubclass(python_type, ID):
return DataType.VARCHAR
for py_type, db_type in type_map.items():
@@ -665,7 +665,7 @@ def find_batched(
if search_field:
if '__' in search_field:
fields = search_field.split('__')
- if issubclass(self._schema._get_field_annotation(fields[0]), AnyDocArray): # type: ignore
+ if safe_issubclass(self._schema._get_field_annotation(fields[0]), AnyDocArray): # type: ignore
return self._subindices[fields[0]].find_batched(
queries,
search_field='__'.join(fields[1:]),
diff --git a/docarray/index/backends/mongodb_atlas.py b/docarray/index/backends/mongodb_atlas.py
index caaa82742f8..f1ccdec02d2 100644
--- a/docarray/index/backends/mongodb_atlas.py
+++ b/docarray/index/backends/mongodb_atlas.py
@@ -1,62 +1,96 @@
import collections
import logging
-from collections import defaultdict
from dataclasses import dataclass, field
from functools import cached_property
-
from typing import (
Any,
Dict,
Generator,
Generic,
List,
+ NamedTuple,
Optional,
Sequence,
+ Tuple,
Type,
TypeVar,
Union,
- Tuple,
)
import bson
import numpy as np
from pymongo import MongoClient
-from docarray import BaseDoc, DocList
+from docarray import BaseDoc, DocList, handler
from docarray.index.abstract import BaseDocIndex, _raise_not_composable
+from docarray.index.backends.helper import _collect_query_required_args
+from docarray.typing import AnyTensor
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.utils._internal._typing import safe_issubclass
from docarray.utils.find import _FindResult, _FindResultBatched
+logger = logging.getLogger(__name__)
+logger.addHandler(handler)
+
+
MAX_CANDIDATES = 10_000
OVERSAMPLING_FACTOR = 10
TSchema = TypeVar('TSchema', bound=BaseDoc)
+class HybridResult(NamedTuple):
+ """Adds breakdown of scores into vector and text components."""
+
+ documents: Union[DocList, List[Dict[str, Any]]]
+ scores: AnyTensor
+ score_breakdown: Dict[str, List[Any]]
+
+
class MongoDBAtlasDocumentIndex(BaseDocIndex, Generic[TSchema]):
+ """DocumentIndex backed by MongoDB Atlas Vector Store.
+
+ MongoDB Atlas provides full Text, Vector, and Hybrid Search
+ and can store structured data, text and vector indexes
+ in the same Collection (Index).
+
+ Atlas provides efficient index and search on vector embeddings
+ using the Hierarchical Navigable Small Worlds (HNSW) algorithm.
+
+ For documentation, see the following.
+ * Text Search: https://www.mongodb.com/docs/atlas/atlas-search/atlas-search-overview/
+ * Vector Search: https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-overview/
+ * Hybrid Search: https://www.mongodb.com/docs/atlas/atlas-vector-search/tutorials/reciprocal-rank-fusion/
+ """
+
def __init__(self, db_config=None, **kwargs):
super().__init__(db_config=db_config, **kwargs)
- self._logger = logging.getLogger(__name__)
- self._create_indexes()
- self._logger.info(f'{self.__class__.__name__} has been initialized')
+ logger.info(f'{self.__class__.__name__} has been initialized')
@property
- def _collection(self):
- if self._is_subindex:
- return self._db_config.index_name
+ def index_name(self):
+ """The name of the index/collection in the database.
- if not self._schema:
- raise ValueError(
- 'A MongoDBAtlasDocumentIndex must be typed with a Document type.'
- 'To do so, use the syntax: MongoDBAtlasDocumentIndex[DocumentType]'
- )
+ Note that in MongoDB Atlas, one has Collections (analogous to Tables),
+ which can have Search Indexes. They are distinct.
+ DocArray tends to consider them together.
- return self._schema.__name__.lower()
+ The index_name can be set when initializing MongoDBAtlasDocumentIndex.
+ The easiest way is to pass index_name= as a kwarg.
+ Otherwise, a rational default uses the name of the DocumentTypes that it contains.
+ """
- @property
- def index_name(self):
- """Return the name of the index in the database."""
- return self._collection
+ if self._db_config.index_name is not None:
+ return self._db_config.index_name
+ else:
+ # Create a reasonable default
+ if not self._schema:
+ raise ValueError(
+ 'A MongoDBAtlasDocumentIndex must be typed with a Document type.'
+ 'To do so, use the syntax: MongoDBAtlasDocumentIndex[DocumentType]'
+ )
+ schema_name = self._schema.__name__.lower()
+ logger.debug(f"db_config.index_name was not set. Using {schema_name}")
+ return schema_name
@property
def _database_name(self):
@@ -69,8 +103,9 @@ def _client(self):
)
@property
- def _doc_collection(self):
- return self._client[self._database_name][self._collection]
+ def _collection(self):
+ """MongoDB Collection"""
+ return self._client[self._database_name][self.index_name]
@staticmethod
def _connect_to_mongodb_atlas(atlas_connection_uri: str):
@@ -86,43 +121,182 @@ def _connect_to_mongodb_atlas(atlas_connection_uri: str):
def _create_indexes(self):
"""Create a new index in the MongoDB database if it doesn't already exist."""
- self._logger.warning(
- "Search Indexes in MongoDB Atlas must be created manually. "
- "Currently, client-side creation of vector indexes is not allowed on free clusters."
- "Please follow instructions in docs/API_reference/doc_index/backends/mongodb.md"
- )
+
+ def _check_index_exists(self, index_name: str) -> bool:
+ """
+ Check if an index exists in the MongoDB Atlas database.
+
+ :param index_name: The name of the index.
+ :return: True if the index exists, False otherwise.
+ """
+
+ @dataclass
+ class Query:
+ """Dataclass describing a query."""
+
+ vector_fields: Optional[Dict[str, np.ndarray]]
+ filters: Optional[List[Any]]
+ text_searches: Optional[List[Any]]
+ limit: int
class QueryBuilder(BaseDocIndex.QueryBuilder):
- ...
+ """Compose complex queries containing vector search (find), text_search, and filters.
+
+ Arguments to `find` are vectors of embeddings, text_search expects strings,
+ and filters expect dicts of MongoDB Query Language (MDB).
+
+
+ NOTE: When doing Hybrid Search, pay close attention to the interpretation and use of inputs,
+ particularly when multiple calls are made of the same method (find, text_search, filter).
+ * find (Vector Search): Embedding vectors will be averaged. The penalty/weight defined in DBConfig will not change.
+ * text_search: Individual searches are performed, each with the same penalty/weight.
+ * filter: Within Vector Search, performs efficient k-NN filtering with the Lucene engine
+ """
+
+ def __init__(self, query: Optional[List[Tuple[str, Dict]]] = None):
+ super().__init__()
+ # list of tuples (method name, kwargs)
+ self._queries: List[Tuple[str, Dict]] = query or []
+
+ def build(self, limit: int = 1, *args, **kwargs) -> Any:
+ """Build a `Query` that can be passed to `execute_query`."""
+ search_fields: Dict[str, np.ndarray] = collections.defaultdict(list)
+ filters: List[Any] = []
+ text_searches: List[Any] = []
+ for method, kwargs in self._queries:
+ if method == 'find':
+ search_field = kwargs['search_field']
+ search_fields[search_field].append(kwargs["query"])
+
+ elif method == 'filter':
+ filters.append(kwargs)
+ else:
+ text_searches.append(kwargs)
+
+ vector_fields = {
+ field: np.average(vectors, axis=0)
+ for field, vectors in search_fields.items()
+ }
+ return MongoDBAtlasDocumentIndex.Query(
+ vector_fields=vector_fields,
+ filters=filters,
+ text_searches=text_searches,
+ limit=limit,
+ )
+
+ find = _collect_query_required_args('find', {'search_field', 'query'})
+ filter = _collect_query_required_args('filter', {'query'})
+ text_search = _collect_query_required_args(
+ 'text_search', {'search_field', 'query'}
+ )
- find = _raise_not_composable('find')
- filter = _raise_not_composable('filter')
- text_search = _raise_not_composable('text_search')
find_batched = _raise_not_composable('find_batched')
filter_batched = _raise_not_composable('filter_batched')
text_search_batched = _raise_not_composable('text_search_batched')
- def execute_query(self, query: Any, *args, **kwargs) -> _FindResult:
- """
- Execute a query on the database.
- Can take two kinds of inputs:
- 1. A native query of the underlying database. This is meant as a passthrough so that you
- can enjoy any functionality that is not available through the Document index API.
- 2. The output of this Document index' `QueryBuilder.build()` method.
- :param query: the query to execute
+ def execute_query(
+ self, query: Any, *args, score_breakdown=True, **kwargs
+ ) -> Any: # _FindResult:
+ """Execute a Query on the database.
+
+ :param query: the query to execute. The output of this Document index's `QueryBuilder.build()` method.
:param args: positional arguments to pass to the query
+ :param score_breakdown: Will provide breakdown of scores into text and vector components for Hybrid Searches.
:param kwargs: keyword arguments to pass to the query
:return: the result of the query
"""
- ...
+ if not isinstance(query, MongoDBAtlasDocumentIndex.Query):
+ raise ValueError(
+ "Expected MongoDBAtlasDocumentIndex.Query. Found {type(query)=}."
+ "For native calls to MongoDBAtlasDocumentIndex, simply call filter()"
+ )
+
+ if len(query.vector_fields) > 1:
+ self._logger.warning(
+ f"{len(query.vector_fields)} embedding vectors have been provided to the query. They will be averaged."
+ )
+ if len(query.text_searches) > 1:
+ self._logger.warning(
+ f"{len(query.text_searches)} text searches will be performed, and each receive a ranked score."
+ )
+
+ # collect filters
+ filters: List[Dict[str, Any]] = []
+ for filter_ in query.filters:
+ filters.append(filter_['query'])
+
+ # check if hybrid search is needed.
+ hybrid = len(query.vector_fields) + len(query.text_searches) > 1
+ if hybrid:
+ if len(query.vector_fields) > 1:
+ raise NotImplementedError(
+ "Hybrid Search on multiple Vector Indexes has yet to be done."
+ )
+ pipeline = self._hybrid_search(
+ query.vector_fields, query.text_searches, filters, query.limit
+ )
+ else:
+ if query.text_searches:
+ # it is a simple text search, perhaps with filters.
+ text_stage = self._text_search_stage(**query.text_searches[0])
+ pipeline = [
+ text_stage,
+ {"$match": {"$and": filters} if filters else {}},
+ {
+ '$project': self._project_fields(
+ extra_fields={"score": {'$meta': 'searchScore'}}
+ )
+ },
+ {"$limit": query.limit},
+ ]
+ elif query.vector_fields:
+ # it is a simple vector search, perhaps with filters.
+ assert (
+ len(query.vector_fields) == 1
+ ), "Query contains more than one vector_field."
+ field, vector_query = list(query.vector_fields.items())[0]
+ pipeline = [
+ self._vector_search_stage(
+ query=vector_query,
+ search_field=field,
+ limit=query.limit,
+ filters=filters,
+ ),
+ {
+ '$project': self._project_fields(
+ extra_fields={"score": {'$meta': 'vectorSearchScore'}}
+ )
+ },
+ ]
+ # it is only a filter search.
+ else:
+ pipeline = [{"$match": {"$and": filters}}]
+
+ with self._collection.aggregate(pipeline) as cursor:
+ results, scores = self._mongo_to_docs(cursor)
+ docs = self._dict_list_to_docarray(results)
+
+ if hybrid and score_breakdown and results:
+ score_breakdown = collections.defaultdict(list)
+ score_fields = [key for key in results[0] if "score" in key]
+ for res in results:
+ score_breakdown["id"].append(res["id"])
+ for sf in score_fields:
+ score_breakdown[sf].append(res[sf])
+ logger.debug(score_breakdown)
+ return HybridResult(
+ documents=docs, scores=scores, score_breakdown=score_breakdown
+ )
+
+ return _FindResult(documents=docs, scores=scores)
@dataclass
class DBConfig(BaseDocIndex.DBConfig):
mongo_connection_uri: str = 'localhost'
index_name: Optional[str] = None
- database_name: Optional[str] = "db"
+ database_name: Optional[str] = "default"
default_column_config: Dict[Type, Dict[str, Any]] = field(
- default_factory=lambda: defaultdict(
+ default_factory=lambda: collections.defaultdict(
dict,
{
bson.BSONARR: {
@@ -131,13 +305,13 @@ class DBConfig(BaseDocIndex.DBConfig):
'max_candidates': MAX_CANDIDATES,
'indexed': False,
'index_name': None,
- 'penalty': 1,
+ 'penalty': 5,
},
bson.BSONSTR: {
'indexed': False,
'index_name': None,
'operator': 'phrase',
- 'penalty': 10,
+ 'penalty': 1,
},
},
)
@@ -145,7 +319,7 @@ class DBConfig(BaseDocIndex.DBConfig):
@dataclass
class RuntimeConfig(BaseDocIndex.RuntimeConfig):
- pass
+ ...
def python_type_to_db_type(self, python_type: Type) -> Any:
"""Map python type to database type.
@@ -186,16 +360,14 @@ def _docs_to_mongo(self, docs):
return [self._doc_to_mongo(doc) for doc in docs]
@staticmethod
- def _mongo_to_doc(mongo_doc: dict) -> Tuple[dict, float]:
+ def _mongo_to_doc(mongo_doc: dict) -> dict:
result = mongo_doc.copy()
result["id"] = result.pop("_id")
- score = result.pop("score", None)
+ score = result.get("score", None)
return result, score
@staticmethod
- def _mongo_to_docs(
- mongo_docs: Generator[Dict, None, None]
- ) -> Tuple[List[dict], List[float]]:
+ def _mongo_to_docs(mongo_docs: Generator[Dict, None, None]) -> List[dict]:
docs = []
scores = []
for mongo_doc in mongo_docs:
@@ -212,11 +384,15 @@ def _get_max_candidates(self, search_field: str) -> int:
return self._column_infos[search_field].config["max_candidates"]
def _index(self, column_to_data: Dict[str, Generator[Any, None, None]]):
- """index a document into the store"""
- # `column_to_data` is a dictionary from column name to a generator
- # that yields the data for that column.
- # If you want to work directly on documents, you can implement index() instead
- # If you implement index(), _index() only needs a dummy implementation.
+ """Add and Index Documents to the datastore
+
+ The input format is aimed towards column vectors, which is not
+ the natural fit for MongoDB Collections, but we have chosen
+ not to override BaseDocIndex.index as it provides valuable validation.
+ This may change in the future.
+
+ :param column_to_data: is a dictionary from column name to a generator
+ """
self._index_subindex(column_to_data)
docs: List[Dict[str, Any]] = []
while True:
@@ -226,11 +402,11 @@ def _index(self, column_to_data: Dict[str, Generator[Any, None, None]]):
docs.append(mongo_doc)
except StopIteration:
break
- self._doc_collection.insert_many(docs)
+ self._collection.insert_many(docs)
def num_docs(self) -> int:
"""Return the number of indexed documents"""
- return self._doc_collection.count_documents({})
+ return self._collection.count_documents({})
@property
def _is_index_empty(self) -> bool:
@@ -246,7 +422,7 @@ def _del_items(self, doc_ids: Sequence[str]) -> None:
:param doc_ids: ids to delete from the Document Store
"""
mg_filter = {"_id": {"$in": doc_ids}}
- self._doc_collection.delete_many(mg_filter)
+ self._collection.delete_many(mg_filter)
def _get_items(
self, doc_ids: Sequence[str]
@@ -258,44 +434,149 @@ def _get_items(
:return: Sequence of Documents, sorted corresponding to the order of `doc_ids`. Duplicate `doc_ids` can be omitted in the output.
"""
mg_filter = {"_id": {"$in": doc_ids}}
- docs = self._doc_collection.find(mg_filter)
+ docs = self._collection.find(mg_filter)
docs, _ = self._mongo_to_docs(docs)
if not docs:
raise KeyError(f'No document with id {doc_ids} found')
return docs
- def _vector_stage_search(
+ def _reciprocal_rank_stage(self, search_field: str, score_field: str):
+ penalty = self._column_infos[search_field].config["penalty"]
+ projection_fields = {
+ key: f"$docs.{key}" for key in self._column_infos.keys() if key != "id"
+ }
+ projection_fields["_id"] = "$docs._id"
+ projection_fields[score_field] = 1
+
+ return [
+ {"$group": {"_id": None, "docs": {"$push": "$$ROOT"}}},
+ {"$unwind": {"path": "$docs", "includeArrayIndex": "rank"}},
+ {
+ "$addFields": {
+ score_field: {"$divide": [1.0, {"$add": ["$rank", penalty, 1]}]}
+ }
+ },
+ {'$project': projection_fields},
+ ]
+
+ def _add_stage_to_pipeline(self, pipeline: List[Any], stage: Dict[str, Any]):
+ if pipeline:
+ pipeline.append(
+ {"$unionWith": {"coll": self.index_name, "pipeline": stage}}
+ )
+ else:
+ pipeline.extend(stage)
+ return pipeline
+
+ def _final_stage(self, scores_fields, limit):
+ """Sum individual scores, sort, and apply limit."""
+ doc_fields = self._column_infos.keys()
+ grouped_fields = {
+ key: {"$first": f"${key}"} for key in doc_fields if key != "_id"
+ }
+ best_score = {score: {'$max': f'${score}'} for score in scores_fields}
+ final_pipeline = [
+ {"$group": {"_id": "$_id", **grouped_fields, **best_score}},
+ {
+ "$project": {
+ **{doc_field: 1 for doc_field in doc_fields},
+ **{score: {"$ifNull": [f"${score}", 0]} for score in scores_fields},
+ }
+ },
+ {
+ "$addFields": {
+ "score": {"$add": [f"${score}" for score in scores_fields]},
+ }
+ },
+ {"$sort": {"score": -1}},
+ {"$limit": limit},
+ ]
+ return final_pipeline
+
+ @staticmethod
+ def _score_field(search_field: str, search_field_counts: Dict[str, int]):
+ score_field = f"{search_field}_score"
+ count = search_field_counts[search_field]
+ if count > 1:
+ score_field += str(count)
+ return score_field
+
+ def _hybrid_search(
+ self,
+ vector_queries: Dict[str, Any],
+ text_queries: List[Dict[str, Any]],
+ filters: Dict[str, Any],
+ limit: int,
+ ):
+ hybrid_pipeline = [] # combined aggregate pipeline
+ search_field_counts = collections.defaultdict(
+ int
+ ) # stores count of calls on same search field
+ score_fields = [] # names given to scores of each search stage
+ for search_field, query in vector_queries.items():
+ search_field_counts[search_field] += 1
+ vector_stage = self._vector_search_stage(
+ query=query,
+ search_field=search_field,
+ limit=limit,
+ filters=filters,
+ )
+ score_field = self._score_field(search_field, search_field_counts)
+ score_fields.append(score_field)
+ vector_pipeline = [
+ vector_stage,
+ *self._reciprocal_rank_stage(search_field, score_field),
+ ]
+ self._add_stage_to_pipeline(hybrid_pipeline, vector_pipeline)
+
+ for kwargs in text_queries:
+ search_field_counts[kwargs["search_field"]] += 1
+ text_stage = self._text_search_stage(**kwargs)
+ search_field = kwargs["search_field"]
+ score_field = self._score_field(search_field, search_field_counts)
+ score_fields.append(score_field)
+ reciprocal_rank_stage = self._reciprocal_rank_stage(
+ search_field, score_field
+ )
+ text_pipeline = [
+ text_stage,
+ {"$match": {"$and": filters} if filters else {}},
+ {"$limit": limit},
+ *reciprocal_rank_stage,
+ ]
+ self._add_stage_to_pipeline(hybrid_pipeline, text_pipeline)
+
+ hybrid_pipeline += self._final_stage(score_fields, limit)
+ return hybrid_pipeline
+
+ def _vector_search_stage(
self,
query: np.ndarray,
search_field: str,
limit: int,
- filters: List[Dict[str, Any]] = [],
+ filters: List[Dict[str, Any]] = None,
) -> Dict[str, Any]:
- index_name = self._get_column_db_index(search_field)
+ search_index_name = self._get_column_db_index(search_field)
oversampling_factor = self._get_oversampling_factor(search_field)
max_candidates = self._get_max_candidates(search_field)
query = query.astype(np.float64).tolist()
- return {
+ stage = {
'$vectorSearch': {
- 'index': index_name,
+ 'index': search_index_name,
'path': search_field,
'queryVector': query,
'numCandidates': min(limit * oversampling_factor, max_candidates),
'limit': limit,
- 'filter': {"$and": filters} if filters else None,
}
}
+ if filters:
+ stage['$vectorSearch']['filter'] = {"$and": filters}
+ return stage
- def _filter_query(
- self,
- query: Any,
- ) -> Dict[str, Any]:
- return query
-
- def _text_stage_step(
+ def _text_search_stage(
self,
query: str,
search_field: str,
@@ -316,7 +597,7 @@ def _doc_exists(self, doc_id: str) -> bool:
:param doc_id: The id of a document to check.
:return: True if the document exists in the index, False otherwise.
"""
- doc = self._doc_collection.find_one({"_id": doc_id})
+ doc = self._collection.find_one({"_id": doc_id})
return bool(doc)
def _find(
@@ -330,12 +611,12 @@ def _find(
:param query: query vector for KNN/ANN search. Has single axis.
:param limit: maximum number of documents to return per query
:param search_field: name of the field to search on
- :return: a named NamedTuple containing `documents` and `scores`
+ :return: a named tuple containing `documents` and `scores`
"""
# NOTE: in standard implementations,
# `search_field` is equal to the column name to search on
- vector_search_stage = self._vector_stage_search(query, search_field, limit)
+ vector_search_stage = self._vector_search_stage(query, search_field, limit)
pipeline = [
vector_search_stage,
@@ -346,7 +627,7 @@ def _find(
},
]
- with self._doc_collection.aggregate(pipeline) as cursor:
+ with self._collection.aggregate(pipeline) as cursor:
documents, scores = self._mongo_to_docs(cursor)
return _FindResult(documents=documents, scores=scores)
@@ -360,7 +641,7 @@ def _find_batched(
Has shape (batch_size, vector_dim)
:param limit: maximum number of documents to return
:param search_field: name of the field to search on
- :return: a named NamedTuple containing `documents` and `scores`
+ :return: a named tuple containing `documents` and `scores`
"""
docs, scores = [], []
for query in queries:
@@ -433,7 +714,7 @@ def _filter(
:param limit: maximum number of documents to return
:return: a DocList containing the documents that match the filter query
"""
- with self._doc_collection.find(filter_query, limit=limit) as cursor:
+ with self._collection.find(filter_query, limit=limit) as cursor:
return self._mongo_to_docs(cursor)[0]
def _filter_batched(
@@ -462,9 +743,9 @@ def _text_search(
:param query: The text to search for
:param limit: maximum number of documents to return
:param search_field: name of the field to search on
- :return: a named Tuple containing `documents` and `scores`
+ :return: a named tuple containing `documents` and `scores`
"""
- text_stage = self._text_stage_step(query=query, search_field=search_field)
+ text_stage = self._text_search_stage(query=query, search_field=search_field)
pipeline = [
text_stage,
@@ -476,7 +757,7 @@ def _text_search(
{"$limit": limit},
]
- with self._doc_collection.aggregate(pipeline) as cursor:
+ with self._collection.aggregate(pipeline) as cursor:
documents, scores = self._mongo_to_docs(cursor)
return _FindResult(documents=documents, scores=scores)
@@ -492,7 +773,7 @@ def _text_search_batched(
:param queries: The texts to search for
:param limit: maximum number of documents to return per query
:param search_field: name of the field to search on
- :return: a named Tuple containing `documents` and `scores`
+ :return: a named tuple containing `documents` and `scores`
"""
# NOTE: in standard implementations,
# `search_field` is equal to the column name to search on
@@ -511,7 +792,5 @@ def _filter_by_parent_id(self, id: str) -> Optional[List[str]]:
:param id: the root document id to filter by
:return: a list of ids of the subindex documents
"""
- with self._doc_collection.find(
- {"parent_id": id}, projection={"_id": 1}
- ) as cursor:
+ with self._collection.find({"parent_id": id}, projection={"_id": 1}) as cursor:
return [doc["_id"] for doc in cursor]
diff --git a/docarray/typing/bytes/base_bytes.py b/docarray/typing/bytes/base_bytes.py
index 4c336ae6940..8a944031b4e 100644
--- a/docarray/typing/bytes/base_bytes.py
+++ b/docarray/typing/bytes/base_bytes.py
@@ -62,7 +62,7 @@ def _to_node_protobuf(self: T) -> 'NodeProto':
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: 'GetCoreSchemaHandler'
) -> 'core_schema.CoreSchema':
- return core_schema.general_after_validator_function(
+ return core_schema.with_info_after_validator_function(
cls.validate,
core_schema.bytes_schema(),
)
diff --git a/docarray/typing/id.py b/docarray/typing/id.py
index c06951eaef7..3e3fdd37ae4 100644
--- a/docarray/typing/id.py
+++ b/docarray/typing/id.py
@@ -77,7 +77,7 @@ def from_protobuf(cls: Type[T], pb_msg: 'str') -> T:
def __get_pydantic_core_schema__(
cls, source: Type[Any], handler: 'GetCoreSchemaHandler'
) -> core_schema.CoreSchema:
- return core_schema.general_plain_validator_function(
+ return core_schema.with_info_plain_validator_function(
cls.validate,
)
diff --git a/docarray/typing/tensor/abstract_tensor.py b/docarray/typing/tensor/abstract_tensor.py
index 994fe42cc85..e7e4fbe7056 100644
--- a/docarray/typing/tensor/abstract_tensor.py
+++ b/docarray/typing/tensor/abstract_tensor.py
@@ -395,10 +395,10 @@ def _docarray_to_ndarray(self) -> np.ndarray:
def __get_pydantic_core_schema__(
cls, _source_type: Any, handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
- return core_schema.general_plain_validator_function(
+ return core_schema.with_info_plain_validator_function(
cls.validate,
serialization=core_schema.plain_serializer_function_ser_schema(
- function=orjson_dumps,
+ function=lambda x: x._docarray_to_ndarray().tolist(),
return_schema=handler.generate_schema(bytes),
when_used="json-unless-none",
),
diff --git a/docarray/typing/url/any_url.py b/docarray/typing/url/any_url.py
index ddd17915132..b7c5d71f835 100644
--- a/docarray/typing/url/any_url.py
+++ b/docarray/typing/url/any_url.py
@@ -56,7 +56,7 @@ def _docarray_validate(
def __get_pydantic_core_schema__(
cls, source: Type[Any], handler: Optional['GetCoreSchemaHandler'] = None
) -> core_schema.CoreSchema:
- return core_schema.general_after_validator_function(
+ return core_schema.with_info_after_validator_function(
cls._docarray_validate,
core_schema.str_schema(),
)
diff --git a/docarray/utils/_internal/_typing.py b/docarray/utils/_internal/_typing.py
index 83e350a0602..3c2bd89a8e5 100644
--- a/docarray/utils/_internal/_typing.py
+++ b/docarray/utils/_internal/_typing.py
@@ -61,11 +61,15 @@ def safe_issubclass(x: type, a_tuple: type) -> bool:
:return: A boolean value - 'True' if 'x' is a subclass of 'A_tuple', 'False' otherwise.
Note that if the origin of 'x' is a list or tuple, the function immediately returns 'False'.
"""
+ origin = get_origin(x)
+ if origin: # If x is a generic type like DocList[SomeDoc], get its origin
+ x = origin
if (
- (get_origin(x) in (list, tuple, dict, set, Union))
+ (origin in (list, tuple, dict, set, Union))
or is_typevar(x)
or (type(x) == ForwardRef)
or is_typevar(x)
):
return False
- return issubclass(x, a_tuple)
+
+ return isinstance(x, type) and issubclass(x, a_tuple)
diff --git a/docarray/utils/create_dynamic_doc_class.py b/docarray/utils/create_dynamic_doc_class.py
index 744fea58c3e..c82a7c89487 100644
--- a/docarray/utils/create_dynamic_doc_class.py
+++ b/docarray/utils/create_dynamic_doc_class.py
@@ -54,8 +54,9 @@ class MyDoc(BaseDoc):
fields: Dict[str, Any] = {}
import copy
- fields_copy = copy.deepcopy(model.__fields__)
- annotations_copy = copy.deepcopy(model.__annotations__)
+ copy_model = copy.deepcopy(model)
+ fields_copy = copy_model.__fields__
+ annotations_copy = copy_model.__annotations__
for field_name, field in annotations_copy.items():
if field_name not in fields_copy:
continue
@@ -65,7 +66,7 @@ class MyDoc(BaseDoc):
else:
field_info = fields_copy[field_name].field_info
try:
- if safe_issubclass(field, DocList):
+ if safe_issubclass(field, DocList) and not is_pydantic_v2:
t: Any = field.doc_type
t_aux = create_pure_python_type_model(t)
fields[field_name] = (List[t_aux], field_info)
@@ -74,13 +75,14 @@ class MyDoc(BaseDoc):
except TypeError:
fields[field_name] = (field, field_info)
- return create_model(model.__name__, __base__=model, __doc__=model.__doc__, **fields)
+ return create_model(
+ copy_model.__name__, __base__=copy_model, __doc__=copy_model.__doc__, **fields
+ )
def _get_field_annotation_from_schema(
field_schema: Dict[str, Any],
field_name: str,
- root_schema: Dict[str, Any],
cached_models: Dict[str, Any],
is_tensor: bool = False,
num_recursions: int = 0,
@@ -90,7 +92,6 @@ def _get_field_annotation_from_schema(
Private method used to extract the corresponding field type from the schema.
:param field_schema: The schema from which to extract the type
:param field_name: The name of the field to be created
- :param root_schema: The schema of the root object, important to get references
:param cached_models: Parameter used when this method is called recursively to reuse partial nested classes.
:param is_tensor: Boolean used to tell between tensor and list
:param num_recursions: Number of recursions to properly handle nested types (Dict, List, etc ..)
@@ -110,7 +111,7 @@ def _get_field_annotation_from_schema(
ref_name = obj_ref.split('/')[-1]
any_of_types.append(
create_base_doc_from_schema(
- root_schema['definitions'][ref_name],
+ definitions[ref_name],
ref_name,
cached_models=cached_models,
definitions=definitions,
@@ -121,7 +122,6 @@ def _get_field_annotation_from_schema(
_get_field_annotation_from_schema(
any_of_schema,
field_name,
- root_schema=root_schema,
cached_models=cached_models,
is_tensor=tensor_shape is not None,
num_recursions=0,
@@ -160,7 +160,10 @@ def _get_field_annotation_from_schema(
doc_type: Any
if 'additionalProperties' in field_schema: # handle Dictionaries
additional_props = field_schema['additionalProperties']
- if additional_props.get('type') == 'object':
+ if (
+ isinstance(additional_props, dict)
+ and additional_props.get('type') == 'object'
+ ):
doc_type = create_base_doc_from_schema(
additional_props, field_name, cached_models=cached_models
)
@@ -201,7 +204,6 @@ def _get_field_annotation_from_schema(
ret = _get_field_annotation_from_schema(
field_schema=field_schema.get('items', {}),
field_name=field_name,
- root_schema=root_schema,
cached_models=cached_models,
is_tensor=tensor_shape is not None,
num_recursions=num_recursions + 1,
@@ -262,6 +264,24 @@ class MyDoc(BaseDoc):
:param definitions: Parameter used when this method is called recursively to reuse root definitions of other schemas.
:return: A BaseDoc class dynamically created following the `schema`.
"""
+
+ def clean_refs(value):
+ """Recursively remove $ref keys and #/$defs values from a data structure."""
+ if isinstance(value, dict):
+ # Create a new dictionary without $ref keys and without values containing #/$defs
+ cleaned_dict = {}
+ for k, v in value.items():
+ if k == '$ref':
+ continue
+ cleaned_dict[k] = clean_refs(v)
+ return cleaned_dict
+ elif isinstance(value, list):
+ # Process each item in the list
+ return [clean_refs(item) for item in value]
+ else:
+ # Return primitive values as-is
+ return value
+
if not definitions:
definitions = (
schema.get('definitions', {}) if not is_pydantic_v2 else schema.get('$defs')
@@ -275,10 +295,10 @@ class MyDoc(BaseDoc):
for field_name, field_schema in schema.get('properties', {}).items():
if field_name == 'id':
has_id = True
+ # Get the field type
field_type = _get_field_annotation_from_schema(
field_schema=field_schema,
field_name=field_name,
- root_schema=schema,
cached_models=cached_models,
is_tensor=False,
num_recursions=0,
@@ -294,10 +314,22 @@ class MyDoc(BaseDoc):
field_kwargs = {}
field_json_schema_extra = {}
for k, v in field_schema.items():
+ if field_name == 'id':
+ # Skip default_factory for Optional fields and use None
+ field_kwargs['default'] = None
if k in FieldInfo.__slots__:
field_kwargs[k] = v
else:
- field_json_schema_extra[k] = v
+ if k != '$ref':
+ if isinstance(v, dict):
+ cleaned_v = clean_refs(v)
+ if (
+ cleaned_v
+ ): # Only add if there's something left after cleaning
+ field_json_schema_extra[k] = cleaned_v
+ else:
+ field_json_schema_extra[k] = v
+
fields[field_name] = (
field_type,
FieldInfo(
diff --git a/docs/_versions.json b/docs/_versions.json
index b7c4791e91d..f318a2796a0 100644
--- a/docs/_versions.json
+++ b/docs/_versions.json
@@ -1 +1 @@
-[{"version": "v0.40.0"}, {"version": "v0.39.1"}, {"version": "v0.39.0"}, {"version": "v0.38.0"}, {"version": "v0.37.1"}, {"version": "v0.37.0"}, {"version": "v0.36.0"}, {"version": "v0.35.0"}, {"version": "v0.34.0"}, {"version": "v0.33.0"}, {"version": "v0.32.1"}, {"version": "v0.32.0"}, {"version": "v0.31.1"}, {"version": "v0.31.0"}, {"version": "v0.30.0"}, {"version": "v0.21.0"}, {"version": "v0.20.1"}, {"version": "v0.20.0"}, {"version": "v0.19.0"}, {"version": "v0.18.1"}, {"version": "v0.18.0"}, {"version": "v0.17.0"}, {"version": "v0.16.5"}, {"version": "v0.16.4"}, {"version": "v0.16.3"}, {"version": "v0.16.2"}, {"version": "v0.16.1"}, {"version": "v0.16.0"}, {"version": "v0.15.4"}, {"version": "v0.15.3"}, {"version": "v0.15.2"}, {"version": "v0.15.1"}, {"version": "v0.15.0"}, {"version": "v0.14.11"}, {"version": "v0.14.10"}, {"version": "v0.14.9"}, {"version": "v0.14.8"}, {"version": "v0.14.7"}, {"version": "v0.14.6"}, {"version": "v0.14.5"}, {"version": "v0.14.4"}, {"version": "v0.14.3"}, {"version": "v0.14.2"}, {"version": "v0.14.1"}, {"version": "v0.14.0"}, {"version": "v0.13.33"}, {"version": "v0.13.0"}, {"version": "v0.12.9"}, {"version": "v0.12.0"}, {"version": "v0.11.3"}, {"version": "v0.11.2"}, {"version": "v0.11.1"}, {"version": "v0.11.0"}, {"version": "v0.10.5"}, {"version": "v0.10.4"}, {"version": "v0.10.3"}, {"version": "v0.10.2"}, {"version": "v0.10.1"}, {"version": "v0.10.0"}]
\ No newline at end of file
+[{"version": "v0.40.1"}, {"version": "v0.40.0"}, {"version": "v0.39.1"}, {"version": "v0.39.0"}, {"version": "v0.38.0"}, {"version": "v0.37.1"}, {"version": "v0.37.0"}, {"version": "v0.36.0"}, {"version": "v0.35.0"}, {"version": "v0.34.0"}, {"version": "v0.33.0"}, {"version": "v0.32.1"}, {"version": "v0.32.0"}, {"version": "v0.31.1"}, {"version": "v0.31.0"}, {"version": "v0.30.0"}, {"version": "v0.21.0"}, {"version": "v0.20.1"}, {"version": "v0.20.0"}, {"version": "v0.19.0"}, {"version": "v0.18.1"}, {"version": "v0.18.0"}, {"version": "v0.17.0"}, {"version": "v0.16.5"}, {"version": "v0.16.4"}, {"version": "v0.16.3"}, {"version": "v0.16.2"}, {"version": "v0.16.1"}, {"version": "v0.16.0"}, {"version": "v0.15.4"}, {"version": "v0.15.3"}, {"version": "v0.15.2"}, {"version": "v0.15.1"}, {"version": "v0.15.0"}, {"version": "v0.14.11"}, {"version": "v0.14.10"}, {"version": "v0.14.9"}, {"version": "v0.14.8"}, {"version": "v0.14.7"}, {"version": "v0.14.6"}, {"version": "v0.14.5"}, {"version": "v0.14.4"}, {"version": "v0.14.3"}, {"version": "v0.14.2"}, {"version": "v0.14.1"}, {"version": "v0.14.0"}, {"version": "v0.13.33"}, {"version": "v0.13.0"}, {"version": "v0.12.9"}, {"version": "v0.12.0"}, {"version": "v0.11.3"}, {"version": "v0.11.2"}, {"version": "v0.11.1"}, {"version": "v0.11.0"}, {"version": "v0.10.5"}, {"version": "v0.10.4"}, {"version": "v0.10.3"}, {"version": "v0.10.2"}, {"version": "v0.10.1"}, {"version": "v0.10.0"}]
\ No newline at end of file
diff --git a/docs/user_guide/storing/doc_store/store_s3.md b/docs/user_guide/storing/doc_store/store_s3.md
index c4e0878133b..cd26f1a358d 100644
--- a/docs/user_guide/storing/doc_store/store_s3.md
+++ b/docs/user_guide/storing/doc_store/store_s3.md
@@ -12,7 +12,7 @@ When you want to use your [`DocList`][docarray.DocList] in another place, you ca
## Push & pull
To use the store [`DocList`][docarray.DocList] on S3, you need to pass an S3 path to the function starting with `'s3://'`.
-In the following demo, we use `MinIO` as a local S3 service. You could use the following docker-compose file to start the service in a Docker container.
+In the following demo, we use `MinIO` as a local S3 service. You could use the following docker compose file to start the service in a Docker container.
```yaml
version: "3"
@@ -26,7 +26,7 @@ services:
```
Save the above file as `docker-compose.yml` and run the following line in the same folder as the file.
```cmd
-docker-compose up
+docker compose up
```
```python
diff --git a/docs/user_guide/storing/index_elastic.md b/docs/user_guide/storing/index_elastic.md
index f05ef0e5cbc..89a104fefa6 100644
--- a/docs/user_guide/storing/index_elastic.md
+++ b/docs/user_guide/storing/index_elastic.md
@@ -45,13 +45,17 @@ from docarray.index import ElasticDocIndex # or ElasticV7DocIndex
from docarray.typing import NdArray
import numpy as np
+
# Define the document schema.
class MyDoc(BaseDoc):
- title: str
+ title: str
embedding: NdArray[128]
+
# Create dummy documents.
-docs = DocList[MyDoc](MyDoc(title=f'title #{i}', embedding=np.random.rand(128)) for i in range(10))
+docs = DocList[MyDoc](
+ MyDoc(title=f'title #{i}', embedding=np.random.rand(128)) for i in range(10)
+)
# Initialize a new ElasticDocIndex instance and add the documents to the index.
doc_index = ElasticDocIndex[MyDoc](index_name='my_index')
@@ -67,7 +71,7 @@ retrieved_docs = doc_index.find(query, search_field='embedding', limit=10)
## Initialize
-You can use docker-compose to create a local Elasticsearch service with the following `docker-compose.yml`.
+You can use docker compose to create a local Elasticsearch service with the following `docker-compose.yml`.
```yaml
version: "3.3"
@@ -91,7 +95,7 @@ networks:
Run the following command in the folder of the above `docker-compose.yml` to start the service:
```bash
-docker-compose up
+docker compose up
```
### Schema definition
@@ -225,9 +229,7 @@ You can also search for multiple documents at once, in a batch, using the [`find
```python
# create some query Documents
- queries = DocList[SimpleDoc](
- SimpleDoc(tensor=np.random.rand(128)) for i in range(3)
- )
+ queries = DocList[SimpleDoc](SimpleDoc(tensor=np.random.rand(128)) for i in range(3))
# find similar documents
matches, scores = doc_index.find_batched(queries, search_field='tensor', limit=5)
diff --git a/docs/user_guide/storing/index_milvus.md b/docs/user_guide/storing/index_milvus.md
index 4cf9c91c7d5..18431902cec 100644
--- a/docs/user_guide/storing/index_milvus.md
+++ b/docs/user_guide/storing/index_milvus.md
@@ -27,13 +27,17 @@ from docarray.typing import NdArray
from pydantic import Field
import numpy as np
+
# Define the document schema.
class MyDoc(BaseDoc):
- title: str
+ title: str
embedding: NdArray[128] = Field(is_embedding=True)
+
# Create dummy documents.
-docs = DocList[MyDoc](MyDoc(title=f'title #{i}', embedding=np.random.rand(128)) for i in range(10))
+docs = DocList[MyDoc](
+ MyDoc(title=f'title #{i}', embedding=np.random.rand(128)) for i in range(10)
+)
# Initialize a new MilvusDocumentIndex instance and add the documents to the index.
doc_index = MilvusDocumentIndex[MyDoc](index_name='tmp_index_1')
@@ -55,7 +59,7 @@ wget https://github.com/milvus-io/milvus/releases/download/v2.2.11/milvus-standa
And start Milvus by running:
```shell
-sudo docker-compose up -d
+sudo docker compose up -d
```
Learn more on [Milvus documentation](https://milvus.io/docs/install_standalone-docker.md).
@@ -142,10 +146,12 @@ Now that you have a Document Index, you can add data to it, using the [`index()`
import numpy as np
from docarray import DocList
+
class MyDoc(BaseDoc):
- title: str
+ title: str
embedding: NdArray[128] = Field(is_embedding=True)
+
doc_index = MilvusDocumentIndex[MyDoc](index_name='tmp_index_5')
# create some random data
@@ -273,7 +279,9 @@ class Book(BaseDoc):
embedding: NdArray[10] = Field(is_embedding=True)
-books = DocList[Book]([Book(price=i * 10, embedding=np.random.rand(10)) for i in range(10)])
+books = DocList[Book](
+ [Book(price=i * 10, embedding=np.random.rand(10)) for i in range(10)]
+)
book_index = MilvusDocumentIndex[Book](index_name='tmp_index_6')
book_index.index(books)
@@ -312,8 +320,11 @@ class SimpleSchema(BaseDoc):
price: int
embedding: NdArray[128] = Field(is_embedding=True)
+
# Create dummy documents.
-docs = DocList[SimpleSchema](SimpleSchema(price=i, embedding=np.random.rand(128)) for i in range(10))
+docs = DocList[SimpleSchema](
+ SimpleSchema(price=i, embedding=np.random.rand(128)) for i in range(10)
+)
doc_index = MilvusDocumentIndex[SimpleSchema](index_name='tmp_index_7')
doc_index.index(docs)
@@ -407,7 +418,9 @@ You can pass any of the above as keyword arguments to the `__init__()` method or
```python
class SimpleDoc(BaseDoc):
- tensor: NdArray[128] = Field(is_embedding=True, index_type='IVF_FLAT', metric_type='L2')
+ tensor: NdArray[128] = Field(
+ is_embedding=True, index_type='IVF_FLAT', metric_type='L2'
+ )
doc_index = MilvusDocumentIndex[SimpleDoc](index_name='tmp_index_10')
diff --git a/docs/user_guide/storing/index_qdrant.md b/docs/user_guide/storing/index_qdrant.md
index 71770e45982..3d34b472a0c 100644
--- a/docs/user_guide/storing/index_qdrant.md
+++ b/docs/user_guide/storing/index_qdrant.md
@@ -22,13 +22,17 @@ from docarray.index import QdrantDocumentIndex
from docarray.typing import NdArray
import numpy as np
+
# Define the document schema.
class MyDoc(BaseDoc):
- title: str
+ title: str
embedding: NdArray[128]
+
# Create dummy documents.
-docs = DocList[MyDoc](MyDoc(title=f'title #{i}', embedding=np.random.rand(128)) for i in range(10))
+docs = DocList[MyDoc](
+ MyDoc(title=f'title #{i}', embedding=np.random.rand(128)) for i in range(10)
+)
# Initialize a new QdrantDocumentIndex instance and add the documents to the index.
doc_index = QdrantDocumentIndex[MyDoc](host='localhost')
@@ -46,7 +50,7 @@ You can initialize [QdrantDocumentIndex][docarray.index.backends.qdrant.QdrantDo
**Connecting to a local Qdrant instance running as a Docker container**
-You can use docker-compose to create a local Qdrant service with the following `docker-compose.yml`.
+You can use docker compose to create a local Qdrant service with the following `docker-compose.yml`.
```yaml
version: '3.8'
@@ -66,7 +70,7 @@ services:
Run the following command in the folder of the above `docker-compose.yml` to start the service:
```bash
-docker-compose up
+docker compose up
```
Next, you can create a [QdrantDocumentIndex][docarray.index.backends.qdrant.QdrantDocumentIndex] instance using:
@@ -89,7 +93,7 @@ doc_index = QdrantDocumentIndex[MyDoc](qdrant_config)
**Connecting to Qdrant Cloud service**
```python
qdrant_config = QdrantDocumentIndex.DBConfig(
- "https://YOUR-CLUSTER-URL.aws.cloud.qdrant.io",
+ "https://YOUR-CLUSTER-URL.aws.cloud.qdrant.io",
api_key="",
)
doc_index = QdrantDocumentIndex[MyDoc](qdrant_config)
@@ -317,9 +321,7 @@ book_index = QdrantDocumentIndex[Book]()
book_index.index(books)
# filter for books that are cheaper than 29 dollars
-query = rest.Filter(
- must=[rest.FieldCondition(key='price', range=rest.Range(lt=29))]
- )
+query = rest.Filter(must=[rest.FieldCondition(key='price', range=rest.Range(lt=29))])
cheap_books = book_index.filter(filter_query=query)
assert len(cheap_books) == 3
@@ -372,7 +374,9 @@ class SimpleDoc(BaseDoc):
doc_index = QdrantDocumentIndex[SimpleDoc](host='localhost')
index_docs = [
- SimpleDoc(id=f'{i}', tens=np.ones(10) * i, num=int(i / 2), text=f'Lorem ipsum {int(i/2)}')
+ SimpleDoc(
+ id=f'{i}', tens=np.ones(10) * i, num=int(i / 2), text=f'Lorem ipsum {int(i/2)}'
+ )
for i in range(10)
]
doc_index.index(index_docs)
@@ -380,16 +384,16 @@ doc_index.index(index_docs)
find_query = np.ones(10)
text_search_query = 'ipsum 1'
filter_query = rest.Filter(
- must=[
- rest.FieldCondition(
- key='num',
- range=rest.Range(
- gte=1,
- lt=5,
- ),
- )
- ]
- )
+ must=[
+ rest.FieldCondition(
+ key='num',
+ range=rest.Range(
+ gte=1,
+ lt=5,
+ ),
+ )
+ ]
+)
query = (
doc_index.build_query()
@@ -437,6 +441,8 @@ import numpy as np
from docarray import BaseDoc, DocList
from docarray.typing import NdArray
from docarray.index import QdrantDocumentIndex
+
+
class MyDoc(BaseDoc):
text: str
embedding: NdArray[128]
@@ -445,7 +451,12 @@ class MyDoc(BaseDoc):
Now, we can instantiate our Index and add some data:
```python
docs = DocList[MyDoc](
- [MyDoc(embedding=np.random.rand(10), text=f'I am the first version of Document {i}') for i in range(100)]
+ [
+ MyDoc(
+ embedding=np.random.rand(10), text=f'I am the first version of Document {i}'
+ )
+ for i in range(100)
+ ]
)
index = QdrantDocumentIndex[MyDoc]()
index.index(docs)
diff --git a/docs/user_guide/storing/index_weaviate.md b/docs/user_guide/storing/index_weaviate.md
index 029c86de377..d1d86d03f2e 100644
--- a/docs/user_guide/storing/index_weaviate.md
+++ b/docs/user_guide/storing/index_weaviate.md
@@ -27,13 +27,17 @@ from docarray.typing import NdArray
from pydantic import Field
import numpy as np
+
# Define the document schema.
class MyDoc(BaseDoc):
- title: str
+ title: str
embedding: NdArray[128] = Field(is_embedding=True)
+
# Create dummy documents.
-docs = DocList[MyDoc](MyDoc(title=f'title #{i}', embedding=np.random.rand(128)) for i in range(10))
+docs = DocList[MyDoc](
+ MyDoc(title=f'title #{i}', embedding=np.random.rand(128)) for i in range(10)
+)
# Initialize a new WeaviateDocumentIndex instance and add the documents to the index.
doc_index = WeaviateDocumentIndex[MyDoc]()
@@ -59,7 +63,7 @@ There are multiple ways to start a Weaviate instance, depending on your use case
| ----- | ----- | ----- | ----- |
| **Weaviate Cloud Services (WCS)** | Development and production | Limited | **Recommended for most users** |
| **Embedded Weaviate** | Experimentation | Limited | Experimental (as of Apr 2023) |
-| **Docker-Compose** | Development | Yes | **Recommended for development + customizability** |
+| **Docker Compose** | Development | Yes | **Recommended for development + customizability** |
| **Kubernetes** | Production | Yes | |
### Instantiation instructions
@@ -70,7 +74,7 @@ Go to the [WCS console](https://console.weaviate.cloud) and create an instance u
Weaviate instances on WCS come pre-configured, so no further configuration is required.
-**Docker-Compose (self-managed)**
+**Docker Compose (self-managed)**
Get a configuration file (`docker-compose.yaml`). You can build it using [this interface](https://weaviate.io/developers/weaviate/installation/docker-compose), or download it directly with:
@@ -84,12 +88,12 @@ Where `v` is the actual version, such as `v1.18.3`.
curl -o docker-compose.yml "https://configuration.weaviate.io/v2/docker-compose/docker-compose.yml?modules=standalone&runtime=docker-compose&weaviate_version=v1.18.3"
```
-**Start up Weaviate with Docker-Compose**
+**Start up Weaviate with Docker Compose**
Then you can start up Weaviate by running from a shell:
```shell
-docker-compose up -d
+docker compose up -d
```
**Shut down Weaviate**
@@ -97,7 +101,7 @@ docker-compose up -d
Then you can shut down Weaviate by running from a shell:
```shell
-docker-compose down
+docker compose down
```
**Notes**
@@ -107,7 +111,7 @@ Unless data persistence or backups are set up, shutting down the Docker instance
See documentation on [Persistent volume](https://weaviate.io/developers/weaviate/installation/docker-compose#persistent-volume) and [Backups](https://weaviate.io/developers/weaviate/configuration/backups) to prevent this if persistence is desired.
```bash
-docker-compose up -d
+docker compose up -d
```
**Embedded Weaviate (from the application)**
@@ -192,9 +196,7 @@ dbconfig = WeaviateDocumentIndex.DBConfig(
### Create an instance
Let's connect to a local Weaviate service and instantiate a `WeaviateDocumentIndex` instance:
```python
-dbconfig = WeaviateDocumentIndex.DBConfig(
- host="http://localhost:8080"
-)
+dbconfig = WeaviateDocumentIndex.DBConfig(host="http://localhost:8080")
doc_index = WeaviateDocumentIndex[MyDoc](db_config=dbconfig)
```
@@ -378,10 +380,10 @@ the [`find()`][docarray.index.abstract.BaseDocIndex.find] method:
embedding=np.array([1, 2]),
file=np.random.rand(100),
)
-
+
# find similar documents
matches, scores = doc_index.find(query, limit=5)
-
+
print(f"{matches=}")
print(f"{matches.text=}")
print(f"{scores=}")
@@ -428,10 +430,10 @@ You can also search for multiple documents at once, in a batch, using the [`find
)
for i in range(3)
)
-
+
# find similar documents
matches, scores = doc_index.find_batched(queries, limit=5)
-
+
print(f"{matches=}")
print(f"{matches[0].text=}")
print(f"{scores=}")
@@ -481,7 +483,9 @@ class Book(BaseDoc):
embedding: NdArray[10] = Field(is_embedding=True)
-books = DocList[Book]([Book(price=i * 10, embedding=np.random.rand(10)) for i in range(10)])
+books = DocList[Book](
+ [Book(price=i * 10, embedding=np.random.rand(10)) for i in range(10)]
+)
book_index = WeaviateDocumentIndex[Book](index_name='tmp_index')
book_index.index(books)
@@ -602,7 +606,7 @@ del doc_index[ids[1:]] # del by list of ids
**WCS instances come pre-configured**, and as such additional settings are not configurable outside of those chosen at creation, such as whether to enable authentication.
-For other cases, such as **Docker-Compose deployment**, its settings can be modified through the configuration file, such as the `docker-compose.yaml` file.
+For other cases, such as **Docker Compose deployment**, its settings can be modified through the configuration file, such as the `docker-compose.yaml` file.
Some of the more commonly used settings include:
diff --git a/poetry.lock b/poetry.lock
index 9980ec66271..4e185af1575 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1,4 +1,4 @@
-# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
+# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
[[package]]
name = "aiofiles"
@@ -284,17 +284,17 @@ tests-no-zope = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy
[[package]]
name = "authlib"
-version = "1.2.0"
+version = "1.3.1"
description = "The ultimate Python library in building OAuth and OpenID Connect servers and clients."
optional = true
-python-versions = "*"
+python-versions = ">=3.8"
files = [
- {file = "Authlib-1.2.0-py2.py3-none-any.whl", hash = "sha256:4ddf4fd6cfa75c9a460b361d4bd9dac71ffda0be879dbe4292a02e92349ad55a"},
- {file = "Authlib-1.2.0.tar.gz", hash = "sha256:4fa3e80883a5915ef9f5bc28630564bc4ed5b5af39812a3ff130ec76bd631e9d"},
+ {file = "Authlib-1.3.1-py2.py3-none-any.whl", hash = "sha256:d35800b973099bbadc49b42b256ecb80041ad56b7fe1216a362c7943c088f377"},
+ {file = "authlib-1.3.1.tar.gz", hash = "sha256:7ae843f03c06c5c0debd63c9db91f9fda64fa62a42a77419fa15fbb7e7a58917"},
]
[package.dependencies]
-cryptography = ">=3.2"
+cryptography = "*"
[[package]]
name = "av"
@@ -531,13 +531,13 @@ files = [
[[package]]
name = "certifi"
-version = "2022.9.24"
+version = "2024.7.4"
description = "Python package for providing Mozilla's CA Bundle."
optional = false
python-versions = ">=3.6"
files = [
- {file = "certifi-2022.9.24-py3-none-any.whl", hash = "sha256:90c1a32f1d68f940488354e36370f6cca89f0f106db09518524c88d6ed83f382"},
- {file = "certifi-2022.9.24.tar.gz", hash = "sha256:0d9c601124e5a6ba9712dbc60d9c53c21e34f5f641fe83002317394311bdce14"},
+ {file = "certifi-2024.7.4-py3-none-any.whl", hash = "sha256:c198e21b1289c2ab85ee4e67bb4b4ef3ead0892059901a8d5b622f24a1101e90"},
+ {file = "certifi-2024.7.4.tar.gz", hash = "sha256:5a1e7645bc0ec61a09e26c36f6106dd4cf40c6db3a1fb6352b0244e7fb057c7b"},
]
[[package]]
@@ -3478,47 +3478,47 @@ files = [
[[package]]
name = "pydantic"
-version = "1.10.8"
+version = "1.10.13"
description = "Data validation and settings management using python type hints"
optional = false
python-versions = ">=3.7"
files = [
- {file = "pydantic-1.10.8-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:1243d28e9b05003a89d72e7915fdb26ffd1d39bdd39b00b7dbe4afae4b557f9d"},
- {file = "pydantic-1.10.8-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c0ab53b609c11dfc0c060d94335993cc2b95b2150e25583bec37a49b2d6c6c3f"},
- {file = "pydantic-1.10.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f9613fadad06b4f3bc5db2653ce2f22e0de84a7c6c293909b48f6ed37b83c61f"},
- {file = "pydantic-1.10.8-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:df7800cb1984d8f6e249351139667a8c50a379009271ee6236138a22a0c0f319"},
- {file = "pydantic-1.10.8-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:0c6fafa0965b539d7aab0a673a046466d23b86e4b0e8019d25fd53f4df62c277"},
- {file = "pydantic-1.10.8-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e82d4566fcd527eae8b244fa952d99f2ca3172b7e97add0b43e2d97ee77f81ab"},
- {file = "pydantic-1.10.8-cp310-cp310-win_amd64.whl", hash = "sha256:ab523c31e22943713d80d8d342d23b6f6ac4b792a1e54064a8d0cf78fd64e800"},
- {file = "pydantic-1.10.8-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:666bdf6066bf6dbc107b30d034615d2627e2121506c555f73f90b54a463d1f33"},
- {file = "pydantic-1.10.8-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:35db5301b82e8661fa9c505c800d0990bc14e9f36f98932bb1d248c0ac5cada5"},
- {file = "pydantic-1.10.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f90c1e29f447557e9e26afb1c4dbf8768a10cc676e3781b6a577841ade126b85"},
- {file = "pydantic-1.10.8-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:93e766b4a8226e0708ef243e843105bf124e21331694367f95f4e3b4a92bbb3f"},
- {file = "pydantic-1.10.8-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:88f195f582851e8db960b4a94c3e3ad25692c1c1539e2552f3df7a9e972ef60e"},
- {file = "pydantic-1.10.8-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:34d327c81e68a1ecb52fe9c8d50c8a9b3e90d3c8ad991bfc8f953fb477d42fb4"},
- {file = "pydantic-1.10.8-cp311-cp311-win_amd64.whl", hash = "sha256:d532bf00f381bd6bc62cabc7d1372096b75a33bc197a312b03f5838b4fb84edd"},
- {file = "pydantic-1.10.8-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:7d5b8641c24886d764a74ec541d2fc2c7fb19f6da2a4001e6d580ba4a38f7878"},
- {file = "pydantic-1.10.8-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7b1f6cb446470b7ddf86c2e57cd119a24959af2b01e552f60705910663af09a4"},
- {file = "pydantic-1.10.8-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c33b60054b2136aef8cf190cd4c52a3daa20b2263917c49adad20eaf381e823b"},
- {file = "pydantic-1.10.8-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:1952526ba40b220b912cdc43c1c32bcf4a58e3f192fa313ee665916b26befb68"},
- {file = "pydantic-1.10.8-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:bb14388ec45a7a0dc429e87def6396f9e73c8c77818c927b6a60706603d5f2ea"},
- {file = "pydantic-1.10.8-cp37-cp37m-win_amd64.whl", hash = "sha256:16f8c3e33af1e9bb16c7a91fc7d5fa9fe27298e9f299cff6cb744d89d573d62c"},
- {file = "pydantic-1.10.8-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1ced8375969673929809d7f36ad322934c35de4af3b5e5b09ec967c21f9f7887"},
- {file = "pydantic-1.10.8-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:93e6bcfccbd831894a6a434b0aeb1947f9e70b7468f274154d03d71fabb1d7c6"},
- {file = "pydantic-1.10.8-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:191ba419b605f897ede9892f6c56fb182f40a15d309ef0142212200a10af4c18"},
- {file = "pydantic-1.10.8-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:052d8654cb65174d6f9490cc9b9a200083a82cf5c3c5d3985db765757eb3b375"},
- {file = "pydantic-1.10.8-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:ceb6a23bf1ba4b837d0cfe378329ad3f351b5897c8d4914ce95b85fba96da5a1"},
- {file = "pydantic-1.10.8-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:6f2e754d5566f050954727c77f094e01793bcb5725b663bf628fa6743a5a9108"},
- {file = "pydantic-1.10.8-cp38-cp38-win_amd64.whl", hash = "sha256:6a82d6cda82258efca32b40040228ecf43a548671cb174a1e81477195ed3ed56"},
- {file = "pydantic-1.10.8-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3e59417ba8a17265e632af99cc5f35ec309de5980c440c255ab1ca3ae96a3e0e"},
- {file = "pydantic-1.10.8-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:84d80219c3f8d4cad44575e18404099c76851bc924ce5ab1c4c8bb5e2a2227d0"},
- {file = "pydantic-1.10.8-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e4148e635994d57d834be1182a44bdb07dd867fa3c2d1b37002000646cc5459"},
- {file = "pydantic-1.10.8-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:12f7b0bf8553e310e530e9f3a2f5734c68699f42218bf3568ef49cd9b0e44df4"},
- {file = "pydantic-1.10.8-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:42aa0c4b5c3025483240a25b09f3c09a189481ddda2ea3a831a9d25f444e03c1"},
- {file = "pydantic-1.10.8-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:17aef11cc1b997f9d574b91909fed40761e13fac438d72b81f902226a69dac01"},
- {file = "pydantic-1.10.8-cp39-cp39-win_amd64.whl", hash = "sha256:66a703d1983c675a6e0fed8953b0971c44dba48a929a2000a493c3772eb61a5a"},
- {file = "pydantic-1.10.8-py3-none-any.whl", hash = "sha256:7456eb22ed9aaa24ff3e7b4757da20d9e5ce2a81018c1b3ebd81a0b88a18f3b2"},
- {file = "pydantic-1.10.8.tar.gz", hash = "sha256:1410275520dfa70effadf4c21811d755e7ef9bb1f1d077a21958153a92c8d9ca"},
+ {file = "pydantic-1.10.13-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:efff03cc7a4f29d9009d1c96ceb1e7a70a65cfe86e89d34e4a5f2ab1e5693737"},
+ {file = "pydantic-1.10.13-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3ecea2b9d80e5333303eeb77e180b90e95eea8f765d08c3d278cd56b00345d01"},
+ {file = "pydantic-1.10.13-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1740068fd8e2ef6eb27a20e5651df000978edce6da6803c2bef0bc74540f9548"},
+ {file = "pydantic-1.10.13-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:84bafe2e60b5e78bc64a2941b4c071a4b7404c5c907f5f5a99b0139781e69ed8"},
+ {file = "pydantic-1.10.13-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:bc0898c12f8e9c97f6cd44c0ed70d55749eaf783716896960b4ecce2edfd2d69"},
+ {file = "pydantic-1.10.13-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:654db58ae399fe6434e55325a2c3e959836bd17a6f6a0b6ca8107ea0571d2e17"},
+ {file = "pydantic-1.10.13-cp310-cp310-win_amd64.whl", hash = "sha256:75ac15385a3534d887a99c713aa3da88a30fbd6204a5cd0dc4dab3d770b9bd2f"},
+ {file = "pydantic-1.10.13-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c553f6a156deb868ba38a23cf0df886c63492e9257f60a79c0fd8e7173537653"},
+ {file = "pydantic-1.10.13-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5e08865bc6464df8c7d61439ef4439829e3ab62ab1669cddea8dd00cd74b9ffe"},
+ {file = "pydantic-1.10.13-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e31647d85a2013d926ce60b84f9dd5300d44535a9941fe825dc349ae1f760df9"},
+ {file = "pydantic-1.10.13-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:210ce042e8f6f7c01168b2d84d4c9eb2b009fe7bf572c2266e235edf14bacd80"},
+ {file = "pydantic-1.10.13-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:8ae5dd6b721459bfa30805f4c25880e0dd78fc5b5879f9f7a692196ddcb5a580"},
+ {file = "pydantic-1.10.13-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f8e81fc5fb17dae698f52bdd1c4f18b6ca674d7068242b2aff075f588301bbb0"},
+ {file = "pydantic-1.10.13-cp311-cp311-win_amd64.whl", hash = "sha256:61d9dce220447fb74f45e73d7ff3b530e25db30192ad8d425166d43c5deb6df0"},
+ {file = "pydantic-1.10.13-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:4b03e42ec20286f052490423682016fd80fda830d8e4119f8ab13ec7464c0132"},
+ {file = "pydantic-1.10.13-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f59ef915cac80275245824e9d771ee939133be38215555e9dc90c6cb148aaeb5"},
+ {file = "pydantic-1.10.13-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5a1f9f747851338933942db7af7b6ee8268568ef2ed86c4185c6ef4402e80ba8"},
+ {file = "pydantic-1.10.13-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:97cce3ae7341f7620a0ba5ef6cf043975cd9d2b81f3aa5f4ea37928269bc1b87"},
+ {file = "pydantic-1.10.13-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:854223752ba81e3abf663d685f105c64150873cc6f5d0c01d3e3220bcff7d36f"},
+ {file = "pydantic-1.10.13-cp37-cp37m-win_amd64.whl", hash = "sha256:b97c1fac8c49be29486df85968682b0afa77e1b809aff74b83081cc115e52f33"},
+ {file = "pydantic-1.10.13-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c958d053453a1c4b1c2062b05cd42d9d5c8eb67537b8d5a7e3c3032943ecd261"},
+ {file = "pydantic-1.10.13-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4c5370a7edaac06daee3af1c8b1192e305bc102abcbf2a92374b5bc793818599"},
+ {file = "pydantic-1.10.13-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7d6f6e7305244bddb4414ba7094ce910560c907bdfa3501e9db1a7fd7eaea127"},
+ {file = "pydantic-1.10.13-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d3a3c792a58e1622667a2837512099eac62490cdfd63bd407993aaf200a4cf1f"},
+ {file = "pydantic-1.10.13-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:c636925f38b8db208e09d344c7aa4f29a86bb9947495dd6b6d376ad10334fb78"},
+ {file = "pydantic-1.10.13-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:678bcf5591b63cc917100dc50ab6caebe597ac67e8c9ccb75e698f66038ea953"},
+ {file = "pydantic-1.10.13-cp38-cp38-win_amd64.whl", hash = "sha256:6cf25c1a65c27923a17b3da28a0bdb99f62ee04230c931d83e888012851f4e7f"},
+ {file = "pydantic-1.10.13-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:8ef467901d7a41fa0ca6db9ae3ec0021e3f657ce2c208e98cd511f3161c762c6"},
+ {file = "pydantic-1.10.13-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:968ac42970f57b8344ee08837b62f6ee6f53c33f603547a55571c954a4225691"},
+ {file = "pydantic-1.10.13-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9849f031cf8a2f0a928fe885e5a04b08006d6d41876b8bbd2fc68a18f9f2e3fd"},
+ {file = "pydantic-1.10.13-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:56e3ff861c3b9c6857579de282ce8baabf443f42ffba355bf070770ed63e11e1"},
+ {file = "pydantic-1.10.13-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:9f00790179497767aae6bcdc36355792c79e7bbb20b145ff449700eb076c5f96"},
+ {file = "pydantic-1.10.13-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:75b297827b59bc229cac1a23a2f7a4ac0031068e5be0ce385be1462e7e17a35d"},
+ {file = "pydantic-1.10.13-cp39-cp39-win_amd64.whl", hash = "sha256:e70ca129d2053fb8b728ee7d1af8e553a928d7e301a311094b8a0501adc8763d"},
+ {file = "pydantic-1.10.13-py3-none-any.whl", hash = "sha256:b87326822e71bd5f313e7d3bfdc77ac3247035ac10b0c0618bd99dcf95b1e687"},
+ {file = "pydantic-1.10.13.tar.gz", hash = "sha256:32c8b48dcd3b2ac4e78b0ba4af3a2c2eb6048cb75202f0ea7b34feb740efc340"},
]
[package.dependencies]
@@ -4070,23 +4070,26 @@ py = {version = "*", markers = "implementation_name == \"pypy\""}
[[package]]
name = "qdrant-client"
-version = "1.4.0"
+version = "1.9.0"
description = "Client library for the Qdrant vector search engine"
optional = true
-python-versions = ">=3.7,<3.12"
+python-versions = ">=3.8"
files = [
- {file = "qdrant_client-1.4.0-py3-none-any.whl", hash = "sha256:2f9e563955b5163da98016f2ed38d9aea5058576c7c5844e9aa205d28155f56d"},
- {file = "qdrant_client-1.4.0.tar.gz", hash = "sha256:2e54f5a80eb1e7e67f4603b76365af4817af15fb3d0c0f44de4fd93afbbe5537"},
+ {file = "qdrant_client-1.9.0-py3-none-any.whl", hash = "sha256:ee02893eab1f642481b1ac1e38eb68ec30bab0f673bef7cc05c19fa5d2cbf43e"},
+ {file = "qdrant_client-1.9.0.tar.gz", hash = "sha256:7b1792f616651a6f0a76312f945c13d088e9451726795b82ce0350f7df3b7981"},
]
[package.dependencies]
grpcio = ">=1.41.0"
grpcio-tools = ">=1.41.0"
-httpx = {version = ">=0.14.0", extras = ["http2"]}
-numpy = {version = ">=1.21", markers = "python_version >= \"3.8\""}
+httpx = {version = ">=0.20.0", extras = ["http2"]}
+numpy = {version = ">=1.21", markers = "python_version >= \"3.8\" and python_version < \"3.12\""}
portalocker = ">=2.7.0,<3.0.0"
pydantic = ">=1.10.8"
-urllib3 = ">=1.26.14,<2.0.0"
+urllib3 = ">=1.26.14,<3"
+
+[package.extras]
+fastembed = ["fastembed (==0.2.6)"]
[[package]]
name = "redis"
@@ -4496,19 +4499,18 @@ tornado = ["tornado (>=5)"]
[[package]]
name = "setuptools"
-version = "65.5.1"
+version = "70.0.0"
description = "Easily download, build, install, upgrade, and uninstall Python packages"
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.8"
files = [
- {file = "setuptools-65.5.1-py3-none-any.whl", hash = "sha256:d0b9a8433464d5800cbe05094acf5c6d52a91bfac9b52bcfc4d41382be5d5d31"},
- {file = "setuptools-65.5.1.tar.gz", hash = "sha256:e197a19aa8ec9722928f2206f8de752def0e4c9fc6953527360d1c36d94ddb2f"},
+ {file = "setuptools-70.0.0-py3-none-any.whl", hash = "sha256:54faa7f2e8d2d11bcd2c07bed282eef1046b5c080d1c32add737d7b5817b1ad4"},
+ {file = "setuptools-70.0.0.tar.gz", hash = "sha256:f211a66637b8fa059bb28183da127d4e86396c991a942b028c6650d4319c3fd0"},
]
[package.extras]
-docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-hoverxref (<2)", "sphinx-inline-tabs", "sphinx-notfound-page (==0.8.3)", "sphinx-reredirects", "sphinxcontrib-towncrier"]
-testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8 (<5)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pip-run (>=8.8)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"]
-testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"]
+docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"]
+testing = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "mypy (==1.9)", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.1)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (>=0.2.1)", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"]
[[package]]
name = "shapely"
@@ -4768,22 +4770,22 @@ opt-einsum = ["opt-einsum (>=3.3)"]
[[package]]
name = "tornado"
-version = "6.2"
+version = "6.4.1"
description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed."
optional = false
-python-versions = ">= 3.7"
+python-versions = ">=3.8"
files = [
- {file = "tornado-6.2-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:20f638fd8cc85f3cbae3c732326e96addff0a15e22d80f049e00121651e82e72"},
- {file = "tornado-6.2-cp37-abi3-macosx_10_9_x86_64.whl", hash = "sha256:87dcafae3e884462f90c90ecc200defe5e580a7fbbb4365eda7c7c1eb809ebc9"},
- {file = "tornado-6.2-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ba09ef14ca9893954244fd872798b4ccb2367c165946ce2dd7376aebdde8e3ac"},
- {file = "tornado-6.2-cp37-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b8150f721c101abdef99073bf66d3903e292d851bee51910839831caba341a75"},
- {file = "tornado-6.2-cp37-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d3a2f5999215a3a06a4fc218026cd84c61b8b2b40ac5296a6db1f1451ef04c1e"},
- {file = "tornado-6.2-cp37-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:5f8c52d219d4995388119af7ccaa0bcec289535747620116a58d830e7c25d8a8"},
- {file = "tornado-6.2-cp37-abi3-musllinux_1_1_i686.whl", hash = "sha256:6fdfabffd8dfcb6cf887428849d30cf19a3ea34c2c248461e1f7d718ad30b66b"},
- {file = "tornado-6.2-cp37-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:1d54d13ab8414ed44de07efecb97d4ef7c39f7438cf5e976ccd356bebb1b5fca"},
- {file = "tornado-6.2-cp37-abi3-win32.whl", hash = "sha256:5c87076709343557ef8032934ce5f637dbb552efa7b21d08e89ae7619ed0eb23"},
- {file = "tornado-6.2-cp37-abi3-win_amd64.whl", hash = "sha256:e5f923aa6a47e133d1cf87d60700889d7eae68988704e20c75fb2d65677a8e4b"},
- {file = "tornado-6.2.tar.gz", hash = "sha256:9b630419bde84ec666bfd7ea0a4cb2a8a651c2d5cccdbdd1972a0c859dfc3c13"},
+ {file = "tornado-6.4.1-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:163b0aafc8e23d8cdc3c9dfb24c5368af84a81e3364745ccb4427669bf84aec8"},
+ {file = "tornado-6.4.1-cp38-abi3-macosx_10_9_x86_64.whl", hash = "sha256:6d5ce3437e18a2b66fbadb183c1d3364fb03f2be71299e7d10dbeeb69f4b2a14"},
+ {file = "tornado-6.4.1-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2e20b9113cd7293f164dc46fffb13535266e713cdb87bd2d15ddb336e96cfc4"},
+ {file = "tornado-6.4.1-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8ae50a504a740365267b2a8d1a90c9fbc86b780a39170feca9bcc1787ff80842"},
+ {file = "tornado-6.4.1-cp38-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:613bf4ddf5c7a95509218b149b555621497a6cc0d46ac341b30bd9ec19eac7f3"},
+ {file = "tornado-6.4.1-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:25486eb223babe3eed4b8aecbac33b37e3dd6d776bc730ca14e1bf93888b979f"},
+ {file = "tornado-6.4.1-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:454db8a7ecfcf2ff6042dde58404164d969b6f5d58b926da15e6b23817950fc4"},
+ {file = "tornado-6.4.1-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a02a08cc7a9314b006f653ce40483b9b3c12cda222d6a46d4ac63bb6c9057698"},
+ {file = "tornado-6.4.1-cp38-abi3-win32.whl", hash = "sha256:d9a566c40b89757c9aa8e6f032bcdb8ca8795d7c1a9762910c722b1635c9de4d"},
+ {file = "tornado-6.4.1-cp38-abi3-win_amd64.whl", hash = "sha256:b24b8982ed444378d7f21d563f4180a2de31ced9d8d84443907a0a64da2072e7"},
+ {file = "tornado-6.4.1.tar.gz", hash = "sha256:92d3ab53183d8c50f8204a51e6f91d18a15d5ef261e84d452800d4ff6fc504e9"},
]
[[package]]
@@ -5066,17 +5068,17 @@ dev = ["flake8", "flake8-annotations", "flake8-bandit", "flake8-bugbear", "flake
[[package]]
name = "urllib3"
-version = "1.26.14"
+version = "1.26.19"
description = "HTTP library with thread-safe connection pooling, file post, and more."
optional = false
-python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*"
+python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7"
files = [
- {file = "urllib3-1.26.14-py2.py3-none-any.whl", hash = "sha256:75edcdc2f7d85b137124a6c3c9fc3933cdeaa12ecb9a6a959f22797a0feca7e1"},
- {file = "urllib3-1.26.14.tar.gz", hash = "sha256:076907bf8fd355cde77728471316625a4d2f7e713c125f51953bb5b3eecf4f72"},
+ {file = "urllib3-1.26.19-py2.py3-none-any.whl", hash = "sha256:37a0344459b199fce0e80b0d3569837ec6b6937435c5244e7fd73fa6006830f3"},
+ {file = "urllib3-1.26.19.tar.gz", hash = "sha256:3e3d753a8618b86d7de333b4223005f68720bcd6a7d2bcb9fbd2229ec7c1e429"},
]
[package.extras]
-brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"]
+brotli = ["brotli (==1.0.9)", "brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"]
secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"]
socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"]
@@ -5559,18 +5561,18 @@ test = ["mypy", "pre-commit", "pytest", "pytest-asyncio", "websockets (>=10.0)"]
[[package]]
name = "zipp"
-version = "3.10.0"
+version = "3.19.1"
description = "Backport of pathlib-compatible object wrapper for zip files"
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.8"
files = [
- {file = "zipp-3.10.0-py3-none-any.whl", hash = "sha256:4fcb6f278987a6605757302a6e40e896257570d11c51628968ccb2a47e80c6c1"},
- {file = "zipp-3.10.0.tar.gz", hash = "sha256:7a7262fd930bd3e36c50b9a64897aec3fafff3dfdeec9623ae22b40e93f99bb8"},
+ {file = "zipp-3.19.1-py3-none-any.whl", hash = "sha256:2828e64edb5386ea6a52e7ba7cdb17bb30a73a858f5eb6eb93d8d36f5ea26091"},
+ {file = "zipp-3.19.1.tar.gz", hash = "sha256:35427f6d5594f4acf82d25541438348c26736fa9b3afa2754bcd63cdb99d8e8f"},
]
[package.extras]
-docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)"]
-testing = ["flake8 (<5)", "func-timeout", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"]
+doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
+test = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"]
[extras]
audio = ["pydub"]
diff --git a/pyproject.toml b/pyproject.toml
index 26d1a047666..efbfcb4fbbf 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "docarray"
-version = '0.40.0'
+version = '0.41.0'
description='The data structure for multimodal data'
readme = 'README.md'
authors=['DocArray']
@@ -165,5 +165,6 @@ markers = [
"index: marks test using a document index",
"benchmark: marks slow benchmarking tests",
"elasticv8: marks test that run with ElasticSearch v8",
- "jac: need to have access to jac cloud"
+ "jac: need to have access to jac cloud",
+ "atlas: mark tests using MongoDB Atlas",
]
diff --git a/scripts/release.sh b/scripts/release.sh
index 03f492674b5..f63e07282fd 100755
--- a/scripts/release.sh
+++ b/scripts/release.sh
@@ -46,7 +46,7 @@ function clean_build {
function pub_pypi {
clean_build
- poetry config http-basic.pypi $PYPI_USERNAME $PYPI_PASSWORD
+ poetry config http-basic.pypi $TWINE_USERNAME $TWINE_PASSWORD
poetry publish --build
clean_build
}
diff --git a/tests/benchmark_tests/test_map.py b/tests/benchmark_tests/test_map.py
index e5c664a408b..2fc7b09496e 100644
--- a/tests/benchmark_tests/test_map.py
+++ b/tests/benchmark_tests/test_map.py
@@ -29,9 +29,9 @@ def test_map_docs_multiprocessing():
if os.cpu_count() > 1:
def time_multiprocessing(num_workers: int) -> float:
- n_docs = 5
+ n_docs = 10
rng = np.random.RandomState(0)
- matrices = [rng.random(size=(1000, 1000)) for _ in range(n_docs)]
+ matrices = [rng.random(size=(100, 100)) for _ in range(n_docs)]
da = DocList[MyMatrix]([MyMatrix(matrix=m) for m in matrices])
start_time = time()
list(
@@ -65,7 +65,7 @@ def test_map_docs_batched_multiprocessing():
def time_multiprocessing(num_workers: int) -> float:
n_docs = 16
rng = np.random.RandomState(0)
- matrices = [rng.random(size=(1000, 1000)) for _ in range(n_docs)]
+ matrices = [rng.random(size=(100, 100)) for _ in range(n_docs)]
da = DocList[MyMatrix]([MyMatrix(matrix=m) for m in matrices])
start_time = time()
list(
diff --git a/tests/index/base_classes/test_base_doc_store.py b/tests/index/base_classes/test_base_doc_store.py
index faf146df6f1..73379694284 100644
--- a/tests/index/base_classes/test_base_doc_store.py
+++ b/tests/index/base_classes/test_base_doc_store.py
@@ -13,6 +13,7 @@
from docarray.typing import ID, ImageBytes, ImageUrl, NdArray
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.utils._internal.misc import torch_imported
+from docarray.utils._internal._typing import safe_issubclass
pytestmark = pytest.mark.index
@@ -54,7 +55,7 @@ class DummyDocIndex(BaseDocIndex):
def __init__(self, db_config=None, **kwargs):
super().__init__(db_config=db_config, **kwargs)
for col_name, col in self._column_infos.items():
- if issubclass(col.docarray_type, AnyDocArray):
+ if safe_issubclass(col.docarray_type, AnyDocArray):
sub_db_config = copy.deepcopy(self._db_config)
self._subindices[col_name] = self.__class__[col.docarray_type.doc_type](
db_config=sub_db_config, subindex=True
@@ -159,7 +160,7 @@ def test_create_columns():
assert index._column_infos['id'].n_dim is None
assert index._column_infos['id'].config['hi'] == 'there'
- assert issubclass(index._column_infos['tens'].docarray_type, AbstractTensor)
+ assert safe_issubclass(index._column_infos['tens'].docarray_type, AbstractTensor)
assert index._column_infos['tens'].db_type == str
assert index._column_infos['tens'].n_dim == 10
assert index._column_infos['tens'].config == {'dim': 1000, 'hi': 'there'}
@@ -173,12 +174,16 @@ def test_create_columns():
assert index._column_infos['id'].n_dim is None
assert index._column_infos['id'].config['hi'] == 'there'
- assert issubclass(index._column_infos['tens_one'].docarray_type, AbstractTensor)
+ assert safe_issubclass(
+ index._column_infos['tens_one'].docarray_type, AbstractTensor
+ )
assert index._column_infos['tens_one'].db_type == str
assert index._column_infos['tens_one'].n_dim is None
assert index._column_infos['tens_one'].config == {'dim': 10, 'hi': 'there'}
- assert issubclass(index._column_infos['tens_two'].docarray_type, AbstractTensor)
+ assert safe_issubclass(
+ index._column_infos['tens_two'].docarray_type, AbstractTensor
+ )
assert index._column_infos['tens_two'].db_type == str
assert index._column_infos['tens_two'].n_dim is None
assert index._column_infos['tens_two'].config == {'dim': 50, 'hi': 'there'}
@@ -192,7 +197,7 @@ def test_create_columns():
assert index._column_infos['id'].n_dim is None
assert index._column_infos['id'].config['hi'] == 'there'
- assert issubclass(index._column_infos['d__tens'].docarray_type, AbstractTensor)
+ assert safe_issubclass(index._column_infos['d__tens'].docarray_type, AbstractTensor)
assert index._column_infos['d__tens'].db_type == str
assert index._column_infos['d__tens'].n_dim == 10
assert index._column_infos['d__tens'].config == {'dim': 1000, 'hi': 'there'}
@@ -206,7 +211,7 @@ def test_create_columns():
'parent_id',
]
- assert issubclass(index._column_infos['d'].docarray_type, AnyDocArray)
+ assert safe_issubclass(index._column_infos['d'].docarray_type, AnyDocArray)
assert index._column_infos['d'].db_type is None
assert index._column_infos['d'].n_dim is None
assert index._column_infos['d'].config == {}
@@ -216,7 +221,7 @@ def test_create_columns():
assert index._subindices['d']._column_infos['id'].n_dim is None
assert index._subindices['d']._column_infos['id'].config['hi'] == 'there'
- assert issubclass(
+ assert safe_issubclass(
index._subindices['d']._column_infos['tens'].docarray_type, AbstractTensor
)
assert index._subindices['d']._column_infos['tens'].db_type == str
@@ -245,7 +250,7 @@ def test_create_columns():
'parent_id',
]
- assert issubclass(
+ assert safe_issubclass(
index._subindices['d_root']._column_infos['d'].docarray_type, AnyDocArray
)
assert index._subindices['d_root']._column_infos['d'].db_type is None
@@ -266,7 +271,7 @@ def test_create_columns():
index._subindices['d_root']._subindices['d']._column_infos['id'].config['hi']
== 'there'
)
- assert issubclass(
+ assert safe_issubclass(
index._subindices['d_root']
._subindices['d']
._column_infos['tens']
@@ -461,11 +466,16 @@ class OtherNestedDoc(NestedDoc):
# SIMPLE
index = DummyDocIndex[SimpleDoc]()
in_list = [SimpleDoc(tens=np.random.random((10,)))]
- assert isinstance(index._validate_docs(in_list), DocList[BaseDoc])
+ assert isinstance(index._validate_docs(in_list), DocList)
+ for d in index._validate_docs(in_list):
+ assert isinstance(d, BaseDoc)
+
in_da = DocList[SimpleDoc](in_list)
assert index._validate_docs(in_da) == in_da
in_other_list = [OtherSimpleDoc(tens=np.random.random((10,)))]
- assert isinstance(index._validate_docs(in_other_list), DocList[BaseDoc])
+ assert isinstance(index._validate_docs(in_other_list), DocList)
+ for d in index._validate_docs(in_other_list):
+ assert isinstance(d, BaseDoc)
in_other_da = DocList[OtherSimpleDoc](in_other_list)
assert index._validate_docs(in_other_da) == in_other_da
@@ -494,7 +504,9 @@ class OtherNestedDoc(NestedDoc):
in_list = [
FlatDoc(tens_one=np.random.random((10,)), tens_two=np.random.random((50,)))
]
- assert isinstance(index._validate_docs(in_list), DocList[BaseDoc])
+ assert isinstance(index._validate_docs(in_list), DocList)
+ for d in index._validate_docs(in_list):
+ assert isinstance(d, BaseDoc)
in_da = DocList[FlatDoc](
[FlatDoc(tens_one=np.random.random((10,)), tens_two=np.random.random((50,)))]
)
@@ -502,7 +514,9 @@ class OtherNestedDoc(NestedDoc):
in_other_list = [
OtherFlatDoc(tens_one=np.random.random((10,)), tens_two=np.random.random((50,)))
]
- assert isinstance(index._validate_docs(in_other_list), DocList[BaseDoc])
+ assert isinstance(index._validate_docs(in_other_list), DocList)
+ for d in index._validate_docs(in_other_list):
+ assert isinstance(d, BaseDoc)
in_other_da = DocList[OtherFlatDoc](
[
OtherFlatDoc(
@@ -521,11 +535,15 @@ class OtherNestedDoc(NestedDoc):
# NESTED
index = DummyDocIndex[NestedDoc]()
in_list = [NestedDoc(d=SimpleDoc(tens=np.random.random((10,))))]
- assert isinstance(index._validate_docs(in_list), DocList[BaseDoc])
+ assert isinstance(index._validate_docs(in_list), DocList)
+ for d in index._validate_docs(in_list):
+ assert isinstance(d, BaseDoc)
in_da = DocList[NestedDoc]([NestedDoc(d=SimpleDoc(tens=np.random.random((10,))))])
assert index._validate_docs(in_da) == in_da
in_other_list = [OtherNestedDoc(d=OtherSimpleDoc(tens=np.random.random((10,))))]
- assert isinstance(index._validate_docs(in_other_list), DocList[BaseDoc])
+ assert isinstance(index._validate_docs(in_other_list), DocList)
+ for d in index._validate_docs(in_other_list):
+ assert isinstance(d, BaseDoc)
in_other_da = DocList[OtherNestedDoc](
[OtherNestedDoc(d=OtherSimpleDoc(tens=np.random.random((10,))))]
)
@@ -552,7 +570,9 @@ class TensorUnionDoc(BaseDoc):
# OPTIONAL
index = DummyDocIndex[SimpleDoc]()
in_list = [OptionalDoc(tens=np.random.random((10,)))]
- assert isinstance(index._validate_docs(in_list), DocList[BaseDoc])
+ assert isinstance(index._validate_docs(in_list), DocList)
+ for d in index._validate_docs(in_list):
+ assert isinstance(d, BaseDoc)
in_da = DocList[OptionalDoc](in_list)
assert index._validate_docs(in_da) == in_da
@@ -562,9 +582,13 @@ class TensorUnionDoc(BaseDoc):
# MIXED UNION
index = DummyDocIndex[SimpleDoc]()
in_list = [MixedUnionDoc(tens=np.random.random((10,)))]
- assert isinstance(index._validate_docs(in_list), DocList[BaseDoc])
+ assert isinstance(index._validate_docs(in_list), DocList)
+ for d in index._validate_docs(in_list):
+ assert isinstance(d, BaseDoc)
in_da = DocList[MixedUnionDoc](in_list)
- assert isinstance(index._validate_docs(in_da), DocList[BaseDoc])
+ assert isinstance(index._validate_docs(in_da), DocList)
+ for d in index._validate_docs(in_da):
+ assert isinstance(d, BaseDoc)
with pytest.raises(ValueError):
index._validate_docs([MixedUnionDoc(tens='hello')])
@@ -572,13 +596,17 @@ class TensorUnionDoc(BaseDoc):
# TENSOR UNION
index = DummyDocIndex[TensorUnionDoc]()
in_list = [SimpleDoc(tens=np.random.random((10,)))]
- assert isinstance(index._validate_docs(in_list), DocList[BaseDoc])
+ assert isinstance(index._validate_docs(in_list), DocList)
+ for d in index._validate_docs(in_list):
+ assert isinstance(d, BaseDoc)
in_da = DocList[SimpleDoc](in_list)
assert index._validate_docs(in_da) == in_da
index = DummyDocIndex[SimpleDoc]()
in_list = [TensorUnionDoc(tens=np.random.random((10,)))]
- assert isinstance(index._validate_docs(in_list), DocList[BaseDoc])
+ assert isinstance(index._validate_docs(in_list), DocList)
+ for d in index._validate_docs(in_list):
+ assert isinstance(d, BaseDoc)
in_da = DocList[TensorUnionDoc](in_list)
assert index._validate_docs(in_da) == in_da
diff --git a/tests/index/elastic/fixture.py b/tests/index/elastic/fixture.py
index d81a91c8931..fddce16d695 100644
--- a/tests/index/elastic/fixture.py
+++ b/tests/index/elastic/fixture.py
@@ -28,32 +28,32 @@
pytestmark = [pytest.mark.slow, pytest.mark.index]
cur_dir = os.path.dirname(os.path.abspath(__file__))
-compose_yml_v7 = os.path.abspath(os.path.join(cur_dir, 'v7/docker-compose.yml'))
-compose_yml_v8 = os.path.abspath(os.path.join(cur_dir, 'v8/docker-compose.yml'))
+compose_yml_v7 = os.path.abspath(os.path.join(cur_dir, "v7/docker-compose.yml"))
+compose_yml_v8 = os.path.abspath(os.path.join(cur_dir, "v8/docker-compose.yml"))
-@pytest.fixture(scope='module', autouse=True)
+@pytest.fixture(scope="module", autouse=True)
def start_storage_v7():
- os.system(f"docker-compose -f {compose_yml_v7} up -d --remove-orphans")
+ os.system(f"docker compose -f {compose_yml_v7} up -d --remove-orphans")
_wait_for_es()
yield
- os.system(f"docker-compose -f {compose_yml_v7} down --remove-orphans")
+ os.system(f"docker compose -f {compose_yml_v7} down --remove-orphans")
-@pytest.fixture(scope='module', autouse=True)
+@pytest.fixture(scope="module", autouse=True)
def start_storage_v8():
- os.system(f"docker-compose -f {compose_yml_v8} up -d --remove-orphans")
+ os.system(f"docker compose -f {compose_yml_v8} up -d --remove-orphans")
_wait_for_es()
yield
- os.system(f"docker-compose -f {compose_yml_v8} down --remove-orphans")
+ os.system(f"docker compose -f {compose_yml_v8} down --remove-orphans")
def _wait_for_es():
from elasticsearch import Elasticsearch
- es = Elasticsearch(hosts='http://localhost:9200/')
+ es = Elasticsearch(hosts="http://localhost:9200/")
while not es.ping():
time.sleep(0.5)
@@ -79,12 +79,12 @@ class MyImageDoc(ImageDoc):
embedding: NdArray = Field(dims=128)
-@pytest.fixture(scope='function')
+@pytest.fixture(scope="function")
def ten_simple_docs():
return [SimpleDoc(tens=np.random.randn(10)) for _ in range(10)]
-@pytest.fixture(scope='function')
+@pytest.fixture(scope="function")
def ten_flat_docs():
return [
FlatDoc(tens_one=np.random.randn(10), tens_two=np.random.randn(50))
@@ -92,12 +92,12 @@ def ten_flat_docs():
]
-@pytest.fixture(scope='function')
+@pytest.fixture(scope="function")
def ten_nested_docs():
return [NestedDoc(d=SimpleDoc(tens=np.random.randn(10))) for _ in range(10)]
-@pytest.fixture(scope='function')
+@pytest.fixture(scope="function")
def ten_deep_nested_docs():
return [
DeepNestedDoc(d=NestedDoc(d=SimpleDoc(tens=np.random.randn(10))))
@@ -105,6 +105,6 @@ def ten_deep_nested_docs():
]
-@pytest.fixture(scope='function')
+@pytest.fixture(scope="function")
def tmp_index_name():
return uuid.uuid4().hex
diff --git a/tests/index/mongo_atlas/__init__.py b/tests/index/mongo_atlas/__init__.py
index 352060a3056..305bebe1edb 100644
--- a/tests/index/mongo_atlas/__init__.py
+++ b/tests/index/mongo_atlas/__init__.py
@@ -26,21 +26,20 @@ class NestedDoc(BaseDoc):
class FlatSchema(BaseDoc):
embedding1: NdArray = Field(dim=N_DIM, index_name="vector_index_1")
- # the dim and N_DIM are setted different on propouse. to check the correct handling of n_dim
- embedding2: NdArray[50] = Field(dim=N_DIM, index_name="vector_index_2")
+ embedding2: NdArray = Field(dim=N_DIM, index_name="vector_index_2")
-def assert_when_ready(callable: Callable, tries: int = 5, interval: float = 2):
+def assert_when_ready(callable: Callable, tries: int = 10, interval: float = 2):
"""
Retry callable to account for time taken to change data on the cluster
"""
while True:
try:
callable()
- except AssertionError:
+ except AssertionError as e:
tries -= 1
if tries == 0:
- raise
+ raise RuntimeError("Retries exhausted.") from e
time.sleep(interval)
else:
return
diff --git a/tests/index/mongo_atlas/conftest.py b/tests/index/mongo_atlas/conftest.py
index 727fabb1f5d..beb1276eed6 100644
--- a/tests/index/mongo_atlas/conftest.py
+++ b/tests/index/mongo_atlas/conftest.py
@@ -1,3 +1,4 @@
+import logging
import os
import numpy as np
@@ -19,7 +20,9 @@ def mongodb_index_config():
@pytest.fixture
def simple_index(mongodb_index_config):
- index = MongoDBAtlasDocumentIndex[SimpleSchema](**mongodb_index_config)
+ index = MongoDBAtlasDocumentIndex[SimpleSchema](
+ index_name="bespoke_name", **mongodb_index_config
+ )
return index
@@ -30,8 +33,20 @@ def nested_index(mongodb_index_config):
@pytest.fixture(scope='module')
-def random_simple_documents():
- N_DIM = 10
+def n_dim():
+ return 10
+
+
+@pytest.fixture(scope='module')
+def embeddings(n_dim):
+ """A consistent, reasonable, mock of vector embeddings, in [-1, 1]."""
+ x = np.linspace(-np.pi, np.pi, n_dim)
+ y = np.arange(n_dim)
+ return np.sin(x[np.newaxis, :] + y[:, np.newaxis])
+
+
+@pytest.fixture(scope='module')
+def random_simple_documents(n_dim, embeddings):
docs_text = [
"Text processing with Python is a valuable skill for data analysis.",
"Gardening tips for a beautiful backyard oasis.",
@@ -45,37 +60,36 @@ def random_simple_documents():
"eleifend eros non, accumsan lectus. Curabitur porta auctor tellus at pharetra. Phasellus ut condimentum",
]
return [
- SimpleSchema(embedding=np.random.rand(N_DIM), number=i, text=docs_text[i])
- for i in range(10)
+ SimpleSchema(embedding=embeddings[i], number=i, text=docs_text[i])
+ for i in range(len(docs_text))
]
@pytest.fixture
-def nested_documents():
- N_DIM = 10
+def nested_documents(n_dim):
docs = [
NestedDoc(
- d=SimpleDoc(embedding=np.random.rand(N_DIM)),
- embedding=np.random.rand(N_DIM),
+ d=SimpleDoc(embedding=np.random.rand(n_dim)),
+ embedding=np.random.rand(n_dim),
)
for _ in range(10)
]
docs.append(
NestedDoc(
- d=SimpleDoc(embedding=np.zeros(N_DIM)),
- embedding=np.ones(N_DIM),
+ d=SimpleDoc(embedding=np.zeros(n_dim)),
+ embedding=np.ones(n_dim),
)
)
docs.append(
NestedDoc(
- d=SimpleDoc(embedding=np.ones(N_DIM)),
- embedding=np.zeros(N_DIM),
+ d=SimpleDoc(embedding=np.ones(n_dim)),
+ embedding=np.zeros(n_dim),
)
)
docs.append(
NestedDoc(
- d=SimpleDoc(embedding=np.zeros(N_DIM)),
- embedding=np.ones(N_DIM),
+ d=SimpleDoc(embedding=np.zeros(n_dim)),
+ embedding=np.ones(n_dim),
)
)
return docs
@@ -86,10 +100,11 @@ def simple_index_with_docs(simple_index, random_simple_documents):
"""
Setup and teardown of simple_index. Accesses the underlying MongoDB collection directly.
"""
- simple_index._doc_collection.delete_many({})
+ simple_index._collection.delete_many({})
+ simple_index._logger.setLevel(logging.DEBUG)
simple_index.index(random_simple_documents)
yield simple_index, random_simple_documents
- simple_index._doc_collection.delete_many({})
+ simple_index._collection.delete_many({})
@pytest.fixture
@@ -97,7 +112,7 @@ def nested_index_with_docs(nested_index, nested_documents):
"""
Setup and teardown of simple_index. Accesses the underlying MongoDB collection directly.
"""
- nested_index._doc_collection.delete_many({})
+ nested_index._collection.delete_many({})
nested_index.index(nested_documents)
yield nested_index, nested_documents
- nested_index._doc_collection.delete_many({})
+ nested_index._collection.delete_many({})
diff --git a/tests/index/mongo_atlas/test_find.py b/tests/index/mongo_atlas/test_find.py
index aadfacb4544..e9968b05dd2 100644
--- a/tests/index/mongo_atlas/test_find.py
+++ b/tests/index/mongo_atlas/test_find.py
@@ -8,13 +8,11 @@
from . import NestedDoc, SimpleDoc, SimpleSchema, assert_when_ready
-N_DIM = 10
-
-def test_find_simple_schema(simple_index_with_docs): # noqa: F811
+def test_find_simple_schema(simple_index_with_docs, n_dim): # noqa: F811
simple_index, random_simple_documents = simple_index_with_docs # noqa: F811
- query = np.ones(N_DIM)
+ query = np.ones(n_dim)
# Insert one doc that identically matches query's embedding
expected_matching_document = SimpleSchema(embedding=query, text="other", number=10)
@@ -29,8 +27,8 @@ def pred():
assert_when_ready(pred)
-def test_find_empty_index(simple_index): # noqa: F811
- query = np.random.rand(N_DIM)
+def test_find_empty_index(simple_index, n_dim): # noqa: F811
+ query = np.random.rand(n_dim)
def pred():
docs, scores = simple_index.find(query, search_field='embedding', limit=5)
@@ -40,10 +38,10 @@ def pred():
assert_when_ready(pred)
-def test_find_limit_larger_than_index(simple_index_with_docs): # noqa: F811
+def test_find_limit_larger_than_index(simple_index_with_docs, n_dim): # noqa: F811
simple_index, random_simple_documents = simple_index_with_docs # noqa: F811
- query = np.ones(N_DIM)
+ query = np.ones(n_dim)
new_doc = SimpleSchema(embedding=query, text="other", number=10)
simple_index.index(new_doc)
@@ -56,29 +54,29 @@ def pred():
assert_when_ready(pred)
-def test_find_flat_schema(mongodb_index_config): # noqa: F811
+def test_find_flat_schema(mongodb_index_config, n_dim): # noqa: F811
class FlatSchema(BaseDoc):
- embedding1: NdArray = Field(dim=N_DIM, index_name="vector_index_1")
- # the dim and N_DIM are setted different on propouse. to check the correct handling of n_dim
- embedding2: NdArray[50] = Field(dim=N_DIM, index_name="vector_index_2")
+ embedding1: NdArray = Field(dim=n_dim, index_name="vector_index_1")
+ # the dim and n_dim are setted different on propouse. to check the correct handling of n_dim
+ embedding2: NdArray[50] = Field(dim=n_dim, index_name="vector_index_2")
index = MongoDBAtlasDocumentIndex[FlatSchema](**mongodb_index_config)
- index._doc_collection.delete_many({})
+ index._collection.delete_many({})
index_docs = [
- FlatSchema(embedding1=np.random.rand(N_DIM), embedding2=np.random.rand(50))
+ FlatSchema(embedding1=np.random.rand(n_dim), embedding2=np.random.rand(50))
for _ in range(10)
]
- index_docs.append(FlatSchema(embedding1=np.zeros(N_DIM), embedding2=np.ones(50)))
- index_docs.append(FlatSchema(embedding1=np.ones(N_DIM), embedding2=np.zeros(50)))
+ index_docs.append(FlatSchema(embedding1=np.zeros(n_dim), embedding2=np.ones(50)))
+ index_docs.append(FlatSchema(embedding1=np.ones(n_dim), embedding2=np.zeros(50)))
index.index(index_docs)
def pred1():
# find on embedding1
- query = np.ones(N_DIM)
+ query = np.ones(n_dim)
docs, scores = index.find(query, search_field='embedding1', limit=5)
assert len(docs) == 5
assert len(scores) == 5
@@ -116,10 +114,10 @@ def pred():
assert_when_ready(pred)
-def test_find_nested_schema(nested_index_with_docs): # noqa: F811
+def test_find_nested_schema(nested_index_with_docs, n_dim): # noqa: F811
db, base_docs = nested_index_with_docs
- query = NestedDoc(d=SimpleDoc(embedding=np.ones(N_DIM)), embedding=np.ones(N_DIM))
+ query = NestedDoc(d=SimpleDoc(embedding=np.ones(n_dim)), embedding=np.ones(n_dim))
# find on root level
def pred():
@@ -137,11 +135,11 @@ def pred():
assert_when_ready(pred)
-def test_find_schema_without_index(mongodb_index_config): # noqa: F811
+def test_find_schema_without_index(mongodb_index_config, n_dim): # noqa: F811
class Schema(BaseDoc):
- vec: NdArray = Field(dim=N_DIM)
+ vec: NdArray = Field(dim=n_dim)
index = MongoDBAtlasDocumentIndex[Schema](**mongodb_index_config)
- query = np.ones(N_DIM)
+ query = np.ones(n_dim)
with pytest.raises(ValueError):
index.find(query, search_field='vec', limit=2)
diff --git a/tests/index/mongo_atlas/test_persist_data.py b/tests/index/mongo_atlas/test_persist_data.py
index 62ff02348d5..d170bfc22a8 100644
--- a/tests/index/mongo_atlas/test_persist_data.py
+++ b/tests/index/mongo_atlas/test_persist_data.py
@@ -5,7 +5,7 @@
def test_persist(mongodb_index_config, random_simple_documents): # noqa: F811
index = MongoDBAtlasDocumentIndex[SimpleSchema](**mongodb_index_config)
- index._doc_collection.delete_many({})
+ index._collection.delete_many({})
def cleaned_database():
assert index.num_docs() == 0
diff --git a/tests/index/mongo_atlas/test_query_builder.py b/tests/index/mongo_atlas/test_query_builder.py
new file mode 100644
index 00000000000..3b103cec3d9
--- /dev/null
+++ b/tests/index/mongo_atlas/test_query_builder.py
@@ -0,0 +1,352 @@
+import numpy as np
+import pytest
+
+from . import assert_when_ready
+
+
+def test_missing_required_var_exceptions(simple_index): # noqa: F811
+ """Ensure that exceptions are raised when required arguments are not provided."""
+
+ with pytest.raises(ValueError):
+ simple_index.build_query().find().build()
+
+ with pytest.raises(ValueError):
+ simple_index.build_query().text_search().build()
+
+ with pytest.raises(ValueError):
+ simple_index.build_query().filter().build()
+
+
+def test_find_uses_provided_vector(simple_index): # noqa: F811
+ query = (
+ simple_index.build_query()
+ .find(query=np.ones(10), search_field='embedding')
+ .build(7)
+ )
+
+ query_vector = query.vector_fields.pop('embedding')
+ assert query.vector_fields == {}
+ assert np.allclose(query_vector, np.ones(10))
+ assert query.filters == []
+ assert query.limit == 7
+
+
+def test_multiple_find_returns_averaged_vector(simple_index, n_dim): # noqa: F811
+ query = (
+ simple_index.build_query() # type: ignore[attr-defined]
+ .find(query=np.ones(n_dim), search_field='embedding')
+ .find(query=np.zeros(n_dim), search_field='embedding')
+ .build(5)
+ )
+
+ assert len(query.vector_fields) == 1
+ query_vector = query.vector_fields.pop('embedding')
+ assert query.vector_fields == {}
+ assert np.allclose(query_vector, np.array([0.5] * n_dim))
+ assert query.filters == []
+ assert query.limit == 5
+
+
+def test_filter_passes_filter(simple_index): # noqa: F811
+ index = simple_index
+
+ filter = {"number": {"$lt": 1}}
+ query = index.build_query().filter(query=filter).build(limit=11) # type: ignore[attr-defined]
+
+ assert query.vector_fields == {}
+ assert query.filters == [{"query": filter}]
+ assert query.limit == 11
+
+
+def test_execute_query_find_filter(simple_index_with_docs, n_dim): # noqa: F811
+ """Tests filters passed to vector search behave as expected"""
+ index, _ = simple_index_with_docs
+
+ find_query = np.ones(n_dim)
+ filter_query1 = {"number": {"$lt": 8}}
+ filter_query2 = {"number": {"$gt": 5}}
+
+ query = (
+ index.build_query() # type: ignore[attr-defined]
+ .find(query=find_query, search_field='embedding')
+ .filter(query=filter_query1)
+ .filter(query=filter_query2)
+ .build(limit=5)
+ )
+
+ def trial():
+ res = index.execute_query(query)
+ assert len(res.documents) == 2
+ assert set(res.documents.number) == {6, 7}
+
+ assert_when_ready(trial)
+
+
+def test_execute_only_filter(
+ simple_index_with_docs, # noqa: F811
+):
+ index, _ = simple_index_with_docs
+
+ filter_query1 = {"number": {"$lt": 8}}
+ filter_query2 = {"number": {"$gt": 5}}
+
+ query = (
+ index.build_query() # type: ignore[attr-defined]
+ .filter(query=filter_query1)
+ .filter(query=filter_query2)
+ .build(limit=5)
+ )
+
+ def trial():
+ res = index.execute_query(query)
+
+ assert len(res.documents) == 2
+ assert set(res.documents.number) == {6, 7}
+
+ assert_when_ready(trial)
+
+
+def test_execute_text_search_with_filter(
+ simple_index_with_docs, # noqa: F811
+):
+ """Note: Text search returns only matching _, not limit."""
+ index, _ = simple_index_with_docs
+
+ filter_query1 = {"number": {"$eq": 0}}
+
+ query = (
+ index.build_query() # type: ignore[attr-defined]
+ .text_search(query="Python is a valuable skill", search_field='text')
+ .filter(query=filter_query1)
+ .build(limit=5)
+ )
+
+ def trial():
+ res = index.execute_query(query)
+
+ assert len(res.documents) == 1
+ assert set(res.documents.number) == {0}
+
+ assert_when_ready(trial)
+
+
+def test_find(
+ simple_index_with_docs,
+ n_dim, # noqa: F811
+):
+ index, _ = simple_index_with_docs
+ limit = 3
+ # Base Case: No filters, single text search, single vector search
+ query = (
+ index.build_query() # type: ignore[attr-defined]
+ .find(query=np.ones(n_dim), search_field='embedding')
+ .build(limit=limit)
+ )
+
+ def trial():
+ res = index.execute_query(query)
+ assert len(res.documents) == limit
+ assert res.documents.number == [5, 4, 6]
+
+ assert_when_ready(trial)
+
+
+def test_hybrid_search(simple_index_with_docs, n_dim): # noqa: F811
+ find_query = np.ones(n_dim)
+ index, docs = simple_index_with_docs
+ n_docs = len(docs)
+ limit = n_docs
+
+ # Base Case: No filters, single text search, single vector search
+ query = (
+ index.build_query() # type: ignore[attr-defined]
+ .find(query=find_query, search_field='embedding')
+ .text_search(query="Python is a valuable skill", search_field='text')
+ .build(limit=limit)
+ )
+
+ def trial():
+ res = index.execute_query(query)
+ assert len(res.documents) == limit
+ assert set(res.documents.number) == set(range(n_docs))
+
+ assert_when_ready(trial)
+
+ # Now that we've successfully executed a query, we know that the search indexes have been built
+ # We no longer need to sleep and retry. Re-run to keep results
+ res_base = index.execute_query(query)
+
+ # Case 2: Base plus a filter
+ filter_query1 = {"number": {"$gt": 0}}
+
+ query = (
+ index.build_query() # type: ignore[attr-defined]
+ .find(query=find_query, search_field='embedding')
+ .text_search(query="Python is a valuable skill", search_field='text')
+ .filter(query=filter_query1)
+ .build(limit=n_docs)
+ )
+
+ res = index.execute_query(query)
+ assert len(res.documents) == 9
+ assert set(res.documents.number) == set(range(1, n_docs))
+
+ # Case 3: Base with, but matching, additional vector search component
+ # As we are using averaging to combine embedding vectors, this is a no-op
+ query = (
+ index.build_query() # type: ignore[attr-defined]
+ .find(query=find_query, search_field='embedding')
+ .find(query=find_query, search_field='embedding')
+ .text_search(query="Python is a valuable skill", search_field='text')
+ .build(limit=n_docs)
+ )
+ res3 = index.execute_query(query)
+ assert res3.documents.number == res_base.documents.number
+
+ # Case 4: Base with, but perpendicular, additional vector search component
+ query = (
+ index.build_query() # type: ignore[attr-defined]
+ # .find(query=find_query, search_field='embedding')
+ .find(
+ query=np.random.standard_normal(find_query.shape), search_field='embedding'
+ )
+ .text_search(query="Python is a valuable skill", search_field='text')
+ .build(limit=n_docs)
+ )
+ res4 = index.execute_query(query)
+ assert res4.documents.number != res_base.documents.number
+
+ # Case 5: Multiple text searches
+ query = (
+ index.build_query() # type: ignore[attr-defined]
+ .find(query=find_query, search_field='embedding')
+ .text_search(query="Python is a valuable skill", search_field='text')
+ .text_search(query="classical music compositions", search_field='text')
+ .build(limit=n_docs)
+ )
+ res5 = index.execute_query(query)
+ assert res5.documents.number[:2] == [0, 3]
+
+ # Case 6: Multiple text search with filters
+ query = (
+ index.build_query() # type: ignore[attr-defined]
+ .find(query=find_query, search_field='embedding')
+ .filter(query={"number": {"$gt": 0}})
+ .text_search(query="classical music compositions", search_field='text')
+ .text_search(query="Python is a valuable skill", search_field='text')
+ .build(limit=n_docs)
+ )
+ res6 = index.execute_query(query)
+ assert res6.documents.number[0] == 3
+
+
+def test_hybrid_search_multiple_text(simple_index_with_docs, n_dim): # noqa: F811
+ """Tests disambiguation of scores on multiple text searches on same field."""
+
+ index, _ = simple_index_with_docs
+ limit = 10
+ query = (
+ index.build_query() # type: ignore[attr-defined]
+ .text_search(query="classical music compositions", search_field='text')
+ .text_search(query="Python is a valuable skill", search_field='text')
+ .find(query=np.ones(n_dim), search_field='embedding')
+ .build(limit=limit)
+ )
+
+ def trial():
+ res = index.execute_query(query, score_breakdown=True)
+ assert len(res.documents) == limit
+ assert res.documents.number == [0, 3, 5, 4, 6, 9, 7, 1, 2, 8]
+
+ assert_when_ready(trial)
+
+
+def test_hybrid_search_only_text(simple_index_with_docs): # noqa: F811
+ """Query built with two text searches will be a Hybrid Search.
+
+ It will return only two results.
+ In our case, each text matches just one document, hence we will receive two results, each top ranked
+ """
+ index, _ = simple_index_with_docs
+ limit = 10
+ query = (
+ index.build_query() # type: ignore[attr-defined]
+ .text_search(query="classical music compositions", search_field='text')
+ .text_search(query="Python is a valuable skill", search_field='text')
+ .build(limit=limit)
+ )
+
+ def trial():
+ res = index.execute_query(query)
+ assert len(res.documents) != limit
+ # Instead, we find the number of documents containing one of these phrases
+ assert len(res.documents) == len(query.text_searches)
+ assert set(res.documents.number) == {0, 3}
+ assert set(res.scores) == {0.5, 0.5}
+
+ assert_when_ready(trial)
+
+
+def test_hybrid_search_only_vector(simple_index_with_docs, n_dim): # noqa: F811
+
+ limit = 3
+ index, _ = simple_index_with_docs
+ query = (
+ index.build_query() # type: ignore[attr-defined]
+ .find(query=np.ones(n_dim), search_field='embedding')
+ .find(query=np.zeros(n_dim), search_field='embedding')
+ .build(limit=limit)
+ )
+
+ def trial():
+ res = index.execute_query(query)
+ assert len(res.documents) == limit
+ assert res.documents.number == [5, 4, 6]
+
+ assert_when_ready(trial)
+
+
+@pytest.mark.skip
+def test_hybrid_search_vectors_with_different_fields(
+ mongodb_index_config,
+): # noqa: F811
+ """Hybrid Search involving queries to two different vector indexes.
+
+ # TODO - To be added in an upcoming release.
+ """
+
+ from docarray.index.backends.mongodb_atlas import MongoDBAtlasDocumentIndex
+ from tests.index.mongo_atlas import FlatSchema
+
+ multi_index = MongoDBAtlasDocumentIndex[FlatSchema](**mongodb_index_config)
+ multi_index._collection.delete_many({})
+
+ n_dim = 25
+ n_docs = 5
+ data = [
+ FlatSchema(
+ embedding1=np.random.standard_normal(n_dim),
+ embedding2=np.random.standard_normal(n_dim),
+ )
+ for _ in range(n_docs)
+ ]
+ multi_index.index(data)
+ yield multi_index
+ multi_index._collection.delete_many({})
+
+ limit = 3
+ query = (
+ multi_index.build_query() # type: ignore[attr-defined]
+ .find(query=np.ones(n_dim), search_field='embedding1')
+ .find(query=np.zeros(n_dim), search_field='embedding2')
+ .build(limit=limit)
+ )
+
+ with pytest.raises(NotImplementedError):
+
+ def trial():
+ res = multi_index.execute_query(query)
+ assert len(res.documents) == limit
+ assert res.documents.number == [5, 4, 6]
+
+ assert_when_ready(trial)
diff --git a/tests/index/mongo_atlas/test_subindex.py b/tests/index/mongo_atlas/test_subindex.py
index 82f8744221e..71e99beca33 100644
--- a/tests/index/mongo_atlas/test_subindex.py
+++ b/tests/index/mongo_atlas/test_subindex.py
@@ -53,7 +53,7 @@ class MyDoc(BaseDoc):
def clean_subindex(index):
for subindex in index._subindices.values():
clean_subindex(subindex)
- index._doc_collection.delete_many({})
+ index._collection.delete_many({})
@pytest.fixture(scope='session')
@@ -262,6 +262,4 @@ def test_subindex_del(index):
def test_subindex_collections(mongodb_index_config): # noqa: F811
doc_index = MongoDBAtlasDocumentIndex[MetaCategoryDoc](**mongodb_index_config)
-
assert doc_index._subindices["paths"].index_name == 'metacategorydoc__paths'
- assert doc_index._subindices["paths"]._collection == 'metacategorydoc__paths'
diff --git a/tests/index/mongo_atlas/test_text_search.py b/tests/index/mongo_atlas/test_text_search.py
index cbc6db80580..c480c218c7f 100644
--- a/tests/index/mongo_atlas/test_text_search.py
+++ b/tests/index/mongo_atlas/test_text_search.py
@@ -9,7 +9,7 @@ def test_text_search(simple_index_with_docs): # noqa: F811
def pred():
docs, scores = simple_index.text_search(
- query=query_string, search_field='text', limit=1
+ query=query_string, search_field='text', limit=10
)
assert len(docs) == 1
assert docs[0].text == expected_text
diff --git a/tests/index/qdrant/fixtures.py b/tests/index/qdrant/fixtures.py
index cf599fe0cd1..ccb725a7744 100644
--- a/tests/index/qdrant/fixtures.py
+++ b/tests/index/qdrant/fixtures.py
@@ -23,19 +23,19 @@
from docarray.index import QdrantDocumentIndex
cur_dir = os.path.dirname(os.path.abspath(__file__))
-qdrant_yml = os.path.abspath(os.path.join(cur_dir, 'docker-compose.yml'))
+qdrant_yml = os.path.abspath(os.path.join(cur_dir, "docker-compose.yml"))
-@pytest.fixture(scope='session', autouse=True)
+@pytest.fixture(scope="session", autouse=True)
def start_storage():
- os.system(f"docker-compose -f {qdrant_yml} up -d --remove-orphans")
+ os.system(f"docker compose -f {qdrant_yml} up -d --remove-orphans")
time.sleep(1)
yield
- os.system(f"docker-compose -f {qdrant_yml} down --remove-orphans")
+ os.system(f"docker compose -f {qdrant_yml} down --remove-orphans")
-@pytest.fixture(scope='function')
+@pytest.fixture(scope="function")
def tmp_collection_name():
return uuid.uuid4().hex
@@ -43,7 +43,7 @@ def tmp_collection_name():
@pytest.fixture
def qdrant() -> qdrant_client.QdrantClient:
"""This fixture takes care of removing the collection before each test case"""
- client = qdrant_client.QdrantClient(path='/tmp/qdrant-local')
+ client = qdrant_client.QdrantClient(path="/tmp/qdrant-local")
for collection in client.get_collections().collections:
client.delete_collection(collection.name)
return client
diff --git a/tests/index/weaviate/fixture_weaviate.py b/tests/index/weaviate/fixture_weaviate.py
index 3699673746e..4358f46b5dd 100644
--- a/tests/index/weaviate/fixture_weaviate.py
+++ b/tests/index/weaviate/fixture_weaviate.py
@@ -24,16 +24,16 @@
cur_dir = os.path.dirname(os.path.abspath(__file__))
-weaviate_yml = os.path.abspath(os.path.join(cur_dir, 'docker-compose.yml'))
+weaviate_yml = os.path.abspath(os.path.join(cur_dir, "docker-compose.yml"))
-@pytest.fixture(scope='session', autouse=True)
+@pytest.fixture(scope="session", autouse=True)
def start_storage():
- os.system(f"docker-compose -f {weaviate_yml} up -d --remove-orphans")
+ os.system(f"docker compose -f {weaviate_yml} up -d --remove-orphans")
_wait_for_weaviate()
yield
- os.system(f"docker-compose -f {weaviate_yml} down --remove-orphans")
+ os.system(f"docker compose -f {weaviate_yml} down --remove-orphans")
def _wait_for_weaviate():
diff --git a/tests/integrations/array/test_optional_doc_vec.py b/tests/integrations/array/test_optional_doc_vec.py
index bb793152d3d..dd77c66762b 100644
--- a/tests/integrations/array/test_optional_doc_vec.py
+++ b/tests/integrations/array/test_optional_doc_vec.py
@@ -20,7 +20,8 @@ class Image(BaseDoc):
docs.features = [Features(tensor=np.random.random([100])) for _ in range(10)]
print(docs.features) #
- assert isinstance(docs.features, DocVec[Features])
+ assert isinstance(docs.features, DocVec)
+ assert isinstance(docs.features[0], Features)
docs.features.tensor = np.ones((10, 100))
diff --git a/tests/integrations/externals/test_fastapi.py b/tests/integrations/externals/test_fastapi.py
index 02967a07cd0..c5ef1868219 100644
--- a/tests/integrations/externals/test_fastapi.py
+++ b/tests/integrations/externals/test_fastapi.py
@@ -1,5 +1,5 @@
-from typing import List
-
+from typing import Any, Dict, List, Optional, Union, ClassVar
+import json
import numpy as np
import pytest
from fastapi import FastAPI
@@ -8,7 +8,9 @@
from docarray import BaseDoc, DocList
from docarray.base_doc import DocArrayResponse
from docarray.documents import ImageDoc, TextDoc
-from docarray.typing import NdArray
+from docarray.typing import NdArray, AnyTensor, ImageUrl
+
+from docarray.utils._internal.pydantic import is_pydantic_v2
@pytest.mark.asyncio
@@ -135,3 +137,256 @@ async def func(fastapi_docs: List[ImageDoc]) -> List[ImageDoc]:
docs = DocList[ImageDoc].from_json(response.content.decode())
assert len(docs) == 2
assert docs[0].tensor.shape == (3, 224, 224)
+
+
+@pytest.mark.asyncio
+@pytest.mark.skipif(
+ not is_pydantic_v2, reason='Behavior is only available for Pydantic V2'
+)
+async def test_doclist_directly():
+ from fastapi import Body
+
+ doc = ImageDoc(tensor=np.zeros((3, 224, 224)), url='url')
+ docs = DocList[ImageDoc]([doc, doc])
+
+ app = FastAPI()
+
+ @app.post("/doc/", response_class=DocArrayResponse)
+ async def func_embed_false(
+ fastapi_docs: DocList[ImageDoc] = Body(embed=False),
+ ) -> DocList[ImageDoc]:
+ return fastapi_docs
+
+ @app.post("/doc_default/", response_class=DocArrayResponse)
+ async def func_default(fastapi_docs: DocList[ImageDoc]) -> DocList[ImageDoc]:
+ return fastapi_docs
+
+ @app.post("/doc_embed/", response_class=DocArrayResponse)
+ async def func_embed_true(
+ fastapi_docs: DocList[ImageDoc] = Body(embed=True),
+ ) -> DocList[ImageDoc]:
+ return fastapi_docs
+
+ async with AsyncClient(app=app, base_url="http://test") as ac:
+ response = await ac.post("/doc/", data=docs.to_json())
+ response_default = await ac.post("/doc_default/", data=docs.to_json())
+ embed_content_json = {'fastapi_docs': json.loads(docs.to_json())}
+ response_embed = await ac.post(
+ "/doc_embed/",
+ json=embed_content_json,
+ )
+ resp_doc = await ac.get("/docs")
+ resp_redoc = await ac.get("/redoc")
+
+ assert response.status_code == 200
+ assert response_default.status_code == 200
+ assert response_embed.status_code == 200
+ assert resp_doc.status_code == 200
+ assert resp_redoc.status_code == 200
+
+ docs = DocList[ImageDoc].from_json(response.content.decode())
+ assert len(docs) == 2
+ assert docs[0].tensor.shape == (3, 224, 224)
+
+ docs_default = DocList[ImageDoc].from_json(response_default.content.decode())
+ assert len(docs_default) == 2
+ assert docs_default[0].tensor.shape == (3, 224, 224)
+
+ docs_embed = DocList[ImageDoc].from_json(response_embed.content.decode())
+ assert len(docs_embed) == 2
+ assert docs_embed[0].tensor.shape == (3, 224, 224)
+
+
+@pytest.mark.asyncio
+@pytest.mark.skipif(
+ not is_pydantic_v2, reason='Behavior is only available for Pydantic V2'
+)
+async def test_doclist_complex_schema():
+ from fastapi import Body
+
+ class Nested2Doc(BaseDoc):
+ value: str
+ classvar: ClassVar[str] = 'classvar2'
+
+ class Nested1Doc(BaseDoc):
+ nested: Nested2Doc
+ classvar: ClassVar[str] = 'classvar1'
+
+ class CustomDoc(BaseDoc):
+ tensor: Optional[AnyTensor] = None
+ url: ImageUrl
+ num: float = 0.5
+ num_num: List[float] = [1.5, 2.5]
+ lll: List[List[List[int]]] = [[[5]]]
+ fff: List[List[List[float]]] = [[[5.2]]]
+ single_text: TextDoc
+ texts: DocList[TextDoc]
+ d: Dict[str, str] = {'a': 'b'}
+ di: Optional[Dict[str, int]] = None
+ u: Union[str, int]
+ lu: List[Union[str, int]] = [0, 1, 2]
+ tags: Optional[Dict[str, Any]] = None
+ nested: Nested1Doc
+ embedding: NdArray
+ classvar: ClassVar[str] = 'classvar'
+
+ docs = DocList[CustomDoc](
+ [
+ CustomDoc(
+ num=3.5,
+ num_num=[4.5, 5.5],
+ url='photo.jpg',
+ lll=[[[40]]],
+ fff=[[[40.2]]],
+ d={'b': 'a'},
+ texts=DocList[TextDoc]([TextDoc(text='hey ha', embedding=np.zeros(3))]),
+ single_text=TextDoc(text='single hey ha', embedding=np.zeros(2)),
+ u='a',
+ lu=[3, 4],
+ embedding=np.random.random((1, 4)),
+ nested=Nested1Doc(nested=Nested2Doc(value='hello world')),
+ )
+ ]
+ )
+
+ app = FastAPI()
+
+ @app.post("/doc/", response_class=DocArrayResponse)
+ async def func_embed_false(
+ fastapi_docs: DocList[CustomDoc] = Body(embed=False),
+ ) -> DocList[CustomDoc]:
+ for doc in fastapi_docs:
+ doc.tensor = np.zeros((10, 10, 10))
+ doc.di = {'a': 2}
+
+ return fastapi_docs
+
+ @app.post("/doc_default/", response_class=DocArrayResponse)
+ async def func_default(fastapi_docs: DocList[CustomDoc]) -> DocList[CustomDoc]:
+ for doc in fastapi_docs:
+ doc.tensor = np.zeros((10, 10, 10))
+ doc.di = {'a': 2}
+ return fastapi_docs
+
+ @app.post("/doc_embed/", response_class=DocArrayResponse)
+ async def func_embed_true(
+ fastapi_docs: DocList[CustomDoc] = Body(embed=True),
+ ) -> DocList[CustomDoc]:
+ for doc in fastapi_docs:
+ doc.tensor = np.zeros((10, 10, 10))
+ doc.di = {'a': 2}
+ return fastapi_docs
+
+ async with AsyncClient(app=app, base_url="http://test") as ac:
+ response = await ac.post("/doc/", data=docs.to_json())
+ response_default = await ac.post("/doc_default/", data=docs.to_json())
+ embed_content_json = {'fastapi_docs': json.loads(docs.to_json())}
+ response_embed = await ac.post(
+ "/doc_embed/",
+ json=embed_content_json,
+ )
+ resp_doc = await ac.get("/docs")
+ resp_redoc = await ac.get("/redoc")
+
+ assert response.status_code == 200
+ assert response_default.status_code == 200
+ assert response_embed.status_code == 200
+ assert resp_doc.status_code == 200
+ assert resp_redoc.status_code == 200
+
+ resp_json = json.loads(response_default.content.decode())
+ assert isinstance(resp_json[0]["tensor"], list)
+ assert isinstance(resp_json[0]["embedding"], list)
+ assert isinstance(resp_json[0]["texts"][0]["embedding"], list)
+
+ docs_response = DocList[CustomDoc].from_json(response.content.decode())
+ assert len(docs_response) == 1
+ assert docs_response[0].url == 'photo.jpg'
+ assert docs_response[0].num == 3.5
+ assert docs_response[0].num_num == [4.5, 5.5]
+ assert docs_response[0].lll == [[[40]]]
+ assert docs_response[0].lu == [3, 4]
+ assert docs_response[0].fff == [[[40.2]]]
+ assert docs_response[0].di == {'a': 2}
+ assert docs_response[0].d == {'b': 'a'}
+ assert len(docs_response[0].texts) == 1
+ assert docs_response[0].texts[0].text == 'hey ha'
+ assert docs_response[0].texts[0].embedding.shape == (3,)
+ assert docs_response[0].tensor.shape == (10, 10, 10)
+ assert docs_response[0].u == 'a'
+ assert docs_response[0].single_text.text == 'single hey ha'
+ assert docs_response[0].single_text.embedding.shape == (2,)
+
+ docs_default = DocList[CustomDoc].from_json(response_default.content.decode())
+ assert len(docs_default) == 1
+ assert docs_default[0].url == 'photo.jpg'
+ assert docs_default[0].num == 3.5
+ assert docs_default[0].num_num == [4.5, 5.5]
+ assert docs_default[0].lll == [[[40]]]
+ assert docs_default[0].lu == [3, 4]
+ assert docs_default[0].fff == [[[40.2]]]
+ assert docs_default[0].di == {'a': 2}
+ assert docs_default[0].d == {'b': 'a'}
+ assert len(docs_default[0].texts) == 1
+ assert docs_default[0].texts[0].text == 'hey ha'
+ assert docs_default[0].texts[0].embedding.shape == (3,)
+ assert docs_default[0].tensor.shape == (10, 10, 10)
+ assert docs_default[0].u == 'a'
+ assert docs_default[0].single_text.text == 'single hey ha'
+ assert docs_default[0].single_text.embedding.shape == (2,)
+
+ docs_embed = DocList[CustomDoc].from_json(response_embed.content.decode())
+ assert len(docs_embed) == 1
+ assert docs_embed[0].url == 'photo.jpg'
+ assert docs_embed[0].num == 3.5
+ assert docs_embed[0].num_num == [4.5, 5.5]
+ assert docs_embed[0].lll == [[[40]]]
+ assert docs_embed[0].lu == [3, 4]
+ assert docs_embed[0].fff == [[[40.2]]]
+ assert docs_embed[0].di == {'a': 2}
+ assert docs_embed[0].d == {'b': 'a'}
+ assert len(docs_embed[0].texts) == 1
+ assert docs_embed[0].texts[0].text == 'hey ha'
+ assert docs_embed[0].texts[0].embedding.shape == (3,)
+ assert docs_embed[0].tensor.shape == (10, 10, 10)
+ assert docs_embed[0].u == 'a'
+ assert docs_embed[0].single_text.text == 'single hey ha'
+ assert docs_embed[0].single_text.embedding.shape == (2,)
+
+
+@pytest.mark.asyncio
+@pytest.mark.skipif(
+ not is_pydantic_v2, reason='Behavior is only available for Pydantic V2'
+)
+async def test_simple_directly():
+ app = FastAPI()
+
+ @app.post("/doc_list/", response_class=DocArrayResponse)
+ async def func_doc_list(fastapi_docs: DocList[TextDoc]) -> DocList[TextDoc]:
+ return fastapi_docs
+
+ @app.post("/doc_single/", response_class=DocArrayResponse)
+ async def func_doc_single(fastapi_doc: TextDoc) -> TextDoc:
+ return fastapi_doc
+
+ async with AsyncClient(app=app, base_url="http://test") as ac:
+ response_doc_list = await ac.post(
+ "/doc_list/", data=json.dumps([{"text": "text"}])
+ )
+ response_single = await ac.post(
+ "/doc_single/", data=json.dumps({"text": "text"})
+ )
+ resp_doc = await ac.get("/docs")
+ resp_redoc = await ac.get("/redoc")
+
+ assert response_doc_list.status_code == 200
+ assert response_single.status_code == 200
+ assert resp_doc.status_code == 200
+ assert resp_redoc.status_code == 200
+
+ docs = DocList[TextDoc].from_json(response_doc_list.content.decode())
+ assert len(docs) == 1
+ assert docs[0].text == 'text'
+
+ doc = TextDoc.from_json(response_single.content.decode())
+ assert doc == 'text'
diff --git a/tests/integrations/store/test_file.py b/tests/integrations/store/test_file.py
index 4cc3a9108cb..e51a61e1407 100644
--- a/tests/integrations/store/test_file.py
+++ b/tests/integrations/store/test_file.py
@@ -181,6 +181,7 @@ def test_list_and_delete(tmp_path: Path):
), 'Deleting a non-existent DA should return False'
+@pytest.mark.skip(reason='Skip it!')
def test_concurrent_push_pull(tmp_path: Path):
# Push to DA that is being pulled should not mess up the pull
namespace_dir = tmp_path
@@ -212,6 +213,7 @@ def _task(choice: str):
p.map(_task, ['pull', 'push', 'pull'])
+@pytest.mark.skip(reason='Skip it!')
@pytest.mark.slow
def test_concurrent_push(tmp_path: Path):
# Double push should fail the second push
diff --git a/tests/integrations/store/test_s3.py b/tests/integrations/store/test_s3.py
index 22105a0ce43..62e0126ea39 100644
--- a/tests/integrations/store/test_s3.py
+++ b/tests/integrations/store/test_s3.py
@@ -12,7 +12,7 @@
DA_LEN: int = 2**10
TOLERANCE_RATIO = 0.5 # Percentage of difference allowed in stream vs non-stream test
-BUCKET: str = 'da-pushpull'
+BUCKET: str = "da-pushpull"
RANDOM: str = uuid.uuid4().hex[:8]
pytestmark = [pytest.mark.s3]
@@ -22,16 +22,16 @@
def minio_container():
file_dir = os.path.dirname(__file__)
os.system(
- f"docker-compose -f {os.path.join(file_dir, 'docker-compose.yml')} up -d --remove-orphans minio"
+ f"docker compose -f {os.path.join(file_dir, 'docker-compose.yml')} up -d --remove-orphans minio"
)
time.sleep(1)
yield
os.system(
- f"docker-compose -f {os.path.join(file_dir, 'docker-compose.yml')} down --remove-orphans"
+ f"docker compose -f {os.path.join(file_dir, 'docker-compose.yml')} down --remove-orphans"
)
-@pytest.fixture(scope='session', autouse=True)
+@pytest.fixture(scope="session", autouse=True)
def testing_bucket(minio_container):
import boto3
from botocore.client import Config
@@ -59,7 +59,7 @@ def testing_bucket(minio_container):
Config(signature_version="s3v4"),
)
# make a bucket
- s3 = boto3.resource('s3')
+ s3 = boto3.resource("s3")
s3.create_bucket(Bucket=BUCKET)
yield
@@ -67,14 +67,15 @@ def testing_bucket(minio_container):
s3.Bucket(BUCKET).delete()
+@pytest.mark.skip(reason="Skip it!")
@pytest.mark.slow
def test_pushpull_correct(capsys):
- namespace_dir = f'{BUCKET}/test{RANDOM}/pushpull-correct'
+ namespace_dir = f"{BUCKET}/test{RANDOM}/pushpull-correct"
da1 = get_test_da(DA_LEN)
# Verbose
- da1.push(f's3://{namespace_dir}/meow', show_progress=True)
- da2 = DocList[TextDoc].pull(f's3://{namespace_dir}/meow', show_progress=True)
+ da1.push(f"s3://{namespace_dir}/meow", show_progress=True)
+ da2 = DocList[TextDoc].pull(f"s3://{namespace_dir}/meow", show_progress=True)
assert len(da1) == len(da2)
assert all(d1.id == d2.id for d1, d2 in zip(da1, da2))
assert all(d1.text == d2.text for d1, d2 in zip(da1, da2))
@@ -84,8 +85,8 @@ def test_pushpull_correct(capsys):
assert len(captured.err) == 0
# Quiet
- da2.push(f's3://{namespace_dir}/meow')
- da1 = DocList[TextDoc].pull(f's3://{namespace_dir}/meow')
+ da2.push(f"s3://{namespace_dir}/meow")
+ da1 = DocList[TextDoc].pull(f"s3://{namespace_dir}/meow")
assert len(da1) == len(da2)
assert all(d1.id == d2.id for d1, d2 in zip(da1, da2))
assert all(d1.text == d2.text for d1, d2 in zip(da1, da2))
@@ -95,17 +96,18 @@ def test_pushpull_correct(capsys):
assert len(captured.err) == 0
+@pytest.mark.skip(reason="Skip it!")
@pytest.mark.slow
def test_pushpull_stream_correct(capsys):
- namespace_dir = f'{BUCKET}/test{RANDOM}/pushpull-stream-correct'
+ namespace_dir = f"{BUCKET}/test{RANDOM}/pushpull-stream-correct"
da1 = get_test_da(DA_LEN)
# Verbosity and correctness
DocList[TextDoc].push_stream(
- iter(da1), f's3://{namespace_dir}/meow', show_progress=True
+ iter(da1), f"s3://{namespace_dir}/meow", show_progress=True
)
doc_stream2 = DocList[TextDoc].pull_stream(
- f's3://{namespace_dir}/meow', show_progress=True
+ f"s3://{namespace_dir}/meow", show_progress=True
)
assert all(d1.id == d2.id for d1, d2 in zip(da1, doc_stream2))
@@ -118,10 +120,10 @@ def test_pushpull_stream_correct(capsys):
# Quiet and chained
doc_stream = DocList[TextDoc].pull_stream(
- f's3://{namespace_dir}/meow', show_progress=False
+ f"s3://{namespace_dir}/meow", show_progress=False
)
DocList[TextDoc].push_stream(
- doc_stream, f's3://{namespace_dir}/meow2', show_progress=False
+ doc_stream, f"s3://{namespace_dir}/meow2", show_progress=False
)
captured = capsys.readouterr()
@@ -130,17 +132,18 @@ def test_pushpull_stream_correct(capsys):
# for some reason this test is failing with pydantic v2
+@pytest.mark.skip(reason="Skip it!")
@pytest.mark.slow
def test_pull_stream_vs_pull_full():
- namespace_dir = f'{BUCKET}/test{RANDOM}/pull-stream-vs-pull-full'
+ namespace_dir = f"{BUCKET}/test{RANDOM}/pull-stream-vs-pull-full"
DocList[TextDoc].push_stream(
gen_text_docs(DA_LEN * 1),
- f's3://{namespace_dir}/meow-short',
+ f"s3://{namespace_dir}/meow-short",
show_progress=False,
)
DocList[TextDoc].push_stream(
gen_text_docs(DA_LEN * 4),
- f's3://{namespace_dir}/meow-long',
+ f"s3://{namespace_dir}/meow-long",
show_progress=False,
)
@@ -155,104 +158,106 @@ def get_total_full(url: str):
return sum(len(d.text) for d in DocList[TextDoc].pull(url, show_progress=False))
# A warmup is needed to get accurate memory usage comparison
- _ = get_total_stream(f's3://{namespace_dir}/meow-short')
+ _ = get_total_stream(f"s3://{namespace_dir}/meow-short")
short_total_stream, (_, short_stream_peak) = get_total_stream(
- f's3://{namespace_dir}/meow-short'
+ f"s3://{namespace_dir}/meow-short"
)
long_total_stream, (_, long_stream_peak) = get_total_stream(
- f's3://{namespace_dir}/meow-long'
+ f"s3://{namespace_dir}/meow-long"
)
- _ = get_total_full(f's3://{namespace_dir}/meow-short')
+ _ = get_total_full(f"s3://{namespace_dir}/meow-short")
short_total_full, (_, short_full_peak) = get_total_full(
- f's3://{namespace_dir}/meow-short'
+ f"s3://{namespace_dir}/meow-short"
)
long_total_full, (_, long_full_peak) = get_total_full(
- f's3://{namespace_dir}/meow-long'
+ f"s3://{namespace_dir}/meow-long"
)
assert (
short_total_stream == short_total_full
- ), 'Streamed and non-streamed pull should have similar statistics'
+ ), "Streamed and non-streamed pull should have similar statistics"
assert (
long_total_stream == long_total_full
- ), 'Streamed and non-streamed pull should have similar statistics'
+ ), "Streamed and non-streamed pull should have similar statistics"
assert (
abs(long_stream_peak - short_stream_peak) / short_stream_peak < TOLERANCE_RATIO
- ), 'Streamed memory usage should not be dependent on the size of the data'
+ ), "Streamed memory usage should not be dependent on the size of the data"
assert (
abs(long_full_peak - short_full_peak) / short_full_peak > TOLERANCE_RATIO
- ), 'Full pull memory usage should be dependent on the size of the data'
+ ), "Full pull memory usage should be dependent on the size of the data"
+@pytest.mark.skip(reason="Skip it!")
@pytest.mark.slow
def test_list_and_delete():
- namespace_dir = f'{BUCKET}/test{RANDOM}/list-and-delete'
+ namespace_dir = f"{BUCKET}/test{RANDOM}/list-and-delete"
da_names = S3DocStore.list(namespace_dir, show_table=False)
assert len(da_names) == 0
DocList[TextDoc].push_stream(
- gen_text_docs(DA_LEN), f's3://{namespace_dir}/meow', show_progress=False
+ gen_text_docs(DA_LEN), f"s3://{namespace_dir}/meow", show_progress=False
)
- da_names = S3DocStore.list(f'{namespace_dir}', show_table=False)
- assert set(da_names) == {'meow'}
+ da_names = S3DocStore.list(f"{namespace_dir}", show_table=False)
+ assert set(da_names) == {"meow"}
DocList[TextDoc].push_stream(
- gen_text_docs(DA_LEN), f's3://{namespace_dir}/woof', show_progress=False
+ gen_text_docs(DA_LEN), f"s3://{namespace_dir}/woof", show_progress=False
)
- da_names = S3DocStore.list(f'{namespace_dir}', show_table=False)
- assert set(da_names) == {'meow', 'woof'}
+ da_names = S3DocStore.list(f"{namespace_dir}", show_table=False)
+ assert set(da_names) == {"meow", "woof"}
assert S3DocStore.delete(
- f'{namespace_dir}/meow'
- ), 'Deleting an existing DA should return True'
+ f"{namespace_dir}/meow"
+ ), "Deleting an existing DA should return True"
da_names = S3DocStore.list(namespace_dir, show_table=False)
- assert set(da_names) == {'woof'}
+ assert set(da_names) == {"woof"}
with pytest.raises(
ValueError
): # Deleting a non-existent DA without safety should raise an error
- S3DocStore.delete(f'{namespace_dir}/meow', missing_ok=False)
+ S3DocStore.delete(f"{namespace_dir}/meow", missing_ok=False)
assert not S3DocStore.delete(
- f'{namespace_dir}/meow', missing_ok=True
- ), 'Deleting a non-existent DA should return False'
+ f"{namespace_dir}/meow", missing_ok=True
+ ), "Deleting a non-existent DA should return False"
+@pytest.mark.skip(reason="Skip it!")
@pytest.mark.slow
def test_concurrent_push_pull():
# Push to DA that is being pulled should not mess up the pull
- namespace_dir = f'{BUCKET}/test{RANDOM}/concurrent-push-pull'
+ namespace_dir = f"{BUCKET}/test{RANDOM}/concurrent-push-pull"
DocList[TextDoc].push_stream(
gen_text_docs(DA_LEN),
- f's3://{namespace_dir}/da0',
+ f"s3://{namespace_dir}/da0",
show_progress=False,
)
global _task
def _task(choice: str):
- if choice == 'push':
+ if choice == "push":
DocList[TextDoc].push_stream(
gen_text_docs(DA_LEN),
- f's3://{namespace_dir}/da0',
+ f"s3://{namespace_dir}/da0",
show_progress=False,
)
- elif choice == 'pull':
+ elif choice == "pull":
pull_len = sum(
- 1 for _ in DocList[TextDoc].pull_stream(f's3://{namespace_dir}/da0')
+ 1 for _ in DocList[TextDoc].pull_stream(f"s3://{namespace_dir}/da0")
)
assert pull_len == DA_LEN
else:
- raise ValueError(f'Unknown choice {choice}')
+ raise ValueError(f"Unknown choice {choice}")
- with mp.get_context('fork').Pool(3) as p:
- p.map(_task, ['pull', 'push', 'pull'])
+ with mp.get_context("fork").Pool(3) as p:
+ p.map(_task, ["pull", "push", "pull"])
-@pytest.mark.skip(reason='Not Applicable')
+@pytest.mark.skip(reason="Not Applicable")
def test_concurrent_push():
"""
Amazon S3 does not support object locking for concurrent writers.
diff --git a/tests/integrations/torch/data/test_torch_dataset.py b/tests/integrations/torch/data/test_torch_dataset.py
index f358f1c16b8..5d8236a70b3 100644
--- a/tests/integrations/torch/data/test_torch_dataset.py
+++ b/tests/integrations/torch/data/test_torch_dataset.py
@@ -60,7 +60,9 @@ def test_torch_dataset(captions_da: DocList[PairTextImage]):
batch_lens = []
for batch in loader:
- assert isinstance(batch, DocVec[PairTextImage])
+ assert isinstance(batch, DocVec)
+ for d in batch:
+ assert isinstance(d, PairTextImage)
batch_lens.append(len(batch))
assert all(x == BATCH_SIZE for x in batch_lens[:-1])
@@ -140,7 +142,9 @@ def test_torch_dl_multiprocessing(captions_da: DocList[PairTextImage]):
batch_lens = []
for batch in loader:
- assert isinstance(batch, DocVec[PairTextImage])
+ assert isinstance(batch, DocVec)
+ for d in batch:
+ assert isinstance(d, PairTextImage)
batch_lens.append(len(batch))
assert all(x == BATCH_SIZE for x in batch_lens[:-1])
diff --git a/tests/units/array/stack/storage/test_storage.py b/tests/units/array/stack/storage/test_storage.py
index 01c1b68a165..b91585d3737 100644
--- a/tests/units/array/stack/storage/test_storage.py
+++ b/tests/units/array/stack/storage/test_storage.py
@@ -26,8 +26,9 @@ class MyDoc(BaseDoc):
for name in storage.any_columns['name']:
assert name == 'hello'
inner_docs = storage.doc_columns['doc']
- assert isinstance(inner_docs, DocVec[InnerDoc])
+ assert isinstance(inner_docs, DocVec)
for i, doc in enumerate(inner_docs):
+ assert isinstance(doc, InnerDoc)
assert doc.price == i
diff --git a/tests/units/array/stack/test_array_stacked.py b/tests/units/array/stack/test_array_stacked.py
index 2a3790da1d3..b1b385840dd 100644
--- a/tests/units/array/stack/test_array_stacked.py
+++ b/tests/units/array/stack/test_array_stacked.py
@@ -504,7 +504,9 @@ class ImageDoc(BaseDoc):
da = parse_obj_as(DocVec[ImageDoc], batch)
- assert isinstance(da, DocVec[ImageDoc])
+ assert isinstance(da, DocVec)
+ for d in da:
+ assert isinstance(d, ImageDoc)
def test_validation_column_tensor(batch):
@@ -536,14 +538,18 @@ def test_validation_column_doc(batch_nested_doc):
batch, Doc, Inner = batch_nested_doc
batch.inner = DocList[Inner]([Inner(hello='hello') for _ in range(10)])
- assert isinstance(batch.inner, DocVec[Inner])
+ assert isinstance(batch.inner, DocVec)
+ for d in batch.inner:
+ assert isinstance(d, Inner)
def test_validation_list_doc(batch_nested_doc):
batch, Doc, Inner = batch_nested_doc
batch.inner = [Inner(hello='hello') for _ in range(10)]
- assert isinstance(batch.inner, DocVec[Inner])
+ assert isinstance(batch.inner, DocVec)
+ for d in batch.inner:
+ assert isinstance(d, Inner)
def test_validation_col_doc_fail(batch_nested_doc):
diff --git a/tests/units/array/stack/test_proto.py b/tests/units/array/stack/test_proto.py
index 8c559826b80..d46766cde30 100644
--- a/tests/units/array/stack/test_proto.py
+++ b/tests/units/array/stack/test_proto.py
@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import os
from typing import Dict, Optional, Union
import numpy as np
@@ -245,6 +246,7 @@ class MyDoc(BaseDoc):
assert da_after._storage.any_columns['d'] == [None, None]
+@pytest.mark.skipif('GITHUB_WORKFLOW' in os.environ, reason='Flaky in Github')
@pytest.mark.proto
@pytest.mark.parametrize('tensor_type', [NdArray, TorchTensor])
def test_proto_tensor_type(tensor_type):
diff --git a/tests/units/array/test_array.py b/tests/units/array/test_array.py
index 1d93fb6b78c..8e51cc1c37e 100644
--- a/tests/units/array/test_array.py
+++ b/tests/units/array/test_array.py
@@ -486,6 +486,8 @@ def test_validate_list_dict():
dict(url=f'http://url.com/foo_{i}.png', tensor=NdArray(i)) for i in [2, 0, 1]
]
+ # docs = DocList[Image]([Image(url=image['url'], tensor=image['tensor']) for image in images])
+
docs = parse_obj_as(DocList[Image], images)
assert docs.url == [
@@ -520,5 +522,3 @@ def test_not_double_subcriptable():
with pytest.raises(TypeError) as excinfo:
da = DocList[TextDoc][TextDoc]
assert da is None
-
- assert 'not subscriptable' in str(excinfo.value)
diff --git a/tests/units/array/test_array_from_to_bytes.py b/tests/units/array/test_array_from_to_bytes.py
index abc31cb4ac7..0ab952ce4a7 100644
--- a/tests/units/array/test_array_from_to_bytes.py
+++ b/tests/units/array/test_array_from_to_bytes.py
@@ -43,11 +43,11 @@ def test_from_to_bytes(protocol, compress, show_progress, array_cls):
@pytest.mark.parametrize(
- 'protocol', ['protobuf'] # ['pickle-array', 'protobuf-array', 'protobuf', 'pickle']
+ 'protocol', ['pickle-array', 'protobuf-array', 'protobuf', 'pickle']
)
-@pytest.mark.parametrize('compress', ['lz4']) # , 'bz2', 'lzma', 'zlib', 'gzip', None])
-@pytest.mark.parametrize('show_progress', [False]) # [False, True])
-@pytest.mark.parametrize('array_cls', [DocVec]) # [DocList, DocVec])
+@pytest.mark.parametrize('compress', ['lz4', 'bz2', 'lzma', 'zlib', 'gzip', None])
+@pytest.mark.parametrize('show_progress', [False, True]) # [False, True])
+@pytest.mark.parametrize('array_cls', [DocList, DocVec])
def test_from_to_base64(protocol, compress, show_progress, array_cls):
da = array_cls[MyDoc](
[
@@ -75,27 +75,35 @@ def test_from_to_base64(protocol, compress, show_progress, array_cls):
# test_from_to_base64('protobuf', 'lz4', False, DocVec)
+class MyTensorTypeDocNdArray(BaseDoc):
+ embedding: NdArray
+ text: str
+ image: ImageDoc
-@pytest.mark.parametrize('tensor_type', [NdArray, TorchTensor])
-@pytest.mark.parametrize('protocol', ['protobuf-array', 'pickle-array'])
-def test_from_to_base64_tensor_type(tensor_type, protocol):
- class MyDoc(BaseDoc):
- embedding: tensor_type
- text: str
- image: ImageDoc
+class MyTensorTypeDocTorchTensor(BaseDoc):
+ embedding: TorchTensor
+ text: str
+ image: ImageDoc
- da = DocVec[MyDoc](
+
+@pytest.mark.parametrize(
+ 'doc_type, tensor_type',
+ [(MyTensorTypeDocNdArray, NdArray), (MyTensorTypeDocTorchTensor, TorchTensor)],
+)
+@pytest.mark.parametrize('protocol', ['protobuf-array', 'pickle-array'])
+def test_from_to_base64_tensor_type(doc_type, tensor_type, protocol):
+ da = DocVec[doc_type](
[
- MyDoc(
+ doc_type(
embedding=[1, 2, 3, 4, 5], text='hello', image=ImageDoc(url='aux.png')
),
- MyDoc(embedding=[5, 4, 3, 2, 1], text='hello world', image=ImageDoc()),
+ doc_type(embedding=[5, 4, 3, 2, 1], text='hello world', image=ImageDoc()),
],
tensor_type=tensor_type,
)
bytes_da = da.to_base64(protocol=protocol)
- da2 = DocVec[MyDoc].from_base64(
+ da2 = DocVec[doc_type].from_base64(
bytes_da, tensor_type=tensor_type, protocol=protocol
)
assert da2.tensor_type == tensor_type
diff --git a/tests/units/array/test_doclist_schema.py b/tests/units/array/test_doclist_schema.py
new file mode 100644
index 00000000000..02a5f562807
--- /dev/null
+++ b/tests/units/array/test_doclist_schema.py
@@ -0,0 +1,22 @@
+import pytest
+from docarray import BaseDoc, DocList
+from docarray.utils._internal.pydantic import is_pydantic_v2
+
+
+@pytest.mark.skipif(not is_pydantic_v2, reason='Feature only available for Pydantic V2')
+def test_schema_nested():
+ # check issue https://github.com/docarray/docarray/issues/1521
+
+ class Doc1Test(BaseDoc):
+ aux: str
+
+ class DocDocTest(BaseDoc):
+ docs: DocList[Doc1Test]
+
+ assert 'Doc1Test' in DocDocTest.schema()['$defs']
+ d = DocDocTest(docs=DocList[Doc1Test]([Doc1Test(aux='aux')]))
+
+ assert isinstance(d.docs, DocList)
+ for dd in d.docs:
+ assert isinstance(dd, Doc1Test)
+ assert d.docs.aux == ['aux']
diff --git a/tests/units/document/test_doc_wo_id.py b/tests/units/document/test_doc_wo_id.py
index ffda3ceec4f..4e2a8bba118 100644
--- a/tests/units/document/test_doc_wo_id.py
+++ b/tests/units/document/test_doc_wo_id.py
@@ -23,4 +23,9 @@ class A(BaseDocWithoutId):
cls_doc_list = DocList[A]
- assert isinstance(cls_doc_list, type)
+ da = cls_doc_list([A(text='hey here')])
+
+ assert isinstance(da, DocList)
+ for d in da:
+ assert isinstance(d, A)
+ assert not hasattr(d, 'id')
diff --git a/tests/units/typing/da/test_relations.py b/tests/units/typing/da/test_relations.py
index f583abef2ec..cadac712f5a 100644
--- a/tests/units/typing/da/test_relations.py
+++ b/tests/units/typing/da/test_relations.py
@@ -13,9 +13,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+import pytest
from docarray import BaseDoc, DocList
+from docarray.utils._internal.pydantic import is_pydantic_v2
+@pytest.mark.skipif(
+ is_pydantic_v2,
+ reason="Subscripted generics cannot be used with class and instance checks",
+)
def test_instance_and_equivalence():
class MyDoc(BaseDoc):
text: str
@@ -28,6 +35,10 @@ class MyDoc(BaseDoc):
assert isinstance(docs, DocList[MyDoc])
+@pytest.mark.skipif(
+ is_pydantic_v2,
+ reason="Subscripted generics cannot be used with class and instance checks",
+)
def test_subclassing():
class MyDoc(BaseDoc):
text: str
diff --git a/tests/units/util/test_create_dynamic_code_class.py b/tests/units/util/test_create_dynamic_code_class.py
index eba25911c4f..b7df497816d 100644
--- a/tests/units/util/test_create_dynamic_code_class.py
+++ b/tests/units/util/test_create_dynamic_code_class.py
@@ -45,6 +45,7 @@ class CustomDoc(BaseDoc):
new_custom_doc_model = create_base_doc_from_schema(
CustomDocCopy.schema(), 'CustomDoc', {}
)
+ print(f'new_custom_doc_model {new_custom_doc_model.schema()}')
original_custom_docs = DocList[CustomDoc](
[
@@ -131,6 +132,7 @@ class TextDocWithId(BaseDoc):
new_textdoc_with_id_model = create_base_doc_from_schema(
TextDocWithIdCopy.schema(), 'TextDocWithId', {}
)
+ print(f'new_textdoc_with_id_model {new_textdoc_with_id_model.schema()}')
original_text_doc_with_id = DocList[TextDocWithId](
[TextDocWithId(ia=f'ID {i}') for i in range(10)]
@@ -207,6 +209,7 @@ class CustomDoc(BaseDoc):
new_custom_doc_model = create_base_doc_from_schema(
CustomDocCopy.schema(), 'CustomDoc'
)
+ print(f'new_custom_doc_model {new_custom_doc_model.schema()}')
original_custom_docs = DocList[CustomDoc]()
if transformation == 'proto':
@@ -232,6 +235,7 @@ class TextDocWithId(BaseDoc):
new_textdoc_with_id_model = create_base_doc_from_schema(
TextDocWithIdCopy.schema(), 'TextDocWithId', {}
)
+ print(f'new_textdoc_with_id_model {new_textdoc_with_id_model.schema()}')
original_text_doc_with_id = DocList[TextDocWithId]()
if transformation == 'proto':
@@ -255,6 +259,9 @@ class ResultTestDoc(BaseDoc):
new_result_test_doc_with_id_model = create_base_doc_from_schema(
ResultTestDocCopy.schema(), 'ResultTestDoc', {}
)
+ print(
+ f'new_result_test_doc_with_id_model {new_result_test_doc_with_id_model.schema()}'
+ )
result_test_docs = DocList[ResultTestDoc]()
if transformation == 'proto':
@@ -309,9 +316,10 @@ class SearchResult(BaseDoc):
models_created_by_name = {}
SearchResult_aux = create_pure_python_type_model(SearchResult)
- _ = create_base_doc_from_schema(
+ m = create_base_doc_from_schema(
SearchResult_aux.schema(), 'SearchResult', models_created_by_name
)
+ print(f'm {m.schema()}')
QuoteFile_reconstructed_in_gateway_from_Search_results = models_created_by_name[
'QuoteFile'
]
@@ -323,3 +331,28 @@ class SearchResult(BaseDoc):
QuoteFile_reconstructed_in_gateway_from_Search_results(id='0', texts=textlist)
)
assert reconstructed_in_gateway_from_Search_results.texts[0].text == 'hey'
+
+
+def test_id_optional():
+ from docarray import BaseDoc
+ import json
+
+ class MyTextDoc(BaseDoc):
+ text: str
+ opt: Optional[str] = None
+
+ MyTextDoc_aux = create_pure_python_type_model(MyTextDoc)
+ td = create_base_doc_from_schema(MyTextDoc_aux.schema(), 'MyTextDoc')
+ print(f'{td.schema()}')
+ direct = MyTextDoc.from_json(json.dumps({"text": "text"}))
+ aux = MyTextDoc_aux.from_json(json.dumps({"text": "text"}))
+ indirect = td.from_json(json.dumps({"text": "text"}))
+ assert direct.text == 'text'
+ assert aux.text == 'text'
+ assert indirect.text == 'text'
+ direct = MyTextDoc(text='hey')
+ aux = MyTextDoc_aux(text='hey')
+ indirect = td(text='hey')
+ assert direct.text == 'hey'
+ assert aux.text == 'hey'
+ assert indirect.text == 'hey'
diff --git a/tests/units/util/test_map.py b/tests/units/util/test_map.py
index 3b9f102d928..65dd3c17389 100644
--- a/tests/units/util/test_map.py
+++ b/tests/units/util/test_map.py
@@ -96,4 +96,6 @@ def test_map_docs_batched(n_docs, batch_size, backend):
assert isinstance(it, Generator)
for batch in it:
- assert isinstance(batch, DocList[MyImage])
+ assert isinstance(batch, DocList)
+ for d in batch:
+ assert isinstance(d, MyImage)