From 6d0c527ce4a00102f4b1574ee957fe1ed76de468 Mon Sep 17 00:00:00 2001 From: abhijeet-dhumal Date: Wed, 22 Apr 2026 01:57:52 +0530 Subject: [PATCH] fix(spark): Use SELECT * when feature_name_columns is empty in pull_all_from_table_or_query MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit pull_all_from_table_or_query always builds an explicit SELECT projection from join_key_columns + feature_name_columns + timestamp_fields. When feature_name_columns=[] — the "read all source columns" signal used by FeatureBuilder.get_column_info for BatchFeatureView with TransformationMode.PYTHON, ray, and pandas — the generated SQL becomes: SELECT user_id, event_timestamp FROM source WHERE ... All raw feature columns (rating, text, helpful_vote, …) are silently dropped. The UDF receives a 2-column DataFrame and every aggregation returns null or fails. Fix: guard on feature_name_columns being non-empty before building the explicit projection; fall through to SELECT * when it is empty. Signed-off-by: abhijeet-dhumal --- .../contrib/spark_offline_store/spark.py | 16 +- .../test_spark_offline_store_pull_all.py | 172 ++++++++++++++++++ 2 files changed, 183 insertions(+), 5 deletions(-) create mode 100644 sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark_offline_store_pull_all.py diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py index c7ed40ccc02..bede2a6f44c 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py @@ -387,12 +387,18 @@ def pull_all_from_table_or_query( timestamp_fields = [timestamp_field] if created_timestamp_column: timestamp_fields.append(created_timestamp_column) - (fields_with_aliases, aliases) = _get_fields_with_aliases( - fields=join_key_columns + feature_name_columns + timestamp_fields, - field_mappings=data_source.field_mapping, - ) - fields_with_alias_string = ", ".join(fields_with_aliases) + if feature_name_columns: + (fields_with_aliases, _) = _get_fields_with_aliases( + fields=join_key_columns + feature_name_columns + timestamp_fields, + field_mappings=data_source.field_mapping, + ) + fields_with_alias_string = ", ".join(fields_with_aliases) + else: + # Empty feature_name_columns signals "read all source columns". + # Used by BatchFeatureView with TransformationMode.PYTHON/ray/pandas where + # the UDF computes output features from raw input — don't project upfront. + fields_with_alias_string = "*" from_expression = data_source.get_table_query_string() timestamp_filter = get_timestamp_filter_sql( diff --git a/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark_offline_store_pull_all.py b/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark_offline_store_pull_all.py new file mode 100644 index 00000000000..c29b89db001 --- /dev/null +++ b/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark_offline_store_pull_all.py @@ -0,0 +1,172 @@ +""" +Unit tests for SparkOfflineStore.pull_all_from_table_or_query SQL generation. + +Covers the bug where feature_name_columns=[] (signalling "read all source +columns" for BatchFeatureView UDF transformations) caused a bare + SELECT user_id, event_timestamp FROM source +instead of SELECT *, silently dropping all columns the UDF needs. +""" + +from datetime import datetime, timezone +from unittest.mock import MagicMock, patch + +import pytest + +from feast.infra.offline_stores.contrib.spark_offline_store.spark import ( # noqa: E402 + SparkOfflineStore, + SparkOfflineStoreConfig, +) +from feast.infra.offline_stores.contrib.spark_offline_store.spark_source import ( # noqa: E402 + SparkSource, +) +from feast.infra.online_stores.sqlite import SqliteOnlineStoreConfig # noqa: E402 +from feast.repo_config import RepoConfig # noqa: E402 + +# --------------------------------------------------------------------------- +# Shared fixtures +# --------------------------------------------------------------------------- + +START = datetime(2023, 1, 1, tzinfo=timezone.utc) +END = datetime(2024, 1, 1, tzinfo=timezone.utc) + +# Fixed table name returned by the mocked get_table_query_string +_TABLE_EXPR = "`raw_reviews`" + + +@pytest.fixture() +def repo_config(): + return RepoConfig( + registry="file:///tmp/registry.db", + project="test", + provider="local", + online_store=SqliteOnlineStoreConfig(type="sqlite"), + offline_store=SparkOfflineStoreConfig(type="spark"), + ) + + +@pytest.fixture() +def spark_source(): + return SparkSource( + name="raw_reviews", + path="s3a://bucket/processed/reviews/", + file_format="parquet", + timestamp_field="event_timestamp", + ) + + +def _run_pull_all(repo_config, spark_source, feature_name_columns): + """ + Call pull_all_from_table_or_query with a mocked SparkSession and mocked + data-source table resolution, then return the SQL query string. + + Two things are patched so no real Spark/S3 access occurs: + 1. get_spark_session_or_start_new_with_repoconfig → MagicMock session + 2. spark_source.get_table_query_string → fixed table expression + (avoids SparkSource.validate / _load_dataframe_from_path hitting S3) + """ + mock_spark = MagicMock() + + with ( + patch( + "feast.infra.offline_stores.contrib.spark_offline_store.spark" + ".get_spark_session_or_start_new_with_repoconfig", + return_value=mock_spark, + ), + patch.object( + spark_source, + "get_table_query_string", + return_value=_TABLE_EXPR, + ), + ): + job = SparkOfflineStore.pull_all_from_table_or_query( + config=repo_config, + data_source=spark_source, + join_key_columns=["user_id"], + feature_name_columns=feature_name_columns, + timestamp_field="event_timestamp", + created_timestamp_column=None, + start_date=START, + end_date=END, + ) + + return job.query.strip() + + +def test_pull_all_with_empty_feature_cols_generates_select_star( + repo_config, spark_source +): + """ + feature_name_columns=[] must produce SELECT * so UDF-based + BatchFeatureViews receive all raw source columns for aggregation. + """ + sql = _run_pull_all(repo_config, spark_source, feature_name_columns=[]) + + assert sql.startswith("SELECT *"), ( + "Expected 'SELECT *' when feature_name_columns=[], " + f"got: {sql[:120]!r}\n\n" + "BatchFeatureView UDFs need all raw source columns to compute " + "aggregations — projecting only join key + timestamp silently " + "drops rating, text, helpful_vote, etc." + ) + assert "user_id" not in sql.split("FROM")[0], ( + "SELECT * must not also explicitly list join key columns" + ) + + +def test_pull_all_with_feature_cols_generates_explicit_projection( + repo_config, spark_source +): + """ + When feature_name_columns is non-empty (normal FeatureView path), + the query must project only the requested columns — not SELECT *. + """ + sql = _run_pull_all( + repo_config, + spark_source, + feature_name_columns=["avg_rating", "review_count"], + ) + + assert "SELECT *" not in sql, ( + "Non-empty feature_name_columns must produce explicit SELECT projection, not SELECT *" + ) + assert "avg_rating" in sql + assert "review_count" in sql + assert "user_id" in sql + assert "event_timestamp" in sql + + +def test_pull_all_empty_feature_cols_upstream_regression(repo_config, spark_source): + """ + Regression guard: the upstream (unfixed) behaviour with feature_name_columns=[] + produced a query that only selected join key + timestamp, dropping all columns + the UDF needs. Verify the fixed code does NOT produce that broken query. + + Broken upstream SQL looked like: + SELECT user_id, event_timestamp FROM ... WHERE ... + """ + sql = _run_pull_all(repo_config, spark_source, feature_name_columns=[]) + + projection = sql.split("FROM")[0] + assert "user_id" not in projection, ( + "Upstream bug: query projected only 'user_id, event_timestamp', " + "silently dropping all columns needed by the BFV UDF. " + "Fixed query should use SELECT *." + ) + + +@pytest.mark.parametrize( + "feature_cols,expect_star", + [ + ([], True), + (["f1"], False), + (["f1", "f2", "f3"], False), + ], +) +def test_pull_all_select_star_only_when_feature_cols_empty( + repo_config, spark_source, feature_cols, expect_star +): + sql = _run_pull_all(repo_config, spark_source, feature_name_columns=feature_cols) + has_star = sql.strip().upper().startswith("SELECT *") + assert has_star == expect_star, ( + f"feature_cols={feature_cols!r}: expected SELECT *={expect_star}, got SQL: {sql[:100]!r}" + )