From f28e8f88f2bf901530d8756dbd5e4ad1a3e60e04 Mon Sep 17 00:00:00 2001 From: Le Xuan An Date: Sat, 6 Jun 2026 15:23:49 +0700 Subject: [PATCH 01/15] feat: Add apache flink compute engine Signed-off-by: Le Xuan An --- docs/SUMMARY.md | 1 + .../components/compute-engine.md | 4 +- docs/reference/compute-engine/README.md | 8 + docs/reference/compute-engine/flink.md | 121 +++ pyproject.toml | 23 +- sdk/python/feast/batch_feature_view.py | 3 +- .../feast/infra/compute_engines/dag/model.py | 1 + .../infra/compute_engines/feature_builder.py | 5 +- .../infra/compute_engines/flink/__init__.py | 11 + .../infra/compute_engines/flink/compute.py | 154 +++ .../compute_engines/flink/feature_builder.py | 217 ++++ .../feast/infra/compute_engines/flink/job.py | 116 +++ .../infra/compute_engines/flink/nodes.py | 768 ++++++++++++++ .../infra/compute_engines/flink/utils.py | 53 + sdk/python/feast/repo_config.py | 1 + sdk/python/feast/stream_feature_view.py | 3 +- sdk/python/feast/transformation/factory.py | 1 + .../transformation/flink_transformation.py | 78 ++ sdk/python/feast/transformation/mode.py | 1 + .../infra/compute_engines/flink/__init__.py | 1 + .../flink/test_flink_compute_engine.py | 969 ++++++++++++++++++ 21 files changed, 2533 insertions(+), 6 deletions(-) create mode 100644 docs/reference/compute-engine/flink.md create mode 100644 sdk/python/feast/infra/compute_engines/flink/__init__.py create mode 100644 sdk/python/feast/infra/compute_engines/flink/compute.py create mode 100644 sdk/python/feast/infra/compute_engines/flink/feature_builder.py create mode 100644 sdk/python/feast/infra/compute_engines/flink/job.py create mode 100644 sdk/python/feast/infra/compute_engines/flink/nodes.py create mode 100644 sdk/python/feast/infra/compute_engines/flink/utils.py create mode 100644 sdk/python/feast/transformation/flink_transformation.py create mode 100644 sdk/python/tests/unit/infra/compute_engines/flink/__init__.py create mode 100644 sdk/python/tests/unit/infra/compute_engines/flink/test_flink_compute_engine.py diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 590005d8fbd..ce750d939c3 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -174,6 +174,7 @@ * [Snowflake](reference/compute-engine/snowflake.md) * [AWS Lambda (alpha)](reference/compute-engine/lambda.md) * [Spark (contrib)](reference/compute-engine/spark.md) + * [Apache Flink](reference/compute-engine/flink.md) * [Ray (contrib)](reference/compute-engine/ray.md) * [Feature repository](reference/feature-repository/README.md) * [feature\_store.yaml](reference/feature-repository/feature-store-yaml.md) diff --git a/docs/getting-started/components/compute-engine.md b/docs/getting-started/components/compute-engine.md index 60da1575932..d115ec5debb 100644 --- a/docs/getting-started/components/compute-engine.md +++ b/docs/getting-started/components/compute-engine.md @@ -24,7 +24,7 @@ engines. | SparkComputeEngine | Runs on Apache Spark, designed for large-scale distributed feature generation. | ✅ | | | SnowflakeComputeEngine | Runs on Snowflake, designed for scalable feature generation using Snowflake SQL. | ✅ | | | LambdaComputeEngine | Runs on AWS Lambda, designed for serverless feature generation. | ✅ | | -| FlinkComputeEngine | Runs on Apache Flink, designed for stream processing and real-time feature generation. | ❌ | | +| FlinkComputeEngine | Runs on Apache Flink, designed for distributed feature generation through PyFlink Table API. | ✅ | | | RayComputeEngine | Runs on Ray, designed for distributed feature generation and machine learning workloads. | ✅ | | ``` @@ -156,4 +156,4 @@ DAG nodes are defined as follows: +----------------+ +----------------+ | OnlineStoreWrite| OfflineStoreWrite| +----------------+ +----------------+ -``` \ No newline at end of file +``` diff --git a/docs/reference/compute-engine/README.md b/docs/reference/compute-engine/README.md index dad2ede75a6..920d5761d28 100644 --- a/docs/reference/compute-engine/README.md +++ b/docs/reference/compute-engine/README.md @@ -57,6 +57,14 @@ An example of built output from FeatureBuilder: - Supports point-in-time joins and large-scale materialization - Integrates with `SparkOfflineStore` and `SparkMaterializationJob` +### 🌊 FlinkComputeEngine + +{% page-ref page="flink.md" %} + +- Distributed DAG execution through Apache Flink's PyFlink Table API +- Supports materialization and historical retrieval with Feast offline stores +- Integrates with `FlinkMaterializationJob` and `FlinkDAGRetrievalJob` + ### ⚡ RayComputeEngine (contrib) - Distributed DAG execution via Ray diff --git a/docs/reference/compute-engine/flink.md b/docs/reference/compute-engine/flink.md new file mode 100644 index 00000000000..fee130ba550 --- /dev/null +++ b/docs/reference/compute-engine/flink.md @@ -0,0 +1,121 @@ +# Apache Flink + +## Description + +The Apache Flink compute engine provides a distributed execution engine for +feature pipelines through the PyFlink Table API. It implements Feast's unified +`ComputeEngine` interface and can be used for batch materialization operations +(`materialize` and `materialize-incremental`) and historical retrieval +(`get_historical_features`). + +The engine reads data through the configured Feast offline store and executes +the Feast DAG as PyFlink tables. Offline stores that expose a native +`to_flink_table(table_env)` retrieval job hand Flink tables directly to the +engine. The engine then uses Flink Table/SQL operations for join, filter, +aggregate, dedupe, and projection steps, and writes materialization results to +the configured online and/or offline store. + +## Configuration + +Install Feast with the Flink extra before using the engine: + +```bash +python -m pip install 'feast[flink]' +``` + +The `flink` extra installs PyFlink directly. Feast's Arrow dependency range is +kept compatible with PyFlink's supported `pyarrow` range so `feast[flink]` +resolves without a separate PyFlink install step. + +Configure the engine in `feature_store.yaml`: + +```yaml +project: my_project +registry: data/registry.db +provider: local +offline_store: + type: file +online_store: + type: sqlite + path: data/online_store.db +batch_engine: + type: flink.engine + execution_mode: batch + parallelism: 4 + table_config: + pipeline.name: "Feast Flink Compute Engine" + pandas_split_num: 4 +``` + +## Configuration Options + +| Option | Type | Default | Description | +| --- | --- | --- | --- | +| `type` | string | `flink.engine` | Must be `flink.engine`. | +| `execution_mode` | string | `batch` | PyFlink execution mode: `batch` or `streaming`. | +| `parallelism` | integer | `null` | Default Flink parallelism for jobs created by the engine. | +| `table_config` | map | `null` | Additional PyFlink table configuration entries. | +| `pandas_split_num` | integer | `1` | Number of PyFlink Arrow source splits when converting pandas entity DataFrames into Flink tables. | + +## Flink Transformations + +Use `mode="flink"` when a `BatchFeatureView` transformation should receive and +return PyFlink table objects: + +```python +from feast import BatchFeatureView, Field +from feast.types import Float32 + + +def double_rates(table): + # In production this can use PyFlink Table API operations and return a table. + return table + + +driver_stats = BatchFeatureView( + name="driver_stats", + entities=[driver], + mode="flink", + udf=double_rates, + schema=[Field(name="conv_rate", dtype=Float32)], + source=driver_stats_source, + online=True, +) +``` + +Flink transformations must return PyFlink table objects. pandas-returning UDFs +are not accepted by the Flink compute engine. + +## DAG Support + +The Flink engine implements Feast's compute DAG with Flink-specific nodes: + +- Source reads from Feast offline stores, preferring native Flink tables when a + retrieval job supports `to_flink_table(table_env)`. +- Transform nodes pass PyFlink tables to `mode="flink"` UDFs and preserve native + Flink table outputs. +- Join nodes use Flink SQL temporary views for feature joins and entity joins. +- Filter nodes apply point-in-time, TTL, and custom filter expressions in Flink + SQL. +- Aggregate nodes support non-windowed Feast aggregations using Flink SQL + aggregate functions. +- Dedupe nodes use `ROW_NUMBER()` over entity keys or internal entity-row ids so + historical retrieval keeps one latest feature row per entity row. +- Validation nodes check required output columns. JSON value validation must be + handled upstream in Flink SQL. +- Output nodes write only for materialization tasks; historical retrieval is + read-only. +- Historical retrieval accepts pandas entity DataFrames and SQL-string entity + DataFrames. SQL strings are interpreted as Flink SQL queries against the + configured TableEnvironment/catalog and must select an `event_timestamp` + column. + +## Current Limitations + +- Windowed aggregations are not yet implemented in the Flink compute engine. Use + non-windowed Feast aggregations or pre-window upstream in Flink. +- Offline store retrieval jobs must implement `to_flink_table(table_env)`. + Arrow/pandas-only retrieval jobs are rejected instead of converted. +- JSON value validation is not implemented inside the Flink compute engine + because the engine does not collect intermediate data out of Flink for + validation. diff --git a/pyproject.toml b/pyproject.toml index 21c48dd09e1..ae82cbf1173 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ dependencies = [ "mmh3", "numpy>=2.0.0,<3", "pandas>=1.4.3,<3", - "pyarrow>=21.0.0", + "pyarrow>=16.1.0", "pydantic>=2.10.6", "pygments>=2.12.0,<3", "PyYAML>=5.4.0,<7", @@ -63,6 +63,7 @@ docling = ["docling==2.27.0"] duckdb = ["ibis-framework[duckdb]>=10.0.0"] elasticsearch = ["elasticsearch>=8.13.0"] faiss = ["faiss-cpu>=1.7.0,<=1.10.0"] +flink = ["apache-flink>=2.2.1,<3"] gcp = [ "google-api-core>=1.23.0,<3", "googleapis-common-protos>=1.52.0,<2", @@ -278,6 +279,26 @@ dev = [ "pytest-xdist>=3.8.0", ] +[tool.uv] +conflicts = [ + [ + { extra = "flink" }, + { extra = "ge" }, + ], + [ + { extra = "flink" }, + { extra = "ci" }, + ], + [ + { extra = "flink" }, + { extra = "dev" }, + ], + [ + { extra = "flink" }, + { extra = "docs" }, + ], +] + # Pixi configuration [tool.pixi.workspace] channels = ["conda-forge"] diff --git a/sdk/python/feast/batch_feature_view.py b/sdk/python/feast/batch_feature_view.py index 95385da1d91..3bdd7d83606 100644 --- a/sdk/python/feast/batch_feature_view.py +++ b/sdk/python/feast/batch_feature_view.py @@ -169,7 +169,8 @@ def get_feature_transformation(self) -> Optional[Transformation]: TransformationMode.PYTHON, TransformationMode.SQL, TransformationMode.RAY, - ) or self.mode in ("pandas", "python", "sql", "ray"): + TransformationMode.FLINK, + ) or self.mode in ("pandas", "python", "sql", "ray", "flink"): return Transformation( mode=self.mode, udf=self.udf, udf_string=self.udf_string or "" ) diff --git a/sdk/python/feast/infra/compute_engines/dag/model.py b/sdk/python/feast/infra/compute_engines/dag/model.py index 5990eea6141..263c5029f4f 100644 --- a/sdk/python/feast/infra/compute_engines/dag/model.py +++ b/sdk/python/feast/infra/compute_engines/dag/model.py @@ -6,3 +6,4 @@ class DAGFormat(str, Enum): PANDAS = "pandas" ARROW = "arrow" RAY = "ray" + FLINK = "flink" diff --git a/sdk/python/feast/infra/compute_engines/feature_builder.py b/sdk/python/feast/infra/compute_engines/feature_builder.py index 43f17ee2986..2a102bf9f2f 100644 --- a/sdk/python/feast/infra/compute_engines/feature_builder.py +++ b/sdk/python/feast/infra/compute_engines/feature_builder.py @@ -158,10 +158,13 @@ def get_column_info( # we need to read ALL source columns, not just the output feature columns. # This is specifically for transformations that create new columns or need raw data. mode = getattr(getattr(view, "feature_transformation", None), "mode", None) - if mode in ("ray", "pandas", "python") or getattr(mode, "value", None) in ( + if mode in ("ray", "pandas", "python", "flink") or getattr( + mode, "value", None + ) in ( "ray", "pandas", "python", + "flink", ): # Signal to read all columns by passing empty list for feature_cols. # "python" (BatchFeatureView) transformations need all raw source columns — the diff --git a/sdk/python/feast/infra/compute_engines/flink/__init__.py b/sdk/python/feast/infra/compute_engines/flink/__init__.py new file mode 100644 index 00000000000..678bae5afc6 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/flink/__init__.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from feast.infra.compute_engines.flink.compute import ( + FlinkComputeEngine, + FlinkComputeEngineConfig, +) + +__all__ = [ + "FlinkComputeEngine", + "FlinkComputeEngineConfig", +] diff --git a/sdk/python/feast/infra/compute_engines/flink/compute.py b/sdk/python/feast/infra/compute_engines/flink/compute.py new file mode 100644 index 00000000000..4018427374e --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/flink/compute.py @@ -0,0 +1,154 @@ +from __future__ import annotations + +import logging +from typing import Any, Dict, Literal, Optional, Sequence, Union + +from feast import ( + BatchFeatureView, + Entity, + FeatureView, + OnDemandFeatureView, + StreamFeatureView, +) +from feast.infra.common.materialization_job import ( + MaterializationJob, + MaterializationJobStatus, + MaterializationTask, +) +from feast.infra.common.retrieval_task import HistoricalRetrievalTask +from feast.infra.compute_engines.base import ComputeEngine +from feast.infra.compute_engines.flink.feature_builder import FlinkFeatureBuilder +from feast.infra.compute_engines.flink.job import ( + FlinkDAGRetrievalJob, + FlinkMaterializationJob, +) +from feast.infra.compute_engines.flink.utils import create_flink_table_environment +from feast.infra.offline_stores.offline_store import OfflineStore, RetrievalJob +from feast.infra.online_stores.online_store import OnlineStore +from feast.infra.registry.base_registry import BaseRegistry +from feast.repo_config import FeastConfigBaseModel, RepoConfig + +logger = logging.getLogger(__name__) + + +class FlinkComputeEngineConfig(FeastConfigBaseModel): + """Configuration for the Apache Flink compute engine.""" + + type: Literal["flink.engine"] = "flink.engine" + """Flink compute engine type selector.""" + + execution_mode: Literal["batch", "streaming"] = "batch" + """PyFlink TableEnvironment execution mode.""" + + parallelism: Optional[int] = None + """Default Flink parallelism for jobs created by this engine.""" + + table_config: Optional[Dict[str, str]] = None + """Additional PyFlink table configuration entries.""" + + pandas_split_num: int = 1 + """Number of PyFlink Arrow source splits for pandas entity DataFrames.""" + + +class FlinkComputeEngine(ComputeEngine): + def __init__( + self, + *, + repo_config: RepoConfig, + offline_store: OfflineStore, + online_store: OnlineStore, + table_environment: Optional[Any] = None, + **kwargs, + ) -> None: + super().__init__( + repo_config=repo_config, + offline_store=offline_store, + online_store=online_store, + **kwargs, + ) + self.config = repo_config.batch_engine + assert isinstance(self.config, FlinkComputeEngineConfig) + self.table_env = table_environment or create_flink_table_environment( + self.config + ) + + def update( + self, + project: str, + views_to_delete: Sequence[ + Union[BatchFeatureView, StreamFeatureView, FeatureView] + ], + views_to_keep: Sequence[ + Union[BatchFeatureView, StreamFeatureView, FeatureView, OnDemandFeatureView] + ], + entities_to_delete: Sequence[Entity], + entities_to_keep: Sequence[Entity], + ) -> None: + """Flink compute engine does not provision Feast-managed infrastructure.""" + pass + + def teardown_infra( + self, + project: str, + fvs: Sequence[Union[BatchFeatureView, StreamFeatureView, FeatureView]], + entities: Sequence[Entity], + ) -> None: + """Flink compute engine does not tear down Feast-managed infrastructure.""" + pass + + def _materialize_one( + self, registry: BaseRegistry, task: MaterializationTask, **kwargs + ) -> MaterializationJob: + job_id = f"{task.feature_view.name}-{task.start_time}-{task.end_time}" + context = self.get_execution_context(registry, task) + + try: + builder = FlinkFeatureBuilder( + registry=registry, + table_env=self.table_env, + task=task, + split_num=self.config.pandas_split_num, + ) + plan = builder.build() + plan.execute(context) + return FlinkMaterializationJob( + job_id=job_id, + status=MaterializationJobStatus.SUCCEEDED, + ) + except Exception as exc: + logger.error("Flink materialization failed for %s: %s", job_id, exc) + return FlinkMaterializationJob( + job_id=job_id, + status=MaterializationJobStatus.ERROR, + error=exc, + ) + + def get_historical_features( + self, registry: BaseRegistry, task: HistoricalRetrievalTask + ) -> RetrievalJob: + context = self.get_execution_context(registry, task) + try: + builder = FlinkFeatureBuilder( + registry=registry, + table_env=self.table_env, + task=task, + split_num=self.config.pandas_split_num, + ) + plan = builder.build() + return FlinkDAGRetrievalJob( + plan=plan, + context=context, + full_feature_names=task.full_feature_name, + ) + except Exception as exc: + logger.error( + "Flink historical retrieval setup failed for %s: %s", + task.feature_view.name, + exc, + ) + return FlinkDAGRetrievalJob( + plan=None, + context=context, + full_feature_names=task.full_feature_name, + error=exc, + ) diff --git a/sdk/python/feast/infra/compute_engines/flink/feature_builder.py b/sdk/python/feast/infra/compute_engines/flink/feature_builder.py new file mode 100644 index 00000000000..4f4abe7bea1 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/flink/feature_builder.py @@ -0,0 +1,217 @@ +from __future__ import annotations + +import logging +from typing import Any, Union + +import pandas as pd + +from feast.infra.common.materialization_job import MaterializationTask +from feast.infra.common.retrieval_task import HistoricalRetrievalTask +from feast.infra.compute_engines.dag.node import DAGNode +from feast.infra.compute_engines.feature_builder import FeatureBuilder +from feast.infra.compute_engines.flink.nodes import ( + FlinkAggregationNode, + FlinkDedupNode, + FlinkFilterNode, + FlinkJoinNode, + FlinkOutputNode, + FlinkSourceReadNode, + FlinkTransformationNode, + FlinkValidationNode, +) +from feast.infra.registry.base_registry import BaseRegistry +from feast.types import PrimitiveFeastType, from_feast_to_pyarrow_type + +logger = logging.getLogger(__name__) + + +class FlinkFeatureBuilder(FeatureBuilder): + def __init__( + self, + registry: BaseRegistry, + table_env: Any, + task: Union[MaterializationTask, HistoricalRetrievalTask], + split_num: int, + ) -> None: + super().__init__(registry, task.feature_view, task) + self.table_env = table_env + self.split_num = split_num + + def _should_join_entity_df(self) -> bool: + return isinstance(self.task, HistoricalRetrievalTask) and ( + ( + isinstance(self.task.entity_df, pd.DataFrame) + and not self.task.entity_df.empty + ) + or ( + isinstance(self.task.entity_df, str) + and bool(self.task.entity_df.strip()) + ) + ) + + def _build(self, view: Any, input_nodes: list[DAGNode] | None) -> DAGNode: + if view.data_source: + last_node = self.build_source_node(view) + + if self._should_transform(view): + last_node = self.build_transformation_node(view, [last_node]) + + if self._should_join_entity_df(): + last_node = self.build_join_node(view, [last_node]) + + elif input_nodes: + if self._should_transform(view): + last_node = self.build_transformation_node(view, input_nodes) + else: + last_node = self.build_join_node(view, input_nodes) + else: + raise ValueError(f"FeatureView {view.name} has no valid source or inputs") + + last_node = self.build_filter_node(view, last_node) + + if self._should_aggregate(view): + last_node = self.build_aggregation_node(view, last_node) + elif self._should_dedupe(view): + last_node = self.build_dedup_node(view, last_node) + + if self._should_validate(view): + last_node = self.build_validation_node(view, last_node) + + return last_node + + def build_source_node(self, view: Any) -> FlinkSourceReadNode: + source = view.batch_source + column_info = self.get_column_info(view) + node = FlinkSourceReadNode( + f"{view.name}:source", + source, + column_info, + self.table_env, + self.split_num, + self.task.start_time, + self.task.end_time, + ) + self.nodes.append(node) + return node + + def build_aggregation_node( + self, view: Any, input_node: DAGNode + ) -> FlinkAggregationNode: + column_info = self.get_column_info(view) + node = FlinkAggregationNode( + f"{view.name}:agg", + column_info.join_keys_columns, + view.aggregations, + self.table_env, + self.split_num, + inputs=[input_node], + ) + self.nodes.append(node) + return node + + def build_join_node(self, view: Any, input_nodes: list[DAGNode]) -> FlinkJoinNode: + column_info = self.get_column_info(view) + node = FlinkJoinNode( + f"{view.name}:join", + column_info, + self.table_env, + self.split_num, + inputs=input_nodes, + ) + self.nodes.append(node) + return node + + def build_filter_node(self, view: Any, input_node: DAGNode) -> FlinkFilterNode: + filter_expr = getattr(view, "filter", None) + ttl = getattr(view, "ttl", None) + column_info = self.get_column_info(view) + node = FlinkFilterNode( + f"{view.name}:filter", + column_info, + self.table_env, + self.split_num, + filter_expr, + ttl, + inputs=[input_node], + ) + self.nodes.append(node) + return node + + def build_dedup_node(self, view: Any, input_node: DAGNode) -> FlinkDedupNode: + column_info = self.get_column_info(view) + node = FlinkDedupNode( + f"{view.name}:dedup", + column_info, + self.table_env, + self.split_num, + inputs=[input_node], + ) + self.nodes.append(node) + return node + + def build_transformation_node( + self, view: Any, input_nodes: list[DAGNode] + ) -> FlinkTransformationNode: + transform_config = view.feature_transformation + transformation_fn = ( + transform_config.udf + if hasattr(transform_config, "udf") + else transform_config + ) + node = FlinkTransformationNode( + f"{view.name}:transform", + transformation_fn, + self.table_env, + self.split_num, + inputs=input_nodes, + ) + self.nodes.append(node) + return node + + def build_output_nodes(self, view: Any, input_node: DAGNode) -> FlinkOutputNode: + node = FlinkOutputNode( + f"{view.name}:output", + self.dag_root.view, + self.table_env, + self.split_num, + isinstance(self.task, MaterializationTask), + [input_node], + ) + self.nodes.append(node) + return node + + def build_validation_node( + self, view: Any, input_node: DAGNode + ) -> FlinkValidationNode: + expected_columns = {} + json_columns: set[str] = set() + if hasattr(view, "features"): + for feature in view.features: + try: + expected_columns[feature.name] = from_feast_to_pyarrow_type( + feature.dtype + ) + except (ValueError, KeyError): + logger.debug( + "Could not resolve PyArrow type for feature '%s' " + "(dtype=%s), skipping type check for this column.", + feature.name, + feature.dtype, + ) + expected_columns[feature.name] = None + if ( + isinstance(feature.dtype, PrimitiveFeastType) + and feature.dtype.name == "JSON" + ): + json_columns.add(feature.name) + + node = FlinkValidationNode( + f"{view.name}:validate", + expected_columns, + json_columns, + self.table_env, + self.split_num, + inputs=[input_node], + ) + self.nodes.append(node) + return node diff --git a/sdk/python/feast/infra/compute_engines/flink/job.py b/sdk/python/feast/infra/compute_engines/flink/job.py new file mode 100644 index 00000000000..44d9afd1824 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/flink/job.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import List, Optional + +import pandas as pd +import pyarrow as pa + +from feast.infra.common.materialization_job import ( + MaterializationJob, + MaterializationJobStatus, +) +from feast.infra.compute_engines.dag.context import ExecutionContext +from feast.infra.compute_engines.dag.plan import ExecutionPlan +from feast.infra.compute_engines.flink.utils import flink_table_to_arrow +from feast.infra.offline_stores.offline_store import RetrievalJob, RetrievalMetadata +from feast.on_demand_feature_view import OnDemandFeatureView +from feast.saved_dataset import SavedDatasetStorage + + +class FlinkDAGRetrievalJob(RetrievalJob): + def __init__( + self, + plan: Optional[ExecutionPlan], + context: ExecutionContext, + full_feature_names: bool, + on_demand_feature_views: Optional[List[OnDemandFeatureView]] = None, + metadata: Optional[RetrievalMetadata] = None, + error: Optional[BaseException] = None, + ) -> None: + self._plan = plan + self._context = context + self._full_feature_names = full_feature_names + self._on_demand_feature_views = on_demand_feature_views or [] + self._metadata = metadata + self._error = error + self._arrow_table: Optional[pa.Table] = None + + def error(self) -> Optional[BaseException]: + return self._error + + def _ensure_executed(self) -> None: + if self._arrow_table is None: + if self._error is not None: + raise self._error + if self._plan is None: + raise RuntimeError("Execution plan is not set") + result = self._plan.execute(self._context) + self._arrow_table = flink_table_to_arrow(result.data) + + def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame: + self._ensure_executed() + assert self._arrow_table is not None + return self._arrow_table.to_pandas() + + def _to_arrow_internal(self, timeout: Optional[int] = None) -> pa.Table: + self._ensure_executed() + assert self._arrow_table is not None + return self._arrow_table + + @property + def full_feature_names(self) -> bool: + return self._full_feature_names + + @property + def on_demand_feature_views(self) -> List[OnDemandFeatureView]: + return self._on_demand_feature_views + + @property + def metadata(self) -> Optional[RetrievalMetadata]: + return self._metadata + + def persist( + self, + storage: SavedDatasetStorage, + allow_overwrite: bool = False, + timeout: Optional[int] = None, + ) -> None: + raise NotImplementedError("Persisting Flink retrieval jobs is not supported.") + + def to_remote_storage(self) -> List[str]: + raise NotImplementedError( + "Remote storage is not supported in FlinkDAGRetrievalJob." + ) + + def to_sql(self) -> str: + raise NotImplementedError("SQL generation is not supported for Flink DAGs.") + + +@dataclass +class FlinkMaterializationJob(MaterializationJob): + def __init__( + self, + job_id: str, + status: MaterializationJobStatus, + error: Optional[BaseException] = None, + ) -> None: + super().__init__() + self._job_id = job_id + self._status = status + self._error = error + + def status(self) -> MaterializationJobStatus: + return self._status + + def error(self) -> Optional[BaseException]: + return self._error + + def should_be_retried(self) -> bool: + return False + + def job_id(self) -> str: + return self._job_id + + def url(self) -> Optional[str]: + return None diff --git a/sdk/python/feast/infra/compute_engines/flink/nodes.py b/sdk/python/feast/infra/compute_engines/flink/nodes.py new file mode 100644 index 00000000000..818efd0b1ef --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/flink/nodes.py @@ -0,0 +1,768 @@ +from __future__ import annotations + +import logging +import uuid +from datetime import datetime, timedelta +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Union + +import pandas as pd +import pyarrow as pa + +from feast import BatchFeatureView, StreamFeatureView +from feast.aggregation import Aggregation, aggregation_specs_to_agg_ops +from feast.data_source import DataSource +from feast.infra.compute_engines.dag.context import ColumnInfo, ExecutionContext +from feast.infra.compute_engines.dag.model import DAGFormat +from feast.infra.compute_engines.dag.node import DAGNode +from feast.infra.compute_engines.dag.value import DAGValue +from feast.infra.compute_engines.flink.utils import ( + flink_table_to_pandas, + pandas_to_flink_table, +) +from feast.infra.compute_engines.utils import create_offline_store_retrieval_job +from feast.infra.offline_stores.offline_utils import ( + infer_event_timestamp_from_entity_df, +) +from feast.utils import _convert_arrow_to_proto + +logger = logging.getLogger(__name__) + +ENTITY_TS_ALIAS = "__entity_event_timestamp" +ENTITY_ROW_ID = "__feast_entity_row_id" +DEDUP_ROW_NUMBER = "__feast_row_number" + + +def _quote_identifier(identifier: str) -> str: + return f"`{identifier.replace('`', '``')}`" + + +def _qualified_column(alias: str, column: str) -> str: + return f"{alias}.{_quote_identifier(column)}" + + +def _select_column(alias: str, column: str, output_name: Optional[str] = None) -> str: + expr = _qualified_column(alias, column) + if output_name and output_name != column: + return f"{expr} AS {_quote_identifier(output_name)}" + return expr + + +def _flink_interval_literal(value: timedelta) -> str: + total_seconds = int(value.total_seconds()) + if total_seconds <= 0: + return "INTERVAL '0' SECOND" + + days, remainder = divmod(total_seconds, 24 * 60 * 60) + hours, remainder = divmod(remainder, 60 * 60) + minutes, seconds = divmod(remainder, 60) + parts = [] + if days: + parts.append(f"INTERVAL '{days}' DAY") + if hours: + parts.append(f"INTERVAL '{hours}' HOUR") + if minutes: + parts.append(f"INTERVAL '{minutes}' MINUTE") + if seconds: + parts.append(f"INTERVAL '{seconds}' SECOND") + return " + ".join(parts) + + +def _get_columns_from_schema(table: Any) -> Optional[List[str]]: + if not hasattr(table, "get_schema"): + return None + schema = table.get_schema() + if hasattr(schema, "get_field_names"): + return list(schema.get_field_names()) + if hasattr(schema, "get_field_count") and hasattr(schema, "get_field_name"): + return [schema.get_field_name(i) for i in range(schema.get_field_count())] + return None + + +def _get_columns(value: DAGValue) -> List[str]: + metadata_columns = value.metadata.get("columns") if value.metadata else None + if metadata_columns: + return list(metadata_columns) + schema_columns = _get_columns_from_schema(value.data) + if schema_columns: + return schema_columns + raise ValueError( + "Could not infer columns for Flink DAG value from metadata or PyFlink schema." + ) + + +def _can_use_sql(table_env: Any) -> bool: + return hasattr(table_env, "create_temporary_view") and hasattr( + table_env, "sql_query" + ) + + +def _require_sql(table_env: Any, node_name: str) -> None: + if not _can_use_sql(table_env): + raise RuntimeError( + f"Flink node '{node_name}' requires a PyFlink TableEnvironment with " + "create_temporary_view() and sql_query()." + ) + + +def _register_table(table_env: Any, table: Any, prefix: str) -> str: + view_name = f"__feast_{prefix}_{uuid.uuid4().hex}" + table_env.create_temporary_view(view_name, table) + return view_name + + +def _sql_value( + table_env: Any, + query: str, + columns: Iterable[str], + metadata: Optional[dict] = None, +) -> DAGValue: + return DAGValue( + data=table_env.sql_query(query), + format=DAGFormat.FLINK, + metadata={**(metadata or {}), "columns": list(columns), "native_sql": query}, + ) + + +def _entity_timestamp_column_from_columns(columns: List[str]) -> str: + if ENTITY_TS_ALIAS in columns: + return ENTITY_TS_ALIAS + if "event_timestamp" in columns: + return "event_timestamp" + raise ValueError( + "SQL-based entity_df for FlinkComputeEngine must select an " + "`event_timestamp` column." + ) + + +def _entity_value_from_dataframe( + table_env: Any, + entity_df: pd.DataFrame, + split_num: int, +) -> tuple[Any, List[str], str]: + entity_df = entity_df.copy() + entity_df[ENTITY_ROW_ID] = range(len(entity_df)) + entity_schema = dict(zip(entity_df.columns, entity_df.dtypes)) + entity_ts_col = infer_event_timestamp_from_entity_df(entity_schema) + if entity_ts_col != ENTITY_TS_ALIAS: + entity_df = entity_df.rename(columns={entity_ts_col: ENTITY_TS_ALIAS}) + return ( + pandas_to_flink_table(table_env, entity_df, split_num), + list(entity_df.columns), + entity_ts_col, + ) + + +def _entity_value_from_sql( + table_env: Any, + entity_sql: str, + join_keys: List[str], +) -> tuple[Any, List[str], str]: + _require_sql(table_env, "entity_df") + entity_table = table_env.sql_query(entity_sql) + entity_columns = _get_columns_from_schema(entity_table) + if entity_columns is None: + raise ValueError("Could not infer columns for SQL-based entity_df.") + + entity_ts_col = _entity_timestamp_column_from_columns(entity_columns) + entity_view = _register_table(table_env, entity_table, "entity_sql") + output_columns = [ + ENTITY_TS_ALIAS if column == entity_ts_col else column + for column in entity_columns + ] + select_exprs = [ + _select_column( + "entity_src", + column, + ENTITY_TS_ALIAS if column == entity_ts_col else column, + ) + for column in entity_columns + ] + order_columns = [ + column for column in [entity_ts_col, *join_keys] if column in entity_columns + ] + order_expr = ", ".join( + _qualified_column("entity_src", col) for col in order_columns + ) + if not order_expr: + order_expr = _qualified_column("entity_src", entity_columns[0]) + select_exprs.append( + f"ROW_NUMBER() OVER (ORDER BY {order_expr}) - 1 AS " + f"{_quote_identifier(ENTITY_ROW_ID)}" + ) + query = ( + f"SELECT {', '.join(select_exprs)} " + f"FROM {_quote_identifier(entity_view)} AS entity_src" + ) + output_columns.append(ENTITY_ROW_ID) + value = _sql_value( + table_env, + query, + output_columns, + metadata={"entity_timestamp_column": entity_ts_col}, + ) + return value.data, output_columns, entity_ts_col + + +def _entity_value_from_context( + table_env: Any, + context: ExecutionContext, + split_num: int, + join_keys: List[str], +) -> tuple[Any, List[str], str]: + if isinstance(context.entity_df, pd.DataFrame): + return _entity_value_from_dataframe(table_env, context.entity_df, split_num) + if isinstance(context.entity_df, str): + return _entity_value_from_sql(table_env, context.entity_df, join_keys) + raise TypeError( + "FlinkComputeEngine entity_df must be a pandas DataFrame, SQL string, or None." + ) + + +class FlinkSourceReadNode(DAGNode): + def __init__( + self, + name: str, + source: DataSource, + column_info: ColumnInfo, + table_env: Any, + split_num: int, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + ) -> None: + super().__init__(name) + self.source = source + self.column_info = column_info + self.table_env = table_env + self.split_num = split_num + self.start_time = start_time + self.end_time = end_time + + def execute(self, context: ExecutionContext) -> DAGValue: + retrieval_job = create_offline_store_retrieval_job( + data_source=self.source, + column_info=self.column_info, + context=context, + start_time=self.start_time, + end_time=self.end_time, + ) + if not hasattr(retrieval_job, "to_flink_table"): + raise TypeError( + "FlinkComputeEngine source reads require RetrievalJob.to_flink_table(" + "table_env). Configure an offline store retrieval job that returns " + "native PyFlink tables instead of Arrow/pandas results." + ) + + flink_table = retrieval_job.to_flink_table(self.table_env) + columns = _get_columns_from_schema(flink_table) + if columns is None: + raise ValueError( + "Could not infer columns for source Flink table returned by " + "RetrievalJob.to_flink_table(table_env)." + ) + + if self.column_info.field_mapping: + view_name = _register_table(self.table_env, flink_table, "source_read") + select_exprs = [ + _select_column( + "src", + col, + self.column_info.field_mapping.get(col, col), + ) + for col in columns + ] + renamed_columns = [ + self.column_info.field_mapping.get(col, col) for col in columns + ] + query = ( + f"SELECT {', '.join(select_exprs)} " + f"FROM {_quote_identifier(view_name)} AS src" + ) + return _sql_value( + self.table_env, + query, + renamed_columns, + metadata={ + "source": "feature_view_batch_source", + "timestamp_field": self.column_info.timestamp_column, + "created_timestamp_column": ( + self.column_info.created_timestamp_column + ), + "start_date": self.start_time, + "end_date": self.end_time, + }, + ) + + return DAGValue( + data=flink_table, + format=DAGFormat.FLINK, + metadata={ + "source": "feature_view_batch_source", + "timestamp_field": self.column_info.timestamp_column, + "created_timestamp_column": (self.column_info.created_timestamp_column), + "start_date": self.start_time, + "end_date": self.end_time, + "columns": columns, + }, + ) + + +class FlinkJoinNode(DAGNode): + def __init__( + self, + name: str, + column_info: ColumnInfo, + table_env: Any, + split_num: int, + inputs: Optional[List[DAGNode]] = None, + how: str = "left", + ) -> None: + super().__init__(name, inputs=inputs or []) + self.column_info = column_info + self.table_env = table_env + self.split_num = split_num + self.how = how + + def execute(self, context: ExecutionContext) -> DAGValue: + input_values = self.get_input_values(context) + for value in input_values: + value.assert_format(DAGFormat.FLINK) + if not input_values: + raise RuntimeError(f"FlinkJoinNode '{self.name}' requires inputs") + + _require_sql(self.table_env, self.name) + return self._execute_sql_join(input_values, context) + + def _execute_sql_join( + self, input_values: List[DAGValue], context: ExecutionContext + ) -> DAGValue: + join_keys = self.column_info.join_keys_columns + view_names = [ + _register_table(self.table_env, value.data, f"join_{index}") + for index, value in enumerate(input_values) + ] + columns_by_input = [_get_columns(value) for value in input_values] + output_columns = list(columns_by_input[0]) + seen_columns = set(output_columns) + select_exprs = [_select_column("t0", column) for column in columns_by_input[0]] + + joins = [] + for index, view_name in enumerate(view_names[1:], start=1): + alias = f"t{index}" + on_clause = " AND ".join( + f"{_qualified_column('t0', key)} = {_qualified_column(alias, key)}" + for key in join_keys + ) + joins.append( + f"{self.how.upper()} JOIN {_quote_identifier(view_name)} AS {alias} " + f"ON {on_clause}" + ) + for column in columns_by_input[index]: + if column in join_keys or column in seen_columns: + continue + output_columns.append(column) + seen_columns.add(column) + select_exprs.append(_select_column(alias, column)) + + query = ( + f"SELECT {', '.join(select_exprs)} " + f"FROM {_quote_identifier(view_names[0])} AS t0 " + f"{' '.join(joins)}" + ) + joined_value = _sql_value( + self.table_env, + query, + output_columns, + metadata={"joined_on": join_keys, "join_type": self.how}, + ) + + if context.entity_df is None: + return joined_value + + entity_table, entity_columns, entity_ts_col = _entity_value_from_context( + self.table_env, context, self.split_num, join_keys + ) + entity_view = _register_table(self.table_env, entity_table, "entity") + feature_view = _register_table(self.table_env, joined_value.data, "features") + feature_columns = [ + column + for column in output_columns + if column not in join_keys and column not in entity_columns + ] + select_entity = [_select_column("e", column) for column in entity_columns] + select_features = [_select_column("f", column) for column in feature_columns] + on_clause = " AND ".join( + f"{_qualified_column('e', key)} = {_qualified_column('f', key)}" + for key in join_keys + ) + entity_join_query = ( + f"SELECT {', '.join(select_entity + select_features)} " + f"FROM {_quote_identifier(entity_view)} AS e " + f"LEFT JOIN {_quote_identifier(feature_view)} AS f ON {on_clause}" + ) + return _sql_value( + self.table_env, + entity_join_query, + entity_columns + feature_columns, + metadata={ + "joined_on": join_keys, + "join_type": "left", + "entity_timestamp_column": entity_ts_col, + }, + ) + + +class FlinkFilterNode(DAGNode): + def __init__( + self, + name: str, + column_info: ColumnInfo, + table_env: Any, + split_num: int, + filter_expr: Optional[str] = None, + ttl: Optional[timedelta] = None, + inputs: Optional[List[DAGNode]] = None, + ) -> None: + super().__init__(name, inputs=inputs) + self.column_info = column_info + self.table_env = table_env + self.split_num = split_num + self.filter_expr = filter_expr + self.ttl = ttl + + def execute(self, context: ExecutionContext) -> DAGValue: + input_value = self.get_single_input_value(context) + input_value.assert_format(DAGFormat.FLINK) + + _require_sql(self.table_env, self.name) + return self._execute_sql_filter(input_value) + + def _execute_sql_filter(self, input_value: DAGValue) -> DAGValue: + columns = _get_columns(input_value) + timestamp_column = self.column_info.timestamp_column + conditions = [] + + if ENTITY_TS_ALIAS in columns and timestamp_column in columns: + conditions.append( + f"{_quote_identifier(timestamp_column)} <= " + f"{_quote_identifier(ENTITY_TS_ALIAS)}" + ) + if self.ttl: + ttl_interval = _flink_interval_literal(self.ttl) + conditions.append( + f"{_quote_identifier(timestamp_column)} >= " + f"{_quote_identifier(ENTITY_TS_ALIAS)} - " + f"({ttl_interval})" + ) + + if self.filter_expr: + conditions.append(f"({self.filter_expr})") + + if not conditions: + return input_value + + view_name = _register_table(self.table_env, input_value.data, "filter") + query = ( + f"SELECT * FROM {_quote_identifier(view_name)} " + f"WHERE {' AND '.join(conditions)}" + ) + return _sql_value( + self.table_env, + query, + columns, + metadata={**(input_value.metadata or {}), "filter_applied": True}, + ) + + +class FlinkAggregationNode(DAGNode): + def __init__( + self, + name: str, + group_keys: List[str], + aggregations: List[Aggregation], + table_env: Any, + split_num: int, + inputs: Optional[List[DAGNode]] = None, + ) -> None: + super().__init__(name, inputs=inputs) + self.group_keys = group_keys + self.aggregations = aggregations + self.table_env = table_env + self.split_num = split_num + + def execute(self, context: ExecutionContext) -> DAGValue: + agg_ops = aggregation_specs_to_agg_ops( + self.aggregations, + time_window_unsupported_error_message=( + "Time window aggregation is not yet supported in the Flink compute " + "engine. Use non-windowed aggregations or pre-window upstream in Flink." + ), + ) + input_value = self.get_single_input_value(context) + input_value.assert_format(DAGFormat.FLINK) + + _require_sql(self.table_env, self.name) + return self._execute_sql_aggregation(input_value, agg_ops) + + def _execute_sql_aggregation( + self, input_value: DAGValue, agg_ops: Dict[str, tuple[str, str]] + ) -> DAGValue: + view_name = _register_table(self.table_env, input_value.data, "aggregate") + select_exprs = [_quote_identifier(key) for key in self.group_keys] + for alias, (function, column) in agg_ops.items(): + sql_function = { + "mean": "AVG", + "avg": "AVG", + "sum": "SUM", + "min": "MIN", + "max": "MAX", + "count": "COUNT", + "nunique": "COUNT_DISTINCT", + "std": "STDDEV_SAMP", + "var": "VAR_SAMP", + }.get(function, function.upper()) + if sql_function == "COUNT_DISTINCT": + expr = ( + f"COUNT(DISTINCT {_quote_identifier(column)}) " + f"AS {_quote_identifier(alias)}" + ) + else: + expr = ( + f"{sql_function}({_quote_identifier(column)}) " + f"AS {_quote_identifier(alias)}" + ) + select_exprs.append(expr) + + query = ( + f"SELECT {', '.join(select_exprs)} " + f"FROM {_quote_identifier(view_name)} " + f"GROUP BY {', '.join(_quote_identifier(key) for key in self.group_keys)}" + ) + return _sql_value( + self.table_env, + query, + [*self.group_keys, *agg_ops.keys()], + metadata={"aggregated": True}, + ) + + +class FlinkDedupNode(DAGNode): + def __init__( + self, + name: str, + column_info: ColumnInfo, + table_env: Any, + split_num: int, + inputs: Optional[List[DAGNode]] = None, + ) -> None: + super().__init__(name, inputs=inputs) + self.column_info = column_info + self.table_env = table_env + self.split_num = split_num + + def execute(self, context: ExecutionContext) -> DAGValue: + input_value = self.get_single_input_value(context) + input_value.assert_format(DAGFormat.FLINK) + + _require_sql(self.table_env, self.name) + return self._execute_sql_dedup(input_value) + + def _execute_sql_dedup(self, input_value: DAGValue) -> DAGValue: + columns = _get_columns(input_value) + dedup_keys = ( + [ENTITY_ROW_ID] + if ENTITY_ROW_ID in columns + else self.column_info.join_keys_columns + ) + dedup_keys = [key for key in dedup_keys if key in columns] + if not dedup_keys: + return input_value + + order_columns = [ + self.column_info.timestamp_column, + self.column_info.created_timestamp_column, + ] + order_exprs = [ + f"{_quote_identifier(column)} DESC" + for column in order_columns + if column and column in columns + ] + if not order_exprs: + order_exprs = [f"{_quote_identifier(dedup_keys[0])} ASC"] + + view_name = _register_table(self.table_env, input_value.data, "dedup") + select_columns = ", ".join(_quote_identifier(column) for column in columns) + query = ( + f"SELECT {select_columns} FROM (" + f"SELECT *, ROW_NUMBER() OVER (" + f"PARTITION BY {', '.join(_quote_identifier(key) for key in dedup_keys)} " + f"ORDER BY {', '.join(order_exprs)}" + f") AS {_quote_identifier(DEDUP_ROW_NUMBER)} " + f"FROM {_quote_identifier(view_name)}" + f") WHERE {_quote_identifier(DEDUP_ROW_NUMBER)} = 1" + ) + return _sql_value( + self.table_env, + query, + columns, + metadata={**(input_value.metadata or {}), "deduped": True}, + ) + + +class FlinkTransformationNode(DAGNode): + def __init__( + self, + name: str, + transformation_fn: Callable[..., Any], + table_env: Any, + split_num: int, + inputs: Optional[List[DAGNode]] = None, + ) -> None: + super().__init__(name, inputs=inputs) + self.transformation_fn = transformation_fn + self.table_env = table_env + self.split_num = split_num + + def execute(self, context: ExecutionContext) -> DAGValue: + input_values = self.get_input_values(context) + for value in input_values: + value.assert_format(DAGFormat.FLINK) + + input_tables = [value.data for value in input_values] + transformed = self.transformation_fn(*input_tables) + + columns = _get_columns_from_schema(transformed) + if columns is None: + raise TypeError( + "Flink transformations must return a PyFlink Table with a schema." + ) + + return DAGValue( + data=transformed, + format=DAGFormat.FLINK, + metadata={"transformed": True, "columns": columns or []}, + ) + + +class FlinkValidationNode(DAGNode): + def __init__( + self, + name: str, + expected_columns: dict[str, Optional[pa.DataType]], + json_columns: Optional[Set[str]], + table_env: Any, + split_num: int, + inputs: Optional[List[DAGNode]] = None, + ) -> None: + super().__init__(name, inputs=inputs) + self.expected_columns = expected_columns + self.json_columns = json_columns or set() + self.table_env = table_env + self.split_num = split_num + + def execute(self, context: ExecutionContext) -> DAGValue: + input_value = self.get_single_input_value(context) + input_value.assert_format(DAGFormat.FLINK) + + columns = _get_columns(input_value) + missing = set(self.expected_columns.keys()) - set(columns) + if missing: + raise ValueError( + f"[Validation: {self.name}] Missing expected columns: {missing}. " + f"Actual columns: {sorted(columns)}" + ) + if not self.json_columns: + return DAGValue( + data=input_value.data, + format=DAGFormat.FLINK, + metadata={**(input_value.metadata or {}), "validated": True}, + ) + + raise NotImplementedError( + "JSON value validation is not supported by FlinkComputeEngine without " + "collecting data out of Flink. Validate JSON upstream in Flink SQL or " + "disable JSON validation for this FeatureView." + ) + + +class FlinkOutputNode(DAGNode): + def __init__( + self, + name: str, + feature_view: Union[BatchFeatureView, StreamFeatureView], + table_env: Any, + split_num: int, + write_output: bool, + inputs: Optional[List[DAGNode]] = None, + ) -> None: + super().__init__(name, inputs=inputs) + self.feature_view = feature_view + self.table_env = table_env + self.split_num = split_num + self.write_output = write_output + + def execute(self, context: ExecutionContext) -> DAGValue: + input_value = self.get_single_input_value(context) + input_value.assert_format(DAGFormat.FLINK) + output_value = self._drop_internal_columns(input_value) + output_table = output_value.data + if not self.write_output: + return output_value + + output_df = flink_table_to_pandas(output_table) + output_arrow = pa.Table.from_pandas(output_df) + + if output_arrow.num_rows == 0: + return output_value + + if self.feature_view.online: + join_key_to_value_type = { + entity.name: entity.dtype.to_value_type() + for entity in self.feature_view.entity_columns + } + batch_size = ( + context.repo_config.materialization_config.online_write_batch_size + ) + batches = ( + [output_arrow] + if batch_size is None + else output_arrow.to_batches(max_chunksize=batch_size) + ) + for batch in batches: + rows_to_write = _convert_arrow_to_proto( + batch, self.feature_view, join_key_to_value_type + ) + context.online_store.online_write_batch( + config=context.repo_config, + table=self.feature_view, + data=rows_to_write, + progress=lambda x: None, + ) + + if self.feature_view.offline: + context.offline_store.offline_write_batch( + config=context.repo_config, + feature_view=self.feature_view, + table=output_arrow, + progress=lambda x: None, + ) + + return output_value + + def _drop_internal_columns(self, input_value: DAGValue) -> DAGValue: + columns = _get_columns(input_value) + output_columns = [column for column in columns if column != ENTITY_ROW_ID] + if output_columns == columns: + return input_value + + _require_sql(self.table_env, self.name) + view_name = _register_table(self.table_env, input_value.data, "output") + query = ( + f"SELECT {', '.join(_quote_identifier(column) for column in output_columns)} " + f"FROM {_quote_identifier(view_name)}" + ) + return _sql_value( + self.table_env, + query, + output_columns, + metadata={**(input_value.metadata or {}), "output_cleaned": True}, + ) diff --git a/sdk/python/feast/infra/compute_engines/flink/utils.py b/sdk/python/feast/infra/compute_engines/flink/utils.py new file mode 100644 index 00000000000..43c9f285a97 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/flink/utils.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import pandas as pd +import pyarrow as pa + +if TYPE_CHECKING: + from feast.infra.compute_engines.flink.compute import FlinkComputeEngineConfig + + +def create_flink_table_environment(config: FlinkComputeEngineConfig) -> Any: + """Create a PyFlink TableEnvironment from Feast engine config.""" + try: + from pyflink.common import Configuration + from pyflink.table import EnvironmentSettings, TableEnvironment + except ImportError as exc: + raise ImportError( + "FlinkComputeEngine requires PyFlink. Install Feast with the `flink` " + "extra or otherwise make the `pyflink` package available to Feast." + ) from exc + + flink_conf = Configuration() + for key, value in (config.table_config or {}).items(): + flink_conf.set_string(key, value) + if config.parallelism is not None: + flink_conf.set_string("parallelism.default", str(config.parallelism)) + + builder = EnvironmentSettings.new_instance().with_configuration(flink_conf) + if config.execution_mode == "streaming": + builder = builder.in_streaming_mode() + else: + builder = builder.in_batch_mode() + return TableEnvironment.create(builder.build()) + + +def pandas_to_flink_table(table_env: Any, df: pd.DataFrame, split_num: int = 1) -> Any: + """Convert a pandas DataFrame to a PyFlink table.""" + schema = list(df.columns) + return table_env.from_pandas(df, schema=schema, splits_num=split_num) + + +def flink_table_to_pandas(table: Any) -> pd.DataFrame: + """Collect a PyFlink table into pandas.""" + if hasattr(table, "to_pandas"): + return table.to_pandas() + raise TypeError(f"Expected a PyFlink table, got {type(table)}") + + +def flink_table_to_arrow(table: Any) -> pa.Table: + """Collect a PyFlink table into Arrow.""" + value = flink_table_to_pandas(table) + return pa.Table.from_pandas(value) diff --git a/sdk/python/feast/repo_config.py b/sdk/python/feast/repo_config.py index 7518f613788..0f832ea0d69 100644 --- a/sdk/python/feast/repo_config.py +++ b/sdk/python/feast/repo_config.py @@ -50,6 +50,7 @@ "k8s": "feast.infra.compute_engines.kubernetes.k8s_engine.KubernetesComputeEngine", "spark.engine": "feast.infra.compute_engines.spark.compute.SparkComputeEngine", "ray.engine": "feast.infra.compute_engines.ray.compute.RayComputeEngine", + "flink.engine": "feast.infra.compute_engines.flink.compute.FlinkComputeEngine", } LEGACY_ONLINE_STORE_CLASS_FOR_TYPE = { diff --git a/sdk/python/feast/stream_feature_view.py b/sdk/python/feast/stream_feature_view.py index 9ee07e6a199..c2b4625214a 100644 --- a/sdk/python/feast/stream_feature_view.py +++ b/sdk/python/feast/stream_feature_view.py @@ -207,7 +207,8 @@ def get_feature_transformation(self) -> Optional[Transformation]: TransformationMode.PYTHON, TransformationMode.SPARK_SQL, TransformationMode.SPARK, - ) or self.mode in ("pandas", "python", "spark_sql", "spark"): + TransformationMode.FLINK, + ) or self.mode in ("pandas", "python", "spark_sql", "spark", "flink"): return Transformation( mode=self.mode, udf=self.udf, udf_string=self.udf_string or "" ) diff --git a/sdk/python/feast/transformation/factory.py b/sdk/python/feast/transformation/factory.py index 16d7a7570d5..a181b7dea69 100644 --- a/sdk/python/feast/transformation/factory.py +++ b/sdk/python/feast/transformation/factory.py @@ -7,6 +7,7 @@ "sql": "feast.transformation.sql_transformation.SQLTransformation", "spark_sql": "feast.transformation.spark_transformation.SparkTransformation", "spark": "feast.transformation.spark_transformation.SparkTransformation", + "flink": "feast.transformation.flink_transformation.FlinkTransformation", "ray": "feast.transformation.ray_transformation.RayTransformation", } diff --git a/sdk/python/feast/transformation/flink_transformation.py b/sdk/python/feast/transformation/flink_transformation.py new file mode 100644 index 00000000000..83b929c4c96 --- /dev/null +++ b/sdk/python/feast/transformation/flink_transformation.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from typing import Any, Callable, Optional, cast + +from feast.transformation.base import Transformation +from feast.transformation.mode import TransformationMode + + +class FlinkTransformation(Transformation): + """Transformation wrapper for Flink compute-engine UDFs. + + The UDF is expected to accept PyFlink Table objects and return a PyFlink + Table. + """ + + def __new__( + cls, + udf: Optional[Callable[..., Any]] = None, + udf_string: Optional[str] = None, + name: Optional[str] = None, + tags: Optional[dict[str, str]] = None, + description: str = "", + owner: str = "", + *args, + **kwargs, + ) -> "FlinkTransformation": + if udf is None and udf_string is None: + return cast("FlinkTransformation", object.__new__(cls)) + if udf is None: + raise ValueError("udf parameter cannot be None") + if udf_string is None: + raise ValueError("udf_string parameter cannot be None") + return cast( + "FlinkTransformation", + super(FlinkTransformation, cls).__new__( + cls, + mode=TransformationMode.FLINK, + udf=udf, + name=name, + udf_string=udf_string, + tags=tags, + description=description, + owner=owner, + ), + ) + + def __init__( + self, + udf: Optional[Callable[..., Any]] = None, + udf_string: Optional[str] = None, + name: Optional[str] = None, + tags: Optional[dict[str, str]] = None, + description: str = "", + owner: str = "", + *args, + **kwargs, + ) -> None: + if udf is None and udf_string is None: + return + if udf is None: + raise ValueError("udf parameter cannot be None") + if udf_string is None: + raise ValueError("udf_string parameter cannot be None") + super().__init__( + mode=TransformationMode.FLINK, + udf=udf, + name=name, + udf_string=udf_string, + tags=tags, + description=description, + owner=owner, + ) + + def transform(self, *inputs: Any) -> Any: + return self.udf(*inputs) + + def infer_features(self, *args, **kwargs) -> Any: + pass diff --git a/sdk/python/feast/transformation/mode.py b/sdk/python/feast/transformation/mode.py index 44d38d8e99c..bd6fdf22424 100644 --- a/sdk/python/feast/transformation/mode.py +++ b/sdk/python/feast/transformation/mode.py @@ -6,6 +6,7 @@ class TransformationMode(Enum): PANDAS = "pandas" SPARK_SQL = "spark_sql" SPARK = "spark" + FLINK = "flink" RAY = "ray" SQL = "sql" SUBSTRAIT = "substrait" diff --git a/sdk/python/tests/unit/infra/compute_engines/flink/__init__.py b/sdk/python/tests/unit/infra/compute_engines/flink/__init__.py new file mode 100644 index 00000000000..8b137891791 --- /dev/null +++ b/sdk/python/tests/unit/infra/compute_engines/flink/__init__.py @@ -0,0 +1 @@ + diff --git a/sdk/python/tests/unit/infra/compute_engines/flink/test_flink_compute_engine.py b/sdk/python/tests/unit/infra/compute_engines/flink/test_flink_compute_engine.py new file mode 100644 index 00000000000..98bb3f6d3d8 --- /dev/null +++ b/sdk/python/tests/unit/infra/compute_engines/flink/test_flink_compute_engine.py @@ -0,0 +1,969 @@ +from __future__ import annotations + +import re +from datetime import datetime, timedelta +from pathlib import Path +from typing import Any, List, Optional +from unittest.mock import MagicMock + +import pandas as pd +import pyarrow as pa +import pytest + +from feast import BatchFeatureView, Entity, Field, FileSource +from feast.aggregation import Aggregation +from feast.infra.common.materialization_job import ( + MaterializationJobStatus, + MaterializationTask, +) +from feast.infra.common.retrieval_task import HistoricalRetrievalTask +from feast.infra.compute_engines.dag.context import ColumnInfo, ExecutionContext +from feast.infra.compute_engines.dag.model import DAGFormat +from feast.infra.compute_engines.dag.node import DAGNode +from feast.infra.compute_engines.dag.value import DAGValue +from feast.infra.compute_engines.flink.compute import ( + FlinkComputeEngine, + FlinkComputeEngineConfig, +) +from feast.infra.compute_engines.flink.nodes import ( + ENTITY_ROW_ID, + ENTITY_TS_ALIAS, + FlinkAggregationNode, + FlinkDedupNode, + FlinkFilterNode, + FlinkJoinNode, + FlinkSourceReadNode, + FlinkTransformationNode, + FlinkValidationNode, + _flink_interval_literal, +) +from feast.infra.offline_stores.offline_store import RetrievalJob, RetrievalMetadata +from feast.on_demand_feature_view import OnDemandFeatureView +from feast.repo_config import RepoConfig +from feast.saved_dataset import SavedDatasetStorage +from feast.types import Float32 +from feast.value_type import ValueType + + +class FakeFlinkTable: + def __init__(self, df: pd.DataFrame) -> None: + self._df = df.copy() + + def to_pandas(self) -> pd.DataFrame: + return self._df.copy() + + def get_schema(self) -> FakeFlinkSchema: + return FakeFlinkSchema(list(self._df.columns)) + + +class FakeTableEnvironment: + def __init__(self) -> None: + self.created_tables: List[pd.DataFrame] = [] + self.split_nums: List[int] = [] + self.views: dict[str, object] = {} + self.queries: List[str] = [] + + def from_pandas( + self, + df: pd.DataFrame, + schema: object = None, + splits_num: int = 1, + split_num: Optional[int] = None, + ) -> FakeFlinkTable: + self.created_tables.append(df.copy()) + self.split_nums.append(split_num if split_num is not None else splits_num) + return FakeFlinkTable(df) + + def create_temporary_view( + self, view_path: str, table_or_data_stream: object, *args: object + ) -> None: + self.views[view_path] = table_or_data_stream + + def sql_query(self, query: str) -> Any: + self.queries.append(query) + return FakeFlinkTable(self._evaluate_sql(query)) + + def _view_df(self, view_name: str) -> pd.DataFrame: + table = self.views[view_name] + if isinstance(table, FakeFlinkTable): + return table.to_pandas() + if isinstance(table, FakeNativeFlinkTable): + return pd.DataFrame(columns=table.get_schema().get_field_names()) + raise TypeError(f"Unsupported fake Flink table type: {type(table)}") + + def _evaluate_sql(self, query: str) -> pd.DataFrame: + if "ROW_NUMBER() OVER" in query: + return self._evaluate_row_number_query(query) + if " GROUP BY " in query: + return self._evaluate_group_by_query(query) + if " JOIN " in query: + return self._evaluate_join_query(query) + if " WHERE " in query: + return self._evaluate_where_query(query) + return self._evaluate_select_query(query) + + def _extract_views(self, query: str) -> List[str]: + return re.findall(r"(?:FROM|JOIN)\s+`([^`]+)`", query) + + def _evaluate_select_query(self, query: str) -> pd.DataFrame: + views = self._extract_views(query) + if not views: + if "FROM entities" in query: + return self._view_df("entities")[["driver_id", "event_timestamp"]] + raise ValueError(f"Could not infer source view from query: {query}") + source_df = self._view_df(views[-1]) + select_clause = query.split(" FROM ", 1)[0].removeprefix("SELECT ") + if select_clause == "*": + return source_df + result = pd.DataFrame() + for column_expr in select_clause.split(", "): + parts = re.findall(r"`([^`]+)`", column_expr) + if not parts: + continue + source_column = parts[0] + output_column = parts[-1] + result[output_column] = source_df[source_column] + return result + + def _evaluate_where_query(self, query: str) -> pd.DataFrame: + view_name = self._extract_views(query)[0] + df = self._view_df(view_name) + if "`event_timestamp` <= `__entity_event_timestamp`" in query: + df = df[df["event_timestamp"] <= df[ENTITY_TS_ALIAS]] + if "conv_rate > 0.15" in query: + df = df[df["conv_rate"] > 0.15] + return df.reset_index(drop=True) + + def _evaluate_group_by_query(self, query: str) -> pd.DataFrame: + view_name = self._extract_views(query)[0] + df = self._view_df(view_name) + return ( + df.groupby("driver_id") + .agg(sum_conv_rate=pd.NamedAgg(column="conv_rate", aggfunc="sum")) + .reset_index() + ) + + def _evaluate_row_number_query(self, query: str) -> pd.DataFrame: + view_name = self._extract_views(query)[-1] + df = self._view_df(view_name) + if " - 1 AS " in query: + df = df.copy() + df[ENTITY_ROW_ID] = range(len(df)) + if " AS `__entity_event_timestamp`" in query: + df = df.rename(columns={"event_timestamp": ENTITY_TS_ALIAS}) + return df.reset_index(drop=True) + + dedup_keys = [ENTITY_ROW_ID] if ENTITY_ROW_ID in df.columns else ["driver_id"] + sort_keys = [ + column for column in ["event_timestamp", "created"] if column in df + ] + return ( + df.sort_values(by=sort_keys, ascending=False) + .drop_duplicates(subset=dedup_keys) + .reset_index(drop=True) + ) + + def _evaluate_join_query(self, query: str) -> pd.DataFrame: + views = self._extract_views(query) + if " AS e LEFT JOIN " in query: + entity_df = self._view_df(views[-2]) + feature_df = self._view_df(views[-1]) + feature_columns = [ + column + for column in feature_df.columns + if column not in entity_df.columns and column != "driver_id" + ] + return entity_df.merge( + feature_df[["driver_id", *feature_columns]], + on="driver_id", + how="left", + ) + + joined_df = self._view_df(views[0]) + for view_name in views[1:]: + joined_df = joined_df.merge( + self._view_df(view_name), on="driver_id", how="left" + ) + return joined_df + + +class FakeFlinkSchema: + def __init__(self, columns: List[str]) -> None: + self._columns = columns + + def get_field_names(self) -> List[str]: + return list(self._columns) + + +class FakeNativeFlinkTable: + def __init__(self, columns: List[str]) -> None: + self._columns = columns + + def get_schema(self) -> FakeFlinkSchema: + return FakeFlinkSchema(self._columns) + + +class RecordingTableEnvironment(FakeTableEnvironment): + def __init__(self) -> None: + super().__init__() + + def create_temporary_view( + self, view_path: str, table_or_data_stream: object, *args: object + ) -> None: + self.views[view_path] = table_or_data_stream + + def sql_query(self, query: str) -> FakeNativeFlinkTable: + self.queries.append(query) + return FakeNativeFlinkTable([]) + + +class InputNode(DAGNode): + def execute(self, context: ExecutionContext) -> DAGValue: + return context.node_outputs[self.name] + + +class FakeRetrievalJob(RetrievalJob): + def __init__(self, table: pa.Table) -> None: + self._table = table + + def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame: + return self._table.to_pandas() + + def _to_arrow_internal(self, timeout: Optional[int] = None) -> pa.Table: + return self._table + + @property + def full_feature_names(self) -> bool: + return False + + @property + def on_demand_feature_views(self) -> List[OnDemandFeatureView]: + return [] + + @property + def metadata(self) -> Optional[RetrievalMetadata]: + return None + + def persist( + self, + storage: SavedDatasetStorage, + allow_overwrite: bool = False, + timeout: Optional[int] = None, + ) -> None: + raise NotImplementedError + + def to_remote_storage(self) -> List[str]: + raise NotImplementedError + + def to_sql(self) -> str: + raise NotImplementedError + + +class FakeFlinkRetrievalJob: + def __init__(self, df: pd.DataFrame) -> None: + self._table = FakeFlinkTable(df) + + def to_flink_table(self, table_env: object) -> FakeFlinkTable: + return self._table + + +def _repo_config(tmp_path: Path, batch_engine: dict[str, object]) -> RepoConfig: + return RepoConfig( + project="test_project", + registry=str(tmp_path / "registry.db"), + provider="local", + offline_store={"type": "file"}, + online_store={"type": "sqlite", "path": str(tmp_path / "online.db")}, + batch_engine=batch_engine, + ) + + +def _driver() -> Entity: + return Entity(name="driver_id", value_type=ValueType.INT64) + + +def _source() -> FileSource: + return FileSource( + name="driver_stats_source", + path="unused.parquet", + timestamp_field="event_timestamp", + created_timestamp_column="created", + ) + + +def _feature_view(source: FileSource, **kwargs: Any) -> BatchFeatureView: + return BatchFeatureView( + name="driver_stats", + entities=[_driver()], + ttl=timedelta(days=2), + schema=[Field(name="conv_rate", dtype=Float32)], + source=source, + **kwargs, + ) + + +def _feature_data() -> pd.DataFrame: + return pd.DataFrame( + { + "driver_id": [1, 1, 2], + "event_timestamp": [ + datetime(2024, 1, 1, 9, 0, 0), + datetime(2024, 1, 1, 10, 0, 0), + datetime(2024, 1, 1, 10, 0, 0), + ], + "created": [ + datetime(2024, 1, 1, 9, 1, 0), + datetime(2024, 1, 1, 10, 1, 0), + datetime(2024, 1, 1, 10, 1, 0), + ], + "conv_rate": [0.1, 0.2, 0.3], + } + ) + + +def _offline_store(df: pd.DataFrame) -> MagicMock: + store = MagicMock() + store.pull_all_from_table_or_query.return_value = FakeFlinkRetrievalJob(df) + store.pull_latest_from_table_or_query.return_value = FakeFlinkRetrievalJob(df) + return store + + +def _registry(entity: Entity) -> MagicMock: + registry = MagicMock() + registry.get_entity.return_value = entity + return registry + + +def _column_info() -> ColumnInfo: + return ColumnInfo( + join_keys=["driver_id"], + feature_cols=["conv_rate"], + ts_col="event_timestamp", + created_ts_col="created", + ) + + +def _execution_context( + tmp_path: Path, node_outputs: dict[str, DAGValue] +) -> ExecutionContext: + return ExecutionContext( + project="test_project", + repo_config=_repo_config(tmp_path, {"type": "flink.engine"}), + offline_store=MagicMock(), + online_store=MagicMock(), + entity_defs=[_driver()], + node_outputs=node_outputs, + ) + + +def _flink_value(df: pd.DataFrame) -> DAGValue: + return DAGValue( + data=FakeFlinkTable(df), + format=DAGFormat.FLINK, + metadata={"columns": list(df.columns)}, + ) + + +def _native_flink_value(columns: List[str]) -> DAGValue: + return DAGValue( + data=FakeNativeFlinkTable(columns), + format=DAGFormat.FLINK, + metadata={"columns": columns}, + ) + + +def test_repo_config_loads_flink_batch_engine_config(tmp_path: Path) -> None: + config = _repo_config( + tmp_path, + { + "type": "flink.engine", + "execution_mode": "streaming", + "parallelism": 3, + "table_config": {"pipeline.name": "feast-flink-test"}, + "pandas_split_num": 2, + }, + ) + + assert isinstance(config.batch_engine, FlinkComputeEngineConfig) + assert config.batch_engine.execution_mode == "streaming" + assert config.batch_engine.parallelism == 3 + assert config.batch_engine.table_config == {"pipeline.name": "feast-flink-test"} + assert config.batch_engine.pandas_split_num == 2 + + +def test_flink_source_read_node_rejects_arrow_retrieval_jobs(tmp_path: Path) -> None: + offline_store = MagicMock() + offline_store.pull_all_from_table_or_query.return_value = FakeRetrievalJob( + pa.Table.from_pandas(_feature_data()) + ) + context = _execution_context(tmp_path, {}) + context.offline_store = offline_store + node = FlinkSourceReadNode( + "source", + _source(), + _column_info(), + FakeTableEnvironment(), + split_num=1, + ) + + with pytest.raises(TypeError, match="to_flink_table"): + node.execute(context) + + +def test_flink_historical_retrieval_executes_dag_with_transformation( + tmp_path: Path, +) -> None: + entity = _driver() + source = _source() + + def double_conv_rate(table: FakeFlinkTable) -> FakeFlinkTable: + df = table.to_pandas() + df["conv_rate"] = df["conv_rate"] * 2 + return FakeFlinkTable(df) + + feature_view = _feature_view( + source, + mode="flink", + udf=double_conv_rate, + udf_string="double_conv_rate", + online=False, + offline=False, + ) + config = _repo_config( + tmp_path, + {"type": "flink.engine", "pandas_split_num": 4}, + ) + table_env = FakeTableEnvironment() + engine = FlinkComputeEngine( + repo_config=config, + offline_store=_offline_store(_feature_data()), + online_store=MagicMock(), + table_environment=table_env, + ) + task = HistoricalRetrievalTask( + project=config.project, + entity_df=pd.DataFrame(), + feature_view=feature_view, + full_feature_name=False, + registry=_registry(entity), + ) + + job = engine.get_historical_features(_registry(entity), task) + result = job.to_df().sort_values("driver_id").reset_index(drop=True) + + assert job.error() is None + assert result["driver_id"].tolist() == [1, 2] + assert result["conv_rate"].tolist() == [0.4, 0.6] + + +def test_flink_historical_retrieval_is_read_only_and_dedupes_per_entity_row( + tmp_path: Path, +) -> None: + entity = _driver() + source = _source() + feature_view = _feature_view(source, online=True, offline=True) + config = _repo_config(tmp_path, {"type": "flink.engine", "pandas_split_num": 4}) + feature_data = pd.DataFrame( + { + "driver_id": [1, 1], + "event_timestamp": [ + datetime(2024, 1, 1, 9, 0, 0), + datetime(2024, 1, 1, 10, 0, 0), + ], + "created": [ + datetime(2024, 1, 1, 9, 1, 0), + datetime(2024, 1, 1, 10, 1, 0), + ], + "conv_rate": [0.1, 0.2], + } + ) + offline_store = _offline_store(feature_data) + online_store = MagicMock() + table_env = FakeTableEnvironment() + engine = FlinkComputeEngine( + repo_config=config, + offline_store=offline_store, + online_store=online_store, + table_environment=table_env, + ) + task = HistoricalRetrievalTask( + project=config.project, + entity_df=pd.DataFrame( + { + "driver_id": [1, 1], + "event_timestamp": [ + datetime(2024, 1, 1, 9, 30, 0), + datetime(2024, 1, 1, 10, 30, 0), + ], + } + ), + feature_view=feature_view, + full_feature_name=False, + registry=_registry(entity), + ) + + result = engine.get_historical_features(_registry(entity), task).to_df() + result = result.sort_values(ENTITY_TS_ALIAS).reset_index(drop=True) + + assert result["conv_rate"].tolist() == [0.1, 0.2] + assert table_env.split_nums == [4] + assert ENTITY_ROW_ID not in result.columns + online_store.online_write_batch.assert_not_called() + offline_store.offline_write_batch.assert_not_called() + + +def test_flink_historical_retrieval_supports_sql_entity_df(tmp_path: Path) -> None: + entity = _driver() + source = _source() + feature_view = _feature_view(source, online=False, offline=False) + config = _repo_config(tmp_path, {"type": "flink.engine"}) + table_env = FakeTableEnvironment() + table_env.create_temporary_view( + "entities", + FakeFlinkTable( + pd.DataFrame( + { + "driver_id": [1, 1], + "event_timestamp": [ + datetime(2024, 1, 1, 9, 30, 0), + datetime(2024, 1, 1, 10, 30, 0), + ], + } + ) + ), + ) + engine = FlinkComputeEngine( + repo_config=config, + offline_store=_offline_store(_feature_data()), + online_store=MagicMock(), + table_environment=table_env, + ) + task = HistoricalRetrievalTask( + project=config.project, + entity_df="SELECT driver_id, event_timestamp FROM entities", + feature_view=feature_view, + full_feature_name=False, + registry=_registry(entity), + ) + + job = engine.get_historical_features(_registry(entity), task) + result = job.to_df().sort_values(ENTITY_TS_ALIAS).reset_index(drop=True) + + assert job.error() is None + assert result["conv_rate"].tolist() == [0.1, 0.2] + assert any( + "SELECT driver_id, event_timestamp FROM entities" in query + for query in table_env.queries + ) + + +def test_flink_materialize_writes_online_and_offline(tmp_path: Path) -> None: + entity = _driver() + source = _source() + feature_view = _feature_view(source, online=True, offline=True) + config = _repo_config(tmp_path, {"type": "flink.engine"}) + offline_store = _offline_store(_feature_data().head(1)) + online_store = MagicMock() + engine = FlinkComputeEngine( + repo_config=config, + offline_store=offline_store, + online_store=online_store, + table_environment=FakeTableEnvironment(), + ) + task = MaterializationTask( + project=config.project, + feature_view=feature_view, + start_time=datetime(2024, 1, 1), + end_time=datetime(2024, 1, 2), + ) + + jobs = engine.materialize(_registry(entity), task) + + assert len(jobs) == 1 + assert jobs[0].status() == MaterializationJobStatus.SUCCEEDED + assert jobs[0].error() is None + online_store.online_write_batch.assert_called_once() + offline_store.offline_write_batch.assert_called_once() + + +def test_flink_engine_reports_materialization_errors(tmp_path: Path) -> None: + entity = _driver() + source = _source() + feature_view = _feature_view(source, online=False, offline=False) + offline_store = MagicMock() + offline_store.pull_all_from_table_or_query.side_effect = RuntimeError("boom") + config = _repo_config(tmp_path, {"type": "flink.engine"}) + engine = FlinkComputeEngine( + repo_config=config, + offline_store=offline_store, + online_store=MagicMock(), + table_environment=FakeTableEnvironment(), + ) + task = MaterializationTask( + project=config.project, + feature_view=feature_view, + start_time=datetime(2024, 1, 1), + end_time=datetime(2024, 1, 2), + ) + + jobs = engine.materialize(_registry(entity), task) + + assert jobs[0].status() == MaterializationJobStatus.ERROR + assert isinstance(jobs[0].error(), RuntimeError) + + +def test_flink_join_node_merges_input_tables(tmp_path: Path) -> None: + left = InputNode("left") + right = InputNode("right") + node = FlinkJoinNode( + "join", + _column_info(), + FakeTableEnvironment(), + split_num=1, + inputs=[left, right], + ) + context = _execution_context( + tmp_path, + { + "left": _flink_value( + pd.DataFrame({"driver_id": [1, 2], "conv_rate": [0.1, 0.2]}) + ), + "right": _flink_value( + pd.DataFrame({"driver_id": [1, 2], "acc_rate": [0.3, 0.4]}) + ), + }, + ) + + result = node.execute(context).data.to_pandas().sort_values("driver_id") + + assert result["conv_rate"].tolist() == [0.1, 0.2] + assert result["acc_rate"].tolist() == [0.3, 0.4] + + +def test_flink_join_node_uses_native_sql_when_available(tmp_path: Path) -> None: + left = InputNode("left") + right = InputNode("right") + table_env = RecordingTableEnvironment() + node = FlinkJoinNode( + "join", + _column_info(), + table_env, + split_num=1, + inputs=[left, right], + ) + context = _execution_context( + tmp_path, + { + "left": _native_flink_value(["driver_id", "conv_rate"]), + "right": _native_flink_value(["driver_id", "acc_rate"]), + }, + ) + + result = node.execute(context) + + assert result.format == DAGFormat.FLINK + assert result.metadata["columns"] == ["driver_id", "conv_rate", "acc_rate"] + assert any("JOIN" in query for query in table_env.queries) + + +def test_flink_filter_node_applies_filter_expression(tmp_path: Path) -> None: + input_node = InputNode("input") + node = FlinkFilterNode( + "filter", + _column_info(), + FakeTableEnvironment(), + split_num=1, + filter_expr="conv_rate > 0.15", + inputs=[input_node], + ) + context = _execution_context( + tmp_path, + { + "input": _flink_value( + pd.DataFrame({"driver_id": [1, 2], "conv_rate": [0.1, 0.2]}) + ) + }, + ) + + result = node.execute(context).data.to_pandas() + + assert result["driver_id"].tolist() == [2] + + +def test_flink_filter_node_uses_native_sql_when_available(tmp_path: Path) -> None: + input_node = InputNode("input") + table_env = RecordingTableEnvironment() + node = FlinkFilterNode( + "filter", + _column_info(), + table_env, + split_num=1, + filter_expr="conv_rate > 0.15", + inputs=[input_node], + ) + context = _execution_context( + tmp_path, + {"input": _native_flink_value(["driver_id", "conv_rate"])}, + ) + + result = node.execute(context) + + assert result.format == DAGFormat.FLINK + assert any( + "WHERE" in query and "conv_rate > 0.15" in query for query in table_env.queries + ) + + +def test_flink_filter_node_renders_ttl_as_valid_flink_interval( + tmp_path: Path, +) -> None: + input_node = InputNode("input") + table_env = RecordingTableEnvironment() + node = FlinkFilterNode( + "filter", + _column_info(), + table_env, + split_num=1, + ttl=timedelta(days=2, hours=3, minutes=4, seconds=5), + inputs=[input_node], + ) + context = _execution_context( + tmp_path, + { + "input": _native_flink_value( + ["driver_id", "conv_rate", "event_timestamp", ENTITY_TS_ALIAS] + ) + }, + ) + + node.execute(context) + + assert _flink_interval_literal( + timedelta(days=2, hours=3, minutes=4, seconds=5) + ) == ( + "INTERVAL '2' DAY + INTERVAL '3' HOUR + " + "INTERVAL '4' MINUTE + INTERVAL '5' SECOND" + ) + assert any("INTERVAL '2' DAY" in query for query in table_env.queries) + + +def test_flink_aggregation_node_groups_features(tmp_path: Path) -> None: + input_node = InputNode("input") + node = FlinkAggregationNode( + "agg", + ["driver_id"], + aggregations=[Aggregation(column="conv_rate", function="sum")], + table_env=FakeTableEnvironment(), + split_num=1, + inputs=[input_node], + ) + context = _execution_context( + tmp_path, + { + "input": _flink_value( + pd.DataFrame({"driver_id": [1, 1, 2], "conv_rate": [0.1, 0.2, 0.3]}) + ) + }, + ) + + result = node.execute(context).data.to_pandas().sort_values("driver_id") + + assert result["sum_conv_rate"].tolist() == pytest.approx([0.3, 0.3]) + + +def test_flink_aggregation_node_uses_native_sql_when_available(tmp_path: Path) -> None: + input_node = InputNode("input") + table_env = RecordingTableEnvironment() + node = FlinkAggregationNode( + "agg", + ["driver_id"], + aggregations=[Aggregation(column="conv_rate", function="sum")], + table_env=table_env, + split_num=1, + inputs=[input_node], + ) + context = _execution_context( + tmp_path, + {"input": _native_flink_value(["driver_id", "conv_rate"])}, + ) + + result = node.execute(context) + + assert result.format == DAGFormat.FLINK + assert result.metadata["columns"] == ["driver_id", "sum_conv_rate"] + assert any("GROUP BY" in query and "SUM" in query for query in table_env.queries) + + +def test_flink_dedup_node_uses_entity_row_id_for_historical_retrieval( + tmp_path: Path, +) -> None: + input_node = InputNode("input") + node = FlinkDedupNode( + "dedup", + _column_info(), + FakeTableEnvironment(), + split_num=1, + inputs=[input_node], + ) + context = _execution_context( + tmp_path, + { + "input": _flink_value( + pd.DataFrame( + { + ENTITY_ROW_ID: [0, 0, 1], + "driver_id": [1, 1, 1], + "event_timestamp": [ + datetime(2024, 1, 1, 9, 0, 0), + datetime(2024, 1, 1, 10, 0, 0), + datetime(2024, 1, 1, 10, 0, 0), + ], + "created": [ + datetime(2024, 1, 1, 9, 1, 0), + datetime(2024, 1, 1, 10, 1, 0), + datetime(2024, 1, 1, 10, 1, 0), + ], + "conv_rate": [0.1, 0.2, 0.3], + } + ) + ) + }, + ) + + result = node.execute(context).data.to_pandas().sort_values(ENTITY_ROW_ID) + + assert result["conv_rate"].tolist() == [0.2, 0.3] + + +def test_flink_dedup_node_uses_native_row_number_when_available( + tmp_path: Path, +) -> None: + input_node = InputNode("input") + table_env = RecordingTableEnvironment() + node = FlinkDedupNode( + "dedup", + _column_info(), + table_env, + split_num=1, + inputs=[input_node], + ) + context = _execution_context( + tmp_path, + { + "input": _native_flink_value( + [ENTITY_ROW_ID, "driver_id", "event_timestamp", "created", "conv_rate"] + ) + }, + ) + + result = node.execute(context) + + assert result.format == DAGFormat.FLINK + assert any("ROW_NUMBER() OVER" in query for query in table_env.queries) + assert ENTITY_ROW_ID in result.metadata["columns"] + + +def test_flink_transformation_node_keeps_native_flink_table(tmp_path: Path) -> None: + input_node = InputNode("input") + native_result = FakeNativeFlinkTable(["driver_id", "conv_rate"]) + + def native_udf(table: object) -> FakeNativeFlinkTable: + return native_result + + node = FlinkTransformationNode( + "transform", + native_udf, + RecordingTableEnvironment(), + split_num=1, + inputs=[input_node], + ) + context = _execution_context( + tmp_path, + {"input": _native_flink_value(["driver_id", "conv_rate"])}, + ) + + result = node.execute(context) + + assert result.data is native_result + assert result.metadata["columns"] == ["driver_id", "conv_rate"] + + +def test_flink_validation_node_raises_for_missing_columns(tmp_path: Path) -> None: + input_node = InputNode("input") + node = FlinkValidationNode( + "validate", + expected_columns={"missing_feature": pa.float32()}, + json_columns=set(), + table_env=FakeTableEnvironment(), + split_num=1, + inputs=[input_node], + ) + context = _execution_context( + tmp_path, + {"input": _flink_value(pd.DataFrame({"driver_id": [1]}))}, + ) + + with pytest.raises(ValueError, match="Missing expected columns"): + node.execute(context) + + +@pytest.mark.integration +@pytest.mark.slow +def test_flink_compute_engine_executes_with_real_pyflink_when_installed( + tmp_path: Path, +) -> None: + pyflink_table = pytest.importorskip( + "pyflink.table", reason="PyFlink is required for this runtime smoke test" + ) + entity = _driver() + source = _source() + feature_view = _feature_view(source, online=True, offline=True) + config = _repo_config(tmp_path, {"type": "flink.engine"}) + offline_store = _offline_store(_feature_data()) + online_store = MagicMock() + table_env = pyflink_table.TableEnvironment.create( + pyflink_table.EnvironmentSettings.new_instance().in_batch_mode().build() + ) + engine = FlinkComputeEngine( + repo_config=config, + offline_store=offline_store, + online_store=online_store, + table_environment=table_env, + ) + task = HistoricalRetrievalTask( + project=config.project, + entity_df=pd.DataFrame( + { + "driver_id": [1, 1, 2], + "event_timestamp": [ + datetime(2024, 1, 1, 9, 30, 0), + datetime(2024, 1, 1, 10, 30, 0), + datetime(2024, 1, 1, 10, 30, 0), + ], + } + ), + feature_view=feature_view, + full_feature_name=False, + registry=_registry(entity), + ) + + result = engine.get_historical_features(_registry(entity), task).to_df() + result = result.sort_values(["driver_id", ENTITY_TS_ALIAS]).reset_index(drop=True) + + assert result["conv_rate"].tolist() == [0.1, 0.2, 0.3] + assert ENTITY_ROW_ID not in result.columns + online_store.online_write_batch.assert_not_called() + + materialization_task = MaterializationTask( + project=config.project, + feature_view=feature_view, + start_time=datetime(2024, 1, 1), + end_time=datetime(2024, 1, 2), + ) + + jobs = engine.materialize(_registry(entity), materialization_task) + + assert jobs[0].status() == MaterializationJobStatus.SUCCEEDED + assert jobs[0].error() is None + online_store.online_write_batch.assert_called_once() + offline_store.offline_write_batch.assert_called_once() From 485c606ba9bec16c194d25f7aa68ef62c442e6cf Mon Sep 17 00:00:00 2001 From: Le Xuan An Date: Sat, 6 Jun 2026 15:40:21 +0700 Subject: [PATCH 02/15] fix: Address flink review comments Signed-off-by: Le Xuan An --- docs/reference/compute-engine/flink.md | 12 +++++++----- pyproject.toml | 3 ++- sdk/python/feast/batch_feature_view.py | 2 +- .../feast/infra/compute_engines/flink/utils.py | 5 +++-- .../flink/test_flink_compute_engine.py | 15 +++++++++++++++ 5 files changed, 28 insertions(+), 9 deletions(-) diff --git a/docs/reference/compute-engine/flink.md b/docs/reference/compute-engine/flink.md index fee130ba550..0f598996ffa 100644 --- a/docs/reference/compute-engine/flink.md +++ b/docs/reference/compute-engine/flink.md @@ -17,15 +17,17 @@ the configured online and/or offline store. ## Configuration -Install Feast with the Flink extra before using the engine: +Install the Flink extra from a Feast source checkout with `uv` before using the +engine: ```bash -python -m pip install 'feast[flink]' +uv sync --extra flink --no-dev ``` -The `flink` extra installs PyFlink directly. Feast's Arrow dependency range is -kept compatible with PyFlink's supported `pyarrow` range so `feast[flink]` -resolves without a separate PyFlink install step. +The `flink` extra installs PyFlink directly. PyFlink currently requires +`pyarrow<21`, while the default Feast install keeps `pyarrow>=21`; Feast's uv +lock resolves the Flink extra in a separate dependency fork so normal Feast +installs do not downgrade Arrow. Configure the engine in `feature_store.yaml`: diff --git a/pyproject.toml b/pyproject.toml index ae82cbf1173..32049cd05ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,8 @@ dependencies = [ "mmh3", "numpy>=2.0.0,<3", "pandas>=1.4.3,<3", - "pyarrow>=16.1.0", + "pyarrow>=21.0.0; extra != 'flink'", + "pyarrow>=16.1.0,<21.0.0; extra == 'flink'", "pydantic>=2.10.6", "pygments>=2.12.0,<3", "PyYAML>=5.4.0,<7", diff --git a/sdk/python/feast/batch_feature_view.py b/sdk/python/feast/batch_feature_view.py index 3bdd7d83606..0bfbdf9d936 100644 --- a/sdk/python/feast/batch_feature_view.py +++ b/sdk/python/feast/batch_feature_view.py @@ -176,7 +176,7 @@ def get_feature_transformation(self) -> Optional[Transformation]: ) else: raise ValueError( - f"Unsupported transformation mode: {self.mode} for StreamFeatureView" + f"Unsupported transformation mode: {self.mode} for BatchFeatureView" ) diff --git a/sdk/python/feast/infra/compute_engines/flink/utils.py b/sdk/python/feast/infra/compute_engines/flink/utils.py index 43c9f285a97..dc330d45d2a 100644 --- a/sdk/python/feast/infra/compute_engines/flink/utils.py +++ b/sdk/python/feast/infra/compute_engines/flink/utils.py @@ -16,8 +16,9 @@ def create_flink_table_environment(config: FlinkComputeEngineConfig) -> Any: from pyflink.table import EnvironmentSettings, TableEnvironment except ImportError as exc: raise ImportError( - "FlinkComputeEngine requires PyFlink. Install Feast with the `flink` " - "extra or otherwise make the `pyflink` package available to Feast." + "FlinkComputeEngine requires PyFlink. Install the `flink` extra with " + "uv from a Feast source checkout, or otherwise make the `pyflink` " + "package available to Feast." ) from exc flink_conf = Configuration() diff --git a/sdk/python/tests/unit/infra/compute_engines/flink/test_flink_compute_engine.py b/sdk/python/tests/unit/infra/compute_engines/flink/test_flink_compute_engine.py index 98bb3f6d3d8..f1d7ad2e8b6 100644 --- a/sdk/python/tests/unit/infra/compute_engines/flink/test_flink_compute_engine.py +++ b/sdk/python/tests/unit/infra/compute_engines/flink/test_flink_compute_engine.py @@ -9,6 +9,7 @@ import pandas as pd import pyarrow as pa import pytest +import toml # type: ignore[import-untyped] from feast import BatchFeatureView, Entity, Field, FileSource from feast.aggregation import Aggregation @@ -45,6 +46,20 @@ from feast.value_type import ValueType +def test_flink_extra_does_not_downgrade_default_pyarrow_dependency() -> None: + pyproject_path = Path(__file__).resolve().parents[7] / "pyproject.toml" + pyproject = toml.loads(pyproject_path.read_text()) + + dependencies = pyproject["project"]["dependencies"] + assert "pyarrow>=21.0.0; extra != 'flink'" in dependencies + assert "pyarrow>=16.1.0,<21.0.0; extra == 'flink'" in dependencies + assert "pyarrow>=16.1.0" not in dependencies + assert ( + "apache-flink>=2.2.1,<3" + in pyproject["project"]["optional-dependencies"]["flink"] + ) + + class FakeFlinkTable: def __init__(self, df: pd.DataFrame) -> None: self._df = df.copy() From b3731e63a896bc57147ccfd94d87c25a478fdb7d Mon Sep 17 00:00:00 2001 From: Le Xuan An Date: Sun, 7 Jun 2026 11:26:03 +0700 Subject: [PATCH 03/15] fix: Address flink copilot review comments Signed-off-by: Le Xuan An --- .../infra/compute_engines/flink/compute.py | 9 +- .../feast/infra/compute_engines/flink/job.py | 14 +- .../infra/compute_engines/flink/nodes.py | 88 ++++---- .../infra/compute_engines/flink/utils.py | 170 ++++++++++++++- .../flink/test_flink_compute_engine.py | 204 +++++++++++++++++- 5 files changed, 433 insertions(+), 52 deletions(-) diff --git a/sdk/python/feast/infra/compute_engines/flink/compute.py b/sdk/python/feast/infra/compute_engines/flink/compute.py index 4018427374e..b0cfc645a7a 100644 --- a/sdk/python/feast/infra/compute_engines/flink/compute.py +++ b/sdk/python/feast/infra/compute_engines/flink/compute.py @@ -22,7 +22,10 @@ FlinkDAGRetrievalJob, FlinkMaterializationJob, ) -from feast.infra.compute_engines.flink.utils import create_flink_table_environment +from feast.infra.compute_engines.flink.utils import ( + cleanup_flink_temporary_views, + create_flink_table_environment, +) from feast.infra.offline_stores.offline_store import OfflineStore, RetrievalJob from feast.infra.online_stores.online_store import OnlineStore from feast.infra.registry.base_registry import BaseRegistry @@ -122,6 +125,8 @@ def _materialize_one( status=MaterializationJobStatus.ERROR, error=exc, ) + finally: + cleanup_flink_temporary_views(self.table_env) def get_historical_features( self, registry: BaseRegistry, task: HistoricalRetrievalTask @@ -138,6 +143,7 @@ def get_historical_features( return FlinkDAGRetrievalJob( plan=plan, context=context, + table_env=self.table_env, full_feature_names=task.full_feature_name, ) except Exception as exc: @@ -149,6 +155,7 @@ def get_historical_features( return FlinkDAGRetrievalJob( plan=None, context=context, + table_env=self.table_env, full_feature_names=task.full_feature_name, error=exc, ) diff --git a/sdk/python/feast/infra/compute_engines/flink/job.py b/sdk/python/feast/infra/compute_engines/flink/job.py index 44d9afd1824..45e80df5506 100644 --- a/sdk/python/feast/infra/compute_engines/flink/job.py +++ b/sdk/python/feast/infra/compute_engines/flink/job.py @@ -12,7 +12,10 @@ ) from feast.infra.compute_engines.dag.context import ExecutionContext from feast.infra.compute_engines.dag.plan import ExecutionPlan -from feast.infra.compute_engines.flink.utils import flink_table_to_arrow +from feast.infra.compute_engines.flink.utils import ( + cleanup_flink_temporary_views, + flink_table_to_arrow, +) from feast.infra.offline_stores.offline_store import RetrievalJob, RetrievalMetadata from feast.on_demand_feature_view import OnDemandFeatureView from feast.saved_dataset import SavedDatasetStorage @@ -23,6 +26,7 @@ def __init__( self, plan: Optional[ExecutionPlan], context: ExecutionContext, + table_env: object, full_feature_names: bool, on_demand_feature_views: Optional[List[OnDemandFeatureView]] = None, metadata: Optional[RetrievalMetadata] = None, @@ -30,6 +34,7 @@ def __init__( ) -> None: self._plan = plan self._context = context + self._table_env = table_env self._full_feature_names = full_feature_names self._on_demand_feature_views = on_demand_feature_views or [] self._metadata = metadata @@ -45,8 +50,11 @@ def _ensure_executed(self) -> None: raise self._error if self._plan is None: raise RuntimeError("Execution plan is not set") - result = self._plan.execute(self._context) - self._arrow_table = flink_table_to_arrow(result.data) + try: + result = self._plan.execute(self._context) + self._arrow_table = flink_table_to_arrow(result.data) + finally: + cleanup_flink_temporary_views(self._table_env) def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame: self._ensure_executed() diff --git a/sdk/python/feast/infra/compute_engines/flink/nodes.py b/sdk/python/feast/infra/compute_engines/flink/nodes.py index 818efd0b1ef..a00235f4f18 100644 --- a/sdk/python/feast/infra/compute_engines/flink/nodes.py +++ b/sdk/python/feast/infra/compute_engines/flink/nodes.py @@ -16,8 +16,9 @@ from feast.infra.compute_engines.dag.node import DAGNode from feast.infra.compute_engines.dag.value import DAGValue from feast.infra.compute_engines.flink.utils import ( - flink_table_to_pandas, + flink_table_to_arrow_batches, pandas_to_flink_table, + register_flink_temporary_view, ) from feast.infra.compute_engines.utils import create_offline_store_retrieval_job from feast.infra.offline_stores.offline_utils import ( @@ -47,15 +48,15 @@ def _select_column(alias: str, column: str, output_name: Optional[str] = None) - return expr -def _flink_interval_literal(value: timedelta) -> str: +def _flink_interval_literals(value: timedelta) -> List[str]: total_seconds = int(value.total_seconds()) if total_seconds <= 0: - return "INTERVAL '0' SECOND" + return ["INTERVAL '0' SECOND"] days, remainder = divmod(total_seconds, 24 * 60 * 60) hours, remainder = divmod(remainder, 60 * 60) minutes, seconds = divmod(remainder, 60) - parts = [] + parts: List[str] = [] if days: parts.append(f"INTERVAL '{days}' DAY") if hours: @@ -64,7 +65,14 @@ def _flink_interval_literal(value: timedelta) -> str: parts.append(f"INTERVAL '{minutes}' MINUTE") if seconds: parts.append(f"INTERVAL '{seconds}' SECOND") - return " + ".join(parts) + return parts + + +def _subtract_flink_intervals(timestamp_expr: str, value: timedelta) -> str: + result = timestamp_expr + for interval in _flink_interval_literals(value): + result = f"{result} - {interval}" + return result def _get_columns_from_schema(table: Any) -> Optional[List[str]]: @@ -107,6 +115,7 @@ def _require_sql(table_env: Any, node_name: str) -> None: def _register_table(table_env: Any, table: Any, prefix: str) -> str: view_name = f"__feast_{prefix}_{uuid.uuid4().hex}" table_env.create_temporary_view(view_name, table) + register_flink_temporary_view(table_env, view_name) return view_name @@ -447,11 +456,11 @@ def _execute_sql_filter(self, input_value: DAGValue) -> DAGValue: f"{_quote_identifier(ENTITY_TS_ALIAS)}" ) if self.ttl: - ttl_interval = _flink_interval_literal(self.ttl) + lower_bound = _subtract_flink_intervals( + _quote_identifier(ENTITY_TS_ALIAS), self.ttl + ) conditions.append( - f"{_quote_identifier(timestamp_column)} >= " - f"{_quote_identifier(ENTITY_TS_ALIAS)} - " - f"({ttl_interval})" + f"{_quote_identifier(timestamp_column)} >= {lower_bound}" ) if self.filter_expr: @@ -708,44 +717,49 @@ def execute(self, context: ExecutionContext) -> DAGValue: if not self.write_output: return output_value - output_df = flink_table_to_pandas(output_table) - output_arrow = pa.Table.from_pandas(output_df) - - if output_arrow.num_rows == 0: - return output_value - + columns = _get_columns(output_value) + batch_size = context.repo_config.materialization_config.online_write_batch_size if self.feature_view.online: join_key_to_value_type = { entity.name: entity.dtype.to_value_type() for entity in self.feature_view.entity_columns } - batch_size = ( - context.repo_config.materialization_config.online_write_batch_size - ) - batches = ( - [output_arrow] - if batch_size is None - else output_arrow.to_batches(max_chunksize=batch_size) - ) - for batch in batches: - rows_to_write = _convert_arrow_to_proto( - batch, self.feature_view, join_key_to_value_type + else: + join_key_to_value_type = {} + + for output_arrow in flink_table_to_arrow_batches( + output_table, + columns, + batch_size, + ): + if output_arrow.num_rows == 0: + continue + + if self.feature_view.online: + arrow_batches = ( + [output_arrow] + if batch_size is None + else output_arrow.to_batches(max_chunksize=batch_size) ) - context.online_store.online_write_batch( + for batch in arrow_batches: + rows_to_write = _convert_arrow_to_proto( + batch, self.feature_view, join_key_to_value_type + ) + context.online_store.online_write_batch( + config=context.repo_config, + table=self.feature_view, + data=rows_to_write, + progress=lambda x: None, + ) + + if self.feature_view.offline: + context.offline_store.offline_write_batch( config=context.repo_config, - table=self.feature_view, - data=rows_to_write, + feature_view=self.feature_view, + table=output_arrow, progress=lambda x: None, ) - if self.feature_view.offline: - context.offline_store.offline_write_batch( - config=context.repo_config, - feature_view=self.feature_view, - table=output_arrow, - progress=lambda x: None, - ) - return output_value def _drop_internal_columns(self, input_value: DAGValue) -> DAGValue: diff --git a/sdk/python/feast/infra/compute_engines/flink/utils.py b/sdk/python/feast/infra/compute_engines/flink/utils.py index dc330d45d2a..f271e8bb113 100644 --- a/sdk/python/feast/infra/compute_engines/flink/utils.py +++ b/sdk/python/feast/infra/compute_engines/flink/utils.py @@ -1,6 +1,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +import logging +from datetime import date, datetime +from decimal import Decimal +from typing import TYPE_CHECKING, Any, Iterator, Optional import pandas as pd import pyarrow as pa @@ -8,6 +11,11 @@ if TYPE_CHECKING: from feast.infra.compute_engines.flink.compute import FlinkComputeEngineConfig +logger = logging.getLogger(__name__) + +DEFAULT_FLINK_RESULT_BATCH_SIZE = 10_000 +_TEMP_VIEW_REGISTRY: dict[int, set[str]] = {} + def create_flink_table_environment(config: FlinkComputeEngineConfig) -> Any: """Create a PyFlink TableEnvironment from Feast engine config.""" @@ -37,8 +45,11 @@ def create_flink_table_environment(config: FlinkComputeEngineConfig) -> Any: def pandas_to_flink_table(table_env: Any, df: pd.DataFrame, split_num: int = 1) -> Any: """Convert a pandas DataFrame to a PyFlink table.""" - schema = list(df.columns) - return table_env.from_pandas(df, schema=schema, splits_num=split_num) + schema = _build_pandas_flink_schema(df) + kwargs: dict[str, Any] = {"splits_num": split_num} + if schema is not None: + kwargs["schema"] = schema + return table_env.from_pandas(df, **kwargs) def flink_table_to_pandas(table: Any) -> pd.DataFrame: @@ -52,3 +63,156 @@ def flink_table_to_arrow(table: Any) -> pa.Table: """Collect a PyFlink table into Arrow.""" value = flink_table_to_pandas(table) return pa.Table.from_pandas(value) + + +def flink_table_to_arrow_batches( + table: Any, + columns: list[str], + batch_size: Optional[int], +) -> Iterator[pa.Table]: + """Stream a PyFlink table into Arrow tables without collecting it all first.""" + effective_batch_size = batch_size or DEFAULT_FLINK_RESULT_BATCH_SIZE + if effective_batch_size <= 0: + raise ValueError("Flink result batch size must be positive.") + + if hasattr(table, "execute"): + rows: list[dict[str, Any]] = [] + row_iterator = table.execute().collect() + try: + for row in row_iterator: + rows.append(_row_to_dict(row, columns)) + if len(rows) >= effective_batch_size: + yield _rows_to_arrow(rows, columns) + rows = [] + if rows: + yield _rows_to_arrow(rows, columns) + finally: + close = getattr(row_iterator, "close", None) + if callable(close): + close() + return + + if hasattr(table, "to_pandas"): + # Used by unit-test fakes and other table-like objects. Real PyFlink tables use + # the execute().collect() path above. + df = table.to_pandas() + for start in range(0, len(df), effective_batch_size): + yield pa.Table.from_pandas( + df.iloc[start : start + effective_batch_size], + preserve_index=False, + ) + return + + raise TypeError(f"Expected a PyFlink table, got {type(table)}") + + +def register_flink_temporary_view(table_env: Any, view_name: str) -> None: + """Track Feast-created temporary views so long-lived table envs can clean up.""" + views = _temporary_views_for_env(table_env) + views.add(view_name) + + +def cleanup_flink_temporary_views(table_env: Any) -> None: + """Drop Feast-created temporary views from a PyFlink TableEnvironment.""" + views = _temporary_views_for_env(table_env) + if not views: + return + + drop_view = getattr(table_env, "drop_temporary_view", None) + if not callable(drop_view): + return + + for view_name in list(views): + try: + drop_view(view_name) + except Exception as exc: + logger.debug("Failed to drop Flink temporary view %s: %s", view_name, exc) + finally: + views.discard(view_name) + if not views: + _TEMP_VIEW_REGISTRY.pop(id(table_env), None) + + +def _build_pandas_flink_schema(df: pd.DataFrame) -> Any: + try: + from pyflink.table import Schema + except ImportError: + return None + + builder = Schema.new_builder() + for column in df.columns: + builder.column(str(column), _pandas_dtype_to_flink_type(df[column])) + return builder.build() + + +def _pandas_dtype_to_flink_type(series: pd.Series) -> Any: + from pyflink.table import DataTypes + + dtype = series.dtype + if pd.api.types.is_bool_dtype(dtype): + return DataTypes.BOOLEAN() + if pd.api.types.is_integer_dtype(dtype): + dtype_name = str(dtype).lower() + if dtype_name.endswith("int8"): + return DataTypes.TINYINT() + if dtype_name.endswith("int16"): + return DataTypes.SMALLINT() + if dtype_name.endswith("int32"): + return DataTypes.INT() + return DataTypes.BIGINT() + if pd.api.types.is_float_dtype(dtype): + return DataTypes.FLOAT() if str(dtype) == "float32" else DataTypes.DOUBLE() + if pd.api.types.is_datetime64_any_dtype(dtype): + if getattr(dtype, "tz", None) is not None: + return DataTypes.TIMESTAMP_LTZ(3) + return DataTypes.TIMESTAMP(3) + if pd.api.types.is_timedelta64_dtype(dtype): + return DataTypes.BIGINT() + + sample = _first_non_null_value(series) + if isinstance(sample, (bytes, bytearray)): + return DataTypes.BYTES() + if isinstance(sample, date) and not isinstance(sample, datetime): + return DataTypes.DATE() + if isinstance(sample, Decimal): + return DataTypes.DECIMAL(38, 18) + return DataTypes.STRING() + + +def _first_non_null_value(series: pd.Series) -> Any: + for value in series: + if pd.notna(value): + return value + return None + + +def _row_to_dict(row: Any, columns: list[str]) -> dict[str, Any]: + if isinstance(row, dict): + return {column: row.get(column) for column in columns} + + as_dict = getattr(row, "as_dict", None) + if callable(as_dict): + row_dict = as_dict() + if row_dict: + return {column: row_dict.get(column) for column in columns} + + values = list(row) + return dict(zip(columns, values)) + + +def _rows_to_arrow(rows: list[dict[str, Any]], columns: list[str]) -> pa.Table: + df = pd.DataFrame.from_records(rows, columns=columns) + return pa.Table.from_pandas(df, preserve_index=False) + + +def _temporary_views_for_env(table_env: Any) -> set[str]: + views = getattr(table_env, "_feast_temporary_views", None) + if isinstance(views, set): + return views + + views = _TEMP_VIEW_REGISTRY.setdefault(id(table_env), set()) + try: + setattr(table_env, "_feast_temporary_views", views) + except Exception: + pass + return views diff --git a/sdk/python/tests/unit/infra/compute_engines/flink/test_flink_compute_engine.py b/sdk/python/tests/unit/infra/compute_engines/flink/test_flink_compute_engine.py index f1d7ad2e8b6..13ee329616b 100644 --- a/sdk/python/tests/unit/infra/compute_engines/flink/test_flink_compute_engine.py +++ b/sdk/python/tests/unit/infra/compute_engines/flink/test_flink_compute_engine.py @@ -1,6 +1,8 @@ from __future__ import annotations import re +import sys +import types from datetime import datetime, timedelta from pathlib import Path from typing import Any, List, Optional @@ -33,11 +35,13 @@ FlinkDedupNode, FlinkFilterNode, FlinkJoinNode, + FlinkOutputNode, FlinkSourceReadNode, FlinkTransformationNode, FlinkValidationNode, - _flink_interval_literal, + _subtract_flink_intervals, ) +from feast.infra.compute_engines.flink.utils import pandas_to_flink_table from feast.infra.offline_stores.offline_store import RetrievalJob, RetrievalMetadata from feast.on_demand_feature_view import OnDemandFeatureView from feast.repo_config import RepoConfig @@ -61,21 +65,52 @@ def test_flink_extra_does_not_downgrade_default_pyarrow_dependency() -> None: class FakeFlinkTable: - def __init__(self, df: pd.DataFrame) -> None: + def __init__(self, df: pd.DataFrame, fail_on_to_pandas: bool = False) -> None: self._df = df.copy() + self.fail_on_to_pandas = fail_on_to_pandas def to_pandas(self) -> pd.DataFrame: + if self.fail_on_to_pandas: + raise AssertionError("FlinkOutputNode should not collect via to_pandas()") return self._df.copy() def get_schema(self) -> FakeFlinkSchema: return FakeFlinkSchema(list(self._df.columns)) + def execute(self) -> FakeTableResult: + return FakeTableResult(self._df) + + +class FakeTableResult: + def __init__(self, df: pd.DataFrame) -> None: + self._df = df.copy() + + def collect(self) -> FakeCloseableIterator: + return FakeCloseableIterator(self._df) + + +class FakeCloseableIterator: + def __init__(self, df: pd.DataFrame) -> None: + self._rows = iter(df.itertuples(index=False, name=None)) + self.closed = False + + def __iter__(self) -> FakeCloseableIterator: + return self + + def __next__(self) -> tuple[Any, ...]: + return next(self._rows) + + def close(self) -> None: + self.closed = True + class FakeTableEnvironment: def __init__(self) -> None: self.created_tables: List[pd.DataFrame] = [] + self.schemas: List[object] = [] self.split_nums: List[int] = [] self.views: dict[str, object] = {} + self.dropped_views: List[str] = [] self.queries: List[str] = [] def from_pandas( @@ -86,6 +121,7 @@ def from_pandas( split_num: Optional[int] = None, ) -> FakeFlinkTable: self.created_tables.append(df.copy()) + self.schemas.append(schema) self.split_nums.append(split_num if split_num is not None else splits_num) return FakeFlinkTable(df) @@ -94,6 +130,10 @@ def create_temporary_view( ) -> None: self.views[view_path] = table_or_data_stream + def drop_temporary_view(self, view_path: str) -> None: + self.dropped_views.append(view_path) + self.views.pop(view_path, None) + def sql_query(self, query: str) -> Any: self.queries.append(query) return FakeFlinkTable(self._evaluate_sql(query)) @@ -379,6 +419,109 @@ def _flink_value(df: pd.DataFrame) -> DAGValue: ) +def test_pandas_to_flink_table_builds_typed_schema( + monkeypatch: pytest.MonkeyPatch, +) -> None: + class FakeSchemaBuilder: + def __init__(self) -> None: + self.columns: List[tuple[str, str]] = [] + + def column(self, name: str, dtype: str) -> FakeSchemaBuilder: + self.columns.append((name, dtype)) + return self + + def build(self) -> List[tuple[str, str]]: + return self.columns + + class FakeSchema: + @staticmethod + def new_builder() -> FakeSchemaBuilder: + return FakeSchemaBuilder() + + class FakeDataTypes: + @staticmethod + def BOOLEAN() -> str: + return "BOOLEAN" + + @staticmethod + def TINYINT() -> str: + return "TINYINT" + + @staticmethod + def SMALLINT() -> str: + return "SMALLINT" + + @staticmethod + def INT() -> str: + return "INT" + + @staticmethod + def BIGINT() -> str: + return "BIGINT" + + @staticmethod + def FLOAT() -> str: + return "FLOAT" + + @staticmethod + def DOUBLE() -> str: + return "DOUBLE" + + @staticmethod + def TIMESTAMP(precision: int) -> str: + return f"TIMESTAMP({precision})" + + @staticmethod + def TIMESTAMP_LTZ(precision: int) -> str: + return f"TIMESTAMP_LTZ({precision})" + + @staticmethod + def BYTES() -> str: + return "BYTES" + + @staticmethod + def DATE() -> str: + return "DATE" + + @staticmethod + def DECIMAL(precision: int, scale: int) -> str: + return f"DECIMAL({precision},{scale})" + + @staticmethod + def STRING() -> str: + return "STRING" + + pyflink_module = types.ModuleType("pyflink") + table_module = types.ModuleType("pyflink.table") + setattr(table_module, "Schema", FakeSchema) + setattr(table_module, "DataTypes", FakeDataTypes) + monkeypatch.setitem(sys.modules, "pyflink", pyflink_module) + monkeypatch.setitem(sys.modules, "pyflink.table", table_module) + + table_env = FakeTableEnvironment() + df = pd.DataFrame( + { + "driver_id": pd.Series([1, 2], dtype="int64"), + "conv_rate": pd.Series([0.1, 0.2], dtype="float64"), + "event_timestamp": pd.to_datetime( + ["2024-01-01 00:00:00", "2024-01-02 00:00:00"] + ), + "active": pd.Series([True, False], dtype="bool"), + } + ) + + pandas_to_flink_table(table_env, df, split_num=4) + + assert table_env.schemas[-1] == [ + ("driver_id", "BIGINT"), + ("conv_rate", "DOUBLE"), + ("event_timestamp", "TIMESTAMP(3)"), + ("active", "BOOLEAN"), + ] + assert table_env.schemas[-1] != list(df.columns) + assert table_env.split_nums == [4] + + def _native_flink_value(columns: List[str]) -> DAGValue: return DAGValue( data=FakeNativeFlinkTable(columns), @@ -469,6 +612,8 @@ def double_conv_rate(table: FakeFlinkTable) -> FakeFlinkTable: assert job.error() is None assert result["driver_id"].tolist() == [1, 2] assert result["conv_rate"].tolist() == [0.4, 0.6] + assert table_env.dropped_views + assert table_env.views == {} def test_flink_historical_retrieval_is_read_only_and_dedupes_per_entity_row( @@ -570,6 +715,7 @@ def test_flink_historical_retrieval_supports_sql_entity_df(tmp_path: Path) -> No "SELECT driver_id, event_timestamp FROM entities" in query for query in table_env.queries ) + assert set(table_env.views) == {"entities"} def test_flink_materialize_writes_online_and_offline(tmp_path: Path) -> None: @@ -579,11 +725,12 @@ def test_flink_materialize_writes_online_and_offline(tmp_path: Path) -> None: config = _repo_config(tmp_path, {"type": "flink.engine"}) offline_store = _offline_store(_feature_data().head(1)) online_store = MagicMock() + table_env = FakeTableEnvironment() engine = FlinkComputeEngine( repo_config=config, offline_store=offline_store, online_store=online_store, - table_environment=FakeTableEnvironment(), + table_environment=table_env, ) task = MaterializationTask( project=config.project, @@ -599,6 +746,42 @@ def test_flink_materialize_writes_online_and_offline(tmp_path: Path) -> None: assert jobs[0].error() is None online_store.online_write_batch.assert_called_once() offline_store.offline_write_batch.assert_called_once() + assert table_env.dropped_views + assert table_env.views == {} + + +def test_flink_output_node_streams_batches_without_full_pandas_collect( + tmp_path: Path, +) -> None: + source = _source() + feature_view = _feature_view(source, online=True, offline=True) + input_node = InputNode("input") + node = FlinkOutputNode( + "output", + feature_view, + FakeTableEnvironment(), + split_num=1, + write_output=True, + inputs=[input_node], + ) + context = _execution_context( + tmp_path, + { + "input": DAGValue( + data=FakeFlinkTable(_feature_data().head(2), fail_on_to_pandas=True), + format=DAGFormat.FLINK, + metadata={"columns": list(_feature_data().columns)}, + ) + }, + ) + context.repo_config.materialization_config.online_write_batch_size = 1 + context.online_store = MagicMock() + context.offline_store = MagicMock() + + node.execute(context) + + assert context.online_store.online_write_batch.call_count == 2 + assert context.offline_store.offline_write_batch.call_count == 2 def test_flink_engine_reports_materialization_errors(tmp_path: Path) -> None: @@ -753,13 +936,18 @@ def test_flink_filter_node_renders_ttl_as_valid_flink_interval( node.execute(context) - assert _flink_interval_literal( - timedelta(days=2, hours=3, minutes=4, seconds=5) + assert _subtract_flink_intervals( + "`__entity_event_timestamp`", timedelta(days=2, hours=3, minutes=4, seconds=5) ) == ( - "INTERVAL '2' DAY + INTERVAL '3' HOUR + " - "INTERVAL '4' MINUTE + INTERVAL '5' SECOND" + "`__entity_event_timestamp` - INTERVAL '2' DAY - INTERVAL '3' HOUR " + "- INTERVAL '4' MINUTE - INTERVAL '5' SECOND" + ) + assert any( + "`event_timestamp` >= `__entity_event_timestamp` - INTERVAL '2' DAY " + "- INTERVAL '3' HOUR - INTERVAL '4' MINUTE - INTERVAL '5' SECOND" in query + for query in table_env.queries ) - assert any("INTERVAL '2' DAY" in query for query in table_env.queries) + assert all("+ INTERVAL" not in query for query in table_env.queries) def test_flink_aggregation_node_groups_features(tmp_path: Path) -> None: From 2aae6005d00937ca5fec4b5102aced48bf4cc462 Mon Sep 17 00:00:00 2001 From: Le Xuan An Date: Sun, 7 Jun 2026 11:32:30 +0700 Subject: [PATCH 04/15] fix: Refresh pixi lock for flink extra Signed-off-by: Le Xuan An --- pixi.lock | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pixi.lock b/pixi.lock index a1ba2cf937a..13b92074447 100644 --- a/pixi.lock +++ b/pixi.lock @@ -2399,8 +2399,8 @@ packages: requires_python: '>=3.10' - pypi: ./ name: feast - version: 0.63.1.dev37+g55c2f185f.d20260529 - sha256: 0e2ee7ed164b2fd366ec4d7fca9b273e92393cf16030342ecc102585d1e8e037 + version: 0.63.1.dev62+gb3731e63a + sha256: e130b9bd91de0cede93d1f22fe696c8b1ac6866ebeeb4c89ab18d38f5d8c89a0 requires_dist: - click>=7.0.0,<9.0.0 - colorama>=0.3.9,<1 @@ -2411,7 +2411,8 @@ packages: - mmh3 - numpy>=2.0.0,<3 - pandas>=1.4.3,<3 - - pyarrow>=21.0.0 + - pyarrow>=21.0.0 ; extra != 'flink' + - pyarrow>=16.1.0,<21.0.0 ; extra == 'flink' - pydantic>=2.10.6 - pygments>=2.12.0,<3 - pyyaml>=5.4.0,<7 @@ -2449,6 +2450,7 @@ packages: - ibis-framework[duckdb]>=10.0.0 ; extra == 'duckdb' - elasticsearch>=8.13.0 ; extra == 'elasticsearch' - faiss-cpu>=1.7.0,<=1.10.0 ; extra == 'faiss' + - apache-flink>=2.2.1,<3 ; extra == 'flink' - google-api-core>=1.23.0,<3 ; extra == 'gcp' - googleapis-common-protos>=1.52.0,<2 ; extra == 'gcp' - google-cloud-bigquery[pandas]>=2,<4 ; extra == 'gcp' From 37accc60b0566681b8db667a56fc482ec49f210d Mon Sep 17 00:00:00 2001 From: Le Xuan An Date: Sun, 7 Jun 2026 11:44:21 +0700 Subject: [PATCH 05/15] fix: Satisfy flink feature builder typing Signed-off-by: Le Xuan An --- .../feast/infra/compute_engines/flink/feature_builder.py | 2 +- sdk/python/feast/infra/compute_engines/flink/nodes.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sdk/python/feast/infra/compute_engines/flink/feature_builder.py b/sdk/python/feast/infra/compute_engines/flink/feature_builder.py index 4f4abe7bea1..1b8aa76304c 100644 --- a/sdk/python/feast/infra/compute_engines/flink/feature_builder.py +++ b/sdk/python/feast/infra/compute_engines/flink/feature_builder.py @@ -51,7 +51,7 @@ def _should_join_entity_df(self) -> bool: def _build(self, view: Any, input_nodes: list[DAGNode] | None) -> DAGNode: if view.data_source: - last_node = self.build_source_node(view) + last_node: DAGNode = self.build_source_node(view) if self._should_transform(view): last_node = self.build_transformation_node(view, [last_node]) diff --git a/sdk/python/feast/infra/compute_engines/flink/nodes.py b/sdk/python/feast/infra/compute_engines/flink/nodes.py index a00235f4f18..780d8acbaf4 100644 --- a/sdk/python/feast/infra/compute_engines/flink/nodes.py +++ b/sdk/python/feast/infra/compute_engines/flink/nodes.py @@ -8,7 +8,7 @@ import pandas as pd import pyarrow as pa -from feast import BatchFeatureView, StreamFeatureView +from feast import BatchFeatureView, FeatureView, StreamFeatureView from feast.aggregation import Aggregation, aggregation_specs_to_agg_ops from feast.data_source import DataSource from feast.infra.compute_engines.dag.context import ColumnInfo, ExecutionContext @@ -697,7 +697,7 @@ class FlinkOutputNode(DAGNode): def __init__( self, name: str, - feature_view: Union[BatchFeatureView, StreamFeatureView], + feature_view: Union[BatchFeatureView, FeatureView, StreamFeatureView], table_env: Any, split_num: int, write_output: bool, From c99bc7fe5d2e0a18bd9ceacf54f00d840fb2d9ee Mon Sep 17 00:00:00 2001 From: Le Xuan An Date: Sun, 7 Jun 2026 12:27:39 +0700 Subject: [PATCH 06/15] refactor: Share compute engine timestamp helpers Signed-off-by: Le Xuan An --- .../infra/compute_engines/flink/nodes.py | 20 +++++++-------- .../infra/compute_engines/local/nodes.py | 9 +++---- .../feast/infra/compute_engines/ray/nodes.py | 8 +++--- .../infra/compute_engines/spark/nodes.py | 12 +++------ .../feast/infra/compute_engines/utils.py | 25 ++++++++++++++++++- 5 files changed, 44 insertions(+), 30 deletions(-) diff --git a/sdk/python/feast/infra/compute_engines/flink/nodes.py b/sdk/python/feast/infra/compute_engines/flink/nodes.py index 780d8acbaf4..33c567f03ce 100644 --- a/sdk/python/feast/infra/compute_engines/flink/nodes.py +++ b/sdk/python/feast/infra/compute_engines/flink/nodes.py @@ -20,16 +20,17 @@ pandas_to_flink_table, register_flink_temporary_view, ) -from feast.infra.compute_engines.utils import create_offline_store_retrieval_job -from feast.infra.offline_stores.offline_utils import ( - infer_event_timestamp_from_entity_df, +from feast.infra.compute_engines.utils import ( + ENTITY_ROW_ID, + ENTITY_TS_ALIAS, + create_offline_store_retrieval_job, + find_entity_timestamp_column, + infer_entity_timestamp_column, ) from feast.utils import _convert_arrow_to_proto logger = logging.getLogger(__name__) -ENTITY_TS_ALIAS = "__entity_event_timestamp" -ENTITY_ROW_ID = "__feast_entity_row_id" DEDUP_ROW_NUMBER = "__feast_row_number" @@ -133,10 +134,9 @@ def _sql_value( def _entity_timestamp_column_from_columns(columns: List[str]) -> str: - if ENTITY_TS_ALIAS in columns: - return ENTITY_TS_ALIAS - if "event_timestamp" in columns: - return "event_timestamp" + entity_ts_col = find_entity_timestamp_column(columns) + if entity_ts_col: + return entity_ts_col raise ValueError( "SQL-based entity_df for FlinkComputeEngine must select an " "`event_timestamp` column." @@ -151,7 +151,7 @@ def _entity_value_from_dataframe( entity_df = entity_df.copy() entity_df[ENTITY_ROW_ID] = range(len(entity_df)) entity_schema = dict(zip(entity_df.columns, entity_df.dtypes)) - entity_ts_col = infer_event_timestamp_from_entity_df(entity_schema) + entity_ts_col = infer_entity_timestamp_column(entity_schema) if entity_ts_col != ENTITY_TS_ALIAS: entity_df = entity_df.rename(columns={entity_ts_col: ENTITY_TS_ALIAS}) return ( diff --git a/sdk/python/feast/infra/compute_engines/local/nodes.py b/sdk/python/feast/infra/compute_engines/local/nodes.py index 3274568671b..9d3e1a48881 100644 --- a/sdk/python/feast/infra/compute_engines/local/nodes.py +++ b/sdk/python/feast/infra/compute_engines/local/nodes.py @@ -14,17 +14,14 @@ from feast.infra.compute_engines.local.arrow_table_value import ArrowTableValue from feast.infra.compute_engines.local.local_node import LocalNode from feast.infra.compute_engines.utils import ( + ENTITY_TS_ALIAS, create_offline_store_retrieval_job, -) -from feast.infra.offline_stores.offline_utils import ( - infer_event_timestamp_from_entity_df, + infer_entity_timestamp_column, ) from feast.utils import _convert_arrow_to_proto logger = logging.getLogger(__name__) -ENTITY_TS_ALIAS = "__entity_event_timestamp" - class LocalSourceReadNode(LocalNode): def __init__( @@ -99,7 +96,7 @@ def execute(self, context: ExecutionContext) -> ArrowTableValue: entity_df = self.backend.from_arrow(pa.Table.from_pandas(context.entity_df)) entity_schema = dict(zip(entity_df.columns, entity_df.dtypes)) - entity_ts_col = infer_event_timestamp_from_entity_df(entity_schema) + entity_ts_col = infer_entity_timestamp_column(entity_schema) if entity_ts_col != ENTITY_TS_ALIAS: entity_df = self.backend.rename_columns( diff --git a/sdk/python/feast/infra/compute_engines/ray/nodes.py b/sdk/python/feast/infra/compute_engines/ray/nodes.py index 026c46d4233..edc6dafd5cd 100644 --- a/sdk/python/feast/infra/compute_engines/ray/nodes.py +++ b/sdk/python/feast/infra/compute_engines/ray/nodes.py @@ -23,7 +23,10 @@ safe_batch_processor, write_to_online_store, ) -from feast.infra.compute_engines.utils import create_offline_store_retrieval_job +from feast.infra.compute_engines.utils import ( + ENTITY_TS_ALIAS, + create_offline_store_retrieval_job, +) from feast.infra.ray_initializer import get_ray_wrapper from feast.infra.ray_shared_utils import ( apply_field_mapping, @@ -33,9 +36,6 @@ logger = logging.getLogger(__name__) -# Entity timestamp alias for historical feature retrieval -ENTITY_TS_ALIAS = "__entity_event_timestamp" - class RayReadNode(DAGNode): """ diff --git a/sdk/python/feast/infra/compute_engines/spark/nodes.py b/sdk/python/feast/infra/compute_engines/spark/nodes.py index 5a8c4368fc5..92964b72bc9 100644 --- a/sdk/python/feast/infra/compute_engines/spark/nodes.py +++ b/sdk/python/feast/infra/compute_engines/spark/nodes.py @@ -34,7 +34,9 @@ from feast.infra.compute_engines.dag.value import DAGValue from feast.infra.compute_engines.spark.utils import map_in_arrow from feast.infra.compute_engines.utils import ( + ENTITY_TS_ALIAS, create_offline_store_retrieval_job, + infer_entity_timestamp_column, ) from feast.infra.offline_stores.contrib.spark_offline_store.spark import ( SparkRetrievalJob, @@ -43,9 +45,6 @@ from feast.infra.offline_stores.contrib.spark_offline_store.spark_source import ( SparkSource, ) -from feast.infra.offline_stores.offline_utils import ( - infer_event_timestamp_from_entity_df, -) logger = logging.getLogger(__name__) @@ -144,9 +143,6 @@ def _spark_types_compatible(expected: SparkDataType, actual: SparkDataType) -> b return False -ENTITY_TS_ALIAS = "__entity_event_timestamp" - - # Rename entity_df event_timestamp_col to match feature_df def rename_entity_ts_column( spark_session: SparkSession, entity_df: DataFrame @@ -159,9 +155,7 @@ def rename_entity_ts_column( spark_session=spark_session, entity_df=entity_df, ) - event_timestamp_col = infer_event_timestamp_from_entity_df( - entity_schema=entity_schema, - ) + event_timestamp_col = infer_entity_timestamp_column(entity_schema) if not isinstance(entity_df, DataFrame): entity_df = spark_session.createDataFrame(entity_df) entity_df = entity_df.withColumnRenamed(event_timestamp_col, ENTITY_TS_ALIAS) diff --git a/sdk/python/feast/infra/compute_engines/utils.py b/sdk/python/feast/infra/compute_engines/utils.py index d2c49305376..c36f4d19f07 100644 --- a/sdk/python/feast/infra/compute_engines/utils.py +++ b/sdk/python/feast/infra/compute_engines/utils.py @@ -1,9 +1,32 @@ from datetime import datetime -from typing import Optional +from typing import Any, Mapping, Optional, Sequence from feast.data_source import DataSource from feast.infra.compute_engines.dag.context import ColumnInfo, ExecutionContext from feast.infra.offline_stores.offline_store import RetrievalJob +from feast.infra.offline_stores.offline_utils import ( + DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL, + infer_event_timestamp_from_entity_df, +) + +ENTITY_TS_ALIAS = "__entity_event_timestamp" +ENTITY_ROW_ID = "__feast_entity_row_id" + + +def infer_entity_timestamp_column(entity_schema: Mapping[str, Any]) -> str: + """Resolve the entity timestamp column used for point-in-time joins.""" + if ENTITY_TS_ALIAS in entity_schema: + return ENTITY_TS_ALIAS + return infer_event_timestamp_from_entity_df(dict(entity_schema)) + + +def find_entity_timestamp_column(columns: Sequence[str]) -> Optional[str]: + """Find the timestamp column in an entity DataFrame schema, if present.""" + if ENTITY_TS_ALIAS in columns: + return ENTITY_TS_ALIAS + if DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL in columns: + return DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL + return None def create_offline_store_retrieval_job( From 6ba7a5305feee99bfabc86fdda29ca4ae9fd5ec0 Mon Sep 17 00:00:00 2001 From: Le Xuan An Date: Sun, 7 Jun 2026 13:14:11 +0700 Subject: [PATCH 07/15] test: Cover flink transformation mode Signed-off-by: Le Xuan An --- .../tests/unit/transformation/test_factory.py | 14 ++++++++++++++ sdk/python/tests/unit/transformation/test_mode.py | 12 +++++++++++- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/sdk/python/tests/unit/transformation/test_factory.py b/sdk/python/tests/unit/transformation/test_factory.py index 4484af0ae23..64c0e282239 100644 --- a/sdk/python/tests/unit/transformation/test_factory.py +++ b/sdk/python/tests/unit/transformation/test_factory.py @@ -17,6 +17,7 @@ def test_all_expected_types_registered(self): "sql", "spark_sql", "spark", + "flink", "ray", } assert set(TRANSFORMATION_CLASS_FOR_TYPE.keys()) == expected_types @@ -69,6 +70,19 @@ def test_sql_type_resolves(self, mock_import): "SQLTransformation", ) + @patch("feast.transformation.factory.import_class") + def test_flink_type_resolves(self, mock_import): + mock_cls = MagicMock() + mock_import.return_value = mock_cls + + get_transformation_class_from_type("flink") + + mock_import.assert_called_once_with( + "feast.transformation.flink_transformation", + "FlinkTransformation", + "FlinkTransformation", + ) + def test_invalid_type_raises_value_error(self): with pytest.raises(ValueError, match="Invalid transformation type"): get_transformation_class_from_type("nonexistent") diff --git a/sdk/python/tests/unit/transformation/test_mode.py b/sdk/python/tests/unit/transformation/test_mode.py index d37fc97e28d..3b35edcba0f 100644 --- a/sdk/python/tests/unit/transformation/test_mode.py +++ b/sdk/python/tests/unit/transformation/test_mode.py @@ -3,7 +3,16 @@ class TestTransformationMode: def test_all_modes_defined(self): - expected = {"PYTHON", "PANDAS", "SPARK_SQL", "SPARK", "RAY", "SQL", "SUBSTRAIT"} + expected = { + "PYTHON", + "PANDAS", + "SPARK_SQL", + "SPARK", + "FLINK", + "RAY", + "SQL", + "SUBSTRAIT", + } actual = {m.name for m in TransformationMode} assert actual == expected @@ -12,6 +21,7 @@ def test_mode_values(self): assert TransformationMode.PANDAS.value == "pandas" assert TransformationMode.SPARK_SQL.value == "spark_sql" assert TransformationMode.SPARK.value == "spark" + assert TransformationMode.FLINK.value == "flink" assert TransformationMode.RAY.value == "ray" assert TransformationMode.SQL.value == "sql" assert TransformationMode.SUBSTRAIT.value == "substrait" From 4bf143394d6b7367404ecd9165f1213dd0348abb Mon Sep 17 00:00:00 2001 From: Le Xuan An Date: Sun, 7 Jun 2026 14:31:10 +0700 Subject: [PATCH 08/15] test: Stabilize feature server performance check Signed-off-by: Le Xuan An --- sdk/python/tests/unit/test_feature_server_utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/sdk/python/tests/unit/test_feature_server_utils.py b/sdk/python/tests/unit/test_feature_server_utils.py index 3c749f9e70a..9a53d69ff54 100644 --- a/sdk/python/tests/unit/test_feature_server_utils.py +++ b/sdk/python/tests/unit/test_feature_server_utils.py @@ -23,6 +23,7 @@ import base64 import json +import platform import time import pytest @@ -655,7 +656,12 @@ def test_faster_than_message_to_dict(self): print(f"\nPerformance: fast={fast_time:.3f}s, standard={standard_time:.3f}s") print(f"Speedup: {speedup:.2f}x") - assert speedup >= 1.5, f"Expected at least 1.5x speedup, got {speedup:.2f}x" + # GitHub-hosted macOS runners show more timing variance on this + # microbenchmark while still validating a meaningful speedup. + min_speedup = 1.3 if platform.system() == "Darwin" else 1.5 + assert speedup >= min_speedup, ( + f"Expected at least {min_speedup:.1f}x speedup, got {speedup:.2f}x" + ) class TestStatusNames: From 7aebf7872661d21dd2b2406d71834887ee0b5ea5 Mon Sep 17 00:00:00 2001 From: Le Xuan An Date: Sun, 7 Jun 2026 15:29:19 +0700 Subject: [PATCH 09/15] test: Mock image extractor model loading Signed-off-by: Le Xuan An --- sdk/python/tests/unit/test_image_utils.py | 35 ++++++++++++++++++++--- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/sdk/python/tests/unit/test_image_utils.py b/sdk/python/tests/unit/test_image_utils.py index e635b477ce6..f54077f08cf 100644 --- a/sdk/python/tests/unit/test_image_utils.py +++ b/sdk/python/tests/unit/test_image_utils.py @@ -17,10 +17,7 @@ import pytest from PIL import Image -pytest.importorskip("torch") -pytest.importorskip("timm") -pytest.importorskip("sklearn") - +import feast.image_utils as image_utils from feast.image_utils import ( ImageFeatureExtractor, combine_embeddings, @@ -28,6 +25,36 @@ validate_image_format, ) +torch = pytest.importorskip("torch") +pytest.importorskip("timm") +pytest.importorskip("sklearn") + + +@pytest.fixture(autouse=True) +def mock_timm_model(monkeypatch): + class DummyModel: + def eval(self): + return self + + def __call__(self, input_tensor): + batch_size = input_tensor.shape[0] + return torch.arange(1, 5, dtype=torch.float32).repeat(batch_size, 1) + + def create_model(*_args, **_kwargs): + return DummyModel() + + def create_transform(**_config): + def transform(_image): + return torch.ones((3, 224, 224), dtype=torch.float32) + + return transform + + monkeypatch.setattr(image_utils.timm, "create_model", create_model) + monkeypatch.setattr( + image_utils, "resolve_data_config", lambda *_args, **_kwargs: {} + ) + monkeypatch.setattr(image_utils, "create_transform", create_transform) + class TestImageFeatureExtractor: """Test ImageFeatureExtractor functionality.""" From ebbc86b011dcdb395298befeb28884b3f50bd1e0 Mon Sep 17 00:00:00 2001 From: Le Xuan An Date: Sun, 7 Jun 2026 16:19:10 +0700 Subject: [PATCH 10/15] test: Use localhost for remote offline clients Signed-off-by: Le Xuan An --- .../universal/data_sources/file.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/sdk/python/tests/universal/feature_repos/universal/data_sources/file.py b/sdk/python/tests/universal/feature_repos/universal/data_sources/file.py index 1084685e361..e05052b120a 100644 --- a/sdk/python/tests/universal/feature_repos/universal/data_sources/file.py +++ b/sdk/python/tests/universal/feature_repos/universal/data_sources/file.py @@ -372,6 +372,7 @@ def setup(self, registry: RegistryConfig): self.server_port = free_port() host = "0.0.0.0" + client_host = "localhost" cmd = [ "feast", "-c" + str(repo_path), @@ -388,11 +389,11 @@ def setup(self, registry: RegistryConfig): _time_out_sec: int = 60 # Wait for server to start wait_retry_backoff( - lambda: (None, check_port_open(host, self.server_port)), + lambda: (None, check_port_open(client_host, self.server_port)), timeout_secs=_time_out_sec, timeout_msg=f"Unable to start the feast remote offline server in {_time_out_sec} seconds at port={self.server_port}", ) - return "grpc+tcp://{}:{}".format(host, self.server_port) + return "grpc+tcp://{}:{}".format(client_host, self.server_port) def teardown(self): super().teardown() @@ -441,6 +442,7 @@ def setup(self, registry: RegistryConfig): self.server_port = free_port() host = "0.0.0.0" + client_host = "localhost" cmd = [ "feast", "-c" + str(repo_path), @@ -461,16 +463,16 @@ def setup(self, registry: RegistryConfig): _time_out_sec: int = 60 # Wait for server to start wait_retry_backoff( - lambda: (None, check_port_open(host, self.server_port)), + lambda: (None, check_port_open(client_host, self.server_port)), timeout_secs=_time_out_sec, timeout_msg=f"Unable to start the feast remote offline server in {_time_out_sec} seconds at port={self.server_port}", ) - return "grpc+tls://{}:{}".format(host, self.server_port) + return "grpc+tls://{}:{}".format(client_host, self.server_port) def create_offline_store_config(self) -> FeastConfigBaseModel: remote_offline_store_config = RemoteOfflineStoreConfig( type="remote", - host="0.0.0.0", + host="localhost", port=self.server_port, scheme="https", cert=self.tls_cert_path, @@ -541,6 +543,7 @@ def setup(self, registry: RegistryConfig): self.server_port = free_port() host = "0.0.0.0" + client_host = "localhost" cmd = [ "feast", "-c" + repo_path, @@ -557,15 +560,15 @@ def setup(self, registry: RegistryConfig): _time_out_sec: int = 60 # Wait for server to start wait_retry_backoff( - lambda: (None, check_port_open(host, self.server_port)), + lambda: (None, check_port_open(client_host, self.server_port)), timeout_secs=_time_out_sec, timeout_msg=f"Unable to start the feast remote offline server in {_time_out_sec} seconds at port={self.server_port}", ) - return "grpc+tcp://{}:{}".format(host, self.server_port) + return "grpc+tcp://{}:{}".format(client_host, self.server_port) def create_offline_store_config(self) -> FeastConfigBaseModel: remote_offline_store_config = RemoteOfflineStoreConfig( - type="remote", host="0.0.0.0", port=self.server_port + type="remote", host="localhost", port=self.server_port ) return remote_offline_store_config From 1a51179e4f0cce85f10b886d18933d23564166bd Mon Sep 17 00:00:00 2001 From: Le Xuan An Date: Sun, 7 Jun 2026 17:20:26 +0700 Subject: [PATCH 11/15] test: Stabilize local Milvus integration index Signed-off-by: Le Xuan An --- .../feast/infra/online_stores/milvus_online_store/milvus.py | 6 ++++-- .../feature_repos/universal/online_store/milvus.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py b/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py index 941de3b64cd..0685103aef3 100644 --- a/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py +++ b/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py @@ -125,8 +125,10 @@ class MilvusOnlineStore(OnlineStore): _collections: Dictionary to cache Milvus collections. """ - client: Optional[MilvusClient] = None - _collections: Dict[str, Any] = {} + def __init__(self): + super().__init__() + self.client: Optional[MilvusClient] = None + self._collections: Dict[str, Any] = {} def _get_db_path(self, config: RepoConfig) -> str: assert ( diff --git a/sdk/python/tests/universal/feature_repos/universal/online_store/milvus.py b/sdk/python/tests/universal/feature_repos/universal/online_store/milvus.py index cfe6aec3677..b4d4a0896bd 100644 --- a/sdk/python/tests/universal/feature_repos/universal/online_store/milvus.py +++ b/sdk/python/tests/universal/feature_repos/universal/online_store/milvus.py @@ -14,7 +14,7 @@ def create_online_store(self) -> dict[str, Any]: return { "type": "milvus", "path": self.db_path, - "index_type": "IVF_FLAT", + "index_type": "FLAT", "metric_type": "L2", "embedding_dim": 2, "vector_enabled": True, From 71f5728fdaefdad92be0c35251aae5d4786d3b09 Mon Sep 17 00:00:00 2001 From: Le Xuan An Date: Sun, 7 Jun 2026 18:17:00 +0700 Subject: [PATCH 12/15] test: Stabilize feature server benchmark timing Signed-off-by: Le Xuan An --- .../tests/unit/test_feature_server_utils.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/sdk/python/tests/unit/test_feature_server_utils.py b/sdk/python/tests/unit/test_feature_server_utils.py index 9a53d69ff54..1642a19f488 100644 --- a/sdk/python/tests/unit/test_feature_server_utils.py +++ b/sdk/python/tests/unit/test_feature_server_utils.py @@ -23,7 +23,6 @@ import base64 import json -import platform import time import pytest @@ -642,26 +641,21 @@ def test_faster_than_message_to_dict(self): convert_response_to_dict(response) MessageToDict(response, preserving_proto_field_name=True) - start = time.perf_counter() + start = time.process_time() for _ in range(iterations): convert_response_to_dict(response) - fast_time = time.perf_counter() - start + fast_time = time.process_time() - start - start = time.perf_counter() + start = time.process_time() for _ in range(iterations): MessageToDict(response, preserving_proto_field_name=True) - standard_time = time.perf_counter() - start + standard_time = time.process_time() - start speedup = standard_time / fast_time print(f"\nPerformance: fast={fast_time:.3f}s, standard={standard_time:.3f}s") print(f"Speedup: {speedup:.2f}x") - # GitHub-hosted macOS runners show more timing variance on this - # microbenchmark while still validating a meaningful speedup. - min_speedup = 1.3 if platform.system() == "Darwin" else 1.5 - assert speedup >= min_speedup, ( - f"Expected at least {min_speedup:.1f}x speedup, got {speedup:.2f}x" - ) + assert speedup >= 1.5, f"Expected at least 1.5x speedup, got {speedup:.2f}x" class TestStatusNames: From 77463caf439a47cf1b9dd424b40beafaaf326f25 Mon Sep 17 00:00:00 2001 From: Le Xuan An Date: Sun, 7 Jun 2026 21:01:38 +0700 Subject: [PATCH 13/15] test: Clean up remote offline server processes Signed-off-by: Le Xuan An --- .secrets.baseline | 6 +- .../universal/data_sources/file.py | 92 +++++++++++-------- 2 files changed, 55 insertions(+), 43 deletions(-) diff --git a/.secrets.baseline b/.secrets.baseline index e9e8acd786b..3ba2b824a13 100644 --- a/.secrets.baseline +++ b/.secrets.baseline @@ -1476,14 +1476,14 @@ "filename": "sdk/python/tests/universal/feature_repos/universal/data_sources/file.py", "hashed_secret": "d70eab08607a4d05faa2d0d6647206599e9abc65", "is_verified": false, - "line_number": 257 + "line_number": 302 }, { "type": "Secret Keyword", "filename": "sdk/python/tests/universal/feature_repos/universal/data_sources/file.py", "hashed_secret": "d70eab08607a4d05faa2d0d6647206599e9abc65", "is_verified": false, - "line_number": 257 + "line_number": 302 } ], "sdk/python/tests/universal/feature_repos/universal/online_store/couchbase.py": [ @@ -1539,5 +1539,5 @@ } ] }, - "generated_at": "2026-05-22T11:36:48Z" + "generated_at": "2026-06-07T14:01:04Z" } diff --git a/sdk/python/tests/universal/feature_repos/universal/data_sources/file.py b/sdk/python/tests/universal/feature_repos/universal/data_sources/file.py index e05052b120a..c43309d61c8 100644 --- a/sdk/python/tests/universal/feature_repos/universal/data_sources/file.py +++ b/sdk/python/tests/universal/feature_repos/universal/data_sources/file.py @@ -1,6 +1,7 @@ import logging import os.path import shutil +import signal import subprocess import tempfile import uuid @@ -41,6 +42,50 @@ logger = logging.getLogger(__name__) +def _start_offline_server_process(cmd: list[str]) -> Popen[bytes]: + kwargs: dict[str, Any] = { + "stdout": subprocess.DEVNULL, + "stderr": subprocess.DEVNULL, + } + if os.name == "posix": + kwargs["start_new_session"] = True + return subprocess.Popen(cmd, **kwargs) + + +def _stop_offline_server_process(proc: Popen[bytes], port: int) -> None: + _signal_offline_server_process(proc, signal.SIGTERM) + try: + proc.wait(timeout=10) + except subprocess.TimeoutExpired: + _signal_offline_server_process(proc, signal.SIGKILL) + proc.wait(timeout=10) + + wait_retry_backoff( + lambda: ( + None, + not check_port_open("localhost", port), + ), + timeout_secs=30, + timeout_msg=f"Timed out waiting for remote offline server port {port} to close.", + ) + + +def _signal_offline_server_process(proc: Popen[bytes], sig: signal.Signals) -> None: + if os.name == "posix": + try: + os.killpg(proc.pid, sig) + return + except ProcessLookupError: + return + + if proc.poll() is not None: + return + if sig == signal.SIGTERM: + proc.terminate() + else: + proc.kill() + + class FileDataSourceCreator(DataSourceCreator): files: List[Any] dirs: List[Any] @@ -382,9 +427,7 @@ def setup(self, registry: RegistryConfig): "--port", str(self.server_port), ] - self.proc = subprocess.Popen( - cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL - ) + self.proc = _start_offline_server_process(cmd) _time_out_sec: int = 60 # Wait for server to start @@ -398,16 +441,7 @@ def setup(self, registry: RegistryConfig): def teardown(self): super().teardown() if self.proc is not None: - self.proc.kill() - - # wait server to free the port - wait_retry_backoff( - lambda: ( - None, - not check_port_open("localhost", self.server_port), - ), - timeout_secs=30, - ) + _stop_offline_server_process(self.proc, self.server_port) class RemoteOfflineTlsStoreDataSourceCreator(FileDataSourceCreator): @@ -456,9 +490,7 @@ def setup(self, registry: RegistryConfig): "--cert", str(self.tls_cert_path), ] - self.proc = subprocess.Popen( - cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL - ) + self.proc = _start_offline_server_process(cmd) _time_out_sec: int = 60 # Wait for server to start @@ -482,16 +514,7 @@ def create_offline_store_config(self) -> FeastConfigBaseModel: def teardown(self): super().teardown() if self.proc is not None: - self.proc.kill() - - # wait server to free the port - wait_retry_backoff( - lambda: ( - None, - not check_port_open("localhost", self.server_port), - ), - timeout_secs=30, - ) + _stop_offline_server_process(self.proc, self.server_port) class RemoteOfflineOidcAuthStoreDataSourceCreator(FileDataSourceCreator): @@ -512,7 +535,7 @@ def __init__(self, project_name: str, *args, **kwargs): """ self.auth_config = auth_config_template.format(keycloak_url=self.keycloak_url) self.server_port: int = 0 - self.proc = None + self.proc: Optional[Popen[bytes]] = None @staticmethod def xdist_groups() -> list[str]: @@ -553,9 +576,7 @@ def setup(self, registry: RegistryConfig): "--port", str(self.server_port), ] - self.proc = subprocess.Popen( - cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL - ) # type: ignore + self.proc = _start_offline_server_process(cmd) _time_out_sec: int = 60 # Wait for server to start @@ -578,13 +599,4 @@ def get_keycloak_url(self): def teardown(self): super().teardown() if self.proc is not None: - self.proc.kill() - - # wait server to free the port - wait_retry_backoff( - lambda: ( - None, - not check_port_open("localhost", self.server_port), - ), - timeout_secs=30, - ) + _stop_offline_server_process(self.proc, self.server_port) From 8ef7bbb43a5026ae882b19ee7925ea724c52fdae Mon Sep 17 00:00:00 2001 From: Le Xuan An Date: Sun, 7 Jun 2026 21:44:17 +0700 Subject: [PATCH 14/15] test: Revert unrelated universal test changes Signed-off-by: Le Xuan An --- .secrets.baseline | 6 +- .../universal/data_sources/file.py | 111 ++++++++---------- .../universal/online_store/milvus.py | 2 +- 3 files changed, 52 insertions(+), 67 deletions(-) diff --git a/.secrets.baseline b/.secrets.baseline index 3ba2b824a13..9f2c42a14fd 100644 --- a/.secrets.baseline +++ b/.secrets.baseline @@ -1476,14 +1476,14 @@ "filename": "sdk/python/tests/universal/feature_repos/universal/data_sources/file.py", "hashed_secret": "d70eab08607a4d05faa2d0d6647206599e9abc65", "is_verified": false, - "line_number": 302 + "line_number": 257 }, { "type": "Secret Keyword", "filename": "sdk/python/tests/universal/feature_repos/universal/data_sources/file.py", "hashed_secret": "d70eab08607a4d05faa2d0d6647206599e9abc65", "is_verified": false, - "line_number": 302 + "line_number": 257 } ], "sdk/python/tests/universal/feature_repos/universal/online_store/couchbase.py": [ @@ -1539,5 +1539,5 @@ } ] }, - "generated_at": "2026-06-07T14:01:04Z" + "generated_at": "2026-06-07T14:43:54Z" } diff --git a/sdk/python/tests/universal/feature_repos/universal/data_sources/file.py b/sdk/python/tests/universal/feature_repos/universal/data_sources/file.py index c43309d61c8..1084685e361 100644 --- a/sdk/python/tests/universal/feature_repos/universal/data_sources/file.py +++ b/sdk/python/tests/universal/feature_repos/universal/data_sources/file.py @@ -1,7 +1,6 @@ import logging import os.path import shutil -import signal import subprocess import tempfile import uuid @@ -42,50 +41,6 @@ logger = logging.getLogger(__name__) -def _start_offline_server_process(cmd: list[str]) -> Popen[bytes]: - kwargs: dict[str, Any] = { - "stdout": subprocess.DEVNULL, - "stderr": subprocess.DEVNULL, - } - if os.name == "posix": - kwargs["start_new_session"] = True - return subprocess.Popen(cmd, **kwargs) - - -def _stop_offline_server_process(proc: Popen[bytes], port: int) -> None: - _signal_offline_server_process(proc, signal.SIGTERM) - try: - proc.wait(timeout=10) - except subprocess.TimeoutExpired: - _signal_offline_server_process(proc, signal.SIGKILL) - proc.wait(timeout=10) - - wait_retry_backoff( - lambda: ( - None, - not check_port_open("localhost", port), - ), - timeout_secs=30, - timeout_msg=f"Timed out waiting for remote offline server port {port} to close.", - ) - - -def _signal_offline_server_process(proc: Popen[bytes], sig: signal.Signals) -> None: - if os.name == "posix": - try: - os.killpg(proc.pid, sig) - return - except ProcessLookupError: - return - - if proc.poll() is not None: - return - if sig == signal.SIGTERM: - proc.terminate() - else: - proc.kill() - - class FileDataSourceCreator(DataSourceCreator): files: List[Any] dirs: List[Any] @@ -417,7 +372,6 @@ def setup(self, registry: RegistryConfig): self.server_port = free_port() host = "0.0.0.0" - client_host = "localhost" cmd = [ "feast", "-c" + str(repo_path), @@ -427,21 +381,32 @@ def setup(self, registry: RegistryConfig): "--port", str(self.server_port), ] - self.proc = _start_offline_server_process(cmd) + self.proc = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL + ) _time_out_sec: int = 60 # Wait for server to start wait_retry_backoff( - lambda: (None, check_port_open(client_host, self.server_port)), + lambda: (None, check_port_open(host, self.server_port)), timeout_secs=_time_out_sec, timeout_msg=f"Unable to start the feast remote offline server in {_time_out_sec} seconds at port={self.server_port}", ) - return "grpc+tcp://{}:{}".format(client_host, self.server_port) + return "grpc+tcp://{}:{}".format(host, self.server_port) def teardown(self): super().teardown() if self.proc is not None: - _stop_offline_server_process(self.proc, self.server_port) + self.proc.kill() + + # wait server to free the port + wait_retry_backoff( + lambda: ( + None, + not check_port_open("localhost", self.server_port), + ), + timeout_secs=30, + ) class RemoteOfflineTlsStoreDataSourceCreator(FileDataSourceCreator): @@ -476,7 +441,6 @@ def setup(self, registry: RegistryConfig): self.server_port = free_port() host = "0.0.0.0" - client_host = "localhost" cmd = [ "feast", "-c" + str(repo_path), @@ -490,21 +454,23 @@ def setup(self, registry: RegistryConfig): "--cert", str(self.tls_cert_path), ] - self.proc = _start_offline_server_process(cmd) + self.proc = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL + ) _time_out_sec: int = 60 # Wait for server to start wait_retry_backoff( - lambda: (None, check_port_open(client_host, self.server_port)), + lambda: (None, check_port_open(host, self.server_port)), timeout_secs=_time_out_sec, timeout_msg=f"Unable to start the feast remote offline server in {_time_out_sec} seconds at port={self.server_port}", ) - return "grpc+tls://{}:{}".format(client_host, self.server_port) + return "grpc+tls://{}:{}".format(host, self.server_port) def create_offline_store_config(self) -> FeastConfigBaseModel: remote_offline_store_config = RemoteOfflineStoreConfig( type="remote", - host="localhost", + host="0.0.0.0", port=self.server_port, scheme="https", cert=self.tls_cert_path, @@ -514,7 +480,16 @@ def create_offline_store_config(self) -> FeastConfigBaseModel: def teardown(self): super().teardown() if self.proc is not None: - _stop_offline_server_process(self.proc, self.server_port) + self.proc.kill() + + # wait server to free the port + wait_retry_backoff( + lambda: ( + None, + not check_port_open("localhost", self.server_port), + ), + timeout_secs=30, + ) class RemoteOfflineOidcAuthStoreDataSourceCreator(FileDataSourceCreator): @@ -535,7 +510,7 @@ def __init__(self, project_name: str, *args, **kwargs): """ self.auth_config = auth_config_template.format(keycloak_url=self.keycloak_url) self.server_port: int = 0 - self.proc: Optional[Popen[bytes]] = None + self.proc = None @staticmethod def xdist_groups() -> list[str]: @@ -566,7 +541,6 @@ def setup(self, registry: RegistryConfig): self.server_port = free_port() host = "0.0.0.0" - client_host = "localhost" cmd = [ "feast", "-c" + repo_path, @@ -576,20 +550,22 @@ def setup(self, registry: RegistryConfig): "--port", str(self.server_port), ] - self.proc = _start_offline_server_process(cmd) + self.proc = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL + ) # type: ignore _time_out_sec: int = 60 # Wait for server to start wait_retry_backoff( - lambda: (None, check_port_open(client_host, self.server_port)), + lambda: (None, check_port_open(host, self.server_port)), timeout_secs=_time_out_sec, timeout_msg=f"Unable to start the feast remote offline server in {_time_out_sec} seconds at port={self.server_port}", ) - return "grpc+tcp://{}:{}".format(client_host, self.server_port) + return "grpc+tcp://{}:{}".format(host, self.server_port) def create_offline_store_config(self) -> FeastConfigBaseModel: remote_offline_store_config = RemoteOfflineStoreConfig( - type="remote", host="localhost", port=self.server_port + type="remote", host="0.0.0.0", port=self.server_port ) return remote_offline_store_config @@ -599,4 +575,13 @@ def get_keycloak_url(self): def teardown(self): super().teardown() if self.proc is not None: - _stop_offline_server_process(self.proc, self.server_port) + self.proc.kill() + + # wait server to free the port + wait_retry_backoff( + lambda: ( + None, + not check_port_open("localhost", self.server_port), + ), + timeout_secs=30, + ) diff --git a/sdk/python/tests/universal/feature_repos/universal/online_store/milvus.py b/sdk/python/tests/universal/feature_repos/universal/online_store/milvus.py index b4d4a0896bd..cfe6aec3677 100644 --- a/sdk/python/tests/universal/feature_repos/universal/online_store/milvus.py +++ b/sdk/python/tests/universal/feature_repos/universal/online_store/milvus.py @@ -14,7 +14,7 @@ def create_online_store(self) -> dict[str, Any]: return { "type": "milvus", "path": self.db_path, - "index_type": "FLAT", + "index_type": "IVF_FLAT", "metric_type": "L2", "embedding_dim": 2, "vector_enabled": True, From 3af9cd7b67d911f48574bd8ca8a9eb8117eb9711 Mon Sep 17 00:00:00 2001 From: Le Xuan An Date: Thu, 11 Jun 2026 14:50:41 +0700 Subject: [PATCH 15/15] fix: Address Flink review blockers Signed-off-by: Le Xuan An --- .../compute_engines/flink/feature_builder.py | 5 +- .../infra/compute_engines/flink/nodes.py | 68 ++++++++++++++----- .../flink/test_flink_compute_engine.py | 49 +++++++++++-- 3 files changed, 97 insertions(+), 25 deletions(-) diff --git a/sdk/python/feast/infra/compute_engines/flink/feature_builder.py b/sdk/python/feast/infra/compute_engines/flink/feature_builder.py index 1b8aa76304c..3fd276c92fe 100644 --- a/sdk/python/feast/infra/compute_engines/flink/feature_builder.py +++ b/sdk/python/feast/infra/compute_engines/flink/feature_builder.py @@ -39,10 +39,7 @@ def __init__( def _should_join_entity_df(self) -> bool: return isinstance(self.task, HistoricalRetrievalTask) and ( - ( - isinstance(self.task.entity_df, pd.DataFrame) - and not self.task.entity_df.empty - ) + isinstance(self.task.entity_df, pd.DataFrame) or ( isinstance(self.task.entity_df, str) and bool(self.task.entity_df.strip()) diff --git a/sdk/python/feast/infra/compute_engines/flink/nodes.py b/sdk/python/feast/infra/compute_engines/flink/nodes.py index 33c567f03ce..2bdb29cd8c4 100644 --- a/sdk/python/feast/infra/compute_engines/flink/nodes.py +++ b/sdk/python/feast/infra/compute_engines/flink/nodes.py @@ -147,11 +147,22 @@ def _entity_value_from_dataframe( table_env: Any, entity_df: pd.DataFrame, split_num: int, + join_keys: List[str], ) -> tuple[Any, List[str], str]: entity_df = entity_df.copy() + if entity_df.empty: + for join_key in join_keys: + if join_key not in entity_df.columns: + entity_df[join_key] = pd.Series(dtype="object") + entity_ts_col = find_entity_timestamp_column(list(entity_df.columns)) + if entity_ts_col is None: + entity_ts_col = ENTITY_TS_ALIAS + entity_df[entity_ts_col] = pd.Series(dtype="datetime64[ns]") + else: + entity_schema = dict(zip(entity_df.columns, entity_df.dtypes)) + entity_ts_col = infer_entity_timestamp_column(entity_schema) + entity_df[ENTITY_ROW_ID] = range(len(entity_df)) - entity_schema = dict(zip(entity_df.columns, entity_df.dtypes)) - entity_ts_col = infer_entity_timestamp_column(entity_schema) if entity_ts_col != ENTITY_TS_ALIAS: entity_df = entity_df.rename(columns={entity_ts_col: ENTITY_TS_ALIAS}) return ( @@ -219,7 +230,9 @@ def _entity_value_from_context( join_keys: List[str], ) -> tuple[Any, List[str], str]: if isinstance(context.entity_df, pd.DataFrame): - return _entity_value_from_dataframe(table_env, context.entity_df, split_num) + return _entity_value_from_dataframe( + table_env, context.entity_df, split_num, join_keys + ) if isinstance(context.entity_df, str): return _entity_value_from_sql(table_env, context.entity_df, join_keys) raise TypeError( @@ -227,6 +240,38 @@ def _entity_value_from_context( ) +def _retrieval_job_to_flink_table( + retrieval_job: Any, + table_env: Any, + split_num: int, +) -> tuple[Any, List[str]]: + to_flink_table = getattr(retrieval_job, "to_flink_table", None) + if callable(to_flink_table): + flink_table = to_flink_table(table_env) + columns = _get_columns_from_schema(flink_table) + if columns is None: + raise ValueError( + "Could not infer columns for source Flink table returned by " + "RetrievalJob.to_flink_table(table_env)." + ) + return flink_table, columns + + if not hasattr(retrieval_job, "to_arrow"): + raise TypeError( + "FlinkComputeEngine source reads require a RetrievalJob with either " + "to_flink_table(table_env) or to_arrow()." + ) + + arrow_table = retrieval_job.to_arrow() + if not isinstance(arrow_table, pa.Table): + raise TypeError( + "RetrievalJob.to_arrow() must return a pyarrow.Table for " + "FlinkComputeEngine source reads." + ) + columns = list(arrow_table.column_names) + return pandas_to_flink_table(table_env, arrow_table.to_pandas(), split_num), columns + + class FlinkSourceReadNode(DAGNode): def __init__( self, @@ -254,20 +299,9 @@ def execute(self, context: ExecutionContext) -> DAGValue: start_time=self.start_time, end_time=self.end_time, ) - if not hasattr(retrieval_job, "to_flink_table"): - raise TypeError( - "FlinkComputeEngine source reads require RetrievalJob.to_flink_table(" - "table_env). Configure an offline store retrieval job that returns " - "native PyFlink tables instead of Arrow/pandas results." - ) - - flink_table = retrieval_job.to_flink_table(self.table_env) - columns = _get_columns_from_schema(flink_table) - if columns is None: - raise ValueError( - "Could not infer columns for source Flink table returned by " - "RetrievalJob.to_flink_table(table_env)." - ) + flink_table, columns = _retrieval_job_to_flink_table( + retrieval_job, self.table_env, self.split_num + ) if self.column_info.field_mapping: view_name = _register_table(self.table_env, flink_table, "source_read") diff --git a/sdk/python/tests/unit/infra/compute_engines/flink/test_flink_compute_engine.py b/sdk/python/tests/unit/infra/compute_engines/flink/test_flink_compute_engine.py index 13ee329616b..7941c8ae405 100644 --- a/sdk/python/tests/unit/infra/compute_engines/flink/test_flink_compute_engine.py +++ b/sdk/python/tests/unit/infra/compute_engines/flink/test_flink_compute_engine.py @@ -549,7 +549,9 @@ def test_repo_config_loads_flink_batch_engine_config(tmp_path: Path) -> None: assert config.batch_engine.pandas_split_num == 2 -def test_flink_source_read_node_rejects_arrow_retrieval_jobs(tmp_path: Path) -> None: +def test_flink_source_read_node_converts_arrow_retrieval_jobs( + tmp_path: Path, +) -> None: offline_store = MagicMock() offline_store.pull_all_from_table_or_query.return_value = FakeRetrievalJob( pa.Table.from_pandas(_feature_data()) @@ -564,8 +566,11 @@ def test_flink_source_read_node_rejects_arrow_retrieval_jobs(tmp_path: Path) -> split_num=1, ) - with pytest.raises(TypeError, match="to_flink_table"): - node.execute(context) + result = node.execute(context) + + assert result.format == DAGFormat.FLINK + assert result.metadata["columns"] == list(_feature_data().columns) + assert result.data.to_pandas().equals(_feature_data()) def test_flink_historical_retrieval_executes_dag_with_transformation( @@ -600,7 +605,7 @@ def double_conv_rate(table: FakeFlinkTable) -> FakeFlinkTable: ) task = HistoricalRetrievalTask( project=config.project, - entity_df=pd.DataFrame(), + entity_df=None, feature_view=feature_view, full_feature_name=False, registry=_registry(entity), @@ -616,6 +621,42 @@ def double_conv_rate(table: FakeFlinkTable) -> FakeFlinkTable: assert table_env.views == {} +def test_flink_historical_retrieval_with_empty_entity_df_returns_empty_result( + tmp_path: Path, +) -> None: + entity = _driver() + source = _source() + feature_view = _feature_view(source, online=False, offline=False) + config = _repo_config(tmp_path, {"type": "flink.engine", "pandas_split_num": 4}) + table_env = FakeTableEnvironment() + engine = FlinkComputeEngine( + repo_config=config, + offline_store=_offline_store(_feature_data()), + online_store=MagicMock(), + table_environment=table_env, + ) + task = HistoricalRetrievalTask( + project=config.project, + entity_df=pd.DataFrame( + { + "driver_id": pd.Series(dtype="int64"), + "event_timestamp": pd.Series(dtype="datetime64[ns]"), + } + ), + feature_view=feature_view, + full_feature_name=False, + registry=_registry(entity), + ) + + job = engine.get_historical_features(_registry(entity), task) + result = job.to_df() + + assert job.error() is None + assert result.empty + assert "conv_rate" in result.columns + assert table_env.created_tables[-1].empty + + def test_flink_historical_retrieval_is_read_only_and_dedupes_per_entity_row( tmp_path: Path, ) -> None: