Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 40 additions & 60 deletions Lib/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,16 @@ def __init__(self, selector):
def __len__(self):
return len(self._selector._fd_to_key)

def get(self, fileobj, default=None):
fd = self._selector._fileobj_lookup(fileobj)
return self._selector._fd_to_key.get(fd, default)

def __getitem__(self, fileobj):
try:
fd = self._selector._fileobj_lookup(fileobj)
return self._selector._fd_to_key[fd]
except KeyError:
raise KeyError("{!r} is not registered".format(fileobj)) from None
fd = self._selector._fileobj_lookup(fileobj)
key = self._selector._fd_to_key.get(fd)
if key is None:
raise KeyError("{!r} is not registered".format(fileobj))
return key

def __iter__(self):
return iter(self._selector._fd_to_key)
Expand Down Expand Up @@ -272,19 +276,6 @@ def close(self):
def get_map(self):
return self._map

def _key_from_fd(self, fd):
"""Return the key associated to a given file descriptor.

Parameters:
fd -- file descriptor

Returns:
corresponding key, or None if not found
"""
try:
return self._fd_to_key[fd]
except KeyError:
return None


class SelectSelector(_BaseSelectorImpl):
Expand Down Expand Up @@ -323,17 +314,15 @@ def select(self, timeout=None):
r, w, _ = self._select(self._readers, self._writers, [], timeout)
except InterruptedError:
return ready
r = set(r)
w = set(w)
for fd in r | w:
events = 0
if fd in r:
events |= EVENT_READ
if fd in w:
events |= EVENT_WRITE

key = self._key_from_fd(fd)
r = frozenset(r)
w = frozenset(w)
rw = r | w
fd_to_key_get = self._fd_to_key.get
for fd in rw:
key = fd_to_key_get(fd)
if key:
events = ((fd in r and EVENT_READ)
| (fd in w and EVENT_WRITE))
ready.append((key, events & key.events))
return ready

Expand All @@ -350,11 +339,8 @@ def __init__(self):

def register(self, fileobj, events, data=None):
key = super().register(fileobj, events, data)
poller_events = 0
if events & EVENT_READ:
poller_events |= self._EVENT_READ
if events & EVENT_WRITE:
poller_events |= self._EVENT_WRITE
poller_events = ((events & EVENT_READ and self._EVENT_READ)
| (events & EVENT_WRITE and self._EVENT_WRITE) )
try:
self._selector.register(key.fd, poller_events)
except:
Expand All @@ -380,11 +366,8 @@ def modify(self, fileobj, events, data=None):

changed = False
if events != key.events:
selector_events = 0
if events & EVENT_READ:
selector_events |= self._EVENT_READ
if events & EVENT_WRITE:
selector_events |= self._EVENT_WRITE
selector_events = ((events & EVENT_READ and self._EVENT_READ)
| (events & EVENT_WRITE and self._EVENT_WRITE))
try:
self._selector.modify(key.fd, selector_events)
except:
Expand Down Expand Up @@ -415,15 +398,13 @@ def select(self, timeout=None):
fd_event_list = self._selector.poll(timeout)
except InterruptedError:
return ready
for fd, event in fd_event_list:
events = 0
if event & ~self._EVENT_READ:
events |= EVENT_WRITE
if event & ~self._EVENT_WRITE:
events |= EVENT_READ

key = self._key_from_fd(fd)
fd_to_key_get = self._fd_to_key.get
for fd, event in fd_event_list:
key = fd_to_key_get(fd)
if key:
events = ((event & ~self._EVENT_READ and EVENT_WRITE)
| (event & ~self._EVENT_WRITE and EVENT_READ))
ready.append((key, events & key.events))
return ready

Expand All @@ -439,6 +420,9 @@ class PollSelector(_PollLikeSelector):

if hasattr(select, 'epoll'):

_NOT_EPOLLIN = ~select.EPOLLIN
_NOT_EPOLLOUT = ~select.EPOLLOUT

