From 44526f5030f1b5b2a3f5742f43dc9943a7251e9b Mon Sep 17 00:00:00 2001 From: Miles Adkins Date: Wed, 5 Apr 2023 09:35:48 -0500 Subject: [PATCH] fix: Batch Snowflake materialization queries to obey Snowpark 100 limit Signed-off-by: Miles Adkins --- .../infra/materialization/snowflake_engine.py | 96 ++++++++++++------- 1 file changed, 61 insertions(+), 35 deletions(-) diff --git a/sdk/python/feast/infra/materialization/snowflake_engine.py b/sdk/python/feast/infra/materialization/snowflake_engine.py index 8a63e008911..36c42cd390c 100644 --- a/sdk/python/feast/infra/materialization/snowflake_engine.py +++ b/sdk/python/feast/infra/materialization/snowflake_engine.py @@ -276,32 +276,65 @@ def _materialize_one( fv_latest_values_sql = offline_job.to_sql() + if feature_view.entity_columns: + join_keys = [entity.name for entity in feature_view.entity_columns] + unique_entities = '"' + '", "'.join(join_keys) + '"' + + query = f""" + SELECT + COUNT(DISTINCT {unique_entities}) + FROM + {feature_view.batch_source.get_table_query_string()} + """ + + with GetSnowflakeConnection(self.repo_config.offline_store) as conn: + entities_to_write = conn.cursor().execute(query).fetchall()[0][0] + else: + entities_to_write = ( + 1 # entityless feature view has a placeholder entity + ) + if feature_view.batch_source.field_mapping is not None: fv_latest_mapped_values_sql = _run_snowflake_field_mapping( fv_latest_values_sql, feature_view.batch_source.field_mapping ) - fv_to_proto_sql = self.generate_snowflake_materialization_query( - self.repo_config, - fv_latest_mapped_values_sql, - feature_view, - project, - ) + features_full_list = feature_view.features + feature_batches = [ + features_full_list[i : i + 100] + for i in range(0, len(features_full_list), 100) + ] if self.repo_config.online_store.type == "snowflake.online": - self.materialize_to_snowflake_online_store( - self.repo_config, - fv_to_proto_sql, - feature_view, - project, - ) + rows_to_write = entities_to_write * len(features_full_list) else: - self.materialize_to_external_online_store( - self.repo_config, - fv_to_proto_sql, - feature_view, - tqdm_builder, - ) + rows_to_write = entities_to_write * len(feature_batches) + + with tqdm_builder(rows_to_write) as pbar: + for i, feature_batch in enumerate(feature_batches): + fv_to_proto_sql = self.generate_snowflake_materialization_query( + self.repo_config, + fv_latest_mapped_values_sql, + feature_view, + feature_batch, + project, + ) + + if self.repo_config.online_store.type == "snowflake.online": + self.materialize_to_snowflake_online_store( + self.repo_config, + fv_to_proto_sql, + feature_view, + project, + ) + pbar.update(entities_to_write * len(feature_batch)) + else: + self.materialize_to_external_online_store( + self.repo_config, + fv_to_proto_sql, + feature_view, + pbar, + ) return SnowflakeMaterializationJob( job_id=job_id, status=MaterializationJobStatus.SUCCEEDED @@ -316,6 +349,7 @@ def generate_snowflake_materialization_query( repo_config: RepoConfig, fv_latest_mapped_values_sql: str, feature_view: Union[BatchFeatureView, FeatureView], + feature_batch: list, project: str, ) -> str: @@ -338,7 +372,7 @@ def generate_snowflake_materialization_query( UDF serialization function. """ feature_sql_list = [] - for feature in feature_view.features: + for feature in feature_batch: feature_value_type_name = feature.dtype.to_value_type().name feature_sql = _convert_value_name_to_snowflake_udf( @@ -434,11 +468,8 @@ def materialize_to_snowflake_online_store( """ with GetSnowflakeConnection(repo_config.batch_engine) as conn: - query_id = execute_snowflake_statement(conn, query).sfqid + execute_snowflake_statement(conn, query).sfqid - click.echo( - f"Snowflake Query ID: {Style.BRIGHT + Fore.GREEN}{query_id}{Style.RESET_ALL}" - ) return None def materialize_to_external_online_store( @@ -446,7 +477,7 @@ def materialize_to_external_online_store( repo_config: RepoConfig, materialization_sql: str, feature_view: Union[StreamFeatureView, FeatureView], - tqdm_builder: Callable[[int], tqdm], + pbar: tqdm, ) -> None: feature_names = [feature.name for feature in feature_view.features] @@ -455,10 +486,6 @@ def materialize_to_external_online_store( query = materialization_sql cursor = execute_snowflake_statement(conn, query) for i, df in enumerate(cursor.fetch_pandas_batches()): - click.echo( - f"Snowflake: Processing Materialization ResultSet Batch #{i+1}" - ) - entity_keys = ( df["entity_key"].apply(EntityKeyProto.FromString).to_numpy() ) @@ -494,11 +521,10 @@ def materialize_to_external_online_store( ) ) - with tqdm_builder(len(rows_to_write)) as pbar: - self.online_store.online_write_batch( - repo_config, - feature_view, - rows_to_write, - lambda x: pbar.update(x), - ) + self.online_store.online_write_batch( + repo_config, + feature_view, + rows_to_write, + lambda x: pbar.update(x), + ) return None