Skip to content

Commit

Permalink
[tls] Carry TLS state within (possibly) response writer (#728)
Browse files Browse the repository at this point in the history
* [tls] Carry TLS state within (possibly) response writer

This allows a server to make decision wether or not the link used to
connect to the DNS server is using TLS.
This can be used by the handler for instance to (but not limited to):
- log that the request was TLS vs TCP
- craft specific responsed knowing that the link is secured
- return custom answers based on client cert (if provided)
...

Fixes #711

* Address @tmthrgd comments:
- do not check whether w.tcp is nil
- create RR after setting txt value

* Address @miekg comments.

Attempt to make a TLS connection state specific test, it goes over
testing each individual server types (TLS, TCP, UDP) and validate that
tls.Connectionstate is only accessible when expected.

* ConnectionState() returns value instead of pointer

* * make ConnectionStater.ConnectionState() return a pointer again
* rename interface ConnectionState to ConnectionStater
* fix nits pointed by @tmthrgd

* @tmthrgd comment: Do not use concret type in `ConnectionState`
  • Loading branch information
chantra authored and miekg committed Sep 22, 2018
1 parent 426ea78 commit 833bf76
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 0 deletions.
18 changes: 18 additions & 0 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ type ResponseWriter interface {
Hijack()
}

// A ConnectionStater interface is used by a DNS Handler to access TLS connection state
// when available.
type ConnectionStater interface {
ConnectionState() *tls.ConnectionState
}

type response struct {
msg []byte
hijacked bool // connection has been hijacked by handler
Expand Down Expand Up @@ -894,3 +900,15 @@ func (w *response) Close() error {
}
return nil
}

// ConnectionState() implements the ConnectionStater.ConnectionState() interface.
func (w *response) ConnectionState() *tls.ConnectionState {
type tlsConnectionStater interface {
ConnectionState() tls.ConnectionState
}
if v, ok := w.tcp.(tlsConnectionStater); ok {
t := v.ConnectionState()
return &t
}
return nil
}
92 changes: 92 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,98 @@ func TestServingTLS(t *testing.T) {
}
}

// TestServingTLSConnectionState tests that we only can access
// tls.ConnectionState under a DNS query handled by a TLS DNS server.
// This test will sequentially create a TLS, UDP and TCP server, attach a custom
// handler which will set a testing error if tls.ConnectionState is available
// when it is not expected, or the other way around.
func TestServingTLSConnectionState(t *testing.T) {
handlerResponse := "Hello example"
// tlsHandlerTLS is a HandlerFunc that can be set to expect or not TLS
// connection state.
tlsHandlerTLS := func(tlsExpected bool) func(ResponseWriter, *Msg) {
return func(w ResponseWriter, req *Msg) {
m := new(Msg)
m.SetReply(req)
tlsFound := true
if connState := w.(ConnectionStater).ConnectionState(); connState == nil {
tlsFound = false
}
if tlsFound != tlsExpected {
t.Errorf("TLS connection state available: %t, expected: %t", tlsFound, tlsExpected)
}
m.Extra = make([]RR, 1)
m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{handlerResponse}}
w.WriteMsg(m)
}
}

// Question used in tests
m := new(Msg)
m.SetQuestion("tlsstate.example.net.", TypeTXT)

// TLS DNS server
HandleFunc(".", tlsHandlerTLS(true))
cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock)
if err != nil {
t.Fatalf("unable to build certificate: %v", err)
}

config := tls.Config{
Certificates: []tls.Certificate{cert},
}

s, addrstr, err := RunLocalTLSServer(":0", &config)
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
defer s.Shutdown()

// TLS DNS query
c := &Client{
Net: "tcp-tls",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
}

_, _, err = c.Exchange(m, addrstr)
if err != nil {
t.Error("failed to exchange tlsstate.example.net", err)
}

HandleRemove(".")
// UDP DNS Server
HandleFunc(".", tlsHandlerTLS(false))
defer HandleRemove(".")
s, addrstr, err = RunLocalUDPServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
defer s.Shutdown()

// UDP DNS query
c = new(Client)
_, _, err = c.Exchange(m, addrstr)
if err != nil {
t.Error("failed to exchange tlsstate.example.net", err)
}

// TCP DNS Server
s, addrstr, err = RunLocalTCPServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
defer s.Shutdown()

// TCP DNS query
c = &Client{Net: "tcp"}
_, _, err = c.Exchange(m, addrstr)
if err != nil {
t.Error("failed to exchange tlsstate.example.net", err)
}
}

func TestServingListenAndServe(t *testing.T) {
HandleFunc("example.com.", AnotherHelloServer)
defer HandleRemove("example.com.")
Expand Down

0 comments on commit 833bf76

Please sign in to comment.