diff --git a/src/zeroconf/_dns.pxd b/src/zeroconf/_dns.pxd index 289cd1a1..5622a5ed 100644 --- a/src/zeroconf/_dns.pxd +++ b/src/zeroconf/_dns.pxd @@ -44,6 +44,8 @@ cdef class DNSRecord(DNSEntry): ) cpdef suppressed_by(self, object msg) + cpdef get_remaining_ttl(self, cython.float now) + cpdef get_expiration_time(self, cython.uint percent) cpdef is_expired(self, cython.float now) @@ -54,6 +56,8 @@ cdef class DNSRecord(DNSEntry): cpdef reset_ttl(self, DNSRecord other) + cpdef set_created_ttl(self, cython.float now, cython.float ttl) + cdef class DNSAddress(DNSRecord): cdef public cython.int _hash diff --git a/src/zeroconf/_dns.py b/src/zeroconf/_dns.py index 6c34f9dd..73b0c751 100644 --- a/src/zeroconf/_dns.py +++ b/src/zeroconf/_dns.py @@ -26,7 +26,7 @@ from ._exceptions import AbstractMethodException from ._utils.net import _is_v6_address -from ._utils.time import current_time_millis, millis_to_seconds +from ._utils.time import current_time_millis from .const import _CLASS_MASK, _CLASS_UNIQUE, _CLASSES, _TYPE_ANY, _TYPES _LEN_BYTE = 1 @@ -193,7 +193,8 @@ def get_expiration_time(self, percent: _int) -> float: # TODO: Switch to just int here def get_remaining_ttl(self, now: _float) -> Union[int, float]: """Returns the remaining TTL in seconds.""" - return max(0, millis_to_seconds((self.created + (_EXPIRE_FULL_TIME_MS * self.ttl)) - now)) + remain = (self.created + (_EXPIRE_FULL_TIME_MS * self.ttl) - now) / 1000.0 + return 0 if remain < 0 else remain def is_expired(self, now: _float) -> bool: """Returns true if this record has expired.""" @@ -212,7 +213,7 @@ def reset_ttl(self, other) -> None: # type: ignore[no-untyped-def] another record.""" self.set_created_ttl(other.created, other.ttl) - def set_created_ttl(self, created: float, ttl: Union[float, int]) -> None: + def set_created_ttl(self, created: _float, ttl: Union[float, int]) -> None: """Set the created and ttl of a record.""" self.created = created self.ttl = ttl diff --git a/src/zeroconf/_handlers.py b/src/zeroconf/_handlers.py index 192496c1..be0d619f 100644 --- a/src/zeroconf/_handlers.py +++ b/src/zeroconf/_handlers.py @@ -408,6 +408,7 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: removes: Set[DNSRecord] = set() now = msg.now unique_types: Set[Tuple[str, int, int]] = set() + cache = self.cache for record in msg.answers: # Protect zeroconf from records that can cause denial of service. @@ -416,7 +417,9 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: # ServiceBrowsers generating excessive queries refresh queries. # Apple uses a 15s minimum TTL, however we do not have the same # level of rate limit and safe guards so we use 1/4 of the recommended value. - if record.ttl and record.type == _TYPE_PTR and record.ttl < _DNS_PTR_MIN_TTL: + record_type = record.type + record_ttl = record.ttl + if record_ttl and record_type == _TYPE_PTR and record_ttl < _DNS_PTR_MIN_TTL: log.debug( "Increasing effective ttl of %s to minimum of %s to protect against excessive refreshes.", record, @@ -425,12 +428,12 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: record.set_created_ttl(record.created, _DNS_PTR_MIN_TTL) if record.unique: # https://tools.ietf.org/html/rfc6762#section-10.2 - unique_types.add((record.name, record.type, record.class_)) + unique_types.add((record.name, record_type, record.class_)) if TYPE_CHECKING: record = cast(_UniqueRecordsType, record) - maybe_entry = self.cache.async_get_unique(record) + maybe_entry = cache.async_get_unique(record) if not record.is_expired(now): if maybe_entry is not None: maybe_entry.reset_ttl(record) @@ -447,7 +450,7 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: removes.add(record) if unique_types: - self.cache.async_mark_unique_records_older_than_1s_to_expire(unique_types, msg.answers, now) + cache.async_mark_unique_records_older_than_1s_to_expire(unique_types, msg.answers, now) if updates: self.async_updates(now, updates) @@ -468,12 +471,12 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: # processsed. new = False if other_adds or address_adds: - new = self.cache.async_add_records(itertools.chain(address_adds, other_adds)) + new = cache.async_add_records(itertools.chain(address_adds, other_adds)) # Removes are processed last since # ServiceInfo could generate an un-needed query # because the data was not yet populated. if removes: - self.cache.async_remove_records(removes) + cache.async_remove_records(removes) if updates: self.async_updates_complete(new) diff --git a/src/zeroconf/_services/info.py b/src/zeroconf/_services/info.py index 02b7137a..5e2bec4c 100644 --- a/src/zeroconf/_services/info.py +++ b/src/zeroconf/_services/info.py @@ -46,7 +46,7 @@ from ..const import ( _ADDRESS_RECORD_TYPES, _CLASS_IN, - _CLASS_UNIQUE, + _CLASS_IN_UNIQUE, _DNS_HOST_TTL, _DNS_OTHER_TTL, _FLAGS_QR_QUERY, @@ -409,15 +409,21 @@ def _get_ip_addresses_from_cache_lifo( def _set_ipv6_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None: """Set IPv6 addresses from the cache.""" - self._ipv6_addresses = cast( - "List[IPv6Address]", self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_AAAA) - ) + if TYPE_CHECKING: + self._ipv6_addresses = cast( + "List[IPv6Address]", self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_AAAA) + ) + else: + self._ipv6_addresses = self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_AAAA) def _set_ipv4_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None: """Set IPv4 addresses from the cache.""" - self._ipv4_addresses = cast( - "List[IPv4Address]", self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_A) - ) + if TYPE_CHECKING: + self._ipv4_addresses = cast( + "List[IPv4Address]", self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_A) + ) + else: + self._ipv4_addresses = self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_A) def update_record(self, zc: 'Zeroconf', now: float, record: Optional[DNSRecord]) -> None: """Updates service information from a DNS record. @@ -525,7 +531,7 @@ def dns_addresses( """Return matching DNSAddress from ServiceInfo.""" name = self.server or self.name ttl = override_ttl if override_ttl is not None else self.host_ttl - class_ = _CLASS_IN | _CLASS_UNIQUE + class_ = _CLASS_IN_UNIQUE version_value = version.value return [ DNSAddress( @@ -552,14 +558,17 @@ def dns_pointer(self, override_ttl: Optional[int] = None, created: Optional[floa def dns_service(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSService: """Return DNSService from ServiceInfo.""" + port = self.port + if TYPE_CHECKING: + assert isinstance(port, int) return DNSService( self.name, _TYPE_SRV, - _CLASS_IN | _CLASS_UNIQUE, + _CLASS_IN_UNIQUE, override_ttl if override_ttl is not None else self.host_ttl, self.priority, self.weight, - cast(int, self.port), + port, self.server or self.name, created, ) @@ -569,7 +578,7 @@ def dns_text(self, override_ttl: Optional[int] = None, created: Optional[float] return DNSText( self.name, _TYPE_TXT, - _CLASS_IN | _CLASS_UNIQUE, + _CLASS_IN_UNIQUE, override_ttl if override_ttl is not None else self.other_ttl, self.text, created, @@ -582,7 +591,7 @@ def dns_nsec( return DNSNsec( self.name, _TYPE_NSEC, - _CLASS_IN | _CLASS_UNIQUE, + _CLASS_IN_UNIQUE, override_ttl if override_ttl is not None else self.host_ttl, self.name, missing_types, @@ -593,12 +602,11 @@ def get_address_and_nsec_records( self, override_ttl: Optional[int] = None, created: Optional[float] = None ) -> Set[DNSRecord]: """Build a set of address records and NSEC records for non-present record types.""" - seen_types: Set[int] = set() + missing_types: Set[int] = _ADDRESS_RECORD_TYPES.copy() records: Set[DNSRecord] = set() for dns_address in self.dns_addresses(override_ttl, IPVersion.All, created): - seen_types.add(dns_address.type) + missing_types.discard(dns_address.type) records.add(dns_address) - missing_types: Set[int] = _ADDRESS_RECORD_TYPES - seen_types if missing_types: assert self.server is not None, "Service server must be set for NSEC record." records.add(self.dns_nsec(list(missing_types), override_ttl, created)) @@ -732,10 +740,13 @@ def generate_request_query( ) -> DNSOutgoing: """Generate the request query.""" out = DNSOutgoing(_FLAGS_QR_QUERY) - out.add_question_or_one_cache(zc.cache, now, self.name, _TYPE_SRV, _CLASS_IN) - out.add_question_or_one_cache(zc.cache, now, self.name, _TYPE_TXT, _CLASS_IN) - out.add_question_or_all_cache(zc.cache, now, self.server or self.name, _TYPE_A, _CLASS_IN) - out.add_question_or_all_cache(zc.cache, now, self.server or self.name, _TYPE_AAAA, _CLASS_IN) + name = self.name + server_or_name = self.server or name + cache = zc.cache + out.add_question_or_one_cache(cache, now, name, _TYPE_SRV, _CLASS_IN) + out.add_question_or_one_cache(cache, now, name, _TYPE_TXT, _CLASS_IN) + out.add_question_or_all_cache(cache, now, server_or_name, _TYPE_A, _CLASS_IN) + out.add_question_or_all_cache(cache, now, server_or_name, _TYPE_AAAA, _CLASS_IN) if question_type == DNSQuestionType.QU: for question in out.questions: question.unicast = True diff --git a/src/zeroconf/const.py b/src/zeroconf/const.py index f87c1336..ca199df5 100644 --- a/src/zeroconf/const.py +++ b/src/zeroconf/const.py @@ -84,6 +84,7 @@ _CLASS_ANY = 255 _CLASS_MASK = 0x7FFF _CLASS_UNIQUE = 0x8000 +_CLASS_IN_UNIQUE = _CLASS_IN | _CLASS_UNIQUE _TYPE_A = 1 _TYPE_NS = 2