From 933789baf2892a399fe560f1addded8e1ee48b46 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 13 Aug 2023 15:26:04 -0500 Subject: [PATCH 1/6] feat: speed up processing incoming records add cython defs for set_created_ttl and reset_ttl --- src/zeroconf/_dns.pxd | 2 ++ src/zeroconf/_handlers.py | 15 +++++++++------ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/zeroconf/_dns.pxd b/src/zeroconf/_dns.pxd index 289cd1a1..969c4386 100644 --- a/src/zeroconf/_dns.pxd +++ b/src/zeroconf/_dns.pxd @@ -54,6 +54,8 @@ cdef class DNSRecord(DNSEntry): cpdef reset_ttl(self, DNSRecord other) + cpdef set_created_ttl(self, cython.float now, cython.float ttl) + cdef class DNSAddress(DNSRecord): cdef public cython.int _hash diff --git a/src/zeroconf/_handlers.py b/src/zeroconf/_handlers.py index 192496c1..be0d619f 100644 --- a/src/zeroconf/_handlers.py +++ b/src/zeroconf/_handlers.py @@ -408,6 +408,7 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: removes: Set[DNSRecord] = set() now = msg.now unique_types: Set[Tuple[str, int, int]] = set() + cache = self.cache for record in msg.answers: # Protect zeroconf from records that can cause denial of service. @@ -416,7 +417,9 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: # ServiceBrowsers generating excessive queries refresh queries. # Apple uses a 15s minimum TTL, however we do not have the same # level of rate limit and safe guards so we use 1/4 of the recommended value. - if record.ttl and record.type == _TYPE_PTR and record.ttl < _DNS_PTR_MIN_TTL: + record_type = record.type + record_ttl = record.ttl + if record_ttl and record_type == _TYPE_PTR and record_ttl < _DNS_PTR_MIN_TTL: log.debug( "Increasing effective ttl of %s to minimum of %s to protect against excessive refreshes.", record, @@ -425,12 +428,12 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: record.set_created_ttl(record.created, _DNS_PTR_MIN_TTL) if record.unique: # https://tools.ietf.org/html/rfc6762#section-10.2 - unique_types.add((record.name, record.type, record.class_)) + unique_types.add((record.name, record_type, record.class_)) if TYPE_CHECKING: record = cast(_UniqueRecordsType, record) - maybe_entry = self.cache.async_get_unique(record) + maybe_entry = cache.async_get_unique(record) if not record.is_expired(now): if maybe_entry is not None: maybe_entry.reset_ttl(record) @@ -447,7 +450,7 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: removes.add(record) if unique_types: - self.cache.async_mark_unique_records_older_than_1s_to_expire(unique_types, msg.answers, now) + cache.async_mark_unique_records_older_than_1s_to_expire(unique_types, msg.answers, now) if updates: self.async_updates(now, updates) @@ -468,12 +471,12 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: # processsed. new = False if other_adds or address_adds: - new = self.cache.async_add_records(itertools.chain(address_adds, other_adds)) + new = cache.async_add_records(itertools.chain(address_adds, other_adds)) # Removes are processed last since # ServiceInfo could generate an un-needed query # because the data was not yet populated. if removes: - self.cache.async_remove_records(removes) + cache.async_remove_records(removes) if updates: self.async_updates_complete(new) From 2e7fefbd987d203c0d9e6180fb5acf26f023fde6 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 13 Aug 2023 15:31:01 -0500 Subject: [PATCH 2/6] feat: speed up processing incoming records add cython defs for set_created_ttl and reset_ttl --- src/zeroconf/_dns.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zeroconf/_dns.py b/src/zeroconf/_dns.py index 6c34f9dd..8e083f33 100644 --- a/src/zeroconf/_dns.py +++ b/src/zeroconf/_dns.py @@ -212,7 +212,7 @@ def reset_ttl(self, other) -> None: # type: ignore[no-untyped-def] another record.""" self.set_created_ttl(other.created, other.ttl) - def set_created_ttl(self, created: float, ttl: Union[float, int]) -> None: + def set_created_ttl(self, created: _float, ttl: Union[float, int]) -> None: """Set the created and ttl of a record.""" self.created = created self.ttl = ttl From 35600fe83f9651adf56f7c33185b67b436bb9c88 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 13 Aug 2023 15:40:31 -0500 Subject: [PATCH 3/6] feat: speed up processing incoming records add cython defs for set_created_ttl and reset_ttl --- src/zeroconf/_dns.pxd | 5 +++++ src/zeroconf/_dns.py | 7 +++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/zeroconf/_dns.pxd b/src/zeroconf/_dns.pxd index 969c4386..d0cae879 100644 --- a/src/zeroconf/_dns.pxd +++ b/src/zeroconf/_dns.pxd @@ -44,6 +44,11 @@ cdef class DNSRecord(DNSEntry): ) cpdef suppressed_by(self, object msg) + @cython.locals( + remaining=cython.float, + ) + cpdef get_remaining_ttl(self, cython.float now) + cpdef get_expiration_time(self, cython.uint percent) cpdef is_expired(self, cython.float now) diff --git a/src/zeroconf/_dns.py b/src/zeroconf/_dns.py index 8e083f33..18fb5235 100644 --- a/src/zeroconf/_dns.py +++ b/src/zeroconf/_dns.py @@ -26,7 +26,7 @@ from ._exceptions import AbstractMethodException from ._utils.net import _is_v6_address -from ._utils.time import current_time_millis, millis_to_seconds +from ._utils.time import current_time_millis from .const import _CLASS_MASK, _CLASS_UNIQUE, _CLASSES, _TYPE_ANY, _TYPES _LEN_BYTE = 1 @@ -193,7 +193,10 @@ def get_expiration_time(self, percent: _int) -> float: # TODO: Switch to just int here def get_remaining_ttl(self, now: _float) -> Union[int, float]: """Returns the remaining TTL in seconds.""" - return max(0, millis_to_seconds((self.created + (_EXPIRE_FULL_TIME_MS * self.ttl)) - now)) + remaining = self.created + (_EXPIRE_FULL_TIME_MS * self.ttl) - now + if remaining < 0: + return 0 + return remaining / 1000.0 def is_expired(self, now: _float) -> bool: """Returns true if this record has expired.""" From 0b7602822de5a678219d68f36f9056d9844707cb Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 13 Aug 2023 15:49:13 -0500 Subject: [PATCH 4/6] fix: tweak --- src/zeroconf/_dns.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/zeroconf/_dns.py b/src/zeroconf/_dns.py index 18fb5235..73b0c751 100644 --- a/src/zeroconf/_dns.py +++ b/src/zeroconf/_dns.py @@ -193,10 +193,8 @@ def get_expiration_time(self, percent: _int) -> float: # TODO: Switch to just int here def get_remaining_ttl(self, now: _float) -> Union[int, float]: """Returns the remaining TTL in seconds.""" - remaining = self.created + (_EXPIRE_FULL_TIME_MS * self.ttl) - now - if remaining < 0: - return 0 - return remaining / 1000.0 + remain = (self.created + (_EXPIRE_FULL_TIME_MS * self.ttl) - now) / 1000.0 + return 0 if remain < 0 else remain def is_expired(self, now: _float) -> bool: """Returns true if this record has expired.""" From 4a16180aabd00e12b0d0e83819f75f25fc499ced Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 13 Aug 2023 15:52:13 -0500 Subject: [PATCH 5/6] Update src/zeroconf/_dns.pxd --- src/zeroconf/_dns.pxd | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/zeroconf/_dns.pxd b/src/zeroconf/_dns.pxd index d0cae879..5622a5ed 100644 --- a/src/zeroconf/_dns.pxd +++ b/src/zeroconf/_dns.pxd @@ -44,9 +44,6 @@ cdef class DNSRecord(DNSEntry): ) cpdef suppressed_by(self, object msg) - @cython.locals( - remaining=cython.float, - ) cpdef get_remaining_ttl(self, cython.float now) cpdef get_expiration_time(self, cython.uint percent) From 180f1be201bce9ff24d5226d1231582b4b550431 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 13 Aug 2023 20:17:43 -0500 Subject: [PATCH 6/6] feat: improve performance responding to queries --- src/zeroconf/_services/info.py | 49 +++++++++++++++++++++------------- src/zeroconf/const.py | 1 + 2 files changed, 31 insertions(+), 19 deletions(-) diff --git a/src/zeroconf/_services/info.py b/src/zeroconf/_services/info.py index 02b7137a..5e2bec4c 100644 --- a/src/zeroconf/_services/info.py +++ b/src/zeroconf/_services/info.py @@ -46,7 +46,7 @@ from ..const import ( _ADDRESS_RECORD_TYPES, _CLASS_IN, - _CLASS_UNIQUE, + _CLASS_IN_UNIQUE, _DNS_HOST_TTL, _DNS_OTHER_TTL, _FLAGS_QR_QUERY, @@ -409,15 +409,21 @@ def _get_ip_addresses_from_cache_lifo( def _set_ipv6_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None: """Set IPv6 addresses from the cache.""" - self._ipv6_addresses = cast( - "List[IPv6Address]", self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_AAAA) - ) + if TYPE_CHECKING: + self._ipv6_addresses = cast( + "List[IPv6Address]", self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_AAAA) + ) + else: + self._ipv6_addresses = self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_AAAA) def _set_ipv4_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None: """Set IPv4 addresses from the cache.""" - self._ipv4_addresses = cast( - "List[IPv4Address]", self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_A) - ) + if TYPE_CHECKING: + self._ipv4_addresses = cast( + "List[IPv4Address]", self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_A) + ) + else: + self._ipv4_addresses = self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_A) def update_record(self, zc: 'Zeroconf', now: float, record: Optional[DNSRecord]) -> None: """Updates service information from a DNS record. @@ -525,7 +531,7 @@ def dns_addresses( """Return matching DNSAddress from ServiceInfo.""" name = self.server or self.name ttl = override_ttl if override_ttl is not None else self.host_ttl - class_ = _CLASS_IN | _CLASS_UNIQUE + class_ = _CLASS_IN_UNIQUE version_value = version.value return [ DNSAddress( @@ -552,14 +558,17 @@ def dns_pointer(self, override_ttl: Optional[int] = None, created: Optional[floa def dns_service(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSService: """Return DNSService from ServiceInfo.""" + port = self.port + if TYPE_CHECKING: + assert isinstance(port, int) return DNSService( self.name, _TYPE_SRV, - _CLASS_IN | _CLASS_UNIQUE, + _CLASS_IN_UNIQUE, override_ttl if override_ttl is not None else self.host_ttl, self.priority, self.weight, - cast(int, self.port), + port, self.server or self.name, created, ) @@ -569,7 +578,7 @@ def dns_text(self, override_ttl: Optional[int] = None, created: Optional[float] return DNSText( self.name, _TYPE_TXT, - _CLASS_IN | _CLASS_UNIQUE, + _CLASS_IN_UNIQUE, override_ttl if override_ttl is not None else self.other_ttl, self.text, created, @@ -582,7 +591,7 @@ def dns_nsec( return DNSNsec( self.name, _TYPE_NSEC, - _CLASS_IN | _CLASS_UNIQUE, + _CLASS_IN_UNIQUE, override_ttl if override_ttl is not None else self.host_ttl, self.name, missing_types, @@ -593,12 +602,11 @@ def get_address_and_nsec_records( self, override_ttl: Optional[int] = None, created: Optional[float] = None ) -> Set[DNSRecord]: """Build a set of address records and NSEC records for non-present record types.""" - seen_types: Set[int] = set() + missing_types: Set[int] = _ADDRESS_RECORD_TYPES.copy() records: Set[DNSRecord] = set() for dns_address in self.dns_addresses(override_ttl, IPVersion.All, created): - seen_types.add(dns_address.type) + missing_types.discard(dns_address.type) records.add(dns_address) - missing_types: Set[int] = _ADDRESS_RECORD_TYPES - seen_types if missing_types: assert self.server is not None, "Service server must be set for NSEC record." records.add(self.dns_nsec(list(missing_types), override_ttl, created)) @@ -732,10 +740,13 @@ def generate_request_query( ) -> DNSOutgoing: """Generate the request query.""" out = DNSOutgoing(_FLAGS_QR_QUERY) - out.add_question_or_one_cache(zc.cache, now, self.name, _TYPE_SRV, _CLASS_IN) - out.add_question_or_one_cache(zc.cache, now, self.name, _TYPE_TXT, _CLASS_IN) - out.add_question_or_all_cache(zc.cache, now, self.server or self.name, _TYPE_A, _CLASS_IN) - out.add_question_or_all_cache(zc.cache, now, self.server or self.name, _TYPE_AAAA, _CLASS_IN) + name = self.name + server_or_name = self.server or name + cache = zc.cache + out.add_question_or_one_cache(cache, now, name, _TYPE_SRV, _CLASS_IN) + out.add_question_or_one_cache(cache, now, name, _TYPE_TXT, _CLASS_IN) + out.add_question_or_all_cache(cache, now, server_or_name, _TYPE_A, _CLASS_IN) + out.add_question_or_all_cache(cache, now, server_or_name, _TYPE_AAAA, _CLASS_IN) if question_type == DNSQuestionType.QU: for question in out.questions: question.unicast = True diff --git a/src/zeroconf/const.py b/src/zeroconf/const.py index f87c1336..ca199df5 100644 --- a/src/zeroconf/const.py +++ b/src/zeroconf/const.py @@ -84,6 +84,7 @@ _CLASS_ANY = 255 _CLASS_MASK = 0x7FFF _CLASS_UNIQUE = 0x8000 +_CLASS_IN_UNIQUE = _CLASS_IN | _CLASS_UNIQUE _TYPE_A = 1 _TYPE_NS = 2