diff --git a/dnserver/main.py b/dnserver/main.py index ae7687c..6f1dbbf 100755 --- a/dnserver/main.py +++ b/dnserver/main.py @@ -4,7 +4,8 @@ from datetime import datetime from pathlib import Path from textwrap import wrap -from typing import Any, List, Generic, TypeVar, overload, TypeVarTuple +from types import NoneType +from typing import Any, List, Generic, TypeVar, overload, TypeVarTuple, Iterable, TypeAlias, Sequence from threading import Lock from dnslib import QTYPE, RR, DNSLabel, dns, DNSRecord @@ -167,29 +168,59 @@ def resolve(self, request: DNSRecord, handler: DNSHandler): return answer +Port: TypeAlias = tuple[int, bool] + + +def _ports(obj): + if isinstance(obj, Sequence): + if len(obj) == 2 and isinstance(obj[1], (bool, NoneType)): + return (obj[0], obj[1]) + return None + return (obj, None) + + class BaseDNSServer(Generic[R]): resolver: R @overload - def __new__(self, resolver: R, port: int | None = None) -> BaseDNSServer[R]: + def __new__(self, resolver: R, port: int | Port | Iterable[int | Port] | None = None) -> BaseDNSServer[R]: ... @overload - def __new__(self, resolver: str, port: int | None = None) -> BaseDNSServer[RoundRobinResolver | ProxyResolver]: + def __new__( + self, resolver: str, port: int | Port | Iterable[int | Port] | None = None + ) -> BaseDNSServer[RoundRobinResolver | ProxyResolver]: ... @overload def __new__( - self, resolver: Records | SharedObject[Records] | None = None, port: int | None = None + self, + resolver: Records | SharedObject[Records] | None = None, + port: int | Port | Iterable[int | Port] | None = None, ) -> BaseDNSServer[RecordsResolver]: ... def __new__(cls, *args, **kwargs): return super().__new__(cls) - def __init__(self, resolver: R | Records | SharedObject[Records] | str | None = None, port: int | None = None): - self.port: int = DEFAULT_PORT if port is None else int(port) - self.servers: list[LibDNSServer] = [] + def __init__( + self, + resolver: R | Records | SharedObject[Records] | str | None = None, + port: int | Port | Iterable[int | Port] | None = None, + ): + ports: list[Port] = DEFAULT_PORT if port is None else int(port) + _port = _ports(ports) + if _port is not None: + ports = [_port] + self.servers: dict[Port, LibDNSServer | None] = {} + for port in ports: + port, tcp = port + port = int(port or DEFAULT_PORT) + if tcp is None or tcp is False: + self.servers[(port, False)] = None + if tcp is None or tcp is True: + self.servers[(port, True)] = None + self.resolver = resolver or Records(zones=[]) if isinstance(self.resolver, Records): self.resolver = SharedObject(self.resolver) @@ -206,25 +237,28 @@ def __init__(self, resolver: R | Records | SharedObject[Records] | str | None = raise ValueError(self.resolver) def start(self): - for port, tcp in [(self.port, False), (self.port, True)]: - logger.info('starting DNS server on port %d protocol: %s"', port, 'tcp' if tcp else 'udp') - server = LibDNSServer(self.resolver, port=self.port) + for port, tcp in self.servers: + logger.info('starting DNS server on port %d protocol: %s', port, 'tcp' if tcp else 'udp') + server = LibDNSServer(self.resolver, port=port, tcp=tcp) server.start_thread() - self.servers.append(server) + self.servers[(port, tcp)] = server def stop(self): - for server in self.servers: + for server in self.servers.values(): server.stop() server.server.server_close() - self.servers = [] @property def is_running(self): - for server in self.servers: + for server in self.servers.values(): if server.isAlive(): return True return False + @property + def port(self): + return next(self.servers.keys().__iter__())[0] + class DNSServer(BaseDNSServer[RoundRobinResolver[RecordsResolver, ProxyResolver] | RecordsResolver]): def __new__(cls, *args, **kwargs) -> 'DNSServer': @@ -233,18 +267,18 @@ def __new__(cls, *args, **kwargs) -> 'DNSServer': def __init__( self, records: Records | SharedObject[Records] | None = None, - port: int | str | None = DEFAULT_PORT, + port: int | Port | Iterable[int | Port] | None = DEFAULT_PORT, upstream: str | None = DEFAULT_UPSTREAM, ): super().__init__(records, port) self.records: SharedObject[Records] = self.resolver._records if upstream: - logger.info('starting DNS server on port %d, upstream DNS server "%s"', self.port, upstream) + logger.info('upstream DNS server "%s"', upstream) self.resolver = RoundRobinResolver( [self.resolver, *[ProxyResolver(*upstream.split(":")) for upstream in upstream.split(',')]] ) else: - logger.info('starting DNS server on port %d, without upstream DNS server', self.port) + logger.info('without upstream DNS server') @classmethod def from_toml(