From ee255e89568cdca89d7abcd76c497610cb76d7d7 Mon Sep 17 00:00:00 2001 From: Juliusz Chroboczek Date: Wed, 7 Jul 2021 15:25:57 +0200 Subject: [PATCH] Avoid crash after a PC callback has been reset We used to crash if a PC callback was reset, due to confusion between a nil interface and an interface whose value is nil. Fixes #1871 --- peerconnection.go | 20 ++++++++++++-------- peerconnection_go_test.go | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 8 deletions(-) diff --git a/peerconnection.go b/peerconnection.go index e669bf8c77b..c2b7d9b9fc3 100644 --- a/peerconnection.go +++ b/peerconnection.go @@ -296,7 +296,7 @@ func (pc *PeerConnection) onNegotiationNeeded() { func (pc *PeerConnection) negotiationNeededOp() { // Don't run NegotiatedNeeded checks if OnNegotiationNeeded is not set - if handler := pc.onNegotiationNeededHandler.Load(); handler == nil { + if handler, ok := pc.onNegotiationNeededHandler.Load().(func()); !ok || handler == nil { return } @@ -464,8 +464,8 @@ func (pc *PeerConnection) OnICEConnectionStateChange(f func(ICEConnectionState)) func (pc *PeerConnection) onICEConnectionStateChange(cs ICEConnectionState) { pc.iceConnectionState.Store(cs) pc.log.Infof("ICE connection state changed: %s", cs) - if handler := pc.onICEConnectionStateChangeHandler.Load(); handler != nil { - handler.(func(ICEConnectionState))(cs) + if handler, ok := pc.onICEConnectionStateChangeHandler.Load().(func(ICEConnectionState)); ok && handler != nil { + handler(cs) } } @@ -475,6 +475,14 @@ func (pc *PeerConnection) OnConnectionStateChange(f func(PeerConnectionState)) { pc.onConnectionStateChangeHandler.Store(f) } +func (pc *PeerConnection) onConnectionStateChange(cs PeerConnectionState) { + pc.connectionState.Store(cs) + pc.log.Infof("peer connection state changed: %s", cs) + if handler, ok := pc.onConnectionStateChangeHandler.Load().(func(PeerConnectionState)); ok && handler != nil { + go handler(cs) + } +} + // SetConfiguration updates the configuration of this PeerConnection object. func (pc *PeerConnection) SetConfiguration(configuration Configuration) error { //nolint:gocognit // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-setconfiguration (step #2) @@ -736,11 +744,7 @@ func (pc *PeerConnection) updateConnectionState(iceConnectionState ICEConnection return } - pc.log.Infof("peer connection state changed: %s", connectionState) - pc.connectionState.Store(connectionState) - if handler := pc.onConnectionStateChangeHandler.Load(); handler != nil { - go handler.(func(PeerConnectionState))(connectionState) - } + pc.onConnectionStateChange(connectionState) } func (pc *PeerConnection) createICETransport() *ICETransport { diff --git a/peerconnection_go_test.go b/peerconnection_go_test.go index d1cd63c0fbb..6da636f3ecb 100644 --- a/peerconnection_go_test.go +++ b/peerconnection_go_test.go @@ -1397,3 +1397,40 @@ func TestPeerConnection_SessionID(t *testing.T) { } closePairNow(t, pcOffer, pcAnswer) } + +func TestPeerConnectionNilCallback(t *testing.T) { + pc, err := NewPeerConnection(Configuration{}) + assert.NoError(t, err) + + pc.onSignalingStateChange(SignalingStateStable) + pc.OnSignalingStateChange(func(ss SignalingState) { + t.Error("OnSignalingStateChange called") + }) + pc.OnSignalingStateChange(nil) + pc.onSignalingStateChange(SignalingStateStable) + + pc.onConnectionStateChange(PeerConnectionStateNew) + pc.OnConnectionStateChange(func(pcs PeerConnectionState) { + t.Error("OnConnectionStateChange called") + }) + pc.OnConnectionStateChange(nil) + pc.onConnectionStateChange(PeerConnectionStateNew) + + pc.onICEConnectionStateChange(ICEConnectionStateNew) + pc.OnICEConnectionStateChange(func(ics ICEConnectionState) { + t.Error("OnConnectionStateChange called") + }) + pc.OnICEConnectionStateChange(nil) + pc.onICEConnectionStateChange(ICEConnectionStateNew) + + pc.onNegotiationNeeded() + pc.negotiationNeededOp() + pc.OnNegotiationNeeded(func() { + t.Error("OnNegotiationNeeded called") + }) + pc.OnNegotiationNeeded(nil) + pc.onNegotiationNeeded() + pc.negotiationNeededOp() + + assert.NoError(t, pc.Close()) +}