diff --git a/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/trino.py b/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/trino.py index 33190bd4635..0f77d6e18fc 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/trino.py +++ b/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/trino.py @@ -1,6 +1,15 @@ +import logging import uuid from datetime import date, datetime -from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from typing import ( + Any, + Dict, + List, + Literal, + Optional, + Tuple, + Union, +) import numpy as np import pandas as pd @@ -37,6 +46,8 @@ from feast.repo_config import FeastConfigBaseModel, RepoConfig from feast.saved_dataset import SavedDatasetStorage +logger = logging.getLogger(__name__) + class BasicAuthModel(FeastConfigBaseModel): username: StrictStr @@ -183,6 +194,7 @@ def __init__( full_feature_names: bool, on_demand_feature_views: Optional[List[OnDemandFeatureView]] = None, metadata: Optional[RetrievalMetadata] = None, + temp_table: Optional[str] = None, ): self._query = query self._client = client @@ -190,6 +202,8 @@ def __init__( self._full_feature_names = full_feature_names self._on_demand_feature_views = on_demand_feature_views or [] self._metadata = metadata + self._temp_table = temp_table + self._cleaned_up = False @property def full_feature_names(self) -> bool: @@ -199,11 +213,29 @@ def full_feature_names(self) -> bool: def on_demand_feature_views(self) -> List[OnDemandFeatureView]: return self._on_demand_feature_views + def _drop_temp_table(self) -> None: + if self._cleaned_up or not self._temp_table: + return + self._cleaned_up = True + try: + self._client.execute_query(f"DROP TABLE IF EXISTS {self._temp_table}") + except Exception: + logger.exception( + "Failed to drop temporary entity table %s", + self._temp_table, + ) + + def __del__(self) -> None: + self._drop_temp_table() + def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame: """Return dataset as Pandas DataFrame synchronously including on demand transforms""" - results = self._client.execute_query(query_text=self._query) - self.pyarrow_schema = results.pyarrow_schema - return results.to_dataframe() + try: + results = self._client.execute_query(query_text=self._query) + self.pyarrow_schema = results.pyarrow_schema + return results.to_dataframe() + finally: + self._drop_temp_table() def _to_arrow_internal(self, timeout: Optional[int] = None) -> pyarrow.Table: """Return payrrow dataset as synchronously including on demand transforms""" @@ -234,8 +266,11 @@ def to_trino( destination_table = f"{self._client.catalog}.{self._config.offline_store.dataset}.historical_{today}_{rand_id}" # TODO: Implement the timeout logic - query = f"CREATE TABLE {destination_table} AS ({self._query})" - self._client.execute_query(query_text=query) + try: + create_query = f"CREATE TABLE {destination_table} AS ({self._query})" + self._client.execute_query(query_text=create_query) + finally: + self._drop_temp_table() return destination_table def persist( @@ -372,11 +407,12 @@ def get_historical_features( ) # Generate the Trino SQL query from the query context + entity_table_ref = table_reference if type(entity_df) is str: - table_reference = f"({entity_df})" + entity_table_ref = f"({entity_df})" query = offline_utils.build_point_in_time_query( query_context, - left_table_query_string=table_reference, + left_table_query_string=entity_table_ref, entity_df_event_timestamp_col=entity_df_event_timestamp_col, entity_df_columns=entity_schema.keys(), query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN, @@ -385,6 +421,7 @@ def get_historical_features( return TrinoRetrievalJob( query=query, + temp_table=table_reference if isinstance(entity_df, pd.DataFrame) else None, client=client, config=config, full_feature_names=full_feature_names, @@ -483,8 +520,6 @@ def _upload_entity_df_and_get_entity_schema( else: raise InvalidEntityType(type(entity_df)) - # TODO: Ensure that the table expires after some time - def _get_trino_client(config: RepoConfig) -> Trino: auth = None