from ._cares import ffi as _ffi, lib as _lib import _cffi_backend # hint for bundler tools if _lib.ARES_SUCCESS != _lib.ares_library_init(_lib.ARES_LIB_INIT_ALL) or _ffi is None: raise RuntimeError('Could not initialize c-ares') if not _lib.ares_threadsafety(): raise RuntimeError("c-ares is not built with thread safety") from . import errno from .utils import ascii_bytes, maybe_str, parse_name from ._version import __version__ import math import socket import threading from collections.abc import Callable, Iterable from dataclasses import dataclass from typing import Any, Callable, Literal, Optional, Dict, Union from queue import SimpleQueue IP4 = tuple[str, int] IP6 = tuple[str, int, int, int] # Flag values ARES_FLAG_USEVC = _lib.ARES_FLAG_USEVC ARES_FLAG_PRIMARY = _lib.ARES_FLAG_PRIMARY ARES_FLAG_IGNTC = _lib.ARES_FLAG_IGNTC ARES_FLAG_NORECURSE = _lib.ARES_FLAG_NORECURSE ARES_FLAG_STAYOPEN = _lib.ARES_FLAG_STAYOPEN ARES_FLAG_NOSEARCH = _lib.ARES_FLAG_NOSEARCH ARES_FLAG_NOALIASES = _lib.ARES_FLAG_NOALIASES ARES_FLAG_NOCHECKRESP = _lib.ARES_FLAG_NOCHECKRESP ARES_FLAG_EDNS = _lib.ARES_FLAG_EDNS ARES_FLAG_NO_DFLT_SVR = _lib.ARES_FLAG_NO_DFLT_SVR # Nameinfo flag values ARES_NI_NOFQDN = _lib.ARES_NI_NOFQDN ARES_NI_NUMERICHOST = _lib.ARES_NI_NUMERICHOST ARES_NI_NAMEREQD = _lib.ARES_NI_NAMEREQD ARES_NI_NUMERICSERV = _lib.ARES_NI_NUMERICSERV ARES_NI_DGRAM = _lib.ARES_NI_DGRAM ARES_NI_TCP = _lib.ARES_NI_TCP ARES_NI_UDP = _lib.ARES_NI_UDP ARES_NI_SCTP = _lib.ARES_NI_SCTP ARES_NI_DCCP = _lib.ARES_NI_DCCP ARES_NI_NUMERICSCOPE = _lib.ARES_NI_NUMERICSCOPE ARES_NI_LOOKUPHOST = _lib.ARES_NI_LOOKUPHOST ARES_NI_LOOKUPSERVICE = _lib.ARES_NI_LOOKUPSERVICE ARES_NI_IDN = _lib.ARES_NI_IDN ARES_NI_IDN_ALLOW_UNASSIGNED = _lib.ARES_NI_IDN_ALLOW_UNASSIGNED ARES_NI_IDN_USE_STD3_ASCII_RULES = _lib.ARES_NI_IDN_USE_STD3_ASCII_RULES # Bad socket ARES_SOCKET_BAD = _lib.ARES_SOCKET_BAD # Query types QUERY_TYPE_A = _lib.ARES_REC_TYPE_A QUERY_TYPE_AAAA = _lib.ARES_REC_TYPE_AAAA QUERY_TYPE_ANY = _lib.ARES_REC_TYPE_ANY QUERY_TYPE_CAA = _lib.ARES_REC_TYPE_CAA QUERY_TYPE_CNAME = _lib.ARES_REC_TYPE_CNAME QUERY_TYPE_MX = _lib.ARES_REC_TYPE_MX QUERY_TYPE_NAPTR = _lib.ARES_REC_TYPE_NAPTR QUERY_TYPE_NS = _lib.ARES_REC_TYPE_NS QUERY_TYPE_PTR = _lib.ARES_REC_TYPE_PTR QUERY_TYPE_SOA = _lib.ARES_REC_TYPE_SOA QUERY_TYPE_SRV = _lib.ARES_REC_TYPE_SRV QUERY_TYPE_TXT = _lib.ARES_REC_TYPE_TXT QUERY_TYPE_TLSA = _lib.ARES_REC_TYPE_TLSA QUERY_TYPE_HTTPS = _lib.ARES_REC_TYPE_HTTPS QUERY_TYPE_URI = _lib.ARES_REC_TYPE_URI # Query classes QUERY_CLASS_IN = _lib.ARES_CLASS_IN QUERY_CLASS_CHAOS = _lib.ARES_CLASS_CHAOS QUERY_CLASS_HS = _lib.ARES_CLASS_HESOID QUERY_CLASS_NONE = _lib.ARES_CLASS_NONE QUERY_CLASS_ANY = _lib.ARES_CLASS_ANY ARES_VERSION = maybe_str(_ffi.string(_lib.ares_version(_ffi.NULL))) PYCARES_ADDRTTL_SIZE = 256 class AresError(Exception): pass # callback helpers _handle_to_channel: Dict[Any, "Channel"] = {} # Maps handle to channel to prevent use-after-free @_ffi.def_extern() def _sock_state_cb(data, socket_fd, readable, writable): # Note: sock_state_cb handle is not tracked in _handle_to_channel # because it has a different lifecycle (tied to the channel, not individual queries) sock_state_cb = _ffi.from_handle(data) sock_state_cb(socket_fd, readable, writable) @_ffi.def_extern() def _host_cb(arg, status, timeouts, hostent): # Get callback data without removing the reference yet if arg not in _handle_to_channel: return callback = _ffi.from_handle(arg) if status != _lib.ARES_SUCCESS: result = None else: result = parse_hostent(hostent) status = None callback(result, status) _handle_to_channel.pop(arg, None) @_ffi.def_extern() def _nameinfo_cb(arg, status, timeouts, node, service): # Get callback data without removing the reference yet if arg not in _handle_to_channel: return callback = _ffi.from_handle(arg) if status != _lib.ARES_SUCCESS: result = None else: result = parse_nameinfo(node, service) status = None callback(result, status) _handle_to_channel.pop(arg, None) @_ffi.def_extern() def _query_dnsrec_cb(arg, status, timeouts, dnsrec): """Callback for new DNS record API queries""" # Get callback data without removing the reference yet if arg not in _handle_to_channel: return callback = _ffi.from_handle(arg) if status != _lib.ARES_SUCCESS: result = None else: result, parse_status = parse_dnsrec(dnsrec) if parse_status is not None: status = parse_status else: # Success - set status to None status = None callback(result, status) _handle_to_channel.pop(arg, None) @_ffi.def_extern() def _addrinfo_cb(arg, status, timeouts, res): # Get callback data without removing the reference yet if arg not in _handle_to_channel: return callback = _ffi.from_handle(arg) if status != _lib.ARES_SUCCESS: result = None else: result = parse_addrinfo(res) status = None callback(result, status) _handle_to_channel.pop(arg, None) def _extract_opt_params(rr, key): """Extract OPT params as list of (key, value) tuples for HTTPS/SVCB records.""" opt_cnt = _lib.ares_dns_rr_get_opt_cnt(rr, key) if opt_cnt == 0: return [] # Collect all options as a list of (key, value) tuples params = [] for i in range(opt_cnt): val_ptr = _ffi.new("unsigned char **") val_len = _ffi.new("size_t *") opt_key = _lib.ares_dns_rr_get_opt(rr, key, i, val_ptr, val_len) if val_ptr[0] != _ffi.NULL: val = bytes(_ffi.buffer(val_ptr[0], val_len[0])) else: val = b'' params.append((opt_key, val)) return params def extract_record_data(rr, record_type): """Extract type-specific data from a DNS resource record and return appropriate dataclass""" if record_type == _lib.ARES_REC_TYPE_A: addr = _lib.ares_dns_rr_get_addr(rr, _lib.ARES_RR_A_ADDR) buf = _ffi.new("char[]", _lib.INET6_ADDRSTRLEN) _lib.ares_inet_ntop(socket.AF_INET, addr, buf, _lib.INET6_ADDRSTRLEN) return ARecordData(addr=maybe_str(_ffi.string(buf))) elif record_type == _lib.ARES_REC_TYPE_AAAA: addr = _lib.ares_dns_rr_get_addr6(rr, _lib.ARES_RR_AAAA_ADDR) buf = _ffi.new("char[]", _lib.INET6_ADDRSTRLEN) _lib.ares_inet_ntop(socket.AF_INET6, addr, buf, _lib.INET6_ADDRSTRLEN) return AAAARecordData(addr=maybe_str(_ffi.string(buf))) elif record_type == _lib.ARES_REC_TYPE_MX: priority = _lib.ares_dns_rr_get_u16(rr, _lib.ARES_RR_MX_PREFERENCE) exchange = _lib.ares_dns_rr_get_str(rr, _lib.ARES_RR_MX_EXCHANGE) return MXRecordData(priority=priority, exchange=maybe_str(_ffi.string(exchange))) elif record_type == _lib.ARES_REC_TYPE_TXT: # TXT records use ABIN (array of binary) for chunks cnt = _lib.ares_dns_rr_get_abin_cnt(rr, _lib.ARES_RR_TXT_DATA) chunks = [] for i in range(cnt): length = _ffi.new("size_t *") data = _lib.ares_dns_rr_get_abin(rr, _lib.ARES_RR_TXT_DATA, i, length) if data != _ffi.NULL: chunks.append(_ffi.buffer(data, length[0])[:]) return TXTRecordData(data=b''.join(chunks)) elif record_type == _lib.ARES_REC_TYPE_CAA: critical = _lib.ares_dns_rr_get_u8(rr, _lib.ARES_RR_CAA_CRITICAL) tag = _lib.ares_dns_rr_get_str(rr, _lib.ARES_RR_CAA_TAG) length = _ffi.new("size_t *") value = _lib.ares_dns_rr_get_bin(rr, _lib.ARES_RR_CAA_VALUE, length) value_str = maybe_str(_ffi.buffer(value, length[0])[:]) return CAARecordData(critical=critical, tag=maybe_str(_ffi.string(tag)), value=value_str) elif record_type == _lib.ARES_REC_TYPE_CNAME: cname = _lib.ares_dns_rr_get_str(rr, _lib.ARES_RR_CNAME_CNAME) return CNAMERecordData(cname=maybe_str(_ffi.string(cname))) elif record_type == _lib.ARES_REC_TYPE_NAPTR: order = _lib.ares_dns_rr_get_u16(rr, _lib.ARES_RR_NAPTR_ORDER) preference = _lib.ares_dns_rr_get_u16(rr, _lib.ARES_RR_NAPTR_PREFERENCE) flags = _lib.ares_dns_rr_get_str(rr, _lib.ARES_RR_NAPTR_FLAGS) service = _lib.ares_dns_rr_get_str(rr, _lib.ARES_RR_NAPTR_SERVICES) regexp = _lib.ares_dns_rr_get_str(rr, _lib.ARES_RR_NAPTR_REGEXP) replacement = _lib.ares_dns_rr_get_str(rr, _lib.ARES_RR_NAPTR_REPLACEMENT) return NAPTRRecordData( order=order, preference=preference, flags=maybe_str(_ffi.string(flags)), service=maybe_str(_ffi.string(service)), regexp=maybe_str(_ffi.string(regexp)), replacement=maybe_str(_ffi.string(replacement)) ) elif record_type == _lib.ARES_REC_TYPE_NS: nsdname = _lib.ares_dns_rr_get_str(rr, _lib.ARES_RR_NS_NSDNAME) return NSRecordData(nsdname=maybe_str(_ffi.string(nsdname))) elif record_type == _lib.ARES_REC_TYPE_PTR: dname = _lib.ares_dns_rr_get_str(rr, _lib.ARES_RR_PTR_DNAME) return PTRRecordData(dname=maybe_str(_ffi.string(dname))) elif record_type == _lib.ARES_REC_TYPE_SOA: mname = _lib.ares_dns_rr_get_str(rr, _lib.ARES_RR_SOA_MNAME) rname = _lib.ares_dns_rr_get_str(rr, _lib.ARES_RR_SOA_RNAME) serial = _lib.ares_dns_rr_get_u32(rr, _lib.ARES_RR_SOA_SERIAL) refresh = _lib.ares_dns_rr_get_u32(rr, _lib.ARES_RR_SOA_REFRESH) retry = _lib.ares_dns_rr_get_u32(rr, _lib.ARES_RR_SOA_RETRY) expire = _lib.ares_dns_rr_get_u32(rr, _lib.ARES_RR_SOA_EXPIRE) minimum = _lib.ares_dns_rr_get_u32(rr, _lib.ARES_RR_SOA_MINIMUM) return SOARecordData( mname=maybe_str(_ffi.string(mname)), rname=maybe_str(_ffi.string(rname)), serial=serial, refresh=refresh, retry=retry, expire=expire, minimum=minimum ) elif record_type == _lib.ARES_REC_TYPE_SRV: priority = _lib.ares_dns_rr_get_u16(rr, _lib.ARES_RR_SRV_PRIORITY) weight = _lib.ares_dns_rr_get_u16(rr, _lib.ARES_RR_SRV_WEIGHT) port = _lib.ares_dns_rr_get_u16(rr, _lib.ARES_RR_SRV_PORT) target = _lib.ares_dns_rr_get_str(rr, _lib.ARES_RR_SRV_TARGET) return SRVRecordData( priority=priority, weight=weight, port=port, target=maybe_str(_ffi.string(target)) ) elif record_type == _lib.ARES_REC_TYPE_TLSA: cert_usage = _lib.ares_dns_rr_get_u8(rr, _lib.ARES_RR_TLSA_CERT_USAGE) selector = _lib.ares_dns_rr_get_u8(rr, _lib.ARES_RR_TLSA_SELECTOR) matching_type = _lib.ares_dns_rr_get_u8(rr, _lib.ARES_RR_TLSA_MATCH) data_len = _ffi.new("size_t *") data_ptr = _lib.ares_dns_rr_get_bin(rr, _lib.ARES_RR_TLSA_DATA, data_len) cert_data = bytes(_ffi.buffer(data_ptr, data_len[0])) if data_ptr != _ffi.NULL else b'' return TLSARecordData( cert_usage=cert_usage, selector=selector, matching_type=matching_type, cert_association_data=cert_data ) elif record_type == _lib.ARES_REC_TYPE_HTTPS: priority = _lib.ares_dns_rr_get_u16(rr, _lib.ARES_RR_HTTPS_PRIORITY) target = _lib.ares_dns_rr_get_str(rr, _lib.ARES_RR_HTTPS_TARGET) params = _extract_opt_params(rr, _lib.ARES_RR_HTTPS_PARAMS) return HTTPSRecordData( priority=priority, target=maybe_str(_ffi.string(target)), params=params ) elif record_type == _lib.ARES_REC_TYPE_URI: priority = _lib.ares_dns_rr_get_u16(rr, _lib.ARES_RR_URI_PRIORITY) weight = _lib.ares_dns_rr_get_u16(rr, _lib.ARES_RR_URI_WEIGHT) target = _lib.ares_dns_rr_get_str(rr, _lib.ARES_RR_URI_TARGET) return URIRecordData( priority=priority, weight=weight, target=maybe_str(_ffi.string(target)) ) else: # Unknown record type - return None or raise error raise ValueError(f"Unsupported DNS record type: {record_type}") def parse_dnsrec(dnsrec): """Parse ares_dns_record_t into DNSResult with all sections""" if dnsrec == _ffi.NULL: return None, _lib.ARES_EBADRESP answer_records = [] authority_records = [] additional_records = [] # Parse answer section answer_count = _lib.ares_dns_record_rr_cnt(dnsrec, _lib.ARES_SECTION_ANSWER) for i in range(answer_count): rr = _lib.ares_dns_record_rr_get_const(dnsrec, _lib.ARES_SECTION_ANSWER, i) if rr != _ffi.NULL: name = maybe_str(_ffi.string(_lib.ares_dns_rr_get_name(rr))) rec_type = _lib.ares_dns_rr_get_type(rr) rec_class = _lib.ares_dns_rr_get_class(rr) ttl = _lib.ares_dns_rr_get_ttl(rr) try: data = extract_record_data(rr, rec_type) answer_records.append(DNSRecord( name=name, type=rec_type, record_class=rec_class, ttl=ttl, data=data )) except (ValueError, Exception): # Skip unsupported record types pass # Parse authority section authority_count = _lib.ares_dns_record_rr_cnt(dnsrec, _lib.ARES_SECTION_AUTHORITY) for i in range(authority_count): rr = _lib.ares_dns_record_rr_get_const(dnsrec, _lib.ARES_SECTION_AUTHORITY, i) if rr != _ffi.NULL: name = maybe_str(_ffi.string(_lib.ares_dns_rr_get_name(rr))) rec_type = _lib.ares_dns_rr_get_type(rr) rec_class = _lib.ares_dns_rr_get_class(rr) ttl = _lib.ares_dns_rr_get_ttl(rr) try: data = extract_record_data(rr, rec_type) authority_records.append(DNSRecord( name=name, type=rec_type, record_class=rec_class, ttl=ttl, data=data )) except (ValueError, Exception): # Skip unsupported record types pass # Parse additional section additional_count = _lib.ares_dns_record_rr_cnt(dnsrec, _lib.ARES_SECTION_ADDITIONAL) for i in range(additional_count): rr = _lib.ares_dns_record_rr_get_const(dnsrec, _lib.ARES_SECTION_ADDITIONAL, i) if rr != _ffi.NULL: name = maybe_str(_ffi.string(_lib.ares_dns_rr_get_name(rr))) rec_type = _lib.ares_dns_rr_get_type(rr) rec_class = _lib.ares_dns_rr_get_class(rr) ttl = _lib.ares_dns_rr_get_ttl(rr) try: data = extract_record_data(rr, rec_type) additional_records.append(DNSRecord( name=name, type=rec_type, record_class=rec_class, ttl=ttl, data=data )) except (ValueError, Exception): # Skip unsupported record types pass result = DNSResult( answer=answer_records, authority=authority_records, additional=additional_records ) return result, None class _ChannelShutdownManager: """Manages channel destruction in a single background thread using SimpleQueue.""" def __init__(self) -> None: self._queue: SimpleQueue = SimpleQueue() self._thread: Optional[threading.Thread] = None self._start_lock = threading.Lock() def _run_safe_shutdown_loop(self) -> None: """Process channel destruction requests from the queue.""" while True: # Block forever until we get a channel to destroy channel, _ = self._queue.get() # Cancel all pending queries - this will trigger callbacks with ARES_ECANCELLED _lib.ares_cancel(channel[0]) # Wait for all queries to finish _lib.ares_queue_wait_empty(channel[0], -1) # Destroy the channel if channel is not None: _lib.ares_destroy(channel[0]) def start(self) -> None: """Start the background thread if not already started.""" if self._thread is not None: return with self._start_lock: if self._thread is not None: # Started by another thread while waiting for the lock return self._thread = threading.Thread(target=self._run_safe_shutdown_loop, daemon=True) self._thread.start() def destroy_channel(self, channel, sock_state_cb_handle) -> None: """ Schedule channel destruction on the background thread. The socket state callback handle is passed along to ensure it remains alive until the channel is destroyed. Thread Safety and Synchronization: This method uses SimpleQueue which is thread-safe for putting items from multiple threads. The background thread processes channels sequentially waiting for queries to end before each destruction. """ self._queue.put((channel, sock_state_cb_handle)) # Global shutdown manager instance _shutdown_manager = _ChannelShutdownManager() class Channel: __qtypes__ = (_lib.ARES_REC_TYPE_A, _lib.ARES_REC_TYPE_AAAA, _lib.ARES_REC_TYPE_ANY, _lib.ARES_REC_TYPE_CAA, _lib.ARES_REC_TYPE_CNAME, _lib.ARES_REC_TYPE_HTTPS, _lib.ARES_REC_TYPE_MX, _lib.ARES_REC_TYPE_NAPTR, _lib.ARES_REC_TYPE_NS, _lib.ARES_REC_TYPE_PTR, _lib.ARES_REC_TYPE_SOA, _lib.ARES_REC_TYPE_SRV, _lib.ARES_REC_TYPE_TLSA, _lib.ARES_REC_TYPE_TXT, _lib.ARES_REC_TYPE_URI) __qclasses__ = (_lib.ARES_CLASS_IN, _lib.ARES_CLASS_CHAOS, _lib.ARES_CLASS_HESOID, _lib.ARES_CLASS_NONE, _lib.ARES_CLASS_ANY) def __init__(self, *, flags: Optional[int] = None, timeout: Optional[float] = None, tries: Optional[int] = None, ndots: Optional[int] = None, tcp_port: Optional[int] = None, udp_port: Optional[int] = None, servers: Optional[Iterable[Union[str, bytes]]] = None, domains: Optional[Iterable[Union[str, bytes]]] = None, lookups: Union[str, bytes, None] = None, sock_state_cb: Optional[Callable[[int, bool, bool], None]] = None, socket_send_buffer_size: Optional[int] = None, socket_receive_buffer_size: Optional[int] = None, rotate: bool = False, local_ip: Union[str, bytes, None] = None, local_dev: Optional[str] = None, resolvconf_path: Union[str, bytes, None] = None) -> None: # Initialize _channel to None first to ensure __del__ doesn't fail self._channel = None # Store flags for later use (default is 0 if not specified) self._flags = flags if flags is not None else 0 channel = _ffi.new("ares_channel *") options = _ffi.new("struct ares_options *") optmask = 0 if flags is not None: options.flags = flags optmask = optmask | _lib.ARES_OPT_FLAGS if timeout is not None: options.timeout = int(timeout * 1000) optmask = optmask | _lib.ARES_OPT_TIMEOUTMS if tries is not None: options.tries = tries optmask = optmask | _lib.ARES_OPT_TRIES if ndots is not None: options.ndots = ndots optmask = optmask | _lib.ARES_OPT_NDOTS if tcp_port is not None: options.tcp_port = tcp_port optmask = optmask | _lib.ARES_OPT_TCP_PORT if udp_port is not None: options.udp_port = udp_port optmask = optmask | _lib.ARES_OPT_UDP_PORT if socket_send_buffer_size is not None: options.socket_send_buffer_size = socket_send_buffer_size optmask = optmask | _lib.ARES_OPT_SOCK_SNDBUF if socket_receive_buffer_size is not None: options.socket_receive_buffer_size = socket_receive_buffer_size optmask = optmask | _lib.ARES_OPT_SOCK_RCVBUF if sock_state_cb: if not callable(sock_state_cb): raise TypeError("sock_state_cb is not callable") userdata = _ffi.new_handle(sock_state_cb) # This must be kept alive while the channel is alive. self._sock_state_cb_handle = userdata options.sock_state_cb = _lib._sock_state_cb options.sock_state_cb_data = userdata optmask = optmask | _lib.ARES_OPT_SOCK_STATE_CB else: self._sock_state_cb_handle = None optmask = optmask | _lib.ARES_OPT_EVENT_THREAD options.evsys = _lib.ARES_EVSYS_DEFAULT if lookups: options.lookups = _ffi.new('char[]', ascii_bytes(lookups)) optmask = optmask | _lib.ARES_OPT_LOOKUPS if domains: strs = [_ffi.new("char[]", ascii_bytes(i)) for i in domains] c = _ffi.new("char *[%d]" % (len(domains) + 1)) for i in range(len(domains)): c[i] = strs[i] options.domains = c options.ndomains = len(domains) optmask = optmask | _lib.ARES_OPT_DOMAINS if rotate: optmask = optmask | _lib.ARES_OPT_ROTATE if resolvconf_path is not None: optmask = optmask | _lib.ARES_OPT_RESOLVCONF options.resolvconf_path = _ffi.new('char[]', ascii_bytes(resolvconf_path)) r = _lib.ares_init_options(channel, options, optmask) if r != _lib.ARES_SUCCESS: raise AresError('Failed to initialize c-ares channel') self._channel = channel if servers: self.servers = servers if local_ip: self.set_local_ip(local_ip) if local_dev: self.set_local_dev(local_dev) # Ensure the shutdown thread is started _shutdown_manager.start() def __del__(self) -> None: """Ensure the channel is destroyed when the object is deleted.""" self.close() def _create_callback_handle(self, callback_data): """ Create a callback handle and register it for tracking. This ensures that: 1. The callback data is wrapped in a CFFI handle 2. The handle is mapped to this channel to keep it alive Args: callback_data: The data to pass to the callback (usually a callable or tuple) Returns: The CFFI handle that can be passed to C functions Raises: RuntimeError: If the channel is destroyed """ if self._channel is None: raise RuntimeError("Channel is destroyed, no new queries allowed") userdata = _ffi.new_handle(callback_data) _handle_to_channel[userdata] = self return userdata def cancel(self) -> None: _lib.ares_cancel(self._channel[0]) def reinit(self) -> None: r = _lib.ares_reinit(self._channel[0]) if r != _lib.ARES_SUCCESS: raise AresError(r, errno.strerror(r)) @property def servers(self) -> list[str]: csv_str = _lib.ares_get_servers_csv(self._channel[0]) if csv_str == _ffi.NULL: raise AresError(_lib.ARES_ENOMEM, errno.strerror(_lib.ARES_ENOMEM)) server_list = [] csv_string = maybe_str(_ffi.string(csv_str)) _lib.ares_free_string(csv_str) server_list = [s.strip() for s in csv_string.split(',')] return server_list @servers.setter def servers(self, servers: Iterable[Union[str, bytes]]) -> None: server_list = [ascii_bytes(s).decode('ascii') if isinstance(s, bytes) else s for s in servers] csv_str = ','.join(server_list) r = _lib.ares_set_servers_csv(self._channel[0], csv_str.encode('ascii')) if r != _lib.ARES_SUCCESS: raise AresError(r, errno.strerror(r)) def process_fd(self, read_fd: int, write_fd: int) -> None: _lib.ares_process_fd(self._channel[0], _ffi.cast("ares_socket_t", read_fd), _ffi.cast("ares_socket_t", write_fd)) def process_read_fd(self, read_fd:int) -> None: _lib.ares_process_fd(self._channel[0], _ffi.cast("ares_socket_t", read_fd), _ffi.cast("ares_socket_t", ARES_SOCKET_BAD)) def process_write_fd(self, write_fd:int) -> None: _lib.ares_process_fd(self._channel[0], _ffi.cast("ares_socket_t", ARES_SOCKET_BAD), _ffi.cast("ares_socket_t", write_fd)) def timeout(self, t = None): maxtv = _ffi.NULL tv = _ffi.new("struct timeval*") if t is not None: if t >= 0.0: maxtv = _ffi.new("struct timeval*") maxtv.tv_sec = int(math.floor(t)) maxtv.tv_usec = int(math.fmod(t, 1.0) * 1000000) else: raise ValueError("timeout needs to be a positive number or None") _lib.ares_timeout(self._channel[0], maxtv, tv) if tv == _ffi.NULL: return 0.0 return (tv.tv_sec + tv.tv_usec / 1000000.0) def gethostbyaddr(self, addr: str, *, callback: Callable[[Any, int], None]) -> None: if not callable(callback): raise TypeError("a callable is required") addr4 = _ffi.new("struct in_addr*") addr6 = _ffi.new("struct ares_in6_addr*") if _lib.ares_inet_pton(socket.AF_INET, ascii_bytes(addr), (addr4)) == 1: address = addr4 family = socket.AF_INET elif _lib.ares_inet_pton(socket.AF_INET6, ascii_bytes(addr), (addr6)) == 1: address = addr6 family = socket.AF_INET6 else: raise ValueError("invalid IP address") userdata = self._create_callback_handle(callback) _lib.ares_gethostbyaddr(self._channel[0], address, _ffi.sizeof(address[0]), family, _lib._host_cb, userdata) def getaddrinfo( self, host: str, port: Optional[int], *, family: socket.AddressFamily = 0, type: int = 0, proto: int = 0, flags: int = 0, callback: Callable[[Any, int], None] ) -> None: if not callable(callback): raise TypeError("a callable is required") if port is None: service = _ffi.NULL elif isinstance(port, int): service = str(port).encode('ascii') else: service = ascii_bytes(port) userdata = self._create_callback_handle(callback) hints = _ffi.new('struct ares_addrinfo_hints*') hints.ai_flags = flags hints.ai_family = family hints.ai_socktype = type hints.ai_protocol = proto _lib.ares_getaddrinfo(self._channel[0], parse_name(host), service, hints, _lib._addrinfo_cb, userdata) def query(self, name: str, query_type: int, *, query_class: int = QUERY_CLASS_IN, callback: Callable[[Any, int], None]) -> None: """ Perform a DNS query. Args: name: Domain name to query query_type: Type of query (e.g., QUERY_TYPE_A, QUERY_TYPE_AAAA, etc.) query_class: Query class (default: QUERY_CLASS_IN) callback: Callback function that receives (result, errno) The callback will receive a DNSResult object containing answer, authority, and additional sections. """ if not callable(callback): raise TypeError('a callable is required') if query_type not in self.__qtypes__: raise ValueError('invalid query type specified') if query_class not in self.__qclasses__: raise ValueError('invalid query class specified') userdata = self._create_callback_handle(callback) qid = _ffi.new("unsigned short *") status = _lib.ares_query_dnsrec( self._channel[0], parse_name(name), query_class, query_type, _lib._query_dnsrec_cb, userdata, qid ) if status != _lib.ARES_SUCCESS: _handle_to_channel.pop(userdata, None) raise AresError(status, errno.strerror(status)) def search(self, name: str, query_type: int, *, query_class: int = QUERY_CLASS_IN, callback: Callable[[Any, int], None]) -> None: """ Perform a DNS search (honors resolv.conf search domains). Args: name: Domain name to search query_type: Type of query (e.g., QUERY_TYPE_A, QUERY_TYPE_AAAA, etc.) query_class: Query class (default: QUERY_CLASS_IN) callback: Callback function that receives (result, errno) The callback will receive a DNSResult object containing answer, authority, and additional sections. """ if not callable(callback): raise TypeError('a callable is required') if query_type not in self.__qtypes__: raise ValueError('invalid query type specified') if query_class not in self.__qclasses__: raise ValueError('invalid query class specified') # Create a DNS record for the search query # Set RD (Recursion Desired) flag unless ARES_FLAG_NORECURSE is set dns_flags = 0 if (self._flags & _lib.ARES_FLAG_NORECURSE) else _lib.ARES_FLAG_RD dnsrec_p = _ffi.new("ares_dns_record_t **") status = _lib.ares_dns_record_create( dnsrec_p, 0, # id (will be set by c-ares) dns_flags, # flags - include RD for recursive queries _lib.ARES_OPCODE_QUERY, _lib.ARES_RCODE_NOERROR ) if status != _lib.ARES_SUCCESS: raise AresError(status, errno.strerror(status)) dnsrec = dnsrec_p[0] # Add the query to the DNS record status = _lib.ares_dns_record_query_add( dnsrec, parse_name(name), query_type, query_class ) if status != _lib.ARES_SUCCESS: _lib.ares_dns_record_destroy(dnsrec) raise AresError(status, errno.strerror(status)) # Wrap callback to destroy DNS record after it's called original_callback = callback def cleanup_callback(result, error): try: original_callback(result, error) finally: # Clean up the DNS record after the callback completes _lib.ares_dns_record_destroy(dnsrec) # Perform the search with the created DNS record userdata = self._create_callback_handle(cleanup_callback) status = _lib.ares_search_dnsrec( self._channel[0], dnsrec, _lib._query_dnsrec_cb, userdata ) if status != _lib.ARES_SUCCESS: _handle_to_channel.pop(userdata, None) _lib.ares_dns_record_destroy(dnsrec) raise AresError(status, errno.strerror(status)) def set_local_ip(self, ip): addr4 = _ffi.new("struct in_addr*") addr6 = _ffi.new("struct ares_in6_addr*") if _lib.ares_inet_pton(socket.AF_INET, ascii_bytes(ip), addr4) == 1: _lib.ares_set_local_ip4(self._channel[0], socket.ntohl(addr4.s_addr)) elif _lib.ares_inet_pton(socket.AF_INET6, ascii_bytes(ip), addr6) == 1: _lib.ares_set_local_ip6(self._channel[0], addr6) else: raise ValueError("invalid IP address") def getnameinfo(self, address: Union[IP4, IP6], flags: int, *, callback: Callable[[Any, int], None]) -> None: if not callable(callback): raise TypeError("a callable is required") if len(address) == 2: (ip, port) = address sa4 = _ffi.new("struct sockaddr_in*") if _lib.ares_inet_pton(socket.AF_INET, ascii_bytes(ip), _ffi.addressof(sa4.sin_addr)) != 1: raise ValueError("Invalid IPv4 address %r" % ip) sa4.sin_family = socket.AF_INET sa4.sin_port = socket.htons(port) sa = sa4 elif len(address) == 4: (ip, port, flowinfo, scope_id) = address sa6 = _ffi.new("struct sockaddr_in6*") if _lib.ares_inet_pton(socket.AF_INET6, ascii_bytes(ip), _ffi.addressof(sa6.sin6_addr)) != 1: raise ValueError("Invalid IPv6 address %r" % ip) sa6.sin6_family = socket.AF_INET6 sa6.sin6_port = socket.htons(port) sa6.sin6_flowinfo = socket.htonl(flowinfo) # I'm unsure about byteorder here. sa6.sin6_scope_id = scope_id # Yes, without htonl. sa = sa6 else: raise ValueError("Invalid address argument") userdata = self._create_callback_handle(callback) _lib.ares_getnameinfo(self._channel[0], _ffi.cast("struct sockaddr*", sa), _ffi.sizeof(sa[0]), flags, _lib._nameinfo_cb, userdata) def set_local_dev(self, dev): _lib.ares_set_local_dev(self._channel[0], dev) def close(self) -> None: """ Close the channel as soon as it's safe to do so. This method can be called from any thread. The channel will be destroyed safely using a background thread with a 1-second delay to ensure c-ares has completed its cleanup. Note: Once close() is called, no new queries can be started. Any pending queries will be cancelled and their callbacks will receive ARES_ECANCELLED. """ if self._channel is None: # Already destroyed return # NB: don't cancel queries here, it may lead to problem if done from a # query callback. # Schedule channel destruction channel, self._channel = self._channel, None _shutdown_manager.destroy_channel(channel, self._sock_state_cb_handle) def wait(self, timeout: float=None) -> bool: """ Wait until all pending queries are complete or timeout occurs. Args: timeout: Maximum time to wait in seconds. Use -1 for infinite wait. """ r = _lib.ares_queue_wait_empty(self._channel[0], int(timeout * 1000) if timeout is not None and timeout >= 0 else -1) if r == _lib.ARES_SUCCESS: return True elif r == _lib.ARES_ETIMEOUT: return False else: raise AresError(r, errno.strerror(r)) # DNS query result types - New dataclass-based API # @dataclass class ARecordData: """Data for A (IPv4 address) record""" addr: str @dataclass class AAAARecordData: """Data for AAAA (IPv6 address) record""" addr: str @dataclass class MXRecordData: """Data for MX (mail exchange) record""" priority: int exchange: str @dataclass class TXTRecordData: """Data for TXT (text) record""" data: bytes @dataclass class CAARecordData: """Data for CAA (certification authority authorization) record""" critical: int tag: str value: str @dataclass class CNAMERecordData: """Data for CNAME (canonical name) record""" cname: str @dataclass class NAPTRRecordData: """Data for NAPTR (naming authority pointer) record""" order: int preference: int flags: str service: str regexp: str replacement: str @dataclass class NSRecordData: """Data for NS (name server) record""" nsdname: str @dataclass class PTRRecordData: """Data for PTR (pointer) record""" dname: str @dataclass class SOARecordData: """Data for SOA (start of authority) record""" mname: str rname: str serial: int refresh: int retry: int expire: int minimum: int @dataclass class SRVRecordData: """Data for SRV (service) record""" priority: int weight: int port: int target: str @dataclass class TLSARecordData: """Data for TLSA (DANE TLS authentication) record - RFC 6698""" cert_usage: int selector: int matching_type: int cert_association_data: bytes @dataclass class HTTPSRecordData: """Data for HTTPS (service binding) record - RFC 9460""" priority: int target: str params: list # List of (key: int, value: bytes) tuples @dataclass class URIRecordData: """Data for URI (Uniform Resource Identifier) record - RFC 7553""" priority: int weight: int target: str @dataclass class DNSRecord: """Represents a single DNS resource record""" name: str type: int record_class: int ttl: int data: Union[ARecordData, AAAARecordData, MXRecordData, TXTRecordData, CAARecordData, CNAMERecordData, HTTPSRecordData, NAPTRRecordData, NSRecordData, PTRRecordData, SOARecordData, SRVRecordData, TLSARecordData, URIRecordData] @dataclass class DNSResult: """Represents a complete DNS query result with all sections""" answer: list[DNSRecord] authority: list[DNSRecord] additional: list[DNSRecord] # Host/AddrInfo result types @dataclass class HostResult: """Result from gethostbyaddr() operation""" name: str aliases: list[str] addresses: list[str] @dataclass class NameInfoResult: """Result from getnameinfo() operation""" node: str service: Optional[str] @dataclass class AddrInfoNode: """Single address node from getaddrinfo() result""" ttl: int flags: int family: int socktype: int protocol: int addr: tuple # (ip, port) or (ip, port, flowinfo, scope_id) @dataclass class AddrInfoCname: """CNAME information from getaddrinfo() result""" ttl: int alias: str name: str @dataclass class AddrInfoResult: """Complete result from getaddrinfo() operation""" cnames: list[AddrInfoCname] nodes: list[AddrInfoNode] # Parser functions for Host/AddrInfo results def parse_hostent(hostent) -> HostResult: """Parse c-ares hostent structure into HostResult""" name = maybe_str(_ffi.string(hostent.h_name)) aliases = [] addresses = [] i = 0 while hostent.h_aliases[i] != _ffi.NULL: aliases.append(maybe_str(_ffi.string(hostent.h_aliases[i]))) i += 1 i = 0 while hostent.h_addr_list[i] != _ffi.NULL: buf = _ffi.new("char[]", _lib.INET6_ADDRSTRLEN) if _ffi.NULL != _lib.ares_inet_ntop(hostent.h_addrtype, hostent.h_addr_list[i], buf, _lib.INET6_ADDRSTRLEN): addresses.append(maybe_str(_ffi.string(buf, _lib.INET6_ADDRSTRLEN))) i += 1 return HostResult(name=name, aliases=aliases, addresses=addresses) def parse_nameinfo(node, service) -> NameInfoResult: """Parse c-ares nameinfo into NameInfoResult""" node_str = maybe_str(_ffi.string(node)) service_str = maybe_str(_ffi.string(service)) if service != _ffi.NULL else None return NameInfoResult(node=node_str, service=service_str) def parse_addrinfo_node(ares_node) -> AddrInfoNode: """Parse a single c-ares addrinfo node into AddrInfoNode""" ttl = ares_node.ai_ttl flags = ares_node.ai_flags socktype = ares_node.ai_socktype protocol = ares_node.ai_protocol addr_struct = ares_node.ai_addr assert addr_struct.sa_family == ares_node.ai_family ip = _ffi.new("char []", _lib.INET6_ADDRSTRLEN) if addr_struct.sa_family == socket.AF_INET: family = socket.AF_INET s = _ffi.cast("struct sockaddr_in*", addr_struct) if _ffi.NULL != _lib.ares_inet_ntop(s.sin_family, _ffi.addressof(s.sin_addr), ip, _lib.INET6_ADDRSTRLEN): # (address, port) 2-tuple for AF_INET addr = (_ffi.string(ip, _lib.INET6_ADDRSTRLEN), socket.ntohs(s.sin_port)) else: raise ValueError("failed to convert IPv4 address") elif addr_struct.sa_family == socket.AF_INET6: family = socket.AF_INET6 s = _ffi.cast("struct sockaddr_in6*", addr_struct) if _ffi.NULL != _lib.ares_inet_ntop(s.sin6_family, _ffi.addressof(s.sin6_addr), ip, _lib.INET6_ADDRSTRLEN): # (address, port, flow info, scope id) 4-tuple for AF_INET6 addr = (_ffi.string(ip, _lib.INET6_ADDRSTRLEN), socket.ntohs(s.sin6_port), s.sin6_flowinfo, s.sin6_scope_id) else: raise ValueError("failed to convert IPv6 address") else: raise ValueError("invalid sockaddr family") return AddrInfoNode(ttl=ttl, flags=flags, family=family, socktype=socktype, protocol=protocol, addr=addr) def parse_addrinfo_cname(ares_cname) -> AddrInfoCname: """Parse a single c-ares addrinfo cname into AddrInfoCname""" return AddrInfoCname( ttl=ares_cname.ttl, alias=maybe_str(_ffi.string(ares_cname.alias)), name=maybe_str(_ffi.string(ares_cname.name)) ) def parse_addrinfo(ares_addrinfo) -> AddrInfoResult: """Parse c-ares addrinfo structure into AddrInfoResult""" cnames = [] nodes = [] cname_ptr = ares_addrinfo.cnames while cname_ptr != _ffi.NULL: cnames.append(parse_addrinfo_cname(cname_ptr)) cname_ptr = cname_ptr.next node_ptr = ares_addrinfo.nodes while node_ptr != _ffi.NULL: nodes.append(parse_addrinfo_node(node_ptr)) node_ptr = node_ptr.ai_next _lib.ares_freeaddrinfo(ares_addrinfo) return AddrInfoResult(cnames=cnames, nodes=nodes) __all__ = ( # Channel flags "ARES_FLAG_USEVC", "ARES_FLAG_PRIMARY", "ARES_FLAG_IGNTC", "ARES_FLAG_NORECURSE", "ARES_FLAG_STAYOPEN", "ARES_FLAG_NOSEARCH", "ARES_FLAG_NOALIASES", "ARES_FLAG_NOCHECKRESP", "ARES_FLAG_EDNS", "ARES_FLAG_NO_DFLT_SVR", # Nameinfo flag values "ARES_NI_NOFQDN", "ARES_NI_NUMERICHOST", "ARES_NI_NAMEREQD", "ARES_NI_NUMERICSERV", "ARES_NI_DGRAM", "ARES_NI_TCP", "ARES_NI_UDP", "ARES_NI_SCTP", "ARES_NI_DCCP", "ARES_NI_NUMERICSCOPE", "ARES_NI_LOOKUPHOST", "ARES_NI_LOOKUPSERVICE", "ARES_NI_IDN", "ARES_NI_IDN_ALLOW_UNASSIGNED", "ARES_NI_IDN_USE_STD3_ASCII_RULES", # Bad socket "ARES_SOCKET_BAD", # Query types "QUERY_TYPE_A", "QUERY_TYPE_AAAA", "QUERY_TYPE_ANY", "QUERY_TYPE_CAA", "QUERY_TYPE_CNAME", "QUERY_TYPE_HTTPS", "QUERY_TYPE_MX", "QUERY_TYPE_NAPTR", "QUERY_TYPE_NS", "QUERY_TYPE_PTR", "QUERY_TYPE_SOA", "QUERY_TYPE_SRV", "QUERY_TYPE_TLSA", "QUERY_TYPE_TXT", "QUERY_TYPE_URI", # Query classes "QUERY_CLASS_IN", "QUERY_CLASS_CHAOS", "QUERY_CLASS_HS", "QUERY_CLASS_NONE", "QUERY_CLASS_ANY", # Core stuff "ARES_VERSION", "AresError", "Channel", "errno", "__version__", # DNS record result types "DNSResult", "DNSRecord", "ARecordData", "AAAARecordData", "MXRecordData", "TXTRecordData", "CAARecordData", "CNAMERecordData", "HTTPSRecordData", "NAPTRRecordData", "NSRecordData", "PTRRecordData", "SOARecordData", "SRVRecordData", "TLSARecordData", "URIRecordData", # Host/AddrInfo result types "HostResult", "NameInfoResult", "AddrInfoResult", "AddrInfoNode", "AddrInfoCname", )