Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docarray/array/storage/base/getsetdel.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,4 +281,5 @@ def _save_offset2ids(self):
...

def __del__(self):
self._save_offset2ids()
if hasattr(self, '_offset2ids'):
self._save_offset2ids()
33 changes: 14 additions & 19 deletions docarray/array/storage/weaviate/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ class WeaviateConfig:
"""This class stores the config variables to initialize
connection to the Weaviate server"""

n_dim: int
host: Optional[str] = field(default="localhost")
host: Optional[str] = field(default='localhost')
port: Optional[int] = field(default=8080)
protocol: Optional[int] = field(default="http")
protocol: Optional[str] = field(default='http')
name: Optional[str] = None
serialize_config: Dict = field(default_factory=dict)
n_dim: Optional[int] = None # deprecated, not used anymore since weaviate 1.10


class BackendMixin(BaseBackendMixin):
Expand All @@ -60,13 +60,12 @@ def _init_storage(
"""

if not config:
raise ValueError('Config object must be specified')
config = WeaviateConfig()
elif isinstance(config, dict):
config = dataclass_from_dict(WeaviateConfig, config)

from ... import DocumentArray

self._n_dim = config.n_dim
self._serialize_config = config.serialize_config

if config.name and config.name != config.name.capitalize():
Expand Down Expand Up @@ -278,25 +277,21 @@ def _doc2weaviate_create_payload(self, value: 'Document'):
:param value: document to create a payload for
:return: the payload dictionary
"""
if value.embedding is None:
embedding = np.zeros(self._n_dim)
else:
if value.embedding is not None:
from ....math.ndarray import to_numpy_array

embedding = to_numpy_array(value.embedding)

if embedding.ndim > 1:
embedding = np.asarray(embedding).squeeze()
if embedding.shape != (self._n_dim,):
raise ValueError(
f'All documents must have embedding of shape n_dim: {self._n_dim}, receiving shape: {embedding.shape}'
)
if embedding.ndim > 1:
embedding = np.asarray(embedding).squeeze()

# Weaviate expects vector to have dim 2 at least
# or get weaviate.exceptions.UnexpectedStatusCodeException: models.C11yVector
# hence we cast it to list of a single element
if len(embedding) == 1:
embedding = [embedding[0]]
# Weaviate expects vector to have dim 2 at least
# or get weaviate.exceptions.UnexpectedStatusCodeException: models.C11yVector
# hence we cast it to list of a single element
if len(embedding) == 1:
embedding = [embedding[0]]
else:
embedding = None

return dict(
data_object={'_serialized': value.to_base64(**self._serialize_config)},
Expand Down
19 changes: 10 additions & 9 deletions docarray/array/storage/weaviate/getsetdel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@ def _getitem(self, wid: str) -> 'Document':
:raises KeyError: raise error when weaviate id does not exist in storage
:return: Document
"""
resp = self._client.data_object.get_by_id(wid, with_vector=True)
if not resp:
raise KeyError(wid)
return Document.from_base64(
resp['properties']['_serialized'], **self._serialize_config
)
try:
resp = self._client.data_object.get_by_id(wid, with_vector=True)
return Document.from_base64(
resp['properties']['_serialized'], **self._serialize_config
)
except Exception as ex:
raise KeyError(wid) from ex

def _get_doc_by_id(self, _id: str) -> 'Document':
"""Concrete implementation of base class' ``_get_doc_by_id``
Expand All @@ -37,10 +38,10 @@ def _set_doc_by_id(self, _id: str, value: 'Document'):
"""
if _id != value.id:
self._del_doc_by_id(_id)
wid = self._wmap(value.id)

payload = self._doc2weaviate_create_payload(value)
if self._client.data_object.exists(wid):
self._client.data_object.delete(wid)
if self._client.data_object.exists(payload['uuid']):
self._client.data_object.delete(payload['uuid'])
self._client.data_object.create(**payload)

def _del_doc_by_id(self, _id: str):
Expand Down
7 changes: 3 additions & 4 deletions docs/advanced/document-store/weaviate.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ services:
- '8080'
- --scheme
- http
image: semitechnologies/weaviate:1.9.0
image: semitechnologies/weaviate:1.10.0
ports:
- 8080:8080
restart: on-failure:0
Expand All @@ -48,7 +48,7 @@ Assuming service is started using the default configuration (i.e. server address
```python
from docarray import DocumentArray

da = DocumentArray(storage='weaviate', config={'n_dim': 10})
da = DocumentArray(storage='weaviate')
```

The usage would be the same as the ordinary DocumentArray.
Expand All @@ -60,7 +60,7 @@ Note, that the `name` parameter in `config` needs to be capitalized.
```python
from docarray import DocumentArray

da = DocumentArray(storage='weaviate', config={'name': 'Persisted', 'host': 'localhost', 'port': 1234, 'n_dim': 10})
da = DocumentArray(storage='weaviate', config={'name': 'Persisted', 'host': 'localhost', 'port': 1234})

da.summary()
```
Expand All @@ -73,7 +73,6 @@ The following configs can be set:

| Name | Description | Default |
|--------------------|----------------------------------------------------------------------------------------|-----------------------------|
| `n_dim` | Number of dimensions of embeddings to be stored and retrieved | **This is always required** |
| `host` | Hostname of the Weaviate server | 'localhost' |
| `port` | port of the Weaviate server | 8080 |
| `protocol` | protocol to be used. Can be 'http' or 'https' | 'http' |
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/array/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
version: "3.3"
services:
weaviate:
image: semitechnologies/weaviate:1.9.0
image: semitechnologies/weaviate:1.10.0
ports:
- 8080:8080
environment:
Expand Down
13 changes: 0 additions & 13 deletions tests/unit/array/mixins/test_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,16 +147,3 @@ def test_embeddings_setter(da_len, da_cls, config, start_storage):
da.embeddings = np.random.rand(da_len, 5)
for doc in da:
assert doc.embedding.shape == (5,)


@pytest.mark.parametrize('da_len', [0, 1])
@pytest.mark.parametrize('da_cls', [DocumentArrayWeaviate])
@pytest.mark.parametrize(
'config, n_dim', [({'n_dim': 1}, 1), (WeaviateConfig(n_dim=5), 5)]
)
def test_content_by_config(da_len, da_cls, config, n_dim):
with pytest.raises(ValueError):
da_cls(da_len)

da = da_cls.empty(da_len, config=config)
assert da._n_dim == n_dim