diff --git a/Makefile b/Makefile index 813a27f4e3b..6fcf95dc7da 100644 --- a/Makefile +++ b/Makefile @@ -200,7 +200,7 @@ test-python-universal-postgres-offline: test-python-universal-postgres-online: PYTHONPATH='.' \ FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.online_stores.contrib.postgres_repo_configuration \ - PYTEST_PLUGINS=sdk.python.feast.infra.offline_stores.contrib.postgres_offline_store.tests \ + PYTEST_PLUGINS=sdk.python.tests.integration.feature_repos.universal.online_store.postgres \ python -m pytest -n 8 --integration \ -k "not test_universal_cli and \ not test_go_feature_server and \ diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 62ce9d6e382..15598e1d609 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -1690,6 +1690,72 @@ def _get_online_features( ) return OnlineResponse(online_features_response) + @log_exceptions_and_usage + def retrieve_online_documents( + self, + feature: str, + query: Union[str, List[float]], + top_k: int, + ) -> OnlineResponse: + """ + Retrieves the top k closest document features. Note, embeddings are a subset of features. + + Args: + feature: The list of document features that should be retrieved from the online document store. These features can be + specified either as a list of string document feature references or as a feature service. String feature + references must have format "feature_view:feature", e.g, "document_fv:document_embeddings". + query: The query to retrieve the closest document features for. + top_k: The number of closest document features to retrieve. + """ + return self._retrieve_online_documents( + feature=feature, + query=query, + top_k=top_k, + ) + + def _retrieve_online_documents( + self, + feature: str, + query: Union[str, List[float]], + top_k: int, + ): + if isinstance(query, str): + raise ValueError( + "Using embedding functionality is not supported for document retrieval. Please embed the query before calling retrieve_online_documents." + ) + ( + requested_feature_views, + _, + ) = self._get_feature_views_to_use( + features=[feature], allow_cache=True, hide_dummy_entity=False + ) + requested_feature = ( + feature.split(":")[1] if isinstance(feature, str) else feature + ) + provider = self._get_provider() + document_features = self._retrieve_from_online_store( + provider, + requested_feature_views[0], + requested_feature, + query, + top_k, + ) + document_feature_vals = [feature[2] for feature in document_features] + document_feature_distance_vals = [feature[3] for feature in document_features] + online_features_response = GetOnlineFeaturesResponse(results=[]) + + # TODO Refactor to better way of populating result + # TODO populate entity in the response after returning entity in document_features is supported + self._populate_result_rows_from_columnar( + online_features_response=online_features_response, + data={requested_feature: document_feature_vals}, + ) + self._populate_result_rows_from_columnar( + online_features_response=online_features_response, + data={"distance": document_feature_distance_vals}, + ) + return OnlineResponse(online_features_response) + @staticmethod def _get_columnar_entity_values( rowise: Optional[List[Dict[str, Any]]], columnar: Optional[Dict[str, List[Any]]] @@ -1906,6 +1972,43 @@ def _read_from_online_store( read_row_protos.append((event_timestamps, statuses, values)) return read_row_protos + def _retrieve_from_online_store( + self, + provider: Provider, + table: FeatureView, + requested_feature: str, + query: List[float], + top_k: int, + ) -> List[Tuple[Timestamp, "FieldStatus.ValueType", Value, Value]]: + """ + Search and return document features from the online document store. + """ + documents = provider.retrieve_online_documents( + config=self.config, + table=table, + requested_feature=requested_feature, + query=query, + top_k=top_k, + ) + + read_row_protos = [] + row_ts_proto = Timestamp() + + for row_ts, feature_val, distance_val in documents: + # Reset timestamp to default or update if row_ts is not None + if row_ts is not None: + row_ts_proto.FromDatetime(row_ts) + + if feature_val is None or distance_val is None: + feature_val = Value() + distance_val = Value() + status = FieldStatus.NOT_FOUND + else: + status = FieldStatus.PRESENT + + read_row_protos.append((row_ts_proto, status, feature_val, distance_val)) + return read_row_protos + @staticmethod def _populate_response_from_feature_data( feature_data: Iterable[ diff --git a/sdk/python/feast/infra/key_encoding_utils.py b/sdk/python/feast/infra/key_encoding_utils.py index 62b6b72724e..e50e438c3de 100644 --- a/sdk/python/feast/infra/key_encoding_utils.py +++ b/sdk/python/feast/infra/key_encoding_utils.py @@ -72,3 +72,11 @@ def serialize_entity_key( output.append(val_bytes) return b"".join(output) + + +def get_val_str(val): + accept_value_types = ["float_list_val", "double_list_val", "int_list_val"] + for accept_type in accept_value_types: + if val.HasField(accept_type): + return str(getattr(val, accept_type).val) + return None diff --git a/sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/tests/data_source.py b/sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/tests/data_source.py index f50cdc4c41f..a23d90e1868 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/tests/data_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/tests/data_source.py @@ -7,6 +7,7 @@ from testcontainers.core.waiting_utils import wait_for_logs from feast.data_source import DataSource +from feast.feature_logging import LoggingDestination from feast.infra.offline_stores.contrib.postgres_offline_store.postgres import ( PostgreSQLOfflineStoreConfig, PostgreSQLSource, @@ -57,6 +58,9 @@ def postgres_container(): class PostgreSQLDataSourceCreator(DataSourceCreator, OnlineStoreCreator): + def create_logged_features_destination(self) -> LoggingDestination: + return None # type: ignore + def __init__( self, project_name: str, fixture_request: pytest.FixtureRequest, **kwargs ): diff --git a/sdk/python/feast/infra/online_stores/contrib/postgres.py b/sdk/python/feast/infra/online_stores/contrib/postgres.py index 308528aaec2..2dcb6187837 100644 --- a/sdk/python/feast/infra/online_stores/contrib/postgres.py +++ b/sdk/python/feast/infra/online_stores/contrib/postgres.py @@ -2,7 +2,7 @@ import logging from collections import defaultdict from datetime import datetime -from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple +from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union import psycopg2 import pytz @@ -12,7 +12,7 @@ from feast import Entity from feast.feature_view import FeatureView -from feast.infra.key_encoding_utils import serialize_entity_key +from feast.infra.key_encoding_utils import get_val_str, serialize_entity_key from feast.infra.online_stores.online_store import OnlineStore from feast.infra.utils.postgres.connection_utils import _get_conn, _get_connection_pool from feast.infra.utils.postgres.postgres_config import ConnectionType, PostgreSQLConfig @@ -25,6 +25,12 @@ class PostgreSQLOnlineStoreConfig(PostgreSQLConfig): type: Literal["postgres"] = "postgres" + # Whether to enable the pgvector extension for vector similarity search + pgvector_enabled: Optional[bool] = False + + # If pgvector is enabled, the length of the vector field + vector_len: Optional[int] = 512 + class PostgreSQLOnlineStore(OnlineStore): _conn: Optional[psycopg2._psycopg.connection] = None @@ -68,11 +74,19 @@ def online_write_batch( created_ts = _to_naive_utc(created_ts) for feature_name, val in values.items(): + val_str: Union[str, bytes] + if ( + "pgvector_enabled" in config.online_config + and config.online_config["pgvector_enabled"] + ): + val_str = get_val_str(val) + else: + val_str = val.SerializeToString() insert_values.append( ( entity_key_bin, feature_name, - val.SerializeToString(), + val_str, timestamp, created_ts, ) @@ -212,6 +226,12 @@ def update( for table in tables_to_keep: table_name = _table_id(project, table) + value_type = "BYTEA" + if ( + "pgvector_enabled" in config.online_config + and config.online_config["pgvector_enabled"] + ): + value_type = f'vector({config.online_config["vector_len"]})' cur.execute( sql.SQL( """ @@ -219,7 +239,7 @@ def update( ( entity_key BYTEA, feature_name TEXT, - value BYTEA, + value {}, event_ts TIMESTAMPTZ, created_ts TIMESTAMPTZ, PRIMARY KEY(entity_key, feature_name) @@ -228,6 +248,7 @@ def update( """ ).format( sql.Identifier(table_name), + sql.SQL(value_type), sql.Identifier(f"{table_name}_ek"), sql.Identifier(table_name), ) @@ -251,6 +272,74 @@ def teardown( logging.exception("Teardown failed") raise + def retrieve_online_documents( + self, + config: RepoConfig, + table: FeatureView, + requested_feature: str, + embedding: List[float], + top_k: int, + ) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]]: + """ + + Args: + config: Feast configuration object + table: FeatureView object as the table to search + requested_feature: The requested feature as the column to search + embedding: The query embedding to search for + top_k: The number of items to return + Returns: + List of tuples containing the event timestamp and the document feature + + """ + project = config.project + + # Convert the embedding to a string to be used in postgres vector search + query_embedding_str = f"[{','.join(str(el) for el in embedding)}]" + + result: List[ + Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]] + ] = [] + with self._get_conn(config) as conn, conn.cursor() as cur: + table_name = _table_id(project, table) + + # Search query template to find the top k items that are closest to the given embedding + # SELECT * FROM items ORDER BY embedding <-> '[3,1,2]' LIMIT 5; + cur.execute( + sql.SQL( + """ + SELECT + entity_key, + feature_name, + value, + value <-> %s as distance, + event_ts FROM {table_name} + WHERE feature_name = {feature_name} + ORDER BY distance + LIMIT {top_k}; + """ + ).format( + table_name=sql.Identifier(table_name), + feature_name=sql.Literal(requested_feature), + top_k=sql.Literal(top_k), + ), + (query_embedding_str,), + ) + rows = cur.fetchall() + + for entity_key, feature_name, value, distance, event_ts in rows: + # TODO Deserialize entity_key to return the entity in response + # entity_key_proto = EntityKeyProto() + # entity_key_proto_bin = bytes(entity_key) + + # TODO Convert to List[float] for value type proto + feature_value_proto = ValueProto(string_val=value) + + distance_value_proto = ValueProto(float_val=distance) + result.append((event_ts, feature_value_proto, distance_value_proto)) + + return result + def _table_id(project: str, table: FeatureView) -> str: return f"{project}_{table.name}" diff --git a/sdk/python/feast/infra/online_stores/contrib/postgres_repo_configuration.py b/sdk/python/feast/infra/online_stores/contrib/postgres_repo_configuration.py index 2a9f0d54cd4..6e4ca3f9501 100644 --- a/sdk/python/feast/infra/online_stores/contrib/postgres_repo_configuration.py +++ b/sdk/python/feast/infra/online_stores/contrib/postgres_repo_configuration.py @@ -1,10 +1,18 @@ -from feast.infra.offline_stores.contrib.postgres_offline_store.tests.data_source import ( - PostgreSQLDataSourceCreator, -) from tests.integration.feature_repos.integration_test_repo_config import ( IntegrationTestRepoConfig, ) +from tests.integration.feature_repos.universal.online_store.postgres import ( + PGVectorOnlineStoreCreator, + PostgresOnlineStoreCreator, +) FULL_REPO_CONFIGS = [ - IntegrationTestRepoConfig(online_store_creator=PostgreSQLDataSourceCreator), + IntegrationTestRepoConfig( + online_store="postgres", online_store_creator=PostgresOnlineStoreCreator + ), + IntegrationTestRepoConfig( + online_store="pgvector", online_store_creator=PGVectorOnlineStoreCreator + ), ] + +AVAILABLE_ONLINE_STORES = {"pgvector": PGVectorOnlineStoreCreator} diff --git a/sdk/python/feast/infra/online_stores/online_store.py b/sdk/python/feast/infra/online_stores/online_store.py index fcc3376dce2..fc1b3d4ad30 100644 --- a/sdk/python/feast/infra/online_stores/online_store.py +++ b/sdk/python/feast/infra/online_stores/online_store.py @@ -134,3 +134,30 @@ def teardown( entities: Entities whose corresponding infrastructure should be deleted. """ pass + + def retrieve_online_documents( + self, + config: RepoConfig, + table: FeatureView, + requested_feature: str, + embedding: List[float], + top_k: int, + ) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]]: + """ + Retrieves online feature values for the specified embeddings. + + Args: + config: The config for the current feature store. + table: The feature view whose feature values should be read. + requested_feature: The name of the feature whose embeddings should be used for retrieval. + embedding: The embeddings to use for retrieval. + top_k: The number of nearest neighbors to retrieve. + + Returns: + object: A list of top k closest documents to the specified embedding. Each item in the list is a tuple + where the first item is the event timestamp for the row, and the second item is a dict of feature + name to embeddings. + """ + raise NotImplementedError( + f"Online store {self.__class__.__name__} does not support online retrieval" + ) diff --git a/sdk/python/feast/infra/passthrough_provider.py b/sdk/python/feast/infra/passthrough_provider.py index aca18f4856b..ec4df66d43a 100644 --- a/sdk/python/feast/infra/passthrough_provider.py +++ b/sdk/python/feast/infra/passthrough_provider.py @@ -190,6 +190,23 @@ def online_read( ) return result + @log_exceptions_and_usage(sampler=RatioSampler(ratio=0.001)) + def retrieve_online_documents( + self, + config: RepoConfig, + table: FeatureView, + requested_feature: str, + query: List[float], + top_k: int, + ) -> List: + set_usage_attribute("provider", self.__class__.__name__) + result = [] + if self.online_store: + result = self.online_store.retrieve_online_documents( + config, table, requested_feature, query, top_k + ) + return result + def ingest_df( self, feature_view: FeatureView, diff --git a/sdk/python/feast/infra/provider.py b/sdk/python/feast/infra/provider.py index 2a9670cacef..e71e87488d7 100644 --- a/sdk/python/feast/infra/provider.py +++ b/sdk/python/feast/infra/provider.py @@ -295,6 +295,30 @@ def get_feature_server_endpoint(self) -> Optional[str]: """Returns endpoint for the feature server, if it exists.""" return None + @abstractmethod + def retrieve_online_documents( + self, + config: RepoConfig, + table: FeatureView, + requested_feature: str, + query: List[float], + top_k: int, + ) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]]: + """ + Searches for the top-k nearest neighbors of the given document in the online document store. + + Args: + config: The config for the current feature store. + table: The feature view whose embeddings should be searched. + requested_feature: the requested document feature name. + query: The query embedding to search for. + top_k: The number of nearest neighbors to return. + + Returns: + A list of dictionaries, where each dictionary contains the document feature. + """ + pass + def get_provider(config: RepoConfig) -> Provider: if "." not in config.provider: diff --git a/sdk/python/tests/conftest.py b/sdk/python/tests/conftest.py index 1c9a958ce36..6abe30822f2 100644 --- a/sdk/python/tests/conftest.py +++ b/sdk/python/tests/conftest.py @@ -23,9 +23,13 @@ import pytest from _pytest.nodes import Item +from feast.data_source import DataSource from feast.feature_store import FeatureStore # noqa: E402 from feast.wait import wait_retry_backoff # noqa: E402 -from tests.data.data_creator import create_basic_driver_dataset # noqa: E402 +from tests.data.data_creator import ( # noqa: E402 + create_basic_driver_dataset, + create_document_dataset, +) from tests.integration.feature_repos.integration_test_repo_config import ( # noqa: E402 IntegrationTestRepoConfig, ) @@ -405,3 +409,13 @@ def fake_ingest_data(): "created": [pd.Timestamp(datetime.utcnow()).round("ms")], } return pd.DataFrame(data) + + +@pytest.fixture +def fake_document_data(environment: Environment) -> Tuple[pd.DataFrame, DataSource]: + df = create_document_dataset() + data_source = environment.data_source_creator.create_data_source( + df, + environment.feature_store.project, + ) + return df, data_source diff --git a/sdk/python/tests/data/data_creator.py b/sdk/python/tests/data/data_creator.py index 1fc66aee845..1be96f753a7 100644 --- a/sdk/python/tests/data/data_creator.py +++ b/sdk/python/tests/data/data_creator.py @@ -78,3 +78,22 @@ def get_feature_values_for_dtype( return [[n, n] if n is not None else None for n in non_list_val] else: return non_list_val + + +def create_document_dataset() -> pd.DataFrame: + data = { + "item_id": [1, 2, 3], + "embedding_float": [[4.0, 5.0], [1.0, 2.0], [3.0, 4.0]], + "embedding_double": [[4.0, 5.0], [1.0, 2.0], [3.0, 4.0]], + "ts": [ + pd.Timestamp(datetime.utcnow()).round("ms"), + pd.Timestamp(datetime.utcnow()).round("ms"), + pd.Timestamp(datetime.utcnow()).round("ms"), + ], + "created_ts": [ + pd.Timestamp(datetime.utcnow()).round("ms"), + pd.Timestamp(datetime.utcnow()).round("ms"), + pd.Timestamp(datetime.utcnow()).round("ms"), + ], + } + return pd.DataFrame(data) diff --git a/sdk/python/tests/foo_provider.py b/sdk/python/tests/foo_provider.py index ba256a3813c..7ba4adb114b 100644 --- a/sdk/python/tests/foo_provider.py +++ b/sdk/python/tests/foo_provider.py @@ -103,3 +103,13 @@ def retrieve_feature_service_logs( registry: BaseRegistry, ) -> RetrievalJob: return RetrievalJob() + + def retrieve_online_documents( + self, + config: RepoConfig, + table: FeatureView, + requested_feature: str, + query: List[float], + top_k: int, + ) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]]: + return [] diff --git a/sdk/python/tests/integration/feature_repos/universal/online_store/postgres.py b/sdk/python/tests/integration/feature_repos/universal/online_store/postgres.py new file mode 100644 index 00000000000..58e7af9c468 --- /dev/null +++ b/sdk/python/tests/integration/feature_repos/universal/online_store/postgres.py @@ -0,0 +1,68 @@ +from typing import Dict + +from testcontainers.core.container import DockerContainer +from testcontainers.core.waiting_utils import wait_for_logs +from testcontainers.postgres import PostgresContainer + +from tests.integration.feature_repos.universal.online_store_creator import ( + OnlineStoreCreator, +) + + +class PostgresOnlineStoreCreator(OnlineStoreCreator): + def __init__(self, project_name: str, **kwargs): + super().__init__(project_name) + self.container = PostgresContainer( + "postgres:16", + username="root", + password="test", + dbname="test", + ).with_exposed_ports(5432) + + def create_online_store(self) -> Dict[str, str]: + self.container.start() + return { + "host": "localhost", + "type": "postgres", + "user": "root", + "password": "test", + "database": "test", + "port": self.container.get_exposed_port(5432), + } + + def teardown(self): + self.container.stop() + + +class PGVectorOnlineStoreCreator(OnlineStoreCreator): + def __init__(self, project_name: str, **kwargs): + super().__init__(project_name) + self.container = ( + DockerContainer("pgvector/pgvector:pg16") + .with_env("POSTGRES_USER", "root") + .with_env("POSTGRES_PASSWORD", "test") + .with_env("POSTGRES_DB", "test") + .with_exposed_ports(5432) + ) + + def create_online_store(self) -> Dict[str, str]: + self.container.start() + log_string_to_wait_for = "database system is ready to accept connections" + wait_for_logs( + container=self.container, predicate=log_string_to_wait_for, timeout=10 + ) + command = "psql -h localhost -p 5432 -U root -d test -c 'CREATE EXTENSION IF NOT EXISTS vector;'" + self.container.exec(command) + return { + "host": "localhost", + "type": "postgres", + "user": "root", + "password": "test", + "database": "test", + "pgvector_enabled": True, + "vector_len": 2, + "port": self.container.get_exposed_port(5432), + } + + def teardown(self): + self.container.stop() diff --git a/sdk/python/tests/integration/online_store/test_universal_online.py b/sdk/python/tests/integration/online_store/test_universal_online.py index 82189713151..3ae7be9e1e4 100644 --- a/sdk/python/tests/integration/online_store/test_universal_online.py +++ b/sdk/python/tests/integration/online_store/test_universal_online.py @@ -25,9 +25,10 @@ Environment, construct_universal_feature_views, ) -from tests.integration.feature_repos.universal.entities import driver +from tests.integration.feature_repos.universal.entities import driver, item from tests.integration.feature_repos.universal.feature_views import ( create_driver_hourly_stats_feature_view, + create_item_embeddings_feature_view, driver_feature_view, ) from tests.utils.data_source_test_creator import prep_file_source @@ -785,3 +786,18 @@ def assert_feature_service_entity_mapping_correctness( entity_rows=entity_rows, full_feature_names=full_feature_names, ) + + +@pytest.mark.integration +@pytest.mark.universal_online_stores(only=["pgvector"]) +def test_retrieve_online_documents(environment, fake_document_data): + fs = environment.feature_store + df, data_source = fake_document_data + item_embeddings_feature_view = create_item_embeddings_feature_view(data_source) + fs.apply([item_embeddings_feature_view, item()]) + fs.write_to_online_store("item_embeddings", df) + + documents = fs.retrieve_online_documents( + feature="item_embeddings:embedding_float", query=[1.0, 2.0], top_k=2 + ).to_dict() + assert len(documents["embedding_float"]) == 2 diff --git a/setup.py b/setup.py index f94fb25bb55..e686ad70620 100644 --- a/setup.py +++ b/setup.py @@ -177,7 +177,7 @@ "pytest-mock==1.10.4", "pytest-env", "Sphinx>4.0.0,<7", - "testcontainers>=3.5,<4", + "testcontainers==4.3.3", "firebase-admin>=5.2.0,<6", "pre-commit<3.3.2", "assertpy==1.1",