diff --git a/docarray/array/mixins/io/pushpull.py b/docarray/array/mixins/io/pushpull.py index 80f626b9d0c..2926a30a59f 100644 --- a/docarray/array/mixins/io/pushpull.py +++ b/docarray/array/mixins/io/pushpull.py @@ -27,6 +27,14 @@ def _get_hub_config() -> Optional[Dict]: return json.load(f) +@lru_cache() +def _get_auth_token() -> Optional[str]: + hub_config = _get_hub_config() + config_auth_token = hub_config.get('auth_token') if hub_config else None + env_auth_token = os.environ.get('JINA_AUTH_TOKEN') + return env_auth_token or config_auth_token + + @lru_cache() def _get_cloud_api() -> str: """Get Cloud Api for transmitting data to the cloud. @@ -73,9 +81,8 @@ def push(self, name: str, show_progress: bool = False, public: bool = True) -> D headers = {'Content-Type': ctype, **get_request_header()} - _hub_config = _get_hub_config() - if _hub_config: - auth_token = _hub_config.get('auth_token') + auth_token = _get_auth_token() + if auth_token: headers['Authorization'] = f'token {auth_token}' _head, _tail = data.split(delimiter) @@ -146,9 +153,8 @@ def pull( headers = {} - _hub_config = _get_hub_config() - if _hub_config: - auth_token = _hub_config.get('auth_token') + auth_token = _get_auth_token() + if auth_token: headers['Authorization'] = f'token {auth_token}' url = f'{_get_cloud_api()}/v2/rpc/artifact.getDownloadUrl?name={name}' diff --git a/tests/conftest.py b/tests/conftest.py index 09989e06156..7399317f1b8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ import tempfile import os import time +from typing import Dict import pytest @@ -33,3 +34,12 @@ def start_storage(): f"docker-compose -f {compose_yml} --project-directory . down " f"--remove-orphans" ) + + +@pytest.fixture(scope='session') +def set_env_vars(request): + _old_environ = dict(os.environ) + os.environ.update(request.param) + yield + os.environ.clear() + os.environ.update(_old_environ) diff --git a/tests/unit/array/mixins/test_pushpull.py b/tests/unit/array/mixins/test_pushpull.py index 491ecf5e636..c651e847aae 100644 --- a/tests/unit/array/mixins/test_pushpull.py +++ b/tests/unit/array/mixins/test_pushpull.py @@ -160,10 +160,12 @@ def test_api_url_change(mocker, monkeypatch): assert pull_kwargs['url'].startswith(test_api_url) -def test_api_authorization_header(mocker, monkeypatch, tmpdir): - from docarray.array.mixins.io.pushpull import _get_hub_config +def test_api_authorization_header_from_config(mocker, monkeypatch, tmpdir): + from docarray.array.mixins.io.pushpull import _get_hub_config, _get_auth_token _get_hub_config.cache_clear() + _get_auth_token.cache_clear() + os.environ['JINA_HUB_ROOT'] = str(tmpdir) token = 'test-auth-token' @@ -179,7 +181,9 @@ def test_api_authorization_header(mocker, monkeypatch, tmpdir): DocumentArray.pull(name='test_name') del os.environ['JINA_HUB_ROOT'] + _get_hub_config.cache_clear() + _get_auth_token.cache_clear() assert mock.call_count == 3 # 1 for push, 1 for pull, 1 for download @@ -188,3 +192,71 @@ def test_api_authorization_header(mocker, monkeypatch, tmpdir): assert push_kwargs['headers'].get('Authorization') == f'token {token}' assert pull_kwargs['headers'].get('Authorization') == f'token {token}' + + +@pytest.mark.parametrize( + 'set_env_vars', [{'JINA_AUTH_TOKEN': 'test-auth-token'}], indirect=True +) +def test_api_authorization_header_from_env(mocker, monkeypatch, set_env_vars): + from docarray.array.mixins.io.pushpull import _get_hub_config, _get_auth_token + + _get_hub_config.cache_clear() + _get_auth_token.cache_clear() + + mock = mocker.Mock() + _mock_post(mock, monkeypatch) + _mock_get(mock, monkeypatch) + + docs = random_docs(2) + docs.push(name='test_name') + DocumentArray.pull(name='test_name') + + _get_hub_config.cache_clear() + _get_auth_token.cache_clear() + + assert mock.call_count == 3 # 1 for push, 1 for pull, 1 for download + + _, push_kwargs = mock.call_args_list[0] + _, pull_kwargs = mock.call_args_list[1] + + assert push_kwargs['headers'].get('Authorization') == 'token test-auth-token' + assert pull_kwargs['headers'].get('Authorization') == 'token test-auth-token' + + +@pytest.mark.parametrize( + 'set_env_vars', [{'JINA_AUTH_TOKEN': 'test-auth-token-env'}], indirect=True +) +def test_api_authorization_header_env_and_config( + mocker, monkeypatch, tmpdir, set_env_vars +): + from docarray.array.mixins.io.pushpull import _get_hub_config, _get_auth_token + + _get_hub_config.cache_clear() + _get_auth_token.cache_clear() + + os.environ['JINA_HUB_ROOT'] = str(tmpdir) + + token = 'test-auth-token-config' + with open(tmpdir / JINA_CLOUD_CONFIG, 'w') as f: + json.dump({'auth_token': token}, f) + + mock = mocker.Mock() + _mock_post(mock, monkeypatch) + _mock_get(mock, monkeypatch) + + docs = random_docs(2) + docs.push(name='test_name') + DocumentArray.pull(name='test_name') + + del os.environ['JINA_HUB_ROOT'] + + _get_hub_config.cache_clear() + _get_auth_token.cache_clear() + + assert mock.call_count == 3 # 1 for push, 1 for pull, 1 for download + + _, push_kwargs = mock.call_args_list[0] + _, pull_kwargs = mock.call_args_list[1] + + assert push_kwargs['headers'].get('Authorization') == 'token test-auth-token-env' + assert pull_kwargs['headers'].get('Authorization') == 'token test-auth-token-env'