Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Merge pull request #3016 from silkeh/improve-service-lookups
Browse files Browse the repository at this point in the history
Improve handling of SRV records for federation connections
  • Loading branch information
richvdh authored Apr 9, 2018
2 parents aea3a93 + 72251d1 commit 664adb4
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 122 deletions.
103 changes: 8 additions & 95 deletions synapse/http/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import socket

from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.internet import defer, reactor
from twisted.internet.error import ConnectError
Expand All @@ -33,7 +31,7 @@

# our record of an individual server which can be tried to reach a destination.
#
# "host" is actually a dotted-quad or ipv6 address string. Except when there's
# "host" is the hostname acquired from the SRV record. Except when there's
# no SRV record, in which case it is the original hostname.
_Server = collections.namedtuple(
"_Server", "priority weight host port expires"
Expand Down Expand Up @@ -297,20 +295,13 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t

payload = answer.payload

hosts = yield _get_hosts_for_srv_record(
dns_client, str(payload.target)
)

for (ip, ttl) in hosts:
host_ttl = min(answer.ttl, ttl)

servers.append(_Server(
host=ip,
port=int(payload.port),
priority=int(payload.priority),
weight=int(payload.weight),
expires=int(clock.time()) + host_ttl,
))
servers.append(_Server(
host=str(payload.target),
port=int(payload.port),
priority=int(payload.priority),
weight=int(payload.weight),
expires=int(clock.time()) + answer.ttl,
))

servers.sort()
cache[service_name] = list(servers)
Expand All @@ -328,81 +319,3 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
raise e

defer.returnValue(servers)


@defer.inlineCallbacks
def _get_hosts_for_srv_record(dns_client, host):
"""Look up each of the hosts in a SRV record
Args:
dns_client (twisted.names.dns.IResolver):
host (basestring): host to look up
Returns:
Deferred[list[(str, int)]]: a list of (host, ttl) pairs
"""
ip4_servers = []
ip6_servers = []

def cb(res):
# lookupAddress and lookupIP6Address return a three-tuple
# giving the answer, authority, and additional sections of the
# response.
#
# we only care about the answers.

return res[0]

def eb(res, record_type):
if res.check(DNSNameError):
return []
logger.warn("Error looking up %s for %s: %s", record_type, host, res)
return res

# no logcontexts here, so we can safely fire these off and gatherResults
d1 = dns_client.lookupAddress(host).addCallbacks(
cb, eb, errbackArgs=("A", ))
d2 = dns_client.lookupIPV6Address(host).addCallbacks(
cb, eb, errbackArgs=("AAAA", ))
results = yield defer.DeferredList(
[d1, d2], consumeErrors=True)

# if all of the lookups failed, raise an exception rather than blowing out
# the cache with an empty result.
if results and all(s == defer.FAILURE for (s, _) in results):
defer.returnValue(results[0][1])

for (success, result) in results:
if success == defer.FAILURE:
continue

for answer in result:
if not answer.payload:
continue

try:
if answer.type == dns.A:
ip = answer.payload.dottedQuad()
ip4_servers.append((ip, answer.ttl))
elif answer.type == dns.AAAA:
ip = socket.inet_ntop(
socket.AF_INET6, answer.payload.address,
)
ip6_servers.append((ip, answer.ttl))
else:
# the most likely candidate here is a CNAME record.
# rfc2782 says srvs may not point to aliases.
logger.warn(
"Ignoring unexpected DNS record type %s for %s",
answer.type, host,
)
continue
except Exception as e:
logger.warn("Ignoring invalid DNS response for %s: %s",
host, e)
continue

# keep the ipv4 results before the ipv6 results, mostly to match historical
# behaviour.
defer.returnValue(ip4_servers + ip6_servers)
29 changes: 2 additions & 27 deletions tests/test_dns.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ def test_resolve(self):

service_name = "test_service.example.com"
host_name = "example.com"
ip_address = "127.0.0.1"
ip6_address = "::1"

answer_srv = dns.RRHeader(
type=dns.SRV,
Expand All @@ -43,29 +41,9 @@ def test_resolve(self):
)
)

answer_a = dns.RRHeader(
type=dns.A,
payload=dns.Record_A(
address=ip_address,
)
)

answer_aaaa = dns.RRHeader(
type=dns.AAAA,
payload=dns.Record_AAAA(
address=ip6_address,
)
)

dns_client_mock.lookupService.return_value = defer.succeed(
([answer_srv], None, None),
)
dns_client_mock.lookupAddress.return_value = defer.succeed(
([answer_a], None, None),
)
dns_client_mock.lookupIPV6Address.return_value = defer.succeed(
([answer_aaaa], None, None),
)

cache = {}

Expand All @@ -74,13 +52,10 @@ def test_resolve(self):
)

dns_client_mock.lookupService.assert_called_once_with(service_name)
dns_client_mock.lookupAddress.assert_called_once_with(host_name)
dns_client_mock.lookupIPV6Address.assert_called_once_with(host_name)

self.assertEquals(len(servers), 2)
self.assertEquals(len(servers), 1)
self.assertEquals(servers, cache[service_name])
self.assertEquals(servers[0].host, ip_address)
self.assertEquals(servers[1].host, ip6_address)
self.assertEquals(servers[0].host, host_name)

@defer.inlineCallbacks
def test_from_cache_expired_and_dns_fail(self):
Expand Down

0 comments on commit 664adb4

Please sign in to comment.