From 0232a30b058388fbd3b8a43111adf58e8e839d36 Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Tue, 11 Jan 2022 12:12:28 +0100 Subject: [PATCH] fix(array): fix edge case on single boolean index --- docarray/array/document.py | 2 +- docs/fundamentals/documentarray/access-elements.md | 2 ++ tests/unit/array/test_advance_indexing.py | 12 ++++++++++++ 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/docarray/array/document.py b/docarray/array/document.py index be70a655bb9..fb988b50acd 100644 --- a/docarray/array/document.py +++ b/docarray/array/document.py @@ -129,7 +129,7 @@ def __getitem__( def __getitem__( self, index: 'DocumentArrayIndexType' ) -> Union['Document', 'DocumentArray']: - if isinstance(index, (int, np.generic)): + if isinstance(index, (int, np.generic)) and not isinstance(index, bool): return self._data[int(index)] elif isinstance(index, str): if index.startswith('@'): diff --git a/docs/fundamentals/documentarray/access-elements.md b/docs/fundamentals/documentarray/access-elements.md index a4455c02394..0157b176915 100644 --- a/docs/fundamentals/documentarray/access-elements.md +++ b/docs/fundamentals/documentarray/access-elements.md @@ -120,6 +120,8 @@ print(da) ``` +Note that if the length of the boolean mask is smaller than the length of a DocumentArray, then the remaining part is padded to `False`. + (path-string)= ## Index by nested structure diff --git a/tests/unit/array/test_advance_indexing.py b/tests/unit/array/test_advance_indexing.py index d87fdd80af9..13cf2dd6905 100644 --- a/tests/unit/array/test_advance_indexing.py +++ b/tests/unit/array/test_advance_indexing.py @@ -219,3 +219,15 @@ def test_advance_selector_mixed(): assert len(da[:, ('id', 'embedding', 'matches')]) == 3 assert len(da[:, ('id', 'embedding', 'matches')][0]) == 10 + + +def test_single_boolean_and_padding(): + from docarray import DocumentArray + + da = DocumentArray.empty(3) + + with pytest.raises(IndexError): + da[True] + + assert len(da[True, False]) == 1 + assert len(da[False, False]) == 0