Skip to content

Commit

Permalink
core: add support for ares_search
Browse files Browse the repository at this point in the history
It's like ares_query but it takes the "domain" and "search" directives
in resolv.conf into accout.

Closes: #67
  • Loading branch information
saghul committed Feb 14, 2019
1 parent 53e3971 commit 44a892a
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 4 deletions.
10 changes: 10 additions & 0 deletions docs/channel.rst
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,16 @@
.. note::
TTL is not implemented for CNAME and NS), so it's set to -1.

.. py:method:: search(name, query_type, callback)
:param string name: Name to query.

:param int query_type: Type of query to perform.

:param callable callback: Callback to be called with the result of the query.

Tis function does the same as :py:meth:`query` but it will honor the ``domain`` and ``search`` directives in
``resolv.conf``.

.. py:method:: cancel()
Expand Down
16 changes: 12 additions & 4 deletions src/pycares/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,8 @@ def parse_result(query_type, abuf, alen):


class Channel:
__qtypes__ = (_lib.T_A, _lib.T_AAAA, _lib.T_ANY, _lib.T_CNAME, _lib.T_MX, _lib.T_NAPTR, _lib.T_NS, _lib.T_PTR, _lib.T_SOA, _lib.T_SRV, _lib.T_TXT)

def __init__(self,
flags = None,
timeout = None,
Expand Down Expand Up @@ -500,15 +502,21 @@ def gethostbyname(self, name, family, callback):
_lib.ares_gethostbyname(self._channel[0], parse_name(name), family, _host_cb, userdata)

def query(self, name, query_type, callback):
self._do_query(_lib.ares_query, name, query_type, callback)

def search(self, name, query_type, callback):
self._do_query(_lib.ares_search, name, query_type, callback)

def _do_query(self, func, name, query_type, callback):
if not callable(callback):
raise TypeError("a callable is required")
raise TypeError('a callable is required')

if query_type not in (_lib.T_A, _lib.T_AAAA, _lib.T_ANY, _lib.T_CNAME, _lib.T_MX, _lib.T_NAPTR, _lib.T_NS, _lib.T_PTR, _lib.T_SOA, _lib.T_SRV, _lib.T_TXT):
raise ValueError("invalid query type specified")
if query_type not in self.__qtypes__:
raise ValueError('invalid query type specified')

userdata = _ffi.new_handle((callback, query_type))
_global_set.add(userdata)
_lib.ares_query(self._channel[0], parse_name(name), _lib.C_IN, query_type, _query_cb, userdata)
func(self._channel[0], parse_name(name), _lib.C_IN, query_type, _query_cb, userdata)

def set_local_ip(self, ip):
addr4 = _ffi.new("struct in_addr*")
Expand Down
12 changes: 12 additions & 0 deletions tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,18 @@ def test_errorcode_dict(self):
val = getattr(pycares.errno, err)
self.assertEqual(pycares.errno.errorcode[val], err)

def test_search(self):
self.result, self.errorno = None, None
def cb(result, errorno):
self.result, self.errorno = result, errorno
self.channel = pycares.Channel(timeout=5.0, tries=1, domains=['google.com'])
self.channel.search('cloud', pycares.QUERY_TYPE_A, cb)
self.wait()
self.assertNoError(self.errorno)
for r in self.result:
self.assertEqual(type(r), pycares.ares_query_a_result)
self.assertNotEqual(r.host, None)


if __name__ == '__main__':
unittest.main(verbosity=2)
Expand Down

0 comments on commit 44a892a

Please sign in to comment.