diff --git a/sdk/python/feast/infra/online_stores/redis.py b/sdk/python/feast/infra/online_stores/redis.py index ad9e378a95f..d5cd1f998e0 100644 --- a/sdk/python/feast/infra/online_stores/redis.py +++ b/sdk/python/feast/infra/online_stores/redis.py @@ -11,8 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import base64 import json import logging +import math from datetime import datetime, timezone from enum import Enum from typing import ( @@ -32,7 +34,7 @@ from pydantic import StrictStr from feast import Entity, FeatureView, RepoConfig, utils -from feast.infra.key_encoding_utils import serialize_entity_key +from feast.infra.key_encoding_utils import deserialize_entity_key, serialize_entity_key from feast.infra.online_stores.helpers import ( _mmh3, _redis_key, @@ -40,10 +42,12 @@ compute_versioned_name, ) from feast.infra.online_stores.online_store import OnlineStore +from feast.infra.online_stores.vector_store import VectorStoreConfig from feast.infra.supported_async_methods import SupportedAsyncMethods from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto from feast.repo_config import FeastConfigBaseModel +from feast.type_map import feast_value_type_to_python_type try: from redis import Redis @@ -71,7 +75,7 @@ class RedisType(str, Enum): redis_sentinel = "redis_sentinel" -class RedisOnlineStoreConfig(FeastConfigBaseModel): +class RedisOnlineStoreConfig(FeastConfigBaseModel, VectorStoreConfig): """Online store config for Redis store""" type: Literal["redis"] = "redis" @@ -114,15 +118,163 @@ class RedisOnlineStore(OnlineStore): _client_async: Optional[Union[redis_asyncio.Redis, redis_asyncio.RedisCluster]] = ( None ) + _vadd_supported: Optional[bool] = None @property def async_supported(self) -> SupportedAsyncMethods: return SupportedAsyncMethods(read=True, write=True) + def _check_vadd_supported(self, client: Union[Redis, RedisCluster]) -> bool: + """Check if the connected Redis server supports VADD (Redis 8+ Vector Sets).""" + if self._vadd_supported is not None: + return self._vadd_supported + try: + info = client.execute_command("COMMAND", "INFO", "VADD") + self._vadd_supported = ( + info is not None and len(info) > 0 and info[0] is not None + ) + except Exception: + self._vadd_supported = False + return self._vadd_supported + + @staticmethod + def _vector_set_key(project: str, fv_name: str) -> str: + """Return the Redis key used for the Vector Set of a feature view.""" + return f"vs:{project}:{fv_name}" + + @staticmethod + def _vector_element_id( + entity_key: EntityKeyProto, entity_key_serialization_version: int = 3 + ) -> str: + """Create a unique string element ID from an entity key for use in VADD/VSIM.""" + return serialize_entity_key( + entity_key, + entity_key_serialization_version=entity_key_serialization_version, + ).hex() + + @staticmethod + def _normalize_vector(vector: Sequence[float]) -> List[float]: + """Return a unit-length copy of *vector* for vector-set indexing/search.""" + floats = [float(v) for v in vector] + norm = math.sqrt(sum(v * v for v in floats)) + if norm == 0: + return floats + return [v / norm for v in floats] + + @staticmethod + def _json_safe_attr_value(value: Any) -> Any: + """Convert Redis vector-set attributes into JSON-serializable values.""" + if isinstance(value, bytes): + return base64.b64encode(value).decode("ascii") + if isinstance(value, datetime): + return value.isoformat() + if isinstance(value, list): + return [RedisOnlineStore._json_safe_attr_value(v) for v in value] + if isinstance(value, tuple): + return [RedisOnlineStore._json_safe_attr_value(v) for v in value] + if isinstance(value, set): + return [RedisOnlineStore._json_safe_attr_value(v) for v in value] + if isinstance(value, dict): + return { + str(k): RedisOnlineStore._json_safe_attr_value(v) + for k, v in value.items() + } + return value + + def _vector_search_metric( + self, + config: RepoConfig, + table: FeatureView, + distance_metric: Optional[str] = None, + ) -> str: + """Resolve the vector-search metric for the feature view.""" + vector_fields = [f for f in table.features if getattr(f, "vector_index", None)] + field_metric = ( + getattr(vector_fields[0], "vector_search_metric", None) + if vector_fields + else None + ) + metric = ( + distance_metric + or field_metric + or getattr(config.online_store, "similarity", None) + ) + metric = (metric or "COSINE").upper() + if metric != "COSINE": + raise ValueError( + f"Unsupported distance metric {metric}. Redis online store only supports COSINE." + ) + return metric + + def _vadd_vectors( + self, + client: Union[Redis, RedisCluster], + config: RepoConfig, + table: FeatureView, + data: List[ + Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] + ], + ) -> None: + """Index vector embeddings into a Redis Vector Set using VADD. + + Identifies the vector field from the FeatureView schema and issues + VADD commands for each entity row that contains a vector value. + Non-vector features and entity join-key values are stored as JSON + attributes via SETATTR so they can be used with VSIM FILTER. + """ + vector_fields = [f for f in table.features if getattr(f, "vector_index", None)] + if not vector_fields: + return + vector_field = vector_fields[0] + fv_name = _versioned_fv_name(table, config) + vs_key = self._vector_set_key(config.project, fv_name) + self._vector_search_metric(config, table) + + with client.pipeline(transaction=False) as pipe: + for entity_key, values, _ts, _created in data: + vec_val = values.get(vector_field.name) + if vec_val is None: + continue + python_vector = feast_value_type_to_python_type(vec_val) + if not isinstance(python_vector, Sequence) or isinstance( + python_vector, (bytes, str) + ): + continue + floats = [float(v) for v in python_vector] + if not floats: + continue + floats = self._normalize_vector(floats) + dim = len(floats) + element_id = self._vector_element_id( + entity_key, config.entity_key_serialization_version + ) + # Build attribute JSON for entity join keys + attrs: Dict[str, Any] = {} + for jk, ev in zip(entity_key.join_keys, entity_key.entity_values): + attrs[jk] = self._json_safe_attr_value( + feast_value_type_to_python_type(ev) + ) + attr_json = json.dumps(attrs) + + # VADD vs_key VALUES ... SETATTR + cmd_args: List[Any] = [vs_key, "VALUES", str(dim)] + cmd_args.extend(str(f) for f in floats) + cmd_args.append(element_id) + cmd_args.extend(["SETATTR", attr_json]) + pipe.execute_command("VADD", *cmd_args) + pipe.execute() + + key_ttl_seconds = getattr(config.online_store, "key_ttl_seconds", None) + if key_ttl_seconds: + client.expire(name=vs_key, time=key_ttl_seconds) + def delete_entity_values(self, config: RepoConfig, join_keys: List[str]): client = self._get_client(config.online_store) deleted_count = 0 prefix = _redis_key_prefix(join_keys) + deleted_entity_ids: List[str] = [] + project_bytes = config.project.encode("utf8") + can_delete_vectors = getattr(config.online_store, "vector_enabled", False) with client.pipeline(transaction=False) as pipe: for _k in client.scan_iter( @@ -130,8 +282,36 @@ def delete_entity_values(self, config: RepoConfig, join_keys: List[str]): ): pipe.delete(_k) deleted_count += 1 + if can_delete_vectors and _k.endswith(project_bytes): + serialized_entity_key = _k[: -len(project_bytes)] + try: + entity_key = deserialize_entity_key( + serialized_entity_key, + entity_key_serialization_version=config.entity_key_serialization_version, + ) + deleted_entity_ids.append( + self._vector_element_id( + entity_key, config.entity_key_serialization_version + ) + ) + except Exception: + # Keep deleting hash rows even if a key cannot be decoded. + continue pipe.execute() + if ( + deleted_entity_ids + and can_delete_vectors + and self._check_vadd_supported(client) + ): + vector_set_keys = list(client.scan_iter(f"vs:{config.project}:*")) + if vector_set_keys: + with client.pipeline(transaction=False) as pipe: + for vs_key in vector_set_keys: + for entity_id in deleted_entity_ids: + pipe.execute_command("VREM", vs_key, entity_id) + pipe.execute() + logger.debug(f"Deleted {deleted_count} rows for entity {', '.join(join_keys)}") def delete_table(self, config: RepoConfig, table: FeatureView): @@ -153,6 +333,10 @@ def delete_table(self, config: RepoConfig, table: FeatureView): redis_hash_keys = [_mmh3(f"{fv_name}:{f.name}") for f in table.features] redis_hash_keys.append(bytes(f"_ts:{fv_name}", "utf8")) + # Clean up the Vector Set key unconditionally before the early-return + vs_key = self._vector_set_key(config.project, fv_name) + client.delete(vs_key) + # Phase 1: collect all matching entity keys from SCAN (no per-key round trips) scan_pattern = b"".join([prefix, b"*", config.project.encode("utf8")]) all_keys = list(client.scan_iter(scan_pattern)) @@ -220,11 +404,18 @@ def teardown( """ We delete the keys in redis for tables/views being removed. """ + client = self._get_client(config.online_store) join_keys_to_delete = set(tuple(table.join_keys) for table in tables) for join_keys in join_keys_to_delete: self.delete_entity_values(config, list(join_keys)) + # Clean up any Vector Set keys for the removed tables + for table in tables: + fv_name = _versioned_fv_name(table, config) + vs_key = self._vector_set_key(config.project, fv_name) + client.delete(vs_key) + @staticmethod def _parse_connection_string(connection_string: str): """ @@ -322,13 +513,16 @@ def online_write_batch( feature_view = _versioned_fv_name(table, config) ts_key = f"_ts:{feature_view}" + vector_data_to_index: List[ + Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] + ] = [] if online_store_config.skip_dedup: # Single-pipeline fast path: no timestamp read, directly write all rows. # Reduces round trips from 2 to 1. Suitable for initial loads or # append-only pipelines where out-of-order writes are not a concern. with client.pipeline(transaction=False) as pipe: - for entity_key, values, timestamp, _ in data: + for entity_key, values, timestamp, created_ts in data: redis_key_bin = _redis_key( project, entity_key, @@ -347,9 +541,18 @@ def online_write_batch( name=redis_key_bin, time=online_store_config.key_ttl_seconds, ) + vector_data_to_index.append( + (entity_key, values, timestamp, created_ts) + ) results = pipe.execute() if progress: progress(len(results)) + if ( + getattr(online_store_config, "vector_enabled", False) + and vector_data_to_index + and self._check_vadd_supported(client) + ): + self._vadd_vectors(client, config, table, vector_data_to_index) return keys = [] @@ -370,9 +573,12 @@ def online_write_batch( # flattening the list of lists. `hmget` does the lookup assuming a list of keys in the key bin prev_event_timestamps = [i[0] for i in prev_event_timestamps] - for redis_key_bin, prev_event_time, (_, values, timestamp, _) in zip( - keys, prev_event_timestamps, data - ): + for redis_key_bin, prev_event_time, ( + entity_key, + values, + timestamp, + created_ts, + ) in zip(keys, prev_event_timestamps, data): # Convert incoming timestamp to millisecond-aware datetime aware_ts = utils.make_tzaware(timestamp) # Build protobuf timestamp with nanos @@ -390,6 +596,7 @@ def online_write_batch( if progress: progress(1) continue + vector_data_to_index.append((entity_key, values, timestamp, created_ts)) # Store full timestamp (seconds + nanos) entity_hset = {ts_key: ts.SerializeToString()} @@ -407,6 +614,14 @@ def online_write_batch( if progress: progress(len(results)) + # Index vectors into Redis Vector Sets if vector_enabled + if ( + getattr(online_store_config, "vector_enabled", False) + and vector_data_to_index + and self._check_vadd_supported(client) + ): + self._vadd_vectors(client, config, table, vector_data_to_index) + async def online_write_batch_async( self, config: RepoConfig, @@ -425,10 +640,13 @@ async def online_write_batch_async( feature_view = _versioned_fv_name(table, config) ts_key = f"_ts:{feature_view}" + vector_data_to_index: List[ + Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] + ] = [] if online_store_config.skip_dedup: async with client.pipeline(transaction=False) as pipe: - for entity_key, values, timestamp, _ in data: + for entity_key, values, timestamp, created_ts in data: redis_key_bin = _redis_key( project, entity_key, @@ -447,9 +665,16 @@ async def online_write_batch_async( name=redis_key_bin, time=online_store_config.key_ttl_seconds, ) + vector_data_to_index.append( + (entity_key, values, timestamp, created_ts) + ) results = await pipe.execute() if progress: progress(len(results)) + if getattr(online_store_config, "vector_enabled", False): + sync_client = self._get_client(online_store_config) + if vector_data_to_index and self._check_vadd_supported(sync_client): + self._vadd_vectors(sync_client, config, table, vector_data_to_index) return keys = [] @@ -467,9 +692,12 @@ async def online_write_batch_async( prev_event_timestamps = [i[0] for i in prev_event_timestamps] async with client.pipeline(transaction=False) as pipe: - for redis_key_bin, prev_event_time, (_, values, timestamp, _) in zip( - keys, prev_event_timestamps, data - ): + for redis_key_bin, prev_event_time, ( + entity_key, + values, + timestamp, + created_ts, + ) in zip(keys, prev_event_timestamps, data): aware_ts = utils.make_tzaware(timestamp) ts = Timestamp() ts.FromDatetime(aware_ts) @@ -484,6 +712,7 @@ async def online_write_batch_async( progress(1) continue + vector_data_to_index.append((entity_key, values, timestamp, created_ts)) entity_hset = {ts_key: ts.SerializeToString()} for feature_name, val in values.items(): f_key = _mmh3(f"{feature_view}:{feature_name}") @@ -499,6 +728,15 @@ async def online_write_batch_async( if progress: progress(len(results)) + # Index vectors into Redis Vector Sets if vector_enabled (sync fallback) + if ( + getattr(online_store_config, "vector_enabled", False) + and vector_data_to_index + ): + sync_client = self._get_client(online_store_config) + if self._check_vadd_supported(sync_client): + self._vadd_vectors(sync_client, config, table, vector_data_to_index) + def _generate_redis_keys_for_entities( self, config: RepoConfig, entity_keys: List[EntityKeyProto] ) -> List[bytes]: @@ -755,3 +993,160 @@ def _get_features_for_entity( timestamp = datetime.fromtimestamp(total_seconds, tz=timezone.utc) return timestamp, res + + def retrieve_online_documents_v2( + self, + config: RepoConfig, + table: FeatureView, + requested_features: List[str], + embedding: Optional[List[float]], + top_k: int, + distance_metric: Optional[str] = None, + query_string: Optional[str] = None, + include_feature_view_version_metadata: bool = False, + ) -> List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[Dict[str, ValueProto]], + ] + ]: + """Retrieve documents via Redis 8 Vector Sets (VSIM). + + Uses the native ``VSIM`` command to find the *top_k* vectors closest + to the provided *embedding*. For each match, the full feature values + are fetched from the existing Redis HSET storage and returned together + with a synthetic ``distance`` key containing the similarity score. + + Args: + config: Feast repo configuration. + table: The FeatureView to search. + requested_features: Feature names to include in the result. + embedding: The query vector (list of floats). + top_k: Number of nearest neighbours to return. + distance_metric: Vector metric to use (COSINE or L2). + query_string: Not supported; raises ValueError if provided. + include_feature_view_version_metadata: Unused, kept for API compat. + + Returns: + A list of ``(event_ts, entity_key_proto, feature_dict)`` tuples. + """ + online_store_config = config.online_store + assert isinstance(online_store_config, RedisOnlineStoreConfig) + + if not getattr(online_store_config, "vector_enabled", False): + raise ValueError( + "Vector search is not enabled in the online store config. " + "Set vector_enabled=True in RedisOnlineStoreConfig." + ) + if embedding is None: + raise ValueError( + "An embedding vector must be provided for Redis vector search." + ) + if query_string is not None: + raise ValueError( + "Full-text search (query_string) is not supported by RedisOnlineStore. " + "Use embedding-based vector search instead." + ) + + self._vector_search_metric(config, table, distance_metric) + normalized_embedding = self._normalize_vector(embedding) + + client = self._get_client(online_store_config) + + if not self._check_vadd_supported(client): + raise NotImplementedError( + "The connected Redis server does not support Vector Sets (VADD/VSIM). " + "Redis 8.0 or later is required for vector search." + ) + + fv_name = _versioned_fv_name(table, config) + vs_key = self._vector_set_key(config.project, fv_name) + dim = len(normalized_embedding) + + # Build VSIM command: VSIM key VALUES WITHSCORES COUNT + cmd_args: List[Any] = [vs_key, "VALUES", str(dim)] + cmd_args.extend(str(f) for f in normalized_embedding) + cmd_args.extend(["WITHSCORES", "COUNT", str(top_k)]) + raw_results = client.execute_command("VSIM", *cmd_args) + + # raw_results is a flat list: [element_id, score, element_id, score, ...] + if not raw_results: + return [] + + # Parse the flat response into (element_id, score) pairs + pairs: List[Tuple[str, float]] = [] + for i in range(0, len(raw_results), 2): + eid = raw_results[i] + if isinstance(eid, bytes): + eid = eid.decode("utf-8") + score = float(raw_results[i + 1]) + pairs.append((eid, score)) + + if not pairs: + return [] + + # Recover entity key protos and build HMGET pipeline to fetch features + entity_key_protos: List[EntityKeyProto] = [] + redis_keys: List[bytes] = [] + for eid, _ in pairs: + entity_key_bin = bytes.fromhex(eid) + ek_proto = deserialize_entity_key( + entity_key_bin, + entity_key_serialization_version=config.entity_key_serialization_version, + ) + entity_key_protos.append(ek_proto) + redis_keys.append( + _redis_key( + config.project, + ek_proto, + entity_key_serialization_version=config.entity_key_serialization_version, + ) + ) + + # Build hash field keys for requested features + timestamp + hset_keys = [_mmh3(f"{fv_name}:{feat}") for feat in requested_features] + ts_hkey = f"_ts:{fv_name}" + hset_keys.append(ts_hkey) + + with client.pipeline(transaction=False) as pipe: + for rk in redis_keys: + pipe.hmget(rk, hset_keys) + all_values = pipe.execute() + + results: List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[Dict[str, ValueProto]], + ] + ] = [] + + for idx, ((_eid, score), ek_proto) in enumerate(zip(pairs, entity_key_protos)): + raw_vals = all_values[idx] + feature_dict: Dict[str, ValueProto] = {} + + # Parse timestamp + event_ts: Optional[datetime] = None + ts_bin = raw_vals[-1] # last element is the timestamp + if ts_bin: + ts_proto = Timestamp() + ts_proto.ParseFromString(ts_bin) + total_seconds = ts_proto.seconds + ts_proto.nanos / 1_000_000_000.0 + event_ts = datetime.fromtimestamp(total_seconds, tz=timezone.utc) + + # Parse feature values + for feat_idx, feat_name in enumerate(requested_features): + val = ValueProto() + val_bin = raw_vals[feat_idx] + if val_bin: + val.ParseFromString(val_bin) + feature_dict[feat_name] = val + + distance = max(0.0, 1.0 - float(score)) + + feature_dict["distance"] = ValueProto(float_val=distance) + + results.append((event_ts, ek_proto, feature_dict)) + + return results diff --git a/sdk/python/tests/unit/infra/online_store/test_redis.py b/sdk/python/tests/unit/infra/online_store/test_redis.py index 68eb28c4c11..cd9d734f37d 100644 --- a/sdk/python/tests/unit/infra/online_store/test_redis.py +++ b/sdk/python/tests/unit/infra/online_store/test_redis.py @@ -9,7 +9,7 @@ from feast.infra.online_stores.redis import RedisOnlineStore, RedisOnlineStoreConfig from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto -from feast.types import Int32 +from feast.types import Array, Float32, Int32 @pytest.fixture @@ -44,6 +44,39 @@ def feature_view(): return feature_view +@pytest.fixture +def vector_feature_view(): + file_source = FileSource(name="my_file_source", path="test.parquet") + entity = Entity(name="entity", join_keys=["entity"]) + feature_view = FeatureView( + name="vector_feature_view", + entities=[entity], + schema=[ + Field( + name="embedding", + dtype=Array(Float32), + vector_index=True, + vector_length=2, + vector_search_metric="COSINE", + ), + Field(name="title", dtype=Int32), + ], + source=file_source, + ) + return feature_view + + +@pytest.fixture +def vector_repo_config(): + return RepoConfig( + provider="local", + project="test", + entity_key_serialization_version=3, + registry="dummy_registry.db", + online_store=RedisOnlineStoreConfig(vector_enabled=True, key_ttl_seconds=60), + ) + + def test_generate_entity_redis_keys(redis_online_store: RedisOnlineStore, repo_config): entity_keys = [ EntityKeyProto(join_keys=["entity"], entity_values=[ValueProto(int32_val=1)]), @@ -482,3 +515,188 @@ def test_online_write_batch_async_exists_and_is_coroutine(): store = RedisOnlineStore() assert hasattr(store, "online_write_batch_async") assert inspect.iscoroutinefunction(store.online_write_batch_async) + + +def test_vadd_vectors_normalizes_vectors_and_sets_ttl( + redis_online_store: RedisOnlineStore, vector_repo_config, vector_feature_view +): + entity_key = EntityKeyProto( + join_keys=["entity"], entity_values=[ValueProto(int32_val=1)] + ) + embedding = ValueProto() + embedding.float_list_val.val.extend([3.0, 4.0]) + data = [ + ( + entity_key, + {"embedding": embedding, "title": ValueProto(int32_val=7)}, + datetime.now(tz=timezone.utc), + None, + ) + ] + + mock_client = MagicMock() + pipe = MagicMock() + pipe.__enter__ = MagicMock(return_value=pipe) + pipe.__exit__ = MagicMock(return_value=False) + pipe.execute.return_value = [] + mock_client.pipeline.return_value = pipe + + redis_online_store._vadd_vectors( + mock_client, vector_repo_config, vector_feature_view, data + ) + + args = pipe.execute_command.call_args.args + assert args[0] == "VADD" + assert args[1] == "vs:test:vector_feature_view" + assert args[2:4] == ("VALUES", "2") + assert args[4:6] == ("0.6", "0.8") + mock_client.expire.assert_called_once_with( + name="vs:test:vector_feature_view", time=60 + ) + + +def test_online_write_batch_skips_stale_rows_for_vector_indexing( + redis_online_store: RedisOnlineStore, vector_repo_config, vector_feature_view +): + mock_client = MagicMock() + pipe = MagicMock() + pipe.__enter__ = MagicMock(return_value=pipe) + pipe.__exit__ = MagicMock(return_value=False) + pipe.execute.side_effect = [ + [ + [Timestamp(seconds=20).SerializeToString()], + [None], + ], + [1], + ] + mock_client.pipeline.return_value = pipe + + entity_key_old = EntityKeyProto( + join_keys=["entity"], entity_values=[ValueProto(int32_val=1)] + ) + entity_key_new = EntityKeyProto( + join_keys=["entity"], entity_values=[ValueProto(int32_val=2)] + ) + embedding_old = ValueProto() + embedding_old.float_list_val.val.extend([1.0, 0.0]) + embedding_new = ValueProto() + embedding_new.float_list_val.val.extend([0.0, 1.0]) + data = [ + ( + entity_key_old, + {"embedding": embedding_old, "title": ValueProto(int32_val=1)}, + datetime.fromtimestamp(10, tz=timezone.utc), + None, + ), + ( + entity_key_new, + {"embedding": embedding_new, "title": ValueProto(int32_val=2)}, + datetime.fromtimestamp(30, tz=timezone.utc), + None, + ), + ] + + with ( + patch.object(redis_online_store, "_get_client", return_value=mock_client), + patch.object(redis_online_store, "_check_vadd_supported", return_value=True), + patch.object(redis_online_store, "_vadd_vectors") as mock_vadd, + ): + redis_online_store.online_write_batch( + vector_repo_config, vector_feature_view, data, progress=None + ) + + vector_batch = mock_vadd.call_args.args[3] + assert len(vector_batch) == 1 + assert vector_batch[0][0] == entity_key_new + + +@pytest.mark.parametrize( + ("distance_metric", "score", "expected_distance"), + [ + ("COSINE", 0.75, 0.25), + ], +) +def test_retrieve_online_documents_v2_converts_scores_to_distances( + redis_online_store: RedisOnlineStore, + vector_repo_config, + vector_feature_view, + distance_metric: str, + score: float, + expected_distance: float, +): + entity_key = EntityKeyProto( + join_keys=["entity"], entity_values=[ValueProto(int32_val=1)] + ) + entity_id = redis_online_store._vector_element_id(entity_key, 3) + embedding = ValueProto() + embedding.float_list_val.val.extend([3.0, 4.0]) + title = ValueProto(int32_val=7) + ts = Timestamp(seconds=123, nanos=456) + + mock_client = MagicMock() + pipe = MagicMock() + pipe.__enter__ = MagicMock(return_value=pipe) + pipe.__exit__ = MagicMock(return_value=False) + pipe.execute.return_value = [ + [ + embedding.SerializeToString(), + title.SerializeToString(), + ts.SerializeToString(), + ] + ] + mock_client.pipeline.return_value = pipe + mock_client.execute_command.return_value = [entity_id, str(score)] + + with ( + patch.object(redis_online_store, "_get_client", return_value=mock_client), + patch.object(redis_online_store, "_check_vadd_supported", return_value=True), + ): + results = redis_online_store.retrieve_online_documents_v2( + config=vector_repo_config, + table=vector_feature_view, + requested_features=["embedding", "title"], + embedding=[3.0, 4.0], + top_k=1, + distance_metric=distance_metric, + ) + + assert len(results) == 1 + event_ts, returned_entity_key, feature_dict = results[0] + assert event_ts is not None + assert returned_entity_key == entity_key + assert feature_dict is not None + assert feature_dict["title"].int32_val == 7 + assert feature_dict["distance"].float_val == pytest.approx(expected_distance) + + +def test_delete_entity_values_removes_vector_members( + redis_online_store: RedisOnlineStore, vector_repo_config +): + entity_key = EntityKeyProto( + join_keys=["entity"], entity_values=[ValueProto(int32_val=1)] + ) + redis_key = redis_online_store._generate_redis_keys_for_entities( + vector_repo_config, [entity_key] + )[0] + vector_id = redis_online_store._vector_element_id(entity_key, 3) + + mock_client = MagicMock() + mock_client.scan_iter.side_effect = [ + iter([redis_key]), + iter([b"vs:test:vector_feature_view"]), + ] + pipe = MagicMock() + pipe.__enter__ = MagicMock(return_value=pipe) + pipe.__exit__ = MagicMock(return_value=False) + pipe.execute.side_effect = [None, None] + mock_client.pipeline.return_value = pipe + + with ( + patch.object(redis_online_store, "_get_client", return_value=mock_client), + patch.object(redis_online_store, "_check_vadd_supported", return_value=True), + ): + redis_online_store.delete_entity_values(vector_repo_config, ["entity"]) + + pipe.execute_command.assert_any_call( + "VREM", b"vs:test:vector_feature_view", vector_id + )