diff --git a/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py b/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py index e33765c1ecf..941de3b64cd 100644 --- a/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py +++ b/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py @@ -19,6 +19,7 @@ deserialize_entity_key, serialize_entity_key, ) +from feast.infra.online_stores.helpers import compute_table_id from feast.infra.online_stores.online_store import OnlineStore from feast.infra.online_stores.vector_store import VectorStoreConfig from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto @@ -164,7 +165,9 @@ def _get_or_create_collection( ) -> Dict[str, Any]: self.client = self._connect(config) vector_field_dict = {k.name: k for k in table.schema if k.vector_index} - collection_name = _table_id(config.project, table) + collection_name = _table_id( + config.project, table, config.registry.enable_online_feature_view_versioning + ) if collection_name not in self._collections: # Create a composite key by combining entity fields composite_key_name = _get_composite_key_name(table) @@ -346,7 +349,9 @@ def online_read( requested_features: Optional[List[str]] = None, ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: self.client = self._connect(config) - collection_name = _table_id(config.project, table) + collection_name = _table_id( + config.project, table, config.registry.enable_online_feature_view_versioning + ) collection = self._get_or_create_collection(config, table) composite_key_name = _get_composite_key_name(table) @@ -493,11 +498,12 @@ def update( for table in tables_to_keep: self._get_or_create_collection(config, table) + # Always drop the base collection plus any "_v{N}" siblings, regardless of + # the current versioning flag. This handles mixed-state repos where + # versioning was toggled on/off across applies and would otherwise leave + # orphan collections behind in Milvus. for table in tables_to_delete: - collection_name = _table_id(config.project, table) - if self._collections.get(collection_name, None): - self.client.drop_collection(collection_name) - self._collections.pop(collection_name, None) + self._drop_all_version_collections(config.project, table) def plan( self, config: RepoConfig, desired_registry_proto: RegistryProto @@ -511,11 +517,9 @@ def teardown( entities: Sequence[Entity], ): self.client = self._connect(config) + # See update(): drop base + all "_v{N}" siblings to handle mixed-state repos. for table in tables: - collection_name = _table_id(config.project, table) - if self._collections.get(collection_name, None): - self.client.drop_collection(collection_name) - self._collections.pop(collection_name, None) + self._drop_all_version_collections(config.project, table) def retrieve_online_documents_v2( self, @@ -551,7 +555,9 @@ def retrieve_online_documents_v2( k.name: k.dtype for k in table.entity_columns } self.client = self._connect(config) - collection_name = _table_id(config.project, table) + collection_name = _table_id( + config.project, table, config.registry.enable_online_feature_view_versioning + ) collection = self._get_or_create_collection(config, table) if not config.online_store.vector_enabled: raise ValueError("Vector search is not enabled in the online store config") @@ -748,9 +754,28 @@ def retrieve_online_documents_v2( result_list.append((res_ts, entity_key_proto, res if res else None)) return result_list + def _drop_all_version_collections(self, project: str, table: FeatureView) -> None: + """Drop the base collection and every ``_v{N}`` versioned sibling. + + Mirrors the ``_drop_all_version_tables`` helpers in the MySQL/PostgreSQL + online stores. Always called from ``update`` and ``teardown`` so a + repo that toggles versioning on and off does not leave orphan + collections behind in Milvus. + """ + base = f"{project}_{table.name}" + versioned_prefix = f"{base}_v" + assert self.client is not None, "Milvus client is not initialized" + for collection_name in self.client.list_collections(): + if collection_name == base or ( + collection_name.startswith(versioned_prefix) + and collection_name[len(versioned_prefix) :].isdigit() + ): + self.client.drop_collection(collection_name) + self._collections.pop(collection_name, None) + -def _table_id(project: str, table: FeatureView) -> str: - return f"{project}_{table.name}" +def _table_id(project: str, table: FeatureView, enable_versioning: bool = False) -> str: + return compute_table_id(project, table, enable_versioning) def _get_composite_key_name(table: FeatureView) -> str: diff --git a/sdk/python/feast/infra/online_stores/online_store.py b/sdk/python/feast/infra/online_stores/online_store.py index c3fda86cc5e..5b5daf17575 100644 --- a/sdk/python/feast/infra/online_stores/online_store.py +++ b/sdk/python/feast/infra/online_stores/online_store.py @@ -292,6 +292,14 @@ def _check_versioned_read_support(self, grouped_refs): supported_types.append(DynamoDBOnlineStore) except Exception: pass + try: + from feast.infra.online_stores.milvus_online_store.milvus import ( + MilvusOnlineStore, + ) + + supported_types.append(MilvusOnlineStore) + except ImportError: + pass if isinstance(self, tuple(supported_types)): return diff --git a/sdk/python/tests/unit/infra/online_store/test_milvus_versioning.py b/sdk/python/tests/unit/infra/online_store/test_milvus_versioning.py new file mode 100644 index 00000000000..c2979d37690 --- /dev/null +++ b/sdk/python/tests/unit/infra/online_store/test_milvus_versioning.py @@ -0,0 +1,180 @@ +"""Unit tests for Milvus online store feature view versioning.""" + +from datetime import timedelta +from unittest.mock import MagicMock + +from feast import Entity, FeatureView +from feast.field import Field +from feast.types import Float32 +from feast.value_type import ValueType + + +def _make_feature_view(name="driver_stats", version_number=None, version_tag=None): + entity = Entity( + name="driver_id", + join_keys=["driver_id"], + value_type=ValueType.INT64, + ) + fv = FeatureView( + name=name, + entities=[entity], + ttl=timedelta(days=1), + schema=[Field(name="trips_today", dtype=Float32)], + ) + if version_number is not None: + fv.current_version_number = version_number + if version_tag is not None: + fv.projection.version_tag = version_tag + return fv + + +def _make_config(project="test_project", versioning=False): + config = MagicMock() + config.project = project + config.entity_key_serialization_version = 2 + config.registry.enable_online_feature_view_versioning = versioning + return config + + +class TestTableId: + """Test _table_id with versioning enabled/disabled.""" + + def test_no_versioning(self): + from feast.infra.online_stores.milvus_online_store.milvus import _table_id + + fv = _make_feature_view() + config = _make_config(versioning=False) + assert _table_id(config.project, fv) == "test_project_driver_stats" + + def test_versioning_enabled_with_version(self): + from feast.infra.online_stores.milvus_online_store.milvus import _table_id + + fv = _make_feature_view(version_number=2) + config = _make_config(versioning=True) + assert ( + _table_id(config.project, fv, enable_versioning=True) + == "test_project_driver_stats_v2" + ) + + def test_projection_version_tag_takes_priority(self): + from feast.infra.online_stores.milvus_online_store.milvus import _table_id + + fv = _make_feature_view(version_number=1, version_tag=3) + config = _make_config(versioning=True) + assert ( + _table_id(config.project, fv, enable_versioning=True) + == "test_project_driver_stats_v3" + ) + + def test_version_zero_no_suffix(self): + from feast.infra.online_stores.milvus_online_store.milvus import _table_id + + fv = _make_feature_view(version_number=0) + config = _make_config(versioning=True) + assert ( + _table_id(config.project, fv, enable_versioning=True) + == "test_project_driver_stats" + ) + + def test_versioning_enabled_no_version_set(self): + from feast.infra.online_stores.milvus_online_store.milvus import _table_id + + fv = _make_feature_view() + config = _make_config(versioning=True) + assert ( + _table_id(config.project, fv, enable_versioning=True) + == "test_project_driver_stats" + ) + + def test_versioning_disabled_ignores_version(self): + from feast.infra.online_stores.milvus_online_store.milvus import _table_id + + fv = _make_feature_view(version_number=5) + config = _make_config(versioning=False) + assert _table_id(config.project, fv) == "test_project_driver_stats" + + +class TestMilvusVersionedReadSupport: + """Test that MilvusOnlineStore passes _check_versioned_read_support.""" + + def test_allowed_with_version_tag(self): + from feast.infra.online_stores.milvus_online_store.milvus import ( + MilvusOnlineStore, + ) + + store = MilvusOnlineStore() + fv = _make_feature_view() + fv.projection.version_tag = 2 + store._check_versioned_read_support([(fv, ["trips_today"])]) + + def test_allowed_without_version_tag(self): + from feast.infra.online_stores.milvus_online_store.milvus import ( + MilvusOnlineStore, + ) + + store = MilvusOnlineStore() + fv = _make_feature_view() + store._check_versioned_read_support([(fv, ["trips_today"])]) + + +class TestTeardownDropsAllVersions: + """Teardown should drop the base collection AND all versioned collections.""" + + def _build_store_with_collections(self, existing_collections): + from feast.infra.online_stores.milvus_online_store.milvus import ( + MilvusOnlineStore, + ) + + store = MilvusOnlineStore() + store.client = MagicMock() + store.client.list_collections.return_value = existing_collections + store._connect = MagicMock(return_value=store.client) + store._collections = {name: MagicMock() for name in existing_collections} + return store + + def test_teardown_drops_base_and_all_versioned_collections(self): + fv = _make_feature_view() + config = _make_config(versioning=True) + existing = [ + "test_project_driver_stats", + "test_project_driver_stats_v1", + "test_project_driver_stats_v2", + "test_project_other_view", # unrelated, must not be dropped + ] + store = self._build_store_with_collections(existing) + + store.teardown(config, [fv], []) + + dropped = {call.args[0] for call in store.client.drop_collection.call_args_list} + assert dropped == { + "test_project_driver_stats", + "test_project_driver_stats_v1", + "test_project_driver_stats_v2", + } + assert "test_project_other_view" not in dropped + + def test_update_drops_all_versions_for_deleted_table(self): + fv = _make_feature_view() + config = _make_config(versioning=True) + existing = [ + "test_project_driver_stats", + "test_project_driver_stats_v3", + "test_project_driver_stats_v4", + ] + store = self._build_store_with_collections(existing) + + store.update( + config=config, + tables_to_delete=[fv], + tables_to_keep=[], + entities_to_delete=[], + entities_to_keep=[], + partial=False, + ) + + dropped = {call.args[0] for call in store.client.drop_collection.call_args_list} + assert dropped == { + "test_project_driver_stats", + "test_project_driver_stats_v3", + "test_project_driver_stats_v4", + }