From db3b7eb8b8a5d351ca0eab7451c69b93eaa45920 Mon Sep 17 00:00:00 2001 From: tokoko Date: Mon, 20 May 2024 23:48:05 +0000 Subject: [PATCH 1/4] refactor it test environment setup Signed-off-by: tokoko --- sdk/python/tests/conftest.py | 7 +- .../feature_repos/repo_configuration.py | 76 +++++++++++-------- .../universal/data_source_creator.py | 5 +- .../contrib/spark/test_spark.py | 2 + .../registration/test_universal_cli.py | 9 ++- .../registration/test_universal_types.py | 9 +-- sdk/python/tests/utils/e2e_test_validation.py | 9 ++- 7 files changed, 68 insertions(+), 49 deletions(-) diff --git a/sdk/python/tests/conftest.py b/sdk/python/tests/conftest.py index c4a62be0c0a..7c875fc9bde 100644 --- a/sdk/python/tests/conftest.py +++ b/sdk/python/tests/conftest.py @@ -182,16 +182,15 @@ def environment(request, worker_id): request.param, worker_id=worker_id, fixture_request=request ) + e.setup() + if hasattr(e.data_source_creator, "mock_environ"): with mock.patch.dict(os.environ, e.data_source_creator.mock_environ): yield e else: yield e - e.feature_store.teardown() - e.data_source_creator.teardown() - if e.online_store_creator: - e.online_store_creator.teardown() + e.teardown() _config_cache: Any = {} diff --git a/sdk/python/tests/integration/feature_repos/repo_configuration.py b/sdk/python/tests/integration/feature_repos/repo_configuration.py index 311325536ed..ffdbbb36c1f 100644 --- a/sdk/python/tests/integration/feature_repos/repo_configuration.py +++ b/sdk/python/tests/integration/feature_repos/repo_configuration.py @@ -1,6 +1,5 @@ import dataclasses import importlib -import json import os import tempfile import uuid @@ -11,13 +10,15 @@ import pandas as pd import pytest -import yaml from feast import FeatureStore, FeatureView, OnDemandFeatureView, driver_test_data from feast.constants import FULL_REPO_CONFIGS_MODULE_ENV_NAME from feast.data_source import DataSource from feast.errors import FeastModuleImportError -from feast.infra.feature_servers.base_config import FeatureLoggingConfig +from feast.infra.feature_servers.base_config import ( + BaseFeatureServerConfig, + FeatureLoggingConfig, +) from feast.infra.feature_servers.local_process.config import LocalFeatureServerConfig from feast.repo_config import RegistryConfig, RepoConfig from tests.integration.feature_repos.integration_test_repo_config import ( @@ -397,18 +398,48 @@ def construct_universal_feature_views( @dataclass class Environment: name: str - test_repo_config: IntegrationTestRepoConfig - feature_store: FeatureStore + project: str + provider: str + registry: RegistryConfig data_source_creator: DataSourceCreator + online_store_creator: Optional[OnlineStoreCreator] + online_store: Optional[Union[str, Dict]] + batch_engine: Optional[Union[str, Dict]] python_feature_server: bool worker_id: str - online_store_creator: Optional[OnlineStoreCreator] = None + feature_server: BaseFeatureServerConfig + entity_key_serialization_version: int + repo_dir_name: str fixture_request: Optional[pytest.FixtureRequest] = None def __post_init__(self): self.end_date = datetime.utcnow().replace(microsecond=0, second=0, minute=0) self.start_date: datetime = self.end_date - timedelta(days=3) + def setup(self): + self.data_source_creator.setup(self.registry) + + config = RepoConfig( + registry=self.registry, + project=self.project, + provider=self.provider, + offline_store=self.data_source_creator.create_offline_store_config(), + online_store=self.online_store_creator.create_online_store() + if self.online_store_creator + else self.online_store, + batch_engine=self.batch_engine, + repo_path=self.repo_dir_name, + feature_server=self.feature_server, + entity_key_serialization_version=self.entity_key_serialization_version, + ) + self.feature_store = FeatureStore(config=config) + + def teardown(self): + self.feature_store.teardown() + self.data_source_creator.teardown() + if self.online_store_creator: + self.online_store_creator.teardown() + def table_name_from_data_source(ds: DataSource) -> Optional[str]: if hasattr(ds, "table_ref"): @@ -436,16 +467,13 @@ def construct_test_environment( offline_creator: DataSourceCreator = test_repo_config.offline_store_creator( project, fixture_request=fixture_request ) - offline_store_config = offline_creator.create_offline_store_config() if test_repo_config.online_store_creator: online_creator = test_repo_config.online_store_creator( project, fixture_request=fixture_request ) - online_store = online_creator.create_online_store() else: online_creator = None - online_store = test_repo_config.online_store if test_repo_config.python_feature_server and test_repo_config.provider == "aws": from feast.infra.feature_servers.aws_lambda.config import ( @@ -481,35 +509,21 @@ def construct_test_environment( cache_ttl_seconds=1, ) - config = RepoConfig( - registry=registry, - project=project, - provider=test_repo_config.provider, - offline_store=offline_store_config, - online_store=online_store, - batch_engine=test_repo_config.batch_engine, - repo_path=repo_dir_name, - feature_server=feature_server, - entity_key_serialization_version=entity_key_serialization_version, - ) - - # Create feature_store.yaml out of the config - with open(Path(repo_dir_name) / "feature_store.yaml", "w") as f: - yaml.safe_dump(json.loads(config.model_dump_json(by_alias=True)), f) - - fs = FeatureStore(repo_dir_name) - # We need to initialize the registry, because if nothing is applied in the test before tearing down - # the feature store, that will cause the teardown method to blow up. - fs.registry._initialize_registry(project) environment = Environment( name=project, - test_repo_config=test_repo_config, - feature_store=fs, + provider=test_repo_config.provider, data_source_creator=offline_creator, python_feature_server=test_repo_config.python_feature_server, worker_id=worker_id, online_store_creator=online_creator, fixture_request=fixture_request, + project=project, + registry=registry, + feature_server=feature_server, + entity_key_serialization_version=entity_key_serialization_version, + repo_dir_name=repo_dir_name, + batch_engine=test_repo_config.batch_engine, + online_store=test_repo_config.online_store, ) return environment diff --git a/sdk/python/tests/integration/feature_repos/universal/data_source_creator.py b/sdk/python/tests/integration/feature_repos/universal/data_source_creator.py index 5e5062291d5..62d458d6f4a 100644 --- a/sdk/python/tests/integration/feature_repos/universal/data_source_creator.py +++ b/sdk/python/tests/integration/feature_repos/universal/data_source_creator.py @@ -5,7 +5,7 @@ from feast.data_source import DataSource from feast.feature_logging import LoggingDestination -from feast.repo_config import FeastConfigBaseModel +from feast.repo_config import FeastConfigBaseModel, RegistryConfig from feast.saved_dataset import SavedDatasetStorage @@ -44,6 +44,9 @@ def create_data_source( """ raise NotImplementedError + def setup(self, registry: RegistryConfig): + pass + @abstractmethod def create_offline_store_config(self) -> FeastConfigBaseModel: raise NotImplementedError diff --git a/sdk/python/tests/integration/materialization/contrib/spark/test_spark.py b/sdk/python/tests/integration/materialization/contrib/spark/test_spark.py index bb4c4e63fc2..ae0e03c9441 100644 --- a/sdk/python/tests/integration/materialization/contrib/spark/test_spark.py +++ b/sdk/python/tests/integration/materialization/contrib/spark/test_spark.py @@ -34,6 +34,8 @@ def test_spark_materialization_consistency(): spark_config, None, entity_key_serialization_version=2 ) + spark_environment.setup() + df = create_basic_driver_dataset() ds = spark_environment.data_source_creator.create_data_source( diff --git a/sdk/python/tests/integration/registration/test_universal_cli.py b/sdk/python/tests/integration/registration/test_universal_cli.py index e7f7a7cb633..e7331a07894 100644 --- a/sdk/python/tests/integration/registration/test_universal_cli.py +++ b/sdk/python/tests/integration/registration/test_universal_cli.py @@ -27,9 +27,10 @@ def test_universal_cli(environment: Environment): repo_path = Path(repo_dir_name) feature_store_yaml = make_feature_store_yaml( project, - environment.test_repo_config, repo_path, environment.data_source_creator, + environment.provider, + environment.online_store, ) repo_config = repo_path / "feature_store.yaml" @@ -124,9 +125,10 @@ def test_odfv_apply(environment) -> None: repo_path = Path(repo_dir_name) feature_store_yaml = make_feature_store_yaml( project, - environment.test_repo_config, repo_path, environment.data_source_creator, + environment.provider, + environment.online_store, ) repo_config = repo_path / "feature_store.yaml" @@ -158,9 +160,10 @@ def test_nullable_online_store(test_nullable_online_store) -> None: repo_path = Path(repo_dir_name) feature_store_yaml = make_feature_store_yaml( project, - test_nullable_online_store, repo_path, test_nullable_online_store.offline_store_creator(project), + test_nullable_online_store.provider, + test_nullable_online_store.online_store, ) repo_config = repo_path / "feature_store.yaml" diff --git a/sdk/python/tests/integration/registration/test_universal_types.py b/sdk/python/tests/integration/registration/test_universal_types.py index 3ce5876bd60..ca15681c9b2 100644 --- a/sdk/python/tests/integration/registration/test_universal_types.py +++ b/sdk/python/tests/integration/registration/test_universal_types.py @@ -110,7 +110,7 @@ def test_feature_get_historical_features_types_match( if config.feature_is_list: assert_feature_list_types( - environment.test_repo_config.provider, + environment.provider, config.feature_dtype, historical_features_df, ) @@ -119,7 +119,7 @@ def test_feature_get_historical_features_types_match( config.feature_dtype, historical_features_df ) assert_expected_arrow_types( - environment.test_repo_config.provider, + environment.provider, config.feature_dtype, config.feature_is_list, historical_features, @@ -335,10 +335,7 @@ class TypeTestConfig: ) def offline_types_test_fixtures(request, environment): config: TypeTestConfig = request.param - if ( - environment.test_repo_config.provider == "aws" - and config.feature_is_list is True - ): + if environment.provider == "aws" and config.feature_is_list is True: pytest.skip("Redshift doesn't support list features") return get_fixtures(request, environment) diff --git a/sdk/python/tests/utils/e2e_test_validation.py b/sdk/python/tests/utils/e2e_test_validation.py index 798e82de9b9..bcfc16f6e38 100644 --- a/sdk/python/tests/utils/e2e_test_validation.py +++ b/sdk/python/tests/utils/e2e_test_validation.py @@ -3,7 +3,7 @@ import time from datetime import datetime, timedelta from pathlib import Path -from typing import List, Optional +from typing import Dict, List, Optional, Union import pandas as pd import pytest @@ -176,17 +176,18 @@ def _check_offline_and_online_features( def make_feature_store_yaml( project, - test_repo_config, repo_dir_name: Path, offline_creator: DataSourceCreator, + provider: str, + online_store: Optional[Union[str, Dict]], ): offline_store_config = offline_creator.create_offline_store_config() - online_store = test_repo_config.online_store + online_store = online_store config = RepoConfig( registry=str(Path(repo_dir_name) / "registry.db"), project=project, - provider=test_repo_config.provider, + provider=provider, offline_store=offline_store_config, online_store=online_store, repo_path=str(Path(repo_dir_name)), From cb2abc3c931da698cb96661a97debac97442e55d Mon Sep 17 00:00:00 2001 From: tokoko Date: Tue, 21 May 2024 00:23:50 +0000 Subject: [PATCH 2/4] fix test_offline_store env setup Signed-off-by: tokoko --- .../feature_repos/repo_configuration.py | 4 +-- .../offline_stores/test_offline_store.py | 31 +++++++++---------- 2 files changed, 16 insertions(+), 19 deletions(-) diff --git a/sdk/python/tests/integration/feature_repos/repo_configuration.py b/sdk/python/tests/integration/feature_repos/repo_configuration.py index ffdbbb36c1f..2f260e87a60 100644 --- a/sdk/python/tests/integration/feature_repos/repo_configuration.py +++ b/sdk/python/tests/integration/feature_repos/repo_configuration.py @@ -419,7 +419,7 @@ def __post_init__(self): def setup(self): self.data_source_creator.setup(self.registry) - config = RepoConfig( + self.config = RepoConfig( registry=self.registry, project=self.project, provider=self.provider, @@ -432,7 +432,7 @@ def setup(self): feature_server=self.feature_server, entity_key_serialization_version=self.entity_key_serialization_version, ) - self.feature_store = FeatureStore(config=config) + self.feature_store = FeatureStore(config=self.config) def teardown(self): self.feature_store.teardown() diff --git a/sdk/python/tests/unit/infra/offline_stores/test_offline_store.py b/sdk/python/tests/unit/infra/offline_stores/test_offline_store.py index e5768a81b21..79a3a27b67a 100644 --- a/sdk/python/tests/unit/infra/offline_stores/test_offline_store.py +++ b/sdk/python/tests/unit/infra/offline_stores/test_offline_store.py @@ -10,7 +10,6 @@ AthenaRetrievalJob, ) from feast.infra.offline_stores.contrib.mssql_offline_store.mssql import ( - MsSqlServerOfflineStoreConfig, MsSqlServerRetrievalJob, ) from feast.infra.offline_stores.contrib.postgres_offline_store.postgres import ( @@ -120,12 +119,14 @@ def retrieval_job(request, environment): iam_role="arn:aws:iam::585132637328:role/service-role/AmazonRedshift-CommandsAccessRole-20240403T092631", workgroup="", ) - environment.test_repo_config.offline_store = offline_store_config + config = environment.config.copy( + update={"offline_config": offline_store_config} + ) return RedshiftRetrievalJob( query="query", redshift_client="", s3_resource="", - config=environment.test_repo_config, + config=config, full_feature_names=False, ) elif request.param is SnowflakeRetrievalJob: @@ -141,12 +142,14 @@ def retrieval_job(request, environment): storage_integration_name="FEAST_S3", blob_export_location="s3://feast-snowflake-offload/export", ) - environment.test_repo_config.offline_store = offline_store_config - environment.test_repo_config.project = "project" + config = environment.config.copy( + update={"offline_config": offline_store_config} + ) + environment.project = "project" return SnowflakeRetrievalJob( query="query", snowflake_conn=MagicMock(), - config=environment.test_repo_config, + config=config, full_feature_names=False, ) elif request.param is AthenaRetrievalJob: @@ -158,21 +161,18 @@ def retrieval_job(request, environment): s3_staging_location="athena", ) - environment.test_repo_config.offline_store = offline_store_config return AthenaRetrievalJob( query="query", athena_client="client", s3_resource="", - config=environment.test_repo_config.offline_store, + config=environment.config, full_feature_names=False, ) elif request.param is MsSqlServerRetrievalJob: return MsSqlServerRetrievalJob( query="query", engine=MagicMock(), - config=MsSqlServerOfflineStoreConfig( - connection_string="str" - ), # TODO: this does not match the RetrievalJob pattern. Suppose to be RepoConfig + config=environment.config, full_feature_names=False, ) elif request.param is PostgreSQLRetrievalJob: @@ -182,28 +182,25 @@ def retrieval_job(request, environment): user="str", password="str", ) - environment.test_repo_config.offline_store = offline_store_config return PostgreSQLRetrievalJob( query="query", - config=environment.test_repo_config.offline_store, + config=environment.config, full_feature_names=False, ) elif request.param is SparkRetrievalJob: offline_store_config = SparkOfflineStoreConfig() - environment.test_repo_config.offline_store = offline_store_config return SparkRetrievalJob( spark_session=MagicMock(), query="str", full_feature_names=False, - config=environment.test_repo_config, + config=environment.config, ) elif request.param is TrinoRetrievalJob: offline_store_config = SparkOfflineStoreConfig() - environment.test_repo_config.offline_store = offline_store_config return TrinoRetrievalJob( query="str", client=MagicMock(), - config=environment.test_repo_config, + config=environment.config, full_feature_names=False, ) else: From 3b1278b3e3ec02865b36b376aa31c423c7828bca Mon Sep 17 00:00:00 2001 From: tokoko Date: Tue, 21 May 2024 07:20:51 +0000 Subject: [PATCH 3/4] fix it tests Signed-off-by: tokoko --- sdk/python/tests/integration/materialization/test_snowflake.py | 3 +++ .../offline_store/test_universal_historical_retrieval.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/sdk/python/tests/integration/materialization/test_snowflake.py b/sdk/python/tests/integration/materialization/test_snowflake.py index 60fa9b30aab..adb2bd7e7df 100644 --- a/sdk/python/tests/integration/materialization/test_snowflake.py +++ b/sdk/python/tests/integration/materialization/test_snowflake.py @@ -52,6 +52,7 @@ def test_snowflake_materialization_consistency(online_store): batch_engine=SNOWFLAKE_ENGINE_CONFIG, ) snowflake_environment = construct_test_environment(snowflake_config, None) + snowflake_environment.setup() df = create_basic_driver_dataset() ds = snowflake_environment.data_source_creator.create_data_source( @@ -112,6 +113,7 @@ def test_snowflake_materialization_consistency_internal_with_lists( batch_engine=SNOWFLAKE_ENGINE_CONFIG, ) snowflake_environment = construct_test_environment(snowflake_config, None) + snowflake_environment.setup() df = create_basic_driver_dataset(Int32, feature_dtype, True, feature_is_empty_list) ds = snowflake_environment.data_source_creator.create_data_source( @@ -195,6 +197,7 @@ def test_snowflake_materialization_entityless_fv(): batch_engine=SNOWFLAKE_ENGINE_CONFIG, ) snowflake_environment = construct_test_environment(snowflake_config, None) + snowflake_environment.setup() df = create_basic_driver_dataset() entityless_df = df.drop("driver_id", axis=1) diff --git a/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py b/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py index 2a2820c10a7..303cb42b286 100644 --- a/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py +++ b/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py @@ -269,7 +269,7 @@ def test_historical_features_with_entities_from_query( if not orders_table: raise pytest.skip("Offline source is not sql-based") - data_source_creator = environment.test_repo_config.offline_store_creator + data_source_creator = environment.data_source_creator if data_source_creator.__name__ == SnowflakeDataSourceCreator.__name__: entity_df_query = f""" SELECT "customer_id", "driver_id", "order_id", "origin_id", "destination_id", "event_timestamp" From 0720cf2483947212d663526da0ef1980cec1ebb6 Mon Sep 17 00:00:00 2001 From: tokoko Date: Tue, 21 May 2024 07:48:41 +0000 Subject: [PATCH 4/4] fix it tests Signed-off-by: tokoko --- .../offline_store/test_universal_historical_retrieval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py b/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py index 303cb42b286..a6db7f2535c 100644 --- a/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py +++ b/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py @@ -270,7 +270,7 @@ def test_historical_features_with_entities_from_query( raise pytest.skip("Offline source is not sql-based") data_source_creator = environment.data_source_creator - if data_source_creator.__name__ == SnowflakeDataSourceCreator.__name__: + if isinstance(data_source_creator, SnowflakeDataSourceCreator): entity_df_query = f""" SELECT "customer_id", "driver_id", "order_id", "origin_id", "destination_id", "event_timestamp" FROM "{orders_table}"