Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions docarray/array/mixins/io/pushpull.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ def _get_hub_config() -> Optional[Dict]:
return json.load(f)


@lru_cache()
def _get_auth_token() -> Optional[str]:
hub_config = _get_hub_config()
config_auth_token = hub_config.get('auth_token') if hub_config else None
env_auth_token = os.environ.get('JINA_AUTH_TOKEN')
return env_auth_token or config_auth_token


@lru_cache()
def _get_cloud_api() -> str:
"""Get Cloud Api for transmitting data to the cloud.
Expand Down Expand Up @@ -73,9 +81,8 @@ def push(self, name: str, show_progress: bool = False, public: bool = True) -> D

headers = {'Content-Type': ctype, **get_request_header()}

_hub_config = _get_hub_config()
if _hub_config:
auth_token = _hub_config.get('auth_token')
auth_token = _get_auth_token()
if auth_token:
headers['Authorization'] = f'token {auth_token}'

_head, _tail = data.split(delimiter)
Expand Down Expand Up @@ -146,9 +153,8 @@ def pull(

headers = {}

_hub_config = _get_hub_config()
if _hub_config:
auth_token = _hub_config.get('auth_token')
auth_token = _get_auth_token()
if auth_token:
headers['Authorization'] = f'token {auth_token}'

url = f'{_get_cloud_api()}/v2/rpc/artifact.getDownloadUrl?name={name}'
Expand Down
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import tempfile
import os
import time
from typing import Dict

import pytest

Expand Down Expand Up @@ -33,3 +34,12 @@ def start_storage():
f"docker-compose -f {compose_yml} --project-directory . down "
f"--remove-orphans"
)


@pytest.fixture(scope='session')
def set_env_vars(request):
_old_environ = dict(os.environ)
os.environ.update(request.param)
yield
os.environ.clear()
os.environ.update(_old_environ)
76 changes: 74 additions & 2 deletions tests/unit/array/mixins/test_pushpull.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,12 @@ def test_api_url_change(mocker, monkeypatch):
assert pull_kwargs['url'].startswith(test_api_url)


def test_api_authorization_header(mocker, monkeypatch, tmpdir):
from docarray.array.mixins.io.pushpull import _get_hub_config
def test_api_authorization_header_from_config(mocker, monkeypatch, tmpdir):
from docarray.array.mixins.io.pushpull import _get_hub_config, _get_auth_token

_get_hub_config.cache_clear()
_get_auth_token.cache_clear()

os.environ['JINA_HUB_ROOT'] = str(tmpdir)

token = 'test-auth-token'
Expand All @@ -179,7 +181,9 @@ def test_api_authorization_header(mocker, monkeypatch, tmpdir):
DocumentArray.pull(name='test_name')

del os.environ['JINA_HUB_ROOT']

_get_hub_config.cache_clear()
_get_auth_token.cache_clear()

assert mock.call_count == 3 # 1 for push, 1 for pull, 1 for download

Expand All @@ -188,3 +192,71 @@ def test_api_authorization_header(mocker, monkeypatch, tmpdir):

assert push_kwargs['headers'].get('Authorization') == f'token {token}'
assert pull_kwargs['headers'].get('Authorization') == f'token {token}'


@pytest.mark.parametrize(
'set_env_vars', [{'JINA_AUTH_TOKEN': 'test-auth-token'}], indirect=True
)
def test_api_authorization_header_from_env(mocker, monkeypatch, set_env_vars):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's just add one last test that asserts the priority in case both are set

from docarray.array.mixins.io.pushpull import _get_hub_config, _get_auth_token

_get_hub_config.cache_clear()
_get_auth_token.cache_clear()

mock = mocker.Mock()
_mock_post(mock, monkeypatch)
_mock_get(mock, monkeypatch)

docs = random_docs(2)
docs.push(name='test_name')
DocumentArray.pull(name='test_name')

_get_hub_config.cache_clear()
_get_auth_token.cache_clear()

assert mock.call_count == 3 # 1 for push, 1 for pull, 1 for download

_, push_kwargs = mock.call_args_list[0]
_, pull_kwargs = mock.call_args_list[1]

assert push_kwargs['headers'].get('Authorization') == 'token test-auth-token'
assert pull_kwargs['headers'].get('Authorization') == 'token test-auth-token'


@pytest.mark.parametrize(
'set_env_vars', [{'JINA_AUTH_TOKEN': 'test-auth-token-env'}], indirect=True
)
def test_api_authorization_header_env_and_config(
mocker, monkeypatch, tmpdir, set_env_vars
):
from docarray.array.mixins.io.pushpull import _get_hub_config, _get_auth_token

_get_hub_config.cache_clear()
_get_auth_token.cache_clear()

os.environ['JINA_HUB_ROOT'] = str(tmpdir)

token = 'test-auth-token-config'
with open(tmpdir / JINA_CLOUD_CONFIG, 'w') as f:
json.dump({'auth_token': token}, f)

mock = mocker.Mock()
_mock_post(mock, monkeypatch)
_mock_get(mock, monkeypatch)

docs = random_docs(2)
docs.push(name='test_name')
DocumentArray.pull(name='test_name')

del os.environ['JINA_HUB_ROOT']

_get_hub_config.cache_clear()
_get_auth_token.cache_clear()

assert mock.call_count == 3 # 1 for push, 1 for pull, 1 for download

_, push_kwargs = mock.call_args_list[0]
_, pull_kwargs = mock.call_args_list[1]

assert push_kwargs['headers'].get('Authorization') == 'token test-auth-token-env'
assert pull_kwargs['headers'].get('Authorization') == 'token test-auth-token-env'