Skip to content

Commit

Permalink
dnsforward: imp code
Browse files Browse the repository at this point in the history
  • Loading branch information
Mizzick committed Apr 24, 2024
1 parent dcd5e41 commit 50888df
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 21 deletions.
2 changes: 2 additions & 0 deletions internal/dnsforward/beforerequest.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ var _ proxy.BeforeRequestHandler = (*Server)(nil)
// HandleBefore is the handler that is called before any other processing,
// including logs. It performs access checks and puts the client ID, if there
// is one, into the server's cache.
//
// TODO(d.kolyshev): Extract to separate package.
func (s *Server) HandleBefore(
_ *proxy.Proxy,
pctx *proxy.DNSContext,
Expand Down
36 changes: 15 additions & 21 deletions internal/dnsforward/beforerequest_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,18 @@ import (
"github.com/stretchr/testify/require"
)

func TestServer_HandleBefore(t *testing.T) {
const (
blockedHost = "blockedhost.org"
testFQDN = "example.org."
dnsClientTimeout = 200 * time.Millisecond
)

func TestServer_HandleBefore_tls(t *testing.T) {
t.Parallel()

const (
blockedHost = "blockedhost.org"
clientID = "client-1"
testFQDN = "example.org."
)
const clientID = "client-1"

testCases := []struct {
want assert.ValueAssertionFunc
clientSrvName string
name string
host string
Expand All @@ -33,7 +34,6 @@ func TestServer_HandleBefore(t *testing.T) {
blockedHosts []string
wantRCode int
}{{
want: assert.NotEmpty,
clientSrvName: tlsServerName,
name: "allow_all",
host: testFQDN,
Expand All @@ -42,7 +42,6 @@ func TestServer_HandleBefore(t *testing.T) {
blockedHosts: []string{},
wantRCode: dns.RcodeSuccess,
}, {
want: assert.Empty,
clientSrvName: "%" + "." + tlsServerName,
name: "invalid_client_id",
host: testFQDN,
Expand All @@ -51,7 +50,6 @@ func TestServer_HandleBefore(t *testing.T) {
blockedHosts: []string{},
wantRCode: dns.RcodeServerFailure,
}, {
want: assert.NotEmpty,
clientSrvName: clientID + "." + tlsServerName,
name: "allowed_client_allowed",
host: testFQDN,
Expand All @@ -60,7 +58,6 @@ func TestServer_HandleBefore(t *testing.T) {
blockedHosts: []string{},
wantRCode: dns.RcodeSuccess,
}, {
want: assert.Empty,
clientSrvName: "client-2." + tlsServerName,
name: "allowed_client_rejected",
host: testFQDN,
Expand All @@ -69,7 +66,6 @@ func TestServer_HandleBefore(t *testing.T) {
blockedHosts: []string{},
wantRCode: dns.RcodeRefused,
}, {
want: assert.NotEmpty,
clientSrvName: tlsServerName,
name: "disallowed_client_allowed",
host: testFQDN,
Expand All @@ -78,7 +74,6 @@ func TestServer_HandleBefore(t *testing.T) {
blockedHosts: []string{},
wantRCode: dns.RcodeSuccess,
}, {
want: assert.Empty,
clientSrvName: clientID + "." + tlsServerName,
name: "disallowed_client_rejected",
host: testFQDN,
Expand All @@ -87,7 +82,6 @@ func TestServer_HandleBefore(t *testing.T) {
blockedHosts: []string{},
wantRCode: dns.RcodeRefused,
}, {
want: assert.NotEmpty,
clientSrvName: tlsServerName,
name: "blocked_hosts_allowed",
host: testFQDN,
Expand All @@ -96,7 +90,6 @@ func TestServer_HandleBefore(t *testing.T) {
blockedHosts: []string{blockedHost},
wantRCode: dns.RcodeSuccess,
}, {
want: assert.Empty,
clientSrvName: tlsServerName,
name: "blocked_hosts_rejected",
host: dns.Fqdn(blockedHost),
Expand Down Expand Up @@ -143,7 +136,7 @@ func TestServer_HandleBefore(t *testing.T) {
client := &dns.Client{
Net: "tcp-tls",
TLSConfig: tlsConfig,
Timeout: 200 * time.Millisecond,
Timeout: dnsClientTimeout,
}

req := createTestMessage(tc.host)
Expand All @@ -152,8 +145,12 @@ func TestServer_HandleBefore(t *testing.T) {
reply, _, err := client.Exchange(req, addr)
require.NoError(t, err)

tc.want(t, reply.Answer)
assert.Equal(t, tc.wantRCode, reply.Rcode)
if tc.wantRCode == dns.RcodeSuccess {
assert.NotEmpty(t, reply.Answer)
} else {
assert.Empty(t, reply.Answer)
}
})
}
}
Expand All @@ -164,9 +161,6 @@ func TestServer_HandleBefore_udp(t *testing.T) {
const (
clientIPv4 = "127.0.0.1"
clientIPv6 = "::1"

blockedHost = "blockedhost.org"
testFQDN = "example.org."
)

clientIPs := []string{clientIPv4, clientIPv6}
Expand Down Expand Up @@ -266,7 +260,7 @@ func TestServer_HandleBefore_udp(t *testing.T) {

client := &dns.Client{
Net: "udp",
Timeout: 200 * time.Millisecond,
Timeout: dnsClientTimeout,
}

req := createTestMessage(tc.host)
Expand Down

0 comments on commit 50888df

Please sign in to comment.