Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions docarray/array/storage/annlite/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

import numpy as np

from ..base.backend import BaseBackendMixin
from ....helper import dataclass_from_dict, filter_dict
from ..base.backend import BaseBackendMixin, TypeMap
from ....helper import dataclass_from_dict, filter_dict, _safe_cast_int

if TYPE_CHECKING:
from ....typing import DocumentArraySourceType, ArrayType
Expand All @@ -33,7 +33,11 @@ class AnnliteConfig:
class BackendMixin(BaseBackendMixin):
"""Provide necessary functions to enable this storage backend."""

TYPE_MAP = {'str': 'TEXT', 'float': 'float', 'int': 'integer'}
TYPE_MAP = {
'str': TypeMap(type='TEXT', converter=str),
'float': TypeMap(type='float', converter=float),
'int': TypeMap(type='integer', converter=_safe_cast_int),
}

def _map_embedding(self, embedding: 'ArrayType') -> 'ArrayType':
if embedding is None:
Expand Down
10 changes: 8 additions & 2 deletions docarray/array/storage/base/backend.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from abc import ABC
from collections import namedtuple
from dataclasses import is_dataclass, asdict
from typing import Dict, Optional, TYPE_CHECKING

if TYPE_CHECKING:
from ....typing import DocumentArraySourceType, ArrayType

TypeMap = namedtuple('TypeMap', ['type', 'converter'])


class BaseBackendMixin(ABC):
TYPE_MAP: Dict
TYPE_MAP: Dict[str, TypeMap]

