Skip to content

Commit

Permalink
Fix segfault in State::serialize method
Browse files Browse the repository at this point in the history
The method gets invoked from public API function Conn::ConnectionState
but the cipherSuite pointer member might not have been initialized yet.
Invoking ConnectionState too early causes a segfault.
Issue is fixed by changing the return type of Conn::ConnectionState from
State to (State, bool) and returning (State{}, false) if the cipherSuite
has not been set.
  • Loading branch information
Danielius1922 authored and Sean-Der committed Aug 6, 2024
1 parent 5a72b12 commit f3e8a9e
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 29 deletions.
8 changes: 6 additions & 2 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -411,10 +411,14 @@ func (c *Conn) Close() error {

// ConnectionState returns basic DTLS details about the connection.
// Note that this replaced the `Export` function of v1.
func (c *Conn) ConnectionState() State {
func (c *Conn) ConnectionState() (State, bool) {
c.lock.RLock()
defer c.lock.RUnlock()
return *c.state.clone()
stateClone, err := c.state.clone()
if err != nil {
return State{}, false
}
return *stateClone, true
}

// SelectedSRTPProtectionProfile returns the selected SRTPProtectionProfile
Expand Down
115 changes: 101 additions & 14 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -497,28 +497,40 @@ func TestExportKeyingMaterial(t *testing.T) {
c.setLocalEpoch(0)
c.setRemoteEpoch(0)

state := c.ConnectionState()
state, ok := c.ConnectionState()
if !ok {
t.Fatal("ConnectionState failed")
}
_, err := state.ExportKeyingMaterial(exportLabel, nil, 0)
if !errors.Is(err, errHandshakeInProgress) {
t.Errorf("ExportKeyingMaterial when epoch == 0: expected '%s' actual '%s'", errHandshakeInProgress, err)
}

c.setLocalEpoch(1)
state = c.ConnectionState()
state, ok = c.ConnectionState()
if !ok {
t.Fatal("ConnectionState failed")
}
_, err = state.ExportKeyingMaterial(exportLabel, []byte{0x00}, 0)
if !errors.Is(err, errContextUnsupported) {
t.Errorf("ExportKeyingMaterial with context: expected '%s' actual '%s'", errContextUnsupported, err)
}

for k := range invalidKeyingLabels() {
state = c.ConnectionState()
state, ok = c.ConnectionState()
if !ok {
t.Fatal("ConnectionState failed")
}
_, err = state.ExportKeyingMaterial(k, nil, 0)
if !errors.Is(err, errReservedExportKeyingMaterial) {
t.Errorf("ExportKeyingMaterial reserved label: expected '%s' actual '%s'", errReservedExportKeyingMaterial, err)
}
}

state = c.ConnectionState()
state, ok = c.ConnectionState()
if !ok {
t.Fatal("ConnectionState failed")
}
keyingMaterial, err := state.ExportKeyingMaterial(exportLabel, nil, 10)
if err != nil {
t.Errorf("ExportKeyingMaterial as server: unexpected error '%s'", err)
Expand All @@ -527,7 +539,10 @@ func TestExportKeyingMaterial(t *testing.T) {
}

c.state.isClient = true
state = c.ConnectionState()
state, ok = c.ConnectionState()
if !ok {
t.Fatal("ConnectionState failed")
}
keyingMaterial, err = state.ExportKeyingMaterial(exportLabel, nil, 10)
if err != nil {
t.Errorf("ExportKeyingMaterial as server: unexpected error '%s'", err)
Expand Down Expand Up @@ -669,7 +684,11 @@ func TestPSK(t *testing.T) {
t.Fatalf("TestPSK: Server failed(%v)", err)
}

actualPSKIdentityHint := server.ConnectionState().IdentityHint
state, ok := server.ConnectionState()
if !ok {
t.Fatalf("TestPSK: Server ConnectionState failed")
}
actualPSKIdentityHint := state.IdentityHint
if !bytes.Equal(actualPSKIdentityHint, test.ClientIdentity) {
t.Errorf("TestPSK: Server ClientPSKIdentity Mismatch '%s': expected(%v) actual(%v)", test.Name, test.ClientIdentity, actualPSKIdentityHint)
}
Expand Down Expand Up @@ -1194,7 +1213,11 @@ func TestClientCertificate(t *testing.T) {
t.Errorf("Client failed(%v)", res.err)
}

actualClientCert := server.ConnectionState().PeerCertificates
state, ok := server.ConnectionState()
if !ok {
t.Error("Server connection state not available")
}
actualClientCert := state.PeerCertificates
if tt.serverCfg.ClientAuth == RequireAnyClientCert || tt.serverCfg.ClientAuth == RequireAndVerifyClientCert {
if actualClientCert == nil {
t.Errorf("Client did not provide a certificate")
Expand All @@ -1221,7 +1244,11 @@ func TestClientCertificate(t *testing.T) {
}
}

actualServerCert := res.c.ConnectionState().PeerCertificates
clientState, ok := res.c.ConnectionState()
if !ok {
t.Error("Client connection state not available")
}
actualServerCert := clientState.PeerCertificates
if actualServerCert == nil {
t.Errorf("Server did not provide a certificate")
}
Expand Down Expand Up @@ -2889,8 +2916,12 @@ func TestSessionResume(t *testing.T) {
t.Fatalf("TestSessionResume: Server failed(%v)", err)
}

actualSessionID := server.ConnectionState().SessionID
actualMasterSecret := server.ConnectionState().masterSecret
state, ok := server.ConnectionState()
if !ok {
t.Fatal("TestSessionResume: ConnectionState failed")
}
actualSessionID := state.SessionID
actualMasterSecret := state.masterSecret
if !bytes.Equal(actualSessionID, id) {
t.Errorf("TestSessionResumetion: SessionID Mismatch: expected(%v) actual(%v)", id, actualSessionID)
}
Expand Down Expand Up @@ -2940,8 +2971,12 @@ func TestSessionResume(t *testing.T) {
t.Fatalf("TestSessionResumetion: Server failed(%v)", err)
}

actualSessionID := server.ConnectionState().SessionID
actualMasterSecret := server.ConnectionState().masterSecret
state, ok := server.ConnectionState()
if !ok {
t.Fatal("TestSessionResumetion: ConnectionState failed")
}
actualSessionID := state.SessionID
actualMasterSecret := state.masterSecret
ss, _ := s2.Get(actualSessionID)
if !bytes.Equal(actualMasterSecret, ss.Secret) {
t.Errorf("TestSessionResumetion: masterSecret Mismatch: expected(%v) actual(%v)", ss.Secret, actualMasterSecret)
Expand Down Expand Up @@ -3071,8 +3106,8 @@ func TestCipherSuiteMatchesCertificateType(t *testing.T) {
t.Fatal(err)
} else if err := c.Close(); err != nil {
t.Fatal(err)
} else if c.ConnectionState().cipherSuite.ID() != test.expectedCipher {
t.Fatalf("Expected(%s) and Actual(%s) CipherSuite do not match", test.expectedCipher, c.ConnectionState().cipherSuite.ID())
} else if state, ok := c.ConnectionState(); !ok || state.cipherSuite.ID() != test.expectedCipher {
t.Fatalf("Expected(%s) and Actual(%s) CipherSuite do not match", test.expectedCipher, state.cipherSuite.ID())
}
})
}
Expand Down Expand Up @@ -3527,3 +3562,55 @@ func TestFragmentBuffer_Retransmission(t *testing.T) {
t.Fatal("fragment should be retransmission")
}
}

func TestConnectionState(t *testing.T) {
ca, cb := dpipe.Pipe()

// Setup client
clientCfg := &Config{}
clientCert, err := selfsign.GenerateSelfSigned()
if err != nil {
t.Fatal(err)
}
clientCfg.Certificates = []tls.Certificate{clientCert}
clientCfg.InsecureSkipVerify = true
client, err := Client(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), clientCfg)
if err != nil {
t.Fatal(err)
}
defer func() {
_ = client.Close()
}()

_, ok := client.ConnectionState()
if ok {
t.Fatal("ConnectionState should be nil")
}

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
c := make(chan error)
go func() {
errC := client.HandshakeContext(ctx)
c <- errC
}()

// Setup server
server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{}, true)
if err != nil {
t.Fatal(err)
}
defer func() {
_ = server.Close()
}()

err = <-c
if err != nil {
t.Fatal(err)
}

_, ok = client.ConnectionState()
if !ok {
t.Fatal("ConnectionState should not be nil")
}
}
12 changes: 10 additions & 2 deletions flight4handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,11 @@ func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handsh

if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeAnonymous {
if cfg.verifyConnection != nil {
if err := cfg.verifyConnection(state.clone()); err != nil {
stateClone, err := state.clone()
if err != nil {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
}
if err := cfg.verifyConnection(stateClone); err != nil {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
}
}
Expand All @@ -210,7 +214,11 @@ func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handsh
// go to flight6
}
if cfg.verifyConnection != nil {
if err := cfg.verifyConnection(state.clone()); err != nil {
stateClone, err := state.clone()
if err != nil {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
}
if err := cfg.verifyConnection(stateClone); err != nil {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
}
}
Expand Down
8 changes: 6 additions & 2 deletions flight5handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,12 @@ func initializeCipherSuite(state *State, cache *handshakeCache, cfg *handshakeCo
}
}
if cfg.verifyConnection != nil {
if err = cfg.verifyConnection(state.clone()); err != nil {
return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
stateClone, errC := state.clone()
if errC != nil {
return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errC
}
if errC = cfg.verifyConnection(stateClone); errC != nil {
return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, errC
}
}

Expand Down
10 changes: 8 additions & 2 deletions resume_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ import (
"github.com/pion/transport/v3/test"
)

var errMessageMissmatch = errors.New("messages missmatch")
var (
errMessageMissmatch = errors.New("messages missmatch")
errInvalidConnectionState = errors.New("failed to get connection state")
)

func TestResumeClient(t *testing.T) {
DoTestResume(t, Client, Server)
Expand Down Expand Up @@ -120,7 +123,10 @@ func DoTestResume(t *testing.T, newLocal, newRemote func(net.PacketConn, net.Add
}

// Serialize and deserialize state
state := local.ConnectionState()
state, ok := local.ConnectionState()
if !ok {
fatal(t, errChan, errInvalidConnectionState)
}
var b []byte
b, err = state.MarshalBinary()
if err != nil {
Expand Down
28 changes: 21 additions & 7 deletions state.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package dtls
import (
"bytes"
"encoding/gob"
"errors"
"sync/atomic"

"github.com/pion/dtls/v3/pkg/crypto/elliptic"
Expand Down Expand Up @@ -87,15 +88,25 @@ type serializedState struct {
NegotiatedProtocol string
}

func (s *State) clone() *State {
serialized := s.serialize()
var errCipherSuiteNotSet = &InternalError{Err: errors.New("cipher suite not set")} //nolint:goerr113

func (s *State) clone() (*State, error) {
serialized, err := s.serialize()
if err != nil {
return nil, err
}
state := &State{}
state.deserialize(*serialized)

return state
return state, err
}

func (s *State) serialize() *serializedState {
func (s *State) serialize() (*serializedState, error) {
if s.cipherSuite == nil {
return nil, errCipherSuiteNotSet
}
cipherSuiteID := uint16(s.cipherSuite.ID())

// Marshal random values
localRnd := s.localRandom.MarshalFixed()
remoteRnd := s.remoteRandom.MarshalFixed()
Expand All @@ -104,7 +115,7 @@ func (s *State) serialize() *serializedState {
return &serializedState{
LocalEpoch: s.getLocalEpoch(),
RemoteEpoch: s.getRemoteEpoch(),
CipherSuiteID: uint16(s.cipherSuite.ID()),
CipherSuiteID: cipherSuiteID,
MasterSecret: s.masterSecret,
SequenceNumber: atomic.LoadUint64(&s.localSequenceNumber[epoch]),
LocalRandom: localRnd,
Expand All @@ -117,7 +128,7 @@ func (s *State) serialize() *serializedState {
RemoteConnectionID: s.remoteConnectionID,
IsClient: s.isClient,
NegotiatedProtocol: s.NegotiatedProtocol,
}
}, nil
}

func (s *State) deserialize(serialized serializedState) {
Expand Down Expand Up @@ -187,7 +198,10 @@ func (s *State) initCipherSuite() error {

// MarshalBinary is a binary.BinaryMarshaler.MarshalBinary implementation
func (s *State) MarshalBinary() ([]byte, error) {
serialized := s.serialize()
serialized, err := s.serialize()
if err != nil {
return nil, err
}

var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
Expand Down

0 comments on commit f3e8a9e

Please sign in to comment.