Skip to content

Commit

Permalink
Merge pull request #1819 from pfreixes/master
Browse files Browse the repository at this point in the history
Add new ttl option to expire DNS entries after N seconds
  • Loading branch information
fafhrd91 authored Apr 19, 2017
2 parents 2248325 + 25f51ee commit ff77836
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 5 deletions.
18 changes: 16 additions & 2 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .helpers import SimpleCookie, is_ip_address, noop, sentinel
from .resolver import DefaultResolver


__all__ = ('BaseConnector', 'TCPConnector', 'UnixConnector')

HASHFUNC_BY_DIGESTLEN = {
Expand Down Expand Up @@ -495,6 +496,7 @@ class TCPConnector(BaseConnector):
resolver - Enable DNS lookups and use this
resolver
use_dns_cache - Use memory cache for DNS lookups.
ttl_dns_cache - Max seconds having cached a DNS entry, None forever.
family - socket address family
local_addr - local tuple of (host, port) to bind socket to
Expand All @@ -507,7 +509,7 @@ class TCPConnector(BaseConnector):
"""

def __init__(self, *, verify_ssl=True, fingerprint=None,
resolve=sentinel, use_dns_cache=True,
resolve=sentinel, use_dns_cache=True, ttl_dns_cache=10,
family=0, ssl_context=None, local_addr=None,
resolver=None, keepalive_timeout=sentinel,
force_close=False, limit=100, limit_per_host=0,
Expand Down Expand Up @@ -543,7 +545,9 @@ def __init__(self, *, verify_ssl=True, fingerprint=None,
self._resolver = resolver

self._use_dns_cache = use_dns_cache
self._ttl_dns_cache = ttl_dns_cache
self._cached_hosts = {}
self._cached_hosts_timestamp = {}
self._ssl_context = ssl_context
self._family = family
self._local_addr = local_addr
Expand Down Expand Up @@ -595,11 +599,20 @@ def clear_dns_cache(self, host=None, port=None):
"""Remove specified host/port or clear all dns local cache."""
if host is not None and port is not None:
self._cached_hosts.pop((host, port), None)
self._cached_hosts_timestamp.pop((host, port), None)
elif host is not None or port is not None:
raise ValueError("either both host and port "
"or none of them are allowed")
else:
self._cached_hosts.clear()
self._cached_hosts_timestamp.clear()

def _dns_entry_expired(self, key):
if self._ttl_dns_cache is None:
return False
return (
self._cached_hosts_timestamp[key] + self._ttl_dns_cache
) < self._loop.time()

@asyncio.coroutine
def _resolve_host(self, host, port):
Expand All @@ -610,9 +623,10 @@ def _resolve_host(self, host, port):
if self._use_dns_cache:
key = (host, port)

if key not in self._cached_hosts:
if key not in self._cached_hosts or (self._dns_entry_expired(key)):
self._cached_hosts[key] = yield from \
self._resolver.resolve(host, port, family=self._family)
self._cached_hosts_timestamp[key] = self._loop.time()

return self._cached_hosts[key]
else:
Expand Down
11 changes: 11 additions & 0 deletions docs/client_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,7 @@ TCPConnector

.. class:: TCPConnector(*, verify_ssl=True, fingerprint=None,\
use_dns_cache=True, \
ttl_dns_cache=10, \
family=0, ssl_context=None, conn_timeout=None, \
keepalive_timeout=30, limit=None, \
force_close=False, loop=None, local_addr=None, \
Expand Down Expand Up @@ -730,6 +731,16 @@ TCPConnector

The default is changed to ``True``

:param int ttl_dns_cache: expire after some seconds the DNS entries, ``None``
means cached forever. By default 10 seconds.

By default DNS entries are cached forever, in some environments the IP
addresses related to a specific HOST can change after a specific time. Use
this option to keep the DNS cache updated refreshing each entry after N
seconds.

.. versionadded:: 2.0.8

:param aiohttp.abc.AbstractResolver resolver: Custom resolver
instance to use. ``aiohttp.DefaultResolver`` by
default (asynchronous if ``aiodns>=1.1`` is installed).
Expand Down
38 changes: 35 additions & 3 deletions tests/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def test_release_close(loop):


@asyncio.coroutine
def test_tcp_connector_resolve_host_use_dns_cache(loop):
def test_tcp_connector_resolve_host(loop):
conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=True)

res = yield from conn._resolve_host('localhost', 8080)
Expand All @@ -365,15 +365,42 @@ def test_tcp_connector_resolve_host_use_dns_cache(loop):


@asyncio.coroutine
def test_tcp_connector_resolve_host_twice_use_dns_cache(loop):
conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=True)
def test_tcp_connector_dns_cache_not_expired(loop):
conn = aiohttp.TCPConnector(
loop=loop,
use_dns_cache=True,
ttl_dns_cache=10
)

res = yield from conn._resolve_host('localhost', 8080)
res2 = yield from conn._resolve_host('localhost', 8080)

assert res is res2


@asyncio.coroutine
def test_tcp_connector_dns_cache_forever(loop):
conn = aiohttp.TCPConnector(
loop=loop,
use_dns_cache=True,
ttl_dns_cache=None
)

res = yield from conn._resolve_host('localhost', 8080)
res2 = yield from conn._resolve_host('localhost', 8080)
assert res is res2


@asyncio.coroutine
def test_tcp_connector_use_dns_cache_disabled(loop):
conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=False)

res = yield from conn._resolve_host('localhost', 8080)
res2 = yield from conn._resolve_host('localhost', 8080)

assert res is not res2


def test_get_pop_empty_conns(loop):
# see issue #473
conn = aiohttp.BaseConnector(loop=loop)
Expand Down Expand Up @@ -606,13 +633,18 @@ def test_tcp_connector_clear_dns_cache(loop):
conn = aiohttp.TCPConnector(loop=loop)
info = object()
conn._cached_hosts[('localhost', 123)] = info
conn._cached_hosts_timestamp[('localhost', 123)] = 100
conn._cached_hosts[('localhost', 124)] = info
conn._cached_hosts_timestamp[('localhost', 124)] = 101
conn.clear_dns_cache('localhost', 123)
assert conn.cached_hosts == {('localhost', 124): info}
assert conn._cached_hosts_timestamp == {('localhost', 124): 101}
conn.clear_dns_cache('localhost', 123)
assert conn.cached_hosts == {('localhost', 124): info}
assert conn._cached_hosts_timestamp == {('localhost', 124): 101}
conn.clear_dns_cache()
assert conn.cached_hosts == {}
assert conn._cached_hosts_timestamp == {}


def test_tcp_connector_clear_dns_cache_bad_args(loop):
Expand Down

0 comments on commit ff77836

Please sign in to comment.