def _init_storage(
self,
Expand All @@ -25,10 +28,13 @@ def _get_storage_infos(self) -> Optional[Dict]:
def _map_id(self, _id: str) -> str:
return _id

def _map_column(self, value, col_type) -> str:
return self.TYPE_MAP[col_type].converter(value)

def _map_embedding(self, embedding: 'ArrayType') -> 'ArrayType':
from ....math.ndarray import to_numpy_array

return to_numpy_array(embedding)

def _map_type(self, col_type: str) -> str:
return self.TYPE_MAP[col_type]
return self.TYPE_MAP[col_type].type
16 changes: 12 additions & 4 deletions docarray/array/storage/weaviate/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
import weaviate

from .... import Document
from ....helper import dataclass_from_dict, filter_dict
from ..base.backend import BaseBackendMixin
from ....helper import dataclass_from_dict, filter_dict, _safe_cast_int
from ..base.backend import BaseBackendMixin, TypeMap
from ..registry import _REGISTRY

if TYPE_CHECKING:
Expand Down Expand Up @@ -52,7 +52,11 @@ class WeaviateConfig:
class BackendMixin(BaseBackendMixin):
"""Provide necessary functions to enable this storage backend."""

TYPE_MAP = {'str': 'string', 'float': 'number', 'int': 'int'}
TYPE_MAP = {
'str': TypeMap(type='string', converter=str),
'float': TypeMap(type='number', converter=float),
'int': TypeMap(type='int', converter=_safe_cast_int),
}

def _init_storage(
self,
Expand Down Expand Up @@ -310,7 +314,11 @@ def _doc2weaviate_create_payload(self, value: 'Document'):
:param value: document to create a payload for
:return: the payload dictionary
"""
extra_columns = {col: value.tags.get(col) for col, _ in self._config.columns}
columns_dict = {key: val for [key, val] in self._config.columns}
extra_columns = {
col: self._map_column(value.tags.get(col), columns_dict[col])
for col, _ in self._config.columns
}

return dict(
data_object={
Expand Down
13 changes: 12 additions & 1 deletion docarray/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import sys
import uuid
import warnings
from typing import Any, Dict, Optional, Sequence, Tuple
from typing import Any, Dict, Optional, Sequence, Tuple, Union

__resources_path__ = os.path.join(
os.path.dirname(
Expand Down Expand Up @@ -441,3 +441,14 @@ def filter_dict(d: Dict) -> Dict:
:return: filtered dict
"""
return dict(filter(lambda item: item[1] is not None, d.items()))


def _safe_cast_int(value: Union[str, int, float]) -> int:
"""Safely cast string and float to an integer
It mainly avoids silently rounding down the float value
:param value: value to be cast
:return: cast integer
"""
if isinstance(value, float) and not value.is_integer():
raise ValueError(f"Can't safely cast {value} to an int")
return int(value)
98 changes: 97 additions & 1 deletion tests/unit/array/test_backend_configuration.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from typing import Tuple, Iterator

import pytest
import requests
import itertools

from docarray import DocumentArray
from docarray import DocumentArray, Document


def test_weaviate_hnsw(start_storage):
Expand Down Expand Up @@ -45,3 +49,95 @@ def test_weaviate_hnsw(start_storage):
assert main_class.get('vectorIndexConfig', {}).get('cleanupIntervalSeconds') == 1000
assert main_class.get('vectorIndexConfig', {}).get('skip') is True
assert main_class.get('vectorIndexConfig', {}).get('distance') == 'l2-squared'


def test_weaviate_da_w_protobuff(start_storage):

N = 10

index = DocumentArray(
storage='weaviate',
config={
'name': 'Test',
'columns': [('price', 'int')],
},
)

docs = DocumentArray([Document(tags={'price': i}) for i in range(N)])
docs = DocumentArray.from_protobuf(
docs.to_protobuf()
) # same as streaming the da in jina

index.extend(docs)

assert len(index) == N


@pytest.mark.parametrize('type_da', [int, float, str])
@pytest.mark.parametrize('type_column', ['int', 'float', 'str'])
def test_cast_columns_weaviate(start_storage, type_da, type_column, request):

test_id = request.node.callspec.id.replace(
'-', ''
) # remove '-' from the test id for the weaviate name
N = 10

index = DocumentArray(
storage='weaviate',
config={
'name': f'Test{test_id}',
'columns': [('price', type_column)],
},
)

docs = DocumentArray([Document(tags={'price': type_da(i)}) for i in range(10)])

index.extend(docs)

assert len(index) == N


@pytest.mark.parametrize('type_da', [int, float, str])
@pytest.mark.parametrize('type_column', ['int', 'float', 'str'])
def test_cast_columns_annlite(start_storage, type_da, type_column):

N = 10

index = DocumentArray(
storage='annlite',
config={
'n_dim': 3,
'columns': [('price', type_column)],
},
)

docs = DocumentArray([Document(tags={'price': type_da(i)}) for i in range(10)])

index.extend(docs)

assert len(index) == N


@pytest.mark.parametrize('type_da', [int, float, str])
@pytest.mark.parametrize('type_column', ['int', 'float', 'str'])
def test_cast_columns_qdrant(start_storage, type_da, type_column, request):

test_id = request.node.callspec.id.replace(
'-', ''
) # remove '-' from the test id for the weaviate name
N = 10

index = DocumentArray(
storage='qdrant',
config={
'collection_name': f'test{test_id}',
'n_dim': 3,
'columns': [('price', type_column)],
},
)

docs = DocumentArray([Document(tags={'price': type_da(i)}) for i in range(10)])

index.extend(docs)

assert len(index) == N
12 changes: 12 additions & 0 deletions tests/unit/test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
add_protocol_and_compress_to_file_path,
filter_dict,
get_full_version,
_safe_cast_int,
)


Expand Down Expand Up @@ -61,3 +62,14 @@ def test_filter_dict():
def test_ci_vendor():
if 'GITHUB_WORKFLOW' in os.environ:
assert get_full_version()['ci-vendor'] == 'GITHUB_ACTIONS'


@pytest.mark.parametrize('input,output', [(1, 1), (1.0, 1), ('1', 1)])
def test_safe_cast(input, output):
assert output == _safe_cast_int(input)


@pytest.mark.parametrize('wrong_input', [1.5, 1.001, 2 / 3])
def test_safe_cast_raise_error(wrong_input):
with pytest.raises(ValueError):
_safe_cast_int(wrong_input)