forked from folbricht/routedns
-
Notifications
You must be signed in to change notification settings - Fork 0
/
route.go
136 lines (125 loc) · 2.58 KB
/
route.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
package rdns
import (
"errors"
"fmt"
"net"
"regexp"
"strings"
"github.com/miekg/dns"
)
type route struct {
types []uint16
class uint16
name *regexp.Regexp
source *net.IPNet
inverted bool // invert the matching behavior
resolver Resolver
}
// NewRoute initializes a route from string parameters.
func NewRoute(name, class string, types []string, source string, resolver Resolver) (*route, error) {
if resolver == nil {
return nil, errors.New("no resolver defined for route")
}
t, err := stringToType(types)
if err != nil {
return nil, err
}
c, err := stringToClass(class)
if err != nil {
return nil, err
}
re, err := regexp.Compile(name)
if err != nil {
return nil, err
}
var sNet *net.IPNet
if source != "" {
_, sNet, err = net.ParseCIDR(source)
if err != nil {
return nil, err
}
}
return &route{
types: t,
class: c,
name: re,
source: sNet,
resolver: resolver,
}, nil
}
func (r *route) match(q *dns.Msg, ci ClientInfo) bool {
question := q.Question[0]
if !r.matchType(question.Qtype) {
return r.inverted
}
if r.class != 0 && r.class != question.Qclass {
return r.inverted
}
if !r.name.MatchString(question.Name) {
return r.inverted
}
if r.source != nil && !r.source.Contains(ci.SourceIP) {
return r.inverted
}
return !r.inverted
}
func (r *route) Invert(value bool) {
r.inverted = value
}
func (r *route) String() string {
if r.isDefault() {
return fmt.Sprintf("default->%s", r.resolver)
}
return fmt.Sprintf("%s->%s", r.name, r.resolver)
}
func (r *route) isDefault() bool {
return r.class == 0 && len(r.types) == 0 && r.name.String() == ""
}
func (r *route) matchType(typ uint16) bool {
if len(r.types) == 0 {
return true
}
for _, t := range r.types {
if t == typ {
return true
}
}
return false
}
// Convert DNS type strings into the numerical type, for example "A" -> 1.
func stringToType(s []string) ([]uint16, error) {
if len(s) == 0 {
return nil, nil
}
var types []uint16
loop:
for _, typ := range s {
for k, v := range dns.TypeToString {
if v == strings.ToUpper(typ) {
types = append(types, k)
continue loop
}
}
return nil, fmt.Errorf("unknown type '%s'", s)
}
return types, nil
}
// Convert a DNS class string into is numerical form, for example "INET" -> 1.
func stringToClass(s string) (uint16, error) {
switch strings.ToUpper(s) {
case "":
return 0, nil
case "IN", "INET":
return 1, nil
case "CH":
return 3, nil
case "HS":
return 4, nil
case "NONE":
return 254, nil
case "ANY":
return 255, nil
default:
return 0, fmt.Errorf("unknown class '%s'", s)
}
}