Skip to content

Commit

Permalink
Option to listen multiple ports
Browse files Browse the repository at this point in the history
  • Loading branch information
jose-pr committed Oct 21, 2023
1 parent 22e5a4f commit 0784f9f
Showing 1 changed file with 51 additions and 17 deletions.
68 changes: 51 additions & 17 deletions dnserver/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from datetime import datetime
from pathlib import Path
from textwrap import wrap
from typing import Any, List, Generic, TypeVar, overload, TypeVarTuple
from types import NoneType
from typing import Any, List, Generic, TypeVar, overload, TypeVarTuple, Iterable, TypeAlias, Sequence
from threading import Lock

from dnslib import QTYPE, RR, DNSLabel, dns, DNSRecord
Expand Down Expand Up @@ -167,29 +168,59 @@ def resolve(self, request: DNSRecord, handler: DNSHandler):
return answer


Port: TypeAlias = tuple[int, bool]


def _ports(obj):
if isinstance(obj, Sequence):
if len(obj) == 2 and isinstance(obj[1], (bool, NoneType)):
return (obj[0], obj[1])
return None
return (obj, None)


class BaseDNSServer(Generic[R]):
resolver: R

@overload
def __new__(self, resolver: R, port: int | None = None) -> BaseDNSServer[R]:
def __new__(self, resolver: R, port: int | Port | Iterable[int | Port] | None = None) -> BaseDNSServer[R]:
...

@overload
def __new__(self, resolver: str, port: int | None = None) -> BaseDNSServer[RoundRobinResolver | ProxyResolver]:
def __new__(
self, resolver: str, port: int | Port | Iterable[int | Port] | None = None
) -> BaseDNSServer[RoundRobinResolver | ProxyResolver]:
...

@overload
def __new__(
self, resolver: Records | SharedObject[Records] | None = None, port: int | None = None
self,
resolver: Records | SharedObject[Records] | None = None,
port: int | Port | Iterable[int | Port] | None = None,
) -> BaseDNSServer[RecordsResolver]:
...

def __new__(cls, *args, **kwargs):
return super().__new__(cls)

def __init__(self, resolver: R | Records | SharedObject[Records] | str | None = None, port: int | None = None):
self.port: int = DEFAULT_PORT if port is None else int(port)
self.servers: list[LibDNSServer] = []
def __init__(
self,
resolver: R | Records | SharedObject[Records] | str | None = None,
port: int | Port | Iterable[int | Port] | None = None,
):
ports: list[Port] = DEFAULT_PORT if port is None else int(port)
_port = _ports(ports)
if _port is not None:
ports = [_port]
self.servers: dict[Port, LibDNSServer | None] = {}
for port in ports:
port, tcp = port
port = int(port or DEFAULT_PORT)
if tcp is None or tcp is False:
self.servers[(port, False)] = None
if tcp is None or tcp is True:
self.servers[(port, True)] = None

self.resolver = resolver or Records(zones=[])
if isinstance(self.resolver, Records):
self.resolver = SharedObject(self.resolver)
Expand All @@ -206,25 +237,28 @@ def __init__(self, resolver: R | Records | SharedObject[Records] | str | None =
raise ValueError(self.resolver)

def start(self):
for port, tcp in [(self.port, False), (self.port, True)]:
logger.info('starting DNS server on port %d protocol: %s"', port, 'tcp' if tcp else 'udp')
server = LibDNSServer(self.resolver, port=self.port)
for port, tcp in self.servers:
logger.info('starting DNS server on port %d protocol: %s', port, 'tcp' if tcp else 'udp')
server = LibDNSServer(self.resolver, port=port, tcp=tcp)
server.start_thread()
self.servers.append(server)
self.servers[(port, tcp)] = server

def stop(self):
for server in self.servers:
for server in self.servers.values():
server.stop()
server.server.server_close()
self.servers = []

@property
def is_running(self):
for server in self.servers:
for server in self.servers.values():
if server.isAlive():
return True
return False

@property
def port(self):
return next(self.servers.keys().__iter__())[0]


class DNSServer(BaseDNSServer[RoundRobinResolver[RecordsResolver, ProxyResolver] | RecordsResolver]):
def __new__(cls, *args, **kwargs) -> 'DNSServer':
Expand All @@ -233,18 +267,18 @@ def __new__(cls, *args, **kwargs) -> 'DNSServer':
def __init__(
self,
records: Records | SharedObject[Records] | None = None,
port: int | str | None = DEFAULT_PORT,
port: int | Port | Iterable[int | Port] | None = DEFAULT_PORT,
upstream: str | None = DEFAULT_UPSTREAM,
):
super().__init__(records, port)
self.records: SharedObject[Records] = self.resolver._records
if upstream:
logger.info('starting DNS server on port %d, upstream DNS server "%s"', self.port, upstream)
logger.info('upstream DNS server "%s"', upstream)
self.resolver = RoundRobinResolver(
[self.resolver, *[ProxyResolver(*upstream.split(":")) for upstream in upstream.split(',')]]
)
else:
logger.info('starting DNS server on port %d, without upstream DNS server', self.port)
logger.info('without upstream DNS server')

@classmethod
def from_toml(
Expand Down

0 comments on commit 0784f9f

Please sign in to comment.