diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index 072279ea3a..61fe109728 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -1,5 +1,6 @@ import doctest import unittest +import itertools from test import support from test.support import threading_helper, script_helper from itertools import * @@ -144,8 +145,7 @@ def expand(it, i=0): c = expand(compare[took:]) self.assertEqual(a, c); - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON @pickle_deprecated def test_accumulate(self): self.assertEqual(list(accumulate(range(10))), # one positional arg @@ -189,7 +189,11 @@ def test_batched(self): [('A', 'B'), ('C', 'D'), ('E', 'F'), ('G',)]) self.assertEqual(list(batched('ABCDEFG', 1)), [('A',), ('B',), ('C',), ('D',), ('E',), ('F',), ('G',)]) + self.assertEqual(list(batched('ABCDEF', 2, strict=True)), + [('A', 'B'), ('C', 'D'), ('E', 'F')]) + with self.assertRaises(ValueError): # Incomplete batch when strict + list(batched('ABCDEFG', 3, strict=True)) with self.assertRaises(TypeError): # Too few arguments list(batched('ABCDEFG')) with self.assertRaises(TypeError): @@ -243,8 +247,7 @@ def test_chain_from_iterable(self): self.assertRaises(TypeError, list, chain.from_iterable([2, 3])) self.assertEqual(list(islice(chain.from_iterable(repeat(range(5))), 2)), [0, 1]) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON @pickle_deprecated def test_chain_reducible(self): for oper in [copy.deepcopy] + picklecopiers: @@ -259,8 +262,7 @@ def test_chain_reducible(self): for proto in range(pickle.HIGHEST_PROTOCOL + 1): self.pickletest(proto, chain('abc', 'def'), compare=list('abcdef')) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON @pickle_deprecated def test_chain_setstate(self): self.assertRaises(TypeError, chain().__setstate__, ()) @@ -275,8 +277,7 @@ def test_chain_setstate(self): it.__setstate__((iter(['abc', 'def']), iter(['ghi']))) self.assertEqual(list(it), ['ghi', 'a', 'b', 'c', 'd', 'e', 'f']) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON @pickle_deprecated def test_combinations(self): self.assertRaises(TypeError, combinations, 'abc') # missing r argument @@ -368,8 +369,7 @@ def test_combinations_tuple_reuse(self): self.assertEqual(len(set(map(id, combinations('abcde', 3)))), 1) self.assertNotEqual(len(set(map(id, list(combinations('abcde', 3))))), 1) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON @pickle_deprecated def test_combinations_with_replacement(self): cwr = combinations_with_replacement @@ -459,8 +459,7 @@ def test_combinations_with_replacement_tuple_reuse(self): self.assertEqual(len(set(map(id, cwr('abcde', 3)))), 1) self.assertNotEqual(len(set(map(id, list(cwr('abcde', 3))))), 1) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON @pickle_deprecated def test_permutations(self): self.assertRaises(TypeError, permutations) # too few arguments @@ -568,8 +567,7 @@ def test_combinatorics(self): self.assertEqual(comb, list(filter(set(perm).__contains__, cwr))) # comb: cwr that is a perm self.assertEqual(comb, sorted(set(cwr) & set(perm))) # comb: both a cwr and a perm - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON @pickle_deprecated def test_compress(self): self.assertEqual(list(compress(data='ABCDEF', selectors=[1,0,1,0,1,1])), list('ACEF')) @@ -604,8 +602,7 @@ def test_compress(self): next(testIntermediate) self.assertEqual(list(op(testIntermediate)), list(result2)) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: DeprecationWarning not triggered @pickle_deprecated def test_count(self): self.assertEqual(lzip('abc',count()), [('a', 0), ('b', 1), ('c', 2)]) @@ -615,6 +612,8 @@ def test_count(self): self.assertEqual(take(2, zip('abc',count(-3))), [('a', -3), ('b', -2)]) self.assertRaises(TypeError, count, 2, 3, 4) self.assertRaises(TypeError, count, 'a') + self.assertEqual(take(3, count(maxsize)), + [maxsize, maxsize + 1, maxsize + 2]) self.assertEqual(take(10, count(maxsize-5)), list(range(maxsize-5, maxsize+5))) self.assertEqual(take(10, count(-maxsize-5)), @@ -637,6 +636,15 @@ def test_count(self): self.assertEqual(next(c), -8) self.assertEqual(repr(count(10.25)), 'count(10.25)') self.assertEqual(repr(count(10.0)), 'count(10.0)') + + self.assertEqual(repr(count(maxsize)), f'count({maxsize})') + c = count(maxsize - 1) + self.assertEqual(repr(c), f'count({maxsize - 1})') + next(c) # c is now at masize + self.assertEqual(repr(c), f'count({maxsize})') + next(c) + self.assertEqual(repr(c), f'count({maxsize + 1})') + self.assertEqual(type(next(count(10.0))), float) for i in (-sys.maxsize-5, -sys.maxsize+5 ,-10, -1, 0, 10, sys.maxsize-5, sys.maxsize+5): # Test repr @@ -655,10 +663,9 @@ def test_count(self): #check proper internal error handling for large "step' sizes count(1, maxsize+5); sys.exc_info() - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON @pickle_deprecated - def test_count_with_stride(self): + def test_count_with_step(self): self.assertEqual(lzip('abc',count(2,3)), [('a', 2), ('b', 5), ('c', 8)]) self.assertEqual(lzip('abc',count(start=2,step=3)), [('a', 2), ('b', 5), ('c', 8)]) @@ -672,6 +679,12 @@ def test_count_with_stride(self): self.assertEqual(take(20, count(-maxsize-15, 3)), take(20, range(-maxsize-15,-maxsize+100, 3))) self.assertEqual(take(3, count(10, maxsize+5)), list(range(10, 10+3*(maxsize+5), maxsize+5))) + self.assertEqual(take(3, count(maxsize, 2)), + [maxsize, maxsize + 2, maxsize + 4]) + self.assertEqual(take(3, count(maxsize, maxsize)), + [maxsize, 2 * maxsize, 3 * maxsize]) + self.assertEqual(take(3, count(-maxsize, maxsize)), + [-maxsize, 0, maxsize]) self.assertEqual(take(3, count(2, 1.25)), [2, 3.25, 4.5]) self.assertEqual(take(3, count(2, 3.25-4j)), [2, 5.25-4j, 8.5-8j]) self.assertEqual(take(3, count(Decimal('1.1'), Decimal('.1'))), @@ -713,6 +726,42 @@ def test_count_with_stride(self): for proto in range(pickle.HIGHEST_PROTOCOL + 1): self.pickletest(proto, count(i, j)) + c = count(maxsize -2, 2) + self.assertEqual(repr(c), f'count({maxsize - 2}, 2)') + next(c) # c is now at masize + self.assertEqual(repr(c), f'count({maxsize}, 2)') + next(c) + self.assertEqual(repr(c), f'count({maxsize + 2}, 2)') + + c = count(maxsize + 1, -1) + self.assertEqual(repr(c), f'count({maxsize + 1}, -1)') + next(c) # c is now at masize + self.assertEqual(repr(c), f'count({maxsize}, -1)') + next(c) + self.assertEqual(repr(c), f'count({maxsize - 1}, -1)') + + @threading_helper.requires_working_threading() + def test_count_threading(self, step=1): + # this test verifies multithreading consistency, which is + # mostly for testing builds without GIL, but nice to test anyway + count_to = 10_000 + num_threads = 10 + c = count(step=step) + def counting_thread(): + for i in range(count_to): + next(c) + threads = [] + for i in range(num_threads): + thread = threading.Thread(target=counting_thread) + thread.start() + threads.append(thread) + for thread in threads: + thread.join() + self.assertEqual(next(c), count_to * num_threads * step) + + def test_count_with_step_threading(self): + self.test_count_threading(step=5) + def test_cycle(self): self.assertEqual(take(10, cycle('abc')), list('abcabcabca')) self.assertEqual(list(cycle('')), []) @@ -720,8 +769,7 @@ def test_cycle(self): self.assertRaises(TypeError, cycle, 5) self.assertEqual(list(islice(cycle(gen3()),10)), [0,1,2,0,1,2,0,1,2,0]) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON @pickle_deprecated def test_cycle_copy_pickle(self): # check copy, deepcopy, pickle @@ -759,8 +807,7 @@ def test_cycle_copy_pickle(self): d = pickle.loads(p) # rebuild the cycle object self.assertEqual(take(20, d), list('cdeabcdeabcdeabcdeab')) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON @pickle_deprecated def test_cycle_unpickle_compat(self): testcases = [ @@ -793,8 +840,7 @@ def test_cycle_unpickle_compat(self): it = pickle.loads(t) self.assertEqual(take(10, it), [2, 3, 1, 2, 3, 1, 2, 3, 1, 2]) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON @pickle_deprecated def test_cycle_setstate(self): # Verify both modes for restoring state @@ -832,8 +878,7 @@ def test_cycle_setstate(self): self.assertRaises(TypeError, cycle('').__setstate__, ()) self.assertRaises(TypeError, cycle('').__setstate__, ([],)) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON @pickle_deprecated def test_groupby(self): # Check whether it accepts arguments correctly @@ -992,8 +1037,7 @@ def test_filter(self): c = filter(isEven, range(6)) self.pickletest(proto, c) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON @pickle_deprecated def test_filterfalse(self): self.assertEqual(list(filterfalse(isEven, range(6))), [1,3,5]) @@ -1101,8 +1145,7 @@ def test_zip_longest_tuple_reuse(self): ids = list(map(id, list(zip_longest('abc', 'def')))) self.assertEqual(len(dict.fromkeys(ids)), len(ids)) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON @pickle_deprecated def test_zip_longest_pickling(self): for proto in range(pickle.HIGHEST_PROTOCOL + 1): @@ -1292,12 +1335,16 @@ def product1(*args, **kwds): else: return - def product2(*args, **kwds): + def product2(*iterables, repeat=1): 'Pure python version used in docs' - pools = list(map(tuple, args)) * kwds.get('repeat', 1) + if repeat < 0: + raise ValueError('repeat argument cannot be negative') + pools = [tuple(pool) for pool in iterables] * repeat + result = [[]] for pool in pools: result = [x+[y] for x in result for y in pool] + for prod in result: yield tuple(prod) @@ -1322,8 +1369,7 @@ def test_product_tuple_reuse(self): self.assertEqual(len(set(map(id, product('abc', 'def')))), 1) self.assertNotEqual(len(set(map(id, list(product('abc', 'def'))))), 1) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON @pickle_deprecated def test_product_pickling(self): # check copy, deepcopy, pickle @@ -1340,8 +1386,7 @@ def test_product_pickling(self): for proto in range(pickle.HIGHEST_PROTOCOL + 1): self.pickletest(proto, product(*args)) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON @pickle_deprecated def test_product_issue_25021(self): # test that indices are properly clamped to the length of the tuples @@ -1353,8 +1398,7 @@ def test_product_issue_25021(self): p.__setstate__((0, 0, 0x1000)) # will access tuple element 1 if not clamped self.assertRaises(StopIteration, next, p) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON @pickle_deprecated def test_repeat(self): self.assertEqual(list(repeat(object='a', times=3)), ['a', 'a', 'a']) @@ -1388,8 +1432,7 @@ def test_repeat_with_negative_times(self): self.assertEqual(repr(repeat('a', times=-1)), "repeat('a', 0)") self.assertEqual(repr(repeat('a', times=-2)), "repeat('a', 0)") - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON @pickle_deprecated def test_map(self): self.assertEqual(list(map(operator.pow, range(3), range(1,7))), @@ -1421,8 +1464,7 @@ def test_map(self): c = map(tupleize, 'abc', count()) self.pickletest(proto, c) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON @pickle_deprecated def test_starmap(self): self.assertEqual(list(starmap(operator.pow, zip(range(3), range(1,7)))), @@ -1451,8 +1493,7 @@ def test_starmap(self): c = starmap(operator.pow, zip(range(3), range(1,7))) self.pickletest(proto, c) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON @pickle_deprecated def test_islice(self): for args in [ # islice(args) should agree with range(args) @@ -1548,8 +1589,7 @@ def __index__(self): self.assertEqual(list(islice(range(100), IntLike(10), IntLike(50), IntLike(5))), list(range(10,50,5))) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON @pickle_deprecated def test_takewhile(self): data = [1, 3, 5, 20, 2, 4, 6, 8] @@ -1571,8 +1611,7 @@ def test_takewhile(self): for proto in range(pickle.HIGHEST_PROTOCOL + 1): self.pickletest(proto, takewhile(underten, data)) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON @pickle_deprecated def test_dropwhile(self): data = [1, 3, 5, 20, 2, 4, 6, 8] @@ -1591,8 +1630,7 @@ def test_dropwhile(self): for proto in range(pickle.HIGHEST_PROTOCOL + 1): self.pickletest(proto, dropwhile(underten, data)) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON @pickle_deprecated def test_tee(self): n = 200 @@ -1660,10 +1698,11 @@ def test_tee(self): self.assertEqual(len(result), n) self.assertEqual([list(x) for x in result], [list('abc')]*n) - # tee pass-through to copyable iterator + # tee objects are independent (see bug gh-123884) a, b = tee('abc') c, d = tee(a) - self.assertTrue(a is c) + e, f = tee(c) + self.assertTrue(len({a, b, c, d, e, f}) == 6) # test tee_new t1, t2 = tee('abc') @@ -1776,7 +1815,7 @@ def __next__(self): with self.assertRaisesRegex(RuntimeError, "tee"): next(a) - @unittest.skip("TODO: RUSTPYTHON, hangs") + @unittest.skip("TODO: RUSTPYTHON; , hangs") @threading_helper.requires_working_threading() def test_tee_concurrent(self): start = threading.Event() @@ -1866,6 +1905,13 @@ def test_zip_longest_result_gc(self): gc.collect() self.assertTrue(gc.is_tracked(next(it))) + @support.cpython_only + def test_pairwise_result_gc(self): + # Ditto for pairwise. + it = pairwise([None, None]) + gc.collect() + self.assertTrue(gc.is_tracked(next(it))) + @support.cpython_only def test_immutable_types(self): from itertools import _grouper, _tee, _tee_dataobject @@ -1904,8 +1950,7 @@ class TestExamples(unittest.TestCase): def test_accumulate(self): self.assertEqual(list(accumulate([1,2,3,4,5])), [1, 3, 6, 10, 15]) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON @pickle_deprecated def test_accumulate_reducible(self): # check copy, deepcopy, pickle @@ -1922,8 +1967,7 @@ def test_accumulate_reducible(self): self.assertEqual(list(copy.deepcopy(it)), accumulated[1:]) self.assertEqual(list(copy.copy(it)), accumulated[1:]) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON @pickle_deprecated def test_accumulate_reducible_none(self): # Issue #25718: total is None @@ -2082,6 +2126,172 @@ def test_islice_recipe(self): self.assertEqual(next(c), 3) + def test_tee_recipe(self): + + # Begin tee() recipe ########################################### + + def tee(iterable, n=2): + if n < 0: + raise ValueError + if n == 0: + return () + iterator = _tee(iterable) + result = [iterator] + for _ in range(n - 1): + result.append(_tee(iterator)) + return tuple(result) + + class _tee: + + def __init__(self, iterable): + it = iter(iterable) + if isinstance(it, _tee): + self.iterator = it.iterator + self.link = it.link + else: + self.iterator = it + self.link = [None, None] + + def __iter__(self): + return self + + def __next__(self): + link = self.link + if link[1] is None: + link[0] = next(self.iterator) + link[1] = [None, None] + value, self.link = link + return value + + # End tee() recipe ############################################# + + n = 200 + + a, b = tee([]) # test empty iterator + self.assertEqual(list(a), []) + self.assertEqual(list(b), []) + + a, b = tee(irange(n)) # test 100% interleaved + self.assertEqual(lzip(a,b), lzip(range(n), range(n))) + + a, b = tee(irange(n)) # test 0% interleaved + self.assertEqual(list(a), list(range(n))) + self.assertEqual(list(b), list(range(n))) + + a, b = tee(irange(n)) # test dealloc of leading iterator + for i in range(100): + self.assertEqual(next(a), i) + del a + self.assertEqual(list(b), list(range(n))) + + a, b = tee(irange(n)) # test dealloc of trailing iterator + for i in range(100): + self.assertEqual(next(a), i) + del b + self.assertEqual(list(a), list(range(100, n))) + + for j in range(5): # test randomly interleaved + order = [0]*n + [1]*n + random.shuffle(order) + lists = ([], []) + its = tee(irange(n)) + for i in order: + value = next(its[i]) + lists[i].append(value) + self.assertEqual(lists[0], list(range(n))) + self.assertEqual(lists[1], list(range(n))) + + # test argument format checking + self.assertRaises(TypeError, tee) + self.assertRaises(TypeError, tee, 3) + self.assertRaises(TypeError, tee, [1,2], 'x') + self.assertRaises(TypeError, tee, [1,2], 3, 'x') + + # tee object should be instantiable + a, b = tee('abc') + c = type(a)('def') + self.assertEqual(list(c), list('def')) + + # test long-lagged and multi-way split + a, b, c = tee(range(2000), 3) + for i in range(100): + self.assertEqual(next(a), i) + self.assertEqual(list(b), list(range(2000))) + self.assertEqual([next(c), next(c)], list(range(2))) + self.assertEqual(list(a), list(range(100,2000))) + self.assertEqual(list(c), list(range(2,2000))) + + # test invalid values of n + self.assertRaises(TypeError, tee, 'abc', 'invalid') + self.assertRaises(ValueError, tee, [], -1) + + for n in range(5): + result = tee('abc', n) + self.assertEqual(type(result), tuple) + self.assertEqual(len(result), n) + self.assertEqual([list(x) for x in result], [list('abc')]*n) + + # tee objects are independent (see bug gh-123884) + a, b = tee('abc') + c, d = tee(a) + e, f = tee(c) + self.assertTrue(len({a, b, c, d, e, f}) == 6) + + # test tee_new + t1, t2 = tee('abc') + tnew = type(t1) + self.assertRaises(TypeError, tnew) + self.assertRaises(TypeError, tnew, 10) + t3 = tnew(t1) + self.assertTrue(list(t1) == list(t2) == list(t3) == list('abc')) + + # test that tee objects are weak referencable + a, b = tee(range(10)) + p = weakref.proxy(a) + self.assertEqual(getattr(p, '__class__'), type(b)) + del a + gc.collect() # For PyPy or other GCs. + self.assertRaises(ReferenceError, getattr, p, '__class__') + + ans = list('abc') + long_ans = list(range(10000)) + + # Tests not applicable to the tee() recipe + if False: + # check copy + a, b = tee('abc') + self.assertEqual(list(copy.copy(a)), ans) + self.assertEqual(list(copy.copy(b)), ans) + a, b = tee(list(range(10000))) + self.assertEqual(list(copy.copy(a)), long_ans) + self.assertEqual(list(copy.copy(b)), long_ans) + + # check partially consumed copy + a, b = tee('abc') + take(2, a) + take(1, b) + self.assertEqual(list(copy.copy(a)), ans[2:]) + self.assertEqual(list(copy.copy(b)), ans[1:]) + self.assertEqual(list(a), ans[2:]) + self.assertEqual(list(b), ans[1:]) + a, b = tee(range(10000)) + take(100, a) + take(60, b) + self.assertEqual(list(copy.copy(a)), long_ans[100:]) + self.assertEqual(list(copy.copy(b)), long_ans[60:]) + self.assertEqual(list(a), long_ans[100:]) + self.assertEqual(list(b), long_ans[60:]) + + # Issue 13454: Crash when deleting backward iterator from tee() + forward, backward = tee(repeat(None, 2000)) # 20000000 + try: + any(forward) # exhaust the iterator + del backward + except: + del forward, backward + raise + + class TestGC(unittest.TestCase): def makecycle(self, iterator, container): @@ -2568,8 +2778,7 @@ def __eq__(self, other): class SubclassWithKwargsTest(unittest.TestCase): - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_keywords_in_subclass(self): # count is not subclassable... testcases = [ @@ -2664,7 +2873,7 @@ def test_permutations_sizeof(self): def load_tests(loader, tests, pattern): - tests.addTest(doctest.DocTestSuite()) + tests.addTest(doctest.DocTestSuite(itertools)) return tests diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index 6a8d6e50a1..addfddd176 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -238,10 +238,10 @@ mod decl { #[derive(FromArgs)] struct CountNewArgs { - #[pyarg(positional, optional)] + #[pyarg(any, optional)] start: OptionalArg, - #[pyarg(positional, optional)] + #[pyarg(any, optional)] step: OptionalArg, } @@ -1945,6 +1945,7 @@ mod decl { exhausted: AtomicCell, iterable: PyIter, n: AtomicCell, + strict: AtomicCell, } #[derive(FromArgs)] @@ -1953,6 +1954,8 @@ mod decl { iterable_ref: PyObjectRef, #[pyarg(positional)] n: PyIntRef, + #[pyarg(named, default = false)] + strict: bool, } impl Constructor for PyItertoolsBatched { @@ -1960,7 +1963,11 @@ mod decl { fn py_new( cls: PyTypeRef, - Self::Args { iterable_ref, n }: Self::Args, + Self::Args { + iterable_ref, + n, + strict, + }: Self::Args, vm: &VirtualMachine, ) -> PyResult { let n = n.as_bigint(); @@ -1976,6 +1983,7 @@ mod decl { iterable, n: AtomicCell::new(n), exhausted: AtomicCell::new(false), + strict: AtomicCell::new(strict), } .into_ref_with_type(vm, cls) .map(Into::into) @@ -2005,9 +2013,16 @@ mod decl { } } } - match result.len() { + let res_len = result.len(); + match res_len { 0 => Ok(PyIterReturn::StopIteration(None)), - _ => Ok(PyIterReturn::Return(vm.ctx.new_tuple(result).into())), + _ => { + if zelf.strict.load() && res_len != n { + Err(vm.new_value_error("batched(): incomplete batch")) + } else { + Ok(PyIterReturn::Return(vm.ctx.new_tuple(result).into())) + } + } } } }