diff --git a/sdk/python/feast/infra/online_stores/dynamodb.py b/sdk/python/feast/infra/online_stores/dynamodb.py index c161a5b9554..61334be1a92 100644 --- a/sdk/python/feast/infra/online_stores/dynamodb.py +++ b/sdk/python/feast/infra/online_stores/dynamodb.py @@ -17,7 +17,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple from pydantic import StrictStr -from pydantic.typing import Literal +from pydantic.typing import Literal, Union from feast import Entity, FeatureView, utils from feast.infra.infra_object import DYNAMODB_INFRA_OBJECT_CLASS_TYPE, InfraObject @@ -50,17 +50,20 @@ class DynamoDBOnlineStoreConfig(FeastConfigBaseModel): type: Literal["dynamodb"] = "dynamodb" """Online store type selector""" + batch_size: int = 40 + """Number of items to retrieve in a DynamoDB BatchGetItem call.""" + + endpoint_url: Union[str, None] = None + """DynamoDB local development endpoint Url, i.e. http://localhost:8000""" + region: StrictStr """AWS Region Name""" - table_name_template: StrictStr = "{project}.{table_name}" - """DynamoDB table name template""" - sort_response: bool = True """Whether or not to sort BatchGetItem response.""" - batch_size: int = 40 - """Number of items to retrieve in a DynamoDB BatchGetItem call.""" + table_name_template: StrictStr = "{project}.{table_name}" + """DynamoDB table name template""" class DynamoDBOnlineStore(OnlineStore): @@ -95,8 +98,12 @@ def update( """ online_config = config.online_store assert isinstance(online_config, DynamoDBOnlineStoreConfig) - dynamodb_client = self._get_dynamodb_client(online_config.region) - dynamodb_resource = self._get_dynamodb_resource(online_config.region) + dynamodb_client = self._get_dynamodb_client( + online_config.region, online_config.endpoint_url + ) + dynamodb_resource = self._get_dynamodb_resource( + online_config.region, online_config.endpoint_url + ) for table_instance in tables_to_keep: try: @@ -141,7 +148,9 @@ def teardown( """ online_config = config.online_store assert isinstance(online_config, DynamoDBOnlineStoreConfig) - dynamodb_resource = self._get_dynamodb_resource(online_config.region) + dynamodb_resource = self._get_dynamodb_resource( + online_config.region, online_config.endpoint_url + ) for table in tables: _delete_table_idempotent( @@ -175,7 +184,9 @@ def online_write_batch( """ online_config = config.online_store assert isinstance(online_config, DynamoDBOnlineStoreConfig) - dynamodb_resource = self._get_dynamodb_resource(online_config.region) + dynamodb_resource = self._get_dynamodb_resource( + online_config.region, online_config.endpoint_url + ) table_instance = dynamodb_resource.Table( _get_table_name(online_config, config, table) @@ -217,7 +228,9 @@ def online_read( """ online_config = config.online_store assert isinstance(online_config, DynamoDBOnlineStoreConfig) - dynamodb_resource = self._get_dynamodb_resource(online_config.region) + dynamodb_resource = self._get_dynamodb_resource( + online_config.region, online_config.endpoint_url + ) table_instance = dynamodb_resource.Table( _get_table_name(online_config, config, table) ) @@ -260,14 +273,16 @@ def online_read( result.extend(batch_size_nones) return result - def _get_dynamodb_client(self, region: str): + def _get_dynamodb_client(self, region: str, endpoint_url: Optional[str] = None): if self._dynamodb_client is None: - self._dynamodb_client = _initialize_dynamodb_client(region) + self._dynamodb_client = _initialize_dynamodb_client(region, endpoint_url) return self._dynamodb_client - def _get_dynamodb_resource(self, region: str): + def _get_dynamodb_resource(self, region: str, endpoint_url: Optional[str] = None): if self._dynamodb_resource is None: - self._dynamodb_resource = _initialize_dynamodb_resource(region) + self._dynamodb_resource = _initialize_dynamodb_resource( + region, endpoint_url + ) return self._dynamodb_resource def _sort_dynamodb_response(self, responses: list, order: list): @@ -285,12 +300,12 @@ def _sort_dynamodb_response(self, responses: list, order: list): return table_responses_ordered -def _initialize_dynamodb_client(region: str): - return boto3.client("dynamodb", region_name=region) +def _initialize_dynamodb_client(region: str, endpoint_url: Optional[str] = None): + return boto3.client("dynamodb", region_name=region, endpoint_url=endpoint_url) -def _initialize_dynamodb_resource(region: str): - return boto3.resource("dynamodb", region_name=region) +def _initialize_dynamodb_resource(region: str, endpoint_url: Optional[str] = None): + return boto3.resource("dynamodb", region_name=region, endpoint_url=endpoint_url) # TODO(achals): This form of user-facing templating is experimental. @@ -327,13 +342,20 @@ class DynamoDBTable(InfraObject): Attributes: name: The name of the table. region: The region of the table. + endpoint_url: Local DynamoDB Endpoint Url. + _dynamodb_client: Boto3 DynamoDB client. + _dynamodb_resource: Boto3 DynamoDB resource. """ region: str + endpoint_url = None + _dynamodb_client = None + _dynamodb_resource = None - def __init__(self, name: str, region: str): + def __init__(self, name: str, region: str, endpoint_url: Optional[str] = None): super().__init__(name) self.region = region + self.endpoint_url = endpoint_url def to_infra_object_proto(self) -> InfraObjectProto: dynamodb_table_proto = self.to_proto() @@ -362,8 +384,8 @@ def from_proto(dynamodb_table_proto: DynamoDBTableProto) -> Any: ) def update(self): - dynamodb_client = _initialize_dynamodb_client(region=self.region) - dynamodb_resource = _initialize_dynamodb_resource(region=self.region) + dynamodb_client = self._get_dynamodb_client(self.region, self.endpoint_url) + dynamodb_resource = self._get_dynamodb_resource(self.region, self.endpoint_url) try: dynamodb_resource.create_table( @@ -384,5 +406,17 @@ def update(self): dynamodb_client.get_waiter("table_exists").wait(TableName=f"{self.name}") def teardown(self): - dynamodb_resource = _initialize_dynamodb_resource(region=self.region) + dynamodb_resource = self._get_dynamodb_resource(self.region, self.endpoint_url) _delete_table_idempotent(dynamodb_resource, self.name) + + def _get_dynamodb_client(self, region: str, endpoint_url: Optional[str] = None): + if self._dynamodb_client is None: + self._dynamodb_client = _initialize_dynamodb_client(region, endpoint_url) + return self._dynamodb_client + + def _get_dynamodb_resource(self, region: str, endpoint_url: Optional[str] = None): + if self._dynamodb_resource is None: + self._dynamodb_resource = _initialize_dynamodb_resource( + region, endpoint_url + ) + return self._dynamodb_resource diff --git a/sdk/python/tests/unit/infra/online_store/test_dynamodb_online_store.py b/sdk/python/tests/unit/infra/online_store/test_dynamodb_online_store.py index 0f42230ef53..7b0c5a4a619 100644 --- a/sdk/python/tests/unit/infra/online_store/test_dynamodb_online_store.py +++ b/sdk/python/tests/unit/infra/online_store/test_dynamodb_online_store.py @@ -7,6 +7,7 @@ from feast.infra.online_stores.dynamodb import ( DynamoDBOnlineStore, DynamoDBOnlineStoreConfig, + DynamoDBTable, ) from feast.repo_config import RepoConfig from tests.utils.online_store_utils import ( @@ -38,6 +39,121 @@ def repo_config(): ) +def test_online_store_config_default(): + """Test DynamoDBOnlineStoreConfig default parameters.""" + aws_region = "us-west-2" + dynamodb_store_config = DynamoDBOnlineStoreConfig(region=aws_region) + assert dynamodb_store_config.type == "dynamodb" + assert dynamodb_store_config.batch_size == 40 + assert dynamodb_store_config.endpoint_url is None + assert dynamodb_store_config.region == aws_region + assert dynamodb_store_config.sort_response is True + assert dynamodb_store_config.table_name_template == "{project}.{table_name}" + + +def test_dynamodb_table_default_params(): + """Test DynamoDBTable default parameters.""" + tbl_name = "dynamodb-test" + aws_region = "us-west-2" + dynamodb_table = DynamoDBTable(tbl_name, aws_region) + assert dynamodb_table.name == tbl_name + assert dynamodb_table.region == aws_region + assert dynamodb_table.endpoint_url is None + assert dynamodb_table._dynamodb_client is None + assert dynamodb_table._dynamodb_resource is None + + +def test_online_store_config_custom_params(): + """Test DynamoDBOnlineStoreConfig custom parameters.""" + aws_region = "us-west-2" + batch_size = 20 + endpoint_url = "http://localhost:8000" + sort_response = False + table_name_template = "feast_test.dynamodb_table" + dynamodb_store_config = DynamoDBOnlineStoreConfig( + region=aws_region, + batch_size=batch_size, + endpoint_url=endpoint_url, + sort_response=sort_response, + table_name_template=table_name_template, + ) + assert dynamodb_store_config.type == "dynamodb" + assert dynamodb_store_config.batch_size == batch_size + assert dynamodb_store_config.endpoint_url == endpoint_url + assert dynamodb_store_config.region == aws_region + assert dynamodb_store_config.sort_response == sort_response + assert dynamodb_store_config.table_name_template == table_name_template + + +def test_dynamodb_table_custom_params(): + """Test DynamoDBTable custom parameters.""" + tbl_name = "dynamodb-test" + aws_region = "us-west-2" + endpoint_url = "http://localhost:8000" + dynamodb_table = DynamoDBTable(tbl_name, aws_region, endpoint_url) + assert dynamodb_table.name == tbl_name + assert dynamodb_table.region == aws_region + assert dynamodb_table.endpoint_url == endpoint_url + assert dynamodb_table._dynamodb_client is None + assert dynamodb_table._dynamodb_resource is None + + +def test_online_store_config_dynamodb_client(): + """Test DynamoDBOnlineStoreConfig configure DynamoDB client with endpoint_url.""" + aws_region = "us-west-2" + endpoint_url = "http://localhost:8000" + dynamodb_store = DynamoDBOnlineStore() + dynamodb_store_config = DynamoDBOnlineStoreConfig( + region=aws_region, endpoint_url=endpoint_url + ) + dynamodb_client = dynamodb_store._get_dynamodb_client( + dynamodb_store_config.region, dynamodb_store_config.endpoint_url + ) + assert dynamodb_client.meta.region_name == aws_region + assert dynamodb_client.meta.endpoint_url == endpoint_url + + +def test_dynamodb_table_dynamodb_client(): + """Test DynamoDBTable configure DynamoDB client with endpoint_url.""" + tbl_name = "dynamodb-test" + aws_region = "us-west-2" + endpoint_url = "http://localhost:8000" + dynamodb_table = DynamoDBTable(tbl_name, aws_region, endpoint_url) + dynamodb_client = dynamodb_table._get_dynamodb_client( + dynamodb_table.region, dynamodb_table.endpoint_url + ) + assert dynamodb_client.meta.region_name == aws_region + assert dynamodb_client.meta.endpoint_url == endpoint_url + + +def test_online_store_config_dynamodb_resource(): + """Test DynamoDBOnlineStoreConfig configure DynamoDB Resource with endpoint_url.""" + aws_region = "us-west-2" + endpoint_url = "http://localhost:8000" + dynamodb_store = DynamoDBOnlineStore() + dynamodb_store_config = DynamoDBOnlineStoreConfig( + region=aws_region, endpoint_url=endpoint_url + ) + dynamodb_resource = dynamodb_store._get_dynamodb_resource( + dynamodb_store_config.region, dynamodb_store_config.endpoint_url + ) + assert dynamodb_resource.meta.client.meta.region_name == aws_region + assert dynamodb_resource.meta.client.meta.endpoint_url == endpoint_url + + +def test_dynamodb_table_dynamodb_resource(): + """Test DynamoDBTable configure DynamoDB resource with endpoint_url.""" + tbl_name = "dynamodb-test" + aws_region = "us-west-2" + endpoint_url = "http://localhost:8000" + dynamodb_table = DynamoDBTable(tbl_name, aws_region, endpoint_url) + dynamodb_resource = dynamodb_table._get_dynamodb_resource( + dynamodb_table.region, dynamodb_table.endpoint_url + ) + assert dynamodb_resource.meta.client.meta.region_name == aws_region + assert dynamodb_resource.meta.client.meta.endpoint_url == endpoint_url + + @mock_dynamodb2 @pytest.mark.parametrize("n_samples", [5, 50, 100]) def test_online_read(repo_config, n_samples):