diff --git a/.github/ISSUE_TEMPLATE/bug-v1-deprecated.yml b/.github/ISSUE_TEMPLATE/bug-v1-deprecated.yml
new file mode 100644
index 00000000000..5824e100c37
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/bug-v1-deprecated.yml
@@ -0,0 +1,66 @@
+name: 🐛 DocArray <=0.21 Bug (0.1.0 - 0.20.1) (Deprecated Version)
+description: Report a bug or unexpected behavior in DocArray version prior to v2 (0.21.1)
+labels: [bug V1, unconfirmed]
+
+body:
+ - type: markdown
+ attributes:
+ value: Thank you for contributing to DocArray! 🙌
+
+ - type: markdown
+ attributes:
+ value: "credits: This issue template is heavily inspired by [pydantic template](https://github.com/pydantic/pydantic/tree/main/.github/ISSUE_TEMPLATE)"
+
+
+ - type: checkboxes
+ id: checks
+ attributes:
+ label: Initial Checks
+ description: |
+ Just a few checks to make sure you need to create a bug report.
+ options:
+ - label: I have read and followed [the docs](https://docs.docarray.org/) and still think this is a bug
+ required: true
+
+ - type: textarea
+ id: description
+ attributes:
+ label: Description
+ description: |
+ Please explain what you're seeing and what you would expect to see.
+
+ Please provide as much detail as possible to make understanding and solving your problem as quick as possible. 🙏
+ validations:
+ required: true
+
+ - type: textarea
+ id: example
+ attributes:
+ label: Example Code
+ description: >
+ If applicable, please add a self-contained,
+ [minimal, reproducible, example](https://stackoverflow.com/help/minimal-reproducible-example)
+ demonstrating the bug.
+
+ placeholder: |
+ import docarray
+
+ ...
+ render: Python
+
+ - type: textarea
+ id: version
+ attributes:
+ label: Python, DocArray & OS Version
+ description: |
+ Which version of Python & DocArray are you using, and which Operating System?
+
+ Please run the following command and copy the output below:
+
+ ```bash
+ python -c "import docarray; print(docarray.__version__);"
+ ```
+
+ render: Text
+ validations:
+ required: true
diff --git a/.github/ISSUE_TEMPLATE/bug-v2.yml b/.github/ISSUE_TEMPLATE/bug-v2.yml
new file mode 100644
index 00000000000..bacad9a2e32
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/bug-v2.yml
@@ -0,0 +1,80 @@
+name: 🐛 DocArray Bug (>=0.30.0 )
+description: Report a bug or unexpected behavior in DocArray (v2) above (0.30.0)
+labels: [bug V2, unconfirmed]
+
+body:
+ - type: markdown
+ attributes:
+ value: Thank you for contributing to DocArray! 🙌
+
+ - type: markdown
+ attributes:
+ value: "credits: This issue template is heavily inspired by [pydantic template](https://github.com/pydantic/pydantic/tree/main/.github/ISSUE_TEMPLATE)"
+
+
+ - type: checkboxes
+ id: checks
+ attributes:
+ label: Initial Checks
+ description: |
+ Just a few checks to make sure you need to create a bug report.
+ options:
+ - label: I have read and followed [the docs](https://docs.docarray.org/) and still think this is a bug
+ required: true
+
+ - type: textarea
+ id: description
+ attributes:
+ label: Description
+ description: |
+ Please explain what you're seeing and what you would expect to see.
+
+ Please provide as much detail as possible to make understanding and solving your problem as quick as possible. 🙏
+ validations:
+ required: true
+
+ - type: textarea
+ id: example
+ attributes:
+ label: Example Code
+ description: >
+ If applicable, please add a self-contained,
+ [minimal, reproducible, example](https://stackoverflow.com/help/minimal-reproducible-example)
+ demonstrating the bug.
+
+ placeholder: |
+ import docarray
+
+ ...
+ render: Python
+
+ - type: textarea
+ id: version
+ attributes:
+ label: Python, DocArray & OS Version
+ description: |
+ Which version of Python & DocArray are you using, and which Operating System?
+
+ Please run the following command and copy the output below:
+
+ ```bash
+ python -c "import docarray; print(docarray.__version__);"
+ ```
+
+ render: Text
+ validations:
+ required: true
+
+ - type: checkboxes
+ id: affected-components
+ attributes:
+ label: Affected Components
+ description: Which of the following parts of docarray does this feature affect?
+ # keep this lis in sync with bug.yml
+ options:
+ - label: '[Vector Database / Index](https://docs.docarray.org/user_guide/storing/docindex/)'
+ - label: '[Representing](https://docs.docarray.org/user_guide/representing/first_step)'
+ - label: '[Sending](https://docs.docarray.org/user_guide/sending/first_step/)'
+ - label: '[storing](https://docs.docarray.org/user_guide/storing/first_step/)'
+ - label: '[multi modal data type](https://docs.docarray.org/data_types/first_steps/)'
+
diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml
new file mode 100644
index 00000000000..994cba23f53
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/config.yml
@@ -0,0 +1,8 @@
+blank_issues_enabled: true
+contact_links:
+ - name: 🤔 Ask a Question
+ url: 'https://github.com/docarray/docarray/discussions/new?category=question'
+ about: Ask a question about how to use docarray using github discussions
+ - name: 🤔 Ask a Question in discord
+ url: 'https://discord.com/invite/WaMp6PVPgR'
+ about: Or in our discord channel
diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml
new file mode 100644
index 00000000000..0f4f36a9543
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/feature_request.yml
@@ -0,0 +1,54 @@
+name: 🚀 DocArray Feature request
+description: |
+ Suggest a new feature for DocArray
+
+labels: [feature request]
+
+body:
+ - type: markdown
+ attributes:
+ value: Thank you for contributing to docarray! ✊
+
+ - type: markdown
+ attributes:
+ value: "credits: This issue template is heavily inspired by [pydantic template](https://github.com/pydantic/pydantic/tree/main/.github/ISSUE_TEMPLATE)"
+
+ - type: checkboxes
+ id: searched
+ attributes:
+ label: Initial Checks
+ description: |
+ Just a few checks to make sure you need to create a feature request.
+
+ options:
+ - label: I have searched Google & GitHub for similar requests and couldn't find anything
+ required: true
+ - label: I have read and followed [the docs](https://docs.docarray.org) and still think this feature is missing
+ required: true
+
+ - type: textarea
+ id: description
+ attributes:
+ label: Description
+ description: |
+ Please give as much detail as possible about the feature you would like to suggest. 🙏
+
+ You might like to add:
+ * A demo of how code might look when using the feature
+ * Your use case(s) for the feature
+ * Why the feature should be added to DocArray (as opposed to another library or just implemented in your code)
+ validations:
+ required: true
+
+ - type: checkboxes
+ id: affected-components
+ attributes:
+ label: Affected Components
+ description: Which of the following parts of DocArray does this feature affect?
+ # keep this lis in sync with bug.yml
+ options:
+ - label: '[Vector Database / Index](https://docs.docarray.org/user_guide/storing/docindex/)'
+ - label: '[Representing](https://docs.docarray.org/user_guide/representing/first_step)'
+ - label: '[Sending](https://docs.docarray.org/user_guide/sending/first_step/)'
+ - label: '[storing](https://docs.docarray.org/user_guide/storing/first_step/)'
+ - label: '[multi modal data type](https://docs.docarray.org/data_types/first_steps/)'
diff --git a/.github/codecov.yml b/.github/codecov.yml
new file mode 100644
index 00000000000..c0e14561232
--- /dev/null
+++ b/.github/codecov.yml
@@ -0,0 +1,17 @@
+codecov:
+ # https://docs.codecov.io/docs/comparing-commits
+ allow_coverage_offsets: true
+coverage:
+ status:
+ project:
+ default:
+ informational: true
+ target: auto # auto compares coverage to the previous base commit
+ flags:
+ - docarray
+ comment:
+ layout: "reach, diff, flags, files"
+ behavior: default
+ require_changes: false # if true: only post the comment if coverage changes
+ branches: # branch names that can post comment
+ - "master"
diff --git a/.github/workflows/add_license.yml b/.github/workflows/add_license.yml
new file mode 100644
index 00000000000..9c63c711a46
--- /dev/null
+++ b/.github/workflows/add_license.yml
@@ -0,0 +1,51 @@
+name: Add License to Python Files
+
+on:
+ push:
+ branches:
+ - main
+
+jobs:
+ add-license:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v2
+
+ - name: Set up Python
+ uses: actions/setup-python@v3
+ with:
+ python-version: "3.10"
+
+ - name: Run add_license.sh and check for changes
+ id: add_license
+ run: |
+ chmod +x scripts/add_license.sh
+ CHANGES=$(git status --porcelain)
+ ./scripts/add_license.sh
+ NEW_CHANGES=$(git status --porcelain)
+ echo "::set-output name=changes::${NEW_CHANGES}"
+
+ - name: Commit changes if there are modifications
+ run: |
+ if [[ -n "${{ steps.add_license.outputs.changes }}" ]]; then
+ git config --local user.email "dev-bot@jina.ai"
+ git config --local user.name "Jina Dev Bot"
+ git add .
+ git commit -m "chore: add license to Python files"
+ git push
+ else
+ echo "No changes detected, skipping commit."
+ fi
+ if: steps.add_license.outputs.changes != ''
+
+ - name: Create Pull Request
+ uses: peter-evans/create-pull-request@v3
+ with:
+ title: "Add license to Python files"
+ branch: "add-license"
+ commit-message: "chore: add license to Python files"
+ base: "main"
+ labels: "auto-merge"
+ token: ${{ secrets.JINA_DEV_BOT }}
+ if: steps.add_license.outputs.changes != ''
\ No newline at end of file
diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml
index 76c1bd87c60..e0a14b5252c 100644
--- a/.github/workflows/cd.yml
+++ b/.github/workflows/cd.yml
@@ -21,7 +21,7 @@ jobs:
- name: Pre-release (.devN)
run: |
git fetch --depth=1 origin +refs/tags/*:refs/tags/*
- pip install poetry
+ pip install poetry==1.7.1
./scripts/release.sh
env:
PYPI_USERNAME: ${{ secrets.TWINE_USERNAME }}
@@ -35,20 +35,16 @@ jobs:
steps:
- uses: actions/checkout@v3
with:
- fetch-depth: 0
-
- - name: Get changed files
- id: changed-files-specific
- uses: tj-actions/changed-files@v34
- with:
- files: |
- README.md
+ fetch-depth: 2
- name: Check if README is modified
id: step_output
- if: steps.changed-files-specific.outputs.any_changed == 'true'
run: |
- echo "readme_changed=true" >> $GITHUB_OUTPUT
+ if git diff --name-only HEAD^ HEAD | grep -q "README.md"; then
+ echo "readme_changed=true" >> $GITHUB_OUTPUT
+ else
+ echo "readme_changed=false" >> $GITHUB_OUTPUT
+ fi
publish-docarray-org:
needs: check-readme-modification
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index a5f55139f06..07c32d0b873 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -18,14 +18,14 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2.5.0
- - name: Set up Python 3.7
+ - name: Set up Python 3.8
uses: actions/setup-python@v4
with:
- python-version: 3.7
+ python-version: 3.8
- name: Lint with ruff
run: |
python -m pip install --upgrade pip
- python -m pip install poetry
+ python -m pip install poetry==1.7.1
poetry install
# stop the build if there are Python syntax errors or undefined names
@@ -37,14 +37,14 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2.5.0
- - name: Set up Python 3.7
+ - name: Set up Python 3.8
uses: actions/setup-python@v4
with:
- python-version: 3.7
+ python-version: 3.8
- name: check black
run: |
python -m pip install --upgrade pip
- python -m pip install poetry
+ python -m pip install poetry==1.7.1
poetry install --only dev
poetry run black --check .
@@ -55,43 +55,47 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2.5.0
- - name: Set up Python 3.7
+ - name: Set up Python 3.8
uses: actions/setup-python@v4
with:
- python-version: 3.7
+ python-version: 3.8
- name: Prepare environment
run: |
python -m pip install --upgrade pip
- python -m pip install poetry
+ python -m pip install poetry==1.7.1
poetry install --without dev
- poetry run pip install tensorflow==2.11.0
+ poetry run pip install tensorflow==2.12.0
+ poetry run pip install jax
+ poetry run pip uninstall -y torch
+ poetry run pip install torch
- name: Test basic import
run: poetry run python -c 'from docarray import DocList, BaseDoc'
-
- check-mypy:
- runs-on: ubuntu-latest
- steps:
- - uses: actions/checkout@v2.5.0
- - name: Set up Python 3.7
- uses: actions/setup-python@v4
- with:
- python-version: 3.7
- - name: check mypy
- run: |
- python -m pip install --upgrade pip
- python -m pip install poetry
- poetry install --all-extras
- poetry run mypy docarray
+ # it is time to say bye bye to mypy because of the way we handle support of pydantic v1 and v2
+ # check-mypy:
+ # runs-on: ubuntu-latest
+ # steps:
+ # - uses: actions/checkout@v2.5.0
+ # - name: Set up Python 3.8
+ # uses: actions/setup-python@v4
+ # with:
+ # python-version: 3.8
+ # - name: check mypy
+ # run: |
+ # python -m pip install --upgrade pip
+ # python -m pip install poetry
+ # poetry install --all-extras
+ # poetry run mypy docarray
docarray-test:
- needs: [lint-ruff, check-black, import-test]
+ needs: [import-test]
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
- python-version: [3.7]
+ python-version: [3.9]
+ pydantic-version: ["pydantic-v2", "pydantic-v1"]
test-path: [tests/integrations, tests/units, tests/documentation]
steps:
- uses: actions/checkout@v2.5.0
@@ -102,41 +106,47 @@ jobs:
- name: Prepare environment
run: |
python -m pip install --upgrade pip
- python -m pip install poetry
+ python -m pip install poetry==1.7.1
poetry install --all-extras
poetry run pip install elasticsearch==8.6.2
+ ./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }}
+ poetry run pip uninstall -y torch
+ poetry run pip install torch
+ poetry run pip install numpy==1.26.1
sudo apt-get update
sudo apt-get install --no-install-recommends ffmpeg
-
+
- name: Test
id: test
run: |
- poetry run pytest -m "not (tensorflow or benchmark or index)" ${{ matrix.test-path }} --ignore=tests/integrations/store/test_jac.py
+ poetry run pytest -m "not (tensorflow or benchmark or index or jax)" --cov=docarray --cov-report=xml -v -s ${{ matrix.test-path }} --ignore=tests/integrations/store/test_jac.py
+ echo "flag it as docarray for codeoverage"
+ echo "codecov_flag=docarray" >> $GITHUB_OUTPUT
timeout-minutes: 30
env:
JINA_AUTH_TOKEN: "${{ secrets.JINA_AUTH_TOKEN }}"
-# - name: Check codecov file
-# id: check_files
-# uses: andstor/file-existence-action@v1
-# with:
-# files: "coverage.xml"
-# - name: Upload coverage from test to Codecov
-# uses: codecov/codecov-action@v3.1.1
-# if: steps.check_files.outputs.files_exists == 'true' && ${{ matrix.python-version }} == '3.7'
-# with:
-# file: coverage.xml
-# flags: ${{ steps.test.outputs.codecov_flag }}
-# fail_ci_if_error: false
-# token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos
-
+ - name: Check codecov file
+ id: check_files
+ uses: andstor/file-existence-action@v1
+ with:
+ files: "coverage.xml"
+ - name: Upload coverage from test to Codecov
+ uses: codecov/codecov-action@v3.1.1
+ if: steps.check_files.outputs.files_exists == 'true' && ${{ matrix.python-version }} == '3.8'
+ with:
+ file: coverage.xml
+ name: benchmark-test-codecov
+ flags: ${{ steps.test.outputs.codecov_flag }}
+ fail_ci_if_error: false
- docarray-test-jac:
- needs: [lint-ruff, check-black, import-test]
+ docarray-test-proto3:
+ needs: [import-test]
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
- python-version: [3.7]
+ python-version: [3.8]
+ pydantic-version: ["pydantic-v2", "pydantic-v1"]
steps:
- uses: actions/checkout@v2.5.0
- name: Set up Python ${{ matrix.python-version }}
@@ -146,29 +156,45 @@ jobs:
- name: Prepare environment
run: |
python -m pip install --upgrade pip
- python -m pip install poetry
- poetry install --all-extras
- poetry run pip install elasticsearch==8.6.2
+ python -m pip install poetry==1.7.1
+ poetry install --all-extras
+ ./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }}
+ poetry run pip install protobuf==3.20.0 # we check that we support 3.19
+ poetry run pip uninstall -y torch
+ poetry run pip install torch
sudo apt-get update
sudo apt-get install --no-install-recommends ffmpeg
-
- name: Test
id: test
run: |
- poetry run pytest -m "not (tensorflow or benchmark or index)" tests/integrations/store/test_jac.py
+ poetry run pytest -m 'proto' --cov=docarray --cov-report=xml -v -s tests
+ echo "flag it as docarray for codeoverage"
+ echo "codecov_flag=docarray" >> $GITHUB_OUTPUT
timeout-minutes: 30
- env:
- JINA_AUTH_TOKEN: "${{ secrets.JINA_AUTH_TOKEN }}"
+ - name: Check codecov file
+ id: check_files
+ uses: andstor/file-existence-action@v1
+ with:
+ files: "coverage.xml"
+ - name: Upload coverage from test to Codecov
+ uses: codecov/codecov-action@v3.1.1
+ if: steps.check_files.outputs.files_exists == 'true' && ${{ matrix.python-version }} == '3.8'
+ with:
+ file: coverage.xml
+ name: benchmark-test-codecov
+ flags: ${{ steps.test.outputs.codecov_flag }}
+ fail_ci_if_error: false
- docarray-test-uncaped: # do test without using poetry lock. This does not block ci passing
- needs: [lint-ruff, check-black, import-test]
+ docarray-doc-index:
+ needs: [import-test]
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
- python-version: [3.7]
- test-path: [tests/integrations, tests/units, tests/documentation]
+ python-version: [3.8]
+ db_test_folder: [base_classes, elastic, epsilla, hnswlib, qdrant, weaviate, redis, milvus]
+ pydantic-version: ["pydantic-v2", "pydantic-v1"]
steps:
- uses: actions/checkout@v2.5.0
- name: Set up Python ${{ matrix.python-version }}
@@ -178,29 +204,46 @@ jobs:
- name: Prepare environment
run: |
python -m pip install --upgrade pip
- python -m pip install poetry
- rm poetry.lock
+ python -m pip install poetry==1.7.1
poetry install --all-extras
- poetry run pip install elasticsearch==8.6.2
+ ./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }}
+ poetry run pip install protobuf==3.20.0
+ poetry run pip install tensorflow==2.12.0
+ poetry run pip uninstall -y torch
+ poetry run pip install torch
sudo apt-get update
sudo apt-get install --no-install-recommends ffmpeg
- name: Test
id: test
run: |
- poetry run pytest -m "not (tensorflow or benchmark or index)" ${{ matrix.test-path }} --ignore=tests/integrations/store/test_jac.py
+ poetry run pytest -m 'index and not elasticv8' --cov=docarray --cov-report=xml -v -s tests/index/${{ matrix.db_test_folder }}
+ echo "flag it as docarray for codeoverage"
+ echo "codecov_flag=docarray" >> $GITHUB_OUTPUT
timeout-minutes: 30
- env:
- JINA_AUTH_TOKEN: "${{ secrets.JINA_AUTH_TOKEN }}"
+ - name: Check codecov file
+ id: check_files
+ uses: andstor/file-existence-action@v1
+ with:
+ files: "coverage.xml"
+ - name: Upload coverage from test to Codecov
+ uses: codecov/codecov-action@v3.1.1
+ if: steps.check_files.outputs.files_exists == 'true' && ${{ matrix.python-version }} == '3.8'
+ with:
+ file: coverage.xml
+ name: benchmark-test-codecov
+ flags: ${{ steps.test.outputs.codecov_flag }}
+ fail_ci_if_error: false
- docarray-test-proto3:
- needs: [lint-ruff, check-black, import-test]
+ docarray-elastic-v8:
+ needs: [import-test]
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
- python-version: [3.7]
+ python-version: [3.8]
+ pydantic-version: ["pydantic-v2", "pydantic-v1"]
steps:
- uses: actions/checkout@v2.5.0
- name: Set up Python ${{ matrix.python-version }}
@@ -210,27 +253,46 @@ jobs:
- name: Prepare environment
run: |
python -m pip install --upgrade pip
- python -m pip install poetry
+ python -m pip install poetry==1.7.1
poetry install --all-extras
- poetry run pip install protobuf==3.19.0 # we check that we support 3.19
+ ./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }}
+ poetry run pip install protobuf==3.20.0
+ poetry run pip install tensorflow==2.12.0
+ poetry run pip install elasticsearch==8.6.2
+ poetry run pip uninstall -y torch
+ poetry run pip install torch
sudo apt-get update
sudo apt-get install --no-install-recommends ffmpeg
- name: Test
id: test
run: |
- poetry run pytest -m 'proto' tests
+ poetry run pytest -m 'index and elasticv8' --cov=docarray --cov-report=xml -v -s tests
+ echo "flag it as docarray for codeoverage"
+ echo "codecov_flag=docarray" >> $GITHUB_OUTPUT
timeout-minutes: 30
+ - name: Check codecov file
+ id: check_files
+ uses: andstor/file-existence-action@v1
+ with:
+ files: "coverage.xml"
+ - name: Upload coverage from test to Codecov
+ uses: codecov/codecov-action@v3.1.1
+ if: steps.check_files.outputs.files_exists == 'true' && ${{ matrix.python-version }} == '3.8'
+ with:
+ file: coverage.xml
+ name: benchmark-test-codecov
+ flags: ${{ steps.test.outputs.codecov_flag }}
+ fail_ci_if_error: false
-
- docarray-doc-index:
- needs: [lint-ruff, check-black, import-test]
+ docarray-test-tensorflow:
+ needs: [import-test]
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
- python-version: [3.7]
- db_test_folder: [base_classes, elastic, hnswlib, qdrant, weaviate]
+ python-version: [3.8]
+ pydantic-version: ["pydantic-v2", "pydantic-v1"]
steps:
- uses: actions/checkout@v2.5.0
- name: Set up Python ${{ matrix.python-version }}
@@ -240,27 +302,46 @@ jobs:
- name: Prepare environment
run: |
python -m pip install --upgrade pip
- python -m pip install poetry
+ python -m pip install poetry==1.7.1
poetry install --all-extras
- poetry run pip install protobuf==3.19.0
- poetry run pip install tensorflow==2.11.0
+ ./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }}
+ poetry run pip install protobuf==3.20.0
+ poetry run pip install google-auth==2.23.0
+ poetry run pip install tensorflow==2.12.0
+ poetry run pip uninstall -y torch
+ poetry run pip install torch
sudo apt-get update
sudo apt-get install --no-install-recommends ffmpeg
- name: Test
id: test
run: |
- poetry run pytest -m 'index and not elasticv8' tests/index/${{ matrix.db_test_folder }}
+ poetry run pytest -m 'tensorflow' --cov=docarray --cov-report=xml -v -s tests
+ echo "flag it as docarray for codeoverage"
+ echo "codecov_flag=docarray" >> $GITHUB_OUTPUT
timeout-minutes: 30
+ - name: Check codecov file
+ id: check_files
+ uses: andstor/file-existence-action@v1
+ with:
+ files: "coverage.xml"
+ - name: Upload coverage from test to Codecov
+ uses: codecov/codecov-action@v3.1.1
+ if: steps.check_files.outputs.files_exists == 'true' && ${{ matrix.python-version }} == '3.8'
+ with:
+ file: coverage.xml
+ name: benchmark-test-codecov
+ flags: ${{ steps.test.outputs.codecov_flag }}
+ fail_ci_if_error: false
-
- docarray-elastic-v8:
- needs: [lint-ruff, check-black, import-test]
+ docarray-test-jax:
+ needs: [import-test]
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
- python-version: [3.7]
+ python-version: [3.8]
+ pydantic-version: ["pydantic-v2", "pydantic-v1"]
steps:
- uses: actions/checkout@v2.5.0
- name: Set up Python ${{ matrix.python-version }}
@@ -270,56 +351,44 @@ jobs:
- name: Prepare environment
run: |
python -m pip install --upgrade pip
- python -m pip install poetry
+ python -m pip install poetry==1.7.1
poetry install --all-extras
- poetry run pip install protobuf==3.19.0
- poetry run pip install tensorflow==2.11.0
- poetry run pip install elasticsearch==8.6.2
- sudo apt-get update
- sudo apt-get install --no-install-recommends ffmpeg
+ ./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }}
+ poetry run pip uninstall -y torch
+ poetry run pip install torch
+ poetry run pip install jaxlib
+ poetry run pip install jax
- name: Test
id: test
run: |
- poetry run pytest -m 'index and elasticv8' tests
+ poetry run pytest -m 'jax' --cov=docarray --cov-report=xml -v -s tests
+ echo "flag it as docarray for codeoverage"
+ echo "codecov_flag=docarray" >> $GITHUB_OUTPUT
timeout-minutes: 30
-
- docarray-test-tensorflow:
- needs: [lint-ruff, check-black, import-test]
- runs-on: ubuntu-latest
- strategy:
- fail-fast: false
- matrix:
- python-version: [3.7]
- steps:
- - uses: actions/checkout@v2.5.0
- - name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v4
+ - name: Check codecov file
+ id: check_files
+ uses: andstor/file-existence-action@v1
with:
- python-version: ${{ matrix.python-version }}
- - name: Prepare environment
- run: |
- python -m pip install --upgrade pip
- python -m pip install poetry
- poetry install --all-extras
- poetry run pip install protobuf==3.19.0
- poetry run pip install tensorflow==2.11.0
- sudo apt-get update
- sudo apt-get install --no-install-recommends ffmpeg
+ files: "coverage.xml"
+ - name: Upload coverage from test to Codecov
+ uses: codecov/codecov-action@v3.1.1
+ if: steps.check_files.outputs.files_exists == 'true' && ${{ matrix.python-version }} == '3.8'
+ with:
+ file: coverage.xml
+ name: benchmark-test-codecov
+ flags: ${{ steps.test.outputs.codecov_flag }}
+ fail_ci_if_error: false
+
- - name: Test
- id: test
- run: |
- poetry run pytest -m 'tensorflow' tests
- timeout-minutes: 30
docarray-test-benchmarks:
- needs: [lint-ruff, check-black, import-test]
+ needs: [import-test]
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
- python-version: [3.7]
+ python-version: [3.8]
steps:
- uses: actions/checkout@v2.5.0
- name: Set up Python ${{ matrix.python-version }}
@@ -329,18 +398,35 @@ jobs:
- name: Prepare environment
run: |
python -m pip install --upgrade pip
- python -m pip install poetry
+ python -m pip install poetry==1.7.1
poetry install --all-extras
+ poetry run pip uninstall -y torch
+ poetry run pip install torch
- name: Test
id: test
run: |
- poetry run pytest -m 'benchmark' tests
+ poetry run pytest -m 'benchmark' --cov=docarray --cov-report=xml -v -s tests
+ echo "flag it as docarray for codeoverage"
+ echo "codecov_flag=docarray" >> $GITHUB_OUTPUT
timeout-minutes: 30
+ - name: Check codecov file
+ id: check_files
+ uses: andstor/file-existence-action@v1
+ with:
+ files: "coverage.xml"
+ - name: Upload coverage from test to Codecov
+ uses: codecov/codecov-action@v3.1.1
+ if: steps.check_files.outputs.files_exists == 'true' && ${{ matrix.python-version }} == '3.8'
+ with:
+ file: coverage.xml
+ name: benchmark-test-codecov
+ flags: ${{ steps.test.outputs.codecov_flag }}
+ fail_ci_if_error: false
- # just for blocking the merge until all parallel core-test are successful
+ # just for blocking the merge until all parallel tests are successful
success-all-test:
- needs: [docarray-test, docarray-test-proto3, docarray-doc-index, docarray-elastic-v8, docarray-test-tensorflow, docarray-test-benchmarks, import-test, check-black, check-mypy, lint-ruff]
+ needs: [docarray-test, docarray-test-proto3, docarray-doc-index, docarray-elastic-v8, docarray-test-tensorflow, docarray-test-benchmarks, import-test, check-black, lint-ruff]
if: always()
runs-on: ubuntu-latest
steps:
diff --git a/.github/workflows/ci_only_pr.yml b/.github/workflows/ci_only_pr.yml
index adac73476a0..9d040e72b62 100644
--- a/.github/workflows/ci_only_pr.yml
+++ b/.github/workflows/ci_only_pr.yml
@@ -35,7 +35,7 @@ jobs:
- uses: actions/checkout@v2.5.0
- uses: actions/setup-python@v4
with:
- python-version: 3.7
+ python-version: 3.8
- uses: actions/setup-node@v2
with:
node-version: '14'
@@ -43,7 +43,7 @@ jobs:
run: |
npm i -g netlify-cli
python -m pip install --upgrade pip
- python -m pip install poetry
+ python -m pip install poetry==1.7.1
python -m poetry config virtualenvs.create false && python -m poetry install --no-interaction --no-ansi --all-extras
cd docs
diff --git a/.github/workflows/force-docs-build.yml b/.github/workflows/force-docs-build.yml
index 1c85125dcbd..171a963ade5 100644
--- a/.github/workflows/force-docs-build.yml
+++ b/.github/workflows/force-docs-build.yml
@@ -57,7 +57,7 @@ jobs:
fetch-depth: 1
- uses: actions/setup-python@v4
with:
- python-version: '3.7'
+ python-version: '3.8'
- name: Install Dependencies
run: |
python -m pip install --upgrade pip
diff --git a/.github/workflows/force-release.yml b/.github/workflows/force-release.yml
index b3b81cce7fc..3ad1af18ced 100644
--- a/.github/workflows/force-release.yml
+++ b/.github/workflows/force-release.yml
@@ -36,15 +36,15 @@ jobs:
# submodules: true
- uses: actions/setup-python@v4
with:
- python-version: 3.7
+ python-version: 3.8
- run: |
git fetch --depth=1 origin +refs/tags/*:refs/tags/*
npm install git-release-notes
- pip install poetry
+ python -m pip install poetry==1.7.1
./scripts/release.sh final "${{ github.event.inputs.release_reason }}" "${{github.actor}}"
env:
- PYPI_USERNAME: ${{ secrets.TWINE_USERNAME }}
- PYPI_PASSWORD: ${{ secrets.TWINE_PASSWORD }}
+ TWINE_USERNAME: __token__
+ TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
JINA_SLACK_WEBHOOK: ${{ secrets.JINA_SLACK_WEBHOOK }}
- if: failure()
run: echo "nothing to release"
diff --git a/.github/workflows/tag.yml b/.github/workflows/tag.yml
index cd658781be6..b4efcc4e2fb 100644
--- a/.github/workflows/tag.yml
+++ b/.github/workflows/tag.yml
@@ -32,7 +32,7 @@ jobs:
ref: 'main'
- uses: actions/setup-python@v4
with:
- python-version: 3.7
+ python-version: 3.8
- run: |
python scripts/get-last-release-note.py
- name: Create Release
diff --git a/.github/workflows/uncaped.yml b/.github/workflows/uncaped.yml
new file mode 100644
index 00000000000..ccb56bc2497
--- /dev/null
+++ b/.github/workflows/uncaped.yml
@@ -0,0 +1,37 @@
+name: Uncaped
+
+on:
+ schedule:
+ - cron: '0 0,1 * * *' # Run at midnight, 1 AM UTC
+
+jobs:
+ docarray-test-uncaped:
+ runs-on: ubuntu-latest
+ strategy:
+ fail-fast: false
+ matrix:
+ python-version: [3.8]
+ test-path: [tests/integrations, tests/units, tests/documentation]
+ steps:
+ - uses: actions/checkout@v2.5.0
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Prepare environment
+ run: |
+ python -m pip install --upgrade pip
+ python -m pip install poetry==1.7.1
+ rm poetry.lock
+ poetry install --all-extras
+ poetry run pip install elasticsearch==8.6.2
+ sudo apt-get update
+ sudo apt-get install --no-install-recommends ffmpeg
+
+ - name: Test
+ id: test
+ run: |
+ poetry run pytest -m "not (tensorflow or benchmark or index)" ${{ matrix.test-path }} --ignore=tests/integrations/store/test_jac.py
+ timeout-minutes: 30
+ env:
+ JINA_AUTH_TOKEN: "${{ secrets.JINA_AUTH_TOKEN }}"
diff --git a/.gitignore b/.gitignore
index a0c35405804..c467cc7b2b3 100644
--- a/.gitignore
+++ b/.gitignore
@@ -151,4 +151,6 @@ output/
.pytest-kind
.kube
-*.ipynb
\ No newline at end of file
+*.ipynb
+
+.python-version
\ No newline at end of file
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 9df8e8a06d2..23993cc072a 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -21,7 +21,7 @@ repos:
exclude: ^(docarray/proto/pb/docarray_pb2.py|docarray/proto/pb/docarray_pb2.py|docs/|docarray/resources/)
- repo: https://github.com/charliermarsh/ruff-pre-commit
- rev: v0.0.243
+ rev: v0.0.250
hooks:
- id: ruff
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 5be7221dd15..48f2dedcd93 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -3,6 +3,19 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
## Release Note (`0.30.0`)
@@ -243,3 +256,545 @@
- [[```94cccbb5```](https://github.com/jina-ai/docarray/commit/94cccbb5fb296e7b2639ecec12c0f163f06546d6)] __-__ fix cd for documentation (#1509) (*samsja*)
- [[```04b930f4```](https://github.com/jina-ai/docarray/commit/04b930f4751a44188b6cb531c1cb4d2f5c43ae02)] __-__ __version__: the next version will be 0.31.1 (*Jina Dev Bot*)
+
+## Release Note (`0.32.0`)
+
+> Release time: 2023-05-16 11:29:36
+
+
+
+🙇 We'd like to thank all contributors for this new release! In particular,
+ samsja, Saba Sturua, Anne Yang, Zhaofeng Miao, Mohammad Kalim Akram, Kacper Łukawski, Joan Fontanals, Johannes Messner, IyadhKhalfallah, Jina Dev Bot, 🙇
+
+
+### 🆕 New Features
+
+ - [[```db4cc1a7```](https://github.com/jina-ai/docarray/commit/db4cc1a770f29368683e6b2be094e4062972f3ca)] __-__ save and load inmemory index (#1534) (*Saba Sturua*)
+ - [[```44977b75```](https://github.com/jina-ai/docarray/commit/44977b75ee9c96a272637ebf21a14285b1c6aeab)] __-__ subindex for document index (#1428) (*Anne Yang*)
+ - [[```096f6449```](https://github.com/jina-ai/docarray/commit/096f644980e63391e3ca7bfc8486b877b181b07f)] __-__ search_field should be optional in hybrid text search (#1516) (*Anne Yang*)
+ - [[```e9fc57af```](https://github.com/jina-ai/docarray/commit/e9fc57af6b7fb0a70534971a8f6a2062be444f94)] __-__ openapi tensor shapes (#1510) (*Johannes Messner*)
+
+### 🐞 Bug fixes
+
+ - [[```2d3bcd2e```](https://github.com/jina-ai/docarray/commit/2d3bcd2ef2886e938cd1d47a7cebb5bc5cb14b34)] __-__ check if filepath exists for inmemory index (#1537) (*Saba Sturua*)
+ - [[```1f2dcea6```](https://github.com/jina-ai/docarray/commit/1f2dcea6e91140de9dc84b251604197dfbae4e53)] __-__ add empty judgement to index search (#1533) (*Anne Yang*)
+ - [[```8dd050f5```](https://github.com/jina-ai/docarray/commit/8dd050f554f44f0f79aa5b8ce19948848940a283)] __-__ detach the torch tensors (#1526) (*Mohammad Kalim Akram*)
+ - [[```bfebad6b```](https://github.com/jina-ai/docarray/commit/bfebad6be30acd24bd4b9ff92c30522fc803c87c)] __-__ DocVec display (#1522) (*Anne Yang*)
+ - [[```cdde136c```](https://github.com/jina-ai/docarray/commit/cdde136c8002a9a114945648e6d4e135b15a14cd)] __-__ docs link (#1518) (*IyadhKhalfallah*)
+
+### 📗 Documentation
+
+ - [[```02afeb9f```](https://github.com/jina-ai/docarray/commit/02afeb9f6be25406beb2f6480fde8954729a272b)] __-__ __store_jac__: remove wrong info (#1531) (*Zhaofeng Miao*)
+ - [[```bc7f7253```](https://github.com/jina-ai/docarray/commit/bc7f72533cbe2af18fd99ea2bfe393ec4157e2de)] __-__ fix link to documentation in readme (#1525) (*Joan Fontanals*)
+ - [[```0ef6772d```](https://github.com/jina-ai/docarray/commit/0ef6772d35e2a4747e0fd7c39384424a79f8ec62)] __-__ flatten structure (#1520) (*Johannes Messner*)
+
+### 🍹 Other Improvements
+
+ - [[```9657d6fb```](https://github.com/jina-ai/docarray/commit/9657d6fbd69cc0f3686da9b056989fe8ce23cd00)] __-__ bump to 0.32.0 (#1541) (*samsja*)
+ - [[```9705431b```](https://github.com/jina-ai/docarray/commit/9705431b8937a090924fe875f5550dc926802315)] __-__ Add more Qdrant examples (#1527) (*Kacper Łukawski*)
+ - [[```445a72fa```](https://github.com/jina-ai/docarray/commit/445a72fa2a32c30de801f4d3203866d85de23990)] __-__ add protobuf in the hnsw extra (#1524) (*Saba Sturua*)
+ - [[```20fdcd27```](https://github.com/jina-ai/docarray/commit/20fdcd27b786f7fea9ab4bc4ee5bbbd244ba66f0)] __-__ __version__: the next version will be 0.31.2 (*Jina Dev Bot*)
+
+
+## Release Note (`0.32.1`)
+
+> Release time: 2023-05-26 14:50:34
+
+
+
+🙇 We'd like to thank all contributors for this new release! In particular,
+ Joan Fontanals, maxwelljin, Johannes Messner, aman-exp-infy, Saba Sturua, Jina Dev Bot, 🙇
+
+
+### 🆕 New Features
+
+ - [[```8651e6e8```](https://github.com/jina-ai/docarray/commit/8651e6e88cef3f66c6a6eeca0531e26e2b4ca18d)] __-__ logs added for es8 index (#1551) (*aman-exp-infy*)
+
+### 🐞 Bug fixes
+
+ - [[```5d41c13c```](https://github.com/jina-ai/docarray/commit/5d41c13c96de299ac8035fd09d3bdd32dc518036)] __-__ fix None embedding exact nn search (#1575) (*Joan Fontanals*)
+ - [[```7a7a83a5```](https://github.com/jina-ai/docarray/commit/7a7a83a5d7526b8840e4b98c966cdbc635280bbc)] __-__ support list in document class (#1557) (#1569) (*maxwelljin*)
+ - [[```40549f4a```](https://github.com/jina-ai/docarray/commit/40549f4aeacde1522fb6a3406c98b8dbd14e0858)] __-__ fix anydoc deserialization (#1571) (*Joan Fontanals*)
+ - [[```44317570```](https://github.com/jina-ai/docarray/commit/44317570395380bcdff8d7de9e815b42460f5b9c)] __-__ dict method for document view (#1559) (*Johannes Messner*)
+
+### 🧼 Code Refactoring
+
+ - [[```0bcc956d```](https://github.com/jina-ai/docarray/commit/0bcc956da6d9d4971ee0b92f69fa776d7aae24f1)] __-__ uncaped tests as a nightly job (#1540) (*Saba Sturua*)
+
+### 📗 Documentation
+
+ - [[```0e6aa3b6```](https://github.com/jina-ai/docarray/commit/0e6aa3b6f43d1762500af96b646e538af44be1b5)] __-__ update doc building guide (#1566) (*Johannes Messner*)
+ - [[```a01a0554```](https://github.com/jina-ai/docarray/commit/a01a05542d17264b8a164bec783633658deeedb8)] __-__ explain the state of doclist in fastapi (#1546) (*Johannes Messner*)
+
+### 🍹 Other Improvements
+
+ - [[```8a2e92a3```](https://github.com/jina-ai/docarray/commit/8a2e92a3f94efc77d90e0747c246bdcf2ce72dfd)] __-__ update pyproject.toml (#1581) (*Joan Fontanals*)
+ - [[```9b5cbeda```](https://github.com/jina-ai/docarray/commit/9b5cbedaa43ea392b985e0ad293839523ce57030)] __-__ __version__: the next version will be 0.32.1 (*Jina Dev Bot*)
+
+
+## Release Note (`0.33.0`)
+
+> Release time: 2023-06-06 14:05:56
+
+
+
+🙇 We'd like to thank all contributors for this new release! In particular,
+ Joan Fontanals, Saba Sturua, samsja, maxwelljin, Mohammad Kalim Akram, Jina Dev Bot, 🙇
+
+
+### 🆕 New Features
+
+ - [[```110f714f```](https://github.com/jina-ai/docarray/commit/110f714fda40689c4c743bc825bd1e017a739d9d)] __-__ avoid stack embedding for every search (#1586) (*maxwelljin*)
+ - [[```5e74fcca```](https://github.com/jina-ai/docarray/commit/5e74fcca1ef0ef03f823b888b8585c3a0177144e)] __-__ tensor coersion (#1588) (*samsja*)
+
+### 🐞 Bug fixes
+
+ - [[```f7371b48```](https://github.com/jina-ai/docarray/commit/f7371b48df7e93de4808a9cbfcf0c89420a12129)] __-__ filter limits (#1618) (*Saba Sturua*)
+ - [[```6903c773```](https://github.com/jina-ai/docarray/commit/6903c773490459a002fad2751dd3238b735e8d5e)] __-__ hnswlib must be able to search with limit more than num docs (#1611) (*Joan Fontanals*)
+ - [[```e24be8d6```](https://github.com/jina-ai/docarray/commit/e24be8d6eff48baef440d184171a0c2e3356d0bf)] __-__ allow update on HNSWLibIndex (#1604) (*Joan Fontanals*)
+ - [[```3cef708c```](https://github.com/jina-ai/docarray/commit/3cef708cce24aaa017fb2a5af5aed33bc029df09)] __-__ dynamically resize internal index to adapt to increasing number of docs (#1602) (*Joan Fontanals*)
+ - [[```a5c90064```](https://github.com/jina-ai/docarray/commit/a5c90064cb2da0120ee6bdfcec613ef8f1447596)] __-__ fix simple usage of HNSWLib (#1596) (*Joan Fontanals*)
+ - [[```88414ce2```](https://github.com/jina-ai/docarray/commit/88414ce25ec1808565f6f46d9051a08646f765b4)] __-__ fix InMemoryExactNN index initialization with nested DocList (#1582) (*Joan Fontanals*)
+ - [[```5d0e24c9```](https://github.com/jina-ai/docarray/commit/5d0e24c94457e04f3e902b6ecd33d4e607c7636e)] __-__ fix summary of Doc with list (#1595) (*Joan Fontanals*)
+ - [[```f765d9f4```](https://github.com/jina-ai/docarray/commit/f765d9f4f01089d0ab361cecd771efc97970ce92)] __-__ solve issues caused by issubclass (#1594) (*maxwelljin*)
+ - [[```5ee87876```](https://github.com/jina-ai/docarray/commit/5ee878763cbc35f95d26a0fc3842211d8add3e16)] __-__ make example payload a string and not bytes (#1587) (*Joan Fontanals*)
+
+### 🧼 Code Refactoring
+
+ - [[```f9e504ef```](https://github.com/jina-ai/docarray/commit/f9e504efbc7229dd5d29d4fecef7f9d0bfb3dbc9)] __-__ minor changes in weaviate (#1621) (*Saba Sturua*)
+ - [[```4eec5599```](https://github.com/jina-ai/docarray/commit/4eec5599bf2905a191e4cc5614a1813e91f4c02f)] __-__ make AnyTensor a class (#1552) (*Mohammad Kalim Akram*)
+
+### 📗 Documentation
+
+ - [[```29c2d23a```](https://github.com/jina-ai/docarray/commit/29c2d23a618704e9ed13108c754e09c6ef053a93)] __-__ add forward declaration steps to example to avoid pickling error (#1615) (*Joan Fontanals*)
+ - [[```5e6bf755```](https://github.com/jina-ai/docarray/commit/5e6bf7550bf2395244633c4a19b7d812b7d6fe9d)] __-__ fix n_dim to dim in docs (#1610) (*Joan Fontanals*)
+ - [[```de8c654b```](https://github.com/jina-ai/docarray/commit/de8c654bd2f4c5465943c974520021e39ff07ab4)] __-__ add in memory to documentation as list of supported vector index (#1607) (*Joan Fontanals*)
+ - [[```1e41b5c5```](https://github.com/jina-ai/docarray/commit/1e41b5c59e4f2c7d14de1619141ed35898bbc815)] __-__ add a tensor section to docs (#1576) (*samsja*)
+
+### 🍹 Other Improvements
+
+ - [[```68194f49```](https://github.com/jina-ai/docarray/commit/68194f492a84ecfb61ffda1b669debe156a24a37)] __-__ update version to 0.33 (#1626) (*Joan Fontanals*)
+ - [[```ac2e417e```](https://github.com/jina-ai/docarray/commit/ac2e417e9fc23ac06ebed515de0b0688827c145a)] __-__ fix issue template (#1624) (*samsja*)
+ - [[```e1777144```](https://github.com/jina-ai/docarray/commit/e177714491caaa28dd1990db52ce3359416b8ab0)] __-__ add a better looking issue template (#1623) (*samsja*)
+ - [[```692584d6```](https://github.com/jina-ai/docarray/commit/692584d6b8a2b9c1f1d6a869ecf7a0114e7e6c5c)] __-__ simplify find batched (#1598) (*Joan Fontanals*)
+ - [[```91350882```](https://github.com/jina-ai/docarray/commit/91350882817cc6ed0f24aa02a6f14e7fe182fb9c)] __-__ __version__: the next version will be 0.32.2 (*Jina Dev Bot*)
+
+
+## Release Note (`0.34.0`)
+
+> Release time: 2023-06-21 08:15:43
+
+
+
+🙇 We'd like to thank all contributors for this new release! In particular,
+ Joan Fontanals, Johannes Messner, Saba Sturua, samsja, maxwelljin, Shukri, Nikolas Pitsillos, Joan Fontanals Martinez, maxwelljin2, Kacper Łukawski, Aman Agarwal, Jina Dev Bot, 🙇
+
+
+### 🆕 New Features
+
+ - [[```eb3f8570```](https://github.com/jina-ai/docarray/commit/eb3f8570da5b1e23e21e3fe50ab0a30f136f7940)] __-__ tensor type for protobuf deserialization (#1645) (*Johannes Messner*)
+ - [[```a6fdd80c```](https://github.com/jina-ai/docarray/commit/a6fdd80c69d8c23660113ad240d82167448e39f6)] __-__ sub-document support for indexer (*maxwelljin2*)
+ - [[```78892703```](https://github.com/jina-ai/docarray/commit/788927034da7efc734c2cbc23ba6854dd245c3cb)] __-__ contain func for qdrant (*maxwelljin2*)
+ - [[```74a683c0```](https://github.com/jina-ai/docarray/commit/74a683c04b07872646ccd6f067ae82f44ea7e370)] __-__ contain func for weaviate (*maxwelljin2*)
+ - [[```6ca3aa6e```](https://github.com/jina-ai/docarray/commit/6ca3aa6eb70afc9f23b69ecf1b75b760d43614fa)] __-__ contain func for elastic (*maxwelljin2*)
+ - [[```66b0f716```](https://github.com/jina-ai/docarray/commit/66b0f716a3e3cf92efe40a4346a2ccaf49897a0e)] __-__ check contain in indexer (*maxwelljin2*)
+ - [[```2c123535```](https://github.com/jina-ai/docarray/commit/2c123535c2d150c6f120aad4d58df3cc6798a1c4)] __-__ support subindex on ExactNNSearch (#1617) (*maxwelljin*)
+
+### 🐞 Bug fixes
+
+ - [[```c3c8061f```](https://github.com/jina-ai/docarray/commit/c3c8061f3e22e50fb08404b254660006802f42a0)] __-__ docvec equality if tensors are involved (#1663) (*Johannes Messner*)
+ - [[```0c27fef6```](https://github.com/jina-ai/docarray/commit/0c27fef603970e22dc1010fd2b18aa0af834ef9e)] __-__ bugs when serialize union type (#1655) (*maxwelljin*)
+ - [[```dc96e38a```](https://github.com/jina-ai/docarray/commit/dc96e38a0446d36bb6c7d6f88a9209265032bb3c)] __-__ pass limit as integer (#1657) (*Joan Fontanals*)
+ - [[```7e211a94```](https://github.com/jina-ai/docarray/commit/7e211a940e4a390b059ee9de5acb4afa78c93909)] __-__ pass limit as integer (#1656) (*Joan Fontanals*)
+ - [[```c3db7553```](https://github.com/jina-ai/docarray/commit/c3db75538bb9d70e35b249100b6c9c7372804e4b)] __-__ update text search to match client's new sig (#1654) (*Shukri*)
+ - [[```4e7e262a```](https://github.com/jina-ai/docarray/commit/4e7e262ab7394becab33cb688a066c6c62dae79c)] __-__ doc vec equality (#1641) (*Nikolas Pitsillos*)
+ - [[```eae44954```](https://github.com/jina-ai/docarray/commit/eae449542c41ef39b853fa1bd3d51ebd77f56e10)] __-__ default column config should be DBConfig and not RuntimeConfig (#1648) (*Joan Fontanals*)
+ - [[```d13c8c45```](https://github.com/jina-ai/docarray/commit/d13c8c450fadf4e5e3094a5b6c843e83c68734a4)] __-__ move default_column_config to DBConfig (*Joan Fontanals Martinez*)
+ - [[```cd3efc6f```](https://github.com/jina-ai/docarray/commit/cd3efc6fbe68f23d1b961ce95f6d62bb26dc8141)] __-__ summary of legacy document (*maxwelljin*)
+ - [[```c13739b8```](https://github.com/jina-ai/docarray/commit/c13739b80532fbcbb1b8257a5e21f08976af160b)] __-__ remove get documents method (*maxwelljin2*)
+ - [[```7c807d4f```](https://github.com/jina-ai/docarray/commit/7c807d4fa8224e1fa90548bbc5e2f44907031d80)] __-__ remove get all documents method (*maxwelljin2*)
+ - [[```00794486```](https://github.com/jina-ai/docarray/commit/00794486336b5bc7a852970b085e11d497de057f)] __-__ mypy issues (*maxwelljin2*)
+ - [[```c8356813```](https://github.com/jina-ai/docarray/commit/c8356813acc81c5b4ce591d9ce2345b713081b08)] __-__ protobuf (de)ser for docvec (#1639) (*Johannes Messner*)
+ - [[```f36c6211```](https://github.com/jina-ai/docarray/commit/f36c621104f64d4e88aeb6673e1a8ba34c3472d1)] __-__ find_and_filter for inmemory (#1642) (*Saba Sturua*)
+ - [[```1abdfce0```](https://github.com/jina-ai/docarray/commit/1abdfce0eca9230bc6f75759c6279f224be23ade)] __-__ legacy document issues (*maxwelljin2*)
+ - [[```b856b0b3```](https://github.com/jina-ai/docarray/commit/b856b0b3f4ccda505acc092bebecfaa00ac3fd83)] __-__ __qdrant__: working with external Qdrant collections #1630 (#1632) (*Kacper Łukawski*)
+ - [[```693f877d```](https://github.com/jina-ai/docarray/commit/693f877d7e1e5921ec69e7dbb4a41f984a14d46d)] __-__ DocList and DocVec are now coerced to each other correctly (#1568) (*Aman Agarwal*)
+ - [[```65afa9a1```](https://github.com/jina-ai/docarray/commit/65afa9a14c6075238aeb95f62620c93ae46aa9ca)] __-__ fix update with tensors (#1628) (*Joan Fontanals*)
+
+### 🧼 Code Refactoring
+
+ - [[```69dc861b```](https://github.com/jina-ai/docarray/commit/69dc861bf857c4f54d4ffc66da5160d570b4bb54)] __-__ implementation of InMemoryExactNNIndex follows DBConfig way (#1649) (*Joan Fontanals*)
+
+### 📗 Documentation
+
+ - [[```4e6bf49b```](https://github.com/jina-ai/docarray/commit/4e6bf49b82daae82ed25d510dc6d22f9f2e5b473)] __-__ coming from langchain (#1660) (*Saba Sturua*)
+ - [[```e870eb88```](https://github.com/jina-ai/docarray/commit/e870eb8824624f4690edc9a976665d791c5d1135)] __-__ enhance DocVec section (#1658) (*maxwelljin*)
+ - [[```eedd83ce```](https://github.com/jina-ai/docarray/commit/eedd83ce249941493a23bc32fc862ba7353d732c)] __-__ qdrant in memory usage (#1634) (*Saba Sturua*)
+
+### 🍹 Other Improvements
+
+ - [[```dc7b681e```](https://github.com/jina-ai/docarray/commit/dc7b681e1701f41fb500308fbcc154f8d09e3a1f)] __-__ upgrade version to 0.34.0 (#1664) (*Joan Fontanals*)
+ - [[```deb892f1```](https://github.com/jina-ai/docarray/commit/deb892f16200c1180c7a025a11a22f87d7006bec)] __-__ fix link on pypi (#1662) (*samsja*)
+ - [[```7f91e217```](https://github.com/jina-ai/docarray/commit/7f91e21737eacad0d4c7aaaec2445f5eaab3a7f7)] __-__ remove useless file (#1650) (*samsja*)
+ - [[```67a328f4```](https://github.com/jina-ai/docarray/commit/67a328f444777db1f36aa99dbf879e42bd28517a)] __-__ Revert "fix: move default_column_config to DBConfig" (*Joan Fontanals Martinez*)
+ - [[```adc48180```](https://github.com/jina-ai/docarray/commit/adc481807ca02721952501479ce4f2b209c6e62c)] __-__ drop python 3.7 (#1644) (*samsja*)
+ - [[```e66bf106```](https://github.com/jina-ai/docarray/commit/e66bf1060cb020023948498df4f5266c3b23324d)] __-__ __version__: the next version will be 0.33.1 (*Jina Dev Bot*)
+
+
+## Release Note (`0.35.0`)
+
+> Release time: 2023-07-03 11:53:25
+
+
+
+🙇 We'd like to thank all contributors for this new release! In particular,
+ Joan Fontanals, Johannes Messner, Saba Sturua, Han Xiao, Jina Dev Bot, 🙇
+
+
+### 🆕 New Features
+
+ - [[```8f25887d```](https://github.com/jina-ai/docarray/commit/8f25887d13f27338a99199ebd85462a4d6764615)] __-__ i/o for DocVec (#1562) (*Johannes Messner*)
+ - [[```e0e5cd8c```](https://github.com/jina-ai/docarray/commit/e0e5cd8ceacc9da8450094f591287d597cd7b0af)] __-__ validate file formats in url (#1606) (#1669) (*Saba Sturua*)
+ - [[```a7643414```](https://github.com/jina-ai/docarray/commit/a7643414da05e1f55198836646580965a49314d2)] __-__ add method to create BaseDoc from schema (#1667) (*Joan Fontanals*)
+
+### 🐞 Bug fixes
+
+ - [[```bcb60ca6```](https://github.com/jina-ai/docarray/commit/bcb60ca66738dc27ce04769e754133a4e9b0e173)] __-__ better error message when docvec is unusable (#1675) (*Johannes Messner*)
+
+### 📗 Documentation
+
+ - [[```b6eaa94c```](https://github.com/jina-ai/docarray/commit/b6eaa94cc1853c261e5a7967a3634f017fc41968)] __-__ fix a reference in readme (#1674) (*Saba Sturua*)
+
+### 🏁 Unit Test and CICD
+
+ - [[```b65b385d```](https://github.com/jina-ai/docarray/commit/b65b385d36d740afb5218a3de7c258617a2e51ca)] __-__ pin pydantic version (#1682) (*Joan Fontanals*)
+
+### 🍹 Other Improvements
+
+ - [[```3f089e52```](https://github.com/jina-ai/docarray/commit/3f089e5237c84e2ada367e30820d96018a7954d0)] __-__ update version to 0.35.0 (#1684) (*Joan Fontanals*)
+ - [[```3fc6ecb7```](https://github.com/jina-ai/docarray/commit/3fc6ecb71bdc0095f2c405c17492debcc3d8412d)] __-__ fix docarray v1v2 terms (#1668) (*Han Xiao*)
+ - [[```f507a5f7```](https://github.com/jina-ai/docarray/commit/f507a5f72548a5235e60f15dbcee2c35930c60c1)] __-__ __version__: the next version will be 0.34.1 (*Jina Dev Bot*)
+
+
+## Release Note (`0.36.0`)
+
+> Release time: 2023-07-18 14:43:28
+
+
+
+🙇 We'd like to thank all contributors for this new release! In particular,
+ Joan Fontanals, Saba Sturua, Aman Agarwal, Shukri, samsja, Puneeth K, Jina Dev Bot, 🙇
+
+
+### 🆕 New Features
+
+ - [[```b306c80b```](https://github.com/jina-ai/docarray/commit/b306c80b334a1d1b2bc865d53d7e9733f27445f5)] __-__ add JAX as Computation Backend (#1646) (*Aman Agarwal*)
+ - [[```069aa3aa```](https://github.com/jina-ai/docarray/commit/069aa3aa2d2eae3a1a0dca574e266a33b1edf9c9)] __-__ support redis (#1550) (*Saba Sturua*)
+
+### 🐞 Bug fixes
+
+ - [[```15e3ed69```](https://github.com/jina-ai/docarray/commit/15e3ed6905025ba3490607eb9659c1cfe7600160)] __-__ weaviate handles lowercase index names (#1711) (*Saba Sturua*)
+ - [[```c5664016```](https://github.com/jina-ai/docarray/commit/c56640160d54ccfd2e699f7f65103672cf77f32b)] __-__ slow hnsw by caching num docs (#1706) (*Saba Sturua*)
+ - [[```d2e18580```](https://github.com/jina-ai/docarray/commit/d2e1858078049217b16db0d05bc6a02be3043934)] __-__ qdrant unable to see index_name (#1705) (*Saba Sturua*)
+ - [[```94a479eb```](https://github.com/jina-ai/docarray/commit/94a479eb1e1bf5e5715f61767992061f61003115)] __-__ fix search in memory with AnyEmbedding (#1696) (*Joan Fontanals*)
+ - [[```62ad22aa```](https://github.com/jina-ai/docarray/commit/62ad22aa8ae3617b9464b904cd33b3115d011781)] __-__ use safe_issubclass everywhere (#1691) (*Joan Fontanals*)
+ - [[```f6ce2833```](https://github.com/jina-ai/docarray/commit/f6ce2833886468e03b8eafec222be7cef3fe62e2)] __-__ avoid converting doclists in the base index (#1685) (*Saba Sturua*)
+
+### 🧼 Code Refactoring
+
+ - [[```0ea68467```](https://github.com/jina-ai/docarray/commit/0ea6846783a1450dc92e4ce181b430f02e32df10)] __-__ contains method in the base class (#1701) (*Saba Sturua*)
+ - [[```0a1da307```](https://github.com/jina-ai/docarray/commit/0a1da3071e2f7dbcd655c2243732a2a07c95f01f)] __-__ more robust method to detect duplicate index (#1651) (*Shukri*)
+
+### 📗 Documentation
+
+ - [[```5089bdae```](https://github.com/jina-ai/docarray/commit/5089bdaea955f77c31495535bf99da37b85edb3b)] __-__ add docs for dict() method (#1643) (*Puneeth K*)
+
+### 🏁 Unit Test and CICD
+
+ - [[```e0afb5e7```](https://github.com/jina-ai/docarray/commit/e0afb5e723a7a2f3a1346eec554c7183868b98e5)] __-__ do not require black for tests more (#1694) (*Joan Fontanals*)
+ - [[```0dd49538```](https://github.com/jina-ai/docarray/commit/0dd4953866faff685173ac5b6871279d545b2a50)] __-__ do not require black for tests (#1693) (*Joan Fontanals*)
+
+### 🍹 Other Improvements
+
+ - [[```ddc73e19```](https://github.com/jina-ai/docarray/commit/ddc73e19024e2c63071fc17792bcf616b6931b0a)] __-__ upgrade version in pyproject (#1712) (*Joan Fontanals*)
+ - [[```528adfc8```](https://github.com/jina-ai/docarray/commit/528adfc8f3b09fc6f3b9d65b31ca256ef34a819f)] __-__ upgrade version to 0.36 (#1710) (*Joan Fontanals*)
+ - [[```a3f6998a```](https://github.com/jina-ai/docarray/commit/a3f6998a9427bc8d23bab1c4ddccd69dec220c8f)] __-__ remove one of the codecov badges (#1700) (*Joan Fontanals*)
+ - [[```b364ae1a```](https://github.com/jina-ai/docarray/commit/b364ae1ae8daff4890d3fddde88ed4fe4c7e3a7c)] __-__ add codecov (#1699) (*Joan Fontanals*)
+ - [[```64bbf14a```](https://github.com/jina-ai/docarray/commit/64bbf14a8d8854b95ec1c9f90ffa8c8b8a04515b)] __-__ add code of conduct (#1688) (*samsja*)
+ - [[```d2655238```](https://github.com/jina-ai/docarray/commit/d2655238858a7838ca4787187aa9491d4a769e02)] __-__ __version__: the next version will be 0.35.1 (*Jina Dev Bot*)
+
+
+## Release Note (`0.37.0`)
+
+> Release time: 2023-08-03 03:11:16
+
+
+
+🙇 We'd like to thank all contributors for this new release! In particular,
+ Joan Fontanals, Saba Sturua, Johannes Messner, Jina Dev Bot, 🙇
+
+
+### 🆕 New Features
+
+ - [[```31c2bb9c```](https://github.com/jina-ai/docarray/commit/31c2bb9c00c2cea9e148d112c1d6226d7f6c19b9)] __-__ add description and example to ID field of BaseDoc (#1737) (*Joan Fontanals*)
+ - [[```efeab90d```](https://github.com/jina-ai/docarray/commit/efeab90d3840f94b15e7767a07be0f617cb8387c)] __-__ tensor_type for all DocVec serializations (#1679) (*Johannes Messner*)
+ - [[```00e980dc```](https://github.com/jina-ai/docarray/commit/00e980dcfc3872b7b833184169a777527387016b)] __-__ filtering in hnsw (#1718) (*Saba Sturua*)
+ - [[```7ad70bfc```](https://github.com/jina-ai/docarray/commit/7ad70bfc751841aee5f8747c681a259e2363cbe8)] __-__ update for inmemory index (#1724) (*Saba Sturua*)
+ - [[```007f1131```](https://github.com/jina-ai/docarray/commit/007f1131844975d812a040c85cc21b6fa19366bd)] __-__ support milvus (#1681) (*Saba Sturua*)
+ - [[```c96707a1```](https://github.com/jina-ai/docarray/commit/c96707a133e21b7810aa57da78fb2a49b448a41a)] __-__ InMemoryExactNNIndex pre filtering (#1713) (*Saba Sturua*)
+
+### 🐞 Bug fixes
+
+ - [[```d2c82d49```](https://github.com/jina-ai/docarray/commit/d2c82d49e3a92e5d5ba29d1e4ce9b31435c73f95)] __-__ tensor equals type raises exception (#1739) (*Johannes Messner*)
+ - [[```87ec19f8```](https://github.com/jina-ai/docarray/commit/87ec19f83827cb2bc1c56087de3ef05d6bcd8e02)] __-__ add description and title to dynamic class (#1734) (*Joan Fontanals*)
+ - [[```896c20be```](https://github.com/jina-ai/docarray/commit/896c20be0c32c9dc9136f2eea7bdbb8e5cf2da0e)] __-__ create more info from dynamic (#1733) (*Joan Fontanals*)
+ - [[```0e130100```](https://github.com/jina-ai/docarray/commit/0e1301006403c59d99cd7c8ae77e6a7bef837838)] __-__ fix call to unsafe issubclass (#1731) (*Joan Fontanals*)
+ - [[```4cd58500```](https://github.com/jina-ai/docarray/commit/4cd5850062af4515bc2aebd2b1727372a49867dc)] __-__ collection and index name in qdrant (#1723) (*Joan Fontanals*)
+ - [[```304a4e9b```](https://github.com/jina-ai/docarray/commit/304a4e9b61a9eab7584cd3859a83663ddb3227ef)] __-__ fix deepcopy torchtensor (#1720) (*Joan Fontanals*)
+
+### 🧼 Code Refactoring
+
+ - [[```a643f6ad```](https://github.com/jina-ai/docarray/commit/a643f6adada89dd1e7e4ddf0f92c3e27fb51a23b)] __-__ hnswlib performance (#1727) (*Joan Fontanals*)
+ - [[```19aec21a```](https://github.com/jina-ai/docarray/commit/19aec21aa043cbe3556744f871297ad9d171ba50)] __-__ do not recompute every time num_docs (#1729) (*Joan Fontanals*)
+
+### 📗 Documentation
+
+ - [[```7c10295c```](https://github.com/jina-ai/docarray/commit/7c10295c964df273483bb0391ceeee35f57c9b28)] __-__ make document indices self-contained (#1678) (*Saba Sturua*)
+
+### 🏁 Unit Test and CICD
+
+ - [[```7be038c8```](https://github.com/jina-ai/docarray/commit/7be038c8bcf48c77e45b7d2654b10b563603cd32)] __-__ refactor test to be independent (#1738) (*Joan Fontanals*)
+ - [[```24c00cc8```](https://github.com/jina-ai/docarray/commit/24c00cc8b7cb85f1e2ef3dea76df3382380e5c99)] __-__ refactor hnswlib test subindex (#1732) (*Joan Fontanals*)
+
+### 🍹 Other Improvements
+
+ - [[```77b4dc1f```](https://github.com/jina-ai/docarray/commit/77b4dc1f1c24c1552b01489725205b7ebd55311c)] __-__ update version (#1743) (*Joan Fontanals*)
+ - [[```3be6f2b9```](https://github.com/jina-ai/docarray/commit/3be6f2b9eca79e154ca445524b6bf32ff5910fbc)] __-__ avoid extra debugging (#1730) (*Joan Fontanals*)
+ - [[```24143a1f```](https://github.com/jina-ai/docarray/commit/24143a1f7a3fbcbf33523b983eec8efd8817ebc4)] __-__ refactor filter in hnswlib (#1728) (*Joan Fontanals*)
+ - [[```410665ad```](https://github.com/jina-ai/docarray/commit/410665ad9ae59c0687c780bdfb2a65d8e8097f8a)] __-__ add JAX to README (#1722) (*Joan Fontanals*)
+ - [[```2a866aea```](https://github.com/jina-ai/docarray/commit/2a866aea9f2779ea59b6ed26177c65bbaee33d3a)] __-__ add link to roadmap in readme (#1715) (*Joan Fontanals*)
+ - [[```68b0c5b8```](https://github.com/jina-ai/docarray/commit/68b0c5b86c572f7f7f2fc3a4b838d7bd484ba77e)] __-__ __version__: the next version will be 0.36.1 (*Jina Dev Bot*)
+
+
+## Release Note (`0.37.1`)
+
+> Release time: 2023-08-22 14:09:53
+
+
+
+🙇 We'd like to thank all contributors for this new release! In particular,
+ samsja, AlaeddineAbdessalem, TERBOUCHE Hacene, Joan Fontanals, Jina Dev Bot, 🙇
+
+
+### 🐞 Bug fixes
+
+ - [[```0ad18a63```](https://github.com/jina-ai/docarray/commit/0ad18a63e6bb1073c080eb3ac304bd90439e878b)] __-__ bump version (#1757) (*samsja*)
+ - [[```46c5dfd0```](https://github.com/jina-ai/docarray/commit/46c5dfd0aa1f10723c99a3e9a39c099dc08710e8)] __-__ relax the schema check in update mixin (#1755) (*AlaeddineAbdessalem*)
+ - [[```6c771125```](https://github.com/jina-ai/docarray/commit/6c771125e5e278c5c176a35ccc619e90098c7339)] __-__ __qdrant__: fix non-class type fields #1748 (#1752) (*TERBOUCHE Hacene*)
+ - [[```adb0d014```](https://github.com/jina-ai/docarray/commit/adb0d0141032c72cf2412a8b998c78cd9a920a9e)] __-__ fix dynamic class creation with doubly nested schemas (#1747) (*AlaeddineAbdessalem*)
+ - [[```691d939e```](https://github.com/jina-ai/docarray/commit/691d939e8021a2589c5d5106ec1270179618599d)] __-__ fix readme test (#1746) (*samsja*)
+
+### 📗 Documentation
+
+ - [[```a39c4f98```](https://github.com/jina-ai/docarray/commit/a39c4f982331d6ef3145127797b1bcedc8e05248)] __-__ update readme (#1744) (*Joan Fontanals*)
+
+### 🍹 Other Improvements
+
+ - [[```bd3d8f03```](https://github.com/jina-ai/docarray/commit/bd3d8f0354c56559c3ae8a30f06b566ed7945f6e)] __-__ __version__: the next version will be 0.37.1 (*Jina Dev Bot*)
+
+
+## Release Note (`0.38.0`)
+
+> Release time: 2023-09-07 13:40:16
+
+
+
+🙇 We'd like to thank all contributors for this new release! In particular,
+ Joan Fontanals, Johannes Messner, samsja, AlaeddineAbdessalem, Jina Dev Bot, 🙇
+
+
+### 🐞 Bug fixes
+
+ - [[```fb174560```](https://github.com/jina-ai/docarray/commit/fb174560aad3b2d554c14cefeef940f8030dfdd2)] __-__ skip doc attributes in __annotations__ but not in __fields__ (#1777) (*Joan Fontanals*)
+ - [[```3dc525f4```](https://github.com/jina-ai/docarray/commit/3dc525f46d8a8771b7f18070beae0c0758371dd6)] __-__ make DocList.to_json() return str instead of bytes (#1769) (*Johannes Messner*)
+ - [[```2af8a0c6```](https://github.com/jina-ai/docarray/commit/2af8a0c60213a46f2b86e4c418bb5d3ef732051c)] __-__ casting in reduce before appending (#1758) (*AlaeddineAbdessalem*)
+
+### 🧼 Code Refactoring
+
+ - [[```08ca686d```](https://github.com/jina-ai/docarray/commit/08ca686dd397b103e6330cd479ad4519126f90b6)] __-__ use safe_issubclass (#1778) (*Joan Fontanals*)
+
+### 📗 Documentation
+
+ - [[```189ff637```](https://github.com/jina-ai/docarray/commit/189ff637e790c59dfea1af3447666b34c9fb9fdf)] __-__ explain how to set document config (#1773) (*Johannes Messner*)
+ - [[```cd4854c9```](https://github.com/jina-ai/docarray/commit/cd4854c9b9e89abc5537be70a5a79d9a8ea47782)] __-__ add workaround for torch compile (#1754) (*Johannes Messner*)
+ - [[```587ab5b3```](https://github.com/jina-ai/docarray/commit/587ab5b39160bdeda1b86d3c09d2296c443cd42e)] __-__ add note about pickling dynamically created doc class (#1763) (*Joan Fontanals*)
+ - [[```61bf9c7a```](https://github.com/jina-ai/docarray/commit/61bf9c7a88a033e0551287fac5f1260fa4d355bf)] __-__ improve filtering docstrings (#1762) (*Joan Fontanals*)
+
+### 🍹 Other Improvements
+
+ - [[```7ec88b46```](https://github.com/jina-ai/docarray/commit/7ec88b46e44b52d0400d1495e56387f367aabd2f)] __-__ update minor (#1781) (*Joan Fontanals*)
+ - [[```cc2339db```](https://github.com/jina-ai/docarray/commit/cc2339db44626e622f3b2f354d0c5f8d8a0b20ea)] __-__ remove pydantic ref from issue template (#1767) (*samsja*)
+ - [[```d5cb02fb```](https://github.com/jina-ai/docarray/commit/d5cb02fbd5cc7392fb92f30c1e7ea436507eb892)] __-__ __version__: the next version will be 0.37.2 (*Jina Dev Bot*)
+
+
+## Release Note (`0.39.0`)
+
+> Release time: 2023-10-02 13:06:02
+
+
+
+🙇 We'd like to thank all contributors for this new release! In particular,
+ Joan Fontanals, samsja, lvzi, Puneeth K, Jina Dev Bot, 🙇
+
+
+### 🆕 New Features
+
+ - [[```83d2236a```](https://github.com/jina-ai/docarray/commit/83d2236a356bd108e686abae20a06c7fbc12899f)] __-__ enable dynamic doc with Pydantic v2 (#1795) (*Joan Fontanals*)
+ - [[```2a1cc9e4```](https://github.com/jina-ai/docarray/commit/2a1cc9e4c975cef50760c8a459d4be30a4da4116)] __-__ add BaseDocWithoutId (#1803) (*samsja*)
+ - [[```8fba9e45```](https://github.com/jina-ai/docarray/commit/8fba9e45f995f2d65efbea2c837f2f70dfe3e858)] __-__ remove JAC (#1791) (*Joan Fontanals*)
+ - [[```715252a7```](https://github.com/jina-ai/docarray/commit/715252a72177e83cb81736106f34d0ab2960ce56)] __-__ hybrid pydantic support for both v1 and v2 (#1652) (*samsja*)
+
+### 🐞 Bug fixes
+
+ - [[```c2b08fa5```](https://github.com/jina-ai/docarray/commit/c2b08fa5cd9fab30fc4a1fe61d1373e943912fe7)] __-__ docstring tests with pydantic v2 (#1816) (*samsja*)
+ - [[```3da3603b```](https://github.com/jina-ai/docarray/commit/3da3603b6a0f69016504fcbb2cd303d68cf764ec)] __-__ allow config extension in pydantic v2 (#1814) (*samsja*)
+ - [[```4a1bc26a```](https://github.com/jina-ai/docarray/commit/4a1bc26a15dd02bbefa5607519f992a283bae975)] __-__ allow nested model dump via docvec (#1808) (*samsja*)
+ - [[```26d776dd```](https://github.com/jina-ai/docarray/commit/26d776dd4b49ee17d92b9c97b0af9ff4ab5a4cdc)] __-__ validate before (#1806) (*samsja*)
+ - [[```7209b784```](https://github.com/jina-ai/docarray/commit/7209b7849a2f9afd8cb98e01f80591d30ba28ec7)] __-__ fix double subscriptable error (#1800) (*Joan Fontanals*)
+ - [[```2937e253```](https://github.com/jina-ai/docarray/commit/2937e253f8a946872a310b93de5c706279eb5adb)] __-__ make DocList compatible with BaseDocWithoutId (#1805) (*samsja*)
+ - [[```0148e99c```](https://github.com/jina-ai/docarray/commit/0148e99c47ae9e1fc02c5bfc2ee717afa0633748)] __-__ milvus connection para missing (#1802) (*lvzi*)
+ - [[```2f3b85e3```](https://github.com/jina-ai/docarray/commit/2f3b85e333446cfa9b8c4877c4ccf9ae49cae660)] __-__ raise exception when type of DocList is object (#1794) (*Puneeth K*)
+
+### 🧼 Code Refactoring
+
+ - [[```3718a747```](https://github.com/jina-ai/docarray/commit/3718a747c806b87e331903b73747d864ace42725)] __-__ add is_index_empty API (#1801) (*Joan Fontanals*)
+
+### 📗 Documentation
+
+ - [[```061bd81a```](https://github.com/jina-ai/docarray/commit/061bd81aa3307f13281c1c51b0ef79761cd7c5a5)] __-__ fix documentation for pydantic v2 (#1815) (*samsja*)
+ - [[```d0b99909```](https://github.com/jina-ai/docarray/commit/d0b99909052d1640670764015e4e385319fdc776)] __-__ adding field descriptions to predefined mesh 3D document (#1789) (*Puneeth K*)
+ - [[```18d3afce```](https://github.com/jina-ai/docarray/commit/18d3afceb1a9206c1b6f9184e7d720f4f09510c9)] __-__ adding field descriptions to predefined point cloud 3D document (#1792) (*Puneeth K*)
+ - [[```4ef49394```](https://github.com/jina-ai/docarray/commit/4ef493943d1ee73613b51800980239a30fe5ae73)] __-__ adding field descriptions to predefined video document (#1775) (*Puneeth K*)
+ - [[```68cc1423```](https://github.com/jina-ai/docarray/commit/68cc1423058c18d27f13cae3bd307fba5d6e9aaa)] __-__ adding field descriptions to predefined text document (#1770) (*Puneeth K*)
+ - [[```441db26d```](https://github.com/jina-ai/docarray/commit/441db26dc3e58d72b44595131b42aa411c98774d)] __-__ adding field descriptions to predefined image document (#1772) (*Puneeth K*)
+ - [[```35d2138c```](https://github.com/jina-ai/docarray/commit/35d2138c2d7d246b6de87de829dd870d05fc54bf)] __-__ adding field descriptions to predefined audio document (#1774) (*Puneeth K*)
+
+### 🏁 Unit Test and CICD
+
+ - [[```9a6b1e64```](https://github.com/jina-ai/docarray/commit/9a6b1e646e1ce5c97b81413f83b0b5a12a2c4732)] __-__ move the pydantic check inside test (#1812) (*Joan Fontanals*)
+ - [[```92de15e6```](https://github.com/jina-ai/docarray/commit/92de15e6a6b42d4afe34b744e0c4f1ae581861bc)] __-__ remove skip of s3 (#1811) (*Joan Fontanals*)
+ - [[```bfac0939```](https://github.com/jina-ai/docarray/commit/bfac09399a514073bcde6c72535942623ecc0e84)] __-__ remove skips (#1809) (*Joan Fontanals*)
+ - [[```dce39075```](https://github.com/jina-ai/docarray/commit/dce3907560b987b9c5fc956c7ccf0cf38405d798)] __-__ fix test (#1807) (*Joan Fontanals*)
+ - [[```8f32866e```](https://github.com/jina-ai/docarray/commit/8f32866e1cd3aa139a88cfca2beaa502656dc76b)] __-__ remove skipif for pydantic (#1796) (*Joan Fontanals*)
+
+### 🍹 Other Improvements
+
+ - [[```7693cf7c```](https://github.com/jina-ai/docarray/commit/7693cf7c27b94d616bae37a183cc9c734bc285ee)] __-__ update version to 0.39.0 (#1818) (*Joan Fontanals*)
+ - [[```a4fdb77d```](https://github.com/jina-ai/docarray/commit/a4fdb77db92af2e49b8a9680950439f9ca5c1870)] __-__ fix failing test (#1793) (*Joan Fontanals*)
+ - [[```805a9825```](https://github.com/jina-ai/docarray/commit/805a9825fd59848bb205461e9da71934395c0768)] __-__ __version__: the next version will be 0.38.1 (*Jina Dev Bot*)
+
+
+## Release Note (`0.39.1`)
+
+> Release time: 2023-10-23 08:56:38
+
+
+
+🙇 We'd like to thank all contributors for this new release! In particular,
+ Joan Fontanals, Johannes Messner, dependabot[bot], Jina Dev Bot, 🙇
+
+
+### 🐞 Bug fixes
+
+ - [[```98d1f1f0```](https://github.com/jina-ai/docarray/commit/98d1f1f0a61851883c9598ce17d8f58594939c26)] __-__ from_dataframe with numpy==1.26.1 and type handling in python 3.9 (#1823) (*Johannes Messner*)
+
+### 🍹 Other Improvements
+
+ - [[```6094854a```](https://github.com/jina-ai/docarray/commit/6094854a5c113d58e55b98f778624766ac1c82f6)] __-__ update version before patch release (#1826) (*Joan Fontanals*)
+ - [[```7479f59a```](https://github.com/jina-ai/docarray/commit/7479f59a69616256cf61679a5a3246f376c22af0)] __-__ __deps__: bump pillow from 9.3.0 to 10.0.1 (#1819) (*dependabot[bot]*)
+ - [[```08bfa9cf```](https://github.com/jina-ai/docarray/commit/08bfa9cfae4d23bed2cd794f67fc5581a0f33133)] __-__ __version__: the next version will be 0.39.1 (*Jina Dev Bot*)
+
+
+## Release Note (`0.40.0`)
+
+> Release time: 2023-12-22 12:12:15
+
+
+
+🙇 We'd like to thank all contributors for this new release! In particular,
+ 954, Joan Fontanals, Tony Yang, Naymul Islam, Ben Shaver, Jina Dev Bot, 🙇
+
+
+### 🆕 New Features
+
+ - [[```ff00b604```](https://github.com/jina-ai/docarray/commit/ff00b6049f5f50bae4786f310907424b45791104)] __-__ __index__: add epsilla connector (#1835) (*Tony Yang*)
+ - [[```522811f4```](https://github.com/jina-ai/docarray/commit/522811f4b47e1c0f30fe13bb84c7625e349d0656)] __-__ use literal in type hints (#1827) (*Ben Shaver*)
+
+### 🐞 Bug fixes
+
+ - [[```1f86e263```](https://github.com/jina-ai/docarray/commit/1f86e263effaeab61f9c9e42becd37622595cd96)] __-__ error type hints in Python3.12 (#1147) (#1840) (*954*)
+ - [[```21e107bd```](https://github.com/jina-ai/docarray/commit/21e107bdaaae319c728c141a076d44738b7ec32e)] __-__ fix issue serializing deserializing complex schemas (#1836) (*Joan Fontanals*)
+ - [[```3cfa0b8f```](https://github.com/jina-ai/docarray/commit/3cfa0b8ff877d95cef0637f7f177499f0a9c6cfd)] __-__ fix storage issue in torchtensor class (#1833) (*Naymul Islam*)
+
+### 📗 Documentation
+
+ - [[```a2421a6a```](https://github.com/jina-ai/docarray/commit/a2421a6a86e4e42a10771e7070be7932caeb1d33)] __-__ __epsilla__: add epsilla integration guide (#1838) (*Tony Yang*)
+ - [[```82918fe7```](https://github.com/jina-ai/docarray/commit/82918fe7b6207ac112e096f88cccc71d80fc0afe)] __-__ fix sign commit commad in docs (#1834) (*Naymul Islam*)
+
+### 🍹 Other Improvements
+
+ - [[```0e183ff0```](https://github.com/jina-ai/docarray/commit/0e183ff0d48555b56fa34989513ac4fb53135626)] __-__ upgrade version (#1841) (*Joan Fontanals*)
+ - [[```8de3e175```](https://github.com/jina-ai/docarray/commit/8de3e1757bdb23b509ad2630219c3c26605308f0)] __-__ refactor test of the torchtensor (#1837) (*Naymul Islam*)
+ - [[```d5d928b8```](https://github.com/jina-ai/docarray/commit/d5d928b82f36a3279277c07bed44fd22bb0bba34)] __-__ __version__: the next version will be 0.39.2 (*Jina Dev Bot*)
+
+
+## Release Note (`0.40.1`)
+
+> Release time: 2025-03-21 08:34:40
+
+
+
+🙇 We'd like to thank all contributors for this new release! In particular,
+ Joan Fontanals, Emmanuel Ferdman, Casey Clements, YuXuan Tay, dependabot[bot], James Brown, Jina Dev Bot, 🙇
+
+
+### 🐞 Bug fixes
+
+ - [[```d98acb71```](https://github.com/jina-ai/docarray/commit/d98acb716e0c336a817f65b62d428ab13cf8ac42)] __-__ fix DocList schema when using Pydantic V2 (#1876) (*Joan Fontanals*)
+ - [[```83ebef60```](https://github.com/jina-ai/docarray/commit/83ebef6087e868517681e59877008f80f1e7f113)] __-__ update license location (#1911) (*Emmanuel Ferdman*)
+ - [[```8f4ba7cd```](https://github.com/jina-ai/docarray/commit/8f4ba7cdf177f3e4ecc838eef659496d6038af03)] __-__ use docker compose (#1905) (*YuXuan Tay*)
+ - [[```febbdc42```](https://github.com/jina-ai/docarray/commit/febbdc4291c4af7ad2058d7feebf6a3169de93e9)] __-__ fix float in dynamic Document creation (#1877) (*Joan Fontanals*)
+ - [[```7c1e18ef```](https://github.com/jina-ai/docarray/commit/7c1e18ef01b09ef3d864b200248c875d0d9ced29)] __-__ fix create pure python class iteratively (#1867) (*Joan Fontanals*)
+
+### 📗 Documentation
+
+ - [[```e4665e91```](https://github.com/jina-ai/docarray/commit/e4665e91b37f97a4a18a80399431d624db8ca453)] __-__ move hint about schemas to common docindex section (#1868) (*Joan Fontanals*)
+ - [[```8da50c92```](https://github.com/jina-ai/docarray/commit/8da50c927c24b981867650399f64d4930bd7c574)] __-__ add code review to contributing.md (#1853) (*Joan Fontanals*)
+
+### 🏁 Unit Test and CICD
+
+ - [[```a162a4b0```](https://github.com/jina-ai/docarray/commit/a162a4b09f4ad8e8c5c117c0c0101541af4c00a1)] __-__ fix release procedure (#1922) (*Joan Fontanals*)
+ - [[```82d7cee7```](https://github.com/jina-ai/docarray/commit/82d7cee71ccdd4d5874985aef0567631424b5bfd)] __-__ fix some ci (#1893) (*Joan Fontanals*)
+ - [[```791e4a04```](https://github.com/jina-ai/docarray/commit/791e4a0473afe9d9bde87733074eef0ce217d198)] __-__ update release procedure (#1869) (*Joan Fontanals*)
+ - [[```aa15b9ef```](https://github.com/jina-ai/docarray/commit/aa15b9eff2f5293849e83291d79bf519994c3503)] __-__ add license (#1861) (*Joan Fontanals*)
+
+### 🍹 Other Improvements
+
+ - [[```b5696b22```](https://github.com/jina-ai/docarray/commit/b5696b227161f087fa32834dcd6c2d212cf82c0e)] __-__ fix poetry in ci (#1921) (*Joan Fontanals*)
+ - [[```d3358105```](https://github.com/jina-ai/docarray/commit/d3358105db645418c3cebfc6acb0f353127364aa)] __-__ update pyproject version (#1919) (*Joan Fontanals*)
+ - [[```40cf2962```](https://github.com/jina-ai/docarray/commit/40cf29622b29be1f32595e26876593bb5f1e03be)] __-__ MongoDB Atlas: Two line change to make our CI builds green (#1910) (*Casey Clements*)
+ - [[```75e0033a```](https://github.com/jina-ai/docarray/commit/75e0033a361a31280709899e94d6f5e14ff4b8ae)] __-__ __deps__: bump setuptools from 65.5.1 to 70.0.0 (#1899) (*dependabot[bot]*)
+ - [[```75a743c9```](https://github.com/jina-ai/docarray/commit/75a743c99dc549eaf4c3ffe01086d09a8f3f3e44)] __-__ __deps-dev__: bump tornado from 6.2 to 6.4.1 (#1894) (*dependabot[bot]*)
+ - [[```f3fa7c23```](https://github.com/jina-ai/docarray/commit/f3fa7c2376da2449e98aff159167bf41467d610c)] __-__ __deps__: bump pydantic from 1.10.8 to 1.10.13 (#1884) (*dependabot[bot]*)
+ - [[```46d50828```](https://github.com/jina-ai/docarray/commit/46d5082844602689de97c904af7c8139980711ed)] __-__ __deps__: bump urllib3 from 1.26.14 to 1.26.19 (#1896) (*dependabot[bot]*)
+ - [[```f0f4236e```](https://github.com/jina-ai/docarray/commit/f0f4236ebf75528e6c5344dc75328ce9cf56cae9)] __-__ __deps__: bump zipp from 3.10.0 to 3.19.1 (#1898) (*dependabot[bot]*)
+ - [[```d65d27ce```](https://github.com/jina-ai/docarray/commit/d65d27ce37f5e7c930b7792fd665ac4da9c6398d)] __-__ __deps__: bump certifi from 2022.9.24 to 2024.7.4 (#1897) (*dependabot[bot]*)
+ - [[```b8b62173```](https://github.com/jina-ai/docarray/commit/b8b621735dbe16c188bf8c1c03cb3f1a22076ae8)] __-__ __deps__: bump authlib from 1.2.0 to 1.3.1 (#1895) (*dependabot[bot]*)
+ - [[```6a972d1c```](https://github.com/jina-ai/docarray/commit/6a972d1c0dcf6d0c2816dea14df37e0039945542)] __-__ __deps__: bump qdrant-client from 1.4.0 to 1.9.0 (#1892) (*dependabot[bot]*)
+ - [[```f71a5e6a```](https://github.com/jina-ai/docarray/commit/f71a5e6af58b77fdeb15ba27abd0b7d40b84fd09)] __-__ __deps__: bump cryptography from 40.0.1 to 42.0.4 (#1872) (*dependabot[bot]*)
+ - [[```065aab44```](https://github.com/jina-ai/docarray/commit/065aab441cd71635ee3711ad862240e967ca3da6)] __-__ __deps__: bump orjson from 3.8.2 to 3.9.15 (#1873) (*dependabot[bot]*)
+ - [[```caf97135```](https://github.com/jina-ai/docarray/commit/caf9713502791a8fbbf0aa53b3ca2db126f18df7)] __-__ add license notice to every file (#1860) (*Joan Fontanals*)
+ - [[```50376358```](https://github.com/jina-ai/docarray/commit/50376358163005e66a76cd0cb40217fd7a4f1252)] __-__ __deps-dev__: bump jupyterlab from 3.5.0 to 3.6.7 (#1848) (*dependabot[bot]*)
+ - [[```104b403b```](https://github.com/jina-ai/docarray/commit/104b403b2b61a485e2cc032a357f46e7dc8044fe)] __-__ __deps__: bump tj-actions/changed-files from 34 to 41 in /.github/workflows (#1844) (*dependabot[bot]*)
+ - [[```f9426a29```](https://github.com/jina-ai/docarray/commit/f9426a29b29580beae8805d2556b4a94ff493edc)] __-__ __version__: the next version will be 0.40.1 (*Jina Dev Bot*)
+
diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md
new file mode 100644
index 00000000000..25103e63317
--- /dev/null
+++ b/CODE_OF_CONDUCT.md
@@ -0,0 +1,128 @@
+# Contributor Covenant Code of Conduct
+
+## Our Pledge
+
+We as members, contributors, and leaders pledge to make participation in our
+community a harassment-free experience for everyone, regardless of age, body
+size, visible or invisible disability, ethnicity, sex characteristics, gender
+identity and expression, level of experience, education, socio-economic status,
+nationality, personal appearance, race, religion, or sexual identity
+and orientation.
+
+We pledge to act and interact in ways that contribute to an open, welcoming,
+diverse, inclusive, and healthy community.
+
+## Our Standards
+
+Examples of behavior that contributes to a positive environment for our
+community include:
+
+* Demonstrating empathy and kindness toward other people
+* Being respectful of differing opinions, viewpoints, and experiences
+* Giving and gracefully accepting constructive feedback
+* Accepting responsibility and apologizing to those affected by our mistakes,
+ and learning from the experience
+* Focusing on what is best not just for us as individuals, but for the
+ overall community
+
+Examples of unacceptable behavior include:
+
+* The use of sexualized language or imagery, and sexual attention or
+ advances of any kind
+* Trolling, insulting or derogatory comments, and personal or political attacks
+* Public or private harassment
+* Publishing others' private information, such as a physical or email
+ address, without their explicit permission
+* Other conduct which could reasonably be considered inappropriate in a
+ professional setting
+
+## Enforcement Responsibilities
+
+Community leaders are responsible for clarifying and enforcing our standards of
+acceptable behavior and will take appropriate and fair corrective action in
+response to any behavior that they deem inappropriate, threatening, offensive,
+or harmful.
+
+Community leaders have the right and responsibility to remove, edit, or reject
+comments, commits, code, wiki edits, issues, and other contributions that are
+not aligned to this Code of Conduct, and will communicate reasons for moderation
+decisions when appropriate.
+
+## Scope
+
+This Code of Conduct applies within all community spaces, and also applies when
+an individual is officially representing the community in public spaces.
+Examples of representing our community include using an official e-mail address,
+posting via an official social media account, or acting as an appointed
+representative at an online or offline event.
+
+## Enforcement
+
+Instances of abusive, harassing, or otherwise unacceptable behavior may be
+reported to the community leaders responsible for enforcement at
+docarray@jina.ai .
+All complaints will be reviewed and investigated promptly and fairly.
+
+All community leaders are obligated to respect the privacy and security of the
+reporter of any incident.
+
+## Enforcement Guidelines
+
+Community leaders will follow these Community Impact Guidelines in determining
+the consequences for any action they deem in violation of this Code of Conduct:
+
+### 1. Correction
+
+**Community Impact**: Use of inappropriate language or other behavior deemed
+unprofessional or unwelcome in the community.
+
+**Consequence**: A private, written warning from community leaders, providing
+clarity around the nature of the violation and an explanation of why the
+behavior was inappropriate. A public apology may be requested.
+
+### 2. Warning
+
+**Community Impact**: A violation through a single incident or series
+of actions.
+
+**Consequence**: A warning with consequences for continued behavior. No
+interaction with the people involved, including unsolicited interaction with
+those enforcing the Code of Conduct, for a specified period of time. This
+includes avoiding interactions in community spaces as well as external channels
+like social media. Violating these terms may lead to a temporary or
+permanent ban.
+
+### 3. Temporary Ban
+
+**Community Impact**: A serious violation of community standards, including
+sustained inappropriate behavior.
+
+**Consequence**: A temporary ban from any sort of interaction or public
+communication with the community for a specified period of time. No public or
+private interaction with the people involved, including unsolicited interaction
+with those enforcing the Code of Conduct, is allowed during this period.
+Violating these terms may lead to a permanent ban.
+
+### 4. Permanent Ban
+
+**Community Impact**: Demonstrating a pattern of violation of community
+standards, including sustained inappropriate behavior, harassment of an
+individual, or aggression toward or disparagement of classes of individuals.
+
+**Consequence**: A permanent ban from any sort of public interaction within
+the community.
+
+## Attribution
+
+This Code of Conduct is adapted from the [Contributor Covenant][homepage],
+version 2.0, available at
+https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
+
+Community Impact Guidelines were inspired by [Mozilla's code of conduct
+enforcement ladder](https://github.com/mozilla/diversity).
+
+[homepage]: https://www.contributor-covenant.org
+
+For answers to common questions about this code of conduct, see the FAQ at
+https://www.contributor-covenant.org/faq. Translations are available at
+https://www.contributor-covenant.org/translations.
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index c0ae7d78cc6..4c153ae2c54 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -14,6 +14,7 @@ In this guide, we're going to go through the steps for each kind of contribution
- [➕ Adding a dependency](#adding-a-dependency)
- [💥 Testing DocArray Locally and on CI](#-testing-docarray-locally-and-on-ci)
- [📖 Contributing Documentation](#-contributing-documentation)
+- [Code Review](#-code-review)
- [🙏 Thank You](#-thank-you)
@@ -210,7 +211,7 @@ Commits need to be signed. Indeed, the DocArray repo enforces the [Developer Cer
To sign your commits you need to [use the `-s` argument](https://docs.github.com/en/authentication/managing-commit-signature-verification/signing-commits) when committing:
```
-git commit -m -s 'feat: add a new feature'
+git commit -S -m 'feat: add a new feature'
```
#### What if I mess up?
@@ -321,6 +322,16 @@ Good docs make developers happy, and we love happy developers! We've got a few d
* Tutorials/examples
* Docstrings in Python functions in RST format - generated by Sphinx
+## ✅ Code Review
+
+Reviewing Pull Requests is also a great way to contribute to the project. When doing code review, please be mindful about the author and the effort they are putting into the contribution. Look for and suggest improvements without disparaging or insulting the author. Provide actionable feedback and explain your reasoning.
+
+* Try to check that the guidelines specified in this document are followed.
+
+* Try to check the presence of new tests covering the new or changed feature added by the code review.
+
+* Check that documentation changes follow the standards of quality and describe the features clearly.
+
### Documentation guidelines
1. Decide if your page is a **user guide or a how-to**, like in the `Data Types` section. Make sure it fits its section.
@@ -337,29 +348,32 @@ Good docs make developers happy, and we love happy developers! We've got a few d
### Building documentation on your local machine
-#### Requirements
-
-* Python 3
-* [jq](https://stedolan.github.io/jq/download/)
#### Steps to build locally
+First install the documentation dependency
+```
+poetry install --with docs
+```
+
+Note: if you need to install extra (proto, database, ...) you need to specify those as well.
+
+Then build the documentation:
```bash
cd docs
-pip install -r requirements.txt
-export NUM_RELEASES=10
-bash makedoc.sh
+./makedoc.sh
```
The docs website will be generated in `site`.
To serve it, run:
```bash
-mkdocs serve
-python -m http.server
+cd ..
+poetry run mkdocs serve
```
You can now see docs website on [http://localhost:8000](http://localhost:8000) on your browser.
+Note: You may have to change the port from 8000 to something else if you already have a server running on that port.
## 🙏 Thank you
diff --git a/README.md b/README.md
index c3df22decb4..1c4e27f989d 100644
--- a/README.md
+++ b/README.md
@@ -6,54 +6,63 @@
-
+
-> ⬆️ **DocArray v2**: This readme is for the second version of DocArray (starting at 0.30). If you want to use the older
-> version (prior to 0.30) check out the [docarray-v1-fixes](https://github.com/docarray/docarray/tree/docarray-v1-fixes) branch
+> **Note**
+> The README you're currently viewing is for DocArray>0.30, which introduces some significant changes from DocArray 0.21. If you wish to continue using the older DocArray <=0.21, ensure you install it via `pip install docarray==0.21`. Refer to its [codebase](https://github.com/docarray/docarray/tree/v0.21.0), [documentation](https://docarray.jina.ai), and [its hot-fixes branch](https://github.com/docarray/docarray/tree/docarray-v1-fixes) for more information.
-DocArray is a library for **representing, sending and storing multi-modal data**, perfect for **Machine Learning applications**.
-With DocArray you can:
+DocArray is a Python library expertly crafted for the [representation](#represent), [transmission](#send), [storage](#store), and [retrieval](#retrieve) of multimodal data. Tailored for the development of multimodal AI applications, its design guarantees seamless integration with the extensive Python and machine learning ecosystems. As of January 2022, DocArray is openly distributed under the [Apache License 2.0](https://github.com/docarray/docarray/blob/main/LICENSE.md) and currently enjoys the status of a sandbox project within the [LF AI & Data Foundation](https://lfaidata.foundation/).
-1. [**Represent data**](#represent)
-2. [**Send data**](#send)
-3. [**Store data**](#store)
-DocArray handles your data while integrating seamlessly with the rest of your **Python and ML ecosystem**:
-- :fire: Native compatibility for **[NumPy](https://github.com/numpy/numpy)**, **[PyTorch](https://github.com/pytorch/pytorch)** and **[TensorFlow](https://github.com/tensorflow/tensorflow)**, including for **model training use cases**
-- :zap: Built on **[Pydantic](https://github.com/pydantic/pydantic)** and out-of-the-box compatible with **[FastAPI](https://github.com/tiangolo/fastapi/)** and **[Jina](https://github.com/jina-ai/jina/)**
-- :package: Support for vector databases like **[Weaviate](https://weaviate.io/), [Qdrant](https://qdrant.tech/), [ElasticSearch](https://www.elastic.co/de/elasticsearch/)** and **[HNSWLib](https://github.com/nmslib/hnswlib)**
-- :chains: Send data as JSON over **HTTP** or as **[Protobuf](https://protobuf.dev/)** over **[gRPC](https://grpc.io/)**
+- :fire: Offers native support for **[NumPy](https://github.com/numpy/numpy)**, **[PyTorch](https://github.com/pytorch/pytorch)**, **[TensorFlow](https://github.com/tensorflow/tensorflow)**, and **[JAX](https://github.com/google/jax)**, catering specifically to **model training scenarios**.
+- :zap: Based on **[Pydantic](https://github.com/pydantic/pydantic)**, and instantly compatible with web and microservice frameworks like **[FastAPI](https://github.com/tiangolo/fastapi/)** and **[Jina](https://github.com/jina-ai/jina/)**.
+- :package: Provides support for vector databases such as **[Weaviate](https://weaviate.io/), [Qdrant](https://qdrant.tech/), [ElasticSearch](https://www.elastic.co/de/elasticsearch/), **[Redis](https://redis.io/)**, **[Mongo Atlas](https://www.mongodb.com/)**, and **[HNSWLib](https://github.com/nmslib/hnswlib)**.
+- :chains: Allows data transmission as JSON over **HTTP** or as **[Protobuf](https://protobuf.dev/)** over **[gRPC](https://grpc.io/)**.
-> :bulb: **Where are you coming from?** Based on your use case and background, there are different ways to understand DocArray:
->
-> - [Coming from pure PyTorch or TensorFlow](#coming-from-pytorch)
-> - [Coming from Pydantic](#coming-from-pydantic)
-> - [Coming from FastAPI](#coming-from-fastapi)
-> - [Coming from a vector database](#coming-from-vector-database)
+## Installation
+
+To install DocArray from the CLI, run the following command:
+
+```shell
+pip install -U docarray
+```
+
+> **Note**
+> To use DocArray <=0.21, make sure you install via `pip install docarray==0.21` and check out its [codebase](https://github.com/docarray/docarray/tree/v0.21.0) and [docs](https://docarray.jina.ai) and [its hot-fixes branch](https://github.com/docarray/docarray/tree/docarray-v1-fixes).
+
+## Get Started
+New to DocArray? Depending on your use case and background, there are multiple ways to learn about DocArray:
+
+- [Coming from pure PyTorch or TensorFlow](#coming-from-pytorch)
+- [Coming from Pydantic](#coming-from-pydantic)
+- [Coming from FastAPI](#coming-from-fastapi)
+- [Coming from Jina](#coming-from-jina)
+- [Coming from a vector database](#coming-from-a-vector-database)
+- [Coming from Langchain](#coming-from-langchain)
-DocArray was released under the open-source [Apache License 2.0](https://github.com/docarray/docarray/blob/main/LICENSE) in January 2022. It is currently a sandbox project under [LF AI & Data Foundation](https://lfaidata.foundation/).
## Represent
-DocArray allows you to **represent your data**, in an ML-native way.
+DocArray empowers you to **represent your data** in a manner that is inherently attuned to machine learning.
+
+This is particularly beneficial for various scenarios:
-This is useful for different use cases:
+- :running: You are **training a model**: You're dealing with tensors of varying shapes and sizes, each signifying different elements. You desire a method to logically organize them.
+- :cloud: You are **serving a model**: Let's say through FastAPI, and you wish to define your API endpoints precisely.
+- :card_index_dividers: You are **parsing data**: Perhaps for future deployment in your machine learning or data science projects.
-- :running: You are **training a model**: There are tensors of different shapes and sizes flying around, representing different _things_, and you want to keep a straight head about them.
-- :cloud: You are **serving a model**: For example through FastAPI, and you want to specify your API endpoints.
-- :card_index_dividers: You are **parsing data**: For later use in your ML or data science applications.
+> :bulb: **Familiar with Pydantic?** You'll be pleased to learn
+> that DocArray is not only constructed atop Pydantic but also maintains complete compatibility with it!
+> Furthermore, we have a [specific section](#coming-from-pydantic) dedicated to your needs!
-> :bulb: **Coming from Pydantic?** You should be happy to hear
-> that DocArray is built on top of, and is fully compatible with, Pydantic!
-> Also, we have a [dedicated section](#coming-from-pydantic) just for you!
+In essence, DocArray facilitates data representation in a way that mirrors Python dataclasses, with machine learning being an integral component:
-Put simply, DocArray lets you represent your data in a dataclass-like way, with ML as a first class citizen:
```python
from docarray import BaseDoc
@@ -102,14 +111,12 @@ class MyDocument(BaseDoc):
image_url: ImageUrl # could also be VideoUrl, AudioUrl, etc.
image_tensor: Optional[
TorchTensor[1704, 2272, 3]
- ] # could also be NdArray or TensorflowTensor
- embedding: Optional[TorchTensor]
+ ] = None # could also be NdArray or TensorflowTensor
+ embedding: Optional[TorchTensor] = None
```
So not only can you define the types of your data, you can even **specify the shape of your tensors!**
-Once you have your model in the form of a document, you can work with it!
-
```python
# Create a document
doc = MyDocument(
@@ -120,6 +127,7 @@ doc = MyDocument(
# Load image tensor from URL
doc.image_tensor = doc.image_url.load()
+
# Compute embedding with any model of your choice
def clip_image_encoder(image_tensor: TorchTensor) -> TorchTensor: # dummy function
return torch.rand(512)
@@ -256,21 +264,22 @@ assert isinstance(dl_2, DocList)
## Send
-DocArray allows you to **send your data** in an ML-native way.
+DocArray facilitates the **transmission of your data** in a manner inherently compatible with machine learning.
+
+This includes native support for **Protobuf and gRPC**, along with **HTTP** and serialization to JSON, JSONSchema, Base64, and Bytes.
-This means there is native support for **Protobuf and gRPC**, on top of **HTTP** and serialization to JSON, JSONSchema, Base64, and Bytes.
+This feature proves beneficial for several scenarios:
-This is useful for different use cases:
+- :cloud: You are **serving a model**, perhaps through frameworks like **[Jina](https://github.com/jina-ai/jina/)** or **[FastAPI](https://github.com/tiangolo/fastapi/)**
+- :spider_web: You are **distributing your model** across multiple machines and need an efficient means of transmitting your data between nodes
+- :gear: You are architecting a **microservice** environment and require a method for data transmission between microservices
-- :cloud: You are **serving a model**, for example through **[Jina](https://github.com/jina-ai/jina/)** or **[FastAPI](https://github.com/tiangolo/fastapi/)**
-- :spider_web: You are **distributing your model** across machines and need to send your data between nodes
-- :gear: You are building a **microservice** architecture and need to send your data between microservices
+> :bulb: **Are you familiar with FastAPI?** You'll be delighted to learn
+> that DocArray maintains full compatibility with FastAPI!
+> Plus, we have a [dedicated section](#coming-from-fastapi) specifically for you!
-> :bulb: **Coming from FastAPI?** You should be happy to hear
-> that DocArray is fully compatible with FastAPI!
-> Also, we have a [dedicated section](#coming-from-fastapi) just for you!
+When it comes to data transmission, serialization is a crucial step. Let's delve into how DocArray streamlines this process:
-Whenever you want to send your data, you need to serialize it, so let's take a look at how that works with DocArray:
```python
from docarray import BaseDoc
@@ -301,22 +310,18 @@ doc_4 = MyDocument.from_bytes(bytes_)
doc_5 = MyDocument.parse_raw(json)
```
-Of course, serialization is not all you need. So check out how DocArray integrates with FastAPI and Jina.
+Of course, serialization is not all you need. So check out how DocArray integrates with **[Jina](https://github.com/jina-ai/jina/)** and **[FastAPI](https://github.com/tiangolo/fastapi/)**.
## Store
-Once you've modelled your data, and maybe sent it around, usually you want to **store it** somewhere.
-DocArray has you covered!
+After modeling and possibly distributing your data, you'll typically want to **store it** somewhere. That's where DocArray steps in!
-**Document Stores** let you, well, store your Documents, locally or remotely, all with the same user interface:
+**Document Stores** provide a seamless way to, as the name suggests, store your Documents. Be it locally or remotely, you can do it all through the same user interface:
-- :cd: **On disk** as a file in your local file system
+- :cd: **On disk**, as a file in your local filesystem
- :bucket: On **[AWS S3](https://aws.amazon.com/de/s3/)**
- :cloud: On **[Jina AI Cloud](https://cloud.jina.ai/)**
-
- See Document Store usage
-
The Document Store interface lets you push and pull Documents to and from multiple data sources, all with the same user interface.
For example, let's see how that works with on-disk storage:
@@ -334,7 +339,8 @@ docs.push('file://simple_docs')
docs_pull = DocList[SimpleDoc].pull('file://simple_docs')
```
-
+
+## Retrieve
**Document Indexes** let you index your Documents in a **vector database** for efficient similarity-based retrieval.
@@ -344,10 +350,7 @@ This is useful for:
- :mag: **Neural search** applications
- :bulb: **Recommender systems**
-Currently, Document Indexes support **[Weaviate](https://weaviate.io/)**, **[Qdrant](https://qdrant.tech/)**, **[ElasticSearch](https://www.elastic.co/)**, and **[HNSWLib](https://github.com/nmslib/hnswlib)**, with more to come!
-
-
- See Document Index usage
+Currently, Document Indexes support **[Weaviate](https://weaviate.io/)**, **[Qdrant](https://qdrant.tech/)**, **[ElasticSearch](https://www.elastic.co/)**, **[Redis](https://redis.io/)**, **[Mongo Atlas](https://www.mongodb.com/)**, and **[HNSWLib](https://github.com/nmslib/hnswlib)**, with more to come!
The Document Index interface lets you index and retrieve Documents from multiple vector databases, all with the same user interface.
@@ -391,18 +394,20 @@ query = dl[0]
results, scores = index.find(query, limit=10, search_field='embedding')
```
-
+---
+
+## Learn DocArray
Depending on your background and use case, there are different ways for you to understand DocArray.
-## Coming from old DocArray
+### Coming from DocArray <=0.21
Click to expand
If you are using DocArray version 0.30.0 or lower, you will be familiar with its [dataclass API](https://docarray.jina.ai/fundamentals/dataclass/).
-_DocArray v2 is that idea, taken seriously._ Every document is created through a dataclass-like interface,
+_DocArray >=0.30 is that idea, taken seriously._ Every document is created through a dataclass-like interface,
courtesy of [Pydantic](https://pydantic-docs.helpmanual.io/usage/models/).
This gives the following advantages:
@@ -416,11 +421,11 @@ They are now called **Document Indexes** and offer the following improvements (s
- **Production-ready:** The new Document Indexes are a much thinner wrapper around the various vector DB libraries, making them more robust and easier to maintain
- **Increased flexibility:** We strive to support any configuration or setting that you could perform through the DB's first-party client
-For now, Document Indexes support **[Weaviate](https://weaviate.io/)**, **[Qdrant](https://qdrant.tech/)**, **[ElasticSearch](https://www.elastic.co/)**, and **[HNSWLib](https://github.com/nmslib/hnswlib)**, with more to come.
+For now, Document Indexes support **[Weaviate](https://weaviate.io/)**, **[Qdrant](https://qdrant.tech/)**, **[ElasticSearch](https://www.elastic.co/)**, **[Redis](https://redis.io/)**, **[Mongo Atlas](https://www.mongodb.com/)**, Exact Nearest Neighbour search and **[HNSWLib](https://github.com/nmslib/hnswlib)**, with more to come.
-## Coming from Pydantic
+### Coming from Pydantic
Click to expand
@@ -497,7 +502,7 @@ except Exception as e:
-## Coming from PyTorch
+### Coming from PyTorch
Click to expand
@@ -511,7 +516,7 @@ It offers you several advantages:
- **Go directly to deployment**, by re-using your data model as a [FastAPI](https://fastapi.tiangolo.com/) or [Jina](https://github.com/jina-ai/jina) API schema
- Connect model components between **microservices**, using Protobuf and gRPC
-DocArray can be used directly inside ML models to handle and represent multi-modal data.
+DocArray can be used directly inside ML models to handle and represent multimodaldata.
This allows you to reason about your data using DocArray's abstractions deep inside of `nn.Module`,
and provides a FastAPI-compatible schema that eases the transition between model training and model serving.
@@ -520,7 +525,6 @@ To see the effect of this, let's first observe a vanilla PyTorch implementation
```python
import torch
from torch import nn
-import torch
def encoder(x):
@@ -609,7 +613,7 @@ schema definition (see [below](#coming-from-fastapi)). Everything is handled in
-## Coming from TensorFlow
+### Coming from TensorFlow
Click to expand
@@ -619,7 +623,7 @@ Like the [PyTorch approach](#coming-from-pytorch), you can also use DocArray wit
First off, to use DocArray with TensorFlow we first need to install it as follows:
```
-pip install tensorflow==2.11.0
+pip install tensorflow==2.12.0
pip install protobuf==3.19.0
```
@@ -639,8 +643,8 @@ import tensorflow as tf
class Podcast(BaseDoc):
- audio_tensor: Optional[AudioTensorFlowTensor]
- embedding: Optional[AudioTensorFlowTensor]
+ audio_tensor: Optional[AudioTensorFlowTensor] = None
+ embedding: Optional[AudioTensorFlowTensor] = None
class MyPodcastModel(tf.keras.Model):
@@ -657,7 +661,7 @@ class MyPodcastModel(tf.keras.Model):
-## Coming from FastAPI
+### Coming from FastAPI
Click to expand
@@ -675,16 +679,15 @@ And to seal the deal, let us show you how easily documents slot into your FastAP
```python
import numpy as np
from fastapi import FastAPI
-from httpx import AsyncClient
-
+from docarray.base_doc import DocArrayResponse
from docarray import BaseDoc
from docarray.documents import ImageDoc
-from docarray.typing import NdArray
-from docarray.base_doc import DocArrayResponse
+from docarray.typing import NdArray, ImageTensor
class InputDoc(BaseDoc):
img: ImageDoc
+ text: str
class OutputDoc(BaseDoc):
@@ -692,31 +695,100 @@ class OutputDoc(BaseDoc):
embedding_bert: NdArray
-input_doc = InputDoc(img=ImageDoc(tensor=np.zeros((3, 224, 224))))
-
app = FastAPI()
-@app.post("/doc/", response_model=OutputDoc, response_class=DocArrayResponse)
+def model_img(img: ImageTensor) -> NdArray:
+ return np.zeros((100, 1))
+
+
+def model_text(text: str) -> NdArray:
+ return np.zeros((100, 1))
+
+
+@app.post("/embed/", response_model=OutputDoc, response_class=DocArrayResponse)
async def create_item(doc: InputDoc) -> OutputDoc:
- ## call my fancy model to generate the embeddings
doc = OutputDoc(
- embedding_clip=np.zeros((100, 1)), embedding_bert=np.zeros((100, 1))
+ embedding_clip=model_img(doc.img.tensor), embedding_bert=model_text(doc.text)
)
return doc
+input_doc = InputDoc(text='', img=ImageDoc(tensor=np.random.random((3, 224, 224))))
+
async with AsyncClient(app=app, base_url="http://test") as ac:
- response = await ac.post("/doc/", data=input_doc.json())
- resp_doc = await ac.get("/docs")
- resp_redoc = await ac.get("/redoc")
+ response = await ac.post("/embed/", data=input_doc.json())
```
Just like a vanilla Pydantic model!
-## Coming from a vector database
+### Coming from Jina
+
+
+ Click to expand
+
+Jina has adopted docarray as their library for representing and serializing Documents.
+
+Jina allows to serve models and services that are built with DocArray allowing you to serve and scale these applications
+making full use of DocArray's serialization capabilites.
+
+```python
+import numpy as np
+from jina import Deployment, Executor, requests
+from docarray import BaseDoc, DocList
+from docarray.documents import ImageDoc
+from docarray.typing import NdArray, ImageTensor
+
+
+class InputDoc(BaseDoc):
+ img: ImageDoc
+ text: str
+
+
+class OutputDoc(BaseDoc):
+ embedding_clip: NdArray
+ embedding_bert: NdArray
+
+
+def model_img(img: ImageTensor) -> NdArray:
+ return np.zeros((100, 1))
+
+
+def model_text(text: str) -> NdArray:
+ return np.zeros((100, 1))
+
+
+class MyEmbeddingExecutor(Executor):
+ @requests(on='/embed')
+ def encode(self, docs: DocList[InputDoc], **kwargs) -> DocList[OutputDoc]:
+ ret = DocList[OutputDoc]()
+ for doc in docs:
+ output = OutputDoc(
+ embedding_clip=model_img(doc.img.tensor),
+ embedding_bert=model_text(doc.text),
+ )
+ ret.append(output)
+ return ret
+
+
+with Deployment(
+ protocols=['grpc', 'http'], ports=[12345, 12346], uses=MyEmbeddingExecutor
+) as dep:
+ resp = dep.post(
+ on='/embed',
+ inputs=DocList[InputDoc](
+ [InputDoc(text='', img=ImageDoc(tensor=np.random.random((3, 224, 224))))]
+ ),
+ return_type=DocList[OutputDoc],
+ )
+ print(resp)
+```
+
+
+
+### Coming from a vector database
Click to expand
@@ -768,31 +840,111 @@ Currently, DocArray supports the following vector databases:
- [Weaviate](https://www.weaviate.io/)
- [Qdrant](https://qdrant.tech/)
- [Elasticsearch](https://www.elastic.co/elasticsearch/) v8 and v7
-- [HNSWlib](https://github.com/nmslib/hnswlib) as a local-first alternative
+- [Redis](https://redis.io/)
+- [Milvus](https://milvus.io)
+- ExactNNMemorySearch as a local alternative with exact kNN search.
+- [HNSWlib](https://github.com/nmslib/hnswlib) as a local-first ANN alternative
+- [Mongo Atlas](https://www.mongodb.com/)
An integration of [OpenSearch](https://opensearch.org/) is currently in progress.
-Legacy versions of DocArray also support [Redis](https://redis.io/) and [Milvus](https://milvus.io/), but these are not yet supported in the current version.
-
Of course this is only one of the things that DocArray can do, so we encourage you to check out the rest of this readme!
-## Installation
+### Coming from Langchain
-To install DocArray from the CLI, run the following command:
+
+ Click to expand
+With DocArray, you can connect external data to LLMs through Langchain. DocArray gives you the freedom to establish
+flexible document schemas and choose from different backends for document storage.
+After creating your document index, you can connect it to your Langchain app using [DocArrayRetriever](https://python.langchain.com/docs/modules/data_connection/retrievers/integrations/docarray_retriever).
+
+Install Langchain via:
```shell
-pip install -U docarray
+pip install langchain
+```
+
+1. Define a schema and create documents:
+```python
+from docarray import BaseDoc, DocList
+from docarray.typing import NdArray
+from langchain.embeddings.openai import OpenAIEmbeddings
+
+embeddings = OpenAIEmbeddings()
+
+
+# Define a document schema
+class MovieDoc(BaseDoc):
+ title: str
+ description: str
+ year: int
+ embedding: NdArray[1536]
+
+
+movies = [
+ {"title": "#1 title", "description": "#1 description", "year": 1999},
+ {"title": "#2 title", "description": "#2 description", "year": 2001},
+]
+
+# Embed `description` and create documents
+docs = DocList[MovieDoc](
+ MovieDoc(embedding=embeddings.embed_query(movie["description"]), **movie)
+ for movie in movies
+)
+```
+
+2. Initialize a document index using any supported backend:
+```python
+from docarray.index import (
+ InMemoryExactNNIndex,
+ HnswDocumentIndex,
+ WeaviateDocumentIndex,
+ QdrantDocumentIndex,
+ ElasticDocIndex,
+ RedisDocumentIndex,
+ MongoDBAtlasDocumentIndex,
+)
+
+# Select a suitable backend and initialize it with data
+db = InMemoryExactNNIndex[MovieDoc](docs)
```
+3. Finally, initialize a retriever and integrate it into your chain!
+```python
+from langchain.chat_models import ChatOpenAI
+from langchain.chains import ConversationalRetrievalChain
+from langchain.retrievers import DocArrayRetriever
+
+
+# Create a retriever
+retriever = DocArrayRetriever(
+ index=db,
+ embeddings=embeddings,
+ search_field="embedding",
+ content_field="description",
+)
+
+# Use the retriever in your chain
+model = ChatOpenAI()
+qa = ConversationalRetrievalChain.from_llm(model, retriever=retriever)
+```
+
+Alternatively, you can use built-in vector stores. Langchain supports two vector stores: [DocArrayInMemorySearch](https://python.langchain.com/docs/modules/data_connection/vectorstores/integrations/docarray_in_memory) and [DocArrayHnswSearch](https://python.langchain.com/docs/modules/data_connection/vectorstores/integrations/docarray_hnsw).
+Both are user-friendly and are best suited to small to medium-sized datasets.
+
+
+
+
## See also
-- [Documentation](https://docarray.jina.ai)
+- [Documentation](https://docs.docarray.org)
+- [DocArray<=0.21 documentation](https://docarray.jina.ai/)
- [Join our Discord server](https://discord.gg/WaMp6PVPgR)
- [Donation to Linux Foundation AI&Data blog post](https://jina.ai/news/donate-docarray-lf-for-inclusive-standard-multimodal-data-model/)
-- ["Legacy" DocArray github page](https://github.com/docarray/docarray/tree/docarray-v1-fixes)
-- ["Legacy" DocArray documentation](https://docarray.jina.ai/)
+- [Roadmap](https://github.com/docarray/docarray/issues/1714)
> DocArray is a trademark of LF AI Projects, LLC
+>
diff --git a/docarray/__init__.py b/docarray/__init__.py
index 82c69839e57..20b08ba1735 100644
--- a/docarray/__init__.py
+++ b/docarray/__init__.py
@@ -1,10 +1,79 @@
-__version__ = '0.31.2'
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+__version__ = '0.40.2'
import logging
from docarray.array import DocList, DocVec
from docarray.base_doc.doc import BaseDoc
from docarray.utils._internal.misc import _get_path_from_docarray_root_level
+from docarray.utils._internal.pydantic import is_pydantic_v2
+
+
+def unpickle_doclist(doc_type, b):
+ return DocList[doc_type].from_bytes(b, protocol="protobuf")
+
+
+def unpickle_docvec(doc_type, tensor_type, b):
+ return DocVec[doc_type].from_bytes(b, protocol="protobuf", tensor_type=tensor_type)
+
+
+if is_pydantic_v2:
+ # Register the pickle functions
+ def register_serializers():
+ import copyreg
+ from functools import partial
+
+ unpickle_doc_fn = partial(BaseDoc.from_bytes, protocol="protobuf")
+
+ def pickle_doc(doc):
+ b = doc.to_bytes(protocol='protobuf')
+ return unpickle_doc_fn, (doc.__class__, b)
+
+ # Register BaseDoc serialization
+ copyreg.pickle(BaseDoc, pickle_doc)
+
+ # For DocList, we need to hook into __reduce__ since it's a generic
+
+ def pickle_doclist(doc_list):
+ b = doc_list.to_bytes(protocol='protobuf')
+ doc_type = doc_list.doc_type
+ return unpickle_doclist, (doc_type, b)
+
+ # Replace DocList.__reduce__ with a method that returns the correct format
+ def doclist_reduce(self):
+ return pickle_doclist(self)
+
+ DocList.__reduce__ = doclist_reduce
+
+ # For DocVec, we need to hook into __reduce__ since it's a generic
+
+ def pickle_docvec(doc_vec):
+ b = doc_vec.to_bytes(protocol='protobuf')
+ doc_type = doc_vec.doc_type
+ tensor_type = doc_vec.tensor_type
+ return unpickle_docvec, (doc_type, tensor_type, b)
+
+ # Replace DocList.__reduce__ with a method that returns the correct format
+ def docvec_reduce(self):
+ return pickle_docvec(self)
+
+ DocVec.__reduce__ = docvec_reduce
+
+ register_serializers()
__all__ = ['BaseDoc', 'DocList', 'DocVec']
diff --git a/docarray/array/__init__.py b/docarray/array/__init__.py
index 16e1274c1e3..8fc423c13f0 100644
--- a/docarray/array/__init__.py
+++ b/docarray/array/__init__.py
@@ -1,3 +1,18 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from docarray.array.any_array import AnyDocArray
from docarray.array.doc_list.doc_list import DocList
from docarray.array.doc_vec.doc_vec import DocVec
diff --git a/docarray/array/any_array.py b/docarray/array/any_array.py
index 612fba7f42e..0c29e54ae82 100644
--- a/docarray/array/any_array.py
+++ b/docarray/array/any_array.py
@@ -1,3 +1,4 @@
+import sys
import random
from abc import abstractmethod
from typing import (
@@ -19,33 +20,49 @@
import numpy as np
-from docarray.base_doc import BaseDoc
+from docarray.base_doc.doc import BaseDocWithoutId
from docarray.display.document_array_summary import DocArraySummary
+from docarray.exceptions.exceptions import UnusableObjectError
from docarray.typing.abstract_type import AbstractType
-from docarray.utils._internal._typing import change_cls_name
+from docarray.utils._internal._typing import change_cls_name, safe_issubclass
+from docarray.utils._internal.pydantic import is_pydantic_v2
if TYPE_CHECKING:
from docarray.proto import DocListProto, NodeProto
from docarray.typing.tensor.abstract_tensor import AbstractTensor
+if sys.version_info >= (3, 12):
+ from types import GenericAlias
+
T = TypeVar('T', bound='AnyDocArray')
-T_doc = TypeVar('T_doc', bound=BaseDoc)
+T_doc = TypeVar('T_doc', bound=BaseDocWithoutId)
IndexIterType = Union[slice, Iterable[int], Iterable[bool], None]
+UNUSABLE_ERROR_MSG = (
+ 'This {cls} instance is in an unusable state. \n'
+ 'The most common cause of this is converting a DocVec to a DocList. '
+ 'After you call `doc_vec.to_doc_list()`, `doc_vec` cannot be used anymore. '
+ 'Instead, you should do `doc_list = doc_vec.to_doc_list()` and only use `doc_list`.'
+)
+
class AnyDocArray(Sequence[T_doc], Generic[T_doc], AbstractType):
- doc_type: Type[BaseDoc]
- __typed_da__: Dict[Type['AnyDocArray'], Dict[Type[BaseDoc], Type]] = {}
+ doc_type: Type[BaseDocWithoutId]
+ __typed_da__: Dict[Type['AnyDocArray'], Dict[Type[BaseDocWithoutId], Type]] = {}
def __repr__(self):
return f'<{self.__class__.__name__} (length={len(self)})>'
@classmethod
- def __class_getitem__(cls, item: Union[Type[BaseDoc], TypeVar, str]):
+ def __class_getitem__(cls, item: Union[Type[BaseDocWithoutId], TypeVar, str]):
if not isinstance(item, type):
- return Generic.__class_getitem__.__func__(cls, item) # type: ignore
- # this do nothing that checking that item is valid type var or str
- if not issubclass(item, BaseDoc):
+ if sys.version_info < (3, 12):
+ return Generic.__class_getitem__.__func__(cls, item) # type: ignore
+ # this do nothing that checking that item is valid type var or str
+ # Keep the approach in #1147 to be compatible with lower versions of Python.
+ else:
+ return GenericAlias(cls, item) # type: ignore
+ if not safe_issubclass(item, BaseDocWithoutId):
raise ValueError(
f'{cls.__name__}[item] item should be a Document not a {item} '
)
@@ -57,16 +74,35 @@ def __class_getitem__(cls, item: Union[Type[BaseDoc], TypeVar, str]):
# Promote to global scope so multiprocessing can pickle it
global _DocArrayTyped
- class _DocArrayTyped(cls): # type: ignore
- doc_type: Type[BaseDoc] = cast(Type[BaseDoc], item)
+ if not is_pydantic_v2:
+
+ class _DocArrayTyped(cls): # type: ignore
+ doc_type: Type[BaseDocWithoutId] = cast(
+ Type[BaseDocWithoutId], item
+ )
+
+ else:
+
+ class _DocArrayTyped(cls, Generic[T_doc]): # type: ignore
+ doc_type: Type[BaseDocWithoutId] = cast(
+ Type[BaseDocWithoutId], item
+ )
- for field in _DocArrayTyped.doc_type.__fields__.keys():
+ for field in _DocArrayTyped.doc_type._docarray_fields().keys():
def _property_generator(val: str):
def _getter(self):
+ if getattr(self, '_is_unusable', False):
+ raise UnusableObjectError(
+ UNUSABLE_ERROR_MSG.format(cls=cls.__name__)
+ )
return self._get_data_column(val)
def _setter(self, value):
+ if getattr(self, '_is_unusable', False):
+ raise UnusableObjectError(
+ UNUSABLE_ERROR_MSG.format(cls=cls.__name__)
+ )
self._set_data_column(val, value)
# need docstring for the property
@@ -75,14 +111,24 @@ def _setter(self, value):
setattr(_DocArrayTyped, field, _property_generator(field))
# this generates property on the fly based on the schema of the item
- # The global scope and qualname need to refer to this class a unique name.
- # Otherwise, creating another _DocArrayTyped will overwrite this one.
- change_cls_name(
- _DocArrayTyped, f'{cls.__name__}[{item.__name__}]', globals()
- )
-
- cls.__typed_da__[cls][item] = _DocArrayTyped
+ # # The global scope and qualname need to refer to this class a unique name.
+ # # Otherwise, creating another _DocArrayTyped will overwrite this one.
+ if not is_pydantic_v2:
+ change_cls_name(
+ _DocArrayTyped, f'{cls.__name__}[{item.__name__}]', globals()
+ )
+ cls.__typed_da__[cls][item] = _DocArrayTyped
+ else:
+ change_cls_name(_DocArrayTyped, f'{cls.__name__}', globals())
+ if sys.version_info < (3, 12):
+ cls.__typed_da__[cls][item] = Generic.__class_getitem__.__func__(
+ _DocArrayTyped, item
+ ) # type: ignore
+ # this do nothing that checking that item is valid type var or str
+ # Keep the approach in #1147 to be compatible with lower versions of Python.
+ else:
+ cls.__typed_da__[cls][item] = GenericAlias(_DocArrayTyped, item) # type: ignore
return cls.__typed_da__[cls][item]
@overload
diff --git a/docarray/array/doc_list/__init__.py b/docarray/array/doc_list/__init__.py
index e69de29bb2d..74f8f7582cd 100644
--- a/docarray/array/doc_list/__init__.py
+++ b/docarray/array/doc_list/__init__.py
@@ -0,0 +1,15 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/docarray/array/doc_list/doc_list.py b/docarray/array/doc_list/doc_list.py
index 29a617e2c5f..49236199153 100644
--- a/docarray/array/doc_list/doc_list.py
+++ b/docarray/array/doc_list/doc_list.py
@@ -10,23 +10,30 @@
Type,
TypeVar,
Union,
+ cast,
overload,
+ Callable,
)
from pydantic import parse_obj_as
from typing_extensions import SupportsIndex
-from typing_inspect import is_union_type
+from typing_inspect import is_typevar, is_union_type
from docarray.array.any_array import AnyDocArray
-from docarray.array.doc_list.io import IOMixinArray
+from docarray.array.doc_list.io import IOMixinDocList
from docarray.array.doc_list.pushpull import PushPullMixin
from docarray.array.list_advance_indexing import IndexIterType, ListAdvancedIndexing
-from docarray.base_doc import AnyDoc, BaseDoc
+from docarray.base_doc import AnyDoc
+from docarray.base_doc.doc import BaseDocWithoutId
from docarray.typing import NdArray
+from docarray.utils._internal.pydantic import is_pydantic_v2
+
+if is_pydantic_v2:
+ from pydantic_core import core_schema
+
+from docarray.utils._internal._typing import safe_issubclass
if TYPE_CHECKING:
- from pydantic import BaseConfig
- from pydantic.fields import ModelField
from docarray.array.doc_vec.doc_vec import DocVec
from docarray.proto import DocListProto
@@ -34,14 +41,11 @@
from docarray.typing.tensor.abstract_tensor import AbstractTensor
T = TypeVar('T', bound='DocList')
-T_doc = TypeVar('T_doc', bound=BaseDoc)
+T_doc = TypeVar('T_doc', bound=BaseDocWithoutId)
class DocList(
- ListAdvancedIndexing[T_doc],
- PushPullMixin,
- IOMixinArray,
- AnyDocArray[T_doc],
+ ListAdvancedIndexing[T_doc], PushPullMixin, IOMixinDocList, AnyDocArray[T_doc]
):
"""
DocList is a container of Documents.
@@ -61,7 +65,7 @@ class DocList(
class Image(BaseDoc):
- tensor: Optional[NdArray[100]]
+ tensor: Optional[NdArray[100]] = None
url: ImageUrl
@@ -114,7 +118,7 @@ class Image(BaseDoc):
"""
- doc_type: Type[BaseDoc] = AnyDoc
+ doc_type: Type[BaseDocWithoutId] = AnyDoc
def __init__(
self,
@@ -157,7 +161,9 @@ def _validate_docs(self, docs: Iterable[T_doc]) -> Iterable[T_doc]:
def _validate_one_doc(self, doc: T_doc) -> T_doc:
"""Validate if a Document is compatible with this `DocList`"""
- if not issubclass(self.doc_type, AnyDoc) and not isinstance(doc, self.doc_type):
+ if not safe_issubclass(self.doc_type, AnyDoc) and not isinstance(
+ doc, self.doc_type
+ ):
raise ValueError(f'{doc} is not a {self.doc_type}')
return doc
@@ -211,13 +217,17 @@ def __class_getitem__(cls, item: Union[Type[BaseDoc], TypeVar, str]):
:return: Returns a list of the field value for each document
in the doc_list like container
"""
- field_type = self.__class__.doc_type._get_field_type(field)
+ field_type = self.__class__.doc_type._get_field_annotation(field)
+ field_info = self.__class__.doc_type._docarray_fields()[field]
+ is_field_required = (
+ field_info.is_required() if is_pydantic_v2 else field_info.required
+ )
if (
not is_union_type(field_type)
- and self.__class__.doc_type.__fields__[field].required
+ and is_field_required
and isinstance(field_type, type)
- and issubclass(field_type, BaseDoc)
+ and safe_issubclass(field_type, BaseDocWithoutId)
):
# calling __class_getitem__ ourselves is a hack otherwise mypy complain
# most likely a bug in mypy though
@@ -259,16 +269,24 @@ def to_doc_vec(
return DocVec.__class_getitem__(self.doc_type)(self, tensor_type=tensor_type)
@classmethod
- def validate(
+ def _docarray_validate(
cls: Type[T],
- value: Union[T, Iterable[BaseDoc]],
- field: 'ModelField',
- config: 'BaseConfig',
+ value: Union[T, Iterable[BaseDocWithoutId]],
):
from docarray.array.doc_vec.doc_vec import DocVec
- if isinstance(value, (cls, DocVec)):
+ if isinstance(value, cls):
return value
+ elif isinstance(value, DocVec):
+ if (
+ safe_issubclass(value.doc_type, cls.doc_type)
+ or value.doc_type == cls.doc_type
+ ):
+ return cast(T, value.to_doc_list())
+ else:
+ raise ValueError(
+ f'DocList[value.doc_type] is not compatible with {cls}'
+ )
elif isinstance(value, cls):
return cls(value)
elif isinstance(value, Iterable):
@@ -295,6 +313,12 @@ def from_protobuf(cls: Type[T], pb_msg: 'DocListProto') -> T:
"""
return super().from_protobuf(pb_msg)
+ @classmethod
+ def _get_proto_class(cls: Type[T]):
+ from docarray.proto import DocListProto
+
+ return DocListProto
+
@overload
def __getitem__(self, item: SupportsIndex) -> T_doc:
...
@@ -307,12 +331,43 @@ def __getitem__(self, item):
return super().__getitem__(item)
@classmethod
- def __class_getitem__(cls, item: Union[Type[BaseDoc], TypeVar, str]):
+ def __class_getitem__(cls, item: Union[Type[BaseDocWithoutId], TypeVar, str]):
+ if cls.doc_type != AnyDoc:
+ raise TypeError(f'{cls} object is not subscriptable')
- if isinstance(item, type) and issubclass(item, BaseDoc):
+ if isinstance(item, type) and safe_issubclass(item, BaseDocWithoutId):
return AnyDocArray.__class_getitem__.__func__(cls, item) # type: ignore
- else:
- return super().__class_getitem__(item)
+ if (
+ isinstance(item, object)
+ and not is_typevar(item)
+ and not isinstance(item, str)
+ and item is not Any
+ ):
+ raise TypeError('Expecting a type, got object instead')
+
+ return super().__class_getitem__(item)
def __repr__(self):
return AnyDocArray.__repr__(self) # type: ignore
+
+ if is_pydantic_v2:
+
+ @classmethod
+ def __get_pydantic_core_schema__(
+ cls, source: Any, handler: Callable[[Any], core_schema.CoreSchema]
+ ) -> core_schema.CoreSchema:
+ instance_schema = core_schema.is_instance_schema(cls)
+ args = getattr(source, '__args__', None)
+ if args:
+ sequence_t_schema = handler(Sequence[args[0]])
+ else:
+ sequence_t_schema = handler(Sequence)
+
+ def validate_fn(v, info):
+ # input has already been validated
+ return cls(v, validate_input_docs=False)
+
+ non_instance_schema = core_schema.with_info_after_validator_function(
+ validate_fn, sequence_t_schema
+ )
+ return core_schema.union_schema([instance_schema, non_instance_schema])
diff --git a/docarray/array/doc_list/io.py b/docarray/array/doc_list/io.py
index c2b531c2550..3acb66bf6e8 100644
--- a/docarray/array/doc_list/io.py
+++ b/docarray/array/doc_list/io.py
@@ -23,6 +23,7 @@
Type,
TypeVar,
Union,
+ cast,
)
import orjson
@@ -35,15 +36,17 @@
_dict_to_access_paths,
)
from docarray.utils._internal.compress import _decompress_bytes, _get_compress_ctx
-from docarray.utils._internal.misc import import_library
+from docarray.utils._internal.misc import import_library, ProtocolType
if TYPE_CHECKING:
import pandas as pd
- from docarray import DocList
+ from docarray.array.doc_vec.doc_vec import DocVec
+ from docarray.array.doc_vec.io import IOMixinDocVec
from docarray.proto import DocListProto
+ from docarray.typing.tensor.abstract_tensor import AbstractTensor
-T = TypeVar('T', bound='IOMixinArray')
+T = TypeVar('T', bound='IOMixinDocList')
T_doc = TypeVar('T_doc', bound=BaseDoc)
ARRAY_PROTOCOLS = {'protobuf-array', 'pickle-array', 'json-array'}
@@ -54,9 +57,9 @@
def _protocol_and_compress_from_file_path(
file_path: Union[pathlib.Path, str],
- default_protocol: Optional[str] = None,
+ default_protocol: Optional[ProtocolType] = None,
default_compress: Optional[str] = None,
-) -> Tuple[Optional[str], Optional[str]]:
+) -> Tuple[Optional[ProtocolType], Optional[str]]:
"""Extract protocol and compression algorithm from a string, use defaults if not found.
:param file_path: path of a file.
:param default_protocol: default serialization protocol used in case not found.
@@ -76,7 +79,7 @@ def _protocol_and_compress_from_file_path(
file_extensions = [e.replace('.', '') for e in pathlib.Path(file_path).suffixes]
for extension in file_extensions:
if extension in ALLOWED_PROTOCOLS:
- protocol = extension
+ protocol = cast(ProtocolType, extension)
elif extension in ALLOWED_COMPRESSIONS:
compress = extension
@@ -97,7 +100,7 @@ def __getitem__(self, item: slice):
return self.content[item]
-class IOMixinArray(Iterable[T_doc]):
+class IOMixinDocList(Iterable[T_doc]):
doc_type: Type[T_doc]
@abstractmethod
@@ -132,7 +135,7 @@ def to_protobuf(self) -> 'DocListProto':
def from_bytes(
cls: Type[T],
data: bytes,
- protocol: str = 'protobuf-array',
+ protocol: ProtocolType = 'protobuf-array',
compress: Optional[str] = None,
show_progress: bool = False,
) -> T:
@@ -154,7 +157,7 @@ def from_bytes(
def _write_bytes(
self,
bf: BinaryIO,
- protocol: str = 'protobuf-array',
+ protocol: ProtocolType = 'protobuf-array',
compress: Optional[str] = None,
show_progress: bool = False,
) -> None:
@@ -180,7 +183,7 @@ def _write_bytes(
elif protocol == 'pickle-array':
f.write(pickle.dumps(self))
elif protocol == 'json-array':
- f.write(self.to_json())
+ f.write(self.to_json().encode())
elif protocol in SINGLE_PROTOCOLS:
f.write(
b''.join(
@@ -198,7 +201,7 @@ def _write_bytes(
def _to_binary_stream(
self,
- protocol: str = 'protobuf',
+ protocol: ProtocolType = 'protobuf',
compress: Optional[str] = None,
show_progress: bool = False,
) -> Iterator[bytes]:
@@ -238,7 +241,7 @@ def _to_binary_stream(
def to_bytes(
self,
- protocol: str = 'protobuf-array',
+ protocol: ProtocolType = 'protobuf-array',
compress: Optional[str] = None,
file_ctx: Optional[BinaryIO] = None,
show_progress: bool = False,
@@ -253,7 +256,6 @@ def to_bytes(
:param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf`
:return: the binary serialization in bytes or None if file_ctx is passed where to store
"""
-
with file_ctx or io.BytesIO() as bf:
self._write_bytes(
bf=bf,
@@ -270,7 +272,7 @@ def to_bytes(
def from_base64(
cls: Type[T],
data: str,
- protocol: str = 'protobuf-array',
+ protocol: ProtocolType = 'protobuf-array',
compress: Optional[str] = None,
show_progress: bool = False,
) -> T:
@@ -291,7 +293,7 @@ def from_base64(
def to_base64(
self,
- protocol: str = 'protobuf-array',
+ protocol: ProtocolType = 'protobuf-array',
compress: Optional[str] = None,
show_progress: bool = False,
) -> str:
@@ -324,19 +326,19 @@ def from_json(
json_docs = orjson.loads(file)
return cls([cls.doc_type(**v) for v in json_docs])
- def to_json(self) -> bytes:
+ def to_json(self) -> str:
"""Convert the object into JSON bytes. Can be loaded via `.from_json`.
:return: JSON serialization of `DocList`
"""
- return orjson_dumps(self)
+ return orjson_dumps(self).decode('UTF-8')
@classmethod
def from_csv(
- cls,
+ cls: Type['T'],
file_path: str,
encoding: str = 'utf-8',
dialect: Union[str, csv.Dialect] = 'excel',
- ) -> 'DocList':
+ ) -> 'T':
"""
Load a DocList from a csv file following the schema defined in the
[`.doc_type`][docarray.DocList] attribute.
@@ -358,10 +360,10 @@ def from_csv(
:return: `DocList` object
"""
- if cls.doc_type == AnyDoc:
+ if cls.doc_type == AnyDoc or cls.doc_type == BaseDoc:
raise TypeError(
'There is no document schema defined. '
- 'Please specify the DocList\'s Document type using `DocList[MyDoc]`.'
+ f'Please specify the {cls}\'s Document type using `{cls}[MyDoc]`.'
)
if file_path.startswith('http'):
@@ -376,14 +378,14 @@ def from_csv(
@classmethod
def _from_csv_file(
- cls, file: Union[StringIO, TextIOWrapper], dialect: Union[str, csv.Dialect]
- ) -> 'DocList':
- from docarray import DocList
-
+ cls: Type['T'],
+ file: Union[StringIO, TextIOWrapper],
+ dialect: Union[str, csv.Dialect],
+ ) -> 'T':
rows = csv.DictReader(file, dialect=dialect)
doc_type = cls.doc_type
- docs = DocList.__class_getitem__(doc_type)()
+ docs = []
field_names: List[str] = (
[] if rows.fieldnames is None else [str(f) for f in rows.fieldnames]
@@ -405,7 +407,7 @@ def _from_csv_file(
doc_dict: Dict[Any, Any] = _access_path_dict_to_nested_dict(access_path2val)
docs.append(doc_type.parse_obj(doc_dict))
- return docs
+ return cls(docs)
def to_csv(
self, file_path: str, dialect: Union[str, csv.Dialect] = 'excel'
@@ -426,11 +428,11 @@ def to_csv(
`'unix'` (for csv file generated on UNIX systems).
"""
- if self.doc_type == AnyDoc:
+ if self.doc_type == AnyDoc or self.doc_type == BaseDoc:
raise TypeError(
- 'DocList must be homogeneous to be converted to a csv.'
+ f'{type(self)} must be homogeneous to be converted to a csv.'
'There is no document schema defined. '
- 'Please specify the DocList\'s Document type using `DocList[MyDoc]`.'
+ f'Please specify the {type(self)}\'s Document type using `{type(self)}[MyDoc]`.'
)
fields = self.doc_type._get_access_paths()
@@ -443,7 +445,7 @@ def to_csv(
writer.writerow(doc_dict)
@classmethod
- def from_dataframe(cls, df: 'pd.DataFrame') -> 'DocList':
+ def from_dataframe(cls: Type['T'], df: 'pd.DataFrame') -> 'T':
"""
Load a `DocList` from a `pandas.DataFrame` following the schema
defined in the [`.doc_type`][docarray.DocList] attribute.
@@ -486,10 +488,10 @@ class Person(BaseDoc):
"""
from docarray import DocList
- if cls.doc_type == AnyDoc:
+ if cls.doc_type == AnyDoc or cls.doc_type == BaseDoc:
raise TypeError(
'There is no document schema defined. '
- 'Please specify the DocList\'s Document type using `DocList[MyDoc]`.'
+ f'Please specify the {cls}\'s Document type using `{cls}[MyDoc]`.'
)
doc_type = cls.doc_type
@@ -555,7 +557,7 @@ def _stream_header(self) -> bytes:
# Binary format for streaming case
# V2 DocList streaming serialization format
- # | 1 byte | 8 bytes | 4 bytes | variable(docarray v2) | 4 bytes | variable(docarray v2) ...
+ # | 1 byte | 8 bytes | 4 bytes | variable(DocArray >=0.30) | 4 bytes | variable(DocArray >=0.30) ...
# 1 byte (uint8)
version_byte = b'\x02'
@@ -563,18 +565,25 @@ def _stream_header(self) -> bytes:
num_docs_as_bytes = len(self).to_bytes(8, 'big', signed=False)
return version_byte + num_docs_as_bytes
+ @classmethod
+ @abstractmethod
+ def _get_proto_class(cls: Type[T]):
+ ...
+
@classmethod
def _load_binary_all(
cls: Type[T],
file_ctx: Union[ContextManager[io.BufferedReader], ContextManager[bytes]],
- protocol: Optional[str],
+ protocol: Optional[ProtocolType],
compress: Optional[str],
show_progress: bool,
+ tensor_type: Optional[Type['AbstractTensor']] = None,
):
"""Read a `DocList` object from a binary file
:param protocol: protocol to use. It can be 'pickle-array', 'protobuf-array', 'pickle' or 'protobuf'
:param compress: compress algorithm to use between `lz4`, `bz2`, `lzma`, `zlib`, `gzip`
:param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf`
+ :param tensor_type: only relevant for DocVec; tensor_type of the DocVec
:return: a `DocList`
"""
with file_ctx as fp:
@@ -593,17 +602,23 @@ def _load_binary_all(
compress = None
if protocol is not None and protocol == 'protobuf-array':
- from docarray.proto import DocListProto
-
- dap = DocListProto()
- dap.ParseFromString(d)
+ proto = cls._get_proto_class()()
+ proto.ParseFromString(d)
- return cls.from_protobuf(dap)
+ if tensor_type is not None:
+ cls_ = cast('IOMixinDocVec', cls)
+ return cls_.from_protobuf(proto, tensor_type=tensor_type)
+ else:
+ return cls.from_protobuf(proto)
elif protocol is not None and protocol == 'pickle-array':
return pickle.loads(d)
elif protocol is not None and protocol == 'json-array':
- return cls.from_json(d)
+ if tensor_type is not None:
+ cls_ = cast('IOMixinDocVec', cls)
+ return cls_.from_json(d, tensor_type=tensor_type)
+ else:
+ return cls.from_json(d)
# Binary format for streaming case
else:
@@ -642,7 +657,9 @@ def _load_binary_all(
start_pos = end_doc_pos
# variable length bytes doc
- load_protocol: str = protocol or 'protobuf'
+ load_protocol: ProtocolType = protocol or cast(
+ ProtocolType, 'protobuf'
+ )
doc = cls.doc_type.from_bytes(
d[start_doc_pos:end_doc_pos],
protocol=load_protocol,
@@ -653,13 +670,17 @@ def _load_binary_all(
pbar.update(
t, advance=1, total_size=str(filesize.decimal(_total_size))
)
+ if tensor_type is not None:
+ cls__ = cast(Type['DocVec'], cls)
+ # mypy doesn't realize that cls_ is callable
+ return cls__(docs, tensor_type=tensor_type) # type: ignore
return cls(docs)
@classmethod
def _load_binary_stream(
cls: Type[T],
file_ctx: ContextManager[io.BufferedReader],
- protocol: str = 'protobuf',
+ protocol: ProtocolType = 'protobuf',
compress: Optional[str] = None,
show_progress: bool = False,
) -> Generator['T_doc', None, None]:
@@ -707,7 +728,7 @@ def _load_binary_stream(
len_current_doc_in_bytes = int.from_bytes(
f.read(4), 'big', signed=False
)
- load_protocol: str = protocol
+ load_protocol: ProtocolType = protocol
yield cls.doc_type.from_bytes(
f.read(len_current_doc_in_bytes),
protocol=load_protocol,
@@ -719,11 +740,34 @@ def _load_binary_stream(
t, advance=1, total_size=str(filesize.decimal(_total_size))
)
+ @staticmethod
+ def _get_file_context(
+ file: Union[str, bytes, pathlib.Path, io.BufferedReader, _LazyRequestReader],
+ protocol: ProtocolType,
+ compress: Optional[str] = None,
+ ) -> Tuple[
+ Union[nullcontext, io.BufferedReader], Optional[ProtocolType], Optional[str]
+ ]:
+ load_protocol: Optional[ProtocolType] = protocol
+ load_compress: Optional[str] = compress
+ file_ctx: Union[nullcontext, io.BufferedReader]
+ if isinstance(file, (io.BufferedReader, _LazyRequestReader, bytes)):
+ file_ctx = nullcontext(file)
+ # by checking path existence we allow file to be of type Path, LocalPath, PurePath and str
+ elif isinstance(file, (str, pathlib.Path)) and os.path.exists(file):
+ load_protocol, load_compress = _protocol_and_compress_from_file_path(
+ file, protocol, compress
+ )
+ file_ctx = open(file, 'rb')
+ else:
+ raise FileNotFoundError(f'cannot find file {file}')
+ return file_ctx, load_protocol, load_compress
+
@classmethod
def load_binary(
cls: Type[T],
file: Union[str, bytes, pathlib.Path, io.BufferedReader, _LazyRequestReader],
- protocol: str = 'protobuf-array',
+ protocol: ProtocolType = 'protobuf-array',
compress: Optional[str] = None,
show_progress: bool = False,
streaming: bool = False,
@@ -748,19 +792,9 @@ def load_binary(
:return: a `DocList` object
"""
- load_protocol: Optional[str] = protocol
- load_compress: Optional[str] = compress
- file_ctx: Union[nullcontext, io.BufferedReader]
- if isinstance(file, (io.BufferedReader, _LazyRequestReader, bytes)):
- file_ctx = nullcontext(file)
- # by checking path existence we allow file to be of type Path, LocalPath, PurePath and str
- elif isinstance(file, (str, pathlib.Path)) and os.path.exists(file):
- load_protocol, load_compress = _protocol_and_compress_from_file_path(
- file, protocol, compress
- )
- file_ctx = open(file, 'rb')
- else:
- raise FileNotFoundError(f'cannot find file {file}')
+ file_ctx, load_protocol, load_compress = cls._get_file_context(
+ file, protocol, compress
+ )
if streaming:
if load_protocol not in SINGLE_PROTOCOLS:
raise ValueError(
@@ -782,7 +816,7 @@ def load_binary(
def save_binary(
self,
file: Union[str, pathlib.Path],
- protocol: str = 'protobuf-array',
+ protocol: ProtocolType = 'protobuf-array',
compress: Optional[str] = None,
show_progress: bool = False,
) -> None:
diff --git a/docarray/array/doc_list/pushpull.py b/docarray/array/doc_list/pushpull.py
index 2bfe6764061..5784610633b 100644
--- a/docarray/array/doc_list/pushpull.py
+++ b/docarray/array/doc_list/pushpull.py
@@ -5,7 +5,6 @@
Dict,
Iterable,
Iterator,
- Optional,
Tuple,
Type,
TypeVar,
@@ -15,7 +14,7 @@
from typing_extensions import Literal
from typing_inspect import get_args
-PUSH_PULL_PROTOCOL = Literal['jac', 's3', 'file']
+PUSH_PULL_PROTOCOL = Literal['s3', 'file']
SUPPORTED_PUSH_PULL_PROTOCOLS = get_args(PUSH_PULL_PROTOCOL)
if TYPE_CHECKING: # pragma: no cover
@@ -55,18 +54,13 @@ def get_pushpull_backend(
"""
Get the backend for the given protocol.
- :param protocol: the protocol to use, e.g. 'jac', 'file', 's3'
+ :param protocol: the protocol to use, e.g. 'file', 's3'
:return: the backend class
"""
if protocol in cls.__backends__:
return cls.__backends__[protocol]
- if protocol == 'jac':
- from docarray.store.jac import JACDocStore
-
- cls.__backends__[protocol] = JACDocStore
- logging.debug('Loaded Jina AI Cloud backend')
- elif protocol == 'file':
+ if protocol == 'file':
from docarray.store.file import FileDocStore
cls.__backends__[protocol] = FileDocStore
@@ -84,22 +78,18 @@ def get_pushpull_backend(
def push(
self,
url: str,
- public: bool = True,
show_progress: bool = False,
- branding: Optional[Dict] = None,
+ **kwargs,
) -> Dict:
"""Push this `DocList` object to the specified url.
:param url: url specifying the protocol and save name of the `DocList`. Should be of the form ``protocol://namespace/name``. e.g. ``s3://bucket/path/to/namespace/name``, ``file:///path/to/folder/name``
- :param public: Only used by ``jac`` protocol. If true, anyone can pull a `DocList` if they know its name.
- Setting this to false will restrict access to only the creator.
:param show_progress: If true, a progress bar will be displayed.
- :param branding: Only used by ``jac`` protocol. A dictionary of branding information to be sent to Jina AI Cloud. {"icon": "emoji", "background": "#fff"}
"""
logging.info(f'Pushing {len(self)} docs to {url}')
protocol, name = self.__class__.resolve_url(url)
return self.__class__.get_pushpull_backend(protocol).push(
- self, name, public, show_progress, branding # type: ignore
+ self, name, show_progress # type: ignore
)
@classmethod
@@ -107,23 +97,17 @@ def push_stream(
cls: Type[SelfPushPullMixin],
docs: Iterator['BaseDoc'],
url: str,
- public: bool = True,
show_progress: bool = False,
- branding: Optional[Dict] = None,
) -> Dict:
"""Push a stream of documents to the specified url.
:param docs: a stream of documents
:param url: url specifying the protocol and save name of the `DocList`. Should be of the form ``protocol://namespace/name``. e.g. ``s3://bucket/path/to/namespace/name``, ``file:///path/to/folder/name``
- :param public: Only used by ``jac`` protocol. If true, anyone can pull a `DocList` if they know its name.
:param show_progress: If true, a progress bar will be displayed.
- :param branding: Only used by ``jac`` protocol. A dictionary of branding information to be sent to Jina AI Cloud. {"icon": "emoji", "background": "#fff"}
"""
logging.info(f'Pushing stream to {url}')
protocol, name = cls.resolve_url(url)
- return cls.get_pushpull_backend(protocol).push_stream(
- docs, name, public, show_progress, branding
- )
+ return cls.get_pushpull_backend(protocol).push_stream(docs, name, show_progress)
@classmethod
def pull(
diff --git a/docarray/array/doc_vec/__init__.py b/docarray/array/doc_vec/__init__.py
index e69de29bb2d..74f8f7582cd 100644
--- a/docarray/array/doc_vec/__init__.py
+++ b/docarray/array/doc_vec/__init__.py
@@ -0,0 +1,15 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/docarray/array/doc_vec/column_storage.py b/docarray/array/doc_vec/column_storage.py
index cd55fb63ea7..ea5da4291fa 100644
--- a/docarray/array/doc_vec/column_storage.py
+++ b/docarray/array/doc_vec/column_storage.py
@@ -3,12 +3,15 @@
TYPE_CHECKING,
Any,
Dict,
+ ItemsView,
Iterable,
MutableMapping,
+ NamedTuple,
Optional,
Type,
TypeVar,
Union,
+ ValuesView,
)
from docarray.array.list_advance_indexing import ListAdvancedIndexing
@@ -24,6 +27,13 @@
T = TypeVar('T', bound='ColumnStorage')
+class ColumnsJsonCompatible(NamedTuple):
+ tensor_columns: Dict[str, Any]
+ doc_columns: Dict[str, Any]
+ docs_vec_columns: Dict[str, Any]
+ any_columns: Dict[str, Any]
+
+
class ColumnStorage:
"""
ColumnStorage is a container to store the columns of the
@@ -89,6 +99,48 @@ def __getitem__(self: T, item: IndexIterType) -> T:
self.tensor_type,
)
+ def columns_json_compatible(self) -> ColumnsJsonCompatible:
+ tens_cols = {
+ key: value._docarray_to_json_compatible() if value is not None else value
+ for key, value in self.tensor_columns.items()
+ }
+ doc_cols = {
+ key: value._docarray_to_json_compatible() if value is not None else value
+ for key, value in self.doc_columns.items()
+ }
+ doc_vec_cols = {
+ key: [vec._docarray_to_json_compatible() for vec in value]
+ if value is not None
+ else value
+ for key, value in self.docs_vec_columns.items()
+ }
+ return ColumnsJsonCompatible(
+ tens_cols, doc_cols, doc_vec_cols, self.any_columns
+ )
+
+ def __eq__(self, other: Any) -> bool:
+ if not isinstance(other, ColumnStorage):
+ return False
+ if self.tensor_type != other.tensor_type:
+ return False
+ for col_map_self, col_map_other in zip(self.columns.maps, other.columns.maps):
+ if col_map_self.keys() != col_map_other.keys():
+ return False
+ for key_self in col_map_self.keys():
+ if key_self == 'id':
+ continue
+
+ val1, val2 = col_map_self[key_self], col_map_other[key_self]
+ if isinstance(val1, AbstractTensor):
+ values_are_equal = val1.get_comp_backend().equal(val1, val2)
+ elif isinstance(val2, AbstractTensor):
+ values_are_equal = val2.get_comp_backend().equal(val1, val2)
+ else:
+ values_are_equal = val1 == val2
+ if not values_are_equal:
+ return False
+ return True
+
class ColumnStorageView(dict, MutableMapping[str, Any]):
index: int
@@ -121,6 +173,11 @@ def __getitem__(self, name: str) -> Any:
return None
return col[self.index]
+ def __reduce__(self):
+ # implementing __reduce__ to solve a pickle issue when subclassing dict
+ # see here: https://stackoverflow.com/questions/21144845/how-can-i-unpickle-a-subclass-of-dict-that-validates-with-setitem-in-pytho
+ return (ColumnStorageView, (self.index, self.storage))
+
def __setitem__(self, name, value) -> None:
if self.storage.columns[name] is None:
raise ValueError(
@@ -140,3 +197,29 @@ def __iter__(self):
def __len__(self):
return len(self.storage.columns)
+
+ def _local_dict(self):
+ """The storage.columns dictionary with every value at position self.index"""
+
+ return {key: self[key] for key in self.storage.columns.keys()}
+
+ def keys(self):
+ return self.storage.columns.keys()
+
+ # type ignore because return type dict_values is private and we cannot use it.
+ # context: https://github.com/python/typing/discussions/1033
+ def values(self) -> ValuesView: # type: ignore
+ return ValuesView(self._local_dict())
+
+ # type ignore because return type dict_items is private and we cannot use it.
+ # context: https://github.com/python/typing/discussions/1033
+ def items(self) -> ItemsView: # type: ignore
+ return ItemsView(self._local_dict())
+
+ def to_dict(self) -> Dict[str, Any]:
+ """
+ Return a dictionary with the same keys as the storage.columns
+ and the values at position self.index.
+ Warning: modification on the dict will not be reflected on the storage.
+ """
+ return {key: self[key] for key in self.storage.columns.keys()}
diff --git a/docarray/array/doc_vec/doc_vec.py b/docarray/array/doc_vec/doc_vec.py
index e27f6882fe9..0cc462f173d 100644
--- a/docarray/array/doc_vec/doc_vec.py
+++ b/docarray/array/doc_vec/doc_vec.py
@@ -1,6 +1,5 @@
from collections import ChainMap
from typing import (
- TYPE_CHECKING,
Any,
Dict,
Iterable,
@@ -17,23 +16,29 @@
overload,
)
-from pydantic import BaseConfig, parse_obj_as
+from pydantic import parse_obj_as
+from typing_inspect import typingGenericAlias
from docarray.array.any_array import AnyDocArray
from docarray.array.doc_list.doc_list import DocList
from docarray.array.doc_vec.column_storage import ColumnStorage, ColumnStorageView
+from docarray.array.doc_vec.io import IOMixinDocVec
from docarray.array.list_advance_indexing import ListAdvancedIndexing
from docarray.base_doc import AnyDoc, BaseDoc
-from docarray.base_doc.mixins.io import _type_to_protobuf
from docarray.typing import NdArray
from docarray.typing.tensor.abstract_tensor import AbstractTensor
-from docarray.utils._internal._typing import is_tensor_union
-from docarray.utils._internal.misc import is_tf_available, is_torch_available
+from docarray.utils._internal.pydantic import is_pydantic_v2
-if TYPE_CHECKING:
- from pydantic.fields import ModelField
+if is_pydantic_v2:
+ from pydantic import GetCoreSchemaHandler
+ from pydantic_core import core_schema
- from docarray.proto import DocVecProto
+from docarray.utils._internal._typing import is_tensor_union, safe_issubclass
+from docarray.utils._internal.misc import (
+ is_jax_available,
+ is_tf_available,
+ is_torch_available,
+)
torch_available = is_torch_available()
if torch_available:
@@ -49,12 +54,23 @@
else:
TensorFlowTensor = None # type: ignore
+jnp_available = is_jax_available()
+if jnp_available:
+ import jax.numpy as jnp # type: ignore
+
+ from docarray.typing import JaxArray # noqa: F401
+else:
+ JaxArray = None # type: ignore
+
T_doc = TypeVar('T_doc', bound=BaseDoc)
T = TypeVar('T', bound='DocVec')
+T_io_mixin = TypeVar('T_io_mixin', bound='IOMixinDocVec')
+
IndexIterType = Union[slice, Iterable[int], Iterable[bool], None]
-class DocVec(AnyDocArray[T_doc]):
+# type ignore because from_protobuf has a different signature
+class DocVec(IOMixinDocVec, AnyDocArray[T_doc]): # type: ignore
"""
DocVec is a container of Documents appropriates to perform
computation that require batches of data (ex: matrix multiplication, distance
@@ -99,7 +115,7 @@ class DocVec(AnyDocArray[T_doc]):
AnyTensor or Union of NdArray and TorchTensor
"""
- doc_type: Type[T_doc]
+ doc_type: Type[T_doc] = BaseDoc # type: ignore
def __init__(
self: T,
@@ -107,12 +123,17 @@ def __init__(
tensor_type: Type['AbstractTensor'] = NdArray,
):
- if not hasattr(self, 'doc_type') or self.doc_type == AnyDoc:
+ if (
+ not hasattr(self, 'doc_type')
+ or self.doc_type == AnyDoc
+ or self.doc_type == BaseDoc
+ ):
raise TypeError(
f'{self.__class__.__name__} does not precise a doc_type. You probably should do'
f'docs = DocVec[MyDoc](docs) instead of DocVec(docs)'
)
self.tensor_type = tensor_type
+ self._is_unusable = False
tensor_columns: Dict[str, Optional[AbstractTensor]] = dict()
doc_columns: Dict[str, Optional['DocVec']] = dict()
@@ -127,12 +148,15 @@ def __init__(
else DocList.__class_getitem__(self.doc_type)(docs)
)
- for field_name, field in self.doc_type.__fields__.items():
+ for field_name, field in self.doc_type._docarray_fields().items():
# here we iterate over the field of the docs schema, and we collect the data
# from each document and put them in the corresponding column
- field_type = self.doc_type._get_field_type(field_name)
+ field_type: Type = self.doc_type._get_field_annotation(field_name)
- is_field_required = self.doc_type.__fields__[field_name].required
+ field_info = self.doc_type._docarray_fields()[field_name]
+ is_field_required = (
+ field_info.is_required() if is_pydantic_v2 else field_info.required
+ )
first_doc_is_none = getattr(docs[0], field_name) is None
@@ -148,7 +172,9 @@ def _verify_optional_field_of_docs(docs):
for i, doc in enumerate(docs):
if getattr(doc, field_name) is not None:
raise ValueError(
- f'Field {field_name} is put to None for the first doc. This mean that all of the other docs should have this field set to None as well. This is not the case for {doc} at index {i}'
+ f'Field {field_name} is put to None for the first doc. This mean that '
+ f'all of the other docs should have this field set to None as well. '
+ f'This is not the case for {doc} at index {i}'
)
def _check_doc_field_not_none(field_name, doc):
@@ -159,9 +185,21 @@ def _check_doc_field_not_none(field_name, doc):
if is_tensor_union(field_type):
field_type = tensor_type
-
- if isinstance(field_type, type):
- if tf_available and issubclass(field_type, TensorFlowTensor):
+ # all generic tensor types such as AnyTensor, ImageTensor, etc. are subclasses of AbstractTensor.
+ # Perform check only if the field_type is not an alias and is a subclass of AbstractTensor
+ elif not isinstance(field_type, typingGenericAlias) and safe_issubclass(
+ field_type, AbstractTensor
+ ):
+ # check if the tensor associated with the field_name in the document is a subclass of the tensor_type
+ # e.g. if the field_type is AnyTensor but the type(docs[0][field_name]) is ImageTensor,
+ # then we change the field_type to ImageTensor, since AnyTensor is a union of all the tensor types
+ # and does not override any methods of specific tensor types
+ tensor = getattr(docs[0], field_name)
+ if safe_issubclass(tensor.__class__, tensor_type):
+ field_type = tensor_type
+
+ if isinstance(field_type, type) or safe_issubclass(field_type, AnyDocArray):
+ if tf_available and safe_issubclass(field_type, TensorFlowTensor):
# tf.Tensor does not allow item assignment, therefore the
# optimized way
# of initializing an empty array and assigning values to it
@@ -180,8 +218,21 @@ def _check_doc_field_not_none(field_name, doc):
stacked: tf.Tensor = tf.stack(tf_stack)
tensor_columns[field_name] = TensorFlowTensor(stacked)
+ elif jnp_available and safe_issubclass(field_type, JaxArray):
+ if first_doc_is_none:
+ _verify_optional_field_of_docs(docs)
+ tensor_columns[field_name] = None
+ else:
+ tf_stack = []
+ for i, doc in enumerate(docs):
+ val = getattr(doc, field_name)
+ _check_doc_field_not_none(field_name, doc)
+ tf_stack.append(val.tensor)
+
+ jax_stacked: jnp.ndarray = jnp.stack(tf_stack)
+ tensor_columns[field_name] = JaxArray(jax_stacked)
- elif issubclass(field_type, AbstractTensor):
+ elif safe_issubclass(field_type, AbstractTensor):
if first_doc_is_none:
_verify_optional_field_of_docs(docs)
tensor_columns[field_name] = None
@@ -209,7 +260,7 @@ def _check_doc_field_not_none(field_name, doc):
val = getattr(doc, field_name)
cast(AbstractTensor, tensor_columns[field_name])[i] = val
- elif issubclass(field_type, BaseDoc):
+ elif safe_issubclass(field_type, BaseDoc):
if first_doc_is_none:
_verify_optional_field_of_docs(docs)
doc_columns[field_name] = None
@@ -225,10 +276,10 @@ def _check_doc_field_not_none(field_name, doc):
tensor_type=self.tensor_type
)
- elif issubclass(field_type, AnyDocArray):
+ elif safe_issubclass(field_type, AnyDocArray):
if first_doc_is_none:
_verify_optional_field_of_docs(docs)
- doc_columns[field_name] = None
+ docs_vec_columns[field_name] = None
else:
docs_list = list()
for doc in docs:
@@ -270,15 +321,23 @@ def from_columns_storage(cls: Type[T], storage: ColumnStorage) -> T:
return docs
@classmethod
- def validate(
+ def _docarray_validate(
cls: Type[T],
value: Union[T, Iterable[T_doc]],
- field: 'ModelField',
- config: 'BaseConfig',
) -> T:
if isinstance(value, cls):
return value
- elif isinstance(value, DocList.__class_getitem__(cls.doc_type)):
+ elif isinstance(value, DocList):
+ if (
+ safe_issubclass(value.doc_type, cls.doc_type)
+ or value.doc_type == cls.doc_type
+ ):
+ return cast(T, value.to_doc_vec())
+ else:
+ raise ValueError(f'DocVec[value.doc_type] is not compatible with {cls}')
+ elif not is_pydantic_v2 and isinstance(
+ value, DocList.__class_getitem__(cls.doc_type)
+ ):
return cast(T, value.to_doc_vec())
elif isinstance(value, Sequence):
return cls(value)
@@ -391,7 +450,7 @@ def _set_data_and_columns(
# set data and prepare columns
processed_value: T
if isinstance(value, DocList):
- if not issubclass(value.doc_type, self.doc_type):
+ if not safe_issubclass(value.doc_type, self.doc_type):
raise TypeError(
f'{value} schema : {value.doc_type} is not compatible with '
f'this DocVec schema : {self.doc_type}'
@@ -401,7 +460,7 @@ def _set_data_and_columns(
) # we need to copy data here
elif isinstance(value, DocVec):
- if not issubclass(value.doc_type, self.doc_type):
+ if not safe_issubclass(value.doc_type, self.doc_type):
raise TypeError(
f'{value} schema : {value.doc_type} is not compatible with '
f'this DocVec schema : {self.doc_type}'
@@ -457,7 +516,7 @@ def _set_data_column(
if col is not None:
validation_class = col.__unparametrizedcls__ or col.__class__
else:
- validation_class = self.doc_type.__fields__[field].type_
+ validation_class = self.doc_type._get_field_annotation(field)
# TODO shape check should be handle by the tensor validation
@@ -466,7 +525,9 @@ def _set_data_column(
elif field in self._storage.doc_columns.keys():
values_ = parse_obj_as(
- DocVec.__class_getitem__(self.doc_type._get_field_type(field)),
+ DocVec.__class_getitem__(
+ self.doc_type._get_field_annotation(field)
+ ),
values,
)
self._storage.doc_columns[field] = values_
@@ -505,67 +566,30 @@ def __iter__(self):
def __len__(self):
return len(self._storage)
+ def __eq__(self, other: Any) -> bool:
+ if not isinstance(other, DocVec):
+ return False
+ if self.doc_type != other.doc_type:
+ return False
+ if self.tensor_type != other.tensor_type:
+ return False
+ if self._storage != other._storage:
+ return False
+ return True
+
####################
# IO related #
####################
@classmethod
- def from_protobuf(cls: Type[T], pb_msg: 'DocVecProto') -> T:
- """create a Document from a protobuf message"""
- storage = ColumnStorage(
- pb_msg.tensor_columns,
- pb_msg.doc_columns,
- pb_msg.docs_vec_columns,
- pb_msg.any_columns,
- )
-
- return cls.from_columns_storage(storage)
-
- def to_protobuf(self) -> 'DocVecProto':
- """Convert DocVec into a Protobuf message"""
- from docarray.proto import (
- DocListProto,
- DocVecProto,
- ListOfAnyProto,
- ListOfDocArrayProto,
- NdArrayProto,
- )
+ def _get_proto_class(cls: Type[T]):
+ from docarray.proto import DocVecProto
- da_proto = DocListProto()
- for doc in self:
- da_proto.docs.append(doc.to_protobuf())
+ return DocVecProto
- doc_columns_proto: Dict[str, DocVecProto] = dict()
- tensor_columns_proto: Dict[str, NdArrayProto] = dict()
- da_columns_proto: Dict[str, ListOfDocArrayProto] = dict()
- any_columns_proto: Dict[str, ListOfAnyProto] = dict()
-
- for field, col_doc in self._storage.doc_columns.items():
- doc_columns_proto[field] = (
- col_doc.to_protobuf() if col_doc is not None else None
- )
- for field, col_tens in self._storage.tensor_columns.items():
- tensor_columns_proto[field] = (
- col_tens.to_protobuf() if col_tens is not None else None
- )
- for field, col_da in self._storage.docs_vec_columns.items():
- list_proto = ListOfDocArrayProto()
- if col_da:
- for docs in col_da:
- list_proto.data.append(docs.to_protobuf())
- da_columns_proto[field] = list_proto
- for field, col_any in self._storage.any_columns.items():
- list_proto = ListOfAnyProto()
- for data in col_any:
- list_proto.data.append(_type_to_protobuf(data))
- any_columns_proto[field] = list_proto
-
- return DocVecProto(
- doc_columns=doc_columns_proto,
- tensor_columns=tensor_columns_proto,
- docs_vec_columns=da_columns_proto,
- any_columns=any_columns_proto,
- )
+ def _docarray_to_json_compatible(self) -> Dict[str, Dict[str, Any]]:
+ tup = self._storage.columns_json_compatible()
+ return tup._asdict()
def to_doc_list(self: T) -> DocList[T_doc]:
"""Convert DocVec into a DocList.
@@ -582,7 +606,6 @@ def to_doc_list(self: T) -> DocList[T_doc]:
unstacked_doc_column[field] = doc_col.to_doc_list() if doc_col else None
for field, da_col in self._storage.docs_vec_columns.items():
-
unstacked_da_column[field] = (
[docs.to_doc_list() for docs in da_col] if da_col else None
)
@@ -613,7 +636,14 @@ def to_doc_list(self: T) -> DocList[T_doc]:
del self._storage
- return DocList.__class_getitem__(self.doc_type).construct(docs)
+ doc_type = self.doc_type
+
+ # Setting _is_unusable will raise an Exception if someone interacts with this instance from hereon out.
+ # I don't like relying on this state, but we can't override the getattr/setattr directly:
+ # https://stackoverflow.com/questions/10376604/overriding-special-methods-on-an-instance
+ self._is_unusable = True
+
+ return DocList.__class_getitem__(doc_type).construct(docs)
def traverse_flat(
self,
@@ -628,3 +658,18 @@ def traverse_flat(
return flattened[0]
else:
return flattened
+
+ @classmethod
+ def __class_getitem__(cls, item: Union[Type[BaseDoc], TypeVar, str]):
+ # call implementation in AnyDocArray
+ return super(IOMixinDocVec, cls).__class_getitem__(item)
+
+ if is_pydantic_v2:
+
+ @classmethod
+ def __get_pydantic_core_schema__(
+ cls, _source_type: Any, _handler: GetCoreSchemaHandler
+ ) -> core_schema.CoreSchema:
+ return core_schema.general_plain_validator_function(
+ cls.validate,
+ )
diff --git a/docarray/array/doc_vec/io.py b/docarray/array/doc_vec/io.py
new file mode 100644
index 00000000000..dd7213252fa
--- /dev/null
+++ b/docarray/array/doc_vec/io.py
@@ -0,0 +1,507 @@
+import base64
+import io
+import pathlib
+from abc import abstractmethod
+from contextlib import nullcontext
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ Generator,
+ Optional,
+ Type,
+ TypeVar,
+ Union,
+ cast,
+)
+
+import numpy as np
+import orjson
+from pydantic import parse_obj_as
+
+from docarray.array.doc_list.io import (
+ SINGLE_PROTOCOLS,
+ IOMixinDocList,
+ _LazyRequestReader,
+)
+from docarray.array.doc_vec.column_storage import ColumnStorage
+from docarray.array.list_advance_indexing import ListAdvancedIndexing
+from docarray.base_doc import BaseDoc
+from docarray.base_doc.mixins.io import _type_to_protobuf
+from docarray.typing import NdArray
+from docarray.typing.tensor.abstract_tensor import AbstractTensor
+from docarray.utils._internal.pydantic import is_pydantic_v2
+from docarray.utils._internal.misc import ProtocolType
+
+if TYPE_CHECKING:
+ import csv
+
+ import pandas as pd
+
+ from docarray.array.doc_vec.doc_vec import DocVec
+ from docarray.proto import (
+ DocVecProto,
+ ListOfDocArrayProto,
+ ListOfDocVecProto,
+ NdArrayProto,
+ )
+
+
+T = TypeVar('T', bound='IOMixinDocVec')
+T_doc = TypeVar('T_doc', bound=BaseDoc)
+
+NONE_NDARRAY_PROTO_SHAPE = (0,)
+NONE_NDARRAY_PROTO_DTYPE = 'None'
+
+
+def _none_ndarray_proto() -> 'NdArrayProto':
+ from docarray.proto import NdArrayProto
+
+ zeros_arr = parse_obj_as(NdArray, np.zeros(NONE_NDARRAY_PROTO_SHAPE))
+ nd_proto = NdArrayProto()
+ nd_proto.dense.buffer = zeros_arr.tobytes()
+ nd_proto.dense.ClearField('shape')
+ nd_proto.dense.shape.extend(list(zeros_arr.shape))
+ nd_proto.dense.dtype = NONE_NDARRAY_PROTO_DTYPE
+
+ return nd_proto
+
+
+def _none_docvec_proto() -> 'DocVecProto':
+ from docarray.proto import DocVecProto
+
+ return DocVecProto()
+
+
+def _none_list_of_docvec_proto() -> 'ListOfDocArrayProto':
+ from docarray.proto import ListOfDocVecProto
+
+ return ListOfDocVecProto()
+
+
+def _is_none_ndarray_proto(proto: 'NdArrayProto') -> bool:
+ return (
+ proto.dense.shape == list(NONE_NDARRAY_PROTO_SHAPE)
+ and proto.dense.dtype == NONE_NDARRAY_PROTO_DTYPE
+ )
+
+
+def _is_none_docvec_proto(proto: 'DocVecProto') -> bool:
+ return (
+ proto.tensor_columns == {}
+ and proto.doc_columns == {}
+ and proto.docs_vec_columns == {}
+ and proto.any_columns == {}
+ )
+
+
+def _is_none_list_of_docvec_proto(proto: 'ListOfDocVecProto') -> bool:
+ from docarray.proto import ListOfDocVecProto
+
+ return isinstance(proto, ListOfDocVecProto) and len(proto.data) == 0
+
+
+class IOMixinDocVec(IOMixinDocList):
+ @classmethod
+ @abstractmethod
+ def from_columns_storage(cls: Type[T], storage: ColumnStorage) -> T:
+ ...
+
+ @classmethod
+ @abstractmethod
+ def __class_getitem__(cls, item: Union[Type[BaseDoc], TypeVar, str]):
+ ...
+
+ @classmethod
+ def from_json(
+ cls: Type[T],
+ file: Union[str, bytes, bytearray],
+ tensor_type: Type[AbstractTensor] = NdArray,
+ ) -> T:
+ """Deserialize JSON strings or bytes into a `DocList`.
+
+ :param file: JSON object from where to deserialize a `DocList`
+ :param tensor_type: the tensor type to use for the tensor columns.
+ Could be NdArray, TorchTensor, or TensorFlowTensor. Defaults to NdArray.
+ All tensors of the output DocVec will be of this type.
+ :return: the deserialized `DocList`
+ """
+ json_columns = orjson.loads(file)
+ return cls._from_json_col_dict(json_columns, tensor_type=tensor_type)
+
+ @classmethod
+ def _from_json_col_dict(
+ cls: Type[T],
+ json_columns: Dict[str, Any],
+ tensor_type: Type[AbstractTensor] = NdArray,
+ ) -> T:
+ tensor_cols = json_columns['tensor_columns']
+ doc_cols = json_columns['doc_columns']
+ docs_vec_cols = json_columns['docs_vec_columns']
+ any_cols = json_columns['any_columns']
+
+ for key, col in tensor_cols.items():
+ if col is not None:
+ tensor_cols[key] = parse_obj_as(tensor_type, col)
+ else:
+ tensor_cols[key] = None
+
+ for key, col in doc_cols.items():
+ if col is not None:
+ col_doc_type = cls.doc_type._get_field_annotation(key)
+ doc_cols[key] = cls.__class_getitem__(col_doc_type)._from_json_col_dict(
+ col, tensor_type=tensor_type
+ )
+ else:
+ doc_cols[key] = None
+
+ for key, col in docs_vec_cols.items():
+ if col is not None:
+ col_doc_type = cls.doc_type._get_field_annotation(key).doc_type
+ col_ = ListAdvancedIndexing(
+ cls.__class_getitem__(col_doc_type)._from_json_col_dict(
+ vec, tensor_type=tensor_type
+ )
+ for vec in col
+ )
+ docs_vec_cols[key] = col_
+ else:
+ docs_vec_cols[key] = None
+
+ for key, col in any_cols.items():
+ if col is not None:
+ col_type = cls.doc_type._get_field_annotation(key)
+
+ field_required = (
+ cls.doc_type._docarray_fields()[key].is_required()
+ if is_pydantic_v2
+ else cls.doc_type._docarray_fields()[key].required
+ )
+
+ col_type = col_type if field_required else Optional[col_type]
+ col_ = ListAdvancedIndexing(parse_obj_as(col_type, val) for val in col)
+ any_cols[key] = col_
+ else:
+ any_cols[key] = None
+
+ return cls.from_columns_storage(
+ ColumnStorage(
+ tensor_cols, doc_cols, docs_vec_cols, any_cols, tensor_type=tensor_type
+ )
+ )
+
+ @classmethod
+ def from_protobuf(
+ cls: Type[T], pb_msg: 'DocVecProto', tensor_type: Type[AbstractTensor] = NdArray
+ ) -> T:
+ """create a DocVec from a protobuf message
+ :param pb_msg: the protobuf message to deserialize
+ :param tensor_type: the tensor type to use for the tensor columns.
+ Could be NdArray, TorchTensor, or TensorFlowTensor. Defaults to NdArray.
+ All tensors of the output DocVec will be of this type.
+ :return: The deserialized DocVec
+ """
+ tensor_columns: Dict[str, Optional[AbstractTensor]] = {}
+ doc_columns: Dict[str, Optional['DocVec']] = {}
+ docs_vec_columns: Dict[str, Optional[ListAdvancedIndexing['DocVec']]] = {}
+ any_columns: Dict[str, ListAdvancedIndexing] = {}
+
+ for tens_col_name, tens_col_proto in pb_msg.tensor_columns.items():
+ if _is_none_ndarray_proto(tens_col_proto):
+ # handle values that were None before serialization
+ tensor_columns[tens_col_name] = None
+ else:
+ tensor_columns[tens_col_name] = tensor_type.from_protobuf(
+ tens_col_proto
+ )
+
+ for doc_col_name, doc_col_proto in pb_msg.doc_columns.items():
+ if _is_none_docvec_proto(doc_col_proto):
+ # handle values that were None before serialization
+ doc_columns[doc_col_name] = None
+ else:
+ col_doc_type: Type = cls.doc_type._get_field_annotation(doc_col_name)
+ doc_columns[doc_col_name] = cls.__class_getitem__(
+ col_doc_type
+ ).from_protobuf(doc_col_proto, tensor_type=tensor_type)
+
+ for docs_vec_col_name, docs_vec_col_proto in pb_msg.docs_vec_columns.items():
+ vec_list: Optional[ListAdvancedIndexing]
+ if _is_none_list_of_docvec_proto(docs_vec_col_proto):
+ # handle values that were None before serialization
+ vec_list = None
+ else:
+ vec_list = ListAdvancedIndexing()
+ for doc_list_proto in docs_vec_col_proto.data:
+ col_doc_type = cls.doc_type._get_field_annotation(
+ docs_vec_col_name
+ ).doc_type
+ vec_list.append(
+ cls.__class_getitem__(col_doc_type).from_protobuf(
+ doc_list_proto, tensor_type=tensor_type
+ )
+ )
+ docs_vec_columns[docs_vec_col_name] = vec_list
+
+ for any_col_name, any_col_proto in pb_msg.any_columns.items():
+ any_column: ListAdvancedIndexing = ListAdvancedIndexing()
+ for node_proto in any_col_proto.data:
+ content = cls.doc_type._get_content_from_node_proto(
+ node_proto, any_col_name
+ )
+ any_column.append(content)
+ any_columns[any_col_name] = any_column
+
+ storage = ColumnStorage(
+ tensor_columns=tensor_columns,
+ doc_columns=doc_columns,
+ docs_vec_columns=docs_vec_columns,
+ any_columns=any_columns,
+ tensor_type=tensor_type,
+ )
+
+ return cls.from_columns_storage(storage)
+
+ def to_protobuf(self) -> 'DocVecProto':
+ """Convert DocVec into a Protobuf message"""
+ from docarray.proto import (
+ DocVecProto,
+ ListOfAnyProto,
+ ListOfDocArrayProto,
+ ListOfDocVecProto,
+ NdArrayProto,
+ )
+
+ self_ = cast('DocVec', self)
+
+ doc_columns_proto: Dict[str, DocVecProto] = dict()
+ tensor_columns_proto: Dict[str, NdArrayProto] = dict()
+ da_columns_proto: Dict[str, ListOfDocArrayProto] = dict()
+ any_columns_proto: Dict[str, ListOfAnyProto] = dict()
+
+ for field, col_doc in self_._storage.doc_columns.items():
+ if col_doc is None:
+ # put dummy empty DocVecProto for serialization
+ doc_columns_proto[field] = _none_docvec_proto()
+ else:
+ doc_columns_proto[field] = col_doc.to_protobuf()
+ for field, col_tens in self_._storage.tensor_columns.items():
+ if col_tens is None:
+ # put dummy empty NdArrayProto for serialization
+ tensor_columns_proto[field] = _none_ndarray_proto()
+ else:
+ tensor_columns_proto[field] = (
+ col_tens.to_protobuf() if col_tens is not None else None
+ )
+ for field, col_da in self_._storage.docs_vec_columns.items():
+ list_proto = ListOfDocVecProto()
+ if col_da:
+ for docs in col_da:
+ list_proto.data.append(docs.to_protobuf())
+ else:
+ # put dummy empty ListOfDocVecProto for serialization
+ list_proto = _none_list_of_docvec_proto()
+ da_columns_proto[field] = list_proto
+ for field, col_any in self_._storage.any_columns.items():
+ list_proto = ListOfAnyProto()
+ for data in col_any:
+ list_proto.data.append(_type_to_protobuf(data))
+ any_columns_proto[field] = list_proto
+
+ return DocVecProto(
+ doc_columns=doc_columns_proto,
+ tensor_columns=tensor_columns_proto,
+ docs_vec_columns=da_columns_proto,
+ any_columns=any_columns_proto,
+ )
+
+ def to_csv(
+ self, file_path: str, dialect: Union[str, 'csv.Dialect'] = 'excel'
+ ) -> None:
+ """
+ DocVec does not support `.to_csv()`. This is because CSV is a row-based format
+ while DocVec has a column-based data layout.
+ To overcome this, do: `doc_vec.to_doc_list().to_csv(...)`.
+ """
+ raise NotImplementedError(
+ f'{type(self)} does not support `.to_csv()`. This is because CSV is a row-based format'
+ f'while {type(self)} has a column-based data layout. '
+ f'To overcome this, do: `doc_vec.to_doc_list().to_csv(...)`.'
+ )
+
+ @classmethod
+ def from_csv(
+ cls: Type['T'],
+ file_path: str,
+ encoding: str = 'utf-8',
+ dialect: Union[str, 'csv.Dialect'] = 'excel',
+ ) -> 'T':
+ """
+ DocVec does not support `.from_csv()`. This is because CSV is a row-based format
+ while DocVec has a column-based data layout.
+ To overcome this, do: `DocList[MyDoc].from_csv(...).to_doc_vec()`.
+ """
+ raise NotImplementedError(
+ f'{cls} does not support `.from_csv()`. This is because CSV is a row-based format while'
+ f'{cls} has a column-based data layout. '
+ f'To overcome this, do: `DocList[MyDoc].from_csv(...).to_doc_vec()`.'
+ )
+
+ @classmethod
+ def from_base64(
+ cls: Type[T],
+ data: str,
+ protocol: ProtocolType = 'protobuf-array',
+ compress: Optional[str] = None,
+ show_progress: bool = False,
+ tensor_type: Type['AbstractTensor'] = NdArray,
+ ) -> T:
+ """Deserialize base64 strings into a `DocVec`.
+
+ :param data: Base64 string to deserialize
+ :param protocol: protocol that was used to serialize
+ :param compress: compress algorithm that was used to serialize between `lz4`, `bz2`, `lzma`, `zlib`, `gzip`
+ :param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf`
+ :param tensor_type: the tensor type of the resulting DocVEc
+ :return: the deserialized `DocVec`
+ """
+ return cls._load_binary_all(
+ file_ctx=nullcontext(base64.b64decode(data)),
+ protocol=protocol,
+ compress=compress,
+ show_progress=show_progress,
+ tensor_type=tensor_type,
+ )
+
+ @classmethod
+ def from_bytes(
+ cls: Type[T],
+ data: bytes,
+ protocol: ProtocolType = 'protobuf-array',
+ compress: Optional[str] = None,
+ show_progress: bool = False,
+ tensor_type: Type['AbstractTensor'] = NdArray,
+ ) -> T:
+ """Deserialize bytes into a `DocList`.
+
+ :param data: Bytes from which to deserialize
+ :param protocol: protocol that was used to serialize
+ :param compress: compression algorithm that was used to serialize between `lz4`, `bz2`, `lzma`, `zlib`, `gzip`
+ :param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf`
+ :param tensor_type: the tensor type of the resulting DocVec
+ :return: the deserialized `DocVec`
+ """
+ return cls._load_binary_all(
+ file_ctx=nullcontext(data),
+ protocol=protocol,
+ compress=compress,
+ show_progress=show_progress,
+ tensor_type=tensor_type,
+ )
+
+ @classmethod
+ def from_dataframe(
+ cls: Type['T'],
+ df: 'pd.DataFrame',
+ tensor_type: Type['AbstractTensor'] = NdArray,
+ ) -> 'T':
+ """
+ Load a `DocVec` from a `pandas.DataFrame` following the schema
+ defined in the [`.doc_type`][docarray.DocVec] attribute.
+ Every row of the dataframe will be mapped to one Document in the doc_vec.
+ The column names of the dataframe have to match the field names of the
+ Document type.
+ For nested fields use "__"-separated access paths as column names,
+ such as `'image__url'`.
+
+ List-like fields (including field of type DocList) are not supported.
+
+ ---
+
+ ```python
+ import pandas as pd
+
+ from docarray import BaseDoc, DocVec
+
+
+ class Person(BaseDoc):
+ name: str
+ follower: int
+
+
+ df = pd.DataFrame(
+ data=[['Maria', 12345], ['Jake', 54321]], columns=['name', 'follower']
+ )
+
+ docs = DocVec[Person].from_dataframe(df)
+
+ assert docs.name == ['Maria', 'Jake']
+ assert docs.follower == [12345, 54321]
+ ```
+
+ ---
+
+ :param df: `pandas.DataFrame` to extract Document's information from
+ :param tensor_type: the tensor type of the resulting DocVec
+ :return: `DocList` where each Document contains the information of one
+ corresponding row of the `pandas.DataFrame`.
+ """
+ # type ignore could be avoided by simply putting this implementation in the DocVec class
+ # but leaving it here for code separation
+ return cls(super().from_dataframe(df), tensor_type=tensor_type) # type: ignore
+
+ @classmethod
+ def load_binary(
+ cls: Type[T],
+ file: Union[str, bytes, pathlib.Path, io.BufferedReader, _LazyRequestReader],
+ protocol: ProtocolType = 'protobuf-array',
+ compress: Optional[str] = None,
+ show_progress: bool = False,
+ streaming: bool = False,
+ tensor_type: Type['AbstractTensor'] = NdArray,
+ ) -> Union[T, Generator['T_doc', None, None]]:
+ """Load doc_vec elements from a compressed binary file.
+
+ In case protocol is pickle the `Documents` are streamed from disk to save memory usage
+
+ !!! note
+ If `file` is `str` it can specify `protocol` and `compress` as file extensions.
+ This functionality assumes `file=file_name.$protocol.$compress` where `$protocol` and `$compress` refer to a
+ string interpolation of the respective `protocol` and `compress` methods.
+ For example if `file=my_docarray.protobuf.lz4` then the binary data will be loaded assuming `protocol=protobuf`
+ and `compress=lz4`.
+
+ :param file: File or filename or serialized bytes where the data is stored.
+ :param protocol: protocol to use. It can be 'pickle-array', 'protobuf-array', 'pickle' or 'protobuf'
+ :param compress: compress algorithm to use between `lz4`, `bz2`, `lzma`, `zlib`, `gzip`
+ :param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf`
+ :param streaming: if `True` returns a generator over `Document` objects.
+ :param tensor_type: the tensor type of the resulting DocVEc
+
+ :return: a `DocVec` object
+
+ """
+ file_ctx, load_protocol, load_compress = cls._get_file_context(
+ file, protocol, compress
+ )
+ if streaming:
+ if load_protocol not in SINGLE_PROTOCOLS:
+ raise ValueError(
+ f'`streaming` is only available when using {" or ".join(map(lambda x: f"`{x}`", SINGLE_PROTOCOLS))} as protocol, '
+ f'got {load_protocol}'
+ )
+ else:
+ return cls._load_binary_stream(
+ file_ctx,
+ protocol=load_protocol,
+ compress=load_compress,
+ show_progress=show_progress,
+ )
+ else:
+ return cls._load_binary_all(
+ file_ctx,
+ load_protocol,
+ load_compress,
+ show_progress,
+ tensor_type=tensor_type,
+ )
diff --git a/docarray/array/list_advance_indexing.py b/docarray/array/list_advance_indexing.py
index bcf966e6454..c3d80ad2f6c 100644
--- a/docarray/array/list_advance_indexing.py
+++ b/docarray/array/list_advance_indexing.py
@@ -1,5 +1,4 @@
from typing import (
- TYPE_CHECKING,
Any,
Iterable,
List,
@@ -14,7 +13,25 @@
import numpy as np
from typing_extensions import SupportsIndex
-from docarray.utils._internal.misc import import_library
+from docarray.utils._internal.misc import (
+ is_jax_available,
+ is_tf_available,
+ is_torch_available,
+)
+
+torch_available = is_torch_available()
+if torch_available:
+ import torch
+tf_available = is_tf_available()
+if tf_available:
+ import tensorflow as tf # type: ignore
+
+ from docarray.typing.tensor.tensorflow_tensor import TensorFlowTensor
+jax_available = is_jax_available()
+if jax_available:
+ import jax.numpy as jnp
+
+ from docarray.typing.tensor.jaxarray import JaxArray
T_item = TypeVar('T_item')
T = TypeVar('T', bound='ListAdvancedIndexing')
@@ -75,17 +92,26 @@ def _normalize_index_item(
return item.tolist()
# torch index types
- if TYPE_CHECKING:
- import torch
- else:
- torch = import_library('torch', raise_error=True)
-
- allowed_torch_dtypes = [
- torch.bool,
- torch.int64,
- ]
- if isinstance(item, torch.Tensor) and (item.dtype in allowed_torch_dtypes):
- return item.tolist()
+ if torch_available:
+
+ allowed_torch_dtypes = [
+ torch.bool,
+ torch.int64,
+ ]
+ if isinstance(item, torch.Tensor) and (item.dtype in allowed_torch_dtypes):
+ return item.tolist()
+
+ if tf_available:
+ if isinstance(item, tf.Tensor):
+ return item.numpy().tolist()
+ if isinstance(item, TensorFlowTensor):
+ return item.tensor.numpy().tolist()
+
+ if jax_available:
+ if isinstance(item, jnp.ndarray):
+ return item.__array__().tolist()
+ if isinstance(item, JaxArray):
+ return item.tensor.__array__().tolist()
return item
diff --git a/docarray/base_doc/__init__.py b/docarray/base_doc/__init__.py
index 47e01c1c662..1c3a3cf7924 100644
--- a/docarray/base_doc/__init__.py
+++ b/docarray/base_doc/__init__.py
@@ -1,3 +1,18 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from docarray.base_doc.any_doc import AnyDoc
from docarray.base_doc.base_node import BaseNode
from docarray.base_doc.doc import BaseDoc
diff --git a/docarray/base_doc/any_doc.py b/docarray/base_doc/any_doc.py
index e04c256f8bb..3a7be2cb125 100644
--- a/docarray/base_doc/any_doc.py
+++ b/docarray/base_doc/any_doc.py
@@ -1,5 +1,7 @@
from typing import Type
+from docarray.utils._internal.pydantic import is_pydantic_v2
+
from .doc import BaseDoc
@@ -17,7 +19,7 @@ def __init__(self, **kwargs):
self.__dict__.update(kwargs)
@classmethod
- def _get_field_type(cls, field: str) -> Type['BaseDoc']:
+ def _get_field_annotation(cls, field: str) -> Type['BaseDoc']:
"""
Accessing the nested python Class define in the schema.
Could be useful for reconstruction of Document in
@@ -28,7 +30,14 @@ def _get_field_type(cls, field: str) -> Type['BaseDoc']:
return AnyDoc
@classmethod
- def _get_field_type_array(cls, field: str) -> Type:
+ def _get_field_annotation_array(cls, field: str) -> Type:
from docarray import DocList
return DocList
+
+ if is_pydantic_v2:
+
+ def dict(self, *args, **kwargs):
+ raise NotImplementedError(
+ "dict() method is not implemented for pydantic v2. Now pydantic requires a schema to dump the dict, but AnyDoc is schemaless"
+ )
diff --git a/docarray/base_doc/base_node.py b/docarray/base_doc/base_node.py
index 7cbb76c9e98..16a64bea599 100644
--- a/docarray/base_doc/base_node.py
+++ b/docarray/base_doc/base_node.py
@@ -1,3 +1,18 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, TypeVar, Optional, Type
diff --git a/docarray/base_doc/doc.py b/docarray/base_doc/doc.py
index f69187d35e0..e880504bc05 100644
--- a/docarray/base_doc/doc.py
+++ b/docarray/base_doc/doc.py
@@ -7,6 +7,7 @@
Callable,
Dict,
List,
+ Literal,
Mapping,
Optional,
Tuple,
@@ -18,8 +19,16 @@
)
import orjson
+import typing_extensions
from pydantic import BaseModel, Field
-from pydantic.main import ROOT_KEY
+from pydantic.fields import FieldInfo
+from typing_inspect import get_args, is_optional_type
+
+from docarray.utils._internal.pydantic import is_pydantic_v2
+
+if not is_pydantic_v2:
+ from pydantic.main import ROOT_KEY
+
from rich.console import Console
from docarray.base_doc.base_node import BaseNode
@@ -27,6 +36,7 @@
from docarray.base_doc.mixins import IOMixin, UpdateMixin
from docarray.typing import ID
from docarray.typing.tensor.abstract_tensor import AbstractTensor
+from docarray.utils._internal._typing import safe_issubclass
if TYPE_CHECKING:
from pydantic import Protocol
@@ -35,6 +45,15 @@
from docarray.array.doc_vec.column_storage import ColumnStorageView
+if is_pydantic_v2:
+
+ IncEx: typing_extensions.TypeAlias = (
+ 'set[int] | set[str] | dict[int, Any] | dict[str, Any] | None'
+ )
+
+ from pydantic import ConfigDict
+
+
_console: Console = Console()
T = TypeVar('T', bound='BaseDoc')
@@ -44,68 +63,171 @@
ExcludeType = Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']]
-class BaseDoc(BaseModel, IOMixin, UpdateMixin, BaseNode):
+class BaseDocWithoutId(BaseModel, IOMixin, UpdateMixin, BaseNode):
+ """
+ BaseDocWoId is the class behind BaseDoc, it should not be used directly unless you know what you are doing.
+ It is basically a BaseDoc without the ID field.
+ !!! warning
+ This class cannot be used with DocumentIndex. Only BaseDoc is compatible
"""
- BaseDoc is the base class for all Documents. This class should be subclassed
- to create new Document types with a specific schema.
- The schema of a Document is defined by the fields of the class.
+ if is_pydantic_v2:
- Example:
- ```python
- from docarray import BaseDoc
- from docarray.typing import NdArray, ImageUrl
- import numpy as np
+ class ConfigDocArray(ConfigDict):
+ _load_extra_fields_from_protobuf: bool
+ model_config = ConfigDocArray(
+ validate_assignment=True,
+ _load_extra_fields_from_protobuf=False,
+ json_encoders={AbstractTensor: lambda x: x},
+ )
- class MyDoc(BaseDoc):
- embedding: NdArray[512]
- image: ImageUrl
+ else:
+ class Config:
+ json_loads = orjson.loads
+ json_dumps = orjson_dumps_and_decode
+ # `DocArrayResponse` is able to handle tensors by itself.
+ # Therefore, we stop FastAPI from doing any transformations
+ # on tensors by setting an identity function as a custom encoder.
+ json_encoders = {AbstractTensor: lambda x: x}
- doc = MyDoc(embedding=np.zeros(512), image='https://example.com/image.jpg')
- ```
+ validate_assignment = True
+ _load_extra_fields_from_protobuf = False
+ if is_pydantic_v2:
- BaseDoc is a subclass of [pydantic.BaseModel](
- https://docs.pydantic.dev/usage/models/) and can be used in a similar way.
- """
+ ## pydantic v2 handle view and shallow copy a bit differently. We need to update different fields
+
+ @classmethod
+ def from_view(cls: Type[T], storage_view: 'ColumnStorageView') -> T:
+ doc = cls.__new__(cls)
+
+ object.__setattr__(doc, '__dict__', storage_view)
+ object.__setattr__(doc, '__pydantic_fields_set__', set(storage_view.keys()))
+ object.__setattr__(doc, '__pydantic_extra__', {})
+
+ if cls.__pydantic_post_init__:
+ doc.model_post_init(None)
+ else:
+ # Note: if there are any private attributes, cls.__pydantic_post_init__ would exist
+ # Since it doesn't, that means that `__pydantic_private__` should be set to None
+ object.__setattr__(doc, '__pydantic_private__', None)
+
+ return doc
+
+ @classmethod
+ def _shallow_copy(cls: Type[T], doc_to_copy: T) -> T:
+ """
+ perform a shallow copy, the new doc share the same data with the original doc
+ """
+ doc = cls.__new__(cls)
+
+ object.__setattr__(doc, '__dict__', doc_to_copy.__dict__)
+ object.__setattr__(
+ doc, '__pydantic_fields_set__', doc_to_copy.__pydantic_fields_set__
+ )
+ object.__setattr__(doc, '__pydantic_extra__', {})
+
+ if cls.__pydantic_post_init__:
+ doc.model_post_init(None)
+ else:
+ # Note: if there are any private attributes, cls.__pydantic_post_init__ would exist
+ # Since it doesn't, that means that `__pydantic_private__` should be set to None
+ object.__setattr__(doc, '__pydantic_private__', None)
+
+ return doc
- id: Optional[ID] = Field(default_factory=lambda: ID(os.urandom(16).hex()))
+ else:
- class Config:
- json_loads = orjson.loads
- json_dumps = orjson_dumps_and_decode
- # `DocArrayResponse` is able to handle tensors by itself.
- # Therefore, we stop FastAPI from doing any transformations
- # on tensors by setting an identity function as a custom encoder.
- json_encoders = {AbstractTensor: lambda x: x}
+ @classmethod
+ def from_view(cls: Type[T], storage_view: 'ColumnStorageView') -> T:
+ doc = cls.__new__(cls)
+ object.__setattr__(doc, '__dict__', storage_view)
+ object.__setattr__(doc, '__fields_set__', set(storage_view.keys()))
- validate_assignment = True
- _load_extra_fields_from_protobuf = False
+ doc._init_private_attributes()
+ return doc
+
+ @classmethod
+ def _shallow_copy(cls: Type[T], doc_to_copy: T) -> T:
+ """
+ perform a shallow copy, the new doc share the same data with the original doc
+ """
+ doc = cls.__new__(cls)
+ object.__setattr__(doc, '__dict__', doc_to_copy.__dict__)
+ object.__setattr__(doc, '__fields_set__', set(doc_to_copy.__fields_set__))
+
+ doc._init_private_attributes()
+ return doc
@classmethod
- def from_view(cls: Type[T], storage_view: 'ColumnStorageView') -> T:
- doc = cls.__new__(cls)
- object.__setattr__(doc, '__dict__', storage_view)
- object.__setattr__(doc, '__fields_set__', set(storage_view.keys()))
+ def _docarray_fields(cls) -> Dict[str, FieldInfo]:
+ """
+ Returns a dictionary of all fields of this document.
+ """
+ if is_pydantic_v2:
+ return cls.model_fields
+ else:
+ return cls.__fields__
- doc._init_private_attributes()
- return doc
+ @classmethod
+ def _get_field_annotation(cls, field: str) -> Type:
+ """
+ Accessing annotation associated with the field in the schema
+ :param field: name of the field
+ :return:
+ """
+
+ if is_pydantic_v2:
+ annotation = cls._docarray_fields()[field].annotation
+
+ if is_optional_type(
+ annotation
+ ): # this is equivalent to `outer_type_` in pydantic v1
+ return get_args(annotation)[0]
+ else:
+ return annotation
+ else:
+ return cls._docarray_fields()[field].outer_type_
@classmethod
- def _get_field_type(cls, field: str) -> Type:
+ def _get_field_inner_type(cls, field: str) -> Type:
"""
- Accessing the nested python Class define in the schema. Could be useful for
- reconstruction of Document in serialization/deserilization
+ Accessing typed associated with the field in the schema
:param field: name of the field
:return:
"""
- return cls.__fields__[field].outer_type_
+
+ if is_pydantic_v2:
+ annotation = cls._docarray_fields()[field].annotation
+
+ if is_optional_type(
+ annotation
+ ): # this is equivalent to `outer_type_` in pydantic v1
+ return get_args(annotation)[0]
+ elif annotation == Tuple:
+ if len(get_args(annotation)) == 0:
+ return Any
+ else:
+ get_args(annotation)[0]
+ else:
+ return annotation
+ else:
+ return cls._docarray_fields()[field].type_
def __str__(self) -> str:
+ content: Any = None
+ if self.is_view():
+ attr_str = ", ".join(
+ f"{field}={self.__getattr__(field)}" for field in self.__dict__.keys()
+ )
+ content = f"{self.__class__.__name__}({attr_str})"
+ else:
+ content = self
+
with _console.capture() as capture:
- _console.print(self)
+ _console.print(content)
return capture.get().strip()
@@ -132,7 +254,7 @@ def is_view(self) -> bool:
return isinstance(self.__dict__, ColumnStorageView)
def __getattr__(self, item) -> Any:
- if item in self.__fields__.keys():
+ if item in self._docarray_fields().keys():
return self.__dict__[item]
else:
return super().__getattribute__(item)
@@ -154,10 +276,10 @@ def __eq__(self, other) -> bool:
if not isinstance(other, BaseDoc):
return False
- if self.__fields__.keys() != other.__fields__.keys():
+ if self._docarray_fields().keys() != other._docarray_fields().keys():
return False
- for field_name in self.__fields__:
+ for field_name in self._docarray_fields():
value1 = getattr(self, field_name)
value2 = getattr(other, field_name)
@@ -193,73 +315,213 @@ def _docarray_to_json_compatible(self) -> Dict:
"""
return self.dict()
- ########################################################################################################################################################
- ### this section is just for documentation purposes will be removed later once
- # https://github.com/mkdocstrings/griffe/issues/138 is fixed ##############
- ########################################################################################################################################################
-
- def json(
- self,
- *,
- include: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None,
- exclude: ExcludeType = None,
- by_alias: bool = False,
- skip_defaults: Optional[bool] = None,
- exclude_unset: bool = False,
- exclude_defaults: bool = False,
- exclude_none: bool = False,
- encoder: Optional[Callable[[Any], Any]] = None,
- models_as_dict: bool = True,
- **dumps_kwargs: Any,
- ) -> str:
+ def _exclude_doclist(
+ self, exclude: ExcludeType
+ ) -> Tuple[ExcludeType, ExcludeType, List[str]]:
"""
- Generate a JSON representation of the model, `include` and `exclude`
- arguments as per `dict()`.
-
- `encoder` is an optional function to supply as `default` to json.dumps(),
- other arguments as per `json.dumps()`.
+ This function exclude the doclist field from the list. It is used in the model dump function because we give a special treatment to DocList during seriliaztion and therefore we want pydantic to ignore this field and let us handle it.
"""
- exclude, original_exclude, doclist_exclude_fields = self._exclude_doclist(
- exclude=exclude
+ doclist_exclude_fields = []
+ for field in self._docarray_fields().keys():
+ from docarray.array.any_array import AnyDocArray
+
+ type_ = self._get_field_annotation(field)
+ if is_pydantic_v2:
+ # Conservative when touching pydantic v1 logic
+ if safe_issubclass(type_, AnyDocArray):
+ doclist_exclude_fields.append(field)
+ else:
+ if isinstance(type_, type) and safe_issubclass(type_, AnyDocArray):
+ doclist_exclude_fields.append(field)
+
+ original_exclude = exclude
+ if exclude is None:
+ exclude = set(doclist_exclude_fields)
+ elif isinstance(exclude, AbstractSet):
+ exclude = set([*exclude, *doclist_exclude_fields])
+ elif isinstance(exclude, Mapping):
+ exclude = dict(**exclude)
+ exclude.update({field: ... for field in doclist_exclude_fields})
+
+ return (
+ exclude,
+ original_exclude,
+ doclist_exclude_fields,
)
- # this is copy from pydantic code
- if skip_defaults is not None:
- warnings.warn(
- f'{self.__class__.__name__}.json(): "skip_defaults" is deprecated and replaced by "exclude_unset"',
- DeprecationWarning,
+ if not is_pydantic_v2:
+
+ def json(
+ self,
+ *,
+ include: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None,
+ exclude: ExcludeType = None,
+ by_alias: bool = False,
+ skip_defaults: Optional[bool] = None,
+ exclude_unset: bool = False,
+ exclude_defaults: bool = False,
+ exclude_none: bool = False,
+ encoder: Optional[Callable[[Any], Any]] = None,
+ models_as_dict: bool = True,
+ **dumps_kwargs: Any,
+ ) -> str:
+ """
+ Generate a JSON representation of the model, `include` and `exclude`
+ arguments as per `dict()`.
+
+ `encoder` is an optional function to supply as `default` to json.dumps(),
+ other arguments as per `json.dumps()`.
+ """
+ exclude, original_exclude, doclist_exclude_fields = self._exclude_docarray(
+ exclude=exclude
)
- exclude_unset = skip_defaults
- encoder = cast(Callable[[Any], Any], encoder or self.__json_encoder__)
-
- # We don't directly call `self.dict()`, which does exactly this with `to_dict=True`
- # because we want to be able to keep raw `BaseModel` instances and not as `dict`.
- # This allows users to write custom JSON encoders for given `BaseModel` classes.
- data = dict(
- self._iter(
- to_dict=models_as_dict,
- by_alias=by_alias,
+
+ # this is copy from pydantic code
+ if skip_defaults is not None:
+ warnings.warn(
+ f'{self.__class__.__name__}.json(): "skip_defaults" is deprecated and replaced by "exclude_unset"',
+ DeprecationWarning,
+ )
+ exclude_unset = skip_defaults
+ encoder = cast(Callable[[Any], Any], encoder or self.__json_encoder__)
+
+ # We don't directly call `self.dict()`, which does exactly this with `to_dict=True`
+ # because we want to be able to keep raw `BaseModel` instances and not as `dict`.
+ # This allows users to write custom JSON encoders for given `BaseModel` classes.
+ data = dict(
+ self._iter(
+ to_dict=models_as_dict,
+ by_alias=by_alias,
+ include=include,
+ exclude=exclude,
+ exclude_unset=exclude_unset,
+ exclude_defaults=exclude_defaults,
+ exclude_none=exclude_none,
+ )
+ )
+
+ # this is the custom part to deal with DocList
+ for field in doclist_exclude_fields:
+ # we need to do this because pydantic will not recognize DocList correctly
+ original_exclude = original_exclude or {}
+ if field not in original_exclude:
+ data[field] = getattr(
+ self, field
+ ) # here we need to keep doclist as doclist otherwise if a user want to have a special json config it will not work
+
+ # this is copy from pydantic code
+ if self.__custom_root_type__:
+ data = data[ROOT_KEY]
+ return self.__config__.json_dumps(data, default=encoder, **dumps_kwargs)
+
+ def dict(
+ self,
+ *,
+ include: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None,
+ exclude: ExcludeType = None,
+ by_alias: bool = False,
+ skip_defaults: Optional[bool] = None,
+ exclude_unset: bool = False,
+ exclude_defaults: bool = False,
+ exclude_none: bool = False,
+ ) -> 'DictStrAny':
+ """
+ Generate a dictionary representation of the model, optionally specifying
+ which fields to include or exclude.
+
+ """
+ exclude, original_exclude, doclist_exclude_fields = self._exclude_doclist(
+ exclude=exclude
+ )
+
+ data = super().dict(
include=include,
exclude=exclude,
+ by_alias=by_alias,
+ skip_defaults=skip_defaults,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
- )
- # this is the custom part to deal with DocList
- for field in doclist_exclude_fields:
- # we need to do this because pydantic will not recognize DocList correctly
- original_exclude = original_exclude or {}
- if field not in original_exclude:
- data[field] = getattr(
- self, field
- ) # here we need to keep doclist as doclist otherwise if a user want to have a special json config it will not work
-
- # this is copy from pydantic code
- if self.__custom_root_type__:
- data = data[ROOT_KEY]
- return self.__config__.json_dumps(data, default=encoder, **dumps_kwargs)
+ for field in doclist_exclude_fields:
+ # we need to do this because pydantic will not recognize DocList correctly
+ original_exclude = original_exclude or {}
+ if field not in original_exclude:
+ val = getattr(self, field)
+ data[field] = (
+ [doc.dict() for doc in val] if val is not None else None
+ )
+
+ return data
+
+ else:
+
+ def _copy_view_pydantic_v2(self: T) -> T:
+ """
+ perform a deep copy, the new doc has its own data
+ """
+ data = {}
+ for key, value in self.__dict__.to_dict().items():
+ if isinstance(value, BaseDocWithoutId):
+ data[key] = value._copy_view_pydantic_v2()
+ else:
+ data[key] = value
+
+ doc = self.__class__.model_construct(**data)
+ return doc
+
+ def model_dump( # type: ignore
+ self,
+ *,
+ mode: Union[Literal['json', 'python'], str] = 'python',
+ include: IncEx = None,
+ exclude: IncEx = None,
+ by_alias: bool = False,
+ exclude_unset: bool = False,
+ exclude_defaults: bool = False,
+ exclude_none: bool = False,
+ round_trip: bool = False,
+ warnings: bool = True,
+ ) -> Dict[str, Any]:
+ def _model_dump(doc):
+ (
+ exclude_,
+ original_exclude,
+ doclist_exclude_fields,
+ ) = self._exclude_doclist(exclude=exclude)
+
+ data = doc.model_dump(
+ mode=mode,
+ include=include,
+ exclude=exclude_,
+ by_alias=by_alias,
+ exclude_unset=exclude_unset,
+ exclude_defaults=exclude_defaults,
+ exclude_none=exclude_none,
+ round_trip=round_trip,
+ warnings=warnings,
+ )
+
+ for field in doclist_exclude_fields:
+ # we need to do this because pydantic will not recognize DocList correctly
+ original_exclude = original_exclude or {}
+ if field not in original_exclude:
+ val = getattr(self, field)
+ data[field] = (
+ [doc.dict() for doc in val] if val is not None else None
+ )
+
+ return data
+
+ if self.is_view():
+ ## for some reason use ColumnViewStorage to dump the data is not working with
+ ## pydantic v2, so we need to create a new doc and dump it
+
+ new_doc = self._copy_view_pydantic_v2()
+ return _model_dump(new_doc)
+ else:
+ return _model_dump(super())
@no_type_check
@classmethod
@@ -281,7 +543,7 @@ def parse_raw(
:param allow_pickle: allow pickle protocol
:return: a document
"""
- return super(BaseDoc, cls).parse_raw(
+ return super(BaseDocWithoutId, cls).parse_raw(
b,
content_type=content_type,
encoding=encoding,
@@ -289,70 +551,66 @@ def parse_raw(
allow_pickle=allow_pickle,
)
- def dict(
- self,
- *,
- include: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None,
- exclude: ExcludeType = None,
- by_alias: bool = False,
- skip_defaults: Optional[bool] = None,
- exclude_unset: bool = False,
- exclude_defaults: bool = False,
- exclude_none: bool = False,
- ) -> 'DictStrAny':
- """
- Generate a dictionary representation of the model, optionally specifying
- which fields to include or exclude.
-
- """
-
- exclude, original_exclude, doclist_exclude_fields = self._exclude_doclist(
- exclude=exclude
- )
-
- data = super().dict(
- include=include,
- exclude=exclude,
- by_alias=by_alias,
- skip_defaults=skip_defaults,
- exclude_unset=exclude_unset,
- exclude_defaults=exclude_defaults,
- exclude_none=exclude_none,
- )
-
- for field in doclist_exclude_fields:
- # we need to do this because pydantic will not recognize DocList correctly
- original_exclude = original_exclude or {}
- if field not in original_exclude:
- val = getattr(self, field)
- data[field] = [doc.dict() for doc in val] if val is not None else None
-
- return data
-
- def _exclude_doclist(
+ def _exclude_docarray(
self, exclude: ExcludeType
) -> Tuple[ExcludeType, ExcludeType, List[str]]:
- doclist_exclude_fields = []
+ docarray_exclude_fields = []
for field in self.__fields__.keys():
- from docarray import DocList
+ from docarray import DocList, DocVec
- type_ = self._get_field_type(field)
- if isinstance(type_, type) and issubclass(type_, DocList):
- doclist_exclude_fields.append(field)
+ type_ = self._get_field_annotation(field)
+ if isinstance(type_, type) and (
+ safe_issubclass(type_, DocList) or safe_issubclass(type_, DocVec)
+ ):
+ docarray_exclude_fields.append(field)
original_exclude = exclude
if exclude is None:
- exclude = set(doclist_exclude_fields)
+ exclude = set(docarray_exclude_fields)
elif isinstance(exclude, AbstractSet):
- exclude = set([*exclude, *doclist_exclude_fields])
+ exclude = set([*exclude, *docarray_exclude_fields])
elif isinstance(exclude, Mapping):
exclude = dict(**exclude)
- exclude.update({field: ... for field in doclist_exclude_fields})
+ exclude.update({field: ... for field in docarray_exclude_fields})
return (
exclude,
original_exclude,
- doclist_exclude_fields,
+ docarray_exclude_fields,
)
- to_json = json
+ to_json = BaseModel.model_dump_json if is_pydantic_v2 else json
+
+
+class BaseDoc(BaseDocWithoutId):
+ """
+ BaseDoc is the base class for all Documents. This class should be subclassed
+ to create new Document types with a specific schema.
+
+ The schema of a Document is defined by the fields of the class.
+
+ Example:
+ ```python
+ from docarray import BaseDoc
+ from docarray.typing import NdArray, ImageUrl
+ import numpy as np
+
+
+ class MyDoc(BaseDoc):
+ embedding: NdArray[512]
+ image: ImageUrl
+
+
+ doc = MyDoc(embedding=np.zeros(512), image='https://example.com/image.jpg')
+ ```
+
+
+ BaseDoc is a subclass of [pydantic.BaseModel](
+ https://docs.pydantic.dev/usage/models/) and can be used in a similar way.
+ """
+
+ id: Optional[ID] = Field(
+ description='The ID of the BaseDoc. This is useful for indexing in vector stores. If not set by user, it will automatically be assigned a random value',
+ default_factory=lambda: ID(os.urandom(16).hex()),
+ example=os.urandom(16).hex(),
+ )
diff --git a/docarray/base_doc/docarray_response.py b/docarray/base_doc/docarray_response.py
index a9f807ab6b4..8f00ffdbf56 100644
--- a/docarray/base_doc/docarray_response.py
+++ b/docarray/base_doc/docarray_response.py
@@ -1,3 +1,18 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from typing import TYPE_CHECKING, Any
from docarray.base_doc.io.json import orjson_dumps
diff --git a/docarray/base_doc/io/__init__.py b/docarray/base_doc/io/__init__.py
index e69de29bb2d..74f8f7582cd 100644
--- a/docarray/base_doc/io/__init__.py
+++ b/docarray/base_doc/io/__init__.py
@@ -0,0 +1,15 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/docarray/base_doc/io/json.py b/docarray/base_doc/io/json.py
index 27468b2b61c..d644c2f194e 100644
--- a/docarray/base_doc/io/json.py
+++ b/docarray/base_doc/io/json.py
@@ -1,5 +1,17 @@
+from typing import Any, Callable, Dict, Type
+
import orjson
-from pydantic.json import ENCODERS_BY_TYPE
+
+from docarray.utils._internal.pydantic import is_pydantic_v2
+
+if not is_pydantic_v2:
+ from pydantic.json import ENCODERS_BY_TYPE
+else:
+ ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = {
+ bytes: lambda o: o.decode(),
+ frozenset: list,
+ set: list,
+ }
def _default_orjson(obj):
@@ -25,5 +37,5 @@ def orjson_dumps(v, *, default=None) -> bytes:
def orjson_dumps_and_decode(v, *, default=None) -> str:
- # dumps to bytes using orjson
+ # dumps to str using orjson
return orjson_dumps(v, default=default).decode()
diff --git a/docarray/base_doc/mixins/__init__.py b/docarray/base_doc/mixins/__init__.py
index bfa675df9a1..dcf5766aa25 100644
--- a/docarray/base_doc/mixins/__init__.py
+++ b/docarray/base_doc/mixins/__init__.py
@@ -1,3 +1,18 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from docarray.base_doc.mixins.io import IOMixin
from docarray.base_doc.mixins.update import UpdateMixin
diff --git a/docarray/base_doc/mixins/io.py b/docarray/base_doc/mixins/io.py
index 4aa27f90900..3121c45c445 100644
--- a/docarray/base_doc/mixins/io.py
+++ b/docarray/base_doc/mixins/io.py
@@ -8,28 +8,35 @@
Dict,
Iterable,
List,
+ Literal,
Optional,
Tuple,
Type,
TypeVar,
+ Union,
)
+from typing import _GenericAlias as GenericAlias
+from typing import get_origin
import numpy as np
-from typing_inspect import is_union_type
+from typing_inspect import get_args, is_union_type
from docarray.base_doc.base_node import BaseNode
from docarray.typing import NdArray
from docarray.typing.proto_register import _PROTO_TYPE_NAME_TO_CLASS
+from docarray.utils._internal._typing import safe_issubclass
from docarray.utils._internal.compress import _compress_bytes, _decompress_bytes
-from docarray.utils._internal.misc import import_library
+from docarray.utils._internal.misc import ProtocolType, import_library
+from docarray.utils._internal.pydantic import is_pydantic_v2
if TYPE_CHECKING:
import tensorflow as tf # type: ignore
import torch
- from pydantic.fields import ModelField
+ from pydantic.fields import FieldInfo
from docarray.proto import DocProto, NodeProto
from docarray.typing import TensorFlowTensor, TorchTensor
+
else:
tf = import_library('tensorflow', raise_error=False)
if tf is not None:
@@ -124,25 +131,25 @@ class IOMixin(Iterable[Tuple[str, Any]]):
IOMixin to define all the bytes/protobuf/json related part of BaseDoc
"""
- __fields__: Dict[str, 'ModelField']
+ _docarray_fields: Dict[str, 'FieldInfo']
class Config:
_load_extra_fields_from_protobuf: bool
@classmethod
@abstractmethod
- def _get_field_type(cls, field: str) -> Type:
+ def _get_field_annotation(cls, field: str) -> Type:
...
@classmethod
- def _get_field_type_array(cls, field: str) -> Type:
- return cls._get_field_type(field)
+ def _get_field_annotation_array(cls, field: str) -> Type:
+ return cls._get_field_annotation(field)
def __bytes__(self) -> bytes:
return self.to_bytes()
def to_bytes(
- self, protocol: str = 'protobuf', compress: Optional[str] = None
+ self, protocol: ProtocolType = 'protobuf', compress: Optional[str] = None
) -> bytes:
"""Serialize itself into bytes.
@@ -169,7 +176,7 @@ def to_bytes(
def from_bytes(
cls: Type[T],
data: bytes,
- protocol: str = 'protobuf',
+ protocol: ProtocolType = 'protobuf',
compress: Optional[str] = None,
) -> T:
"""Build Document object from binary bytes
@@ -195,7 +202,7 @@ def from_bytes(
)
def to_base64(
- self, protocol: str = 'protobuf', compress: Optional[str] = None
+ self, protocol: ProtocolType = 'protobuf', compress: Optional[str] = None
) -> str:
"""Serialize a Document object into as base64 string
@@ -209,7 +216,7 @@ def to_base64(
def from_base64(
cls: Type[T],
data: str,
- protocol: str = 'pickle',
+ protocol: Literal['pickle', 'protobuf'] = 'pickle',
compress: Optional[str] = None,
) -> T:
"""Build Document object from binary bytes
@@ -230,11 +237,15 @@ def from_protobuf(cls: Type[T], pb_msg: 'DocProto') -> T:
"""
fields: Dict[str, Any] = {}
-
+ load_extra_field = (
+ cls.model_config['_load_extra_fields_from_protobuf']
+ if is_pydantic_v2
+ else cls.Config._load_extra_fields_from_protobuf
+ )
for field_name in pb_msg.data:
if (
- not (cls.Config._load_extra_fields_from_protobuf)
- and field_name not in cls.__fields__.keys()
+ not (load_extra_field)
+ and field_name not in cls._docarray_fields().keys()
):
continue # optimization we don't even load the data if the key does not
# match any field in the cls or in the mapping
@@ -259,12 +270,11 @@ def _get_content_from_node_proto(
:param field_name: the name of the field
:return: the loaded field
"""
-
if field_name is not None and field_type is not None:
raise ValueError("field_type and field_name cannot be both passed")
field_type = field_type or (
- cls._get_field_type(field_name) if field_name else None
+ cls._get_field_annotation(field_name) if field_name else None
)
content_type_dict = _PROTO_TYPE_NAME_TO_CLASS
@@ -275,7 +285,6 @@ def _get_content_from_node_proto(
)
return_field: Any
-
if docarray_type in content_type_dict:
return_field = content_type_dict[docarray_type].from_protobuf(
getattr(value, content_key)
@@ -285,17 +294,31 @@ def _get_content_from_node_proto(
raise ValueError(
'field_type cannot be None when trying to deserialize a BaseDoc'
)
- return_field = field_type.from_protobuf(
- getattr(value, content_key)
- ) # we get to the parent class
+ try:
+ return_field = field_type.from_protobuf(
+ getattr(value, content_key)
+ ) # we get to the parent class
+ except Exception:
+ if get_origin(field_type) is Union:
+ raise ValueError(
+ 'Union type is not supported for proto deserialization. Please use JSON serialization instead'
+ )
+ raise ValueError(
+ f'{field_type} is not supported for proto deserialization'
+ )
elif content_key == 'doc_array':
- if field_name is None:
+ if field_type is not None and field_name is None:
+ return_field = field_type.from_protobuf(getattr(value, content_key))
+ elif field_name is not None:
+ return_field = cls._get_field_annotation_array(
+ field_name
+ ).from_protobuf(
+ getattr(value, content_key)
+ ) # we get to the parent class
+ else:
raise ValueError(
- 'field_name cannot be None when trying to deserialize a BaseDoc'
+ 'field_name and field_type cannot be None when trying to deserialize a DocArray'
)
- return_field = cls._get_field_type_array(field_name).from_protobuf(
- getattr(value, content_key)
- ) # we get to the parent class
elif content_key is None:
return_field = None
elif docarray_type is None:
@@ -309,7 +332,12 @@ def _get_content_from_node_proto(
return_field = getattr(value, content_key)
elif content_key in arg_to_container.keys():
- field_type = cls.__fields__[field_name].type_ if field_name else None
+ if field_name and field_name in cls._docarray_fields():
+ field_type = cls._get_field_inner_type(field_name)
+
+ if isinstance(field_type, GenericAlias):
+ field_type = get_args(field_type)[0]
+
return_field = arg_to_container[content_key](
cls._get_content_from_node_proto(node, field_type=field_type)
for node in getattr(value, content_key).data
@@ -317,7 +345,22 @@ def _get_content_from_node_proto(
elif content_key == 'dict':
deser_dict: Dict[str, Any] = dict()
- field_type = cls.__fields__[field_name].type_ if field_name else None
+
+ if field_name and field_name in cls._docarray_fields():
+ if is_pydantic_v2:
+ dict_args = get_args(
+ cls._docarray_fields()[field_name].annotation
+ )
+ if len(dict_args) < 2:
+ field_type = Any
+ else:
+ field_type = dict_args[1]
+ else:
+ field_type = cls._docarray_fields()[field_name].type_
+
+ else:
+ field_type = None
+
for key_name, node in value.dict.data.items():
deser_dict[key_name] = cls._get_content_from_node_proto(
node, field_type=field_type
@@ -364,14 +407,14 @@ def to_protobuf(self: T) -> 'DocProto':
return DocProto(data=data)
def _to_node_protobuf(self) -> 'NodeProto':
- from docarray.proto import NodeProto
-
"""Convert Document into a NodeProto protobuf message. This function should be
called when the Document is nest into another Document that need to be
converted into a protobuf
:return: the nested item protobuf message
"""
+ from docarray.proto import NodeProto
+
return NodeProto(doc=self.to_protobuf())
@classmethod
@@ -384,12 +427,27 @@ def _get_access_paths(cls) -> List[str]:
from docarray import BaseDoc
paths = []
- for field in cls.__fields__.keys():
- field_type = cls._get_field_type(field)
- if not is_union_type(field_type) and issubclass(field_type, BaseDoc):
+ for field in cls._docarray_fields().keys():
+ field_type = cls._get_field_annotation(field)
+ if not is_union_type(field_type) and safe_issubclass(field_type, BaseDoc):
sub_paths = field_type._get_access_paths()
for path in sub_paths:
paths.append(f'{field}__{path}')
else:
paths.append(field)
return paths
+
+ @classmethod
+ def from_json(
+ cls: Type[T],
+ data: str,
+ ) -> T:
+ """Build Document object from json data
+ :return: a Document object
+ """
+ # TODO: add tests
+
+ if is_pydantic_v2:
+ return cls.model_validate_json(data)
+ else:
+ return cls.parse_raw(data)
diff --git a/docarray/base_doc/mixins/update.py b/docarray/base_doc/mixins/update.py
index e463bcb6af1..7ce596ce1aa 100644
--- a/docarray/base_doc/mixins/update.py
+++ b/docarray/base_doc/mixins/update.py
@@ -3,6 +3,8 @@
from typing_inspect import get_origin
+from docarray.utils._internal._typing import safe_issubclass
+
T = TypeVar('T', bound='UpdateMixin')
if TYPE_CHECKING:
@@ -10,14 +12,14 @@
class UpdateMixin:
- __fields__: Dict[str, 'ModelField']
+ _docarray_fields: Dict[str, 'ModelField']
def _get_string_for_regex_filter(self):
return str(self)
@classmethod
@abstractmethod
- def _get_field_type(cls, field: str) -> Type['UpdateMixin']:
+ def _get_field_annotation(cls, field: str) -> Type['UpdateMixin']:
...
def update(self, other: T):
@@ -68,10 +70,10 @@ class MyDocument(BaseDoc):
---
:param other: The Document with which to update the contents of this
"""
- if type(self) != type(other):
+ if not _similar_schemas(self, other):
raise Exception(
f'Update operation can only be applied to '
- f'Documents of the same type. '
+ f'Documents of the same schema. '
f'Trying to update Document of type '
f'{type(self)} with Document of type '
f'{type(other)}'
@@ -104,11 +106,11 @@ def _group_fields(doc: 'UpdateMixin') -> _FieldGroups:
nested_docs_fields: List[str] = []
nested_docarray_fields: List[str] = []
- for field_name, field in doc.__fields__.items():
+ for field_name, field in doc._docarray_fields().items():
if field_name not in FORBIDDEN_FIELDS_TO_UPDATE:
- field_type = doc._get_field_type(field_name)
+ field_type = doc._get_field_annotation(field_name)
- if isinstance(field_type, type) and issubclass(field_type, DocList):
+ if safe_issubclass(field_type, DocList):
nested_docarray_fields.append(field_name)
else:
origin = get_origin(field_type)
@@ -120,7 +122,7 @@ def _group_fields(doc: 'UpdateMixin') -> _FieldGroups:
dict_fields.append(field_name)
else:
v = getattr(doc, field_name)
- if v:
+ if v is not None:
if isinstance(v, UpdateMixin):
nested_docs_fields.append(field_name)
else:
@@ -185,3 +187,7 @@ def _group_fields(doc: 'UpdateMixin') -> _FieldGroups:
elif dict1 is not None and dict2 is not None:
dict1.update(dict2)
setattr(self, field, dict1)
+
+
+def _similar_schemas(model1, model2):
+ return model1.__annotations__ == model2.__annotations__
diff --git a/docarray/computation/__init__.py b/docarray/computation/__init__.py
index 570505565c6..06ddb5ea287 100644
--- a/docarray/computation/__init__.py
+++ b/docarray/computation/__init__.py
@@ -1,3 +1,18 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from docarray.computation.abstract_comp_backend import AbstractComputationalBackend
__all__ = ['AbstractComputationalBackend']
diff --git a/docarray/computation/abstract_comp_backend.py b/docarray/computation/abstract_comp_backend.py
index 8e2be24cbfb..01e9663a4d0 100644
--- a/docarray/computation/abstract_comp_backend.py
+++ b/docarray/computation/abstract_comp_backend.py
@@ -1,13 +1,13 @@
import typing
from abc import ABC, abstractmethod
-from typing import TYPE_CHECKING, Any, List, Optional, Tuple, TypeVar, Union
+from typing import TYPE_CHECKING, Any, List, Optional, Tuple, TypeVar, Union, Iterable
if TYPE_CHECKING:
import numpy as np
# In practice all of the below will be the same type
TTensor = TypeVar('TTensor')
-TTensorRetrieval = TypeVar('TTensorRetrieval')
+TTensorRetrieval = TypeVar('TTensorRetrieval', bound=Iterable)
TTensorMetrics = TypeVar('TTensorMetrics')
@@ -157,6 +157,19 @@ def minmax_normalize(
"""
...
+ @classmethod
+ @abstractmethod
+ def equal(cls, tensor1: 'TTensor', tensor2: 'TTensor') -> bool:
+ """
+ Check if two tensors are equal.
+
+ :param tensor1: the first tensor
+ :param tensor2: the second tensor
+ :return: True if two tensors are equal, False otherwise.
+ If one or more of the inputs is not a tensor of this framework, return False.
+ """
+ ...
+
class Retrieval(ABC, typing.Generic[TTensorRetrieval]):
"""
Abstract class for retrieval and ranking functionalities
diff --git a/docarray/computation/jax_backend.py b/docarray/computation/jax_backend.py
new file mode 100644
index 00000000000..f571c79b701
--- /dev/null
+++ b/docarray/computation/jax_backend.py
@@ -0,0 +1,336 @@
+from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
+
+import numpy as np
+
+from docarray.computation.abstract_comp_backend import AbstractComputationalBackend
+from docarray.computation.abstract_numpy_based_backend import AbstractNumpyBasedBackend
+from docarray.typing import JaxArray
+from docarray.utils._internal.misc import import_library
+
+if TYPE_CHECKING:
+ import jax
+ import jax.numpy as jnp
+else:
+ jax = import_library('jax', raise_error=True)
+ jnp = jax.numpy
+
+
+def _expand_if_single_axis(*matrices: jnp.ndarray) -> List[jnp.ndarray]:
+ """Expands arrays that only have one axis, at dim 0.
+ This ensures that all outputs can be treated as matrices, not vectors.
+
+ :param matrices: Matrices to be expanded
+ :return: List of the input matrices,
+ where single axis matrices are expanded at dim 0.
+ """
+ expanded = []
+ for m in matrices:
+ if len(m.shape) == 1:
+ expanded.append(jnp.expand_dims(m, axis=0))
+ else:
+ expanded.append(m)
+ return expanded
+
+
+def _expand_if_scalar(arr: jnp.ndarray) -> jnp.ndarray:
+ if len(arr.shape) == 0: # avoid scalar output
+ arr = jnp.expand_dims(arr, axis=0)
+ return arr
+
+
+def norm_left(t: jnp.ndarray) -> JaxArray:
+ return JaxArray(tensor=t)
+
+
+def norm_right(t: JaxArray) -> jnp.ndarray:
+ return t.tensor
+
+
+class JaxCompBackend(AbstractNumpyBasedBackend):
+ """
+ Computational backend for Jax.
+ """
+
+ _module = jnp
+ _cast_output: Callable = norm_left
+ _get_tensor: Callable = norm_right
+
+ @classmethod
+ def to_device(cls, tensor: 'JaxArray', device: str) -> 'JaxArray':
+ """Move the tensor to the specified device."""
+ if cls.device(tensor) == device:
+ return tensor
+ else:
+ jax_devices = jax.devices(device)
+ return cls._cast_output(
+ jax.device_put(cls._get_tensor(tensor), jax_devices)
+ )
+
+ @classmethod
+ def device(cls, tensor: 'JaxArray') -> Optional[str]:
+ """Return device on which the tensor is allocated."""
+ return cls._get_tensor(tensor).device().platform
+
+ @classmethod
+ def to_numpy(cls, array: 'JaxArray') -> 'np.ndarray':
+ return cls._get_tensor(array).__array__()
+
+ @classmethod
+ def none_value(cls) -> Any:
+ """Provide a compatible value that represents None in JAX."""
+ return jnp.nan
+
+ @classmethod
+ def detach(cls, tensor: 'JaxArray') -> 'JaxArray':
+ """
+ Returns the tensor detached from its current graph.
+
+ :param tensor: tensor to be detached
+ :return: a detached tensor with the same data.
+ """
+ return cls._cast_output(jax.lax.stop_gradient(cls._get_tensor(tensor)))
+
+ @classmethod
+ def dtype(cls, tensor: 'JaxArray') -> jnp.dtype:
+ """Get the data type of the tensor."""
+ d_type = cls._get_tensor(tensor).dtype
+ return d_type.name
+
+ @classmethod
+ def minmax_normalize(
+ cls,
+ tensor: 'JaxArray',
+ t_range: Tuple = (0, 1),
+ x_range: Optional[Tuple] = None,
+ eps: float = 1e-7,
+ ) -> 'JaxArray':
+ """
+ Normalize values in `tensor` into `t_range`.
+
+ `tensor` can be a 1D array or a 2D array. When `tensor` is a 2D array, then
+ normalization is row-based.
+
+ !!! note
+
+ - with `t_range=(0, 1)` will normalize the min-value of data to 0, max to 1;
+ - with `t_range=(1, 0)` will normalize the min-value of data to 1, max value
+ of the data to 0.
+
+ :param tensor: the data to be normalized
+ :param t_range: a tuple represents the target range.
+ :param x_range: a tuple represents tensors range.
+ :param eps: a small jitter to avoid dividing by zero
+ :return: normalized data in `t_range`
+ """
+ a, b = t_range
+
+ t = jnp.asarray(cls._get_tensor(tensor), jnp.float32)
+
+ min_d = x_range[0] if x_range else jnp.min(t, axis=-1, keepdims=True)
+ max_d = x_range[1] if x_range else jnp.max(t, axis=-1, keepdims=True)
+ r = (b - a) * (t - min_d) / (max_d - min_d + eps) + a
+
+ normalized = jnp.clip(r, *((a, b) if a < b else (b, a)))
+ return cls._cast_output(jnp.asarray(normalized, cls._get_tensor(tensor).dtype))
+
+ @classmethod
+ def equal(cls, tensor1: 'JaxArray', tensor2: 'JaxArray') -> bool:
+ """
+ Check if two tensors are equal.
+
+ :param tensor1: the first tensor
+ :param tensor2: the second tensor
+ :return: True if two tensors are equal, False otherwise.
+ If one or more of the inputs is not a TensorFlowTensor, return False.
+ """
+ t1, t2 = getattr(tensor1, 'tensor', None), getattr(tensor2, 'tensor', None)
+ if isinstance(t1, jnp.ndarray) and isinstance(t2, jnp.ndarray):
+ # mypy doesn't know that tf.is_tensor implies that t1, t2 are not None
+ return t1.shape == t2.shape and jnp.all(jnp.equal(t1, t1)) # type: ignore
+ return False
+
+ class Retrieval(AbstractComputationalBackend.Retrieval[JaxArray]):
+ """
+ Abstract class for retrieval and ranking functionalities
+ """
+
+ @staticmethod
+ def top_k(
+ values: 'JaxArray',
+ k: int,
+ descending: bool = False,
+ device: Optional[str] = None,
+ ) -> Tuple['JaxArray', 'JaxArray']:
+ """
+ Returns the k smallest values in `values` along with their indices.
+ Can also be used to retrieve the k largest values,
+ by setting the `descending` flag.
+
+ :param values: Jax tensor of values to rank.
+ Should be of shape (n_queries, n_values_per_query).
+ Inputs of shape (n_values_per_query,) will be expanded
+ to (1, n_values_per_query).
+ :param k: number of values to retrieve
+ :param descending: retrieve largest values instead of smallest values
+ :param device: Not supported for this backend
+ :return: Tuple containing the retrieved values, and their indices.
+ Both are of shape (n_queries, k)
+ """
+ comp_be = JaxCompBackend
+ if device is not None:
+ values = comp_be.to_device(values, device)
+
+ jax_values: jnp.ndarray = comp_be._get_tensor(values)
+
+ if len(jax_values.shape) == 1:
+ jax_values = jnp.expand_dims(jax_values, axis=0)
+
+ if descending:
+ jax_values = -jax_values
+
+ if k >= jax_values.shape[1]:
+ idx = jax_values.argsort(axis=1)[:, :k]
+ jax_values = jnp.take_along_axis(jax_values, idx, axis=1)
+ else:
+ idx_ps = jax_values.argpartition(kth=k, axis=1)[:, :k]
+ jax_values = jnp.take_along_axis(jax_values, idx_ps, axis=1)
+ idx_fs = jax_values.argsort(axis=1)
+ idx = jnp.take_along_axis(idx_ps, idx_fs, axis=1)
+ jax_values = jnp.take_along_axis(jax_values, idx_fs, axis=1)
+
+ if descending:
+ jax_values = -jax_values
+
+ return comp_be._cast_output(jax_values), comp_be._cast_output(idx)
+
+ class Metrics(AbstractComputationalBackend.Metrics[JaxArray]):
+ """
+ Abstract base class for metrics (distances and similarities).
+ """
+
+ @staticmethod
+ def cosine_sim(
+ x_mat: 'JaxArray',
+ y_mat: 'JaxArray',
+ eps: float = 1e-7,
+ device: Optional[str] = None,
+ ) -> 'JaxArray':
+ """Pairwise cosine similarities between all vectors in x_mat and y_mat.
+
+ :param x_mat: tensor of shape (n_vectors, n_dim), where n_vectors is the
+ number of vectors and n_dim is the number of dimensions of each example.
+ :param y_mat: tensor of shape (n_vectors, n_dim), where n_vectors is the
+ number of vectors and n_dim is the number of dimensions of each example.
+ :param eps: a small jitter to avoid dividing by zero
+ :param device: the device to use for computations.
+ If not provided, the devices of x_mat and y_mat are used.
+ :return: JaxArray of shape (n_vectors, n_vectors) containing all pairwise
+ cosine distances.
+ The index [i_x, i_y] contains the cosine distance between
+ x_mat[i_x] and y_mat[i_y].
+ """
+ comp_be = JaxCompBackend
+ x_mat_jax: jnp.ndarray = comp_be._get_tensor(x_mat)
+ y_mat_jax: jnp.ndarray = comp_be._get_tensor(y_mat)
+
+ x_mat_jax, y_mat_jax = _expand_if_single_axis(x_mat_jax, y_mat_jax)
+
+ sims = jnp.clip(
+ (jnp.dot(x_mat_jax, y_mat_jax.T) + eps)
+ / (
+ jnp.outer(
+ jnp.linalg.norm(x_mat_jax, axis=1),
+ jnp.linalg.norm(y_mat_jax, axis=1),
+ )
+ + eps
+ ),
+ -1,
+ 1,
+ ).squeeze()
+ sims = _expand_if_scalar(sims)
+
+ return comp_be._cast_output(sims)
+
+ @classmethod
+ def euclidean_dist(
+ cls, x_mat: JaxArray, y_mat: JaxArray, device: Optional[str] = None
+ ) -> JaxArray:
+ """Pairwise Euclidian distances between all vectors in x_mat and y_mat.
+
+ :param x_mat: jnp.ndarray of shape (n_vectors, n_dim), where n_vectors is
+ the number of vectors and n_dim is the number of dimensions of each
+ example.
+ :param y_mat: jnp.ndarray of shape (n_vectors, n_dim), where n_vectors is
+ the number of vectors and n_dim is the number of dimensions of each
+ example.
+ :param eps: a small jitter to avoid dividing by zero
+ :param device: Not supported for this backend
+ :return: JaxArray of shape (n_vectors, n_vectors) containing all
+ pairwise euclidian distances.
+ The index [i_x, i_y] contains the euclidian distance between
+ x_mat[i_x] and y_mat[i_y].
+ """
+ comp_be = JaxCompBackend
+ x_mat_jax: jnp.ndarray = comp_be._get_tensor(x_mat)
+ y_mat_jax: jnp.ndarray = comp_be._get_tensor(y_mat)
+ if device is not None:
+ # warnings.warn('`device` is not supported for numpy operations')
+ pass
+
+ x_mat_jax, y_mat_jax = _expand_if_single_axis(x_mat_jax, y_mat_jax)
+
+ x_mat_jax_arr: JaxArray = comp_be._cast_output(x_mat_jax)
+ y_mat_jax_arr: JaxArray = comp_be._cast_output(y_mat_jax)
+
+ dists = _expand_if_scalar(
+ jnp.sqrt(
+ comp_be._get_tensor(
+ cls.sqeuclidean_dist(x_mat_jax_arr, y_mat_jax_arr)
+ )
+ ).squeeze()
+ )
+
+ return comp_be._cast_output(dists)
+
+ @staticmethod
+ def sqeuclidean_dist(
+ x_mat: JaxArray,
+ y_mat: JaxArray,
+ device: Optional[str] = None,
+ ) -> JaxArray:
+ """Pairwise Squared Euclidian distances between all vectors in
+ x_mat and y_mat.
+
+ :param x_mat: jnp.ndarray of shape (n_vectors, n_dim), where n_vectors is
+ the number of vectors and n_dim is the number of dimensions of each
+ example.
+ :param y_mat: jnp.ndarray of shape (n_vectors, n_dim), where n_vectors is
+ the number of vectors and n_dim is the number of dimensions of each
+ example.
+ :param device: Not supported for this backend
+ :return: JaxArray of shape (n_vectors, n_vectors) containing all
+ pairwise Squared Euclidian distances.
+ The index [i_x, i_y] contains the cosine Squared Euclidian between
+ x_mat[i_x] and y_mat[i_y].
+ """
+ comp_be = JaxCompBackend
+ x_mat_jax: jnp.ndarray = comp_be._get_tensor(x_mat)
+ y_mat_jax: jnp.ndarray = comp_be._get_tensor(y_mat)
+ eps: float = 1e-7 # avoid problems with numerical inaccuracies
+
+ if device is not None:
+ pass
+ # warnings.warn('`device` is not supported for numpy operations')
+
+ x_mat_jax, y_mat_jax = _expand_if_single_axis(x_mat_jax, y_mat_jax)
+
+ dists = (
+ jnp.sum(y_mat_jax**2, axis=1)
+ + jnp.sum(x_mat_jax**2, axis=1)[:, jnp.newaxis]
+ - 2 * jnp.dot(x_mat_jax, y_mat_jax.T)
+ ).squeeze()
+
+ # remove numerical artifacts
+ dists = jnp.where(np.logical_and(dists < 0, dists > -eps), 0, dists)
+ dists = _expand_if_scalar(dists)
+ return comp_be._cast_output(dists)
diff --git a/docarray/computation/numpy_backend.py b/docarray/computation/numpy_backend.py
index 30d50cc0174..913f42d429e 100644
--- a/docarray/computation/numpy_backend.py
+++ b/docarray/computation/numpy_backend.py
@@ -111,6 +111,21 @@ def minmax_normalize(
return np.clip(r, *((a, b) if a < b else (b, a)))
+ @classmethod
+ def equal(cls, tensor1: 'np.ndarray', tensor2: 'np.ndarray') -> bool:
+ """
+ Check if two tensors are equal.
+
+ :param tensor1: the first array
+ :param tensor2: the second array
+ :return: True if two arrays are equal, False otherwise.
+ If one or more of the inputs is not an ndarray, return False.
+ """
+ are_np_arrays = isinstance(tensor1, np.ndarray) and isinstance(
+ tensor2, np.ndarray
+ )
+ return are_np_arrays and np.array_equal(tensor1, tensor2)
+
class Retrieval(AbstractComputationalBackend.Retrieval[np.ndarray]):
"""
Abstract class for retrieval and ranking functionalities
diff --git a/docarray/computation/tensorflow_backend.py b/docarray/computation/tensorflow_backend.py
index fc963cdb48b..27609b737e1 100644
--- a/docarray/computation/tensorflow_backend.py
+++ b/docarray/computation/tensorflow_backend.py
@@ -121,6 +121,22 @@ def minmax_normalize(
normalized = tnp.clip(i, *((a, b) if a < b else (b, a)))
return cls._cast_output(tf.cast(normalized, tensor.tensor.dtype))
+ @classmethod
+ def equal(cls, tensor1: 'TensorFlowTensor', tensor2: 'TensorFlowTensor') -> bool:
+ """
+ Check if two tensors are equal.
+
+ :param tensor1: the first tensor
+ :param tensor2: the second tensor
+ :return: True if two tensors are equal, False otherwise.
+ If one or more of the inputs is not a TensorFlowTensor, return False.
+ """
+ t1, t2 = getattr(tensor1, 'tensor', None), getattr(tensor2, 'tensor', None)
+ if tf.is_tensor(t1) and tf.is_tensor(t2):
+ # mypy doesn't know that tf.is_tensor implies that t1, t2 are not None
+ return t1.shape == t2.shape and tf.math.reduce_all(tf.equal(t1, t1)) # type: ignore
+ return False
+
class Retrieval(AbstractComputationalBackend.Retrieval[TensorFlowTensor]):
"""
Abstract class for retrieval and ranking functionalities
diff --git a/docarray/computation/torch_backend.py b/docarray/computation/torch_backend.py
index be6d4ea03fd..97f0abbb3b5 100644
--- a/docarray/computation/torch_backend.py
+++ b/docarray/computation/torch_backend.py
@@ -113,6 +113,21 @@ def reshape(cls, tensor: 'torch.Tensor', shape: Tuple[int, ...]) -> 'torch.Tenso
"""
return tensor.reshape(shape)
+ @classmethod
+ def equal(cls, tensor1: 'torch.Tensor', tensor2: 'torch.Tensor') -> bool:
+ """
+ Check if two tensors are equal.
+
+ :param tensor1: the first tensor
+ :param tensor2: the second tensor
+ :return: True if two tensors are equal, False otherwise.
+ If one or more of the inputs is not a torch.Tensor, return False.
+ """
+ are_torch = isinstance(tensor1, torch.Tensor) and isinstance(
+ tensor2, torch.Tensor
+ )
+ return are_torch and torch.equal(tensor1, tensor2)
+
@classmethod
def detach(cls, tensor: 'torch.Tensor') -> 'torch.Tensor':
"""
diff --git a/docarray/data/__init__.py b/docarray/data/__init__.py
index 69da35e8c57..1ffabbcbd11 100644
--- a/docarray/data/__init__.py
+++ b/docarray/data/__init__.py
@@ -1,3 +1,18 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from docarray.data.torch_dataset import MultiModalDataset
__all__ = ['MultiModalDataset']
diff --git a/docarray/data/torch_dataset.py b/docarray/data/torch_dataset.py
index f174326c2a1..8e2f3afa490 100644
--- a/docarray/data/torch_dataset.py
+++ b/docarray/data/torch_dataset.py
@@ -4,7 +4,7 @@
from docarray import BaseDoc, DocList, DocVec
from docarray.typing import TorchTensor
-from docarray.utils._internal._typing import change_cls_name
+from docarray.utils._internal._typing import change_cls_name, safe_issubclass
T_doc = TypeVar('T_doc', bound=BaseDoc)
@@ -141,7 +141,7 @@ def collate_fn(cls, batch: List[T_doc]):
@classmethod
def __class_getitem__(cls, item: Type[BaseDoc]) -> Type['MultiModalDataset']:
- if not issubclass(item, BaseDoc):
+ if not safe_issubclass(item, BaseDoc):
raise ValueError(
f'{cls.__name__}[item] item should be a Document not a {item} '
)
diff --git a/docarray/display/__init__.py b/docarray/display/__init__.py
index e69de29bb2d..74f8f7582cd 100644
--- a/docarray/display/__init__.py
+++ b/docarray/display/__init__.py
@@ -0,0 +1,15 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/docarray/display/document_summary.py b/docarray/display/document_summary.py
index 349829b6e0e..265236a8d35 100644
--- a/docarray/display/document_summary.py
+++ b/docarray/display/document_summary.py
@@ -1,4 +1,4 @@
-from typing import Any, Optional, Type, Union
+from typing import Any, List, Optional, Type, Union, get_args
from rich.highlighter import RegexHighlighter
from rich.theme import Theme
@@ -10,6 +10,7 @@
from docarray.display.tensor_display import TensorDisplay
from docarray.typing import ID
from docarray.typing.tensor.abstract_tensor import AbstractTensor
+from docarray.utils._internal._typing import safe_issubclass
if TYPE_CHECKING:
from rich.console import Console, ConsoleOptions, RenderResult
@@ -49,7 +50,11 @@ def schema_summary(cls: Type['BaseDoc']) -> None:
console.print(panel)
@staticmethod
- def _get_schema(cls: Type['BaseDoc'], doc_name: Optional[str] = None) -> Tree:
+ def _get_schema(
+ cls: Type['BaseDoc'],
+ doc_name: Optional[str] = None,
+ recursion_list: Optional[List] = None,
+ ) -> Tree:
"""Get Documents schema as a rich.tree.Tree object."""
import re
@@ -57,15 +62,20 @@ def _get_schema(cls: Type['BaseDoc'], doc_name: Optional[str] = None) -> Tree:
from docarray import BaseDoc, DocList
+ if recursion_list is None:
+ recursion_list = []
+
+ if cls in recursion_list:
+ return Tree(cls.__name__)
+ else:
+ recursion_list.append(cls)
+
root = cls.__name__ if doc_name is None else f'{doc_name}: {cls.__name__}'
tree = Tree(root, highlight=True)
- for field_name, value in cls.__fields__.items():
+ for field_name, value in cls._docarray_fields().items():
if field_name != 'id':
- field_type = value.type_
- if not value.required:
- field_type = Optional[field_type]
-
+ field_type = value.annotation
field_cls = str(field_type).replace('[', '\[')
field_cls = re.sub('|[a-zA-Z_]*[.]', '', field_cls)
@@ -73,21 +83,37 @@ def _get_schema(cls: Type['BaseDoc'], doc_name: Optional[str] = None) -> Tree:
if is_union_type(field_type) or is_optional_type(field_type):
sub_tree = Tree(node_name, highlight=True)
- for arg in field_type.__args__:
- if issubclass(arg, BaseDoc):
- sub_tree.add(DocumentSummary._get_schema(cls=arg))
- elif issubclass(arg, DocList):
- sub_tree.add(DocumentSummary._get_schema(cls=arg.doc_type))
+ for arg in get_args(field_type):
+ if safe_issubclass(arg, BaseDoc):
+ sub_tree.add(
+ DocumentSummary._get_schema(
+ cls=arg, recursion_list=recursion_list
+ )
+ )
+ elif safe_issubclass(arg, DocList):
+ sub_tree.add(
+ DocumentSummary._get_schema(
+ cls=arg.doc_type, recursion_list=recursion_list
+ )
+ )
tree.add(sub_tree)
- elif issubclass(field_type, BaseDoc):
+ elif safe_issubclass(field_type, BaseDoc):
tree.add(
- DocumentSummary._get_schema(cls=field_type, doc_name=field_name)
+ DocumentSummary._get_schema(
+ cls=field_type,
+ doc_name=field_name,
+ recursion_list=recursion_list,
+ )
)
- elif issubclass(field_type, DocList):
+ elif safe_issubclass(field_type, DocList):
sub_tree = Tree(node_name, highlight=True)
- sub_tree.add(DocumentSummary._get_schema(cls=field_type.doc_type))
+ sub_tree.add(
+ DocumentSummary._get_schema(
+ cls=field_type.doc_type, recursion_list=recursion_list
+ )
+ )
tree.add(sub_tree)
else:
diff --git a/docarray/display/tensor_display.py b/docarray/display/tensor_display.py
index 1bf884b518f..c0f41aea6a2 100644
--- a/docarray/display/tensor_display.py
+++ b/docarray/display/tensor_display.py
@@ -1,3 +1,18 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from typing_extensions import TYPE_CHECKING
if TYPE_CHECKING:
diff --git a/docarray/documents/__init__.py b/docarray/documents/__init__.py
index aba89edd172..5de8a33597f 100644
--- a/docarray/documents/__init__.py
+++ b/docarray/documents/__init__.py
@@ -1,3 +1,18 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from docarray.documents.audio import AudioDoc
from docarray.documents.image import ImageDoc
from docarray.documents.mesh import Mesh3D, VerticesAndFaces
diff --git a/docarray/documents/audio.py b/docarray/documents/audio.py
index fd746a2dfe5..5571a6e42f0 100644
--- a/docarray/documents/audio.py
+++ b/docarray/documents/audio.py
@@ -1,6 +1,7 @@
-from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar, Union
+from typing import TYPE_CHECKING, Any, Dict, Optional, Type, TypeVar, Union
import numpy as np
+from pydantic import Field
from docarray.base_doc import BaseDoc
from docarray.typing import AnyEmbedding, AudioUrl
@@ -8,6 +9,10 @@
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.typing.tensor.audio.audio_tensor import AudioTensor
from docarray.utils._internal.misc import import_library
+from docarray.utils._internal.pydantic import is_pydantic_v2
+
+if is_pydantic_v2:
+ from pydantic import model_validator
if TYPE_CHECKING:
import tensorflow as tf # type: ignore
@@ -55,7 +60,7 @@ class AudioDoc(BaseDoc):
# extend it
class MyAudio(AudioDoc):
- name: Optional[TextDoc]
+ name: Optional[TextDoc] = None
audio = MyAudio(
@@ -94,24 +99,55 @@ class MultiModalDoc(BaseDoc):
```
"""
- url: Optional[AudioUrl]
- tensor: Optional[AudioTensor]
- embedding: Optional[AnyEmbedding]
- bytes_: Optional[AudioBytes]
- frame_rate: Optional[int]
+ url: Optional[AudioUrl] = Field(
+ description='The url to a (potentially remote) audio file that can be loaded',
+ example='https://github.com/docarray/docarray/blob/main/tests/toydata/hello.mp3?raw=true',
+ default=None,
+ )
+ tensor: Optional[AudioTensor] = Field(
+ description='Tensor object of the audio which can be specified to one of `AudioNdArray`, `AudioTorchTensor`, `AudioTensorFlowTensor`',
+ default=None,
+ )
+ embedding: Optional[AnyEmbedding] = Field(
+ description='Store an embedding: a vector representation of the audio.',
+ example=[0, 1, 0],
+ default=None,
+ )
+ bytes_: Optional[AudioBytes] = Field(
+ description='Bytes representation pf the audio',
+ default=None,
+ )
+ frame_rate: Optional[int] = Field(
+ description='An integer representing the frame rate of the audio.',
+ example=24,
+ default=None,
+ )
@classmethod
- def validate(
- cls: Type[T],
- value: Union[str, AbstractTensor, Any],
- ) -> T:
+ def _validate(cls, value) -> Dict[str, Any]:
if isinstance(value, str):
- value = cls(url=value)
+ value = dict(url=value)
elif isinstance(value, (AbstractTensor, np.ndarray)) or (
torch is not None
and isinstance(value, torch.Tensor)
or (tf is not None and isinstance(value, tf.Tensor))
):
- value = cls(tensor=value)
+ value = dict(tensor=value)
+
+ return value
+
+ if is_pydantic_v2:
+
+ @model_validator(mode='before')
+ @classmethod
+ def validate_model_before(cls, value):
+ return cls._validate(value)
+
+ else:
- return super().validate(value)
+ @classmethod
+ def validate(
+ cls: Type[T],
+ value: Union[str, AbstractTensor, Any],
+ ) -> T:
+ return super().validate(cls._validate(value))
diff --git a/docarray/documents/helper.py b/docarray/documents/helper.py
index 039ada7ae71..6f34f0386bd 100644
--- a/docarray/documents/helper.py
+++ b/docarray/documents/helper.py
@@ -1,10 +1,24 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type, TypeVar
-from pydantic import create_model, create_model_from_typeddict
+from pydantic import create_model
+
+from docarray.utils._internal.pydantic import is_pydantic_v2
+
+if not is_pydantic_v2:
+ from pydantic import create_model_from_typeddict
+else:
+
+ def create_model_from_typeddict(*args, **kwargs):
+ raise NotImplementedError(
+ "This function is not compatible with pydantic v2 anymore"
+ )
+
+
from pydantic.config import BaseConfig
from typing_extensions import TypedDict
from docarray import BaseDoc
+from docarray.utils._internal._typing import safe_issubclass
if TYPE_CHECKING:
from pydantic.typing import AnyClassMethod
@@ -26,6 +40,12 @@ def create_doc(
"""
Dynamically create a subclass of BaseDoc. This is a wrapper around pydantic's create_model.
+ !!! note
+ To pickle a dynamically created BaseDoc subclass:
+
+ - the class must be defined globally
+ - it must provide `__module__`
+
```python
from docarray.documents import Audio
from docarray.documents.helper import create_doc
@@ -38,8 +58,8 @@ def create_doc(
tensor=(AudioNdArray, ...),
)
- assert issubclass(MyAudio, BaseDoc)
- assert issubclass(MyAudio, Audio)
+ assert safe_issubclass(MyAudio, BaseDoc)
+ assert safe_issubclass(MyAudio, Audio)
```
:param __model_name: name of the created model
@@ -54,7 +74,7 @@ def create_doc(
:return: the new Document class
"""
- if not issubclass(__base__, BaseDoc):
+ if not safe_issubclass(__base__, BaseDoc):
raise ValueError(f'{type(__base__)} is not a BaseDoc or its subclass')
doc = create_model(
@@ -96,8 +116,8 @@ class MyAudio(TypedDict):
Doc = create_doc_from_typeddict(MyAudio, __base__=Audio)
- assert issubclass(Doc, BaseDoc)
- assert issubclass(Doc, Audio)
+ assert safe_issubclass(Doc, BaseDoc)
+ assert safe_issubclass(Doc, Audio)
```
---
@@ -108,7 +128,7 @@ class MyAudio(TypedDict):
"""
if '__base__' in kwargs:
- if not issubclass(kwargs['__base__'], BaseDoc):
+ if not safe_issubclass(kwargs['__base__'], BaseDoc):
raise ValueError(f'{kwargs["__base__"]} is not a BaseDoc or its subclass')
else:
kwargs['__base__'] = BaseDoc
@@ -136,7 +156,7 @@ def create_doc_from_dict(model_name: str, data_dict: Dict[str, Any]) -> Type['T_
MyDoc = create_doc_from_dict(model_name='MyDoc', data_dict=data_dict)
- assert issubclass(MyDoc, BaseDoc)
+ assert safe_issubclass(MyDoc, BaseDoc)
```
---
diff --git a/docarray/documents/image.py b/docarray/documents/image.py
index e0072b622ab..1b98a235f20 100644
--- a/docarray/documents/image.py
+++ b/docarray/documents/image.py
@@ -1,12 +1,17 @@
-from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar, Union
+from typing import TYPE_CHECKING, Any, Dict, Optional, Type, TypeVar, Union
import numpy as np
+from pydantic import Field
from docarray.base_doc import BaseDoc
from docarray.typing import AnyEmbedding, ImageBytes, ImageUrl
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.typing.tensor.image.image_tensor import ImageTensor
from docarray.utils._internal.misc import import_library
+from docarray.utils._internal.pydantic import is_pydantic_v2
+
+if is_pydantic_v2:
+ from pydantic import model_validator
if TYPE_CHECKING:
import tensorflow as tf # type: ignore
@@ -53,7 +58,7 @@ class ImageDoc(BaseDoc):
# extend it
class MyImage(ImageDoc):
- second_embedding: Optional[AnyEmbedding]
+ second_embedding: Optional[AnyEmbedding] = None
image = MyImage(
@@ -92,25 +97,52 @@ class MultiModalDoc(BaseDoc):
```
"""
- url: Optional[ImageUrl]
- tensor: Optional[ImageTensor]
- embedding: Optional[AnyEmbedding]
- bytes_: Optional[ImageBytes]
+ url: Optional[ImageUrl] = Field(
+ description='URL to a (potentially remote) image file that needs to be loaded',
+ example='https://github.com/docarray/docarray/blob/main/tests/toydata/image-data/apple.png?raw=true',
+ default=None,
+ )
+ tensor: Optional[ImageTensor] = Field(
+ description='Tensor object of the image which can be specifed to one of `ImageNdArray`, `ImageTorchTensor`, `ImageTensorflowTensor`.',
+ default=None,
+ )
+ embedding: Optional[AnyEmbedding] = Field(
+ description='Store an embedding: a vector representation of the image.',
+ example=[1, 0, 1],
+ default=None,
+ )
+ bytes_: Optional[ImageBytes] = Field(
+ description='Bytes object of the image which is an instance of `ImageBytes`.',
+ default=None,
+ )
@classmethod
- def validate(
- cls: Type[T],
- value: Union[str, AbstractTensor, Any],
- ) -> T:
+ def _validate(cls, value) -> Dict[str, Any]:
if isinstance(value, str):
- value = cls(url=value)
+ value = dict(url=value)
elif (
isinstance(value, (AbstractTensor, np.ndarray))
or (torch is not None and isinstance(value, torch.Tensor))
or (tf is not None and isinstance(value, tf.Tensor))
):
- value = cls(tensor=value)
+ value = dict(tensor=value)
elif isinstance(value, bytes):
- value = cls(byte=value)
+ value = dict(byte=value)
+
+ return value
+
+ if is_pydantic_v2:
+
+ @model_validator(mode='before')
+ @classmethod
+ def validate_model_before(cls, value):
+ return cls._validate(value)
+
+ else:
- return super().validate(value)
+ @classmethod
+ def validate(
+ cls: Type[T],
+ value: Union[str, AbstractTensor, Any],
+ ) -> T:
+ return super().validate(cls._validate(value))
diff --git a/docarray/documents/legacy/__init__.py b/docarray/documents/legacy/__init__.py
index 61cb9c485c1..0e092cf6c57 100644
--- a/docarray/documents/legacy/__init__.py
+++ b/docarray/documents/legacy/__init__.py
@@ -1,3 +1,18 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from docarray.documents.legacy.legacy_document import LegacyDocument
__all__ = ['LegacyDocument']
diff --git a/docarray/documents/legacy/legacy_document.py b/docarray/documents/legacy/legacy_document.py
index eea42f1d93e..dc77f10d0b4 100644
--- a/docarray/documents/legacy/legacy_document.py
+++ b/docarray/documents/legacy/legacy_document.py
@@ -1,3 +1,18 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from __future__ import annotations
from typing import Any, Dict, Optional
@@ -8,10 +23,10 @@
class LegacyDocument(BaseDoc):
"""
- This Document is the LegacyDocument. It follows the same schema as in DocArray v1.
+ This Document is the LegacyDocument. It follows the same schema as in DocArray <=0.21.
It can be useful to start migrating a codebase from v1 to v2.
- Nevertheless, the API is not totally compatible with DocArray v1 `Document`.
+ Nevertheless, the API is not totally compatible with DocArray <=0.21 `Document`.
Indeed, none of the method associated with `Document` are present. Only the schema
of the data is similar.
@@ -34,12 +49,12 @@ class LegacyDocument(BaseDoc):
"""
- tensor: Optional[AnyTensor]
- chunks: Optional[DocList[LegacyDocument]]
- matches: Optional[DocList[LegacyDocument]]
- blob: Optional[bytes]
- text: Optional[str]
- url: Optional[str]
- embedding: Optional[AnyEmbedding]
+ tensor: Optional[AnyTensor] = None
+ chunks: Optional[DocList[LegacyDocument]] = None
+ matches: Optional[DocList[LegacyDocument]] = None
+ blob: Optional[bytes] = None
+ text: Optional[str] = None
+ url: Optional[str] = None
+ embedding: Optional[AnyEmbedding] = None
tags: Dict[str, Any] = dict()
- scores: Optional[Dict[str, Any]]
+ scores: Optional[Dict[str, Any]] = None
diff --git a/docarray/documents/mesh/__init__.py b/docarray/documents/mesh/__init__.py
index 15ba1fdab10..a07ac3fc6f8 100644
--- a/docarray/documents/mesh/__init__.py
+++ b/docarray/documents/mesh/__init__.py
@@ -1,3 +1,18 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from docarray.documents.mesh.mesh_3d import Mesh3D
from docarray.documents.mesh.vertices_and_faces import VerticesAndFaces
diff --git a/docarray/documents/mesh/mesh_3d.py b/docarray/documents/mesh/mesh_3d.py
index 82d93f73456..e9ff863b2ca 100644
--- a/docarray/documents/mesh/mesh_3d.py
+++ b/docarray/documents/mesh/mesh_3d.py
@@ -1,9 +1,15 @@
from typing import Any, Optional, Type, TypeVar, Union
+from pydantic import Field
+
from docarray.base_doc import BaseDoc
from docarray.documents.mesh.vertices_and_faces import VerticesAndFaces
from docarray.typing.tensor.embedding import AnyEmbedding
from docarray.typing.url.url_3d.mesh_url import Mesh3DUrl
+from docarray.utils._internal.pydantic import is_pydantic_v2
+
+if is_pydantic_v2:
+ from pydantic import model_validator
T = TypeVar('T', bound='Mesh3D')
@@ -54,7 +60,7 @@ class Mesh3D(BaseDoc):
# extend it
class MyMesh3D(Mesh3D):
- name: Optional[str]
+ name: Optional[str] = None
mesh = MyMesh3D(url='https://people.sc.fsu.edu/~jburkardt/data/obj/al.obj')
@@ -103,16 +109,41 @@ class MultiModalDoc(BaseDoc):
"""
- url: Optional[Mesh3DUrl]
- tensors: Optional[VerticesAndFaces]
- embedding: Optional[AnyEmbedding]
- bytes_: Optional[bytes]
-
- @classmethod
- def validate(
- cls: Type[T],
- value: Union[str, Any],
- ) -> T:
- if isinstance(value, str):
- value = cls(url=value)
- return super().validate(value)
+ url: Optional[Mesh3DUrl] = Field(
+ description='URL to a file containing 3D mesh information. Can be remote (web) URL, or a local file path.',
+ example='https://people.sc.fsu.edu/~jburkardt/data/obj/al.obj',
+ default=None,
+ )
+ tensors: Optional[VerticesAndFaces] = Field(
+ description='A tensor object of 3D mesh of type `VerticesAndFaces`.',
+ example=[[0, 1, 1], [1, 0, 1], [1, 1, 0]],
+ default=None,
+ )
+ embedding: Optional[AnyEmbedding] = Field(
+ description='Store an embedding: a vector representation of the 3D mesh.',
+ default=[1, 0, 1],
+ )
+ bytes_: Optional[bytes] = Field(
+ description='Bytes representation of 3D mesh.',
+ default=None,
+ )
+
+ if is_pydantic_v2:
+
+ @model_validator(mode='before')
+ @classmethod
+ def validate_model_before(cls, value):
+ if isinstance(value, str):
+ return {'url': value}
+ return value
+
+ else:
+
+ @classmethod
+ def validate(
+ cls: Type[T],
+ value: Union[str, Any],
+ ) -> T:
+ if isinstance(value, str):
+ value = cls(url=value)
+ return super().validate(value)
diff --git a/docarray/documents/mesh/vertices_and_faces.py b/docarray/documents/mesh/vertices_and_faces.py
index 758f0acc6b0..05cfea86e34 100644
--- a/docarray/documents/mesh/vertices_and_faces.py
+++ b/docarray/documents/mesh/vertices_and_faces.py
@@ -1,3 +1,18 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from typing import TYPE_CHECKING, Any, Type, TypeVar, Union
from docarray.base_doc import BaseDoc
@@ -23,7 +38,7 @@ class VerticesAndFaces(BaseDoc):
faces: AnyTensor
@classmethod
- def validate(
+ def _docarray_validate(
cls: Type[T],
value: Union[str, Any],
) -> T:
diff --git a/docarray/documents/point_cloud/__init__.py b/docarray/documents/point_cloud/__init__.py
index 27a9defeb87..67013333e17 100644
--- a/docarray/documents/point_cloud/__init__.py
+++ b/docarray/documents/point_cloud/__init__.py
@@ -1,3 +1,18 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from docarray.documents.point_cloud.point_cloud_3d import PointCloud3D
from docarray.documents.point_cloud.points_and_colors import PointsAndColors
diff --git a/docarray/documents/point_cloud/point_cloud_3d.py b/docarray/documents/point_cloud/point_cloud_3d.py
index 8a1963be69f..cd3ad2f268b 100644
--- a/docarray/documents/point_cloud/point_cloud_3d.py
+++ b/docarray/documents/point_cloud/point_cloud_3d.py
@@ -1,12 +1,17 @@
from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar, Union
import numpy as np
+from pydantic import Field
from docarray.base_doc import BaseDoc
from docarray.documents.point_cloud.points_and_colors import PointsAndColors
from docarray.typing import AnyEmbedding, PointCloud3DUrl
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.utils._internal.misc import import_library
+from docarray.utils._internal.pydantic import is_pydantic_v2
+
+if is_pydantic_v2:
+ from pydantic import model_validator
if TYPE_CHECKING:
import tensorflow as tf # type: ignore
@@ -57,7 +62,7 @@ class PointCloud3D(BaseDoc):
# extend it
class MyPointCloud3D(PointCloud3D):
- second_embedding: Optional[AnyEmbedding]
+ second_embedding: Optional[AnyEmbedding] = None
pc = MyPointCloud3D(url='https://people.sc.fsu.edu/~jburkardt/data/obj/al.obj')
@@ -107,23 +112,51 @@ class MultiModalDoc(BaseDoc):
```
"""
- url: Optional[PointCloud3DUrl]
- tensors: Optional[PointsAndColors]
- embedding: Optional[AnyEmbedding]
- bytes_: Optional[bytes]
+ url: Optional[PointCloud3DUrl] = Field(
+ description='URL to a file containing point cloud information. Can be remote (web) URL, or a local file path.',
+ example='https://people.sc.fsu.edu/~jburkardt/data/obj/al.obj',
+ default=None,
+ )
+ tensors: Optional[PointsAndColors] = Field(
+ description='A tensor object of 3D point cloud of type `PointsAndColors`.',
+ example=[[0, 0, 1], [1, 0, 1], [0, 1, 1]],
+ default=None,
+ )
+ embedding: Optional[AnyEmbedding] = Field(
+ description='Store an embedding: a vector representation of 3D point cloud.',
+ example=[1, 1, 1],
+ default=None,
+ )
+ bytes_: Optional[bytes] = Field(
+ description='Bytes representation of 3D point cloud.',
+ default=None,
+ )
@classmethod
- def validate(
- cls: Type[T],
- value: Union[str, AbstractTensor, Any],
- ) -> T:
+ def _validate(self, value: Union[str, AbstractTensor, Any]) -> Any:
if isinstance(value, str):
- value = cls(url=value)
+ value = {'url': value}
elif isinstance(value, (AbstractTensor, np.ndarray)) or (
torch is not None
and isinstance(value, torch.Tensor)
or (tf is not None and isinstance(value, tf.Tensor))
):
- value = cls(tensors=PointsAndColors(points=value))
+ value = {'tensors': PointsAndColors(points=value)}
+
+ return value
+
+ if is_pydantic_v2:
+
+ @model_validator(mode='before')
+ @classmethod
+ def validate_model_before(cls, value):
+ return cls._validate(value)
+
+ else:
- return super().validate(value)
+ @classmethod
+ def validate(
+ cls: Type[T],
+ value: Union[str, AbstractTensor, Any],
+ ) -> T:
+ return super().validate(cls._validate(value))
diff --git a/docarray/documents/point_cloud/points_and_colors.py b/docarray/documents/point_cloud/points_and_colors.py
index 89475d3d9cd..69d184c0a10 100644
--- a/docarray/documents/point_cloud/points_and_colors.py
+++ b/docarray/documents/point_cloud/points_and_colors.py
@@ -31,7 +31,7 @@ class PointsAndColors(BaseDoc):
"""
points: AnyTensor
- colors: Optional[AnyTensor]
+ colors: Optional[AnyTensor] = None
@classmethod
def validate(
diff --git a/docarray/documents/text.py b/docarray/documents/text.py
index c6e6645f4e1..101b3c3d3f7 100644
--- a/docarray/documents/text.py
+++ b/docarray/documents/text.py
@@ -1,8 +1,14 @@
from typing import Any, Optional, Type, TypeVar, Union
+from pydantic import Field
+
from docarray.base_doc import BaseDoc
from docarray.typing import TextUrl
from docarray.typing.tensor.embedding import AnyEmbedding
+from docarray.utils._internal.pydantic import is_pydantic_v2
+
+if is_pydantic_v2:
+ from pydantic import model_validator
T = TypeVar('T', bound='TextDoc')
@@ -24,7 +30,7 @@ class TextDoc(BaseDoc):
from docarray.documents import TextDoc
# use it directly
- txt_doc = TextDoc(url='http://www.jina.ai/')
+ txt_doc = TextDoc(url='https://www.gutenberg.org/files/1065/1065-0.txt')
txt_doc.text = txt_doc.url.load()
# model = MyEmbeddingModel()
# txt_doc.embedding = model(txt_doc.text)
@@ -48,10 +54,10 @@ class TextDoc(BaseDoc):
# extend it
class MyText(TextDoc):
- second_embedding: Optional[AnyEmbedding]
+ second_embedding: Optional[AnyEmbedding] = None
- txt_doc = MyText(url='http://www.jina.ai/')
+ txt_doc = MyText(url='https://www.gutenberg.org/files/1065/1065-0.txt')
txt_doc.text = txt_doc.url.load()
# model = MyEmbeddingModel()
# txt_doc.embedding = model(txt_doc.text)
@@ -93,8 +99,8 @@ class MultiModalDoc(BaseDoc):
```python
from docarray.documents import TextDoc
- doc = TextDoc(text='This is the main text', url='exampleurl.com')
- doc2 = TextDoc(text='This is the main text', url='exampleurl.com')
+ doc = TextDoc(text='This is the main text', url='exampleurl.com/file')
+ doc2 = TextDoc(text='This is the main text', url='exampleurl.com/file')
doc == 'This is the main text' # True
doc == doc2 # True
@@ -102,24 +108,51 @@ class MultiModalDoc(BaseDoc):
"""
- text: Optional[str]
- url: Optional[TextUrl]
- embedding: Optional[AnyEmbedding]
- bytes_: Optional[bytes]
+ text: Optional[str] = Field(
+ description='The text content stored in the document',
+ example='This is an example text content of the document',
+ default=None,
+ )
+ url: Optional[TextUrl] = Field(
+ description='URL to a (potentially remote) text file that can be loaded',
+ example='https://www.w3.org/History/19921103-hypertext/hypertext/README.html',
+ default=None,
+ )
+ embedding: Optional[AnyEmbedding] = Field(
+ description='Store an embedding: a vector representation of the text',
+ example=[1, 0, 1],
+ default=None,
+ )
+ bytes_: Optional[bytes] = Field(
+ description='Bytes representation of the text',
+ default=None,
+ )
def __init__(self, text: Optional[str] = None, **kwargs):
if 'text' not in kwargs:
kwargs['text'] = text
super().__init__(**kwargs)
- @classmethod
- def validate(
- cls: Type[T],
- value: Union[str, Any],
- ) -> T:
- if isinstance(value, str):
- value = cls(text=value)
- return super().validate(value)
+ if is_pydantic_v2:
+
+ @model_validator(mode='before')
+ @classmethod
+ def validate_model_before(cls, values):
+ if isinstance(values, str):
+ return {'text': values}
+ else:
+ return values
+
+ else:
+
+ @classmethod
+ def validate(
+ cls: Type[T],
+ value: Union[str, Any],
+ ) -> T:
+ if isinstance(value, str):
+ value = cls(text=value)
+ return super().validate(value)
def __eq__(self, other: Any) -> bool:
if isinstance(other, str):
diff --git a/docarray/documents/video.py b/docarray/documents/video.py
index fad4a0e843a..23965d26daf 100644
--- a/docarray/documents/video.py
+++ b/docarray/documents/video.py
@@ -1,6 +1,7 @@
-from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar, Union
+from typing import TYPE_CHECKING, Any, Dict, Optional, Type, TypeVar, Union
import numpy as np
+from pydantic import Field
from docarray.base_doc import BaseDoc
from docarray.documents import AudioDoc
@@ -9,6 +10,10 @@
from docarray.typing.tensor.video.video_tensor import VideoTensor
from docarray.typing.url.video_url import VideoUrl
from docarray.utils._internal.misc import import_library
+from docarray.utils._internal.pydantic import is_pydantic_v2
+
+if is_pydantic_v2:
+ from pydantic import model_validator
if TYPE_CHECKING:
import tensorflow as tf # type: ignore
@@ -37,13 +42,16 @@ class VideoDoc(BaseDoc):
You can use this Document directly:
```python
- from docarray.documents import VideoDoc
+ from docarray.documents import VideoDoc, AudioDoc
# use it directly
vid = VideoDoc(
url='https://github.com/docarray/docarray/blob/main/tests/toydata/mov_bbb.mp4?raw=true'
)
- vid.tensor, vid.audio.tensor, vid.key_frame_indices = vid.url.load()
+ tensor, audio_tensor, key_frame_indices = vid.url.load()
+ vid.tensor = tensor
+ vid.audio = AudioDoc(tensor=audio_tensor)
+ vid.key_frame_indices = key_frame_indices
# model = MyEmbeddingModel()
# vid.embedding = model(vid.tensor)
```
@@ -58,7 +66,7 @@ class VideoDoc(BaseDoc):
# extend it
class MyVideo(VideoDoc):
- name: Optional[TextDoc]
+ name: Optional[TextDoc] = None
video = MyVideo(
@@ -97,25 +105,59 @@ class MultiModalDoc(BaseDoc):
```
"""
- url: Optional[VideoUrl]
- audio: Optional[AudioDoc] = AudioDoc()
- tensor: Optional[VideoTensor]
- key_frame_indices: Optional[AnyTensor]
- embedding: Optional[AnyEmbedding]
- bytes_: Optional[VideoBytes]
+ url: Optional[VideoUrl] = Field(
+ description='URL to a (potentially remote) video file that needs to be loaded',
+ example='https://github.com/docarray/docarray/blob/main/tests/toydata/mov_bbb.mp4?raw=true',
+ default=None,
+ )
+ audio: Optional[AudioDoc] = Field(
+ description='Audio document associated with the video',
+ default=None,
+ )
+ tensor: Optional[VideoTensor] = Field(
+ description='Tensor object representing the video which be specified to one of `VideoNdArray`, `VideoTorchTensor`, `VideoTensorFlowTensor`',
+ default=None,
+ )
+ key_frame_indices: Optional[AnyTensor] = Field(
+ description='List of all the key frames in the video',
+ example=[0, 1, 2, 3, 4],
+ default=None,
+ )
+ embedding: Optional[AnyEmbedding] = Field(
+ description='Store an embedding: a vector representation of the video',
+ example=[1, 0, 1],
+ default=None,
+ )
+ bytes_: Optional[VideoBytes] = Field(
+ description='Bytes representation of the video',
+ default=None,
+ )
@classmethod
- def validate(
- cls: Type[T],
- value: Union[str, AbstractTensor, Any],
- ) -> T:
+ def _validate(cls, value) -> Dict[str, Any]:
if isinstance(value, str):
- value = cls(url=value)
+ value = dict(url=value)
elif isinstance(value, (AbstractTensor, np.ndarray)) or (
torch is not None
and isinstance(value, torch.Tensor)
or (tf is not None and isinstance(value, tf.Tensor))
):
- value = cls(tensor=value)
+ value = dict(tensor=value)
+
+ return value
+
+ if is_pydantic_v2:
+
+ @model_validator(mode='before')
+ @classmethod
+ def validate_model_before(cls, value):
+ return cls._validate(value)
+
+ else:
- return super().validate(value)
+ @classmethod
+ def validate(
+ cls: Type[T],
+ value: Union[str, AbstractTensor, Any],
+ ) -> T:
+ return super().validate(cls._validate(value))
diff --git a/docarray/exceptions/__init__.py b/docarray/exceptions/__init__.py
new file mode 100644
index 00000000000..74f8f7582cd
--- /dev/null
+++ b/docarray/exceptions/__init__.py
@@ -0,0 +1,15 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/docarray/exceptions/exceptions.py b/docarray/exceptions/exceptions.py
new file mode 100644
index 00000000000..c6d975cd1ed
--- /dev/null
+++ b/docarray/exceptions/exceptions.py
@@ -0,0 +1,17 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+class UnusableObjectError(NotImplementedError):
+ ...
diff --git a/docarray/helper.py b/docarray/helper.py
index 21469ca2acd..34b0c2bfd40 100644
--- a/docarray/helper.py
+++ b/docarray/helper.py
@@ -15,6 +15,24 @@
Union,
)
+import numpy as np
+
+from docarray.utils._internal._typing import safe_issubclass
+from docarray.utils._internal.misc import (
+ is_jax_available,
+ is_tf_available,
+ is_torch_available,
+)
+
+if is_torch_available():
+ import torch
+
+if is_jax_available():
+ import jax
+
+if is_tf_available():
+ import tensorflow as tf
+
if TYPE_CHECKING:
from docarray import BaseDoc
@@ -24,7 +42,7 @@ def _is_access_path_valid(doc_type: Type['BaseDoc'], access_path: str) -> bool:
Check if a given access path ("__"-separated) is a valid path for a given Document class.
"""
- field_type = _get_field_type_by_access_path(doc_type, access_path)
+ field_type = _get_field_annotation_by_access_path(doc_type, access_path)
return field_type is not None
@@ -52,6 +70,35 @@ def _access_path_to_dict(access_path: str, value) -> Dict[str, Any]:
return result
+def _is_none_like(val: Any) -> bool:
+ """
+ :param val: any value
+ :return: true iff `val` equals to `None`, `'None'` or `''`
+ """
+ # Convoluted implementation, but fixes https://github.com/docarray/docarray/issues/1821
+
+ # tensor-like types can have unexpected (= broadcast) `==`/`in` semantics,
+ # so treat separately
+ is_np_arr = isinstance(val, np.ndarray)
+ if is_np_arr:
+ return False
+
+ is_torch_tens = is_torch_available() and isinstance(val, torch.Tensor)
+ if is_torch_tens:
+ return False
+
+ is_tf_tens = is_tf_available() and isinstance(val, tf.Tensor)
+ if is_tf_tens:
+ return False
+
+ is_jax_arr = is_jax_available() and isinstance(val, jax.numpy.ndarray)
+ if is_jax_arr:
+ return False
+
+ # "normal" case
+ return val in ['', 'None', None]
+
+
def _access_path_dict_to_nested_dict(access_path2val: Dict[str, Any]) -> Dict[Any, Any]:
"""
Convert a dict, where the keys are access paths ("__"-separated) to a nested dictionary.
@@ -74,7 +121,7 @@ def _access_path_dict_to_nested_dict(access_path2val: Dict[str, Any]) -> Dict[An
for access_path, value in access_path2val.items():
field2val = _access_path_to_dict(
access_path=access_path,
- value=value if value not in ['', 'None'] else None,
+ value=None if _is_none_like(value) else value,
)
_update_nested_dicts(to_update=nested_dict, update_with=field2val)
return nested_dict
@@ -127,7 +174,7 @@ def _update_nested_dicts(
_update_nested_dicts(to_update[k], update_with[k])
-def _get_field_type_by_access_path(
+def _get_field_annotation_by_access_path(
doc_type: Type['BaseDoc'], access_path: str
) -> Optional[Type]:
"""
@@ -140,17 +187,17 @@ def _get_field_type_by_access_path(
from docarray import BaseDoc, DocList
field, _, remaining = access_path.partition('__')
- field_valid = field in doc_type.__fields__.keys()
+ field_valid = field in doc_type._docarray_fields().keys()
if field_valid:
if len(remaining) == 0:
- return doc_type._get_field_type(field)
+ return doc_type._get_field_annotation(field)
else:
- d = doc_type._get_field_type(field)
- if issubclass(d, DocList):
- return _get_field_type_by_access_path(d.doc_type, remaining)
- elif issubclass(d, BaseDoc):
- return _get_field_type_by_access_path(d, remaining)
+ d = doc_type._get_field_annotation(field)
+ if safe_issubclass(d, DocList):
+ return _get_field_annotation_by_access_path(d.doc_type, remaining)
+ elif safe_issubclass(d, BaseDoc):
+ return _get_field_annotation_by_access_path(d, remaining)
else:
return None
else:
@@ -240,3 +287,7 @@ def _iter_file_extensions(ps):
num_docs += 1
if size is not None and num_docs >= size:
break
+
+
+def _shallow_copy_doc(doc):
+ return doc.__class__._shallow_copy(doc)
diff --git a/docarray/index/__init__.py b/docarray/index/__init__.py
index 9e4dbde474a..aa20ff5db82 100644
--- a/docarray/index/__init__.py
+++ b/docarray/index/__init__.py
@@ -10,11 +10,27 @@
if TYPE_CHECKING:
from docarray.index.backends.elastic import ElasticDocIndex # noqa: F401
from docarray.index.backends.elasticv7 import ElasticV7DocIndex # noqa: F401
+ from docarray.index.backends.epsilla import EpsillaDocumentIndex # noqa: F401
from docarray.index.backends.hnswlib import HnswDocumentIndex # noqa: F401
+ from docarray.index.backends.milvus import MilvusDocumentIndex # noqa: F401
+ from docarray.index.backends.mongodb_atlas import ( # noqa: F401
+ MongoDBAtlasDocumentIndex,
+ )
from docarray.index.backends.qdrant import QdrantDocumentIndex # noqa: F401
+ from docarray.index.backends.redis import RedisDocumentIndex # noqa: F401
from docarray.index.backends.weaviate import WeaviateDocumentIndex # noqa: F401
-__all__ = ['InMemoryExactNNIndex']
+__all__ = [
+ 'InMemoryExactNNIndex',
+ 'ElasticDocIndex',
+ 'ElasticV7DocIndex',
+ 'EpsillaDocumentIndex',
+ 'QdrantDocumentIndex',
+ 'WeaviateDocumentIndex',
+ 'RedisDocumentIndex',
+ 'MilvusDocumentIndex',
+ 'MongoDBAtlasDocumentIndex',
+]
def __getattr__(name: str):
@@ -28,12 +44,24 @@ def __getattr__(name: str):
elif name == 'ElasticV7DocIndex':
import_library('elasticsearch', raise_error=True)
import docarray.index.backends.elasticv7 as lib
+ elif name == 'EpsillaDocumentIndex':
+ import_library('pyepsilla', raise_error=True)
+ import docarray.index.backends.epsilla as lib
elif name == 'QdrantDocumentIndex':
import_library('qdrant_client', raise_error=True)
import docarray.index.backends.qdrant as lib
elif name == 'WeaviateDocumentIndex':
import_library('weaviate', raise_error=True)
import docarray.index.backends.weaviate as lib
+ elif name == 'MilvusDocumentIndex':
+ import_library('pymilvus', raise_error=True)
+ import docarray.index.backends.milvus as lib
+ elif name == 'RedisDocumentIndex':
+ import_library('redis', raise_error=True)
+ import docarray.index.backends.redis as lib
+ elif name == 'MongoDBAtlasDocumentIndex':
+ import_library('pymongo', raise_error=True)
+ import docarray.index.backends.mongodb_atlas as lib
else:
raise ImportError(
f'cannot import name \'{name}\' from \'{_get_path_from_docarray_root_level(__file__)}\''
diff --git a/docarray/index/abstract.py b/docarray/index/abstract.py
index 59fad087ca4..142d7e61c8e 100644
--- a/docarray/index/abstract.py
+++ b/docarray/index/abstract.py
@@ -1,3 +1,4 @@
+import copy
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, field, replace
@@ -25,13 +26,15 @@
from docarray import BaseDoc, DocList
from docarray.array.any_array import AnyDocArray
-from docarray.typing import AnyTensor
+from docarray.typing import ID, AnyTensor
from docarray.typing.tensor.abstract_tensor import AbstractTensor
-from docarray.utils._internal._typing import is_tensor_union
+from docarray.utils._internal._typing import is_tensor_union, safe_issubclass
from docarray.utils._internal.misc import import_library
+from docarray.utils._internal.pydantic import is_pydantic_v2
from docarray.utils.find import (
FindResult,
FindResultBatched,
+ SubindexFindResult,
_FindResult,
_FindResultBatched,
)
@@ -85,12 +88,20 @@ class BaseDocIndex(ABC, Generic[TSchema]):
# for subclasses this is filled automatically
_schema: Optional[Type[BaseDoc]] = None
- def __init__(self, db_config=None, **kwargs):
+ def __init__(self, db_config=None, subindex: bool = False, **kwargs):
if self._schema is None:
raise ValueError(
'A DocumentIndex must be typed with a Document type.'
'To do so, use the syntax: DocumentIndex[DocumentType]'
)
+ if subindex:
+
+ class _NewSchema(self._schema): # type: ignore
+ parent_id: Optional[ID] = None
+
+ self._ori_schema = self._schema
+ self._schema = cast(Type[BaseDoc], _NewSchema)
+
self._logger = logging.getLogger('docarray')
self._db_config = db_config or self.DBConfig(**kwargs)
if not isinstance(self._db_config, self.DBConfig):
@@ -101,6 +112,9 @@ def __init__(self, db_config=None, **kwargs):
self._column_infos: Dict[str, _ColumnInfo] = self._create_column_infos(
self._schema
)
+ self._is_subindex = subindex
+ self._subindices: Dict[str, BaseDocIndex] = {}
+ self._init_subindex()
###############################################
# Inner classes for query builder and configs #
@@ -116,6 +130,8 @@ def build(self, *args, **kwargs) -> Any:
"""
...
+ # TODO support subindex in QueryBuilder
+
# the methods below need to be implemented by subclasses
# If, in your subclass, one of these is not usable in a query builder, but
# can be called directly on the DocumentIndex, use `_raise_not_composable`.
@@ -129,10 +145,7 @@ def build(self, *args, **kwargs) -> Any:
@dataclass
class DBConfig(ABC):
- ...
-
- @dataclass
- class RuntimeConfig(ABC):
+ index_name: Optional[str] = None
# default configurations for every column type
# a dictionary from a column type (DB specific) to a dictionary
# of default configurations for that type
@@ -141,6 +154,15 @@ class RuntimeConfig(ABC):
# Example: `default_column_config['VARCHAR'] = {'length': 255}`
default_column_config: Dict[Type, Dict[str, Any]] = field(default_factory=dict)
+ @dataclass
+ class RuntimeConfig(ABC):
+ pass
+
+ @property
+ def index_name(self):
+ """Return the name of the index in the database."""
+ ...
+
#####################################
# Abstract methods #
# Subclasses must implement these #
@@ -171,6 +193,14 @@ def num_docs(self) -> int:
"""Return the number of indexed documents"""
...
+ @property
+ def _is_index_empty(self) -> bool:
+ """
+ Check if index is empty by comparing the number of documents to zero.
+ :return: True if the index is empty, False otherwise.
+ """
+ return self.num_docs() == 0
+
@abstractmethod
def _del_items(self, doc_ids: Sequence[str]):
"""Delete Documents from the index.
@@ -209,6 +239,16 @@ def execute_query(self, query: Any, *args, **kwargs) -> Any:
"""
...
+ @abstractmethod
+ def _doc_exists(self, doc_id: str) -> bool:
+ """
+ Checks if a given document exists in the index.
+
+ :param doc_id: The id of a document to check.
+ :return: True if the document exists in the index, False otherwise.
+ """
+ ...
+
@abstractmethod
def _find(
self,
@@ -329,12 +369,24 @@ def __getitem__(
key = [key]
else:
return_singleton = False
+
# retrieve data
doc_sequence = self._get_items(key)
+
# check data
if len(doc_sequence) == 0:
raise KeyError(f'No document with id {key} found')
+ # retrieve nested data
+ for field_name, type_, _ in self._flatten_schema(
+ cast(Type[BaseDoc], self._schema)
+ ):
+ if safe_issubclass(type_, AnyDocArray) and isinstance(
+ doc_sequence[0], Dict
+ ):
+ for doc in doc_sequence:
+ self._get_subindex_doclist(doc, field_name) # type: ignore
+
# cast output
if isinstance(doc_sequence, DocList):
out_docs: DocList[TSchema] = doc_sequence
@@ -355,8 +407,36 @@ def __delitem__(self, key: Union[str, Sequence[str]]):
self._logger.info(f'Deleting documents with id(s) {key} from the index')
if isinstance(key, str):
key = [key]
+
+ # delete nested data
+ for field_name, type_, _ in self._flatten_schema(
+ cast(Type[BaseDoc], self._schema)
+ ):
+ if safe_issubclass(type_, AnyDocArray):
+ for doc_id in key:
+ nested_docs_id = self._subindices[field_name]._filter_by_parent_id(
+ doc_id
+ )
+ if nested_docs_id:
+ del self._subindices[field_name][nested_docs_id]
+ # delete data
self._del_items(key)
+ def __contains__(self, item: BaseDoc) -> bool:
+ """
+ Checks if a given document exists in the index.
+
+ :param item: The document to check.
+ It must be an instance of BaseDoc or its subclass.
+ :return: True if the document exists in the index, False otherwise.
+ """
+ if safe_issubclass(type(item), BaseDoc):
+ return self._doc_exists(str(item.id))
+ else:
+ raise TypeError(
+ f"item must be an instance of BaseDoc or its subclass, not '{type(item).__name__}'"
+ )
+
def configure(self, runtime_config=None, **kwargs):
"""
Configure the DocumentIndex.
@@ -390,6 +470,7 @@ def index(self, docs: Union[BaseDoc, Sequence[BaseDoc]], **kwargs):
n_docs = 1 if isinstance(docs, BaseDoc) else len(docs)
self._logger.debug(f'Indexing {n_docs} documents')
docs_validated = self._validate_docs(docs)
+ self._update_subindex_data(docs_validated)
data_by_columns = self._get_col_value_dict(docs_validated)
self._index(data_by_columns, **kwargs)
@@ -412,6 +493,7 @@ def find(
:return: a named tuple containing `documents` and `scores`
"""
self._logger.debug(f'Executing `find` for search field {search_field}')
+
self._validate_search_field(search_field)
if isinstance(query, BaseDoc):
query_vec = self._get_values_by_column([query], search_field)[0]
@@ -427,6 +509,43 @@ def find(
return FindResult(documents=docs, scores=scores)
+ def find_subindex(
+ self,
+ query: Union[AnyTensor, BaseDoc],
+ subindex: str = '',
+ search_field: str = '',
+ limit: int = 10,
+ **kwargs,
+ ) -> SubindexFindResult:
+ """Find documents in subindex level.
+
+ :param query: query vector for KNN/ANN search.
+ Can be either a tensor-like (np.array, torch.Tensor, etc.)
+ with a single axis, or a Document
+ :param subindex: name of the subindex to search on
+ :param search_field: name of the field to search on
+ :param limit: maximum number of documents to return
+ :return: a named tuple containing root docs, subindex docs and scores
+ """
+ self._logger.debug(f'Executing `find_subindex` for search field {search_field}')
+
+ sub_docs, scores = self._find_subdocs(
+ query, subindex=subindex, search_field=search_field, limit=limit, **kwargs
+ )
+
+ fields = subindex.split('__')
+ root_ids = [
+ self._get_root_doc_id(doc.id, fields[0], '__'.join(fields[1:]))
+ for doc in sub_docs
+ ]
+ root_docs = DocList[self._schema]() # type: ignore
+ for id in root_ids:
+ root_docs.append(self[id])
+
+ return SubindexFindResult(
+ root_documents=root_docs, sub_documents=sub_docs, scores=scores # type: ignore
+ )
+
def find_batched(
self,
queries: Union[AnyTensor, DocList],
@@ -447,6 +566,18 @@ def find_batched(
:return: a named tuple containing `documents` and `scores`
"""
self._logger.debug(f'Executing `find_batched` for search field {search_field}')
+
+ if search_field:
+ if '__' in search_field:
+ fields = search_field.split('__')
+ if safe_issubclass(self._schema._get_field_annotation(fields[0]), AnyDocArray): # type: ignore
+ return self._subindices[fields[0]].find_batched(
+ queries,
+ search_field='__'.join(fields[1:]),
+ limit=limit,
+ **kwargs,
+ )
+
self._validate_search_field(search_field)
if isinstance(queries, Sequence):
query_vec_list = self._get_values_by_column(queries, search_field)
@@ -459,8 +590,11 @@ def find_batched(
da_list, scores = self._find_batched(
query_vec_np, search_field=search_field, limit=limit, **kwargs
)
-
- if len(da_list) > 0 and isinstance(da_list[0], List):
+ if (
+ len(da_list) > 0
+ and isinstance(da_list[0], List)
+ and not isinstance(da_list[0], DocList)
+ ):
da_list = [self._dict_list_to_docarray(docs) for docs in da_list]
return FindResultBatched(documents=da_list, scores=scores) # type: ignore
@@ -480,11 +614,38 @@ def filter(
self._logger.debug(f'Executing `filter` for the query {filter_query}')
docs = self._filter(filter_query, limit=limit, **kwargs)
- if isinstance(docs, List):
+ if isinstance(docs, List) and not isinstance(docs, DocList):
docs = self._dict_list_to_docarray(docs)
return docs
+ def filter_subindex(
+ self,
+ filter_query: Any,
+ subindex: str,
+ limit: int = 10,
+ **kwargs,
+ ) -> DocList:
+ """Find documents in subindex level based on a filter query
+
+ :param filter_query: the DB specific filter query to execute
+ :param subindex: name of the subindex to search on
+ :param limit: maximum number of documents to return
+ :return: a DocList containing the subindex level documents that match the filter query
+ """
+ self._logger.debug(
+ f'Executing `filter` for the query {filter_query} in subindex {subindex}'
+ )
+ if '__' in subindex:
+ fields = subindex.split('__')
+ return self._subindices[fields[0]].filter_subindex(
+ filter_query, '__'.join(fields[1:]), limit=limit, **kwargs
+ )
+ else:
+ return self._subindices[subindex].filter(
+ filter_query, limit=limit, **kwargs
+ )
+
def filter_batched(
self,
filter_queries: Any,
@@ -531,7 +692,7 @@ def text_search(
query_text, search_field=search_field, limit=limit, **kwargs
)
- if isinstance(docs, List):
+ if isinstance(docs, List) and not isinstance(docs, DocList):
docs = self._dict_list_to_docarray(docs)
return FindResult(documents=docs, scores=scores)
@@ -572,6 +733,14 @@ def text_search_batched(
da_list_ = cast(List[DocList], da_list)
return FindResultBatched(documents=da_list_, scores=scores)
+ def _filter_by_parent_id(self, id: str) -> Optional[List[str]]:
+ """Filter the ids of the subindex documents given id of root document.
+
+ :param id: the root document id to filter by
+ :return: a list of ids of the subindex documents
+ """
+ return None
+
##########################################################
# Helper methods #
# These might be useful in your subclass implementation #
@@ -635,6 +804,28 @@ def _col_gen(col_name: str):
return {col_name: _col_gen(col_name) for col_name in self._column_infos}
+ def _update_subindex_data(
+ self,
+ docs: DocList[BaseDoc],
+ ):
+ """
+ Add `parent_id` to all sublevel documents.
+
+ :param docs: The document(s) to update the `parent_id` for
+ """
+ for field_name, type_, _ in self._flatten_schema(
+ cast(Type[BaseDoc], self._schema)
+ ):
+ if safe_issubclass(type_, AnyDocArray):
+ for doc in docs:
+ _list = getattr(doc, field_name)
+ for i, nested_doc in enumerate(_list):
+ nested_doc = self._subindices[field_name]._schema( # type: ignore
+ **nested_doc.__dict__
+ )
+ nested_doc.parent_id = doc.id
+ _list[i] = nested_doc
+
##################################################
# Behind-the-scenes magic #
# Subclasses should not need to implement these #
@@ -644,7 +835,7 @@ def __class_getitem__(cls, item: Type[TSchema]):
# do nothing
# enables use in static contexts with type vars, e.g. as type annotation
return Generic.__class_getitem__.__func__(cls, item)
- if not issubclass(item, BaseDoc):
+ if not safe_issubclass(item, BaseDoc):
raise ValueError(
f'{cls.__name__}[item] `item` should be a Document not a {item} '
)
@@ -677,8 +868,8 @@ def _flatten_schema(
:return: A list of column names, types, and fields
"""
names_types_fields: List[Tuple[str, Type, 'ModelField']] = []
- for field_name, field_ in schema.__fields__.items():
- t_ = schema._get_field_type(field_name)
+ for field_name, field_ in schema._docarray_fields().items():
+ t_ = schema._get_field_annotation(field_name)
inner_prefix = name_prefix + field_name + '__'
if is_union_type(t_):
@@ -694,7 +885,7 @@ def _flatten_schema(
# treat as if it was a single non-optional type
for t_arg in union_args:
if t_arg is not type(None):
- if issubclass(t_arg, BaseDoc):
+ if safe_issubclass(t_arg, BaseDoc):
names_types_fields.extend(
cls._flatten_schema(t_arg, name_prefix=inner_prefix)
)
@@ -706,11 +897,11 @@ def _flatten_schema(
raise ValueError(
f'Union type {t_} is not supported. Only Union of subclasses of AbstractTensor or Union[type, None] are supported.'
)
- elif issubclass(t_, BaseDoc):
+ elif safe_issubclass(t_, BaseDoc):
names_types_fields.extend(
cls._flatten_schema(t_, name_prefix=inner_prefix)
)
- elif issubclass(t_, AbstractTensor):
+ elif safe_issubclass(t_, AbstractTensor):
names_types_fields.append(
(name_prefix + field_name, AbstractTensor, field_)
)
@@ -728,43 +919,63 @@ def _create_column_infos(self, schema: Type[BaseDoc]) -> Dict[str, _ColumnInfo]:
column_infos: Dict[str, _ColumnInfo] = dict()
for field_name, type_, field_ in self._flatten_schema(schema):
# Union types are handle in _flatten_schema
- if issubclass(type_, AnyDocArray):
- raise ValueError(
- 'Indexing field of DocList type (=subindex)' 'is not yet supported.'
+ if safe_issubclass(type_, AnyDocArray):
+ column_infos[field_name] = _ColumnInfo(
+ docarray_type=type_, db_type=None, config=dict(), n_dim=None
)
else:
column_infos[field_name] = self._create_single_column(field_, type_)
+
return column_infos
def _create_single_column(self, field: 'ModelField', type_: Type) -> _ColumnInfo:
- custom_config = field.field_info.extra
+ custom_config = (
+ field.json_schema_extra if is_pydantic_v2 else field.field_info.extra
+ )
+ if custom_config is None:
+ custom_config = dict()
+
if 'col_type' in custom_config.keys():
db_type = custom_config['col_type']
custom_config.pop('col_type')
- if db_type not in self._runtime_config.default_column_config.keys():
+ if db_type not in self._db_config.default_column_config.keys():
raise ValueError(
f'The given col_type is not a valid db type: {db_type}'
)
else:
db_type = self.python_type_to_db_type(type_)
- config = self._runtime_config.default_column_config[db_type].copy()
+ config = self._db_config.default_column_config[db_type].copy()
config.update(custom_config)
# parse n_dim from parametrized tensor type
+
+ field_type = field.annotation if is_pydantic_v2 else field.type_
if (
- hasattr(field.type_, '__docarray_target_shape__')
- and field.type_.__docarray_target_shape__
+ hasattr(field_type, '__docarray_target_shape__')
+ and field_type.__docarray_target_shape__
):
- if len(field.type_.__docarray_target_shape__) == 1:
- n_dim = field.type_.__docarray_target_shape__[0]
+ if len(field_type.__docarray_target_shape__) == 1:
+ n_dim = field_type.__docarray_target_shape__[0]
else:
- n_dim = field.type_.__docarray_target_shape__
+ n_dim = field_type.__docarray_target_shape__
else:
n_dim = None
return _ColumnInfo(
docarray_type=type_, db_type=db_type, config=config, n_dim=n_dim
)
+ def _init_subindex(
+ self,
+ ):
+ """Initialize subindices if any column is subclass of AnyDocArray."""
+ for col_name, col in self._column_infos.items():
+ if safe_issubclass(col.docarray_type, AnyDocArray):
+ sub_db_config = copy.deepcopy(self._db_config)
+ sub_db_config.index_name = f'{self.index_name}__{col_name}'
+ self._subindices[col_name] = self.__class__[col.docarray_type.doc_type]( # type: ignore
+ db_config=sub_db_config, subindex=True
+ )
+
def _validate_docs(
self, docs: Union[BaseDoc, Sequence[BaseDoc]]
) -> DocList[BaseDoc]:
@@ -799,7 +1010,7 @@ def _validate_docs(
# see schema translation ideas in the design doc
names_compatible = reference_names == input_names
types_compatible = all(
- (issubclass(t2, t1))
+ (safe_issubclass(t2, t1))
for (t1, t2) in zip(reference_types, input_types)
)
if names_compatible and types_compatible:
@@ -809,12 +1020,15 @@ def _validate_docs(
for i in range(len(docs)):
# validate the data
try:
- out_docs.append(cast(Type[BaseDoc], self._schema).parse_obj(docs[i]))
- except (ValueError, ValidationError):
+ out_docs.append(
+ cast(Type[BaseDoc], self._schema).parse_obj(dict(docs[i]))
+ )
+ except (ValueError, ValidationError) as e:
raise ValueError(
'The schema of the input Documents is not compatible with the schema of the Document Index.'
' Ensure that the field names of your data match the field names of the Document Index schema,'
' and that the types of your data match the types of the Document Index schema.'
+ f'original error {e}'
)
return DocList[BaseDoc].construct(out_docs)
@@ -864,7 +1078,7 @@ def _to_numpy(self, val: Any, allow_passthrough=False) -> Any:
raise ValueError(f'Unsupported input type for {type(self)}: {type(val)}')
def _convert_dict_to_doc(
- self, doc_dict: Dict[str, Any], schema: Type[BaseDoc]
+ self, doc_dict: Dict[str, Any], schema: Type[BaseDoc], inner=False
) -> BaseDoc:
"""
Convert a dict to a Document object.
@@ -873,14 +1087,18 @@ def _convert_dict_to_doc(
:param schema: The schema of the Document object
:return: A Document object
"""
- for field_name, _ in schema.__fields__.items():
- t_ = schema._get_field_type(field_name)
+ for field_name, _ in schema._docarray_fields().items():
+ t_ = schema._get_field_annotation(field_name)
+
+ if not is_union_type(t_) and safe_issubclass(t_, AnyDocArray):
+ self._get_subindex_doclist(doc_dict, field_name)
+
if is_optional_type(t_):
for t_arg in get_args(t_):
if t_arg is not type(None):
t_ = t_arg
- if not is_union_type(t_) and issubclass(t_, BaseDoc):
+ if not is_union_type(t_) and safe_issubclass(t_, BaseDoc):
inner_dict = {}
fields = [
@@ -890,17 +1108,126 @@ def _convert_dict_to_doc(
nested_name = key[len(f'{field_name}__') :]
inner_dict[nested_name] = doc_dict.pop(key)
- doc_dict[field_name] = self._convert_dict_to_doc(inner_dict, t_)
+ doc_dict[field_name] = self._convert_dict_to_doc(
+ inner_dict, t_, inner=True
+ )
- schema_cls = cast(Type[BaseDoc], schema)
- return schema_cls(**doc_dict)
+ if self._is_subindex and not inner:
+ doc_dict.pop('parent_id', None)
+ schema_cls = cast(Type[BaseDoc], self._ori_schema)
+ else:
+ schema_cls = cast(Type[BaseDoc], schema)
+ doc = schema_cls(**doc_dict)
+ return doc
def _dict_list_to_docarray(self, dict_list: Sequence[Dict[str, Any]]) -> DocList:
"""Convert a list of docs in dict type to a DocList of the schema type."""
-
doc_list = [self._convert_dict_to_doc(doc_dict, self._schema) for doc_dict in dict_list] # type: ignore
- docs_cls = DocList.__class_getitem__(cast(Type[BaseDoc], self._schema))
+ if self._is_subindex:
+ docs_cls = DocList.__class_getitem__(cast(Type[BaseDoc], self._ori_schema))
+ else:
+ docs_cls = DocList.__class_getitem__(cast(Type[BaseDoc], self._schema))
return docs_cls(doc_list)
def __len__(self) -> int:
return self.num_docs()
+
+ def _index_subindex(self, column_to_data: Dict[str, Generator[Any, None, None]]):
+ """Index subindex documents in the corresponding subindex.
+
+ :param column_to_data: A dictionary from column name to a generator
+ """
+ for col_name, col in self._column_infos.items():
+ if safe_issubclass(col.docarray_type, AnyDocArray):
+ docs = [
+ doc for doc_list in column_to_data[col_name] for doc in doc_list
+ ]
+ self._subindices[col_name].index(docs)
+ column_to_data.pop(col_name, None)
+
+ def _get_subindex_doclist(self, doc: Dict[str, Any], field_name: str):
+ """Get subindex Documents from the index and assign them to `field_name`.
+
+ :param doc: a dictionary mapping from column name to value
+ :param field_name: field name of the subindex Documents
+ """
+ if field_name not in doc.keys():
+ parent_id = doc['id']
+ nested_docs_id = self._subindices[field_name]._filter_by_parent_id(
+ parent_id
+ )
+ if nested_docs_id:
+ doc[field_name] = self._subindices[field_name].__getitem__(
+ nested_docs_id
+ )
+
+ def _find_subdocs(
+ self,
+ query: Union[AnyTensor, BaseDoc],
+ subindex: str = '',
+ search_field: str = '',
+ limit: int = 10,
+ **kwargs,
+ ) -> FindResult:
+ """Find documents in the subindex and return subindex docs and scores."""
+ fields = subindex.split('__')
+ if not subindex or not safe_issubclass(
+ self._schema._get_field_annotation(fields[0]), AnyDocArray # type: ignore
+ ):
+ raise ValueError(f'subindex {subindex} is not valid')
+
+ if len(fields) == 1:
+ return self._subindices[fields[0]].find(
+ query, search_field=search_field, limit=limit, **kwargs
+ )
+
+ return self._subindices[fields[0]]._find_subdocs(
+ query,
+ subindex='___'.join(fields[1:]),
+ search_field=search_field,
+ limit=limit,
+ **kwargs,
+ )
+
+ def _get_root_doc_id(self, id: str, root: str, sub: str) -> str:
+ """Get the root_id given the id of a subindex Document and the root and subindex name
+
+ :param id: id of the subindex Document
+ :param root: root index name
+ :param sub: subindex name
+ :return: the root_id of the Document
+ """
+ subindex = self._subindices[root]
+
+ if not sub:
+ sub_doc = subindex._get_items([id])
+ parent_id = (
+ sub_doc[0]['parent_id']
+ if isinstance(sub_doc[0], dict)
+ else sub_doc[0].parent_id
+ )
+ return parent_id
+ else:
+ fields = sub.split('__')
+ cur_root_id = subindex._get_root_doc_id(
+ id, fields[0], '__'.join(fields[1:])
+ )
+ return self._get_root_doc_id(cur_root_id, root, '')
+
+ def subindex_contains(self, item: BaseDoc) -> bool:
+ """Checks if a given BaseDoc item is contained in the index or any of its subindices.
+
+ :param item: the given BaseDoc
+ :return: if the given BaseDoc item is contained in the index/subindices
+ """
+ if self._is_index_empty:
+ return False
+
+ if safe_issubclass(type(item), BaseDoc):
+ return self.__contains__(item) or any(
+ index.subindex_contains(item) for index in self._subindices.values()
+ )
+ else:
+ raise TypeError(
+ f"item must be an instance of BaseDoc or its subclass, not '{type(item).__name__}'"
+ )
diff --git a/docarray/index/backends/__init__.py b/docarray/index/backends/__init__.py
index e69de29bb2d..74f8f7582cd 100644
--- a/docarray/index/backends/__init__.py
+++ b/docarray/index/backends/__init__.py
@@ -0,0 +1,15 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/docarray/index/backends/elastic.py b/docarray/index/backends/elastic.py
index c4da3ad5e0d..a335f85e32a 100644
--- a/docarray/index/backends/elastic.py
+++ b/docarray/index/backends/elastic.py
@@ -25,10 +25,12 @@
import docarray.typing
from docarray import BaseDoc
+from docarray.array.any_array import AnyDocArray
from docarray.index.abstract import BaseDocIndex, _ColumnInfo, _raise_not_composable
from docarray.typing import AnyTensor
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.typing.tensor.ndarray import NdArray
+from docarray.utils._internal._typing import safe_issubclass
from docarray.utils._internal.misc import import_library
from docarray.utils.find import _FindResult, _FindResultBatched
@@ -37,8 +39,9 @@
ELASTIC_PY_VEC_TYPES: List[Any] = [list, tuple, np.ndarray, AbstractTensor]
-
if TYPE_CHECKING:
+ import tensorflow as tf # type: ignore
+ import torch
from elastic_transport import NodeConfig
from elasticsearch import Elasticsearch
from elasticsearch.helpers import parallel_bulk
@@ -53,7 +56,6 @@
torch = import_library('torch', raise_error=False)
tf = import_library('tensorflow', raise_error=False)
-
if torch is not None:
ELASTIC_PY_VEC_TYPES.append(torch.Tensor)
@@ -65,6 +67,9 @@
class ElasticDocIndex(BaseDocIndex, Generic[TSchema]):
+ _index_vector_params: Optional[Tuple[str]] = ('dims', 'similarity', 'index')
+ _index_vector_options: Optional[Tuple[str]] = ('m', 'ef_construction')
+
def __init__(self, db_config=None, **kwargs):
"""Initialize ElasticDocIndex"""
super().__init__(db_config=db_config, **kwargs)
@@ -77,11 +82,9 @@ def __init__(self, db_config=None, **kwargs):
hosts=self._db_config.hosts,
**self._db_config.es_config,
)
+ self._logger.debug('ElasticSearch client has been created')
# ElasticSearh index setup
- self._index_vector_params = ('dims', 'similarity', 'index')
- self._index_vector_options = ('m', 'ef_construction')
-
mappings: Dict[str, Any] = {
'dynamic': True,
'_source': {'enabled': 'true'},
@@ -89,7 +92,11 @@ def __init__(self, db_config=None, **kwargs):
}
mappings.update(self._db_config.index_mappings)
+ self._logger.debug('Mappings have been updated with db_config.index_mappings')
+
for col_name, col in self._column_infos.items():
+ if safe_issubclass(col.docarray_type, AnyDocArray):
+ continue
if col.db_type == 'dense_vector' and (
not col.n_dim and col.config['dims'] < 0
):
@@ -99,17 +106,21 @@ def __init__(self, db_config=None, **kwargs):
continue
mappings['properties'][col_name] = self._create_index_mapping(col)
+ self._logger.debug(f'Index mapping created for column {col_name}')
- # print(mappings['properties'])
if self._client.indices.exists(index=self.index_name):
self._client_put_mapping(mappings)
+ self._logger.debug(f'Put mapping for index {self.index_name}')
else:
self._client_create(mappings)
+ self._logger.debug(f'Created new index {self.index_name} with mappings')
if len(self._db_config.index_settings):
self._client_put_settings(self._db_config.index_settings)
+ self._logger.debug('Updated index settings')
self._refresh(self.index_name)
+ self._logger.debug(f'Refreshed index {self.index_name}')
@property
def index_name(self):
@@ -117,12 +128,15 @@ def index_name(self):
self._schema.__name__.lower() if self._schema is not None else None
)
if default_index_name is None:
- raise ValueError(
- 'A ElasticDocIndex must be typed with a Document type.'
- 'To do so, use the syntax: ElasticDocIndex[DocumentType]'
+ err_msg = (
+ 'A ElasticDocIndex must be typed with a Document type.To do so, use the syntax: '
+ 'ElasticDocIndex[DocumentType] '
)
- return self._db_config.index_name or default_index_name
+ self._logger.error(err_msg)
+ raise ValueError(err_msg)
+ index_name = self._db_config.index_name or default_index_name
+ return index_name
###############################################
# Inner classes for query builder and configs #
@@ -137,6 +151,10 @@ def __init__(self, outer_instance, **kwargs):
def build(self, *args, **kwargs) -> Any:
"""Build the elastic search query object."""
+ self._outer_instance._logger.debug(
+ 'Building the Elastic Search query object'
+ )
+
if len(self._query['query']) == 0:
del self._query['query']
elif 'knn' in self._query:
@@ -161,6 +179,8 @@ def find(
:param num_candidates: number of candidates
:return: self
"""
+ self._outer_instance._logger.debug('Executing find query')
+
self._outer_instance._validate_search_field(search_field)
if isinstance(query, BaseDoc):
query_vec = BaseDocIndex._get_values_by_column([query], search_field)[0]
@@ -185,6 +205,8 @@ def filter(self, query: Dict[str, Any], limit: int = 10):
:param limit: maximum number of documents to return
:return: self
"""
+ self._outer_instance._logger.debug('Executing filter query')
+
self._query['size'] = limit
self._query['query']['bool']['filter'].append(query)
return self
@@ -197,6 +219,8 @@ def text_search(self, query: str, search_field: str = 'text', limit: int = 10):
:param limit: maximum number of documents to find
:return: self
"""
+ self._outer_instance._logger.debug('Executing text search query')
+
self._outer_instance._validate_search_field(search_field)
self._query['size'] = limit
self._query['query']['bool']['must'].append(
@@ -227,13 +251,7 @@ class DBConfig(BaseDocIndex.DBConfig):
es_config: Dict[str, Any] = field(default_factory=dict)
index_settings: Dict[str, Any] = field(default_factory=dict)
index_mappings: Dict[str, Any] = field(default_factory=dict)
-
- @dataclass
- class RuntimeConfig(BaseDocIndex.RuntimeConfig):
- """Dataclass that contains all "dynamic" configurations of ElasticDocIndex."""
-
default_column_config: Dict[Any, Dict[str, Any]] = field(default_factory=dict)
- chunk_size: int = 500
def __post_init__(self):
self.default_column_config = {
@@ -284,6 +302,7 @@ def __post_init__(self):
def dense_vector_config(self):
"""Get the dense vector config."""
+
config = {
'dims': -1,
'index': True,
@@ -295,6 +314,12 @@ def dense_vector_config(self):
return config
+ @dataclass
+ class RuntimeConfig(BaseDocIndex.RuntimeConfig):
+ """Dataclass that contains all "dynamic" configurations of ElasticDocIndex."""
+
+ chunk_size: int = 500
+
###############################################
# Implementation of abstract methods #
###############################################
@@ -307,8 +332,13 @@ def python_type_to_db_type(self, python_type: Type) -> Any:
:return: the corresponding database column type,
or None if ``python_type`` is not supported.
"""
+ self._logger.debug(f'Mapping Python type {python_type} to database type')
+
for allowed_type in ELASTIC_PY_VEC_TYPES:
- if issubclass(python_type, allowed_type):
+ if safe_issubclass(python_type, allowed_type):
+ self._logger.info(
+ f'Mapped Python type {python_type} to database type "dense_vector"'
+ )
return 'dense_vector'
elastic_py_types = {
@@ -322,11 +352,16 @@ def python_type_to_db_type(self, python_type: Type) -> Any:
dict: 'object',
}
- for type in elastic_py_types.keys():
- if issubclass(python_type, type):
- return elastic_py_types[type]
+ for t in elastic_py_types.keys():
+ if safe_issubclass(python_type, t):
+ self._logger.info(
+ f'Mapped Python type {python_type} to database type "{elastic_py_types[t]}"'
+ )
+ return elastic_py_types[t]
- raise ValueError(f'Unsupported column type for {type(self)}: {python_type}')
+ err_msg = f'Unsupported column type for {type(self)}: {python_type}'
+ self._logger.error(err_msg)
+ raise ValueError(err_msg)
def _index(
self,
@@ -334,6 +369,8 @@ def _index(
refresh: bool = True,
chunk_size: Optional[int] = None,
):
+ self._index_subindex(column_to_data)
+
data = self._transpose_col_value_dict(column_to_data)
requests = []
@@ -343,6 +380,8 @@ def _index(
'_id': row['id'],
}
for col_name, col in self._column_infos.items():
+ if safe_issubclass(col.docarray_type, AnyDocArray):
+ continue
if col.db_type == 'dense_vector' and np.all(row[col_name] == 0):
row[col_name] = row[col_name] + 1.0e-9
if row[col_name] is None:
@@ -353,14 +392,17 @@ def _index(
_, warning_info = self._send_requests(requests, chunk_size)
for info in warning_info:
warnings.warn(str(info))
+ self._logger.warning('Warning: %s', str(info))
if refresh:
+ self._logger.debug('Refreshing the index')
self._refresh(self.index_name)
def num_docs(self) -> int:
"""
Get the number of documents.
"""
+ self._logger.debug('Getting the number of documents in the index')
return self._client.count(index=self.index_name)['count']
def _del_items(
@@ -383,7 +425,7 @@ def _del_items(
self._refresh(self.index_name)
- def _get_items(self, doc_ids: Sequence[str]) -> Sequence[TSchema]:
+ def _get_items(self, doc_ids: Sequence[str]) -> Sequence[Dict[str, Any]]:
accumulated_docs = []
accumulated_docs_id_not_found = []
@@ -417,10 +459,14 @@ def execute_query(self, query: Dict[str, Any], *args, **kwargs) -> Any:
:param kwargs: keyword arguments to pass to the query
:return: the result of the query
"""
+ self._logger.debug(f'Executing query: {query}')
+
if args or kwargs:
- raise ValueError(
+ err_msg = (
f'args and kwargs not supported for `execute_query` on {type(self)}'
)
+ self._logger.error(err_msg)
+ raise ValueError(err_msg)
resp = self._client.search(index=self.index_name, **query)
docs, scores = self._format_response(resp)
@@ -515,24 +561,34 @@ def _text_search_batched(
)
return _FindResultBatched(documents=list(das), scores=scores)
+ def _filter_by_parent_id(self, id: str) -> List[str]:
+ resp = self._client_search(
+ query={'term': {'parent_id': id}}, fields=['id'], _source=False
+ )
+ ids = [hit['fields']['id'][0] for hit in resp['hits']['hits']]
+ return ids
+
###############################################
# Helpers #
###############################################
- def _create_index_mapping(self, col: '_ColumnInfo') -> Dict[str, Any]:
+ @classmethod
+ def _create_index_mapping(cls, col: '_ColumnInfo') -> Dict[str, Any]:
"""Create a new HNSW index for a column, and initialize it."""
index = {'type': col.config['type'] if 'type' in col.config else col.db_type}
if col.db_type == 'dense_vector':
- for k in self._index_vector_params:
- index[k] = col.config[k]
+ if cls._index_vector_params is not None:
+ for k in cls._index_vector_params:
+ index[k] = col.config[k]
if col.n_dim:
index['dims'] = col.n_dim
- index['index_options'] = dict(
- (k, col.config[k]) for k in self._index_vector_options
- )
- index['index_options']['type'] = 'hnsw'
+ if cls._index_vector_options is not None:
+ index['index_options'] = dict(
+ (k, col.config[k]) for k in cls._index_vector_options
+ )
+ index['index_options']['type'] = 'hnsw'
return index
def _send_requests(
@@ -568,7 +624,7 @@ def _form_search_body(
num_candidates: Optional[int] = None,
) -> Dict[str, Any]:
if not num_candidates:
- num_candidates = self._runtime_config.default_column_config['dense_vector'][
+ num_candidates = self._db_config.default_column_config['dense_vector'][
'num_candidates'
]
body = {
@@ -615,6 +671,12 @@ def _format_response(self, response: Any) -> Tuple[List[Dict], List[Any]]:
def _refresh(self, index_name: str):
self._client.indices.refresh(index=index_name)
+ def _doc_exists(self, doc_id: str) -> bool:
+ if len(doc_id) == 0:
+ return False
+ ret = self._client_mget([doc_id])
+ return ret["docs"][0]["found"]
+
###############################################
# API Wrappers #
###############################################
diff --git a/docarray/index/backends/elasticv7.py b/docarray/index/backends/elasticv7.py
index 1782e921f62..6ff428b0436 100644
--- a/docarray/index/backends/elasticv7.py
+++ b/docarray/index/backends/elasticv7.py
@@ -1,13 +1,13 @@
import warnings
from dataclasses import dataclass
-from typing import Any, Dict, List, Optional, Sequence, TypeVar, Union
+from typing import Any, Dict, List, Optional, Sequence, TypeVar, Union, Tuple
import numpy as np
from pydantic import parse_obj_as
from docarray import BaseDoc
from docarray.index import ElasticDocIndex
-from docarray.index.abstract import BaseDocIndex, _ColumnInfo
+from docarray.index.abstract import BaseDocIndex
from docarray.typing import AnyTensor
from docarray.typing.tensor.ndarray import NdArray
from docarray.utils.find import _FindResult
@@ -17,6 +17,9 @@
class ElasticV7DocIndex(ElasticDocIndex):
+ _index_vector_params: Optional[Tuple[str]] = ('dims',)
+ _index_vector_options: Optional[Tuple[str]] = None
+
def __init__(self, db_config=None, **kwargs):
"""Initialize ElasticV7DocIndex"""
from elasticsearch import __version__ as __es__version__
@@ -90,12 +93,14 @@ class DBConfig(ElasticDocIndex.DBConfig):
hosts: Union[str, List[str], None] = 'http://localhost:9200' # type: ignore
+ def dense_vector_config(self):
+ return {'dims': 128}
+
@dataclass
class RuntimeConfig(ElasticDocIndex.RuntimeConfig):
"""Dataclass that contains all "dynamic" configurations of ElasticDocIndex."""
- def dense_vector_config(self):
- return {'dims': 128}
+ pass
###############################################
# Implementation of abstract methods #
@@ -128,19 +133,6 @@ def execute_query(self, query: Dict[str, Any], *args, **kwargs) -> Any:
# Helpers #
###############################################
- # ElasticSearch helpers
- def _create_index_mapping(self, col: '_ColumnInfo') -> Dict[str, Any]:
- """Create a new HNSW index for a column, and initialize it."""
-
- index = col.config.copy()
- if 'type' not in index:
- index['type'] = col.db_type
-
- if col.db_type == 'dense_vector' and col.n_dim:
- index['dims'] = col.n_dim
-
- return index
-
def _form_search_body(self, query: np.ndarray, limit: int, search_field: str = '') -> Dict[str, Any]: # type: ignore
body = {
'size': limit,
diff --git a/docarray/index/backends/epsilla.py b/docarray/index/backends/epsilla.py
new file mode 100644
index 00000000000..0392e9d010e
--- /dev/null
+++ b/docarray/index/backends/epsilla.py
@@ -0,0 +1,531 @@
+import copy
+from dataclasses import dataclass, field
+from http import HTTPStatus
+from typing import (
+ Any,
+ Dict,
+ Generator,
+ Generic,
+ List,
+ Optional,
+ Sequence,
+ Type,
+ TypeVar,
+ Union,
+ cast,
+)
+
+import numpy as np
+from pyepsilla import cloud, vectordb
+
+from docarray import BaseDoc, DocList
+from docarray.index.abstract import (
+ BaseDocIndex,
+ _FindResultBatched,
+ _raise_not_composable,
+ _raise_not_supported,
+)
+from docarray.typing import ID, NdArray
+from docarray.typing.tensor.abstract_tensor import AbstractTensor
+from docarray.utils._internal._typing import safe_issubclass
+from docarray.utils.find import _FindResult
+
+TSchema = TypeVar('TSchema', bound=BaseDoc)
+
+
+class EpsillaDocumentIndex(BaseDocIndex, Generic[TSchema]):
+ def __init__(self, db_config=None, **kwargs):
+ # will set _db_config from args / kwargs
+ super().__init__(db_config=db_config, **kwargs)
+
+ self._db_config: EpsillaDocumentIndex.DBConfig = cast(
+ EpsillaDocumentIndex.DBConfig, self._db_config
+ )
+ self._db_config.validate_config()
+ self._validate_column_info()
+
+ self._table_name = (
+ self._db_config.table_name
+ if self._db_config.table_name
+ else self._schema.__name__
+ )
+
+ if self._db_config.is_self_hosted:
+ self._db = vectordb.Client(
+ protocol=self._db_config.protocol,
+ host=self._db_config.host,
+ port=self._db_config.port,
+ )
+ status_code, response = self._db.load_db(
+ db_name=self._db_config.db_name,
+ db_path=self._db_config.db_path,
+ )
+
+ if status_code != HTTPStatus.OK:
+ if status_code == HTTPStatus.CONFLICT:
+ self._logger.info(f'{self._db_config.db_name} already loaded.')
+ else:
+ raise IOError(
+ f"Failed to load database {self._db_config.db_name}. "
+ f"Error code: {status_code}. Error message: {response}."
+ )
+ self._db.use_db(self._db_config.db_name)
+
+ status_code, response = self._db.list_tables()
+ if status_code != HTTPStatus.OK:
+ raise IOError(
+ f"Failed to list tables. "
+ f"Error code: {status_code}. Error message: {response}."
+ )
+
+ if self._table_name not in response["result"]:
+ self._create_table_self_hosted()
+ else:
+ self._client = cloud.Client(
+ project_id=self._db_config.cloud_project_id,
+ api_key=self._db_config.api_key,
+ )
+ self._db = self._client.vectordb(self._db_config.cloud_db_id)
+
+ status_code, response = self._db.list_tables()
+ if status_code != HTTPStatus.OK:
+ raise IOError(
+ f"Failed to list tables. "
+ f"Error code: {status_code}. Error message: {response}."
+ )
+
+ # Epsilla cloud requires table to be created in the web UI before inserting data
+ # It does not support creating tables from Python client yet.
+
+ def _validate_column_info(self):
+ vector_columns = []
+ for info in self._column_infos.values():
+ for t in [list, np.ndarray, AbstractTensor]:
+ if safe_issubclass(info.docarray_type, t) and info.config.get(
+ 'is_embedding', False
+ ):
+ # check that dimension is present
+ if info.n_dim is None and info.config.get('dim', None) is None:
+ raise ValueError("The dimension information is missing")
+
+ vector_columns.append(info.docarray_type)
+ break
+
+ if len(vector_columns) == 0:
+ raise ValueError(
+ "Unable to find any vector columns. Please make sure that at least one "
+ "column is of a vector type with the is_embedding=True attribute specified."
+ )
+ elif len(vector_columns) > 1:
+ raise ValueError("Specifying multiple vector fields is not supported.")
+
+ def _create_table_self_hosted(self):
+ """Use _column_infos to create a table in the database."""
+ table_fields = []
+
+ primary_keys = []
+ for column_name, column_info in self._column_infos.items():
+ if column_info.docarray_type == ID:
+ primary_keys.append(column_name)
+
+ # when there is a nested schema, we may have multiple "ID" fields. We use the presence of "__"
+ # to determine if the field is nested or not
+ if len(primary_keys) > 1:
+ sorted_pkeys = sorted(primary_keys, key=lambda x: x.count("__"))
+ primary_keys = sorted_pkeys[:1]
+
+ for column_name, column_info in self._column_infos.items():
+ dim = (
+ column_info.n_dim
+ if column_info.n_dim is not None
+ else column_info.config.get('dim', None)
+ )
+ if dim is None:
+ table_fields.append(
+ {
+ 'name': column_name,
+ 'dataType': column_info.db_type,
+ 'primaryKey': column_name in primary_keys,
+ }
+ )
+ else:
+ table_fields.append(
+ {
+ 'name': column_name,
+ 'dataType': column_info.db_type,
+ 'dimensions': dim,
+ }
+ )
+
+ status_code, response = self._db.create_table(
+ table_name=self._table_name,
+ table_fields=table_fields,
+ )
+ if status_code != HTTPStatus.OK:
+ raise IOError(
+ f"Failed to create table {self._table_name}. "
+ f"Error code: {status_code}. Error message: {response}."
+ )
+
+ @dataclass
+ class Query:
+ """Dataclass describing a query."""
+
+ vector_field: Optional[str]
+ vector_query: Optional[NdArray]
+ filter: Optional[str]
+ limit: int
+
+ class QueryBuilder(BaseDocIndex.QueryBuilder):
+ def __init__(
+ self,
+ vector_search_field: Optional[str] = None,
+ vector_queries: Optional[List[NdArray]] = None,
+ filter: Optional[str] = None,
+ ):
+ self._vector_search_field: Optional[str] = vector_search_field
+ self._vector_queries: List[NdArray] = vector_queries or []
+ self._filter: Optional[str] = filter
+
+ def find(self, query: NdArray, search_field: str = ''):
+ if self._vector_search_field and self._vector_search_field != search_field:
+ raise ValueError(
+ f'Trying to call .find for search_field = {search_field}, but '
+ f'previously {self._vector_search_field} was used. Only a single '
+ f'field might be used in chained calls.'
+ )
+ return EpsillaDocumentIndex.QueryBuilder(
+ vector_search_field=search_field,
+ vector_queries=self._vector_queries + [query],
+ filter=self._filter,
+ )
+
+ def filter(self, filter_query: str): # type: ignore[override]
+ return EpsillaDocumentIndex.QueryBuilder(
+ vector_search_field=self._vector_search_field,
+ vector_queries=self._vector_queries,
+ filter=filter_query,
+ )
+
+ def build(self, limit: int) -> Any:
+ if len(self._vector_queries) > 0:
+ # If there are multiple vector queries applied, we can average them and
+ # perform semantic search on a single vector instead
+ vector_query = np.average(self._vector_queries, axis=0)
+ else:
+ vector_query = None
+ return EpsillaDocumentIndex.Query(
+ vector_field=self._vector_search_field,
+ vector_query=vector_query,
+ filter=self._filter,
+ limit=limit,
+ )
+
+ find_batched = _raise_not_composable('find_batched')
+ filter_batched = _raise_not_composable('filter_batched')
+ text_search = _raise_not_supported('text_search')
+ text_search_batched = _raise_not_supported('text_search_batched')
+
+ @dataclass
+ class DBConfig(BaseDocIndex.DBConfig):
+ """Static configuration for EpsillaDocumentIndex"""
+
+ # default value is the schema type name
+ table_name: Optional[str] = None
+
+ # Indicator for self-hosted or cloud version
+ is_self_hosted: bool = False
+
+ # self-hosted version uses the following configs
+ protocol: Optional[str] = None
+ host: Optional[str] = None
+ port: Optional[int] = 8888
+ db_path: Optional[str] = None
+ db_name: Optional[str] = None
+
+ # cloud version uses the following configs
+ cloud_project_id: Optional[str] = None
+ cloud_db_id: Optional[str] = None
+ api_key: Optional[str] = None
+
+ default_column_config: Dict[Any, Dict[str, Any]] = field(
+ default_factory=lambda: {
+ 'TINYINT': {},
+ 'SMALLINT': {},
+ 'INT': {},
+ 'BIGINT': {},
+ 'FLOAT': {},
+ 'DOUBLE': {},
+ 'STRING': {},
+ 'BOOL': {},
+ 'JSON': {},
+ 'VECTOR_FLOAT': {},
+ }
+ )
+
+ def validate_config(self):
+ if self.is_self_hosted:
+ self.validate_self_hosted_config()
+ else:
+ self.validate_cloud_config()
+
+ def validate_self_hosted_config(self):
+ missing_attributes = [
+ attr
+ for attr in ["protocol", "host", "port", "db_path", "db_name"]
+ if getattr(self, attr, None) is None
+ ]
+
+ if missing_attributes:
+ raise ValueError(
+ f"Missing required attributes for self-hosted version: {', '.join(missing_attributes)}"
+ )
+
+ def validate_cloud_config(self):
+ missing_attributes_cloud = [
+ attr
+ for attr in ["cloud_project_id", "cloud_db_id", "api_key"]
+ if getattr(self, attr, None) is None
+ ]
+
+ if missing_attributes_cloud:
+ raise ValueError(
+ f"Missing required attributes for cloud version: {', '.join(missing_attributes_cloud)}"
+ )
+
+ @dataclass
+ class RuntimeConfig(BaseDocIndex.RuntimeConfig):
+ # No dynamic config used
+ pass
+
+ @property
+ def collection_name(self):
+ return self._db_config.table_name
+
+ @property
+ def index_name(self):
+ return self.collection_name
+
+ def python_type_to_db_type(self, python_type: Type) -> str:
+ # AbstractTensor does not have n_dims, which is required by Epsilla
+ # Use NdArray instead
+ for allowed_type in [list, np.ndarray, AbstractTensor]:
+ if safe_issubclass(python_type, allowed_type):
+ return 'VECTOR_FLOAT'
+
+ py_type_map = {
+ ID: 'STRING',
+ str: 'STRING',
+ bytes: 'STRING',
+ int: 'BIGINT',
+ float: 'FLOAT',
+ bool: 'BOOL',
+ np.ndarray: 'VECTOR_FLOAT',
+ }
+
+ for py_type, epsilla_type in py_type_map.items():
+ if safe_issubclass(python_type, py_type):
+ return epsilla_type
+
+ raise ValueError(f'Unsupported column type for {type(self)}: {python_type}')
+
+ def _index(self, column_to_data: Dict[str, Generator[Any, None, None]]):
+ self._index_subindex(column_to_data)
+
+ rows = list(self._transpose_col_value_dict(column_to_data))
+ normalized_rows = []
+ for row in rows:
+ normalized_row = {}
+ for key, value in row.items():
+ if isinstance(value, NdArray):
+ normalized_row[key] = value.tolist()
+ elif isinstance(value, np.ndarray):
+ normalized_row[key] = value.tolist()
+ else:
+ normalized_row[key] = value
+ normalized_rows.append(normalized_row)
+
+ status_code, response = self._db.insert(
+ table_name=self._table_name, records=normalized_rows
+ )
+
+ if status_code != HTTPStatus.OK:
+ raise IOError(
+ f"Failed to insert documents. "
+ f"Error code: {status_code}. Error message: {response}."
+ )
+
+ def num_docs(self) -> int:
+ raise NotImplementedError
+
+ @property
+ def _is_index_empty(self) -> bool:
+ """
+ Check if index is empty by comparing the number of documents to zero.
+ :return: True if the index is empty, False otherwise.
+ """
+ # Overriding this method to always return False because Epsilla does not have a count API for num_docs
+ return False
+
+ def _del_items(self, doc_ids: Sequence[str]):
+ status_code, response = self._db.delete(
+ table_name=self._table_name,
+ primary_keys=list(doc_ids),
+ )
+ if status_code != HTTPStatus.OK:
+ raise IOError(
+ f"Failed to get documents with ids {doc_ids}. "
+ f"Error code: {status_code}. Error message: {response}."
+ )
+ return response['message']
+
+ def _get_items(
+ self, doc_ids: Sequence[str]
+ ) -> Union[Sequence[TSchema], Sequence[Dict[str, Any]]]:
+ status_code, response = self._db.get(
+ table_name=self._table_name,
+ primary_keys=list(doc_ids),
+ )
+ if status_code != HTTPStatus.OK:
+ raise IOError(
+ f"Failed to get documents with ids {doc_ids}. "
+ f"Error code: {status_code}. Error message: {response}."
+ )
+ return response['result']
+
+ def execute_query(self, query: Query) -> DocList:
+ if query.vector_query is not None:
+ result = self._find_with_filter_batched(
+ queries=np.expand_dims(query.vector_query, axis=0),
+ filter=query.filter,
+ limit=query.limit,
+ search_field=query.vector_field,
+ )
+ return self._dict_list_to_docarray(result.documents[0])
+ else:
+ return self._dict_list_to_docarray(
+ self._filter(
+ filter_query=query.filter,
+ limit=query.limit,
+ )
+ )
+
+ def _doc_exists(self, doc_id: str) -> bool:
+ return len(self._get_items([doc_id])) > 0
+
+ def _find(
+ self,
+ query: np.ndarray,
+ limit: int,
+ search_field: str = '',
+ ) -> _FindResult:
+ query_batched = np.expand_dims(query, axis=0)
+ docs, scores = self._find_batched(
+ queries=query_batched, limit=limit, search_field=search_field
+ )
+ return _FindResult(documents=docs[0], scores=scores[0])
+
+ def _find_batched(
+ self,
+ queries: np.ndarray,
+ limit: int,
+ search_field: str = '',
+ ) -> _FindResultBatched:
+ return self._find_with_filter_batched(
+ queries=queries, limit=limit, search_field=search_field
+ )
+
+ def _find_with_filter_batched(
+ self,
+ queries: np.ndarray,
+ limit: int,
+ search_field: str,
+ filter: Optional[str] = None,
+ ) -> _FindResultBatched:
+ if search_field == '':
+ raise ValueError(
+ 'EpsillaDocumentIndex requires a search_field to be specified.'
+ )
+
+ responses = []
+ for query in queries:
+ status_code, response = self._db.query(
+ table_name=self._table_name,
+ query_field=search_field,
+ limit=limit,
+ filter=filter if filter is not None else '',
+ query_vector=query.tolist(),
+ with_distance=True,
+ )
+
+ if status_code != HTTPStatus.OK:
+ raise IOError(
+ f"Failed to find documents with query {query}. "
+ f"Error code: {status_code}. Error message: {response}."
+ )
+
+ results = response['result']
+ scores = NdArray._docarray_from_native(
+ np.array([result['@distance'] for result in results])
+ )
+ documents = []
+ for result in results:
+ doc = copy.copy(result)
+ del doc["@distance"]
+ documents.append(doc)
+
+ responses.append((documents, scores))
+
+ return _FindResultBatched(
+ documents=[r[0] for r in responses],
+ scores=[r[1] for r in responses],
+ )
+
+ def _filter(
+ self,
+ filter_query: str,
+ limit: int,
+ ) -> Union[DocList, List[Dict]]:
+ query_batched = [filter_query]
+ docs = self._filter_batched(filter_queries=query_batched, limit=limit)
+ return docs[0]
+
+ def _filter_batched(
+ self,
+ filter_queries: str,
+ limit: int,
+ ) -> Union[List[DocList], List[List[Dict]]]:
+ responses = []
+ for filter_query in filter_queries:
+ status_code, response = self._db.get(
+ table_name=self._table_name,
+ limit=limit,
+ filter=filter_query,
+ )
+
+ if status_code != HTTPStatus.OK:
+ raise IOError(
+ f"Failed to find documents with filter {filter_query}. "
+ f"Error code: {status_code}. Error message: {response}."
+ )
+
+ results = response['result']
+ responses.append(results)
+
+ return responses
+
+ def _text_search(
+ self,
+ query: str,
+ limit: int,
+ search_field: str = '',
+ ) -> _FindResult:
+ raise NotImplementedError(f'{type(self)} does not support text search.')
+
+ def _text_search_batched(
+ self,
+ queries: Sequence[str],
+ limit: int,
+ search_field: str = '',
+ ) -> _FindResultBatched:
+ raise NotImplementedError(f'{type(self)} does not support text search.')
diff --git a/docarray/index/backends/helper.py b/docarray/index/backends/helper.py
index e8739fdfcb4..5582dbba866 100644
--- a/docarray/index/backends/helper.py
+++ b/docarray/index/backends/helper.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, List, Tuple, Type, cast
+from typing import Any, Dict, List, Tuple, Type, cast, Set
from docarray import BaseDoc, DocList
from docarray.index.abstract import BaseDocIndex
@@ -20,17 +20,64 @@ def inner(self, *args, **kwargs):
return inner
+def _collect_query_required_args(method_name: str, required_args: Set[str] = None):
+ """
+ Returns a function that ensures required keyword arguments are provided.
+
+ :param method_name: The name of the method for which the required arguments are being checked.
+ :type method_name: str
+ :param required_args: A set containing the names of required keyword arguments. Defaults to None.
+ :type required_args: Optional[Set[str]]
+ :return: A function that checks for required keyword arguments before executing the specified method.
+ Raises ValueError if positional arguments are provided.
+ Raises TypeError if any required keyword argument is missing.
+ :rtype: Callable
+ """
+
+ if required_args is None:
+ required_args = set()
+
+ def inner(self, *args, **kwargs):
+ if args:
+ raise ValueError(
+ f"Positional arguments are not supported for "
+ f"`{type(self)}.{method_name}`. "
+ f"Use keyword arguments instead."
+ )
+
+ missing_args = required_args - set(kwargs.keys())
+ if missing_args:
+ raise ValueError(
+ f"`{type(self)}.{method_name}` is missing required argument(s): {', '.join(missing_args)}"
+ )
+
+ updated_query = self._queries + [(method_name, kwargs)]
+ return type(self)(updated_query)
+
+ return inner
+
+
def _execute_find_and_filter_query(
- doc_index: BaseDocIndex, query: List[Tuple[str, Dict]]
+ doc_index: BaseDocIndex, query: List[Tuple[str, Dict]], reverse_order: bool = False
) -> FindResult:
"""
Executes all find calls from query first using `doc_index.find()`,
and filtering queries after that using DocArray's `filter_docs()`.
Text search is not supported.
+
+ :param doc_index: Document index instance.
+ Either InMemoryExactNNIndex or HnswDocumentIndex.
+ :param query: Dictionary containing search and filtering configuration.
+ :param reverse_order: Flag indicating whether to sort in descending order.
+ If set to False (default), the sorting will be in ascending order.
+ This option is necessary because, depending on the index, lower scores
+ can correspond to better matches, and vice versa.
+ :return: Sorted documents and their corresponding scores.
"""
docs_found = DocList.__class_getitem__(cast(Type[BaseDoc], doc_index._schema))([])
filter_conditions = []
+ filter_limit = None
doc_to_score: Dict[BaseDoc, Any] = {}
for op, op_kwargs in query:
if op == 'find':
@@ -39,6 +86,7 @@ def _execute_find_and_filter_query(
doc_to_score.update(zip(docs.__getattribute__('id'), scores))
elif op == 'filter':
filter_conditions.append(op_kwargs['filter_query'])
+ filter_limit = op_kwargs.get('limit')
else:
raise ValueError(f'Query operation is not supported: {op}')
@@ -48,11 +96,14 @@ def _execute_find_and_filter_query(
docs_cls = DocList.__class_getitem__(cast(Type[BaseDoc], doc_index._schema))
docs_filtered = docs_cls(filter_docs(docs_filtered, cond))
+ if filter_limit:
+ docs_filtered = docs_filtered[:filter_limit]
+
doc_index._logger.debug(f'{len(docs_filtered)} results found')
docs_and_scores = zip(
docs_filtered, (doc_to_score[doc.id] for doc in docs_filtered)
)
- docs_sorted = sorted(docs_and_scores, key=lambda x: x[1])
+ docs_sorted = sorted(docs_and_scores, key=lambda x: x[1], reverse=reverse_order)
out_docs, out_scores = zip(*docs_sorted)
return FindResult(documents=out_docs, scores=out_scores)
diff --git a/docarray/index/backends/hnswlib.py b/docarray/index/backends/hnswlib.py
index 756f80f78e4..e542711e0ca 100644
--- a/docarray/index/backends/hnswlib.py
+++ b/docarray/index/backends/hnswlib.py
@@ -1,16 +1,20 @@
+import glob
import hashlib
import os
import sqlite3
+from collections import OrderedDict, defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Dict,
+ Generator,
Generic,
List,
Optional,
Sequence,
+ Set,
Tuple,
Type,
TypeVar,
@@ -21,21 +25,21 @@
import numpy as np
from docarray import BaseDoc, DocList
+from docarray.array.any_array import AnyDocArray
from docarray.index.abstract import (
BaseDocIndex,
_ColumnInfo,
_raise_not_composable,
_raise_not_supported,
)
-from docarray.index.backends.helper import (
- _collect_query_args,
- _execute_find_and_filter_query,
-)
+from docarray.index.backends.helper import _collect_query_args
from docarray.proto import DocProto
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.typing.tensor.ndarray import NdArray
+from docarray.utils._internal._typing import safe_issubclass
from docarray.utils._internal.misc import import_library, is_np_int
-from docarray.utils.find import _FindResult, _FindResultBatched
+from docarray.utils.filter import filter_docs
+from docarray.utils.find import FindResult, _FindResult, _FindResultBatched
if TYPE_CHECKING:
import hnswlib
@@ -59,19 +63,32 @@
HNSWLIB_PY_VEC_TYPES.append(tf.Tensor)
HNSWLIB_PY_VEC_TYPES.append(TensorFlowTensor)
-
TSchema = TypeVar('TSchema', bound=BaseDoc)
T = TypeVar('T', bound='HnswDocumentIndex')
+OPERATOR_MAPPING = {
+ '$eq': '=',
+ '$neq': '!=',
+ '$lt': '<',
+ '$lte': '<=',
+ '$gt': '>',
+ '$gte': '>=',
+}
+
class HnswDocumentIndex(BaseDocIndex, Generic[TSchema]):
def __init__(self, db_config=None, **kwargs):
"""Initialize HnswDocumentIndex"""
+ if db_config is not None and getattr(db_config, 'index_name'):
+ db_config.work_dir = db_config.index_name.replace("__", "/")
+
super().__init__(db_config=db_config, **kwargs)
self._db_config = cast(HnswDocumentIndex.DBConfig, self._db_config)
self._work_dir = self._db_config.work_dir
self._logger.debug(f'Working directory set to {self._work_dir}')
- load_existing = os.path.exists(self._work_dir) and os.listdir(self._work_dir)
+ load_existing = os.path.exists(self._work_dir) and glob.glob(
+ f'{self._work_dir}/*.bin'
+ )
Path(self._work_dir).mkdir(parents=True, exist_ok=True)
# HNSWLib setup
@@ -89,8 +106,14 @@ def __init__(self, db_config=None, **kwargs):
if col.config
}
self._hnsw_indices = {}
+ sub_docs_exist = False
+ cosine_metric_index_exist = False
for col_name, col in self._column_infos.items():
- if not col.config:
+ if '__' in col_name:
+ sub_docs_exist = True
+ if safe_issubclass(col.docarray_type, AnyDocArray):
+ continue
+ if not col.config or 'dim' not in col.config:
# non-tensor type; don't create an index
continue
if not load_existing and (
@@ -107,17 +130,35 @@ def __init__(self, db_config=None, **kwargs):
else:
self._hnsw_indices[col_name] = self._create_index(col_name, col)
self._logger.info(f'Created a new index for column `{col_name}`')
+ if self._hnsw_indices[col_name].space == 'cosine':
+ cosine_metric_index_exist = True
+ self._apply_optim_no_embedding_in_sqlite = (
+ not sub_docs_exist and not cosine_metric_index_exist
+ ) # optimization consisting in not serializing embeddings to SQLite because they are expensive to send and they can be reconstructed from the HNSW index itself.
# SQLite setup
self._sqlite_db_path = os.path.join(self._work_dir, 'docs_sqlite.db')
self._logger.debug(f'DB path set to {self._sqlite_db_path}')
self._sqlite_conn = sqlite3.connect(self._sqlite_db_path)
self._logger.info('Connection to DB has been established')
self._sqlite_cursor = self._sqlite_conn.cursor()
+ self._column_names: List[str] = []
self._create_docs_table()
self._sqlite_conn.commit()
+ self._num_docs = 0 # recompute again when needed
self._logger.info(f'{self.__class__.__name__} has been initialized')
+ @property
+ def index_name(self):
+ return self._db_config.work_dir # type: ignore
+
+ @property
+ def out_schema(self) -> Type[BaseDoc]:
+ """Return the real schema of the index."""
+ if self._is_subindex:
+ return self._ori_schema
+ return cast(Type[BaseDoc], self._schema)
+
###############################################
# Inner classes for query builder and configs #
###############################################
@@ -140,31 +181,33 @@ def build(self, *args, **kwargs) -> Any:
@dataclass
class DBConfig(BaseDocIndex.DBConfig):
- """Dataclass that contains all "static" configurations of WeaviateDocumentIndex."""
+ """Dataclass that contains all "static" configurations of HnswDocumentIndex."""
work_dir: str = '.'
+ default_column_config: Dict[Type, Dict[str, Any]] = field(
+ default_factory=lambda: defaultdict(
+ dict,
+ {
+ np.ndarray: {
+ 'dim': -1,
+ 'index': True, # if False, don't index at all
+ 'space': 'l2', # 'l2', 'ip', 'cosine'
+ 'max_elements': 1024,
+ 'ef_construction': 200,
+ 'ef': 10,
+ 'M': 16,
+ 'allow_replace_deleted': True,
+ 'num_threads': 1,
+ },
+ },
+ )
+ )
@dataclass
class RuntimeConfig(BaseDocIndex.RuntimeConfig):
- """Dataclass that contains all "dynamic" configurations of WeaviateDocumentIndex."""
+ """Dataclass that contains all "dynamic" configurations of HnswDocumentIndex."""
- default_column_config: Dict[Type, Dict[str, Any]] = field(
- default_factory=lambda: {
- np.ndarray: {
- 'dim': -1,
- 'index': True, # if False, don't index at all
- 'space': 'l2', # 'l2', 'ip', 'cosine'
- 'max_elements': 1024,
- 'ef_construction': 200,
- 'ef': 10,
- 'M': 16,
- 'allow_replace_deleted': True,
- 'num_threads': 1,
- },
- # `None` is not a Type, but we allow it here anyway
- None: {}, # type: ignore
- }
- )
+ pass
###############################################
# Implementation of abstract methods #
@@ -179,14 +222,47 @@ def python_type_to_db_type(self, python_type: Type) -> Any:
or None if ``python_type`` is not supported.
"""
for allowed_type in HNSWLIB_PY_VEC_TYPES:
- if issubclass(python_type, allowed_type):
+ if safe_issubclass(python_type, allowed_type):
return np.ndarray
+ # types allowed for filtering
+ type_map = {
+ int: 'INTEGER',
+ float: 'REAL',
+ str: 'TEXT',
+ }
+ for py_type, sqlite_type in type_map.items():
+ if safe_issubclass(python_type, py_type):
+ return sqlite_type
+
return None # all types allowed, but no db type needed
- def _index(self, column_data_dic, **kwargs):
+ def _index(
+ self,
+ column_to_data: Dict[str, Generator[Any, None, None]],
+ docs_validated: Sequence[BaseDoc] = [],
+ ):
+ self._index_subindex(column_to_data)
+
# not needed, we implement `index` directly
- ...
+ hashed_ids = tuple(self._to_hashed_id(doc.id) for doc in docs_validated)
+ # indexing into HNSWLib and SQLite sequentially
+ # could be improved by processing in parallel
+ for col_name, index in self._hnsw_indices.items():
+ data = column_to_data[col_name]
+ data_np = [self._to_numpy(arr) for arr in data]
+ data_stacked = np.stack(data_np)
+ num_docs_to_index = len(hashed_ids)
+ index_max_elements = index.get_max_elements()
+ current_elements = index.get_current_count()
+ if current_elements + num_docs_to_index > index_max_elements:
+ new_capacity = max(
+ index_max_elements, current_elements + num_docs_to_index
+ )
+ self._logger.info(f'Resizing the index to {new_capacity}')
+ index.resize_index(new_capacity)
+ index.add_items(data_stacked, ids=hashed_ids)
+ index.save_index(self._hnsw_locations[col_name])
def index(self, docs: Union[BaseDoc, Sequence[BaseDoc]], **kwargs):
"""Index Documents into the index.
@@ -206,19 +282,12 @@ def index(self, docs: Union[BaseDoc, Sequence[BaseDoc]], **kwargs):
n_docs = 1 if isinstance(docs, BaseDoc) else len(docs)
self._logger.debug(f'Indexing {n_docs} documents')
docs_validated = self._validate_docs(docs)
+ self._update_subindex_data(docs_validated)
data_by_columns = self._get_col_value_dict(docs_validated)
- hashed_ids = tuple(self._to_hashed_id(doc.id) for doc in docs_validated)
- # indexing into HNSWLib and SQLite sequentially
- # could be improved by processing in parallel
- for col_name, index in self._hnsw_indices.items():
- data = data_by_columns[col_name]
- data_np = [self._to_numpy(arr) for arr in data]
- data_stacked = np.stack(data_np)
- index.add_items(data_stacked, ids=hashed_ids)
- index.save_index(self._hnsw_locations[col_name])
-
+ self._index(data_by_columns, docs_validated, **kwargs)
self._send_docs_to_sqlite(docs_validated)
self._sqlite_conn.commit()
+ self._num_docs = 0 # recompute again when needed
def execute_query(self, query: List[Tuple[str, Dict]], *args, **kwargs) -> Any:
"""
@@ -239,11 +308,8 @@ def execute_query(self, query: List[Tuple[str, Dict]], *args, **kwargs) -> Any:
raise ValueError(
f'args and kwargs not supported for `execute_query` on {type(self)}'
)
- find_res = _execute_find_and_filter_query(
- doc_index=self,
- query=query,
- )
- return find_res
+
+ return self._execute_find_and_filter_query(query)
def _find_batched(
self,
@@ -251,15 +317,9 @@ def _find_batched(
limit: int,
search_field: str = '',
) -> _FindResultBatched:
- index = self._hnsw_indices[search_field]
- labels, distances = index.knn_query(queries, k=limit)
- result_das = [
- self._get_docs_sqlite_hashed_id(
- ids_per_query.tolist(),
- )
- for ids_per_query in labels
- ]
- return _FindResultBatched(documents=result_das, scores=distances)
+ return self._search_and_filter(
+ queries=queries, limit=limit, search_field=search_field
+ )
def _find(
self, query: np.ndarray, limit: int, search_field: str = ''
@@ -277,11 +337,20 @@ def _filter(
filter_query: Any,
limit: int,
) -> DocList:
- raise NotImplementedError(
- f'{type(self)} does not support filter-only queries.'
- f' To perform post-filtering on a query, use'
- f' `build_query()` and `execute_query()`.'
- )
+ rows = self._execute_filter(filter_query=filter_query, limit=limit)
+ hashed_ids = [doc_id for doc_id, _ in rows]
+ embeddings: OrderedDict[str, list] = OrderedDict()
+ for col_name, index in self._hnsw_indices.items():
+ embeddings[col_name] = index.get_items(hashed_ids)
+
+ docs = DocList.__class_getitem__(cast(Type[BaseDoc], self.out_schema))()
+ for i, row in enumerate(rows):
+ reconstruct_embeddings = {}
+ for col_name in embeddings.keys():
+ reconstruct_embeddings[col_name] = embeddings[col_name][i]
+ docs.append(self._doc_from_bytes(row[1], reconstruct_embeddings))
+
+ return docs
def _filter_batched(
self,
@@ -312,6 +381,15 @@ def _text_search_batched(
def _del_items(self, doc_ids: Sequence[str]):
# delete from the indices
+ for field_name, type_, _ in self._flatten_schema(
+ cast(Type[BaseDoc], self._schema)
+ ):
+ if safe_issubclass(type_, AnyDocArray):
+ for id in doc_ids:
+ doc = self.__getitem__(id)
+ sub_ids = [sub_doc.id for sub_doc in getattr(doc, field_name)]
+ del self._subindices[field_name][sub_ids]
+
try:
for doc_id in doc_ids:
id_ = self._to_hashed_id(doc_id)
@@ -322,18 +400,34 @@ def _del_items(self, doc_ids: Sequence[str]):
self._delete_docs_from_sqlite(doc_ids)
self._sqlite_conn.commit()
+ self._num_docs = 0 # recompute again when needed
+
+ def _get_items(self, doc_ids: Sequence[str], out: bool = True) -> Sequence[TSchema]:
+ """Get Documents from the hnswlib index, by `id`.
+ If no document is found, a KeyError is raised.
- def _get_items(self, doc_ids: Sequence[str]) -> Sequence[TSchema]:
- out_docs = self._get_docs_sqlite_doc_id(doc_ids)
+ :param doc_ids: ids to get from the Document index
+ :param out: return the documents in the original schema(True) or inner schema(False) for subindex
+ :return: Sequence of Documents, sorted corresponding to the order of `doc_ids`. Duplicate `doc_ids` can be omitted in the output.
+ """
+ out_docs = self._get_docs_sqlite_doc_id(doc_ids, out)
if len(out_docs) == 0:
raise KeyError(f'No document with id {doc_ids} found')
return out_docs
+ def _doc_exists(self, doc_id: str) -> bool:
+ hash_id = self._to_hashed_id(doc_id)
+ self._sqlite_cursor.execute(f"SELECT data FROM docs WHERE doc_id = '{hash_id}'")
+ rows = self._sqlite_cursor.fetchall()
+ return len(rows) > 0
+
def num_docs(self) -> int:
"""
Get the number of documents.
"""
- return self._get_num_docs_sqlite()
+ if self._num_docs == 0:
+ self._num_docs = self._get_num_docs_sqlite()
+ return self._num_docs
###############################################
# Helpers #
@@ -353,9 +447,7 @@ def _to_hashed_id(doc_id: Optional[str]) -> int:
def _load_index(self, col_name: str, col: '_ColumnInfo') -> hnswlib.Index:
"""Load an existing HNSW index from disk."""
index = self._create_index_class(col)
- index.load_index(
- self._hnsw_locations[col_name], max_elements=col.config['max_elements']
- )
+ index.load_index(self._hnsw_locations[col_name])
return index
# HNSWLib helpers
@@ -380,34 +472,79 @@ def _create_index(self, col_name: str, col: '_ColumnInfo') -> hnswlib.Index:
# SQLite helpers
def _create_docs_table(self):
- self._sqlite_cursor.execute(
- 'CREATE TABLE IF NOT EXISTS docs (doc_id INTEGER PRIMARY KEY, data BLOB)'
- )
+ columns: List[Tuple[str, str]] = []
+ for col, info in self._column_infos.items():
+ if (
+ col == 'id'
+ or '__' in col
+ or not info.db_type
+ or info.db_type == np.ndarray
+ ):
+ continue
+ columns.append((col, info.db_type))
+
+ columns_str = ', '.join(f'{name} {type}' for name, type in columns)
+ if columns_str:
+ columns_str = ', ' + columns_str
+
+ query = f'CREATE TABLE IF NOT EXISTS docs (doc_id INTEGER PRIMARY KEY, data BLOB{columns_str})'
+ self._sqlite_cursor.execute(query)
def _send_docs_to_sqlite(self, docs: Sequence[BaseDoc]):
+ # Generate the IDs
ids = (self._to_hashed_id(doc.id) for doc in docs)
- self._sqlite_cursor.executemany(
- 'INSERT INTO docs VALUES (?, ?)',
- ((id_, self._doc_to_bytes(doc)) for id_, doc in zip(ids, docs)),
+
+ column_names = self._get_column_names()
+ # Construct the field names and placeholders for the SQL query
+ all_fields = ', '.join(column_names)
+ placeholders = ', '.join(['?'] * len(column_names))
+
+ # Prepare the SQL statement
+ query = f'INSERT OR REPLACE INTO docs ({all_fields}) VALUES ({placeholders})'
+
+ # Prepare the data for insertion
+ data_to_insert = (
+ (id_, self._doc_to_bytes(doc))
+ + tuple(getattr(doc, field) for field in column_names[2:])
+ for id_, doc in zip(ids, docs)
)
- def _get_docs_sqlite_unsorted(self, univ_ids: Sequence[int]):
+ # Execute the query
+ self._sqlite_cursor.executemany(query, data_to_insert)
+
+ def _get_docs_sqlite_unsorted(self, univ_ids: Sequence[int], out: bool = True):
for id_ in univ_ids:
# I hope this protects from injection attacks
# properly binding with '?' doesn't work for some reason
assert isinstance(id_, int) or is_np_int(id_)
sql_id_list = '(' + ', '.join(str(id_) for id_ in univ_ids) + ')'
self._sqlite_cursor.execute(
- 'SELECT data FROM docs WHERE doc_id IN %s' % sql_id_list,
+ 'SELECT doc_id, data FROM docs WHERE doc_id IN %s' % sql_id_list,
)
- rows = self._sqlite_cursor.fetchall()
- docs_cls = DocList.__class_getitem__(cast(Type[BaseDoc], self._schema))
- return docs_cls([self._doc_from_bytes(row[0]) for row in rows])
+ rows = (
+ self._sqlite_cursor.fetchall()
+ ) # doc_ids do not come back in the same order
+ embeddings: OrderedDict[str, list] = OrderedDict()
+ for col_name, index in self._hnsw_indices.items():
+ embeddings[col_name] = index.get_items([row[0] for row in rows])
+
+ schema = self.out_schema if out else self._schema
+ docs = DocList.__class_getitem__(cast(Type[BaseDoc], schema))()
+ for i, (_, data_bytes) in enumerate(rows):
+ reconstruct_embeddings = {}
+ for col_name in embeddings.keys():
+ reconstruct_embeddings[col_name] = embeddings[col_name][i]
+ docs.append(self._doc_from_bytes(data_bytes, reconstruct_embeddings, out))
+
+ return docs
- def _get_docs_sqlite_doc_id(self, doc_ids: Sequence[str]) -> DocList[TSchema]:
+ def _get_docs_sqlite_doc_id(
+ self, doc_ids: Sequence[str], out: bool = True
+ ) -> DocList[TSchema]:
hashed_ids = tuple(self._to_hashed_id(id_) for id_ in doc_ids)
- docs_unsorted = self._get_docs_sqlite_unsorted(hashed_ids)
- docs_cls = DocList.__class_getitem__(cast(Type[BaseDoc], self._schema))
+ docs_unsorted = self._get_docs_sqlite_unsorted(hashed_ids, out)
+ schema = self.out_schema if out else self._schema
+ docs_cls = DocList.__class_getitem__(cast(Type[BaseDoc], schema))
return docs_cls(sorted(docs_unsorted, key=lambda doc: doc_ids.index(doc.id)))
def _get_docs_sqlite_hashed_id(self, hashed_ids: Sequence[int]) -> DocList:
@@ -416,7 +553,7 @@ def _get_docs_sqlite_hashed_id(self, hashed_ids: Sequence[int]) -> DocList:
def _in_position(doc):
return hashed_ids.index(self._to_hashed_id(doc.id))
- docs_cls = DocList.__class_getitem__(cast(Type[BaseDoc], self._schema))
+ docs_cls = DocList.__class_getitem__(cast(Type[BaseDoc], self.out_schema))
return docs_cls(sorted(docs_unsorted, key=_in_position))
def _delete_docs_from_sqlite(self, doc_ids: Sequence[Union[str, int]]):
@@ -434,8 +571,351 @@ def _get_num_docs_sqlite(self) -> int:
# serialization helpers
def _doc_to_bytes(self, doc: BaseDoc) -> bytes:
- return doc.to_protobuf().SerializeToString()
+ pb = doc.to_protobuf()
+ if self._apply_optim_no_embedding_in_sqlite:
+ for col_name in self._hnsw_indices.keys():
+ pb.data[col_name].Clear()
+ pb.data[col_name].Clear()
+ return pb.SerializeToString()
+
+ def _doc_from_bytes(
+ self, data: bytes, reconstruct_embeddings: Dict, out: bool = True
+ ) -> BaseDoc:
+ schema = self.out_schema if out else self._schema
+ schema_cls = cast(Type[BaseDoc], schema)
+ pb = DocProto.FromString(
+ data
+ ) # I cannot reconstruct directly the DA object because it may fail at validation because embedding may not be Optional
+ if self._apply_optim_no_embedding_in_sqlite:
+ for k, v in reconstruct_embeddings.items():
+ node_proto = (
+ schema_cls._get_field_annotation(k)
+ ._docarray_from_ndarray(np.array(v))
+ ._to_node_protobuf()
+ )
+ pb.data[k].MergeFrom(node_proto)
+
+ doc = schema_cls.from_protobuf(pb)
+ return doc
+
+ def _get_root_doc_id(self, id: str, root: str, sub: str) -> str:
+ """Get the root_id given the id of a subindex Document and the root and subindex name for hnswlib.
+
+ :param id: id of the subindex Document
+ :param root: root index name
+ :param sub: subindex name
+ :return: the root_id of the Document
+ """
+ subindex = self._subindices[root]
+
+ if not sub:
+ sub_doc = subindex._get_items([id], out=False) # type: ignore
+ parent_id = (
+ sub_doc[0]['parent_id']
+ if isinstance(sub_doc[0], dict)
+ else sub_doc[0].parent_id
+ )
+ return parent_id
+ else:
+ fields = sub.split('__')
+ cur_root_id = subindex._get_root_doc_id(
+ id, fields[0], '__'.join(fields[1:])
+ )
+ return self._get_root_doc_id(cur_root_id, root, '')
+
+ def _get_column_names(self) -> List[str]:
+ """
+ Retrieves the column names of the 'docs' table in the SQLite database.
+ The column names are cached in `self._column_names` to prevent multiple queries to the SQLite database.
+
+ :return: A list of strings, where each string is a column name.
+ """
+ if not self._column_names:
+ self._sqlite_cursor.execute('PRAGMA table_info(docs)')
+ info = self._sqlite_cursor.fetchall()
+ self._column_names = [row[1] for row in info]
+ return self._column_names
+
+ def _search_and_filter(
+ self,
+ queries: np.ndarray,
+ limit: int,
+ search_field: str = '',
+ hashed_ids: Optional[Set[int]] = None,
+ ) -> _FindResultBatched:
+ """
+ Executes a search and filter operation on the database.
+
+ :param queries: A numpy array of queries.
+ :param limit: The maximum number of results to return.
+ :param search_field: The field to search in.
+ :param hashed_ids: A set of hashed IDs to filter the results with.
+ :return: An instance of _FindResultBatched, containing the matching
+ documents and their corresponding scores.
+ """
+ # If there are no documents or hashed_ids is an empty set, return an empty _FindResultBatched
+ if hashed_ids is not None and len(hashed_ids) == 0:
+ return _FindResultBatched(documents=[], scores=[]) # type: ignore
+
+ # Set limit as the minimum of the provided limit and the total number of documents
+ limit = limit
+
+ # Ensure the search field is in the HNSW indices
+ if search_field not in self._hnsw_indices:
+ raise ValueError(
+ f'Search field {search_field} is not present in the HNSW indices'
+ )
+
+ def accept_hashed_ids(id):
+ """Accepts IDs that are in hashed_ids."""
+ return id in hashed_ids # type: ignore[operator]
+
+ extra_kwargs = {'filter': accept_hashed_ids} if hashed_ids else {}
+
+ # If hashed_ids is provided, k is the minimum of limit and the length of hashed_ids; else it is limit
+ k = min(limit, len(hashed_ids)) if hashed_ids else limit
+ index = self._hnsw_indices[search_field]
+
+ try:
+ labels, distances = index.knn_query(queries, k=k, **extra_kwargs)
+ except RuntimeError:
+ k = min(k, self.num_docs())
+ labels, distances = index.knn_query(queries, k=k, **extra_kwargs)
+
+ result_das = [
+ self._get_docs_sqlite_hashed_id(
+ ids_per_query.tolist(),
+ )
+ for ids_per_query in labels
+ ]
+ return _FindResultBatched(documents=result_das, scores=distances)
+
+ @classmethod
+ def _build_filter_query(
+ cls, query: Union[Dict, str], param_values: List[Any]
+ ) -> str:
+ """
+ Builds a filter query for database operations.
+
+ :param query: Query for filtering.
+ :param param_values: A list to store the parameters for the query.
+ :return: A string representing a SQL filter query.
+ """
+ if not isinstance(query, dict):
+ raise ValueError('Invalid query')
+
+ if len(query) != 1:
+ raise ValueError('Each nested dict must have exactly one key')
+
+ key, value = next(iter(query.items()))
+
+ if key in ['$and', '$or']:
+ # Combine subqueries using the AND or OR operator
+ subqueries = [cls._build_filter_query(q, param_values) for q in value]
+ return f'({f" {key[1:].upper()} ".join(subqueries)})'
+ elif key == '$not':
+ # Negate the query
+ return f'NOT {cls._build_filter_query(value, param_values)}'
+ else: # normal field
+ field = key
+ if not isinstance(value, dict) or len(value) != 1:
+ raise ValueError(f'Invalid condition for field {field}')
+ operator_key, operator_value = next(iter(value.items()))
+
+ if operator_key == "$exists":
+ # Check for the existence or non-existence of a field
+ if operator_value:
+ return f'{field} IS NOT NULL'
+ else:
+ return f'{field} IS NULL'
+ elif operator_key not in OPERATOR_MAPPING:
+ raise ValueError(f"Invalid operator {operator_key}")
+ else:
+ # If the operator is valid, create a placeholder and append the value to param_values
+ operator = OPERATOR_MAPPING[operator_key]
+ placeholder = '?'
+ param_values.append(operator_value)
+ return f'{field} {operator} {placeholder}'
+
+ def _execute_filter(
+ self,
+ filter_query: Any,
+ limit: int,
+ ) -> List[Tuple[str, bytes]]:
+ """
+ Executes a filter query on the database.
+
+ :param filter_query: Query for filtering.
+ :param limit: Maximum number of rows to be fetched.
+ :return: A list of rows fetched from the database.
+ """
+ param_values: List[Any] = []
+ sql_query = self._build_filter_query(filter_query, param_values)
+ sql_query = f'SELECT doc_id, data FROM docs WHERE {sql_query} LIMIT {limit}'
+ return self._sqlite_cursor.execute(sql_query, param_values).fetchall()
+
+ def _execute_find_and_filter_query(
+ self, query: List[Tuple[str, Dict]]
+ ) -> FindResult:
+ """
+ Executes a query to find and filter documents.
+
+ :param query: A list of operations and their corresponding arguments.
+ :return: A FindResult object containing filtered documents and their scores.
+ """
+ # Dictionary to store the score of each document
+ doc_to_score: Dict[BaseDoc, Any] = {}
+
+ # Pre- and post-filter conditions
+ pre_filters: Dict[str, Dict] = {}
+ post_filters: Dict[str, Dict] = {}
+
+ # Define filter limits
+ pre_filter_limit = self.num_docs()
+ post_filter_limit = self.num_docs()
+
+ find_executed: bool = False
+
+ # Document list with output schema
+ out_docs: DocList = DocList[self.out_schema]() # type: ignore[name-defined]
+
+ for op, op_kwargs in query:
+ if op == 'find':
+ hashed_ids: Optional[Set[str]] = None
+ if pre_filters:
+ hashed_ids = self._pre_filtering(pre_filters, pre_filter_limit)
+
+ query_vector = self._get_vector_for_query_builder(op_kwargs)
+ # Perform search and filter if hashed_ids returned by pre-filtering is not empty
+ if not (pre_filters and not hashed_ids):
+ # Returns batched output, so we need to get the first lists
+ out_docs, scores = self._search_and_filter( # type: ignore[assignment]
+ queries=query_vector,
+ limit=op_kwargs.get('limit', self.num_docs()),
+ search_field=op_kwargs['search_field'],
+ hashed_ids=hashed_ids,
+ )
+ out_docs = DocList[self.out_schema](out_docs[0]) # type: ignore[name-defined]
+ doc_to_score.update(zip(out_docs.__getattribute__('id'), scores[0]))
+ find_executed = True
+ elif op == 'filter':
+ if find_executed:
+ post_filters, post_filter_limit = self._update_filter_conditions(
+ post_filters, op_kwargs, post_filter_limit
+ )
+ else:
+ pre_filters, pre_filter_limit = self._update_filter_conditions(
+ pre_filters, op_kwargs, pre_filter_limit
+ )
+ else:
+ raise ValueError(f'Query operation is not supported: {op}')
+
+ if post_filters:
+ out_docs = self._post_filtering(
+ out_docs, post_filters, post_filter_limit, find_executed
+ )
+
+ return self._prepare_out_docs(out_docs, doc_to_score)
+
+ def _update_filter_conditions(
+ self, filter_conditions: Dict, operation_args: Dict, filter_limit: int
+ ) -> Tuple[Dict, int]:
+ """
+ Updates filter conditions based on the operation arguments and updates the filter limit.
+
+ :param filter_conditions: Current filter conditions.
+ :param operation_args: Arguments of the operation to be executed.
+ :param filter_limit: Current filter limit.
+ :return: Updated filter conditions and filter limit.
+ """
+ # Use '$and' operator if filter_conditions is not empty, else use operation_args['filter_query']
+ updated_filter_conditions = (
+ {'$and': {**filter_conditions, **operation_args['filter_query']}}
+ if filter_conditions
+ else operation_args['filter_query']
+ )
+ # Update filter limit based on the operation_args limit
+ updated_filter_limit = min(
+ filter_limit, operation_args.get('limit', filter_limit)
+ )
+ return updated_filter_conditions, updated_filter_limit
+
+ def _pre_filtering(
+ self, pre_filters: Dict[str, Dict], pre_filter_limit: int
+ ) -> Set[str]:
+ """
+ Performs pre-filtering on the data.
+
+ :param pre_filters: Filter conditions.
+ :param pre_filter_limit: Limit for the filtering.
+ :return: A set of hashed IDs from the filtered rows.
+ """
+ rows = self._execute_filter(filter_query=pre_filters, limit=pre_filter_limit)
+ return set(hashed_id for hashed_id, _ in rows)
+
+ def _get_vector_for_query_builder(self, find_args: Dict[str, Any]) -> np.ndarray:
+ """
+ Prepares the query vector for search operation.
+
+ :param find_args: Arguments for the 'find' operation.
+ :return: A numpy array representing the query vector.
+ """
+ if isinstance(find_args['query'], BaseDoc):
+ query_vec = self._get_values_by_column(
+ [find_args['query']], find_args['search_field']
+ )[0]
+ else:
+ query_vec = find_args['query']
+ query_vec_np = self._to_numpy(query_vec)
+ query_batched = np.expand_dims(query_vec_np, axis=0)
+ return query_batched
+
+ def _post_filtering(
+ self,
+ out_docs: DocList,
+ post_filters: Dict[str, Dict],
+ post_filter_limit: int,
+ find_executed: bool,
+ ) -> DocList:
+ """
+ Performs post-filtering on the found documents.
+
+ :param out_docs: The documents found by the 'find' operation.
+ :param post_filters: The post-filter conditions.
+ :param post_filter_limit: Limit for the post-filtering.
+ :param find_executed: Whether 'find' operation was executed.
+ :return: Filtered documents as per the post-filter conditions.
+ """
+ if not find_executed:
+ out_docs = self.filter(post_filters, limit=self.num_docs())
+ else:
+ docs_cls = DocList.__class_getitem__(cast(Type[BaseDoc], self.out_schema))
+ out_docs = docs_cls(filter_docs(out_docs, post_filters))
+
+ if post_filters:
+ out_docs = out_docs[:post_filter_limit]
+
+ return out_docs
+
+ def _prepare_out_docs(
+ self, out_docs: DocList, doc_to_score: Dict[BaseDoc, Any]
+ ) -> FindResult:
+ """
+ Prepares output documents with their scores.
+
+ :param out_docs: The documents to be output.
+ :param doc_to_score: Mapping of documents to their scores.
+ :return: FindResult object with documents and their scores.
+ """
+ if out_docs:
+ # If the "find" operation isn't called through the query builder,
+ # all returned scores will be 0
+ docs_and_scores = zip(
+ out_docs, (doc_to_score.get(doc.id, 0) for doc in out_docs)
+ )
+ docs_sorted = sorted(docs_and_scores, key=lambda x: x[1])
+ out_docs, out_scores = zip(*docs_sorted)
+ else:
+ out_docs, out_scores = [], [] # type: ignore[assignment]
- def _doc_from_bytes(self, data: bytes) -> BaseDoc:
- schema_cls = cast(Type[BaseDoc], self._schema)
- return schema_cls.from_protobuf(DocProto.FromString(data))
+ return FindResult(documents=out_docs, scores=out_scores)
diff --git a/docarray/index/backends/in_memory.py b/docarray/index/backends/in_memory.py
index 400b60fc4c5..c2d699b1873 100644
--- a/docarray/index/backends/in_memory.py
+++ b/docarray/index/backends/in_memory.py
@@ -1,3 +1,4 @@
+import os
from collections import defaultdict
from dataclasses import dataclass, field
from typing import (
@@ -18,17 +19,18 @@
import numpy as np
from docarray import BaseDoc, DocList
+from docarray.array.any_array import AnyDocArray
+from docarray.helper import _shallow_copy_doc
from docarray.index.abstract import BaseDocIndex, _raise_not_supported
-from docarray.index.backends.helper import (
- _collect_query_args,
- _execute_find_and_filter_query,
-)
+from docarray.index.backends.helper import _collect_query_args
from docarray.typing import AnyTensor, NdArray
from docarray.typing.tensor.abstract_tensor import AbstractTensor
+from docarray.utils._internal._typing import safe_issubclass
from docarray.utils.filter import filter_docs
from docarray.utils.find import (
FindResult,
FindResultBatched,
+ _extract_embeddings,
_FindResult,
_FindResultBatched,
find,
@@ -39,15 +41,60 @@
class InMemoryExactNNIndex(BaseDocIndex, Generic[TSchema]):
- def __init__(self, docs: Optional[DocList] = None, **kwargs):
+ def __init__(
+ self,
+ docs: Optional[DocList] = None,
+ db_config=None,
+ **kwargs,
+ ):
"""Initialize InMemoryExactNNIndex"""
- super().__init__(db_config=None, **kwargs)
+ super().__init__(db_config=db_config, **kwargs)
self._runtime_config = self.RuntimeConfig()
- self._docs = (
- docs
- if docs is not None
- else DocList.__class_getitem__(cast(Type[BaseDoc], self._schema))()
- )
+ self._db_config = cast(InMemoryExactNNIndex.DBConfig, self._db_config)
+ self._index_file_path = self._db_config.index_file_path
+
+ if docs and self._index_file_path:
+ raise ValueError(
+ 'Initialize `InMemoryExactNNIndex` with either `docs` or '
+ '`index_file_path`, not both. Provide `docs` for a fresh index, or '
+ '`index_file_path` to use an existing file.'
+ )
+
+ if self._index_file_path:
+ if os.path.exists(self._index_file_path):
+ self._logger.info(
+ f'Loading index from a binary file: {self._index_file_path}'
+ )
+ self._docs = DocList.__class_getitem__(
+ cast(Type[BaseDoc], self._schema)
+ ).load_binary(file=self._index_file_path)
+
+ data_by_columns = self._get_col_value_dict(self._docs)
+ self._update_subindex_data(self._docs)
+ self._index_subindex(data_by_columns)
+
+ else:
+ self._logger.warning(
+ f'Index file does not exist: {self._index_file_path}. '
+ f'Initializing empty InMemoryExactNNIndex.'
+ )
+ self._docs = DocList.__class_getitem__(
+ cast(Type[BaseDoc], self._schema)
+ )()
+ else:
+ if docs:
+ self._logger.info('Docs provided. Initializing with provided docs.')
+ self._docs = docs
+ else:
+ self._logger.info(
+ 'No docs or index file provided. Initializing empty InMemoryExactNNIndex.'
+ )
+ self._docs = DocList.__class_getitem__(
+ cast(Type[BaseDoc], self._schema)
+ )()
+
+ self._embedding_map: Dict[str, Tuple[AnyTensor, Optional[List[int]]]] = {}
+ self._ids_to_positions: Dict[str, int] = {}
def python_type_to_db_type(self, python_type: Type) -> Any:
"""Map python type to database type.
@@ -59,6 +106,13 @@ def python_type_to_db_type(self, python_type: Type) -> Any:
"""
return python_type
+ @property
+ def out_schema(self) -> Type[BaseDoc]:
+ """Return the original schema (without the parent_id from new_schema type)"""
+ if self._is_subindex:
+ return self._ori_schema
+ return cast(Type[BaseDoc], self._schema)
+
class QueryBuilder(BaseDocIndex.QueryBuilder):
def __init__(self, query: Optional[List[Tuple[str, Dict]]] = None):
super().__init__()
@@ -80,12 +134,7 @@ def build(self, *args, **kwargs) -> Any:
class DBConfig(BaseDocIndex.DBConfig):
"""Dataclass that contains all "static" configurations of InMemoryExactNNIndex."""
- pass
-
- @dataclass
- class RuntimeConfig(BaseDocIndex.RuntimeConfig):
- """Dataclass that contains all "dynamic" configurations of InMemoryExactNNIndex."""
-
+ index_file_path: Optional[str] = None
default_column_config: Dict[Type, Dict[str, Any]] = field(
default_factory=lambda: defaultdict(
dict,
@@ -95,6 +144,12 @@ class RuntimeConfig(BaseDocIndex.RuntimeConfig):
)
)
+ @dataclass
+ class RuntimeConfig(BaseDocIndex.RuntimeConfig):
+ """Dataclass that contains all "dynamic" configurations of InMemoryExactNNIndex."""
+
+ pass
+
def index(self, docs: Union[BaseDoc, Sequence[BaseDoc]], **kwargs):
"""index Documents into the index.
@@ -109,7 +164,20 @@ def index(self, docs: Union[BaseDoc, Sequence[BaseDoc]], **kwargs):
"""
# implementing the public option because conversion to column dict is not needed
docs = self._validate_docs(docs)
- self._docs.extend(docs)
+ ids_to_positions = self._get_ids_to_positions()
+ for doc in docs:
+ if doc.id in ids_to_positions:
+ self._docs[ids_to_positions[doc.id]] = doc
+ else:
+ self._docs.append(doc)
+ self._ids_to_positions[str(doc.id)] = len(self._ids_to_positions)
+
+ # Add parent_id to all sub-index documents and store sub-index documents
+ data_by_columns = self._get_col_value_dict(docs)
+ self._update_subindex_data(docs)
+ self._index_subindex(data_by_columns)
+
+ self._rebuild_embedding()
def _index(self, column_to_data: Dict[str, Generator[Any, None, None]]):
raise NotImplementedError
@@ -120,33 +188,99 @@ def num_docs(self) -> int:
"""
return len(self._docs)
+ def _rebuild_embedding(self):
+ """
+ Reconstructs the embeddings map for each field. This is performed to store pre-stacked
+ embeddings, thereby optimizing performance by avoiding repeated stacking of embeddings.
+
+ Note: '_embedding_map' is a dictionary mapping fields to their corresponding embeddings.
+ """
+ if self._is_index_empty:
+ self._embedding_map = dict()
+ else:
+ for field_, embedding in self._embedding_map.items():
+ self._embedding_map[field_] = _extract_embeddings(self._docs, field_)
+
def _del_items(self, doc_ids: Sequence[str]):
"""Delete Documents from the index.
:param doc_ids: ids to delete from the Document Store
"""
+ for field_, type_, _ in self._flatten_schema(cast(Type[BaseDoc], self._schema)):
+ if safe_issubclass(type_, AnyDocArray):
+ for id in doc_ids:
+ doc_ = self._get_items([id])
+ if len(doc_) == 0:
+ raise KeyError(
+ f"The document (id = '{id}') does not exist in the ExactNNIndexer."
+ )
+ sub_ids = [sub_doc.id for sub_doc in getattr(doc_[0], field_)]
+ del self._subindices[field_][sub_ids]
+
indices = []
for i, doc in enumerate(self._docs):
if doc.id in doc_ids:
indices.append(i)
del self._docs[indices]
+ self._update_ids_to_positions()
+ self._rebuild_embedding()
+
+ def _ori_items(self, doc: BaseDoc) -> BaseDoc:
+ """
+ The Indexer's backend stores parent_id to support nested data. However,
+ this method enables us to retrieve the original items in their original
+ type, which is what the user interacts with.
+
+ :param doc: The input document in New_Schema format from the Indexer's backend.
+ :return: The input document with its original schema.
+ """
+
+ ori_doc = _shallow_copy_doc(doc)
+ for field_name, type_, _ in self._flatten_schema(
+ cast(Type[BaseDoc], self.out_schema)
+ ):
+ if safe_issubclass(type_, AnyDocArray):
+ _list = getattr(ori_doc, field_name)
+ for i, nested_doc in enumerate(_list):
+ sub_indexer: InMemoryExactNNIndex = cast(
+ InMemoryExactNNIndex, self._subindices[field_name]
+ )
+ nested_doc = self._subindices[field_name]._ori_schema(
+ **nested_doc.__dict__
+ )
+
+ _list[i] = sub_indexer._ori_items(nested_doc)
+
+ return ori_doc
def _get_items(
- self, doc_ids: Sequence[str]
+ self, doc_ids: Sequence[str], raw: bool = False
) -> Union[Sequence[TSchema], Sequence[Dict[str, Any]]]:
"""Get Documents from the index, by `id`.
If no document is found, a KeyError is raised.
:param doc_ids: ids to get from the Document index
+ :param raw: if raw, output the new_schema type (with parent id)
:return: Sequence of Documents, sorted corresponding to the order of `doc_ids`.
Duplicate `doc_ids` can be omitted in the output.
"""
- indices = []
- for i, doc in enumerate(self._docs):
- if doc.id in doc_ids:
- indices.append(i)
- return self._docs[indices]
+
+ out_docs = []
+ ids_to_positions = self._get_ids_to_positions()
+ for doc_id in doc_ids:
+ if doc_id not in ids_to_positions:
+ continue
+ doc = self._docs[ids_to_positions[doc_id]]
+ if raw:
+ out_docs.append(doc)
+ else:
+ ori_doc = self._ori_items(doc)
+ schema_cls = cast(Type[BaseDoc], self.out_schema)
+ new_doc = schema_cls(**ori_doc.__dict__)
+ out_docs.append(new_doc)
+
+ return out_docs
def execute_query(self, query: List[Tuple[str, Dict]], *args, **kwargs) -> Any:
"""
@@ -167,11 +301,44 @@ def execute_query(self, query: List[Tuple[str, Dict]], *args, **kwargs) -> Any:
raise ValueError(
f'args and kwargs not supported for `execute_query` on {type(self)}'
)
- find_res = _execute_find_and_filter_query(
- doc_index=self,
- query=query,
- )
- return find_res
+ return self._find_and_filter(query)
+
+ def _find_and_filter(self, query: List[Tuple[str, Dict]]) -> FindResult:
+ """
+ The function executes search operations such as 'find' and 'filter' in the order
+ they appear in the query. The 'find' operation performs a vector similarity search.
+ The 'filter' operation filters out documents based on a filter query.
+ The documents are finally sorted based on their scores.
+
+ :param query: The query to execute.
+ :return: A tuple of retrieved documents and their scores.
+ """
+ out_docs = self._docs
+ doc_to_score: Dict[BaseDoc, Any] = {}
+ for op, op_kwargs in query:
+ if op == 'find':
+ out_docs, scores = find(
+ index=out_docs,
+ query=op_kwargs['query'],
+ search_field=op_kwargs['search_field'],
+ limit=op_kwargs.get('limit', len(out_docs)),
+ metric=self._column_infos[op_kwargs['search_field']].config[
+ 'space'
+ ],
+ )
+ doc_to_score.update(zip(out_docs.id, scores))
+ elif op == 'filter':
+ out_docs = filter_docs(out_docs, op_kwargs['filter_query'])
+ if 'limit' in op_kwargs:
+ out_docs = out_docs[: op_kwargs['limit']]
+ else:
+ raise ValueError(f'Query operation is not supported: {op}')
+
+ scores_and_docs = zip([doc_to_score[doc.id] for doc in out_docs], out_docs)
+ sorted_lists = sorted(scores_and_docs, reverse=True)
+ out_scores, out_docs = zip(*sorted_lists)
+
+ return FindResult(documents=out_docs, scores=out_scores)
def find(
self,
@@ -193,6 +360,10 @@ def find(
"""
self._logger.debug(f'Executing `find` for search field {search_field}')
self._validate_search_field(search_field)
+
+ if self._is_index_empty:
+ return FindResult(documents=[], scores=[]) # type: ignore
+
config = self._column_infos[search_field].config
docs, scores = find(
@@ -201,10 +372,19 @@ def find(
search_field=search_field,
limit=limit,
metric=config['space'],
+ cache=self._embedding_map,
)
- docs_with_schema = DocList.__class_getitem__(cast(Type[BaseDoc], self._schema))(
- docs
- )
+
+ docs_ = []
+ for doc in docs:
+ ori_doc = self._ori_items(doc)
+ schema_cls = cast(Type[BaseDoc], self.out_schema)
+ docs_.append(schema_cls(**ori_doc.__dict__))
+
+ docs_with_schema = DocList.__class_getitem__(
+ cast(Type[BaseDoc], self.out_schema)
+ )(docs_)
+
return FindResult(documents=docs_with_schema, scores=scores)
def _find(
@@ -233,6 +413,10 @@ def find_batched(
"""
self._logger.debug(f'Executing `find_batched` for search field {search_field}')
self._validate_search_field(search_field)
+
+ if self._is_index_empty:
+ return FindResultBatched(documents=[], scores=[]) # type: ignore
+
config = self._column_infos[search_field].config
find_res = find_batched(
@@ -241,6 +425,7 @@ def find_batched(
search_field=search_field,
limit=limit,
metric=config['space'],
+ cache=self._embedding_map,
)
return find_res
@@ -265,7 +450,7 @@ def filter(
"""
self._logger.debug(f'Executing `filter` for the query {filter_query}')
- docs = filter_docs(docs=self._docs, query=filter_query)
+ docs = filter_docs(docs=self._docs, query=filter_query)[:limit]
return cast(DocList, docs)
def _filter(self, filter_query: Any, limit: int) -> Union[DocList, List[Dict]]:
@@ -285,3 +470,62 @@ def _text_search_batched(
self, queries: Sequence[str], limit: int, search_field: str = ''
) -> _FindResultBatched:
raise NotImplementedError(f'{type(self)} does not support text search.')
+
+ def _doc_exists(self, doc_id: str) -> bool:
+ return doc_id in self._get_ids_to_positions()
+
+ def persist(self, file: Optional[str] = None) -> None:
+ """Persist InMemoryExactNNIndex into a binary file."""
+ DEFAULT_INDEX_FILE_PATH = 'in_memory_index.bin'
+ file_to_save = self._index_file_path or file
+ if file_to_save is None:
+ self._logger.warning(
+ f'persisting index to {DEFAULT_INDEX_FILE_PATH} because no `index_file_path` has been used inside DBConfig and no `file` has been passed as argument'
+ )
+ file_to_save = file_to_save or DEFAULT_INDEX_FILE_PATH
+ self._docs.save_binary(file=file_to_save)
+
+ def _get_root_doc_id(self, id: str, root: str, sub: str) -> str:
+ """Get the root_id given the id of a subindex Document and the root and subindex name
+
+ :param id: id of the subindex Document
+ :param root: root index name
+ :param sub: subindex name
+ :return: the root_id of the Document
+ """
+ subindex: InMemoryExactNNIndex = cast(
+ InMemoryExactNNIndex, self._subindices[root]
+ )
+
+ if not sub:
+ sub_doc = subindex._get_items([id], raw=True)
+ parent_id = (
+ sub_doc[0]['parent_id']
+ if isinstance(sub_doc[0], dict)
+ else sub_doc[0].parent_id
+ )
+ return parent_id
+ else:
+ fields = sub.split('__')
+ cur_root_id = subindex._get_root_doc_id(
+ id, fields[0], '__'.join(fields[1:])
+ )
+ return self._get_root_doc_id(cur_root_id, root, '')
+
+ def _get_ids_to_positions(self) -> Dict[str, int]:
+ """
+ Obtains a mapping between document IDs and their respective positions
+ within the DocList. If this mapping hasn't been initialized, it will be created.
+
+ :return: A dictionary mapping each document ID to its corresponding position.
+ """
+ if not self._ids_to_positions:
+ self._update_ids_to_positions()
+ return self._ids_to_positions
+
+ def _update_ids_to_positions(self) -> None:
+ """
+ Generates or updates the mapping between document IDs and their corresponding
+ positions within the DocList.
+ """
+ self._ids_to_positions = {doc.id: pos for pos, doc in enumerate(self._docs)}
diff --git a/docarray/index/backends/milvus.py b/docarray/index/backends/milvus.py
new file mode 100644
index 00000000000..e84baac7210
--- /dev/null
+++ b/docarray/index/backends/milvus.py
@@ -0,0 +1,842 @@
+from collections import defaultdict
+from dataclasses import dataclass, field
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ Generator,
+ Generic,
+ List,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+ cast,
+)
+
+import numpy as np
+
+from docarray import BaseDoc, DocList
+from docarray.array.any_array import AnyDocArray
+from docarray.index.abstract import (
+ BaseDocIndex,
+ _raise_not_composable,
+ _raise_not_supported,
+)
+from docarray.index.backends.helper import _collect_query_args
+from docarray.typing import AnyTensor, NdArray
+from docarray.typing.id import ID
+from docarray.typing.tensor.abstract_tensor import AbstractTensor
+from docarray.utils._internal._typing import safe_issubclass
+from docarray.utils.find import (
+ FindResult,
+ FindResultBatched,
+ _FindResult,
+ _FindResultBatched,
+)
+
+if TYPE_CHECKING:
+ from pymilvus import ( # type: ignore[import]
+ Collection,
+ CollectionSchema,
+ DataType,
+ FieldSchema,
+ Hits,
+ connections,
+ utility,
+ )
+else:
+ from pymilvus import (
+ Collection,
+ CollectionSchema,
+ DataType,
+ FieldSchema,
+ Hits,
+ connections,
+ utility,
+ )
+
+MAX_LEN = 65_535 # Maximum length that Milvus allows for a VARCHAR field
+VALID_METRICS = ['L2', 'IP']
+VALID_INDEX_TYPES = [
+ 'FLAT',
+ 'IVF_FLAT',
+ 'IVF_SQ8',
+ 'IVF_PQ',
+ 'HNSW',
+ 'ANNOY',
+ 'DISKANN',
+]
+
+TSchema = TypeVar('TSchema', bound=BaseDoc)
+
+
+class MilvusDocumentIndex(BaseDocIndex, Generic[TSchema]):
+ def __init__(self, db_config=None, **kwargs):
+ """Initialize MilvusDocumentIndex"""
+ super().__init__(db_config=db_config, **kwargs)
+ self._db_config: MilvusDocumentIndex.DBConfig = cast(
+ MilvusDocumentIndex.DBConfig, self._db_config
+ )
+ self._runtime_config: MilvusDocumentIndex.RuntimeConfig = cast(
+ MilvusDocumentIndex.RuntimeConfig, self._runtime_config
+ )
+
+ self._client = connections.connect(
+ db_name="default",
+ host=self._db_config.host,
+ port=self._db_config.port,
+ user=self._db_config.user,
+ password=self._db_config.password,
+ token=self._db_config.token,
+ )
+
+ self._validate_columns()
+ self._field_name = self._get_vector_field_name()
+ self._collection = self._create_or_load_collection()
+ self._build_index()
+ self._collection.load()
+ self._logger.info(f'{self.__class__.__name__} has been initialized')
+
+ @dataclass
+ class DBConfig(BaseDocIndex.DBConfig):
+ """Dataclass that contains all "static" configurations of MilvusDocumentIndex.
+
+ :param index_name: The name of the index in the Milvus database. If not provided, default index name will be used.
+ :param collection_description: Description of the collection in the database.
+ :param host: Hostname of the server where the database resides. Default is 'localhost'.
+ :param port: Port number used to connect to the database. Default is 19530.
+ :param user: User for the database. Can be an empty string if no user is required.
+ :param password: Password for the specified user. Can be an empty string if no password is required.
+ :param token: Token for secure connection. Can be an empty string if no token is required.
+ :param consistency_level: The level of consistency for the database session. Default is 'Session'.
+ :param search_params: Dictionary containing parameters for search operations,
+ default has a single key 'params' with 'nprobe' set to 10.
+ :param serialize_config: Dictionary containing configuration for serialization,
+ default is {'protocol': 'protobuf'}.
+ :param default_column_config: Dictionary that defines the default configuration
+ for each data type column.
+ """
+
+ index_name: Optional[str] = None
+ collection_description: str = ""
+ host: str = "localhost"
+ port: int = 19530
+ user: Optional[str] = ""
+ password: Optional[str] = ""
+ token: Optional[str] = ""
+ consistency_level: str = 'Session'
+ search_params: Dict = field(
+ default_factory=lambda: {
+ "params": {"nprobe": 10},
+ }
+ )
+ serialize_config: Dict = field(default_factory=lambda: {"protocol": "protobuf"})
+ default_column_config: Dict[Type, Dict[str, Any]] = field(
+ default_factory=lambda: defaultdict(
+ dict,
+ {
+ DataType.FLOAT_VECTOR: {
+ 'index_type': 'IVF_FLAT',
+ 'metric_type': 'L2',
+ 'params': {"nlist": 1024},
+ },
+ },
+ )
+ )
+
+ @dataclass
+ class RuntimeConfig(BaseDocIndex.RuntimeConfig):
+ """Dataclass that contains all "dynamic" configurations of RedisDocumentIndex.
+
+ :param batch_size: Batch size for index/get/del.
+ """
+
+ batch_size: int = 100
+
+ class QueryBuilder(BaseDocIndex.QueryBuilder):
+ def __init__(self, query: Optional[List[Tuple[str, Dict]]] = None):
+ super().__init__()
+ # list of tuples (method name, kwargs)
+ self._queries: List[Tuple[str, Dict]] = query or []
+
+ def build(self, *args, **kwargs) -> Any:
+ """Build the query object."""
+ return self._queries
+
+ find = _collect_query_args('find')
+ filter = _collect_query_args('filter')
+ text_search = _raise_not_supported('text_search')
+ find_batched = _raise_not_composable('find_batched')
+ filter_batched = _raise_not_composable('filter_batched')
+ text_search_batched = _raise_not_supported('text_search_batched')
+
+ def python_type_to_db_type(self, python_type: Type) -> Any:
+ """Map python type to database type.
+ Takes any python type and returns the corresponding database column type.
+
+ :param python_type: a python type.
+ :return: the corresponding database column type, or None if ``python_type``
+ is not supported.
+ """
+ type_map = {
+ int: DataType.INT64,
+ float: DataType.FLOAT,
+ str: DataType.VARCHAR,
+ bytes: DataType.VARCHAR,
+ np.ndarray: DataType.FLOAT_VECTOR,
+ list: DataType.FLOAT_VECTOR,
+ AnyTensor: DataType.FLOAT_VECTOR,
+ AbstractTensor: DataType.FLOAT_VECTOR,
+ }
+
+ if safe_issubclass(python_type, ID):
+ return DataType.VARCHAR
+
+ for py_type, db_type in type_map.items():
+ if safe_issubclass(python_type, py_type):
+ return db_type
+
+ raise ValueError(f'Unsupported column type for {type(self)}: {python_type}')
+
+ def _create_or_load_collection(self) -> Collection:
+ """
+ This function initializes or retrieves a Milvus collection with a specified schema,
+ storing documents as serialized data and using the document's ID as the collection's ID
+ , while inheriting other schema properties from the indexer's schema.
+
+ !!! note
+ Milvus framework currently only supports a single vector column, and only one vector
+ column can store in the schema (others are stored in the serialized data)
+ """
+
+ if not utility.has_collection(self.index_name):
+ fields = [
+ FieldSchema(
+ name="serialized",
+ dtype=DataType.VARCHAR,
+ max_length=MAX_LEN,
+ ),
+ FieldSchema(
+ name="id",
+ dtype=DataType.VARCHAR,
+ is_primary=True,
+ max_length=MAX_LEN,
+ ),
+ ]
+ for column_name, info in self._column_infos.items():
+ if (
+ column_name != 'id'
+ and not (
+ info.db_type == DataType.FLOAT_VECTOR
+ and column_name
+ != self._field_name # Only store one vector field as a column
+ )
+ and not safe_issubclass(info.docarray_type, AnyDocArray)
+ ):
+ field_dict: Dict[str, Any] = {}
+ if info.db_type == DataType.VARCHAR:
+ field_dict = {'max_length': MAX_LEN}
+ elif info.db_type == DataType.FLOAT_VECTOR:
+ field_dict = {'dim': info.n_dim or info.config.get('dim')}
+
+ fields.append(
+ FieldSchema(
+ name=column_name,
+ dtype=info.db_type,
+ is_primary=False,
+ **field_dict,
+ )
+ )
+
+ self._logger.info("Collection has been created")
+ return Collection(
+ name=self.index_name,
+ schema=CollectionSchema(
+ fields=fields,
+ description=self._db_config.collection_description,
+ ),
+ using='default',
+ )
+
+ return Collection(self.index_name)
+
+ def _validate_columns(self):
+ """
+ Validates whether the data schema includes at least one vector column used
+ for embedding (as required by Milvus), and ensures that dimension information
+ is specified for that column.
+ """
+ vector_columns = sum(
+ safe_issubclass(info.docarray_type, AbstractTensor)
+ and info.config.get('is_embedding', False)
+ for info in self._column_infos.values()
+ )
+ if vector_columns == 0:
+ raise ValueError(
+ "Unable to find any vector columns. Please make sure that at least one "
+ "column is of a vector type with the is_embedding=True attribute specified."
+ )
+ elif vector_columns > 1:
+ raise ValueError("Specifying multiple vector fields is not supported.")
+
+ for column, info in self._column_infos.items():
+ if info.config.get('is_embedding') and (
+ not info.n_dim and not info.config.get('dim')
+ ):
+ raise ValueError(
+ f"The dimension information is missing for the column '{column}', which is of vector type."
+ )
+
+ @property
+ def index_name(self):
+ default_index_name = (
+ self._schema.__name__.lower() if self._schema is not None else None
+ )
+ if default_index_name is None:
+ err_msg = (
+ 'A MilvusDocumentIndex must be typed with a Document type. '
+ 'To do so, use the syntax: MilvusDocumentIndex[DocumentType]'
+ )
+
+ self._logger.error(err_msg)
+ raise ValueError(err_msg)
+ index_name = self._db_config.index_name or default_index_name
+ self._logger.debug(f'Retrieved index name: {index_name}')
+ return index_name
+
+ @property
+ def out_schema(self) -> Type[BaseDoc]:
+ """Return the real schema of the index."""
+ if self._is_subindex:
+ return self._ori_schema
+ return cast(Type[BaseDoc], self._schema)
+
+ def _build_index(self):
+ """
+ Sets up an index configuration for a specific column index, which is
+ required by the Milvus backend.
+ """
+
+ existing_indices = [index.field_name for index in self._collection.indexes]
+ if self._field_name in existing_indices:
+ return
+
+ index_type = self._column_infos[self._field_name].config['index_type'].upper()
+ if index_type not in VALID_INDEX_TYPES:
+ raise ValueError(
+ f"Invalid index type '{index_type}' provided. "
+ f"Must be one of: {', '.join(VALID_INDEX_TYPES)}"
+ )
+ metric_type = (
+ self._column_infos[self._field_name].config.get('space', '').upper()
+ )
+ if metric_type not in VALID_METRICS:
+ self._logger.warning(
+ f"Invalid or no distance metric '{metric_type}' was provided. "
+ f"Should be one of: {', '.join(VALID_INDEX_TYPES)}. "
+ f"Default distance metric will be used."
+ )
+ metric_type = self._column_infos[self._field_name].config['metric_type']
+
+ index = {
+ "index_type": index_type,
+ "metric_type": metric_type,
+ "params": self._column_infos[self._field_name].config['params'],
+ }
+
+ self._collection.create_index(self._field_name, index)
+ self._logger.info(
+ f"Index for the field '{self._field_name}' has been successfully created"
+ )
+
+ def _get_vector_field_name(self):
+ for column, info in self._column_infos.items():
+ if info.db_type == DataType.FLOAT_VECTOR and info.config.get(
+ 'is_embedding'
+ ):
+ return column
+ return ''
+
+ @staticmethod
+ def _get_batches(docs, batch_size):
+ """Yield successive batch_size batches from docs."""
+ for i in range(0, len(docs), batch_size):
+ yield docs[i : i + batch_size]
+
+ def index(self, docs: Union[BaseDoc, Sequence[BaseDoc]], **kwargs):
+ """Index Documents into the index.
+
+ !!! note
+ Passing a sequence of Documents that is not a DocList
+ (such as a List of Docs) comes at a performance penalty.
+ This is because the Index needs to check compatibility between itself and
+ the data. With a DocList as input this is a single check; for other inputs
+ compatibility needs to be checked for every Document individually.
+
+ :param docs: Documents to index.
+ """
+
+ n_docs = 1 if isinstance(docs, BaseDoc) else len(docs)
+ self._logger.debug(f'Indexing {n_docs} documents')
+ docs = self._validate_docs(docs)
+ self._update_subindex_data(docs)
+ data_by_columns = self._get_col_value_dict(docs)
+ self._index_subindex(data_by_columns)
+
+ positions: Dict[str, int] = {
+ info.name: num for num, info in enumerate(self._collection.schema.fields)
+ }
+
+ for batch in self._get_batches(
+ docs, batch_size=self._runtime_config.batch_size
+ ):
+ entities: List[List[Any]] = [
+ [] for _ in range(len(self._collection.schema))
+ ]
+ for doc in batch:
+ # "serialized" will always be in the first position
+ entities[0].append(doc.to_base64(**self._db_config.serialize_config))
+ for schema_field in self._collection.schema.fields:
+ if schema_field.name == 'serialized':
+ continue
+ column_value = self._get_values_by_column([doc], schema_field.name)[
+ 0
+ ]
+ if schema_field.dtype == DataType.FLOAT_VECTOR:
+ column_value = self._map_embedding(column_value)
+
+ entities[positions[schema_field.name]].append(column_value)
+ self._collection.insert(entities)
+
+ self._collection.flush()
+ self._logger.info(f"{len(docs)} documents has been indexed")
+
+ def _filter_by_parent_id(self, id: str) -> Optional[List[str]]:
+ """Filter the ids of the subindex documents given id of root document.
+
+ :param id: the root document id to filter by
+ :return: a list of ids of the subindex documents
+ """
+ docs = self._filter(filter_query=f"parent_id == '{id}'", limit=self.num_docs())
+ return [doc.id for doc in docs] # type: ignore[union-attr]
+
+ def num_docs(self) -> int:
+ """
+ Get the number of documents.
+
+ !!! note
+ Cannot use Milvus' num_entities method because it's not precise
+ especially after delete ops (#15201 issue in Milvus)
+ """
+
+ self._collection.load()
+
+ result = self._collection.query(
+ expr=self._always_true_expr("id"),
+ offset=0,
+ output_fields=["serialized"],
+ )
+
+ return len(result)
+
+ def _get_items(
+ self, doc_ids: Sequence[str]
+ ) -> Union[Sequence[TSchema], Sequence[Dict[str, Any]]]:
+ """Get Documents from the index, by `id`.
+ If no document is found, a KeyError is raised.
+
+ :param doc_ids: ids to get from the Document index
+ :param raw: if raw, output the new_schema type (with parent id)
+ :return: Sequence of Documents, sorted corresponding to the order of `doc_ids`.
+ Duplicate `doc_ids` can be omitted in the output.
+ """
+
+ self._collection.load()
+ results: List[Dict] = []
+ for batch in self._get_batches(
+ doc_ids, batch_size=self._runtime_config.batch_size
+ ):
+ results.extend(
+ self._collection.query(
+ expr="id in " + str([id for id in batch]),
+ offset=0,
+ output_fields=["serialized"],
+ consistency_level=self._db_config.consistency_level,
+ )
+ )
+
+ self._collection.release()
+
+ return self._docs_from_query_response(results)
+
+ def _del_items(self, doc_ids: Sequence[str]):
+ """Delete Documents from the index.
+
+ :param doc_ids: ids to delete from the Document Store
+ """
+ self._collection.load()
+ for batch in self._get_batches(
+ doc_ids, batch_size=self._runtime_config.batch_size
+ ):
+ self._collection.delete(
+ expr="id in " + str([id for id in batch]),
+ consistency_level=self._db_config.consistency_level,
+ )
+ self._logger.info(f"{len(doc_ids)} documents has been deleted")
+
+ def _filter(
+ self,
+ filter_query: Any,
+ limit: int,
+ ) -> Union[DocList, List[Dict]]:
+ """
+ Filters the index based on the given filter query.
+
+ :param filter_query: The filter condition.
+ :param limit: The maximum number of results to return.
+ :return: Filter results.
+ """
+
+ self._collection.load()
+
+ result = self._collection.query(
+ expr=filter_query,
+ offset=0,
+ limit=min(limit, self.num_docs()),
+ output_fields=["serialized"],
+ )
+
+ self._collection.release()
+
+ return self._docs_from_query_response(result)
+
+ def _filter_batched(
+ self,
+ filter_queries: Any,
+ limit: int,
+ ) -> Union[List[DocList], List[List[Dict]]]:
+ """
+ Filters the index based on the given batch of filter queries.
+
+ :param filter_queries: The filter conditions.
+ :param limit: The maximum number of results to return for each filter query.
+ :return: Filter results.
+ """
+ return [
+ self._filter(filter_query=query, limit=limit) for query in filter_queries
+ ]
+
+ def _text_search(
+ self,
+ query: str,
+ limit: int,
+ search_field: str = '',
+ ) -> _FindResult:
+ raise NotImplementedError(f'{type(self)} does not support text search.')
+
+ def _text_search_batched(
+ self,
+ queries: Sequence[str],
+ limit: int,
+ search_field: str = '',
+ ) -> _FindResultBatched:
+ raise NotImplementedError(f'{type(self)} does not support text search.')
+
+ def _index(self, column_to_data: Dict[str, Generator[Any, None, None]]):
+ """index a document into the store"""
+ raise NotImplementedError()
+
+ def find(
+ self,
+ query: Union[AnyTensor, BaseDoc],
+ search_field: str = '',
+ limit: int = 10,
+ **kwargs,
+ ) -> FindResult:
+ """Find documents in the index using nearest neighbor search.
+
+ :param query: query vector for KNN/ANN search.
+ Can be either a tensor-like (np.array, torch.Tensor, etc.)
+ with a single axis, or a Document
+ :param search_field: name of the field to search on.
+ Documents in the index are retrieved based on this similarity
+ of this field to the query.
+ :param limit: maximum number of documents to return
+ :return: a named tuple containing `documents` and `scores`
+ """
+ self._logger.debug(f'Executing `find` for search field {search_field}')
+ if search_field != '':
+ raise ValueError(
+ 'Argument search_field is not supported for MilvusDocumentIndex.'
+ 'Set search_field to an empty string to proceed.'
+ )
+
+ search_field = self._field_name
+ if isinstance(query, BaseDoc):
+ query_vec = self._get_values_by_column([query], search_field)[0]
+ else:
+ query_vec = query
+ query_vec_np = self._to_numpy(query_vec)
+ docs, scores = self._find(
+ query_vec_np, search_field=search_field, limit=limit, **kwargs
+ )
+
+ if isinstance(docs, List) and not isinstance(docs, DocList):
+ docs = self._dict_list_to_docarray(docs)
+
+ return FindResult(documents=docs, scores=scores)
+
+ def _find(
+ self,
+ query: np.ndarray,
+ limit: int,
+ search_field: str = '',
+ ) -> _FindResult:
+ """
+ Conducts a search on the index.
+
+ :param query: The vector query to search.
+ :param limit: The maximum number of results to return.
+ :param search_field: The field to search the query.
+ :return: Search results.
+ """
+
+ return self._hybrid_search(query=query, limit=limit, search_field=search_field)
+
+ def _hybrid_search(
+ self,
+ query: np.ndarray,
+ limit: int,
+ search_field: str = '',
+ expr: Optional[str] = None,
+ ):
+ """
+ Conducts a hybrid search on the index.
+
+ :param query: The vector query to search.
+ :param limit: The maximum number of results to return.
+ :param search_field: The field to search the query.
+ :param expr: Boolean expression used for filtering.
+ :return: Search results.
+ """
+ self._collection.load()
+
+ results = self._collection.search(
+ data=[query],
+ anns_field=search_field,
+ param=self._db_config.search_params,
+ limit=limit,
+ offset=0,
+ expr=expr,
+ output_fields=["serialized"],
+ consistency_level=self._db_config.consistency_level,
+ )
+
+ self._collection.release()
+
+ results = next(iter(results), None) # Only consider the first element
+
+ return self._docs_from_find_response(results)
+
+ def find_batched(
+ self,
+ queries: Union[AnyTensor, DocList],
+ search_field: str = '',
+ limit: int = 10,
+ **kwargs,
+ ) -> FindResultBatched:
+ """Find documents in the index using nearest neighbor search.
+
+ :param queries: query vector for KNN/ANN search.
+ Can be either a tensor-like (np.array, torch.Tensor, etc.) with a,
+ or a DocList.
+ If a tensor-like is passed, it should have shape (batch_size, vector_dim)
+ :param search_field: name of the field to search on.
+ Documents in the index are retrieved based on this similarity
+ of this field to the query.
+ :param limit: maximum number of documents to return per query
+ :return: a named tuple containing `documents` and `scores`
+ """
+ self._logger.debug(f'Executing `find_batched` for search field {search_field}')
+
+ if search_field:
+ if '__' in search_field:
+ fields = search_field.split('__')
+ if safe_issubclass(self._schema._get_field_annotation(fields[0]), AnyDocArray): # type: ignore
+ return self._subindices[fields[0]].find_batched(
+ queries,
+ search_field='__'.join(fields[1:]),
+ limit=limit,
+ **kwargs,
+ )
+ if search_field != '':
+ raise ValueError(
+ 'Argument search_field is not supported for MilvusDocumentIndex.'
+ 'Set search_field to an empty string to proceed.'
+ )
+ search_field = self._field_name
+ if isinstance(queries, Sequence):
+ query_vec_list = self._get_values_by_column(queries, search_field)
+ query_vec_np = np.stack(
+ tuple(self._to_numpy(query_vec) for query_vec in query_vec_list)
+ )
+ else:
+ query_vec_np = self._to_numpy(queries)
+
+ da_list, scores = self._find_batched(
+ query_vec_np, search_field=search_field, limit=limit, **kwargs
+ )
+ if (
+ len(da_list) > 0
+ and isinstance(da_list[0], List)
+ and not isinstance(da_list[0], DocList)
+ ):
+ da_list = [self._dict_list_to_docarray(docs) for docs in da_list]
+
+ return FindResultBatched(documents=da_list, scores=scores) # type: ignore
+
+ def _find_batched(
+ self,
+ queries: np.ndarray,
+ limit: int,
+ search_field: str = '',
+ ) -> _FindResultBatched:
+ """
+ Conducts a batched search on the index.
+
+ :param queries: The queries to search.
+ :param limit: The maximum number of results to return for each query.
+ :param search_field: The field to search the queries.
+ :return: Search results.
+ """
+
+ self._collection.load()
+
+ results = self._collection.search(
+ data=queries,
+ anns_field=self._field_name,
+ param=self._db_config.search_params,
+ limit=limit,
+ expr=None,
+ output_fields=["serialized"],
+ consistency_level=self._db_config.consistency_level,
+ )
+
+ self._collection.release()
+
+ documents, scores = zip(
+ *[self._docs_from_find_response(result) for result in results]
+ )
+
+ return _FindResultBatched(
+ documents=list(documents),
+ scores=list(scores),
+ )
+
+ def execute_query(self, query: Any, *args, **kwargs) -> Any:
+ """
+ Executes a hybrid query on the index.
+
+ :param query: Query to execute on the index.
+ :return: Query results.
+ """
+ components: Dict[str, List[Dict[str, Any]]] = {}
+ for component, value in query:
+ if component not in components:
+ components[component] = []
+ components[component].append(value)
+
+ if (
+ len(components) != 2
+ or len(components.get('find', [])) != 1
+ or len(components.get('filter', [])) != 1
+ ):
+ raise ValueError(
+ 'The query must contain exactly one "find" and "filter" components.'
+ )
+
+ expr = components['filter'][0]['filter_query']
+ query = components['find'][0]['query']
+ limit = (
+ components['find'][0].get('limit')
+ or components['filter'][0].get('limit')
+ or 10
+ )
+ docs, scores = self._hybrid_search(
+ query=query,
+ expr=expr,
+ search_field=self._field_name,
+ limit=limit,
+ )
+ if isinstance(docs, List) and not isinstance(docs, DocList):
+ docs = self._dict_list_to_docarray(docs)
+
+ return FindResult(documents=docs, scores=scores)
+
+ def _docs_from_query_response(self, result: Sequence[Dict]) -> DocList[Any]:
+ return DocList[self._schema]( # type: ignore
+ [
+ self._schema.from_base64( # type: ignore
+ result[i]["serialized"], **self._db_config.serialize_config
+ )
+ for i in range(len(result))
+ ]
+ )
+
+ def _docs_from_find_response(self, result: Hits) -> _FindResult:
+ scores: NdArray = NdArray._docarray_from_native(
+ np.array([hit.score for hit in result])
+ )
+
+ return _FindResult(
+ documents=DocList[self.out_schema]( # type: ignore
+ [
+ self.out_schema.from_base64(
+ hit.entity.get('serialized'), **self._db_config.serialize_config
+ )
+ for hit in result
+ ]
+ ),
+ scores=scores,
+ )
+
+ def _always_true_expr(self, primary_key: str) -> str:
+ """
+ Returns a Milvus expression that is always true, thus allowing for the retrieval of all entries in a Collection.
+ Assumes that the primary key is of type DataType.VARCHAR
+
+ :param primary_key: the name of the primary key
+ :return: a Milvus expression that is always true for that primary key
+ """
+ return f'({primary_key} in ["1"]) or ({primary_key} not in ["1"])'
+
+ def _map_embedding(self, embedding: AnyTensor) -> np.ndarray:
+ """
+ Milvus exclusively supports one-dimensional vectors. If multi-dimensional
+ vectors are provided, they will be automatically flattened to ensure compatibility.
+
+ :param embedding: The original raw embedding, which can be in the form of a TensorFlow or PyTorch tensor.
+ :return embedding: A one-dimensional numpy array representing the flattened version of the original embedding.
+ """
+ if embedding is None:
+ raise ValueError(
+ "Embedding is None. Each document must have a valid embedding."
+ )
+
+ embedding = self._to_numpy(embedding)
+ if embedding.ndim > 1:
+ embedding = np.asarray(embedding).squeeze() # type: ignore
+
+ return embedding
+
+ def _doc_exists(self, doc_id: str) -> bool:
+ result = self._collection.query(
+ expr="id in " + str([doc_id]),
+ offset=0,
+ output_fields=["serialized"],
+ )
+
+ return len(result) > 0
diff --git a/docarray/index/backends/mongodb_atlas.py b/docarray/index/backends/mongodb_atlas.py
new file mode 100644
index 00000000000..f1ccdec02d2
--- /dev/null
+++ b/docarray/index/backends/mongodb_atlas.py
@@ -0,0 +1,796 @@
+import collections
+import logging
+from dataclasses import dataclass, field
+from functools import cached_property
+from typing import (
+ Any,
+ Dict,
+ Generator,
+ Generic,
+ List,
+ NamedTuple,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+)
+
+import bson
+import numpy as np
+from pymongo import MongoClient
+
+from docarray import BaseDoc, DocList, handler
+from docarray.index.abstract import BaseDocIndex, _raise_not_composable
+from docarray.index.backends.helper import _collect_query_required_args
+from docarray.typing import AnyTensor
+from docarray.typing.tensor.abstract_tensor import AbstractTensor
+from docarray.utils._internal._typing import safe_issubclass
+from docarray.utils.find import _FindResult, _FindResultBatched
+
+logger = logging.getLogger(__name__)
+logger.addHandler(handler)
+
+
+MAX_CANDIDATES = 10_000
+OVERSAMPLING_FACTOR = 10
+TSchema = TypeVar('TSchema', bound=BaseDoc)
+
+
+class HybridResult(NamedTuple):
+ """Adds breakdown of scores into vector and text components."""
+
+ documents: Union[DocList, List[Dict[str, Any]]]
+ scores: AnyTensor
+ score_breakdown: Dict[str, List[Any]]
+
+
+class MongoDBAtlasDocumentIndex(BaseDocIndex, Generic[TSchema]):
+ """DocumentIndex backed by MongoDB Atlas Vector Store.
+
+ MongoDB Atlas provides full Text, Vector, and Hybrid Search
+ and can store structured data, text and vector indexes
+ in the same Collection (Index).
+
+ Atlas provides efficient index and search on vector embeddings
+ using the Hierarchical Navigable Small Worlds (HNSW) algorithm.
+
+ For documentation, see the following.
+ * Text Search: https://www.mongodb.com/docs/atlas/atlas-search/atlas-search-overview/
+ * Vector Search: https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-overview/
+ * Hybrid Search: https://www.mongodb.com/docs/atlas/atlas-vector-search/tutorials/reciprocal-rank-fusion/
+ """
+
+ def __init__(self, db_config=None, **kwargs):
+ super().__init__(db_config=db_config, **kwargs)
+ logger.info(f'{self.__class__.__name__} has been initialized')
+
+ @property
+ def index_name(self):
+ """The name of the index/collection in the database.
+
+ Note that in MongoDB Atlas, one has Collections (analogous to Tables),
+ which can have Search Indexes. They are distinct.
+ DocArray tends to consider them together.
+
+ The index_name can be set when initializing MongoDBAtlasDocumentIndex.
+ The easiest way is to pass index_name= as a kwarg.
+ Otherwise, a rational default uses the name of the DocumentTypes that it contains.
+ """
+
+ if self._db_config.index_name is not None:
+ return self._db_config.index_name
+ else:
+ # Create a reasonable default
+ if not self._schema:
+ raise ValueError(
+ 'A MongoDBAtlasDocumentIndex must be typed with a Document type.'
+ 'To do so, use the syntax: MongoDBAtlasDocumentIndex[DocumentType]'
+ )
+ schema_name = self._schema.__name__.lower()
+ logger.debug(f"db_config.index_name was not set. Using {schema_name}")
+ return schema_name
+
+ @property
+ def _database_name(self):
+ return self._db_config.database_name
+
+ @cached_property
+ def _client(self):
+ return self._connect_to_mongodb_atlas(
+ atlas_connection_uri=self._db_config.mongo_connection_uri
+ )
+
+ @property
+ def _collection(self):
+ """MongoDB Collection"""
+ return self._client[self._database_name][self.index_name]
+
+ @staticmethod
+ def _connect_to_mongodb_atlas(atlas_connection_uri: str):
+ """
+ Establish a connection to MongoDB Atlas.
+ """
+
+ client = MongoClient(
+ atlas_connection_uri,
+ # driver=DriverInfo(name="docarray", version=version("docarray"))
+ )
+ return client
+
+ def _create_indexes(self):
+ """Create a new index in the MongoDB database if it doesn't already exist."""
+
+ def _check_index_exists(self, index_name: str) -> bool:
+ """
+ Check if an index exists in the MongoDB Atlas database.
+
+ :param index_name: The name of the index.
+ :return: True if the index exists, False otherwise.
+ """
+
+ @dataclass
+ class Query:
+ """Dataclass describing a query."""
+
+ vector_fields: Optional[Dict[str, np.ndarray]]
+ filters: Optional[List[Any]]
+ text_searches: Optional[List[Any]]
+ limit: int
+
+ class QueryBuilder(BaseDocIndex.QueryBuilder):
+ """Compose complex queries containing vector search (find), text_search, and filters.
+
+ Arguments to `find` are vectors of embeddings, text_search expects strings,
+ and filters expect dicts of MongoDB Query Language (MDB).
+
+
+ NOTE: When doing Hybrid Search, pay close attention to the interpretation and use of inputs,
+ particularly when multiple calls are made of the same method (find, text_search, filter).
+ * find (Vector Search): Embedding vectors will be averaged. The penalty/weight defined in DBConfig will not change.
+ * text_search: Individual searches are performed, each with the same penalty/weight.
+ * filter: Within Vector Search, performs efficient k-NN filtering with the Lucene engine
+ """
+
+ def __init__(self, query: Optional[List[Tuple[str, Dict]]] = None):
+ super().__init__()
+ # list of tuples (method name, kwargs)
+ self._queries: List[Tuple[str, Dict]] = query or []
+
+ def build(self, limit: int = 1, *args, **kwargs) -> Any:
+ """Build a `Query` that can be passed to `execute_query`."""
+ search_fields: Dict[str, np.ndarray] = collections.defaultdict(list)
+ filters: List[Any] = []
+ text_searches: List[Any] = []
+ for method, kwargs in self._queries:
+ if method == 'find':
+ search_field = kwargs['search_field']
+ search_fields[search_field].append(kwargs["query"])
+
+ elif method == 'filter':
+ filters.append(kwargs)
+ else:
+ text_searches.append(kwargs)
+
+ vector_fields = {
+ field: np.average(vectors, axis=0)
+ for field, vectors in search_fields.items()
+ }
+ return MongoDBAtlasDocumentIndex.Query(
+ vector_fields=vector_fields,
+ filters=filters,
+ text_searches=text_searches,
+ limit=limit,
+ )
+
+ find = _collect_query_required_args('find', {'search_field', 'query'})
+ filter = _collect_query_required_args('filter', {'query'})
+ text_search = _collect_query_required_args(
+ 'text_search', {'search_field', 'query'}
+ )
+
+ find_batched = _raise_not_composable('find_batched')
+ filter_batched = _raise_not_composable('filter_batched')
+ text_search_batched = _raise_not_composable('text_search_batched')
+
+ def execute_query(
+ self, query: Any, *args, score_breakdown=True, **kwargs
+ ) -> Any: # _FindResult:
+ """Execute a Query on the database.
+
+ :param query: the query to execute. The output of this Document index's `QueryBuilder.build()` method.
+ :param args: positional arguments to pass to the query
+ :param score_breakdown: Will provide breakdown of scores into text and vector components for Hybrid Searches.
+ :param kwargs: keyword arguments to pass to the query
+ :return: the result of the query
+ """
+ if not isinstance(query, MongoDBAtlasDocumentIndex.Query):
+ raise ValueError(
+ "Expected MongoDBAtlasDocumentIndex.Query. Found {type(query)=}."
+ "For native calls to MongoDBAtlasDocumentIndex, simply call filter()"
+ )
+
+ if len(query.vector_fields) > 1:
+ self._logger.warning(
+ f"{len(query.vector_fields)} embedding vectors have been provided to the query. They will be averaged."
+ )
+ if len(query.text_searches) > 1:
+ self._logger.warning(
+ f"{len(query.text_searches)} text searches will be performed, and each receive a ranked score."
+ )
+
+ # collect filters
+ filters: List[Dict[str, Any]] = []
+ for filter_ in query.filters:
+ filters.append(filter_['query'])
+
+ # check if hybrid search is needed.
+ hybrid = len(query.vector_fields) + len(query.text_searches) > 1
+ if hybrid:
+ if len(query.vector_fields) > 1:
+ raise NotImplementedError(
+ "Hybrid Search on multiple Vector Indexes has yet to be done."
+ )
+ pipeline = self._hybrid_search(
+ query.vector_fields, query.text_searches, filters, query.limit
+ )
+ else:
+ if query.text_searches:
+ # it is a simple text search, perhaps with filters.
+ text_stage = self._text_search_stage(**query.text_searches[0])
+ pipeline = [
+ text_stage,
+ {"$match": {"$and": filters} if filters else {}},
+ {
+ '$project': self._project_fields(
+ extra_fields={"score": {'$meta': 'searchScore'}}
+ )
+ },
+ {"$limit": query.limit},
+ ]
+ elif query.vector_fields:
+ # it is a simple vector search, perhaps with filters.
+ assert (
+ len(query.vector_fields) == 1
+ ), "Query contains more than one vector_field."
+ field, vector_query = list(query.vector_fields.items())[0]
+ pipeline = [
+ self._vector_search_stage(
+ query=vector_query,
+ search_field=field,
+ limit=query.limit,
+ filters=filters,
+ ),
+ {
+ '$project': self._project_fields(
+ extra_fields={"score": {'$meta': 'vectorSearchScore'}}
+ )
+ },
+ ]
+ # it is only a filter search.
+ else:
+ pipeline = [{"$match": {"$and": filters}}]
+
+ with self._collection.aggregate(pipeline) as cursor:
+ results, scores = self._mongo_to_docs(cursor)
+ docs = self._dict_list_to_docarray(results)
+
+ if hybrid and score_breakdown and results:
+ score_breakdown = collections.defaultdict(list)
+ score_fields = [key for key in results[0] if "score" in key]
+ for res in results:
+ score_breakdown["id"].append(res["id"])
+ for sf in score_fields:
+ score_breakdown[sf].append(res[sf])
+ logger.debug(score_breakdown)
+ return HybridResult(
+ documents=docs, scores=scores, score_breakdown=score_breakdown
+ )
+
+ return _FindResult(documents=docs, scores=scores)
+
+ @dataclass
+ class DBConfig(BaseDocIndex.DBConfig):
+ mongo_connection_uri: str = 'localhost'
+ index_name: Optional[str] = None
+ database_name: Optional[str] = "default"
+ default_column_config: Dict[Type, Dict[str, Any]] = field(
+ default_factory=lambda: collections.defaultdict(
+ dict,
+ {
+ bson.BSONARR: {
+ 'distance': 'COSINE',
+ 'oversample_factor': OVERSAMPLING_FACTOR,
+ 'max_candidates': MAX_CANDIDATES,
+ 'indexed': False,
+ 'index_name': None,
+ 'penalty': 5,
+ },
+ bson.BSONSTR: {
+ 'indexed': False,
+ 'index_name': None,
+ 'operator': 'phrase',
+ 'penalty': 1,
+ },
+ },
+ )
+ )
+
+ @dataclass
+ class RuntimeConfig(BaseDocIndex.RuntimeConfig):
+ ...
+
+ def python_type_to_db_type(self, python_type: Type) -> Any:
+ """Map python type to database type.
+ Takes any python type and returns the corresponding database column type.
+
+ :param python_type: a python type.
+ :return: the corresponding database column type,
+ or None if ``python_type`` is not supported.
+ """
+
+ type_map = {
+ int: bson.BSONNUM,
+ float: bson.BSONDEC,
+ collections.OrderedDict: bson.BSONOBJ,
+ str: bson.BSONSTR,
+ bytes: bson.BSONBIN,
+ dict: bson.BSONOBJ,
+ np.ndarray: bson.BSONARR,
+ AbstractTensor: bson.BSONARR,
+ }
+
+ for py_type, mongo_types in type_map.items():
+ if safe_issubclass(python_type, py_type):
+ return mongo_types
+ raise ValueError(f'Unsupported column type for {type(self)}: {python_type}')
+
+ def _doc_to_mongo(self, doc):
+ result = doc.copy()
+
+ for name in result:
+ if self._column_infos[name].db_type == bson.BSONARR:
+ result[name] = list(result[name])
+
+ result["_id"] = result.pop("id")
+ return result
+
+ def _docs_to_mongo(self, docs):
+ return [self._doc_to_mongo(doc) for doc in docs]
+
+ @staticmethod
+ def _mongo_to_doc(mongo_doc: dict) -> dict:
+ result = mongo_doc.copy()
+ result["id"] = result.pop("_id")
+ score = result.get("score", None)
+ return result, score
+
+ @staticmethod
+ def _mongo_to_docs(mongo_docs: Generator[Dict, None, None]) -> List[dict]:
+ docs = []
+ scores = []
+ for mongo_doc in mongo_docs:
+ doc, score = MongoDBAtlasDocumentIndex._mongo_to_doc(mongo_doc)
+ docs.append(doc)
+ scores.append(score)
+
+ return docs, scores
+
+ def _get_oversampling_factor(self, search_field: str) -> int:
+ return self._column_infos[search_field].config["oversample_factor"]
+
+ def _get_max_candidates(self, search_field: str) -> int:
+ return self._column_infos[search_field].config["max_candidates"]
+
+ def _index(self, column_to_data: Dict[str, Generator[Any, None, None]]):
+ """Add and Index Documents to the datastore
+
+ The input format is aimed towards column vectors, which is not
+ the natural fit for MongoDB Collections, but we have chosen
+ not to override BaseDocIndex.index as it provides valuable validation.
+ This may change in the future.
+
+ :param column_to_data: is a dictionary from column name to a generator
+ """
+ self._index_subindex(column_to_data)
+ docs: List[Dict[str, Any]] = []
+ while True:
+ try:
+ doc = {key: next(column_to_data[key]) for key in column_to_data}
+ mongo_doc = self._doc_to_mongo(doc)
+ docs.append(mongo_doc)
+ except StopIteration:
+ break
+ self._collection.insert_many(docs)
+
+ def num_docs(self) -> int:
+ """Return the number of indexed documents"""
+ return self._collection.count_documents({})
+
+ @property
+ def _is_index_empty(self) -> bool:
+ """
+ Check if index is empty by comparing the number of documents to zero.
+ :return: True if the index is empty, False otherwise.
+ """
+ return self.num_docs() == 0
+
+ def _del_items(self, doc_ids: Sequence[str]) -> None:
+ """Delete Documents from the index.
+
+ :param doc_ids: ids to delete from the Document Store
+ """
+ mg_filter = {"_id": {"$in": doc_ids}}
+ self._collection.delete_many(mg_filter)
+
+ def _get_items(
+ self, doc_ids: Sequence[str]
+ ) -> Union[Sequence[TSchema], Sequence[Dict[str, Any]]]:
+ """Get Documents from the index, by `id`.
+ If no document is found, a KeyError is raised.
+
+ :param doc_ids: ids to get from the Document index
+ :return: Sequence of Documents, sorted corresponding to the order of `doc_ids`. Duplicate `doc_ids` can be omitted in the output.
+ """
+ mg_filter = {"_id": {"$in": doc_ids}}
+ docs = self._collection.find(mg_filter)
+ docs, _ = self._mongo_to_docs(docs)
+
+ if not docs:
+ raise KeyError(f'No document with id {doc_ids} found')
+ return docs
+
+ def _reciprocal_rank_stage(self, search_field: str, score_field: str):
+ penalty = self._column_infos[search_field].config["penalty"]
+ projection_fields = {
+ key: f"$docs.{key}" for key in self._column_infos.keys() if key != "id"
+ }
+ projection_fields["_id"] = "$docs._id"
+ projection_fields[score_field] = 1
+
+ return [
+ {"$group": {"_id": None, "docs": {"$push": "$$ROOT"}}},
+ {"$unwind": {"path": "$docs", "includeArrayIndex": "rank"}},
+ {
+ "$addFields": {
+ score_field: {"$divide": [1.0, {"$add": ["$rank", penalty, 1]}]}
+ }
+ },
+ {'$project': projection_fields},
+ ]
+
+ def _add_stage_to_pipeline(self, pipeline: List[Any], stage: Dict[str, Any]):
+ if pipeline:
+ pipeline.append(
+ {"$unionWith": {"coll": self.index_name, "pipeline": stage}}
+ )
+ else:
+ pipeline.extend(stage)
+ return pipeline
+
+ def _final_stage(self, scores_fields, limit):
+ """Sum individual scores, sort, and apply limit."""
+ doc_fields = self._column_infos.keys()
+ grouped_fields = {
+ key: {"$first": f"${key}"} for key in doc_fields if key != "_id"
+ }
+ best_score = {score: {'$max': f'${score}'} for score in scores_fields}
+ final_pipeline = [
+ {"$group": {"_id": "$_id", **grouped_fields, **best_score}},
+ {
+ "$project": {
+ **{doc_field: 1 for doc_field in doc_fields},
+ **{score: {"$ifNull": [f"${score}", 0]} for score in scores_fields},
+ }
+ },
+ {
+ "$addFields": {
+ "score": {"$add": [f"${score}" for score in scores_fields]},
+ }
+ },
+ {"$sort": {"score": -1}},
+ {"$limit": limit},
+ ]
+ return final_pipeline
+
+ @staticmethod
+ def _score_field(search_field: str, search_field_counts: Dict[str, int]):
+ score_field = f"{search_field}_score"
+ count = search_field_counts[search_field]
+ if count > 1:
+ score_field += str(count)
+ return score_field
+
+ def _hybrid_search(
+ self,
+ vector_queries: Dict[str, Any],
+ text_queries: List[Dict[str, Any]],
+ filters: Dict[str, Any],
+ limit: int,
+ ):
+ hybrid_pipeline = [] # combined aggregate pipeline
+ search_field_counts = collections.defaultdict(
+ int
+ ) # stores count of calls on same search field
+ score_fields = [] # names given to scores of each search stage
+ for search_field, query in vector_queries.items():
+ search_field_counts[search_field] += 1
+ vector_stage = self._vector_search_stage(
+ query=query,
+ search_field=search_field,
+ limit=limit,
+ filters=filters,
+ )
+ score_field = self._score_field(search_field, search_field_counts)
+ score_fields.append(score_field)
+ vector_pipeline = [
+ vector_stage,
+ *self._reciprocal_rank_stage(search_field, score_field),
+ ]
+ self._add_stage_to_pipeline(hybrid_pipeline, vector_pipeline)
+
+ for kwargs in text_queries:
+ search_field_counts[kwargs["search_field"]] += 1
+ text_stage = self._text_search_stage(**kwargs)
+ search_field = kwargs["search_field"]
+ score_field = self._score_field(search_field, search_field_counts)
+ score_fields.append(score_field)
+ reciprocal_rank_stage = self._reciprocal_rank_stage(
+ search_field, score_field
+ )
+ text_pipeline = [
+ text_stage,
+ {"$match": {"$and": filters} if filters else {}},
+ {"$limit": limit},
+ *reciprocal_rank_stage,
+ ]
+ self._add_stage_to_pipeline(hybrid_pipeline, text_pipeline)
+
+ hybrid_pipeline += self._final_stage(score_fields, limit)
+ return hybrid_pipeline
+
+ def _vector_search_stage(
+ self,
+ query: np.ndarray,
+ search_field: str,
+ limit: int,
+ filters: List[Dict[str, Any]] = None,
+ ) -> Dict[str, Any]:
+
+ search_index_name = self._get_column_db_index(search_field)
+ oversampling_factor = self._get_oversampling_factor(search_field)
+ max_candidates = self._get_max_candidates(search_field)
+ query = query.astype(np.float64).tolist()
+
+ stage = {
+ '$vectorSearch': {
+ 'index': search_index_name,
+ 'path': search_field,
+ 'queryVector': query,
+ 'numCandidates': min(limit * oversampling_factor, max_candidates),
+ 'limit': limit,
+ }
+ }
+ if filters:
+ stage['$vectorSearch']['filter'] = {"$and": filters}
+ return stage
+
+ def _text_search_stage(
+ self,
+ query: str,
+ search_field: str,
+ ) -> Dict[str, Any]:
+ operator = self._column_infos[search_field].config["operator"]
+ index = self._get_column_db_index(search_field)
+ return {
+ "$search": {
+ "index": index,
+ operator: {"query": query, "path": search_field},
+ }
+ }
+
+ def _doc_exists(self, doc_id: str) -> bool:
+ """
+ Checks if a given document exists in the index.
+
+ :param doc_id: The id of a document to check.
+ :return: True if the document exists in the index, False otherwise.
+ """
+ doc = self._collection.find_one({"_id": doc_id})
+ return bool(doc)
+
+ def _find(
+ self,
+ query: np.ndarray,
+ limit: int,
+ search_field: str = '',
+ ) -> _FindResult:
+ """Find documents in the index
+
+ :param query: query vector for KNN/ANN search. Has single axis.
+ :param limit: maximum number of documents to return per query
+ :param search_field: name of the field to search on
+ :return: a named tuple containing `documents` and `scores`
+ """
+ # NOTE: in standard implementations,
+ # `search_field` is equal to the column name to search on
+
+ vector_search_stage = self._vector_search_stage(query, search_field, limit)
+
+ pipeline = [
+ vector_search_stage,
+ {
+ '$project': self._project_fields(
+ extra_fields={"score": {'$meta': 'vectorSearchScore'}}
+ )
+ },
+ ]
+
+ with self._collection.aggregate(pipeline) as cursor:
+ documents, scores = self._mongo_to_docs(cursor)
+
+ return _FindResult(documents=documents, scores=scores)
+
+ def _find_batched(
+ self, queries: np.ndarray, limit: int, search_field: str = ''
+ ) -> _FindResultBatched:
+ """Find documents in the index
+
+ :param queries: query vectors for KNN/ANN search.
+ Has shape (batch_size, vector_dim)
+ :param limit: maximum number of documents to return
+ :param search_field: name of the field to search on
+ :return: a named tuple containing `documents` and `scores`
+ """
+ docs, scores = [], []
+ for query in queries:
+ results = self._find(query=query, search_field=search_field, limit=limit)
+ docs.append(results.documents)
+ scores.append(results.scores)
+
+ return _FindResultBatched(documents=docs, scores=scores)
+
+ def _get_column_db_index(self, column_name: str) -> Optional[str]:
+ """
+ Retrieve the index name associated with the specified column name.
+
+ Parameters:
+ column_name (str): The name of the column.
+
+ Returns:
+ Optional[str]: The index name associated with the specified column name, or None if not found.
+ """
+ index_name = self._column_infos[column_name].config.get("index_name")
+
+ is_vector_index = safe_issubclass(
+ self._column_infos[column_name].docarray_type, AbstractTensor
+ )
+ is_text_index = safe_issubclass(
+ self._column_infos[column_name].docarray_type, str
+ )
+
+ if index_name is None or not isinstance(index_name, str):
+ if is_vector_index:
+ raise ValueError(
+ f'The column {column_name} for MongoDBAtlasDocumentIndex should be associated '
+ 'with an Atlas Vector Index.'
+ )
+ elif is_text_index:
+ raise ValueError(
+ f'The column {column_name} for MongoDBAtlasDocumentIndex should be associated '
+ 'with an Atlas Index.'
+ )
+ if not (is_vector_index or is_text_index):
+ raise ValueError(
+ f'The column {column_name} for MongoDBAtlasDocumentIndex cannot be associated to an index'
+ )
+
+ return index_name
+
+ def _project_fields(self, extra_fields: Dict[str, Any] = None) -> dict:
+ """
+ Create a projection dictionary to include all fields defined in the column information.
+
+ Returns:
+ dict: A dictionary where each field key from the column information is mapped to the value 1,
+ indicating that the field should be included in the projection.
+ """
+
+ fields = {key: 1 for key in self._column_infos.keys() if key != "id"}
+ fields["_id"] = 1
+ if extra_fields:
+ fields.update(extra_fields)
+ return fields
+
+ def _filter(
+ self,
+ filter_query: Any,
+ limit: int,
+ ) -> Union[DocList, List[Dict]]:
+ """Find documents in the index based on a filter query
+
+ :param filter_query: the DB specific filter query to execute
+ :param limit: maximum number of documents to return
+ :return: a DocList containing the documents that match the filter query
+ """
+ with self._collection.find(filter_query, limit=limit) as cursor:
+ return self._mongo_to_docs(cursor)[0]
+
+ def _filter_batched(
+ self,
+ filter_queries: Any,
+ limit: int,
+ ) -> Union[List[DocList], List[List[Dict]]]:
+ """Find documents in the index based on multiple filter queries.
+ Each query is considered individually, and results are returned per query.
+
+ :param filter_queries: the DB specific filter queries to execute
+ :param limit: maximum number of documents to return per query
+ :return: List of DocLists containing the documents that match the filter
+ queries
+ """
+ return [self._filter(query, limit) for query in filter_queries]
+
+ def _text_search(
+ self,
+ query: str,
+ limit: int,
+ search_field: str = '',
+ ) -> _FindResult:
+ """Find documents in the index based on a text search query
+
+ :param query: The text to search for
+ :param limit: maximum number of documents to return
+ :param search_field: name of the field to search on
+ :return: a named tuple containing `documents` and `scores`
+ """
+ text_stage = self._text_search_stage(query=query, search_field=search_field)
+
+ pipeline = [
+ text_stage,
+ {
+ '$project': self._project_fields(
+ extra_fields={'score': {'$meta': 'searchScore'}}
+ )
+ },
+ {"$limit": limit},
+ ]
+
+ with self._collection.aggregate(pipeline) as cursor:
+ documents, scores = self._mongo_to_docs(cursor)
+
+ return _FindResult(documents=documents, scores=scores)
+
+ def _text_search_batched(
+ self,
+ queries: Sequence[str],
+ limit: int,
+ search_field: str = '',
+ ) -> _FindResultBatched:
+ """Find documents in the index based on a text search query
+
+ :param queries: The texts to search for
+ :param limit: maximum number of documents to return per query
+ :param search_field: name of the field to search on
+ :return: a named tuple containing `documents` and `scores`
+ """
+ # NOTE: in standard implementations,
+ # `search_field` is equal to the column name to search on
+ documents, scores = [], []
+ for query in queries:
+ results = self._text_search(
+ query=query, search_field=search_field, limit=limit
+ )
+ documents.append(results.documents)
+ scores.append(results.scores)
+ return _FindResultBatched(documents=documents, scores=scores)
+
+ def _filter_by_parent_id(self, id: str) -> Optional[List[str]]:
+ """Filter the ids of the subindex documents given id of root document.
+
+ :param id: the root document id to filter by
+ :return: a list of ids of the subindex documents
+ """
+ with self._collection.find({"parent_id": id}, projection={"_id": 1}) as cursor:
+ return [doc["_id"] for doc in cursor]
diff --git a/docarray/index/backends/qdrant.py b/docarray/index/backends/qdrant.py
index e2e503d593c..6f1330f9eab 100644
--- a/docarray/index/backends/qdrant.py
+++ b/docarray/index/backends/qdrant.py
@@ -21,6 +21,7 @@
import docarray.typing.id
from docarray import BaseDoc, DocList
+from docarray.array.any_array import AnyDocArray
from docarray.index.abstract import (
BaseDocIndex,
_ColumnInfo,
@@ -29,6 +30,7 @@
)
from docarray.typing import NdArray
from docarray.typing.tensor.abstract_tensor import AbstractTensor
+from docarray.utils._internal._typing import safe_issubclass
from docarray.utils._internal.misc import import_library, torch_imported
from docarray.utils.find import _FindResult
@@ -65,6 +67,10 @@ class QdrantDocumentIndex(BaseDocIndex, Generic[TSchema]):
def __init__(self, db_config=None, **kwargs):
"""Initialize QdrantDocumentIndex"""
+ if db_config is not None and getattr(
+ db_config, 'index_name'
+ ): # this is needed for subindices
+ db_config.collection_name = db_config.index_name
super().__init__(db_config=db_config, **kwargs)
self._db_config: QdrantDocumentIndex.DBConfig = cast(
QdrantDocumentIndex.DBConfig, self._db_config
@@ -98,6 +104,10 @@ def collection_name(self):
return self._db_config.collection_name or default_collection_name
+ @property
+ def index_name(self):
+ return self.collection_name
+
@dataclass
class Query:
"""Dataclass describing a query."""
@@ -232,11 +242,6 @@ class DBConfig(BaseDocIndex.DBConfig):
optimizers_config: Optional[types.OptimizersConfigDiff] = None
wal_config: Optional[types.WalConfigDiff] = None
quantization_config: Optional[types.QuantizationConfig] = None
-
- @dataclass
- class RuntimeConfig(BaseDocIndex.RuntimeConfig):
- """Dataclass that contains all "dynamic" configurations of QdrantDocumentIndex."""
-
default_column_config: Dict[Type, Dict[str, Any]] = field(
default_factory=lambda: {
'id': {}, # type: ignore[dict-item]
@@ -245,6 +250,18 @@ class RuntimeConfig(BaseDocIndex.RuntimeConfig):
}
)
+ def __post_init__(self):
+ if self.collection_name is None and self.index_name is not None:
+ self.collection_name = self.index_name
+ if self.index_name is None and self.collection_name is not None:
+ self.index_name = self.collection_name
+
+ @dataclass
+ class RuntimeConfig(BaseDocIndex.RuntimeConfig):
+ """Dataclass that contains all "dynamic" configurations of QdrantDocumentIndex."""
+
+ pass
+
def python_type_to_db_type(self, python_type: Type) -> Any:
"""Map python type to database type.
Takes any python type and returns the corresponding database column type.
@@ -252,10 +269,10 @@ def python_type_to_db_type(self, python_type: Type) -> Any:
:param python_type: a python type.
:return: the corresponding database column type.
"""
- if any(issubclass(python_type, vt) for vt in QDRANT_PY_VECTOR_TYPES):
+ if any(safe_issubclass(python_type, vt) for vt in QDRANT_PY_VECTOR_TYPES):
return 'vector'
- if issubclass(python_type, docarray.typing.id.ID):
+ if safe_issubclass(python_type, docarray.typing.id.ID):
return 'id'
return 'payload'
@@ -264,11 +281,14 @@ def _initialize_collection(self):
try:
self._client.get_collection(self.collection_name)
except (UnexpectedResponse, RpcError, ValueError):
- vectors_config = {
- column_name: self._to_qdrant_vector_params(column_info)
- for column_name, column_info in self._column_infos.items()
- if column_info.db_type == 'vector'
- }
+ vectors_config = {}
+
+ for column_name, column_info in self._column_infos.items():
+ if column_info.db_type == 'vector':
+ vectors_config[column_name] = self._to_qdrant_vector_params(
+ column_info
+ )
+
self._client.create_collection(
collection_name=self.collection_name,
vectors_config=vectors_config,
@@ -288,6 +308,8 @@ def _initialize_collection(self):
)
def _index(self, column_to_data: Dict[str, Generator[Any, None, None]]):
+ self._index_subindex(column_to_data)
+
rows = self._transpose_col_value_dict(column_to_data)
# TODO: add batching the documents to avoid timeouts
points = [self._build_point_from_row(row) for row in rows]
@@ -302,6 +324,17 @@ def num_docs(self) -> int:
"""
return self._client.count(collection_name=self.collection_name).count
+ def _doc_exists(self, doc_id: str) -> bool:
+ response, _ = self._client.scroll(
+ collection_name=self.index_name,
+ scroll_filter=rest.Filter(
+ must=[
+ rest.HasIdCondition(has_id=[self._to_qdrant_id(doc_id)]),
+ ],
+ ),
+ )
+ return len(response) > 0
+
def _del_items(self, doc_ids: Sequence[str]):
items = self._get_items(doc_ids)
if len(items) < len(doc_ids):
@@ -332,7 +365,10 @@ def _get_items(
with_payload=True,
with_vectors=True,
)
- return [self._convert_to_doc(point) for point in response]
+ return sorted(
+ [self._convert_to_doc(point) for point in response],
+ key=lambda x: doc_ids.index(x['id']),
+ )
def execute_query(self, query: Union[Query, RawQuery], *args, **kwargs) -> DocList:
"""
@@ -532,11 +568,29 @@ def _text_search_batched(
],
)
+ def _filter_by_parent_id(self, id: str) -> Optional[List[str]]:
+ response, _ = self._client.scroll(
+ collection_name=self.collection_name, # type: ignore
+ scroll_filter=rest.Filter(
+ must=[
+ rest.FieldCondition(
+ key='parent_id', match=rest.MatchValue(value=id)
+ )
+ ]
+ ),
+ with_payload=rest.PayloadSelectorInclude(include=['id']),
+ )
+
+ ids = [point.payload['id'] for point in response] # type: ignore
+ return ids
+
def _build_point_from_row(self, row: Dict[str, Any]) -> rest.PointStruct:
point_id = self._to_qdrant_id(row.get('id'))
vectors: Dict[str, List[float]] = {}
payload: Dict[str, Any] = {'__generated_vectors': []}
for column_name, column_info in self._column_infos.items():
+ if safe_issubclass(column_info.docarray_type, AnyDocArray):
+ continue
if column_info.db_type in ['id', 'payload']:
payload[column_name] = row.get(column_name)
continue
@@ -578,7 +632,11 @@ def _convert_to_doc(
self, point: Union[rest.ScoredPoint, rest.Record]
) -> Dict[str, Any]:
document = cast(Dict[str, Any], point.payload)
- generated_vectors = document.pop('__generated_vectors')
+ generated_vectors = (
+ document.pop('__generated_vectors')
+ if '__generated_vectors' in document
+ else []
+ )
vectors = point.vector if point.vector else dict()
if not isinstance(vectors, dict):
vectors = {'__default__': vectors}
diff --git a/docarray/index/backends/redis.py b/docarray/index/backends/redis.py
new file mode 100644
index 00000000000..2a338b424aa
--- /dev/null
+++ b/docarray/index/backends/redis.py
@@ -0,0 +1,606 @@
+from collections import defaultdict
+from typing import (
+ TypeVar,
+ Generic,
+ Optional,
+ List,
+ Dict,
+ Any,
+ Sequence,
+ Union,
+ Generator,
+ Type,
+ cast,
+ TYPE_CHECKING,
+ Iterator,
+ Mapping,
+ Tuple,
+)
+from dataclasses import dataclass, field
+
+import json
+import numpy as np
+from numpy import ndarray
+
+from docarray.array import AnyDocArray
+from docarray.index.backends.helper import _collect_query_args
+from docarray import BaseDoc, DocList
+from docarray.index.abstract import (
+ BaseDocIndex,
+ _raise_not_composable,
+)
+from docarray.typing import NdArray
+from docarray.typing.tensor.abstract_tensor import AbstractTensor
+from docarray.utils._internal._typing import safe_issubclass
+from docarray.utils._internal.misc import import_library
+from docarray.utils.find import _FindResultBatched, _FindResult, FindResult
+
+if TYPE_CHECKING:
+ import redis
+ from redis.commands.search.query import Query
+ from redis.commands.search.field import ( # type: ignore[import]
+ NumericField,
+ TextField,
+ VectorField,
+ TagField,
+ )
+ from redis.commands.search.indexDefinition import IndexDefinition, IndexType # type: ignore[import]
+else:
+ redis = import_library('redis')
+
+ from redis.commands.search.field import (
+ NumericField,
+ TextField,
+ VectorField,
+ TagField,
+ )
+ from redis.commands.search.indexDefinition import IndexDefinition, IndexType
+ from redis.commands.search.query import Query
+
+TSchema = TypeVar('TSchema', bound=BaseDoc)
+
+VALID_DISTANCES = ['L2', 'IP', 'COSINE']
+VALID_ALGORITHMS = ['FLAT', 'HNSW']
+VALID_TEXT_SCORERS = [
+ 'BM25',
+ 'TFIDF',
+ 'TFIDF.DOCNORM',
+ 'DISMAX',
+ 'DOCSCORE',
+ 'HAMMING',
+]
+
+
+class RedisDocumentIndex(BaseDocIndex, Generic[TSchema]):
+ def __init__(self, db_config=None, **kwargs):
+ """Initialize RedisDocumentIndex"""
+ super().__init__(db_config=db_config, **kwargs)
+ self._db_config = cast(RedisDocumentIndex.DBConfig, self._db_config)
+
+ self._runtime_config: RedisDocumentIndex.RuntimeConfig = cast(
+ RedisDocumentIndex.RuntimeConfig, self._runtime_config
+ )
+ self._prefix = self.index_name + ':'
+ self._text_scorer = self._db_config.text_scorer
+ # initialize Redis client
+ self._client = redis.Redis(
+ host=self._db_config.host,
+ port=self._db_config.port,
+ username=self._db_config.username,
+ password=self._db_config.password,
+ decode_responses=False,
+ )
+ self._create_index()
+ self._logger.info(f'{self.__class__.__name__} has been initialized')
+
+ def _create_index(self) -> None:
+ """Create a new index in the Redis database if it doesn't already exist."""
+ if not self._check_index_exists(self.index_name):
+ schema = []
+ for column, info in self._column_infos.items():
+ if safe_issubclass(info.docarray_type, AnyDocArray):
+ continue
+ elif info.db_type == VectorField:
+ space = info.config.get('space') or info.config.get('distance')
+ if not space or space.upper() not in VALID_DISTANCES:
+ raise ValueError(
+ f"Invalid distance metric '{space}' provided. "
+ f"Must be one of: {', '.join(VALID_DISTANCES)}"
+ )
+ space = space.upper()
+ attributes = {
+ 'TYPE': 'FLOAT32',
+ 'DIM': info.n_dim or info.config.get('dim'),
+ 'DISTANCE_METRIC': space,
+ 'EF_CONSTRUCTION': info.config['ef_construction'],
+ 'EF_RUNTIME': info.config['ef_runtime'],
+ 'M': info.config['m'],
+ 'INITIAL_CAP': info.config['initial_cap'],
+ }
+ attributes = {
+ name: value for name, value in attributes.items() if value
+ }
+ algorithm = info.config['algorithm'].upper()
+ if algorithm not in VALID_ALGORITHMS:
+ raise ValueError(
+ f"Invalid algorithm '{algorithm}' provided. "
+ f"Must be one of: {', '.join(VALID_ALGORITHMS)}"
+ )
+ schema.append(
+ info.db_type(
+ '$.' + column,
+ algorithm=algorithm,
+ attributes=attributes,
+ as_name=column,
+ )
+ )
+ elif column in ['id', 'parent_id']:
+ schema.append(TagField('$.' + column, as_name=column))
+ else:
+ schema.append(info.db_type('$.' + column, as_name=column))
+
+ # Create Redis Index
+ self._client.ft(self.index_name).create_index(
+ schema,
+ definition=IndexDefinition(
+ prefix=[self._prefix], index_type=IndexType.JSON
+ ),
+ )
+
+ self._logger.info(f'index {self.index_name} has been created')
+ else:
+ self._logger.info(f'connected to existing {self.index_name} index')
+
+ def _check_index_exists(self, index_name: str) -> bool:
+ """
+ Check if an index exists in the Redis database.
+
+ :param index_name: The name of the index.
+ :return: True if the index exists, False otherwise.
+ """
+ try:
+ self._client.ft(index_name).info()
+ except: # noqa: E722
+ self._logger.info(f'Index {index_name} does not exist')
+ return False
+ self._logger.info(f'Index {index_name} already exists')
+ return True
+
+ @property
+ def index_name(self):
+ default_index_name = (
+ self._schema.__name__.lower() if self._schema is not None else None
+ )
+ if default_index_name is None:
+ err_msg = (
+ 'A RedisDocumentIndex must be typed with a Document type. '
+ 'To do so, use the syntax: RedisDocumentIndex[DocumentType]'
+ )
+
+ self._logger.error(err_msg)
+ raise ValueError(err_msg)
+ index_name = self._db_config.index_name or default_index_name
+ self._logger.debug(f'Retrieved index name: {index_name}')
+ return index_name
+
+ @property
+ def out_schema(self) -> Type[BaseDoc]:
+ """Return the real schema of the index."""
+ if self._is_subindex:
+ return self._ori_schema
+ return cast(Type[BaseDoc], self._schema)
+
+ class QueryBuilder(BaseDocIndex.QueryBuilder):
+ def __init__(self, query: Optional[List[Tuple[str, Dict]]] = None):
+ super().__init__()
+ # list of tuples (method name, kwargs)
+ self._queries: List[Tuple[str, Dict]] = query or []
+
+ def build(self, *args, **kwargs) -> Any:
+ """Build the query object."""
+ return self._queries
+
+ find = _collect_query_args('find')
+ filter = _collect_query_args('filter')
+ text_search = _raise_not_composable('text_search')
+ find_batched = _raise_not_composable('find_batched')
+ filter_batched = _raise_not_composable('filter_batched')
+ text_search_batched = _raise_not_composable('text_search_batched')
+
+ @dataclass
+ class DBConfig(BaseDocIndex.DBConfig):
+ """Dataclass that contains all "static" configurations of RedisDocumentIndex.
+
+ :param host: The host address for the Redis server. Default is 'localhost'.
+ :param port: The port number for the Redis server. Default is 6379.
+ :param index_name: The name of the index in the Redis database.
+ If not provided, default index name will be used.
+ :param username: The username for the Redis server. Default is None.
+ :param password: The password for the Redis server. Default is None.
+ :param text_scorer: The method for scoring text during text search.
+ Default is 'BM25'.
+ :param default_column_config: Default configuration for columns.
+ """
+
+ host: str = 'localhost'
+ port: int = 6379
+ index_name: Optional[str] = None
+ username: Optional[str] = None
+ password: Optional[str] = None
+ text_scorer: str = field(default='BM25')
+ default_column_config: Dict[Type, Dict[str, Any]] = field(
+ default_factory=lambda: defaultdict(
+ dict,
+ {
+ VectorField: {
+ 'algorithm': 'FLAT',
+ 'distance': 'COSINE',
+ 'ef_construction': None,
+ 'm': None,
+ 'ef_runtime': None,
+ 'initial_cap': None,
+ },
+ },
+ )
+ )
+
+ def __post_init__(self):
+ self.text_scorer = self.text_scorer.upper()
+
+ if self.text_scorer not in VALID_TEXT_SCORERS:
+ raise ValueError(
+ f"Invalid text scorer '{self.text_scorer}' provided. "
+ f"Must be one of: {', '.join(VALID_TEXT_SCORERS)}"
+ )
+
+ @dataclass
+ class RuntimeConfig(BaseDocIndex.RuntimeConfig):
+ """Dataclass that contains all "dynamic" configurations of RedisDocumentIndex.
+
+ :param batch_size: Batch size for index/get/del.
+ """
+
+ batch_size: int = 100
+
+ def python_type_to_db_type(self, python_type: Type) -> Any:
+ """
+ Map python types to corresponding Redis types.
+
+ :param python_type: Python type.
+ :return: Corresponding Redis type.
+ """
+ type_map = {
+ int: NumericField,
+ float: NumericField,
+ str: TextField,
+ bytes: TextField,
+ np.ndarray: VectorField,
+ list: VectorField,
+ AbstractTensor: VectorField,
+ }
+
+ for py_type, redis_type in type_map.items():
+ if safe_issubclass(python_type, py_type):
+ return redis_type
+ raise ValueError(f'Unsupported column type for {type(self)}: {python_type}')
+
+ @staticmethod
+ def _generate_items(
+ column_to_data: Dict[str, Generator[Any, None, None]],
+ batch_size: int,
+ ) -> Iterator[List[Dict[str, Any]]]:
+ """
+ Given a dictionary of data generators, yield a list of dictionaries where each
+ item consists of a column name and a single item from the corresponding generator.
+
+ :param column_to_data: A dictionary where each key is a column name and each value
+ is a generator.
+ :param batch_size: Size of batch to generate each time.
+
+ :yield: A list of dictionaries where each item consists of a column name and
+ an item from the corresponding generator. Yields until all generators
+ are exhausted.
+ """
+ column_names = list(column_to_data.keys())
+ data_generators = [iter(column_to_data[name]) for name in column_names]
+ batch: List[Dict[str, Any]] = []
+
+ while True:
+ data_dict = {}
+ for name, generator in zip(column_names, data_generators):
+ item = next(generator, None)
+
+ if name == 'id' and not item:
+ if batch:
+ yield batch
+ return
+
+ if isinstance(item, AbstractTensor):
+ data_dict[name] = item._docarray_to_ndarray().tolist()
+ elif isinstance(item, ndarray):
+ data_dict[name] = item.astype(np.float32).tolist()
+ elif item is not None:
+ data_dict[name] = item
+
+ batch.append(data_dict)
+ if len(batch) == batch_size:
+ yield batch
+ batch = []
+
+ def _index(
+ self, column_to_data: Dict[str, Generator[Any, None, None]]
+ ) -> List[str]:
+ """
+ Indexes the given data into Redis.
+
+ :param column_to_data: A dictionary where each key is a column and each value is a generator.
+ :return: A list of document ids that have been indexed.
+ """
+ self._index_subindex(column_to_data)
+ ids: List[str] = []
+ for items in self._generate_items(
+ column_to_data, self._runtime_config.batch_size
+ ):
+ doc_id_item_pairs = [
+ (self._prefix + item['id'], '$', item) for item in items
+ ]
+ ids.extend(doc_id for doc_id, _, _ in doc_id_item_pairs)
+ self._client.json().mset(doc_id_item_pairs) # type: ignore[attr-defined]
+
+ return ids
+
+ def num_docs(self) -> int:
+ """
+ Fetch the number of documents in the index.
+
+ :return: Number of documents in the index.
+ """
+ num_docs = self._client.ft(self.index_name).info()['num_docs']
+ return int(num_docs)
+
+ def _del_items(self, doc_ids: Sequence[str]) -> None:
+ """
+ Deletes documents from the index based on document ids.
+
+ :param doc_ids: A sequence of document ids to be deleted.
+ """
+ doc_ids = [self._prefix + id for id in doc_ids if self._doc_exists(id)]
+ if doc_ids:
+ for batch in self._generate_batches(
+ doc_ids, batch_size=self._runtime_config.batch_size
+ ):
+ self._client.delete(*batch)
+
+ def _doc_exists(self, doc_id: str) -> bool:
+ """
+ Checks if a document exists in the index.
+
+ :param doc_id: The id of the document.
+ :return: True if the document exists, False otherwise.
+ """
+ return bool(self._client.exists(self._prefix + doc_id))
+
+ @staticmethod
+ def _generate_batches(data, batch_size):
+ for i in range(0, len(data), batch_size):
+ yield data[i : i + batch_size]
+
+ def _get_items(
+ self, doc_ids: Sequence[str]
+ ) -> Union[Sequence[TSchema], Sequence[Dict[str, Any]]]:
+ """
+ Fetches the documents from the index based on document ids.
+
+ :param doc_ids: A sequence of document ids.
+ :return: A sequence of documents from the index.
+ """
+ if not doc_ids:
+ return []
+ docs: List[Dict[str, Any]] = []
+ for batch in self._generate_batches(
+ doc_ids, batch_size=self._runtime_config.batch_size
+ ):
+ ids = [self._prefix + id for id in batch]
+ retrieved_docs = self._client.json().mget(ids, '$')
+ docs.extend(doc[0] for doc in retrieved_docs if doc)
+
+ if not docs:
+ raise KeyError(f'No document with id {doc_ids} found')
+ return docs
+
+ def execute_query(self, query: Any, *args: Any, **kwargs: Any) -> Any:
+ """
+ Executes a hybrid query on the index.
+
+ :param query: Query to execute on the index.
+ :return: Query results.
+ """
+ components: Dict[str, List[Dict[str, Any]]] = {}
+ for component, value in query:
+ if component not in components:
+ components[component] = []
+ components[component].append(value)
+
+ if (
+ len(components) != 2
+ or len(components.get('find', [])) != 1
+ or len(components.get('filter', [])) != 1
+ ):
+ raise ValueError(
+ 'The query must contain exactly one "find" and "filter" components.'
+ )
+
+ filter_query = components['filter'][0]['filter_query']
+ query = components['find'][0]['query']
+ search_field = components['find'][0]['search_field']
+ limit = (
+ components['find'][0].get('limit')
+ or components['filter'][0].get('limit')
+ or 10
+ )
+ docs, scores = self._hybrid_search(
+ query=query,
+ filter_query=filter_query,
+ search_field=search_field,
+ limit=limit,
+ )
+ docs = self._dict_list_to_docarray(docs)
+ return FindResult(documents=docs, scores=scores)
+
+ def _hybrid_search(
+ self, query: np.ndarray, filter_query: str, search_field: str, limit: int
+ ) -> _FindResult:
+ """
+ Conducts a hybrid search (a combination of vector search and filter-based search) on the index.
+
+ :param query: The query to search.
+ :param filter_query: The filter condition.
+ :param search_field: The vector field to search on.
+ :param limit: The maximum number of results to return.
+ :return: Query results.
+ """
+ redis_query = (
+ Query(f'{filter_query}=>[KNN {limit} @{search_field} $vec AS vector_score]')
+ .sort_by('vector_score')
+ .paging(0, limit)
+ .dialect(2)
+ )
+ query_params: Mapping[str, bytes] = {
+ 'vec': np.array(query, dtype=np.float32).tobytes()
+ }
+ results = (
+ self._client.ft(self.index_name).search(redis_query, query_params).docs # type: ignore[arg-type]
+ )
+
+ scores: NdArray = NdArray._docarray_from_native(
+ np.array([document['vector_score'] for document in results])
+ )
+
+ docs = []
+ for out_doc in results:
+ doc_dict = json.loads(out_doc.json)
+ docs.append(doc_dict)
+ return _FindResult(documents=docs, scores=scores)
+
+ def _find(
+ self, query: np.ndarray, limit: int, search_field: str = ''
+ ) -> _FindResult:
+ """
+ Conducts a search on the index.
+
+ :param query: The vector query to search.
+ :param limit: The maximum number of results to return.
+ :param search_field: The field to search the query.
+ :return: Search results.
+ """
+ return self._hybrid_search(
+ query=query, filter_query='*', search_field=search_field, limit=limit
+ )
+
+ def _find_batched(
+ self, queries: np.ndarray, limit: int, search_field: str = ''
+ ) -> _FindResultBatched:
+ """
+ Conducts a batched search on the index.
+
+ :param queries: The queries to search.
+ :param limit: The maximum number of results to return for each query.
+ :param search_field: The field to search the queries.
+ :return: Search results.
+ """
+ docs, scores = [], []
+ for query in queries:
+ results = self._find(query=query, search_field=search_field, limit=limit)
+ docs.append(results.documents)
+ scores.append(results.scores)
+
+ return _FindResultBatched(documents=docs, scores=scores)
+
+ def _filter(self, filter_query: Any, limit: int) -> Union[DocList, List[Dict]]:
+ """
+ Filters the index based on the given filter query.
+
+ :param filter_query: The filter condition.
+ :param limit: The maximum number of results to return.
+ :return: Filter results.
+ """
+ q = Query(filter_query)
+ q.paging(0, limit)
+
+ results = self._client.ft(index_name=self.index_name).search(q).docs
+ docs = [json.loads(doc.json) for doc in results]
+ return docs
+
+ def _filter_batched(
+ self, filter_queries: Any, limit: int
+ ) -> Union[List[DocList], List[List[Dict]]]:
+ """
+ Filters the index based on the given batch of filter queries.
+
+ :param filter_queries: The filter conditions.
+ :param limit: The maximum number of results to return for each filter query.
+ :return: Filter results.
+ """
+ results = []
+ for query in filter_queries:
+ results.append(self._filter(filter_query=query, limit=limit))
+ return results
+
+ def _filter_by_parent_id(self, id: str) -> Optional[List[str]]:
+ """Filter the ids of the subindex documents given id of root document.
+
+ :param id: the root document id to filter by
+ :return: a list of ids of the subindex documents
+ """
+ docs = self._filter(filter_query=f'@parent_id:{{{id}}}', limit=self.num_docs())
+ return [doc['id'] for doc in docs]
+
+ def _text_search(
+ self, query: str, limit: int, search_field: str = ''
+ ) -> _FindResult:
+ """
+ Conducts a text-based search on the index.
+
+ :param query: The query to search.
+ :param limit: The maximum number of results to return.
+ :param search_field: The field to search the query.
+ :return: Search results.
+ """
+ query_str = '|'.join(query.split(' '))
+ q = (
+ Query(f'@{search_field}:{query_str}')
+ .scorer(self._text_scorer)
+ .with_scores()
+ .paging(0, limit)
+ )
+
+ results = self._client.ft(index_name=self.index_name).search(q).docs
+
+ scores: NdArray = NdArray._docarray_from_native(
+ np.array([document['score'] for document in results])
+ )
+
+ docs = [json.loads(doc.json) for doc in results]
+
+ return _FindResult(documents=docs, scores=scores)
+
+ def _text_search_batched(
+ self, queries: Sequence[str], limit: int, search_field: str = ''
+ ) -> _FindResultBatched:
+ """
+ Conducts a batched text-based search on the index.
+
+ :param queries: The queries to search.
+ :param limit: The maximum number of results to return for each query.
+ :param search_field: The field to search the queries.
+ :return: Search results.
+ """
+ docs, scores = [], []
+ for query in queries:
+ results = self._text_search(
+ query=query, search_field=search_field, limit=limit
+ )
+ docs.append(results.documents)
+ scores.append(results.scores)
+
+ return _FindResultBatched(documents=docs, scores=scores)
diff --git a/docarray/index/backends/weaviate.py b/docarray/index/backends/weaviate.py
index 5179f8cb588..13eb6893753 100644
--- a/docarray/index/backends/weaviate.py
+++ b/docarray/index/backends/weaviate.py
@@ -26,10 +26,12 @@
import docarray
from docarray import BaseDoc, DocList
+from docarray.array.any_array import AnyDocArray
from docarray.index.abstract import BaseDocIndex, FindResultBatched, _FindResultBatched
from docarray.typing import AnyTensor
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.typing.tensor.ndarray import NdArray
+from docarray.utils._internal._typing import safe_issubclass
from docarray.utils._internal.misc import import_library
from docarray.utils.find import FindResult, _FindResult
@@ -129,6 +131,7 @@ def _set_properties(self) -> None:
field_overwrites.get(k, k)
for k, v in self._column_infos.items()
if v.config.get('is_embedding', False) is False
+ and not safe_issubclass(v.docarray_type, AnyDocArray)
]
def _validate_columns(self) -> None:
@@ -199,6 +202,8 @@ def _create_schema(self) -> None:
for column_name, column_info in column_infos.items():
# in weaviate, we do not create a property for the doc's embeddings
+ if safe_issubclass(column_info.docarray_type, AnyDocArray):
+ continue
if column_name == self.embedding_column:
continue
if column_info.db_type == 'blob':
@@ -220,9 +225,7 @@ def _create_schema(self) -> None:
schema["properties"] = properties
schema["class"] = self.index_name
- # TODO: Use exists() instead of contains() when available
- # see https://github.com/weaviate/weaviate-python-client/issues/232
- if self._client.schema.contains(schema):
+ if self._client.schema.exists(self.index_name):
logging.warning(
f"Found index {self.index_name} with schema {schema}. Will reuse existing schema."
)
@@ -240,11 +243,6 @@ class DBConfig(BaseDocIndex.DBConfig):
scopes: List[str] = field(default_factory=lambda: ["offline_access"])
auth_api_key: Optional[str] = None
embedded_options: Optional[EmbeddedOptions] = None
-
- @dataclass
- class RuntimeConfig(BaseDocIndex.RuntimeConfig):
- """Dataclass that contains all "dynamic" configurations of WeaviateDocumentIndex."""
-
default_column_config: Dict[Any, Dict[str, Any]] = field(
default_factory=lambda: {
np.ndarray: {},
@@ -259,6 +257,20 @@ class RuntimeConfig(BaseDocIndex.RuntimeConfig):
}
)
+ def __post_init__(self):
+ # To prevent errors, it is important to capitalize the provided index name
+ # when working with Weaviate, as it stores index names in a capitalized format.
+ # Can't use .capitalize() because it modifies the whole string (See test).
+ self.index_name = (
+ self.index_name[0].upper() + self.index_name[1:]
+ if self.index_name
+ else None
+ )
+
+ @dataclass
+ class RuntimeConfig(BaseDocIndex.RuntimeConfig):
+ """Dataclass that contains all "dynamic" configurations of WeaviateDocumentIndex."""
+
batch_config: Dict[str, Any] = field(
default_factory=lambda: DEFAULT_BATCH_CONFIG
)
@@ -353,7 +365,7 @@ def find(
query_vec_np, search_field=search_field, limit=limit, **kwargs
)
- if isinstance(docs, List):
+ if isinstance(docs, List) and not isinstance(docs, DocList):
docs = self._dict_list_to_docarray(docs)
return FindResult(documents=docs, scores=scores)
@@ -384,7 +396,7 @@ def _find(
index_name = self.index_name
if search_field:
logging.warning(
- 'Argument search_field is not supported for WeaviateDocumentIndex. Ignoring.'
+ 'The search_field argument is not supported for the WeaviateDocumentIndex and will be ignored.'
)
near_vector: Dict[str, Any] = {
"vector": query,
@@ -431,7 +443,7 @@ def find_batched(
queries: Union[AnyTensor, DocList],
search_field: str = '',
limit: int = 10,
- **kwargs,
+ **kwargs: Any,
) -> FindResultBatched:
"""Find documents in the index using nearest neighbor search.
@@ -564,6 +576,8 @@ def _parse_weaviate_result(self, result: Dict) -> Dict:
return result
def _index(self, column_to_data: Dict[str, Generator[Any, None, None]]):
+ self._index_subindex(column_to_data)
+
docs = self._transpose_col_value_dict(column_to_data)
index_name = self.index_name
@@ -593,7 +607,7 @@ def _text_search(
results = (
self._client.query.get(index_name, self.properties)
- .with_bm25(bm25)
+ .with_bm25(**bm25)
.with_limit(limit)
.with_additional(["score", "vector"])
.do()
@@ -614,7 +628,7 @@ def _text_search_batched(
q = (
self._client.query.get(self.index_name, self.properties)
- .with_bm25(bm25)
+ .with_bm25(**bm25)
.with_limit(limit)
.with_additional(["score", "vector"])
.with_alias(f'query_{i}')
@@ -690,7 +704,7 @@ def python_type_to_db_type(self, python_type: Type) -> Any:
or None if ``python_type`` is not supported.
"""
for allowed_type in WEAVIATE_PY_VEC_TYPES:
- if issubclass(python_type, allowed_type):
+ if safe_issubclass(python_type, allowed_type):
return 'number[]'
py_weaviate_type_map = {
@@ -704,7 +718,7 @@ def python_type_to_db_type(self, python_type: Type) -> Any:
}
for py_type, weaviate_type in py_weaviate_type_map.items():
- if issubclass(python_type, py_type):
+ if safe_issubclass(python_type, py_type):
return weaviate_type
raise ValueError(f'Unsupported column type for {type(self)}: {python_type}')
@@ -741,6 +755,36 @@ def _convert_nonembedding_array_to_list(self, doc):
if doc[column] is not None:
doc[column] = doc[column].tolist()
+ def _filter_by_parent_id(self, id: str) -> Optional[List[str]]:
+ results = (
+ self._client.query.get(self._db_config.index_name, ['docarrayid'])
+ .with_where(
+ {'path': ['parent_id'], 'operator': 'Equal', 'valueString': f'{id}'}
+ )
+ .do()
+ )
+
+ ids = [
+ res['docarrayid']
+ for res in results['data']['Get'][self._db_config.index_name]
+ ]
+ return ids
+
+ def _doc_exists(self, doc_id: str) -> bool:
+ result = (
+ self._client.query.get(self.index_name, ['docarrayid'])
+ .with_where(
+ {
+ "path": ['docarrayid'],
+ "operator": "Equal",
+ "valueString": f'{doc_id}',
+ }
+ )
+ .do()
+ )
+ docs = result["data"]["Get"][self.index_name]
+ return docs is not None and len(docs) > 0
+
class QueryBuilder(BaseDocIndex.QueryBuilder):
def __init__(self, document_index):
self._queries = [
@@ -749,7 +793,7 @@ def __init__(self, document_index):
)
]
- def build(self) -> Any:
+ def build(self, *args, **kwargs) -> Any:
"""Build the query object."""
num_queries = len(self._queries)
@@ -810,6 +854,7 @@ def find(
query,
score_name: Literal["certainty", "distance"] = "certainty",
score_threshold: Optional[float] = None,
+ **kwargs,
) -> Any:
"""
Find k-nearest neighbors of the query.
@@ -819,6 +864,11 @@ def find(
:param score_threshold: the threshold of the score
:return: self
"""
+ if kwargs.get('search_field'):
+ logging.warning(
+ 'The search_field argument is not supported for the WeaviateDocumentIndex and will be ignored.'
+ )
+
near_vector = {
"vector": query,
}
@@ -862,7 +912,7 @@ def find_batched(
return self
- def filter(self, where_filter) -> Any:
+ def filter(self, where_filter: Any) -> Any:
"""Find documents in the index based on a filter query
:param where_filter: a filter
:return: self
@@ -891,18 +941,22 @@ def filter_batched(self, filters) -> Any:
return self
- def text_search(self, query, search_field) -> Any:
+ def text_search(self, query: str, search_field: Optional[str] = None) -> Any:
"""Find documents in the index based on a text search query
:param query: The text to search for
:param search_field: name of the field to search on
:return: self
"""
- bm25 = {"query": query, "properties": [search_field]}
+ bm25: Dict[str, Any] = {"query": query}
+ if search_field:
+ bm25["properties"] = [search_field]
self._queries[0] = self._queries[0].with_bm25(**bm25)
return self
- def text_search_batched(self, queries, search_field) -> Any:
+ def text_search_batched(
+ self, queries: Sequence[str], search_field: Optional[str] = None
+ ) -> Any:
"""Find documents in the index based on a text search query
:param queries: The texts to search for
@@ -915,7 +969,9 @@ def text_search_batched(self, queries, search_field) -> Any:
new_queries = []
for query, clause in zip(adj_queries, adj_clauses):
- bm25 = {"query": clause, "properties": [search_field]}
+ bm25 = {"query": clause}
+ if search_field:
+ bm25["properties"] = [search_field]
new_queries.append(query.with_bm25(**bm25))
self._queries = new_queries
diff --git a/docarray/proto/__init__.py b/docarray/proto/__init__.py
index b1a201b6e2f..faa1cdffe8f 100644
--- a/docarray/proto/__init__.py
+++ b/docarray/proto/__init__.py
@@ -1,3 +1,18 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from typing import TYPE_CHECKING
from docarray.utils._internal.misc import import_library
@@ -17,6 +32,7 @@
DocVecProto,
ListOfAnyProto,
ListOfDocArrayProto,
+ ListOfDocVecProto,
NdArrayProto,
NodeProto,
)
@@ -28,6 +44,7 @@
DocVecProto,
ListOfAnyProto,
ListOfDocArrayProto,
+ ListOfDocVecProto,
NdArrayProto,
NodeProto,
)
@@ -40,6 +57,7 @@
'DocVecProto',
'DocListProto',
'ListOfDocArrayProto',
+ 'ListOfDocVecProto',
'ListOfAnyProto',
'DictOfAnyProto',
]
diff --git a/docarray/proto/docarray.proto b/docarray/proto/docarray.proto
index 19a33ccbc22..a73451bac1b 100644
--- a/docarray/proto/docarray.proto
+++ b/docarray/proto/docarray.proto
@@ -100,9 +100,13 @@ message ListOfDocArrayProto {
repeated DocListProto data = 1;
}
+message ListOfDocVecProto {
+ repeated DocVecProto data = 1;
+}
+
message DocVecProto{
map tensor_columns = 1; // a dict of document columns
map doc_columns = 2; // a dict of tensor columns
- map docs_vec_columns = 3; // a dict of document array columns
+ map docs_vec_columns = 3; // a dict of document array columns
map any_columns = 4; // a dict of any columns. Used for the rest of the data
}
\ No newline at end of file
diff --git a/docarray/proto/pb/docarray_pb2.py b/docarray/proto/pb/docarray_pb2.py
index 8ff91a9f5e8..0cd5b334a18 100644
--- a/docarray/proto/pb/docarray_pb2.py
+++ b/docarray/proto/pb/docarray_pb2.py
@@ -14,7 +14,7 @@
from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2
-DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0e\x64ocarray.proto\x12\x08\x64ocarray\x1a\x1cgoogle/protobuf/struct.proto\"A\n\x11\x44\x65nseNdArrayProto\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\r\x12\r\n\x05\x64type\x18\x03 \x01(\t\"g\n\x0cNdArrayProto\x12*\n\x05\x64\x65nse\x18\x01 \x01(\x0b\x32\x1b.docarray.DenseNdArrayProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\"Z\n\x0cKeyValuePair\x12#\n\x03key\x18\x01 \x01(\x0b\x32\x16.google.protobuf.Value\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.google.protobuf.Value\";\n\x10GenericDictValue\x12\'\n\x07\x65ntries\x18\x01 \x03(\x0b\x32\x16.docarray.KeyValuePair\"\xb1\x03\n\tNodeProto\x12\x0e\n\x04text\x18\x01 \x01(\tH\x00\x12\x11\n\x07integer\x18\x02 \x01(\x05H\x00\x12\x0f\n\x05\x66loat\x18\x03 \x01(\x01H\x00\x12\x11\n\x07\x62oolean\x18\x04 \x01(\x08H\x00\x12\x0e\n\x04\x62lob\x18\x05 \x01(\x0cH\x00\x12)\n\x07ndarray\x18\x06 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x12!\n\x03\x64oc\x18\x07 \x01(\x0b\x32\x12.docarray.DocProtoH\x00\x12+\n\tdoc_array\x18\x08 \x01(\x0b\x32\x16.docarray.DocListProtoH\x00\x12(\n\x04list\x18\t \x01(\x0b\x32\x18.docarray.ListOfAnyProtoH\x00\x12\'\n\x03set\x18\n \x01(\x0b\x32\x18.docarray.ListOfAnyProtoH\x00\x12)\n\x05tuple\x18\x0b \x01(\x0b\x32\x18.docarray.ListOfAnyProtoH\x00\x12(\n\x04\x64ict\x18\x0c \x01(\x0b\x32\x18.docarray.DictOfAnyProtoH\x00\x12\x0e\n\x04type\x18\r \x01(\tH\x01\x42\t\n\x07\x63ontentB\x0f\n\rdocarray_type\"x\n\x08\x44ocProto\x12*\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x1c.docarray.DocProto.DataEntry\x1a@\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.docarray.NodeProto:\x02\x38\x01\"\x84\x01\n\x0e\x44ictOfAnyProto\x12\x30\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\".docarray.DictOfAnyProto.DataEntry\x1a@\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.docarray.NodeProto:\x02\x38\x01\"3\n\x0eListOfAnyProto\x12!\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x13.docarray.NodeProto\"0\n\x0c\x44ocListProto\x12 \n\x04\x64ocs\x18\x01 \x03(\x0b\x32\x12.docarray.DocProto\";\n\x13ListOfDocArrayProto\x12$\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x16.docarray.DocListProto\"\xc7\x04\n\x0b\x44ocVecProto\x12@\n\x0etensor_columns\x18\x01 \x03(\x0b\x32(.docarray.DocVecProto.TensorColumnsEntry\x12:\n\x0b\x64oc_columns\x18\x02 \x03(\x0b\x32%.docarray.DocVecProto.DocColumnsEntry\x12\x43\n\x10\x64ocs_vec_columns\x18\x03 \x03(\x0b\x32).docarray.DocVecProto.DocsVecColumnsEntry\x12:\n\x0b\x61ny_columns\x18\x04 \x03(\x0b\x32%.docarray.DocVecProto.AnyColumnsEntry\x1aL\n\x12TensorColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.docarray.NdArrayProto:\x02\x38\x01\x1aH\n\x0f\x44ocColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.docarray.DocVecProto:\x02\x38\x01\x1aT\n\x13\x44ocsVecColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12,\n\x05value\x18\x02 \x01(\x0b\x32\x1d.docarray.ListOfDocArrayProto:\x02\x38\x01\x1aK\n\x0f\x41nyColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\'\n\x05value\x18\x02 \x01(\x0b\x32\x18.docarray.ListOfAnyProto:\x02\x38\x01\x62\x06proto3')
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0e\x64ocarray.proto\x12\x08\x64ocarray\x1a\x1cgoogle/protobuf/struct.proto\"A\n\x11\x44\x65nseNdArrayProto\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\r\x12\r\n\x05\x64type\x18\x03 \x01(\t\"g\n\x0cNdArrayProto\x12*\n\x05\x64\x65nse\x18\x01 \x01(\x0b\x32\x1b.docarray.DenseNdArrayProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\"Z\n\x0cKeyValuePair\x12#\n\x03key\x18\x01 \x01(\x0b\x32\x16.google.protobuf.Value\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.google.protobuf.Value\";\n\x10GenericDictValue\x12\'\n\x07\x65ntries\x18\x01 \x03(\x0b\x32\x16.docarray.KeyValuePair\"\xb1\x03\n\tNodeProto\x12\x0e\n\x04text\x18\x01 \x01(\tH\x00\x12\x11\n\x07integer\x18\x02 \x01(\x05H\x00\x12\x0f\n\x05\x66loat\x18\x03 \x01(\x01H\x00\x12\x11\n\x07\x62oolean\x18\x04 \x01(\x08H\x00\x12\x0e\n\x04\x62lob\x18\x05 \x01(\x0cH\x00\x12)\n\x07ndarray\x18\x06 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x12!\n\x03\x64oc\x18\x07 \x01(\x0b\x32\x12.docarray.DocProtoH\x00\x12+\n\tdoc_array\x18\x08 \x01(\x0b\x32\x16.docarray.DocListProtoH\x00\x12(\n\x04list\x18\t \x01(\x0b\x32\x18.docarray.ListOfAnyProtoH\x00\x12\'\n\x03set\x18\n \x01(\x0b\x32\x18.docarray.ListOfAnyProtoH\x00\x12)\n\x05tuple\x18\x0b \x01(\x0b\x32\x18.docarray.ListOfAnyProtoH\x00\x12(\n\x04\x64ict\x18\x0c \x01(\x0b\x32\x18.docarray.DictOfAnyProtoH\x00\x12\x0e\n\x04type\x18\r \x01(\tH\x01\x42\t\n\x07\x63ontentB\x0f\n\rdocarray_type\"x\n\x08\x44ocProto\x12*\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x1c.docarray.DocProto.DataEntry\x1a@\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.docarray.NodeProto:\x02\x38\x01\"\x84\x01\n\x0e\x44ictOfAnyProto\x12\x30\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\".docarray.DictOfAnyProto.DataEntry\x1a@\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.docarray.NodeProto:\x02\x38\x01\"3\n\x0eListOfAnyProto\x12!\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x13.docarray.NodeProto\"0\n\x0c\x44ocListProto\x12 \n\x04\x64ocs\x18\x01 \x03(\x0b\x32\x12.docarray.DocProto\";\n\x13ListOfDocArrayProto\x12$\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x16.docarray.DocListProto\"8\n\x11ListOfDocVecProto\x12#\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x15.docarray.DocVecProto\"\xc5\x04\n\x0b\x44ocVecProto\x12@\n\x0etensor_columns\x18\x01 \x03(\x0b\x32(.docarray.DocVecProto.TensorColumnsEntry\x12:\n\x0b\x64oc_columns\x18\x02 \x03(\x0b\x32%.docarray.DocVecProto.DocColumnsEntry\x12\x43\n\x10\x64ocs_vec_columns\x18\x03 \x03(\x0b\x32).docarray.DocVecProto.DocsVecColumnsEntry\x12:\n\x0b\x61ny_columns\x18\x04 \x03(\x0b\x32%.docarray.DocVecProto.AnyColumnsEntry\x1aL\n\x12TensorColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.docarray.NdArrayProto:\x02\x38\x01\x1aH\n\x0f\x44ocColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.docarray.DocVecProto:\x02\x38\x01\x1aR\n\x13\x44ocsVecColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12*\n\x05value\x18\x02 \x01(\x0b\x32\x1b.docarray.ListOfDocVecProto:\x02\x38\x01\x1aK\n\x0f\x41nyColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\'\n\x05value\x18\x02 \x01(\x0b\x32\x18.docarray.ListOfAnyProto:\x02\x38\x01\x62\x06proto3')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'docarray_pb2', globals())
@@ -57,14 +57,16 @@
_DOCLISTPROTO._serialized_end=1177
_LISTOFDOCARRAYPROTO._serialized_start=1179
_LISTOFDOCARRAYPROTO._serialized_end=1238
- _DOCVECPROTO._serialized_start=1241
- _DOCVECPROTO._serialized_end=1824
- _DOCVECPROTO_TENSORCOLUMNSENTRY._serialized_start=1511
- _DOCVECPROTO_TENSORCOLUMNSENTRY._serialized_end=1587
- _DOCVECPROTO_DOCCOLUMNSENTRY._serialized_start=1589
- _DOCVECPROTO_DOCCOLUMNSENTRY._serialized_end=1661
- _DOCVECPROTO_DOCSVECCOLUMNSENTRY._serialized_start=1663
- _DOCVECPROTO_DOCSVECCOLUMNSENTRY._serialized_end=1747
- _DOCVECPROTO_ANYCOLUMNSENTRY._serialized_start=1749
- _DOCVECPROTO_ANYCOLUMNSENTRY._serialized_end=1824
+ _LISTOFDOCVECPROTO._serialized_start=1240
+ _LISTOFDOCVECPROTO._serialized_end=1296
+ _DOCVECPROTO._serialized_start=1299
+ _DOCVECPROTO._serialized_end=1880
+ _DOCVECPROTO_TENSORCOLUMNSENTRY._serialized_start=1569
+ _DOCVECPROTO_TENSORCOLUMNSENTRY._serialized_end=1645
+ _DOCVECPROTO_DOCCOLUMNSENTRY._serialized_start=1647
+ _DOCVECPROTO_DOCCOLUMNSENTRY._serialized_end=1719
+ _DOCVECPROTO_DOCSVECCOLUMNSENTRY._serialized_start=1721
+ _DOCVECPROTO_DOCSVECCOLUMNSENTRY._serialized_end=1803
+ _DOCVECPROTO_ANYCOLUMNSENTRY._serialized_start=1805
+ _DOCVECPROTO_ANYCOLUMNSENTRY._serialized_end=1880
# @@protoc_insertion_point(module_scope)
diff --git a/docarray/proto/pb2/docarray_pb2.py b/docarray/proto/pb2/docarray_pb2.py
index 9fbbbadf342..e178c8c3f9d 100644
--- a/docarray/proto/pb2/docarray_pb2.py
+++ b/docarray/proto/pb2/docarray_pb2.py
@@ -16,7 +16,7 @@
from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x0e\x64ocarray.proto\x12\x08\x64ocarray\x1a\x1cgoogle/protobuf/struct.proto\"A\n\x11\x44\x65nseNdArrayProto\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\r\x12\r\n\x05\x64type\x18\x03 \x01(\t\"g\n\x0cNdArrayProto\x12*\n\x05\x64\x65nse\x18\x01 \x01(\x0b\x32\x1b.docarray.DenseNdArrayProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\"Z\n\x0cKeyValuePair\x12#\n\x03key\x18\x01 \x01(\x0b\x32\x16.google.protobuf.Value\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.google.protobuf.Value\";\n\x10GenericDictValue\x12\'\n\x07\x65ntries\x18\x01 \x03(\x0b\x32\x16.docarray.KeyValuePair\"\xb1\x03\n\tNodeProto\x12\x0e\n\x04text\x18\x01 \x01(\tH\x00\x12\x11\n\x07integer\x18\x02 \x01(\x05H\x00\x12\x0f\n\x05\x66loat\x18\x03 \x01(\x01H\x00\x12\x11\n\x07\x62oolean\x18\x04 \x01(\x08H\x00\x12\x0e\n\x04\x62lob\x18\x05 \x01(\x0cH\x00\x12)\n\x07ndarray\x18\x06 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x12!\n\x03\x64oc\x18\x07 \x01(\x0b\x32\x12.docarray.DocProtoH\x00\x12+\n\tdoc_array\x18\x08 \x01(\x0b\x32\x16.docarray.DocListProtoH\x00\x12(\n\x04list\x18\t \x01(\x0b\x32\x18.docarray.ListOfAnyProtoH\x00\x12\'\n\x03set\x18\n \x01(\x0b\x32\x18.docarray.ListOfAnyProtoH\x00\x12)\n\x05tuple\x18\x0b \x01(\x0b\x32\x18.docarray.ListOfAnyProtoH\x00\x12(\n\x04\x64ict\x18\x0c \x01(\x0b\x32\x18.docarray.DictOfAnyProtoH\x00\x12\x0e\n\x04type\x18\r \x01(\tH\x01\x42\t\n\x07\x63ontentB\x0f\n\rdocarray_type\"x\n\x08\x44ocProto\x12*\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x1c.docarray.DocProto.DataEntry\x1a@\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.docarray.NodeProto:\x02\x38\x01\"\x84\x01\n\x0e\x44ictOfAnyProto\x12\x30\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\".docarray.DictOfAnyProto.DataEntry\x1a@\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.docarray.NodeProto:\x02\x38\x01\"3\n\x0eListOfAnyProto\x12!\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x13.docarray.NodeProto\"0\n\x0c\x44ocListProto\x12 \n\x04\x64ocs\x18\x01 \x03(\x0b\x32\x12.docarray.DocProto\";\n\x13ListOfDocArrayProto\x12$\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x16.docarray.DocListProto\"\xc7\x04\n\x0b\x44ocVecProto\x12@\n\x0etensor_columns\x18\x01 \x03(\x0b\x32(.docarray.DocVecProto.TensorColumnsEntry\x12:\n\x0b\x64oc_columns\x18\x02 \x03(\x0b\x32%.docarray.DocVecProto.DocColumnsEntry\x12\x43\n\x10\x64ocs_vec_columns\x18\x03 \x03(\x0b\x32).docarray.DocVecProto.DocsVecColumnsEntry\x12:\n\x0b\x61ny_columns\x18\x04 \x03(\x0b\x32%.docarray.DocVecProto.AnyColumnsEntry\x1aL\n\x12TensorColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.docarray.NdArrayProto:\x02\x38\x01\x1aH\n\x0f\x44ocColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.docarray.DocVecProto:\x02\x38\x01\x1aT\n\x13\x44ocsVecColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12,\n\x05value\x18\x02 \x01(\x0b\x32\x1d.docarray.ListOfDocArrayProto:\x02\x38\x01\x1aK\n\x0f\x41nyColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\'\n\x05value\x18\x02 \x01(\x0b\x32\x18.docarray.ListOfAnyProto:\x02\x38\x01\x62\x06proto3'
+ b'\n\x0e\x64ocarray.proto\x12\x08\x64ocarray\x1a\x1cgoogle/protobuf/struct.proto\"A\n\x11\x44\x65nseNdArrayProto\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\r\x12\r\n\x05\x64type\x18\x03 \x01(\t\"g\n\x0cNdArrayProto\x12*\n\x05\x64\x65nse\x18\x01 \x01(\x0b\x32\x1b.docarray.DenseNdArrayProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\"Z\n\x0cKeyValuePair\x12#\n\x03key\x18\x01 \x01(\x0b\x32\x16.google.protobuf.Value\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.google.protobuf.Value\";\n\x10GenericDictValue\x12\'\n\x07\x65ntries\x18\x01 \x03(\x0b\x32\x16.docarray.KeyValuePair\"\xb1\x03\n\tNodeProto\x12\x0e\n\x04text\x18\x01 \x01(\tH\x00\x12\x11\n\x07integer\x18\x02 \x01(\x05H\x00\x12\x0f\n\x05\x66loat\x18\x03 \x01(\x01H\x00\x12\x11\n\x07\x62oolean\x18\x04 \x01(\x08H\x00\x12\x0e\n\x04\x62lob\x18\x05 \x01(\x0cH\x00\x12)\n\x07ndarray\x18\x06 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x12!\n\x03\x64oc\x18\x07 \x01(\x0b\x32\x12.docarray.DocProtoH\x00\x12+\n\tdoc_array\x18\x08 \x01(\x0b\x32\x16.docarray.DocListProtoH\x00\x12(\n\x04list\x18\t \x01(\x0b\x32\x18.docarray.ListOfAnyProtoH\x00\x12\'\n\x03set\x18\n \x01(\x0b\x32\x18.docarray.ListOfAnyProtoH\x00\x12)\n\x05tuple\x18\x0b \x01(\x0b\x32\x18.docarray.ListOfAnyProtoH\x00\x12(\n\x04\x64ict\x18\x0c \x01(\x0b\x32\x18.docarray.DictOfAnyProtoH\x00\x12\x0e\n\x04type\x18\r \x01(\tH\x01\x42\t\n\x07\x63ontentB\x0f\n\rdocarray_type\"x\n\x08\x44ocProto\x12*\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x1c.docarray.DocProto.DataEntry\x1a@\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.docarray.NodeProto:\x02\x38\x01\"\x84\x01\n\x0e\x44ictOfAnyProto\x12\x30\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\".docarray.DictOfAnyProto.DataEntry\x1a@\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.docarray.NodeProto:\x02\x38\x01\"3\n\x0eListOfAnyProto\x12!\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x13.docarray.NodeProto\"0\n\x0c\x44ocListProto\x12 \n\x04\x64ocs\x18\x01 \x03(\x0b\x32\x12.docarray.DocProto\";\n\x13ListOfDocArrayProto\x12$\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x16.docarray.DocListProto\"8\n\x11ListOfDocVecProto\x12#\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x15.docarray.DocVecProto\"\xc5\x04\n\x0b\x44ocVecProto\x12@\n\x0etensor_columns\x18\x01 \x03(\x0b\x32(.docarray.DocVecProto.TensorColumnsEntry\x12:\n\x0b\x64oc_columns\x18\x02 \x03(\x0b\x32%.docarray.DocVecProto.DocColumnsEntry\x12\x43\n\x10\x64ocs_vec_columns\x18\x03 \x03(\x0b\x32).docarray.DocVecProto.DocsVecColumnsEntry\x12:\n\x0b\x61ny_columns\x18\x04 \x03(\x0b\x32%.docarray.DocVecProto.AnyColumnsEntry\x1aL\n\x12TensorColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.docarray.NdArrayProto:\x02\x38\x01\x1aH\n\x0f\x44ocColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.docarray.DocVecProto:\x02\x38\x01\x1aR\n\x13\x44ocsVecColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12*\n\x05value\x18\x02 \x01(\x0b\x32\x1b.docarray.ListOfDocVecProto:\x02\x38\x01\x1aK\n\x0f\x41nyColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\'\n\x05value\x18\x02 \x01(\x0b\x32\x18.docarray.ListOfAnyProto:\x02\x38\x01\x62\x06proto3'
)
@@ -32,6 +32,7 @@
_LISTOFANYPROTO = DESCRIPTOR.message_types_by_name['ListOfAnyProto']
_DOCLISTPROTO = DESCRIPTOR.message_types_by_name['DocListProto']
_LISTOFDOCARRAYPROTO = DESCRIPTOR.message_types_by_name['ListOfDocArrayProto']
+_LISTOFDOCVECPROTO = DESCRIPTOR.message_types_by_name['ListOfDocVecProto']
_DOCVECPROTO = DESCRIPTOR.message_types_by_name['DocVecProto']
_DOCVECPROTO_TENSORCOLUMNSENTRY = _DOCVECPROTO.nested_types_by_name[
'TensorColumnsEntry'
@@ -171,6 +172,17 @@
)
_sym_db.RegisterMessage(ListOfDocArrayProto)
+ListOfDocVecProto = _reflection.GeneratedProtocolMessageType(
+ 'ListOfDocVecProto',
+ (_message.Message,),
+ {
+ 'DESCRIPTOR': _LISTOFDOCVECPROTO,
+ '__module__': 'docarray_pb2'
+ # @@protoc_insertion_point(class_scope:docarray.ListOfDocVecProto)
+ },
+)
+_sym_db.RegisterMessage(ListOfDocVecProto)
+
DocVecProto = _reflection.GeneratedProtocolMessageType(
'DocVecProto',
(_message.Message,),
@@ -261,14 +273,16 @@
_DOCLISTPROTO._serialized_end = 1177
_LISTOFDOCARRAYPROTO._serialized_start = 1179
_LISTOFDOCARRAYPROTO._serialized_end = 1238
- _DOCVECPROTO._serialized_start = 1241
- _DOCVECPROTO._serialized_end = 1824
- _DOCVECPROTO_TENSORCOLUMNSENTRY._serialized_start = 1511
- _DOCVECPROTO_TENSORCOLUMNSENTRY._serialized_end = 1587
- _DOCVECPROTO_DOCCOLUMNSENTRY._serialized_start = 1589
- _DOCVECPROTO_DOCCOLUMNSENTRY._serialized_end = 1661
- _DOCVECPROTO_DOCSVECCOLUMNSENTRY._serialized_start = 1663
- _DOCVECPROTO_DOCSVECCOLUMNSENTRY._serialized_end = 1747
- _DOCVECPROTO_ANYCOLUMNSENTRY._serialized_start = 1749
- _DOCVECPROTO_ANYCOLUMNSENTRY._serialized_end = 1824
+ _LISTOFDOCVECPROTO._serialized_start = 1240
+ _LISTOFDOCVECPROTO._serialized_end = 1296
+ _DOCVECPROTO._serialized_start = 1299
+ _DOCVECPROTO._serialized_end = 1880
+ _DOCVECPROTO_TENSORCOLUMNSENTRY._serialized_start = 1569
+ _DOCVECPROTO_TENSORCOLUMNSENTRY._serialized_end = 1645
+ _DOCVECPROTO_DOCCOLUMNSENTRY._serialized_start = 1647
+ _DOCVECPROTO_DOCCOLUMNSENTRY._serialized_end = 1719
+ _DOCVECPROTO_DOCSVECCOLUMNSENTRY._serialized_start = 1721
+ _DOCVECPROTO_DOCSVECCOLUMNSENTRY._serialized_end = 1803
+ _DOCVECPROTO_ANYCOLUMNSENTRY._serialized_start = 1805
+ _DOCVECPROTO_ANYCOLUMNSENTRY._serialized_end = 1880
# @@protoc_insertion_point(module_scope)
diff --git a/docarray/store/__init__.py b/docarray/store/__init__.py
index 9547db27c3e..42e7025ce85 100644
--- a/docarray/store/__init__.py
+++ b/docarray/store/__init__.py
@@ -8,7 +8,6 @@
)
if TYPE_CHECKING:
- from docarray.store.jac import JACDocStore # noqa: F401
from docarray.store.s3 import S3DocStore # noqa: F401
__all__ = ['FileDocStore']
@@ -16,10 +15,7 @@
def __getattr__(name: str):
lib: types.ModuleType
- if name == 'JACDocStore':
- import_library('hubble', raise_error=True)
- import docarray.store.jac as lib
- elif name == 'S3DocStore':
+ if name == 'S3DocStore':
import_library('smart_open', raise_error=True)
import_library('botocore', raise_error=True)
import_library('boto3', raise_error=True)
diff --git a/docarray/store/abstract_doc_store.py b/docarray/store/abstract_doc_store.py
index df7788f584a..e95c014d38e 100644
--- a/docarray/store/abstract_doc_store.py
+++ b/docarray/store/abstract_doc_store.py
@@ -1,5 +1,20 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from abc import ABC, abstractmethod
-from typing import Dict, Iterator, List, Optional, Type
+from typing import Dict, Iterator, List, Type
from typing_extensions import TYPE_CHECKING
@@ -35,17 +50,13 @@ def delete(name: str, missing_ok: bool) -> bool:
def push(
docs: 'DocList',
name: str,
- public: bool,
show_progress: bool,
- branding: Optional[Dict],
) -> Dict:
"""Push this DocList to the specified name.
:param docs: The DocList to push
:param name: The name to push to
- :param public: Whether the DocList should be publicly accessible
:param show_progress: If true, a progress bar will be displayed.
- :param branding: Branding information to be stored with the DocList
"""
...
@@ -54,17 +65,13 @@ def push(
def push_stream(
docs: Iterator['BaseDoc'],
url: str,
- public: bool = True,
show_progress: bool = False,
- branding: Optional[Dict] = None,
) -> Dict:
"""Push a stream of documents to the specified name.
:param docs: a stream of documents
:param url: The name to push to
- :param public: Whether the DocList should be publicly accessible
:param show_progress: If true, a progress bar will be displayed.
- :param branding: Branding information to be stored with the DocList
"""
...
diff --git a/docarray/store/exceptions.py b/docarray/store/exceptions.py
index 9caf0d8a167..52809621337 100644
--- a/docarray/store/exceptions.py
+++ b/docarray/store/exceptions.py
@@ -1,3 +1,18 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
class ConcurrentPushException(Exception):
"""Exception raised when a concurrent push is detected."""
diff --git a/docarray/store/file.py b/docarray/store/file.py
index 6c46c3ab615..b728b21460d 100644
--- a/docarray/store/file.py
+++ b/docarray/store/file.py
@@ -1,6 +1,21 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import logging
from pathlib import Path
-from typing import Dict, Iterator, List, Optional, Type, TypeVar
+from typing import Dict, Iterator, List, Type, TypeVar
from typing_extensions import TYPE_CHECKING
@@ -98,40 +113,29 @@ def push(
cls: Type[SelfFileDocStore],
docs: 'DocList',
name: str,
- public: bool,
show_progress: bool,
- branding: Optional[Dict],
) -> Dict:
"""Push this [`DocList`][docarray.DocList] object to the specified file path.
:param docs: The `DocList` to push.
:param name: The file path to push to.
- :param public: Not used by the ``file`` protocol.
:param show_progress: If true, a progress bar will be displayed.
- :param branding: Not used by the ``file`` protocol.
"""
- return cls.push_stream(iter(docs), name, public, show_progress, branding)
+ return cls.push_stream(iter(docs), name, show_progress)
@classmethod
def push_stream(
cls: Type[SelfFileDocStore],
docs: Iterator['BaseDoc'],
name: str,
- public: bool = True,
show_progress: bool = False,
- branding: Optional[Dict] = None,
) -> Dict:
"""Push a stream of documents to the specified file path.
:param docs: a stream of documents
:param name: The file path to push to.
- :param public: Not used by the ``file`` protocol.
:param show_progress: If true, a progress bar will be displayed.
- :param branding: Not used by the ``file`` protocol.
"""
- if branding is not None:
- logging.warning('branding is not supported for "file" protocol')
-
source = _to_binary_stream(
docs, protocol='protobuf', compress='gzip', show_progress=show_progress
)
diff --git a/docarray/store/helpers.py b/docarray/store/helpers.py
index 35c6fd2c6ae..24f28ac8ff4 100644
--- a/docarray/store/helpers.py
+++ b/docarray/store/helpers.py
@@ -6,6 +6,7 @@
from rich import filesize
from typing_extensions import TYPE_CHECKING, Protocol
+from docarray.utils._internal.misc import ProtocolType
from docarray.utils._internal.progress_bar import _get_progressbar
if TYPE_CHECKING:
@@ -112,12 +113,12 @@ def raise_req_error(resp: 'requests.Response') -> NoReturn:
class Streamable(Protocol):
"""A protocol for streamable objects."""
- def to_bytes(self, protocol: str, compress: Optional[str]) -> bytes:
+ def to_bytes(self, protocol: ProtocolType, compress: Optional[str]) -> bytes:
...
@classmethod
def from_bytes(
- cls: Type[T_Elem], bytes: bytes, protocol: str, compress: Optional[str]
+ cls: Type[T_Elem], bytes: bytes, protocol: ProtocolType, compress: Optional[str]
) -> 'T_Elem':
...
@@ -133,7 +134,7 @@ def close(self):
def _to_binary_stream(
iterator: Iterator['Streamable'],
total: Optional[int] = None,
- protocol: str = 'protobuf',
+ protocol: ProtocolType = 'protobuf',
compress: Optional[str] = None,
show_progress: bool = False,
) -> Iterator[bytes]:
@@ -170,36 +171,40 @@ def _from_binary_stream(
cls: Type[T],
stream: ReadableBytes,
total: Optional[int] = None,
- protocol: str = 'protobuf',
+ protocol: ProtocolType = 'protobuf',
compress: Optional[str] = None,
show_progress: bool = False,
) -> Iterator['T']:
- if show_progress:
- pbar, t = _get_progressbar(
- 'Deserializing', disable=not show_progress, total=total
- )
- else:
- pbar = nullcontext()
-
- with pbar:
+ try:
if show_progress:
- _total_size = 0
- pbar.start_task(t)
- while True:
- len_bytes = stream.read(4)
- if len(len_bytes) < 4:
- raise ValueError('Unexpected end of stream')
- len_item = int.from_bytes(len_bytes, 'big', signed=False)
- if len_item == 0:
- break
- item_bytes = stream.read(len_item)
- if len(item_bytes) < len_item:
- raise ValueError('Unexpected end of stream')
- item = cls.from_bytes(item_bytes, protocol=protocol, compress=compress)
-
- yield item
+ pbar, t = _get_progressbar(
+ 'Deserializing', disable=not show_progress, total=total
+ )
+ else:
+ pbar = nullcontext()
+ with pbar:
if show_progress:
- _total_size += len_item + 4
- pbar.update(t, advance=1, total_size=str(filesize.decimal(_total_size)))
+ _total_size = 0
+ pbar.start_task(t)
+ while True:
+ len_bytes = stream.read(4)
+ if len(len_bytes) < 4:
+ raise ValueError('Unexpected end of stream')
+ len_item = int.from_bytes(len_bytes, 'big', signed=False)
+ if len_item == 0:
+ break
+ item_bytes = stream.read(len_item)
+ if len(item_bytes) < len_item:
+ raise ValueError('Unexpected end of stream')
+ item = cls.from_bytes(item_bytes, protocol=protocol, compress=compress)
+
+ yield item
+
+ if show_progress:
+ _total_size += len_item + 4
+ pbar.update(
+ t, advance=1, total_size=str(filesize.decimal(_total_size))
+ )
+ finally:
stream.close()
diff --git a/docarray/store/jac.py b/docarray/store/jac.py
deleted file mode 100644
index 2ca4920194f..00000000000
--- a/docarray/store/jac.py
+++ /dev/null
@@ -1,370 +0,0 @@
-import json
-import logging
-import os
-from pathlib import Path
-from typing import (
- TYPE_CHECKING,
- Any,
- Dict,
- Iterator,
- List,
- Optional,
- Type,
- TypeVar,
- Union,
-)
-
-from docarray.store.abstract_doc_store import AbstractDocStore
-from docarray.store.helpers import (
- _BufferedCachingRequestReader,
- get_version_info,
- raise_req_error,
-)
-from docarray.utils._internal.cache import _get_cache_path
-from docarray.utils._internal.misc import import_library
-
-if TYPE_CHECKING: # pragma: no cover
- import io
-
- from docarray import BaseDoc, DocList
-
-if TYPE_CHECKING:
- import hubble
- from hubble import Client as HubbleClient
- from hubble.client.endpoints import EndpointsV2
-else:
- hubble = import_library('hubble', raise_error=True)
- HubbleClient = hubble.Client
- EndpointsV2 = hubble.client.endpoints.EndpointsV2
-
-
-def _get_length_from_summary(summary: List[Dict]) -> Optional[int]:
- """Get the length from summary."""
- for item in summary:
- if 'Length' == item['name']:
- return item['value']
- raise ValueError('Length not found in summary')
-
-
-def _get_raw_summary(self: 'DocList') -> List[Dict[str, Any]]:
- items: List[Dict[str, Any]] = [
- dict(
- name='Type',
- value=self.__class__.__name__,
- description='The type of the DocList',
- ),
- dict(
- name='Length',
- value=len(self),
- description='The length of the DocList',
- ),
- dict(
- name='Homogenous Documents',
- value=True,
- description='Whether all documents are of the same structure, attributes',
- ),
- dict(
- name='Fields',
- value=tuple(self[0].__class__.__fields__.keys()),
- description='The fields of the Document',
- ),
- dict(
- name='Multimodal dataclass',
- value=True,
- description='Whether all documents are multimodal',
- ),
- ]
-
- return items
-
-
-SelfJACDocStore = TypeVar('SelfJACDocStore', bound='JACDocStore')
-
-
-class JACDocStore(AbstractDocStore):
- """Class to push and pull [`DocList`][docarray.DocList] to and from Jina AI Cloud."""
-
- @staticmethod
- @hubble.login_required
- def list(namespace: str = '', show_table: bool = False) -> List[str]:
- """List all available arrays in the cloud.
-
- :param namespace: Not supported for Jina AI Cloud.
- :param show_table: if true, show the table of the arrays.
- :returns: List of available DocList's names.
- """
- if len(namespace) > 0:
- logging.warning('Namespace is not supported for Jina AI Cloud.')
- from rich import print
-
- result = []
- from rich import box
- from rich.table import Table
-
- resp = HubbleClient(jsonify=True).list_artifacts(
- filter={'type': 'documentArray'},
- sort={'createdAt': 1},
- pageSize=10000,
- )
-
- table = Table(
- title=f'You have {resp["meta"]["total"]} DocList on the cloud',
- box=box.SIMPLE,
- highlight=True,
- )
- table.add_column('Name')
- table.add_column('Length')
- table.add_column('Access')
- table.add_column('Created at', justify='center')
- table.add_column('Updated at', justify='center')
-
- for docs in resp['data']:
- result.append(docs['name'])
-
- table.add_row(
- docs['name'],
- str(_get_length_from_summary(docs['metaData'].get('summary', []))),
- docs['visibility'],
- docs['createdAt'],
- docs['updatedAt'],
- )
-
- if show_table:
- print(table)
- return result
-
- @staticmethod
- @hubble.login_required
- def delete(name: str, missing_ok: bool = True) -> bool:
- """
- Delete a [`DocList`][docarray.DocList] from the cloud.
- :param name: the name of the DocList to delete.
- :param missing_ok: if true, do not raise an error if the DocList does not exist.
- :return: True if the DocList was deleted, False if it did not exist.
- """
- try:
- HubbleClient(jsonify=True).delete_artifact(name=name)
- except hubble.excepts.RequestedEntityNotFoundError:
- if missing_ok:
- return False
- else:
- raise
- return True
-
- @staticmethod
- @hubble.login_required
- def push(
- docs: 'DocList',
- name: str,
- public: bool = True,
- show_progress: bool = False,
- branding: Optional[Dict] = None,
- ) -> Dict:
- """Push this [`DocList`][docarray.DocList] object to Jina AI Cloud
-
- !!! note
- - Push with the same ``name`` will override the existing content.
- - Kinda like a public clipboard where everyone can override anyone's content.
- So to make your content survive longer, you may want to use longer & more complicated name.
- - The lifetime of the content is not promised atm, could be a day, could be a week. Do not use it for
- persistence. Only use this full temporary transmission/storage/clipboard.
-
- :param docs: The `DocList` to push.
- :param name: A name that can later be used to retrieve this `DocList`.
- :param public: By default, anyone can pull a `DocList` if they know its name.
- Setting this to false will restrict access to only the creator.
- :param show_progress: If true, a progress bar will be displayed.
- :param branding: A dictionary of branding information to be sent to Jina Cloud. e.g. {"icon": "emoji", "background": "#fff"}
- """
- import requests
- import urllib3
-
- delimiter = os.urandom(32)
-
- data, ctype = urllib3.filepost.encode_multipart_formdata(
- {
- 'file': (
- 'DocumentArray',
- delimiter,
- ),
- 'name': name,
- 'type': 'documentArray',
- 'public': public,
- 'metaData': json.dumps(
- {
- 'summary': _get_raw_summary(docs),
- 'branding': branding,
- 'version': get_version_info(),
- },
- sort_keys=True,
- ),
- }
- )
-
- headers = {
- 'Content-Type': ctype,
- }
-
- auth_token = hubble.get_token()
- if auth_token:
- headers['Authorization'] = f'token {auth_token}'
-
- _head, _tail = data.split(delimiter)
-
- def gen():
- yield _head
- binary_stream = docs._to_binary_stream(
- protocol='protobuf', compress='gzip', show_progress=show_progress
- )
- while True:
- try:
- yield next(binary_stream)
- except StopIteration:
- break
- yield _tail
-
- response = requests.post(
- HubbleClient()._base_url + EndpointsV2.upload_artifact,
- data=gen(),
- headers=headers,
- )
-
- if response.ok:
- return response.json()['data']
- else:
- if response.status_code >= 400 and 'readableMessage' in response.json():
- response.reason = response.json()['readableMessage']
- raise_req_error(response)
-
- @classmethod
- @hubble.login_required
- def push_stream(
- cls: Type[SelfJACDocStore],
- docs: Iterator['BaseDoc'],
- name: str,
- public: bool = True,
- show_progress: bool = False,
- branding: Optional[Dict] = None,
- ) -> Dict:
- """Push a stream of documents to Jina AI Cloud
-
- !!! note
- - Push with the same ``name`` will override the existing content.
- - Kinda like a public clipboard where everyone can override anyone's content.
- So to make your content survive longer, you may want to use longer & more complicated name.
- - The lifetime of the content is not promised atm, could be a day, could be a week. Do not use it for
- persistence. Only use this full temporary transmission/storage/clipboard.
-
- :param docs: a stream of documents
- :param name: A name that can later be used to retrieve this `DocList`.
- :param public: By default, anyone can pull a `DocList` if they know its name.
- Setting this to false will restrict access to only the creator.
- :param show_progress: If true, a progress bar will be displayed.
- :param branding: A dictionary of branding information to be sent to Jina Cloud. e.g. {"icon": "emoji", "background": "#fff"}
- """
- from docarray import DocList
-
- # This is a temporary solution to push a stream of documents
- # The memory footprint is not ideal
- # But it must be done this way for now because Hubble expects to know the length of the DocList
- # before it starts receiving the documents
- first_doc = next(docs)
- _docs = DocList[first_doc.__class__]([first_doc]) # type: ignore
- for doc in docs:
- _docs.append(doc)
- return cls.push(_docs, name, public, show_progress, branding)
-
- @staticmethod
- @hubble.login_required
- def pull(
- cls: Type['DocList'],
- name: str,
- show_progress: bool = False,
- local_cache: bool = True,
- ) -> 'DocList':
- """Pull a [`DocList`][docarray.DocList] from Jina AI Cloud to local.
-
- :param name: the upload name set during `.push`
- :param show_progress: if true, display a progress bar.
- :param local_cache: store the downloaded DocList to local folder
- :return: a [`DocList`][docarray.DocList] object
- """
- from docarray import DocList
-
- return DocList[cls.doc_type]( # type: ignore
- JACDocStore.pull_stream(cls, name, show_progress, local_cache)
- )
-
- @staticmethod
- @hubble.login_required
- def pull_stream(
- cls: Type['DocList'],
- name: str,
- show_progress: bool = False,
- local_cache: bool = False,
- ) -> Iterator['BaseDoc']:
- """Pull a [`DocList`][docarray.DocList] from Jina AI Cloud to local.
-
- :param name: the upload name set during `.push`
- :param show_progress: if true, display a progress bar.
- :param local_cache: store the downloaded DocList to local folder
- :return: An iterator of Documents
- """
- import requests
-
- headers = {}
-
- auth_token = hubble.get_token()
-
- if auth_token:
- headers['Authorization'] = f'token {auth_token}'
-
- url = HubbleClient()._base_url + EndpointsV2.download_artifact + f'?name={name}'
- response = requests.get(url, headers=headers)
-
- if response.ok:
- url = response.json()['data']['download']
- else:
- response.raise_for_status()
-
- with requests.get(
- url,
- stream=True,
- ) as r:
- from contextlib import nullcontext
-
- r.raise_for_status()
- save_name = name.replace('/', '_')
-
- tmp_cache_file = Path(f'/tmp/{save_name}.docs')
- _source: Union[
- _BufferedCachingRequestReader, io.BufferedReader
- ] = _BufferedCachingRequestReader(r, tmp_cache_file)
-
- cache_file = _get_cache_path() / f'{save_name}.docs'
- if local_cache and cache_file.exists():
- _cache_len = cache_file.stat().st_size
- if _cache_len == int(r.headers['Content-length']):
- if show_progress:
- print(f'Loading from local cache {cache_file}')
- _source = open(cache_file, 'rb')
- r.close()
-
- docs = cls._load_binary_stream(
- nullcontext(_source), # type: ignore
- protocol='protobuf',
- compress='gzip',
- show_progress=show_progress,
- )
- try:
- while True:
- yield next(docs)
- except StopIteration:
- pass
-
- if local_cache:
- if isinstance(_source, _BufferedCachingRequestReader):
- Path(_get_cache_path()).mkdir(parents=True, exist_ok=True)
- tmp_cache_file.rename(cache_file)
- else:
- _source.close()
diff --git a/docarray/store/s3.py b/docarray/store/s3.py
index 2ebb864fc8d..5b2e4ae6f4b 100644
--- a/docarray/store/s3.py
+++ b/docarray/store/s3.py
@@ -121,39 +121,28 @@ def push(
cls: Type[SelfS3DocStore],
docs: 'DocList',
name: str,
- public: bool = False,
show_progress: bool = False,
- branding: Optional[Dict] = None,
) -> Dict:
"""Push this [`DocList`][docarray.DocList] object to the specified bucket and key.
:param docs: The `DocList` to push.
:param name: The bucket and key to push to. e.g. my_bucket/my_key
- :param public: Not used by the ``s3`` protocol.
:param show_progress: If true, a progress bar will be displayed.
- :param branding: Not used by the ``s3`` protocol.
"""
- return cls.push_stream(iter(docs), name, public, show_progress, branding)
+ return cls.push_stream(iter(docs), name, show_progress)
@staticmethod
def push_stream(
docs: Iterator['BaseDoc'],
name: str,
- public: bool = True,
show_progress: bool = False,
- branding: Optional[Dict] = None,
) -> Dict:
"""Push a stream of documents to the specified bucket and key.
:param docs: a stream of documents
:param name: The bucket and key to push to. e.g. my_bucket/my_key
- :param public: Not used by the ``s3`` protocol.
:param show_progress: If true, a progress bar will be displayed.
- :param branding: Not used by the ``s3`` protocol.
"""
- if branding is not None:
- logging.warning("Branding is not supported for S3 push")
-
bucket, name = name.split('/', 1)
binary_stream = _to_binary_stream(
docs, protocol='pickle', compress=None, show_progress=show_progress
diff --git a/docarray/typing/__init__.py b/docarray/typing/__init__.py
index 5fdb578ad04..ed7e1d7b9d2 100644
--- a/docarray/typing/__init__.py
+++ b/docarray/typing/__init__.py
@@ -24,12 +24,20 @@
if TYPE_CHECKING:
from docarray.typing.tensor import TensorFlowTensor # noqa: F401
- from docarray.typing.tensor import TorchEmbedding, TorchTensor # noqa: F401
+ from docarray.typing.tensor import ( # noqa: F401
+ JaxArray,
+ JaxArrayEmbedding,
+ TorchEmbedding,
+ TorchTensor,
+ )
+ from docarray.typing.tensor.audio import AudioJaxArray # noqa: F401
from docarray.typing.tensor.audio import AudioTensorFlowTensor # noqa: F401
from docarray.typing.tensor.audio import AudioTorchTensor # noqa: F401
from docarray.typing.tensor.embedding import TensorFlowEmbedding # noqa: F401
+ from docarray.typing.tensor.image import ImageJaxArray # noqa: F401
from docarray.typing.tensor.image import ImageTensorFlowTensor # noqa: F401
from docarray.typing.tensor.image import ImageTorchTensor # noqa: F401
+ from docarray.typing.tensor.video import VideoJaxArray # noqa: F401
from docarray.typing.tensor.video import VideoTensorFlowTensor # noqa: F401
from docarray.typing.tensor.video import VideoTorchTensor # noqa: F401
@@ -73,6 +81,15 @@
'AudioTensorFlowTensor',
'VideoTensorFlowTensor',
]
+
+_jax_tensors = [
+ 'JaxArray',
+ 'JaxArrayEmbedding',
+ 'VideoJaxArray',
+ 'AudioJaxArray',
+ 'ImageJaxArray',
+]
+
__all_test__ = __all__ + _torch_tensors
@@ -81,6 +98,8 @@ def __getattr__(name: str):
import_library('torch', raise_error=True)
elif name in _tf_tensors:
import_library('tensorflow', raise_error=True)
+ elif name in _jax_tensors:
+ import_library('jax', raise_error=True)
else:
raise ImportError(
f'cannot import name \'{name}\' from \'{_get_path_from_docarray_root_level(__file__)}\''
diff --git a/docarray/typing/abstract_type.py b/docarray/typing/abstract_type.py
index 3193116db08..2aa009d4e6a 100644
--- a/docarray/typing/abstract_type.py
+++ b/docarray/typing/abstract_type.py
@@ -1,8 +1,27 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from abc import abstractmethod
-from typing import Any, Type, TypeVar
+from typing import TYPE_CHECKING, Any, Type, TypeVar
-from pydantic import BaseConfig
-from pydantic.fields import ModelField
+from docarray.utils._internal.pydantic import is_pydantic_v2
+
+if TYPE_CHECKING:
+ if is_pydantic_v2:
+ from pydantic import GetCoreSchemaHandler
+ from pydantic_core import core_schema
from docarray.base_doc.base_node import BaseNode
@@ -16,10 +35,29 @@ def __get_validators__(cls):
@classmethod
@abstractmethod
- def validate(
- cls: Type[T],
- value: Any,
- field: 'ModelField',
- config: 'BaseConfig',
- ) -> T:
+ def _docarray_validate(cls: Type[T], value: Any) -> T:
...
+
+ if is_pydantic_v2:
+
+ @classmethod
+ def validate(cls: Type[T], value: Any, _: Any) -> T:
+ return cls._docarray_validate(value)
+
+ else:
+
+ @classmethod
+ def validate(
+ cls: Type[T],
+ value: Any,
+ ) -> T:
+ return cls._docarray_validate(value)
+
+ if is_pydantic_v2:
+
+ @classmethod
+ @abstractmethod
+ def __get_pydantic_core_schema__(
+ cls, _source_type: Any, _handler: 'GetCoreSchemaHandler'
+ ) -> 'core_schema.CoreSchema':
+ ...
diff --git a/docarray/typing/bytes/__init__.py b/docarray/typing/bytes/__init__.py
index 2cf8524bcc0..015f3243759 100644
--- a/docarray/typing/bytes/__init__.py
+++ b/docarray/typing/bytes/__init__.py
@@ -1,3 +1,18 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from docarray.typing.bytes.audio_bytes import AudioBytes
from docarray.typing.bytes.image_bytes import ImageBytes
from docarray.typing.bytes.video_bytes import VideoBytes
diff --git a/docarray/typing/bytes/audio_bytes.py b/docarray/typing/bytes/audio_bytes.py
index 23c6f49a4d0..4747231be2c 100644
--- a/docarray/typing/bytes/audio_bytes.py
+++ b/docarray/typing/bytes/audio_bytes.py
@@ -1,48 +1,23 @@
import io
-from typing import TYPE_CHECKING, Any, Tuple, Type, TypeVar
+from typing import Tuple, TypeVar
import numpy as np
from pydantic import parse_obj_as
-from pydantic.validators import bytes_validator
-from docarray.typing.abstract_type import AbstractType
+from docarray.typing.bytes.base_bytes import BaseBytes
from docarray.typing.proto_register import _register_proto
from docarray.typing.tensor.audio import AudioNdArray
from docarray.utils._internal.misc import import_library
-if TYPE_CHECKING:
- from pydantic.fields import BaseConfig, ModelField
-
- from docarray.proto import NodeProto
-
T = TypeVar('T', bound='AudioBytes')
@_register_proto(proto_type_name='audio_bytes')
-class AudioBytes(bytes, AbstractType):
+class AudioBytes(BaseBytes):
"""
Bytes that store an audio and that can be load into an Audio tensor
"""
- @classmethod
- def validate(
- cls: Type[T],
- value: Any,
- field: 'ModelField',
- config: 'BaseConfig',
- ) -> T:
- value = bytes_validator(value)
- return cls(value)
-
- @classmethod
- def from_protobuf(cls: Type[T], pb_msg: T) -> T:
- return parse_obj_as(cls, pb_msg)
-
- def _to_node_protobuf(self: T) -> 'NodeProto':
- from docarray.proto import NodeProto
-
- return NodeProto(blob=self, type=self._proto_type_name)
-
def load(self) -> Tuple[AudioNdArray, int]:
"""
Load the Audio from the [`AudioBytes`][docarray.typing.AudioBytes] into an
@@ -58,9 +33,9 @@ def load(self) -> Tuple[AudioNdArray, int]:
class MyAudio(BaseDoc):
url: AudioUrl
- tensor: Optional[AudioNdArray]
- bytes_: Optional[AudioBytes]
- frame_rate: Optional[float]
+ tensor: Optional[AudioNdArray] = None
+ bytes_: Optional[AudioBytes] = None
+ frame_rate: Optional[float] = None
doc = MyAudio(url='https://www.kozco.com/tech/piano2.wav')
diff --git a/docarray/typing/bytes/base_bytes.py b/docarray/typing/bytes/base_bytes.py
new file mode 100644
index 00000000000..8a944031b4e
--- /dev/null
+++ b/docarray/typing/bytes/base_bytes.py
@@ -0,0 +1,68 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from abc import abstractmethod
+from typing import TYPE_CHECKING, Any, Type, TypeVar
+
+from pydantic import parse_obj_as
+
+from docarray.typing.abstract_type import AbstractType
+from docarray.utils._internal.pydantic import bytes_validator, is_pydantic_v2
+
+if is_pydantic_v2:
+ from pydantic_core import core_schema
+
+if TYPE_CHECKING:
+ from docarray.proto import NodeProto
+
+ if is_pydantic_v2:
+ from pydantic import GetCoreSchemaHandler
+
+T = TypeVar('T', bound='BaseBytes')
+
+
+class BaseBytes(bytes, AbstractType):
+ """
+ Bytes type for docarray
+ """
+
+ @classmethod
+ def _docarray_validate(
+ cls: Type[T],
+ value: Any,
+ ) -> T:
+ value = bytes_validator(value)
+ return cls(value)
+
+ @classmethod
+ def from_protobuf(cls: Type[T], pb_msg: T) -> T:
+ return parse_obj_as(cls, pb_msg)
+
+ def _to_node_protobuf(self: T) -> 'NodeProto':
+ from docarray.proto import NodeProto
+
+ return NodeProto(blob=self, type=self._proto_type_name)
+
+ if is_pydantic_v2:
+
+ @classmethod
+ @abstractmethod
+ def __get_pydantic_core_schema__(
+ cls, _source_type: Any, _handler: 'GetCoreSchemaHandler'
+ ) -> 'core_schema.CoreSchema':
+ return core_schema.with_info_after_validator_function(
+ cls.validate,
+ core_schema.bytes_schema(),
+ )
diff --git a/docarray/typing/bytes/image_bytes.py b/docarray/typing/bytes/image_bytes.py
index a456a493ccb..a2a847ef8ed 100644
--- a/docarray/typing/bytes/image_bytes.py
+++ b/docarray/typing/bytes/image_bytes.py
@@ -1,49 +1,27 @@
from io import BytesIO
-from typing import TYPE_CHECKING, Any, Optional, Tuple, Type, TypeVar
+from typing import TYPE_CHECKING, Optional, Tuple, TypeVar
import numpy as np
from pydantic import parse_obj_as
-from pydantic.validators import bytes_validator
-from docarray.typing.abstract_type import AbstractType
+from docarray.typing.bytes.base_bytes import BaseBytes
from docarray.typing.proto_register import _register_proto
from docarray.typing.tensor.image.image_ndarray import ImageNdArray
from docarray.utils._internal.misc import import_library
if TYPE_CHECKING:
from PIL import Image as PILImage
- from pydantic.fields import BaseConfig, ModelField
- from docarray.proto import NodeProto
T = TypeVar('T', bound='ImageBytes')
@_register_proto(proto_type_name='image_bytes')
-class ImageBytes(bytes, AbstractType):
+class ImageBytes(BaseBytes):
"""
Bytes that store an image and that can be load into an image tensor
"""
- @classmethod
- def validate(
- cls: Type[T],
- value: Any,
- field: 'ModelField',
- config: 'BaseConfig',
- ) -> T:
- value = bytes_validator(value)
- return cls(value)
-
- @classmethod
- def from_protobuf(cls: Type[T], pb_msg: T) -> T:
- return parse_obj_as(cls, pb_msg)
-
- def _to_node_protobuf(self: T) -> 'NodeProto':
- from docarray.proto import NodeProto
-
- return NodeProto(blob=self, type=self._proto_type_name)
-
def load_pil(
self,
) -> 'PILImage.Image':
diff --git a/docarray/typing/bytes/video_bytes.py b/docarray/typing/bytes/video_bytes.py
index 720326fdbc1..a1003046720 100644
--- a/docarray/typing/bytes/video_bytes.py
+++ b/docarray/typing/bytes/video_bytes.py
@@ -1,20 +1,14 @@
from io import BytesIO
-from typing import TYPE_CHECKING, Any, List, NamedTuple, Type, TypeVar
+from typing import TYPE_CHECKING, List, NamedTuple, TypeVar
import numpy as np
from pydantic import parse_obj_as
-from pydantic.validators import bytes_validator
-from docarray.typing.abstract_type import AbstractType
+from docarray.typing.bytes.base_bytes import BaseBytes
from docarray.typing.proto_register import _register_proto
from docarray.typing.tensor import AudioNdArray, NdArray, VideoNdArray
from docarray.utils._internal.misc import import_library
-if TYPE_CHECKING:
- from pydantic.fields import BaseConfig, ModelField
-
- from docarray.proto import NodeProto
-
T = TypeVar('T', bound='VideoBytes')
@@ -25,30 +19,11 @@ class VideoLoadResult(NamedTuple):
@_register_proto(proto_type_name='video_bytes')
-class VideoBytes(bytes, AbstractType):
+class VideoBytes(BaseBytes):
"""
Bytes that store a video and that can be load into a video tensor
"""
- @classmethod
- def validate(
- cls: Type[T],
- value: Any,
- field: 'ModelField',
- config: 'BaseConfig',
- ) -> T:
- value = bytes_validator(value)
- return cls(value)
-
- @classmethod
- def from_protobuf(cls: Type[T], pb_msg: T) -> T:
- return parse_obj_as(cls, pb_msg)
-
- def _to_node_protobuf(self: T) -> 'NodeProto':
- from docarray.proto import NodeProto
-
- return NodeProto(blob=self, type=self._proto_type_name)
-
def load(self, **kwargs) -> VideoLoadResult:
"""
Load the video from the bytes into a VideoLoadResult object consisting of:
diff --git a/docarray/typing/id.py b/docarray/typing/id.py
index dd4b0db08e0..3e3fdd37ae4 100644
--- a/docarray/typing/id.py
+++ b/docarray/typing/id.py
@@ -1,16 +1,36 @@
-from typing import TYPE_CHECKING, Type, TypeVar, Union
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING, Any, Type, TypeVar, Union
from uuid import UUID
-from pydantic import BaseConfig, parse_obj_as
-from pydantic.fields import ModelField
+from pydantic import parse_obj_as
from docarray.typing.proto_register import _register_proto
+from docarray.utils._internal.pydantic import is_pydantic_v2
if TYPE_CHECKING:
from docarray.proto import NodeProto
from docarray.typing.abstract_type import AbstractType
+if is_pydantic_v2:
+ from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler
+ from pydantic.json_schema import JsonSchemaValue
+ from pydantic_core import core_schema
+
T = TypeVar('T', bound='ID')
@@ -21,15 +41,9 @@ class ID(str, AbstractType):
"""
@classmethod
- def __get_validators__(cls):
- yield cls.validate
-
- @classmethod
- def validate(
+ def _docarray_validate(
cls: Type[T],
value: Union[str, int, UUID],
- field: 'ModelField',
- config: 'BaseConfig',
) -> T:
try:
id: str = str(value)
@@ -56,3 +70,21 @@ def from_protobuf(cls: Type[T], pb_msg: 'str') -> T:
:return: a string
"""
return parse_obj_as(cls, pb_msg)
+
+ if is_pydantic_v2:
+
+ @classmethod
+ def __get_pydantic_core_schema__(
+ cls, source: Type[Any], handler: 'GetCoreSchemaHandler'
+ ) -> core_schema.CoreSchema:
+ return core_schema.with_info_plain_validator_function(
+ cls.validate,
+ )
+
+ @classmethod
+ def __get_pydantic_json_schema__(
+ cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
+ ) -> JsonSchemaValue:
+ field_schema: dict[str, Any] = {}
+ field_schema.update(type='string')
+ return field_schema
diff --git a/docarray/typing/proto_register.py b/docarray/typing/proto_register.py
index 700fe744ad8..0839039f4e6 100644
--- a/docarray/typing/proto_register.py
+++ b/docarray/typing/proto_register.py
@@ -1,3 +1,18 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from typing import Callable, Dict, Type, TypeVar
from docarray.typing.abstract_type import AbstractType
diff --git a/docarray/typing/tensor/__init__.py b/docarray/typing/tensor/__init__.py
index 4c4077f3cdb..2da7f5939ec 100644
--- a/docarray/typing/tensor/__init__.py
+++ b/docarray/typing/tensor/__init__.py
@@ -14,14 +14,19 @@
)
if TYPE_CHECKING:
+ from docarray.typing.tensor.audio import AudioJaxArray # noqa: F401
from docarray.typing.tensor.audio import AudioTensorFlowTensor # noqa: F401
from docarray.typing.tensor.audio import AudioTorchTensor # noqa: F401
+ from docarray.typing.tensor.embedding import JaxArrayEmbedding # noqa F401
from docarray.typing.tensor.embedding import TensorFlowEmbedding # noqa: F401
from docarray.typing.tensor.embedding import TorchEmbedding # noqa: F401
+ from docarray.typing.tensor.image import ImageJaxArray # noqa: F401
from docarray.typing.tensor.image import ImageTensorFlowTensor # noqa: F401
from docarray.typing.tensor.image import ImageTorchTensor # noqa: F401
+ from docarray.typing.tensor.jaxarray import JaxArray # noqa: F401
from docarray.typing.tensor.tensorflow_tensor import TensorFlowTensor # noqa: F401
from docarray.typing.tensor.torch_tensor import TorchTensor # noqa: F401
+ from docarray.typing.tensor.video import VideoJaxArray # noqa: F401
from docarray.typing.tensor.video import VideoTensorFlowTensor # noqa: F401
from docarray.typing.tensor.video import VideoTorchTensor # noqa: F401
@@ -42,19 +47,23 @@ def __getattr__(name: str):
import_library('torch', raise_error=True)
elif 'TensorFlow' in name:
import_library('tensorflow', raise_error=True)
+ elif 'Jax' in name:
+ import_library('jax', raise_error=True)
lib: types.ModuleType
if name == 'TorchTensor':
import docarray.typing.tensor.torch_tensor as lib
elif name == 'TensorFlowTensor':
import docarray.typing.tensor.tensorflow_tensor as lib
- elif name in ['TorchEmbedding', 'TensorFlowEmbedding']:
+ elif name == 'JaxArray':
+ import docarray.typing.tensor.jaxarray as lib
+ elif name in ['TorchEmbedding', 'TensorFlowEmbedding', 'JaxArrayEmbedding']:
import docarray.typing.tensor.embedding as lib
- elif name in ['ImageTorchTensor', 'ImageTensorFlowTensor']:
+ elif name in ['ImageTorchTensor', 'ImageTensorFlowTensor', 'ImageJaxArray']:
import docarray.typing.tensor.image as lib
- elif name in ['AudioTorchTensor', 'AudioTensorFlowTensor']:
+ elif name in ['AudioTorchTensor', 'AudioTensorFlowTensor', 'AudioJaxArray']:
import docarray.typing.tensor.audio as lib
- elif name in ['VideoTorchTensor', 'VideoTensorFlowTensor']:
+ elif name in ['VideoTorchTensor', 'VideoTensorFlowTensor', 'VideoJaxArray']:
import docarray.typing.tensor.video as lib
else:
raise ImportError(
diff --git a/docarray/typing/tensor/abstract_tensor.py b/docarray/typing/tensor/abstract_tensor.py
index 166c4539ba0..e7e4fbe7056 100644
--- a/docarray/typing/tensor/abstract_tensor.py
+++ b/docarray/typing/tensor/abstract_tensor.py
@@ -23,11 +23,14 @@
from docarray.base_doc.io.json import orjson_dumps
from docarray.computation import AbstractComputationalBackend
from docarray.typing.abstract_type import AbstractType
+from docarray.utils._internal._typing import safe_issubclass
+from docarray.utils._internal.pydantic import is_pydantic_v2
-if TYPE_CHECKING:
- from pydantic import BaseConfig
- from pydantic.fields import ModelField
+if is_pydantic_v2:
+ from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler
+ from pydantic_core import CoreSchema, core_schema
+if TYPE_CHECKING:
from docarray.proto import NdArrayProto, NodeProto
T = TypeVar('T', bound='AbstractTensor')
@@ -44,7 +47,7 @@ class _ParametrizedMeta(type):
This metaclass ensures that instance, subclass and equality checks on parametrized Tensors
are handled as expected:
- assert issubclass(TorchTensor[128], TorchTensor[128])
+ assert safe_issubclass(TorchTensor[128], TorchTensor[128])
t = parse_obj_as(TorchTensor[128], torch.zeros(128))
assert isinstance(t, TorchTensor[128])
TorchTensor[128] == TorchTensor[128]
@@ -58,8 +61,8 @@ class _ParametrizedMeta(type):
def _equals_special_case(cls, other):
is_type = isinstance(other, type)
- is_tensor = is_type and AbstractTensor in other.mro()
- same_parents = is_tensor and cls.mro()[1:] == other.mro()[1:]
+ is_tensor = is_type and AbstractTensor in other.__mro__
+ same_parents = is_tensor and cls.__mro__[1:] == other.__mro__[1:]
subclass_target_shape = getattr(other, '__docarray_target_shape__', False)
self_target_shape = getattr(cls, '__docarray_target_shape__', False)
@@ -90,10 +93,12 @@ def __instancecheck__(cls, instance):
):
return False
return any(
- issubclass(candidate, _cls.__unparametrizedcls__)
- for candidate in type(instance).mro()
+ safe_issubclass(candidate, _cls.__unparametrizedcls__)
+ for candidate in type(instance).__mro__
)
- return any(issubclass(candidate, cls) for candidate in type(instance).mro())
+ return any(
+ safe_issubclass(candidate, cls) for candidate in type(instance).__mro__
+ )
return super().__instancecheck__(instance)
def __eq__(cls, other):
@@ -234,23 +239,57 @@ def __docarray_validate_getitem__(cls, item: Any) -> Tuple[int]:
raise TypeError(f'{item} is not a valid tensor shape.')
return item
- @classmethod
- def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
- field_schema.update(type='array', items={'type': 'number'})
- if cls.__docarray_target_shape__ is not None:
- shape_info = (
- '[' + ', '.join([str(s) for s in cls.__docarray_target_shape__]) + ']'
- )
- if (
- reduce(mul, cls.__docarray_target_shape__, 1)
- <= DISPLAY_TENSOR_OPENAPI_MAX_ITEMS
- ):
- # custom example only for 'small' shapes, otherwise it is too big to display
- example_payload = orjson_dumps(np.zeros(cls.__docarray_target_shape__))
- field_schema.update(example=example_payload)
- else:
- shape_info = 'not specified'
- field_schema['tensor/array shape'] = shape_info
+ if is_pydantic_v2:
+
+ @classmethod
+ def __get_pydantic_json_schema__(
+ cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler
+ ) -> Dict[str, Any]:
+ json_schema = {}
+ json_schema.update(type='array', items={'type': 'number'})
+ if cls.__docarray_target_shape__ is not None:
+ shape_info = (
+ '['
+ + ', '.join([str(s) for s in cls.__docarray_target_shape__])
+ + ']'
+ )
+ if (
+ reduce(mul, cls.__docarray_target_shape__, 1)
+ <= DISPLAY_TENSOR_OPENAPI_MAX_ITEMS
+ ):
+ # custom example only for 'small' shapes, otherwise it is too big to display
+ example_payload = orjson_dumps(
+ np.zeros(cls.__docarray_target_shape__)
+ ).decode()
+ json_schema.update(example=example_payload)
+ else:
+ shape_info = 'not specified'
+ json_schema['tensor/array shape'] = shape_info
+ return json_schema
+
+ else:
+
+ @classmethod
+ def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
+ field_schema.update(type='array', items={'type': 'number'})
+ if cls.__docarray_target_shape__ is not None:
+ shape_info = (
+ '['
+ + ', '.join([str(s) for s in cls.__docarray_target_shape__])
+ + ']'
+ )
+ if (
+ reduce(mul, cls.__docarray_target_shape__, 1)
+ <= DISPLAY_TENSOR_OPENAPI_MAX_ITEMS
+ ):
+ # custom example only for 'small' shapes, otherwise it is too big to display
+ example_payload = orjson_dumps(
+ np.zeros(cls.__docarray_target_shape__)
+ ).decode()
+ field_schema.update(example=example_payload)
+ else:
+ shape_info = 'not specified'
+ field_schema['tensor/array shape'] = shape_info
@classmethod
def _docarray_create_parametrized_type(cls: Type[T], shape: Tuple[int]):
@@ -264,13 +303,11 @@ class _ParametrizedTensor(
__docarray_target_shape__ = shape
@classmethod
- def validate(
+ def _docarray_validate(
_cls,
value: Any,
- field: 'ModelField',
- config: 'BaseConfig',
):
- t = super().validate(value, field, config)
+ t = super()._docarray_validate(value)
return _cls.__docarray_validate_shape__(
t, _cls.__docarray_target_shape__
)
@@ -337,3 +374,32 @@ def _docarray_to_json_compatible(self):
:return: a representation of the tensor compatible with orjson
"""
return self
+
+ @classmethod
+ @abc.abstractmethod
+ def _docarray_from_ndarray(cls: Type[T], value: np.ndarray) -> T:
+ """Create a `tensor from a numpy array
+ PS: this function is different from `from_ndarray` because it is private under the docarray namesapce.
+ This allows us to avoid breaking change if one day we introduce a Tensor backend with a `from_ndarray` method.
+ """
+ ...
+
+ @abc.abstractmethod
+ def _docarray_to_ndarray(self) -> np.ndarray:
+ """cast itself to a numpy array"""
+ ...
+
+ if is_pydantic_v2:
+
+ @classmethod
+ def __get_pydantic_core_schema__(
+ cls, _source_type: Any, handler: GetCoreSchemaHandler
+ ) -> core_schema.CoreSchema:
+ return core_schema.with_info_plain_validator_function(
+ cls.validate,
+ serialization=core_schema.plain_serializer_function_ser_schema(
+ function=lambda x: x._docarray_to_ndarray().tolist(),
+ return_schema=handler.generate_schema(bytes),
+ when_used="json-unless-none",
+ ),
+ )
diff --git a/docarray/typing/tensor/audio/__init__.py b/docarray/typing/tensor/audio/__init__.py
index a505ab05720..5f304ae544f 100644
--- a/docarray/typing/tensor/audio/__init__.py
+++ b/docarray/typing/tensor/audio/__init__.py
@@ -9,12 +9,13 @@
)
if TYPE_CHECKING:
+ from docarray.typing.tensor.audio.audio_jax_array import AudioJaxArray # noqa
from docarray.typing.tensor.audio.audio_tensorflow_tensor import ( # noqa
AudioTensorFlowTensor,
)
from docarray.typing.tensor.audio.audio_torch_tensor import AudioTorchTensor # noqa
-__all__ = ['AudioNdArray', 'AudioTensor']
+__all__ = ['AudioNdArray', 'AudioTensor', 'AudioJaxArray']
def __getattr__(name: str):
@@ -25,6 +26,9 @@ def __getattr__(name: str):
elif name == 'AudioTensorFlowTensor':
import_library('tensorflow', raise_error=True)
import docarray.typing.tensor.audio.audio_tensorflow_tensor as lib
+ elif name == 'AudioJaxArray':
+ import_library('jax', raise_error=True)
+ import docarray.typing.tensor.audio.audio_jax_array as lib
else:
raise ImportError(
f'cannot import name \'{name}\' from \'{_get_path_from_docarray_root_level(__file__)}\''
diff --git a/docarray/typing/tensor/audio/audio_jax_array.py b/docarray/typing/tensor/audio/audio_jax_array.py
new file mode 100644
index 00000000000..50ce9c97438
--- /dev/null
+++ b/docarray/typing/tensor/audio/audio_jax_array.py
@@ -0,0 +1,27 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TypeVar
+
+from docarray.typing.proto_register import _register_proto
+from docarray.typing.tensor.audio.abstract_audio_tensor import AbstractAudioTensor
+from docarray.typing.tensor.jaxarray import JaxArray, metaJax
+
+T = TypeVar('T', bound='AudioJaxArray')
+
+
+@_register_proto(proto_type_name='audio_jaxarray')
+class AudioJaxArray(AbstractAudioTensor, JaxArray, metaclass=metaJax):
+ ...
diff --git a/docarray/typing/tensor/audio/audio_ndarray.py b/docarray/typing/tensor/audio/audio_ndarray.py
index 3b15c0bc932..3b5c79aad1f 100644
--- a/docarray/typing/tensor/audio/audio_ndarray.py
+++ b/docarray/typing/tensor/audio/audio_ndarray.py
@@ -22,9 +22,9 @@ class AudioNdArray(AbstractAudioTensor, NdArray):
class MyAudioDoc(BaseDoc):
title: str
- audio_tensor: Optional[AudioNdArray]
- url: Optional[AudioUrl]
- bytes_: Optional[AudioBytes]
+ audio_tensor: Optional[AudioNdArray] = None
+ url: Optional[AudioUrl] = None
+ bytes_: Optional[AudioBytes] = None
# from tensor
diff --git a/docarray/typing/tensor/audio/audio_tensor.py b/docarray/typing/tensor/audio/audio_tensor.py
index 7d1f4a9e315..27c5efddff5 100644
--- a/docarray/typing/tensor/audio/audio_tensor.py
+++ b/docarray/typing/tensor/audio/audio_tensor.py
@@ -1,23 +1,106 @@
-from typing import Union
+from typing import Any, Type, TypeVar, Union, cast
+import numpy as np
+
+from docarray.typing.tensor.audio.abstract_audio_tensor import AbstractAudioTensor
from docarray.typing.tensor.audio.audio_ndarray import AudioNdArray
-from docarray.utils._internal.misc import is_tf_available, is_torch_available
+from docarray.typing.tensor.tensor import AnyTensor
+from docarray.utils._internal.misc import (
+ is_jax_available,
+ is_tf_available,
+ is_torch_available,
+)
torch_available = is_torch_available()
if torch_available:
+ import torch
+
from docarray.typing.tensor.audio.audio_torch_tensor import AudioTorchTensor
+ from docarray.typing.tensor.torch_tensor import TorchTensor
tf_available = is_tf_available()
if tf_available:
+ import tensorflow as tf # type: ignore
+
from docarray.typing.tensor.audio.audio_tensorflow_tensor import (
- AudioTensorFlowTensor as AudioTFTensor,
+ AudioTensorFlowTensor,
)
+ from docarray.typing.tensor.tensorflow_tensor import TensorFlowTensor
+
+jax_available = is_jax_available()
+if jax_available:
+ import jax.numpy as jnp # type: ignore
+
+ from docarray.typing.tensor.audio.audio_jax_array import AudioJaxArray
+ from docarray.typing.tensor.jaxarray import JaxArray
+
+T = TypeVar("T", bound="AudioTensor")
+
+
+class AudioTensor(AnyTensor, AbstractAudioTensor):
+ """
+ Represents an audio tensor object that can be used with TensorFlow, PyTorch, and NumPy type.
+
+ ---
+ '''python
+ from docarray import BaseDoc
+ from docarray.typing import AudioTensor
+
+
+ class MyAudioDoc(BaseDoc):
+ tensor: AudioTensor
+
+
+ # Example usage with TensorFlow:
+ import tensorflow as tf
+
+ doc = MyAudioDoc(tensor=tf.zeros(1000, 2))
+ type(doc.tensor) # AudioTensorFlowTensor
+
+ # Example usage with PyTorch:
+ import torch
+
+ doc = MyAudioDoc(tensor=torch.zeros(1000, 2))
+ type(doc.tensor) # AudioTorchTensor
+
+ # Example usage with NumPy:
+ import numpy as np
+
+ doc = MyAudioDoc(tensor=np.zeros((1000, 2)))
+ type(doc.tensor) # AudioNdArray
+ '''
+ ---
+
+ Raises:
+ TypeError: If the input value is not a compatible type (torch.Tensor, tensorflow.Tensor, numpy.ndarray).
+ """
-AudioTensor = AudioNdArray
-if tf_available and torch_available:
- AudioTensor = Union[AudioNdArray, AudioTorchTensor, AudioTFTensor] # type: ignore
-elif tf_available:
- AudioTensor = Union[AudioNdArray, AudioTFTensor] # type: ignore
-elif torch_available:
- AudioTensor = Union[AudioNdArray, AudioTorchTensor] # type: ignore
+ @classmethod
+ def _docarray_validate(
+ cls: Type[T],
+ value: Union[T, np.ndarray, Any],
+ ):
+ if torch_available:
+ if isinstance(value, TorchTensor):
+ return cast(AudioTorchTensor, value)
+ elif isinstance(value, torch.Tensor):
+ return AudioTorchTensor._docarray_from_native(value) # noqa
+ if tf_available:
+ if isinstance(value, TensorFlowTensor):
+ return cast(AudioTensorFlowTensor, value)
+ elif isinstance(value, tf.Tensor):
+ return AudioTensorFlowTensor._docarray_from_native(value) # noqa
+ if jax_available:
+ if isinstance(value, JaxArray):
+ return cast(AudioJaxArray, value)
+ elif isinstance(value, jnp.ndarray):
+ return AudioJaxArray._docarray_from_native(value) # noqa
+ try:
+ return AudioNdArray._docarray_validate(value)
+ except Exception: # noqa
+ pass
+ raise TypeError(
+ f"Expected one of [torch.Tensor, tensorflow.Tensor, numpy.ndarray] "
+ f"compatible type, got {type(value)}"
+ )
diff --git a/docarray/typing/tensor/audio/audio_torch_tensor.py b/docarray/typing/tensor/audio/audio_torch_tensor.py
index 974ddff120b..06b6c649e4c 100644
--- a/docarray/typing/tensor/audio/audio_torch_tensor.py
+++ b/docarray/typing/tensor/audio/audio_torch_tensor.py
@@ -22,9 +22,9 @@ class AudioTorchTensor(AbstractAudioTensor, TorchTensor, metaclass=metaTorchAndN
class MyAudioDoc(BaseDoc):
title: str
- audio_tensor: Optional[AudioTorchTensor]
- url: Optional[AudioUrl]
- bytes_: Optional[AudioBytes]
+ audio_tensor: Optional[AudioTorchTensor] = None
+ url: Optional[AudioUrl] = None
+ bytes_: Optional[AudioBytes] = None
doc_1 = MyAudioDoc(
diff --git a/docarray/typing/tensor/embedding/__init__.py b/docarray/typing/tensor/embedding/__init__.py
index c32048b21c6..0e518b67a57 100644
--- a/docarray/typing/tensor/embedding/__init__.py
+++ b/docarray/typing/tensor/embedding/__init__.py
@@ -10,6 +10,7 @@
)
if TYPE_CHECKING:
+ from docarray.typing.tensor.embedding.jax_array import JaxArrayEmbedding # noqa
from docarray.typing.tensor.embedding.tensorflow import TensorFlowEmbedding # noqa
from docarray.typing.tensor.embedding.torch import TorchEmbedding # noqa
@@ -24,6 +25,9 @@ def __getattr__(name: str):
elif name == 'TensorFlowEmbedding':
import_library('tensorflow', raise_error=True)
import docarray.typing.tensor.embedding.tensorflow as lib
+ elif name == 'JaxArrayEmbedding':
+ import_library('jax', raise_error=True)
+ import docarray.typing.tensor.embedding.jax_array as lib
else:
raise ImportError(
f'cannot import name \'{name}\' from \'{_get_path_from_docarray_root_level(__file__)}\''
diff --git a/docarray/typing/tensor/embedding/embedding.py b/docarray/typing/tensor/embedding/embedding.py
index 2eb1fa7e842..1f498f39f50 100644
--- a/docarray/typing/tensor/embedding/embedding.py
+++ b/docarray/typing/tensor/embedding/embedding.py
@@ -1,27 +1,105 @@
-from typing import Union
+from typing import Any, Type, TypeVar, Union, cast
+import numpy as np
+
+from docarray.typing.tensor.embedding.embedding_mixin import EmbeddingMixin
from docarray.typing.tensor.embedding.ndarray import NdArrayEmbedding
-from docarray.utils._internal.misc import is_tf_available, is_torch_available
+from docarray.typing.tensor.tensor import AnyTensor
+from docarray.utils._internal.misc import ( # noqa
+ is_jax_available,
+ is_tf_available,
+ is_torch_available,
+)
+
+jax_available = is_jax_available()
+if jax_available:
+ import jax.numpy as jnp # type: ignore
+
+ from docarray.typing.tensor.embedding.jax_array import JaxArrayEmbedding
+ from docarray.typing.tensor.jaxarray import JaxArray # noqa: F401
torch_available = is_torch_available()
if torch_available:
+ import torch
+
from docarray.typing.tensor.embedding.torch import TorchEmbedding
+ from docarray.typing.tensor.torch_tensor import TorchTensor # noqa: F401
tf_available = is_tf_available()
if tf_available:
- from docarray.typing.tensor.embedding.tensorflow import (
- TensorFlowEmbedding as TFEmbedding,
- )
+ import tensorflow as tf # type: ignore
+
+ from docarray.typing.tensor.embedding.tensorflow import TensorFlowEmbedding
+ from docarray.typing.tensor.tensorflow_tensor import TensorFlowTensor # noqa: F401
+
+
+T = TypeVar("T", bound="AnyEmbedding")
+
+
+class AnyEmbedding(AnyTensor, EmbeddingMixin):
+ """
+ Represents an embedding tensor object that can be used with TensorFlow, PyTorch, and NumPy type.
+
+ ---
+ '''python
+ from docarray import BaseDoc
+ from docarray.typing import AnyEmbedding
+
+
+ class MyEmbeddingDoc(BaseDoc):
+ embedding: AnyEmbedding
+
+
+ # Example usage with TensorFlow:
+ import tensorflow as tf
+
+ doc = MyEmbeddingDoc(embedding=tf.zeros(1000, 2))
+ type(doc.embedding) # TensorFlowEmbedding
+
+ # Example usage with PyTorch:
+ import torch
+
+ doc = MyEmbeddingDoc(embedding=torch.zeros(1000, 2))
+ type(doc.embedding) # TorchEmbedding
+
+ # Example usage with NumPy:
+ import numpy as np
+ doc = MyEmbeddingDoc(embedding=np.zeros((1000, 2)))
+ type(doc.embedding) # NdArrayEmbedding
+ '''
+ ---
-if tf_available and torch_available:
- AnyEmbedding = Union[NdArrayEmbedding, TorchEmbedding, TFEmbedding] # type: ignore
-elif tf_available:
- AnyEmbedding = Union[NdArrayEmbedding, TFEmbedding] # type: ignore
-elif torch_available:
- AnyEmbedding = Union[NdArrayEmbedding, TorchEmbedding] # type: ignore
-else:
- AnyEmbedding = Union[NdArrayEmbedding] # type: ignore
+ Raises:
+ TypeError: If the type of the value is not one of [torch.Tensor, tensorflow.Tensor, numpy.ndarray]
+ """
-__all__ = ['AnyEmbedding']
+ @classmethod
+ def _docarray_validate(
+ cls: Type[T],
+ value: Union[T, np.ndarray, Any],
+ ):
+ if torch_available:
+ if isinstance(value, TorchTensor):
+ return cast(TorchEmbedding, value)
+ elif isinstance(value, torch.Tensor):
+ return TorchEmbedding._docarray_from_native(value) # noqa
+ if tf_available:
+ if isinstance(value, TensorFlowTensor):
+ return cast(TensorFlowEmbedding, value)
+ elif isinstance(value, tf.Tensor):
+ return TensorFlowEmbedding._docarray_from_native(value) # noqa
+ if jax_available:
+ if isinstance(value, JaxArray):
+ return cast(JaxArrayEmbedding, value)
+ elif isinstance(value, jnp.ndarray):
+ return JaxArrayEmbedding._docarray_from_native(value) # noqa
+ try:
+ return NdArrayEmbedding._docarray_validate(value)
+ except Exception: # noqa
+ pass
+ raise TypeError(
+ f"Expected one of [torch.Tensor, tensorflow.Tensor, numpy.ndarray] "
+ f"compatible type, got {type(value)}"
+ )
diff --git a/docarray/typing/tensor/embedding/embedding_mixin.py b/docarray/typing/tensor/embedding/embedding_mixin.py
index a80cfc3d666..1310fae15ca 100644
--- a/docarray/typing/tensor/embedding/embedding_mixin.py
+++ b/docarray/typing/tensor/embedding/embedding_mixin.py
@@ -1,3 +1,18 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from abc import ABC
from typing import Any, Optional, Tuple, Type
diff --git a/docarray/typing/tensor/embedding/jax_array.py b/docarray/typing/tensor/embedding/jax_array.py
new file mode 100644
index 00000000000..4dbb7a67ee0
--- /dev/null
+++ b/docarray/typing/tensor/embedding/jax_array.py
@@ -0,0 +1,17 @@
+from typing import Any # noqa: F401
+
+from docarray.typing.proto_register import _register_proto
+from docarray.typing.tensor.embedding.embedding_mixin import EmbeddingMixin
+from docarray.typing.tensor.jaxarray import JaxArray
+
+jax_base = type(JaxArray) # type: Any
+embedding_base = type(EmbeddingMixin) # type: Any
+
+
+class metaJaxAndEmbedding(jax_base, embedding_base):
+ pass
+
+
+@_register_proto(proto_type_name='jaxarray_embedding')
+class JaxArrayEmbedding(JaxArray, EmbeddingMixin, metaclass=metaJaxAndEmbedding):
+ alternative_type = JaxArray
diff --git a/docarray/typing/tensor/embedding/ndarray.py b/docarray/typing/tensor/embedding/ndarray.py
index 631268e7c26..a320eb6942d 100644
--- a/docarray/typing/tensor/embedding/ndarray.py
+++ b/docarray/typing/tensor/embedding/ndarray.py
@@ -1,3 +1,18 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from docarray.typing.proto_register import _register_proto
from docarray.typing.tensor.embedding.embedding_mixin import EmbeddingMixin
from docarray.typing.tensor.ndarray import NdArray
diff --git a/docarray/typing/tensor/image/__init__.py b/docarray/typing/tensor/image/__init__.py
index 7af4b852206..d62b096c1fe 100644
--- a/docarray/typing/tensor/image/__init__.py
+++ b/docarray/typing/tensor/image/__init__.py
@@ -10,6 +10,7 @@
)
if TYPE_CHECKING:
+ from docarray.typing.tensor.image.image_jax_array import ImageJaxArray # noqa
from docarray.typing.tensor.image.image_tensorflow_tensor import ( # noqa
ImageTensorFlowTensor,
)
@@ -26,6 +27,9 @@ def __getattr__(name: str):
elif name == 'ImageTensorFlowTensor':
import_library('tensorflow', raise_error=True)
import docarray.typing.tensor.image.image_tensorflow_tensor as lib
+ elif name == 'ImageJaxArray':
+ import_library('jax', raise_error=True)
+ import docarray.typing.tensor.image.image_jax_array as lib
else:
raise ImportError(
f'cannot import name \'{name}\' from \'{_get_path_from_docarray_root_level(__file__)}\''
diff --git a/docarray/typing/tensor/image/image_jax_array.py b/docarray/typing/tensor/image/image_jax_array.py
new file mode 100644
index 00000000000..a814f2f7dae
--- /dev/null
+++ b/docarray/typing/tensor/image/image_jax_array.py
@@ -0,0 +1,25 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from docarray.typing.proto_register import _register_proto
+from docarray.typing.tensor.image.abstract_image_tensor import AbstractImageTensor
+from docarray.typing.tensor.jaxarray import JaxArray, metaJax
+
+MAX_INT_16 = 2**15
+
+
+@_register_proto(proto_type_name='image_jaxarray')
+class ImageJaxArray(JaxArray, AbstractImageTensor, metaclass=metaJax):
+ ...
diff --git a/docarray/typing/tensor/image/image_ndarray.py b/docarray/typing/tensor/image/image_ndarray.py
index 1ff3a14eaa0..b5e588961b0 100644
--- a/docarray/typing/tensor/image/image_ndarray.py
+++ b/docarray/typing/tensor/image/image_ndarray.py
@@ -25,9 +25,9 @@ class ImageNdArray(AbstractImageTensor, NdArray):
class MyImageDoc(BaseDoc):
title: str
- tensor: Optional[ImageNdArray]
- url: Optional[ImageUrl]
- bytes: Optional[ImageBytes]
+ tensor: Optional[ImageNdArray] = None
+ url: Optional[ImageUrl] = None
+ bytes: Optional[ImageBytes] = None
# from url
diff --git a/docarray/typing/tensor/image/image_tensor.py b/docarray/typing/tensor/image/image_tensor.py
index af439ce607e..b8920f7dba5 100644
--- a/docarray/typing/tensor/image/image_tensor.py
+++ b/docarray/typing/tensor/image/image_tensor.py
@@ -1,23 +1,109 @@
-from typing import Union
+from typing import Any, Type, TypeVar, Union, cast
+import numpy as np
+
+from docarray.typing.tensor.image.abstract_image_tensor import AbstractImageTensor
from docarray.typing.tensor.image.image_ndarray import ImageNdArray
-from docarray.utils._internal.misc import is_tf_available, is_torch_available
+from docarray.typing.tensor.tensor import AnyTensor
+from docarray.utils._internal.misc import (
+ is_jax_available,
+ is_tf_available,
+ is_torch_available,
+)
+
+jax_available = is_jax_available()
+if jax_available:
+ import jax.numpy as jnp # type: ignore
+
+ from docarray.typing.tensor.image.image_jax_array import ImageJaxArray
+ from docarray.typing.tensor.jaxarray import JaxArray
torch_available = is_torch_available()
if torch_available:
- from docarray.typing.tensor.image.image_torch_tensor import ImageTorchTensor
+ import torch
+ from docarray.typing.tensor.image.image_torch_tensor import ImageTorchTensor
+ from docarray.typing.tensor.torch_tensor import TorchTensor
tf_available = is_tf_available()
if tf_available:
+ import tensorflow as tf # type: ignore
+
from docarray.typing.tensor.image.image_tensorflow_tensor import (
- ImageTensorFlowTensor as ImageTFTensor,
+ ImageTensorFlowTensor,
)
+ from docarray.typing.tensor.tensorflow_tensor import TensorFlowTensor
+
+
+T = TypeVar("T", bound="ImageTensor")
+
+
+class ImageTensor(AnyTensor, AbstractImageTensor):
+ """
+ Represents an image tensor object that can be used with TensorFlow, PyTorch, and NumPy type.
+
+ ---
+ '''python
+ from docarray import BaseDoc
+ from docarray.typing import ImageTensor
+
+
+ class MyImageDoc(BaseDoc):
+ image: ImageTensor
+
+
+ # Example usage with TensorFlow:
+ import tensorflow as tf
+
+ doc = MyImageDoc(image=tf.zeros((1000, 2)))
+ type(doc.image) # ImageTensorFlowTensor
+
+ # Example usage with PyTorch:
+ import torch
+
+ doc = MyImageDoc(image=torch.zeros((1000, 2)))
+ type(doc.image) # ImageTorchTensor
+
+ # Example usage with NumPy:
+ import numpy as np
+
+ doc = MyImageDoc(image=np.zeros((1000, 2)))
+ type(doc.image) # ImageNdArray
+ '''
+ ---
+
+ Returns:
+ Union[ImageTorchTensor, ImageTensorFlowTensor, ImageNdArray]: The validated and converted image tensor.
+
+ Raises:
+ TypeError: If the input type is not one of [torch.Tensor, tensorflow.Tensor, numpy.ndarray].
+ """
-ImageTensor = Union[ImageNdArray] # type: ignore
-if tf_available and torch_available:
- ImageTensor = Union[ImageNdArray, ImageTorchTensor, ImageTFTensor] # type: ignore
-elif tf_available:
- ImageTensor = Union[ImageNdArray, ImageTFTensor] # type: ignore
-elif torch_available:
- ImageTensor = Union[ImageNdArray, ImageTorchTensor] # type: ignore
+ @classmethod
+ def _docarray_validate(
+ cls: Type[T],
+ value: Union[T, np.ndarray, Any],
+ ):
+ if torch_available:
+ if isinstance(value, TorchTensor):
+ return cast(ImageTorchTensor, value)
+ elif isinstance(value, torch.Tensor):
+ return ImageTorchTensor._docarray_from_native(value) # noqa
+ if tf_available:
+ if isinstance(value, TensorFlowTensor):
+ return cast(ImageTensorFlowTensor, value)
+ elif isinstance(value, tf.Tensor):
+ return ImageTensorFlowTensor._docarray_from_native(value) # noqa
+ if jax_available:
+ if isinstance(value, JaxArray):
+ return cast(ImageJaxArray, value)
+ elif isinstance(value, jnp.ndarray):
+ return ImageJaxArray._docarray_from_native(value) # noqa
+ try:
+ return ImageNdArray._docarray_validate(value)
+ except Exception: # noqa
+ pass
+ raise TypeError(
+ f"Expected one of [torch.Tensor, tensorflow.Tensor, numpy.ndarray] "
+ f"compatible type, got {type(value)}"
+ )
diff --git a/docarray/typing/tensor/image/image_tensorflow_tensor.py b/docarray/typing/tensor/image/image_tensorflow_tensor.py
index f373f45b30e..2120df5626a 100644
--- a/docarray/typing/tensor/image/image_tensorflow_tensor.py
+++ b/docarray/typing/tensor/image/image_tensorflow_tensor.py
@@ -1,3 +1,18 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from typing import TypeVar
from docarray.typing.proto_register import _register_proto
diff --git a/docarray/typing/tensor/image/image_torch_tensor.py b/docarray/typing/tensor/image/image_torch_tensor.py
index 103a936d705..7edc5aaa5fa 100644
--- a/docarray/typing/tensor/image/image_torch_tensor.py
+++ b/docarray/typing/tensor/image/image_torch_tensor.py
@@ -1,3 +1,18 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from typing import TypeVar
from docarray.typing.proto_register import _register_proto
@@ -28,9 +43,9 @@ class ImageTorchTensor(AbstractImageTensor, TorchTensor, metaclass=metaTorchAndN
class MyImageDoc(BaseDoc):
title: str
- tensor: Optional[ImageTorchTensor]
- url: Optional[ImageUrl]
- bytes: Optional[ImageBytes]
+ tensor: Optional[ImageTorchTensor] = None
+ url: Optional[ImageUrl] = None
+ bytes: Optional[ImageBytes] = None
doc = MyImageDoc(
diff --git a/docarray/typing/tensor/jaxarray.py b/docarray/typing/tensor/jaxarray.py
new file mode 100644
index 00000000000..db49aa6bf29
--- /dev/null
+++ b/docarray/typing/tensor/jaxarray.py
@@ -0,0 +1,266 @@
+from typing import TYPE_CHECKING, Any, Generic, Type, TypeVar, Union, cast
+
+import numpy as np
+import orjson
+
+from docarray.typing.proto_register import _register_proto
+from docarray.typing.tensor.abstract_tensor import AbstractTensor
+from docarray.utils._internal.misc import import_library
+
+if TYPE_CHECKING:
+ import jax
+ import jax.numpy as jnp
+
+ from docarray.computation.jax_backend import JaxCompBackend
+ from docarray.proto import NdArrayProto
+else:
+ jax = import_library('jax', raise_error=True)
+ jnp = jax.numpy
+from docarray.base_doc.base_node import BaseNode
+
+T = TypeVar('T', bound='JaxArray')
+ShapeT = TypeVar('ShapeT')
+
+node_base: type = type(BaseNode)
+
+
+# the mypy error suppression below should not be necessary anymore once the following
+# is released in mypy: https://github.com/python/mypy/pull/14135
+class metaJax(
+ AbstractTensor.__parametrized_meta__, # type: ignore
+ node_base, # type: ignore
+): # type: ignore
+ pass
+
+
+@_register_proto(proto_type_name='jaxarray')
+class JaxArray(AbstractTensor, Generic[ShapeT], metaclass=metaJax):
+ """
+ Subclass of `jnp.ndarray`, intended for use in a Document.
+ This enables (de)serialization from/to protobuf and json, data validation,
+ and coercion from compatible types like `torch.Tensor`.
+
+ This type can also be used in a parametrized way, specifying the shape of the array.
+
+ ---
+
+ ```python
+ from docarray import BaseDoc
+ from docarray.typing import JaxArray
+ import jax.numpy as jnp
+
+
+ class MyDoc(BaseDoc):
+ arr: JaxArray
+ image_arr: JaxArray[3, 224, 224]
+ square_crop: JaxArray[3, 'x', 'x']
+ random_image: JaxArray[3, ...] # first dimension is fixed, can have arbitrary shape
+
+
+ # create a document with tensors
+ doc = MyDoc(
+ arr=jnp.zeros((128,)),
+ image_arr=jnp.zeros((3, 224, 224)),
+ square_crop=jnp.zeros((3, 64, 64)),
+ random_image=jnp.zeros((3, 128, 256)),
+ )
+ assert doc.image_arr.shape == (3, 224, 224)
+
+ # automatic shape conversion
+ doc = MyDoc(
+ arr=np.zeros((128,)),
+ image_arr=np.zeros((224, 224, 3)), # will reshape to (3, 224, 224)
+ square_crop=np.zeros((3, 128, 128)),
+ random_image=np.zeros((3, 64, 128)),
+ )
+ assert doc.image_arr.shape == (3, 224, 224)
+
+ # !! The following will raise an error due to shape mismatch !!
+ from pydantic import ValidationError
+
+ try:
+ doc = MyDoc(
+ arr=np.zeros((128,)),
+ image_arr=np.zeros((224, 224)), # this will fail validation
+ square_crop=np.zeros((3, 128, 64)), # this will also fail validation
+ random_image=np.zeros((4, 64, 128)), # this will also fail validation
+ )
+ except ValidationError as e:
+ pass
+ ```
+
+ ---
+ """
+
+ __parametrized_meta__ = metaJax
+
+ def __init__(self, tensor: jnp.ndarray):
+ super().__init__()
+ self.tensor = tensor
+
+ def __getitem__(self, item):
+ from docarray.computation.jax_backend import JaxCompBackend
+
+ tensor = self.unwrap()
+ if tensor is not None:
+ tensor = tensor[item]
+ return JaxCompBackend._cast_output(t=tensor)
+
+ def __setitem__(self, index, value):
+ """"""
+ # print(index, value)
+ self.tensor = self.tensor.at[index : index + 1].set(value)
+
+ def __iter__(self):
+ for i in range(len(self)):
+ yield self[i]
+
+ def __len__(self):
+ return len(self.tensor)
+
+ @classmethod
+ def __get_validators__(cls):
+ # one or more validators may be yielded which will be called in the
+ # order to validate the input, each validator will receive as an input
+ # the value returned from the previous validator
+ yield cls.validate
+
+ @classmethod
+ def _docarray_validate(
+ cls: Type[T],
+ value: Union[T, np.ndarray, str, Any],
+ ) -> T:
+ if isinstance(value, jax.Array):
+ return cls._docarray_from_native(value)
+ elif isinstance(value, JaxArray):
+ return cast(T, value)
+ elif isinstance(value, list) or isinstance(value, tuple):
+ try:
+ arr_from_list: jnp.ndarray = jnp.asarray(value)
+ return cls._docarray_from_native(arr_from_list)
+ except Exception:
+ pass # handled below
+ elif isinstance(value, str):
+ value = orjson.loads(value)
+
+ try:
+ arr: jnp.ndarray = jnp.ndarray(value)
+ return cls._docarray_from_native(arr)
+ except Exception:
+ pass # handled below
+
+ raise ValueError(f'Expected a numpy.ndarray compatible type, got {type(value)}')
+
+ @classmethod
+ def _docarray_from_native(cls: Type[T], value: jnp.ndarray) -> T:
+ if isinstance(value, JaxArray):
+ if cls.__unparametrizedcls__: # None if the tensor is parametrized
+ value.__class__ = cls.__unparametrizedcls__ # type: ignore
+ else:
+ value.__class__ = cls # type: ignore
+ return cast(T, value)
+ else:
+ if cls.__unparametrizedcls__: # None if the tensor is parametrized
+ cls_param_ = cls.__unparametrizedcls__
+ cls_param = cast(Type[T], cls_param_)
+ else:
+ cls_param = cls
+
+ return cls_param(tensor=value)
+
+ @classmethod
+ def from_ndarray(cls: Type[T], value: np.ndarray) -> T:
+ """Create a `TensorFlowTensor` from a numpy array.
+
+ :param value: the numpy array
+ :return: a `TensorFlowTensor`
+ """
+ return cls._docarray_from_native(jnp.array(value))
+
+ def _docarray_to_json_compatible(self) -> jnp.ndarray:
+ """
+ Convert `JaxArray` into a json compatible object
+ :return: a representation of the tensor compatible with orjson
+ """
+ return self.unwrap()
+
+ def unwrap(self) -> jnp.ndarray:
+ """
+ Return the original jax ndarray without making a copy in memory.
+
+ The original view remains intact and is still a Document `JaxArray`
+ but the return object is a pure `np.ndarray` and both objects share
+ the same underlying memory.
+
+ ---
+
+ ```python
+ from docarray.typing import JaxArray
+ import jax.numpy as jnp
+ from pydantic import parse_obj_as
+
+ t1 = parse_obj_as(JaxArray, jnp.zeros((3, 224, 224)))
+ # here t1 is a docarray JaxArray
+ t2 = t1.unwrap()
+ # here t2 is a pure jnp.ndarray but t1 is still a Docarray JaxArray
+ # But both share the same underlying memory
+ ```
+
+ ---
+
+ :return: a `jnp.ndarray`
+ """
+ return self.tensor
+
+ @classmethod
+ def from_protobuf(cls: Type[T], pb_msg: 'NdArrayProto') -> 'T':
+ """
+ Read ndarray from a proto msg
+ :param pb_msg:
+ :return: a numpy array
+ """
+ source = pb_msg.dense
+ if source.buffer:
+ x = np.frombuffer(bytearray(source.buffer), dtype=source.dtype)
+ return cls.from_ndarray(x.reshape(source.shape))
+ elif len(source.shape) > 0:
+ return cls.from_ndarray(np.zeros(source.shape))
+ else:
+ raise ValueError(
+ f'Proto message {pb_msg} cannot be cast to a TensorFlowTensor.'
+ )
+
+ def to_protobuf(self) -> 'NdArrayProto':
+ """
+ Transform self into a NdArrayProto protobuf message
+ """
+ from docarray.proto import NdArrayProto
+
+ nd_proto = NdArrayProto()
+
+ value_np = self.tensor
+ nd_proto.dense.buffer = value_np.tobytes()
+ nd_proto.dense.ClearField('shape')
+ nd_proto.dense.shape.extend(list(value_np.shape))
+ nd_proto.dense.dtype = value_np.dtype.str
+
+ return nd_proto
+
+ @staticmethod
+ def get_comp_backend() -> 'JaxCompBackend':
+ """Return the computational backend of the tensor"""
+ from docarray.computation.jax_backend import JaxCompBackend
+
+ return JaxCompBackend()
+
+ def __class_getitem__(cls, item: Any, *args, **kwargs):
+ # see here for mypy bug: https://github.com/python/mypy/issues/14123
+ return AbstractTensor.__class_getitem__.__func__(cls, item) # type: ignore
+
+ @classmethod
+ def _docarray_from_ndarray(cls: Type[T], value: np.ndarray) -> T:
+ return cls.from_ndarray(value)
+
+ def _docarray_to_ndarray(self) -> np.ndarray:
+ """cast itself to a numpy array"""
+ return self.tensor.__array__()
diff --git a/docarray/typing/tensor/ndarray.py b/docarray/typing/tensor/ndarray.py
index 18e84050a25..6ea94dc6979 100644
--- a/docarray/typing/tensor/ndarray.py
+++ b/docarray/typing/tensor/ndarray.py
@@ -1,18 +1,40 @@
from typing import TYPE_CHECKING, Any, Generic, List, Tuple, Type, TypeVar, Union, cast
import numpy as np
+import orjson
+from docarray.base_doc.base_node import BaseNode
from docarray.typing.proto_register import _register_proto
from docarray.typing.tensor.abstract_tensor import AbstractTensor
+from docarray.utils._internal.misc import ( # noqa
+ is_jax_available,
+ is_tf_available,
+ is_torch_available,
+)
+
+jax_available = is_jax_available()
+if jax_available:
+ import jax.numpy as jnp
+
+ from docarray.typing.tensor.jaxarray import JaxArray # noqa: F401
+
+torch_available = is_torch_available()
+if torch_available:
+ import torch
+
+ from docarray.typing.tensor.torch_tensor import TorchTensor # noqa: F401
+
+tf_available = is_tf_available()
+if tf_available:
+ import tensorflow as tf # type: ignore
+
+ from docarray.typing.tensor.tensorflow_tensor import TensorFlowTensor # noqa: F401
if TYPE_CHECKING:
- from pydantic import BaseConfig
- from pydantic.fields import ModelField
from docarray.computation.numpy_backend import NumpyCompBackend
from docarray.proto import NdArrayProto
-from docarray.base_doc.base_node import BaseNode
T = TypeVar('T', bound='NdArray')
ShapeT = TypeVar('ShapeT')
@@ -31,7 +53,7 @@ class NdArray(np.ndarray, AbstractTensor, Generic[ShapeT]):
"""
Subclass of `np.ndarray`, intended for use in a Document.
This enables (de)serialization from/to protobuf and json, data validation,
- and coersion from compatible types like `torch.Tensor`.
+ and coercion from compatible types like `torch.Tensor`.
This type can also be used in a parametrized way, specifying the shape of the array.
@@ -88,35 +110,38 @@ class MyDoc(BaseDoc):
__parametrized_meta__ = metaNumpy
@classmethod
- def __get_validators__(cls):
- # one or more validators may be yielded which will be called in the
- # order to validate the input, each validator will receive as an input
- # the value returned from the previous validator
- yield cls.validate
-
- @classmethod
- def validate(
+ def _docarray_validate(
cls: Type[T],
- value: Union[T, np.ndarray, List[Any], Tuple[Any], Any],
- field: 'ModelField',
- config: 'BaseConfig',
+ value: Union[T, np.ndarray, str, List[Any], Tuple[Any], Any],
) -> T:
+
+ if isinstance(value, str):
+ value = orjson.loads(value)
+
if isinstance(value, np.ndarray):
return cls._docarray_from_native(value)
elif isinstance(value, NdArray):
return cast(T, value)
+ elif isinstance(value, AbstractTensor):
+ return cls._docarray_from_native(value._docarray_to_ndarray())
+ elif torch_available and isinstance(value, torch.Tensor):
+ return cls._docarray_from_native(value.detach().cpu().numpy())
+ elif tf_available and isinstance(value, tf.Tensor):
+ return cls._docarray_from_native(value.numpy())
+
+ elif jax_available and isinstance(value, jnp.ndarray):
+ return cls._docarray_from_native(value.__array__())
elif isinstance(value, list) or isinstance(value, tuple):
try:
arr_from_list: np.ndarray = np.asarray(value)
return cls._docarray_from_native(arr_from_list)
except Exception:
pass # handled below
- else:
- try:
- arr: np.ndarray = np.ndarray(value)
- return cls._docarray_from_native(arr)
- except Exception:
- pass # handled below
+ try:
+ arr: np.ndarray = np.ndarray(value)
+ return cls._docarray_from_native(arr)
+ except Exception:
+ pass # handled below
raise ValueError(f'Expected a numpy.ndarray compatible type, got {type(value)}')
@classmethod
@@ -145,9 +170,9 @@ def unwrap(self) -> np.ndarray:
```python
from docarray.typing import NdArray
import numpy as np
+ from pydantic import parse_obj_as
- t1 = NdArray.validate(np.zeros((3, 224, 224)), None, None)
- # here t1 is a docarray NdArray
+ t1 = parse_obj_as(NdArray, np.zeros((3, 224, 224)))
t2 = t1.unwrap()
# here t2 is a pure np.ndarray but t1 is still a Docarray NdArray
# But both share the same underlying memory
@@ -200,3 +225,18 @@ def get_comp_backend() -> 'NumpyCompBackend':
def __class_getitem__(cls, item: Any, *args, **kwargs):
# see here for mypy bug: https://github.com/python/mypy/issues/14123
return AbstractTensor.__class_getitem__.__func__(cls, item) # type: ignore
+
+ @classmethod
+ def _docarray_from_ndarray(cls: Type[T], value: np.ndarray) -> T:
+ """Create a `tensor from a numpy array
+ PS: this function is different from `from_ndarray` because it is private under the docarray namesapce.
+ This allows us to avoid breaking change if one day we introduce a Tensor backend with a `from_ndarray` method.
+ """
+ return cls._docarray_from_native(value)
+
+ def _docarray_to_ndarray(self) -> np.ndarray:
+ """Create a `tensor from a numpy array
+ PS: this function is different from `from_ndarray` because it is private under the docarray namesapce.
+ This allows us to avoid breaking change if one day we introduce a Tensor backend with a `from_ndarray` method.
+ """
+ return self.unwrap()
diff --git a/docarray/typing/tensor/tensor.py b/docarray/typing/tensor/tensor.py
index 19a89cbed3f..d515f33bfaa 100644
--- a/docarray/typing/tensor/tensor.py
+++ b/docarray/typing/tensor/tensor.py
@@ -1,22 +1,150 @@
-from typing import Union
+from typing import TYPE_CHECKING, Any, Generic, Type, TypeVar, Union
+import numpy as np
+
+from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.typing.tensor.ndarray import NdArray
-from docarray.utils._internal.misc import is_tf_available, is_torch_available
+from docarray.utils._internal.misc import ( # noqa
+ is_jax_available,
+ is_tf_available,
+ is_torch_available,
+)
+
+jax_available = is_jax_available()
+if jax_available:
+ import jax.numpy as jnp
+
+ from docarray.typing.tensor.jaxarray import JaxArray # noqa: F401
torch_available = is_torch_available()
if torch_available:
- from docarray.typing.tensor.torch_tensor import TorchTensor # noqa: F401
+ import torch
+ from docarray.typing.tensor.torch_tensor import TorchTensor # noqa: F401
tf_available = is_tf_available()
if tf_available:
+ import tensorflow as tf # type: ignore
+
from docarray.typing.tensor.tensorflow_tensor import TensorFlowTensor # noqa: F401
-AnyTensor = Union[NdArray]
-if torch_available and tf_available:
- AnyTensor = Union[NdArray, TorchTensor, TensorFlowTensor] # type: ignore
-elif torch_available:
- AnyTensor = Union[NdArray, TorchTensor] # type: ignore
-elif tf_available:
- AnyTensor = Union[NdArray, TensorFlowTensor] # type: ignore
+if TYPE_CHECKING:
+
+ # Below is the hack to make the type checker happy. But `AnyTensor` is defined as a class and with same underlying
+ # behavior as `Union[TorchTensor, TensorFlowTensor, NdArray]` so it should be fine to use `AnyTensor` as
+ # the type for `tensor` field in `BaseDoc` class.
+ AnyTensor = Union[NdArray]
+ if torch_available and tf_available and jax_available:
+ AnyTensor = Union[NdArray, TorchTensor, TensorFlowTensor, JaxArray] # type: ignore
+ elif torch_available and tf_available:
+ AnyTensor = Union[NdArray, TorchTensor, TensorFlowTensor] # type: ignore
+ elif tf_available and jax_available:
+ AnyTensor = Union[NdArray, TensorFlowTensor, JaxArray] # type: ignore
+ elif torch_available and jax_available:
+ AnyTensor = Union[NdArray, TorchTensor, JaxArray] # type: ignore
+ elif tf_available:
+ AnyTensor = Union[NdArray, TensorFlowTensor] # type: ignore
+ elif torch_available:
+ AnyTensor = Union[NdArray, TorchTensor] # type: ignore
+ elif jax_available:
+ AnyTensor = Union[NdArray, JaxArray] # type: ignore
+
+else:
+
+ T = TypeVar("T", bound="AnyTensor")
+ ShapeT = TypeVar('ShapeT')
+
+ class AnyTensor(AbstractTensor, Generic[ShapeT]):
+ """
+ Represents a tensor object that can be used with TensorFlow, PyTorch, and NumPy type.
+ !!! note:
+ when doing type checking (mypy or pycharm type checker), this class will actually be replace by a Union of the three
+ tensor types. You can reason about this class as if it was a Union.
+
+ ```python
+ from docarray import BaseDoc
+ from docarray.typing import AnyTensor
+
+
+ class MyTensorDoc(BaseDoc):
+ tensor: AnyTensor
+
+
+ # Example usage with TensorFlow:
+ # import tensorflow as tf
+
+ # doc = MyTensorDoc(tensor=tf.zeros(1000, 2))
+
+ # Example usage with PyTorch:
+ import torch
+
+ doc = MyTensorDoc(tensor=torch.zeros(1000, 2))
+
+ # Example usage with NumPy:
+ import numpy as np
+
+ doc = MyTensorDoc(tensor=np.zeros((1000, 2)))
+ ```
+ """
+
+ def __getitem__(self: T, item):
+ pass
+
+ def __setitem__(self, index, value):
+ pass
+
+ def __iter__(self):
+ pass
+
+ def __len__(self):
+ pass
+
+ @classmethod
+ def _docarray_from_native(cls: Type[T], value: Any):
+ raise RuntimeError(f'This method should not be called on {cls}.')
+
+ @staticmethod
+ def get_comp_backend():
+ raise RuntimeError('This method should not be called on AnyTensor.')
+
+ def to_protobuf(self):
+ raise RuntimeError(f'This method should not be called on {self.__class__}.')
+
+ def _docarray_to_json_compatible(self):
+ raise RuntimeError(f'This method should not be called on {self.__class__}.')
+
+ @classmethod
+ def from_protobuf(cls: Type[T], pb_msg: T):
+ raise RuntimeError(f'This method should not be called on {cls}.')
+
+ @classmethod
+ def _docarray_validate(
+ cls: Type[T],
+ value: Union[T, np.ndarray, Any],
+ ):
+ # Check for TorchTensor first, then TensorFlowTensor, then NdArray
+ if torch_available:
+ if isinstance(value, TorchTensor):
+ return value
+ elif isinstance(value, torch.Tensor):
+ return TorchTensor._docarray_from_native(value) # noqa
+ if tf_available:
+ if isinstance(value, TensorFlowTensor):
+ return value
+ elif isinstance(value, tf.Tensor):
+ return TensorFlowTensor._docarray_from_native(value) # noqa
+ if jax_available:
+ if isinstance(value, JaxArray):
+ return value
+ elif isinstance(value, jnp.ndarray):
+ return JaxArray._docarray_from_native(value) # noqa
+ try:
+ return NdArray._docarray_validate(value)
+ except Exception as e: # noqa
+ print(e)
+ pass
+ raise TypeError(
+ f"Expected one of [torch.Tensor, tensorflow.Tensor, numpy.ndarray] "
+ f"compatible type, got {type(value)}"
+ )
diff --git a/docarray/typing/tensor/tensorflow_tensor.py b/docarray/typing/tensor/tensorflow_tensor.py
index 5b9d53a76ab..8a66dcc0864 100644
--- a/docarray/typing/tensor/tensorflow_tensor.py
+++ b/docarray/typing/tensor/tensorflow_tensor.py
@@ -1,22 +1,32 @@
from typing import TYPE_CHECKING, Any, Generic, Type, TypeVar, Union, cast
import numpy as np
+import orjson
from docarray.base_doc.base_node import BaseNode
from docarray.typing.proto_register import _register_proto
from docarray.typing.tensor.abstract_tensor import AbstractTensor
-from docarray.utils._internal.misc import import_library
+from docarray.utils._internal.misc import (
+ import_library,
+ is_jax_available,
+ is_torch_available,
+)
if TYPE_CHECKING:
import tensorflow as tf # type: ignore
- from pydantic import BaseConfig
- from pydantic.fields import ModelField
from docarray.computation.tensorflow_backend import TensorFlowCompBackend
from docarray.proto import NdArrayProto
else:
tf = import_library('tensorflow', raise_error=True)
+torch_available = is_torch_available()
+if torch_available:
+ import torch
+
+jax_available = is_jax_available()
+if jax_available:
+ import jax.numpy as jnp
T = TypeVar('T', bound='TensorFlowTensor')
ShapeT = TypeVar('ShapeT')
@@ -42,7 +52,7 @@ class TensorFlowTensor(AbstractTensor, Generic[ShapeT], metaclass=metaTensorFlow
intended for use in a Document.
This enables (de)serialization from/to protobuf and json, data validation,
- and coersion from compatible types like numpy.ndarray.
+ and coercion from compatible types like numpy.ndarray.
This type can also be used in a parametrized way, specifying the shape of the
tensor.
@@ -185,29 +195,31 @@ def __iter__(self):
yield self[i]
@classmethod
- def __get_validators__(cls):
- # one or more validators may be yielded which will be called in the
- # order to validate the input, each validator will receive as an input
- # the value returned from the previous validator
- yield cls.validate
-
- @classmethod
- def validate(
+ def _docarray_validate(
cls: Type[T],
- value: Union[T, np.ndarray, Any],
- field: 'ModelField',
- config: 'BaseConfig',
+ value: Union[T, np.ndarray, str, Any],
) -> T:
if isinstance(value, TensorFlowTensor):
return cast(T, value)
elif isinstance(value, tf.Tensor):
return cls._docarray_from_native(value)
- else:
- try:
- arr: tf.Tensor = tf.constant(value)
- return cls(tensor=arr)
- except Exception:
- pass # handled below
+ elif isinstance(value, np.ndarray):
+ return cls._docarray_from_ndarray(value)
+ elif isinstance(value, AbstractTensor):
+ return cls._docarray_from_ndarray(value._docarray_to_ndarray())
+ elif torch_available and isinstance(value, torch.Tensor):
+ return cls._docarray_from_native(value.detach().cpu().numpy())
+ elif jax_available and isinstance(value, jnp.ndarray):
+ return cls._docarray_from_native(value.__array__())
+ elif isinstance(value, str):
+ value = orjson.loads(value)
+
+ try:
+ arr: tf.Tensor = tf.constant(value)
+ return cls(tensor=arr)
+ except Exception:
+ pass # handled below
+
raise ValueError(
f'Expected a tensorflow.Tensor compatible type, got {type(value)}'
)
@@ -320,3 +332,19 @@ def unwrap(self) -> tf.Tensor:
def __len__(self) -> int:
return len(self.tensor)
+
+ @classmethod
+ def _docarray_from_ndarray(cls: Type[T], value: np.ndarray) -> T:
+ """Create a `tensor from a numpy array
+ PS: this function is different from `from_ndarray` because it is private under the docarray namesapce.
+ This allows us to avoid breaking change if one day we introduce a Tensor backend with a `from_ndarray` method.
+ """
+ return cls.from_ndarray(value)
+
+ def _docarray_to_ndarray(self) -> np.ndarray:
+ """cast itself to a numpy array"""
+ return self.tensor.numpy()
+
+ @property
+ def shape(self):
+ return tf.shape(self.tensor)
diff --git a/docarray/typing/tensor/torch_tensor.py b/docarray/typing/tensor/torch_tensor.py
index 3fb9246e46d..7ed3bd3800e 100644
--- a/docarray/typing/tensor/torch_tensor.py
+++ b/docarray/typing/tensor/torch_tensor.py
@@ -2,22 +2,32 @@
from typing import TYPE_CHECKING, Any, Generic, Type, TypeVar, Union, cast
import numpy as np
+import orjson
from docarray.base_doc.base_node import BaseNode
from docarray.typing.proto_register import _register_proto
from docarray.typing.tensor.abstract_tensor import AbstractTensor
-from docarray.utils._internal.misc import import_library
+from docarray.utils._internal.misc import (
+ import_library,
+ is_jax_available,
+ is_tf_available,
+)
if TYPE_CHECKING:
import torch
- from pydantic import BaseConfig
- from pydantic.fields import ModelField
from docarray.computation.torch_backend import TorchCompBackend
from docarray.proto import NdArrayProto
else:
torch = import_library('torch', raise_error=True)
+tf_available = is_tf_available()
+if tf_available:
+ import tensorflow as tf # type: ignore
+
+jax_available = is_jax_available()
+if jax_available:
+ import jax.numpy as jnp
T = TypeVar('T', bound='TorchTensor')
ShapeT = TypeVar('ShapeT')
@@ -48,7 +58,7 @@ class TorchTensor(
"""
Subclass of `torch.Tensor`, intended for use in a Document.
This enables (de)serialization from/to protobuf and json, data validation,
- and coersion from compatible types like numpy.ndarray.
+ and coercion from compatible types like numpy.ndarray.
This type can also be used in a parametrized way,
specifying the shape of the tensor.
@@ -101,35 +111,74 @@ class MyDoc(BaseDoc):
```
---
+
+
+ ## Compatibility with `torch.compile()`
+
+
+ PyTorch 2 [introduced compilation support](https://pytorch.org/blog/pytorch-2.0-release/) in the form of `torch.compile()`.
+
+ Currently, **`torch.compile()` does not properly support subclasses of `torch.Tensor` such as `TorchTensor`**.
+ The PyTorch team is currently working on a [fix for this issue](https://github.com/pytorch/pytorch/pull/105167#issuecomment-1678050808).
+
+ In the meantime, you can use the following workaround:
+
+ ### Workaround: Convert `TorchTensor` to `torch.Tensor` before calling `torch.compile()`
+
+ Converting any `TorchTensor`s tor `torch.Tensor` before calling `torch.compile()` side-steps the issue:
+
+ ```python
+ from docarray import BaseDoc
+ from docarray.typing import TorchTensor
+ import torch
+
+
+ class MyDoc(BaseDoc):
+ tensor: TorchTensor
+
+
+ doc = MyDoc(tensor=torch.zeros(128))
+
+
+ def foo(tensor: torch.Tensor):
+ return tensor @ tensor.t()
+
+
+ foo_compiled = torch.compile(foo)
+
+ # unwrap the tensor before passing it to torch.compile()
+ foo_compiled(doc.tensor.unwrap())
+ ```
+
"""
__parametrized_meta__ = metaTorchAndNode
@classmethod
- def __get_validators__(cls):
- # one or more validators may be yielded which will be called in the
- # order to validate the input, each validator will receive as an input
- # the value returned from the previous validator
- yield cls.validate
-
- @classmethod
- def validate(
+ def _docarray_validate(
cls: Type[T],
- value: Union[T, np.ndarray, Any],
- field: 'ModelField',
- config: 'BaseConfig',
+ value: Union[T, np.ndarray, str, Any],
) -> T:
if isinstance(value, TorchTensor):
return cast(T, value)
elif isinstance(value, torch.Tensor):
return cls._docarray_from_native(value)
+ elif isinstance(value, AbstractTensor):
+ return cls._docarray_from_ndarray(value._docarray_to_ndarray())
+ elif tf_available and isinstance(value, tf.Tensor):
+ return cls._docarray_from_ndarray(value.numpy())
+ elif isinstance(value, np.ndarray):
+ return cls._docarray_from_ndarray(value)
+ elif jax_available and isinstance(value, jnp.ndarray):
+ return cls._docarray_from_ndarray(value.__array__())
+ elif isinstance(value, str):
+ value = orjson.loads(value)
+ try:
+ arr: torch.Tensor = torch.tensor(value)
+ return cls._docarray_from_native(arr)
+ except Exception:
+ pass # handled below
- else:
- try:
- arr: torch.Tensor = torch.tensor(value)
- return cls._docarray_from_native(arr)
- except Exception:
- pass # handled below
raise ValueError(f'Expected a torch.Tensor compatible type, got {type(value)}')
def _docarray_to_json_compatible(self) -> np.ndarray:
@@ -137,7 +186,7 @@ def _docarray_to_json_compatible(self) -> np.ndarray:
Convert `TorchTensor` into a json compatible object
:return: a representation of the tensor compatible with orjson
"""
- return self.numpy() ## might need to check device later
+ return self.detach().numpy() # might need to check device later
def unwrap(self) -> torch.Tensor:
"""
@@ -152,8 +201,10 @@ def unwrap(self) -> torch.Tensor:
```python
from docarray.typing import TorchTensor
import torch
+ from pydantic import parse_obj_as
+
- t = TorchTensor.validate(torch.zeros(3, 224, 224), None, None)
+ t = parse_obj_as(TorchTensor, torch.zeros(3, 224, 224))
# here t is a docarray TorchTensor
t2 = t.unwrap()
# here t2 is a pure torch.Tensor but t1 is still a Docarray TorchTensor
@@ -241,3 +292,32 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
torch.Tensor if t in docarray_torch_tensors else t for t in types
)
return super().__torch_function__(func, types_, args, kwargs)
+
+ def __deepcopy__(self, memo):
+ """
+ Custom implementation of deepcopy for TorchTensor to avoid storage sharing issues.
+ """
+ # Create a new tensor with the same data and properties
+ new_tensor = self.clone()
+ # Set the class to the custom TorchTensor class
+ new_tensor.__class__ = self.__class__
+ return new_tensor
+
+ @classmethod
+ def _docarray_from_ndarray(cls: Type[T], value: np.ndarray) -> T:
+ """Create a `tensor from a numpy array
+ PS: this function is different from `from_ndarray` because it is private under the docarray namesapce.
+ This allows us to avoid breaking change if one day we introduce a Tensor backend with a `from_ndarray` method.
+ """
+ return cls.from_ndarray(value)
+
+ def _docarray_to_ndarray(self) -> np.ndarray:
+ """cast itself to a numpy array"""
+ return self.detach().cpu().numpy()
+
+ def new_empty(self, *args, **kwargs):
+ """
+ This method enables the deepcopy of `TorchTensor` by returning another instance of this subclass.
+ If this function is not implemented, the deepcopy will throw an RuntimeError from Torch.
+ """
+ return self.__class__(*args, **kwargs)
diff --git a/docarray/typing/tensor/video/__init__.py b/docarray/typing/tensor/video/__init__.py
index a575e7b6201..18f0a2e5d8b 100644
--- a/docarray/typing/tensor/video/__init__.py
+++ b/docarray/typing/tensor/video/__init__.py
@@ -10,6 +10,7 @@
)
if TYPE_CHECKING:
+ from docarray.typing.tensor.video.video_jax_array import VideoJaxArray # noqa
from docarray.typing.tensor.video.video_tensorflow_tensor import ( # noqa
VideoTensorFlowTensor,
)
@@ -26,6 +27,9 @@ def __getattr__(name: str):
elif name == 'VideoTensorFlowTensor':
import_library('tensorflow', raise_error=True)
import docarray.typing.tensor.video.video_tensorflow_tensor as lib
+ elif name == 'VideoJaxArray':
+ import_library('jax', raise_error=True)
+ import docarray.typing.tensor.video.video_jax_array as lib
else:
raise ImportError(
f'cannot import name \'{name}\' from \'{_get_path_from_docarray_root_level(__file__)}\''
diff --git a/docarray/typing/tensor/video/video_jax_array.py b/docarray/typing/tensor/video/video_jax_array.py
new file mode 100644
index 00000000000..07aecea7439
--- /dev/null
+++ b/docarray/typing/tensor/video/video_jax_array.py
@@ -0,0 +1,43 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING, Any, List, Tuple, Type, TypeVar, Union
+
+import numpy as np
+
+from docarray.typing.proto_register import _register_proto
+from docarray.typing.tensor.jaxarray import JaxArray, metaJax
+from docarray.typing.tensor.video.video_tensor_mixin import VideoTensorMixin
+
+T = TypeVar('T', bound='VideoJaxArray')
+
+if TYPE_CHECKING:
+ from pydantic import BaseConfig
+ from pydantic.fields import ModelField
+
+
+@_register_proto(proto_type_name='video_jaxarray')
+class VideoJaxArray(JaxArray, VideoTensorMixin, metaclass=metaJax):
+ """ """
+
+ @classmethod
+ def validate(
+ cls: Type[T],
+ value: Union[T, np.ndarray, List[Any], Tuple[Any], Any],
+ field: 'ModelField',
+ config: 'BaseConfig',
+ ) -> T:
+ tensor = super().validate(value=value, field=field, config=config)
+ return cls.validate_shape(value=tensor)
diff --git a/docarray/typing/tensor/video/video_ndarray.py b/docarray/typing/tensor/video/video_ndarray.py
index 5b11e75bd94..30129e40f97 100644
--- a/docarray/typing/tensor/video/video_ndarray.py
+++ b/docarray/typing/tensor/video/video_ndarray.py
@@ -1,4 +1,4 @@
-from typing import TYPE_CHECKING, Any, List, Tuple, Type, TypeVar, Union
+from typing import Any, List, Tuple, Type, TypeVar, Union
import numpy as np
@@ -8,10 +8,6 @@
T = TypeVar('T', bound='VideoNdArray')
-if TYPE_CHECKING:
- from pydantic import BaseConfig
- from pydantic.fields import ModelField
-
@_register_proto(proto_type_name='video_ndarray')
class VideoNdArray(NdArray, VideoTensorMixin):
@@ -33,8 +29,8 @@ class VideoNdArray(NdArray, VideoTensorMixin):
class MyVideoDoc(BaseDoc):
title: str
- url: Optional[VideoUrl]
- video_tensor: Optional[VideoNdArray]
+ url: Optional[VideoUrl] = None
+ video_tensor: Optional[VideoNdArray] = None
doc_1 = MyVideoDoc(
@@ -55,11 +51,9 @@ class MyVideoDoc(BaseDoc):
"""
@classmethod
- def validate(
+ def _docarray_validate(
cls: Type[T],
value: Union[T, np.ndarray, List[Any], Tuple[Any], Any],
- field: 'ModelField',
- config: 'BaseConfig',
) -> T:
- tensor = super().validate(value=value, field=field, config=config)
+ tensor = super()._docarray_validate(value=value)
return cls.validate_shape(value=tensor)
diff --git a/docarray/typing/tensor/video/video_tensor.py b/docarray/typing/tensor/video/video_tensor.py
index fa4fbb47d6b..56f91b14731 100644
--- a/docarray/typing/tensor/video/video_tensor.py
+++ b/docarray/typing/tensor/video/video_tensor.py
@@ -1,24 +1,114 @@
-from typing import Union
+from typing import Any, Type, TypeVar, Union, cast
+import numpy as np
+
+from docarray.typing.tensor.tensor import AnyTensor
from docarray.typing.tensor.video.video_ndarray import VideoNdArray
-from docarray.utils._internal.misc import is_tf_available, is_torch_available
+from docarray.typing.tensor.video.video_tensor_mixin import VideoTensorMixin
+from docarray.utils._internal.misc import (
+ is_jax_available,
+ is_tf_available,
+ is_torch_available,
+)
+
+jax_available = is_jax_available()
+if jax_available:
+ import jax.numpy as jnp
+
+ from docarray.typing.tensor.jaxarray import JaxArray # noqa: F401
+ from docarray.typing.tensor.video.video_jax_array import VideoJaxArray
torch_available = is_torch_available()
if torch_available:
+ import torch
+
+ from docarray.typing.tensor.torch_tensor import TorchTensor # noqa: F401
from docarray.typing.tensor.video.video_torch_tensor import VideoTorchTensor
tf_available = is_tf_available()
if tf_available:
+ import tensorflow as tf # type: ignore
+
+ from docarray.typing.tensor.tensorflow_tensor import TensorFlowTensor # noqa: F401
from docarray.typing.tensor.video.video_tensorflow_tensor import (
- VideoTensorFlowTensor as VideoTFTensor,
+ VideoTensorFlowTensor,
)
-if tf_available and torch_available:
- VideoTensor = Union[VideoNdArray, VideoTorchTensor, VideoTFTensor] # type: ignore
-elif tf_available:
- VideoTensor = Union[VideoNdArray, VideoTFTensor] # type: ignore
-elif torch_available:
- VideoTensor = Union[VideoNdArray, VideoTorchTensor] # type: ignore
-else:
- VideoTensor = Union[VideoNdArray] # type: ignore
+
+T = TypeVar("T", bound="VideoTensor")
+
+
+class VideoTensor(AnyTensor, VideoTensorMixin):
+ """
+ Represents a Video tensor object that can be used with TensorFlow, PyTorch, and NumPy type.
+
+ ---
+ '''python
+ from docarray import BaseDoc
+ from docarray.typing import VideoTensor
+
+
+ class MyVideoDoc(BaseDoc):
+ video: VideoTensor
+
+
+ # Example usage with TensorFlow:
+ import tensorflow as tf
+
+ doc = MyVideoDoc(video=tf.zeros(1000, 2))
+ type(doc.video) # VideoTensorFlowTensor
+
+ # Example usage with PyTorch:
+ import torch
+
+ doc = MyVideoDoc(video=torch.zeros(1000, 2))
+ type(doc.video) # VideoTorchTensor
+
+ # Example usage with NumPy:
+ import numpy as np
+
+ doc = MyVideoDoc(video=np.zeros((1000, 2)))
+ type(doc.video) # VideoNdArray
+ '''
+ ---
+
+ Returns:
+ Union[VideoTorchTensor, VideoTensorFlowTensor, VideoNdArray]: The validated and converted audio tensor.
+
+ Raises:
+ TypeError: If the input value is not a compatible type (torch.Tensor, tensorflow.Tensor, numpy.ndarray).
+
+ """
+
+ @classmethod
+ def _docarray_validate(
+ cls: Type[T],
+ value: Union[T, np.ndarray, Any],
+ ):
+ if torch_available:
+ if isinstance(value, TorchTensor):
+ return cast(VideoTorchTensor, value)
+ elif isinstance(value, torch.Tensor):
+ return VideoTorchTensor._docarray_from_native(value) # noqa
+ if tf_available:
+ if isinstance(value, TensorFlowTensor):
+ return cast(VideoTensorFlowTensor, value)
+ elif isinstance(value, tf.Tensor):
+ return VideoTensorFlowTensor._docarray_from_native(value) # noqa
+ if jax_available:
+ if isinstance(value, JaxArray):
+ return cast(VideoJaxArray, value)
+ elif isinstance(value, jnp.ndarray):
+ return VideoJaxArray._docarray_from_native(value) # noqa
+ if isinstance(value, VideoNdArray):
+ return cast(VideoNdArray, value)
+ if isinstance(value, np.ndarray):
+ try:
+ return VideoNdArray._docarray_validate(value)
+ except Exception as e: # noqa
+ raise e
+ raise TypeError(
+ f"Expected one of [torch.Tensor, tensorflow.Tensor, numpy.ndarray] "
+ f"compatible type, got {type(value)}"
+ )
diff --git a/docarray/typing/tensor/video/video_tensorflow_tensor.py b/docarray/typing/tensor/video/video_tensorflow_tensor.py
index d98794f8aa3..940a85a012b 100644
--- a/docarray/typing/tensor/video/video_tensorflow_tensor.py
+++ b/docarray/typing/tensor/video/video_tensorflow_tensor.py
@@ -1,4 +1,4 @@
-from typing import TYPE_CHECKING, Any, List, Tuple, Type, TypeVar, Union
+from typing import Any, List, Tuple, Type, TypeVar, Union
import numpy as np
@@ -8,10 +8,6 @@
T = TypeVar('T', bound='VideoTensorFlowTensor')
-if TYPE_CHECKING:
- from pydantic import BaseConfig
- from pydantic.fields import ModelField
-
@_register_proto(proto_type_name='video_tensorflow_tensor')
class VideoTensorFlowTensor(
@@ -57,11 +53,9 @@ class MyVideoDoc(BaseDoc):
"""
@classmethod
- def validate(
+ def _docarray_validate(
cls: Type[T],
value: Union[T, np.ndarray, List[Any], Tuple[Any], Any],
- field: 'ModelField',
- config: 'BaseConfig',
) -> T:
- tensor = super().validate(value=value, field=field, config=config)
+ tensor = super()._docarray_validate(value=value)
return cls.validate_shape(value=tensor)
diff --git a/docarray/typing/tensor/video/video_torch_tensor.py b/docarray/typing/tensor/video/video_torch_tensor.py
index dd4c5a5dcd3..a1e1a73e33a 100644
--- a/docarray/typing/tensor/video/video_torch_tensor.py
+++ b/docarray/typing/tensor/video/video_torch_tensor.py
@@ -1,4 +1,4 @@
-from typing import TYPE_CHECKING, Any, List, Tuple, Type, TypeVar, Union
+from typing import Any, List, Tuple, Type, TypeVar, Union
import numpy as np
@@ -8,10 +8,6 @@
T = TypeVar('T', bound='VideoTorchTensor')
-if TYPE_CHECKING:
- from pydantic import BaseConfig
- from pydantic.fields import ModelField
-
@_register_proto(proto_type_name='video_torch_tensor')
class VideoTorchTensor(TorchTensor, VideoTensorMixin, metaclass=metaTorchAndNode):
@@ -32,8 +28,8 @@ class VideoTorchTensor(TorchTensor, VideoTensorMixin, metaclass=metaTorchAndNode
class MyVideoDoc(BaseDoc):
title: str
- url: Optional[VideoUrl]
- video_tensor: Optional[VideoTorchTensor]
+ url: Optional[VideoUrl] = None
+ video_tensor: Optional[VideoTorchTensor] = None
doc_1 = MyVideoDoc(
@@ -56,11 +52,9 @@ class MyVideoDoc(BaseDoc):
"""
@classmethod
- def validate(
+ def _docarray_validate(
cls: Type[T],
value: Union[T, np.ndarray, List[Any], Tuple[Any], Any],
- field: 'ModelField',
- config: 'BaseConfig',
) -> T:
- tensor = super().validate(value=value, field=field, config=config)
+ tensor = super()._docarray_validate(value=value)
return cls.validate_shape(value=tensor)
diff --git a/docarray/typing/url/__init__.py b/docarray/typing/url/__init__.py
index b1a4416744d..f0483c43285 100644
--- a/docarray/typing/url/__init__.py
+++ b/docarray/typing/url/__init__.py
@@ -1,3 +1,18 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from docarray.typing.url.any_url import AnyUrl
from docarray.typing.url.audio_url import AudioUrl
from docarray.typing.url.image_url import ImageUrl
diff --git a/docarray/typing/url/any_url.py b/docarray/typing/url/any_url.py
index 6d930aa53f3..b7c5d71f835 100644
--- a/docarray/typing/url/any_url.py
+++ b/docarray/typing/url/any_url.py
@@ -1,8 +1,9 @@
+import mimetypes
import os
import urllib
import urllib.parse
import urllib.request
-from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar, Union
+from typing import TYPE_CHECKING, Any, List, Optional, Type, TypeVar, Union
import numpy as np
from pydantic import AnyUrl as BaseAnyUrl
@@ -10,148 +11,327 @@
from docarray.typing.abstract_type import AbstractType
from docarray.typing.proto_register import _register_proto
+from docarray.utils._internal.pydantic import is_pydantic_v2
+
+if is_pydantic_v2:
+ from pydantic_core import core_schema
if TYPE_CHECKING:
- from pydantic import BaseConfig
- from pydantic.fields import ModelField
+ if not is_pydantic_v2:
+ from pydantic import BaseConfig
+ from pydantic.fields import ModelField
+ else:
+ from pydantic import GetCoreSchemaHandler
+
from pydantic.networks import Parts
from docarray.proto import NodeProto
T = TypeVar('T', bound='AnyUrl')
+mimetypes.init([])
+
+# TODO need refactoring here
+# - code is duplicate in both version
+# - validation is very dummy for pydantic v2
+
+if is_pydantic_v2:
-@_register_proto(proto_type_name='any_url')
-class AnyUrl(BaseAnyUrl, AbstractType):
- host_required = (
- False # turn off host requirement to allow passing of local paths as URL
- )
-
- def _to_node_protobuf(self) -> 'NodeProto':
- """Convert Document into a NodeProto protobuf message. This function should
- be called when the Document is nested into another Document that need to
- be converted into a protobuf
-
- :return: the nested item protobuf message
- """
- from docarray.proto import NodeProto
-
- return NodeProto(text=str(self), type=self._proto_type_name)
-
- @classmethod
- def validate(
- cls: Type[T],
- value: Union[T, np.ndarray, Any],
- field: 'ModelField',
- config: 'BaseConfig',
- ) -> T:
- import os
-
- abs_path: Union[T, np.ndarray, Any]
- if (
- isinstance(value, str)
- and not value.startswith('http')
- and not os.path.isabs(value)
+ @_register_proto(proto_type_name='any_url')
+ class AnyUrl(str, AbstractType): # todo dummy url for now
+ @classmethod
+ def _docarray_validate(
+ cls: Type[T],
+ value: Any,
+ _: Any,
):
- input_is_relative_path = True
- abs_path = os.path.abspath(value)
- else:
- input_is_relative_path = False
- abs_path = value
-
- url = super().validate(abs_path, field, config) # basic url validation
-
- if input_is_relative_path:
- return cls(str(value), scheme=None)
- else:
- return cls(str(url), scheme=None)
-
- @classmethod
- def validate_parts(cls, parts: 'Parts', validate_port: bool = True) -> 'Parts':
- """
- A method used to validate parts of a URL.
- Our URLs should be able to function both in local and remote settings.
- Therefore, we allow missing `scheme`, making it possible to pass a file
- path without prefix.
- If `scheme` is missing, we assume it is a local file path.
- """
- scheme = parts['scheme']
- if scheme is None:
- # allow missing scheme, unlike pydantic
- pass
-
- elif cls.allowed_schemes and scheme.lower() not in cls.allowed_schemes:
- raise errors.UrlSchemePermittedError(set(cls.allowed_schemes))
-
- if validate_port:
- cls._validate_port(parts['port'])
-
- user = parts['user']
- if cls.user_required and user is None:
- raise errors.UrlUserInfoError()
-
- return parts
-
- @classmethod
- def build(
- cls,
- *,
- scheme: str,
- user: Optional[str] = None,
- password: Optional[str] = None,
- host: str,
- port: Optional[str] = None,
- path: Optional[str] = None,
- query: Optional[str] = None,
- fragment: Optional[str] = None,
- **_kwargs: str,
- ) -> str:
- """
- Build a URL from its parts.
- The only difference from the pydantic implementation is that we allow
- missing `scheme`, making it possible to pass a file path without prefix.
- """
-
- # allow missing scheme, unlike pydantic
- scheme_ = scheme if scheme is not None else ''
- url = super().build(
- scheme=scheme_,
- user=user,
- password=password,
- host=host,
- port=port,
- path=path,
- query=query,
- fragment=fragment,
- **_kwargs,
+
+ if not cls.is_extension_allowed(value):
+ raise ValueError(
+ f"The file '{value}' is not in a valid format for class '{cls.__name__}'."
+ )
+
+ return cls(str(value))
+
+ def __get_pydantic_core_schema__(
+ cls, source: Type[Any], handler: Optional['GetCoreSchemaHandler'] = None
+ ) -> core_schema.CoreSchema:
+ return core_schema.with_info_after_validator_function(
+ cls._docarray_validate,
+ core_schema.str_schema(),
+ )
+
+ def load_bytes(self, timeout: Optional[float] = None) -> bytes:
+ """Convert url to bytes. This will either load or download the file and save
+ it into a bytes object.
+ :param timeout: timeout for urlopen. Only relevant if URI is not local
+ :return: bytes.
+ """
+ if urllib.parse.urlparse(self).scheme in {'http', 'https', 'data'}:
+ req = urllib.request.Request(
+ self, headers={'User-Agent': 'Mozilla/5.0'}
+ )
+ urlopen_kwargs = {'timeout': timeout} if timeout is not None else {}
+ with urllib.request.urlopen(req, **urlopen_kwargs) as fp: # type: ignore
+ return fp.read()
+ elif os.path.exists(self):
+ with open(self, 'rb') as fp:
+ return fp.read()
+ else:
+ raise FileNotFoundError(f'`{self}` is not a URL or a valid local path')
+
+ def _to_node_protobuf(self) -> 'NodeProto':
+ """Convert Document into a NodeProto protobuf message. This function should
+ be called when the Document is nested into another Document that need to
+ be converted into a protobuf
+
+ :return: the nested item protobuf message
+ """
+ from docarray.proto import NodeProto
+
+ return NodeProto(text=str(self), type=self._proto_type_name)
+
+ @classmethod
+ def from_protobuf(cls: Type[T], pb_msg: 'str') -> T:
+ """
+ Read url from a proto msg.
+ :param pb_msg:
+ :return: url
+ """
+ return parse_obj_as(cls, pb_msg)
+
+ @classmethod
+ def is_extension_allowed(cls, value: Any) -> bool:
+ """
+ Check if the file extension of the URL is allowed for this class.
+ First, it guesses the mime type of the file. If it fails to detect the
+ mime type, it then checks the extra file extensions.
+ Note: This method assumes that any URL without an extension is valid.
+
+ :param value: The URL or file path.
+ :return: True if the extension is allowed, False otherwise
+ """
+ if cls is AnyUrl:
+ return True
+
+ url_parts = value.split('?')
+ extension = cls._get_url_extension(value)
+ if not extension:
+ return True
+
+ mimetype, _ = mimetypes.guess_type(url_parts[0])
+ if mimetype and mimetype.startswith(cls.mime_type()):
+ return True
+
+ return extension in cls.extra_extensions()
+
+ @staticmethod
+ def _get_url_extension(url: str) -> str:
+ """
+ Extracts and returns the file extension from a given URL.
+ If no file extension is present, the function returns an empty string.
+
+
+ :param url: The URL to extract the file extension from.
+ :return: The file extension without the period, if one exists,
+ otherwise an empty string.
+ """
+
+ parsed_url = urllib.parse.urlparse(url)
+ ext = os.path.splitext(parsed_url.path)[1]
+ ext = ext[1:] if ext.startswith('.') else ext
+ return ext
+
+else:
+
+ @_register_proto(proto_type_name='any_url')
+ class AnyUrl(BaseAnyUrl, AbstractType):
+ host_required = (
+ False # turn off host requirement to allow passing of local paths as URL
)
- if scheme is None and url.startswith('://'):
- # remove the `://` prefix, since scheme is missing
- url = url[3:]
- return url
-
- @classmethod
- def from_protobuf(cls: Type[T], pb_msg: 'str') -> T:
- """
- Read url from a proto msg.
- :param pb_msg:
- :return: url
- """
- return parse_obj_as(cls, pb_msg)
-
- def load_bytes(self, timeout: Optional[float] = None) -> bytes:
- """Convert url to bytes. This will either load or download the file and save
- it into a bytes object.
- :param timeout: timeout for urlopen. Only relevant if URI is not local
- :return: bytes.
- """
- if urllib.parse.urlparse(self).scheme in {'http', 'https', 'data'}:
- req = urllib.request.Request(self, headers={'User-Agent': 'Mozilla/5.0'})
- urlopen_kwargs = {'timeout': timeout} if timeout is not None else {}
- with urllib.request.urlopen(req, **urlopen_kwargs) as fp: # type: ignore
- return fp.read()
- elif os.path.exists(self):
- with open(self, 'rb') as fp:
- return fp.read()
- else:
- raise FileNotFoundError(f'`{self}` is not a URL or a valid local path')
+
+ @classmethod
+ def mime_type(cls) -> str:
+ """Returns the mime type associated with the class."""
+ raise NotImplementedError
+
+ @classmethod
+ def extra_extensions(cls) -> List[str]:
+ """Returns a list of allowed file extensions for the class
+ that are not covered by the mimetypes library."""
+ raise NotImplementedError
+
+ def _to_node_protobuf(self) -> 'NodeProto':
+ """Convert Document into a NodeProto protobuf message. This function should
+ be called when the Document is nested into another Document that need to
+ be converted into a protobuf
+
+ :return: the nested item protobuf message
+ """
+ from docarray.proto import NodeProto
+
+ return NodeProto(text=str(self), type=self._proto_type_name)
+
+ @staticmethod
+ def _get_url_extension(url: str) -> str:
+ """
+ Extracts and returns the file extension from a given URL.
+ If no file extension is present, the function returns an empty string.
+
+
+ :param url: The URL to extract the file extension from.
+ :return: The file extension without the period, if one exists,
+ otherwise an empty string.
+ """
+
+ parsed_url = urllib.parse.urlparse(url)
+ ext = os.path.splitext(parsed_url.path)[1]
+ ext = ext[1:] if ext.startswith('.') else ext
+ return ext
+
+ @classmethod
+ def is_extension_allowed(cls, value: Any) -> bool:
+ """
+ Check if the file extension of the URL is allowed for this class.
+ First, it guesses the mime type of the file. If it fails to detect the
+ mime type, it then checks the extra file extensions.
+ Note: This method assumes that any URL without an extension is valid.
+
+ :param value: The URL or file path.
+ :return: True if the extension is allowed, False otherwise
+ """
+ if cls is AnyUrl:
+ return True
+
+ url_parts = value.split('?')
+ extension = cls._get_url_extension(value)
+ if not extension:
+ return True
+
+ mimetype, _ = mimetypes.guess_type(url_parts[0])
+ if mimetype and mimetype.startswith(cls.mime_type()):
+ return True
+
+ return extension in cls.extra_extensions()
+
+ @classmethod
+ def validate(
+ cls: Type[T],
+ value: Union[T, np.ndarray, Any],
+ field: 'ModelField',
+ config: 'BaseConfig',
+ ) -> T:
+ import os
+
+ abs_path: Union[T, np.ndarray, Any]
+ if (
+ isinstance(value, str)
+ and not value.startswith('http')
+ and not os.path.isabs(value)
+ ):
+ input_is_relative_path = True
+ abs_path = os.path.abspath(value)
+ else:
+ input_is_relative_path = False
+ abs_path = value
+
+ url = super().validate(abs_path, field, config) # basic url validation
+
+ if not cls.is_extension_allowed(value):
+ raise ValueError(
+ f"The file '{value}' is not in a valid format for class '{cls.__name__}'."
+ )
+
+ return cls(str(value if input_is_relative_path else url), scheme=None)
+
+ @classmethod
+ def validate_parts(cls, parts: 'Parts', validate_port: bool = True) -> 'Parts':
+ """
+ A method used to validate parts of a URL.
+ Our URLs should be able to function both in local and remote settings.
+ Therefore, we allow missing `scheme`, making it possible to pass a file
+ path without prefix.
+ If `scheme` is missing, we assume it is a local file path.
+ """
+ scheme = parts['scheme']
+ if scheme is None:
+ # allow missing scheme, unlike pydantic
+ pass
+
+ elif cls.allowed_schemes and scheme.lower() not in cls.allowed_schemes:
+ raise errors.UrlSchemePermittedError(set(cls.allowed_schemes))
+
+ if validate_port:
+ cls._validate_port(parts['port'])
+
+ user = parts['user']
+ if cls.user_required and user is None:
+ raise errors.UrlUserInfoError()
+
+ return parts
+
+ @classmethod
+ def build(
+ cls,
+ *,
+ scheme: str,
+ user: Optional[str] = None,
+ password: Optional[str] = None,
+ host: str,
+ port: Optional[str] = None,
+ path: Optional[str] = None,
+ query: Optional[str] = None,
+ fragment: Optional[str] = None,
+ **_kwargs: str,
+ ) -> str:
+ """
+ Build a URL from its parts.
+ The only difference from the pydantic implementation is that we allow
+ missing `scheme`, making it possible to pass a file path without prefix.
+ """
+
+ # allow missing scheme, unlike pydantic
+ scheme_ = scheme if scheme is not None else ''
+ url = super().build(
+ scheme=scheme_,
+ user=user,
+ password=password,
+ host=host,
+ port=port,
+ path=path,
+ query=query,
+ fragment=fragment,
+ **_kwargs,
+ )
+ if scheme is None and url.startswith('://'):
+ # remove the `://` prefix, since scheme is missing
+ url = url[3:]
+ return url
+
+ @classmethod
+ def from_protobuf(cls: Type[T], pb_msg: 'str') -> T:
+ """
+ Read url from a proto msg.
+ :param pb_msg:
+ :return: url
+ """
+ return parse_obj_as(cls, pb_msg)
+
+ def load_bytes(self, timeout: Optional[float] = None) -> bytes:
+ """Convert url to bytes. This will either load or download the file and save
+ it into a bytes object.
+ :param timeout: timeout for urlopen. Only relevant if URI is not local
+ :return: bytes.
+ """
+ if urllib.parse.urlparse(self).scheme in {'http', 'https', 'data'}:
+ req = urllib.request.Request(
+ self, headers={'User-Agent': 'Mozilla/5.0'}
+ )
+ urlopen_kwargs = {'timeout': timeout} if timeout is not None else {}
+ with urllib.request.urlopen(req, **urlopen_kwargs) as fp: # type: ignore
+ return fp.read()
+ elif os.path.exists(self):
+ with open(self, 'rb') as fp:
+ return fp.read()
+ else:
+ raise FileNotFoundError(f'`{self}` is not a URL or a valid local path')
diff --git a/docarray/typing/url/audio_url.py b/docarray/typing/url/audio_url.py
index a84a68754ee..95700681b84 100644
--- a/docarray/typing/url/audio_url.py
+++ b/docarray/typing/url/audio_url.py
@@ -1,10 +1,26 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import warnings
-from typing import Optional, Tuple, TypeVar
+from typing import List, Optional, Tuple, TypeVar
from docarray.typing import AudioNdArray
from docarray.typing.bytes.audio_bytes import AudioBytes
from docarray.typing.proto_register import _register_proto
from docarray.typing.url.any_url import AnyUrl
+from docarray.typing.url.mimetypes import AUDIO_MIMETYPE
from docarray.utils._internal.misc import is_notebook
T = TypeVar('T', bound='AudioUrl')
@@ -17,6 +33,18 @@ class AudioUrl(AnyUrl):
Can be remote (web) URL, or a local file path.
"""
+ @classmethod
+ def mime_type(cls) -> str:
+ return AUDIO_MIMETYPE
+
+ @classmethod
+ def extra_extensions(cls) -> List[str]:
+ """
+ Returns a list of additional file extensions that are valid for this class
+ but cannot be identified by the mimetypes library.
+ """
+ return []
+
def load(self: T) -> Tuple[AudioNdArray, int]:
"""
Load the data from the url into an [`AudioNdArray`][docarray.typing.AudioNdArray]
@@ -33,7 +61,7 @@ def load(self: T) -> Tuple[AudioNdArray, int]:
class MyDoc(BaseDoc):
audio_url: AudioUrl
- audio_tensor: Optional[AudioNdArray]
+ audio_tensor: Optional[AudioNdArray] = None
doc = MyDoc(audio_url='https://www.kozco.com/tech/piano2.wav')
diff --git a/docarray/typing/url/image_url.py b/docarray/typing/url/image_url.py
index 43758cf7436..8c6691a7ff6 100644
--- a/docarray/typing/url/image_url.py
+++ b/docarray/typing/url/image_url.py
@@ -1,10 +1,26 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import warnings
-from typing import TYPE_CHECKING, Optional, Tuple, TypeVar
+from typing import TYPE_CHECKING, List, Optional, Tuple, TypeVar
from docarray.typing import ImageBytes
from docarray.typing.proto_register import _register_proto
from docarray.typing.tensor.image import ImageNdArray
from docarray.typing.url.any_url import AnyUrl
+from docarray.typing.url.mimetypes import IMAGE_MIMETYPE
from docarray.utils._internal.misc import is_notebook
if TYPE_CHECKING:
@@ -20,6 +36,18 @@ class ImageUrl(AnyUrl):
Can be remote (web) URL, or a local file path.
"""
+ @classmethod
+ def mime_type(cls) -> str:
+ return IMAGE_MIMETYPE
+
+ @classmethod
+ def extra_extensions(cls) -> List[str]:
+ """
+ Returns a list of additional file extensions that are valid for this class
+ but cannot be identified by the mimetypes library.
+ """
+ return []
+
def load_pil(self, timeout: Optional[float] = None) -> 'PILImage.Image':
"""
Load the image from the bytes into a `PIL.Image.Image` instance
diff --git a/docarray/typing/url/mimetypes.py b/docarray/typing/url/mimetypes.py
new file mode 100644
index 00000000000..828a47b962b
--- /dev/null
+++ b/docarray/typing/url/mimetypes.py
@@ -0,0 +1,109 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+TEXT_MIMETYPE = 'text'
+AUDIO_MIMETYPE = 'audio'
+IMAGE_MIMETYPE = 'image'
+OBJ_MIMETYPE = 'application/x-tgif'
+VIDEO_MIMETYPE = 'video'
+
+MESH_EXTRA_EXTENSIONS = [
+ '3ds',
+ '3mf',
+ 'ac',
+ 'ac3d',
+ 'amf',
+ 'assimp',
+ 'bvh',
+ 'cob',
+ 'collada',
+ 'ctm',
+ 'dxf',
+ 'e57',
+ 'fbx',
+ 'gltf',
+ 'glb',
+ 'ifc',
+ 'lwo',
+ 'lws',
+ 'lxo',
+ 'md2',
+ 'md3',
+ 'md5',
+ 'mdc',
+ 'm3d',
+ 'mdl',
+ 'ms3d',
+ 'nff',
+ 'obj',
+ 'off',
+ 'pcd',
+ 'pod',
+ 'pmd',
+ 'pmx',
+ 'ply',
+ 'q3o',
+ 'q3s',
+ 'raw',
+ 'sib',
+ 'smd',
+ 'stl',
+ 'ter',
+ 'terragen',
+ 'vtk',
+ 'vrml',
+ 'x3d',
+ 'xaml',
+ 'xgl',
+ 'xml',
+ 'xyz',
+ 'zgl',
+ 'vta',
+]
+
+TEXT_EXTRA_EXTENSIONS = ['md', 'log']
+
+POINT_CLOUD_EXTRA_EXTENSIONS = [
+ 'ascii',
+ 'bin',
+ 'b3dm',
+ 'bpf',
+ 'dp',
+ 'dxf',
+ 'e57',
+ 'fls',
+ 'fls',
+ 'glb',
+ 'ply',
+ 'gpf',
+ 'las',
+ 'obj',
+ 'osgb',
+ 'pcap',
+ 'pcd',
+ 'pdal',
+ 'pfm',
+ 'ply',
+ 'ply2',
+ 'pod',
+ 'pods',
+ 'pnts',
+ 'ptg',
+ 'ptx',
+ 'pts',
+ 'rcp',
+ 'xyz',
+ 'zfs',
+]
diff --git a/docarray/typing/url/text_url.py b/docarray/typing/url/text_url.py
index 86da87790e6..24ae669ce69 100644
--- a/docarray/typing/url/text_url.py
+++ b/docarray/typing/url/text_url.py
@@ -1,7 +1,23 @@
-from typing import Optional, TypeVar
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import List, Optional, TypeVar
from docarray.typing.proto_register import _register_proto
from docarray.typing.url.any_url import AnyUrl
+from docarray.typing.url.mimetypes import TEXT_EXTRA_EXTENSIONS, TEXT_MIMETYPE
T = TypeVar('T', bound='TextUrl')
@@ -13,6 +29,18 @@ class TextUrl(AnyUrl):
Can be remote (web) URL, or a local file path.
"""
+ @classmethod
+ def mime_type(cls) -> str:
+ return TEXT_MIMETYPE
+
+ @classmethod
+ def extra_extensions(cls) -> List[str]:
+ """
+ Returns a list of additional file extensions that are valid for this class
+ but cannot be identified by the mimetypes library.
+ """
+ return TEXT_EXTRA_EXTENSIONS
+
def load(self, charset: str = 'utf-8', timeout: Optional[float] = None) -> str:
"""
Load the text file into a string.
diff --git a/docarray/typing/url/url_3d/__init__.py b/docarray/typing/url/url_3d/__init__.py
index a8aaf02e49d..58717ab952f 100644
--- a/docarray/typing/url/url_3d/__init__.py
+++ b/docarray/typing/url/url_3d/__init__.py
@@ -1,3 +1,18 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from docarray.typing.url.url_3d.mesh_url import Mesh3DUrl
from docarray.typing.url.url_3d.point_cloud_url import PointCloud3DUrl
diff --git a/docarray/typing/url/url_3d/mesh_url.py b/docarray/typing/url/url_3d/mesh_url.py
index 70f32eb5581..094a6c4af2a 100644
--- a/docarray/typing/url/url_3d/mesh_url.py
+++ b/docarray/typing/url/url_3d/mesh_url.py
@@ -1,10 +1,26 @@
-from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypeVar
import numpy as np
from pydantic import parse_obj_as
from docarray.typing.proto_register import _register_proto
from docarray.typing.tensor.ndarray import NdArray
+from docarray.typing.url.mimetypes import MESH_EXTRA_EXTENSIONS
from docarray.typing.url.url_3d.url_3d import Url3D
if TYPE_CHECKING:
@@ -20,6 +36,14 @@ class Mesh3DUrl(Url3D):
Can be remote (web) URL, or a local file path.
"""
+ @classmethod
+ def extra_extensions(cls) -> List[str]:
+ """
+ Returns a list of additional file extensions that are valid for this class
+ but cannot be identified by the mimetypes library.
+ """
+ return MESH_EXTRA_EXTENSIONS
+
def load(
self: T,
skip_materials: bool = True,
diff --git a/docarray/typing/url/url_3d/point_cloud_url.py b/docarray/typing/url/url_3d/point_cloud_url.py
index efe6ce6ae0e..94bbf19b0cc 100644
--- a/docarray/typing/url/url_3d/point_cloud_url.py
+++ b/docarray/typing/url/url_3d/point_cloud_url.py
@@ -1,10 +1,11 @@
-from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypeVar
import numpy as np
from pydantic import parse_obj_as
from docarray.typing.proto_register import _register_proto
from docarray.typing.tensor.ndarray import NdArray
+from docarray.typing.url.mimetypes import POINT_CLOUD_EXTRA_EXTENSIONS
from docarray.typing.url.url_3d.url_3d import Url3D
if TYPE_CHECKING:
@@ -21,6 +22,14 @@ class PointCloud3DUrl(Url3D):
Can be remote (web) URL, or a local file path.
"""
+ @classmethod
+ def extra_extensions(cls) -> List[str]:
+ """
+ Returns a list of additional file extensions that are valid for this class
+ but cannot be identified by the mimetypes library.
+ """
+ return POINT_CLOUD_EXTRA_EXTENSIONS
+
def load(
self: T,
samples: int,
diff --git a/docarray/typing/url/url_3d/url_3d.py b/docarray/typing/url/url_3d/url_3d.py
index c55c0f954e7..0f93e2bc00d 100644
--- a/docarray/typing/url/url_3d/url_3d.py
+++ b/docarray/typing/url/url_3d/url_3d.py
@@ -1,8 +1,24 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from abc import ABC
from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar, Union
from docarray.typing.proto_register import _register_proto
from docarray.typing.url.any_url import AnyUrl
+from docarray.typing.url.mimetypes import OBJ_MIMETYPE
from docarray.utils._internal.misc import import_library
if TYPE_CHECKING:
@@ -18,6 +34,10 @@ class Url3D(AnyUrl, ABC):
Can be remote (web) URL, or a local file path.
"""
+ @classmethod
+ def mime_type(cls) -> str:
+ return OBJ_MIMETYPE
+
def _load_trimesh_instance(
self: T,
force: Optional[str] = None,
diff --git a/docarray/typing/url/video_url.py b/docarray/typing/url/video_url.py
index 5bd7b1be0b9..385e4946f84 100644
--- a/docarray/typing/url/video_url.py
+++ b/docarray/typing/url/video_url.py
@@ -1,9 +1,10 @@
import warnings
-from typing import Optional, TypeVar
+from typing import List, Optional, TypeVar
from docarray.typing.bytes.video_bytes import VideoBytes, VideoLoadResult
from docarray.typing.proto_register import _register_proto
from docarray.typing.url.any_url import AnyUrl
+from docarray.typing.url.mimetypes import VIDEO_MIMETYPE
from docarray.utils._internal.misc import is_notebook
T = TypeVar('T', bound='VideoUrl')
@@ -16,6 +17,18 @@ class VideoUrl(AnyUrl):
Can be remote (web) URL, or a local file path.
"""
+ @classmethod
+ def mime_type(cls) -> str:
+ return VIDEO_MIMETYPE
+
+ @classmethod
+ def extra_extensions(cls) -> List[str]:
+ """
+ Returns a list of additional file extensions that are valid for this class
+ but cannot be identified by the mimetypes library.
+ """
+ return []
+
def load(self: T, **kwargs) -> VideoLoadResult:
"""
Load the data from the url into a `NamedTuple` of
@@ -35,9 +48,9 @@ def load(self: T, **kwargs) -> VideoLoadResult:
class MyDoc(BaseDoc):
video_url: VideoUrl
- video: Optional[VideoNdArray]
- audio: Optional[AudioNdArray]
- key_frame_indices: Optional[NdArray]
+ video: Optional[VideoNdArray] = None
+ audio: Optional[AudioNdArray] = None
+ key_frame_indices: Optional[NdArray] = None
doc = MyDoc(
diff --git a/docarray/utils/__init__.py b/docarray/utils/__init__.py
index e69de29bb2d..74f8f7582cd 100644
--- a/docarray/utils/__init__.py
+++ b/docarray/utils/__init__.py
@@ -0,0 +1,15 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/docarray/utils/_internal/__init__.py b/docarray/utils/_internal/__init__.py
index e69de29bb2d..74f8f7582cd 100644
--- a/docarray/utils/_internal/__init__.py
+++ b/docarray/utils/_internal/__init__.py
@@ -0,0 +1,15 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/docarray/utils/_internal/_typing.py b/docarray/utils/_internal/_typing.py
index 62680cf964e..3c2bd89a8e5 100644
--- a/docarray/utils/_internal/_typing.py
+++ b/docarray/utils/_internal/_typing.py
@@ -1,13 +1,29 @@
-from typing import Any, Optional
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Any, ForwardRef, Optional, Union
-from typing_inspect import get_args, is_union_type
-
-from docarray.typing.tensor.abstract_tensor import AbstractTensor
+from typing_extensions import get_origin
+from typing_inspect import get_args, is_typevar, is_union_type
def is_type_tensor(type_: Any) -> bool:
"""Return True if type is a type Tensor or an Optional Tensor type."""
- return isinstance(type_, type) and issubclass(type_, AbstractTensor)
+ from docarray.typing.tensor.abstract_tensor import AbstractTensor
+
+ return isinstance(type_, type) and safe_issubclass(type_, AbstractTensor)
def is_tensor_union(type_: Any) -> bool:
@@ -17,7 +33,8 @@ def is_tensor_union(type_: Any) -> bool:
return False
else:
return is_union and all(
- (is_type_tensor(t) or issubclass(t, type(None))) for t in get_args(type_)
+ (is_type_tensor(t) or safe_issubclass(t, type(None)))
+ for t in get_args(type_)
)
@@ -32,3 +49,27 @@ def change_cls_name(cls: type, new_name: str, scope: Optional[dict] = None) -> N
scope[new_name] = cls
cls.__qualname__ = cls.__qualname__[: -len(cls.__name__)] + new_name
cls.__name__ = new_name
+
+
+def safe_issubclass(x: type, a_tuple: type) -> bool:
+ """
+ This is a modified version of the built-in 'issubclass' function to support non-class input.
+ Traditional 'issubclass' calls can result in a crash if the input is non-class type (e.g. list/tuple).
+
+ :param x: A class 'x'
+ :param a_tuple: A class, or a tuple of classes.
+ :return: A boolean value - 'True' if 'x' is a subclass of 'A_tuple', 'False' otherwise.
+ Note that if the origin of 'x' is a list or tuple, the function immediately returns 'False'.
+ """
+ origin = get_origin(x)
+ if origin: # If x is a generic type like DocList[SomeDoc], get its origin
+ x = origin
+ if (
+ (origin in (list, tuple, dict, set, Union))
+ or is_typevar(x)
+ or (type(x) == ForwardRef)
+ or is_typevar(x)
+ ):
+ return False
+
+ return isinstance(x, type) and issubclass(x, a_tuple)
diff --git a/docarray/utils/_internal/cache.py b/docarray/utils/_internal/cache.py
index 249c4f9d179..83ffcf4b9c8 100644
--- a/docarray/utils/_internal/cache.py
+++ b/docarray/utils/_internal/cache.py
@@ -1,3 +1,18 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import os
from functools import lru_cache
from pathlib import Path
diff --git a/docarray/utils/_internal/misc.py b/docarray/utils/_internal/misc.py
index ea1b7399ffd..b44da92dc7e 100644
--- a/docarray/utils/_internal/misc.py
+++ b/docarray/utils/_internal/misc.py
@@ -2,7 +2,7 @@
import os
import re
import types
-from typing import Any, Optional
+from typing import Any, Literal, Optional
import numpy as np
@@ -22,6 +22,13 @@
tf_imported = True
+try:
+ import jax.numpy as jnp # type: ignore # noqa: F401
+except (ImportError, TypeError):
+ jnp_imported = False
+else:
+ jnp_imported = True
+
INSTALL_INSTRUCTIONS = {
'google.protobuf': '"docarray[proto]"',
'lz4': '"docarray[proto]"',
@@ -38,12 +45,18 @@
'fastapi': '"docarray[web]"',
'torch': '"docarray[torch]"',
'tensorflow': 'protobuf==3.19.0 tensorflow',
- 'hubble': '"docarray[jac]"',
'smart_open': '"docarray[aws]"',
'boto3': '"docarray[aws]"',
'botocore': '"docarray[aws]"',
+ 'redis': '"docarray[redis]"',
+ 'pymilvus': '"docarray[milvus]"',
+ "pymongo": '"docarray[mongo]"',
}
+ProtocolType = Literal[
+ 'protobuf', 'pickle', 'json', 'json-array', 'protobuf-array', 'pickle-array'
+]
+
def import_library(
package: str, raise_error: bool = True
@@ -77,6 +90,10 @@ def is_tf_available():
return tf_imported
+def is_jax_available():
+ return jnp_imported
+
+
def is_np_int(item: Any) -> bool:
dtype = getattr(item, 'dtype', None)
ndim = getattr(item, 'ndim', None)
diff --git a/docarray/utils/_internal/progress_bar.py b/docarray/utils/_internal/progress_bar.py
index 4750c509a1a..b5460c31148 100644
--- a/docarray/utils/_internal/progress_bar.py
+++ b/docarray/utils/_internal/progress_bar.py
@@ -1,3 +1,18 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from typing import Optional
from rich.progress import (
diff --git a/docarray/utils/_internal/pydantic.py b/docarray/utils/_internal/pydantic.py
new file mode 100644
index 00000000000..d8e28df3b56
--- /dev/null
+++ b/docarray/utils/_internal/pydantic.py
@@ -0,0 +1,27 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import pydantic
+
+is_pydantic_v2 = pydantic.__version__.startswith('2.')
+
+
+if not is_pydantic_v2:
+ from pydantic.validators import bytes_validator
+
+else:
+ from pydantic.v1.validators import bytes_validator
+
+__all__ = ['is_pydantic_v2', 'bytes_validator']
diff --git a/docarray/utils/_internal/query_language/__init__.py b/docarray/utils/_internal/query_language/__init__.py
index e69de29bb2d..74f8f7582cd 100644
--- a/docarray/utils/_internal/query_language/__init__.py
+++ b/docarray/utils/_internal/query_language/__init__.py
@@ -0,0 +1,15 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/docarray/utils/_internal/query_language/query_parser.py b/docarray/utils/_internal/query_language/query_parser.py
index b635d296d8e..8656fbd8406 100644
--- a/docarray/utils/_internal/query_language/query_parser.py
+++ b/docarray/utils/_internal/query_language/query_parser.py
@@ -1,3 +1,18 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from typing import Any, Dict, List, Optional, Union
from docarray.utils._internal.query_language.lookup import (
diff --git a/docarray/utils/create_dynamic_doc_class.py b/docarray/utils/create_dynamic_doc_class.py
new file mode 100644
index 00000000000..c82a7c89487
--- /dev/null
+++ b/docarray/utils/create_dynamic_doc_class.py
@@ -0,0 +1,358 @@
+from typing import Any, Dict, List, Optional, Type, Union
+
+from pydantic import BaseModel, create_model
+from pydantic.fields import FieldInfo
+
+from docarray.base_doc.doc import BaseDocWithoutId
+from docarray import BaseDoc, DocList
+from docarray.typing import AnyTensor
+from docarray.utils._internal._typing import safe_issubclass
+from docarray.utils._internal.pydantic import is_pydantic_v2
+
+RESERVED_KEYS = [
+ 'type',
+ 'anyOf',
+ '$ref',
+ 'additionalProperties',
+ 'allOf',
+ 'items',
+ 'definitions',
+ 'properties',
+ 'default',
+]
+
+
+def create_pure_python_type_model(model: BaseModel) -> BaseDoc:
+ """
+ Take a Pydantic model and cast DocList fields into List fields.
+
+ This may be necessary due to limitations in Pydantic:
+
+ https://github.com/docarray/docarray/issues/1521
+ https://github.com/pydantic/pydantic/issues/1457
+
+ ---
+
+ ```python
+ from docarray import BaseDoc
+
+
+ class MyDoc(BaseDoc):
+ tensor: Optional[AnyTensor]
+ url: ImageUrl
+ title: str
+ texts: DocList[TextDoc]
+
+
+ MyDocCorrected = create_new_model_cast_doclist_to_list(CustomDoc)
+ ```
+
+ ---
+ :param model: The input model
+ :return: A new subclass of BaseDoc, where every DocList type in the schema is replaced by List.
+ """
+ fields: Dict[str, Any] = {}
+ import copy
+
+ copy_model = copy.deepcopy(model)
+ fields_copy = copy_model.__fields__
+ annotations_copy = copy_model.__annotations__
+ for field_name, field in annotations_copy.items():
+ if field_name not in fields_copy:
+ continue
+
+ if is_pydantic_v2:
+ field_info = fields_copy[field_name]
+ else:
+ field_info = fields_copy[field_name].field_info
+ try:
+ if safe_issubclass(field, DocList) and not is_pydantic_v2:
+ t: Any = field.doc_type
+ t_aux = create_pure_python_type_model(t)
+ fields[field_name] = (List[t_aux], field_info)
+ else:
+ fields[field_name] = (field, field_info)
+ except TypeError:
+ fields[field_name] = (field, field_info)
+
+ return create_model(
+ copy_model.__name__, __base__=copy_model, __doc__=copy_model.__doc__, **fields
+ )
+
+
+def _get_field_annotation_from_schema(
+ field_schema: Dict[str, Any],
+ field_name: str,
+ cached_models: Dict[str, Any],
+ is_tensor: bool = False,
+ num_recursions: int = 0,
+ definitions: Optional[Dict] = None,
+) -> type:
+ """
+ Private method used to extract the corresponding field type from the schema.
+ :param field_schema: The schema from which to extract the type
+ :param field_name: The name of the field to be created
+ :param cached_models: Parameter used when this method is called recursively to reuse partial nested classes.
+ :param is_tensor: Boolean used to tell between tensor and list
+ :param num_recursions: Number of recursions to properly handle nested types (Dict, List, etc ..)
+ :param definitions: Parameter used when this method is called recursively to reuse root definitions of other schemas.
+ :return: A type created from the schema
+ """
+ if not definitions:
+ definitions = {}
+ field_type = field_schema.get('type', None)
+ tensor_shape = field_schema.get('tensor/array shape', None)
+ ret: Any
+ if 'anyOf' in field_schema:
+ any_of_types = []
+ for any_of_schema in field_schema['anyOf']:
+ if '$ref' in any_of_schema:
+ obj_ref = any_of_schema.get('$ref')
+ ref_name = obj_ref.split('/')[-1]
+ any_of_types.append(
+ create_base_doc_from_schema(
+ definitions[ref_name],
+ ref_name,
+ cached_models=cached_models,
+ definitions=definitions,
+ )
+ )
+ else:
+ any_of_types.append(
+ _get_field_annotation_from_schema(
+ any_of_schema,
+ field_name,
+ cached_models=cached_models,
+ is_tensor=tensor_shape is not None,
+ num_recursions=0,
+ definitions=definitions,
+ )
+ ) # No Union of Lists
+ ret = Union[tuple(any_of_types)]
+ for rec in range(num_recursions):
+ ret = List[ret]
+ elif field_type == 'string':
+ ret = str
+ for rec in range(num_recursions):
+ ret = List[ret]
+ elif field_type == 'integer':
+ ret = int
+ for rec in range(num_recursions):
+ ret = List[ret]
+ elif field_type == 'number':
+ if num_recursions == 0:
+ ret = float
+ elif num_recursions == 1:
+ # This is a hack because AnyTensor is more generic than a simple List and it comes as simple List
+ if is_tensor:
+ ret = AnyTensor
+ else:
+ ret = List[float]
+ else:
+ ret = float
+ for rec in range(num_recursions):
+ ret = List[ret]
+ elif field_type == 'boolean':
+ ret = bool
+ for rec in range(num_recursions):
+ ret = List[ret]
+ elif field_type == 'object' or field_type is None:
+ doc_type: Any
+ if 'additionalProperties' in field_schema: # handle Dictionaries
+ additional_props = field_schema['additionalProperties']
+ if (
+ isinstance(additional_props, dict)
+ and additional_props.get('type') == 'object'
+ ):
+ doc_type = create_base_doc_from_schema(
+ additional_props, field_name, cached_models=cached_models
+ )
+ ret = Dict[str, doc_type]
+ else:
+ ret = Dict[str, Any]
+ else:
+ obj_ref = field_schema.get('$ref') or field_schema.get('allOf', [{}])[
+ 0
+ ].get('$ref', None)
+ if num_recursions == 0: # single object reference
+ if obj_ref:
+ ref_name = obj_ref.split('/')[-1]
+ ret = create_base_doc_from_schema(
+ definitions[ref_name],
+ ref_name,
+ cached_models=cached_models,
+ definitions=definitions,
+ )
+ else:
+ ret = Any
+ else: # object reference in definitions
+ if obj_ref:
+ ref_name = obj_ref.split('/')[-1]
+ doc_type = create_base_doc_from_schema(
+ definitions[ref_name],
+ ref_name,
+ cached_models=cached_models,
+ definitions=definitions,
+ )
+ ret = DocList[doc_type]
+ else:
+ doc_type = create_base_doc_from_schema(
+ field_schema, field_name, cached_models=cached_models
+ )
+ ret = DocList[doc_type]
+ elif field_type == 'array':
+ ret = _get_field_annotation_from_schema(
+ field_schema=field_schema.get('items', {}),
+ field_name=field_name,
+ cached_models=cached_models,
+ is_tensor=tensor_shape is not None,
+ num_recursions=num_recursions + 1,
+ definitions=definitions,
+ )
+ elif field_type == 'null':
+ ret = None
+ else:
+ if num_recursions > 0:
+ raise ValueError(
+ f"Unknown array item type: {field_type} for field_name {field_name}"
+ )
+ else:
+ raise ValueError(
+ f"Unknown field type: {field_type} for field_name {field_name}"
+ )
+ return ret
+
+
+def create_base_doc_from_schema(
+ schema: Dict[str, Any],
+ base_doc_name: str,
+ cached_models: Optional[Dict] = None,
+ definitions: Optional[Dict] = None,
+) -> Type:
+ """
+ Dynamically create a `BaseDoc` subclass from a `schema` of another `BaseDoc`.
+
+ This method is intended to dynamically create a `BaseDoc` compatible with the schema
+ of another BaseDoc. This is useful when that other `BaseDoc` is not available in the current scope. For instance, you may have stored the schema
+ as a JSON, or sent it to another service, etc.
+
+ Due to this Pydantic limitation (https://github.com/docarray/docarray/issues/1521, https://github.com/pydantic/pydantic/issues/1457), we need to make sure that the
+ input schema uses `List` and not `DocList`. Therefore this is recommended to be used in combination with `create_new_model_cast_doclist_to_list`
+ to make sure that `DocLists` in schema are converted to `List`.
+
+ ---
+
+ ```python
+ from docarray import BaseDoc
+
+
+ class MyDoc(BaseDoc):
+ tensor: Optional[AnyTensor]
+ url: ImageUrl
+ title: str
+ texts: DocList[TextDoc]
+
+
+ MyDocCorrected = create_pure_python_type_model(CustomDoc)
+ new_my_doc_cls = create_base_doc_from_schema(CustomDocCopy.schema(), 'MyDoc')
+ ```
+
+ ---
+ :param schema: The schema of the original `BaseDoc` where DocLists are passed as regular Lists of Documents.
+ :param base_doc_name: The name of the new pydantic model created.
+ :param cached_models: Parameter used when this method is called recursively to reuse partial nested classes.
+ :param definitions: Parameter used when this method is called recursively to reuse root definitions of other schemas.
+ :return: A BaseDoc class dynamically created following the `schema`.
+ """
+
+ def clean_refs(value):
+ """Recursively remove $ref keys and #/$defs values from a data structure."""
+ if isinstance(value, dict):
+ # Create a new dictionary without $ref keys and without values containing #/$defs
+ cleaned_dict = {}
+ for k, v in value.items():
+ if k == '$ref':
+ continue
+ cleaned_dict[k] = clean_refs(v)
+ return cleaned_dict
+ elif isinstance(value, list):
+ # Process each item in the list
+ return [clean_refs(item) for item in value]
+ else:
+ # Return primitive values as-is
+ return value
+
+ if not definitions:
+ definitions = (
+ schema.get('definitions', {}) if not is_pydantic_v2 else schema.get('$defs')
+ )
+
+ cached_models = cached_models if cached_models is not None else {}
+ fields: Dict[str, Any] = {}
+ if base_doc_name in cached_models:
+ return cached_models[base_doc_name]
+ has_id = False
+ for field_name, field_schema in schema.get('properties', {}).items():
+ if field_name == 'id':
+ has_id = True
+ # Get the field type
+ field_type = _get_field_annotation_from_schema(
+ field_schema=field_schema,
+ field_name=field_name,
+ cached_models=cached_models,
+ is_tensor=False,
+ num_recursions=0,
+ definitions=definitions,
+ )
+ if not is_pydantic_v2:
+ field_schema['default'] = field_schema.get('default', None)
+ fields[field_name] = (
+ field_type,
+ FieldInfo(**field_schema),
+ )
+ else:
+ field_kwargs = {}
+ field_json_schema_extra = {}
+ for k, v in field_schema.items():
+ if field_name == 'id':
+ # Skip default_factory for Optional fields and use None
+ field_kwargs['default'] = None
+ if k in FieldInfo.__slots__:
+ field_kwargs[k] = v
+ else:
+ if k != '$ref':
+ if isinstance(v, dict):
+ cleaned_v = clean_refs(v)
+ if (
+ cleaned_v
+ ): # Only add if there's something left after cleaning
+ field_json_schema_extra[k] = cleaned_v
+ else:
+ field_json_schema_extra[k] = v
+
+ fields[field_name] = (
+ field_type,
+ FieldInfo(
+ json_schema_extra=field_json_schema_extra,
+ **field_kwargs,
+ ),
+ )
+
+ base_model = BaseDoc if has_id else BaseDocWithoutId
+ model = create_model(base_doc_name, __base__=base_model, **fields)
+ if not is_pydantic_v2:
+ model.__config__.title = schema.get('title', model.__config__.title)
+ else:
+ set_title = schema.get('title', model.model_config.get('title', None))
+ if set_title:
+ model.model_config['title'] = set_title
+
+ for k in RESERVED_KEYS:
+ if k in schema:
+ schema.pop(k)
+ if not is_pydantic_v2:
+ model.__config__.schema_extra = schema
+ else:
+ model.model_config['json_schema_extra'] = schema
+ cached_models[base_doc_name] = model
+ return model
diff --git a/docarray/utils/filter.py b/docarray/utils/filter.py
index 5b7daa1e6f2..2f95a95ffe1 100644
--- a/docarray/utils/filter.py
+++ b/docarray/utils/filter.py
@@ -12,63 +12,64 @@ def filter_docs(
query: Union[str, Dict, List[Dict]],
) -> AnyDocArray:
"""
- Filter the Documents in the index according to the given filter query.
-
-
-
- ---
-
- ```python
- from docarray import DocList, BaseDoc
- from docarray.documents import TextDoc, ImageDoc
- from docarray.utils.filter import filter_docs
-
-
- class MyDocument(BaseDoc):
- caption: TextDoc
- ImageDoc: ImageDoc
- price: int
-
-
- docs = DocList[MyDocument](
- [
- MyDocument(
- caption='A tiger in the jungle',
- ImageDoc=ImageDoc(url='tigerphoto.png'),
- price=100,
- ),
- MyDocument(
- caption='A swimming turtle',
- ImageDoc=ImageDoc(url='turtlepic.png'),
- price=50,
- ),
- MyDocument(
- caption='A couple birdwatching with binoculars',
- ImageDoc=ImageDoc(url='binocularsphoto.png'),
- price=30,
- ),
- ]
- )
- query = {
- '$and': {
- 'ImageDoc__url': {'$regex': 'photo'},
- 'price': {'$lte': 50},
- }
- }
-
- results = filter_docs(docs, query)
- assert len(results) == 1
- assert results[0].price == 30
- assert results[0].caption == 'A couple birdwatching with binoculars'
- assert results[0].ImageDoc.url == 'binocularsphoto.png'
- ```
-
- ---
-
- :param docs: the DocList where to apply the filter
- :param query: the query to filter by
- :return: A DocList containing the Documents
- in `docs` that fulfill the filter conditions in the `query`
+ Filter the Documents in the index according to the given filter query.
+ Filter queries use the same syntax as the MongoDB query language (https://www.mongodb.com/docs/manual/tutorial/query-documents/#specify-conditions-using-query-operators).
+ You can see a list of the supported operators here (https://www.mongodb.com/docs/manual/reference/operator/query/#std-label-query-selectors)
+
+
+ ---
+
+ ```python
+ from docarray import DocList, BaseDoc
+ from docarray.documents import TextDoc, ImageDoc
+ from docarray.utils.filter import filter_docs
+
+
+ class MyDocument(BaseDoc):
+ caption: TextDoc
+ ImageDoc: ImageDoc
+ price: int
+
+
+ docs = DocList[MyDocument](
+ [
+ MyDocument(
+ caption='A tiger in the jungle',
+ ImageDoc=ImageDoc(url='tigerphoto.png'),
+ price=100,
+ ),
+ MyDocument(
+ caption='A swimming turtle',
+ ImageDoc=ImageDoc(url='turtlepic.png'),
+ price=50,
+ ),
+ MyDocument(
+ caption='A couple birdwatching with binoculars',
+ ImageDoc=ImageDoc(url='binocularsphoto.png'),
+ price=30,
+ ),
+ ]
+ )
+ query = {
+ '$and': {
+ 'ImageDoc__url': {'$regex': 'photo'},
+ 'price': {'$lte': 50},
+ }
+ }
+
+ results = filter_docs(docs, query)
+ assert len(results) == 1
+ assert results[0].price == 30
+ assert results[0].caption == 'A couple birdwatching with binoculars'
+ assert results[0].ImageDoc.url == 'binocularsphoto.png'
+ ```
+
+ ---
+
+ :param docs: the DocList where to apply the filter
+ :param query: the query to filter by
+ :return: A DocList containing the Documents
+ in `docs` that fulfill the filter conditions in the `query`
"""
from docarray.utils._internal.query_language.query_parser import QueryParser
diff --git a/docarray/utils/find.py b/docarray/utils/find.py
index f522a78f297..2b77bcbb77e 100644
--- a/docarray/utils/find.py
+++ b/docarray/utils/find.py
@@ -1,16 +1,57 @@
__all__ = ['find', 'find_batched']
-from typing import Any, Dict, List, NamedTuple, Optional, Type, Union, cast
-
-from typing_inspect import is_union_type
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ List,
+ NamedTuple,
+ Optional,
+ Tuple,
+ Type,
+ Union,
+ cast,
+)
from docarray.array.any_array import AnyDocArray
from docarray.array.doc_list.doc_list import DocList
from docarray.array.doc_vec.doc_vec import DocVec
from docarray.base_doc import BaseDoc
-from docarray.helper import _get_field_type_by_access_path
+from docarray.computation.numpy_backend import NumpyCompBackend
from docarray.typing import AnyTensor
-from docarray.typing.tensor.abstract_tensor import AbstractTensor
+from docarray.typing.tensor import NdArray
+from docarray.utils._internal.misc import ( # noqa
+ is_jax_available,
+ is_tf_available,
+ is_torch_available,
+)
+
+jax_available = is_jax_available()
+if jax_available:
+ import jax.numpy as jnp
+
+ from docarray.computation.jax_backend import JaxCompBackend
+ from docarray.typing.tensor.jaxarray import JaxArray # noqa: F401
+
+torch_available = is_torch_available()
+if torch_available:
+ import torch
+
+ from docarray.computation.torch_backend import TorchCompBackend
+ from docarray.typing.tensor.torch_tensor import TorchTensor # noqa: F401
+
+tf_available = is_tf_available()
+if tf_available:
+ import tensorflow as tf # type: ignore
+
+ from docarray.computation.tensorflow_backend import TensorFlowCompBackend
+ from docarray.typing.tensor.tensorflow_tensor import TensorFlowTensor # noqa: F401
+
+if TYPE_CHECKING:
+ from docarray.computation.abstract_numpy_based_backend import (
+ AbstractComputationalBackend,
+ )
+ from docarray.typing.tensor.abstract_tensor import AbstractTensor
class FindResult(NamedTuple):
@@ -23,6 +64,12 @@ class _FindResult(NamedTuple):
scores: AnyTensor
+class SubindexFindResult(NamedTuple):
+ root_documents: DocList
+ sub_documents: DocList
+ scores: AnyTensor
+
+
class FindResultBatched(NamedTuple):
documents: List[DocList]
scores: List[AnyTensor]
@@ -41,6 +88,7 @@ def find(
limit: int = 10,
device: Optional[str] = None,
descending: Optional[bool] = None,
+ cache: Optional[Dict[str, Tuple[AnyTensor, Optional[List[int]]]]] = None,
) -> FindResult:
"""
Find the closest Documents in the index to the query.
@@ -100,6 +148,7 @@ class MyDocument(BaseDoc):
can be either `cpu` or a `cuda` device.
:param descending: sort the results in descending order.
Per default, this is chosen based on the `metric` argument.
+ :param cache: Precomputed data storing the valid index data per search field together with the valid indexes to account for deleted entries.
:return: A named tuple of the form (DocList, AnyTensor),
where the first element contains the closes matches for the query,
and the second element contains the corresponding scores.
@@ -113,6 +162,7 @@ class MyDocument(BaseDoc):
limit=limit,
device=device,
descending=descending,
+ cache=cache,
)
return FindResult(documents=docs[0], scores=scores[0])
@@ -125,6 +175,7 @@ def find_batched(
limit: int = 10,
device: Optional[str] = None,
descending: Optional[bool] = None,
+ cache: Optional[Dict[str, Tuple[AnyTensor, Optional[List[int]]]]] = None,
) -> FindResultBatched:
"""
Find the closest Documents in the index to the queries.
@@ -135,6 +186,8 @@ def find_batched(
search using approximate nearest neighbours search or hybrid search or
multi vector search please take a look at the [`BaseDoc`][docarray.base_doc.doc.BaseDoc]
+ !!! note
+ Only non-None embeddings will be considered from the `index` array
---
@@ -187,6 +240,7 @@ class MyDocument(BaseDoc):
can be either `cpu` or a `cuda` device.
:param descending: sort the results in descending order.
Per default, this is chosen based on the `metric` argument.
+ :param cache: Precomputed data storing the valid index data per search field together with the valid indexes to account for deleted entries.
:return: A named tuple of the form (DocList, AnyTensor),
where the first element contains the closest matches for each query,
and the second element contains the corresponding scores.
@@ -194,29 +248,35 @@ class MyDocument(BaseDoc):
if descending is None:
descending = metric.endswith('_sim') # similarity metrics are descending
- embedding_type = _da_attr_type(index, search_field)
- comp_backend = embedding_type.get_comp_backend()
-
# extract embeddings from query and index
- index_embeddings = _extract_embeddings(index, search_field, embedding_type)
- query_embeddings = _extract_embeddings(query, search_field, embedding_type)
+ if cache is not None and search_field in cache:
+ index_embeddings, valid_idx = cache[search_field]
+ else:
+ index_embeddings, valid_idx = _extract_embeddings(index, search_field)
+ if cache is not None:
+ cache[search_field] = (
+ index_embeddings,
+ valid_idx,
+ ) # cache embedding for next query
+ query_embeddings, _ = _extract_embeddings(query, search_field)
+ _, comp_backend = _get_tensor_type_and_comp_backend_from_tensor(index_embeddings)
# compute distances and return top results
metric_fn = getattr(comp_backend.Metrics, metric)
dists = metric_fn(query_embeddings, index_embeddings, device=device)
top_scores, top_indices = comp_backend.Retrieval.top_k(
- dists, k=limit, device=device, descending=descending
+ dists, k=int(limit), device=device, descending=descending
)
batched_docs: List[DocList] = []
+ candidate_index = index
+ if valid_idx is not None and len(valid_idx) < len(index):
+ candidate_index = index[valid_idx]
scores = []
for _, (indices_per_query, scores_per_query) in enumerate(
zip(top_indices, top_scores)
):
- doc_type = cast(Type[BaseDoc], index.doc_type)
- docs_per_query: DocList = DocList.__class_getitem__(doc_type)()
- for idx in indices_per_query: # workaround until #930 is fixed
- docs_per_query.append(index[int(idx)])
+ docs_per_query: DocList = candidate_index[indices_per_query]
batched_docs.append(docs_per_query)
scores.append(scores_per_query)
return FindResultBatched(documents=batched_docs, scores=scores)
@@ -245,54 +305,67 @@ def _extract_embedding_single(
return emb
+def _get_tensor_type_and_comp_backend_from_tensor(
+ tensor,
+) -> Tuple[Type['AbstractTensor'], 'AbstractComputationalBackend']:
+ """Extract the embeddings from the data.
+
+ :param tensor: the tensor for which to extract
+ :return: a tuple of the tensor type and the computational backend
+ """
+ da_tensor_type: Type['AbstractTensor'] = NdArray
+ comp_backend: 'AbstractComputationalBackend' = NumpyCompBackend()
+ if torch_available and isinstance(tensor, (TorchTensor, torch.Tensor)):
+ comp_backend = TorchCompBackend()
+ da_tensor_type = TorchTensor
+ elif tf_available and isinstance(tensor, (TensorFlowTensor, tf.Tensor)):
+ comp_backend = TensorFlowCompBackend()
+ da_tensor_type = TensorFlowTensor
+ elif jax_available and isinstance(tensor, (JaxArray, jnp.ndarray)):
+ comp_backend = JaxCompBackend()
+ da_tensor_type = JaxArray
+
+ return da_tensor_type, comp_backend
+
+
def _extract_embeddings(
data: Union[AnyDocArray, BaseDoc, AnyTensor],
search_field: str,
- embedding_type: Type,
-) -> AnyTensor:
+) -> Tuple[AnyTensor, Optional[List[int]]]:
"""Extract the embeddings from the data.
:param data: the data
:param search_field: the embedding field
- :param embedding_type: type of the embedding: torch.Tensor, numpy.ndarray etc.
- :return: the embeddings
+ :return: a tuple of the embeddings and optionally a list of the non-null indices
"""
emb: AnyTensor
+ valid_idx = None
+ comp_backend = None
+ da_tensor_type = None
if isinstance(data, DocList):
- emb_list = list(AnyDocArray._traverse(data, search_field))
- emb = embedding_type._docarray_stack(emb_list)
+ emb_valid = [
+ (emb, i)
+ for i, emb in enumerate(AnyDocArray._traverse(data, search_field))
+ if emb is not None
+ ]
+ emb_list, valid_idx = zip(*emb_valid)
+ if len(emb_list) > 0:
+ (
+ da_tensor_type,
+ comp_backend,
+ ) = _get_tensor_type_and_comp_backend_from_tensor(emb_list[0])
+ else:
+ raise Exception(f'No embedding could be extracted from data {data}')
+
+ emb = da_tensor_type._docarray_stack(emb_list)
elif isinstance(data, (DocVec, BaseDoc)):
emb = next(AnyDocArray._traverse(data, search_field))
else: # treat data as tensor
emb = cast(AnyTensor, data)
- if len(emb.shape) == 1:
- emb = emb.get_comp_backend().reshape(array=emb, shape=(1, -1))
- return emb
-
+ if comp_backend is None:
+ _, comp_backend = _get_tensor_type_and_comp_backend_from_tensor(emb)
-def _da_attr_type(docs: AnyDocArray, access_path: str) -> Type[AnyTensor]:
- """Get the type of the attribute according to the Document type
- (schema) of the DocList.
-
- :param docs: the DocList
- :param access_path: the "__"-separated access path
- :return: the type of the attribute
- """
- field_type: Optional[Type] = _get_field_type_by_access_path(
- docs.doc_type, access_path
- )
- if field_type is None:
- raise ValueError(f"Access path is not valid: {access_path}")
-
- if is_union_type(field_type):
- # determine type based on the fist element
- field_type = type(next(AnyDocArray._traverse(docs[0], access_path)))
-
- if not issubclass(field_type, AbstractTensor):
- raise ValueError(
- f'attribute {access_path} is not a tensor-like type, '
- f'but {field_type.__class__.__name__}'
- )
-
- return cast(Type[AnyTensor], field_type)
+ if len(emb.shape) == 1:
+ emb = comp_backend.reshape(tensor=emb, shape=(1, -1))
+ return emb, valid_idx
diff --git a/docarray/utils/reduce.py b/docarray/utils/reduce.py
index f615b9aeaeb..04433252e53 100644
--- a/docarray/utils/reduce.py
+++ b/docarray/utils/reduce.py
@@ -1,3 +1,18 @@
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
__all__ = ['reduce', 'reduce_all']
from typing import Dict, List, Optional
@@ -31,7 +46,8 @@ def reduce(
if doc.id in left_id_map:
left[left_id_map[doc.id]].update(doc)
else:
- left.append(doc)
+ casted = left.doc_type(**doc.__dict__)
+ left.append(casted)
return left
diff --git a/docs/.gitignore b/docs/.gitignore
index eee951db889..c528ce87543 100644
--- a/docs/.gitignore
+++ b/docs/.gitignore
@@ -1,7 +1,6 @@
api/*
proto/*
-<<<<<<< HEAD
README.md
#index.md
CONTRIBUTING.md
\ No newline at end of file
diff --git a/docs/API_reference/array/da.md b/docs/API_reference/array/da.md
index e1f5b33f008..6de33ec06bc 100644
--- a/docs/API_reference/array/da.md
+++ b/docs/API_reference/array/da.md
@@ -1,5 +1,5 @@
# DocList
::: docarray.array.doc_list.doc_list.DocList
-::: docarray.array.doc_list.io.IOMixinArray
+::: docarray.array.doc_list.io.IOMixinDocList
::: docarray.array.doc_list.pushpull.PushPullMixin
diff --git a/docs/API_reference/array/da_stack.md b/docs/API_reference/array/da_stack.md
index 917b4488d78..8aa38d46182 100644
--- a/docs/API_reference/array/da_stack.md
+++ b/docs/API_reference/array/da_stack.md
@@ -1,3 +1,4 @@
# DocVec
-::: docarray.array.doc_vec.doc_vec.DocVec
\ No newline at end of file
+::: docarray.array.doc_vec.doc_vec.DocVec
+::: docarray.array.doc_vec.io.IOMixinDocVec
\ No newline at end of file
diff --git a/docs/API_reference/doc_index/backends/epsilla.md b/docs/API_reference/doc_index/backends/epsilla.md
new file mode 100644
index 00000000000..6248690c4b0
--- /dev/null
+++ b/docs/API_reference/doc_index/backends/epsilla.md
@@ -0,0 +1,3 @@
+# EpsillaDocumentIndex
+
+::: docarray.index.backends.epsilla.EpsillaDocumentIndex
\ No newline at end of file
diff --git a/docs/API_reference/doc_index/backends/milvus.md b/docs/API_reference/doc_index/backends/milvus.md
new file mode 100644
index 00000000000..38514163cac
--- /dev/null
+++ b/docs/API_reference/doc_index/backends/milvus.md
@@ -0,0 +1,3 @@
+# MilvusDocumentIndex
+
+::: docarray.index.backends.milvus.MilvusDocumentIndex
\ No newline at end of file
diff --git a/docs/API_reference/doc_index/backends/mongodb.md b/docs/API_reference/doc_index/backends/mongodb.md
new file mode 100644
index 00000000000..0a7dc2f6ec1
--- /dev/null
+++ b/docs/API_reference/doc_index/backends/mongodb.md
@@ -0,0 +1,134 @@
+# MongoDBAtlasDocumentIndex
+
+::: docarray.index.backends.mongodb_atlas.MongoDBAtlasDocumentIndex
+
+# Setting up MongoDB Atlas as the Document Index
+
+MongoDB Atlas is a multi-cloud database service made by the same people that build MongoDB.
+Atlas simplifies deploying and managing your databases while offering the versatility you need
+to build resilient and performant global applications on the cloud providers of your choice.
+
+You can perform semantic search on data in your Atlas cluster running MongoDB v6.0.11
+or later using Atlas Vector Search. You can store vector embeddings for any kind of data along
+with other data in your collection on the Atlas cluster.
+
+In the section, we set up a cluster, a database, test it, and finally create an Atlas Vector Search Index.
+
+### Deploy a Cluster
+
+Follow the [Getting-Started](https://www.mongodb.com/basics/mongodb-atlas-tutorial) documentation
+to create an account, deploy an Atlas cluster, and connect to a database.
+
+
+### Retrieve the URI used by Python to connect to the Cluster
+
+When you deploy, this will be stored as the environment variable: `MONGODB_URI`
+It will look something like the following. The username and password, if not provided,
+can be configured in *Database Access* under Security in the left panel.
+
+```
+export MONGODB_URI="mongodb+srv://:@cluster0.foo.mongodb.net/?retryWrites=true&w=majority"
+```
+
+There are a number of ways to navigate the Atlas UI. Keep your eye out for "Connect" and "Driver".
+
+On the left panel, navigate and click 'Database' under DEPLOYMENT.
+Click the Connect button that appears, then Drivers. Select Python.
+(Have no concern for the version. This is the PyMongo, not Python, version.)
+Once you have got the Connect Window open, you will see an instruction to `pip install pymongo`.
+You will also see a **connection string**.
+This is the `uri` that a `pymongo.MongoClient` uses to connect to the Database.
+
+
+### Test the connection
+
+Atlas provides a simple check. Once you have your `uri` and `pymongo` installed,
+try the following in a python console.
+
+```python
+from pymongo.mongo_client import MongoClient
+client = MongoClient(uri) # Create a new client and connect to the server
+try:
+ client.admin.command('ping') # Send a ping to confirm a successful connection
+ print("Pinged your deployment. You successfully connected to MongoDB!")
+except Exception as e:
+ print(e)
+```
+
+**Troubleshooting**
+* You can edit a Database's users and passwords on the 'Database Access' page, under Security.
+* Remember to add your IP address. (Try `curl -4 ifconfig.co`)
+
+### Create a Database and Collection
+
+As mentioned, Vector Databases provide two functions. In addition to being the data store,
+they provide very efficient search based on natural language queries.
+With Vector Search, one will index and query data with a powerful vector search algorithm
+using "Hierarchical Navigable Small World (HNSW) graphs to find vector similarity.
+
+The indexing runs beside the data as a separate service asynchronously.
+The Search index monitors changes to the Collection that it applies to.
+Subsequently, one need not upload the data first.
+We will create an empty collection now, which will simplify setup in the example notebook.
+
+Back in the UI, navigate to the Database Deployments page by clicking Database on the left panel.
+Click the "Browse Collections" and then "+ Create Database" buttons.
+This will open a window where you choose Database and Collection names. (No additional preferences.)
+Remember these values as they will be as the environment variables,
+`MONGODB_DATABASE`.
+
+### MongoDBAtlasDocumentIndex
+
+To connect to the MongoDB Cluster and Database, define the following environment variables.
+You can confirm that the required ones have been set like this: `assert "MONGODB_URI" in os.environ`
+
+**IMPORTANT** It is crucial that the choices are consistent between setup in Atlas and Python environment(s).
+
+| Name | Description | Example |
+|-----------------------|-----------------------------|--------------------------------------------------------------|
+| `MONGODB_URI` | Connection String | mongodb+srv://``:``@cluster0.bar.mongodb.net |
+| `MONGODB_DATABASE` | Database name | docarray_test_db |
+
+
+```python
+
+from docarray.index.backends.mongodb_atlas import MongoDBAtlasDocumentIndex
+import os
+
+index = MongoDBAtlasDocumentIndex(
+ mongo_connection_uri=os.environ["MONGODB_URI"],
+ database_name=os.environ["MONGODB_DATABASE"])
+```
+
+
+### Create an Atlas Vector Search Index
+
+The final step to configure a MongoDBAtlasDocumentIndex is to create a Vector Search Indexes.
+The procedure is described [here](https://www.mongodb.com/docs/atlas/atlas-vector-search/create-index/#procedure).
+
+Under Services on the left panel, choose Atlas Search > Create Search Index >
+Atlas Vector Search JSON Editor. An index definition looks like the following.
+
+
+```json
+{
+ "fields": [
+ {
+ "numDimensions": 1536,
+ "path": "embedding",
+ "similarity": "cosine",
+ "type": "vector"
+ }
+ ]
+}
+```
+
+
+### Running MongoDB Atlas Integration Tests
+
+Setup is described in detail here `tests/index/mongo_atlas/README.md`.
+There are actually a number of different collections and indexes to be created within your cluster's database.
+
+```bash
+MONGODB_URI= MONGODB_DATABASE= py.test tests/index/mongo_atlas/
+```
diff --git a/docs/API_reference/doc_index/backends/redis.md b/docs/API_reference/doc_index/backends/redis.md
new file mode 100644
index 00000000000..f9622b23d55
--- /dev/null
+++ b/docs/API_reference/doc_index/backends/redis.md
@@ -0,0 +1,3 @@
+# RedisDocumentIndex
+
+::: docarray.index.backends.redis.RedisDocumentIndex
\ No newline at end of file
diff --git a/docs/API_reference/doc_store/jac_doc_store.md b/docs/API_reference/doc_store/jac_doc_store.md
deleted file mode 100644
index 1d4c0a28303..00000000000
--- a/docs/API_reference/doc_store/jac_doc_store.md
+++ /dev/null
@@ -1,3 +0,0 @@
-# JACDocStore
-
-::: docarray.store.jac.JACDocStore
diff --git a/docs/API_reference/typing/tensor/tensor.md b/docs/API_reference/typing/tensor/tensor.md
index 7273dea9476..4c329d29ed5 100644
--- a/docs/API_reference/typing/tensor/tensor.md
+++ b/docs/API_reference/typing/tensor/tensor.md
@@ -4,3 +4,5 @@
::: docarray.typing.tensor.ndarray
::: docarray.typing.tensor.tensorflow_tensor
::: docarray.typing.tensor.torch_tensor
+::: docarray.typing.tensor.AnyTensor
+
diff --git a/docs/_versions.json b/docs/_versions.json
index e5bc801b127..f318a2796a0 100644
--- a/docs/_versions.json
+++ b/docs/_versions.json
@@ -1 +1 @@
-[{"version": "v0.31.1"}, {"version": "v0.31.0"}, {"version": "v0.30.0"}, {"version": "v0.21.0"}, {"version": "v0.20.1"}, {"version": "v0.20.0"}, {"version": "v0.19.0"}, {"version": "v0.18.1"}, {"version": "v0.18.0"}, {"version": "v0.17.0"}, {"version": "v0.16.5"}, {"version": "v0.16.4"}, {"version": "v0.16.3"}, {"version": "v0.16.2"}, {"version": "v0.16.1"}, {"version": "v0.16.0"}, {"version": "v0.15.4"}, {"version": "v0.15.3"}, {"version": "v0.15.2"}, {"version": "v0.15.1"}, {"version": "v0.15.0"}, {"version": "v0.14.11"}, {"version": "v0.14.10"}, {"version": "v0.14.9"}, {"version": "v0.14.8"}, {"version": "v0.14.7"}, {"version": "v0.14.6"}, {"version": "v0.14.5"}, {"version": "v0.14.4"}, {"version": "v0.14.3"}, {"version": "v0.14.2"}, {"version": "v0.14.1"}, {"version": "v0.14.0"}, {"version": "v0.13.33"}, {"version": "v0.13.0"}, {"version": "v0.12.9"}, {"version": "v0.12.0"}, {"version": "v0.11.3"}, {"version": "v0.11.2"}, {"version": "v0.11.1"}, {"version": "v0.11.0"}, {"version": "v0.10.5"}, {"version": "v0.10.4"}, {"version": "v0.10.3"}, {"version": "v0.10.2"}, {"version": "v0.10.1"}, {"version": "v0.10.0"}]
\ No newline at end of file
+[{"version": "v0.40.1"}, {"version": "v0.40.0"}, {"version": "v0.39.1"}, {"version": "v0.39.0"}, {"version": "v0.38.0"}, {"version": "v0.37.1"}, {"version": "v0.37.0"}, {"version": "v0.36.0"}, {"version": "v0.35.0"}, {"version": "v0.34.0"}, {"version": "v0.33.0"}, {"version": "v0.32.1"}, {"version": "v0.32.0"}, {"version": "v0.31.1"}, {"version": "v0.31.0"}, {"version": "v0.30.0"}, {"version": "v0.21.0"}, {"version": "v0.20.1"}, {"version": "v0.20.0"}, {"version": "v0.19.0"}, {"version": "v0.18.1"}, {"version": "v0.18.0"}, {"version": "v0.17.0"}, {"version": "v0.16.5"}, {"version": "v0.16.4"}, {"version": "v0.16.3"}, {"version": "v0.16.2"}, {"version": "v0.16.1"}, {"version": "v0.16.0"}, {"version": "v0.15.4"}, {"version": "v0.15.3"}, {"version": "v0.15.2"}, {"version": "v0.15.1"}, {"version": "v0.15.0"}, {"version": "v0.14.11"}, {"version": "v0.14.10"}, {"version": "v0.14.9"}, {"version": "v0.14.8"}, {"version": "v0.14.7"}, {"version": "v0.14.6"}, {"version": "v0.14.5"}, {"version": "v0.14.4"}, {"version": "v0.14.3"}, {"version": "v0.14.2"}, {"version": "v0.14.1"}, {"version": "v0.14.0"}, {"version": "v0.13.33"}, {"version": "v0.13.0"}, {"version": "v0.12.9"}, {"version": "v0.12.0"}, {"version": "v0.11.3"}, {"version": "v0.11.2"}, {"version": "v0.11.1"}, {"version": "v0.11.0"}, {"version": "v0.10.5"}, {"version": "v0.10.4"}, {"version": "v0.10.3"}, {"version": "v0.10.2"}, {"version": "v0.10.1"}, {"version": "v0.10.0"}]
\ No newline at end of file
diff --git a/docs/assets/docarray-colorful.svg b/docs/assets/docarray-colorful.svg
new file mode 100644
index 00000000000..ed803d09d56
--- /dev/null
+++ b/docs/assets/docarray-colorful.svg
@@ -0,0 +1,16 @@
+
+
\ No newline at end of file
diff --git a/docs/assets/docarray-dark.svg b/docs/assets/docarray-dark.svg
index 7bb9d21c90e..e8c43ac48d4 100644
--- a/docs/assets/docarray-dark.svg
+++ b/docs/assets/docarray-dark.svg
@@ -2,7 +2,7 @@