diff --git a/docarray/array/storage/annlite/backend.py b/docarray/array/storage/annlite/backend.py index 4523c057db0..b56e94b5ae8 100644 --- a/docarray/array/storage/annlite/backend.py +++ b/docarray/array/storage/annlite/backend.py @@ -11,8 +11,8 @@ import numpy as np -from ..base.backend import BaseBackendMixin -from ....helper import dataclass_from_dict, filter_dict +from ..base.backend import BaseBackendMixin, TypeMap +from ....helper import dataclass_from_dict, filter_dict, _safe_cast_int if TYPE_CHECKING: from ....typing import DocumentArraySourceType, ArrayType @@ -33,7 +33,11 @@ class AnnliteConfig: class BackendMixin(BaseBackendMixin): """Provide necessary functions to enable this storage backend.""" - TYPE_MAP = {'str': 'TEXT', 'float': 'float', 'int': 'integer'} + TYPE_MAP = { + 'str': TypeMap(type='TEXT', converter=str), + 'float': TypeMap(type='float', converter=float), + 'int': TypeMap(type='integer', converter=_safe_cast_int), + } def _map_embedding(self, embedding: 'ArrayType') -> 'ArrayType': if embedding is None: diff --git a/docarray/array/storage/base/backend.py b/docarray/array/storage/base/backend.py index 14f32434aa1..1ec44b92e74 100644 --- a/docarray/array/storage/base/backend.py +++ b/docarray/array/storage/base/backend.py @@ -1,13 +1,16 @@ from abc import ABC +from collections import namedtuple from dataclasses import is_dataclass, asdict from typing import Dict, Optional, TYPE_CHECKING if TYPE_CHECKING: from ....typing import DocumentArraySourceType, ArrayType +TypeMap = namedtuple('TypeMap', ['type', 'converter']) + class BaseBackendMixin(ABC): - TYPE_MAP: Dict + TYPE_MAP: Dict[str, TypeMap] def _init_storage( self, @@ -25,10 +28,13 @@ def _get_storage_infos(self) -> Optional[Dict]: def _map_id(self, _id: str) -> str: return _id + def _map_column(self, value, col_type) -> str: + return self.TYPE_MAP[col_type].converter(value) + def _map_embedding(self, embedding: 'ArrayType') -> 'ArrayType': from ....math.ndarray import to_numpy_array return to_numpy_array(embedding) def _map_type(self, col_type: str) -> str: - return self.TYPE_MAP[col_type] + return self.TYPE_MAP[col_type].type diff --git a/docarray/array/storage/weaviate/backend.py b/docarray/array/storage/weaviate/backend.py index c693c1cecd9..2c3763c62b0 100644 --- a/docarray/array/storage/weaviate/backend.py +++ b/docarray/array/storage/weaviate/backend.py @@ -14,8 +14,8 @@ import weaviate from .... import Document -from ....helper import dataclass_from_dict, filter_dict -from ..base.backend import BaseBackendMixin +from ....helper import dataclass_from_dict, filter_dict, _safe_cast_int +from ..base.backend import BaseBackendMixin, TypeMap from ..registry import _REGISTRY if TYPE_CHECKING: @@ -52,7 +52,11 @@ class WeaviateConfig: class BackendMixin(BaseBackendMixin): """Provide necessary functions to enable this storage backend.""" - TYPE_MAP = {'str': 'string', 'float': 'number', 'int': 'int'} + TYPE_MAP = { + 'str': TypeMap(type='string', converter=str), + 'float': TypeMap(type='number', converter=float), + 'int': TypeMap(type='int', converter=_safe_cast_int), + } def _init_storage( self, @@ -310,7 +314,11 @@ def _doc2weaviate_create_payload(self, value: 'Document'): :param value: document to create a payload for :return: the payload dictionary """ - extra_columns = {col: value.tags.get(col) for col, _ in self._config.columns} + columns_dict = {key: val for [key, val] in self._config.columns} + extra_columns = { + col: self._map_column(value.tags.get(col), columns_dict[col]) + for col, _ in self._config.columns + } return dict( data_object={ diff --git a/docarray/helper.py b/docarray/helper.py index 30c074ae833..e97b46d17d0 100644 --- a/docarray/helper.py +++ b/docarray/helper.py @@ -5,7 +5,7 @@ import sys import uuid import warnings -from typing import Any, Dict, Optional, Sequence, Tuple +from typing import Any, Dict, Optional, Sequence, Tuple, Union __resources_path__ = os.path.join( os.path.dirname( @@ -441,3 +441,14 @@ def filter_dict(d: Dict) -> Dict: :return: filtered dict """ return dict(filter(lambda item: item[1] is not None, d.items())) + + +def _safe_cast_int(value: Union[str, int, float]) -> int: + """Safely cast string and float to an integer + It mainly avoids silently rounding down the float value + :param value: value to be cast + :return: cast integer + """ + if isinstance(value, float) and not value.is_integer(): + raise ValueError(f"Can't safely cast {value} to an int") + return int(value) diff --git a/tests/unit/array/test_backend_configuration.py b/tests/unit/array/test_backend_configuration.py index 000dcb79d49..86ddc969a52 100644 --- a/tests/unit/array/test_backend_configuration.py +++ b/tests/unit/array/test_backend_configuration.py @@ -1,6 +1,10 @@ +from typing import Tuple, Iterator + +import pytest import requests +import itertools -from docarray import DocumentArray +from docarray import DocumentArray, Document def test_weaviate_hnsw(start_storage): @@ -45,3 +49,95 @@ def test_weaviate_hnsw(start_storage): assert main_class.get('vectorIndexConfig', {}).get('cleanupIntervalSeconds') == 1000 assert main_class.get('vectorIndexConfig', {}).get('skip') is True assert main_class.get('vectorIndexConfig', {}).get('distance') == 'l2-squared' + + +def test_weaviate_da_w_protobuff(start_storage): + + N = 10 + + index = DocumentArray( + storage='weaviate', + config={ + 'name': 'Test', + 'columns': [('price', 'int')], + }, + ) + + docs = DocumentArray([Document(tags={'price': i}) for i in range(N)]) + docs = DocumentArray.from_protobuf( + docs.to_protobuf() + ) # same as streaming the da in jina + + index.extend(docs) + + assert len(index) == N + + +@pytest.mark.parametrize('type_da', [int, float, str]) +@pytest.mark.parametrize('type_column', ['int', 'float', 'str']) +def test_cast_columns_weaviate(start_storage, type_da, type_column, request): + + test_id = request.node.callspec.id.replace( + '-', '' + ) # remove '-' from the test id for the weaviate name + N = 10 + + index = DocumentArray( + storage='weaviate', + config={ + 'name': f'Test{test_id}', + 'columns': [('price', type_column)], + }, + ) + + docs = DocumentArray([Document(tags={'price': type_da(i)}) for i in range(10)]) + + index.extend(docs) + + assert len(index) == N + + +@pytest.mark.parametrize('type_da', [int, float, str]) +@pytest.mark.parametrize('type_column', ['int', 'float', 'str']) +def test_cast_columns_annlite(start_storage, type_da, type_column): + + N = 10 + + index = DocumentArray( + storage='annlite', + config={ + 'n_dim': 3, + 'columns': [('price', type_column)], + }, + ) + + docs = DocumentArray([Document(tags={'price': type_da(i)}) for i in range(10)]) + + index.extend(docs) + + assert len(index) == N + + +@pytest.mark.parametrize('type_da', [int, float, str]) +@pytest.mark.parametrize('type_column', ['int', 'float', 'str']) +def test_cast_columns_qdrant(start_storage, type_da, type_column, request): + + test_id = request.node.callspec.id.replace( + '-', '' + ) # remove '-' from the test id for the weaviate name + N = 10 + + index = DocumentArray( + storage='qdrant', + config={ + 'collection_name': f'test{test_id}', + 'n_dim': 3, + 'columns': [('price', type_column)], + }, + ) + + docs = DocumentArray([Document(tags={'price': type_da(i)}) for i in range(10)]) + + index.extend(docs) + + assert len(index) == N diff --git a/tests/unit/test_helper.py b/tests/unit/test_helper.py index e6bc4f6c4df..bb271f826e8 100644 --- a/tests/unit/test_helper.py +++ b/tests/unit/test_helper.py @@ -8,6 +8,7 @@ add_protocol_and_compress_to_file_path, filter_dict, get_full_version, + _safe_cast_int, ) @@ -61,3 +62,14 @@ def test_filter_dict(): def test_ci_vendor(): if 'GITHUB_WORKFLOW' in os.environ: assert get_full_version()['ci-vendor'] == 'GITHUB_ACTIONS' + + +@pytest.mark.parametrize('input,output', [(1, 1), (1.0, 1), ('1', 1)]) +def test_safe_cast(input, output): + assert output == _safe_cast_int(input) + + +@pytest.mark.parametrize('wrong_input', [1.5, 1.001, 2 / 3]) +def test_safe_cast_raise_error(wrong_input): + with pytest.raises(ValueError): + _safe_cast_int(wrong_input)