diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index ca930386..76f32174 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -2239,7 +2239,7 @@ def new_socket( def add_multicast_member( listen_socket: socket.socket, interface: Union[str, Tuple[Tuple[str, int, int], int]], -) -> None: +) -> bool: # This is based on assumptions in normalize_interface_choice is_v6 = isinstance(interface, tuple) err_einval = {errno.EINVAL} @@ -2263,19 +2263,20 @@ def add_multicast_member( 'it is expected to happen on some systems', interface, ) - return None + return False elif _errno == errno.EADDRNOTAVAIL: log.info( 'Address not available when adding %s to multicast ' 'group, it is expected to happen on some systems', interface, ) - return None + return False elif _errno in err_einval: log.info('Interface of %s does not support multicast, ' 'it is expected in WSL', interface) - return None + return False else: raise + return True def new_respond_socket( @@ -2323,8 +2324,10 @@ def create_sockets( for i in normalized_interfaces: if not unicast: - add_multicast_member(cast(socket.socket, listen_socket), i) - respond_socket = new_respond_socket(i, apple_p2p=apple_p2p) + if add_multicast_member(cast(socket.socket, listen_socket), i): + respond_socket = new_respond_socket(i, apple_p2p=apple_p2p) + else: + respond_socket = None else: respond_socket = new_socket( port=0, diff --git a/zeroconf/test.py b/zeroconf/test.py index 7349b42a..60bfb4d5 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -5,6 +5,7 @@ """ Unit tests for zeroconf.py """ import copy +import errno import itertools import logging import os @@ -16,8 +17,7 @@ import unittest import unittest.mock from threading import Event -from typing import Dict, Optional # noqa # used in type hints -from typing import cast +from typing import Dict, Optional, cast # noqa # used in type hints import pytest @@ -2105,3 +2105,18 @@ def test_dns_compression_rollback_for_corruption(): # ensure there is no corruption with the dns compression incoming = r.DNSIncoming(packet) assert incoming.valid is True + + +@pytest.mark.parametrize( + "errno,expected_result", + [(errno.EADDRINUSE, False), (errno.EADDRNOTAVAIL, False), (errno.EINVAL, False), (0, True)], +) +def test_add_multicast_member_socket_errors(errno, expected_result): + """Test we handle socket errors when adding multicast members.""" + if errno: + setsockopt_mock = unittest.mock.Mock(side_effect=OSError(errno, "Error: {}".format(errno))) + else: + setsockopt_mock = unittest.mock.Mock() + fileno_mock = unittest.mock.PropertyMock(return_value=10) + socket_mock = unittest.mock.Mock(setsockopt=setsockopt_mock, fileno=fileno_mock) + assert r.add_multicast_member(socket_mock, "0.0.0.0") == expected_result