diff --git a/server.go b/server.go index 4fbf7db6f5..2901f8724c 100644 --- a/server.go +++ b/server.go @@ -63,6 +63,12 @@ type ResponseWriter interface { Hijack() } +// A ConnectionStater interface is used by a DNS Handler to access TLS connection state +// when available. +type ConnectionStater interface { + ConnectionState() *tls.ConnectionState +} + type response struct { msg []byte hijacked bool // connection has been hijacked by handler @@ -894,3 +900,15 @@ func (w *response) Close() error { } return nil } + +// ConnectionState() implements the ConnectionStater.ConnectionState() interface. +func (w *response) ConnectionState() *tls.ConnectionState { + type tlsConnectionStater interface { + ConnectionState() tls.ConnectionState + } + if v, ok := w.tcp.(tlsConnectionStater); ok { + t := v.ConnectionState() + return &t + } + return nil +} diff --git a/server_test.go b/server_test.go index c53edd0383..d60f0cd010 100644 --- a/server_test.go +++ b/server_test.go @@ -263,6 +263,98 @@ func TestServingTLS(t *testing.T) { } } +// TestServingTLSConnectionState tests that we only can access +// tls.ConnectionState under a DNS query handled by a TLS DNS server. +// This test will sequentially create a TLS, UDP and TCP server, attach a custom +// handler which will set a testing error if tls.ConnectionState is available +// when it is not expected, or the other way around. +func TestServingTLSConnectionState(t *testing.T) { + handlerResponse := "Hello example" + // tlsHandlerTLS is a HandlerFunc that can be set to expect or not TLS + // connection state. + tlsHandlerTLS := func(tlsExpected bool) func(ResponseWriter, *Msg) { + return func(w ResponseWriter, req *Msg) { + m := new(Msg) + m.SetReply(req) + tlsFound := true + if connState := w.(ConnectionStater).ConnectionState(); connState == nil { + tlsFound = false + } + if tlsFound != tlsExpected { + t.Errorf("TLS connection state available: %t, expected: %t", tlsFound, tlsExpected) + } + m.Extra = make([]RR, 1) + m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{handlerResponse}} + w.WriteMsg(m) + } + } + + // Question used in tests + m := new(Msg) + m.SetQuestion("tlsstate.example.net.", TypeTXT) + + // TLS DNS server + HandleFunc(".", tlsHandlerTLS(true)) + cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock) + if err != nil { + t.Fatalf("unable to build certificate: %v", err) + } + + config := tls.Config{ + Certificates: []tls.Certificate{cert}, + } + + s, addrstr, err := RunLocalTLSServer(":0", &config) + if err != nil { + t.Fatalf("unable to run test server: %v", err) + } + defer s.Shutdown() + + // TLS DNS query + c := &Client{ + Net: "tcp-tls", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + } + + _, _, err = c.Exchange(m, addrstr) + if err != nil { + t.Error("failed to exchange tlsstate.example.net", err) + } + + HandleRemove(".") + // UDP DNS Server + HandleFunc(".", tlsHandlerTLS(false)) + defer HandleRemove(".") + s, addrstr, err = RunLocalUDPServer(":0") + if err != nil { + t.Fatalf("unable to run test server: %v", err) + } + defer s.Shutdown() + + // UDP DNS query + c = new(Client) + _, _, err = c.Exchange(m, addrstr) + if err != nil { + t.Error("failed to exchange tlsstate.example.net", err) + } + + // TCP DNS Server + s, addrstr, err = RunLocalTCPServer(":0") + if err != nil { + t.Fatalf("unable to run test server: %v", err) + } + defer s.Shutdown() + + // TCP DNS query + c = &Client{Net: "tcp"} + _, _, err = c.Exchange(m, addrstr) + if err != nil { + t.Error("failed to exchange tlsstate.example.net", err) + } +} + func TestServingListenAndServe(t *testing.T) { HandleFunc("example.com.", AnotherHelloServer) defer HandleRemove("example.com.")