diff --git a/src/zeroconf/_core.py b/src/zeroconf/_core.py index 5e3a7f46..63608090 100644 --- a/src/zeroconf/_core.py +++ b/src/zeroconf/_core.py @@ -178,10 +178,10 @@ def __init__( raise RuntimeError("Option `apple_p2p` is not supported on non-Apple platforms.") self.unicast = unicast - listen_socket, respond_sockets = create_sockets(interfaces, unicast, ip_version, apple_p2p=apple_p2p) - log.debug("Listen socket %s, respond sockets %s", listen_socket, respond_sockets) + listen_sockets, respond_sockets = create_sockets(interfaces, unicast, ip_version, apple_p2p=apple_p2p) + log.debug("Listen socket %s, respond sockets %s", listen_sockets, respond_sockets) - self.engine = AsyncEngine(self, listen_socket, respond_sockets) + self.engine = AsyncEngine(self, listen_sockets, respond_sockets) self.browsers: dict[ServiceListener, ServiceBrowser] = {} self.registry = ServiceRegistry() diff --git a/src/zeroconf/_engine.py b/src/zeroconf/_engine.py index 8a371e1e..903bd7bc 100644 --- a/src/zeroconf/_engine.py +++ b/src/zeroconf/_engine.py @@ -28,6 +28,7 @@ import threading from typing import TYPE_CHECKING, cast +from ._logger import log from ._record_update import RecordUpdate from ._utils.asyncio import get_running_loop, run_coro_with_timeout from ._utils.time import current_time_millis @@ -48,7 +49,7 @@ class AsyncEngine: __slots__ = ( "_cleanup_timer", - "_listen_socket", + "_listen_sockets", "_respond_sockets", "_setup_task", "loop", @@ -62,7 +63,7 @@ class AsyncEngine: def __init__( self, zeroconf: Zeroconf, - listen_socket: socket.socket | None, + listen_sockets: list[tuple[socket.socket, int]], respond_sockets: list[socket.socket], ) -> None: self.loop: asyncio.AbstractEventLoop | None = None @@ -71,7 +72,7 @@ def __init__( self.readers: list[_WrappedTransport] = [] self.senders: list[_WrappedTransport] = [] self.running_future: asyncio.Future[bool | None] | None = None - self._listen_socket = listen_socket + self._listen_sockets = listen_sockets self._respond_sockets = respond_sockets self._cleanup_timer: asyncio.TimerHandle | None = None self._setup_task: asyncio.Task[None] | None = None @@ -100,20 +101,27 @@ async def _async_create_endpoints(self) -> None: """Create endpoints to send and receive.""" assert self.loop is not None loop = self.loop - reader_sockets = [] + reader_socket_tuples = self._listen_sockets.copy() sender_sockets = [] - if self._listen_socket: - reader_sockets.append(self._listen_socket) + reader_sockets = (t[0] for t in reader_socket_tuples) for s in self._respond_sockets: if s not in reader_sockets: - reader_sockets.append(s) + reader_socket_tuples.append((s, 0)) sender_sockets.append(s) - for s in reader_sockets: + log.info( + "Creating %d reader sockets (%s) and %d sender sockets", + len(reader_socket_tuples), + reader_socket_tuples, + len(sender_sockets), + ) + for s, interface_idx in reader_socket_tuples: + log.debug("Creating endpoint for socket %s", s) transport, protocol = await loop.create_datagram_endpoint( # type: ignore[type-var] - lambda: AsyncListener(self.zc), # type: ignore[arg-type, return-value] + lambda: AsyncListener(self.zc, interface_idx), # type: ignore[arg-type, return-value] sock=s, ) + log.debug("Creating endpoint for socket %s, transport %s, protocol %s", s, transport, protocol) self.protocols.append(cast(AsyncListener, protocol)) self.readers.append(make_wrapped_transport(cast(asyncio.DatagramTransport, transport))) if s in sender_sockets: diff --git a/src/zeroconf/_listener.pxd b/src/zeroconf/_listener.pxd index 20084b47..6133169e 100644 --- a/src/zeroconf/_listener.pxd +++ b/src/zeroconf/_listener.pxd @@ -19,6 +19,7 @@ cdef cython.uint _DUPLICATE_PACKET_SUPPRESSION_INTERVAL cdef class AsyncListener: cdef public object zc + cdef public cython.uint interface_idx cdef ServiceRegistry _registry cdef RecordManager _record_manager cdef QueryHandler _query_handler diff --git a/src/zeroconf/_listener.py b/src/zeroconf/_listener.py index ed503169..f5964bb5 100644 --- a/src/zeroconf/_listener.py +++ b/src/zeroconf/_listener.py @@ -63,6 +63,7 @@ class AsyncListener: "_registry", "_timers", "data", + "interface_idx", "last_message", "last_time", "sock_description", @@ -70,8 +71,9 @@ class AsyncListener: "zc", ) - def __init__(self, zc: Zeroconf) -> None: + def __init__(self, zc: Zeroconf, interface_idx: int = 0) -> None: self.zc = zc + self.interface_idx = interface_idx self._registry = zc.registry self._record_manager = zc.record_manager self._query_handler = zc.query_handler @@ -139,12 +141,10 @@ def _process_datagram_at_time( else: # https://github.com/python/mypy/issues/1178 addr, port, flow, scope = addrs - if debug: # pragma: no branch - log.debug("IPv6 scope_id %d associated to the receiving interface", scope) v6_flow_scope = (flow, scope) addr_port = (addr, port) - msg = DNSIncoming(data, addr_port, scope, now) + msg = DNSIncoming(data, addr_port, self.interface_idx, now) self.data = data self.last_time = now self.last_message = msg diff --git a/src/zeroconf/_services/info.py b/src/zeroconf/_services/info.py index fff9e125..8d52160d 100644 --- a/src/zeroconf/_services/info.py +++ b/src/zeroconf/_services/info.py @@ -183,6 +183,7 @@ def __init__( interface_index: int | None = None, ) -> None: # Accept both none, or one, but not both. + log.info("ServiceInfo.__init__() called with %s", properties) if addresses is not None and parsed_addresses is not None: raise TypeError("addresses and parsed_addresses cannot be provided together") if not type_.endswith(service_type_name(name, strict=False)): @@ -505,6 +506,8 @@ def _process_record_threadsafe(self, zc: Zeroconf, record: DNSRecord, now: float if TYPE_CHECKING: assert isinstance(dns_address_record, DNSAddress) ip_addr = get_ip_address_object_from_record(dns_address_record) + + # log.info("Got ip addr: %r with scope %d from %r", ip_addr, ip_addr.scope_id if ip_addr.scope_id else 0, dns_address_record) if ip_addr is None: log.warning( "Encountered invalid address while processing %s: %s", @@ -533,6 +536,9 @@ def _process_record_threadsafe(self, zc: Zeroconf, record: DNSRecord, now: float if TYPE_CHECKING: assert isinstance(ip_addr, ZeroconfIPv6Address) ipv6_addresses = self._ipv6_addresses + if ip_addr.is_link_local and not ip_addr.scope_id: + log.debug("Ignoring link-local address without scope %s", ip_addr) + return False if ip_addr not in self._ipv6_addresses: ipv6_addresses.insert(0, ip_addr) return True @@ -838,6 +844,7 @@ async def async_request( :param addr: address to send the request to :param port: port to send the request to """ + log.info("Asking for %s %s", question_type, self._name) if not zc.started: await zc.async_wait_for_start() @@ -860,6 +867,7 @@ async def async_request( return False if next_ <= now: this_question_type = question_type or (QU_QUESTION if first_request else QM_QUESTION) + log.info("Generating request for %s %s", this_question_type, self._name) out = self._generate_request_query(zc, now, this_question_type) first_request = False if out.questions: diff --git a/src/zeroconf/_utils/ipaddress.py b/src/zeroconf/_utils/ipaddress.py index d172d0c9..24cf5ccd 100644 --- a/src/zeroconf/_utils/ipaddress.py +++ b/src/zeroconf/_utils/ipaddress.py @@ -27,6 +27,7 @@ from typing import Any from .._dns import DNSAddress +from .._logger import log from ..const import _TYPE_AAAA bytes_ = bytes @@ -123,6 +124,14 @@ def get_ip_address_object_from_record( record: DNSAddress, ) -> ZeroconfIPv4Address | ZeroconfIPv6Address | None: """Get the IP address object from the record.""" + + log.info( + "Got ip addr: %r from %r with scope %d", + record.address, + record, + record.scope_id if record.scope_id else 0, + ) + if record.type == _TYPE_AAAA and record.scope_id: return ip_bytes_and_scope_to_address(record.address, record.scope_id) return cached_ip_addresses_wrapper(record.address) diff --git a/src/zeroconf/_utils/net.py b/src/zeroconf/_utils/net.py index e67edf78..a8d43041 100644 --- a/src/zeroconf/_utils/net.py +++ b/src/zeroconf/_utils/net.py @@ -30,6 +30,7 @@ import sys import warnings from collections.abc import Iterable, Sequence +from functools import lru_cache from typing import Any, Union, cast import ifaddr @@ -74,8 +75,16 @@ def _encode_address(address: str) -> bytes: return socket.inet_pton(address_family, address) +@lru_cache(maxsize=512) +def _get_interface_name(iface: int) -> str: + """Get the interface name from the interface index.""" + return socket.if_indextoname(iface) + + def get_all_addresses_ipv4(adapters: Iterable[ifaddr.Adapter]) -> list[str]: - return list({addr.ip for iface in adapters for addr in iface.ips if addr.is_IPv4}) # type: ignore[misc] + return list( + {(addr.ip, iface.index) for iface in ifaddr.get_adapters() for addr in iface.ips if addr.is_IPv4} + ) # type: ignore[misc] def get_all_addresses_ipv6(adapters: Iterable[ifaddr.Adapter]) -> list[tuple[tuple[str, int, int], int]]: @@ -86,6 +95,17 @@ def get_all_addresses_ipv6(adapters: Iterable[ifaddr.Adapter]) -> list[tuple[tup ) +def get_all_addresses_ipv6_link_local( + adapters: Iterable[ifaddr.Adapter], +) -> list[tuple[tuple[str, int, int], int]]: + return [ + (ip.ip, iface.index) + for iface in adapters + for ip in iface.ips + if ip.is_IPv6 and ip.ip[2] == iface.index + ] + + def get_all_addresses() -> list[str]: warnings.warn( "get_all_addresses is deprecated, and will be removed in a future version. Use ifaddr" @@ -106,20 +126,36 @@ def get_all_addresses_v6() -> list[tuple[tuple[str, int, int], int]]: return get_all_addresses_ipv6(ifaddr.get_adapters()) -def ip6_to_address_and_index(adapters: Iterable[ifaddr.Adapter], ip: str) -> tuple[tuple[str, int, int], int]: - if "%" in ip: - ip = ip[: ip.index("%")] # Strip scope_id. - ipaddr = ipaddress.ip_address(ip) +def ip_to_address_and_index( + adapters: Iterable[ifaddr.Adapter], ip: str +) -> tuple[tuple[str, int, int], int] | tuple[str, int]: + """Convert IP to IP address and interface index. + + :param adapters: List of adapters. + :param ip: IP address to convert. If the address is an IPv6 with scope id, the scope id is being validated. + :returns: Tuple of IP address and interface index. + """ + addr, sep, scope_id = ip.partition("%") + if not sep: + scope_id = None + + ipaddr = ipaddress.ip_address(addr) for adapter in adapters: for adapter_ip in adapter.ips: + if adapter.index is None: + continue # IPv6 addresses are represented as tuples if ( - adapter.index is not None - and isinstance(adapter_ip.ip, tuple) + isinstance(adapter_ip.ip, tuple) and ipaddress.ip_address(adapter_ip.ip[0]) == ipaddr + and (scope_id is None or adapter_ip.ip[2] == int(scope_id)) ): return (adapter_ip.ip, adapter.index) + # IPv4 addresses are represented as strings + if isinstance(adapter_ip.ip, str) and ipaddress.ip_address(adapter_ip.ip) == ipaddr: + return (adapter_ip.ip, adapter.index) + raise RuntimeError(f"No adapter found for IP address {ip}") @@ -134,12 +170,10 @@ def interface_index_to_ip6_address(adapters: Iterable[ifaddr.Adapter], index: in raise RuntimeError(f"No adapter found for index {index}") -def ip6_addresses_to_indexes( +def ip_addresses_to_indexes( interfaces: Sequence[str | int | tuple[tuple[str, int, int], int]], ) -> list[tuple[tuple[str, int, int], int]]: - """Convert IPv6 interface addresses to interface indexes. - - IPv4 addresses are ignored. + """Convert IPv4 and IPv6 interface addresses and bare interface indexes to addresses and interface indexes. :param interfaces: List of IP addresses and indexes. :returns: List of indexes. @@ -150,22 +184,22 @@ def ip6_addresses_to_indexes( for iface in interfaces: if isinstance(iface, int): result.append((interface_index_to_ip6_address(adapters, iface), iface)) # type: ignore[arg-type] - elif isinstance(iface, str) and ipaddress.ip_address(iface).version == 6: - result.append(ip6_to_address_and_index(adapters, iface)) # type: ignore[arg-type] + elif isinstance(iface, str): + result.append(ip_to_address_and_index(adapters, iface)) # type: ignore[arg-type] return result def normalize_interface_choice( choice: InterfacesType, ip_version: IPVersion = IPVersion.V4Only -) -> list[str | tuple[tuple[str, int, int], int]]: +) -> list[tuple[str, int] | tuple[tuple[str, int, int], int]]: """Convert the interfaces choice into internal representation. :param choice: `InterfaceChoice` or list of interface addresses or indexes (IPv6 only). :param ip_address: IP version to use (ignored if `choice` is a list). :returns: List of IP addresses (for IPv4) and indexes (for IPv6). """ - result: list[str | tuple[tuple[str, int, int], int]] = [] + result: list[tuple[str, int] | tuple[tuple[str, int, int], int]] = [] if choice is InterfaceChoice.Default: if ip_version != IPVersion.V4Only: # IPv6 multicast uses interface 0 to mean the default. However, @@ -184,7 +218,7 @@ def normalize_interface_choice( elif choice is InterfaceChoice.All: adapters = ifaddr.get_adapters() if ip_version != IPVersion.V4Only: - result.extend(get_all_addresses_ipv6(adapters)) + result.extend(get_all_addresses_ipv6_link_local(adapters)) if ip_version != IPVersion.V6Only: result.extend(get_all_addresses_ipv4(adapters)) if not result: @@ -192,19 +226,16 @@ def normalize_interface_choice( f"No interfaces to listen on, check that any interfaces have IP version {ip_version}" ) elif isinstance(choice, list): - # First, take IPv4 addresses. - result = [i for i in choice if isinstance(i, str) and ipaddress.ip_address(i).version == 4] - # Unlike IP_ADD_MEMBERSHIP, IPV6_JOIN_GROUP requires interface indexes. - result += ip6_addresses_to_indexes(choice) + result = ip_addresses_to_indexes(choice) else: raise TypeError(f"choice must be a list or InterfaceChoice, got {choice!r}") return result -def disable_ipv6_only_or_raise(s: socket.socket) -> None: +def ipv6_only_or_raise(s: socket.socket, value: bool) -> None: """Make V6 sockets work for both V4 and V6 (required for Windows).""" try: - s.setsockopt(_IPPROTO_IPV6, socket.IPV6_V6ONLY, False) + s.setsockopt(_IPPROTO_IPV6, socket.IPV6_V6ONLY, value) except OSError: log.error("Support for dual V4-V6 sockets is not present, use IPVersion.V4 or IPVersion.V6") raise @@ -275,7 +306,9 @@ def new_socket( s = socket.socket(socket_family, socket.SOCK_DGRAM) if ip_version == IPVersion.All: - disable_ipv6_only_or_raise(s) + ipv6_only_or_raise(s, False) + if ip_version == IPVersion.V6Only: + ipv6_only_or_raise(s, True) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) set_so_reuseport_if_available(s) @@ -287,6 +320,7 @@ def new_socket( # Bind expects (address, port) for AF_INET and (address, port, flowinfo, scope_id) for AF_INET6 bind_tup = (bind_addr[0], port, *bind_addr[1:]) + log.info("Binding to %s", bind_tup) try: s.bind(bind_tup) except OSError as ex: @@ -320,18 +354,37 @@ def new_socket( return s +def _bind_to_interface( + s: socket.socket, + ifaceidx: int, +) -> None: + if hasattr(socket, "SO_BINDTODEVICE"): + iface = _get_interface_name(ifaceidx).encode() + s.setsockopt(socket.SOL_SOCKET, socket.SO_BINDTODEVICE, iface) + elif sys.platform != "win32": + log.warning( + "SO_BINDTODEVICE not available on this platform, " + "this may cause duplicated reception with multicast on some systems", + ) + else: + # SO_BINDTODEVICE is not available on Windows + # TODO: Check if bind by a IP of the interface in question is maybe enough on Windows + pass + + def add_multicast_member( listen_socket: socket.socket, - interface: str | tuple[tuple[str, int, int], int], + interface: tuple[str, int] | tuple[tuple[str, int, int], int], ) -> bool: # This is based on assumptions in normalize_interface_choice - is_v6 = isinstance(interface, tuple) + is_v6 = isinstance(interface[0], tuple) err_einval = {errno.EINVAL} if sys.platform == "win32": # No WSAEINVAL definition in typeshed err_einval |= {cast(Any, errno).WSAEINVAL} # pylint: disable=no-member log.debug("Adding %r (socket %d) to multicast group", interface, listen_socket.fileno()) try: + _bind_to_interface(listen_socket, interface[1]) if is_v6: try: mdns_addr6_bytes = socket.inet_pton(socket.AF_INET6, _MDNS_ADDR6) @@ -342,12 +395,13 @@ def add_multicast_member( interface, ) return False + log.info("IPV6_JOIN_GROUP to interface %d", interface[1]) iface_bin = struct.pack("@I", cast(int, interface[1])) _value = mdns_addr6_bytes + iface_bin listen_socket.setsockopt(_IPPROTO_IPV6, socket.IPV6_JOIN_GROUP, _value) else: - _value = socket.inet_aton(_MDNS_ADDR) + socket.inet_aton(cast(str, interface)) - listen_socket.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, _value) + mreq = struct.pack("=4s4s", socket.inet_aton(_MDNS_ADDR), socket.inet_aton(interface[0])) + listen_socket.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, mreq) except OSError as e: _errno = get_errno(e) if _errno == errno.EADDRINUSE: @@ -402,19 +456,20 @@ def add_multicast_member( def new_respond_socket( - interface: str | tuple[tuple[str, int, int], int], + interface: tuple[str, int] | tuple[tuple[str, int, int], int], apple_p2p: bool = False, unicast: bool = False, ) -> socket.socket | None: """Create interface specific socket for responding to multicast queries.""" - is_v6 = isinstance(interface, tuple) + is_v6 = isinstance(interface[0], tuple) # For response sockets: # - Bind explicitly to the interface address # - Use ephemeral ports if in unicast mode # - Create socket according to the interface IP type (IPv4 or IPv6) respond_socket = new_socket( - bind_addr=cast(tuple[tuple[str, int, int], int], interface)[0] if is_v6 else (cast(str, interface),), + # bind_addr=cast(tuple[tuple[str, int, int], int], interface)[0] if is_v6 else (cast(tuple[str, int], interface)[0],), + bind_addr=("::", 0, interface[1]), port=0 if unicast else _MDNS_PORT, ip_version=(IPVersion.V6Only if is_v6 else IPVersion.V4Only), apple_p2p=apple_p2p, @@ -433,7 +488,7 @@ def new_respond_socket( respond_socket.setsockopt( socket.IPPROTO_IP, socket.IP_MULTICAST_IF, - socket.inet_aton(cast(str, interface)), + socket.inet_aton(cast(tuple[str, int], interface)[0]), ) set_respond_socket_multicast_options(respond_socket, IPVersion.V6Only if is_v6 else IPVersion.V4Only) return respond_socket @@ -444,36 +499,28 @@ def create_sockets( unicast: bool = False, ip_version: IPVersion = IPVersion.V4Only, apple_p2p: bool = False, -) -> tuple[socket.socket | None, list[socket.socket]]: - if unicast: - listen_socket = None - else: - listen_socket = new_socket(bind_addr=("",), ip_version=ip_version, apple_p2p=apple_p2p) - +) -> tuple[list[tuple[socket.socket, int]], list[socket.socket]]: normalized_interfaces = normalize_interface_choice(interfaces, ip_version) + log.info("Creating sockets for interfaces %s", normalized_interfaces) - # If we are using InterfaceChoice.Default with only IPv4 or only IPv6, we can use - # a single socket to listen and respond. - if not unicast and interfaces is InterfaceChoice.Default and ip_version != IPVersion.All: - for interface in normalized_interfaces: - add_multicast_member(cast(socket.socket, listen_socket), interface) - # Sent responder socket options to the dual-use listen socket - set_respond_socket_multicast_options(cast(socket.socket, listen_socket), ip_version) - return listen_socket, [cast(socket.socket, listen_socket)] - + listen_sockets = [] respond_sockets = [] for interface in normalized_interfaces: # Only create response socket if unicast or becoming multicast member was successful - if not unicast and not add_multicast_member(cast(socket.socket, listen_socket), interface): + sck = new_respond_socket(interface, apple_p2p=apple_p2p, unicast=unicast) + if not sck: continue - respond_socket = new_respond_socket(interface, apple_p2p=apple_p2p, unicast=unicast) + if not unicast: + if not add_multicast_member(cast(socket.socket, sck), interface): + sck.close() + continue + listen_sockets.append((sck, interface[1])) - if respond_socket is not None: - respond_sockets.append(respond_socket) + respond_sockets.append(sck) - return listen_socket, respond_sockets + return listen_sockets, respond_sockets def get_errno(e: OSError) -> int: diff --git a/tests/utils/test_net.py b/tests/utils/test_net.py index 7de10661..98fbc9d0 100644 --- a/tests/utils/test_net.py +++ b/tests/utils/test_net.py @@ -107,14 +107,14 @@ def test_ip6_addresses_to_indexes(): "zeroconf._utils.net.ifaddr.get_adapters", return_value=_generate_mock_adapters(), ): - assert netutils.ip6_addresses_to_indexes(interfaces) == [(("2001:db8::", 1, 1), 1)] + assert netutils.ip_addresses_to_indexes(interfaces) == [(("2001:db8::", 1, 1), 1)] interfaces_2 = ["2001:db8::"] with patch( "zeroconf._utils.net.ifaddr.get_adapters", return_value=_generate_mock_adapters(), ): - assert netutils.ip6_addresses_to_indexes(interfaces_2) == [(("2001:db8::", 1, 1), 1)] + assert netutils.ip_addresses_to_indexes(interfaces_2) == [(("2001:db8::", 1, 1), 1)] def test_normalize_interface_choice_errors():