From 6270a5b04c13cefc75ae35f49b0dda15097354a8 Mon Sep 17 00:00:00 2001 From: CPython Developers <> Date: Sun, 14 Aug 2022 23:13:35 +0900 Subject: [PATCH 01/23] Update webbrowser from CPython 3.10.6 --- Lib/test/test_webbrowser.py | 19 +--- Lib/webbrowser.py | 217 ++++++++++++++++++++++-------------- 2 files changed, 137 insertions(+), 99 deletions(-) diff --git a/Lib/test/test_webbrowser.py b/Lib/test/test_webbrowser.py index a97e64a753..673cc995d3 100644 --- a/Lib/test/test_webbrowser.py +++ b/Lib/test/test_webbrowser.py @@ -5,7 +5,8 @@ import subprocess from unittest import mock from test import support -from test.support import os_helper, import_helper +from test.support import import_helper +from test.support import os_helper URL = 'http://www.example.com' @@ -171,29 +172,21 @@ class OperaCommandTest(CommandTestMixin, unittest.TestCase): browser_class = webbrowser.Opera - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_open(self): self._test('open', options=[], arguments=[URL]) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_open_with_autoraise_false(self): self._test('open', kw=dict(autoraise=False), options=[], arguments=[URL]) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_open_new(self): self._test('open_new', options=['--new-window'], arguments=[URL]) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_open_new_tab(self): self._test('open_new_tab', options=[], @@ -267,23 +260,17 @@ class ExampleBrowser: self.assertEqual(webbrowser._tryorder, expected_tryorder) self.assertEqual(webbrowser._browsers, expected_browsers) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_register(self): self._check_registration(preferred=False) def test_register_default(self): self._check_registration(preferred=None) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_register_preferred(self): self._check_registration(preferred=True) class ImportTest(unittest.TestCase): - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_register(self): webbrowser = import_helper.import_fresh_module('webbrowser') self.assertIsNone(webbrowser._tryorder) @@ -298,8 +285,6 @@ class ExampleBrowser: self.assertIn('example1', webbrowser._browsers) self.assertEqual(webbrowser._browsers['example1'], [ExampleBrowser, None]) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_get(self): webbrowser = import_helper.import_fresh_module('webbrowser') self.assertIsNone(webbrowser._tryorder) diff --git a/Lib/webbrowser.py b/Lib/webbrowser.py index 6f43b7f126..ec3cece48c 100755 --- a/Lib/webbrowser.py +++ b/Lib/webbrowser.py @@ -1,5 +1,5 @@ #! /usr/bin/env python3 -"""Interfaces for launching and remotely controlling Web browsers.""" +"""Interfaces for launching and remotely controlling web browsers.""" # Maintained by Georg Brandl. import os @@ -7,25 +7,39 @@ import shutil import sys import subprocess +import threading __all__ = ["Error", "open", "open_new", "open_new_tab", "get", "register"] class Error(Exception): pass -_browsers = {} # Dictionary of available browser controllers -_tryorder = [] # Preference order of available browsers - -def register(name, klass, instance=None, update_tryorder=1): - """Register a browser connector and, optionally, connection.""" - _browsers[name.lower()] = [klass, instance] - if update_tryorder > 0: - _tryorder.append(name) - elif update_tryorder < 0: - _tryorder.insert(0, name) +_lock = threading.RLock() +_browsers = {} # Dictionary of available browser controllers +_tryorder = None # Preference order of available browsers +_os_preferred_browser = None # The preferred browser + +def register(name, klass, instance=None, *, preferred=False): + """Register a browser connector.""" + with _lock: + if _tryorder is None: + register_standard_browsers() + _browsers[name.lower()] = [klass, instance] + + # Preferred browsers go to the front of the list. + # Need to match to the default browser returned by xdg-settings, which + # may be of the form e.g. "firefox.desktop". + if preferred or (_os_preferred_browser and name in _os_preferred_browser): + _tryorder.insert(0, name) + else: + _tryorder.append(name) def get(using=None): """Return a browser launcher instance appropriate for the environment.""" + if _tryorder is None: + with _lock: + if _tryorder is None: + register_standard_browsers() if using is not None: alternatives = [using] else: @@ -55,6 +69,18 @@ def get(using=None): # instead of "from webbrowser import *". def open(url, new=0, autoraise=True): + """Display url using the default browser. + + If possible, open url in a location determined by new. + - 0: the same browser window (the default). + - 1: a new browser window. + - 2: a new browser page ("tab"). + If possible, autoraise raises the window (the default) or not. + """ + if _tryorder is None: + with _lock: + if _tryorder is None: + register_standard_browsers() for name in _tryorder: browser = get(name) if browser.open(url, new, autoraise): @@ -62,14 +88,22 @@ def open(url, new=0, autoraise=True): return False def open_new(url): + """Open url in a new window of the default browser. + + If not possible, then open url in the only browser window. + """ return open(url, 1) def open_new_tab(url): + """Open url in a new page ("tab") of the default browser. + + If not possible, then the behavior becomes equivalent to open_new(). + """ return open(url, 2) -def _synthesize(browser, update_tryorder=1): - """Attempt to synthesize a controller base on existing controllers. +def _synthesize(browser, *, preferred=False): + """Attempt to synthesize a controller based on existing controllers. This is useful to create a controller when a user specifies a path to an entry in the BROWSER environment variable -- we can copy a general @@ -95,7 +129,7 @@ def _synthesize(browser, update_tryorder=1): controller = copy.copy(controller) controller.name = browser controller.basename = os.path.basename(browser) - register(browser, None, controller, update_tryorder) + register(browser, None, instance=controller, preferred=preferred) return [None, controller] return [None, None] @@ -136,6 +170,7 @@ def __init__(self, name): self.basename = os.path.basename(self.name) def open(self, url, new=0, autoraise=True): + sys.audit("webbrowser.open", url) cmdline = [self.name] + [arg.replace("%s", url) for arg in self.args] try: @@ -155,6 +190,7 @@ class BackgroundBrowser(GenericBrowser): def open(self, url, new=0, autoraise=True): cmdline = [self.name] + [arg.replace("%s", url) for arg in self.args] + sys.audit("webbrowser.open", url) try: if sys.platform[:3] == 'win': p = subprocess.Popen(cmdline) @@ -183,7 +219,7 @@ class UnixBrowser(BaseBrowser): remote_action_newwin = None remote_action_newtab = None - def _invoke(self, args, remote, autoraise): + def _invoke(self, args, remote, autoraise, url=None): raise_opt = [] if remote and self.raise_opts: # use autoraise argument only for remote invocation @@ -219,6 +255,7 @@ def _invoke(self, args, remote, autoraise): return not p.wait() def open(self, url, new=0, autoraise=True): + sys.audit("webbrowser.open", url) if new == 0: action = self.remote_action elif new == 1: @@ -235,7 +272,7 @@ def open(self, url, new=0, autoraise=True): args = [arg.replace("%s", url).replace("%action", action) for arg in self.remote_args] args = [arg for arg in args if arg] - success = self._invoke(args, True, autoraise) + success = self._invoke(args, True, autoraise, url) if not success: # remote invocation failed, try straight way args = [arg.replace("%s", url) for arg in self.args] @@ -290,11 +327,10 @@ class Chrome(UnixBrowser): class Opera(UnixBrowser): "Launcher class for Opera browser." - raise_opts = ["-noraise", ""] - remote_args = ['-remote', 'openURL(%s%action)'] + remote_args = ['%action', '%s'] remote_action = "" - remote_action_newwin = ",new-window" - remote_action_newtab = ",new-page" + remote_action_newwin = "--new-window" + remote_action_newtab = "" background = True @@ -320,6 +356,7 @@ class Konqueror(BaseBrowser): """ def open(self, url, new=0, autoraise=True): + sys.audit("webbrowser.open", url) # XXX Currently I know no way to prevent KFM from opening a new win. if new == 2: action = "newTab" @@ -376,7 +413,7 @@ def _find_grail_rc(self): tempdir = os.path.join(tempfile.gettempdir(), ".grail-unix") user = pwd.getpwuid(os.getuid())[0] - filename = os.path.join(tempdir, user + "-*") + filename = os.path.join(glob.escape(tempdir), glob.escape(user) + "-*") maybes = glob.glob(filename) if not maybes: return None @@ -403,6 +440,7 @@ def _remote(self, action): return 1 def open(self, url, new=0, autoraise=True): + sys.audit("webbrowser.open", url) if new: ok = self._remote("LOADNEW " + url) else: @@ -482,25 +520,80 @@ def register_X_browsers(): if shutil.which("grail"): register("grail", Grail, None) -# Prefer X browsers if present -if os.environ.get("DISPLAY"): - register_X_browsers() - -# Also try console browsers -if os.environ.get("TERM"): - if shutil.which("www-browser"): - register("www-browser", None, GenericBrowser("www-browser")) - # The Links/elinks browsers - if shutil.which("links"): - register("links", None, GenericBrowser("links")) - if shutil.which("elinks"): - register("elinks", None, Elinks("elinks")) - # The Lynx browser , - if shutil.which("lynx"): - register("lynx", None, GenericBrowser("lynx")) - # The w3m browser - if shutil.which("w3m"): - register("w3m", None, GenericBrowser("w3m")) +def register_standard_browsers(): + global _tryorder + _tryorder = [] + + if sys.platform == 'darwin': + register("MacOSX", None, MacOSXOSAScript('default')) + register("chrome", None, MacOSXOSAScript('chrome')) + register("firefox", None, MacOSXOSAScript('firefox')) + register("safari", None, MacOSXOSAScript('safari')) + # OS X can use below Unix support (but we prefer using the OS X + # specific stuff) + + if sys.platform == "serenityos": + # SerenityOS webbrowser, simply called "Browser". + register("Browser", None, BackgroundBrowser("Browser")) + + if sys.platform[:3] == "win": + # First try to use the default Windows browser + register("windows-default", WindowsDefault) + + # Detect some common Windows browsers, fallback to IE + iexplore = os.path.join(os.environ.get("PROGRAMFILES", "C:\\Program Files"), + "Internet Explorer\\IEXPLORE.EXE") + for browser in ("firefox", "firebird", "seamonkey", "mozilla", + "netscape", "opera", iexplore): + if shutil.which(browser): + register(browser, None, BackgroundBrowser(browser)) + else: + # Prefer X browsers if present + if os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY"): + try: + cmd = "xdg-settings get default-web-browser".split() + raw_result = subprocess.check_output(cmd, stderr=subprocess.DEVNULL) + result = raw_result.decode().strip() + except (FileNotFoundError, subprocess.CalledProcessError, PermissionError, NotADirectoryError) : + pass + else: + global _os_preferred_browser + _os_preferred_browser = result + + register_X_browsers() + + # Also try console browsers + if os.environ.get("TERM"): + if shutil.which("www-browser"): + register("www-browser", None, GenericBrowser("www-browser")) + # The Links/elinks browsers + if shutil.which("links"): + register("links", None, GenericBrowser("links")) + if shutil.which("elinks"): + register("elinks", None, Elinks("elinks")) + # The Lynx browser , + if shutil.which("lynx"): + register("lynx", None, GenericBrowser("lynx")) + # The w3m browser + if shutil.which("w3m"): + register("w3m", None, GenericBrowser("w3m")) + + # OK, now that we know what the default preference orders for each + # platform are, allow user to override them with the BROWSER variable. + if "BROWSER" in os.environ: + userchoices = os.environ["BROWSER"].split(os.pathsep) + userchoices.reverse() + + # Treat choices in same way as if passed into get() but do register + # and prepend to _tryorder + for cmdline in userchoices: + if cmdline != '': + cmd = _synthesize(cmdline, preferred=True) + if cmd[1] is None: + register(cmdline, None, GenericBrowser(cmdline), preferred=True) + + # what to do if _tryorder is now empty? + # # Platform support for Windows @@ -509,6 +602,7 @@ def register_X_browsers(): if sys.platform[:3] == "win": class WindowsDefault(BaseBrowser): def open(self, url, new=0, autoraise=True): + sys.audit("webbrowser.open", url) try: os.startfile(url) except OSError: @@ -518,20 +612,6 @@ def open(self, url, new=0, autoraise=True): else: return True - _tryorder = [] - _browsers = {} - - # First try to use the default Windows browser - register("windows-default", WindowsDefault) - - # Detect some common Windows browsers, fallback to IE - iexplore = os.path.join(os.environ.get("PROGRAMFILES", "C:\\Program Files"), - "Internet Explorer\\IEXPLORE.EXE") - for browser in ("firefox", "firebird", "seamonkey", "mozilla", - "netscape", "opera", iexplore): - if shutil.which(browser): - register(browser, None, BackgroundBrowser(browser)) - # # Platform support for MacOS # @@ -552,6 +632,7 @@ def __init__(self, name): self.name = name def open(self, url, new=0, autoraise=True): + sys.audit("webbrowser.open", url) assert "'" not in url # hack for local urls if not ':' in url: @@ -608,34 +689,6 @@ def open(self, url, new=0, autoraise=True): return not rc - # Don't clear _tryorder or _browsers since OS X can use above Unix support - # (but we prefer using the OS X specific stuff) - register("safari", None, MacOSXOSAScript('safari'), -1) - register("firefox", None, MacOSXOSAScript('firefox'), -1) - register("chrome", None, MacOSXOSAScript('chrome'), -1) - register("MacOSX", None, MacOSXOSAScript('default'), -1) - - -# OK, now that we know what the default preference orders for each -# platform are, allow user to override them with the BROWSER variable. -if "BROWSER" in os.environ: - _userchoices = os.environ["BROWSER"].split(os.pathsep) - _userchoices.reverse() - - # Treat choices in same way as if passed into get() but do register - # and prepend to _tryorder - for cmdline in _userchoices: - if cmdline != '': - cmd = _synthesize(cmdline, -1) - if cmd[1] is None: - register(cmdline, None, GenericBrowser(cmdline), -1) - cmdline = None # to make del work if _userchoices was empty - del cmdline - del _userchoices - -# what to do if _tryorder is now empty? - - def main(): import getopt usage = """Usage: %s [-n | -t] url From ead12b5611bf466bd079b8c45794c0683d33e5ff Mon Sep 17 00:00:00 2001 From: CPython Developers <> Date: Sun, 14 Aug 2022 23:15:13 +0900 Subject: [PATCH 02/23] Update test_zlib from CPython 3.10.6 --- Lib/test/test_zlib.py | 30 +++++++++++------------------- 1 file changed, 11 insertions(+), 19 deletions(-) diff --git a/Lib/test/test_zlib.py b/Lib/test/test_zlib.py index a6575002d1..9c1c7ddba9 100644 --- a/Lib/test/test_zlib.py +++ b/Lib/test/test_zlib.py @@ -1,11 +1,13 @@ import unittest from test import support +from test.support import import_helper import binascii import copy import pickle import random import sys -from test.support import bigmemtest, _1G, _4G, import_helper +from test.support import bigmemtest, _1G, _4G + zlib = import_helper.import_module('zlib') @@ -129,6 +131,12 @@ def test_overflow(self): with self.assertRaisesRegex(OverflowError, 'int too large'): zlib.decompressobj().flush(sys.maxsize + 1) + @support.cpython_only + def test_disallow_instantiation(self): + # Ensure that the type disallows instantiation (bpo-43916) + support.check_disallow_instantiation(self, type(zlib.compressobj())) + support.check_disallow_instantiation(self, type(zlib.decompressobj())) + class BaseCompressTestCase(object): def check_big_compress_buffer(self, size, compress_func): @@ -136,8 +144,7 @@ def check_big_compress_buffer(self, size, compress_func): # Generate 10 MiB worth of random, and expand it by repeating it. # The assumption is that zlib's memory is not big enough to exploit # such spread out redundancy. - data = b''.join([random.getrandbits(8 * _1M).to_bytes(_1M, 'little') - for i in range(10)]) + data = random.randbytes(_1M * 10) data = data * (size // len(data) + 1) try: compress_func(data) @@ -498,7 +505,7 @@ def test_odd_flush(self): # others might simply have a single RNG gen = random gen.seed(1) - data = genblock(1, 17 * 1024, generator=gen) + data = gen.randbytes(17 * 1024) # compress, sync-flush, and decompress first = co.compress(data) @@ -847,20 +854,6 @@ def test_wbits(self): self.assertEqual(dco.decompress(gzip), HAMLET_SCENE) -def genblock(seed, length, step=1024, generator=random): - """length-byte stream of random data from a seed (in step-byte blocks).""" - if seed is not None: - generator.seed(seed) - randint = generator.randint - if length < step or step < 2: - step = length - blocks = bytes() - for i in range(0, length, step): - blocks += bytes(randint(0, 255) for x in range(step)) - return blocks - - - def choose_lines(source, number, seed=None, generator=random): """Return a list of number lines randomly chosen from the source""" if seed is not None: @@ -869,7 +862,6 @@ def choose_lines(source, number, seed=None, generator=random): return [generator.choice(sources) for n in range(number)] - HAMLET_SCENE = b""" LAERTES From 9571a68c01284191d4b18229694c303dad8992ff Mon Sep 17 00:00:00 2001 From: CPython Developers <> Date: Sun, 14 Aug 2022 23:17:42 +0900 Subject: [PATCH 03/23] Update test_yield_from from CPython 3.10.6 --- Lib/test/test_yield_from.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Lib/test/test_yield_from.py b/Lib/test/test_yield_from.py index a4a971072a..97acfd5413 100644 --- a/Lib/test/test_yield_from.py +++ b/Lib/test/test_yield_from.py @@ -946,6 +946,9 @@ def two(): res.append(g1.throw(MyErr)) except StopIteration: pass + except: + self.assertEqual(res, [0, 1, 2, 3]) + raise # Check with close class MyIt(object): def __iter__(self): From 3c5ab3a8ef1b3d175ceba9f0982221c8bde4f85c Mon Sep 17 00:00:00 2001 From: CPython Developers <> Date: Sun, 14 Aug 2022 23:26:31 +0900 Subject: [PATCH 04/23] Update test_with from CPython 3.10.6 --- Lib/test/test_with.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/Lib/test/test_with.py b/Lib/test/test_with.py index b1d7a15b5e..f21bf65fed 100644 --- a/Lib/test/test_with.py +++ b/Lib/test/test_with.py @@ -7,7 +7,7 @@ import sys import unittest from collections import deque -from contextlib import _GeneratorContextManager, contextmanager +from contextlib import _GeneratorContextManager, contextmanager, nullcontext class MockContextManager(_GeneratorContextManager): @@ -641,6 +641,12 @@ class B: pass self.assertEqual(blah.two, 2) self.assertEqual(blah.three, 3) + def testWithExtendedTargets(self): + with nullcontext(range(1, 5)) as (a, *b, c): + self.assertEqual(a, 1) + self.assertEqual(b, [2, 3]) + self.assertEqual(c, 4) + class ExitSwallowsExceptionTestCase(unittest.TestCase): From be1636a9477cbab3be758931e701593331545e82 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Sun, 14 Aug 2022 23:28:21 +0900 Subject: [PATCH 05/23] crlf -> lf --- Lib/test/test_trace.py | 1172 ++++++++++++++++++++-------------------- 1 file changed, 586 insertions(+), 586 deletions(-) diff --git a/Lib/test/test_trace.py b/Lib/test/test_trace.py index e5183b90eb..d74501549a 100644 --- a/Lib/test/test_trace.py +++ b/Lib/test/test_trace.py @@ -1,586 +1,586 @@ -import os -import sys -from test.support import captured_stdout -from test.support.os_helper import TESTFN, TESTFN_UNICODE, FS_NONASCII, rmtree, unlink -from test.support.script_helper import assert_python_ok, assert_python_failure -import textwrap -import unittest - -import trace -from trace import Trace - -from test.tracedmodules import testmod - -#------------------------------- Utilities -----------------------------------# - -def fix_ext_py(filename): - """Given a .pyc filename converts it to the appropriate .py""" - if filename.endswith('.pyc'): - filename = filename[:-1] - return filename - -def my_file_and_modname(): - """The .py file and module name of this file (__file__)""" - modname = os.path.splitext(os.path.basename(__file__))[0] - return fix_ext_py(__file__), modname - -def get_firstlineno(func): - return func.__code__.co_firstlineno - -#-------------------- Target functions for tracing ---------------------------# -# -# The relative line numbers of lines in these functions matter for verifying -# tracing. Please modify the appropriate tests if you change one of the -# functions. Absolute line numbers don't matter. -# - -def traced_func_linear(x, y): - a = x - b = y - c = a + b - return c - -def traced_func_loop(x, y): - c = x - for i in range(5): - c += y - return c - -def traced_func_importing(x, y): - return x + y + testmod.func(1) - -def traced_func_simple_caller(x): - c = traced_func_linear(x, x) - return c + x - -def traced_func_importing_caller(x): - k = traced_func_simple_caller(x) - k += traced_func_importing(k, x) - return k - -def traced_func_generator(num): - c = 5 # executed once - for i in range(num): - yield i + c - -def traced_func_calling_generator(): - k = 0 - for i in traced_func_generator(10): - k += i - -def traced_doubler(num): - return num * 2 - -def traced_capturer(*args, **kwargs): - return args, kwargs - -def traced_caller_list_comprehension(): - k = 10 - mylist = [traced_doubler(i) for i in range(k)] - return mylist - -def traced_decorated_function(): - def decorator1(f): - return f - def decorator_fabric(): - def decorator2(f): - return f - return decorator2 - @decorator1 - @decorator_fabric() - def func(): - pass - func() - - -class TracedClass(object): - def __init__(self, x): - self.a = x - - def inst_method_linear(self, y): - return self.a + y - - def inst_method_calling(self, x): - c = self.inst_method_linear(x) - return c + traced_func_linear(x, c) - - @classmethod - def class_method_linear(cls, y): - return y * 2 - - @staticmethod - def static_method_linear(y): - return y * 2 - - -#------------------------------ Test cases -----------------------------------# - - -class TestLineCounts(unittest.TestCase): - """White-box testing of line-counting, via runfunc""" - def setUp(self): - self.addCleanup(sys.settrace, sys.gettrace()) - self.tracer = Trace(count=1, trace=0, countfuncs=0, countcallers=0) - self.my_py_filename = fix_ext_py(__file__) - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_traced_func_linear(self): - result = self.tracer.runfunc(traced_func_linear, 2, 5) - self.assertEqual(result, 7) - - # all lines are executed once - expected = {} - firstlineno = get_firstlineno(traced_func_linear) - for i in range(1, 5): - expected[(self.my_py_filename, firstlineno + i)] = 1 - - self.assertEqual(self.tracer.results().counts, expected) - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_traced_func_loop(self): - self.tracer.runfunc(traced_func_loop, 2, 3) - - firstlineno = get_firstlineno(traced_func_loop) - expected = { - (self.my_py_filename, firstlineno + 1): 1, - (self.my_py_filename, firstlineno + 2): 6, - (self.my_py_filename, firstlineno + 3): 5, - (self.my_py_filename, firstlineno + 4): 1, - } - self.assertEqual(self.tracer.results().counts, expected) - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_traced_func_importing(self): - self.tracer.runfunc(traced_func_importing, 2, 5) - - firstlineno = get_firstlineno(traced_func_importing) - expected = { - (self.my_py_filename, firstlineno + 1): 1, - (fix_ext_py(testmod.__file__), 2): 1, - (fix_ext_py(testmod.__file__), 3): 1, - } - - self.assertEqual(self.tracer.results().counts, expected) - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_trace_func_generator(self): - self.tracer.runfunc(traced_func_calling_generator) - - firstlineno_calling = get_firstlineno(traced_func_calling_generator) - firstlineno_gen = get_firstlineno(traced_func_generator) - expected = { - (self.my_py_filename, firstlineno_calling + 1): 1, - (self.my_py_filename, firstlineno_calling + 2): 11, - (self.my_py_filename, firstlineno_calling + 3): 10, - (self.my_py_filename, firstlineno_gen + 1): 1, - (self.my_py_filename, firstlineno_gen + 2): 11, - (self.my_py_filename, firstlineno_gen + 3): 10, - } - self.assertEqual(self.tracer.results().counts, expected) - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_trace_list_comprehension(self): - self.tracer.runfunc(traced_caller_list_comprehension) - - firstlineno_calling = get_firstlineno(traced_caller_list_comprehension) - firstlineno_called = get_firstlineno(traced_doubler) - expected = { - (self.my_py_filename, firstlineno_calling + 1): 1, - # List compehentions work differently in 3.x, so the count - # below changed compared to 2.x. - (self.my_py_filename, firstlineno_calling + 2): 12, - (self.my_py_filename, firstlineno_calling + 3): 1, - (self.my_py_filename, firstlineno_called + 1): 10, - } - self.assertEqual(self.tracer.results().counts, expected) - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_traced_decorated_function(self): - self.tracer.runfunc(traced_decorated_function) - - firstlineno = get_firstlineno(traced_decorated_function) - expected = { - (self.my_py_filename, firstlineno + 1): 1, - (self.my_py_filename, firstlineno + 2): 1, - (self.my_py_filename, firstlineno + 3): 1, - (self.my_py_filename, firstlineno + 4): 1, - (self.my_py_filename, firstlineno + 5): 1, - (self.my_py_filename, firstlineno + 6): 1, - (self.my_py_filename, firstlineno + 7): 1, - (self.my_py_filename, firstlineno + 8): 1, - (self.my_py_filename, firstlineno + 9): 1, - (self.my_py_filename, firstlineno + 10): 1, - (self.my_py_filename, firstlineno + 11): 1, - } - self.assertEqual(self.tracer.results().counts, expected) - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_linear_methods(self): - # XXX todo: later add 'static_method_linear' and 'class_method_linear' - # here, once issue1764286 is resolved - # - for methname in ['inst_method_linear',]: - tracer = Trace(count=1, trace=0, countfuncs=0, countcallers=0) - traced_obj = TracedClass(25) - method = getattr(traced_obj, methname) - tracer.runfunc(method, 20) - - firstlineno = get_firstlineno(method) - expected = { - (self.my_py_filename, firstlineno + 1): 1, - } - self.assertEqual(tracer.results().counts, expected) - - -class TestRunExecCounts(unittest.TestCase): - """A simple sanity test of line-counting, via runctx (exec)""" - def setUp(self): - self.my_py_filename = fix_ext_py(__file__) - self.addCleanup(sys.settrace, sys.gettrace()) - - # TODO: RUSTPYTHON, KeyError: ('Lib/test/test_trace.py', 43) - @unittest.expectedFailure - def test_exec_counts(self): - self.tracer = Trace(count=1, trace=0, countfuncs=0, countcallers=0) - code = r'''traced_func_loop(2, 5)''' - code = compile(code, __file__, 'exec') - self.tracer.runctx(code, globals(), vars()) - - firstlineno = get_firstlineno(traced_func_loop) - expected = { - (self.my_py_filename, firstlineno + 1): 1, - (self.my_py_filename, firstlineno + 2): 6, - (self.my_py_filename, firstlineno + 3): 5, - (self.my_py_filename, firstlineno + 4): 1, - } - - # When used through 'run', some other spurious counts are produced, like - # the settrace of threading, which we ignore, just making sure that the - # counts fo traced_func_loop were right. - # - for k in expected.keys(): - self.assertEqual(self.tracer.results().counts[k], expected[k]) - - -class TestFuncs(unittest.TestCase): - """White-box testing of funcs tracing""" - def setUp(self): - self.addCleanup(sys.settrace, sys.gettrace()) - self.tracer = Trace(count=0, trace=0, countfuncs=1) - self.filemod = my_file_and_modname() - self._saved_tracefunc = sys.gettrace() - - def tearDown(self): - if self._saved_tracefunc is not None: - sys.settrace(self._saved_tracefunc) - - # TODO: RUSTPYTHON, gc - @unittest.expectedFailure - def test_simple_caller(self): - self.tracer.runfunc(traced_func_simple_caller, 1) - - expected = { - self.filemod + ('traced_func_simple_caller',): 1, - self.filemod + ('traced_func_linear',): 1, - } - self.assertEqual(self.tracer.results().calledfuncs, expected) - - # TODO: RUSTPYTHON, gc - @unittest.expectedFailure - def test_arg_errors(self): - res = self.tracer.runfunc(traced_capturer, 1, 2, self=3, func=4) - self.assertEqual(res, ((1, 2), {'self': 3, 'func': 4})) - with self.assertWarns(DeprecationWarning): - res = self.tracer.runfunc(func=traced_capturer, arg=1) - self.assertEqual(res, ((), {'arg': 1})) - with self.assertRaises(TypeError): - self.tracer.runfunc() - - # TODO: RUSTPYTHON, gc - @unittest.expectedFailure - def test_loop_caller_importing(self): - self.tracer.runfunc(traced_func_importing_caller, 1) - - expected = { - self.filemod + ('traced_func_simple_caller',): 1, - self.filemod + ('traced_func_linear',): 1, - self.filemod + ('traced_func_importing_caller',): 1, - self.filemod + ('traced_func_importing',): 1, - (fix_ext_py(testmod.__file__), 'testmod', 'func'): 1, - } - self.assertEqual(self.tracer.results().calledfuncs, expected) - - # TODO: RUSTPYTHON, gc - @unittest.expectedFailure - @unittest.skipIf(hasattr(sys, 'gettrace') and sys.gettrace(), - 'pre-existing trace function throws off measurements') - def test_inst_method_calling(self): - obj = TracedClass(20) - self.tracer.runfunc(obj.inst_method_calling, 1) - - expected = { - self.filemod + ('TracedClass.inst_method_calling',): 1, - self.filemod + ('TracedClass.inst_method_linear',): 1, - self.filemod + ('traced_func_linear',): 1, - } - self.assertEqual(self.tracer.results().calledfuncs, expected) - - # TODO: RUSTPYTHON, gc - @unittest.expectedFailure - def test_traced_decorated_function(self): - self.tracer.runfunc(traced_decorated_function) - - expected = { - self.filemod + ('traced_decorated_function',): 1, - self.filemod + ('decorator_fabric',): 1, - self.filemod + ('decorator2',): 1, - self.filemod + ('decorator1',): 1, - self.filemod + ('func',): 1, - } - self.assertEqual(self.tracer.results().calledfuncs, expected) - - -class TestCallers(unittest.TestCase): - """White-box testing of callers tracing""" - def setUp(self): - self.addCleanup(sys.settrace, sys.gettrace()) - self.tracer = Trace(count=0, trace=0, countcallers=1) - self.filemod = my_file_and_modname() - - @unittest.skipIf(hasattr(sys, 'gettrace') and sys.gettrace(), - 'pre-existing trace function throws off measurements') - @unittest.skip("TODO: RUSTPYTHON, Error in atexit._run_exitfuncs") - def test_loop_caller_importing(self): - self.tracer.runfunc(traced_func_importing_caller, 1) - - expected = { - ((os.path.splitext(trace.__file__)[0] + '.py', 'trace', 'Trace.runfunc'), - (self.filemod + ('traced_func_importing_caller',))): 1, - ((self.filemod + ('traced_func_simple_caller',)), - (self.filemod + ('traced_func_linear',))): 1, - ((self.filemod + ('traced_func_importing_caller',)), - (self.filemod + ('traced_func_simple_caller',))): 1, - ((self.filemod + ('traced_func_importing_caller',)), - (self.filemod + ('traced_func_importing',))): 1, - ((self.filemod + ('traced_func_importing',)), - (fix_ext_py(testmod.__file__), 'testmod', 'func')): 1, - } - self.assertEqual(self.tracer.results().callers, expected) - - -# Created separately for issue #3821 -class TestCoverage(unittest.TestCase): - def setUp(self): - self.addCleanup(sys.settrace, sys.gettrace()) - - def tearDown(self): - rmtree(TESTFN) - unlink(TESTFN) - - def _coverage(self, tracer, - cmd='import test.support, test.test_pprint;' - 'test.support.run_unittest(test.test_pprint.QueryTestCase)'): - tracer.run(cmd) - r = tracer.results() - r.write_results(show_missing=True, summary=True, coverdir=TESTFN) - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_coverage(self): - tracer = trace.Trace(trace=0, count=1) - with captured_stdout() as stdout: - self._coverage(tracer) - stdout = stdout.getvalue() - self.assertIn("pprint.py", stdout) - self.assertIn("case.py", stdout) # from unittest - files = os.listdir(TESTFN) - self.assertIn("pprint.cover", files) - self.assertIn("unittest.case.cover", files) - - def test_coverage_ignore(self): - # Ignore all files, nothing should be traced nor printed - libpath = os.path.normpath(os.path.dirname(os.__file__)) - # sys.prefix does not work when running from a checkout - tracer = trace.Trace(ignoredirs=[sys.base_prefix, sys.base_exec_prefix, - libpath], trace=0, count=1) - with captured_stdout() as stdout: - self._coverage(tracer) - if os.path.exists(TESTFN): - files = os.listdir(TESTFN) - self.assertEqual(files, ['_importlib.cover']) # Ignore __import__ - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_issue9936(self): - tracer = trace.Trace(trace=0, count=1) - modname = 'test.tracedmodules.testmod' - # Ensure that the module is executed in import - if modname in sys.modules: - del sys.modules[modname] - cmd = ("import test.tracedmodules.testmod as t;" - "t.func(0); t.func2();") - with captured_stdout() as stdout: - self._coverage(tracer, cmd) - stdout.seek(0) - stdout.readline() - coverage = {} - for line in stdout: - lines, cov, module = line.split()[:3] - coverage[module] = (int(lines), int(cov[:-1])) - # XXX This is needed to run regrtest.py as a script - modname = trace._fullmodname(sys.modules[modname].__file__) - self.assertIn(modname, coverage) - self.assertEqual(coverage[modname], (5, 100)) - -### Tests that don't mess with sys.settrace and can be traced -### themselves TODO: Skip tests that do mess with sys.settrace when -### regrtest is invoked with -T option. -class Test_Ignore(unittest.TestCase): - def test_ignored(self): - jn = os.path.join - ignore = trace._Ignore(['x', 'y.z'], [jn('foo', 'bar')]) - self.assertTrue(ignore.names('x.py', 'x')) - self.assertFalse(ignore.names('xy.py', 'xy')) - self.assertFalse(ignore.names('y.py', 'y')) - self.assertTrue(ignore.names(jn('foo', 'bar', 'baz.py'), 'baz')) - self.assertFalse(ignore.names(jn('bar', 'z.py'), 'z')) - # Matched before. - self.assertTrue(ignore.names(jn('bar', 'baz.py'), 'baz')) - -# Created for Issue 31908 -- CLI utility not writing cover files -class TestCoverageCommandLineOutput(unittest.TestCase): - - codefile = 'tmp.py' - coverfile = 'tmp.cover' - - def setUp(self): - with open(self.codefile, 'w', encoding='iso-8859-15') as f: - f.write(textwrap.dedent('''\ - # coding: iso-8859-15 - x = 'spœm' - if []: - print('unreachable') - ''')) - - def tearDown(self): - unlink(self.codefile) - unlink(self.coverfile) - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_cover_files_written_no_highlight(self): - # Test also that the cover file for the trace module is not created - # (issue #34171). - tracedir = os.path.dirname(os.path.abspath(trace.__file__)) - tracecoverpath = os.path.join(tracedir, 'trace.cover') - unlink(tracecoverpath) - - argv = '-m trace --count'.split() + [self.codefile] - status, stdout, stderr = assert_python_ok(*argv) - self.assertEqual(stderr, b'') - self.assertFalse(os.path.exists(tracecoverpath)) - self.assertTrue(os.path.exists(self.coverfile)) - with open(self.coverfile, encoding='iso-8859-15') as f: - self.assertEqual(f.read(), - " # coding: iso-8859-15\n" - " 1: x = 'spœm'\n" - " 1: if []:\n" - " print('unreachable')\n" - ) - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_cover_files_written_with_highlight(self): - argv = '-m trace --count --missing'.split() + [self.codefile] - status, stdout, stderr = assert_python_ok(*argv) - self.assertTrue(os.path.exists(self.coverfile)) - with open(self.coverfile, encoding='iso-8859-15') as f: - self.assertEqual(f.read(), textwrap.dedent('''\ - # coding: iso-8859-15 - 1: x = 'spœm' - 1: if []: - >>>>>> print('unreachable') - ''')) - -class TestCommandLine(unittest.TestCase): - - def test_failures(self): - _errors = ( - (b'progname is missing: required with the main options', '-l', '-T'), - (b'cannot specify both --listfuncs and (--trace or --count)', '-lc'), - (b'argument -R/--no-report: not allowed with argument -r/--report', '-rR'), - (b'must specify one of --trace, --count, --report, --listfuncs, or --trackcalls', '-g'), - (b'-r/--report requires -f/--file', '-r'), - (b'--summary can only be used with --count or --report', '-sT'), - (b'unrecognized arguments: -y', '-y')) - for message, *args in _errors: - *_, stderr = assert_python_failure('-m', 'trace', *args) - self.assertIn(message, stderr) - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_listfuncs_flag_success(self): - filename = TESTFN + '.py' - modulename = os.path.basename(TESTFN) - with open(filename, 'w', encoding='utf-8') as fd: - self.addCleanup(unlink, filename) - fd.write("a = 1\n") - status, stdout, stderr = assert_python_ok('-m', 'trace', '-l', filename, - PYTHONIOENCODING='utf-8') - self.assertIn(b'functions called:', stdout) - expected = f'filename: {filename}, modulename: {modulename}, funcname: ' - self.assertIn(expected.encode(), stdout) - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_sys_argv_list(self): - with open(TESTFN, 'w', encoding='utf-8') as fd: - self.addCleanup(unlink, TESTFN) - fd.write("import sys\n") - fd.write("print(type(sys.argv))\n") - - status, direct_stdout, stderr = assert_python_ok(TESTFN) - status, trace_stdout, stderr = assert_python_ok('-m', 'trace', '-l', TESTFN) - self.assertIn(direct_stdout.strip(), trace_stdout) - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_count_and_summary(self): - filename = f'{TESTFN}.py' - coverfilename = f'{TESTFN}.cover' - modulename = os.path.basename(TESTFN) - with open(filename, 'w', encoding='utf-8') as fd: - self.addCleanup(unlink, filename) - self.addCleanup(unlink, coverfilename) - fd.write(textwrap.dedent("""\ - x = 1 - y = 2 - - def f(): - return x + y - - for i in range(10): - f() - """)) - status, stdout, _ = assert_python_ok('-m', 'trace', '-cs', filename) - stdout = stdout.decode() - self.assertEqual(status, 0) - self.assertIn('lines cov% module (path)', stdout) - self.assertIn(f'6 100% {modulename} ({filename})', stdout) - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_run_as_module(self): - assert_python_ok('-m', 'trace', '-l', '--module', 'timeit', '-n', '1') - assert_python_failure('-m', 'trace', '-l', '--module', 'not_a_module_zzz') - - -if __name__ == '__main__': - unittest.main() +import os +import sys +from test.support import captured_stdout +from test.support.os_helper import TESTFN, TESTFN_UNICODE, FS_NONASCII, rmtree, unlink +from test.support.script_helper import assert_python_ok, assert_python_failure +import textwrap +import unittest + +import trace +from trace import Trace + +from test.tracedmodules import testmod + +#------------------------------- Utilities -----------------------------------# + +def fix_ext_py(filename): + """Given a .pyc filename converts it to the appropriate .py""" + if filename.endswith('.pyc'): + filename = filename[:-1] + return filename + +def my_file_and_modname(): + """The .py file and module name of this file (__file__)""" + modname = os.path.splitext(os.path.basename(__file__))[0] + return fix_ext_py(__file__), modname + +def get_firstlineno(func): + return func.__code__.co_firstlineno + +#-------------------- Target functions for tracing ---------------------------# +# +# The relative line numbers of lines in these functions matter for verifying +# tracing. Please modify the appropriate tests if you change one of the +# functions. Absolute line numbers don't matter. +# + +def traced_func_linear(x, y): + a = x + b = y + c = a + b + return c + +def traced_func_loop(x, y): + c = x + for i in range(5): + c += y + return c + +def traced_func_importing(x, y): + return x + y + testmod.func(1) + +def traced_func_simple_caller(x): + c = traced_func_linear(x, x) + return c + x + +def traced_func_importing_caller(x): + k = traced_func_simple_caller(x) + k += traced_func_importing(k, x) + return k + +def traced_func_generator(num): + c = 5 # executed once + for i in range(num): + yield i + c + +def traced_func_calling_generator(): + k = 0 + for i in traced_func_generator(10): + k += i + +def traced_doubler(num): + return num * 2 + +def traced_capturer(*args, **kwargs): + return args, kwargs + +def traced_caller_list_comprehension(): + k = 10 + mylist = [traced_doubler(i) for i in range(k)] + return mylist + +def traced_decorated_function(): + def decorator1(f): + return f + def decorator_fabric(): + def decorator2(f): + return f + return decorator2 + @decorator1 + @decorator_fabric() + def func(): + pass + func() + + +class TracedClass(object): + def __init__(self, x): + self.a = x + + def inst_method_linear(self, y): + return self.a + y + + def inst_method_calling(self, x): + c = self.inst_method_linear(x) + return c + traced_func_linear(x, c) + + @classmethod + def class_method_linear(cls, y): + return y * 2 + + @staticmethod + def static_method_linear(y): + return y * 2 + + +#------------------------------ Test cases -----------------------------------# + + +class TestLineCounts(unittest.TestCase): + """White-box testing of line-counting, via runfunc""" + def setUp(self): + self.addCleanup(sys.settrace, sys.gettrace()) + self.tracer = Trace(count=1, trace=0, countfuncs=0, countcallers=0) + self.my_py_filename = fix_ext_py(__file__) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_traced_func_linear(self): + result = self.tracer.runfunc(traced_func_linear, 2, 5) + self.assertEqual(result, 7) + + # all lines are executed once + expected = {} + firstlineno = get_firstlineno(traced_func_linear) + for i in range(1, 5): + expected[(self.my_py_filename, firstlineno + i)] = 1 + + self.assertEqual(self.tracer.results().counts, expected) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_traced_func_loop(self): + self.tracer.runfunc(traced_func_loop, 2, 3) + + firstlineno = get_firstlineno(traced_func_loop) + expected = { + (self.my_py_filename, firstlineno + 1): 1, + (self.my_py_filename, firstlineno + 2): 6, + (self.my_py_filename, firstlineno + 3): 5, + (self.my_py_filename, firstlineno + 4): 1, + } + self.assertEqual(self.tracer.results().counts, expected) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_traced_func_importing(self): + self.tracer.runfunc(traced_func_importing, 2, 5) + + firstlineno = get_firstlineno(traced_func_importing) + expected = { + (self.my_py_filename, firstlineno + 1): 1, + (fix_ext_py(testmod.__file__), 2): 1, + (fix_ext_py(testmod.__file__), 3): 1, + } + + self.assertEqual(self.tracer.results().counts, expected) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_trace_func_generator(self): + self.tracer.runfunc(traced_func_calling_generator) + + firstlineno_calling = get_firstlineno(traced_func_calling_generator) + firstlineno_gen = get_firstlineno(traced_func_generator) + expected = { + (self.my_py_filename, firstlineno_calling + 1): 1, + (self.my_py_filename, firstlineno_calling + 2): 11, + (self.my_py_filename, firstlineno_calling + 3): 10, + (self.my_py_filename, firstlineno_gen + 1): 1, + (self.my_py_filename, firstlineno_gen + 2): 11, + (self.my_py_filename, firstlineno_gen + 3): 10, + } + self.assertEqual(self.tracer.results().counts, expected) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_trace_list_comprehension(self): + self.tracer.runfunc(traced_caller_list_comprehension) + + firstlineno_calling = get_firstlineno(traced_caller_list_comprehension) + firstlineno_called = get_firstlineno(traced_doubler) + expected = { + (self.my_py_filename, firstlineno_calling + 1): 1, + # List compehentions work differently in 3.x, so the count + # below changed compared to 2.x. + (self.my_py_filename, firstlineno_calling + 2): 12, + (self.my_py_filename, firstlineno_calling + 3): 1, + (self.my_py_filename, firstlineno_called + 1): 10, + } + self.assertEqual(self.tracer.results().counts, expected) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_traced_decorated_function(self): + self.tracer.runfunc(traced_decorated_function) + + firstlineno = get_firstlineno(traced_decorated_function) + expected = { + (self.my_py_filename, firstlineno + 1): 1, + (self.my_py_filename, firstlineno + 2): 1, + (self.my_py_filename, firstlineno + 3): 1, + (self.my_py_filename, firstlineno + 4): 1, + (self.my_py_filename, firstlineno + 5): 1, + (self.my_py_filename, firstlineno + 6): 1, + (self.my_py_filename, firstlineno + 7): 1, + (self.my_py_filename, firstlineno + 8): 1, + (self.my_py_filename, firstlineno + 9): 1, + (self.my_py_filename, firstlineno + 10): 1, + (self.my_py_filename, firstlineno + 11): 1, + } + self.assertEqual(self.tracer.results().counts, expected) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_linear_methods(self): + # XXX todo: later add 'static_method_linear' and 'class_method_linear' + # here, once issue1764286 is resolved + # + for methname in ['inst_method_linear',]: + tracer = Trace(count=1, trace=0, countfuncs=0, countcallers=0) + traced_obj = TracedClass(25) + method = getattr(traced_obj, methname) + tracer.runfunc(method, 20) + + firstlineno = get_firstlineno(method) + expected = { + (self.my_py_filename, firstlineno + 1): 1, + } + self.assertEqual(tracer.results().counts, expected) + + +class TestRunExecCounts(unittest.TestCase): + """A simple sanity test of line-counting, via runctx (exec)""" + def setUp(self): + self.my_py_filename = fix_ext_py(__file__) + self.addCleanup(sys.settrace, sys.gettrace()) + + # TODO: RUSTPYTHON, KeyError: ('Lib/test/test_trace.py', 43) + @unittest.expectedFailure + def test_exec_counts(self): + self.tracer = Trace(count=1, trace=0, countfuncs=0, countcallers=0) + code = r'''traced_func_loop(2, 5)''' + code = compile(code, __file__, 'exec') + self.tracer.runctx(code, globals(), vars()) + + firstlineno = get_firstlineno(traced_func_loop) + expected = { + (self.my_py_filename, firstlineno + 1): 1, + (self.my_py_filename, firstlineno + 2): 6, + (self.my_py_filename, firstlineno + 3): 5, + (self.my_py_filename, firstlineno + 4): 1, + } + + # When used through 'run', some other spurious counts are produced, like + # the settrace of threading, which we ignore, just making sure that the + # counts fo traced_func_loop were right. + # + for k in expected.keys(): + self.assertEqual(self.tracer.results().counts[k], expected[k]) + + +class TestFuncs(unittest.TestCase): + """White-box testing of funcs tracing""" + def setUp(self): + self.addCleanup(sys.settrace, sys.gettrace()) + self.tracer = Trace(count=0, trace=0, countfuncs=1) + self.filemod = my_file_and_modname() + self._saved_tracefunc = sys.gettrace() + + def tearDown(self): + if self._saved_tracefunc is not None: + sys.settrace(self._saved_tracefunc) + + # TODO: RUSTPYTHON, gc + @unittest.expectedFailure + def test_simple_caller(self): + self.tracer.runfunc(traced_func_simple_caller, 1) + + expected = { + self.filemod + ('traced_func_simple_caller',): 1, + self.filemod + ('traced_func_linear',): 1, + } + self.assertEqual(self.tracer.results().calledfuncs, expected) + + # TODO: RUSTPYTHON, gc + @unittest.expectedFailure + def test_arg_errors(self): + res = self.tracer.runfunc(traced_capturer, 1, 2, self=3, func=4) + self.assertEqual(res, ((1, 2), {'self': 3, 'func': 4})) + with self.assertWarns(DeprecationWarning): + res = self.tracer.runfunc(func=traced_capturer, arg=1) + self.assertEqual(res, ((), {'arg': 1})) + with self.assertRaises(TypeError): + self.tracer.runfunc() + + # TODO: RUSTPYTHON, gc + @unittest.expectedFailure + def test_loop_caller_importing(self): + self.tracer.runfunc(traced_func_importing_caller, 1) + + expected = { + self.filemod + ('traced_func_simple_caller',): 1, + self.filemod + ('traced_func_linear',): 1, + self.filemod + ('traced_func_importing_caller',): 1, + self.filemod + ('traced_func_importing',): 1, + (fix_ext_py(testmod.__file__), 'testmod', 'func'): 1, + } + self.assertEqual(self.tracer.results().calledfuncs, expected) + + # TODO: RUSTPYTHON, gc + @unittest.expectedFailure + @unittest.skipIf(hasattr(sys, 'gettrace') and sys.gettrace(), + 'pre-existing trace function throws off measurements') + def test_inst_method_calling(self): + obj = TracedClass(20) + self.tracer.runfunc(obj.inst_method_calling, 1) + + expected = { + self.filemod + ('TracedClass.inst_method_calling',): 1, + self.filemod + ('TracedClass.inst_method_linear',): 1, + self.filemod + ('traced_func_linear',): 1, + } + self.assertEqual(self.tracer.results().calledfuncs, expected) + + # TODO: RUSTPYTHON, gc + @unittest.expectedFailure + def test_traced_decorated_function(self): + self.tracer.runfunc(traced_decorated_function) + + expected = { + self.filemod + ('traced_decorated_function',): 1, + self.filemod + ('decorator_fabric',): 1, + self.filemod + ('decorator2',): 1, + self.filemod + ('decorator1',): 1, + self.filemod + ('func',): 1, + } + self.assertEqual(self.tracer.results().calledfuncs, expected) + + +class TestCallers(unittest.TestCase): + """White-box testing of callers tracing""" + def setUp(self): + self.addCleanup(sys.settrace, sys.gettrace()) + self.tracer = Trace(count=0, trace=0, countcallers=1) + self.filemod = my_file_and_modname() + + @unittest.skipIf(hasattr(sys, 'gettrace') and sys.gettrace(), + 'pre-existing trace function throws off measurements') + @unittest.skip("TODO: RUSTPYTHON, Error in atexit._run_exitfuncs") + def test_loop_caller_importing(self): + self.tracer.runfunc(traced_func_importing_caller, 1) + + expected = { + ((os.path.splitext(trace.__file__)[0] + '.py', 'trace', 'Trace.runfunc'), + (self.filemod + ('traced_func_importing_caller',))): 1, + ((self.filemod + ('traced_func_simple_caller',)), + (self.filemod + ('traced_func_linear',))): 1, + ((self.filemod + ('traced_func_importing_caller',)), + (self.filemod + ('traced_func_simple_caller',))): 1, + ((self.filemod + ('traced_func_importing_caller',)), + (self.filemod + ('traced_func_importing',))): 1, + ((self.filemod + ('traced_func_importing',)), + (fix_ext_py(testmod.__file__), 'testmod', 'func')): 1, + } + self.assertEqual(self.tracer.results().callers, expected) + + +# Created separately for issue #3821 +class TestCoverage(unittest.TestCase): + def setUp(self): + self.addCleanup(sys.settrace, sys.gettrace()) + + def tearDown(self): + rmtree(TESTFN) + unlink(TESTFN) + + def _coverage(self, tracer, + cmd='import test.support, test.test_pprint;' + 'test.support.run_unittest(test.test_pprint.QueryTestCase)'): + tracer.run(cmd) + r = tracer.results() + r.write_results(show_missing=True, summary=True, coverdir=TESTFN) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_coverage(self): + tracer = trace.Trace(trace=0, count=1) + with captured_stdout() as stdout: + self._coverage(tracer) + stdout = stdout.getvalue() + self.assertIn("pprint.py", stdout) + self.assertIn("case.py", stdout) # from unittest + files = os.listdir(TESTFN) + self.assertIn("pprint.cover", files) + self.assertIn("unittest.case.cover", files) + + def test_coverage_ignore(self): + # Ignore all files, nothing should be traced nor printed + libpath = os.path.normpath(os.path.dirname(os.__file__)) + # sys.prefix does not work when running from a checkout + tracer = trace.Trace(ignoredirs=[sys.base_prefix, sys.base_exec_prefix, + libpath], trace=0, count=1) + with captured_stdout() as stdout: + self._coverage(tracer) + if os.path.exists(TESTFN): + files = os.listdir(TESTFN) + self.assertEqual(files, ['_importlib.cover']) # Ignore __import__ + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_issue9936(self): + tracer = trace.Trace(trace=0, count=1) + modname = 'test.tracedmodules.testmod' + # Ensure that the module is executed in import + if modname in sys.modules: + del sys.modules[modname] + cmd = ("import test.tracedmodules.testmod as t;" + "t.func(0); t.func2();") + with captured_stdout() as stdout: + self._coverage(tracer, cmd) + stdout.seek(0) + stdout.readline() + coverage = {} + for line in stdout: + lines, cov, module = line.split()[:3] + coverage[module] = (int(lines), int(cov[:-1])) + # XXX This is needed to run regrtest.py as a script + modname = trace._fullmodname(sys.modules[modname].__file__) + self.assertIn(modname, coverage) + self.assertEqual(coverage[modname], (5, 100)) + +### Tests that don't mess with sys.settrace and can be traced +### themselves TODO: Skip tests that do mess with sys.settrace when +### regrtest is invoked with -T option. +class Test_Ignore(unittest.TestCase): + def test_ignored(self): + jn = os.path.join + ignore = trace._Ignore(['x', 'y.z'], [jn('foo', 'bar')]) + self.assertTrue(ignore.names('x.py', 'x')) + self.assertFalse(ignore.names('xy.py', 'xy')) + self.assertFalse(ignore.names('y.py', 'y')) + self.assertTrue(ignore.names(jn('foo', 'bar', 'baz.py'), 'baz')) + self.assertFalse(ignore.names(jn('bar', 'z.py'), 'z')) + # Matched before. + self.assertTrue(ignore.names(jn('bar', 'baz.py'), 'baz')) + +# Created for Issue 31908 -- CLI utility not writing cover files +class TestCoverageCommandLineOutput(unittest.TestCase): + + codefile = 'tmp.py' + coverfile = 'tmp.cover' + + def setUp(self): + with open(self.codefile, 'w', encoding='iso-8859-15') as f: + f.write(textwrap.dedent('''\ + # coding: iso-8859-15 + x = 'spœm' + if []: + print('unreachable') + ''')) + + def tearDown(self): + unlink(self.codefile) + unlink(self.coverfile) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_cover_files_written_no_highlight(self): + # Test also that the cover file for the trace module is not created + # (issue #34171). + tracedir = os.path.dirname(os.path.abspath(trace.__file__)) + tracecoverpath = os.path.join(tracedir, 'trace.cover') + unlink(tracecoverpath) + + argv = '-m trace --count'.split() + [self.codefile] + status, stdout, stderr = assert_python_ok(*argv) + self.assertEqual(stderr, b'') + self.assertFalse(os.path.exists(tracecoverpath)) + self.assertTrue(os.path.exists(self.coverfile)) + with open(self.coverfile, encoding='iso-8859-15') as f: + self.assertEqual(f.read(), + " # coding: iso-8859-15\n" + " 1: x = 'spœm'\n" + " 1: if []:\n" + " print('unreachable')\n" + ) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_cover_files_written_with_highlight(self): + argv = '-m trace --count --missing'.split() + [self.codefile] + status, stdout, stderr = assert_python_ok(*argv) + self.assertTrue(os.path.exists(self.coverfile)) + with open(self.coverfile, encoding='iso-8859-15') as f: + self.assertEqual(f.read(), textwrap.dedent('''\ + # coding: iso-8859-15 + 1: x = 'spœm' + 1: if []: + >>>>>> print('unreachable') + ''')) + +class TestCommandLine(unittest.TestCase): + + def test_failures(self): + _errors = ( + (b'progname is missing: required with the main options', '-l', '-T'), + (b'cannot specify both --listfuncs and (--trace or --count)', '-lc'), + (b'argument -R/--no-report: not allowed with argument -r/--report', '-rR'), + (b'must specify one of --trace, --count, --report, --listfuncs, or --trackcalls', '-g'), + (b'-r/--report requires -f/--file', '-r'), + (b'--summary can only be used with --count or --report', '-sT'), + (b'unrecognized arguments: -y', '-y')) + for message, *args in _errors: + *_, stderr = assert_python_failure('-m', 'trace', *args) + self.assertIn(message, stderr) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_listfuncs_flag_success(self): + filename = TESTFN + '.py' + modulename = os.path.basename(TESTFN) + with open(filename, 'w', encoding='utf-8') as fd: + self.addCleanup(unlink, filename) + fd.write("a = 1\n") + status, stdout, stderr = assert_python_ok('-m', 'trace', '-l', filename, + PYTHONIOENCODING='utf-8') + self.assertIn(b'functions called:', stdout) + expected = f'filename: {filename}, modulename: {modulename}, funcname: ' + self.assertIn(expected.encode(), stdout) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_sys_argv_list(self): + with open(TESTFN, 'w', encoding='utf-8') as fd: + self.addCleanup(unlink, TESTFN) + fd.write("import sys\n") + fd.write("print(type(sys.argv))\n") + + status, direct_stdout, stderr = assert_python_ok(TESTFN) + status, trace_stdout, stderr = assert_python_ok('-m', 'trace', '-l', TESTFN) + self.assertIn(direct_stdout.strip(), trace_stdout) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_count_and_summary(self): + filename = f'{TESTFN}.py' + coverfilename = f'{TESTFN}.cover' + modulename = os.path.basename(TESTFN) + with open(filename, 'w', encoding='utf-8') as fd: + self.addCleanup(unlink, filename) + self.addCleanup(unlink, coverfilename) + fd.write(textwrap.dedent("""\ + x = 1 + y = 2 + + def f(): + return x + y + + for i in range(10): + f() + """)) + status, stdout, _ = assert_python_ok('-m', 'trace', '-cs', filename) + stdout = stdout.decode() + self.assertEqual(status, 0) + self.assertIn('lines cov% module (path)', stdout) + self.assertIn(f'6 100% {modulename} ({filename})', stdout) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_run_as_module(self): + assert_python_ok('-m', 'trace', '-l', '--module', 'timeit', '-n', '1') + assert_python_failure('-m', 'trace', '-l', '--module', 'not_a_module_zzz') + + +if __name__ == '__main__': + unittest.main() From c05e9cd52dfcea9608c53de506a4361c2138d6ff Mon Sep 17 00:00:00 2001 From: CPython Developers <> Date: Sun, 14 Aug 2022 23:34:09 +0900 Subject: [PATCH 06/23] Update tarfile from CPython 3.10.6 --- Lib/tarfile.py | 10 ++ Lib/test/test_tarfile.py | 250 +++++++++++++++++++++++++++++++++++---- 2 files changed, 238 insertions(+), 22 deletions(-) diff --git a/Lib/tarfile.py b/Lib/tarfile.py index 6ada9a05db..dea150e8db 100755 --- a/Lib/tarfile.py +++ b/Lib/tarfile.py @@ -1163,6 +1163,11 @@ def _proc_builtin(self, tarfile): # header information. self._apply_pax_info(tarfile.pax_headers, tarfile.encoding, tarfile.errors) + # Remove redundant slashes from directories. This is to be consistent + # with frombuf(). + if self.isdir(): + self.name = self.name.rstrip("/") + return self def _proc_gnulong(self, tarfile): @@ -1185,6 +1190,11 @@ def _proc_gnulong(self, tarfile): elif self.type == GNUTYPE_LONGLINK: next.linkname = nts(buf, tarfile.encoding, tarfile.errors) + # Remove redundant slashes from directories. This is to be consistent + # with frombuf(). + if next.isdir(): + next.name = next.name.removesuffix("/") + return next def _proc_sparse(self, tarfile): diff --git a/Lib/test/test_tarfile.py b/Lib/test/test_tarfile.py index 36838e7f0f..bfd8693998 100644 --- a/Lib/test/test_tarfile.py +++ b/Lib/test/test_tarfile.py @@ -220,6 +220,26 @@ def test_fileobj_symlink2(self): def test_issue14160(self): self._test_fileobj_link("symtype2", "ustar/regtype") + def test_add_dir_getmember(self): + # bpo-21987 + self.add_dir_and_getmember('bar') + self.add_dir_and_getmember('a'*101) + + def add_dir_and_getmember(self, name): + with os_helper.temp_cwd(): + with tarfile.open(tmpname, 'w') as tar: + tar.format = tarfile.USTAR_FORMAT + try: + os.mkdir(name) + tar.add(name) + finally: + os.rmdir(name) + with tarfile.open(tmpname) as tar: + self.assertEqual( + tar.getmember(name), + tar.getmember(name + '/') + ) + class GzipUstarReadTest(GzipTest, UstarReadTest): pass @@ -330,6 +350,38 @@ class LzmaListTest(LzmaTest, ListTest): class CommonReadTest(ReadTest): + def test_is_tarfile_erroneous(self): + with open(tmpname, "wb"): + pass + + # is_tarfile works on filenames + self.assertFalse(tarfile.is_tarfile(tmpname)) + + # is_tarfile works on path-like objects + self.assertFalse(tarfile.is_tarfile(pathlib.Path(tmpname))) + + # is_tarfile works on file objects + with open(tmpname, "rb") as fobj: + self.assertFalse(tarfile.is_tarfile(fobj)) + + # is_tarfile works on file-like objects + self.assertFalse(tarfile.is_tarfile(io.BytesIO(b"invalid"))) + + def test_is_tarfile_valid(self): + # is_tarfile works on filenames + self.assertTrue(tarfile.is_tarfile(self.tarname)) + + # is_tarfile works on path-like objects + self.assertTrue(tarfile.is_tarfile(pathlib.Path(self.tarname))) + + # is_tarfile works on file objects + with open(self.tarname, "rb") as fobj: + self.assertTrue(tarfile.is_tarfile(fobj)) + + # is_tarfile works on file-like objects + with open(self.tarname, "rb") as fobj: + self.assertTrue(tarfile.is_tarfile(io.BytesIO(fobj.read()))) + def test_empty_tarfile(self): # Test for issue6123: Allow opening empty archives. # This test checks if tarfile.open() is able to open an empty tar @@ -365,7 +417,7 @@ def test_null_tarfile(self): def test_ignore_zeros(self): # Test TarFile's ignore_zeros option. # generate 512 pseudorandom bytes - data = Random(0).getrandbits(512*8).to_bytes(512, 'big') + data = Random(0).randbytes(512) for char in (b'\0', b'a'): # Test if EOFHeaderError ('\0') and InvalidHeaderError ('a') # are ignored correctly. @@ -682,6 +734,16 @@ def test_parallel_iteration(self): self.assertEqual(m1.offset, m2.offset) self.assertEqual(m1.get_info(), m2.get_info()) + @unittest.skipIf(zlib is None, "requires zlib") + def test_zlib_error_does_not_leak(self): + # bpo-39039: tarfile.open allowed zlib exceptions to bubble up when + # parsing certain types of invalid data + with unittest.mock.patch("tarfile.TarInfo.fromtarfile") as mock: + mock.side_effect = zlib.error + with self.assertRaises(tarfile.ReadError): + tarfile.open(self.tarname) + + class MiscReadTest(MiscReadTestBase, unittest.TestCase): test_fail_comp = None @@ -968,11 +1030,26 @@ def test_header_offset(self): "iso8859-1", "strict") self.assertEqual(tarinfo.type, self.longnametype) + def test_longname_directory(self): + # Test reading a longlink directory. Issue #47231. + longdir = ('a' * 101) + '/' + with os_helper.temp_cwd(): + with tarfile.open(tmpname, 'w') as tar: + tar.format = self.format + try: + os.mkdir(longdir) + tar.add(longdir) + finally: + os.rmdir(longdir) + with tarfile.open(tmpname) as tar: + self.assertIsNotNone(tar.getmember(longdir)) + self.assertIsNotNone(tar.getmember(longdir.removesuffix('/'))) class GNUReadTest(LongnameTest, ReadTest, unittest.TestCase): subdir = "gnu" longnametype = tarfile.GNUTYPE_LONGNAME + format = tarfile.GNU_FORMAT # Since 3.2 tarfile is supposed to accurately restore sparse members and # produce files with holes. This is what we actually want to test here. @@ -1048,6 +1125,7 @@ class PaxReadTest(LongnameTest, ReadTest, unittest.TestCase): subdir = "pax" longnametype = tarfile.XHDTYPE + format = tarfile.PAX_FORMAT def test_pax_global_headers(self): tar = tarfile.open(tarname, encoding="iso8859-1") @@ -1600,6 +1678,52 @@ def test_longnamelink_1025(self): ("longlnk/" * 127) + "longlink_") +class DeviceHeaderTest(WriteTestBase, unittest.TestCase): + + prefix = "w:" + + def test_headers_written_only_for_device_files(self): + # Regression test for bpo-18819. + tempdir = os.path.join(TEMPDIR, "device_header_test") + os.mkdir(tempdir) + try: + tar = tarfile.open(tmpname, self.mode) + try: + input_blk = tarfile.TarInfo(name="my_block_device") + input_reg = tarfile.TarInfo(name="my_regular_file") + input_blk.type = tarfile.BLKTYPE + input_reg.type = tarfile.REGTYPE + tar.addfile(input_blk) + tar.addfile(input_reg) + finally: + tar.close() + + # devmajor and devminor should be *interpreted* as 0 in both... + tar = tarfile.open(tmpname, "r") + try: + output_blk = tar.getmember("my_block_device") + output_reg = tar.getmember("my_regular_file") + finally: + tar.close() + self.assertEqual(output_blk.devmajor, 0) + self.assertEqual(output_blk.devminor, 0) + self.assertEqual(output_reg.devmajor, 0) + self.assertEqual(output_reg.devminor, 0) + + # ...but the fields should not actually be set on regular files: + with open(tmpname, "rb") as infile: + buf = infile.read() + buf_blk = buf[output_blk.offset:output_blk.offset_data] + buf_reg = buf[output_reg.offset:output_reg.offset_data] + # See `struct posixheader` in GNU docs for byte offsets: + # + device_headers = slice(329, 329 + 16) + self.assertEqual(buf_blk[device_headers], b"0000000\0" * 2) + self.assertEqual(buf_reg[device_headers], b"\0" * 16) + finally: + os_helper.rmtree(tempdir) + + class CreateTest(WriteTestBase, unittest.TestCase): prefix = "x:" @@ -1691,15 +1815,30 @@ def test_create_taropen_pathlike_name(self): class GzipCreateTest(GzipTest, CreateTest): - pass + + def test_create_with_compresslevel(self): + with tarfile.open(tmpname, self.mode, compresslevel=1) as tobj: + tobj.add(self.file_path) + with tarfile.open(tmpname, 'r:gz', compresslevel=1) as tobj: + pass class Bz2CreateTest(Bz2Test, CreateTest): - pass + + def test_create_with_compresslevel(self): + with tarfile.open(tmpname, self.mode, compresslevel=1) as tobj: + tobj.add(self.file_path) + with tarfile.open(tmpname, 'r:bz2', compresslevel=1) as tobj: + pass class LzmaCreateTest(LzmaTest, CreateTest): - pass + + # Unlike gz and bz2, xz uses the preset keyword instead of compresslevel. + # It does not allow for preset to be specified when reading. + def test_create_with_preset(self): + with tarfile.open(tmpname, self.mode, preset=1) as tobj: + tobj.add(self.file_path) class CreateWithXModeTest(CreateTest): @@ -1840,6 +1979,61 @@ def test_pax_extended_header(self): finally: tar.close() + def test_create_pax_header(self): + # The ustar header should contain values that can be + # represented reasonably, even if a better (e.g. higher + # precision) version is set in the pax header. + # Issue #45863 + + # values that should be kept + t = tarfile.TarInfo() + t.name = "foo" + t.mtime = 1000.1 + t.size = 100 + t.uid = 123 + t.gid = 124 + info = t.get_info() + header = t.create_pax_header(info, encoding="iso8859-1") + self.assertEqual(info['name'], "foo") + # mtime should be rounded to nearest second + self.assertIsInstance(info['mtime'], int) + self.assertEqual(info['mtime'], 1000) + self.assertEqual(info['size'], 100) + self.assertEqual(info['uid'], 123) + self.assertEqual(info['gid'], 124) + self.assertEqual(header, + b'././@PaxHeader' + bytes(86) \ + + b'0000000\x000000000\x000000000\x0000000000020\x0000000000000\x00010205\x00 x' \ + + bytes(100) + b'ustar\x0000'+ bytes(247) \ + + b'16 mtime=1000.1\n' + bytes(496) + b'foo' + bytes(97) \ + + b'0000644\x000000173\x000000174\x0000000000144\x0000000001750\x00006516\x00 0' \ + + bytes(100) + b'ustar\x0000' + bytes(247)) + + # values that should be changed + t = tarfile.TarInfo() + t.name = "foo\u3374" # can't be represented in ascii + t.mtime = 10**10 # too big + t.size = 10**10 # too big + t.uid = 8**8 # too big + t.gid = 8**8+1 # too big + info = t.get_info() + header = t.create_pax_header(info, encoding="iso8859-1") + # name is kept as-is in info but should be added to pax header + self.assertEqual(info['name'], "foo\u3374") + self.assertEqual(info['mtime'], 0) + self.assertEqual(info['size'], 0) + self.assertEqual(info['uid'], 0) + self.assertEqual(info['gid'], 0) + self.assertEqual(header, + b'././@PaxHeader' + bytes(86) \ + + b'0000000\x000000000\x000000000\x0000000000130\x0000000000000\x00010207\x00 x' \ + + bytes(100) + b'ustar\x0000' + bytes(247) \ + + b'15 path=foo\xe3\x8d\xb4\n16 uid=16777216\n' \ + + b'16 gid=16777217\n20 size=10000000000\n' \ + + b'21 mtime=10000000000\n'+ bytes(424) + b'foo?' + bytes(96) \ + + b'0000644\x000000000\x000000000\x0000000000000\x0000000000000\x00006540\x00 0' \ + + bytes(100) + b'ustar\x0000' + bytes(247)) + class UnicodeTest: @@ -2274,23 +2468,32 @@ def test_number_field_limits(self): tarfile.itn(0x10000000000, 6, tarfile.GNU_FORMAT) def test__all__(self): - not_exported = {'version', 'grp', 'pwd', 'symlink_exception', - 'NUL', 'BLOCKSIZE', 'RECORDSIZE', 'GNU_MAGIC', - 'POSIX_MAGIC', 'LENGTH_NAME', 'LENGTH_LINK', - 'LENGTH_PREFIX', 'REGTYPE', 'AREGTYPE', 'LNKTYPE', - 'SYMTYPE', 'CHRTYPE', 'BLKTYPE', 'DIRTYPE', 'FIFOTYPE', - 'CONTTYPE', 'GNUTYPE_LONGNAME', 'GNUTYPE_LONGLINK', - 'GNUTYPE_SPARSE', 'XHDTYPE', 'XGLTYPE', 'SOLARIS_XHDTYPE', - 'SUPPORTED_TYPES', 'REGULAR_TYPES', 'GNU_TYPES', - 'PAX_FIELDS', 'PAX_NAME_FIELDS', 'PAX_NUMBER_FIELDS', - 'stn', 'nts', 'nti', 'itn', 'calc_chksums', 'copyfileobj', - 'filemode', - 'EmptyHeaderError', 'TruncatedHeaderError', - 'EOFHeaderError', 'InvalidHeaderError', - 'SubsequentHeaderError', 'ExFileObject', - 'main'} + not_exported = { + 'version', 'grp', 'pwd', 'symlink_exception', 'NUL', 'BLOCKSIZE', + 'RECORDSIZE', 'GNU_MAGIC', 'POSIX_MAGIC', 'LENGTH_NAME', + 'LENGTH_LINK', 'LENGTH_PREFIX', 'REGTYPE', 'AREGTYPE', 'LNKTYPE', + 'SYMTYPE', 'CHRTYPE', 'BLKTYPE', 'DIRTYPE', 'FIFOTYPE', 'CONTTYPE', + 'GNUTYPE_LONGNAME', 'GNUTYPE_LONGLINK', 'GNUTYPE_SPARSE', + 'XHDTYPE', 'XGLTYPE', 'SOLARIS_XHDTYPE', 'SUPPORTED_TYPES', + 'REGULAR_TYPES', 'GNU_TYPES', 'PAX_FIELDS', 'PAX_NAME_FIELDS', + 'PAX_NUMBER_FIELDS', 'stn', 'nts', 'nti', 'itn', 'calc_chksums', + 'copyfileobj', 'filemode', 'EmptyHeaderError', + 'TruncatedHeaderError', 'EOFHeaderError', 'InvalidHeaderError', + 'SubsequentHeaderError', 'ExFileObject', 'main'} support.check__all__(self, tarfile, not_exported=not_exported) + def test_useful_error_message_when_modules_missing(self): + fname = os.path.join(os.path.dirname(__file__), 'testtar.tar.xz') + with self.assertRaises(tarfile.ReadError) as excinfo: + error = tarfile.CompressionError('lzma module is not available'), + with unittest.mock.patch.object(tarfile.TarFile, 'xzopen', side_effect=error): + tarfile.open(fname) + + self.assertIn( + "\n- method xz: CompressionError('lzma module is not available')\n", + str(excinfo.exception), + ) + class CommandLineTest(unittest.TestCase): @@ -2330,7 +2533,8 @@ def test_test_command(self): def test_test_command_verbose(self): for tar_name in testtarnames: for opt in '-v', '--verbose': - out = self.tarfilecmd(opt, '-t', tar_name) + out = self.tarfilecmd(opt, '-t', tar_name, + PYTHONIOENCODING='utf-8') self.assertIn(b'is a tar archive.\n', out) def test_test_command_invalid_file(self): @@ -2405,7 +2609,8 @@ def test_create_command_verbose(self): 'and-utf8-bom-sig-only.txt')] for opt in '-v', '--verbose': try: - out = self.tarfilecmd(opt, '-c', tmpname, *files) + out = self.tarfilecmd(opt, '-c', tmpname, *files, + PYTHONIOENCODING='utf-8') self.assertIn(b' file created.', out) with tarfile.open(tmpname) as tar: tar.getmembers() @@ -2463,7 +2668,8 @@ def test_extract_command_verbose(self): for opt in '-v', '--verbose': try: with os_helper.temp_cwd(tarextdir): - out = self.tarfilecmd(opt, '-e', tmpname) + out = self.tarfilecmd(opt, '-e', tmpname, + PYTHONIOENCODING='utf-8') self.assertIn(b' file is extracted.', out) finally: os_helper.rmtree(tarextdir) From 43b095f18de9edc0a825800fc1620580249a24cb Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Sun, 14 Aug 2022 23:35:30 +0900 Subject: [PATCH 07/23] mark failing test --- Lib/test/test_tarfile.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Lib/test/test_tarfile.py b/Lib/test/test_tarfile.py index bfd8693998..ea701cf96f 100644 --- a/Lib/test/test_tarfile.py +++ b/Lib/test/test_tarfile.py @@ -2482,6 +2482,8 @@ def test__all__(self): 'SubsequentHeaderError', 'ExFileObject', 'main'} support.check__all__(self, tarfile, not_exported=not_exported) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_useful_error_message_when_modules_missing(self): fname = os.path.join(os.path.dirname(__file__), 'testtar.tar.xz') with self.assertRaises(tarfile.ReadError) as excinfo: From a01511cf604e32712f056448f4e0ed1cb9ed3b43 Mon Sep 17 00:00:00 2001 From: CPython Developers <> Date: Sun, 14 Aug 2022 23:37:03 +0900 Subject: [PATCH 08/23] Update test_structseq from CPython 3.10.6 --- Lib/test/test_structseq.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/Lib/test/test_structseq.py b/Lib/test/test_structseq.py index 2a328c5d7f..013797e730 100644 --- a/Lib/test/test_structseq.py +++ b/Lib/test/test_structseq.py @@ -124,5 +124,17 @@ def test_extended_getslice(self): self.assertEqual(list(t[start:stop:step]), L[start:stop:step]) + def test_match_args(self): + expected_args = ('tm_year', 'tm_mon', 'tm_mday', 'tm_hour', 'tm_min', + 'tm_sec', 'tm_wday', 'tm_yday', 'tm_isdst') + self.assertEqual(time.struct_time.__match_args__, expected_args) + + def test_match_args_with_unnamed_fields(self): + expected_args = ('st_mode', 'st_ino', 'st_dev', 'st_nlink', 'st_uid', + 'st_gid', 'st_size') + self.assertEqual(os.stat_result.n_unnamed_fields, 3) + self.assertEqual(os.stat_result.__match_args__, expected_args) + + if __name__ == "__main__": unittest.main() From a60cceac53213b75c2f606e6ab188a0c2eb91d50 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Sun, 14 Aug 2022 23:38:03 +0900 Subject: [PATCH 09/23] mark failing tests frmo test_structseq --- Lib/test/test_structseq.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/Lib/test/test_structseq.py b/Lib/test/test_structseq.py index 013797e730..43e4cb51ea 100644 --- a/Lib/test/test_structseq.py +++ b/Lib/test/test_structseq.py @@ -124,11 +124,15 @@ def test_extended_getslice(self): self.assertEqual(list(t[start:stop:step]), L[start:stop:step]) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_match_args(self): expected_args = ('tm_year', 'tm_mon', 'tm_mday', 'tm_hour', 'tm_min', 'tm_sec', 'tm_wday', 'tm_yday', 'tm_isdst') self.assertEqual(time.struct_time.__match_args__, expected_args) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_match_args_with_unnamed_fields(self): expected_args = ('st_mode', 'st_ino', 'st_dev', 'st_nlink', 'st_uid', 'st_gid', 'st_size') From 193fa88ea91abcf64c1e944d080e6dcacabf339a Mon Sep 17 00:00:00 2001 From: CPython Developers <> Date: Sun, 14 Aug 2022 23:41:37 +0900 Subject: [PATCH 10/23] Update test_struct from CPython 3.10.6 --- Lib/test/test_struct.py | 60 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/Lib/test/test_struct.py b/Lib/test/test_struct.py index 6ce1cfe1d2..89a330ba93 100644 --- a/Lib/test/test_struct.py +++ b/Lib/test/test_struct.py @@ -1,12 +1,16 @@ from collections import abc import array +import gc import math import operator import unittest import struct import sys +import weakref from test import support +from test.support import import_helper +from test.support.script_helper import assert_python_ok ISBIGENDIAN = sys.byteorder == "big" @@ -654,6 +658,58 @@ def test_format_attr(self): s2 = struct.Struct(s.format.encode()) self.assertEqual(s2.format, s.format) + def test_struct_cleans_up_at_runtime_shutdown(self): + code = """if 1: + import struct + + class C: + def __init__(self): + self.pack = struct.pack + def __del__(self): + self.pack('I', -42) + + struct.x = C() + """ + rc, stdout, stderr = assert_python_ok("-c", code) + self.assertEqual(rc, 0) + self.assertEqual(stdout.rstrip(), b"") + self.assertIn(b"Exception ignored in:", stderr) + self.assertIn(b"C.__del__", stderr) + + def test__struct_reference_cycle_cleaned_up(self): + # Regression test for python/cpython#94207. + + # When we create a new struct module, trigger use of its cache, + # and then delete it ... + _struct_module = import_helper.import_fresh_module("_struct") + module_ref = weakref.ref(_struct_module) + _struct_module.calcsize("b") + del _struct_module + + # Then the module should have been garbage collected. + gc.collect() + self.assertIsNone( + module_ref(), "_struct module was not garbage collected") + + @support.cpython_only + def test__struct_types_immutable(self): + # See https://github.com/python/cpython/issues/94254 + + Struct = struct.Struct + unpack_iterator = type(struct.iter_unpack("b", b'x')) + for cls in (Struct, unpack_iterator): + with self.subTest(cls=cls): + with self.assertRaises(TypeError): + cls.x = 1 + + + def test_issue35714(self): + # Embedded null characters should not be allowed in format strings. + for s in '\0', '2\0i', b'\0': + with self.assertRaisesRegex(struct.error, + 'embedded null character'): + struct.calcsize(s) + class UnpackIteratorTest(unittest.TestCase): """ @@ -681,6 +737,10 @@ def _check_iterator(it): with self.assertRaises(struct.error): s.iter_unpack(b"12") + def test_uninstantiable(self): + iter_unpack_type = type(struct.Struct(">ibcp").iter_unpack(b"")) + self.assertRaises(TypeError, iter_unpack_type) + def test_iterate(self): s = struct.Struct('>IB') b = bytes(range(1, 16)) From 14e86e1374626e89e7054295595663905adbd415 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Sun, 14 Aug 2022 23:43:34 +0900 Subject: [PATCH 11/23] mark filing tests for test_struct --- Lib/test/test_struct.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/Lib/test/test_struct.py b/Lib/test/test_struct.py index 89a330ba93..6cb7b610e5 100644 --- a/Lib/test/test_struct.py +++ b/Lib/test/test_struct.py @@ -658,6 +658,8 @@ def test_format_attr(self): s2 = struct.Struct(s.format.encode()) self.assertEqual(s2.format, s.format) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_struct_cleans_up_at_runtime_shutdown(self): code = """if 1: import struct @@ -703,6 +705,8 @@ def test__struct_types_immutable(self): cls.x = 1 + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_issue35714(self): # Embedded null characters should not be allowed in format strings. for s in '\0', '2\0i', b'\0': @@ -737,6 +741,8 @@ def _check_iterator(it): with self.assertRaises(struct.error): s.iter_unpack(b"12") + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_uninstantiable(self): iter_unpack_type = type(struct.Struct(">ibcp").iter_unpack(b"")) self.assertRaises(TypeError, iter_unpack_type) From 17e12dea1ec11b871a9e5d3a4a26e2f9ad83fde2 Mon Sep 17 00:00:00 2001 From: CPython Developers <> Date: Sun, 14 Aug 2022 23:52:26 +0900 Subject: [PATCH 12/23] Update string, test_str* from cpython 3.10.6 --- Lib/string.py | 46 +++++++++++++++----------------- Lib/test/test_strftime.py | 2 +- Lib/test/test_string_literals.py | 3 +-- 3 files changed, 24 insertions(+), 27 deletions(-) diff --git a/Lib/string.py b/Lib/string.py index b423ff5dc6..489777b10c 100644 --- a/Lib/string.py +++ b/Lib/string.py @@ -54,30 +54,7 @@ def capwords(s, sep=None): _sentinel_dict = {} -class _TemplateMetaclass(type): - pattern = r""" - %(delim)s(?: - (?P%(delim)s) | # Escape sequence of two delimiters - (?P%(id)s) | # delimiter and a Python identifier - {(?P%(bid)s)} | # delimiter and a braced identifier - (?P) # Other ill-formed delimiter exprs - ) - """ - - def __init__(cls, name, bases, dct): - super(_TemplateMetaclass, cls).__init__(name, bases, dct) - if 'pattern' in dct: - pattern = cls.pattern - else: - pattern = _TemplateMetaclass.pattern % { - 'delim' : _re.escape(cls.delimiter), - 'id' : cls.idpattern, - 'bid' : cls.braceidpattern or cls.idpattern, - } - cls.pattern = _re.compile(pattern, cls.flags | _re.VERBOSE) - - -class Template(metaclass=_TemplateMetaclass): +class Template: """A string class for supporting $-substitutions.""" delimiter = '$' @@ -89,6 +66,24 @@ class Template(metaclass=_TemplateMetaclass): braceidpattern = None flags = _re.IGNORECASE + def __init_subclass__(cls): + super().__init_subclass__() + if 'pattern' in cls.__dict__: + pattern = cls.pattern + else: + delim = _re.escape(cls.delimiter) + id = cls.idpattern + bid = cls.braceidpattern or cls.idpattern + pattern = fr""" + {delim}(?: + (?P{delim}) | # Escape sequence of two delimiters + (?P{id}) | # delimiter and a Python identifier + {{(?P{bid})}} | # delimiter and a braced identifier + (?P) # Other ill-formed delimiter exprs + ) + """ + cls.pattern = _re.compile(pattern, cls.flags | _re.VERBOSE) + def __init__(self, template): self.template = template @@ -146,6 +141,9 @@ def convert(mo): self.pattern) return self.pattern.sub(convert, self.template) +# Initialize Template.pattern. __init_subclass__() is automatically called +# only for subclasses, not for the Template class itself. +Template.__init_subclass__() ######################################################################## diff --git a/Lib/test/test_strftime.py b/Lib/test/test_strftime.py index 16440ef3bd..f4c0bd0de9 100644 --- a/Lib/test/test_strftime.py +++ b/Lib/test/test_strftime.py @@ -115,7 +115,7 @@ def strftest1(self, now): ) for e in expectations: - # musn't raise a value error + # mustn't raise a value error try: result = time.strftime(e[0], now) except ValueError as error: diff --git a/Lib/test/test_string_literals.py b/Lib/test/test_string_literals.py index f117428152..8a822dd2f7 100644 --- a/Lib/test/test_string_literals.py +++ b/Lib/test/test_string_literals.py @@ -63,8 +63,6 @@ def byte(i): class TestLiterals(unittest.TestCase): - from test.support.warnings_helper import check_syntax_warning - def setUp(self): self.save_path = sys.path[:] self.tmpdir = tempfile.mkdtemp() @@ -133,6 +131,7 @@ def test_eval_str_invalid_escape(self): self.assertEqual(w, []) self.assertEqual(exc.filename, '') self.assertEqual(exc.lineno, 1) + self.assertEqual(exc.offset, 1) # TODO: RUSTPYTHON @unittest.expectedFailure From 3fd08b32ef0fcdba113aaeb9e93c3b35a53251ef Mon Sep 17 00:00:00 2001 From: CPython Developers <> Date: Sun, 14 Aug 2022 23:54:55 +0900 Subject: [PATCH 13/23] Update stat, statistics from CPython 3.10.6 --- Lib/test/test_stat.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/Lib/test/test_stat.py b/Lib/test/test_stat.py index cc2c85bfb9..c1edea8491 100644 --- a/Lib/test/test_stat.py +++ b/Lib/test/test_stat.py @@ -2,9 +2,11 @@ import os import socket import sys -from test.support.socket_helper import skip_unless_bind_unix_socket -from test.support.os_helper import TESTFN +from test.support import os_helper +from test.support import socket_helper from test.support.import_helper import import_fresh_module +from test.support.os_helper import TESTFN + c_stat = import_fresh_module('stat', fresh=['_stat']) py_stat = import_fresh_module('stat', blocked=['_stat']) @@ -172,11 +174,16 @@ def test_link(self): @unittest.skipUnless(hasattr(os, 'mkfifo'), 'os.mkfifo not available') def test_fifo(self): + if sys.platform == "vxworks": + fifo_path = os.path.join("/fifos/", TESTFN) + else: + fifo_path = TESTFN + self.addCleanup(os_helper.unlink, fifo_path) try: - os.mkfifo(TESTFN, 0o700) + os.mkfifo(fifo_path, 0o700) except PermissionError as e: self.skipTest('os.mkfifo(): %s' % e) - st_mode, modestr = self.get_mode() + st_mode, modestr = self.get_mode(fifo_path) self.assertEqual(modestr, 'prwx------') self.assertS_IS("FIFO", st_mode) @@ -194,7 +201,7 @@ def test_devices(self): self.assertS_IS("BLK", st_mode) break - @skip_unless_bind_unix_socket + @socket_helper.skip_unless_bind_unix_socket def test_socket(self): with socket.socket(socket.AF_UNIX) as s: s.bind(TESTFN) From 939257411bde3f48c9dc08ec59b6103ccec5f727 Mon Sep 17 00:00:00 2001 From: CPython Developers <> Date: Mon, 15 Aug 2022 00:03:36 +0900 Subject: [PATCH 14/23] Update socket{server} from CPython 3.10.6 --- Lib/socket.py | 49 ++- Lib/socketserver.py | 72 +++- Lib/test/test_socket.py | 617 ++++++++++++++++++++++++++-------- Lib/test/test_socketserver.py | 13 +- 4 files changed, 588 insertions(+), 163 deletions(-) diff --git a/Lib/socket.py b/Lib/socket.py index f83f36d0ad..63ba0acc90 100644 --- a/Lib/socket.py +++ b/Lib/socket.py @@ -12,6 +12,8 @@ socket() -- create a new socket object socketpair() -- create a pair of new socket objects [*] fromfd() -- create a socket object from an open file descriptor [*] +send_fds() -- Send file descriptor to the socket. +recv_fds() -- Recieve file descriptors from the socket. fromshare() -- create a socket object from data received from socket.share() [*] gethostname() -- return the current hostname gethostbyname() -- map a hostname to its IP number @@ -104,7 +106,6 @@ def _intenum_converter(value, enum_klass): except ValueError: return value -_realsocket = socket # WSA error codes if sys.platform.lower().startswith("win"): @@ -336,6 +337,7 @@ def makefile(self, mode="r", buffering=None, *, buffer = io.BufferedWriter(raw, buffering) if binary: return buffer + encoding = io.text_encoding(encoding) text = io.TextIOWrapper(buffer, encoding, errors, newline) text.mode = mode return text @@ -376,7 +378,7 @@ def _sendfile_use_sendfile(self, file, offset=0, count=None): try: while True: if timeout and not selector_select(timeout): - raise _socket.timeout('timed out') + raise TimeoutError('timed out') if count: blocksize = count - total_sent if blocksize <= 0: @@ -543,6 +545,40 @@ def fromfd(fd, family, type, proto=0): nfd = dup(fd) return socket(family, type, proto, nfd) +if hasattr(_socket.socket, "sendmsg"): + import array + + def send_fds(sock, buffers, fds, flags=0, address=None): + """ send_fds(sock, buffers, fds[, flags[, address]]) -> integer + + Send the list of file descriptors fds over an AF_UNIX socket. + """ + return sock.sendmsg(buffers, [(_socket.SOL_SOCKET, + _socket.SCM_RIGHTS, array.array("i", fds))]) + __all__.append("send_fds") + +if hasattr(_socket.socket, "recvmsg"): + import array + + def recv_fds(sock, bufsize, maxfds, flags=0): + """ recv_fds(sock, bufsize, maxfds[, flags]) -> (data, list of file + descriptors, msg_flags, address) + + Receive up to maxfds file descriptors returning the message + data and a list containing the descriptors. + """ + # Array of ints + fds = array.array("i") + msg, ancdata, flags, addr = sock.recvmsg(bufsize, + _socket.CMSG_LEN(maxfds * fds.itemsize)) + for cmsg_level, cmsg_type, cmsg_data in ancdata: + if (cmsg_level == _socket.SOL_SOCKET and cmsg_type == _socket.SCM_RIGHTS): + fds.frombytes(cmsg_data[: + len(cmsg_data) - (len(cmsg_data) % fds.itemsize)]) + + return msg, list(fds), flags, addr + __all__.append("recv_fds") + if hasattr(_socket.socket, "share"): def fromshare(info): """ fromshare(info) -> socket object @@ -671,7 +707,7 @@ def readinto(self, b): self._timeout_occurred = True raise except error as e: - if e.args[0] in _blocking_errnos: + if e.errno in _blocking_errnos: return None raise @@ -687,7 +723,7 @@ def write(self, b): return self._sock.send(b) except error as e: # XXX what about EINTR? - if e.args[0] in _blocking_errnos: + if e.errno in _blocking_errnos: return None raise @@ -746,8 +782,9 @@ def getfqdn(name=''): An empty argument is interpreted as meaning the local host. First the hostname returned by gethostbyaddr() is checked, then - possibly existing aliases. In case no FQDN is available, hostname - from gethostname() is returned. + possibly existing aliases. In case no FQDN is available and `name` + was given, it is returned unchanged. If `name` was empty or '0.0.0.0', + hostname from gethostname() is returned. """ name = name.strip() if not name or name == '0.0.0.0': diff --git a/Lib/socketserver.py b/Lib/socketserver.py index b654c06d84..2905e3eac3 100644 --- a/Lib/socketserver.py +++ b/Lib/socketserver.py @@ -24,7 +24,7 @@ The classes in this module favor the server type that is simplest to write: a synchronous TCP/IP server. This is bad class design, but -save some typing. (There's also the issue that a deep class hierarchy +saves some typing. (There's also the issue that a deep class hierarchy slows down method lookups.) There are five classes in an inheritance diagram, four of which represent @@ -126,7 +126,6 @@ class will essentially render the service "deaf" while one request is import socket import selectors import os -import errno import sys try: import threading @@ -234,6 +233,9 @@ def serve_forever(self, poll_interval=0.5): while not self.__shutdown_request: ready = selector.select(poll_interval) + # bpo-35017: shutdown() called during select(), exit immediately. + if self.__shutdown_request: + break if ready: self._handle_request_noblock() @@ -375,7 +377,7 @@ def handle_error(self, request, client_address): """ print('-'*40, file=sys.stderr) - print('Exception happened during processing of request from', + print('Exception occurred during processing of request from', client_address, file=sys.stderr) import traceback traceback.print_exc() @@ -547,8 +549,10 @@ class ForkingMixIn: timeout = 300 active_children = None max_children = 40 + # If true, server_close() waits until all child processes complete. + block_on_close = True - def collect_children(self): + def collect_children(self, *, blocking=False): """Internal routine to wait for children that have exited.""" if self.active_children is None: return @@ -572,7 +576,8 @@ def collect_children(self): # Now reap all defunct children. for pid in self.active_children.copy(): try: - pid, _ = os.waitpid(pid, os.WNOHANG) + flags = 0 if blocking else os.WNOHANG + pid, _ = os.waitpid(pid, flags) # if the child hasn't exited yet, pid will be 0 and ignored by # discard() below self.active_children.discard(pid) @@ -592,7 +597,7 @@ def handle_timeout(self): def service_actions(self): """Collect the zombie child processes regularly in the ForkingMixIn. - service_actions is called in the BaseServer's serve_forver loop. + service_actions is called in the BaseServer's serve_forever loop. """ self.collect_children() @@ -621,6 +626,43 @@ def process_request(self, request, client_address): finally: os._exit(status) + def server_close(self): + super().server_close() + self.collect_children(blocking=self.block_on_close) + + +class _Threads(list): + """ + Joinable list of all non-daemon threads. + """ + def append(self, thread): + self.reap() + if thread.daemon: + return + super().append(thread) + + def pop_all(self): + self[:], result = [], self[:] + return result + + def join(self): + for thread in self.pop_all(): + thread.join() + + def reap(self): + self[:] = (thread for thread in self if thread.is_alive()) + + +class _NoThreads: + """ + Degenerate version of _Threads. + """ + def append(self, thread): + pass + + def join(self): + pass + class ThreadingMixIn: """Mix-in class to handle each request in a new thread.""" @@ -628,6 +670,11 @@ class ThreadingMixIn: # Decides how threads will act upon termination of the # main process daemon_threads = False + # If true, server_close() waits until all non-daemonic threads terminate. + block_on_close = True + # Threads object + # used by server_close() to wait for all threads completion. + _threads = _NoThreads() def process_request_thread(self, request, client_address): """Same as in BaseServer but as a thread. @@ -644,11 +691,18 @@ def process_request_thread(self, request, client_address): def process_request(self, request, client_address): """Start a new thread to process the request.""" + if self.block_on_close: + vars(self).setdefault('_threads', _Threads()) t = threading.Thread(target = self.process_request_thread, args = (request, client_address)) t.daemon = self.daemon_threads + self._threads.append(t) t.start() + def server_close(self): + super().server_close() + self._threads.join() + if hasattr(os, "fork"): class ForkingUDPServer(ForkingMixIn, UDPServer): pass @@ -773,10 +827,8 @@ def writable(self): def write(self, b): self._sock.sendall(b) - # XXX RustPython TODO: implement memoryview properly - #with memoryview(b) as view: - # return view.nbytes - return len(b) + with memoryview(b) as view: + return view.nbytes def fileno(self): return self._sock.fileno() diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index f3b7146da7..80ba3d5f3e 100644 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -40,7 +40,6 @@ HOST = socket_helper.HOST # test unicode string and carriage return MSG = 'Michael Gilfix was here\u1234\r\n'.encode('utf-8') -MAIN_TIMEOUT = 60.0 VSOCKPORT = 1234 AIX = platform.system() == "AIX" @@ -83,6 +82,16 @@ def _have_socket_can_isotp(): s.close() return True +def _have_socket_can_j1939(): + """Check whether CAN J1939 sockets are supported on this host.""" + try: + s = socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_J1939) + except (AttributeError, OSError): + return False + else: + s.close() + return True + def _have_socket_rds(): """Check whether RDS sockets are supported on this host.""" try: @@ -119,6 +128,19 @@ def _have_socket_vsock(): return ret +def _have_socket_bluetooth(): + """Check whether AF_BLUETOOTH sockets are supported on this host.""" + try: + # RFCOMM is supported by all platforms with bluetooth support. Windows + # does not support omitting the protocol. + s = socket.socket(socket.AF_BLUETOOTH, socket.SOCK_STREAM, socket.BTPROTO_RFCOMM) + except (AttributeError, OSError): + return False + else: + s.close() + return True + + @contextlib.contextmanager def socket_setdefaulttimeout(timeout): old_timeout = socket.getdefaulttimeout() @@ -133,6 +155,8 @@ def socket_setdefaulttimeout(timeout): HAVE_SOCKET_CAN_ISOTP = _have_socket_can_isotp() +HAVE_SOCKET_CAN_J1939 = _have_socket_can_j1939() + HAVE_SOCKET_RDS = _have_socket_rds() HAVE_SOCKET_ALG = _have_socket_alg() @@ -141,6 +165,10 @@ def socket_setdefaulttimeout(timeout): HAVE_SOCKET_VSOCK = _have_socket_vsock() +HAVE_SOCKET_UDPLITE = hasattr(socket, "IPPROTO_UDPLITE") + +HAVE_SOCKET_BLUETOOTH = _have_socket_bluetooth() + # Size in bytes of the int type SIZEOF_INT = array.array("i").itemsize @@ -165,7 +193,13 @@ def tearDown(self): self.serv.close() self.serv = None -class ThreadSafeCleanupTestCase(unittest.TestCase): +class SocketUDPLITETest(SocketUDPTest): + + def setUp(self): + self.serv = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDPLITE) + self.port = socket_helper.bind_port(self.serv) + +class ThreadSafeCleanupTestCase: """Subclass of unittest.TestCase with thread-safe cleanup methods. This subclass protects the addCleanup() and doCleanups() methods @@ -292,9 +326,7 @@ def _testFoo(self): def __init__(self): # Swap the true setup function self.__setUp = self.setUp - self.__tearDown = self.tearDown self.setUp = self._setUp - self.tearDown = self._tearDown def serverExplicitReady(self): """This method allows the server to explicitly indicate that @@ -306,6 +338,7 @@ def serverExplicitReady(self): def _setUp(self): self.wait_threads = threading_helper.wait_threads_exit() self.wait_threads.__enter__() + self.addCleanup(self.wait_threads.__exit__, None, None, None) self.server_ready = threading.Event() self.client_ready = threading.Event() @@ -313,6 +346,11 @@ def _setUp(self): self.queue = queue.Queue(1) self.server_crashed = False + def raise_queued_exception(): + if self.queue.qsize(): + raise self.queue.get() + self.addCleanup(raise_queued_exception) + # Do some munging to start the client test. methodname = self.id() i = methodname.rfind('.') @@ -329,15 +367,7 @@ def _setUp(self): finally: self.server_ready.set() self.client_ready.wait() - - def _tearDown(self): - self.__tearDown() - self.done.wait() - self.wait_threads.__exit__(None, None, None) - - if self.queue.qsize(): - exc = self.queue.get() - raise exc + self.addCleanup(self.done.wait) def clientRun(self, test_func): self.server_ready.wait() @@ -396,6 +426,22 @@ def clientTearDown(self): self.cli = None ThreadableTest.clientTearDown(self) +@unittest.skipUnless(HAVE_SOCKET_UDPLITE, + 'UDPLITE sockets required for this test.') +class ThreadedUDPLITESocketTest(SocketUDPLITETest, ThreadableTest): + + def __init__(self, methodName='runTest'): + SocketUDPLITETest.__init__(self, methodName=methodName) + ThreadableTest.__init__(self) + + def clientSetUp(self): + self.cli = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDPLITE) + + def clientTearDown(self): + self.cli.close() + self.cli = None + ThreadableTest.clientTearDown(self) + class ThreadedCANSocketTest(SocketCANTest, ThreadableTest): def __init__(self, methodName='runTest'): @@ -681,6 +727,12 @@ class UDPTestBase(InetTestBase): def newSocket(self): return socket.socket(socket.AF_INET, socket.SOCK_DGRAM) +class UDPLITETestBase(InetTestBase): + """Base class for UDPLITE-over-IPv4 tests.""" + + def newSocket(self): + return socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDPLITE) + class SCTPStreamBase(InetTestBase): """Base class for SCTP tests in one-to-one (SOCK_STREAM) mode.""" @@ -700,6 +752,12 @@ class UDP6TestBase(Inet6TestBase): def newSocket(self): return socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) +class UDPLITE6TestBase(Inet6TestBase): + """Base class for UDPLITE-over-IPv6 tests.""" + + def newSocket(self): + return socket.socket(socket.AF_INET6, socket.SOCK_DGRAM, socket.IPPROTO_UDPLITE) + # Test-skipping decorators for use with ThreadableTest. @@ -808,6 +866,7 @@ def test_weakref(self): p = proxy(s) self.assertEqual(p.fileno(), s.fileno()) s = None + support.gc_collect() # For PyPy or other GCs. try: p.fileno() except ReferenceError: @@ -859,10 +918,8 @@ def testSendtoErrors(self): self.assertIn('not NoneType', str(cm.exception)) with self.assertRaises(TypeError) as cm: s.sendto(b'foo', 'bar', sockname) - self.assertIn('an integer is required', str(cm.exception)) with self.assertRaises(TypeError) as cm: s.sendto(b'foo', None, None) - self.assertIn('an integer is required', str(cm.exception)) # wrong number of args with self.assertRaises(TypeError) as cm: s.sendto(b'foo') @@ -901,6 +958,37 @@ def testWindowsSpecificConstants(self): socket.IPPROTO_L2TP socket.IPPROTO_SCTP + @unittest.skipUnless(sys.platform == 'darwin', 'macOS specific test') + @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test') + def test3542SocketOptions(self): + # Ref. issue #35569 and https://tools.ietf.org/html/rfc3542 + opts = { + 'IPV6_CHECKSUM', + 'IPV6_DONTFRAG', + 'IPV6_DSTOPTS', + 'IPV6_HOPLIMIT', + 'IPV6_HOPOPTS', + 'IPV6_NEXTHOP', + 'IPV6_PATHMTU', + 'IPV6_PKTINFO', + 'IPV6_RECVDSTOPTS', + 'IPV6_RECVHOPLIMIT', + 'IPV6_RECVHOPOPTS', + 'IPV6_RECVPATHMTU', + 'IPV6_RECVPKTINFO', + 'IPV6_RECVRTHDR', + 'IPV6_RECVTCLASS', + 'IPV6_RTHDR', + 'IPV6_RTHDRDSTOPTS', + 'IPV6_RTHDR_TYPE_0', + 'IPV6_TCLASS', + 'IPV6_USE_MIN_MTU', + } + for opt in opts: + self.assertTrue( + hasattr(socket, opt), f"Missing RFC3542 socket option '{opt}'" + ) + def testHostnameRes(self): # Testing hostname resolution mechanisms hostname = socket.gethostname() @@ -1032,9 +1120,11 @@ def testNtoHErrors(self): s_good_values = [0, 1, 2, 0xffff] l_good_values = s_good_values + [0xffffffff] l_bad_values = [-1, -2, 1<<32, 1<<1000] - s_bad_values = l_bad_values + [_testcapi.INT_MIN - 1, - _testcapi.INT_MAX + 1] - s_deprecated_values = [1<<16, _testcapi.INT_MAX] + s_bad_values = ( + l_bad_values + + [_testcapi.INT_MIN-1, _testcapi.INT_MAX+1] + + [1 << 16, _testcapi.INT_MAX] + ) for k in s_good_values: socket.ntohs(k) socket.htons(k) @@ -1047,9 +1137,6 @@ def testNtoHErrors(self): for k in l_bad_values: self.assertRaises(OverflowError, socket.ntohl, k) self.assertRaises(OverflowError, socket.htonl, k) - for k in s_deprecated_values: - self.assertWarns(DeprecationWarning, socket.ntohs, k) - self.assertWarns(DeprecationWarning, socket.htons, k) def testGetServBy(self): eq = self.assertEqual @@ -1531,7 +1618,7 @@ def raising_handler(*args): if with_timeout: signal.signal(signal.SIGALRM, ok_handler) signal.alarm(1) - self.assertRaises(socket.timeout, c.sendall, + self.assertRaises(TimeoutError, c.sendall, b"x" * support.SOCK_MAX_SIZE) finally: signal.alarm(0) @@ -1603,7 +1690,8 @@ def test_makefile_mode(self): for mode in 'r', 'rb', 'rw', 'w', 'wb': with self.subTest(mode=mode): with socket.socket() as sock: - with sock.makefile(mode) as fp: + encoding = None if "b" in mode else "utf-8" + with sock.makefile(mode, encoding=encoding) as fp: self.assertEqual(fp.mode, mode) def test_makefile_invalid_mode(self): @@ -1662,6 +1750,7 @@ def test_getaddrinfo_ipv6_basic(self): @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.') @unittest.skipIf(sys.platform == 'win32', 'does not work on Windows') @unittest.skipIf(AIX, 'Symbolic scope id does not work') + @unittest.skipUnless(hasattr(socket, 'if_nameindex'), "test needs socket.if_nameindex()") def test_getaddrinfo_ipv6_scopeid_symbolic(self): # Just pick up any network interface (Linux, Mac OS X) (ifindex, test_interface) = socket.if_nameindex()[0] @@ -1694,6 +1783,7 @@ def test_getaddrinfo_ipv6_scopeid_numeric(self): @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.') @unittest.skipIf(sys.platform == 'win32', 'does not work on Windows') @unittest.skipIf(AIX, 'Symbolic scope id does not work') + @unittest.skipUnless(hasattr(socket, 'if_nameindex'), "test needs socket.if_nameindex()") def test_getnameinfo_ipv6_scopeid_symbolic(self): # Just pick up any network interface. (ifindex, test_interface) = socket.if_nameindex()[0] @@ -1827,11 +1917,11 @@ def test_socket_fileno(self): socket.SOCK_STREAM) def test_socket_fileno_rejects_float(self): - with self.assertRaisesRegex(TypeError, "integer argument expected"): + with self.assertRaises(TypeError): socket.socket(socket.AF_INET, socket.SOCK_STREAM, fileno=42.5) def test_socket_fileno_rejects_other_types(self): - with self.assertRaisesRegex(TypeError, "integer is required"): + with self.assertRaises(TypeError): socket.socket(socket.AF_INET, socket.SOCK_STREAM, fileno="foo") def test_socket_fileno_rejects_invalid_socket(self): @@ -2093,6 +2183,68 @@ def testBind(self): raise +@unittest.skipUnless(HAVE_SOCKET_CAN_J1939, 'CAN J1939 required for this test.') +class J1939Test(unittest.TestCase): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.interface = "vcan0" + + @unittest.skipUnless(hasattr(socket, "CAN_J1939"), + 'socket.CAN_J1939 required for this test.') + def testJ1939Constants(self): + socket.CAN_J1939 + + socket.J1939_MAX_UNICAST_ADDR + socket.J1939_IDLE_ADDR + socket.J1939_NO_ADDR + socket.J1939_NO_NAME + socket.J1939_PGN_REQUEST + socket.J1939_PGN_ADDRESS_CLAIMED + socket.J1939_PGN_ADDRESS_COMMANDED + socket.J1939_PGN_PDU1_MAX + socket.J1939_PGN_MAX + socket.J1939_NO_PGN + + # J1939 socket options + socket.SO_J1939_FILTER + socket.SO_J1939_PROMISC + socket.SO_J1939_SEND_PRIO + socket.SO_J1939_ERRQUEUE + + socket.SCM_J1939_DEST_ADDR + socket.SCM_J1939_DEST_NAME + socket.SCM_J1939_PRIO + socket.SCM_J1939_ERRQUEUE + + socket.J1939_NLA_PAD + socket.J1939_NLA_BYTES_ACKED + + socket.J1939_EE_INFO_NONE + socket.J1939_EE_INFO_TX_ABORT + + socket.J1939_FILTER_MAX + + @unittest.skipUnless(hasattr(socket, "CAN_J1939"), + 'socket.CAN_J1939 required for this test.') + def testCreateJ1939Socket(self): + with socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_J1939) as s: + pass + + def testBind(self): + try: + with socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_J1939) as s: + addr = self.interface, socket.J1939_NO_NAME, socket.J1939_NO_PGN, socket.J1939_NO_ADDR + s.bind(addr) + self.assertEqual(s.getsockname(), addr) + except OSError as e: + if e.errno == errno.ENODEV: + self.skipTest('network interface `%s` does not exist' % + self.interface) + else: + raise + + @unittest.skipUnless(HAVE_SOCKET_RDS, 'RDS sockets required for this test.') class BasicRDSTest(unittest.TestCase): @@ -2252,6 +2404,45 @@ def testSocketBufferSize(self): socket.SO_VM_SOCKETS_BUFFER_MIN_SIZE)) +@unittest.skipUnless(HAVE_SOCKET_BLUETOOTH, + 'Bluetooth sockets required for this test.') +class BasicBluetoothTest(unittest.TestCase): + + def testBluetoothConstants(self): + socket.BDADDR_ANY + socket.BDADDR_LOCAL + socket.AF_BLUETOOTH + socket.BTPROTO_RFCOMM + + if sys.platform != "win32": + socket.BTPROTO_HCI + socket.SOL_HCI + socket.BTPROTO_L2CAP + + if not sys.platform.startswith("freebsd"): + socket.BTPROTO_SCO + + def testCreateRfcommSocket(self): + with socket.socket(socket.AF_BLUETOOTH, socket.SOCK_STREAM, socket.BTPROTO_RFCOMM) as s: + pass + + @unittest.skipIf(sys.platform == "win32", "windows does not support L2CAP sockets") + def testCreateL2capSocket(self): + with socket.socket(socket.AF_BLUETOOTH, socket.SOCK_SEQPACKET, socket.BTPROTO_L2CAP) as s: + pass + + @unittest.skipIf(sys.platform == "win32", "windows does not support HCI sockets") + def testCreateHciSocket(self): + with socket.socket(socket.AF_BLUETOOTH, socket.SOCK_RAW, socket.BTPROTO_HCI) as s: + pass + + @unittest.skipIf(sys.platform == "win32" or sys.platform.startswith("freebsd"), + "windows and freebsd do not support SCO sockets") + def testCreateScoSocket(self): + with socket.socket(socket.AF_BLUETOOTH, socket.SOCK_SEQPACKET, socket.BTPROTO_SCO) as s: + pass + + class BasicTCPTest(SocketConnectedTest): def __init__(self, methodName='runTest'): @@ -2403,6 +2594,37 @@ def testRecvFromNegative(self): def _testRecvFromNegative(self): self.cli.sendto(MSG, 0, (HOST, self.port)) + +@unittest.skipUnless(HAVE_SOCKET_UDPLITE, + 'UDPLITE sockets required for this test.') +class BasicUDPLITETest(ThreadedUDPLITESocketTest): + + def __init__(self, methodName='runTest'): + ThreadedUDPLITESocketTest.__init__(self, methodName=methodName) + + def testSendtoAndRecv(self): + # Testing sendto() and Recv() over UDPLITE + msg = self.serv.recv(len(MSG)) + self.assertEqual(msg, MSG) + + def _testSendtoAndRecv(self): + self.cli.sendto(MSG, 0, (HOST, self.port)) + + def testRecvFrom(self): + # Testing recvfrom() over UDPLITE + msg, addr = self.serv.recvfrom(len(MSG)) + self.assertEqual(msg, MSG) + + def _testRecvFrom(self): + self.cli.sendto(MSG, 0, (HOST, self.port)) + + def testRecvFromNegative(self): + # Negative lengths passed to recvfrom should give ValueError. + self.assertRaises(ValueError, self.serv.recvfrom, -1) + + def _testRecvFromNegative(self): + self.cli.sendto(MSG, 0, (HOST, self.port)) + # Tests for the sendmsg()/recvmsg() interface. Where possible, the # same test code is used with different families and types of socket # (e.g. stream, datagram), and tests using recvmsg() are repeated @@ -2767,7 +2989,7 @@ def _testSendmsgTimeout(self): try: while True: self.sendmsgToServer([b"a"*512]) - except socket.timeout: + except TimeoutError: pass except OSError as exc: if exc.errno != errno.ENOMEM: @@ -2775,7 +2997,7 @@ def _testSendmsgTimeout(self): # bpo-33937 the test randomly fails on Travis CI with # "OSError: [Errno 12] Cannot allocate memory" else: - self.fail("socket.timeout not raised") + self.fail("TimeoutError not raised") finally: self.misc_event.set() @@ -2910,7 +3132,7 @@ def testRecvmsgTimeout(self): # Check that timeout works. try: self.serv_sock.settimeout(0.03) - self.assertRaises(socket.timeout, + self.assertRaises(TimeoutError, self.doRecvmsg, self.serv_sock, len(MSG)) finally: self.misc_event.set() @@ -3930,7 +4152,7 @@ def _testSecondCmsgTruncLen0Minus1(self): @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT", "IPV6_RECVTCLASS", "IPV6_TCLASS") - def testSecomdCmsgTruncInData(self): + def testSecondCmsgTruncInData(self): # Test truncation of the second of two control messages inside # its associated data. self.serv_sock.setsockopt(socket.IPPROTO_IPV6, @@ -3965,8 +4187,8 @@ def testSecomdCmsgTruncInData(self): self.assertEqual(ancdata, []) - @testSecomdCmsgTruncInData.client_skip - def _testSecomdCmsgTruncInData(self): + @testSecondCmsgTruncInData.client_skip + def _testSecondCmsgTruncInData(self): self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout)) self.sendToServer(MSG) @@ -4036,6 +4258,89 @@ class RecvmsgIntoRFC3542AncillaryUDP6Test(RecvmsgIntoMixin, pass +@unittest.skipUnless(HAVE_SOCKET_UDPLITE, + 'UDPLITE sockets required for this test.') +class SendrecvmsgUDPLITETestBase(SendrecvmsgDgramFlagsBase, + SendrecvmsgConnectionlessBase, + ThreadedSocketTestMixin, UDPLITETestBase): + pass + +@unittest.skipUnless(HAVE_SOCKET_UDPLITE, + 'UDPLITE sockets required for this test.') +@requireAttrs(socket.socket, "sendmsg") +class SendmsgUDPLITETest(SendmsgConnectionlessTests, SendrecvmsgUDPLITETestBase): + pass + +@unittest.skipUnless(HAVE_SOCKET_UDPLITE, + 'UDPLITE sockets required for this test.') +@requireAttrs(socket.socket, "recvmsg") +class RecvmsgUDPLITETest(RecvmsgTests, SendrecvmsgUDPLITETestBase): + pass + +@unittest.skipUnless(HAVE_SOCKET_UDPLITE, + 'UDPLITE sockets required for this test.') +@requireAttrs(socket.socket, "recvmsg_into") +class RecvmsgIntoUDPLITETest(RecvmsgIntoTests, SendrecvmsgUDPLITETestBase): + pass + + +@unittest.skipUnless(HAVE_SOCKET_UDPLITE, + 'UDPLITE sockets required for this test.') +class SendrecvmsgUDPLITE6TestBase(SendrecvmsgDgramFlagsBase, + SendrecvmsgConnectionlessBase, + ThreadedSocketTestMixin, UDPLITE6TestBase): + + def checkRecvmsgAddress(self, addr1, addr2): + # Called to compare the received address with the address of + # the peer, ignoring scope ID + self.assertEqual(addr1[:-1], addr2[:-1]) + +@requireAttrs(socket.socket, "sendmsg") +@unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.') +@unittest.skipUnless(HAVE_SOCKET_UDPLITE, + 'UDPLITE sockets required for this test.') +@requireSocket("AF_INET6", "SOCK_DGRAM") +class SendmsgUDPLITE6Test(SendmsgConnectionlessTests, SendrecvmsgUDPLITE6TestBase): + pass + +@requireAttrs(socket.socket, "recvmsg") +@unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.') +@unittest.skipUnless(HAVE_SOCKET_UDPLITE, + 'UDPLITE sockets required for this test.') +@requireSocket("AF_INET6", "SOCK_DGRAM") +class RecvmsgUDPLITE6Test(RecvmsgTests, SendrecvmsgUDPLITE6TestBase): + pass + +@requireAttrs(socket.socket, "recvmsg_into") +@unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.') +@unittest.skipUnless(HAVE_SOCKET_UDPLITE, + 'UDPLITE sockets required for this test.') +@requireSocket("AF_INET6", "SOCK_DGRAM") +class RecvmsgIntoUDPLITE6Test(RecvmsgIntoTests, SendrecvmsgUDPLITE6TestBase): + pass + +@requireAttrs(socket.socket, "recvmsg") +@unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.') +@unittest.skipUnless(HAVE_SOCKET_UDPLITE, + 'UDPLITE sockets required for this test.') +@requireAttrs(socket, "IPPROTO_IPV6") +@requireSocket("AF_INET6", "SOCK_DGRAM") +class RecvmsgRFC3542AncillaryUDPLITE6Test(RFC3542AncillaryTest, + SendrecvmsgUDPLITE6TestBase): + pass + +@requireAttrs(socket.socket, "recvmsg_into") +@unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.') +@unittest.skipUnless(HAVE_SOCKET_UDPLITE, + 'UDPLITE sockets required for this test.') +@requireAttrs(socket, "IPPROTO_IPV6") +@requireSocket("AF_INET6", "SOCK_DGRAM") +class RecvmsgIntoRFC3542AncillaryUDPLITE6Test(RecvmsgIntoMixin, + RFC3542AncillaryTest, + SendrecvmsgUDPLITE6TestBase): + pass + + class SendrecvmsgTCPTestBase(SendrecvmsgConnectedBase, ConnectedStreamTestMixin, TCPTestBase): pass @@ -4133,7 +4438,7 @@ class RecvmsgIntoSCMRightsStreamTest(RecvmsgIntoMixin, SCMRightsTest, # threads alive during the test so that the OS cannot deliver the # signal to the wrong one. -class InterruptedTimeoutBase(unittest.TestCase): +class InterruptedTimeoutBase: # Base class for interrupted send/receive tests. Installs an # empty handler for SIGALRM and removes it on teardown, along with # any scheduled alarms. @@ -4145,7 +4450,7 @@ def setUp(self): self.addCleanup(signal.signal, signal.SIGALRM, orig_alrm_handler) # Timeout for socket operations - timeout = 4.0 + timeout = support.LOOPBACK_TIMEOUT # Provide setAlarm() method to schedule delivery of SIGALRM after # given number of seconds, or cancel it if zero, and an @@ -4248,12 +4553,10 @@ def checkInterruptedSend(self, func, *args, **kwargs): self.setAlarm(0) # Issue #12958: The following tests have problems on OS X prior to 10.7 - @unittest.skip("TODO: RUSTPYTHON, plistlib") @support.requires_mac_ver(10, 7) def testInterruptedSendTimeout(self): self.checkInterruptedSend(self.serv_conn.send, b"a"*512) - @unittest.skip("TODO: RUSTPYTHON, plistlib") @support.requires_mac_ver(10, 7) def testInterruptedSendtoTimeout(self): # Passing an actual address here as Python's wrapper for @@ -4263,7 +4566,6 @@ def testInterruptedSendtoTimeout(self): self.checkInterruptedSend(self.serv_conn.sendto, b"a"*512, self.serv_addr) - @unittest.skip("TODO: RUSTPYTHON, plistlib") @support.requires_mac_ver(10, 7) @requireAttrs(socket.socket, "sendmsg") def testInterruptedSendmsgTimeout(self): @@ -4429,7 +4731,7 @@ def _testInheritFlagsTimeout(self): def testAccept(self): # Testing non-blocking accept - self.serv.setblocking(0) + self.serv.setblocking(False) # connect() didn't start: non-blocking accept() fails start_time = time.monotonic() @@ -4440,7 +4742,7 @@ def testAccept(self): self.event.set() - read, write, err = select.select([self.serv], [], [], MAIN_TIMEOUT) + read, write, err = select.select([self.serv], [], [], support.LONG_TIMEOUT) if self.serv not in read: self.fail("Error trying to do accept after select.") @@ -4460,7 +4762,7 @@ def testRecv(self): # Testing non-blocking recv conn, addr = self.serv.accept() self.addCleanup(conn.close) - conn.setblocking(0) + conn.setblocking(False) # the server didn't send data yet: non-blocking recv() fails with self.assertRaises(BlockingIOError): @@ -4468,7 +4770,7 @@ def testRecv(self): self.event.set() - read, write, err = select.select([conn], [], [], MAIN_TIMEOUT) + read, write, err = select.select([conn], [], [], support.LONG_TIMEOUT) if conn not in read: self.fail("Error during select call to non-blocking socket.") @@ -4550,7 +4852,7 @@ def testReadAfterTimeout(self): self.cli_conn.settimeout(1) self.read_file.read(3) # First read raises a timeout - self.assertRaises(socket.timeout, self.read_file.read, 1) + self.assertRaises(TimeoutError, self.read_file.read, 1) # Second read is disallowed with self.assertRaises(OSError) as ctx: self.read_file.read(1) @@ -4815,7 +5117,7 @@ class NetworkConnectionNoServer(unittest.TestCase): class MockSocket(socket.socket): def connect(self, *args): - raise socket.timeout('timed out') + raise TimeoutError('timed out') @contextlib.contextmanager def mocked_socket_module(self): @@ -4865,13 +5167,13 @@ def test_create_connection_timeout(self): with self.mocked_socket_module(): try: socket.create_connection((HOST, 1234)) - except socket.timeout: + except TimeoutError: pass except OSError as exc: if socket_helper.IPV6_ENABLED or exc.errno != errno.EAFNOSUPPORT: raise else: - self.fail('socket.timeout not raised') + self.fail('TimeoutError not raised') class NetworkConnectionAttributesTest(SocketTCPTest, ThreadableTest): @@ -4894,14 +5196,16 @@ def _justAccept(self): testFamily = _justAccept def _testFamily(self): - self.cli = socket.create_connection((HOST, self.port), timeout=30) + self.cli = socket.create_connection((HOST, self.port), + timeout=support.LOOPBACK_TIMEOUT) self.addCleanup(self.cli.close) self.assertEqual(self.cli.family, 2) testSourceAddress = _justAccept def _testSourceAddress(self): - self.cli = socket.create_connection((HOST, self.port), timeout=30, - source_address=('', self.source_port)) + self.cli = socket.create_connection((HOST, self.port), + timeout=support.LOOPBACK_TIMEOUT, + source_address=('', self.source_port)) self.addCleanup(self.cli.close) self.assertEqual(self.cli.getsockname()[1], self.source_port) # The port number being used is sufficient to show that the bind() @@ -4971,7 +5275,7 @@ def _testInsideTimeout(self): def _testOutsideTimeout(self): self.cli = sock = socket.create_connection((HOST, self.port), timeout=1) - self.assertRaises(socket.timeout, lambda: sock.recv(5)) + self.assertRaises(TimeoutError, lambda: sock.recv(5)) class TCPTimeoutTest(SocketTCPTest): @@ -4980,7 +5284,7 @@ def testTCPTimeout(self): def raise_timeout(*args, **kwargs): self.serv.settimeout(1.0) self.serv.accept() - self.assertRaises(socket.timeout, raise_timeout, + self.assertRaises(TimeoutError, raise_timeout, "Error generating a timeout exception (TCP)") def testTimeoutZero(self): @@ -4988,7 +5292,7 @@ def testTimeoutZero(self): try: self.serv.settimeout(0.0) foo = self.serv.accept() - except socket.timeout: + except TimeoutError: self.fail("caught timeout instead of error (TCP)") except OSError: ok = True @@ -5013,7 +5317,7 @@ def alarm_handler(signal, frame): try: signal.alarm(2) # POSIX allows alarm to be up to 1 second early foo = self.serv.accept() - except socket.timeout: + except TimeoutError: self.fail("caught timeout instead of Alarm") except Alarm: pass @@ -5037,7 +5341,7 @@ def testUDPTimeout(self): def raise_timeout(*args, **kwargs): self.serv.settimeout(1.0) self.serv.recv(1024) - self.assertRaises(socket.timeout, raise_timeout, + self.assertRaises(TimeoutError, raise_timeout, "Error generating a timeout exception (UDP)") def testTimeoutZero(self): @@ -5045,7 +5349,7 @@ def testTimeoutZero(self): try: self.serv.settimeout(0.0) foo = self.serv.recv(1024) - except socket.timeout: + except TimeoutError: self.fail("caught timeout instead of error (UDP)") except OSError: ok = True @@ -5054,6 +5358,31 @@ def testTimeoutZero(self): if not ok: self.fail("recv() returned success when we did not expect it") +@unittest.skipUnless(HAVE_SOCKET_UDPLITE, + 'UDPLITE sockets required for this test.') +class UDPLITETimeoutTest(SocketUDPLITETest): + + def testUDPLITETimeout(self): + def raise_timeout(*args, **kwargs): + self.serv.settimeout(1.0) + self.serv.recv(1024) + self.assertRaises(TimeoutError, raise_timeout, + "Error generating a timeout exception (UDPLITE)") + + def testTimeoutZero(self): + ok = False + try: + self.serv.settimeout(0.0) + foo = self.serv.recv(1024) + except TimeoutError: + self.fail("caught timeout instead of error (UDPLITE)") + except OSError: + ok = True + except: + self.fail("caught unexpected exception (UDPLITE)") + if not ok: + self.fail("recv() returned success when we did not expect it") + class TestExceptions(unittest.TestCase): def testExceptionTree(self): @@ -5061,6 +5390,8 @@ def testExceptionTree(self): self.assertTrue(issubclass(socket.herror, OSError)) self.assertTrue(issubclass(socket.gaierror, OSError)) self.assertTrue(issubclass(socket.timeout, OSError)) + self.assertIs(socket.error, OSError) + self.assertIs(socket.timeout, TimeoutError) def test_setblocking_invalidfd(self): # Regression test for issue #28471 @@ -5117,6 +5448,20 @@ def testBytearrayName(self): s.bind(bytearray(b"\x00python\x00test\x00")) self.assertEqual(s.getsockname(), b"\x00python\x00test\x00") + def testAutobind(self): + # Check that binding to an empty string binds to an available address + # in the abstract namespace as specified in unix(7) "Autobind feature". + abstract_address = b"^\0[0-9a-f]{5}" + with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s1: + s1.bind("") + self.assertRegex(s1.getsockname(), abstract_address) + # Each socket is bound to a different abstract address. + with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s2: + s2.bind("") + self.assertRegex(s2.getsockname(), abstract_address) + self.assertNotEqual(s1.getsockname(), s2.getsockname()) + + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'test needs socket.AF_UNIX') class TestUnixDomain(unittest.TestCase): @@ -5172,7 +5517,7 @@ def testBytesAddr(self): def testSurrogateescapeBind(self): # Test binding to a valid non-ASCII pathname, with the # non-ASCII bytes supplied using surrogateescape encoding. - path = os.path.abspath(support.TESTFN_UNICODE) + path = os.path.abspath(os_helper.TESTFN_UNICODE) b = self.encoded(path) self.bind(self.sock, b.decode("ascii", "surrogateescape")) self.addCleanup(os_helper.unlink, path) @@ -5188,6 +5533,11 @@ def testUnencodableAddr(self): self.addCleanup(os_helper.unlink, path) self.assertEqual(self.sock.getsockname(), path) + @unittest.skipIf(sys.platform == 'linux', 'Linux specific test') + def testEmptyAddress(self): + # Test that binding empty address fails. + self.assertRaises(OSError, self.sock.bind, "") + class BufferIOTest(SocketConnectedTest): """ @@ -5285,7 +5635,7 @@ def isTipcAvailable(): if not hasattr(socket, "AF_TIPC"): return False try: - f = open("/proc/modules") + f = open("/proc/modules", encoding="utf-8") except (FileNotFoundError, IsADirectoryError, PermissionError): # It's ok if the file does not exist, is a directory or if we # have not the permission to read it. @@ -5511,15 +5861,15 @@ def test_SOCK_NONBLOCK(self): with socket.socket(socket.AF_INET, socket.SOCK_STREAM | socket.SOCK_NONBLOCK) as s: self.checkNonblock(s) - s.setblocking(1) + s.setblocking(True) self.checkNonblock(s, nonblock=False) - s.setblocking(0) + s.setblocking(False) self.checkNonblock(s) s.settimeout(None) self.checkNonblock(s, nonblock=False) s.settimeout(2.0) self.checkNonblock(s, timeout=2.0) - s.setblocking(1) + s.setblocking(True) self.checkNonblock(s, nonblock=False) # defaulttimeout t = socket.getdefaulttimeout() @@ -5650,7 +6000,7 @@ class SendfileUsingSendTest(ThreadedTCPSocketTest): FILESIZE = (10 * 1024 * 1024) # 10 MiB BUFSIZE = 8192 FILEDATA = b"" - TIMEOUT = 2 + TIMEOUT = support.LOOPBACK_TIMEOUT @classmethod def setUpClass(cls): @@ -5676,7 +6026,7 @@ def tearDownClass(cls): os_helper.unlink(os_helper.TESTFN) def accept_conn(self): - self.serv.settimeout(MAIN_TIMEOUT) + self.serv.settimeout(support.LONG_TIMEOUT) conn, addr = self.serv.accept() conn.settimeout(self.TIMEOUT) self.addCleanup(conn.close) @@ -5772,7 +6122,9 @@ def testOffset(self): def _testCount(self): address = self.serv.getsockname() file = open(os_helper.TESTFN, 'rb') - with socket.create_connection(address, timeout=2) as sock, file as file: + sock = socket.create_connection(address, + timeout=support.LOOPBACK_TIMEOUT) + with sock, file: count = 5000007 meth = self.meth_from_sock(sock) sent = meth(file, count=count) @@ -5792,7 +6144,9 @@ def testCount(self): def _testCountSmall(self): address = self.serv.getsockname() file = open(os_helper.TESTFN, 'rb') - with socket.create_connection(address, timeout=2) as sock, file as file: + sock = socket.create_connection(address, + timeout=support.LOOPBACK_TIMEOUT) + with sock, file: count = 1 meth = self.meth_from_sock(sock) sent = meth(file, count=count) @@ -5846,12 +6200,13 @@ def testNonBlocking(self): def _testWithTimeout(self): address = self.serv.getsockname() file = open(os_helper.TESTFN, 'rb') - with socket.create_connection(address, timeout=2) as sock, file as file: + sock = socket.create_connection(address, + timeout=support.LOOPBACK_TIMEOUT) + with sock, file: meth = self.meth_from_sock(sock) sent = meth(file) self.assertEqual(sent, self.FILESIZE) - @unittest.skipIf(sys.platform == "darwin", "TODO: RUSTPYTHON, killed (for OOM?)") def testWithTimeout(self): conn = self.accept_conn() data = self.recv_data(conn) @@ -5866,11 +6221,12 @@ def _testWithTimeoutTriggeredSend(self): with socket.create_connection(address) as sock: sock.settimeout(0.01) meth = self.meth_from_sock(sock) - self.assertRaises(socket.timeout, meth, file) + self.assertRaises(TimeoutError, meth, file) def testWithTimeoutTriggeredSend(self): conn = self.accept_conn() conn.recv(88192) + time.sleep(1) # errors @@ -5883,7 +6239,7 @@ def test_errors(self): meth = self.meth_from_sock(s) self.assertRaisesRegex( ValueError, "SOCK_STREAM", meth, file) - with open(os_helper.TESTFN, 'rt') as file: + with open(os_helper.TESTFN, encoding="utf-8") as file: with socket.socket() as s: meth = self.meth_from_sock(s) self.assertRaisesRegex( @@ -6121,6 +6477,12 @@ def test_length_restriction(self): sock.bind(("type", "n" * 64)) +@unittest.skipUnless(sys.platform == 'darwin', 'macOS specific test') +class TestMacOSTCPFlags(unittest.TestCase): + def test_tcp_keepalive(self): + self.assertTrue(socket.TCP_KEEPALIVE) + + @unittest.skipUnless(sys.platform.startswith("win"), "requires Windows") class TestMSWindowsTCPFlags(unittest.TestCase): knownTCPFlags = { @@ -6196,14 +6558,7 @@ def test_dualstack_ipv6_family(self): class CreateServerFunctionalTest(unittest.TestCase): - timeout = 3 - - def setUp(self): - self.thread = None - - def tearDown(self): - if self.thread is not None: - self.thread.join(self.timeout) + timeout = support.LOOPBACK_TIMEOUT def echo_server(self, sock): def run(sock): @@ -6218,8 +6573,9 @@ def run(sock): event = threading.Event() sock.settimeout(self.timeout) - self.thread = threading.Thread(target=run, args=(sock, )) - self.thread.start() + thread = threading.Thread(target=run, args=(sock, )) + thread.start() + self.addCleanup(thread.join, self.timeout) event.set() def echo_client(self, addr, family): @@ -6265,73 +6621,52 @@ def test_dual_stack_client_v6(self): self.echo_server(sock) self.echo_client(("::1", port), socket.AF_INET6) +@requireAttrs(socket, "send_fds") +@requireAttrs(socket, "recv_fds") +@requireAttrs(socket, "AF_UNIX") +class SendRecvFdsTests(unittest.TestCase): + def testSendAndRecvFds(self): + def close_pipes(pipes): + for fd1, fd2 in pipes: + os.close(fd1) + os.close(fd2) + + def close_fds(fds): + for fd in fds: + os.close(fd) + + # send 10 file descriptors + pipes = [os.pipe() for _ in range(10)] + self.addCleanup(close_pipes, pipes) + fds = [rfd for rfd, wfd in pipes] + + # use a UNIX socket pair to exchange file descriptors locally + sock1, sock2 = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM) + with sock1, sock2: + socket.send_fds(sock1, [MSG], fds) + # request more data and file descriptors than expected + msg, fds2, flags, addr = socket.recv_fds(sock2, len(MSG) * 2, len(fds) * 2) + self.addCleanup(close_fds, fds2) + + self.assertEqual(msg, MSG) + self.assertEqual(len(fds2), len(fds)) + self.assertEqual(flags, 0) + # don't test addr -def test_main(): - tests = [GeneralModuleTests, BasicTCPTest, TCPCloserTest, TCPTimeoutTest, - TestExceptions, BufferIOTest, BasicTCPTest2, BasicUDPTest, - UDPTimeoutTest, CreateServerTest, CreateServerFunctionalTest] - - tests.extend([ - NonBlockingTCPTests, - FileObjectClassTestCase, - UnbufferedFileObjectClassTestCase, - LineBufferedFileObjectClassTestCase, - SmallBufferedFileObjectClassTestCase, - UnicodeReadFileObjectClassTestCase, - UnicodeWriteFileObjectClassTestCase, - UnicodeReadWriteFileObjectClassTestCase, - NetworkConnectionNoServer, - NetworkConnectionAttributesTest, - NetworkConnectionBehaviourTest, - ContextManagersTest, - InheritanceTest, - NonblockConstantTest - ]) - tests.append(BasicSocketPairTest) - tests.append(TestUnixDomain) - tests.append(TestLinuxAbstractNamespace) - tests.extend([TIPCTest, TIPCThreadableTest]) - tests.extend([BasicCANTest, CANTest]) - tests.extend([BasicRDSTest, RDSTest]) - tests.append(LinuxKernelCryptoAPI) - tests.append(BasicQIPCRTRTest) - tests.extend([ - BasicVSOCKTest, - ThreadedVSOCKSocketStreamTest, - ]) - tests.extend([ - CmsgMacroTests, - SendmsgUDPTest, - RecvmsgUDPTest, - RecvmsgIntoUDPTest, - SendmsgUDP6Test, - RecvmsgUDP6Test, - RecvmsgRFC3542AncillaryUDP6Test, - RecvmsgIntoRFC3542AncillaryUDP6Test, - RecvmsgIntoUDP6Test, - SendmsgTCPTest, - RecvmsgTCPTest, - RecvmsgIntoTCPTest, - SendmsgSCTPStreamTest, - RecvmsgSCTPStreamTest, - RecvmsgIntoSCTPStreamTest, - SendmsgUnixStreamTest, - RecvmsgUnixStreamTest, - RecvmsgIntoUnixStreamTest, - RecvmsgSCMRightsStreamTest, - RecvmsgIntoSCMRightsStreamTest, - # These are slow when setitimer() is not available - InterruptedRecvTimeoutTest, - InterruptedSendTimeoutTest, - TestSocketSharing, - SendfileUsingSendTest, - SendfileUsingSendfileTest, - ]) - tests.append(TestMSWindowsTCPFlags) + # test that file descriptors are connected + for index, fds in enumerate(pipes): + rfd, wfd = fds + os.write(wfd, str(index).encode()) + for index, rfd in enumerate(fds2): + data = os.read(rfd, 100) + self.assertEqual(data, str(index).encode()) + + +def setUpModule(): thread_info = threading_helper.threading_setup() - support.run_unittest(*tests) - threading_helper.threading_cleanup(*thread_info) + unittest.addModuleCleanup(threading_helper.threading_cleanup, *thread_info) + if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_socketserver.py b/Lib/test/test_socketserver.py index 6af948e922..0d089edc9b 100644 --- a/Lib/test/test_socketserver.py +++ b/Lib/test/test_socketserver.py @@ -39,7 +39,7 @@ def signal_alarm(n): # Remember real select() to avoid interferences with mocking _real_select = select.select -def receive(sock, n, timeout=20): +def receive(sock, n, timeout=test.support.SHORT_TIMEOUT): r, w, x = _real_select([sock], [], [], timeout) if sock in r: return sock.recv(n) @@ -68,9 +68,7 @@ def simple_subprocess(testcase): except: raise finally: - pid2, status = os.waitpid(pid, 0) - testcase.assertEqual(pid2, pid) - testcase.assertEqual(72 << 8, status) + test.support.wait_process(pid, exitcode=72) class SocketServerTest(unittest.TestCase): @@ -345,8 +343,11 @@ def test_threading_handled(self): self.check_result(handled=True) def test_threading_not_handled(self): - ThreadingErrorTestServer(SystemExit) - self.check_result(handled=False) + with threading_helper.catch_threading_exception() as cm: + ThreadingErrorTestServer(SystemExit) + self.check_result(handled=False) + + self.assertIs(cm.exc_type, SystemExit) @requires_forking def test_forking_handled(self): From 19504a6b9dd6d00d64be545d2429494181bb7ec7 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Mon, 15 Aug 2022 00:11:01 +0900 Subject: [PATCH 15/23] mark failing tests again --- Lib/test/test_socket.py | 7 +++++++ Lib/test/test_socketserver.py | 2 -- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index 80ba3d5f3e..d6cba6b8b3 100644 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -958,6 +958,8 @@ def testWindowsSpecificConstants(self): socket.IPPROTO_L2TP socket.IPPROTO_SCTP + # TODO: RUSTPYTHON + @unittest.expectedFailure @unittest.skipUnless(sys.platform == 'darwin', 'macOS specific test') @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test') def test3542SocketOptions(self): @@ -2190,6 +2192,8 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.interface = "vcan0" + # TODO: RUSTPYTHON + @unittest.expectedFailure @unittest.skipUnless(hasattr(socket, "CAN_J1939"), 'socket.CAN_J1939 required for this test.') def testJ1939Constants(self): @@ -2231,6 +2235,8 @@ def testCreateJ1939Socket(self): with socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_J1939) as s: pass + # TODO: RUSTPYTHON + @unittest.expectedFailure def testBind(self): try: with socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_J1939) as s: @@ -6207,6 +6213,7 @@ def _testWithTimeout(self): sent = meth(file) self.assertEqual(sent, self.FILESIZE) + @unittest.skip("TODO: RUSTPYTHON") def testWithTimeout(self): conn = self.accept_conn() data = self.recv_data(conn) diff --git a/Lib/test/test_socketserver.py b/Lib/test/test_socketserver.py index 0d089edc9b..92b0ad71c6 100644 --- a/Lib/test/test_socketserver.py +++ b/Lib/test/test_socketserver.py @@ -521,8 +521,6 @@ def shutdown_request(self, request): self.assertEqual(server.shutdown_called, 1) server.server_close() - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_threads_reaped(self): """ In #37193, users reported a memory leak From b6523ddb90eb7e4c8e5d74492d36b124ac624b8d Mon Sep 17 00:00:00 2001 From: CPython Developers <> Date: Mon, 15 Aug 2022 00:17:17 +0900 Subject: [PATCH 16/23] Update secrets from CPython 3.10.6 --- Lib/secrets.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/Lib/secrets.py b/Lib/secrets.py index 130434229e..a546efbdd4 100644 --- a/Lib/secrets.py +++ b/Lib/secrets.py @@ -14,7 +14,6 @@ import base64 import binascii -import os from hmac import compare_digest from random import SystemRandom @@ -44,7 +43,7 @@ def token_bytes(nbytes=None): """ if nbytes is None: nbytes = DEFAULT_ENTROPY - return os.urandom(nbytes) + return _sysrand.randbytes(nbytes) def token_hex(nbytes=None): """Return a random text string, in hexadecimal. From 56cc13f81d1585de6b40a5d8e26ebb37c4078585 Mon Sep 17 00:00:00 2001 From: CPython Developers <> Date: Mon, 15 Aug 2022 00:20:01 +0900 Subject: [PATCH 17/23] Update test_scope from CPython 3.10.6 --- Lib/test/test_scope.py | 44 +++++++++++++++--------------------------- 1 file changed, 16 insertions(+), 28 deletions(-) diff --git a/Lib/test/test_scope.py b/Lib/test/test_scope.py index be868e50e2..29a4ac3c16 100644 --- a/Lib/test/test_scope.py +++ b/Lib/test/test_scope.py @@ -2,11 +2,11 @@ import weakref from test.support import check_syntax_error, cpython_only +from test.support import gc_collect class ScopeTests(unittest.TestCase): - def testSimpleNesting(self): def make_adder(x): @@ -20,7 +20,6 @@ def adder(y): self.assertEqual(inc(1), 2) self.assertEqual(plus10(-2), 8) - def testExtraNesting(self): def make_adder2(x): @@ -36,7 +35,6 @@ def adder(y): self.assertEqual(inc(1), 2) self.assertEqual(plus10(-2), 8) - def testSimpleAndRebinding(self): def make_adder3(x): @@ -50,8 +48,7 @@ def adder(y): self.assertEqual(inc(1), 2) self.assertEqual(plus10(-2), 8) - - + def testNestingGlobalNoFree(self): def make_adder4(): # XXX add exta level of indirection @@ -70,7 +67,6 @@ def adder(y): global_x = 10 self.assertEqual(adder(-2), 8) - def testNestingThroughClass(self): def make_adder5(x): @@ -85,7 +81,6 @@ def __call__(self, y): self.assertEqual(inc(1), 2) self.assertEqual(plus10(-2), 8) - def testNestingPlusFreeRefToGlobal(self): def make_adder6(x): @@ -101,7 +96,6 @@ def adder(y): self.assertEqual(inc(1), 11) # there's only one global self.assertEqual(plus10(-2), 8) - def testNearestEnclosingScope(self): def f(x): @@ -115,7 +109,6 @@ def h(z): test_func = f(10) self.assertEqual(test_func(5), 47) - def testMixedFreevarsAndCellvars(self): def identity(x): @@ -173,7 +166,6 @@ def str(self): self.assertEqual(t.method_and_var(), "method") self.assertEqual(t.actual_global(), "global") - def testCellIsKwonlyArg(self): # Issue 1409: Initialisation of a cell value, # when it comes from a keyword-only parameter @@ -185,7 +177,6 @@ def bar(): self.assertEqual(foo(a=42), 50) self.assertEqual(foo(), 25) - def testRecursion(self): def f(x): @@ -201,6 +192,7 @@ def fact(n): self.assertEqual(f(6), 720) + def testUnoptimizedNamespaces(self): check_syntax_error(self, """if 1: @@ -235,7 +227,6 @@ def g(): return getrefcount # global or local? """) - def testLambdas(self): f1 = lambda x: lambda y: x + y @@ -312,8 +303,6 @@ def f(): fail('scope of global_x not correctly determined') """, {'fail': self.fail}) - - def testComplexDefinitions(self): def makeReturner(*lst): @@ -348,6 +337,7 @@ def h(): return g() self.assertEqual(f(), 7) self.assertEqual(x, 7) + # II x = 7 def f(): @@ -362,6 +352,7 @@ def h(): return g() self.assertEqual(f(), 2) self.assertEqual(x, 7) + # III x = 7 def f(): @@ -377,6 +368,7 @@ def h(): return g() self.assertEqual(f(), 2) self.assertEqual(x, 2) + # IV x = 7 def f(): @@ -392,8 +384,10 @@ def h(): return g() self.assertEqual(f(), 2) self.assertEqual(x, 2) + # XXX what about global statements in class blocks? # do they affect methods? + x = 12 class Global: global x @@ -402,6 +396,7 @@ def set(self, val): x = val def get(self): return x + g = Global() self.assertEqual(g.get(), 13) g.set(15) @@ -428,6 +423,7 @@ def f2(): for i in range(100): f1() + gc_collect() # For PyPy or other GCs. self.assertEqual(Foo.count, 0) def testClassAndGlobal(self): @@ -439,16 +435,19 @@ class Foo: def __call__(self, y): return x + y return Foo() + x = 0 self.assertEqual(test(6)(2), 8) x = -1 self.assertEqual(test(3)(2), 5) + looked_up_by_load_name = False class X: # Implicit globals inside classes are be looked up by LOAD_NAME, not # LOAD_GLOBAL. locals()['looked_up_by_load_name'] = True passed = looked_up_by_load_name + self.assertTrue(X.passed) """) @@ -468,7 +467,6 @@ def h(z): del d['h'] self.assertEqual(d, {'x': 2, 'y': 7, 'w': 6}) - def testLocalsClass(self): # This test verifies that calling locals() does not pollute # the local namespace of the class with free variables. Old @@ -502,7 +500,6 @@ def m(self): self.assertNotIn("x", varnames) self.assertIn("y", varnames) - @cpython_only def testLocalsClass_WithTrace(self): # Issue23728: after the trace function returns, the locals() @@ -520,7 +517,6 @@ def f(self): self.assertEqual(x, 12) # Used to raise UnboundLocalError - def testBoundAndFree(self): # var is bound and free in class @@ -534,7 +530,6 @@ def m(self): inst = f(3)() self.assertEqual(inst.a, inst.m()) - @cpython_only def testInteractionWithTraceFunc(self): @@ -574,7 +569,6 @@ def f(x): else: self.fail("exec should have failed, because code contained free vars") - def testListCompLocalVars(self): try: @@ -603,7 +597,6 @@ def g(): f(4)() - def testFreeingCell(self): # Test what happens when a finalizer accesses # the cell where the object was stored. @@ -611,7 +604,6 @@ class Special: def __del__(self): nestedcell_get() - def testNonLocalFunction(self): def f(x): @@ -631,7 +623,6 @@ def dec(): self.assertEqual(dec(), 1) self.assertEqual(dec(), 0) - def testNonLocalMethod(self): def f(x): class c: @@ -678,7 +669,6 @@ def h(): self.assertEqual(2, global_ns["result2"]) self.assertEqual(9, global_ns["result9"]) - def testNonLocalClass(self): def f(x): @@ -693,7 +683,7 @@ def get(self): self.assertEqual(c.get(), 1) self.assertNotIn("x", c.__class__.__dict__) - + def testNonLocalGenerator(self): def f(x): @@ -707,7 +697,6 @@ def g(y): g = f(0) self.assertEqual(list(g(5)), [1, 2, 3, 4, 5]) - def testNestedNonLocal(self): def f(x): @@ -725,7 +714,6 @@ def h(): h = g() self.assertEqual(h(), 3) - def testTopIsNotSignificant(self): # See #9997. def top(a): @@ -746,7 +734,6 @@ class X: self.assertFalse(hasattr(X, "x")) self.assertEqual(x, 42) - @cpython_only def testCellLeak(self): # Issue 17927. @@ -773,8 +760,9 @@ def dig(self): tester.dig() ref = weakref.ref(tester) del tester + gc_collect() # For PyPy or other GCs. self.assertIsNone(ref()) if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() From 95e16a94fa627c24add1ee0b81f3687be74d36bf Mon Sep 17 00:00:00 2001 From: CPython Developers <> Date: Mon, 15 Aug 2022 00:21:08 +0900 Subject: [PATCH 18/23] Update sched from CPython 3.10.6 --- Lib/sched.py | 20 ++++++++++---------- Lib/test/test_sched.py | 11 +++++++++++ 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/Lib/sched.py b/Lib/sched.py index ff87874a3a..14613cf298 100644 --- a/Lib/sched.py +++ b/Lib/sched.py @@ -26,23 +26,19 @@ import time import heapq from collections import namedtuple +from itertools import count import threading from time import monotonic as _time __all__ = ["scheduler"] -class Event(namedtuple('Event', 'time, priority, action, argument, kwargs')): - __slots__ = [] - def __eq__(s, o): return (s.time, s.priority) == (o.time, o.priority) - def __lt__(s, o): return (s.time, s.priority) < (o.time, o.priority) - def __le__(s, o): return (s.time, s.priority) <= (o.time, o.priority) - def __gt__(s, o): return (s.time, s.priority) > (o.time, o.priority) - def __ge__(s, o): return (s.time, s.priority) >= (o.time, o.priority) - +Event = namedtuple('Event', 'time, priority, sequence, action, argument, kwargs') Event.time.__doc__ = ('''Numeric type compatible with the return value of the timefunc function passed to the constructor.''') Event.priority.__doc__ = ('''Events scheduled for the same time will be executed in the order of their priority.''') +Event.sequence.__doc__ = ('''A continually increasing sequence number that + separates events if time and priority are equal.''') Event.action.__doc__ = ('''Executing the event means executing action(*argument, **kwargs)''') Event.argument.__doc__ = ('''argument is a sequence holding the positional @@ -61,6 +57,7 @@ def __init__(self, timefunc=_time, delayfunc=time.sleep): self._lock = threading.RLock() self.timefunc = timefunc self.delayfunc = delayfunc + self._sequence_generator = count() def enterabs(self, time, priority, action, argument=(), kwargs=_sentinel): """Enter a new event in the queue at an absolute time. @@ -71,8 +68,10 @@ def enterabs(self, time, priority, action, argument=(), kwargs=_sentinel): """ if kwargs is _sentinel: kwargs = {} - event = Event(time, priority, action, argument, kwargs) + with self._lock: + event = Event(time, priority, next(self._sequence_generator), + action, argument, kwargs) heapq.heappush(self._queue, event) return event # The ID @@ -136,7 +135,8 @@ def run(self, blocking=True): with lock: if not q: break - time, priority, action, argument, kwargs = q[0] + (time, priority, sequence, action, + argument, kwargs) = q[0] now = timefunc() if time > now: delay = True diff --git a/Lib/test/test_sched.py b/Lib/test/test_sched.py index 491d7b3a74..7ae7baae85 100644 --- a/Lib/test/test_sched.py +++ b/Lib/test/test_sched.py @@ -142,6 +142,17 @@ def test_cancel_concurrent(self): self.assertTrue(q.empty()) self.assertEqual(timer.time(), 4) + def test_cancel_correct_event(self): + # bpo-19270 + events = [] + scheduler = sched.scheduler() + scheduler.enterabs(1, 1, events.append, ("a",)) + b = scheduler.enterabs(1, 1, events.append, ("b",)) + scheduler.enterabs(1, 1, events.append, ("c",)) + scheduler.cancel(b) + scheduler.run() + self.assertEqual(events, ["a", "c"]) + def test_empty(self): l = [] fun = lambda x: l.append(x) From 628c4c0edf7a76d03749f6dd4c772e40bd3803cf Mon Sep 17 00:00:00 2001 From: CPython Developers <> Date: Mon, 15 Aug 2022 00:22:40 +0900 Subject: [PATCH 19/23] Update runpy from CPython 3.10.6 --- Lib/runpy.py | 49 +++++++++++++++++++++++++------- Lib/test/test_cmd_line_script.py | 2 -- Lib/test/test_runpy.py | 7 ++--- 3 files changed, 41 insertions(+), 17 deletions(-) diff --git a/Lib/runpy.py b/Lib/runpy.py index d86f0e4a3c..c7d3d8caad 100644 --- a/Lib/runpy.py +++ b/Lib/runpy.py @@ -13,8 +13,9 @@ import sys import importlib.machinery # importlib first so we can test #15386 via -m import importlib.util +import io import types -from pkgutil import read_code, get_importer +import os __all__ = [ "run_module", "run_path", @@ -131,6 +132,9 @@ def _get_module_details(mod_name, error=ImportError): # importlib, where the latter raises other errors for cases where # pkgutil previously raised ImportError msg = "Error while finding module specification for {!r} ({}: {})" + if mod_name.endswith(".py"): + msg += (f". Try using '{mod_name[:-3]}' instead of " + f"'{mod_name}' as the module name.") raise error(msg.format(mod_name, type(ex).__name__, ex)) from ex if spec is None: raise error("No module named %s" % mod_name) @@ -194,9 +198,24 @@ def _run_module_as_main(mod_name, alter_argv=True): def run_module(mod_name, init_globals=None, run_name=None, alter_sys=False): - """Execute a module's code without importing it + """Execute a module's code without importing it. - Returns the resulting top level namespace dictionary + mod_name -- an absolute module name or package name. + + Optional arguments: + init_globals -- dictionary used to pre-populate the module’s + globals dictionary before the code is executed. + + run_name -- if not None, this will be used for setting __name__; + otherwise, __name__ will be set to mod_name + '__main__' if the + named module is a package and to just mod_name otherwise. + + alter_sys -- if True, sys.argv[0] is updated with the value of + __file__ and sys.modules[__name__] is updated with a temporary + module object for the module being executed. Both are + restored to their original values before the function returns. + + Returns the resulting module globals dictionary. """ mod_name, mod_spec, code = _get_module_details(mod_name) if run_name is None: @@ -228,27 +247,35 @@ def _get_main_module_details(error=ImportError): def _get_code_from_file(run_name, fname): # Check for a compiled file first - with open(fname, "rb") as f: + from pkgutil import read_code + decoded_path = os.path.abspath(os.fsdecode(fname)) + with io.open_code(decoded_path) as f: code = read_code(f) if code is None: # That didn't work, so try it as normal source code - with open(fname, "rb") as f: + with io.open_code(decoded_path) as f: code = compile(f.read(), fname, 'exec') return code, fname def run_path(path_name, init_globals=None, run_name=None): - """Execute code located at the specified filesystem location + """Execute code located at the specified filesystem location. + + path_name -- filesystem location of a Python script, zipfile, + or directory containing a top level __main__.py script. + + Optional arguments: + init_globals -- dictionary used to pre-populate the module’s + globals dictionary before the code is executed. - Returns the resulting top level namespace dictionary + run_name -- if not None, this will be used to set __name__; + otherwise, '' will be used for __name__. - The file path may refer directly to a Python script (i.e. - one that could be directly executed with execfile) or else - it may refer to a zipfile or directory containing a top - level __main__.py script. + Returns the resulting module globals dictionary. """ if run_name is None: run_name = "" pkg_name = run_name.rpartition(".")[0] + from pkgutil import get_importer importer = get_importer(path_name) # Trying to avoid importing imp so as to not consume the deprecation warning. is_NullImporter = False diff --git a/Lib/test/test_cmd_line_script.py b/Lib/test/test_cmd_line_script.py index 88845d26b9..8b9eb7ffd9 100644 --- a/Lib/test/test_cmd_line_script.py +++ b/Lib/test/test_cmd_line_script.py @@ -523,8 +523,6 @@ def test_dash_m_bad_pyc(self): self.assertNotIn(b'is a package', err) self.assertNotIn(b'Traceback', err) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_hint_when_triying_to_import_a_py_file(self): with os_helper.temp_dir() as script_dir, \ os_helper.change_cwd(path=script_dir): diff --git a/Lib/test/test_runpy.py b/Lib/test/test_runpy.py index 4bec101b07..c93f0f7ec1 100644 --- a/Lib/test/test_runpy.py +++ b/Lib/test/test_runpy.py @@ -12,10 +12,10 @@ import textwrap import unittest import warnings -from test.support import verbose, no_tracing +from test.support import no_tracing, verbose +from test.support.import_helper import forget, make_legacy_pyc, unload from test.support.os_helper import create_empty_file, temp_dir from test.support.script_helper import make_script, make_zip_script -from test.support.import_helper import unload, forget, make_legacy_pyc import runpy @@ -745,8 +745,7 @@ def test_main_recursion_error(self): "runpy.run_path(%r)\n") % dummy_dir script_name = self._make_test_script(script_dir, mod_name, source) zip_name, fname = make_zip_script(script_dir, 'test_zip', script_name) - msg = "recursion depth exceeded" - self.assertRaisesRegex(RecursionError, msg, run_path, zip_name) + self.assertRaises(RecursionError, run_path, zip_name) # TODO: RUSTPYTHON, detect encoding comments in files @unittest.expectedFailure From 2eec9194166266c9a87066aad5a9178380dc8067 Mon Sep 17 00:00:00 2001 From: CPython Developers <> Date: Sun, 14 Aug 2022 23:21:01 +0900 Subject: [PATCH 20/23] Update xml and test_xml_etree --- Lib/test/test_xml_etree.py | 102 +++++++++++++++----- Lib/xml/dom/expatbuilder.py | 8 +- Lib/xml/dom/minidom.py | 56 ++++++++--- Lib/xml/dom/pulldom.py | 7 ++ Lib/xml/dom/xmlbuilder.py | 25 +---- Lib/xml/etree/ElementInclude.py | 58 ++++++++++-- Lib/xml/etree/ElementPath.py | 34 +++++-- Lib/xml/etree/ElementTree.py | 160 ++++++++++++++++++-------------- Lib/xml/etree/__init__.py | 2 +- Lib/xml/sax/__init__.py | 10 +- Lib/xml/sax/expatreader.py | 14 +-- Lib/xml/sax/handler.py | 45 +++++++++ Lib/xml/sax/saxutils.py | 5 +- 13 files changed, 358 insertions(+), 168 deletions(-) diff --git a/Lib/test/test_xml_etree.py b/Lib/test/test_xml_etree.py index 4d67101218..b721f028a1 100644 --- a/Lib/test/test_xml_etree.py +++ b/Lib/test/test_xml_etree.py @@ -10,7 +10,6 @@ import html import io import itertools -import locale import operator import os import pickle @@ -130,6 +129,9 @@ def newtest(*args, **kwargs): return newtest return decorator +def convlinesep(data): + return data.replace(b'\n', os.linesep.encode()) + class ModuleTest(unittest.TestCase): def test_sanity(self): @@ -1023,17 +1025,15 @@ def test_tostring_xml_declaration(self): @unittest.expectedFailure def test_tostring_xml_declaration_unicode_encoding(self): elem = ET.XML('') - preferredencoding = locale.getpreferredencoding() self.assertEqual( - f"\n", - ET.tostring(elem, encoding='unicode', xml_declaration=True) + ET.tostring(elem, encoding='unicode', xml_declaration=True), + "\n" ) # TODO: RUSTPYTHON @unittest.expectedFailure def test_tostring_xml_declaration_cases(self): elem = ET.XML('ø') - preferredencoding = locale.getpreferredencoding() TESTCASES = [ # (expected_retval, encoding, xml_declaration) # ... xml_declaration = None @@ -1060,7 +1060,7 @@ def test_tostring_xml_declaration_cases(self): b"ø", 'US-ASCII', True), (b"\n" b"\xf8", 'ISO-8859-1', True), - (f"\n" + ("\n" "ø", 'unicode', True), ] @@ -1102,11 +1102,10 @@ def test_tostringlist_xml_declaration(self): b"\n" ) - preferredencoding = locale.getpreferredencoding() stringlist = ET.tostringlist(elem, encoding='unicode', xml_declaration=True) self.assertEqual( ''.join(stringlist), - f"\n" + "\n" ) self.assertRegex(stringlist[0], r"^<\?xml version='1.0' encoding='.+'?>") self.assertEqual(['', '', ''], stringlist[1:]) @@ -3914,38 +3913,91 @@ def test_encoding(self): @unittest.expectedFailure def test_write_to_filename(self): self.addCleanup(os_helper.unlink, TESTFN) - tree = ET.ElementTree(ET.XML('''''')) + tree = ET.ElementTree(ET.XML('''\xf8''')) tree.write(TESTFN) with open(TESTFN, 'rb') as f: - self.assertEqual(f.read(), b'''''') + self.assertEqual(f.read(), b'''ø''') + + def test_write_to_filename_with_encoding(self): + self.addCleanup(os_helper.unlink, TESTFN) + tree = ET.ElementTree(ET.XML('''\xf8''')) + tree.write(TESTFN, encoding='utf-8') + with open(TESTFN, 'rb') as f: + self.assertEqual(f.read(), b'''\xc3\xb8''') + + tree.write(TESTFN, encoding='ISO-8859-1') + with open(TESTFN, 'rb') as f: + self.assertEqual(f.read(), convlinesep( + b'''\n''' + b'''\xf8''')) + + def test_write_to_filename_as_unicode(self): + self.addCleanup(os_helper.unlink, TESTFN) + with open(TESTFN, 'w') as f: + encoding = f.encoding + os_helper.unlink(TESTFN) + + tree = ET.ElementTree(ET.XML('''\xf8''')) + tree.write(TESTFN, encoding='unicode') + with open(TESTFN, 'rb') as f: + self.assertEqual(f.read(), b"\xc3\xb8") # TODO: RUSTPYTHON @unittest.expectedFailure def test_write_to_text_file(self): self.addCleanup(os_helper.unlink, TESTFN) - tree = ET.ElementTree(ET.XML('''''')) + tree = ET.ElementTree(ET.XML('''\xf8''')) with open(TESTFN, 'w', encoding='utf-8') as f: tree.write(f, encoding='unicode') self.assertFalse(f.closed) with open(TESTFN, 'rb') as f: - self.assertEqual(f.read(), b'''''') + self.assertEqual(f.read(), b'''\xc3\xb8''') + + with open(TESTFN, 'w', encoding='ascii', errors='xmlcharrefreplace') as f: + tree.write(f, encoding='unicode') + self.assertFalse(f.closed) + with open(TESTFN, 'rb') as f: + self.assertEqual(f.read(), b'''ø''') + + with open(TESTFN, 'w', encoding='ISO-8859-1') as f: + tree.write(f, encoding='unicode') + self.assertFalse(f.closed) + with open(TESTFN, 'rb') as f: + self.assertEqual(f.read(), b'''\xf8''') # TODO: RUSTPYTHON @unittest.expectedFailure def test_write_to_binary_file(self): self.addCleanup(os_helper.unlink, TESTFN) - tree = ET.ElementTree(ET.XML('''''')) + tree = ET.ElementTree(ET.XML('''\xf8''')) with open(TESTFN, 'wb') as f: tree.write(f) self.assertFalse(f.closed) with open(TESTFN, 'rb') as f: - self.assertEqual(f.read(), b'''''') + self.assertEqual(f.read(), b'''ø''') + + def test_write_to_binary_file_with_encoding(self): + self.addCleanup(os_helper.unlink, TESTFN) + tree = ET.ElementTree(ET.XML('''\xf8''')) + with open(TESTFN, 'wb') as f: + tree.write(f, encoding='utf-8') + self.assertFalse(f.closed) + with open(TESTFN, 'rb') as f: + self.assertEqual(f.read(), b'''\xc3\xb8''') + + with open(TESTFN, 'wb') as f: + tree.write(f, encoding='ISO-8859-1') + self.assertFalse(f.closed) + with open(TESTFN, 'rb') as f: + self.assertEqual(f.read(), + b'''\n''' + b'''\xf8''') # TODO: RUSTPYTHON @unittest.expectedFailure def test_write_to_binary_file_with_bom(self): self.addCleanup(os_helper.unlink, TESTFN) - tree = ET.ElementTree(ET.XML('''''')) + tree = ET.ElementTree(ET.XML('''\xf8''')) # test BOM writing to buffered file with open(TESTFN, 'wb') as f: tree.write(f, encoding='utf-16') @@ -3953,7 +4005,7 @@ def test_write_to_binary_file_with_bom(self): with open(TESTFN, 'rb') as f: self.assertEqual(f.read(), '''\n''' - ''''''.encode("utf-16")) + '''\xf8'''.encode("utf-16")) # test BOM writing to non-buffered file with open(TESTFN, 'wb', buffering=0) as f: tree.write(f, encoding='utf-16') @@ -3961,7 +4013,7 @@ def test_write_to_binary_file_with_bom(self): with open(TESTFN, 'rb') as f: self.assertEqual(f.read(), '''\n''' - ''''''.encode("utf-16")) + '''\xf8'''.encode("utf-16")) # TODO: RUSTPYTHON @unittest.expectedFailure @@ -3974,10 +4026,10 @@ def test_read_from_stringio(self): # TODO: RUSTPYTHON @unittest.expectedFailure def test_write_to_stringio(self): - tree = ET.ElementTree(ET.XML('''''')) + tree = ET.ElementTree(ET.XML('''\xf8''')) stream = io.StringIO() tree.write(stream, encoding='unicode') - self.assertEqual(stream.getvalue(), '''''') + self.assertEqual(stream.getvalue(), '''\xf8''') # TODO: RUSTPYTHON @unittest.expectedFailure @@ -3990,10 +4042,10 @@ def test_read_from_bytesio(self): # TODO: RUSTPYTHON @unittest.expectedFailure def test_write_to_bytesio(self): - tree = ET.ElementTree(ET.XML('''''')) + tree = ET.ElementTree(ET.XML('''\xf8''')) raw = io.BytesIO() tree.write(raw) - self.assertEqual(raw.getvalue(), b'''''') + self.assertEqual(raw.getvalue(), b'''ø''') class dummy: pass @@ -4011,12 +4063,12 @@ def test_read_from_user_text_reader(self): # TODO: RUSTPYTHON @unittest.expectedFailure def test_write_to_user_text_writer(self): - tree = ET.ElementTree(ET.XML('''''')) + tree = ET.ElementTree(ET.XML('''\xf8''')) stream = io.StringIO() writer = self.dummy() writer.write = stream.write tree.write(writer, encoding='unicode') - self.assertEqual(stream.getvalue(), '''''') + self.assertEqual(stream.getvalue(), '''\xf8''') # TODO: RUSTPYTHON @unittest.expectedFailure @@ -4032,12 +4084,12 @@ def test_read_from_user_binary_reader(self): # TODO: RUSTPYTHON @unittest.expectedFailure def test_write_to_user_binary_writer(self): - tree = ET.ElementTree(ET.XML('''''')) + tree = ET.ElementTree(ET.XML('''\xf8''')) raw = io.BytesIO() writer = self.dummy() writer.write = raw.write tree.write(writer) - self.assertEqual(raw.getvalue(), b'''''') + self.assertEqual(raw.getvalue(), b'''ø''') # TODO: RUSTPYTHON @unittest.expectedFailure diff --git a/Lib/xml/dom/expatbuilder.py b/Lib/xml/dom/expatbuilder.py index 2bd835b035..199c22d0af 100644 --- a/Lib/xml/dom/expatbuilder.py +++ b/Lib/xml/dom/expatbuilder.py @@ -204,11 +204,11 @@ def parseFile(self, file): buffer = file.read(16*1024) if not buffer: break - parser.Parse(buffer, 0) + parser.Parse(buffer, False) if first_buffer and self.document.documentElement: self._setup_subset(buffer) first_buffer = False - parser.Parse("", True) + parser.Parse(b"", True) except ParseEscape: pass doc = self.document @@ -637,7 +637,7 @@ def parseString(self, string): nsattrs = self._getNSattrs() # get ns decls from node's ancestors document = _FRAGMENT_BUILDER_TEMPLATE % (ident, subset, nsattrs) try: - parser.Parse(document, 1) + parser.Parse(document, True) except: self.reset() raise @@ -697,7 +697,7 @@ def external_entity_ref_handler(self, context, base, systemId, publicId): self.fragment = self.document.createDocumentFragment() self.curNode = self.fragment try: - parser.Parse(self._source, 1) + parser.Parse(self._source, True) finally: self.curNode = old_cur_node self.document = old_document diff --git a/Lib/xml/dom/minidom.py b/Lib/xml/dom/minidom.py index 24957ea14f..d09ef5e7d0 100644 --- a/Lib/xml/dom/minidom.py +++ b/Lib/xml/dom/minidom.py @@ -43,10 +43,11 @@ class Node(xml.dom.Node): def __bool__(self): return True - def toxml(self, encoding=None): - return self.toprettyxml("", "", encoding) + def toxml(self, encoding=None, standalone=None): + return self.toprettyxml("", "", encoding, standalone) - def toprettyxml(self, indent="\t", newl="\n", encoding=None): + def toprettyxml(self, indent="\t", newl="\n", encoding=None, + standalone=None): if encoding is None: writer = io.StringIO() else: @@ -56,7 +57,7 @@ def toprettyxml(self, indent="\t", newl="\n", encoding=None): newline='\n') if self.nodeType == Node.DOCUMENT_NODE: # Can pass encoding only to document, to put it into XML header - self.writexml(writer, "", indent, newl, encoding) + self.writexml(writer, "", indent, newl, encoding, standalone) else: self.writexml(writer, "", indent, newl) if encoding is None: @@ -718,6 +719,14 @@ def unlink(self): Node.unlink(self) def getAttribute(self, attname): + """Returns the value of the specified attribute. + + Returns the value of the element's attribute named attname as + a string. An empty string is returned if the element does not + have such an attribute. Note that an empty string may also be + returned as an explicitly given attribute value, use the + hasAttribute method to distinguish these two cases. + """ if self._attrs is None: return "" try: @@ -823,10 +832,16 @@ def removeAttributeNode(self, node): # Restore this since the node is still useful and otherwise # unlinked node.ownerDocument = self.ownerDocument + return node removeAttributeNodeNS = removeAttributeNode def hasAttribute(self, name): + """Checks whether the element has an attribute with the specified name. + + Returns True if the element has an attribute with the specified name. + Otherwise, returns False. + """ if self._attrs is None: return False return name in self._attrs @@ -837,6 +852,11 @@ def hasAttributeNS(self, namespaceURI, localName): return (namespaceURI, localName) in self._attrsNS def getElementsByTagName(self, name): + """Returns all descendant elements with the given tag name. + + Returns the list of all descendant elements (not direct children + only) with the specified tag name. + """ return _get_elements_by_tagName_helper(self, name, NodeList()) def getElementsByTagNameNS(self, namespaceURI, localName): @@ -847,22 +867,27 @@ def __repr__(self): return "" % (self.tagName, id(self)) def writexml(self, writer, indent="", addindent="", newl=""): + """Write an XML element to a file-like object + + Write the element to the writer object that must provide + a write method (e.g. a file or StringIO object). + """ # indent = current indentation # addindent = indentation to add to higher levels # newl = newline string writer.write(indent+"<" + self.tagName) attrs = self._get_attributes() - a_names = sorted(attrs.keys()) - for a_name in a_names: + for a_name in attrs.keys(): writer.write(" %s=\"" % a_name) _write_data(writer, attrs[a_name].value) writer.write("\"") if self.childNodes: writer.write(">") if (len(self.childNodes) == 1 and - self.childNodes[0].nodeType == Node.TEXT_NODE): + self.childNodes[0].nodeType in ( + Node.TEXT_NODE, Node.CDATA_SECTION_NODE)): self.childNodes[0].writexml(writer, '', '', '') else: writer.write(newl) @@ -1786,12 +1811,17 @@ def importNode(self, node, deep): raise xml.dom.NotSupportedErr("cannot import document type nodes") return _clone_node(node, deep, self) - def writexml(self, writer, indent="", addindent="", newl="", encoding=None): - if encoding is None: - writer.write(''+newl) - else: - writer.write('%s' % ( - encoding, newl)) + def writexml(self, writer, indent="", addindent="", newl="", encoding=None, + standalone=None): + declarations = [] + + if encoding: + declarations.append(f'encoding="{encoding}"') + if standalone is not None: + declarations.append(f'standalone="{"yes" if standalone else "no"}"') + + writer.write(f'{newl}') + for node in self.childNodes: node.writexml(writer, indent, addindent, newl) diff --git a/Lib/xml/dom/pulldom.py b/Lib/xml/dom/pulldom.py index 43504f7656..96a8d59519 100644 --- a/Lib/xml/dom/pulldom.py +++ b/Lib/xml/dom/pulldom.py @@ -217,6 +217,13 @@ def reset(self): self.parser.setContentHandler(self.pulldom) def __getitem__(self, pos): + import warnings + warnings.warn( + "DOMEventStream's __getitem__ method ignores 'pos' parameter. " + "Use iterator protocol instead.", + DeprecationWarning, + stacklevel=2 + ) rc = self.getEvent() if rc: return rc diff --git a/Lib/xml/dom/xmlbuilder.py b/Lib/xml/dom/xmlbuilder.py index e9a1536472..8a20026349 100644 --- a/Lib/xml/dom/xmlbuilder.py +++ b/Lib/xml/dom/xmlbuilder.py @@ -1,7 +1,6 @@ """Implementation of the DOM Level 3 'LS-Load' feature.""" import copy -import warnings import xml.dom from xml.dom.NodeFilter import NodeFilter @@ -80,7 +79,7 @@ def setFeature(self, name, state): settings = self._settings[(_name_xform(name), state)] except KeyError: raise xml.dom.NotSupportedErr( - "unsupported feature: %r" % (name,)) + "unsupported feature: %r" % (name,)) from None else: for name, value in settings: setattr(self._options, name, value) @@ -332,29 +331,10 @@ def startContainer(self, element): del NodeFilter -class _AsyncDeprecatedProperty: - def warn(self, cls): - clsname = cls.__name__ - warnings.warn( - "{cls}.async is deprecated; use {cls}.async_".format(cls=clsname), - DeprecationWarning) - - def __get__(self, instance, cls): - self.warn(cls) - if instance is not None: - return instance.async_ - return False - - def __set__(self, instance, value): - self.warn(type(instance)) - setattr(instance, 'async_', value) - - class DocumentLS: """Mixin to create documents that conform to the load/save spec.""" async_ = False - locals()['async'] = _AsyncDeprecatedProperty() # Avoid DeprecationWarning def _get_async(self): return False @@ -384,9 +364,6 @@ def saveXML(self, snode): return snode.toxml() -del _AsyncDeprecatedProperty - - class DOMImplementationLS: MODE_SYNCHRONOUS = 1 MODE_ASYNCHRONOUS = 2 diff --git a/Lib/xml/etree/ElementInclude.py b/Lib/xml/etree/ElementInclude.py index 963470e3b1..40a9b22292 100644 --- a/Lib/xml/etree/ElementInclude.py +++ b/Lib/xml/etree/ElementInclude.py @@ -42,7 +42,7 @@ # -------------------------------------------------------------------- # Licensed to PSF under a Contributor Agreement. -# See http://www.python.org/psf/license for licensing details. +# See https://www.python.org/psf/license for licensing details. ## # Limited XInclude support for the ElementTree package. @@ -50,18 +50,28 @@ import copy from . import ElementTree +from urllib.parse import urljoin XINCLUDE = "{http://www.w3.org/2001/XInclude}" XINCLUDE_INCLUDE = XINCLUDE + "include" XINCLUDE_FALLBACK = XINCLUDE + "fallback" +# For security reasons, the inclusion depth is limited to this read-only value by default. +DEFAULT_MAX_INCLUSION_DEPTH = 6 + + ## # Fatal include error. class FatalIncludeError(SyntaxError): pass + +class LimitedRecursiveIncludeError(FatalIncludeError): + pass + + ## # Default loader. This loader reads an included resource from disk. # @@ -92,13 +102,33 @@ def default_loader(href, parse, encoding=None): # @param loader Optional resource loader. If omitted, it defaults # to {@link default_loader}. If given, it should be a callable # that implements the same interface as default_loader. +# @param base_url The base URL of the original file, to resolve +# relative include file references. +# @param max_depth The maximum number of recursive inclusions. +# Limited to reduce the risk of malicious content explosion. +# Pass a negative value to disable the limitation. +# @throws LimitedRecursiveIncludeError If the {@link max_depth} was exceeded. # @throws FatalIncludeError If the function fails to include a given # resource, or if the tree contains malformed XInclude elements. -# @throws OSError If the function fails to load a given resource. +# @throws IOError If the function fails to load a given resource. +# @returns the node or its replacement if it was an XInclude node -def include(elem, loader=None): +def include(elem, loader=None, base_url=None, + max_depth=DEFAULT_MAX_INCLUSION_DEPTH): + if max_depth is None: + max_depth = -1 + elif max_depth < 0: + raise ValueError("expected non-negative depth or None for 'max_depth', got %r" % max_depth) + + if hasattr(elem, 'getroot'): + elem = elem.getroot() if loader is None: loader = default_loader + + _include(elem, loader, base_url, max_depth, set()) + + +def _include(elem, loader, base_url, max_depth, _parent_hrefs): # look for xinclude elements i = 0 while i < len(elem): @@ -106,14 +136,24 @@ def include(elem, loader=None): if e.tag == XINCLUDE_INCLUDE: # process xinclude directive href = e.get("href") + if base_url: + href = urljoin(base_url, href) parse = e.get("parse", "xml") if parse == "xml": + if href in _parent_hrefs: + raise FatalIncludeError("recursive include of %s" % href) + if max_depth == 0: + raise LimitedRecursiveIncludeError( + "maximum xinclude depth reached when including file %s" % href) + _parent_hrefs.add(href) node = loader(href, parse) if node is None: raise FatalIncludeError( "cannot load %r as %r" % (href, parse) ) - node = copy.copy(node) + node = copy.copy(node) # FIXME: this makes little sense with recursive includes + _include(node, loader, href, max_depth - 1, _parent_hrefs) + _parent_hrefs.remove(href) if e.tail: node.tail = (node.tail or "") + e.tail elem[i] = node @@ -123,11 +163,13 @@ def include(elem, loader=None): raise FatalIncludeError( "cannot load %r as %r" % (href, parse) ) + if e.tail: + text += e.tail if i: node = elem[i-1] - node.tail = (node.tail or "") + text + (e.tail or "") + node.tail = (node.tail or "") + text else: - elem.text = (elem.text or "") + text + (e.tail or "") + elem.text = (elem.text or "") + text del elem[i] continue else: @@ -139,5 +181,5 @@ def include(elem, loader=None): "xi:fallback tag must be child of xi:include (%r)" % e.tag ) else: - include(e, loader) - i = i + 1 + _include(e, loader, base_url, max_depth, _parent_hrefs) + i += 1 diff --git a/Lib/xml/etree/ElementPath.py b/Lib/xml/etree/ElementPath.py index d318e65d84..cd3c354d08 100644 --- a/Lib/xml/etree/ElementPath.py +++ b/Lib/xml/etree/ElementPath.py @@ -48,7 +48,7 @@ # -------------------------------------------------------------------- # Licensed to PSF under a Contributor Agreement. -# See http://www.python.org/psf/license for licensing details. +# See https://www.python.org/psf/license for licensing details. ## # Implementation module for XPath support. There's usually no reason @@ -65,8 +65,9 @@ r"//?|" r"\.\.|" r"\(\)|" + r"!=|" r"[/.*:\[\]\(\)@=])|" - r"((?:\{[^}]+\})?[^/\[\]\(\)@=\s]+)|" + r"((?:\{[^}]+\})?[^/\[\]\(\)@!=\s]+)|" r"\s+" ) @@ -225,7 +226,6 @@ def select(context, result): def prepare_predicate(next, token): # FIXME: replace with real parser!!! refs: - # http://effbot.org/zone/simple-iterator-parser.htm # http://javascript.crockford.com/tdop/tdop.html signature = [] predicate = [] @@ -253,15 +253,19 @@ def select(context, result): if elem.get(key) is not None: yield elem return select - if signature == "@-='": - # [@attribute='value'] + if signature == "@-='" or signature == "@-!='": + # [@attribute='value'] or [@attribute!='value'] key = predicate[1] value = predicate[-1] def select(context, result): for elem in result: if elem.get(key) == value: yield elem - return select + def select_negated(context, result): + for elem in result: + if (attr_value := elem.get(key)) is not None and attr_value != value: + yield elem + return select_negated if '!=' in signature else select if signature == "-" and not re.match(r"\-?\d+$", predicate[0]): # [tag] tag = predicate[0] @@ -270,8 +274,10 @@ def select(context, result): if elem.find(tag) is not None: yield elem return select - if signature == ".='" or (signature == "-='" and not re.match(r"\-?\d+$", predicate[0])): - # [.='value'] or [tag='value'] + if signature == ".='" or signature == ".!='" or ( + (signature == "-='" or signature == "-!='") + and not re.match(r"\-?\d+$", predicate[0])): + # [.='value'] or [tag='value'] or [.!='value'] or [tag!='value'] tag = predicate[0] value = predicate[-1] if tag: @@ -281,12 +287,22 @@ def select(context, result): if "".join(e.itertext()) == value: yield elem break + def select_negated(context, result): + for elem in result: + for e in elem.iterfind(tag): + if "".join(e.itertext()) != value: + yield elem + break else: def select(context, result): for elem in result: if "".join(elem.itertext()) == value: yield elem - return select + def select_negated(context, result): + for elem in result: + if "".join(elem.itertext()) != value: + yield elem + return select_negated if '!=' in signature else select if signature == "-" or signature == "-()" or signature == "-()-": # [index] or [last()] or [last()-index] if signature == "-": diff --git a/Lib/xml/etree/ElementTree.py b/Lib/xml/etree/ElementTree.py index f8538dfb2b..2503d9ee76 100644 --- a/Lib/xml/etree/ElementTree.py +++ b/Lib/xml/etree/ElementTree.py @@ -35,7 +35,7 @@ #--------------------------------------------------------------------- # Licensed to PSF under a Contributor Agreement. -# See http://www.python.org/psf/license for licensing details. +# See https://www.python.org/psf/license for licensing details. # # ElementTree # Copyright (c) 1999-2008 by Fredrik Lundh. All rights reserved. @@ -76,7 +76,7 @@ "dump", "Element", "ElementTree", "fromstring", "fromstringlist", - "iselement", "iterparse", + "indent", "iselement", "iterparse", "parse", "ParseError", "PI", "ProcessingInstruction", "QName", @@ -195,6 +195,13 @@ def copy(self): original tree. """ + warnings.warn( + "elem.copy() is deprecated. Use copy.copy(elem) instead.", + DeprecationWarning + ) + return self.__copy__() + + def __copy__(self): elem = self.makeelement(self.tag, self.attrib) elem.text = self.text elem.tail = self.tail @@ -273,19 +280,6 @@ def remove(self, subelement): # assert iselement(element) self._children.remove(subelement) - def getchildren(self): - """(Deprecated) Return all subelements. - - Elements are returned in document order. - - """ - warnings.warn( - "This method will be removed in future versions. " - "Use 'list(elem)' or iteration over elem instead.", - DeprecationWarning, stacklevel=2 - ) - return self._children - def find(self, path, namespaces=None): """Find first matching element by tag name or path. @@ -409,15 +403,6 @@ def iter(self, tag=None): for e in self._children: yield from e.iter(tag) - # compatibility - def getiterator(self, tag=None): - warnings.warn( - "This method will be removed in future versions. " - "Use 'elem.iter()' or 'list(elem.iter())' instead.", - DeprecationWarning, stacklevel=2 - ) - return list(self.iter(tag)) - def itertext(self): """Create text iterator. @@ -617,15 +602,6 @@ def iter(self, tag=None): # assert self._root is not None return self._root.iter(tag) - # compatibility - def getiterator(self, tag=None): - warnings.warn( - "This method will be removed in future versions. " - "Use 'tree.iter()' or 'list(tree.iter())' instead.", - DeprecationWarning, stacklevel=2 - ) - return list(self.iter(tag)) - def find(self, path, namespaces=None): """Find first matching element by tag name or path. @@ -752,16 +728,11 @@ def write(self, file_or_filename, encoding = "utf-8" else: encoding = "us-ascii" - enc_lower = encoding.lower() - with _get_writer(file_or_filename, enc_lower) as write: + with _get_writer(file_or_filename, encoding) as (write, declared_encoding): if method == "xml" and (xml_declaration or (xml_declaration is None and - enc_lower not in ("utf-8", "us-ascii", "unicode"))): - declared_encoding = encoding - if enc_lower == "unicode": - # Retrieve the default encoding for the xml declaration - import locale - declared_encoding = locale.getpreferredencoding() + encoding.lower() != "unicode" and + declared_encoding.lower() not in ("utf-8", "us-ascii"))): write("\n" % ( declared_encoding,)) if method == "text": @@ -786,19 +757,17 @@ def _get_writer(file_or_filename, encoding): write = file_or_filename.write except AttributeError: # file_or_filename is a file name - if encoding == "unicode": - file = open(file_or_filename, "w") - else: - file = open(file_or_filename, "w", encoding=encoding, - errors="xmlcharrefreplace") - with file: - yield file.write + if encoding.lower() == "unicode": + encoding="utf-8" + with open(file_or_filename, "w", encoding=encoding, + errors="xmlcharrefreplace") as file: + yield file.write, encoding else: # file_or_filename is a file-like object # encoding determines if it is a text or binary writer - if encoding == "unicode": + if encoding.lower() == "unicode": # use a text writer as is - yield write + yield write, getattr(file_or_filename, "encoding", None) or "utf-8" else: # wrap a binary writer with TextIOWrapper with contextlib.ExitStack() as stack: @@ -829,7 +798,7 @@ def _get_writer(file_or_filename, encoding): # Keep the original file open when the TextIOWrapper is # destroyed stack.callback(file.detach) - yield file.write + yield file.write, encoding def _namespaces(elem, default_namespace=None): # identify namespaces used in this tree @@ -1081,15 +1050,15 @@ def _escape_attrib(text): text = text.replace(">", ">") if "\"" in text: text = text.replace("\"", """) - # The following business with carriage returns is to satisfy - # Section 2.11 of the XML specification, stating that - # CR or CR LN should be replaced with just LN + # Although section 2.11 of the XML specification states that CR or + # CR LN should be replaced with just LN, it applies only to EOLNs + # which take part of organizing file into lines. Within attributes, + # we are replacing these with entity numbers, so they do not count. # http://www.w3.org/TR/REC-xml/#sec-line-ends - if "\r\n" in text: - text = text.replace("\r\n", "\n") + # The current solution, contained in following six lines, was + # discussed in issue 17582 and 39011. if "\r" in text: - text = text.replace("\r", "\n") - #The following four lines are issue 17582 + text = text.replace("\r", " ") if "\n" in text: text = text.replace("\n", " ") if "\t" in text: @@ -1185,6 +1154,57 @@ def dump(elem): if not tail or tail[-1] != "\n": sys.stdout.write("\n") + +def indent(tree, space=" ", level=0): + """Indent an XML document by inserting newlines and indentation space + after elements. + + *tree* is the ElementTree or Element to modify. The (root) element + itself will not be changed, but the tail text of all elements in its + subtree will be adapted. + + *space* is the whitespace to insert for each indentation level, two + space characters by default. + + *level* is the initial indentation level. Setting this to a higher + value than 0 can be used for indenting subtrees that are more deeply + nested inside of a document. + """ + if isinstance(tree, ElementTree): + tree = tree.getroot() + if level < 0: + raise ValueError(f"Initial indentation level must be >= 0, got {level}") + if not len(tree): + return + + # Reduce the memory consumption by reusing indentation strings. + indentations = ["\n" + level * space] + + def _indent_children(elem, level): + # Start a new indentation level for the first child. + child_level = level + 1 + try: + child_indentation = indentations[child_level] + except IndexError: + child_indentation = indentations[level] + space + indentations.append(child_indentation) + + if not elem.text or not elem.text.strip(): + elem.text = child_indentation + + for child in elem: + if len(child): + _indent_children(child, child_level) + if not child.tail or not child.tail.strip(): + child.tail = child_indentation + + # Dedent after the last child by overwriting the previous indentation. + if not child.tail.strip(): + child.tail = indentations[level] + + _indent_children(tree, 0) + + # -------------------------------------------------------------------- # parsing @@ -1221,8 +1241,14 @@ def iterparse(source, events=None, parser=None): # Use the internal, undocumented _parser argument for now; When the # parser argument of iterparse is removed, this can be killed. pullparser = XMLPullParser(events=events, _parser=parser) - def iterator(): + + def iterator(source): + close_source = False try: + if not hasattr(source, "read"): + source = open(source, "rb") + close_source = True + yield None while True: yield from pullparser.read_events() # load event buffer @@ -1238,16 +1264,12 @@ def iterator(): source.close() class IterParseIterator(collections.abc.Iterator): - __next__ = iterator().__next__ + __next__ = iterator(source).__next__ it = IterParseIterator() it.root = None del iterator, IterParseIterator - close_source = False - if not hasattr(source, "read"): - source = open(source, "rb") - close_source = True - + next(it) return it @@ -1256,7 +1278,7 @@ class XMLPullParser: def __init__(self, events=None, *, _parser=None): # The _parser argument is for internal use only and must not be relied # upon in user code. It will be removed in a future release. - # See http://bugs.python.org/issue17741 for more details. + # See https://bugs.python.org/issue17741 for more details. self._events_queue = collections.deque() self._parser = _parser or XMLParser(target=TreeBuilder()) @@ -1533,7 +1555,6 @@ def __init__(self, *, target=None, encoding=None): # Configure pyexpat: buffering, new-style attribute handling. parser.buffer_text = 1 parser.ordered_attributes = 1 - parser.specified_attributes = 1 self._doctype = None self.entity = {} try: @@ -1553,7 +1574,6 @@ def _setevents(self, events_queue, events_to_report): for event_name in events_to_report: if event_name == "start": parser.ordered_attributes = 1 - parser.specified_attributes = 1 def handler(tag, attrib_in, event=event_name, append=append, start=self._start): append((event, start(tag, attrib_in))) @@ -1690,14 +1710,14 @@ def _default(self, text): def feed(self, data): """Feed encoded data to parser.""" try: - self.parser.Parse(data, 0) + self.parser.Parse(data, False) except self._error as v: self._raiseerror(v) def close(self): """Finish feeding data to parser and return element structure.""" try: - self.parser.Parse("", 1) # end of data + self.parser.Parse(b"", True) # end of data except self._error as v: self._raiseerror(v) try: diff --git a/Lib/xml/etree/__init__.py b/Lib/xml/etree/__init__.py index 27fd8f6d4e..e2ec53421d 100644 --- a/Lib/xml/etree/__init__.py +++ b/Lib/xml/etree/__init__.py @@ -30,4 +30,4 @@ # -------------------------------------------------------------------- # Licensed to PSF under a Contributor Agreement. -# See http://www.python.org/psf/license for licensing details. +# See https://www.python.org/psf/license for licensing details. diff --git a/Lib/xml/sax/__init__.py b/Lib/xml/sax/__init__.py index 13f6cf58d0..17b75879eb 100644 --- a/Lib/xml/sax/__init__.py +++ b/Lib/xml/sax/__init__.py @@ -67,18 +67,18 @@ def parseString(string, handler, errorHandler=ErrorHandler()): default_parser_list = sys.registry.getProperty(_key).split(",") -def make_parser(parser_list = []): +def make_parser(parser_list=()): """Creates and returns a SAX parser. Creates the first parser it is able to instantiate of the ones - given in the list created by doing parser_list + - default_parser_list. The lists must contain the names of Python + given in the iterable created by chaining parser_list and + default_parser_list. The iterables must contain the names of Python modules containing both a SAX parser and a create_parser function.""" - for parser_name in parser_list + default_parser_list: + for parser_name in list(parser_list) + default_parser_list: try: return _create_parser(parser_name) - except ImportError as e: + except ImportError: import sys if parser_name in sys.modules: # The parser module was found, but importing it diff --git a/Lib/xml/sax/expatreader.py b/Lib/xml/sax/expatreader.py index 5066ffc2fa..e334ac9fea 100644 --- a/Lib/xml/sax/expatreader.py +++ b/Lib/xml/sax/expatreader.py @@ -93,7 +93,7 @@ def __init__(self, namespaceHandling=0, bufsize=2**16-20): self._parser = None self._namespaces = namespaceHandling self._lex_handler_prop = None - self._parsing = 0 + self._parsing = False self._entity_stack = [] self._external_ges = 0 self._interning = None @@ -203,10 +203,10 @@ def setProperty(self, name, value): # IncrementalParser methods - def feed(self, data, isFinal = 0): + def feed(self, data, isFinal=False): if not self._parsing: self.reset() - self._parsing = 1 + self._parsing = True self._cont_handler.startDocument() try: @@ -237,13 +237,13 @@ def close(self): # If we are completing an external entity, do nothing here return try: - self.feed("", isFinal = 1) + self.feed(b"", isFinal=True) self._cont_handler.endDocument() - self._parsing = 0 + self._parsing = False # break cycle created by expat handlers pointing to our methods self._parser = None finally: - self._parsing = 0 + self._parsing = False if self._parser is not None: # Keep ErrorColumnNumber and ErrorLineNumber after closing. parser = _ClosedParser() @@ -307,7 +307,7 @@ def reset(self): self._parser.SetParamEntityParsing( expat.XML_PARAM_ENTITY_PARSING_UNLESS_STANDALONE) - self._parsing = 0 + self._parsing = False self._entity_stack = [] # Locator methods diff --git a/Lib/xml/sax/handler.py b/Lib/xml/sax/handler.py index 481733d2cb..e8d417e519 100644 --- a/Lib/xml/sax/handler.py +++ b/Lib/xml/sax/handler.py @@ -340,3 +340,48 @@ def resolveEntity(self, publicId, systemId): property_xml_string, property_encoding, property_interning_dict] + + +class LexicalHandler: + """Optional SAX2 handler for lexical events. + + This handler is used to obtain lexical information about an XML + document, that is, information about how the document was encoded + (as opposed to what it contains, which is reported to the + ContentHandler), such as comments and CDATA marked section + boundaries. + + To set the LexicalHandler of an XMLReader, use the setProperty + method with the property identifier + 'http://xml.org/sax/properties/lexical-handler'.""" + + def comment(self, content): + """Reports a comment anywhere in the document (including the + DTD and outside the document element). + + content is a string that holds the contents of the comment.""" + + def startDTD(self, name, public_id, system_id): + """Report the start of the DTD declarations, if the document + has an associated DTD. + + A startEntity event will be reported before declaration events + from the external DTD subset are reported, and this can be + used to infer from which subset DTD declarations derive. + + name is the name of the document element type, public_id the + public identifier of the DTD (or None if none were supplied) + and system_id the system identfier of the external subset (or + None if none were supplied).""" + + def endDTD(self): + """Signals the end of DTD declarations.""" + + def startCDATA(self): + """Reports the beginning of a CDATA marked section. + + The contents of the CDATA marked section will be reported + through the characters event.""" + + def endCDATA(self): + """Reports the end of a CDATA marked section.""" diff --git a/Lib/xml/sax/saxutils.py b/Lib/xml/sax/saxutils.py index a69c7f7621..c1612ea1ce 100644 --- a/Lib/xml/sax/saxutils.py +++ b/Lib/xml/sax/saxutils.py @@ -56,8 +56,7 @@ def quoteattr(data, entities={}): the optional entities parameter. The keys and values must all be strings; each key will be replaced with its corresponding value. """ - entities = entities.copy() - entities.update({'\n': ' ', '\r': ' ', '\t':' '}) + entities = {**entities, '\n': ' ', '\r': ' ', '\t':' '} data = escape(data, entities) if '"' in data: if "'" in data: @@ -340,6 +339,8 @@ def prepare_input_source(source, base=""): """This function takes an InputSource and an optional base URL and returns a fully resolved InputSource object ready for reading.""" + if isinstance(source, os.PathLike): + source = os.fspath(source) if isinstance(source, str): source = xmlreader.InputSource(source) elif hasattr(source, "read"): From 2539e8276c36e364b2a96918bd503cadc1e3c0a6 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Sun, 14 Aug 2022 23:25:26 +0900 Subject: [PATCH 21/23] adjust failing test markers for xml --- Lib/test/test_xml_etree.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/Lib/test/test_xml_etree.py b/Lib/test/test_xml_etree.py index b721f028a1..53552dc8b6 100644 --- a/Lib/test/test_xml_etree.py +++ b/Lib/test/test_xml_etree.py @@ -2466,8 +2466,6 @@ def test___init__(self): self.assertIsNot(element_foo.attrib, attrib) self.assertNotEqual(element_foo.attrib, attrib) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_copy(self): # Only run this test if Element.copy() is defined. if "copy" not in dir(ET.Element): @@ -3918,6 +3916,8 @@ def test_write_to_filename(self): with open(TESTFN, 'rb') as f: self.assertEqual(f.read(), b'''ø''') + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_write_to_filename_with_encoding(self): self.addCleanup(os_helper.unlink, TESTFN) tree = ET.ElementTree(ET.XML('''\xf8''')) @@ -3931,6 +3931,8 @@ def test_write_to_filename_with_encoding(self): b'''\n''' b'''\xf8''')) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_write_to_filename_as_unicode(self): self.addCleanup(os_helper.unlink, TESTFN) with open(TESTFN, 'w') as f: @@ -3976,6 +3978,8 @@ def test_write_to_binary_file(self): with open(TESTFN, 'rb') as f: self.assertEqual(f.read(), b'''ø''') + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_write_to_binary_file_with_encoding(self): self.addCleanup(os_helper.unlink, TESTFN) tree = ET.ElementTree(ET.XML('''\xf8''')) From 563f9ecd53ca6a91b78dbe72172add132c016e27 Mon Sep 17 00:00:00 2001 From: CPython Developers <> Date: Mon, 15 Aug 2022 00:35:20 +0900 Subject: [PATCH 22/23] Update test_repl from cpython 3.10.6 --- Lib/test/test_repl.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/Lib/test/test_repl.py b/Lib/test/test_repl.py index 71f192f90d..03bf8d8b54 100644 --- a/Lib/test/test_repl.py +++ b/Lib/test/test_repl.py @@ -29,7 +29,9 @@ def spawn_repl(*args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, **kw): # test.support.script_helper. env = kw.setdefault('env', dict(os.environ)) env['TERM'] = 'vt100' - return subprocess.Popen(cmd_line, executable=sys.executable, + return subprocess.Popen(cmd_line, + executable=sys.executable, + text=True, stdin=subprocess.PIPE, stdout=stdout, stderr=stderr, **kw) @@ -49,12 +51,11 @@ def test_no_memory(self): sys.exit(0) """ user_input = dedent(user_input) - user_input = user_input.encode() p = spawn_repl() with SuppressCrashReport(): p.stdin.write(user_input) output = kill_python(p) - self.assertIn(b'After the exception.', output) + self.assertIn('After the exception.', output) # Exit code 120: Py_FinalizeEx() failed to flush stdout and stderr. self.assertIn(p.returncode, (1, 120)) @@ -86,13 +87,26 @@ def test_multiline_string_parsing(self): """ ''' user_input = dedent(user_input) - user_input = user_input.encode() p = spawn_repl() - with SuppressCrashReport(): - p.stdin.write(user_input) + p.stdin.write(user_input) output = kill_python(p) self.assertEqual(p.returncode, 0) + def test_close_stdin(self): + user_input = dedent(''' + import os + print("before close") + os.close(0) + ''') + prepare_repl = dedent(''' + from test.support import suppress_msvcrt_asserts + suppress_msvcrt_asserts() + ''') + process = spawn_repl('-c', prepare_repl) + output = process.communicate(user_input)[0] + self.assertEqual(process.returncode, 0) + self.assertIn('before close', output) + if __name__ == "__main__": unittest.main() From bf53bb0c5fc6f0d591395021b53aee883ae77bb8 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Mon, 15 Aug 2022 00:37:39 +0900 Subject: [PATCH 23/23] Mark failing tests from test_repl --- Lib/test/test_repl.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Lib/test/test_repl.py b/Lib/test/test_repl.py index 03bf8d8b54..a63c1213c6 100644 --- a/Lib/test/test_repl.py +++ b/Lib/test/test_repl.py @@ -92,6 +92,8 @@ def test_multiline_string_parsing(self): output = kill_python(p) self.assertEqual(p.returncode, 0) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_close_stdin(self): user_input = dedent(''' import os