class EpollSelector(_PollLikeSelector):
"""Epoll-based selector."""
_selector_cls = select.epoll
Expand All @@ -461,22 +445,20 @@ def select(self, timeout=None):
# epoll_wait() expects `maxevents` to be greater than zero;
# we want to make sure that `select()` can be called when no
# FD is registered.
max_ev = max(len(self._fd_to_key), 1)
max_ev = len(self._fd_to_key) or 1

ready = []
try:
fd_event_list = self._selector.poll(timeout, max_ev)
except InterruptedError:
return ready
for fd, event in fd_event_list:
events = 0
if event & ~select.EPOLLIN:
events |= EVENT_WRITE
if event & ~select.EPOLLOUT:
events |= EVENT_READ

key = self._key_from_fd(fd)
fd_to_key = self._fd_to_key
for fd, event in fd_event_list:
key = fd_to_key.get(fd)
if key:
events = ((event & _NOT_EPOLLIN and EVENT_WRITE)
| (event & _NOT_EPOLLOUT and EVENT_READ))
ready.append((key, events & key.events))
return ready

Expand Down Expand Up @@ -566,17 +548,15 @@ def select(self, timeout=None):
kev_list = self._selector.control(None, max_ev, timeout)
except InterruptedError:
return ready

fd_to_key_get = self._fd_to_key.get
for kev in kev_list:
fd = kev.ident
flag = kev.filter
events = 0
if flag == select.KQ_FILTER_READ:
events |= EVENT_READ
if flag == select.KQ_FILTER_WRITE:
events |= EVENT_WRITE

key = self._key_from_fd(fd)
key = fd_to_key_get(fd)
if key:
events = ((flag == select.KQ_FILTER_READ and EVENT_READ)
| (flag == select.KQ_FILTER_WRITE and EVENT_WRITE))
ready.append((key, events & key.events))
return ready

Expand Down
12 changes: 9 additions & 3 deletions Lib/test/test_selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
import socket
import sys
from test import support
from test.support import os_helper
from test.support import socket_helper
from test.support import is_apple, os_helper, socket_helper
from time import sleep
import unittest
import unittest.mock
Expand Down Expand Up @@ -132,6 +131,7 @@ def test_unregister_after_fd_close_and_reuse(self):
s.unregister(r)
s.unregister(w)

# TODO: RUSTPYTHON
@unittest.expectedFailureIfWindows("TODO: RUSTPYTHON")
def test_unregister_after_socket_close(self):
s = self.SELECTOR()
Expand Down Expand Up @@ -224,6 +224,8 @@ def test_close(self):
self.assertRaises(RuntimeError, s.get_key, wr)
self.assertRaises(KeyError, mapping.__getitem__, rd)
self.assertRaises(KeyError, mapping.__getitem__, wr)
self.assertEqual(mapping.get(rd), None)
self.assertEqual(mapping.get(wr), None)

def test_get_key(self):
s = self.SELECTOR()
Expand All @@ -242,13 +244,17 @@ def test_get_map(self):
self.addCleanup(s.close)

rd, wr = self.make_socketpair()
sentinel = object()

keys = s.get_map()
self.assertFalse(keys)
self.assertEqual(len(keys), 0)
self.assertEqual(list(keys), [])
self.assertEqual(keys.get(rd), None)
self.assertEqual(keys.get(rd, sentinel), sentinel)
key = s.register(rd, selectors.EVENT_READ, "data")
self.assertIn(rd, keys)
self.assertEqual(key, keys.get(rd))
self.assertEqual(key, keys[rd])
self.assertEqual(len(keys), 1)
self.assertEqual(list(keys), [rd.fileno()])
Expand Down Expand Up @@ -521,7 +527,7 @@ def test_above_fd_setsize(self):
try:
fds = s.select()
except OSError as e:
if e.errno == errno.EINVAL and sys.platform == 'darwin':
if e.errno == errno.EINVAL and is_apple:
# unexplainable errors on macOS don't need to fail the test
self.skipTest("Invalid argument error calling poll()")
raise
Expand Down
Loading