Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
import logging
import uuid
from datetime import date, datetime
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from typing import (
Any,
Dict,
List,
Literal,
Optional,
Tuple,
Union,
)

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -37,6 +46,8 @@
from feast.repo_config import FeastConfigBaseModel, RepoConfig
from feast.saved_dataset import SavedDatasetStorage

logger = logging.getLogger(__name__)


class BasicAuthModel(FeastConfigBaseModel):
username: StrictStr
Expand Down Expand Up @@ -183,13 +194,16 @@ def __init__(
full_feature_names: bool,
on_demand_feature_views: Optional[List[OnDemandFeatureView]] = None,
metadata: Optional[RetrievalMetadata] = None,
temp_table: Optional[str] = None,
):
self._query = query
self._client = client
self._config = config
self._full_feature_names = full_feature_names
self._on_demand_feature_views = on_demand_feature_views or []
self._metadata = metadata
self._temp_table = temp_table
self._cleaned_up = False

@property
def full_feature_names(self) -> bool:
Expand All @@ -199,11 +213,29 @@ def full_feature_names(self) -> bool:
def on_demand_feature_views(self) -> List[OnDemandFeatureView]:
return self._on_demand_feature_views

def _drop_temp_table(self) -> None:
if self._cleaned_up or not self._temp_table:
return
self._cleaned_up = True
try:
self._client.execute_query(f"DROP TABLE IF EXISTS {self._temp_table}")
except Exception:
logger.exception(
"Failed to drop temporary entity table %s",
self._temp_table,
)

def __del__(self) -> None:
self._drop_temp_table()

def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame:
"""Return dataset as Pandas DataFrame synchronously including on demand transforms"""
results = self._client.execute_query(query_text=self._query)
self.pyarrow_schema = results.pyarrow_schema
return results.to_dataframe()
try:
results = self._client.execute_query(query_text=self._query)
self.pyarrow_schema = results.pyarrow_schema
return results.to_dataframe()
finally:
self._drop_temp_table()

def _to_arrow_internal(self, timeout: Optional[int] = None) -> pyarrow.Table:
"""Return payrrow dataset as synchronously including on demand transforms"""
Expand Down Expand Up @@ -234,8 +266,11 @@ def to_trino(
destination_table = f"{self._client.catalog}.{self._config.offline_store.dataset}.historical_{today}_{rand_id}"

# TODO: Implement the timeout logic
query = f"CREATE TABLE {destination_table} AS ({self._query})"
self._client.execute_query(query_text=query)
try:
create_query = f"CREATE TABLE {destination_table} AS ({self._query})"
self._client.execute_query(query_text=create_query)
finally:
self._drop_temp_table()
return destination_table

def persist(
Expand Down Expand Up @@ -372,11 +407,12 @@ def get_historical_features(
)

# Generate the Trino SQL query from the query context
entity_table_ref = table_reference
if type(entity_df) is str:
table_reference = f"({entity_df})"
entity_table_ref = f"({entity_df})"
query = offline_utils.build_point_in_time_query(
query_context,
left_table_query_string=table_reference,
left_table_query_string=entity_table_ref,
entity_df_event_timestamp_col=entity_df_event_timestamp_col,
entity_df_columns=entity_schema.keys(),
query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN,
Expand All @@ -385,6 +421,7 @@ def get_historical_features(

return TrinoRetrievalJob(
query=query,
temp_table=table_reference if isinstance(entity_df, pd.DataFrame) else None,
client=client,
config=config,
full_feature_names=full_feature_names,
Expand Down Expand Up @@ -483,8 +520,6 @@ def _upload_entity_df_and_get_entity_schema(
else:
raise InvalidEntityType(type(entity_df))

# TODO: Ensure that the table expires after some time


def _get_trino_client(config: RepoConfig) -> Trino:
auth = None
Expand Down
Loading