diff --git a/docarray/array/storage/base/getsetdel.py b/docarray/array/storage/base/getsetdel.py index e6ca6a6b735..adc36f4c915 100644 --- a/docarray/array/storage/base/getsetdel.py +++ b/docarray/array/storage/base/getsetdel.py @@ -281,4 +281,5 @@ def _save_offset2ids(self): ... def __del__(self): - self._save_offset2ids() + if hasattr(self, '_offset2ids'): + self._save_offset2ids() diff --git a/docarray/array/storage/weaviate/backend.py b/docarray/array/storage/weaviate/backend.py index 5c711cf6d75..13aabb03ca8 100644 --- a/docarray/array/storage/weaviate/backend.py +++ b/docarray/array/storage/weaviate/backend.py @@ -33,12 +33,12 @@ class WeaviateConfig: """This class stores the config variables to initialize connection to the Weaviate server""" - n_dim: int - host: Optional[str] = field(default="localhost") + host: Optional[str] = field(default='localhost') port: Optional[int] = field(default=8080) - protocol: Optional[int] = field(default="http") + protocol: Optional[str] = field(default='http') name: Optional[str] = None serialize_config: Dict = field(default_factory=dict) + n_dim: Optional[int] = None # deprecated, not used anymore since weaviate 1.10 class BackendMixin(BaseBackendMixin): @@ -60,13 +60,12 @@ def _init_storage( """ if not config: - raise ValueError('Config object must be specified') + config = WeaviateConfig() elif isinstance(config, dict): config = dataclass_from_dict(WeaviateConfig, config) from ... import DocumentArray - self._n_dim = config.n_dim self._serialize_config = config.serialize_config if config.name and config.name != config.name.capitalize(): @@ -278,25 +277,21 @@ def _doc2weaviate_create_payload(self, value: 'Document'): :param value: document to create a payload for :return: the payload dictionary """ - if value.embedding is None: - embedding = np.zeros(self._n_dim) - else: + if value.embedding is not None: from ....math.ndarray import to_numpy_array embedding = to_numpy_array(value.embedding) - if embedding.ndim > 1: - embedding = np.asarray(embedding).squeeze() - if embedding.shape != (self._n_dim,): - raise ValueError( - f'All documents must have embedding of shape n_dim: {self._n_dim}, receiving shape: {embedding.shape}' - ) + if embedding.ndim > 1: + embedding = np.asarray(embedding).squeeze() - # Weaviate expects vector to have dim 2 at least - # or get weaviate.exceptions.UnexpectedStatusCodeException: models.C11yVector - # hence we cast it to list of a single element - if len(embedding) == 1: - embedding = [embedding[0]] + # Weaviate expects vector to have dim 2 at least + # or get weaviate.exceptions.UnexpectedStatusCodeException: models.C11yVector + # hence we cast it to list of a single element + if len(embedding) == 1: + embedding = [embedding[0]] + else: + embedding = None return dict( data_object={'_serialized': value.to_base64(**self._serialize_config)}, diff --git a/docarray/array/storage/weaviate/getsetdel.py b/docarray/array/storage/weaviate/getsetdel.py index 2998cdbe7bb..f3a0da4af0e 100644 --- a/docarray/array/storage/weaviate/getsetdel.py +++ b/docarray/array/storage/weaviate/getsetdel.py @@ -14,12 +14,13 @@ def _getitem(self, wid: str) -> 'Document': :raises KeyError: raise error when weaviate id does not exist in storage :return: Document """ - resp = self._client.data_object.get_by_id(wid, with_vector=True) - if not resp: - raise KeyError(wid) - return Document.from_base64( - resp['properties']['_serialized'], **self._serialize_config - ) + try: + resp = self._client.data_object.get_by_id(wid, with_vector=True) + return Document.from_base64( + resp['properties']['_serialized'], **self._serialize_config + ) + except Exception as ex: + raise KeyError(wid) from ex def _get_doc_by_id(self, _id: str) -> 'Document': """Concrete implementation of base class' ``_get_doc_by_id`` @@ -37,10 +38,10 @@ def _set_doc_by_id(self, _id: str, value: 'Document'): """ if _id != value.id: self._del_doc_by_id(_id) - wid = self._wmap(value.id) + payload = self._doc2weaviate_create_payload(value) - if self._client.data_object.exists(wid): - self._client.data_object.delete(wid) + if self._client.data_object.exists(payload['uuid']): + self._client.data_object.delete(payload['uuid']) self._client.data_object.create(**payload) def _del_doc_by_id(self, _id: str): diff --git a/docs/advanced/document-store/weaviate.md b/docs/advanced/document-store/weaviate.md index 8848b84480f..e9614e1ca8c 100644 --- a/docs/advanced/document-store/weaviate.md +++ b/docs/advanced/document-store/weaviate.md @@ -21,7 +21,7 @@ services: - '8080' - --scheme - http - image: semitechnologies/weaviate:1.9.0 + image: semitechnologies/weaviate:1.10.0 ports: - 8080:8080 restart: on-failure:0 @@ -48,7 +48,7 @@ Assuming service is started using the default configuration (i.e. server address ```python from docarray import DocumentArray -da = DocumentArray(storage='weaviate', config={'n_dim': 10}) +da = DocumentArray(storage='weaviate') ``` The usage would be the same as the ordinary DocumentArray. @@ -60,7 +60,7 @@ Note, that the `name` parameter in `config` needs to be capitalized. ```python from docarray import DocumentArray -da = DocumentArray(storage='weaviate', config={'name': 'Persisted', 'host': 'localhost', 'port': 1234, 'n_dim': 10}) +da = DocumentArray(storage='weaviate', config={'name': 'Persisted', 'host': 'localhost', 'port': 1234}) da.summary() ``` @@ -73,7 +73,6 @@ The following configs can be set: | Name | Description | Default | |--------------------|----------------------------------------------------------------------------------------|-----------------------------| -| `n_dim` | Number of dimensions of embeddings to be stored and retrieved | **This is always required** | | `host` | Hostname of the Weaviate server | 'localhost' | | `port` | port of the Weaviate server | 8080 | | `protocol` | protocol to be used. Can be 'http' or 'https' | 'http' | diff --git a/tests/unit/array/docker-compose.yml b/tests/unit/array/docker-compose.yml index 175777badd7..21806311f3f 100644 --- a/tests/unit/array/docker-compose.yml +++ b/tests/unit/array/docker-compose.yml @@ -1,7 +1,7 @@ version: "3.3" services: weaviate: - image: semitechnologies/weaviate:1.9.0 + image: semitechnologies/weaviate:1.10.0 ports: - 8080:8080 environment: diff --git a/tests/unit/array/mixins/test_content.py b/tests/unit/array/mixins/test_content.py index d38a423aa0c..8659341e1b9 100644 --- a/tests/unit/array/mixins/test_content.py +++ b/tests/unit/array/mixins/test_content.py @@ -147,16 +147,3 @@ def test_embeddings_setter(da_len, da_cls, config, start_storage): da.embeddings = np.random.rand(da_len, 5) for doc in da: assert doc.embedding.shape == (5,) - - -@pytest.mark.parametrize('da_len', [0, 1]) -@pytest.mark.parametrize('da_cls', [DocumentArrayWeaviate]) -@pytest.mark.parametrize( - 'config, n_dim', [({'n_dim': 1}, 1), (WeaviateConfig(n_dim=5), 5)] -) -def test_content_by_config(da_len, da_cls, config, n_dim): - with pytest.raises(ValueError): - da_cls(da_len) - - da = da_cls.empty(da_len, config=config) - assert da._n_dim == n_dim