From 6270a5b04c13cefc75ae35f49b0dda15097354a8 Mon Sep 17 00:00:00 2001
From: CPython Developers <>
Date: Sun, 14 Aug 2022 23:13:35 +0900
Subject: [PATCH 01/23] Update webbrowser from CPython 3.10.6
---
Lib/test/test_webbrowser.py | 19 +---
Lib/webbrowser.py | 217 ++++++++++++++++++++++--------------
2 files changed, 137 insertions(+), 99 deletions(-)
diff --git a/Lib/test/test_webbrowser.py b/Lib/test/test_webbrowser.py
index a97e64a753..673cc995d3 100644
--- a/Lib/test/test_webbrowser.py
+++ b/Lib/test/test_webbrowser.py
@@ -5,7 +5,8 @@
import subprocess
from unittest import mock
from test import support
-from test.support import os_helper, import_helper
+from test.support import import_helper
+from test.support import os_helper
URL = 'http://www.example.com'
@@ -171,29 +172,21 @@ class OperaCommandTest(CommandTestMixin, unittest.TestCase):
browser_class = webbrowser.Opera
- # TODO: RUSTPYTHON
- @unittest.expectedFailure
def test_open(self):
self._test('open',
options=[],
arguments=[URL])
- # TODO: RUSTPYTHON
- @unittest.expectedFailure
def test_open_with_autoraise_false(self):
self._test('open', kw=dict(autoraise=False),
options=[],
arguments=[URL])
- # TODO: RUSTPYTHON
- @unittest.expectedFailure
def test_open_new(self):
self._test('open_new',
options=['--new-window'],
arguments=[URL])
- # TODO: RUSTPYTHON
- @unittest.expectedFailure
def test_open_new_tab(self):
self._test('open_new_tab',
options=[],
@@ -267,23 +260,17 @@ class ExampleBrowser:
self.assertEqual(webbrowser._tryorder, expected_tryorder)
self.assertEqual(webbrowser._browsers, expected_browsers)
- # TODO: RUSTPYTHON
- @unittest.expectedFailure
def test_register(self):
self._check_registration(preferred=False)
def test_register_default(self):
self._check_registration(preferred=None)
- # TODO: RUSTPYTHON
- @unittest.expectedFailure
def test_register_preferred(self):
self._check_registration(preferred=True)
class ImportTest(unittest.TestCase):
- # TODO: RUSTPYTHON
- @unittest.expectedFailure
def test_register(self):
webbrowser = import_helper.import_fresh_module('webbrowser')
self.assertIsNone(webbrowser._tryorder)
@@ -298,8 +285,6 @@ class ExampleBrowser:
self.assertIn('example1', webbrowser._browsers)
self.assertEqual(webbrowser._browsers['example1'], [ExampleBrowser, None])
- # TODO: RUSTPYTHON
- @unittest.expectedFailure
def test_get(self):
webbrowser = import_helper.import_fresh_module('webbrowser')
self.assertIsNone(webbrowser._tryorder)
diff --git a/Lib/webbrowser.py b/Lib/webbrowser.py
index 6f43b7f126..ec3cece48c 100755
--- a/Lib/webbrowser.py
+++ b/Lib/webbrowser.py
@@ -1,5 +1,5 @@
#! /usr/bin/env python3
-"""Interfaces for launching and remotely controlling Web browsers."""
+"""Interfaces for launching and remotely controlling web browsers."""
# Maintained by Georg Brandl.
import os
@@ -7,25 +7,39 @@
import shutil
import sys
import subprocess
+import threading
__all__ = ["Error", "open", "open_new", "open_new_tab", "get", "register"]
class Error(Exception):
pass
-_browsers = {} # Dictionary of available browser controllers
-_tryorder = [] # Preference order of available browsers
-
-def register(name, klass, instance=None, update_tryorder=1):
- """Register a browser connector and, optionally, connection."""
- _browsers[name.lower()] = [klass, instance]
- if update_tryorder > 0:
- _tryorder.append(name)
- elif update_tryorder < 0:
- _tryorder.insert(0, name)
+_lock = threading.RLock()
+_browsers = {} # Dictionary of available browser controllers
+_tryorder = None # Preference order of available browsers
+_os_preferred_browser = None # The preferred browser
+
+def register(name, klass, instance=None, *, preferred=False):
+ """Register a browser connector."""
+ with _lock:
+ if _tryorder is None:
+ register_standard_browsers()
+ _browsers[name.lower()] = [klass, instance]
+
+ # Preferred browsers go to the front of the list.
+ # Need to match to the default browser returned by xdg-settings, which
+ # may be of the form e.g. "firefox.desktop".
+ if preferred or (_os_preferred_browser and name in _os_preferred_browser):
+ _tryorder.insert(0, name)
+ else:
+ _tryorder.append(name)
def get(using=None):
"""Return a browser launcher instance appropriate for the environment."""
+ if _tryorder is None:
+ with _lock:
+ if _tryorder is None:
+ register_standard_browsers()
if using is not None:
alternatives = [using]
else:
@@ -55,6 +69,18 @@ def get(using=None):
# instead of "from webbrowser import *".
def open(url, new=0, autoraise=True):
+ """Display url using the default browser.
+
+ If possible, open url in a location determined by new.
+ - 0: the same browser window (the default).
+ - 1: a new browser window.
+ - 2: a new browser page ("tab").
+ If possible, autoraise raises the window (the default) or not.
+ """
+ if _tryorder is None:
+ with _lock:
+ if _tryorder is None:
+ register_standard_browsers()
for name in _tryorder:
browser = get(name)
if browser.open(url, new, autoraise):
@@ -62,14 +88,22 @@ def open(url, new=0, autoraise=True):
return False
def open_new(url):
+ """Open url in a new window of the default browser.
+
+ If not possible, then open url in the only browser window.
+ """
return open(url, 1)
def open_new_tab(url):
+ """Open url in a new page ("tab") of the default browser.
+
+ If not possible, then the behavior becomes equivalent to open_new().
+ """
return open(url, 2)
-def _synthesize(browser, update_tryorder=1):
- """Attempt to synthesize a controller base on existing controllers.
+def _synthesize(browser, *, preferred=False):
+ """Attempt to synthesize a controller based on existing controllers.
This is useful to create a controller when a user specifies a path to
an entry in the BROWSER environment variable -- we can copy a general
@@ -95,7 +129,7 @@ def _synthesize(browser, update_tryorder=1):
controller = copy.copy(controller)
controller.name = browser
controller.basename = os.path.basename(browser)
- register(browser, None, controller, update_tryorder)
+ register(browser, None, instance=controller, preferred=preferred)
return [None, controller]
return [None, None]
@@ -136,6 +170,7 @@ def __init__(self, name):
self.basename = os.path.basename(self.name)
def open(self, url, new=0, autoraise=True):
+ sys.audit("webbrowser.open", url)
cmdline = [self.name] + [arg.replace("%s", url)
for arg in self.args]
try:
@@ -155,6 +190,7 @@ class BackgroundBrowser(GenericBrowser):
def open(self, url, new=0, autoraise=True):
cmdline = [self.name] + [arg.replace("%s", url)
for arg in self.args]
+ sys.audit("webbrowser.open", url)
try:
if sys.platform[:3] == 'win':
p = subprocess.Popen(cmdline)
@@ -183,7 +219,7 @@ class UnixBrowser(BaseBrowser):
remote_action_newwin = None
remote_action_newtab = None
- def _invoke(self, args, remote, autoraise):
+ def _invoke(self, args, remote, autoraise, url=None):
raise_opt = []
if remote and self.raise_opts:
# use autoraise argument only for remote invocation
@@ -219,6 +255,7 @@ def _invoke(self, args, remote, autoraise):
return not p.wait()
def open(self, url, new=0, autoraise=True):
+ sys.audit("webbrowser.open", url)
if new == 0:
action = self.remote_action
elif new == 1:
@@ -235,7 +272,7 @@ def open(self, url, new=0, autoraise=True):
args = [arg.replace("%s", url).replace("%action", action)
for arg in self.remote_args]
args = [arg for arg in args if arg]
- success = self._invoke(args, True, autoraise)
+ success = self._invoke(args, True, autoraise, url)
if not success:
# remote invocation failed, try straight way
args = [arg.replace("%s", url) for arg in self.args]
@@ -290,11 +327,10 @@ class Chrome(UnixBrowser):
class Opera(UnixBrowser):
"Launcher class for Opera browser."
- raise_opts = ["-noraise", ""]
- remote_args = ['-remote', 'openURL(%s%action)']
+ remote_args = ['%action', '%s']
remote_action = ""
- remote_action_newwin = ",new-window"
- remote_action_newtab = ",new-page"
+ remote_action_newwin = "--new-window"
+ remote_action_newtab = ""
background = True
@@ -320,6 +356,7 @@ class Konqueror(BaseBrowser):
"""
def open(self, url, new=0, autoraise=True):
+ sys.audit("webbrowser.open", url)
# XXX Currently I know no way to prevent KFM from opening a new win.
if new == 2:
action = "newTab"
@@ -376,7 +413,7 @@ def _find_grail_rc(self):
tempdir = os.path.join(tempfile.gettempdir(),
".grail-unix")
user = pwd.getpwuid(os.getuid())[0]
- filename = os.path.join(tempdir, user + "-*")
+ filename = os.path.join(glob.escape(tempdir), glob.escape(user) + "-*")
maybes = glob.glob(filename)
if not maybes:
return None
@@ -403,6 +440,7 @@ def _remote(self, action):
return 1
def open(self, url, new=0, autoraise=True):
+ sys.audit("webbrowser.open", url)
if new:
ok = self._remote("LOADNEW " + url)
else:
@@ -482,25 +520,80 @@ def register_X_browsers():
if shutil.which("grail"):
register("grail", Grail, None)
-# Prefer X browsers if present
-if os.environ.get("DISPLAY"):
- register_X_browsers()
-
-# Also try console browsers
-if os.environ.get("TERM"):
- if shutil.which("www-browser"):
- register("www-browser", None, GenericBrowser("www-browser"))
- # The Links/elinks browsers
- if shutil.which("links"):
- register("links", None, GenericBrowser("links"))
- if shutil.which("elinks"):
- register("elinks", None, Elinks("elinks"))
- # The Lynx browser ,
- if shutil.which("lynx"):
- register("lynx", None, GenericBrowser("lynx"))
- # The w3m browser
- if shutil.which("w3m"):
- register("w3m", None, GenericBrowser("w3m"))
+def register_standard_browsers():
+ global _tryorder
+ _tryorder = []
+
+ if sys.platform == 'darwin':
+ register("MacOSX", None, MacOSXOSAScript('default'))
+ register("chrome", None, MacOSXOSAScript('chrome'))
+ register("firefox", None, MacOSXOSAScript('firefox'))
+ register("safari", None, MacOSXOSAScript('safari'))
+ # OS X can use below Unix support (but we prefer using the OS X
+ # specific stuff)
+
+ if sys.platform == "serenityos":
+ # SerenityOS webbrowser, simply called "Browser".
+ register("Browser", None, BackgroundBrowser("Browser"))
+
+ if sys.platform[:3] == "win":
+ # First try to use the default Windows browser
+ register("windows-default", WindowsDefault)
+
+ # Detect some common Windows browsers, fallback to IE
+ iexplore = os.path.join(os.environ.get("PROGRAMFILES", "C:\\Program Files"),
+ "Internet Explorer\\IEXPLORE.EXE")
+ for browser in ("firefox", "firebird", "seamonkey", "mozilla",
+ "netscape", "opera", iexplore):
+ if shutil.which(browser):
+ register(browser, None, BackgroundBrowser(browser))
+ else:
+ # Prefer X browsers if present
+ if os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY"):
+ try:
+ cmd = "xdg-settings get default-web-browser".split()
+ raw_result = subprocess.check_output(cmd, stderr=subprocess.DEVNULL)
+ result = raw_result.decode().strip()
+ except (FileNotFoundError, subprocess.CalledProcessError, PermissionError, NotADirectoryError) :
+ pass
+ else:
+ global _os_preferred_browser
+ _os_preferred_browser = result
+
+ register_X_browsers()
+
+ # Also try console browsers
+ if os.environ.get("TERM"):
+ if shutil.which("www-browser"):
+ register("www-browser", None, GenericBrowser("www-browser"))
+ # The Links/elinks browsers
+ if shutil.which("links"):
+ register("links", None, GenericBrowser("links"))
+ if shutil.which("elinks"):
+ register("elinks", None, Elinks("elinks"))
+ # The Lynx browser ,
+ if shutil.which("lynx"):
+ register("lynx", None, GenericBrowser("lynx"))
+ # The w3m browser
+ if shutil.which("w3m"):
+ register("w3m", None, GenericBrowser("w3m"))
+
+ # OK, now that we know what the default preference orders for each
+ # platform are, allow user to override them with the BROWSER variable.
+ if "BROWSER" in os.environ:
+ userchoices = os.environ["BROWSER"].split(os.pathsep)
+ userchoices.reverse()
+
+ # Treat choices in same way as if passed into get() but do register
+ # and prepend to _tryorder
+ for cmdline in userchoices:
+ if cmdline != '':
+ cmd = _synthesize(cmdline, preferred=True)
+ if cmd[1] is None:
+ register(cmdline, None, GenericBrowser(cmdline), preferred=True)
+
+ # what to do if _tryorder is now empty?
+
#
# Platform support for Windows
@@ -509,6 +602,7 @@ def register_X_browsers():
if sys.platform[:3] == "win":
class WindowsDefault(BaseBrowser):
def open(self, url, new=0, autoraise=True):
+ sys.audit("webbrowser.open", url)
try:
os.startfile(url)
except OSError:
@@ -518,20 +612,6 @@ def open(self, url, new=0, autoraise=True):
else:
return True
- _tryorder = []
- _browsers = {}
-
- # First try to use the default Windows browser
- register("windows-default", WindowsDefault)
-
- # Detect some common Windows browsers, fallback to IE
- iexplore = os.path.join(os.environ.get("PROGRAMFILES", "C:\\Program Files"),
- "Internet Explorer\\IEXPLORE.EXE")
- for browser in ("firefox", "firebird", "seamonkey", "mozilla",
- "netscape", "opera", iexplore):
- if shutil.which(browser):
- register(browser, None, BackgroundBrowser(browser))
-
#
# Platform support for MacOS
#
@@ -552,6 +632,7 @@ def __init__(self, name):
self.name = name
def open(self, url, new=0, autoraise=True):
+ sys.audit("webbrowser.open", url)
assert "'" not in url
# hack for local urls
if not ':' in url:
@@ -608,34 +689,6 @@ def open(self, url, new=0, autoraise=True):
return not rc
- # Don't clear _tryorder or _browsers since OS X can use above Unix support
- # (but we prefer using the OS X specific stuff)
- register("safari", None, MacOSXOSAScript('safari'), -1)
- register("firefox", None, MacOSXOSAScript('firefox'), -1)
- register("chrome", None, MacOSXOSAScript('chrome'), -1)
- register("MacOSX", None, MacOSXOSAScript('default'), -1)
-
-
-# OK, now that we know what the default preference orders for each
-# platform are, allow user to override them with the BROWSER variable.
-if "BROWSER" in os.environ:
- _userchoices = os.environ["BROWSER"].split(os.pathsep)
- _userchoices.reverse()
-
- # Treat choices in same way as if passed into get() but do register
- # and prepend to _tryorder
- for cmdline in _userchoices:
- if cmdline != '':
- cmd = _synthesize(cmdline, -1)
- if cmd[1] is None:
- register(cmdline, None, GenericBrowser(cmdline), -1)
- cmdline = None # to make del work if _userchoices was empty
- del cmdline
- del _userchoices
-
-# what to do if _tryorder is now empty?
-
-
def main():
import getopt
usage = """Usage: %s [-n | -t] url
From ead12b5611bf466bd079b8c45794c0683d33e5ff Mon Sep 17 00:00:00 2001
From: CPython Developers <>
Date: Sun, 14 Aug 2022 23:15:13 +0900
Subject: [PATCH 02/23] Update test_zlib from CPython 3.10.6
---
Lib/test/test_zlib.py | 30 +++++++++++-------------------
1 file changed, 11 insertions(+), 19 deletions(-)
diff --git a/Lib/test/test_zlib.py b/Lib/test/test_zlib.py
index a6575002d1..9c1c7ddba9 100644
--- a/Lib/test/test_zlib.py
+++ b/Lib/test/test_zlib.py
@@ -1,11 +1,13 @@
import unittest
from test import support
+from test.support import import_helper
import binascii
import copy
import pickle
import random
import sys
-from test.support import bigmemtest, _1G, _4G, import_helper
+from test.support import bigmemtest, _1G, _4G
+
zlib = import_helper.import_module('zlib')
@@ -129,6 +131,12 @@ def test_overflow(self):
with self.assertRaisesRegex(OverflowError, 'int too large'):
zlib.decompressobj().flush(sys.maxsize + 1)
+ @support.cpython_only
+ def test_disallow_instantiation(self):
+ # Ensure that the type disallows instantiation (bpo-43916)
+ support.check_disallow_instantiation(self, type(zlib.compressobj()))
+ support.check_disallow_instantiation(self, type(zlib.decompressobj()))
+
class BaseCompressTestCase(object):
def check_big_compress_buffer(self, size, compress_func):
@@ -136,8 +144,7 @@ def check_big_compress_buffer(self, size, compress_func):
# Generate 10 MiB worth of random, and expand it by repeating it.
# The assumption is that zlib's memory is not big enough to exploit
# such spread out redundancy.
- data = b''.join([random.getrandbits(8 * _1M).to_bytes(_1M, 'little')
- for i in range(10)])
+ data = random.randbytes(_1M * 10)
data = data * (size // len(data) + 1)
try:
compress_func(data)
@@ -498,7 +505,7 @@ def test_odd_flush(self):
# others might simply have a single RNG
gen = random
gen.seed(1)
- data = genblock(1, 17 * 1024, generator=gen)
+ data = gen.randbytes(17 * 1024)
# compress, sync-flush, and decompress
first = co.compress(data)
@@ -847,20 +854,6 @@ def test_wbits(self):
self.assertEqual(dco.decompress(gzip), HAMLET_SCENE)
-def genblock(seed, length, step=1024, generator=random):
- """length-byte stream of random data from a seed (in step-byte blocks)."""
- if seed is not None:
- generator.seed(seed)
- randint = generator.randint
- if length < step or step < 2:
- step = length
- blocks = bytes()
- for i in range(0, length, step):
- blocks += bytes(randint(0, 255) for x in range(step))
- return blocks
-
-
-
def choose_lines(source, number, seed=None, generator=random):
"""Return a list of number lines randomly chosen from the source"""
if seed is not None:
@@ -869,7 +862,6 @@ def choose_lines(source, number, seed=None, generator=random):
return [generator.choice(sources) for n in range(number)]
-
HAMLET_SCENE = b"""
LAERTES
From 9571a68c01284191d4b18229694c303dad8992ff Mon Sep 17 00:00:00 2001
From: CPython Developers <>
Date: Sun, 14 Aug 2022 23:17:42 +0900
Subject: [PATCH 03/23] Update test_yield_from from CPython 3.10.6
---
Lib/test/test_yield_from.py | 3 +++
1 file changed, 3 insertions(+)
diff --git a/Lib/test/test_yield_from.py b/Lib/test/test_yield_from.py
index a4a971072a..97acfd5413 100644
--- a/Lib/test/test_yield_from.py
+++ b/Lib/test/test_yield_from.py
@@ -946,6 +946,9 @@ def two():
res.append(g1.throw(MyErr))
except StopIteration:
pass
+ except:
+ self.assertEqual(res, [0, 1, 2, 3])
+ raise
# Check with close
class MyIt(object):
def __iter__(self):
From 3c5ab3a8ef1b3d175ceba9f0982221c8bde4f85c Mon Sep 17 00:00:00 2001
From: CPython Developers <>
Date: Sun, 14 Aug 2022 23:26:31 +0900
Subject: [PATCH 04/23] Update test_with from CPython 3.10.6
---
Lib/test/test_with.py | 8 +++++++-
1 file changed, 7 insertions(+), 1 deletion(-)
diff --git a/Lib/test/test_with.py b/Lib/test/test_with.py
index b1d7a15b5e..f21bf65fed 100644
--- a/Lib/test/test_with.py
+++ b/Lib/test/test_with.py
@@ -7,7 +7,7 @@
import sys
import unittest
from collections import deque
-from contextlib import _GeneratorContextManager, contextmanager
+from contextlib import _GeneratorContextManager, contextmanager, nullcontext
class MockContextManager(_GeneratorContextManager):
@@ -641,6 +641,12 @@ class B: pass
self.assertEqual(blah.two, 2)
self.assertEqual(blah.three, 3)
+ def testWithExtendedTargets(self):
+ with nullcontext(range(1, 5)) as (a, *b, c):
+ self.assertEqual(a, 1)
+ self.assertEqual(b, [2, 3])
+ self.assertEqual(c, 4)
+
class ExitSwallowsExceptionTestCase(unittest.TestCase):
From be1636a9477cbab3be758931e701593331545e82 Mon Sep 17 00:00:00 2001
From: Jeong YunWon
Date: Sun, 14 Aug 2022 23:28:21 +0900
Subject: [PATCH 05/23] crlf -> lf
---
Lib/test/test_trace.py | 1172 ++++++++++++++++++++--------------------
1 file changed, 586 insertions(+), 586 deletions(-)
diff --git a/Lib/test/test_trace.py b/Lib/test/test_trace.py
index e5183b90eb..d74501549a 100644
--- a/Lib/test/test_trace.py
+++ b/Lib/test/test_trace.py
@@ -1,586 +1,586 @@
-import os
-import sys
-from test.support import captured_stdout
-from test.support.os_helper import TESTFN, TESTFN_UNICODE, FS_NONASCII, rmtree, unlink
-from test.support.script_helper import assert_python_ok, assert_python_failure
-import textwrap
-import unittest
-
-import trace
-from trace import Trace
-
-from test.tracedmodules import testmod
-
-#------------------------------- Utilities -----------------------------------#
-
-def fix_ext_py(filename):
- """Given a .pyc filename converts it to the appropriate .py"""
- if filename.endswith('.pyc'):
- filename = filename[:-1]
- return filename
-
-def my_file_and_modname():
- """The .py file and module name of this file (__file__)"""
- modname = os.path.splitext(os.path.basename(__file__))[0]
- return fix_ext_py(__file__), modname
-
-def get_firstlineno(func):
- return func.__code__.co_firstlineno
-
-#-------------------- Target functions for tracing ---------------------------#
-#
-# The relative line numbers of lines in these functions matter for verifying
-# tracing. Please modify the appropriate tests if you change one of the
-# functions. Absolute line numbers don't matter.
-#
-
-def traced_func_linear(x, y):
- a = x
- b = y
- c = a + b
- return c
-
-def traced_func_loop(x, y):
- c = x
- for i in range(5):
- c += y
- return c
-
-def traced_func_importing(x, y):
- return x + y + testmod.func(1)
-
-def traced_func_simple_caller(x):
- c = traced_func_linear(x, x)
- return c + x
-
-def traced_func_importing_caller(x):
- k = traced_func_simple_caller(x)
- k += traced_func_importing(k, x)
- return k
-
-def traced_func_generator(num):
- c = 5 # executed once
- for i in range(num):
- yield i + c
-
-def traced_func_calling_generator():
- k = 0
- for i in traced_func_generator(10):
- k += i
-
-def traced_doubler(num):
- return num * 2
-
-def traced_capturer(*args, **kwargs):
- return args, kwargs
-
-def traced_caller_list_comprehension():
- k = 10
- mylist = [traced_doubler(i) for i in range(k)]
- return mylist
-
-def traced_decorated_function():
- def decorator1(f):
- return f
- def decorator_fabric():
- def decorator2(f):
- return f
- return decorator2
- @decorator1
- @decorator_fabric()
- def func():
- pass
- func()
-
-
-class TracedClass(object):
- def __init__(self, x):
- self.a = x
-
- def inst_method_linear(self, y):
- return self.a + y
-
- def inst_method_calling(self, x):
- c = self.inst_method_linear(x)
- return c + traced_func_linear(x, c)
-
- @classmethod
- def class_method_linear(cls, y):
- return y * 2
-
- @staticmethod
- def static_method_linear(y):
- return y * 2
-
-
-#------------------------------ Test cases -----------------------------------#
-
-
-class TestLineCounts(unittest.TestCase):
- """White-box testing of line-counting, via runfunc"""
- def setUp(self):
- self.addCleanup(sys.settrace, sys.gettrace())
- self.tracer = Trace(count=1, trace=0, countfuncs=0, countcallers=0)
- self.my_py_filename = fix_ext_py(__file__)
-
- # TODO: RUSTPYTHON
- @unittest.expectedFailure
- def test_traced_func_linear(self):
- result = self.tracer.runfunc(traced_func_linear, 2, 5)
- self.assertEqual(result, 7)
-
- # all lines are executed once
- expected = {}
- firstlineno = get_firstlineno(traced_func_linear)
- for i in range(1, 5):
- expected[(self.my_py_filename, firstlineno + i)] = 1
-
- self.assertEqual(self.tracer.results().counts, expected)
-
- # TODO: RUSTPYTHON
- @unittest.expectedFailure
- def test_traced_func_loop(self):
- self.tracer.runfunc(traced_func_loop, 2, 3)
-
- firstlineno = get_firstlineno(traced_func_loop)
- expected = {
- (self.my_py_filename, firstlineno + 1): 1,
- (self.my_py_filename, firstlineno + 2): 6,
- (self.my_py_filename, firstlineno + 3): 5,
- (self.my_py_filename, firstlineno + 4): 1,
- }
- self.assertEqual(self.tracer.results().counts, expected)
-
- # TODO: RUSTPYTHON
- @unittest.expectedFailure
- def test_traced_func_importing(self):
- self.tracer.runfunc(traced_func_importing, 2, 5)
-
- firstlineno = get_firstlineno(traced_func_importing)
- expected = {
- (self.my_py_filename, firstlineno + 1): 1,
- (fix_ext_py(testmod.__file__), 2): 1,
- (fix_ext_py(testmod.__file__), 3): 1,
- }
-
- self.assertEqual(self.tracer.results().counts, expected)
-
- # TODO: RUSTPYTHON
- @unittest.expectedFailure
- def test_trace_func_generator(self):
- self.tracer.runfunc(traced_func_calling_generator)
-
- firstlineno_calling = get_firstlineno(traced_func_calling_generator)
- firstlineno_gen = get_firstlineno(traced_func_generator)
- expected = {
- (self.my_py_filename, firstlineno_calling + 1): 1,
- (self.my_py_filename, firstlineno_calling + 2): 11,
- (self.my_py_filename, firstlineno_calling + 3): 10,
- (self.my_py_filename, firstlineno_gen + 1): 1,
- (self.my_py_filename, firstlineno_gen + 2): 11,
- (self.my_py_filename, firstlineno_gen + 3): 10,
- }
- self.assertEqual(self.tracer.results().counts, expected)
-
- # TODO: RUSTPYTHON
- @unittest.expectedFailure
- def test_trace_list_comprehension(self):
- self.tracer.runfunc(traced_caller_list_comprehension)
-
- firstlineno_calling = get_firstlineno(traced_caller_list_comprehension)
- firstlineno_called = get_firstlineno(traced_doubler)
- expected = {
- (self.my_py_filename, firstlineno_calling + 1): 1,
- # List compehentions work differently in 3.x, so the count
- # below changed compared to 2.x.
- (self.my_py_filename, firstlineno_calling + 2): 12,
- (self.my_py_filename, firstlineno_calling + 3): 1,
- (self.my_py_filename, firstlineno_called + 1): 10,
- }
- self.assertEqual(self.tracer.results().counts, expected)
-
- # TODO: RUSTPYTHON
- @unittest.expectedFailure
- def test_traced_decorated_function(self):
- self.tracer.runfunc(traced_decorated_function)
-
- firstlineno = get_firstlineno(traced_decorated_function)
- expected = {
- (self.my_py_filename, firstlineno + 1): 1,
- (self.my_py_filename, firstlineno + 2): 1,
- (self.my_py_filename, firstlineno + 3): 1,
- (self.my_py_filename, firstlineno + 4): 1,
- (self.my_py_filename, firstlineno + 5): 1,
- (self.my_py_filename, firstlineno + 6): 1,
- (self.my_py_filename, firstlineno + 7): 1,
- (self.my_py_filename, firstlineno + 8): 1,
- (self.my_py_filename, firstlineno + 9): 1,
- (self.my_py_filename, firstlineno + 10): 1,
- (self.my_py_filename, firstlineno + 11): 1,
- }
- self.assertEqual(self.tracer.results().counts, expected)
-
- # TODO: RUSTPYTHON
- @unittest.expectedFailure
- def test_linear_methods(self):
- # XXX todo: later add 'static_method_linear' and 'class_method_linear'
- # here, once issue1764286 is resolved
- #
- for methname in ['inst_method_linear',]:
- tracer = Trace(count=1, trace=0, countfuncs=0, countcallers=0)
- traced_obj = TracedClass(25)
- method = getattr(traced_obj, methname)
- tracer.runfunc(method, 20)
-
- firstlineno = get_firstlineno(method)
- expected = {
- (self.my_py_filename, firstlineno + 1): 1,
- }
- self.assertEqual(tracer.results().counts, expected)
-
-
-class TestRunExecCounts(unittest.TestCase):
- """A simple sanity test of line-counting, via runctx (exec)"""
- def setUp(self):
- self.my_py_filename = fix_ext_py(__file__)
- self.addCleanup(sys.settrace, sys.gettrace())
-
- # TODO: RUSTPYTHON, KeyError: ('Lib/test/test_trace.py', 43)
- @unittest.expectedFailure
- def test_exec_counts(self):
- self.tracer = Trace(count=1, trace=0, countfuncs=0, countcallers=0)
- code = r'''traced_func_loop(2, 5)'''
- code = compile(code, __file__, 'exec')
- self.tracer.runctx(code, globals(), vars())
-
- firstlineno = get_firstlineno(traced_func_loop)
- expected = {
- (self.my_py_filename, firstlineno + 1): 1,
- (self.my_py_filename, firstlineno + 2): 6,
- (self.my_py_filename, firstlineno + 3): 5,
- (self.my_py_filename, firstlineno + 4): 1,
- }
-
- # When used through 'run', some other spurious counts are produced, like
- # the settrace of threading, which we ignore, just making sure that the
- # counts fo traced_func_loop were right.
- #
- for k in expected.keys():
- self.assertEqual(self.tracer.results().counts[k], expected[k])
-
-
-class TestFuncs(unittest.TestCase):
- """White-box testing of funcs tracing"""
- def setUp(self):
- self.addCleanup(sys.settrace, sys.gettrace())
- self.tracer = Trace(count=0, trace=0, countfuncs=1)
- self.filemod = my_file_and_modname()
- self._saved_tracefunc = sys.gettrace()
-
- def tearDown(self):
- if self._saved_tracefunc is not None:
- sys.settrace(self._saved_tracefunc)
-
- # TODO: RUSTPYTHON, gc
- @unittest.expectedFailure
- def test_simple_caller(self):
- self.tracer.runfunc(traced_func_simple_caller, 1)
-
- expected = {
- self.filemod + ('traced_func_simple_caller',): 1,
- self.filemod + ('traced_func_linear',): 1,
- }
- self.assertEqual(self.tracer.results().calledfuncs, expected)
-
- # TODO: RUSTPYTHON, gc
- @unittest.expectedFailure
- def test_arg_errors(self):
- res = self.tracer.runfunc(traced_capturer, 1, 2, self=3, func=4)
- self.assertEqual(res, ((1, 2), {'self': 3, 'func': 4}))
- with self.assertWarns(DeprecationWarning):
- res = self.tracer.runfunc(func=traced_capturer, arg=1)
- self.assertEqual(res, ((), {'arg': 1}))
- with self.assertRaises(TypeError):
- self.tracer.runfunc()
-
- # TODO: RUSTPYTHON, gc
- @unittest.expectedFailure
- def test_loop_caller_importing(self):
- self.tracer.runfunc(traced_func_importing_caller, 1)
-
- expected = {
- self.filemod + ('traced_func_simple_caller',): 1,
- self.filemod + ('traced_func_linear',): 1,
- self.filemod + ('traced_func_importing_caller',): 1,
- self.filemod + ('traced_func_importing',): 1,
- (fix_ext_py(testmod.__file__), 'testmod', 'func'): 1,
- }
- self.assertEqual(self.tracer.results().calledfuncs, expected)
-
- # TODO: RUSTPYTHON, gc
- @unittest.expectedFailure
- @unittest.skipIf(hasattr(sys, 'gettrace') and sys.gettrace(),
- 'pre-existing trace function throws off measurements')
- def test_inst_method_calling(self):
- obj = TracedClass(20)
- self.tracer.runfunc(obj.inst_method_calling, 1)
-
- expected = {
- self.filemod + ('TracedClass.inst_method_calling',): 1,
- self.filemod + ('TracedClass.inst_method_linear',): 1,
- self.filemod + ('traced_func_linear',): 1,
- }
- self.assertEqual(self.tracer.results().calledfuncs, expected)
-
- # TODO: RUSTPYTHON, gc
- @unittest.expectedFailure
- def test_traced_decorated_function(self):
- self.tracer.runfunc(traced_decorated_function)
-
- expected = {
- self.filemod + ('traced_decorated_function',): 1,
- self.filemod + ('decorator_fabric',): 1,
- self.filemod + ('decorator2',): 1,
- self.filemod + ('decorator1',): 1,
- self.filemod + ('func',): 1,
- }
- self.assertEqual(self.tracer.results().calledfuncs, expected)
-
-
-class TestCallers(unittest.TestCase):
- """White-box testing of callers tracing"""
- def setUp(self):
- self.addCleanup(sys.settrace, sys.gettrace())
- self.tracer = Trace(count=0, trace=0, countcallers=1)
- self.filemod = my_file_and_modname()
-
- @unittest.skipIf(hasattr(sys, 'gettrace') and sys.gettrace(),
- 'pre-existing trace function throws off measurements')
- @unittest.skip("TODO: RUSTPYTHON, Error in atexit._run_exitfuncs")
- def test_loop_caller_importing(self):
- self.tracer.runfunc(traced_func_importing_caller, 1)
-
- expected = {
- ((os.path.splitext(trace.__file__)[0] + '.py', 'trace', 'Trace.runfunc'),
- (self.filemod + ('traced_func_importing_caller',))): 1,
- ((self.filemod + ('traced_func_simple_caller',)),
- (self.filemod + ('traced_func_linear',))): 1,
- ((self.filemod + ('traced_func_importing_caller',)),
- (self.filemod + ('traced_func_simple_caller',))): 1,
- ((self.filemod + ('traced_func_importing_caller',)),
- (self.filemod + ('traced_func_importing',))): 1,
- ((self.filemod + ('traced_func_importing',)),
- (fix_ext_py(testmod.__file__), 'testmod', 'func')): 1,
- }
- self.assertEqual(self.tracer.results().callers, expected)
-
-
-# Created separately for issue #3821
-class TestCoverage(unittest.TestCase):
- def setUp(self):
- self.addCleanup(sys.settrace, sys.gettrace())
-
- def tearDown(self):
- rmtree(TESTFN)
- unlink(TESTFN)
-
- def _coverage(self, tracer,
- cmd='import test.support, test.test_pprint;'
- 'test.support.run_unittest(test.test_pprint.QueryTestCase)'):
- tracer.run(cmd)
- r = tracer.results()
- r.write_results(show_missing=True, summary=True, coverdir=TESTFN)
-
- # TODO: RUSTPYTHON
- @unittest.expectedFailure
- def test_coverage(self):
- tracer = trace.Trace(trace=0, count=1)
- with captured_stdout() as stdout:
- self._coverage(tracer)
- stdout = stdout.getvalue()
- self.assertIn("pprint.py", stdout)
- self.assertIn("case.py", stdout) # from unittest
- files = os.listdir(TESTFN)
- self.assertIn("pprint.cover", files)
- self.assertIn("unittest.case.cover", files)
-
- def test_coverage_ignore(self):
- # Ignore all files, nothing should be traced nor printed
- libpath = os.path.normpath(os.path.dirname(os.__file__))
- # sys.prefix does not work when running from a checkout
- tracer = trace.Trace(ignoredirs=[sys.base_prefix, sys.base_exec_prefix,
- libpath], trace=0, count=1)
- with captured_stdout() as stdout:
- self._coverage(tracer)
- if os.path.exists(TESTFN):
- files = os.listdir(TESTFN)
- self.assertEqual(files, ['_importlib.cover']) # Ignore __import__
-
- # TODO: RUSTPYTHON
- @unittest.expectedFailure
- def test_issue9936(self):
- tracer = trace.Trace(trace=0, count=1)
- modname = 'test.tracedmodules.testmod'
- # Ensure that the module is executed in import
- if modname in sys.modules:
- del sys.modules[modname]
- cmd = ("import test.tracedmodules.testmod as t;"
- "t.func(0); t.func2();")
- with captured_stdout() as stdout:
- self._coverage(tracer, cmd)
- stdout.seek(0)
- stdout.readline()
- coverage = {}
- for line in stdout:
- lines, cov, module = line.split()[:3]
- coverage[module] = (int(lines), int(cov[:-1]))
- # XXX This is needed to run regrtest.py as a script
- modname = trace._fullmodname(sys.modules[modname].__file__)
- self.assertIn(modname, coverage)
- self.assertEqual(coverage[modname], (5, 100))
-
-### Tests that don't mess with sys.settrace and can be traced
-### themselves TODO: Skip tests that do mess with sys.settrace when
-### regrtest is invoked with -T option.
-class Test_Ignore(unittest.TestCase):
- def test_ignored(self):
- jn = os.path.join
- ignore = trace._Ignore(['x', 'y.z'], [jn('foo', 'bar')])
- self.assertTrue(ignore.names('x.py', 'x'))
- self.assertFalse(ignore.names('xy.py', 'xy'))
- self.assertFalse(ignore.names('y.py', 'y'))
- self.assertTrue(ignore.names(jn('foo', 'bar', 'baz.py'), 'baz'))
- self.assertFalse(ignore.names(jn('bar', 'z.py'), 'z'))
- # Matched before.
- self.assertTrue(ignore.names(jn('bar', 'baz.py'), 'baz'))
-
-# Created for Issue 31908 -- CLI utility not writing cover files
-class TestCoverageCommandLineOutput(unittest.TestCase):
-
- codefile = 'tmp.py'
- coverfile = 'tmp.cover'
-
- def setUp(self):
- with open(self.codefile, 'w', encoding='iso-8859-15') as f:
- f.write(textwrap.dedent('''\
- # coding: iso-8859-15
- x = 'spœm'
- if []:
- print('unreachable')
- '''))
-
- def tearDown(self):
- unlink(self.codefile)
- unlink(self.coverfile)
-
- # TODO: RUSTPYTHON
- @unittest.expectedFailure
- def test_cover_files_written_no_highlight(self):
- # Test also that the cover file for the trace module is not created
- # (issue #34171).
- tracedir = os.path.dirname(os.path.abspath(trace.__file__))
- tracecoverpath = os.path.join(tracedir, 'trace.cover')
- unlink(tracecoverpath)
-
- argv = '-m trace --count'.split() + [self.codefile]
- status, stdout, stderr = assert_python_ok(*argv)
- self.assertEqual(stderr, b'')
- self.assertFalse(os.path.exists(tracecoverpath))
- self.assertTrue(os.path.exists(self.coverfile))
- with open(self.coverfile, encoding='iso-8859-15') as f:
- self.assertEqual(f.read(),
- " # coding: iso-8859-15\n"
- " 1: x = 'spœm'\n"
- " 1: if []:\n"
- " print('unreachable')\n"
- )
-
- # TODO: RUSTPYTHON
- @unittest.expectedFailure
- def test_cover_files_written_with_highlight(self):
- argv = '-m trace --count --missing'.split() + [self.codefile]
- status, stdout, stderr = assert_python_ok(*argv)
- self.assertTrue(os.path.exists(self.coverfile))
- with open(self.coverfile, encoding='iso-8859-15') as f:
- self.assertEqual(f.read(), textwrap.dedent('''\
- # coding: iso-8859-15
- 1: x = 'spœm'
- 1: if []:
- >>>>>> print('unreachable')
- '''))
-
-class TestCommandLine(unittest.TestCase):
-
- def test_failures(self):
- _errors = (
- (b'progname is missing: required with the main options', '-l', '-T'),
- (b'cannot specify both --listfuncs and (--trace or --count)', '-lc'),
- (b'argument -R/--no-report: not allowed with argument -r/--report', '-rR'),
- (b'must specify one of --trace, --count, --report, --listfuncs, or --trackcalls', '-g'),
- (b'-r/--report requires -f/--file', '-r'),
- (b'--summary can only be used with --count or --report', '-sT'),
- (b'unrecognized arguments: -y', '-y'))
- for message, *args in _errors:
- *_, stderr = assert_python_failure('-m', 'trace', *args)
- self.assertIn(message, stderr)
-
- # TODO: RUSTPYTHON
- @unittest.expectedFailure
- def test_listfuncs_flag_success(self):
- filename = TESTFN + '.py'
- modulename = os.path.basename(TESTFN)
- with open(filename, 'w', encoding='utf-8') as fd:
- self.addCleanup(unlink, filename)
- fd.write("a = 1\n")
- status, stdout, stderr = assert_python_ok('-m', 'trace', '-l', filename,
- PYTHONIOENCODING='utf-8')
- self.assertIn(b'functions called:', stdout)
- expected = f'filename: {filename}, modulename: {modulename}, funcname: '
- self.assertIn(expected.encode(), stdout)
-
- # TODO: RUSTPYTHON
- @unittest.expectedFailure
- def test_sys_argv_list(self):
- with open(TESTFN, 'w', encoding='utf-8') as fd:
- self.addCleanup(unlink, TESTFN)
- fd.write("import sys\n")
- fd.write("print(type(sys.argv))\n")
-
- status, direct_stdout, stderr = assert_python_ok(TESTFN)
- status, trace_stdout, stderr = assert_python_ok('-m', 'trace', '-l', TESTFN)
- self.assertIn(direct_stdout.strip(), trace_stdout)
-
- # TODO: RUSTPYTHON
- @unittest.expectedFailure
- def test_count_and_summary(self):
- filename = f'{TESTFN}.py'
- coverfilename = f'{TESTFN}.cover'
- modulename = os.path.basename(TESTFN)
- with open(filename, 'w', encoding='utf-8') as fd:
- self.addCleanup(unlink, filename)
- self.addCleanup(unlink, coverfilename)
- fd.write(textwrap.dedent("""\
- x = 1
- y = 2
-
- def f():
- return x + y
-
- for i in range(10):
- f()
- """))
- status, stdout, _ = assert_python_ok('-m', 'trace', '-cs', filename)
- stdout = stdout.decode()
- self.assertEqual(status, 0)
- self.assertIn('lines cov% module (path)', stdout)
- self.assertIn(f'6 100% {modulename} ({filename})', stdout)
-
- # TODO: RUSTPYTHON
- @unittest.expectedFailure
- def test_run_as_module(self):
- assert_python_ok('-m', 'trace', '-l', '--module', 'timeit', '-n', '1')
- assert_python_failure('-m', 'trace', '-l', '--module', 'not_a_module_zzz')
-
-
-if __name__ == '__main__':
- unittest.main()
+import os
+import sys
+from test.support import captured_stdout
+from test.support.os_helper import TESTFN, TESTFN_UNICODE, FS_NONASCII, rmtree, unlink
+from test.support.script_helper import assert_python_ok, assert_python_failure
+import textwrap
+import unittest
+
+import trace
+from trace import Trace
+
+from test.tracedmodules import testmod
+
+#------------------------------- Utilities -----------------------------------#
+
+def fix_ext_py(filename):
+ """Given a .pyc filename converts it to the appropriate .py"""
+ if filename.endswith('.pyc'):
+ filename = filename[:-1]
+ return filename
+
+def my_file_and_modname():
+ """The .py file and module name of this file (__file__)"""
+ modname = os.path.splitext(os.path.basename(__file__))[0]
+ return fix_ext_py(__file__), modname
+
+def get_firstlineno(func):
+ return func.__code__.co_firstlineno
+
+#-------------------- Target functions for tracing ---------------------------#
+#
+# The relative line numbers of lines in these functions matter for verifying
+# tracing. Please modify the appropriate tests if you change one of the
+# functions. Absolute line numbers don't matter.
+#
+
+def traced_func_linear(x, y):
+ a = x
+ b = y
+ c = a + b
+ return c
+
+def traced_func_loop(x, y):
+ c = x
+ for i in range(5):
+ c += y
+ return c
+
+def traced_func_importing(x, y):
+ return x + y + testmod.func(1)
+
+def traced_func_simple_caller(x):
+ c = traced_func_linear(x, x)
+ return c + x
+
+def traced_func_importing_caller(x):
+ k = traced_func_simple_caller(x)
+ k += traced_func_importing(k, x)
+ return k
+
+def traced_func_generator(num):
+ c = 5 # executed once
+ for i in range(num):
+ yield i + c
+
+def traced_func_calling_generator():
+ k = 0
+ for i in traced_func_generator(10):
+ k += i
+
+def traced_doubler(num):
+ return num * 2
+
+def traced_capturer(*args, **kwargs):
+ return args, kwargs
+
+def traced_caller_list_comprehension():
+ k = 10
+ mylist = [traced_doubler(i) for i in range(k)]
+ return mylist
+
+def traced_decorated_function():
+ def decorator1(f):
+ return f
+ def decorator_fabric():
+ def decorator2(f):
+ return f
+ return decorator2
+ @decorator1
+ @decorator_fabric()
+ def func():
+ pass
+ func()
+
+
+class TracedClass(object):
+ def __init__(self, x):
+ self.a = x
+
+ def inst_method_linear(self, y):
+ return self.a + y
+
+ def inst_method_calling(self, x):
+ c = self.inst_method_linear(x)
+ return c + traced_func_linear(x, c)
+
+ @classmethod
+ def class_method_linear(cls, y):
+ return y * 2
+
+ @staticmethod
+ def static_method_linear(y):
+ return y * 2
+
+
+#------------------------------ Test cases -----------------------------------#
+
+
+class TestLineCounts(unittest.TestCase):
+ """White-box testing of line-counting, via runfunc"""
+ def setUp(self):
+ self.addCleanup(sys.settrace, sys.gettrace())
+ self.tracer = Trace(count=1, trace=0, countfuncs=0, countcallers=0)
+ self.my_py_filename = fix_ext_py(__file__)
+
+ # TODO: RUSTPYTHON
+ @unittest.expectedFailure
+ def test_traced_func_linear(self):
+ result = self.tracer.runfunc(traced_func_linear, 2, 5)
+ self.assertEqual(result, 7)
+
+ # all lines are executed once
+ expected = {}
+ firstlineno = get_firstlineno(traced_func_linear)
+ for i in range(1, 5):
+ expected[(self.my_py_filename, firstlineno + i)] = 1
+
+ self.assertEqual(self.tracer.results().counts, expected)
+
+ # TODO: RUSTPYTHON
+ @unittest.expectedFailure
+ def test_traced_func_loop(self):
+ self.tracer.runfunc(traced_func_loop, 2, 3)
+
+ firstlineno = get_firstlineno(traced_func_loop)
+ expected = {
+ (self.my_py_filename, firstlineno + 1): 1,
+ (self.my_py_filename, firstlineno + 2): 6,
+ (self.my_py_filename, firstlineno + 3): 5,
+ (self.my_py_filename, firstlineno + 4): 1,
+ }
+ self.assertEqual(self.tracer.results().counts, expected)
+
+ # TODO: RUSTPYTHON
+ @unittest.expectedFailure
+ def test_traced_func_importing(self):
+ self.tracer.runfunc(traced_func_importing, 2, 5)
+
+ firstlineno = get_firstlineno(traced_func_importing)
+ expected = {
+ (self.my_py_filename, firstlineno + 1): 1,
+ (fix_ext_py(testmod.__file__), 2): 1,
+ (fix_ext_py(testmod.__file__), 3): 1,
+ }
+
+ self.assertEqual(self.tracer.results().counts, expected)
+
+ # TODO: RUSTPYTHON
+ @unittest.expectedFailure
+ def test_trace_func_generator(self):
+ self.tracer.runfunc(traced_func_calling_generator)
+
+ firstlineno_calling = get_firstlineno(traced_func_calling_generator)
+ firstlineno_gen = get_firstlineno(traced_func_generator)
+ expected = {
+ (self.my_py_filename, firstlineno_calling + 1): 1,
+ (self.my_py_filename, firstlineno_calling + 2): 11,
+ (self.my_py_filename, firstlineno_calling + 3): 10,
+ (self.my_py_filename, firstlineno_gen + 1): 1,
+ (self.my_py_filename, firstlineno_gen + 2): 11,
+ (self.my_py_filename, firstlineno_gen + 3): 10,
+ }
+ self.assertEqual(self.tracer.results().counts, expected)
+
+ # TODO: RUSTPYTHON
+ @unittest.expectedFailure
+ def test_trace_list_comprehension(self):
+ self.tracer.runfunc(traced_caller_list_comprehension)
+
+ firstlineno_calling = get_firstlineno(traced_caller_list_comprehension)
+ firstlineno_called = get_firstlineno(traced_doubler)
+ expected = {
+ (self.my_py_filename, firstlineno_calling + 1): 1,
+ # List compehentions work differently in 3.x, so the count
+ # below changed compared to 2.x.
+ (self.my_py_filename, firstlineno_calling + 2): 12,
+ (self.my_py_filename, firstlineno_calling + 3): 1,
+ (self.my_py_filename, firstlineno_called + 1): 10,
+ }
+ self.assertEqual(self.tracer.results().counts, expected)
+
+ # TODO: RUSTPYTHON
+ @unittest.expectedFailure
+ def test_traced_decorated_function(self):
+ self.tracer.runfunc(traced_decorated_function)
+
+ firstlineno = get_firstlineno(traced_decorated_function)
+ expected = {
+ (self.my_py_filename, firstlineno + 1): 1,
+ (self.my_py_filename, firstlineno + 2): 1,
+ (self.my_py_filename, firstlineno + 3): 1,
+ (self.my_py_filename, firstlineno + 4): 1,
+ (self.my_py_filename, firstlineno + 5): 1,
+ (self.my_py_filename, firstlineno + 6): 1,
+ (self.my_py_filename, firstlineno + 7): 1,
+ (self.my_py_filename, firstlineno + 8): 1,
+ (self.my_py_filename, firstlineno + 9): 1,
+ (self.my_py_filename, firstlineno + 10): 1,
+ (self.my_py_filename, firstlineno + 11): 1,
+ }
+ self.assertEqual(self.tracer.results().counts, expected)
+
+ # TODO: RUSTPYTHON
+ @unittest.expectedFailure
+ def test_linear_methods(self):
+ # XXX todo: later add 'static_method_linear' and 'class_method_linear'
+ # here, once issue1764286 is resolved
+ #
+ for methname in ['inst_method_linear',]:
+ tracer = Trace(count=1, trace=0, countfuncs=0, countcallers=0)
+ traced_obj = TracedClass(25)
+ method = getattr(traced_obj, methname)
+ tracer.runfunc(method, 20)
+
+ firstlineno = get_firstlineno(method)
+ expected = {
+ (self.my_py_filename, firstlineno + 1): 1,
+ }
+ self.assertEqual(tracer.results().counts, expected)
+
+
+class TestRunExecCounts(unittest.TestCase):
+ """A simple sanity test of line-counting, via runctx (exec)"""
+ def setUp(self):
+ self.my_py_filename = fix_ext_py(__file__)
+ self.addCleanup(sys.settrace, sys.gettrace())
+
+ # TODO: RUSTPYTHON, KeyError: ('Lib/test/test_trace.py', 43)
+ @unittest.expectedFailure
+ def test_exec_counts(self):
+ self.tracer = Trace(count=1, trace=0, countfuncs=0, countcallers=0)
+ code = r'''traced_func_loop(2, 5)'''
+ code = compile(code, __file__, 'exec')
+ self.tracer.runctx(code, globals(), vars())
+
+ firstlineno = get_firstlineno(traced_func_loop)
+ expected = {
+ (self.my_py_filename, firstlineno + 1): 1,
+ (self.my_py_filename, firstlineno + 2): 6,
+ (self.my_py_filename, firstlineno + 3): 5,
+ (self.my_py_filename, firstlineno + 4): 1,
+ }
+
+ # When used through 'run', some other spurious counts are produced, like
+ # the settrace of threading, which we ignore, just making sure that the
+ # counts fo traced_func_loop were right.
+ #
+ for k in expected.keys():
+ self.assertEqual(self.tracer.results().counts[k], expected[k])
+
+
+class TestFuncs(unittest.TestCase):
+ """White-box testing of funcs tracing"""
+ def setUp(self):
+ self.addCleanup(sys.settrace, sys.gettrace())
+ self.tracer = Trace(count=0, trace=0, countfuncs=1)
+ self.filemod = my_file_and_modname()
+ self._saved_tracefunc = sys.gettrace()
+
+ def tearDown(self):
+ if self._saved_tracefunc is not None:
+ sys.settrace(self._saved_tracefunc)
+
+ # TODO: RUSTPYTHON, gc
+ @unittest.expectedFailure
+ def test_simple_caller(self):
+ self.tracer.runfunc(traced_func_simple_caller, 1)
+
+ expected = {
+ self.filemod + ('traced_func_simple_caller',): 1,
+ self.filemod + ('traced_func_linear',): 1,
+ }
+ self.assertEqual(self.tracer.results().calledfuncs, expected)
+
+ # TODO: RUSTPYTHON, gc
+ @unittest.expectedFailure
+ def test_arg_errors(self):
+ res = self.tracer.runfunc(traced_capturer, 1, 2, self=3, func=4)
+ self.assertEqual(res, ((1, 2), {'self': 3, 'func': 4}))
+ with self.assertWarns(DeprecationWarning):
+ res = self.tracer.runfunc(func=traced_capturer, arg=1)
+ self.assertEqual(res, ((), {'arg': 1}))
+ with self.assertRaises(TypeError):
+ self.tracer.runfunc()
+
+ # TODO: RUSTPYTHON, gc
+ @unittest.expectedFailure
+ def test_loop_caller_importing(self):
+ self.tracer.runfunc(traced_func_importing_caller, 1)
+
+ expected = {
+ self.filemod + ('traced_func_simple_caller',): 1,
+ self.filemod + ('traced_func_linear',): 1,
+ self.filemod + ('traced_func_importing_caller',): 1,
+ self.filemod + ('traced_func_importing',): 1,
+ (fix_ext_py(testmod.__file__), 'testmod', 'func'): 1,
+ }
+ self.assertEqual(self.tracer.results().calledfuncs, expected)
+
+ # TODO: RUSTPYTHON, gc
+ @unittest.expectedFailure
+ @unittest.skipIf(hasattr(sys, 'gettrace') and sys.gettrace(),
+ 'pre-existing trace function throws off measurements')
+ def test_inst_method_calling(self):
+ obj = TracedClass(20)
+ self.tracer.runfunc(obj.inst_method_calling, 1)
+
+ expected = {
+ self.filemod + ('TracedClass.inst_method_calling',): 1,
+ self.filemod + ('TracedClass.inst_method_linear',): 1,
+ self.filemod + ('traced_func_linear',): 1,
+ }
+ self.assertEqual(self.tracer.results().calledfuncs, expected)
+
+ # TODO: RUSTPYTHON, gc
+ @unittest.expectedFailure
+ def test_traced_decorated_function(self):
+ self.tracer.runfunc(traced_decorated_function)
+
+ expected = {
+ self.filemod + ('traced_decorated_function',): 1,
+ self.filemod + ('decorator_fabric',): 1,
+ self.filemod + ('decorator2',): 1,
+ self.filemod + ('decorator1',): 1,
+ self.filemod + ('func',): 1,
+ }
+ self.assertEqual(self.tracer.results().calledfuncs, expected)
+
+
+class TestCallers(unittest.TestCase):
+ """White-box testing of callers tracing"""
+ def setUp(self):
+ self.addCleanup(sys.settrace, sys.gettrace())
+ self.tracer = Trace(count=0, trace=0, countcallers=1)
+ self.filemod = my_file_and_modname()
+
+ @unittest.skipIf(hasattr(sys, 'gettrace') and sys.gettrace(),
+ 'pre-existing trace function throws off measurements')
+ @unittest.skip("TODO: RUSTPYTHON, Error in atexit._run_exitfuncs")
+ def test_loop_caller_importing(self):
+ self.tracer.runfunc(traced_func_importing_caller, 1)
+
+ expected = {
+ ((os.path.splitext(trace.__file__)[0] + '.py', 'trace', 'Trace.runfunc'),
+ (self.filemod + ('traced_func_importing_caller',))): 1,
+ ((self.filemod + ('traced_func_simple_caller',)),
+ (self.filemod + ('traced_func_linear',))): 1,
+ ((self.filemod + ('traced_func_importing_caller',)),
+ (self.filemod + ('traced_func_simple_caller',))): 1,
+ ((self.filemod + ('traced_func_importing_caller',)),
+ (self.filemod + ('traced_func_importing',))): 1,
+ ((self.filemod + ('traced_func_importing',)),
+ (fix_ext_py(testmod.__file__), 'testmod', 'func')): 1,
+ }
+ self.assertEqual(self.tracer.results().callers, expected)
+
+
+# Created separately for issue #3821
+class TestCoverage(unittest.TestCase):
+ def setUp(self):
+ self.addCleanup(sys.settrace, sys.gettrace())
+
+ def tearDown(self):
+ rmtree(TESTFN)
+ unlink(TESTFN)
+
+ def _coverage(self, tracer,
+ cmd='import test.support, test.test_pprint;'
+ 'test.support.run_unittest(test.test_pprint.QueryTestCase)'):
+ tracer.run(cmd)
+ r = tracer.results()
+ r.write_results(show_missing=True, summary=True, coverdir=TESTFN)
+
+ # TODO: RUSTPYTHON
+ @unittest.expectedFailure
+ def test_coverage(self):
+ tracer = trace.Trace(trace=0, count=1)
+ with captured_stdout() as stdout:
+ self._coverage(tracer)
+ stdout = stdout.getvalue()
+ self.assertIn("pprint.py", stdout)
+ self.assertIn("case.py", stdout) # from unittest
+ files = os.listdir(TESTFN)
+ self.assertIn("pprint.cover", files)
+ self.assertIn("unittest.case.cover", files)
+
+ def test_coverage_ignore(self):
+ # Ignore all files, nothing should be traced nor printed
+ libpath = os.path.normpath(os.path.dirname(os.__file__))
+ # sys.prefix does not work when running from a checkout
+ tracer = trace.Trace(ignoredirs=[sys.base_prefix, sys.base_exec_prefix,
+ libpath], trace=0, count=1)
+ with captured_stdout() as stdout:
+ self._coverage(tracer)
+ if os.path.exists(TESTFN):
+ files = os.listdir(TESTFN)
+ self.assertEqual(files, ['_importlib.cover']) # Ignore __import__
+
+ # TODO: RUSTPYTHON
+ @unittest.expectedFailure
+ def test_issue9936(self):
+ tracer = trace.Trace(trace=0, count=1)
+ modname = 'test.tracedmodules.testmod'
+ # Ensure that the module is executed in import
+ if modname in sys.modules:
+ del sys.modules[modname]
+ cmd = ("import test.tracedmodules.testmod as t;"
+ "t.func(0); t.func2();")
+ with captured_stdout() as stdout:
+ self._coverage(tracer, cmd)
+ stdout.seek(0)
+ stdout.readline()
+ coverage = {}
+ for line in stdout:
+ lines, cov, module = line.split()[:3]
+ coverage[module] = (int(lines), int(cov[:-1]))
+ # XXX This is needed to run regrtest.py as a script
+ modname = trace._fullmodname(sys.modules[modname].__file__)
+ self.assertIn(modname, coverage)
+ self.assertEqual(coverage[modname], (5, 100))
+
+### Tests that don't mess with sys.settrace and can be traced
+### themselves TODO: Skip tests that do mess with sys.settrace when
+### regrtest is invoked with -T option.
+class Test_Ignore(unittest.TestCase):
+ def test_ignored(self):
+ jn = os.path.join
+ ignore = trace._Ignore(['x', 'y.z'], [jn('foo', 'bar')])
+ self.assertTrue(ignore.names('x.py', 'x'))
+ self.assertFalse(ignore.names('xy.py', 'xy'))
+ self.assertFalse(ignore.names('y.py', 'y'))
+ self.assertTrue(ignore.names(jn('foo', 'bar', 'baz.py'), 'baz'))
+ self.assertFalse(ignore.names(jn('bar', 'z.py'), 'z'))
+ # Matched before.
+ self.assertTrue(ignore.names(jn('bar', 'baz.py'), 'baz'))
+
+# Created for Issue 31908 -- CLI utility not writing cover files
+class TestCoverageCommandLineOutput(unittest.TestCase):
+
+ codefile = 'tmp.py'
+ coverfile = 'tmp.cover'
+
+ def setUp(self):
+ with open(self.codefile, 'w', encoding='iso-8859-15') as f:
+ f.write(textwrap.dedent('''\
+ # coding: iso-8859-15
+ x = 'spœm'
+ if []:
+ print('unreachable')
+ '''))
+
+ def tearDown(self):
+ unlink(self.codefile)
+ unlink(self.coverfile)
+
+ # TODO: RUSTPYTHON
+ @unittest.expectedFailure
+ def test_cover_files_written_no_highlight(self):
+ # Test also that the cover file for the trace module is not created
+ # (issue #34171).
+ tracedir = os.path.dirname(os.path.abspath(trace.__file__))
+ tracecoverpath = os.path.join(tracedir, 'trace.cover')
+ unlink(tracecoverpath)
+
+ argv = '-m trace --count'.split() + [self.codefile]
+ status, stdout, stderr = assert_python_ok(*argv)
+ self.assertEqual(stderr, b'')
+ self.assertFalse(os.path.exists(tracecoverpath))
+ self.assertTrue(os.path.exists(self.coverfile))
+ with open(self.coverfile, encoding='iso-8859-15') as f:
+ self.assertEqual(f.read(),
+ " # coding: iso-8859-15\n"
+ " 1: x = 'spœm'\n"
+ " 1: if []:\n"
+ " print('unreachable')\n"
+ )
+
+ # TODO: RUSTPYTHON
+ @unittest.expectedFailure
+ def test_cover_files_written_with_highlight(self):
+ argv = '-m trace --count --missing'.split() + [self.codefile]
+ status, stdout, stderr = assert_python_ok(*argv)
+ self.assertTrue(os.path.exists(self.coverfile))
+ with open(self.coverfile, encoding='iso-8859-15') as f:
+ self.assertEqual(f.read(), textwrap.dedent('''\
+ # coding: iso-8859-15
+ 1: x = 'spœm'
+ 1: if []:
+ >>>>>> print('unreachable')
+ '''))
+
+class TestCommandLine(unittest.TestCase):
+
+ def test_failures(self):
+ _errors = (
+ (b'progname is missing: required with the main options', '-l', '-T'),
+ (b'cannot specify both --listfuncs and (--trace or --count)', '-lc'),
+ (b'argument -R/--no-report: not allowed with argument -r/--report', '-rR'),
+ (b'must specify one of --trace, --count, --report, --listfuncs, or --trackcalls', '-g'),
+ (b'-r/--report requires -f/--file', '-r'),
+ (b'--summary can only be used with --count or --report', '-sT'),
+ (b'unrecognized arguments: -y', '-y'))
+ for message, *args in _errors:
+ *_, stderr = assert_python_failure('-m', 'trace', *args)
+ self.assertIn(message, stderr)
+
+ # TODO: RUSTPYTHON
+ @unittest.expectedFailure
+ def test_listfuncs_flag_success(self):
+ filename = TESTFN + '.py'
+ modulename = os.path.basename(TESTFN)
+ with open(filename, 'w', encoding='utf-8') as fd:
+ self.addCleanup(unlink, filename)
+ fd.write("a = 1\n")
+ status, stdout, stderr = assert_python_ok('-m', 'trace', '-l', filename,
+ PYTHONIOENCODING='utf-8')
+ self.assertIn(b'functions called:', stdout)
+ expected = f'filename: {filename}, modulename: {modulename}, funcname: '
+ self.assertIn(expected.encode(), stdout)
+
+ # TODO: RUSTPYTHON
+ @unittest.expectedFailure
+ def test_sys_argv_list(self):
+ with open(TESTFN, 'w', encoding='utf-8') as fd:
+ self.addCleanup(unlink, TESTFN)
+ fd.write("import sys\n")
+ fd.write("print(type(sys.argv))\n")
+
+ status, direct_stdout, stderr = assert_python_ok(TESTFN)
+ status, trace_stdout, stderr = assert_python_ok('-m', 'trace', '-l', TESTFN)
+ self.assertIn(direct_stdout.strip(), trace_stdout)
+
+ # TODO: RUSTPYTHON
+ @unittest.expectedFailure
+ def test_count_and_summary(self):
+ filename = f'{TESTFN}.py'
+ coverfilename = f'{TESTFN}.cover'
+ modulename = os.path.basename(TESTFN)
+ with open(filename, 'w', encoding='utf-8') as fd:
+ self.addCleanup(unlink, filename)
+ self.addCleanup(unlink, coverfilename)
+ fd.write(textwrap.dedent("""\
+ x = 1
+ y = 2
+
+ def f():
+ return x + y
+
+ for i in range(10):
+ f()
+ """))
+ status, stdout, _ = assert_python_ok('-m', 'trace', '-cs', filename)
+ stdout = stdout.decode()
+ self.assertEqual(status, 0)
+ self.assertIn('lines cov% module (path)', stdout)
+ self.assertIn(f'6 100% {modulename} ({filename})', stdout)
+
+ # TODO: RUSTPYTHON
+ @unittest.expectedFailure
+ def test_run_as_module(self):
+ assert_python_ok('-m', 'trace', '-l', '--module', 'timeit', '-n', '1')
+ assert_python_failure('-m', 'trace', '-l', '--module', 'not_a_module_zzz')
+
+
+if __name__ == '__main__':
+ unittest.main()
From c05e9cd52dfcea9608c53de506a4361c2138d6ff Mon Sep 17 00:00:00 2001
From: CPython Developers <>
Date: Sun, 14 Aug 2022 23:34:09 +0900
Subject: [PATCH 06/23] Update tarfile from CPython 3.10.6
---
Lib/tarfile.py | 10 ++
Lib/test/test_tarfile.py | 250 +++++++++++++++++++++++++++++++++++----
2 files changed, 238 insertions(+), 22 deletions(-)
diff --git a/Lib/tarfile.py b/Lib/tarfile.py
index 6ada9a05db..dea150e8db 100755
--- a/Lib/tarfile.py
+++ b/Lib/tarfile.py
@@ -1163,6 +1163,11 @@ def _proc_builtin(self, tarfile):
# header information.
self._apply_pax_info(tarfile.pax_headers, tarfile.encoding, tarfile.errors)
+ # Remove redundant slashes from directories. This is to be consistent
+ # with frombuf().
+ if self.isdir():
+ self.name = self.name.rstrip("/")
+
return self
def _proc_gnulong(self, tarfile):
@@ -1185,6 +1190,11 @@ def _proc_gnulong(self, tarfile):
elif self.type == GNUTYPE_LONGLINK:
next.linkname = nts(buf, tarfile.encoding, tarfile.errors)
+ # Remove redundant slashes from directories. This is to be consistent
+ # with frombuf().
+ if next.isdir():
+ next.name = next.name.removesuffix("/")
+
return next
def _proc_sparse(self, tarfile):
diff --git a/Lib/test/test_tarfile.py b/Lib/test/test_tarfile.py
index 36838e7f0f..bfd8693998 100644
--- a/Lib/test/test_tarfile.py
+++ b/Lib/test/test_tarfile.py
@@ -220,6 +220,26 @@ def test_fileobj_symlink2(self):
def test_issue14160(self):
self._test_fileobj_link("symtype2", "ustar/regtype")
+ def test_add_dir_getmember(self):
+ # bpo-21987
+ self.add_dir_and_getmember('bar')
+ self.add_dir_and_getmember('a'*101)
+
+ def add_dir_and_getmember(self, name):
+ with os_helper.temp_cwd():
+ with tarfile.open(tmpname, 'w') as tar:
+ tar.format = tarfile.USTAR_FORMAT
+ try:
+ os.mkdir(name)
+ tar.add(name)
+ finally:
+ os.rmdir(name)
+ with tarfile.open(tmpname) as tar:
+ self.assertEqual(
+ tar.getmember(name),
+ tar.getmember(name + '/')
+ )
+
class GzipUstarReadTest(GzipTest, UstarReadTest):
pass
@@ -330,6 +350,38 @@ class LzmaListTest(LzmaTest, ListTest):
class CommonReadTest(ReadTest):
+ def test_is_tarfile_erroneous(self):
+ with open(tmpname, "wb"):
+ pass
+
+ # is_tarfile works on filenames
+ self.assertFalse(tarfile.is_tarfile(tmpname))
+
+ # is_tarfile works on path-like objects
+ self.assertFalse(tarfile.is_tarfile(pathlib.Path(tmpname)))
+
+ # is_tarfile works on file objects
+ with open(tmpname, "rb") as fobj:
+ self.assertFalse(tarfile.is_tarfile(fobj))
+
+ # is_tarfile works on file-like objects
+ self.assertFalse(tarfile.is_tarfile(io.BytesIO(b"invalid")))
+
+ def test_is_tarfile_valid(self):
+ # is_tarfile works on filenames
+ self.assertTrue(tarfile.is_tarfile(self.tarname))
+
+ # is_tarfile works on path-like objects
+ self.assertTrue(tarfile.is_tarfile(pathlib.Path(self.tarname)))
+
+ # is_tarfile works on file objects
+ with open(self.tarname, "rb") as fobj:
+ self.assertTrue(tarfile.is_tarfile(fobj))
+
+ # is_tarfile works on file-like objects
+ with open(self.tarname, "rb") as fobj:
+ self.assertTrue(tarfile.is_tarfile(io.BytesIO(fobj.read())))
+
def test_empty_tarfile(self):
# Test for issue6123: Allow opening empty archives.
# This test checks if tarfile.open() is able to open an empty tar
@@ -365,7 +417,7 @@ def test_null_tarfile(self):
def test_ignore_zeros(self):
# Test TarFile's ignore_zeros option.
# generate 512 pseudorandom bytes
- data = Random(0).getrandbits(512*8).to_bytes(512, 'big')
+ data = Random(0).randbytes(512)
for char in (b'\0', b'a'):
# Test if EOFHeaderError ('\0') and InvalidHeaderError ('a')
# are ignored correctly.
@@ -682,6 +734,16 @@ def test_parallel_iteration(self):
self.assertEqual(m1.offset, m2.offset)
self.assertEqual(m1.get_info(), m2.get_info())
+ @unittest.skipIf(zlib is None, "requires zlib")
+ def test_zlib_error_does_not_leak(self):
+ # bpo-39039: tarfile.open allowed zlib exceptions to bubble up when
+ # parsing certain types of invalid data
+ with unittest.mock.patch("tarfile.TarInfo.fromtarfile") as mock:
+ mock.side_effect = zlib.error
+ with self.assertRaises(tarfile.ReadError):
+ tarfile.open(self.tarname)
+
+
class MiscReadTest(MiscReadTestBase, unittest.TestCase):
test_fail_comp = None
@@ -968,11 +1030,26 @@ def test_header_offset(self):
"iso8859-1", "strict")
self.assertEqual(tarinfo.type, self.longnametype)
+ def test_longname_directory(self):
+ # Test reading a longlink directory. Issue #47231.
+ longdir = ('a' * 101) + '/'
+ with os_helper.temp_cwd():
+ with tarfile.open(tmpname, 'w') as tar:
+ tar.format = self.format
+ try:
+ os.mkdir(longdir)
+ tar.add(longdir)
+ finally:
+ os.rmdir(longdir)
+ with tarfile.open(tmpname) as tar:
+ self.assertIsNotNone(tar.getmember(longdir))
+ self.assertIsNotNone(tar.getmember(longdir.removesuffix('/')))
class GNUReadTest(LongnameTest, ReadTest, unittest.TestCase):
subdir = "gnu"
longnametype = tarfile.GNUTYPE_LONGNAME
+ format = tarfile.GNU_FORMAT
# Since 3.2 tarfile is supposed to accurately restore sparse members and
# produce files with holes. This is what we actually want to test here.
@@ -1048,6 +1125,7 @@ class PaxReadTest(LongnameTest, ReadTest, unittest.TestCase):
subdir = "pax"
longnametype = tarfile.XHDTYPE
+ format = tarfile.PAX_FORMAT
def test_pax_global_headers(self):
tar = tarfile.open(tarname, encoding="iso8859-1")
@@ -1600,6 +1678,52 @@ def test_longnamelink_1025(self):
("longlnk/" * 127) + "longlink_")
+class DeviceHeaderTest(WriteTestBase, unittest.TestCase):
+
+ prefix = "w:"
+
+ def test_headers_written_only_for_device_files(self):
+ # Regression test for bpo-18819.
+ tempdir = os.path.join(TEMPDIR, "device_header_test")
+ os.mkdir(tempdir)
+ try:
+ tar = tarfile.open(tmpname, self.mode)
+ try:
+ input_blk = tarfile.TarInfo(name="my_block_device")
+ input_reg = tarfile.TarInfo(name="my_regular_file")
+ input_blk.type = tarfile.BLKTYPE
+ input_reg.type = tarfile.REGTYPE
+ tar.addfile(input_blk)
+ tar.addfile(input_reg)
+ finally:
+ tar.close()
+
+ # devmajor and devminor should be *interpreted* as 0 in both...
+ tar = tarfile.open(tmpname, "r")
+ try:
+ output_blk = tar.getmember("my_block_device")
+ output_reg = tar.getmember("my_regular_file")
+ finally:
+ tar.close()
+ self.assertEqual(output_blk.devmajor, 0)
+ self.assertEqual(output_blk.devminor, 0)
+ self.assertEqual(output_reg.devmajor, 0)
+ self.assertEqual(output_reg.devminor, 0)
+
+ # ...but the fields should not actually be set on regular files:
+ with open(tmpname, "rb") as infile:
+ buf = infile.read()
+ buf_blk = buf[output_blk.offset:output_blk.offset_data]
+ buf_reg = buf[output_reg.offset:output_reg.offset_data]
+ # See `struct posixheader` in GNU docs for byte offsets:
+ #
+ device_headers = slice(329, 329 + 16)
+ self.assertEqual(buf_blk[device_headers], b"0000000\0" * 2)
+ self.assertEqual(buf_reg[device_headers], b"\0" * 16)
+ finally:
+ os_helper.rmtree(tempdir)
+
+
class CreateTest(WriteTestBase, unittest.TestCase):
prefix = "x:"
@@ -1691,15 +1815,30 @@ def test_create_taropen_pathlike_name(self):
class GzipCreateTest(GzipTest, CreateTest):
- pass
+
+ def test_create_with_compresslevel(self):
+ with tarfile.open(tmpname, self.mode, compresslevel=1) as tobj:
+ tobj.add(self.file_path)
+ with tarfile.open(tmpname, 'r:gz', compresslevel=1) as tobj:
+ pass
class Bz2CreateTest(Bz2Test, CreateTest):
- pass
+
+ def test_create_with_compresslevel(self):
+ with tarfile.open(tmpname, self.mode, compresslevel=1) as tobj:
+ tobj.add(self.file_path)
+ with tarfile.open(tmpname, 'r:bz2', compresslevel=1) as tobj:
+ pass
class LzmaCreateTest(LzmaTest, CreateTest):
- pass
+
+ # Unlike gz and bz2, xz uses the preset keyword instead of compresslevel.
+ # It does not allow for preset to be specified when reading.
+ def test_create_with_preset(self):
+ with tarfile.open(tmpname, self.mode, preset=1) as tobj:
+ tobj.add(self.file_path)
class CreateWithXModeTest(CreateTest):
@@ -1840,6 +1979,61 @@ def test_pax_extended_header(self):
finally:
tar.close()
+ def test_create_pax_header(self):
+ # The ustar header should contain values that can be
+ # represented reasonably, even if a better (e.g. higher
+ # precision) version is set in the pax header.
+ # Issue #45863
+
+ # values that should be kept
+ t = tarfile.TarInfo()
+ t.name = "foo"
+ t.mtime = 1000.1
+ t.size = 100
+ t.uid = 123
+ t.gid = 124
+ info = t.get_info()
+ header = t.create_pax_header(info, encoding="iso8859-1")
+ self.assertEqual(info['name'], "foo")
+ # mtime should be rounded to nearest second
+ self.assertIsInstance(info['mtime'], int)
+ self.assertEqual(info['mtime'], 1000)
+ self.assertEqual(info['size'], 100)
+ self.assertEqual(info['uid'], 123)
+ self.assertEqual(info['gid'], 124)
+ self.assertEqual(header,
+ b'././@PaxHeader' + bytes(86) \
+ + b'0000000\x000000000\x000000000\x0000000000020\x0000000000000\x00010205\x00 x' \
+ + bytes(100) + b'ustar\x0000'+ bytes(247) \
+ + b'16 mtime=1000.1\n' + bytes(496) + b'foo' + bytes(97) \
+ + b'0000644\x000000173\x000000174\x0000000000144\x0000000001750\x00006516\x00 0' \
+ + bytes(100) + b'ustar\x0000' + bytes(247))
+
+ # values that should be changed
+ t = tarfile.TarInfo()
+ t.name = "foo\u3374" # can't be represented in ascii
+ t.mtime = 10**10 # too big
+ t.size = 10**10 # too big
+ t.uid = 8**8 # too big
+ t.gid = 8**8+1 # too big
+ info = t.get_info()
+ header = t.create_pax_header(info, encoding="iso8859-1")
+ # name is kept as-is in info but should be added to pax header
+ self.assertEqual(info['name'], "foo\u3374")
+ self.assertEqual(info['mtime'], 0)
+ self.assertEqual(info['size'], 0)
+ self.assertEqual(info['uid'], 0)
+ self.assertEqual(info['gid'], 0)
+ self.assertEqual(header,
+ b'././@PaxHeader' + bytes(86) \
+ + b'0000000\x000000000\x000000000\x0000000000130\x0000000000000\x00010207\x00 x' \
+ + bytes(100) + b'ustar\x0000' + bytes(247) \
+ + b'15 path=foo\xe3\x8d\xb4\n16 uid=16777216\n' \
+ + b'16 gid=16777217\n20 size=10000000000\n' \
+ + b'21 mtime=10000000000\n'+ bytes(424) + b'foo?' + bytes(96) \
+ + b'0000644\x000000000\x000000000\x0000000000000\x0000000000000\x00006540\x00 0' \
+ + bytes(100) + b'ustar\x0000' + bytes(247))
+
class UnicodeTest:
@@ -2274,23 +2468,32 @@ def test_number_field_limits(self):
tarfile.itn(0x10000000000, 6, tarfile.GNU_FORMAT)
def test__all__(self):
- not_exported = {'version', 'grp', 'pwd', 'symlink_exception',
- 'NUL', 'BLOCKSIZE', 'RECORDSIZE', 'GNU_MAGIC',
- 'POSIX_MAGIC', 'LENGTH_NAME', 'LENGTH_LINK',
- 'LENGTH_PREFIX', 'REGTYPE', 'AREGTYPE', 'LNKTYPE',
- 'SYMTYPE', 'CHRTYPE', 'BLKTYPE', 'DIRTYPE', 'FIFOTYPE',
- 'CONTTYPE', 'GNUTYPE_LONGNAME', 'GNUTYPE_LONGLINK',
- 'GNUTYPE_SPARSE', 'XHDTYPE', 'XGLTYPE', 'SOLARIS_XHDTYPE',
- 'SUPPORTED_TYPES', 'REGULAR_TYPES', 'GNU_TYPES',
- 'PAX_FIELDS', 'PAX_NAME_FIELDS', 'PAX_NUMBER_FIELDS',
- 'stn', 'nts', 'nti', 'itn', 'calc_chksums', 'copyfileobj',
- 'filemode',
- 'EmptyHeaderError', 'TruncatedHeaderError',
- 'EOFHeaderError', 'InvalidHeaderError',
- 'SubsequentHeaderError', 'ExFileObject',
- 'main'}
+ not_exported = {
+ 'version', 'grp', 'pwd', 'symlink_exception', 'NUL', 'BLOCKSIZE',
+ 'RECORDSIZE', 'GNU_MAGIC', 'POSIX_MAGIC', 'LENGTH_NAME',
+ 'LENGTH_LINK', 'LENGTH_PREFIX', 'REGTYPE', 'AREGTYPE', 'LNKTYPE',
+ 'SYMTYPE', 'CHRTYPE', 'BLKTYPE', 'DIRTYPE', 'FIFOTYPE', 'CONTTYPE',
+ 'GNUTYPE_LONGNAME', 'GNUTYPE_LONGLINK', 'GNUTYPE_SPARSE',
+ 'XHDTYPE', 'XGLTYPE', 'SOLARIS_XHDTYPE', 'SUPPORTED_TYPES',
+ 'REGULAR_TYPES', 'GNU_TYPES', 'PAX_FIELDS', 'PAX_NAME_FIELDS',
+ 'PAX_NUMBER_FIELDS', 'stn', 'nts', 'nti', 'itn', 'calc_chksums',
+ 'copyfileobj', 'filemode', 'EmptyHeaderError',
+ 'TruncatedHeaderError', 'EOFHeaderError', 'InvalidHeaderError',
+ 'SubsequentHeaderError', 'ExFileObject', 'main'}
support.check__all__(self, tarfile, not_exported=not_exported)
+ def test_useful_error_message_when_modules_missing(self):
+ fname = os.path.join(os.path.dirname(__file__), 'testtar.tar.xz')
+ with self.assertRaises(tarfile.ReadError) as excinfo:
+ error = tarfile.CompressionError('lzma module is not available'),
+ with unittest.mock.patch.object(tarfile.TarFile, 'xzopen', side_effect=error):
+ tarfile.open(fname)
+
+ self.assertIn(
+ "\n- method xz: CompressionError('lzma module is not available')\n",
+ str(excinfo.exception),
+ )
+
class CommandLineTest(unittest.TestCase):
@@ -2330,7 +2533,8 @@ def test_test_command(self):
def test_test_command_verbose(self):
for tar_name in testtarnames:
for opt in '-v', '--verbose':
- out = self.tarfilecmd(opt, '-t', tar_name)
+ out = self.tarfilecmd(opt, '-t', tar_name,
+ PYTHONIOENCODING='utf-8')
self.assertIn(b'is a tar archive.\n', out)
def test_test_command_invalid_file(self):
@@ -2405,7 +2609,8 @@ def test_create_command_verbose(self):
'and-utf8-bom-sig-only.txt')]
for opt in '-v', '--verbose':
try:
- out = self.tarfilecmd(opt, '-c', tmpname, *files)
+ out = self.tarfilecmd(opt, '-c', tmpname, *files,
+ PYTHONIOENCODING='utf-8')
self.assertIn(b' file created.', out)
with tarfile.open(tmpname) as tar:
tar.getmembers()
@@ -2463,7 +2668,8 @@ def test_extract_command_verbose(self):
for opt in '-v', '--verbose':
try:
with os_helper.temp_cwd(tarextdir):
- out = self.tarfilecmd(opt, '-e', tmpname)
+ out = self.tarfilecmd(opt, '-e', tmpname,
+ PYTHONIOENCODING='utf-8')
self.assertIn(b' file is extracted.', out)
finally:
os_helper.rmtree(tarextdir)
From 43b095f18de9edc0a825800fc1620580249a24cb Mon Sep 17 00:00:00 2001
From: Jeong YunWon
Date: Sun, 14 Aug 2022 23:35:30 +0900
Subject: [PATCH 07/23] mark failing test
---
Lib/test/test_tarfile.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/Lib/test/test_tarfile.py b/Lib/test/test_tarfile.py
index bfd8693998..ea701cf96f 100644
--- a/Lib/test/test_tarfile.py
+++ b/Lib/test/test_tarfile.py
@@ -2482,6 +2482,8 @@ def test__all__(self):
'SubsequentHeaderError', 'ExFileObject', 'main'}
support.check__all__(self, tarfile, not_exported=not_exported)
+ # TODO: RUSTPYTHON
+ @unittest.expectedFailure
def test_useful_error_message_when_modules_missing(self):
fname = os.path.join(os.path.dirname(__file__), 'testtar.tar.xz')
with self.assertRaises(tarfile.ReadError) as excinfo:
From a01511cf604e32712f056448f4e0ed1cb9ed3b43 Mon Sep 17 00:00:00 2001
From: CPython Developers <>
Date: Sun, 14 Aug 2022 23:37:03 +0900
Subject: [PATCH 08/23] Update test_structseq from CPython 3.10.6
---
Lib/test/test_structseq.py | 12 ++++++++++++
1 file changed, 12 insertions(+)
diff --git a/Lib/test/test_structseq.py b/Lib/test/test_structseq.py
index 2a328c5d7f..013797e730 100644
--- a/Lib/test/test_structseq.py
+++ b/Lib/test/test_structseq.py
@@ -124,5 +124,17 @@ def test_extended_getslice(self):
self.assertEqual(list(t[start:stop:step]),
L[start:stop:step])
+ def test_match_args(self):
+ expected_args = ('tm_year', 'tm_mon', 'tm_mday', 'tm_hour', 'tm_min',
+ 'tm_sec', 'tm_wday', 'tm_yday', 'tm_isdst')
+ self.assertEqual(time.struct_time.__match_args__, expected_args)
+
+ def test_match_args_with_unnamed_fields(self):
+ expected_args = ('st_mode', 'st_ino', 'st_dev', 'st_nlink', 'st_uid',
+ 'st_gid', 'st_size')
+ self.assertEqual(os.stat_result.n_unnamed_fields, 3)
+ self.assertEqual(os.stat_result.__match_args__, expected_args)
+
+
if __name__ == "__main__":
unittest.main()
From a60cceac53213b75c2f606e6ab188a0c2eb91d50 Mon Sep 17 00:00:00 2001
From: Jeong YunWon
Date: Sun, 14 Aug 2022 23:38:03 +0900
Subject: [PATCH 09/23] mark failing tests frmo test_structseq
---
Lib/test/test_structseq.py | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/Lib/test/test_structseq.py b/Lib/test/test_structseq.py
index 013797e730..43e4cb51ea 100644
--- a/Lib/test/test_structseq.py
+++ b/Lib/test/test_structseq.py
@@ -124,11 +124,15 @@ def test_extended_getslice(self):
self.assertEqual(list(t[start:stop:step]),
L[start:stop:step])
+ # TODO: RUSTPYTHON
+ @unittest.expectedFailure
def test_match_args(self):
expected_args = ('tm_year', 'tm_mon', 'tm_mday', 'tm_hour', 'tm_min',
'tm_sec', 'tm_wday', 'tm_yday', 'tm_isdst')
self.assertEqual(time.struct_time.__match_args__, expected_args)
+ # TODO: RUSTPYTHON
+ @unittest.expectedFailure
def test_match_args_with_unnamed_fields(self):
expected_args = ('st_mode', 'st_ino', 'st_dev', 'st_nlink', 'st_uid',
'st_gid', 'st_size')
From 193fa88ea91abcf64c1e944d080e6dcacabf339a Mon Sep 17 00:00:00 2001
From: CPython Developers <>
Date: Sun, 14 Aug 2022 23:41:37 +0900
Subject: [PATCH 10/23] Update test_struct from CPython 3.10.6
---
Lib/test/test_struct.py | 60 +++++++++++++++++++++++++++++++++++++++++
1 file changed, 60 insertions(+)
diff --git a/Lib/test/test_struct.py b/Lib/test/test_struct.py
index 6ce1cfe1d2..89a330ba93 100644
--- a/Lib/test/test_struct.py
+++ b/Lib/test/test_struct.py
@@ -1,12 +1,16 @@
from collections import abc
import array
+import gc
import math
import operator
import unittest
import struct
import sys
+import weakref
from test import support
+from test.support import import_helper
+from test.support.script_helper import assert_python_ok
ISBIGENDIAN = sys.byteorder == "big"
@@ -654,6 +658,58 @@ def test_format_attr(self):
s2 = struct.Struct(s.format.encode())
self.assertEqual(s2.format, s.format)
+ def test_struct_cleans_up_at_runtime_shutdown(self):
+ code = """if 1:
+ import struct
+
+ class C:
+ def __init__(self):
+ self.pack = struct.pack
+ def __del__(self):
+ self.pack('I', -42)
+
+ struct.x = C()
+ """
+ rc, stdout, stderr = assert_python_ok("-c", code)
+ self.assertEqual(rc, 0)
+ self.assertEqual(stdout.rstrip(), b"")
+ self.assertIn(b"Exception ignored in:", stderr)
+ self.assertIn(b"C.__del__", stderr)
+
+ def test__struct_reference_cycle_cleaned_up(self):
+ # Regression test for python/cpython#94207.
+
+ # When we create a new struct module, trigger use of its cache,
+ # and then delete it ...
+ _struct_module = import_helper.import_fresh_module("_struct")
+ module_ref = weakref.ref(_struct_module)
+ _struct_module.calcsize("b")
+ del _struct_module
+
+ # Then the module should have been garbage collected.
+ gc.collect()
+ self.assertIsNone(
+ module_ref(), "_struct module was not garbage collected")
+
+ @support.cpython_only
+ def test__struct_types_immutable(self):
+ # See https://github.com/python/cpython/issues/94254
+
+ Struct = struct.Struct
+ unpack_iterator = type(struct.iter_unpack("b", b'x'))
+ for cls in (Struct, unpack_iterator):
+ with self.subTest(cls=cls):
+ with self.assertRaises(TypeError):
+ cls.x = 1
+
+
+ def test_issue35714(self):
+ # Embedded null characters should not be allowed in format strings.
+ for s in '\0', '2\0i', b'\0':
+ with self.assertRaisesRegex(struct.error,
+ 'embedded null character'):
+ struct.calcsize(s)
+
class UnpackIteratorTest(unittest.TestCase):
"""
@@ -681,6 +737,10 @@ def _check_iterator(it):
with self.assertRaises(struct.error):
s.iter_unpack(b"12")
+ def test_uninstantiable(self):
+ iter_unpack_type = type(struct.Struct(">ibcp").iter_unpack(b""))
+ self.assertRaises(TypeError, iter_unpack_type)
+
def test_iterate(self):
s = struct.Struct('>IB')
b = bytes(range(1, 16))
From 14e86e1374626e89e7054295595663905adbd415 Mon Sep 17 00:00:00 2001
From: Jeong YunWon
Date: Sun, 14 Aug 2022 23:43:34 +0900
Subject: [PATCH 11/23] mark filing tests for test_struct
---
Lib/test/test_struct.py | 6 ++++++
1 file changed, 6 insertions(+)
diff --git a/Lib/test/test_struct.py b/Lib/test/test_struct.py
index 89a330ba93..6cb7b610e5 100644
--- a/Lib/test/test_struct.py
+++ b/Lib/test/test_struct.py
@@ -658,6 +658,8 @@ def test_format_attr(self):
s2 = struct.Struct(s.format.encode())
self.assertEqual(s2.format, s.format)
+ # TODO: RUSTPYTHON
+ @unittest.expectedFailure
def test_struct_cleans_up_at_runtime_shutdown(self):
code = """if 1:
import struct
@@ -703,6 +705,8 @@ def test__struct_types_immutable(self):
cls.x = 1
+ # TODO: RUSTPYTHON
+ @unittest.expectedFailure
def test_issue35714(self):
# Embedded null characters should not be allowed in format strings.
for s in '\0', '2\0i', b'\0':
@@ -737,6 +741,8 @@ def _check_iterator(it):
with self.assertRaises(struct.error):
s.iter_unpack(b"12")
+ # TODO: RUSTPYTHON
+ @unittest.expectedFailure
def test_uninstantiable(self):
iter_unpack_type = type(struct.Struct(">ibcp").iter_unpack(b""))
self.assertRaises(TypeError, iter_unpack_type)
From 17e12dea1ec11b871a9e5d3a4a26e2f9ad83fde2 Mon Sep 17 00:00:00 2001
From: CPython Developers <>
Date: Sun, 14 Aug 2022 23:52:26 +0900
Subject: [PATCH 12/23] Update string, test_str* from cpython 3.10.6
---
Lib/string.py | 46 +++++++++++++++-----------------
Lib/test/test_strftime.py | 2 +-
Lib/test/test_string_literals.py | 3 +--
3 files changed, 24 insertions(+), 27 deletions(-)
diff --git a/Lib/string.py b/Lib/string.py
index b423ff5dc6..489777b10c 100644
--- a/Lib/string.py
+++ b/Lib/string.py
@@ -54,30 +54,7 @@ def capwords(s, sep=None):
_sentinel_dict = {}
-class _TemplateMetaclass(type):
- pattern = r"""
- %(delim)s(?:
- (?P%(delim)s) | # Escape sequence of two delimiters
- (?P%(id)s) | # delimiter and a Python identifier
- {(?P%(bid)s)} | # delimiter and a braced identifier
- (?P) # Other ill-formed delimiter exprs
- )
- """
-
- def __init__(cls, name, bases, dct):
- super(_TemplateMetaclass, cls).__init__(name, bases, dct)
- if 'pattern' in dct:
- pattern = cls.pattern
- else:
- pattern = _TemplateMetaclass.pattern % {
- 'delim' : _re.escape(cls.delimiter),
- 'id' : cls.idpattern,
- 'bid' : cls.braceidpattern or cls.idpattern,
- }
- cls.pattern = _re.compile(pattern, cls.flags | _re.VERBOSE)
-
-
-class Template(metaclass=_TemplateMetaclass):
+class Template:
"""A string class for supporting $-substitutions."""
delimiter = '$'
@@ -89,6 +66,24 @@ class Template(metaclass=_TemplateMetaclass):
braceidpattern = None
flags = _re.IGNORECASE
+ def __init_subclass__(cls):
+ super().__init_subclass__()
+ if 'pattern' in cls.__dict__:
+ pattern = cls.pattern
+ else:
+ delim = _re.escape(cls.delimiter)
+ id = cls.idpattern
+ bid = cls.braceidpattern or cls.idpattern
+ pattern = fr"""
+ {delim}(?:
+ (?P{delim}) | # Escape sequence of two delimiters
+ (?P{id}) | # delimiter and a Python identifier
+ {{(?P{bid})}} | # delimiter and a braced identifier
+ (?P) # Other ill-formed delimiter exprs
+ )
+ """
+ cls.pattern = _re.compile(pattern, cls.flags | _re.VERBOSE)
+
def __init__(self, template):
self.template = template
@@ -146,6 +141,9 @@ def convert(mo):
self.pattern)
return self.pattern.sub(convert, self.template)
+# Initialize Template.pattern. __init_subclass__() is automatically called
+# only for subclasses, not for the Template class itself.
+Template.__init_subclass__()
########################################################################
diff --git a/Lib/test/test_strftime.py b/Lib/test/test_strftime.py
index 16440ef3bd..f4c0bd0de9 100644
--- a/Lib/test/test_strftime.py
+++ b/Lib/test/test_strftime.py
@@ -115,7 +115,7 @@ def strftest1(self, now):
)
for e in expectations:
- # musn't raise a value error
+ # mustn't raise a value error
try:
result = time.strftime(e[0], now)
except ValueError as error:
diff --git a/Lib/test/test_string_literals.py b/Lib/test/test_string_literals.py
index f117428152..8a822dd2f7 100644
--- a/Lib/test/test_string_literals.py
+++ b/Lib/test/test_string_literals.py
@@ -63,8 +63,6 @@ def byte(i):
class TestLiterals(unittest.TestCase):
- from test.support.warnings_helper import check_syntax_warning
-
def setUp(self):
self.save_path = sys.path[:]
self.tmpdir = tempfile.mkdtemp()
@@ -133,6 +131,7 @@ def test_eval_str_invalid_escape(self):
self.assertEqual(w, [])
self.assertEqual(exc.filename, '')
self.assertEqual(exc.lineno, 1)
+ self.assertEqual(exc.offset, 1)
# TODO: RUSTPYTHON
@unittest.expectedFailure
From 3fd08b32ef0fcdba113aaeb9e93c3b35a53251ef Mon Sep 17 00:00:00 2001
From: CPython Developers <>
Date: Sun, 14 Aug 2022 23:54:55 +0900
Subject: [PATCH 13/23] Update stat, statistics from CPython 3.10.6
---
Lib/test/test_stat.py | 17 ++++++++++++-----
1 file changed, 12 insertions(+), 5 deletions(-)
diff --git a/Lib/test/test_stat.py b/Lib/test/test_stat.py
index cc2c85bfb9..c1edea8491 100644
--- a/Lib/test/test_stat.py
+++ b/Lib/test/test_stat.py
@@ -2,9 +2,11 @@
import os
import socket
import sys
-from test.support.socket_helper import skip_unless_bind_unix_socket
-from test.support.os_helper import TESTFN
+from test.support import os_helper
+from test.support import socket_helper
from test.support.import_helper import import_fresh_module
+from test.support.os_helper import TESTFN
+
c_stat = import_fresh_module('stat', fresh=['_stat'])
py_stat = import_fresh_module('stat', blocked=['_stat'])
@@ -172,11 +174,16 @@ def test_link(self):
@unittest.skipUnless(hasattr(os, 'mkfifo'), 'os.mkfifo not available')
def test_fifo(self):
+ if sys.platform == "vxworks":
+ fifo_path = os.path.join("/fifos/", TESTFN)
+ else:
+ fifo_path = TESTFN
+ self.addCleanup(os_helper.unlink, fifo_path)
try:
- os.mkfifo(TESTFN, 0o700)
+ os.mkfifo(fifo_path, 0o700)
except PermissionError as e:
self.skipTest('os.mkfifo(): %s' % e)
- st_mode, modestr = self.get_mode()
+ st_mode, modestr = self.get_mode(fifo_path)
self.assertEqual(modestr, 'prwx------')
self.assertS_IS("FIFO", st_mode)
@@ -194,7 +201,7 @@ def test_devices(self):
self.assertS_IS("BLK", st_mode)
break
- @skip_unless_bind_unix_socket
+ @socket_helper.skip_unless_bind_unix_socket
def test_socket(self):
with socket.socket(socket.AF_UNIX) as s:
s.bind(TESTFN)
From 939257411bde3f48c9dc08ec59b6103ccec5f727 Mon Sep 17 00:00:00 2001
From: CPython Developers <>
Date: Mon, 15 Aug 2022 00:03:36 +0900
Subject: [PATCH 14/23] Update socket{server} from CPython 3.10.6
---
Lib/socket.py | 49 ++-
Lib/socketserver.py | 72 +++-
Lib/test/test_socket.py | 617 ++++++++++++++++++++++++++--------
Lib/test/test_socketserver.py | 13 +-
4 files changed, 588 insertions(+), 163 deletions(-)
diff --git a/Lib/socket.py b/Lib/socket.py
index f83f36d0ad..63ba0acc90 100644
--- a/Lib/socket.py
+++ b/Lib/socket.py
@@ -12,6 +12,8 @@
socket() -- create a new socket object
socketpair() -- create a pair of new socket objects [*]
fromfd() -- create a socket object from an open file descriptor [*]
+send_fds() -- Send file descriptor to the socket.
+recv_fds() -- Recieve file descriptors from the socket.
fromshare() -- create a socket object from data received from socket.share() [*]
gethostname() -- return the current hostname
gethostbyname() -- map a hostname to its IP number
@@ -104,7 +106,6 @@ def _intenum_converter(value, enum_klass):
except ValueError:
return value
-_realsocket = socket
# WSA error codes
if sys.platform.lower().startswith("win"):
@@ -336,6 +337,7 @@ def makefile(self, mode="r", buffering=None, *,
buffer = io.BufferedWriter(raw, buffering)
if binary:
return buffer
+ encoding = io.text_encoding(encoding)
text = io.TextIOWrapper(buffer, encoding, errors, newline)
text.mode = mode
return text
@@ -376,7 +378,7 @@ def _sendfile_use_sendfile(self, file, offset=0, count=None):
try:
while True:
if timeout and not selector_select(timeout):
- raise _socket.timeout('timed out')
+ raise TimeoutError('timed out')
if count:
blocksize = count - total_sent
if blocksize <= 0:
@@ -543,6 +545,40 @@ def fromfd(fd, family, type, proto=0):
nfd = dup(fd)
return socket(family, type, proto, nfd)
+if hasattr(_socket.socket, "sendmsg"):
+ import array
+
+ def send_fds(sock, buffers, fds, flags=0, address=None):
+ """ send_fds(sock, buffers, fds[, flags[, address]]) -> integer
+
+ Send the list of file descriptors fds over an AF_UNIX socket.
+ """
+ return sock.sendmsg(buffers, [(_socket.SOL_SOCKET,
+ _socket.SCM_RIGHTS, array.array("i", fds))])
+ __all__.append("send_fds")
+
+if hasattr(_socket.socket, "recvmsg"):
+ import array
+
+ def recv_fds(sock, bufsize, maxfds, flags=0):
+ """ recv_fds(sock, bufsize, maxfds[, flags]) -> (data, list of file
+ descriptors, msg_flags, address)
+
+ Receive up to maxfds file descriptors returning the message
+ data and a list containing the descriptors.
+ """
+ # Array of ints
+ fds = array.array("i")
+ msg, ancdata, flags, addr = sock.recvmsg(bufsize,
+ _socket.CMSG_LEN(maxfds * fds.itemsize))
+ for cmsg_level, cmsg_type, cmsg_data in ancdata:
+ if (cmsg_level == _socket.SOL_SOCKET and cmsg_type == _socket.SCM_RIGHTS):
+ fds.frombytes(cmsg_data[:
+ len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
+
+ return msg, list(fds), flags, addr
+ __all__.append("recv_fds")
+
if hasattr(_socket.socket, "share"):
def fromshare(info):
""" fromshare(info) -> socket object
@@ -671,7 +707,7 @@ def readinto(self, b):
self._timeout_occurred = True
raise
except error as e:
- if e.args[0] in _blocking_errnos:
+ if e.errno in _blocking_errnos:
return None
raise
@@ -687,7 +723,7 @@ def write(self, b):
return self._sock.send(b)
except error as e:
# XXX what about EINTR?
- if e.args[0] in _blocking_errnos:
+ if e.errno in _blocking_errnos:
return None
raise
@@ -746,8 +782,9 @@ def getfqdn(name=''):
An empty argument is interpreted as meaning the local host.
First the hostname returned by gethostbyaddr() is checked, then
- possibly existing aliases. In case no FQDN is available, hostname
- from gethostname() is returned.
+ possibly existing aliases. In case no FQDN is available and `name`
+ was given, it is returned unchanged. If `name` was empty or '0.0.0.0',
+ hostname from gethostname() is returned.
"""
name = name.strip()
if not name or name == '0.0.0.0':
diff --git a/Lib/socketserver.py b/Lib/socketserver.py
index b654c06d84..2905e3eac3 100644
--- a/Lib/socketserver.py
+++ b/Lib/socketserver.py
@@ -24,7 +24,7 @@
The classes in this module favor the server type that is simplest to
write: a synchronous TCP/IP server. This is bad class design, but
-save some typing. (There's also the issue that a deep class hierarchy
+saves some typing. (There's also the issue that a deep class hierarchy
slows down method lookups.)
There are five classes in an inheritance diagram, four of which represent
@@ -126,7 +126,6 @@ class will essentially render the service "deaf" while one request is
import socket
import selectors
import os
-import errno
import sys
try:
import threading
@@ -234,6 +233,9 @@ def serve_forever(self, poll_interval=0.5):
while not self.__shutdown_request:
ready = selector.select(poll_interval)
+ # bpo-35017: shutdown() called during select(), exit immediately.
+ if self.__shutdown_request:
+ break
if ready:
self._handle_request_noblock()
@@ -375,7 +377,7 @@ def handle_error(self, request, client_address):
"""
print('-'*40, file=sys.stderr)
- print('Exception happened during processing of request from',
+ print('Exception occurred during processing of request from',
client_address, file=sys.stderr)
import traceback
traceback.print_exc()
@@ -547,8 +549,10 @@ class ForkingMixIn:
timeout = 300
active_children = None
max_children = 40
+ # If true, server_close() waits until all child processes complete.
+ block_on_close = True
- def collect_children(self):
+ def collect_children(self, *, blocking=False):
"""Internal routine to wait for children that have exited."""
if self.active_children is None:
return
@@ -572,7 +576,8 @@ def collect_children(self):
# Now reap all defunct children.
for pid in self.active_children.copy():
try:
- pid, _ = os.waitpid(pid, os.WNOHANG)
+ flags = 0 if blocking else os.WNOHANG
+ pid, _ = os.waitpid(pid, flags)
# if the child hasn't exited yet, pid will be 0 and ignored by
# discard() below
self.active_children.discard(pid)
@@ -592,7 +597,7 @@ def handle_timeout(self):
def service_actions(self):
"""Collect the zombie child processes regularly in the ForkingMixIn.
- service_actions is called in the BaseServer's serve_forver loop.
+ service_actions is called in the BaseServer's serve_forever loop.
"""
self.collect_children()
@@ -621,6 +626,43 @@ def process_request(self, request, client_address):
finally:
os._exit(status)
+ def server_close(self):
+ super().server_close()
+ self.collect_children(blocking=self.block_on_close)
+
+
+class _Threads(list):
+ """
+ Joinable list of all non-daemon threads.
+ """
+ def append(self, thread):
+ self.reap()
+ if thread.daemon:
+ return
+ super().append(thread)
+
+ def pop_all(self):
+ self[:], result = [], self[:]
+ return result
+
+ def join(self):
+ for thread in self.pop_all():
+ thread.join()
+
+ def reap(self):
+ self[:] = (thread for thread in self if thread.is_alive())
+
+
+class _NoThreads:
+ """
+ Degenerate version of _Threads.
+ """
+ def append(self, thread):
+ pass
+
+ def join(self):
+ pass
+
class ThreadingMixIn:
"""Mix-in class to handle each request in a new thread."""
@@ -628,6 +670,11 @@ class ThreadingMixIn:
# Decides how threads will act upon termination of the
# main process
daemon_threads = False
+ # If true, server_close() waits until all non-daemonic threads terminate.
+ block_on_close = True
+ # Threads object
+ # used by server_close() to wait for all threads completion.
+ _threads = _NoThreads()
def process_request_thread(self, request, client_address):
"""Same as in BaseServer but as a thread.
@@ -644,11 +691,18 @@ def process_request_thread(self, request, client_address):
def process_request(self, request, client_address):
"""Start a new thread to process the request."""
+ if self.block_on_close:
+ vars(self).setdefault('_threads', _Threads())
t = threading.Thread(target = self.process_request_thread,
args = (request, client_address))
t.daemon = self.daemon_threads
+ self._threads.append(t)
t.start()
+ def server_close(self):
+ super().server_close()
+ self._threads.join()
+
if hasattr(os, "fork"):
class ForkingUDPServer(ForkingMixIn, UDPServer): pass
@@ -773,10 +827,8 @@ def writable(self):
def write(self, b):
self._sock.sendall(b)
- # XXX RustPython TODO: implement memoryview properly
- #with memoryview(b) as view:
- # return view.nbytes
- return len(b)
+ with memoryview(b) as view:
+ return view.nbytes
def fileno(self):
return self._sock.fileno()
diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py
index f3b7146da7..80ba3d5f3e 100644
--- a/Lib/test/test_socket.py
+++ b/Lib/test/test_socket.py
@@ -40,7 +40,6 @@
HOST = socket_helper.HOST
# test unicode string and carriage return
MSG = 'Michael Gilfix was here\u1234\r\n'.encode('utf-8')
-MAIN_TIMEOUT = 60.0
VSOCKPORT = 1234
AIX = platform.system() == "AIX"
@@ -83,6 +82,16 @@ def _have_socket_can_isotp():
s.close()
return True
+def _have_socket_can_j1939():
+ """Check whether CAN J1939 sockets are supported on this host."""
+ try:
+ s = socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_J1939)
+ except (AttributeError, OSError):
+ return False
+ else:
+ s.close()
+ return True
+
def _have_socket_rds():
"""Check whether RDS sockets are supported on this host."""
try:
@@ -119,6 +128,19 @@ def _have_socket_vsock():
return ret
+def _have_socket_bluetooth():
+ """Check whether AF_BLUETOOTH sockets are supported on this host."""
+ try:
+ # RFCOMM is supported by all platforms with bluetooth support. Windows
+ # does not support omitting the protocol.
+ s = socket.socket(socket.AF_BLUETOOTH, socket.SOCK_STREAM, socket.BTPROTO_RFCOMM)
+ except (AttributeError, OSError):
+ return False
+ else:
+ s.close()
+ return True
+
+
@contextlib.contextmanager
def socket_setdefaulttimeout(timeout):
old_timeout = socket.getdefaulttimeout()
@@ -133,6 +155,8 @@ def socket_setdefaulttimeout(timeout):
HAVE_SOCKET_CAN_ISOTP = _have_socket_can_isotp()
+HAVE_SOCKET_CAN_J1939 = _have_socket_can_j1939()
+
HAVE_SOCKET_RDS = _have_socket_rds()
HAVE_SOCKET_ALG = _have_socket_alg()
@@ -141,6 +165,10 @@ def socket_setdefaulttimeout(timeout):
HAVE_SOCKET_VSOCK = _have_socket_vsock()
+HAVE_SOCKET_UDPLITE = hasattr(socket, "IPPROTO_UDPLITE")
+
+HAVE_SOCKET_BLUETOOTH = _have_socket_bluetooth()
+
# Size in bytes of the int type
SIZEOF_INT = array.array("i").itemsize
@@ -165,7 +193,13 @@ def tearDown(self):
self.serv.close()
self.serv = None
-class ThreadSafeCleanupTestCase(unittest.TestCase):
+class SocketUDPLITETest(SocketUDPTest):
+
+ def setUp(self):
+ self.serv = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDPLITE)
+ self.port = socket_helper.bind_port(self.serv)
+
+class ThreadSafeCleanupTestCase:
"""Subclass of unittest.TestCase with thread-safe cleanup methods.
This subclass protects the addCleanup() and doCleanups() methods
@@ -292,9 +326,7 @@ def _testFoo(self):
def __init__(self):
# Swap the true setup function
self.__setUp = self.setUp
- self.__tearDown = self.tearDown
self.setUp = self._setUp
- self.tearDown = self._tearDown
def serverExplicitReady(self):
"""This method allows the server to explicitly indicate that
@@ -306,6 +338,7 @@ def serverExplicitReady(self):
def _setUp(self):
self.wait_threads = threading_helper.wait_threads_exit()
self.wait_threads.__enter__()
+ self.addCleanup(self.wait_threads.__exit__, None, None, None)
self.server_ready = threading.Event()
self.client_ready = threading.Event()
@@ -313,6 +346,11 @@ def _setUp(self):
self.queue = queue.Queue(1)
self.server_crashed = False
+ def raise_queued_exception():
+ if self.queue.qsize():
+ raise self.queue.get()
+ self.addCleanup(raise_queued_exception)
+
# Do some munging to start the client test.
methodname = self.id()
i = methodname.rfind('.')
@@ -329,15 +367,7 @@ def _setUp(self):
finally:
self.server_ready.set()
self.client_ready.wait()
-
- def _tearDown(self):
- self.__tearDown()
- self.done.wait()
- self.wait_threads.__exit__(None, None, None)
-
- if self.queue.qsize():
- exc = self.queue.get()
- raise exc
+ self.addCleanup(self.done.wait)
def clientRun(self, test_func):
self.server_ready.wait()
@@ -396,6 +426,22 @@ def clientTearDown(self):
self.cli = None
ThreadableTest.clientTearDown(self)
+@unittest.skipUnless(HAVE_SOCKET_UDPLITE,
+ 'UDPLITE sockets required for this test.')
+class ThreadedUDPLITESocketTest(SocketUDPLITETest, ThreadableTest):
+
+ def __init__(self, methodName='runTest'):
+ SocketUDPLITETest.__init__(self, methodName=methodName)
+ ThreadableTest.__init__(self)
+
+ def clientSetUp(self):
+ self.cli = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDPLITE)
+
+ def clientTearDown(self):
+ self.cli.close()
+ self.cli = None
+ ThreadableTest.clientTearDown(self)
+
class ThreadedCANSocketTest(SocketCANTest, ThreadableTest):
def __init__(self, methodName='runTest'):
@@ -681,6 +727,12 @@ class UDPTestBase(InetTestBase):
def newSocket(self):
return socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+class UDPLITETestBase(InetTestBase):
+ """Base class for UDPLITE-over-IPv4 tests."""
+
+ def newSocket(self):
+ return socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDPLITE)
+
class SCTPStreamBase(InetTestBase):
"""Base class for SCTP tests in one-to-one (SOCK_STREAM) mode."""
@@ -700,6 +752,12 @@ class UDP6TestBase(Inet6TestBase):
def newSocket(self):
return socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
+class UDPLITE6TestBase(Inet6TestBase):
+ """Base class for UDPLITE-over-IPv6 tests."""
+
+ def newSocket(self):
+ return socket.socket(socket.AF_INET6, socket.SOCK_DGRAM, socket.IPPROTO_UDPLITE)
+
# Test-skipping decorators for use with ThreadableTest.
@@ -808,6 +866,7 @@ def test_weakref(self):
p = proxy(s)
self.assertEqual(p.fileno(), s.fileno())
s = None
+ support.gc_collect() # For PyPy or other GCs.
try:
p.fileno()
except ReferenceError:
@@ -859,10 +918,8 @@ def testSendtoErrors(self):
self.assertIn('not NoneType', str(cm.exception))
with self.assertRaises(TypeError) as cm:
s.sendto(b'foo', 'bar', sockname)
- self.assertIn('an integer is required', str(cm.exception))
with self.assertRaises(TypeError) as cm:
s.sendto(b'foo', None, None)
- self.assertIn('an integer is required', str(cm.exception))
# wrong number of args
with self.assertRaises(TypeError) as cm:
s.sendto(b'foo')
@@ -901,6 +958,37 @@ def testWindowsSpecificConstants(self):
socket.IPPROTO_L2TP
socket.IPPROTO_SCTP
+ @unittest.skipUnless(sys.platform == 'darwin', 'macOS specific test')
+ @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test')
+ def test3542SocketOptions(self):
+ # Ref. issue #35569 and https://tools.ietf.org/html/rfc3542
+ opts = {
+ 'IPV6_CHECKSUM',
+ 'IPV6_DONTFRAG',
+ 'IPV6_DSTOPTS',
+ 'IPV6_HOPLIMIT',
+ 'IPV6_HOPOPTS',
+ 'IPV6_NEXTHOP',
+ 'IPV6_PATHMTU',
+ 'IPV6_PKTINFO',
+ 'IPV6_RECVDSTOPTS',
+ 'IPV6_RECVHOPLIMIT',
+ 'IPV6_RECVHOPOPTS',
+ 'IPV6_RECVPATHMTU',
+ 'IPV6_RECVPKTINFO',
+ 'IPV6_RECVRTHDR',
+ 'IPV6_RECVTCLASS',
+ 'IPV6_RTHDR',
+ 'IPV6_RTHDRDSTOPTS',
+ 'IPV6_RTHDR_TYPE_0',
+ 'IPV6_TCLASS',
+ 'IPV6_USE_MIN_MTU',
+ }
+ for opt in opts:
+ self.assertTrue(
+ hasattr(socket, opt), f"Missing RFC3542 socket option '{opt}'"
+ )
+
def testHostnameRes(self):
# Testing hostname resolution mechanisms
hostname = socket.gethostname()
@@ -1032,9 +1120,11 @@ def testNtoHErrors(self):
s_good_values = [0, 1, 2, 0xffff]
l_good_values = s_good_values + [0xffffffff]
l_bad_values = [-1, -2, 1<<32, 1<<1000]
- s_bad_values = l_bad_values + [_testcapi.INT_MIN - 1,
- _testcapi.INT_MAX + 1]
- s_deprecated_values = [1<<16, _testcapi.INT_MAX]
+ s_bad_values = (
+ l_bad_values +
+ [_testcapi.INT_MIN-1, _testcapi.INT_MAX+1] +
+ [1 << 16, _testcapi.INT_MAX]
+ )
for k in s_good_values:
socket.ntohs(k)
socket.htons(k)
@@ -1047,9 +1137,6 @@ def testNtoHErrors(self):
for k in l_bad_values:
self.assertRaises(OverflowError, socket.ntohl, k)
self.assertRaises(OverflowError, socket.htonl, k)
- for k in s_deprecated_values:
- self.assertWarns(DeprecationWarning, socket.ntohs, k)
- self.assertWarns(DeprecationWarning, socket.htons, k)
def testGetServBy(self):
eq = self.assertEqual
@@ -1531,7 +1618,7 @@ def raising_handler(*args):
if with_timeout:
signal.signal(signal.SIGALRM, ok_handler)
signal.alarm(1)
- self.assertRaises(socket.timeout, c.sendall,
+ self.assertRaises(TimeoutError, c.sendall,
b"x" * support.SOCK_MAX_SIZE)
finally:
signal.alarm(0)
@@ -1603,7 +1690,8 @@ def test_makefile_mode(self):
for mode in 'r', 'rb', 'rw', 'w', 'wb':
with self.subTest(mode=mode):
with socket.socket() as sock:
- with sock.makefile(mode) as fp:
+ encoding = None if "b" in mode else "utf-8"
+ with sock.makefile(mode, encoding=encoding) as fp:
self.assertEqual(fp.mode, mode)
def test_makefile_invalid_mode(self):
@@ -1662,6 +1750,7 @@ def test_getaddrinfo_ipv6_basic(self):
@unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.')
@unittest.skipIf(sys.platform == 'win32', 'does not work on Windows')
@unittest.skipIf(AIX, 'Symbolic scope id does not work')
+ @unittest.skipUnless(hasattr(socket, 'if_nameindex'), "test needs socket.if_nameindex()")
def test_getaddrinfo_ipv6_scopeid_symbolic(self):
# Just pick up any network interface (Linux, Mac OS X)
(ifindex, test_interface) = socket.if_nameindex()[0]
@@ -1694,6 +1783,7 @@ def test_getaddrinfo_ipv6_scopeid_numeric(self):
@unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.')
@unittest.skipIf(sys.platform == 'win32', 'does not work on Windows')
@unittest.skipIf(AIX, 'Symbolic scope id does not work')
+ @unittest.skipUnless(hasattr(socket, 'if_nameindex'), "test needs socket.if_nameindex()")
def test_getnameinfo_ipv6_scopeid_symbolic(self):
# Just pick up any network interface.
(ifindex, test_interface) = socket.if_nameindex()[0]
@@ -1827,11 +1917,11 @@ def test_socket_fileno(self):
socket.SOCK_STREAM)
def test_socket_fileno_rejects_float(self):
- with self.assertRaisesRegex(TypeError, "integer argument expected"):
+ with self.assertRaises(TypeError):
socket.socket(socket.AF_INET, socket.SOCK_STREAM, fileno=42.5)
def test_socket_fileno_rejects_other_types(self):
- with self.assertRaisesRegex(TypeError, "integer is required"):
+ with self.assertRaises(TypeError):
socket.socket(socket.AF_INET, socket.SOCK_STREAM, fileno="foo")
def test_socket_fileno_rejects_invalid_socket(self):
@@ -2093,6 +2183,68 @@ def testBind(self):
raise
+@unittest.skipUnless(HAVE_SOCKET_CAN_J1939, 'CAN J1939 required for this test.')
+class J1939Test(unittest.TestCase):
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.interface = "vcan0"
+
+ @unittest.skipUnless(hasattr(socket, "CAN_J1939"),
+ 'socket.CAN_J1939 required for this test.')
+ def testJ1939Constants(self):
+ socket.CAN_J1939
+
+ socket.J1939_MAX_UNICAST_ADDR
+ socket.J1939_IDLE_ADDR
+ socket.J1939_NO_ADDR
+ socket.J1939_NO_NAME
+ socket.J1939_PGN_REQUEST
+ socket.J1939_PGN_ADDRESS_CLAIMED
+ socket.J1939_PGN_ADDRESS_COMMANDED
+ socket.J1939_PGN_PDU1_MAX
+ socket.J1939_PGN_MAX
+ socket.J1939_NO_PGN
+
+ # J1939 socket options
+ socket.SO_J1939_FILTER
+ socket.SO_J1939_PROMISC
+ socket.SO_J1939_SEND_PRIO
+ socket.SO_J1939_ERRQUEUE
+
+ socket.SCM_J1939_DEST_ADDR
+ socket.SCM_J1939_DEST_NAME
+ socket.SCM_J1939_PRIO
+ socket.SCM_J1939_ERRQUEUE
+
+ socket.J1939_NLA_PAD
+ socket.J1939_NLA_BYTES_ACKED
+
+ socket.J1939_EE_INFO_NONE
+ socket.J1939_EE_INFO_TX_ABORT
+
+ socket.J1939_FILTER_MAX
+
+ @unittest.skipUnless(hasattr(socket, "CAN_J1939"),
+ 'socket.CAN_J1939 required for this test.')
+ def testCreateJ1939Socket(self):
+ with socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_J1939) as s:
+ pass
+
+ def testBind(self):
+ try:
+ with socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_J1939) as s:
+ addr = self.interface, socket.J1939_NO_NAME, socket.J1939_NO_PGN, socket.J1939_NO_ADDR
+ s.bind(addr)
+ self.assertEqual(s.getsockname(), addr)
+ except OSError as e:
+ if e.errno == errno.ENODEV:
+ self.skipTest('network interface `%s` does not exist' %
+ self.interface)
+ else:
+ raise
+
+
@unittest.skipUnless(HAVE_SOCKET_RDS, 'RDS sockets required for this test.')
class BasicRDSTest(unittest.TestCase):
@@ -2252,6 +2404,45 @@ def testSocketBufferSize(self):
socket.SO_VM_SOCKETS_BUFFER_MIN_SIZE))
+@unittest.skipUnless(HAVE_SOCKET_BLUETOOTH,
+ 'Bluetooth sockets required for this test.')
+class BasicBluetoothTest(unittest.TestCase):
+
+ def testBluetoothConstants(self):
+ socket.BDADDR_ANY
+ socket.BDADDR_LOCAL
+ socket.AF_BLUETOOTH
+ socket.BTPROTO_RFCOMM
+
+ if sys.platform != "win32":
+ socket.BTPROTO_HCI
+ socket.SOL_HCI
+ socket.BTPROTO_L2CAP
+
+ if not sys.platform.startswith("freebsd"):
+ socket.BTPROTO_SCO
+
+ def testCreateRfcommSocket(self):
+ with socket.socket(socket.AF_BLUETOOTH, socket.SOCK_STREAM, socket.BTPROTO_RFCOMM) as s:
+ pass
+
+ @unittest.skipIf(sys.platform == "win32", "windows does not support L2CAP sockets")
+ def testCreateL2capSocket(self):
+ with socket.socket(socket.AF_BLUETOOTH, socket.SOCK_SEQPACKET, socket.BTPROTO_L2CAP) as s:
+ pass
+
+ @unittest.skipIf(sys.platform == "win32", "windows does not support HCI sockets")
+ def testCreateHciSocket(self):
+ with socket.socket(socket.AF_BLUETOOTH, socket.SOCK_RAW, socket.BTPROTO_HCI) as s:
+ pass
+
+ @unittest.skipIf(sys.platform == "win32" or sys.platform.startswith("freebsd"),
+ "windows and freebsd do not support SCO sockets")
+ def testCreateScoSocket(self):
+ with socket.socket(socket.AF_BLUETOOTH, socket.SOCK_SEQPACKET, socket.BTPROTO_SCO) as s:
+ pass
+
+
class BasicTCPTest(SocketConnectedTest):
def __init__(self, methodName='runTest'):
@@ -2403,6 +2594,37 @@ def testRecvFromNegative(self):
def _testRecvFromNegative(self):
self.cli.sendto(MSG, 0, (HOST, self.port))
+
+@unittest.skipUnless(HAVE_SOCKET_UDPLITE,
+ 'UDPLITE sockets required for this test.')
+class BasicUDPLITETest(ThreadedUDPLITESocketTest):
+
+ def __init__(self, methodName='runTest'):
+ ThreadedUDPLITESocketTest.__init__(self, methodName=methodName)
+
+ def testSendtoAndRecv(self):
+ # Testing sendto() and Recv() over UDPLITE
+ msg = self.serv.recv(len(MSG))
+ self.assertEqual(msg, MSG)
+
+ def _testSendtoAndRecv(self):
+ self.cli.sendto(MSG, 0, (HOST, self.port))
+
+ def testRecvFrom(self):
+ # Testing recvfrom() over UDPLITE
+ msg, addr = self.serv.recvfrom(len(MSG))
+ self.assertEqual(msg, MSG)
+
+ def _testRecvFrom(self):
+ self.cli.sendto(MSG, 0, (HOST, self.port))
+
+ def testRecvFromNegative(self):
+ # Negative lengths passed to recvfrom should give ValueError.
+ self.assertRaises(ValueError, self.serv.recvfrom, -1)
+
+ def _testRecvFromNegative(self):
+ self.cli.sendto(MSG, 0, (HOST, self.port))
+
# Tests for the sendmsg()/recvmsg() interface. Where possible, the
# same test code is used with different families and types of socket
# (e.g. stream, datagram), and tests using recvmsg() are repeated
@@ -2767,7 +2989,7 @@ def _testSendmsgTimeout(self):
try:
while True:
self.sendmsgToServer([b"a"*512])
- except socket.timeout:
+ except TimeoutError:
pass
except OSError as exc:
if exc.errno != errno.ENOMEM:
@@ -2775,7 +2997,7 @@ def _testSendmsgTimeout(self):
# bpo-33937 the test randomly fails on Travis CI with
# "OSError: [Errno 12] Cannot allocate memory"
else:
- self.fail("socket.timeout not raised")
+ self.fail("TimeoutError not raised")
finally:
self.misc_event.set()
@@ -2910,7 +3132,7 @@ def testRecvmsgTimeout(self):
# Check that timeout works.
try:
self.serv_sock.settimeout(0.03)
- self.assertRaises(socket.timeout,
+ self.assertRaises(TimeoutError,
self.doRecvmsg, self.serv_sock, len(MSG))
finally:
self.misc_event.set()
@@ -3930,7 +4152,7 @@ def _testSecondCmsgTruncLen0Minus1(self):
@requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
"IPV6_RECVTCLASS", "IPV6_TCLASS")
- def testSecomdCmsgTruncInData(self):
+ def testSecondCmsgTruncInData(self):
# Test truncation of the second of two control messages inside
# its associated data.
self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
@@ -3965,8 +4187,8 @@ def testSecomdCmsgTruncInData(self):
self.assertEqual(ancdata, [])
- @testSecomdCmsgTruncInData.client_skip
- def _testSecomdCmsgTruncInData(self):
+ @testSecondCmsgTruncInData.client_skip
+ def _testSecondCmsgTruncInData(self):
self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
self.sendToServer(MSG)
@@ -4036,6 +4258,89 @@ class RecvmsgIntoRFC3542AncillaryUDP6Test(RecvmsgIntoMixin,
pass
+@unittest.skipUnless(HAVE_SOCKET_UDPLITE,
+ 'UDPLITE sockets required for this test.')
+class SendrecvmsgUDPLITETestBase(SendrecvmsgDgramFlagsBase,
+ SendrecvmsgConnectionlessBase,
+ ThreadedSocketTestMixin, UDPLITETestBase):
+ pass
+
+@unittest.skipUnless(HAVE_SOCKET_UDPLITE,
+ 'UDPLITE sockets required for this test.')
+@requireAttrs(socket.socket, "sendmsg")
+class SendmsgUDPLITETest(SendmsgConnectionlessTests, SendrecvmsgUDPLITETestBase):
+ pass
+
+@unittest.skipUnless(HAVE_SOCKET_UDPLITE,
+ 'UDPLITE sockets required for this test.')
+@requireAttrs(socket.socket, "recvmsg")
+class RecvmsgUDPLITETest(RecvmsgTests, SendrecvmsgUDPLITETestBase):
+ pass
+
+@unittest.skipUnless(HAVE_SOCKET_UDPLITE,
+ 'UDPLITE sockets required for this test.')
+@requireAttrs(socket.socket, "recvmsg_into")
+class RecvmsgIntoUDPLITETest(RecvmsgIntoTests, SendrecvmsgUDPLITETestBase):
+ pass
+
+
+@unittest.skipUnless(HAVE_SOCKET_UDPLITE,
+ 'UDPLITE sockets required for this test.')
+class SendrecvmsgUDPLITE6TestBase(SendrecvmsgDgramFlagsBase,
+ SendrecvmsgConnectionlessBase,
+ ThreadedSocketTestMixin, UDPLITE6TestBase):
+
+ def checkRecvmsgAddress(self, addr1, addr2):
+ # Called to compare the received address with the address of
+ # the peer, ignoring scope ID
+ self.assertEqual(addr1[:-1], addr2[:-1])
+
+@requireAttrs(socket.socket, "sendmsg")
+@unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.')
+@unittest.skipUnless(HAVE_SOCKET_UDPLITE,
+ 'UDPLITE sockets required for this test.')
+@requireSocket("AF_INET6", "SOCK_DGRAM")
+class SendmsgUDPLITE6Test(SendmsgConnectionlessTests, SendrecvmsgUDPLITE6TestBase):
+ pass
+
+@requireAttrs(socket.socket, "recvmsg")
+@unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.')
+@unittest.skipUnless(HAVE_SOCKET_UDPLITE,
+ 'UDPLITE sockets required for this test.')
+@requireSocket("AF_INET6", "SOCK_DGRAM")
+class RecvmsgUDPLITE6Test(RecvmsgTests, SendrecvmsgUDPLITE6TestBase):
+ pass
+
+@requireAttrs(socket.socket, "recvmsg_into")
+@unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.')
+@unittest.skipUnless(HAVE_SOCKET_UDPLITE,
+ 'UDPLITE sockets required for this test.')
+@requireSocket("AF_INET6", "SOCK_DGRAM")
+class RecvmsgIntoUDPLITE6Test(RecvmsgIntoTests, SendrecvmsgUDPLITE6TestBase):
+ pass
+
+@requireAttrs(socket.socket, "recvmsg")
+@unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.')
+@unittest.skipUnless(HAVE_SOCKET_UDPLITE,
+ 'UDPLITE sockets required for this test.')
+@requireAttrs(socket, "IPPROTO_IPV6")
+@requireSocket("AF_INET6", "SOCK_DGRAM")
+class RecvmsgRFC3542AncillaryUDPLITE6Test(RFC3542AncillaryTest,
+ SendrecvmsgUDPLITE6TestBase):
+ pass
+
+@requireAttrs(socket.socket, "recvmsg_into")
+@unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.')
+@unittest.skipUnless(HAVE_SOCKET_UDPLITE,
+ 'UDPLITE sockets required for this test.')
+@requireAttrs(socket, "IPPROTO_IPV6")
+@requireSocket("AF_INET6", "SOCK_DGRAM")
+class RecvmsgIntoRFC3542AncillaryUDPLITE6Test(RecvmsgIntoMixin,
+ RFC3542AncillaryTest,
+ SendrecvmsgUDPLITE6TestBase):
+ pass
+
+
class SendrecvmsgTCPTestBase(SendrecvmsgConnectedBase,
ConnectedStreamTestMixin, TCPTestBase):
pass
@@ -4133,7 +4438,7 @@ class RecvmsgIntoSCMRightsStreamTest(RecvmsgIntoMixin, SCMRightsTest,
# threads alive during the test so that the OS cannot deliver the
# signal to the wrong one.
-class InterruptedTimeoutBase(unittest.TestCase):
+class InterruptedTimeoutBase:
# Base class for interrupted send/receive tests. Installs an
# empty handler for SIGALRM and removes it on teardown, along with
# any scheduled alarms.
@@ -4145,7 +4450,7 @@ def setUp(self):
self.addCleanup(signal.signal, signal.SIGALRM, orig_alrm_handler)
# Timeout for socket operations
- timeout = 4.0
+ timeout = support.LOOPBACK_TIMEOUT
# Provide setAlarm() method to schedule delivery of SIGALRM after
# given number of seconds, or cancel it if zero, and an
@@ -4248,12 +4553,10 @@ def checkInterruptedSend(self, func, *args, **kwargs):
self.setAlarm(0)
# Issue #12958: The following tests have problems on OS X prior to 10.7
- @unittest.skip("TODO: RUSTPYTHON, plistlib")
@support.requires_mac_ver(10, 7)
def testInterruptedSendTimeout(self):
self.checkInterruptedSend(self.serv_conn.send, b"a"*512)
- @unittest.skip("TODO: RUSTPYTHON, plistlib")
@support.requires_mac_ver(10, 7)
def testInterruptedSendtoTimeout(self):
# Passing an actual address here as Python's wrapper for
@@ -4263,7 +4566,6 @@ def testInterruptedSendtoTimeout(self):
self.checkInterruptedSend(self.serv_conn.sendto, b"a"*512,
self.serv_addr)
- @unittest.skip("TODO: RUSTPYTHON, plistlib")
@support.requires_mac_ver(10, 7)
@requireAttrs(socket.socket, "sendmsg")
def testInterruptedSendmsgTimeout(self):
@@ -4429,7 +4731,7 @@ def _testInheritFlagsTimeout(self):
def testAccept(self):
# Testing non-blocking accept
- self.serv.setblocking(0)
+ self.serv.setblocking(False)
# connect() didn't start: non-blocking accept() fails
start_time = time.monotonic()
@@ -4440,7 +4742,7 @@ def testAccept(self):
self.event.set()
- read, write, err = select.select([self.serv], [], [], MAIN_TIMEOUT)
+ read, write, err = select.select([self.serv], [], [], support.LONG_TIMEOUT)
if self.serv not in read:
self.fail("Error trying to do accept after select.")
@@ -4460,7 +4762,7 @@ def testRecv(self):
# Testing non-blocking recv
conn, addr = self.serv.accept()
self.addCleanup(conn.close)
- conn.setblocking(0)
+ conn.setblocking(False)
# the server didn't send data yet: non-blocking recv() fails
with self.assertRaises(BlockingIOError):
@@ -4468,7 +4770,7 @@ def testRecv(self):
self.event.set()
- read, write, err = select.select([conn], [], [], MAIN_TIMEOUT)
+ read, write, err = select.select([conn], [], [], support.LONG_TIMEOUT)
if conn not in read:
self.fail("Error during select call to non-blocking socket.")
@@ -4550,7 +4852,7 @@ def testReadAfterTimeout(self):
self.cli_conn.settimeout(1)
self.read_file.read(3)
# First read raises a timeout
- self.assertRaises(socket.timeout, self.read_file.read, 1)
+ self.assertRaises(TimeoutError, self.read_file.read, 1)
# Second read is disallowed
with self.assertRaises(OSError) as ctx:
self.read_file.read(1)
@@ -4815,7 +5117,7 @@ class NetworkConnectionNoServer(unittest.TestCase):
class MockSocket(socket.socket):
def connect(self, *args):
- raise socket.timeout('timed out')
+ raise TimeoutError('timed out')
@contextlib.contextmanager
def mocked_socket_module(self):
@@ -4865,13 +5167,13 @@ def test_create_connection_timeout(self):
with self.mocked_socket_module():
try:
socket.create_connection((HOST, 1234))
- except socket.timeout:
+ except TimeoutError:
pass
except OSError as exc:
if socket_helper.IPV6_ENABLED or exc.errno != errno.EAFNOSUPPORT:
raise
else:
- self.fail('socket.timeout not raised')
+ self.fail('TimeoutError not raised')
class NetworkConnectionAttributesTest(SocketTCPTest, ThreadableTest):
@@ -4894,14 +5196,16 @@ def _justAccept(self):
testFamily = _justAccept
def _testFamily(self):
- self.cli = socket.create_connection((HOST, self.port), timeout=30)
+ self.cli = socket.create_connection((HOST, self.port),
+ timeout=support.LOOPBACK_TIMEOUT)
self.addCleanup(self.cli.close)
self.assertEqual(self.cli.family, 2)
testSourceAddress = _justAccept
def _testSourceAddress(self):
- self.cli = socket.create_connection((HOST, self.port), timeout=30,
- source_address=('', self.source_port))
+ self.cli = socket.create_connection((HOST, self.port),
+ timeout=support.LOOPBACK_TIMEOUT,
+ source_address=('', self.source_port))
self.addCleanup(self.cli.close)
self.assertEqual(self.cli.getsockname()[1], self.source_port)
# The port number being used is sufficient to show that the bind()
@@ -4971,7 +5275,7 @@ def _testInsideTimeout(self):
def _testOutsideTimeout(self):
self.cli = sock = socket.create_connection((HOST, self.port), timeout=1)
- self.assertRaises(socket.timeout, lambda: sock.recv(5))
+ self.assertRaises(TimeoutError, lambda: sock.recv(5))
class TCPTimeoutTest(SocketTCPTest):
@@ -4980,7 +5284,7 @@ def testTCPTimeout(self):
def raise_timeout(*args, **kwargs):
self.serv.settimeout(1.0)
self.serv.accept()
- self.assertRaises(socket.timeout, raise_timeout,
+ self.assertRaises(TimeoutError, raise_timeout,
"Error generating a timeout exception (TCP)")
def testTimeoutZero(self):
@@ -4988,7 +5292,7 @@ def testTimeoutZero(self):
try:
self.serv.settimeout(0.0)
foo = self.serv.accept()
- except socket.timeout:
+ except TimeoutError:
self.fail("caught timeout instead of error (TCP)")
except OSError:
ok = True
@@ -5013,7 +5317,7 @@ def alarm_handler(signal, frame):
try:
signal.alarm(2) # POSIX allows alarm to be up to 1 second early
foo = self.serv.accept()
- except socket.timeout:
+ except TimeoutError:
self.fail("caught timeout instead of Alarm")
except Alarm:
pass
@@ -5037,7 +5341,7 @@ def testUDPTimeout(self):
def raise_timeout(*args, **kwargs):
self.serv.settimeout(1.0)
self.serv.recv(1024)
- self.assertRaises(socket.timeout, raise_timeout,
+ self.assertRaises(TimeoutError, raise_timeout,
"Error generating a timeout exception (UDP)")
def testTimeoutZero(self):
@@ -5045,7 +5349,7 @@ def testTimeoutZero(self):
try:
self.serv.settimeout(0.0)
foo = self.serv.recv(1024)
- except socket.timeout:
+ except TimeoutError:
self.fail("caught timeout instead of error (UDP)")
except OSError:
ok = True
@@ -5054,6 +5358,31 @@ def testTimeoutZero(self):
if not ok:
self.fail("recv() returned success when we did not expect it")
+@unittest.skipUnless(HAVE_SOCKET_UDPLITE,
+ 'UDPLITE sockets required for this test.')
+class UDPLITETimeoutTest(SocketUDPLITETest):
+
+ def testUDPLITETimeout(self):
+ def raise_timeout(*args, **kwargs):
+ self.serv.settimeout(1.0)
+ self.serv.recv(1024)
+ self.assertRaises(TimeoutError, raise_timeout,
+ "Error generating a timeout exception (UDPLITE)")
+
+ def testTimeoutZero(self):
+ ok = False
+ try:
+ self.serv.settimeout(0.0)
+ foo = self.serv.recv(1024)
+ except TimeoutError:
+ self.fail("caught timeout instead of error (UDPLITE)")
+ except OSError:
+ ok = True
+ except:
+ self.fail("caught unexpected exception (UDPLITE)")
+ if not ok:
+ self.fail("recv() returned success when we did not expect it")
+
class TestExceptions(unittest.TestCase):
def testExceptionTree(self):
@@ -5061,6 +5390,8 @@ def testExceptionTree(self):
self.assertTrue(issubclass(socket.herror, OSError))
self.assertTrue(issubclass(socket.gaierror, OSError))
self.assertTrue(issubclass(socket.timeout, OSError))
+ self.assertIs(socket.error, OSError)
+ self.assertIs(socket.timeout, TimeoutError)
def test_setblocking_invalidfd(self):
# Regression test for issue #28471
@@ -5117,6 +5448,20 @@ def testBytearrayName(self):
s.bind(bytearray(b"\x00python\x00test\x00"))
self.assertEqual(s.getsockname(), b"\x00python\x00test\x00")
+ def testAutobind(self):
+ # Check that binding to an empty string binds to an available address
+ # in the abstract namespace as specified in unix(7) "Autobind feature".
+ abstract_address = b"^\0[0-9a-f]{5}"
+ with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s1:
+ s1.bind("")
+ self.assertRegex(s1.getsockname(), abstract_address)
+ # Each socket is bound to a different abstract address.
+ with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s2:
+ s2.bind("")
+ self.assertRegex(s2.getsockname(), abstract_address)
+ self.assertNotEqual(s1.getsockname(), s2.getsockname())
+
+
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'test needs socket.AF_UNIX')
class TestUnixDomain(unittest.TestCase):
@@ -5172,7 +5517,7 @@ def testBytesAddr(self):
def testSurrogateescapeBind(self):
# Test binding to a valid non-ASCII pathname, with the
# non-ASCII bytes supplied using surrogateescape encoding.
- path = os.path.abspath(support.TESTFN_UNICODE)
+ path = os.path.abspath(os_helper.TESTFN_UNICODE)
b = self.encoded(path)
self.bind(self.sock, b.decode("ascii", "surrogateescape"))
self.addCleanup(os_helper.unlink, path)
@@ -5188,6 +5533,11 @@ def testUnencodableAddr(self):
self.addCleanup(os_helper.unlink, path)
self.assertEqual(self.sock.getsockname(), path)
+ @unittest.skipIf(sys.platform == 'linux', 'Linux specific test')
+ def testEmptyAddress(self):
+ # Test that binding empty address fails.
+ self.assertRaises(OSError, self.sock.bind, "")
+
class BufferIOTest(SocketConnectedTest):
"""
@@ -5285,7 +5635,7 @@ def isTipcAvailable():
if not hasattr(socket, "AF_TIPC"):
return False
try:
- f = open("/proc/modules")
+ f = open("/proc/modules", encoding="utf-8")
except (FileNotFoundError, IsADirectoryError, PermissionError):
# It's ok if the file does not exist, is a directory or if we
# have not the permission to read it.
@@ -5511,15 +5861,15 @@ def test_SOCK_NONBLOCK(self):
with socket.socket(socket.AF_INET,
socket.SOCK_STREAM | socket.SOCK_NONBLOCK) as s:
self.checkNonblock(s)
- s.setblocking(1)
+ s.setblocking(True)
self.checkNonblock(s, nonblock=False)
- s.setblocking(0)
+ s.setblocking(False)
self.checkNonblock(s)
s.settimeout(None)
self.checkNonblock(s, nonblock=False)
s.settimeout(2.0)
self.checkNonblock(s, timeout=2.0)
- s.setblocking(1)
+ s.setblocking(True)
self.checkNonblock(s, nonblock=False)
# defaulttimeout
t = socket.getdefaulttimeout()
@@ -5650,7 +6000,7 @@ class SendfileUsingSendTest(ThreadedTCPSocketTest):
FILESIZE = (10 * 1024 * 1024) # 10 MiB
BUFSIZE = 8192
FILEDATA = b""
- TIMEOUT = 2
+ TIMEOUT = support.LOOPBACK_TIMEOUT
@classmethod
def setUpClass(cls):
@@ -5676,7 +6026,7 @@ def tearDownClass(cls):
os_helper.unlink(os_helper.TESTFN)
def accept_conn(self):
- self.serv.settimeout(MAIN_TIMEOUT)
+ self.serv.settimeout(support.LONG_TIMEOUT)
conn, addr = self.serv.accept()
conn.settimeout(self.TIMEOUT)
self.addCleanup(conn.close)
@@ -5772,7 +6122,9 @@ def testOffset(self):
def _testCount(self):
address = self.serv.getsockname()
file = open(os_helper.TESTFN, 'rb')
- with socket.create_connection(address, timeout=2) as sock, file as file:
+ sock = socket.create_connection(address,
+ timeout=support.LOOPBACK_TIMEOUT)
+ with sock, file:
count = 5000007
meth = self.meth_from_sock(sock)
sent = meth(file, count=count)
@@ -5792,7 +6144,9 @@ def testCount(self):
def _testCountSmall(self):
address = self.serv.getsockname()
file = open(os_helper.TESTFN, 'rb')
- with socket.create_connection(address, timeout=2) as sock, file as file:
+ sock = socket.create_connection(address,
+ timeout=support.LOOPBACK_TIMEOUT)
+ with sock, file:
count = 1
meth = self.meth_from_sock(sock)
sent = meth(file, count=count)
@@ -5846,12 +6200,13 @@ def testNonBlocking(self):
def _testWithTimeout(self):
address = self.serv.getsockname()
file = open(os_helper.TESTFN, 'rb')
- with socket.create_connection(address, timeout=2) as sock, file as file:
+ sock = socket.create_connection(address,
+ timeout=support.LOOPBACK_TIMEOUT)
+ with sock, file:
meth = self.meth_from_sock(sock)
sent = meth(file)
self.assertEqual(sent, self.FILESIZE)
- @unittest.skipIf(sys.platform == "darwin", "TODO: RUSTPYTHON, killed (for OOM?)")
def testWithTimeout(self):
conn = self.accept_conn()
data = self.recv_data(conn)
@@ -5866,11 +6221,12 @@ def _testWithTimeoutTriggeredSend(self):
with socket.create_connection(address) as sock:
sock.settimeout(0.01)
meth = self.meth_from_sock(sock)
- self.assertRaises(socket.timeout, meth, file)
+ self.assertRaises(TimeoutError, meth, file)
def testWithTimeoutTriggeredSend(self):
conn = self.accept_conn()
conn.recv(88192)
+ time.sleep(1)
# errors
@@ -5883,7 +6239,7 @@ def test_errors(self):
meth = self.meth_from_sock(s)
self.assertRaisesRegex(
ValueError, "SOCK_STREAM", meth, file)
- with open(os_helper.TESTFN, 'rt') as file:
+ with open(os_helper.TESTFN, encoding="utf-8") as file:
with socket.socket() as s:
meth = self.meth_from_sock(s)
self.assertRaisesRegex(
@@ -6121,6 +6477,12 @@ def test_length_restriction(self):
sock.bind(("type", "n" * 64))
+@unittest.skipUnless(sys.platform == 'darwin', 'macOS specific test')
+class TestMacOSTCPFlags(unittest.TestCase):
+ def test_tcp_keepalive(self):
+ self.assertTrue(socket.TCP_KEEPALIVE)
+
+
@unittest.skipUnless(sys.platform.startswith("win"), "requires Windows")
class TestMSWindowsTCPFlags(unittest.TestCase):
knownTCPFlags = {
@@ -6196,14 +6558,7 @@ def test_dualstack_ipv6_family(self):
class CreateServerFunctionalTest(unittest.TestCase):
- timeout = 3
-
- def setUp(self):
- self.thread = None
-
- def tearDown(self):
- if self.thread is not None:
- self.thread.join(self.timeout)
+ timeout = support.LOOPBACK_TIMEOUT
def echo_server(self, sock):
def run(sock):
@@ -6218,8 +6573,9 @@ def run(sock):
event = threading.Event()
sock.settimeout(self.timeout)
- self.thread = threading.Thread(target=run, args=(sock, ))
- self.thread.start()
+ thread = threading.Thread(target=run, args=(sock, ))
+ thread.start()
+ self.addCleanup(thread.join, self.timeout)
event.set()
def echo_client(self, addr, family):
@@ -6265,73 +6621,52 @@ def test_dual_stack_client_v6(self):
self.echo_server(sock)
self.echo_client(("::1", port), socket.AF_INET6)
+@requireAttrs(socket, "send_fds")
+@requireAttrs(socket, "recv_fds")
+@requireAttrs(socket, "AF_UNIX")
+class SendRecvFdsTests(unittest.TestCase):
+ def testSendAndRecvFds(self):
+ def close_pipes(pipes):
+ for fd1, fd2 in pipes:
+ os.close(fd1)
+ os.close(fd2)
+
+ def close_fds(fds):
+ for fd in fds:
+ os.close(fd)
+
+ # send 10 file descriptors
+ pipes = [os.pipe() for _ in range(10)]
+ self.addCleanup(close_pipes, pipes)
+ fds = [rfd for rfd, wfd in pipes]
+
+ # use a UNIX socket pair to exchange file descriptors locally
+ sock1, sock2 = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM)
+ with sock1, sock2:
+ socket.send_fds(sock1, [MSG], fds)
+ # request more data and file descriptors than expected
+ msg, fds2, flags, addr = socket.recv_fds(sock2, len(MSG) * 2, len(fds) * 2)
+ self.addCleanup(close_fds, fds2)
+
+ self.assertEqual(msg, MSG)
+ self.assertEqual(len(fds2), len(fds))
+ self.assertEqual(flags, 0)
+ # don't test addr
-def test_main():
- tests = [GeneralModuleTests, BasicTCPTest, TCPCloserTest, TCPTimeoutTest,
- TestExceptions, BufferIOTest, BasicTCPTest2, BasicUDPTest,
- UDPTimeoutTest, CreateServerTest, CreateServerFunctionalTest]
-
- tests.extend([
- NonBlockingTCPTests,
- FileObjectClassTestCase,
- UnbufferedFileObjectClassTestCase,
- LineBufferedFileObjectClassTestCase,
- SmallBufferedFileObjectClassTestCase,
- UnicodeReadFileObjectClassTestCase,
- UnicodeWriteFileObjectClassTestCase,
- UnicodeReadWriteFileObjectClassTestCase,
- NetworkConnectionNoServer,
- NetworkConnectionAttributesTest,
- NetworkConnectionBehaviourTest,
- ContextManagersTest,
- InheritanceTest,
- NonblockConstantTest
- ])
- tests.append(BasicSocketPairTest)
- tests.append(TestUnixDomain)
- tests.append(TestLinuxAbstractNamespace)
- tests.extend([TIPCTest, TIPCThreadableTest])
- tests.extend([BasicCANTest, CANTest])
- tests.extend([BasicRDSTest, RDSTest])
- tests.append(LinuxKernelCryptoAPI)
- tests.append(BasicQIPCRTRTest)
- tests.extend([
- BasicVSOCKTest,
- ThreadedVSOCKSocketStreamTest,
- ])
- tests.extend([
- CmsgMacroTests,
- SendmsgUDPTest,
- RecvmsgUDPTest,
- RecvmsgIntoUDPTest,
- SendmsgUDP6Test,
- RecvmsgUDP6Test,
- RecvmsgRFC3542AncillaryUDP6Test,
- RecvmsgIntoRFC3542AncillaryUDP6Test,
- RecvmsgIntoUDP6Test,
- SendmsgTCPTest,
- RecvmsgTCPTest,
- RecvmsgIntoTCPTest,
- SendmsgSCTPStreamTest,
- RecvmsgSCTPStreamTest,
- RecvmsgIntoSCTPStreamTest,
- SendmsgUnixStreamTest,
- RecvmsgUnixStreamTest,
- RecvmsgIntoUnixStreamTest,
- RecvmsgSCMRightsStreamTest,
- RecvmsgIntoSCMRightsStreamTest,
- # These are slow when setitimer() is not available
- InterruptedRecvTimeoutTest,
- InterruptedSendTimeoutTest,
- TestSocketSharing,
- SendfileUsingSendTest,
- SendfileUsingSendfileTest,
- ])
- tests.append(TestMSWindowsTCPFlags)
+ # test that file descriptors are connected
+ for index, fds in enumerate(pipes):
+ rfd, wfd = fds
+ os.write(wfd, str(index).encode())
+ for index, rfd in enumerate(fds2):
+ data = os.read(rfd, 100)
+ self.assertEqual(data, str(index).encode())
+
+
+def setUpModule():
thread_info = threading_helper.threading_setup()
- support.run_unittest(*tests)
- threading_helper.threading_cleanup(*thread_info)
+ unittest.addModuleCleanup(threading_helper.threading_cleanup, *thread_info)
+
if __name__ == "__main__":
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_socketserver.py b/Lib/test/test_socketserver.py
index 6af948e922..0d089edc9b 100644
--- a/Lib/test/test_socketserver.py
+++ b/Lib/test/test_socketserver.py
@@ -39,7 +39,7 @@ def signal_alarm(n):
# Remember real select() to avoid interferences with mocking
_real_select = select.select
-def receive(sock, n, timeout=20):
+def receive(sock, n, timeout=test.support.SHORT_TIMEOUT):
r, w, x = _real_select([sock], [], [], timeout)
if sock in r:
return sock.recv(n)
@@ -68,9 +68,7 @@ def simple_subprocess(testcase):
except:
raise
finally:
- pid2, status = os.waitpid(pid, 0)
- testcase.assertEqual(pid2, pid)
- testcase.assertEqual(72 << 8, status)
+ test.support.wait_process(pid, exitcode=72)
class SocketServerTest(unittest.TestCase):
@@ -345,8 +343,11 @@ def test_threading_handled(self):
self.check_result(handled=True)
def test_threading_not_handled(self):
- ThreadingErrorTestServer(SystemExit)
- self.check_result(handled=False)
+ with threading_helper.catch_threading_exception() as cm:
+ ThreadingErrorTestServer(SystemExit)
+ self.check_result(handled=False)
+
+ self.assertIs(cm.exc_type, SystemExit)
@requires_forking
def test_forking_handled(self):
From 19504a6b9dd6d00d64be545d2429494181bb7ec7 Mon Sep 17 00:00:00 2001
From: Jeong YunWon
Date: Mon, 15 Aug 2022 00:11:01 +0900
Subject: [PATCH 15/23] mark failing tests again
---
Lib/test/test_socket.py | 7 +++++++
Lib/test/test_socketserver.py | 2 --
2 files changed, 7 insertions(+), 2 deletions(-)
diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py
index 80ba3d5f3e..d6cba6b8b3 100644
--- a/Lib/test/test_socket.py
+++ b/Lib/test/test_socket.py
@@ -958,6 +958,8 @@ def testWindowsSpecificConstants(self):
socket.IPPROTO_L2TP
socket.IPPROTO_SCTP
+ # TODO: RUSTPYTHON
+ @unittest.expectedFailure
@unittest.skipUnless(sys.platform == 'darwin', 'macOS specific test')
@unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test')
def test3542SocketOptions(self):
@@ -2190,6 +2192,8 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.interface = "vcan0"
+ # TODO: RUSTPYTHON
+ @unittest.expectedFailure
@unittest.skipUnless(hasattr(socket, "CAN_J1939"),
'socket.CAN_J1939 required for this test.')
def testJ1939Constants(self):
@@ -2231,6 +2235,8 @@ def testCreateJ1939Socket(self):
with socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_J1939) as s:
pass
+ # TODO: RUSTPYTHON
+ @unittest.expectedFailure
def testBind(self):
try:
with socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_J1939) as s:
@@ -6207,6 +6213,7 @@ def _testWithTimeout(self):
sent = meth(file)
self.assertEqual(sent, self.FILESIZE)
+ @unittest.skip("TODO: RUSTPYTHON")
def testWithTimeout(self):
conn = self.accept_conn()
data = self.recv_data(conn)
diff --git a/Lib/test/test_socketserver.py b/Lib/test/test_socketserver.py
index 0d089edc9b..92b0ad71c6 100644
--- a/Lib/test/test_socketserver.py
+++ b/Lib/test/test_socketserver.py
@@ -521,8 +521,6 @@ def shutdown_request(self, request):
self.assertEqual(server.shutdown_called, 1)
server.server_close()
- # TODO: RUSTPYTHON
- @unittest.expectedFailure
def test_threads_reaped(self):
"""
In #37193, users reported a memory leak
From b6523ddb90eb7e4c8e5d74492d36b124ac624b8d Mon Sep 17 00:00:00 2001
From: CPython Developers <>
Date: Mon, 15 Aug 2022 00:17:17 +0900
Subject: [PATCH 16/23] Update secrets from CPython 3.10.6
---
Lib/secrets.py | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/Lib/secrets.py b/Lib/secrets.py
index 130434229e..a546efbdd4 100644
--- a/Lib/secrets.py
+++ b/Lib/secrets.py
@@ -14,7 +14,6 @@
import base64
import binascii
-import os
from hmac import compare_digest
from random import SystemRandom
@@ -44,7 +43,7 @@ def token_bytes(nbytes=None):
"""
if nbytes is None:
nbytes = DEFAULT_ENTROPY
- return os.urandom(nbytes)
+ return _sysrand.randbytes(nbytes)
def token_hex(nbytes=None):
"""Return a random text string, in hexadecimal.
From 56cc13f81d1585de6b40a5d8e26ebb37c4078585 Mon Sep 17 00:00:00 2001
From: CPython Developers <>
Date: Mon, 15 Aug 2022 00:20:01 +0900
Subject: [PATCH 17/23] Update test_scope from CPython 3.10.6
---
Lib/test/test_scope.py | 44 +++++++++++++++---------------------------
1 file changed, 16 insertions(+), 28 deletions(-)
diff --git a/Lib/test/test_scope.py b/Lib/test/test_scope.py
index be868e50e2..29a4ac3c16 100644
--- a/Lib/test/test_scope.py
+++ b/Lib/test/test_scope.py
@@ -2,11 +2,11 @@
import weakref
from test.support import check_syntax_error, cpython_only
+from test.support import gc_collect
class ScopeTests(unittest.TestCase):
-
def testSimpleNesting(self):
def make_adder(x):
@@ -20,7 +20,6 @@ def adder(y):
self.assertEqual(inc(1), 2)
self.assertEqual(plus10(-2), 8)
-
def testExtraNesting(self):
def make_adder2(x):
@@ -36,7 +35,6 @@ def adder(y):
self.assertEqual(inc(1), 2)
self.assertEqual(plus10(-2), 8)
-
def testSimpleAndRebinding(self):
def make_adder3(x):
@@ -50,8 +48,7 @@ def adder(y):
self.assertEqual(inc(1), 2)
self.assertEqual(plus10(-2), 8)
-
-
+
def testNestingGlobalNoFree(self):
def make_adder4(): # XXX add exta level of indirection
@@ -70,7 +67,6 @@ def adder(y):
global_x = 10
self.assertEqual(adder(-2), 8)
-
def testNestingThroughClass(self):
def make_adder5(x):
@@ -85,7 +81,6 @@ def __call__(self, y):
self.assertEqual(inc(1), 2)
self.assertEqual(plus10(-2), 8)
-
def testNestingPlusFreeRefToGlobal(self):
def make_adder6(x):
@@ -101,7 +96,6 @@ def adder(y):
self.assertEqual(inc(1), 11) # there's only one global
self.assertEqual(plus10(-2), 8)
-
def testNearestEnclosingScope(self):
def f(x):
@@ -115,7 +109,6 @@ def h(z):
test_func = f(10)
self.assertEqual(test_func(5), 47)
-
def testMixedFreevarsAndCellvars(self):
def identity(x):
@@ -173,7 +166,6 @@ def str(self):
self.assertEqual(t.method_and_var(), "method")
self.assertEqual(t.actual_global(), "global")
-
def testCellIsKwonlyArg(self):
# Issue 1409: Initialisation of a cell value,
# when it comes from a keyword-only parameter
@@ -185,7 +177,6 @@ def bar():
self.assertEqual(foo(a=42), 50)
self.assertEqual(foo(), 25)
-
def testRecursion(self):
def f(x):
@@ -201,6 +192,7 @@ def fact(n):
self.assertEqual(f(6), 720)
+
def testUnoptimizedNamespaces(self):
check_syntax_error(self, """if 1:
@@ -235,7 +227,6 @@ def g():
return getrefcount # global or local?
""")
-
def testLambdas(self):
f1 = lambda x: lambda y: x + y
@@ -312,8 +303,6 @@ def f():
fail('scope of global_x not correctly determined')
""", {'fail': self.fail})
-
-
def testComplexDefinitions(self):
def makeReturner(*lst):
@@ -348,6 +337,7 @@ def h():
return g()
self.assertEqual(f(), 7)
self.assertEqual(x, 7)
+
# II
x = 7
def f():
@@ -362,6 +352,7 @@ def h():
return g()
self.assertEqual(f(), 2)
self.assertEqual(x, 7)
+
# III
x = 7
def f():
@@ -377,6 +368,7 @@ def h():
return g()
self.assertEqual(f(), 2)
self.assertEqual(x, 2)
+
# IV
x = 7
def f():
@@ -392,8 +384,10 @@ def h():
return g()
self.assertEqual(f(), 2)
self.assertEqual(x, 2)
+
# XXX what about global statements in class blocks?
# do they affect methods?
+
x = 12
class Global:
global x
@@ -402,6 +396,7 @@ def set(self, val):
x = val
def get(self):
return x
+
g = Global()
self.assertEqual(g.get(), 13)
g.set(15)
@@ -428,6 +423,7 @@ def f2():
for i in range(100):
f1()
+ gc_collect() # For PyPy or other GCs.
self.assertEqual(Foo.count, 0)
def testClassAndGlobal(self):
@@ -439,16 +435,19 @@ class Foo:
def __call__(self, y):
return x + y
return Foo()
+
x = 0
self.assertEqual(test(6)(2), 8)
x = -1
self.assertEqual(test(3)(2), 5)
+
looked_up_by_load_name = False
class X:
# Implicit globals inside classes are be looked up by LOAD_NAME, not
# LOAD_GLOBAL.
locals()['looked_up_by_load_name'] = True
passed = looked_up_by_load_name
+
self.assertTrue(X.passed)
""")
@@ -468,7 +467,6 @@ def h(z):
del d['h']
self.assertEqual(d, {'x': 2, 'y': 7, 'w': 6})
-
def testLocalsClass(self):
# This test verifies that calling locals() does not pollute
# the local namespace of the class with free variables. Old
@@ -502,7 +500,6 @@ def m(self):
self.assertNotIn("x", varnames)
self.assertIn("y", varnames)
-
@cpython_only
def testLocalsClass_WithTrace(self):
# Issue23728: after the trace function returns, the locals()
@@ -520,7 +517,6 @@ def f(self):
self.assertEqual(x, 12) # Used to raise UnboundLocalError
-
def testBoundAndFree(self):
# var is bound and free in class
@@ -534,7 +530,6 @@ def m(self):
inst = f(3)()
self.assertEqual(inst.a, inst.m())
-
@cpython_only
def testInteractionWithTraceFunc(self):
@@ -574,7 +569,6 @@ def f(x):
else:
self.fail("exec should have failed, because code contained free vars")
-
def testListCompLocalVars(self):
try:
@@ -603,7 +597,6 @@ def g():
f(4)()
-
def testFreeingCell(self):
# Test what happens when a finalizer accesses
# the cell where the object was stored.
@@ -611,7 +604,6 @@ class Special:
def __del__(self):
nestedcell_get()
-
def testNonLocalFunction(self):
def f(x):
@@ -631,7 +623,6 @@ def dec():
self.assertEqual(dec(), 1)
self.assertEqual(dec(), 0)
-
def testNonLocalMethod(self):
def f(x):
class c:
@@ -678,7 +669,6 @@ def h():
self.assertEqual(2, global_ns["result2"])
self.assertEqual(9, global_ns["result9"])
-
def testNonLocalClass(self):
def f(x):
@@ -693,7 +683,7 @@ def get(self):
self.assertEqual(c.get(), 1)
self.assertNotIn("x", c.__class__.__dict__)
-
+
def testNonLocalGenerator(self):
def f(x):
@@ -707,7 +697,6 @@ def g(y):
g = f(0)
self.assertEqual(list(g(5)), [1, 2, 3, 4, 5])
-
def testNestedNonLocal(self):
def f(x):
@@ -725,7 +714,6 @@ def h():
h = g()
self.assertEqual(h(), 3)
-
def testTopIsNotSignificant(self):
# See #9997.
def top(a):
@@ -746,7 +734,6 @@ class X:
self.assertFalse(hasattr(X, "x"))
self.assertEqual(x, 42)
-
@cpython_only
def testCellLeak(self):
# Issue 17927.
@@ -773,8 +760,9 @@ def dig(self):
tester.dig()
ref = weakref.ref(tester)
del tester
+ gc_collect() # For PyPy or other GCs.
self.assertIsNone(ref())
if __name__ == '__main__':
- unittest.main()
\ No newline at end of file
+ unittest.main()
From 95e16a94fa627c24add1ee0b81f3687be74d36bf Mon Sep 17 00:00:00 2001
From: CPython Developers <>
Date: Mon, 15 Aug 2022 00:21:08 +0900
Subject: [PATCH 18/23] Update sched from CPython 3.10.6
---
Lib/sched.py | 20 ++++++++++----------
Lib/test/test_sched.py | 11 +++++++++++
2 files changed, 21 insertions(+), 10 deletions(-)
diff --git a/Lib/sched.py b/Lib/sched.py
index ff87874a3a..14613cf298 100644
--- a/Lib/sched.py
+++ b/Lib/sched.py
@@ -26,23 +26,19 @@
import time
import heapq
from collections import namedtuple
+from itertools import count
import threading
from time import monotonic as _time
__all__ = ["scheduler"]
-class Event(namedtuple('Event', 'time, priority, action, argument, kwargs')):
- __slots__ = []
- def __eq__(s, o): return (s.time, s.priority) == (o.time, o.priority)
- def __lt__(s, o): return (s.time, s.priority) < (o.time, o.priority)
- def __le__(s, o): return (s.time, s.priority) <= (o.time, o.priority)
- def __gt__(s, o): return (s.time, s.priority) > (o.time, o.priority)
- def __ge__(s, o): return (s.time, s.priority) >= (o.time, o.priority)
-
+Event = namedtuple('Event', 'time, priority, sequence, action, argument, kwargs')
Event.time.__doc__ = ('''Numeric type compatible with the return value of the
timefunc function passed to the constructor.''')
Event.priority.__doc__ = ('''Events scheduled for the same time will be executed
in the order of their priority.''')
+Event.sequence.__doc__ = ('''A continually increasing sequence number that
+ separates events if time and priority are equal.''')
Event.action.__doc__ = ('''Executing the event means executing
action(*argument, **kwargs)''')
Event.argument.__doc__ = ('''argument is a sequence holding the positional
@@ -61,6 +57,7 @@ def __init__(self, timefunc=_time, delayfunc=time.sleep):
self._lock = threading.RLock()
self.timefunc = timefunc
self.delayfunc = delayfunc
+ self._sequence_generator = count()
def enterabs(self, time, priority, action, argument=(), kwargs=_sentinel):
"""Enter a new event in the queue at an absolute time.
@@ -71,8 +68,10 @@ def enterabs(self, time, priority, action, argument=(), kwargs=_sentinel):
"""
if kwargs is _sentinel:
kwargs = {}
- event = Event(time, priority, action, argument, kwargs)
+
with self._lock:
+ event = Event(time, priority, next(self._sequence_generator),
+ action, argument, kwargs)
heapq.heappush(self._queue, event)
return event # The ID
@@ -136,7 +135,8 @@ def run(self, blocking=True):
with lock:
if not q:
break
- time, priority, action, argument, kwargs = q[0]
+ (time, priority, sequence, action,
+ argument, kwargs) = q[0]
now = timefunc()
if time > now:
delay = True
diff --git a/Lib/test/test_sched.py b/Lib/test/test_sched.py
index 491d7b3a74..7ae7baae85 100644
--- a/Lib/test/test_sched.py
+++ b/Lib/test/test_sched.py
@@ -142,6 +142,17 @@ def test_cancel_concurrent(self):
self.assertTrue(q.empty())
self.assertEqual(timer.time(), 4)
+ def test_cancel_correct_event(self):
+ # bpo-19270
+ events = []
+ scheduler = sched.scheduler()
+ scheduler.enterabs(1, 1, events.append, ("a",))
+ b = scheduler.enterabs(1, 1, events.append, ("b",))
+ scheduler.enterabs(1, 1, events.append, ("c",))
+ scheduler.cancel(b)
+ scheduler.run()
+ self.assertEqual(events, ["a", "c"])
+
def test_empty(self):
l = []
fun = lambda x: l.append(x)
From 628c4c0edf7a76d03749f6dd4c772e40bd3803cf Mon Sep 17 00:00:00 2001
From: CPython Developers <>
Date: Mon, 15 Aug 2022 00:22:40 +0900
Subject: [PATCH 19/23] Update runpy from CPython 3.10.6
---
Lib/runpy.py | 49 +++++++++++++++++++++++++-------
Lib/test/test_cmd_line_script.py | 2 --
Lib/test/test_runpy.py | 7 ++---
3 files changed, 41 insertions(+), 17 deletions(-)
diff --git a/Lib/runpy.py b/Lib/runpy.py
index d86f0e4a3c..c7d3d8caad 100644
--- a/Lib/runpy.py
+++ b/Lib/runpy.py
@@ -13,8 +13,9 @@
import sys
import importlib.machinery # importlib first so we can test #15386 via -m
import importlib.util
+import io
import types
-from pkgutil import read_code, get_importer
+import os
__all__ = [
"run_module", "run_path",
@@ -131,6 +132,9 @@ def _get_module_details(mod_name, error=ImportError):
# importlib, where the latter raises other errors for cases where
# pkgutil previously raised ImportError
msg = "Error while finding module specification for {!r} ({}: {})"
+ if mod_name.endswith(".py"):
+ msg += (f". Try using '{mod_name[:-3]}' instead of "
+ f"'{mod_name}' as the module name.")
raise error(msg.format(mod_name, type(ex).__name__, ex)) from ex
if spec is None:
raise error("No module named %s" % mod_name)
@@ -194,9 +198,24 @@ def _run_module_as_main(mod_name, alter_argv=True):
def run_module(mod_name, init_globals=None,
run_name=None, alter_sys=False):
- """Execute a module's code without importing it
+ """Execute a module's code without importing it.
- Returns the resulting top level namespace dictionary
+ mod_name -- an absolute module name or package name.
+
+ Optional arguments:
+ init_globals -- dictionary used to pre-populate the module’s
+ globals dictionary before the code is executed.
+
+ run_name -- if not None, this will be used for setting __name__;
+ otherwise, __name__ will be set to mod_name + '__main__' if the
+ named module is a package and to just mod_name otherwise.
+
+ alter_sys -- if True, sys.argv[0] is updated with the value of
+ __file__ and sys.modules[__name__] is updated with a temporary
+ module object for the module being executed. Both are
+ restored to their original values before the function returns.
+
+ Returns the resulting module globals dictionary.
"""
mod_name, mod_spec, code = _get_module_details(mod_name)
if run_name is None:
@@ -228,27 +247,35 @@ def _get_main_module_details(error=ImportError):
def _get_code_from_file(run_name, fname):
# Check for a compiled file first
- with open(fname, "rb") as f:
+ from pkgutil import read_code
+ decoded_path = os.path.abspath(os.fsdecode(fname))
+ with io.open_code(decoded_path) as f:
code = read_code(f)
if code is None:
# That didn't work, so try it as normal source code
- with open(fname, "rb") as f:
+ with io.open_code(decoded_path) as f:
code = compile(f.read(), fname, 'exec')
return code, fname
def run_path(path_name, init_globals=None, run_name=None):
- """Execute code located at the specified filesystem location
+ """Execute code located at the specified filesystem location.
+
+ path_name -- filesystem location of a Python script, zipfile,
+ or directory containing a top level __main__.py script.
+
+ Optional arguments:
+ init_globals -- dictionary used to pre-populate the module’s
+ globals dictionary before the code is executed.
- Returns the resulting top level namespace dictionary
+ run_name -- if not None, this will be used to set __name__;
+ otherwise, '' will be used for __name__.
- The file path may refer directly to a Python script (i.e.
- one that could be directly executed with execfile) or else
- it may refer to a zipfile or directory containing a top
- level __main__.py script.
+ Returns the resulting module globals dictionary.
"""
if run_name is None:
run_name = ""
pkg_name = run_name.rpartition(".")[0]
+ from pkgutil import get_importer
importer = get_importer(path_name)
# Trying to avoid importing imp so as to not consume the deprecation warning.
is_NullImporter = False
diff --git a/Lib/test/test_cmd_line_script.py b/Lib/test/test_cmd_line_script.py
index 88845d26b9..8b9eb7ffd9 100644
--- a/Lib/test/test_cmd_line_script.py
+++ b/Lib/test/test_cmd_line_script.py
@@ -523,8 +523,6 @@ def test_dash_m_bad_pyc(self):
self.assertNotIn(b'is a package', err)
self.assertNotIn(b'Traceback', err)
- # TODO: RUSTPYTHON
- @unittest.expectedFailure
def test_hint_when_triying_to_import_a_py_file(self):
with os_helper.temp_dir() as script_dir, \
os_helper.change_cwd(path=script_dir):
diff --git a/Lib/test/test_runpy.py b/Lib/test/test_runpy.py
index 4bec101b07..c93f0f7ec1 100644
--- a/Lib/test/test_runpy.py
+++ b/Lib/test/test_runpy.py
@@ -12,10 +12,10 @@
import textwrap
import unittest
import warnings
-from test.support import verbose, no_tracing
+from test.support import no_tracing, verbose
+from test.support.import_helper import forget, make_legacy_pyc, unload
from test.support.os_helper import create_empty_file, temp_dir
from test.support.script_helper import make_script, make_zip_script
-from test.support.import_helper import unload, forget, make_legacy_pyc
import runpy
@@ -745,8 +745,7 @@ def test_main_recursion_error(self):
"runpy.run_path(%r)\n") % dummy_dir
script_name = self._make_test_script(script_dir, mod_name, source)
zip_name, fname = make_zip_script(script_dir, 'test_zip', script_name)
- msg = "recursion depth exceeded"
- self.assertRaisesRegex(RecursionError, msg, run_path, zip_name)
+ self.assertRaises(RecursionError, run_path, zip_name)
# TODO: RUSTPYTHON, detect encoding comments in files
@unittest.expectedFailure
From 2eec9194166266c9a87066aad5a9178380dc8067 Mon Sep 17 00:00:00 2001
From: CPython Developers <>
Date: Sun, 14 Aug 2022 23:21:01 +0900
Subject: [PATCH 20/23] Update xml and test_xml_etree
---
Lib/test/test_xml_etree.py | 102 +++++++++++++++-----
Lib/xml/dom/expatbuilder.py | 8 +-
Lib/xml/dom/minidom.py | 56 ++++++++---
Lib/xml/dom/pulldom.py | 7 ++
Lib/xml/dom/xmlbuilder.py | 25 +----
Lib/xml/etree/ElementInclude.py | 58 ++++++++++--
Lib/xml/etree/ElementPath.py | 34 +++++--
Lib/xml/etree/ElementTree.py | 160 ++++++++++++++++++--------------
Lib/xml/etree/__init__.py | 2 +-
Lib/xml/sax/__init__.py | 10 +-
Lib/xml/sax/expatreader.py | 14 +--
Lib/xml/sax/handler.py | 45 +++++++++
Lib/xml/sax/saxutils.py | 5 +-
13 files changed, 358 insertions(+), 168 deletions(-)
diff --git a/Lib/test/test_xml_etree.py b/Lib/test/test_xml_etree.py
index 4d67101218..b721f028a1 100644
--- a/Lib/test/test_xml_etree.py
+++ b/Lib/test/test_xml_etree.py
@@ -10,7 +10,6 @@
import html
import io
import itertools
-import locale
import operator
import os
import pickle
@@ -130,6 +129,9 @@ def newtest(*args, **kwargs):
return newtest
return decorator
+def convlinesep(data):
+ return data.replace(b'\n', os.linesep.encode())
+
class ModuleTest(unittest.TestCase):
def test_sanity(self):
@@ -1023,17 +1025,15 @@ def test_tostring_xml_declaration(self):
@unittest.expectedFailure
def test_tostring_xml_declaration_unicode_encoding(self):
elem = ET.XML('')
- preferredencoding = locale.getpreferredencoding()
self.assertEqual(
- f"\n",
- ET.tostring(elem, encoding='unicode', xml_declaration=True)
+ ET.tostring(elem, encoding='unicode', xml_declaration=True),
+ "\n"
)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_tostring_xml_declaration_cases(self):
elem = ET.XML('ø')
- preferredencoding = locale.getpreferredencoding()
TESTCASES = [
# (expected_retval, encoding, xml_declaration)
# ... xml_declaration = None
@@ -1060,7 +1060,7 @@ def test_tostring_xml_declaration_cases(self):
b"ø", 'US-ASCII', True),
(b"\n"
b"\xf8", 'ISO-8859-1', True),
- (f"\n"
+ ("\n"
"ø", 'unicode', True),
]
@@ -1102,11 +1102,10 @@ def test_tostringlist_xml_declaration(self):
b"\n"
)
- preferredencoding = locale.getpreferredencoding()
stringlist = ET.tostringlist(elem, encoding='unicode', xml_declaration=True)
self.assertEqual(
''.join(stringlist),
- f"\n"
+ "\n"
)
self.assertRegex(stringlist[0], r"^<\?xml version='1.0' encoding='.+'?>")
self.assertEqual(['', '', ''], stringlist[1:])
@@ -3914,38 +3913,91 @@ def test_encoding(self):
@unittest.expectedFailure
def test_write_to_filename(self):
self.addCleanup(os_helper.unlink, TESTFN)
- tree = ET.ElementTree(ET.XML(''''''))
+ tree = ET.ElementTree(ET.XML('''\xf8'''))
tree.write(TESTFN)
with open(TESTFN, 'rb') as f:
- self.assertEqual(f.read(), b'''''')
+ self.assertEqual(f.read(), b'''ø''')
+
+ def test_write_to_filename_with_encoding(self):
+ self.addCleanup(os_helper.unlink, TESTFN)
+ tree = ET.ElementTree(ET.XML('''\xf8'''))
+ tree.write(TESTFN, encoding='utf-8')
+ with open(TESTFN, 'rb') as f:
+ self.assertEqual(f.read(), b'''\xc3\xb8''')
+
+ tree.write(TESTFN, encoding='ISO-8859-1')
+ with open(TESTFN, 'rb') as f:
+ self.assertEqual(f.read(), convlinesep(
+ b'''\n'''
+ b'''\xf8'''))
+
+ def test_write_to_filename_as_unicode(self):
+ self.addCleanup(os_helper.unlink, TESTFN)
+ with open(TESTFN, 'w') as f:
+ encoding = f.encoding
+ os_helper.unlink(TESTFN)
+
+ tree = ET.ElementTree(ET.XML('''\xf8'''))
+ tree.write(TESTFN, encoding='unicode')
+ with open(TESTFN, 'rb') as f:
+ self.assertEqual(f.read(), b"\xc3\xb8")
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_write_to_text_file(self):
self.addCleanup(os_helper.unlink, TESTFN)
- tree = ET.ElementTree(ET.XML(''''''))
+ tree = ET.ElementTree(ET.XML('''\xf8'''))
with open(TESTFN, 'w', encoding='utf-8') as f:
tree.write(f, encoding='unicode')
self.assertFalse(f.closed)
with open(TESTFN, 'rb') as f:
- self.assertEqual(f.read(), b'''''')
+ self.assertEqual(f.read(), b'''\xc3\xb8''')
+
+ with open(TESTFN, 'w', encoding='ascii', errors='xmlcharrefreplace') as f:
+ tree.write(f, encoding='unicode')
+ self.assertFalse(f.closed)
+ with open(TESTFN, 'rb') as f:
+ self.assertEqual(f.read(), b'''ø''')
+
+ with open(TESTFN, 'w', encoding='ISO-8859-1') as f:
+ tree.write(f, encoding='unicode')
+ self.assertFalse(f.closed)
+ with open(TESTFN, 'rb') as f:
+ self.assertEqual(f.read(), b'''\xf8''')
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_write_to_binary_file(self):
self.addCleanup(os_helper.unlink, TESTFN)
- tree = ET.ElementTree(ET.XML(''''''))
+ tree = ET.ElementTree(ET.XML('''\xf8'''))
with open(TESTFN, 'wb') as f:
tree.write(f)
self.assertFalse(f.closed)
with open(TESTFN, 'rb') as f:
- self.assertEqual(f.read(), b'''''')
+ self.assertEqual(f.read(), b'''ø''')
+
+ def test_write_to_binary_file_with_encoding(self):
+ self.addCleanup(os_helper.unlink, TESTFN)
+ tree = ET.ElementTree(ET.XML('''\xf8'''))
+ with open(TESTFN, 'wb') as f:
+ tree.write(f, encoding='utf-8')
+ self.assertFalse(f.closed)
+ with open(TESTFN, 'rb') as f:
+ self.assertEqual(f.read(), b'''\xc3\xb8''')
+
+ with open(TESTFN, 'wb') as f:
+ tree.write(f, encoding='ISO-8859-1')
+ self.assertFalse(f.closed)
+ with open(TESTFN, 'rb') as f:
+ self.assertEqual(f.read(),
+ b'''\n'''
+ b'''\xf8''')
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_write_to_binary_file_with_bom(self):
self.addCleanup(os_helper.unlink, TESTFN)
- tree = ET.ElementTree(ET.XML(''''''))
+ tree = ET.ElementTree(ET.XML('''\xf8'''))
# test BOM writing to buffered file
with open(TESTFN, 'wb') as f:
tree.write(f, encoding='utf-16')
@@ -3953,7 +4005,7 @@ def test_write_to_binary_file_with_bom(self):
with open(TESTFN, 'rb') as f:
self.assertEqual(f.read(),
'''\n'''
- ''''''.encode("utf-16"))
+ '''\xf8'''.encode("utf-16"))
# test BOM writing to non-buffered file
with open(TESTFN, 'wb', buffering=0) as f:
tree.write(f, encoding='utf-16')
@@ -3961,7 +4013,7 @@ def test_write_to_binary_file_with_bom(self):
with open(TESTFN, 'rb') as f:
self.assertEqual(f.read(),
'''\n'''
- ''''''.encode("utf-16"))
+ '''\xf8'''.encode("utf-16"))
# TODO: RUSTPYTHON
@unittest.expectedFailure
@@ -3974,10 +4026,10 @@ def test_read_from_stringio(self):
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_write_to_stringio(self):
- tree = ET.ElementTree(ET.XML(''''''))
+ tree = ET.ElementTree(ET.XML('''\xf8'''))
stream = io.StringIO()
tree.write(stream, encoding='unicode')
- self.assertEqual(stream.getvalue(), '''''')
+ self.assertEqual(stream.getvalue(), '''\xf8''')
# TODO: RUSTPYTHON
@unittest.expectedFailure
@@ -3990,10 +4042,10 @@ def test_read_from_bytesio(self):
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_write_to_bytesio(self):
- tree = ET.ElementTree(ET.XML(''''''))
+ tree = ET.ElementTree(ET.XML('''\xf8'''))
raw = io.BytesIO()
tree.write(raw)
- self.assertEqual(raw.getvalue(), b'''''')
+ self.assertEqual(raw.getvalue(), b'''ø''')
class dummy:
pass
@@ -4011,12 +4063,12 @@ def test_read_from_user_text_reader(self):
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_write_to_user_text_writer(self):
- tree = ET.ElementTree(ET.XML(''''''))
+ tree = ET.ElementTree(ET.XML('''\xf8'''))
stream = io.StringIO()
writer = self.dummy()
writer.write = stream.write
tree.write(writer, encoding='unicode')
- self.assertEqual(stream.getvalue(), '''''')
+ self.assertEqual(stream.getvalue(), '''\xf8''')
# TODO: RUSTPYTHON
@unittest.expectedFailure
@@ -4032,12 +4084,12 @@ def test_read_from_user_binary_reader(self):
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_write_to_user_binary_writer(self):
- tree = ET.ElementTree(ET.XML(''''''))
+ tree = ET.ElementTree(ET.XML('''\xf8'''))
raw = io.BytesIO()
writer = self.dummy()
writer.write = raw.write
tree.write(writer)
- self.assertEqual(raw.getvalue(), b'''''')
+ self.assertEqual(raw.getvalue(), b'''ø''')
# TODO: RUSTPYTHON
@unittest.expectedFailure
diff --git a/Lib/xml/dom/expatbuilder.py b/Lib/xml/dom/expatbuilder.py
index 2bd835b035..199c22d0af 100644
--- a/Lib/xml/dom/expatbuilder.py
+++ b/Lib/xml/dom/expatbuilder.py
@@ -204,11 +204,11 @@ def parseFile(self, file):
buffer = file.read(16*1024)
if not buffer:
break
- parser.Parse(buffer, 0)
+ parser.Parse(buffer, False)
if first_buffer and self.document.documentElement:
self._setup_subset(buffer)
first_buffer = False
- parser.Parse("", True)
+ parser.Parse(b"", True)
except ParseEscape:
pass
doc = self.document
@@ -637,7 +637,7 @@ def parseString(self, string):
nsattrs = self._getNSattrs() # get ns decls from node's ancestors
document = _FRAGMENT_BUILDER_TEMPLATE % (ident, subset, nsattrs)
try:
- parser.Parse(document, 1)
+ parser.Parse(document, True)
except:
self.reset()
raise
@@ -697,7 +697,7 @@ def external_entity_ref_handler(self, context, base, systemId, publicId):
self.fragment = self.document.createDocumentFragment()
self.curNode = self.fragment
try:
- parser.Parse(self._source, 1)
+ parser.Parse(self._source, True)
finally:
self.curNode = old_cur_node
self.document = old_document
diff --git a/Lib/xml/dom/minidom.py b/Lib/xml/dom/minidom.py
index 24957ea14f..d09ef5e7d0 100644
--- a/Lib/xml/dom/minidom.py
+++ b/Lib/xml/dom/minidom.py
@@ -43,10 +43,11 @@ class Node(xml.dom.Node):
def __bool__(self):
return True
- def toxml(self, encoding=None):
- return self.toprettyxml("", "", encoding)
+ def toxml(self, encoding=None, standalone=None):
+ return self.toprettyxml("", "", encoding, standalone)
- def toprettyxml(self, indent="\t", newl="\n", encoding=None):
+ def toprettyxml(self, indent="\t", newl="\n", encoding=None,
+ standalone=None):
if encoding is None:
writer = io.StringIO()
else:
@@ -56,7 +57,7 @@ def toprettyxml(self, indent="\t", newl="\n", encoding=None):
newline='\n')
if self.nodeType == Node.DOCUMENT_NODE:
# Can pass encoding only to document, to put it into XML header
- self.writexml(writer, "", indent, newl, encoding)
+ self.writexml(writer, "", indent, newl, encoding, standalone)
else:
self.writexml(writer, "", indent, newl)
if encoding is None:
@@ -718,6 +719,14 @@ def unlink(self):
Node.unlink(self)
def getAttribute(self, attname):
+ """Returns the value of the specified attribute.
+
+ Returns the value of the element's attribute named attname as
+ a string. An empty string is returned if the element does not
+ have such an attribute. Note that an empty string may also be
+ returned as an explicitly given attribute value, use the
+ hasAttribute method to distinguish these two cases.
+ """
if self._attrs is None:
return ""
try:
@@ -823,10 +832,16 @@ def removeAttributeNode(self, node):
# Restore this since the node is still useful and otherwise
# unlinked
node.ownerDocument = self.ownerDocument
+ return node
removeAttributeNodeNS = removeAttributeNode
def hasAttribute(self, name):
+ """Checks whether the element has an attribute with the specified name.
+
+ Returns True if the element has an attribute with the specified name.
+ Otherwise, returns False.
+ """
if self._attrs is None:
return False
return name in self._attrs
@@ -837,6 +852,11 @@ def hasAttributeNS(self, namespaceURI, localName):
return (namespaceURI, localName) in self._attrsNS
def getElementsByTagName(self, name):
+ """Returns all descendant elements with the given tag name.
+
+ Returns the list of all descendant elements (not direct children
+ only) with the specified tag name.
+ """
return _get_elements_by_tagName_helper(self, name, NodeList())
def getElementsByTagNameNS(self, namespaceURI, localName):
@@ -847,22 +867,27 @@ def __repr__(self):
return "" % (self.tagName, id(self))
def writexml(self, writer, indent="", addindent="", newl=""):
+ """Write an XML element to a file-like object
+
+ Write the element to the writer object that must provide
+ a write method (e.g. a file or StringIO object).
+ """
# indent = current indentation
# addindent = indentation to add to higher levels
# newl = newline string
writer.write(indent+"<" + self.tagName)
attrs = self._get_attributes()
- a_names = sorted(attrs.keys())
- for a_name in a_names:
+ for a_name in attrs.keys():
writer.write(" %s=\"" % a_name)
_write_data(writer, attrs[a_name].value)
writer.write("\"")
if self.childNodes:
writer.write(">")
if (len(self.childNodes) == 1 and
- self.childNodes[0].nodeType == Node.TEXT_NODE):
+ self.childNodes[0].nodeType in (
+ Node.TEXT_NODE, Node.CDATA_SECTION_NODE)):
self.childNodes[0].writexml(writer, '', '', '')
else:
writer.write(newl)
@@ -1786,12 +1811,17 @@ def importNode(self, node, deep):
raise xml.dom.NotSupportedErr("cannot import document type nodes")
return _clone_node(node, deep, self)
- def writexml(self, writer, indent="", addindent="", newl="", encoding=None):
- if encoding is None:
- writer.write(''+newl)
- else:
- writer.write('%s' % (
- encoding, newl))
+ def writexml(self, writer, indent="", addindent="", newl="", encoding=None,
+ standalone=None):
+ declarations = []
+
+ if encoding:
+ declarations.append(f'encoding="{encoding}"')
+ if standalone is not None:
+ declarations.append(f'standalone="{"yes" if standalone else "no"}"')
+
+ writer.write(f'{newl}')
+
for node in self.childNodes:
node.writexml(writer, indent, addindent, newl)
diff --git a/Lib/xml/dom/pulldom.py b/Lib/xml/dom/pulldom.py
index 43504f7656..96a8d59519 100644
--- a/Lib/xml/dom/pulldom.py
+++ b/Lib/xml/dom/pulldom.py
@@ -217,6 +217,13 @@ def reset(self):
self.parser.setContentHandler(self.pulldom)
def __getitem__(self, pos):
+ import warnings
+ warnings.warn(
+ "DOMEventStream's __getitem__ method ignores 'pos' parameter. "
+ "Use iterator protocol instead.",
+ DeprecationWarning,
+ stacklevel=2
+ )
rc = self.getEvent()
if rc:
return rc
diff --git a/Lib/xml/dom/xmlbuilder.py b/Lib/xml/dom/xmlbuilder.py
index e9a1536472..8a20026349 100644
--- a/Lib/xml/dom/xmlbuilder.py
+++ b/Lib/xml/dom/xmlbuilder.py
@@ -1,7 +1,6 @@
"""Implementation of the DOM Level 3 'LS-Load' feature."""
import copy
-import warnings
import xml.dom
from xml.dom.NodeFilter import NodeFilter
@@ -80,7 +79,7 @@ def setFeature(self, name, state):
settings = self._settings[(_name_xform(name), state)]
except KeyError:
raise xml.dom.NotSupportedErr(
- "unsupported feature: %r" % (name,))
+ "unsupported feature: %r" % (name,)) from None
else:
for name, value in settings:
setattr(self._options, name, value)
@@ -332,29 +331,10 @@ def startContainer(self, element):
del NodeFilter
-class _AsyncDeprecatedProperty:
- def warn(self, cls):
- clsname = cls.__name__
- warnings.warn(
- "{cls}.async is deprecated; use {cls}.async_".format(cls=clsname),
- DeprecationWarning)
-
- def __get__(self, instance, cls):
- self.warn(cls)
- if instance is not None:
- return instance.async_
- return False
-
- def __set__(self, instance, value):
- self.warn(type(instance))
- setattr(instance, 'async_', value)
-
-
class DocumentLS:
"""Mixin to create documents that conform to the load/save spec."""
async_ = False
- locals()['async'] = _AsyncDeprecatedProperty() # Avoid DeprecationWarning
def _get_async(self):
return False
@@ -384,9 +364,6 @@ def saveXML(self, snode):
return snode.toxml()
-del _AsyncDeprecatedProperty
-
-
class DOMImplementationLS:
MODE_SYNCHRONOUS = 1
MODE_ASYNCHRONOUS = 2
diff --git a/Lib/xml/etree/ElementInclude.py b/Lib/xml/etree/ElementInclude.py
index 963470e3b1..40a9b22292 100644
--- a/Lib/xml/etree/ElementInclude.py
+++ b/Lib/xml/etree/ElementInclude.py
@@ -42,7 +42,7 @@
# --------------------------------------------------------------------
# Licensed to PSF under a Contributor Agreement.
-# See http://www.python.org/psf/license for licensing details.
+# See https://www.python.org/psf/license for licensing details.
##
# Limited XInclude support for the ElementTree package.
@@ -50,18 +50,28 @@
import copy
from . import ElementTree
+from urllib.parse import urljoin
XINCLUDE = "{http://www.w3.org/2001/XInclude}"
XINCLUDE_INCLUDE = XINCLUDE + "include"
XINCLUDE_FALLBACK = XINCLUDE + "fallback"
+# For security reasons, the inclusion depth is limited to this read-only value by default.
+DEFAULT_MAX_INCLUSION_DEPTH = 6
+
+
##
# Fatal include error.
class FatalIncludeError(SyntaxError):
pass
+
+class LimitedRecursiveIncludeError(FatalIncludeError):
+ pass
+
+
##
# Default loader. This loader reads an included resource from disk.
#
@@ -92,13 +102,33 @@ def default_loader(href, parse, encoding=None):
# @param loader Optional resource loader. If omitted, it defaults
# to {@link default_loader}. If given, it should be a callable
# that implements the same interface as default_loader.
+# @param base_url The base URL of the original file, to resolve
+# relative include file references.
+# @param max_depth The maximum number of recursive inclusions.
+# Limited to reduce the risk of malicious content explosion.
+# Pass a negative value to disable the limitation.
+# @throws LimitedRecursiveIncludeError If the {@link max_depth} was exceeded.
# @throws FatalIncludeError If the function fails to include a given
# resource, or if the tree contains malformed XInclude elements.
-# @throws OSError If the function fails to load a given resource.
+# @throws IOError If the function fails to load a given resource.
+# @returns the node or its replacement if it was an XInclude node
-def include(elem, loader=None):
+def include(elem, loader=None, base_url=None,
+ max_depth=DEFAULT_MAX_INCLUSION_DEPTH):
+ if max_depth is None:
+ max_depth = -1
+ elif max_depth < 0:
+ raise ValueError("expected non-negative depth or None for 'max_depth', got %r" % max_depth)
+
+ if hasattr(elem, 'getroot'):
+ elem = elem.getroot()
if loader is None:
loader = default_loader
+
+ _include(elem, loader, base_url, max_depth, set())
+
+
+def _include(elem, loader, base_url, max_depth, _parent_hrefs):
# look for xinclude elements
i = 0
while i < len(elem):
@@ -106,14 +136,24 @@ def include(elem, loader=None):
if e.tag == XINCLUDE_INCLUDE:
# process xinclude directive
href = e.get("href")
+ if base_url:
+ href = urljoin(base_url, href)
parse = e.get("parse", "xml")
if parse == "xml":
+ if href in _parent_hrefs:
+ raise FatalIncludeError("recursive include of %s" % href)
+ if max_depth == 0:
+ raise LimitedRecursiveIncludeError(
+ "maximum xinclude depth reached when including file %s" % href)
+ _parent_hrefs.add(href)
node = loader(href, parse)
if node is None:
raise FatalIncludeError(
"cannot load %r as %r" % (href, parse)
)
- node = copy.copy(node)
+ node = copy.copy(node) # FIXME: this makes little sense with recursive includes
+ _include(node, loader, href, max_depth - 1, _parent_hrefs)
+ _parent_hrefs.remove(href)
if e.tail:
node.tail = (node.tail or "") + e.tail
elem[i] = node
@@ -123,11 +163,13 @@ def include(elem, loader=None):
raise FatalIncludeError(
"cannot load %r as %r" % (href, parse)
)
+ if e.tail:
+ text += e.tail
if i:
node = elem[i-1]
- node.tail = (node.tail or "") + text + (e.tail or "")
+ node.tail = (node.tail or "") + text
else:
- elem.text = (elem.text or "") + text + (e.tail or "")
+ elem.text = (elem.text or "") + text
del elem[i]
continue
else:
@@ -139,5 +181,5 @@ def include(elem, loader=None):
"xi:fallback tag must be child of xi:include (%r)" % e.tag
)
else:
- include(e, loader)
- i = i + 1
+ _include(e, loader, base_url, max_depth, _parent_hrefs)
+ i += 1
diff --git a/Lib/xml/etree/ElementPath.py b/Lib/xml/etree/ElementPath.py
index d318e65d84..cd3c354d08 100644
--- a/Lib/xml/etree/ElementPath.py
+++ b/Lib/xml/etree/ElementPath.py
@@ -48,7 +48,7 @@
# --------------------------------------------------------------------
# Licensed to PSF under a Contributor Agreement.
-# See http://www.python.org/psf/license for licensing details.
+# See https://www.python.org/psf/license for licensing details.
##
# Implementation module for XPath support. There's usually no reason
@@ -65,8 +65,9 @@
r"//?|"
r"\.\.|"
r"\(\)|"
+ r"!=|"
r"[/.*:\[\]\(\)@=])|"
- r"((?:\{[^}]+\})?[^/\[\]\(\)@=\s]+)|"
+ r"((?:\{[^}]+\})?[^/\[\]\(\)@!=\s]+)|"
r"\s+"
)
@@ -225,7 +226,6 @@ def select(context, result):
def prepare_predicate(next, token):
# FIXME: replace with real parser!!! refs:
- # http://effbot.org/zone/simple-iterator-parser.htm
# http://javascript.crockford.com/tdop/tdop.html
signature = []
predicate = []
@@ -253,15 +253,19 @@ def select(context, result):
if elem.get(key) is not None:
yield elem
return select
- if signature == "@-='":
- # [@attribute='value']
+ if signature == "@-='" or signature == "@-!='":
+ # [@attribute='value'] or [@attribute!='value']
key = predicate[1]
value = predicate[-1]
def select(context, result):
for elem in result:
if elem.get(key) == value:
yield elem
- return select
+ def select_negated(context, result):
+ for elem in result:
+ if (attr_value := elem.get(key)) is not None and attr_value != value:
+ yield elem
+ return select_negated if '!=' in signature else select
if signature == "-" and not re.match(r"\-?\d+$", predicate[0]):
# [tag]
tag = predicate[0]
@@ -270,8 +274,10 @@ def select(context, result):
if elem.find(tag) is not None:
yield elem
return select
- if signature == ".='" or (signature == "-='" and not re.match(r"\-?\d+$", predicate[0])):
- # [.='value'] or [tag='value']
+ if signature == ".='" or signature == ".!='" or (
+ (signature == "-='" or signature == "-!='")
+ and not re.match(r"\-?\d+$", predicate[0])):
+ # [.='value'] or [tag='value'] or [.!='value'] or [tag!='value']
tag = predicate[0]
value = predicate[-1]
if tag:
@@ -281,12 +287,22 @@ def select(context, result):
if "".join(e.itertext()) == value:
yield elem
break
+ def select_negated(context, result):
+ for elem in result:
+ for e in elem.iterfind(tag):
+ if "".join(e.itertext()) != value:
+ yield elem
+ break
else:
def select(context, result):
for elem in result:
if "".join(elem.itertext()) == value:
yield elem
- return select
+ def select_negated(context, result):
+ for elem in result:
+ if "".join(elem.itertext()) != value:
+ yield elem
+ return select_negated if '!=' in signature else select
if signature == "-" or signature == "-()" or signature == "-()-":
# [index] or [last()] or [last()-index]
if signature == "-":
diff --git a/Lib/xml/etree/ElementTree.py b/Lib/xml/etree/ElementTree.py
index f8538dfb2b..2503d9ee76 100644
--- a/Lib/xml/etree/ElementTree.py
+++ b/Lib/xml/etree/ElementTree.py
@@ -35,7 +35,7 @@
#---------------------------------------------------------------------
# Licensed to PSF under a Contributor Agreement.
-# See http://www.python.org/psf/license for licensing details.
+# See https://www.python.org/psf/license for licensing details.
#
# ElementTree
# Copyright (c) 1999-2008 by Fredrik Lundh. All rights reserved.
@@ -76,7 +76,7 @@
"dump",
"Element", "ElementTree",
"fromstring", "fromstringlist",
- "iselement", "iterparse",
+ "indent", "iselement", "iterparse",
"parse", "ParseError",
"PI", "ProcessingInstruction",
"QName",
@@ -195,6 +195,13 @@ def copy(self):
original tree.
"""
+ warnings.warn(
+ "elem.copy() is deprecated. Use copy.copy(elem) instead.",
+ DeprecationWarning
+ )
+ return self.__copy__()
+
+ def __copy__(self):
elem = self.makeelement(self.tag, self.attrib)
elem.text = self.text
elem.tail = self.tail
@@ -273,19 +280,6 @@ def remove(self, subelement):
# assert iselement(element)
self._children.remove(subelement)
- def getchildren(self):
- """(Deprecated) Return all subelements.
-
- Elements are returned in document order.
-
- """
- warnings.warn(
- "This method will be removed in future versions. "
- "Use 'list(elem)' or iteration over elem instead.",
- DeprecationWarning, stacklevel=2
- )
- return self._children
-
def find(self, path, namespaces=None):
"""Find first matching element by tag name or path.
@@ -409,15 +403,6 @@ def iter(self, tag=None):
for e in self._children:
yield from e.iter(tag)
- # compatibility
- def getiterator(self, tag=None):
- warnings.warn(
- "This method will be removed in future versions. "
- "Use 'elem.iter()' or 'list(elem.iter())' instead.",
- DeprecationWarning, stacklevel=2
- )
- return list(self.iter(tag))
-
def itertext(self):
"""Create text iterator.
@@ -617,15 +602,6 @@ def iter(self, tag=None):
# assert self._root is not None
return self._root.iter(tag)
- # compatibility
- def getiterator(self, tag=None):
- warnings.warn(
- "This method will be removed in future versions. "
- "Use 'tree.iter()' or 'list(tree.iter())' instead.",
- DeprecationWarning, stacklevel=2
- )
- return list(self.iter(tag))
-
def find(self, path, namespaces=None):
"""Find first matching element by tag name or path.
@@ -752,16 +728,11 @@ def write(self, file_or_filename,
encoding = "utf-8"
else:
encoding = "us-ascii"
- enc_lower = encoding.lower()
- with _get_writer(file_or_filename, enc_lower) as write:
+ with _get_writer(file_or_filename, encoding) as (write, declared_encoding):
if method == "xml" and (xml_declaration or
(xml_declaration is None and
- enc_lower not in ("utf-8", "us-ascii", "unicode"))):
- declared_encoding = encoding
- if enc_lower == "unicode":
- # Retrieve the default encoding for the xml declaration
- import locale
- declared_encoding = locale.getpreferredencoding()
+ encoding.lower() != "unicode" and
+ declared_encoding.lower() not in ("utf-8", "us-ascii"))):
write("\n" % (
declared_encoding,))
if method == "text":
@@ -786,19 +757,17 @@ def _get_writer(file_or_filename, encoding):
write = file_or_filename.write
except AttributeError:
# file_or_filename is a file name
- if encoding == "unicode":
- file = open(file_or_filename, "w")
- else:
- file = open(file_or_filename, "w", encoding=encoding,
- errors="xmlcharrefreplace")
- with file:
- yield file.write
+ if encoding.lower() == "unicode":
+ encoding="utf-8"
+ with open(file_or_filename, "w", encoding=encoding,
+ errors="xmlcharrefreplace") as file:
+ yield file.write, encoding
else:
# file_or_filename is a file-like object
# encoding determines if it is a text or binary writer
- if encoding == "unicode":
+ if encoding.lower() == "unicode":
# use a text writer as is
- yield write
+ yield write, getattr(file_or_filename, "encoding", None) or "utf-8"
else:
# wrap a binary writer with TextIOWrapper
with contextlib.ExitStack() as stack:
@@ -829,7 +798,7 @@ def _get_writer(file_or_filename, encoding):
# Keep the original file open when the TextIOWrapper is
# destroyed
stack.callback(file.detach)
- yield file.write
+ yield file.write, encoding
def _namespaces(elem, default_namespace=None):
# identify namespaces used in this tree
@@ -1081,15 +1050,15 @@ def _escape_attrib(text):
text = text.replace(">", ">")
if "\"" in text:
text = text.replace("\"", """)
- # The following business with carriage returns is to satisfy
- # Section 2.11 of the XML specification, stating that
- # CR or CR LN should be replaced with just LN
+ # Although section 2.11 of the XML specification states that CR or
+ # CR LN should be replaced with just LN, it applies only to EOLNs
+ # which take part of organizing file into lines. Within attributes,
+ # we are replacing these with entity numbers, so they do not count.
# http://www.w3.org/TR/REC-xml/#sec-line-ends
- if "\r\n" in text:
- text = text.replace("\r\n", "\n")
+ # The current solution, contained in following six lines, was
+ # discussed in issue 17582 and 39011.
if "\r" in text:
- text = text.replace("\r", "\n")
- #The following four lines are issue 17582
+ text = text.replace("\r", "
")
if "\n" in text:
text = text.replace("\n", "
")
if "\t" in text:
@@ -1185,6 +1154,57 @@ def dump(elem):
if not tail or tail[-1] != "\n":
sys.stdout.write("\n")
+
+def indent(tree, space=" ", level=0):
+ """Indent an XML document by inserting newlines and indentation space
+ after elements.
+
+ *tree* is the ElementTree or Element to modify. The (root) element
+ itself will not be changed, but the tail text of all elements in its
+ subtree will be adapted.
+
+ *space* is the whitespace to insert for each indentation level, two
+ space characters by default.
+
+ *level* is the initial indentation level. Setting this to a higher
+ value than 0 can be used for indenting subtrees that are more deeply
+ nested inside of a document.
+ """
+ if isinstance(tree, ElementTree):
+ tree = tree.getroot()
+ if level < 0:
+ raise ValueError(f"Initial indentation level must be >= 0, got {level}")
+ if not len(tree):
+ return
+
+ # Reduce the memory consumption by reusing indentation strings.
+ indentations = ["\n" + level * space]
+
+ def _indent_children(elem, level):
+ # Start a new indentation level for the first child.
+ child_level = level + 1
+ try:
+ child_indentation = indentations[child_level]
+ except IndexError:
+ child_indentation = indentations[level] + space
+ indentations.append(child_indentation)
+
+ if not elem.text or not elem.text.strip():
+ elem.text = child_indentation
+
+ for child in elem:
+ if len(child):
+ _indent_children(child, child_level)
+ if not child.tail or not child.tail.strip():
+ child.tail = child_indentation
+
+ # Dedent after the last child by overwriting the previous indentation.
+ if not child.tail.strip():
+ child.tail = indentations[level]
+
+ _indent_children(tree, 0)
+
+
# --------------------------------------------------------------------
# parsing
@@ -1221,8 +1241,14 @@ def iterparse(source, events=None, parser=None):
# Use the internal, undocumented _parser argument for now; When the
# parser argument of iterparse is removed, this can be killed.
pullparser = XMLPullParser(events=events, _parser=parser)
- def iterator():
+
+ def iterator(source):
+ close_source = False
try:
+ if not hasattr(source, "read"):
+ source = open(source, "rb")
+ close_source = True
+ yield None
while True:
yield from pullparser.read_events()
# load event buffer
@@ -1238,16 +1264,12 @@ def iterator():
source.close()
class IterParseIterator(collections.abc.Iterator):
- __next__ = iterator().__next__
+ __next__ = iterator(source).__next__
it = IterParseIterator()
it.root = None
del iterator, IterParseIterator
- close_source = False
- if not hasattr(source, "read"):
- source = open(source, "rb")
- close_source = True
-
+ next(it)
return it
@@ -1256,7 +1278,7 @@ class XMLPullParser:
def __init__(self, events=None, *, _parser=None):
# The _parser argument is for internal use only and must not be relied
# upon in user code. It will be removed in a future release.
- # See http://bugs.python.org/issue17741 for more details.
+ # See https://bugs.python.org/issue17741 for more details.
self._events_queue = collections.deque()
self._parser = _parser or XMLParser(target=TreeBuilder())
@@ -1533,7 +1555,6 @@ def __init__(self, *, target=None, encoding=None):
# Configure pyexpat: buffering, new-style attribute handling.
parser.buffer_text = 1
parser.ordered_attributes = 1
- parser.specified_attributes = 1
self._doctype = None
self.entity = {}
try:
@@ -1553,7 +1574,6 @@ def _setevents(self, events_queue, events_to_report):
for event_name in events_to_report:
if event_name == "start":
parser.ordered_attributes = 1
- parser.specified_attributes = 1
def handler(tag, attrib_in, event=event_name, append=append,
start=self._start):
append((event, start(tag, attrib_in)))
@@ -1690,14 +1710,14 @@ def _default(self, text):
def feed(self, data):
"""Feed encoded data to parser."""
try:
- self.parser.Parse(data, 0)
+ self.parser.Parse(data, False)
except self._error as v:
self._raiseerror(v)
def close(self):
"""Finish feeding data to parser and return element structure."""
try:
- self.parser.Parse("", 1) # end of data
+ self.parser.Parse(b"", True) # end of data
except self._error as v:
self._raiseerror(v)
try:
diff --git a/Lib/xml/etree/__init__.py b/Lib/xml/etree/__init__.py
index 27fd8f6d4e..e2ec53421d 100644
--- a/Lib/xml/etree/__init__.py
+++ b/Lib/xml/etree/__init__.py
@@ -30,4 +30,4 @@
# --------------------------------------------------------------------
# Licensed to PSF under a Contributor Agreement.
-# See http://www.python.org/psf/license for licensing details.
+# See https://www.python.org/psf/license for licensing details.
diff --git a/Lib/xml/sax/__init__.py b/Lib/xml/sax/__init__.py
index 13f6cf58d0..17b75879eb 100644
--- a/Lib/xml/sax/__init__.py
+++ b/Lib/xml/sax/__init__.py
@@ -67,18 +67,18 @@ def parseString(string, handler, errorHandler=ErrorHandler()):
default_parser_list = sys.registry.getProperty(_key).split(",")
-def make_parser(parser_list = []):
+def make_parser(parser_list=()):
"""Creates and returns a SAX parser.
Creates the first parser it is able to instantiate of the ones
- given in the list created by doing parser_list +
- default_parser_list. The lists must contain the names of Python
+ given in the iterable created by chaining parser_list and
+ default_parser_list. The iterables must contain the names of Python
modules containing both a SAX parser and a create_parser function."""
- for parser_name in parser_list + default_parser_list:
+ for parser_name in list(parser_list) + default_parser_list:
try:
return _create_parser(parser_name)
- except ImportError as e:
+ except ImportError:
import sys
if parser_name in sys.modules:
# The parser module was found, but importing it
diff --git a/Lib/xml/sax/expatreader.py b/Lib/xml/sax/expatreader.py
index 5066ffc2fa..e334ac9fea 100644
--- a/Lib/xml/sax/expatreader.py
+++ b/Lib/xml/sax/expatreader.py
@@ -93,7 +93,7 @@ def __init__(self, namespaceHandling=0, bufsize=2**16-20):
self._parser = None
self._namespaces = namespaceHandling
self._lex_handler_prop = None
- self._parsing = 0
+ self._parsing = False
self._entity_stack = []
self._external_ges = 0
self._interning = None
@@ -203,10 +203,10 @@ def setProperty(self, name, value):
# IncrementalParser methods
- def feed(self, data, isFinal = 0):
+ def feed(self, data, isFinal=False):
if not self._parsing:
self.reset()
- self._parsing = 1
+ self._parsing = True
self._cont_handler.startDocument()
try:
@@ -237,13 +237,13 @@ def close(self):
# If we are completing an external entity, do nothing here
return
try:
- self.feed("", isFinal = 1)
+ self.feed(b"", isFinal=True)
self._cont_handler.endDocument()
- self._parsing = 0
+ self._parsing = False
# break cycle created by expat handlers pointing to our methods
self._parser = None
finally:
- self._parsing = 0
+ self._parsing = False
if self._parser is not None:
# Keep ErrorColumnNumber and ErrorLineNumber after closing.
parser = _ClosedParser()
@@ -307,7 +307,7 @@ def reset(self):
self._parser.SetParamEntityParsing(
expat.XML_PARAM_ENTITY_PARSING_UNLESS_STANDALONE)
- self._parsing = 0
+ self._parsing = False
self._entity_stack = []
# Locator methods
diff --git a/Lib/xml/sax/handler.py b/Lib/xml/sax/handler.py
index 481733d2cb..e8d417e519 100644
--- a/Lib/xml/sax/handler.py
+++ b/Lib/xml/sax/handler.py
@@ -340,3 +340,48 @@ def resolveEntity(self, publicId, systemId):
property_xml_string,
property_encoding,
property_interning_dict]
+
+
+class LexicalHandler:
+ """Optional SAX2 handler for lexical events.
+
+ This handler is used to obtain lexical information about an XML
+ document, that is, information about how the document was encoded
+ (as opposed to what it contains, which is reported to the
+ ContentHandler), such as comments and CDATA marked section
+ boundaries.
+
+ To set the LexicalHandler of an XMLReader, use the setProperty
+ method with the property identifier
+ 'http://xml.org/sax/properties/lexical-handler'."""
+
+ def comment(self, content):
+ """Reports a comment anywhere in the document (including the
+ DTD and outside the document element).
+
+ content is a string that holds the contents of the comment."""
+
+ def startDTD(self, name, public_id, system_id):
+ """Report the start of the DTD declarations, if the document
+ has an associated DTD.
+
+ A startEntity event will be reported before declaration events
+ from the external DTD subset are reported, and this can be
+ used to infer from which subset DTD declarations derive.
+
+ name is the name of the document element type, public_id the
+ public identifier of the DTD (or None if none were supplied)
+ and system_id the system identfier of the external subset (or
+ None if none were supplied)."""
+
+ def endDTD(self):
+ """Signals the end of DTD declarations."""
+
+ def startCDATA(self):
+ """Reports the beginning of a CDATA marked section.
+
+ The contents of the CDATA marked section will be reported
+ through the characters event."""
+
+ def endCDATA(self):
+ """Reports the end of a CDATA marked section."""
diff --git a/Lib/xml/sax/saxutils.py b/Lib/xml/sax/saxutils.py
index a69c7f7621..c1612ea1ce 100644
--- a/Lib/xml/sax/saxutils.py
+++ b/Lib/xml/sax/saxutils.py
@@ -56,8 +56,7 @@ def quoteattr(data, entities={}):
the optional entities parameter. The keys and values must all be
strings; each key will be replaced with its corresponding value.
"""
- entities = entities.copy()
- entities.update({'\n': '
', '\r': '
', '\t':' '})
+ entities = {**entities, '\n': '
', '\r': '
', '\t':' '}
data = escape(data, entities)
if '"' in data:
if "'" in data:
@@ -340,6 +339,8 @@ def prepare_input_source(source, base=""):
"""This function takes an InputSource and an optional base URL and
returns a fully resolved InputSource object ready for reading."""
+ if isinstance(source, os.PathLike):
+ source = os.fspath(source)
if isinstance(source, str):
source = xmlreader.InputSource(source)
elif hasattr(source, "read"):
From 2539e8276c36e364b2a96918bd503cadc1e3c0a6 Mon Sep 17 00:00:00 2001
From: Jeong YunWon
Date: Sun, 14 Aug 2022 23:25:26 +0900
Subject: [PATCH 21/23] adjust failing test markers for xml
---
Lib/test/test_xml_etree.py | 8 ++++++--
1 file changed, 6 insertions(+), 2 deletions(-)
diff --git a/Lib/test/test_xml_etree.py b/Lib/test/test_xml_etree.py
index b721f028a1..53552dc8b6 100644
--- a/Lib/test/test_xml_etree.py
+++ b/Lib/test/test_xml_etree.py
@@ -2466,8 +2466,6 @@ def test___init__(self):
self.assertIsNot(element_foo.attrib, attrib)
self.assertNotEqual(element_foo.attrib, attrib)
- # TODO: RUSTPYTHON
- @unittest.expectedFailure
def test_copy(self):
# Only run this test if Element.copy() is defined.
if "copy" not in dir(ET.Element):
@@ -3918,6 +3916,8 @@ def test_write_to_filename(self):
with open(TESTFN, 'rb') as f:
self.assertEqual(f.read(), b'''ø''')
+ # TODO: RUSTPYTHON
+ @unittest.expectedFailure
def test_write_to_filename_with_encoding(self):
self.addCleanup(os_helper.unlink, TESTFN)
tree = ET.ElementTree(ET.XML('''\xf8'''))
@@ -3931,6 +3931,8 @@ def test_write_to_filename_with_encoding(self):
b'''\n'''
b'''\xf8'''))
+ # TODO: RUSTPYTHON
+ @unittest.expectedFailure
def test_write_to_filename_as_unicode(self):
self.addCleanup(os_helper.unlink, TESTFN)
with open(TESTFN, 'w') as f:
@@ -3976,6 +3978,8 @@ def test_write_to_binary_file(self):
with open(TESTFN, 'rb') as f:
self.assertEqual(f.read(), b'''ø''')
+ # TODO: RUSTPYTHON
+ @unittest.expectedFailure
def test_write_to_binary_file_with_encoding(self):
self.addCleanup(os_helper.unlink, TESTFN)
tree = ET.ElementTree(ET.XML('''\xf8'''))
From 563f9ecd53ca6a91b78dbe72172add132c016e27 Mon Sep 17 00:00:00 2001
From: CPython Developers <>
Date: Mon, 15 Aug 2022 00:35:20 +0900
Subject: [PATCH 22/23] Update test_repl from cpython 3.10.6
---
Lib/test/test_repl.py | 26 ++++++++++++++++++++------
1 file changed, 20 insertions(+), 6 deletions(-)
diff --git a/Lib/test/test_repl.py b/Lib/test/test_repl.py
index 71f192f90d..03bf8d8b54 100644
--- a/Lib/test/test_repl.py
+++ b/Lib/test/test_repl.py
@@ -29,7 +29,9 @@ def spawn_repl(*args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, **kw):
# test.support.script_helper.
env = kw.setdefault('env', dict(os.environ))
env['TERM'] = 'vt100'
- return subprocess.Popen(cmd_line, executable=sys.executable,
+ return subprocess.Popen(cmd_line,
+ executable=sys.executable,
+ text=True,
stdin=subprocess.PIPE,
stdout=stdout, stderr=stderr,
**kw)
@@ -49,12 +51,11 @@ def test_no_memory(self):
sys.exit(0)
"""
user_input = dedent(user_input)
- user_input = user_input.encode()
p = spawn_repl()
with SuppressCrashReport():
p.stdin.write(user_input)
output = kill_python(p)
- self.assertIn(b'After the exception.', output)
+ self.assertIn('After the exception.', output)
# Exit code 120: Py_FinalizeEx() failed to flush stdout and stderr.
self.assertIn(p.returncode, (1, 120))
@@ -86,13 +87,26 @@ def test_multiline_string_parsing(self):
"""
'''
user_input = dedent(user_input)
- user_input = user_input.encode()
p = spawn_repl()
- with SuppressCrashReport():
- p.stdin.write(user_input)
+ p.stdin.write(user_input)
output = kill_python(p)
self.assertEqual(p.returncode, 0)
+ def test_close_stdin(self):
+ user_input = dedent('''
+ import os
+ print("before close")
+ os.close(0)
+ ''')
+ prepare_repl = dedent('''
+ from test.support import suppress_msvcrt_asserts
+ suppress_msvcrt_asserts()
+ ''')
+ process = spawn_repl('-c', prepare_repl)
+ output = process.communicate(user_input)[0]
+ self.assertEqual(process.returncode, 0)
+ self.assertIn('before close', output)
+
if __name__ == "__main__":
unittest.main()
From bf53bb0c5fc6f0d591395021b53aee883ae77bb8 Mon Sep 17 00:00:00 2001
From: Jeong YunWon
Date: Mon, 15 Aug 2022 00:37:39 +0900
Subject: [PATCH 23/23] Mark failing tests from test_repl
---
Lib/test/test_repl.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/Lib/test/test_repl.py b/Lib/test/test_repl.py
index 03bf8d8b54..a63c1213c6 100644
--- a/Lib/test/test_repl.py
+++ b/Lib/test/test_repl.py
@@ -92,6 +92,8 @@ def test_multiline_string_parsing(self):
output = kill_python(p)
self.assertEqual(p.returncode, 0)
+ # TODO: RUSTPYTHON
+ @unittest.expectedFailure
def test_close_stdin(self):
user_input = dedent('''
import os