diff --git a/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py b/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py index ee2534684cc..e33765c1ecf 100644 --- a/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py +++ b/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py @@ -1,4 +1,5 @@ import base64 +import logging from datetime import datetime from pathlib import Path from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union @@ -42,6 +43,8 @@ to_naive_utc, ) +logger = logging.getLogger(__name__) + PROTO_TO_MILVUS_TYPE_MAPPING: Dict[ValueType, DataType] = { PROTO_VALUE_TO_VALUE_TYPE_MAP["bytes_val"]: DataType.VARCHAR, ValueType.IMAGE_BYTES: DataType.VARCHAR, @@ -140,11 +143,13 @@ def _connect(self, config: RepoConfig) -> MilvusClient: if not self.client: if config.provider == "local" and config.online_store.path: db_path = self._get_db_path(config) - print(f"Connecting to Milvus in local mode using {db_path}") + logger.info("Connecting to Milvus in local mode using %s", db_path) self.client = MilvusClient(db_path) else: - print( - f"Connecting to Milvus remotely at {config.online_store.host}:{config.online_store.port}" + logger.info( + "Connecting to Milvus remotely at %s:%s", + config.online_store.host, + config.online_store.port, ) self.client = MilvusClient( uri=f"{config.online_store.host}:{config.online_store.port}", @@ -339,7 +344,6 @@ def online_read( table: FeatureView, entity_keys: List[EntityKeyProto], requested_features: Optional[List[str]] = None, - full_feature_names: bool = False, ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: self.client = self._connect(config) collection_name = _table_id(config.project, table) @@ -487,7 +491,7 @@ def update( ): self.client = self._connect(config) for table in tables_to_keep: - self._collections = self._get_or_create_collection(config, table) + self._get_or_create_collection(config, table) for table in tables_to_delete: collection_name = _table_id(config.project, table) @@ -498,7 +502,7 @@ def update( def plan( self, config: RepoConfig, desired_registry_proto: RegistryProto ) -> List[InfraObject]: - raise NotImplementedError + return [] def teardown( self, @@ -686,9 +690,8 @@ def retrieve_online_documents_v2( for hit in hits: res = {} res_ts = None - entity_key_bytes = bytes.fromhex( - hit.get("entity", {}).get(composite_key_name, None) - ) + raw_key = hit.get("entity", {}).get(composite_key_name) + entity_key_bytes = bytes.fromhex(raw_key) if raw_key else None entity_key_proto = ( deserialize_entity_key(entity_key_bytes) if entity_key_bytes diff --git a/sdk/python/tests/unit/online_store/test_online_retrieval.py b/sdk/python/tests/unit/online_store/test_online_retrieval.py index 60f583ad669..63dcc8c7e48 100644 --- a/sdk/python/tests/unit/online_store/test_online_retrieval.py +++ b/sdk/python/tests/unit/online_store/test_online_retrieval.py @@ -1714,3 +1714,181 @@ def test_milvus_keyword_search() -> None: assert len(result_hybrid["content"]) > 0 assert any("Feast" in content for content in result_hybrid["content"]) assert len(result_hybrid["vector"]) > 0 + + +def test_milvus_update_preserves_collection_cache() -> None: + """ + Regression test: update() used to overwrite self._collections with the + describe_collection() dict of the last processed table, replacing the + dict-of-dicts cache with a single flat dict. After the fix, each call + to _get_or_create_collection() updates the keyed entry in-place and the + cache remains a proper mapping from collection name to collection info. + """ + from datetime import timedelta + + from feast import Entity, FeatureView, Field, FileSource + from feast.types import Array, Float32, Int64, String + + runner = CliRunner() + with runner.local_repo( + example_repo_py=get_example_repo("example_rag_feature_repo.py"), + offline_store="file", + online_store="milvus", + apply=False, + teardown=False, + ) as store: + source = FileSource( + path="data/dummy.parquet", + timestamp_field="event_timestamp", + created_timestamp_column="created_timestamp", + ) + entity_a = Entity(name="id_a", join_keys=["id_a"], value_type=ValueType.INT64) + entity_b = Entity(name="id_b", join_keys=["id_b"], value_type=ValueType.INT64) + + fv_a = FeatureView( + name="fv_a", + entities=[entity_a], + schema=[ + Field(name="id_a", dtype=Int64), + Field( + name="vec_a", + dtype=Array(Float32), + vector_index=True, + vector_search_metric="COSINE", + ), + Field(name="text_a", dtype=String), + ], + source=source, + ttl=timedelta(hours=1), + ) + fv_b = FeatureView( + name="fv_b", + entities=[entity_b], + schema=[ + Field(name="id_b", dtype=Int64), + Field( + name="vec_b", + dtype=Array(Float32), + vector_index=True, + vector_search_metric="COSINE", + ), + Field(name="text_b", dtype=String), + ], + source=source, + ttl=timedelta(hours=1), + ) + + store.apply([source, entity_a, entity_b, fv_a, fv_b]) + + online_store = store._provider._online_store + # After applying two feature views, the cache must be a proper dict + # mapping collection names to collection-info dicts, not a flat dict. + assert isinstance(online_store._collections, dict), ( + "_collections should be a dict" + ) + collection_name_a = f"{store.config.project}_fv_a" + collection_name_b = f"{store.config.project}_fv_b" + assert collection_name_a in online_store._collections, ( + f"Cache missing entry for {collection_name_a}" + ) + assert collection_name_b in online_store._collections, ( + f"Cache missing entry for {collection_name_b} — " + "update() likely overwrote _collections with a single collection dict" + ) + # Each cached value must be a collection-info dict (has a 'fields' key), + # not itself keyed by collection name. + for name in [collection_name_a, collection_name_b]: + assert "fields" in online_store._collections[name], ( + f"Cache entry for {name} looks like a corrupted flat dict" + ) + + +def test_milvus_plan_returns_empty_list() -> None: + """ + Regression test: plan() used to raise NotImplementedError, causing + `feast plan` to crash for any project using the Milvus online store. + It should return [] matching the OnlineStore base class default. + """ + from feast.infra.online_stores.milvus_online_store.milvus import MilvusOnlineStore + + store = MilvusOnlineStore() + result = store.plan(config=None, desired_registry_proto=None) # type: ignore[arg-type] + assert result == [], f"plan() should return [] but returned {result!r}" + + +def test_milvus_retrieve_online_documents_v2_missing_entity_key() -> None: + """ + Regression test: retrieve_online_documents_v2() passed the raw + hit.get("entity", {}).get(composite_key_name, None) directly to + bytes.fromhex(), raising TypeError when the key was absent. + After the fix, a missing composite key produces a None entity_key_proto + instead of crashing. + """ + from datetime import timedelta + from unittest.mock import patch + + from feast import Entity, FeatureView, Field, FileSource + from feast.types import Array, Float32, Int64, String + + runner = CliRunner() + with runner.local_repo( + example_repo_py=get_example_repo("example_rag_feature_repo.py"), + offline_store="file", + online_store="milvus", + apply=False, + teardown=False, + ) as store: + source = FileSource( + path="data/dummy.parquet", + timestamp_field="event_timestamp", + created_timestamp_column="created_timestamp", + ) + entity = Entity(name="doc_id", join_keys=["doc_id"], value_type=ValueType.INT64) + fv = FeatureView( + name="docs", + entities=[entity], + schema=[ + Field(name="doc_id", dtype=Int64), + Field( + name="vec", + dtype=Array(Float32), + vector_index=True, + vector_search_metric="COSINE", + ), + Field(name="text", dtype=String), + ], + source=source, + ttl=timedelta(hours=1), + ) + store.apply([source, entity, fv]) + + online_store = store._provider._online_store + fv_obj = store.get_feature_view("docs") + # Simulate a search hit that is missing the composite primary key. + fake_hit = { + "entity": { + "event_ts": int(_utc_now().timestamp() * 1e6), + "created_ts": int(_utc_now().timestamp() * 1e6), + "text": "hello", + }, + "distance": 0.9, + } + + mock_results = [[fake_hit]] + with patch.object(online_store.client, "search", return_value=mock_results): + with patch.object( + online_store.client, "load_collection", return_value=None + ): + # Before the fix this raised TypeError: fromhex argument must be str, not None + result = online_store.retrieve_online_documents_v2( + config=store.config, + table=fv_obj, + requested_features=["text"], + embedding=[0.1] * 10, + top_k=1, + ) + assert len(result) == 1 + _ts, entity_key_proto, _features = result[0] + assert entity_key_proto is None, ( + "entity_key_proto should be None when the composite key is absent from the hit" + )