diff options
author | Laurent Ghigonis <laurent@p1sec.com> | 2013-04-20 13:30:05 +0200 |
---|---|---|
committer | Laurent Ghigonis <laurent@p1sec.com> | 2013-04-20 13:30:05 +0200 |
commit | 2f4a252454906325e49cfa4fe2b39e3e00f0621c (patch) | |
tree | 7dfbd336488a80ec1ce1093abbdf5b76915a50e2 /toys/brhute-py/threaded_resolver.py | |
parent | brhute: move to dedicated dir (diff) | |
download | laurent-tools-2f4a252454906325e49cfa4fe2b39e3e00f0621c.tar.xz laurent-tools-2f4a252454906325e49cfa4fe2b39e3e00f0621c.zip |
brhute: move to brhute-py
Diffstat (limited to 'toys/brhute-py/threaded_resolver.py')
-rw-r--r-- | toys/brhute-py/threaded_resolver.py | 302 |
1 files changed, 302 insertions, 0 deletions
diff --git a/toys/brhute-py/threaded_resolver.py b/toys/brhute-py/threaded_resolver.py new file mode 100644 index 0000000..a2d347b --- /dev/null +++ b/toys/brhute-py/threaded_resolver.py @@ -0,0 +1,302 @@ +# From pyxmpp2 resolver.py +# Made standalone by laurent +# https://raw.github.com/Jajcus/pyxmpp2/master/pyxmpp2/resolver.py + +# (C) Copyright 2003-2011 Jacek Konieczny <jajcus@jajcus.net> +# (C) Copyright 2013 Laurent Ghigonis <laurent@p1sec.com> +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License Version +# 2.1 as published by the Free Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this program; if not, write to the Free Software +# Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. +# + +"""DNS resolever with SRV record support. + +Normative reference: + - `RFC 1035 <http://www.ietf.org/rfc/rfc1035.txt>`__ + - `RFC 2782 <http://www.ietf.org/rfc/rfc2782.txt>`__ +""" + +import socket +import random +import logging +import threading +import Queue + +import dns.resolver +import dns.name +import dns.exception + +DEFAULT_SETTINGS = {"ipv4": True, "ipv6": False, "prefer_ipv6": False} + +def is_ipv6_available(): + """Check if IPv6 is available. + + :Return: `True` when an IPv6 socket can be created. + """ + try: + socket.socket(socket.AF_INET6).close() + except (socket.error, AttributeError): + return False + return True + +def is_ipv4_available(): + """Check if IPv4 is available. + + :Return: `True` when an IPv4 socket can be created. + """ + try: + socket.socket(socket.AF_INET).close() + except socket.error: + return False + return True + +def shuffle_srv(records): + """Randomly reorder SRV records using their weights. + + :Parameters: + - `records`: SRV records to shuffle. + :Types: + - `records`: sequence of :dns:`dns.rdtypes.IN.SRV` + + :return: reordered records. + :returntype: `list` of :dns:`dns.rdtypes.IN.SRV`""" + if not records: + return [] + ret = [] + while len(records) > 1: + weight_sum = 0 + for rrecord in records: + weight_sum += rrecord.weight + 0.1 + thres = random.random() * weight_sum + weight_sum = 0 + for rrecord in records: + weight_sum += rrecord.weight + 0.1 + if thres < weight_sum: + records.remove(rrecord) + ret.append(rrecord) + break + ret.append(records[0]) + return ret + +def reorder_srv(records): + """Reorder SRV records using their priorities and weights. + + :Parameters: + - `records`: SRV records to shuffle. + :Types: + - `records`: `list` of :dns:`dns.rdtypes.IN.SRV` + + :return: reordered records. + :returntype: `list` of :dns:`dns.rdtypes.IN.SRV`""" + records = list(records) + records.sort() + ret = [] + tmp = [] + for rrecord in records: + if not tmp or rrecord.priority == tmp[0].priority: + tmp.append(rrecord) + continue + ret += shuffle_srv(tmp) + tmp = [rrecord] + if tmp: + ret += shuffle_srv(tmp) + return ret + +class BlockingResolver(): + """Blocking resolver using the DNSPython package. + + Both `resolve_srv` and `resolve_address` will block until the + lookup completes or fail and then call the callback immediately. + """ + def __init__(self, settings = None): + if settings: + self.settings = settings + else: + self.settings = DEFAULT_SETTINGS + + def resolve_srv(self, domain, service, protocol, callback): + """Start looking up an SRV record for `service` at `domain`. + + `callback` will be called with a properly sorted list of (hostname, + port) pairs on success. The list will be empty on error and it will + contain only (".", 0) when the service is explicitely disabled. + + :Parameters: + - `domain`: domain name to look up + - `service`: service name e.g. 'xmpp-client' + - `protocol`: protocol name, e.g. 'tcp' + - `callback`: a function to be called with a list of received + addresses + :Types: + - `domain`: `unicode` + - `service`: `unicode` + - `protocol`: `unicode` + - `callback`: function accepting a single argument + """ + if isinstance(domain, unicode): + domain = domain.encode("idna").decode("us-ascii") + domain = "_{0}._{1}.{2}".format(service, protocol, domain) + try: + records = dns.resolver.query(domain, 'SRV') + except dns.exception.DNSException, err: + logger.warning("Could not resolve {0!r}: {1}" + .format(domain, err.__class__.__name__)) + callback([]) + return + if not records: + callback([]) + return + + result = [] + for record in reorder_srv(records): + hostname = record.target.to_text() + if hostname in (".", ""): + continue + result.append((hostname, record.port)) + + if not result: + callback([(".", 0)]) + else: + callback(result) + return + + def resolve_address(self, hostname, callback, allow_cname = True): + """Start looking up an A or AAAA record. + + `callback` will be called with a list of (family, address) tuples + (each holiding socket.AF_* and IPv4 or IPv6 address literal) on + success. The list will be empty on error. + + :Parameters: + - `hostname`: the host name to look up + - `callback`: a function to be called with a list of received + addresses + - `allow_cname`: `True` if CNAMEs should be followed + :Types: + - `hostname`: `unicode` + - `callback`: function accepting a single argument + - `allow_cname`: `bool` + """ + if isinstance(hostname, unicode): + hostname = hostname.encode("idna").decode("us-ascii") + rtypes = [] + if self.settings["ipv6"]: + rtypes.append(("AAAA", socket.AF_INET6)) + if self.settings["ipv4"]: + rtypes.append(("A", socket.AF_INET)) + if not self.settings["prefer_ipv6"]: + rtypes.reverse() + exception = None + result = [] + for rtype, rfamily in rtypes: + try: + try: + records = dns.resolver.query(hostname, rtype) + except dns.exception.DNSException: + records = dns.resolver.query(hostname + ".", rtype) + except dns.exception.DNSException, err: + exception = err + continue + if not allow_cname and records.rrset.name != dns.name.from_text( + hostname): + logger.warning("Unexpected CNAME record found for {0!r}" + .format(hostname)) + continue + if records: + for record in records: + result.append((rfamily, record.to_text())) + + if not result and exception: + logger.warning("Could not resolve {0!r}: {1}".format(hostname, + exception.__class__.__name__)) + callback(result) + +class ThreadedResolver(): + """Base class for threaded resolvers. + + Starts worker threads, each running a blocking resolver implementation + and communicates with them to provide non-blocking asynchronous API. + """ + def __init__(self, settings = None, max_threads = 1): + if settings: + self.settings = settings + else: + self.settings = DEFAULT_SETTINGS + self.threads = [] + self.queue = Queue.Queue() + self.lock = threading.RLock() + self.max_threads = max_threads + self.last_thread_n = 0 + +def _make_resolver(self): + """Threaded resolver implementation using the DNSPython + :dns:`dns.resolver` module. + """ + return BlockingResolver(self.settings) + + def stop(self): + """Stop the resolver threads. + """ + with self.lock: + for dummy in self.threads: + self.queue.put(None) + + def _start_thread(self): + """Start a new working thread unless the maximum number of threads + has been reached or the request queue is empty. + """ + with self.lock: + if self.threads and self.queue.empty(): + return + if len(self.threads) >= self.max_threads: + return + thread_n = self.last_thread_n + 1 + self.last_thread_n = thread_n + thread = threading.Thread(target = self._run, + name = "{0!r} #{1}".format(self, thread_n), + args = (thread_n,)) + self.threads.append(thread) + thread.daemon = True + thread.start() + + def resolve_address(self, hostname, callback, allow_cname = True): + request = ("resolve_address", (hostname, callback, allow_cname)) + self._start_thread() + self.queue.put(request) + + def resolve_srv(self, domain, service, protocol, callback): + request = ("resolve_srv", (domain, service, protocol, callback)) + self._start_thread() + self.queue.put(request) + + def _run(self, thread_n): + """The thread function.""" + try: + logger.debug("{0!r}: entering thread #{1}" + .format(self, thread_n)) + resolver = self._make_resolver() + while True: + request = self.queue.get() + if request is None: + break + method, args = request + logger.debug(" calling {0!r}.{1}{2!r}" + .format(resolver, method, args)) + getattr(resolver, method)(*args) # pylint: disable=W0142 + self.queue.task_done() + logger.debug("{0!r}: leaving thread #{1}" + .format(self, thread_n)) + finally: + self.threads.remove(threading.currentThread()) + +# vi: sts=4 et sw=4 |