Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow more resolver options #19

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
name: CI

on:
workflow_dispatch:
push:
branches:
- main
Expand Down
257 changes: 184 additions & 73 deletions dnserver/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
from datetime import datetime
from pathlib import Path
from textwrap import wrap
from typing import Any, List
from threading import Lock
from typing import Any, Dict, Generic, Iterable, List, Sequence, Tuple, TypeVar, overload

from dnslib import QTYPE, RR, DNSLabel, dns
from dnslib import QTYPE, RR, DNSLabel, DNSRecord, dns
from dnslib.proxy import ProxyResolver as LibProxyResolver
from dnslib.server import BaseResolver as LibBaseResolver, DNSServer as LibDNSServer
from dnslib.server import BaseResolver as LibBaseResolver, DNSHandler, DNSServer as LibDNSServer

from .load_records import Records, Zone, load_records

Expand Down Expand Up @@ -85,74 +86,205 @@ def __str__(self):
return str(self.rr)


def resolve(request, handler, records):
records = [Record(zone) for zone in records.zones]
type_name = QTYPE[request.q.qtype]
reply = request.reply()
for record in records:
if record.match(request.q):
reply.add_answer(record.rr)
T = TypeVar('T')

if reply.rr:
logger.info('found zone for %s[%s], %d replies', request.q.qname, type_name, len(reply.rr))
return reply

# no direct zone so look for an SOA record for a higher level zone
for record in records:
if record.sub_match(request.q):
reply.add_answer(record.rr)
class SharedObject(Generic[T]):
def __init__(self, obj: T, lock: Lock = None) -> None:
self._obj = obj
self.lock = lock or Lock()

if reply.rr:
logger.info('found higher level SOA resource for %s[%s]', request.q.qname, type_name)
return reply
def __enter__(self):
self.lock.acquire()
return self._obj

def __exit__(self, exc_type, exc_value, traceback):
self.lock.release()

class BaseResolver(LibBaseResolver):
def __init__(self, records: Records):
self.records = records
super().__init__()
def set(self, obj: T):
with self:
self._obj = obj

def resolve(self, request, handler):
answer = resolve(request, handler, self.records)
if answer:
return answer

class RecordsResolver(LibBaseResolver):
def __init__(self, records: SharedObject[Records]):
self._records = records

def records(self):
with self._records as records:
return [Record(zone) for zone in records.zones]

def resolve(self, request: DNSRecord, handler: DNSHandler):
records = self.records()
type_name = QTYPE[request.q.qtype]
logger.info('no local zone found, not proxying %s[%s]', request.q.qname, type_name)
reply = request.reply()
for record in records:
if record.match(request.q):
reply.add_answer(record.rr)

if reply.rr:
logger.info('found zone for %s[%s], %d replies', request.q.qname, type_name, len(reply.rr))
return reply

# no direct zone so look for an SOA record for a higher level zone
for record in records:
if record.sub_match(request.q):
reply.add_answer(record.rr)

if reply.rr:
logger.info('found higher level SOA resource for %s[%s]', request.q.qname, type_name)
return reply

logger.info('no local zone found %s[%s]', request.q.qname, type_name)
return request.reply()


class ProxyResolver(LibProxyResolver):
def __init__(self, records: Records, upstream: str):
self.records = records
super().__init__(address=upstream, port=53, timeout=5)

def resolve(self, request, handler):
answer = resolve(request, handler, self.records)
if answer:
return answer
def __init__(self, upstream: str, port=DEFAULT_PORT, timeout=5):
super().__init__(address=upstream, port=int(port or DEFAULT_PORT), timeout=int(timeout or 5))

def resolve(self, request: DNSRecord, handler: DNSHandler):
type_name = QTYPE[request.q.qtype]
logger.info('no local zone found, proxying %s[%s]', request.q.qname, type_name)
logger.info('proxying %s[%s]', request.q.qname, type_name)
return super().resolve(request, handler)


class DNSServer:
R = TypeVar('R', bound=LibBaseResolver)


class RoundRobinResolver(LibBaseResolver, Generic[R]):
def __init__(self, resolvers: Iterable[R]):
self.resolvers = tuple(resolvers)

def resolve(self, request: DNSRecord, handler: DNSHandler):
answer = request.reply()
resolver: LibBaseResolver
for resolver in self.resolvers:
answer: DNSRecord = resolver.resolve(request, handler)
if answer.header.rcode == 0 and answer.rr:
return answer
return answer


Port = Tuple[int, bool]


def _ports(obj):
if isinstance(obj, Sequence):
if len(obj) == 2 and isinstance(obj[1], (bool, type(None))):
return (obj[0], obj[1])
return None
return (obj, None)


class BaseDNSServer(Generic[R]):
resolver: R

@overload
def __new__(self, resolver: R, port: 'int | Port | Iterable[int | Port] | None' = None) -> BaseDNSServer[R]:
...

@overload
def __new__(
self, resolver: str, port: 'int | Port | Iterable[int | Port] | None' = None
) -> BaseDNSServer[RoundRobinResolver | ProxyResolver]:
...

