diff --git a/sdk/python/feast/infra/offline_stores/bigquery_source.py b/sdk/python/feast/infra/offline_stores/bigquery_source.py index 69e42e3fd09..7b476afdcda 100644 --- a/sdk/python/feast/infra/offline_stores/bigquery_source.py +++ b/sdk/python/feast/infra/offline_stores/bigquery_source.py @@ -47,15 +47,15 @@ def __init__( case the table must be specified. timestamp_field (optional): Event timestamp field used for point in time joins of feature values. - table (optional): BigQuery table where the features are stored. Exactly one of 'table' - and 'query' must be specified. - table (optional): The BigQuery table where features can be found. + table (optional): BigQuery table where the features are stored. At least one of 'table' + and 'query' must be specified. When both are set, 'query' is used for reads and + 'table' is used as the write destination. created_timestamp_column (optional): Timestamp column when row was created, used for deduplicating rows. field_mapping (optional): A dictionary mapping of column names in this data source to feature names in a feature table or view. Only used for feature columns, not entities or timestamp columns. date_partition_column (optional): Timestamp column used for partitioning. - query (optional): The query to be executed to obtain the features. Exactly one of 'table' - and 'query' must be specified. + query (optional): The query to be executed to obtain the features. When both 'table' + and 'query' are provided, 'query' takes priority for reads. description (optional): A human-readable description. tags (optional): A dictionary of key-value pairs to store arbitrary metadata. owner (optional): The owner of the bigquery source, typically the email of the primary @@ -156,10 +156,10 @@ def validate(self, config: RepoConfig): def get_table_query_string(self) -> str: """Returns a string that can directly be used to reference this table in SQL""" - if self.table: - return f"`{self.table}`" - else: + if self.query: return f"({self.query})" + else: + return f"`{self.table}`" @staticmethod def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]: @@ -185,14 +185,14 @@ def get_table_column_names_and_types( location=config.offline_store.location, client_info=http_client_info.ClientInfo(user_agent=get_user_agent()), ) - if self.table: - schema = client.get_table(self.table).schema - if not isinstance(schema[0], bigquery.schema.SchemaField): - raise TypeError("Could not parse BigQuery table schema.") - else: + if self.query: bq_columns_query = f"SELECT * FROM ({self.query}) LIMIT 0" query_res = client.query(bq_columns_query).result() schema = query_res.schema + else: + schema = client.get_table(self.table).schema + if not isinstance(schema[0], bigquery.schema.SchemaField): + raise TypeError("Could not parse BigQuery table schema.") name_type_pairs: List[Tuple[str, str]] = [] for field in schema: diff --git a/sdk/python/tests/unit/infra/offline_stores/test_bigquery.py b/sdk/python/tests/unit/infra/offline_stores/test_bigquery.py index 7dbf06e94a8..969a9679971 100644 --- a/sdk/python/tests/unit/infra/offline_stores/test_bigquery.py +++ b/sdk/python/tests/unit/infra/offline_stores/test_bigquery.py @@ -152,3 +152,51 @@ def test_pull_all_from_table_or_query_partition_pruning(mock_get_bigquery_client ) assert "partition_date >= '2021-01-01'" in actual_query assert "partition_date <= '2021-01-02'" in actual_query + + +class TestBigQuerySourceGetTableQueryString: + def test_table_only(self): + source = BigQuerySource( + name="test", + table="project.dataset.table", + timestamp_field="ts", + ) + assert source.get_table_query_string() == "`project.dataset.table`" + + def test_query_only(self): + source = BigQuerySource( + name="test", + query="SELECT * FROM `project.dataset.table` WHERE active = TRUE", + timestamp_field="ts", + ) + assert ( + source.get_table_query_string() + == "(SELECT * FROM `project.dataset.table` WHERE active = TRUE)" + ) + + def test_both_table_and_query_prefers_query(self): + """When both table and query are set, query takes priority for reads.""" + query = ( + "SELECT * FROM `project.dataset.table`" + " QUALIFY ROW_NUMBER() OVER (PARTITION BY entity_id, event_time) = 1" + ) + source = BigQuerySource( + name="test", + table="project.dataset.table", + query=query, + timestamp_field="ts", + ) + result = source.get_table_query_string() + assert result.startswith("(") + assert "QUALIFY" in result + assert result != "`project.dataset.table`" + + def test_table_property_unaffected_by_query_priority(self): + """The .table property is still accessible for write paths.""" + source = BigQuerySource( + name="test", + table="project.dataset.write_target", + query="SELECT * FROM `project.dataset.write_target` WHERE deduped", + timestamp_field="ts", + ) + assert source.table == "project.dataset.write_target"