@overload
def __new__(
self,
resolver: 'Records | SharedObject[Records] | None' = None,
port: 'int | Port | Iterable[int | Port] | None' = None,
) -> BaseDNSServer[RecordsResolver]:
...

def __new__(cls, *args, **kwargs):
return super().__new__(cls)

def __init__(
self,
records: Records | None = None,
port: int | str | None = DEFAULT_PORT,
upstream: str | None = DEFAULT_UPSTREAM,
resolver: 'R | Records | SharedObject[Records] | str | None' = None,
port: 'int | Port | Iterable[int | Port] | None' = None,
):
self.port: int = DEFAULT_PORT if port is None else int(port)
self.upstream: str | None = upstream
self.udp_server: LibDNSServer | None = None
self.tcp_server: LibDNSServer | None = None
self.records: Records = records if records else Records(zones=[])
ports: List[Port] = DEFAULT_PORT if port is None else port
_port = _ports(ports)
if _port is not None:
ports = [_port]
self.servers: Dict[Port, 'LibDNSServer | None'] = {}
for port in ports:
port, tcp = _ports(port)
port = int(port or DEFAULT_PORT)
if tcp is None or tcp is False:
self.servers[(port, False)] = None
if tcp is None or tcp is True:
self.servers[(port, True)] = None

self.resolver = resolver or Records(zones=[])
if isinstance(self.resolver, Records):
self.resolver = SharedObject(self.resolver)
if isinstance(self.resolver, SharedObject):
self.resolver = RecordsResolver(self.resolver)
if isinstance(self.resolver, str):
resolvers = [ProxyResolver(*upstream.split(':')) for upstream in resolver.split(',')]
if len(resolvers) > 1:
self.resolver = RoundRobinResolver(resolvers)
else:
self.resolver = resolvers[0]

if not isinstance(self.resolver, LibBaseResolver):
raise ValueError(self.resolver)

def start(self):
for port, tcp in self.servers:
logger.info('starting DNS server on port %d protocol: %s', port, 'tcp' if tcp else 'udp')
server = LibDNSServer(self.resolver, port=port, tcp=tcp)
server.start_thread()
self.servers[(port, tcp)] = server

def stop(self):
for server in self.servers.values():
server.stop()
server.server.server_close()

@property
def is_running(self):
for server in self.servers.values():
if server.isAlive():
return True
return False

@property
def port(self):
return next(self.servers.keys().__iter__())[0]


class DNSServer(BaseDNSServer['RoundRobinResolver[RecordsResolver | ProxyResolver] | RecordsResolver']):
def __new__(cls, *args, **kwargs) -> 'DNSServer':
return super().__new__(cls)

def __init__(
self,
records: 'Records | SharedObject[Records] | None' = None,
port: 'int | Port | Iterable[int | Port] | None' = DEFAULT_PORT,
upstream: 'str | None' = DEFAULT_UPSTREAM,
):
super().__init__(records, port)
self.records: SharedObject[Records] = self.resolver._records
if upstream:
logger.info('upstream DNS server "%s"', upstream)
self.resolver = RoundRobinResolver(
[self.resolver, *[ProxyResolver(*upstream.split(':')) for upstream in upstream.split(',')]]
)
else:
logger.info('without upstream DNS server')

@classmethod
def from_toml(
cls, zones_file: str | Path, *, port: int | str | None = DEFAULT_PORT, upstream: str | None = DEFAULT_UPSTREAM
cls,
zones_file: 'str | Path',
*,
port: 'int | str | None' = DEFAULT_PORT,
upstream: 'str | None' = DEFAULT_UPSTREAM,
) -> 'DNSServer':
records = load_records(zones_file)
logger.info(
Expand All @@ -163,31 +295,10 @@ def from_toml(
)
return DNSServer(records, port=port, upstream=upstream)

def start(self):
if self.upstream:
logger.info('starting DNS server on port %d, upstream DNS server "%s"', self.port, self.upstream)
resolver = ProxyResolver(self.records, self.upstream)
else:
logger.info('starting DNS server on port %d, without upstream DNS server', self.port)
resolver = BaseResolver(self.records)

self.udp_server = LibDNSServer(resolver, port=self.port)
self.tcp_server = LibDNSServer(resolver, port=self.port, tcp=True)
self.udp_server.start_thread()
self.tcp_server.start_thread()

def stop(self):
self.udp_server.stop()
self.udp_server.server.server_close()
self.tcp_server.stop()
self.tcp_server.server.server_close()

@property
def is_running(self):
return (self.udp_server and self.udp_server.isAlive()) or (self.tcp_server and self.tcp_server.isAlive())

def add_record(self, zone: Zone):
self.records.zones.append(zone)
with self.records as records:
records.zones.append(zone)

def set_records(self, zones: List[Zone]):
self.records.zones = zones
with self.records as records:
records.zones = zones
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ Changelog = 'https://github.com/samuelcolvin/dnserver/releases'

[tool.pytest.ini_options]
testpaths = 'tests'
filterwarnings = ['error']
filterwarnings = ['error','ignore::DeprecationWarning']
timeout = 10

[tool.coverage.run]
Expand Down