Skip to content

Commit

Permalink
Merge pull request #3514 from lucas-clemente/closed-session
Browse files Browse the repository at this point in the history
use a single Go routine to send copies of CONNECTION_CLOSE packets
  • Loading branch information
marten-seemann authored Aug 22, 2022
2 parents 57650fc + 263f728 commit 509616c
Show file tree
Hide file tree
Showing 11 changed files with 269 additions and 219 deletions.
86 changes: 19 additions & 67 deletions closed_conn.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package quic

import (
"sync"
"math/bits"
"net"

"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
Expand All @@ -11,87 +12,38 @@ import (
// When receiving packets for such a connection, we need to retransmit the packet containing the CONNECTION_CLOSE frame,
// with an exponential backoff.
type closedLocalConn struct {
conn sendConn
connClosePacket []byte

closeOnce sync.Once
closeChan chan struct{} // is closed when the connection is closed or destroyed

receivedPackets chan *receivedPacket
counter uint64 // number of packets received

counter uint32
perspective protocol.Perspective
logger utils.Logger

logger utils.Logger
sendPacket func(net.Addr, *packetInfo)
}

var _ packetHandler = &closedLocalConn{}

// newClosedLocalConn creates a new closedLocalConn and runs it.
func newClosedLocalConn(
conn sendConn,
connClosePacket []byte,
perspective protocol.Perspective,
logger utils.Logger,
) packetHandler {
s := &closedLocalConn{
conn: conn,
connClosePacket: connClosePacket,
perspective: perspective,
logger: logger,
closeChan: make(chan struct{}),
receivedPackets: make(chan *receivedPacket, 64),
}
go s.run()
return s
}

func (s *closedLocalConn) run() {
for {
select {
case p := <-s.receivedPackets:
s.handlePacketImpl(p)
case <-s.closeChan:
return
}
func newClosedLocalConn(sendPacket func(net.Addr, *packetInfo), pers protocol.Perspective, logger utils.Logger) packetHandler {
return &closedLocalConn{
sendPacket: sendPacket,
perspective: pers,
logger: logger,
}
}

func (s *closedLocalConn) handlePacket(p *receivedPacket) {
select {
case s.receivedPackets <- p:
default:
}
}

func (s *closedLocalConn) handlePacketImpl(_ *receivedPacket) {
s.counter++
func (c *closedLocalConn) handlePacket(p *receivedPacket) {
c.counter++
// exponential backoff
// only send a CONNECTION_CLOSE for the 1st, 2nd, 4th, 8th, 16th, ... packet arriving
for n := s.counter; n > 1; n = n / 2 {
if n%2 != 0 {
return
}
}
s.logger.Debugf("Received %d packets after sending CONNECTION_CLOSE. Retransmitting.", s.counter)
if err := s.conn.Write(s.connClosePacket); err != nil {
s.logger.Debugf("Error retransmitting CONNECTION_CLOSE: %s", err)
if bits.OnesCount32(c.counter) != 1 {
return
}
c.logger.Debugf("Received %d packets after sending CONNECTION_CLOSE. Retransmitting.", c.counter)
c.sendPacket(p.remoteAddr, p.info)
}

func (s *closedLocalConn) shutdown() {
s.destroy(nil)
}

func (s *closedLocalConn) destroy(error) {
s.closeOnce.Do(func() {
close(s.closeChan)
})
}

func (s *closedLocalConn) getPerspective() protocol.Perspective {
return s.perspective
}
func (c *closedLocalConn) shutdown() {}
func (c *closedLocalConn) destroy(error) {}
func (c *closedLocalConn) getPerspective() protocol.Perspective { return c.perspective }

// A closedRemoteConn is a connection that was closed remotely.
// For such a connection, we might receive reordered packets that were sent before the CONNECTION_CLOSE.
Expand Down
42 changes: 12 additions & 30 deletions closed_conn_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package quic

import (
"errors"
"time"
"net"

"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"

Expand All @@ -13,44 +11,28 @@ import (
)

var _ = Describe("Closed local connection", func() {
var (
conn packetHandler
mconn *MockSendConn
)

BeforeEach(func() {
mconn = NewMockSendConn(mockCtrl)
conn = newClosedLocalConn(mconn, []byte("close"), protocol.PerspectiveClient, utils.DefaultLogger)
})

AfterEach(func() {
Eventually(areClosedConnsRunning).Should(BeFalse())
})

It("tells its perspective", func() {
conn := newClosedLocalConn(nil, protocol.PerspectiveClient, utils.DefaultLogger)
Expect(conn.getPerspective()).To(Equal(protocol.PerspectiveClient))
// stop the connection
conn.shutdown()
})

It("repeats the packet containing the CONNECTION_CLOSE frame", func() {
written := make(chan []byte)
mconn.EXPECT().Write(gomock.Any()).Do(func(p []byte) { written <- p }).AnyTimes()
written := make(chan net.Addr, 1)
conn := newClosedLocalConn(
func(addr net.Addr, _ *packetInfo) { written <- addr },
protocol.PerspectiveClient,
utils.DefaultLogger,
)
addr := &net.UDPAddr{IP: net.IPv4(127, 1, 2, 3), Port: 1337}
for i := 1; i <= 20; i++ {
conn.handlePacket(&receivedPacket{})
conn.handlePacket(&receivedPacket{remoteAddr: addr})
if i == 1 || i == 2 || i == 4 || i == 8 || i == 16 {
Eventually(written).Should(Receive(Equal([]byte("close")))) // receive the CONNECTION_CLOSE
Expect(written).To(Receive(Equal(addr))) // receive the CONNECTION_CLOSE
} else {
Consistently(written, 10*time.Millisecond).Should(HaveLen(0))
Expect(written).ToNot(Receive())
}
}
// stop the connection
conn.shutdown()
})

It("destroys connections", func() {
Eventually(areClosedConnsRunning).Should(BeTrue())
conn.destroy(errors.New("destroy"))
Eventually(areClosedConnsRunning).Should(BeFalse())
})
})
12 changes: 7 additions & 5 deletions conn_id_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ type connIDGenerator struct {
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken
removeConnectionID func(protocol.ConnectionID)
retireConnectionID func(protocol.ConnectionID)
replaceWithClosed func(protocol.ConnectionID, packetHandler)
replaceWithClosed func([]protocol.ConnectionID, protocol.Perspective, []byte)
queueControlFrame func(wire.Frame)

version protocol.VersionNumber
Expand All @@ -33,7 +33,7 @@ func newConnIDGenerator(
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken,
removeConnectionID func(protocol.ConnectionID),
retireConnectionID func(protocol.ConnectionID),
replaceWithClosed func(protocol.ConnectionID, packetHandler),
replaceWithClosed func([]protocol.ConnectionID, protocol.Perspective, []byte),
queueControlFrame func(wire.Frame),
version protocol.VersionNumber,
) *connIDGenerator {
Expand Down Expand Up @@ -130,11 +130,13 @@ func (m *connIDGenerator) RemoveAll() {
}
}

func (m *connIDGenerator) ReplaceWithClosed(handler packetHandler) {
func (m *connIDGenerator) ReplaceWithClosed(pers protocol.Perspective, connClose []byte) {
connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+1)
if m.initialClientDestConnID != nil {
m.replaceWithClosed(m.initialClientDestConnID, handler)
connIDs = append(connIDs, m.initialClientDestConnID)
}
for _, connID := range m.activeSrcConnIDs {
m.replaceWithClosed(connID, handler)
connIDs = append(connIDs, connID)
}
m.replaceWithClosed(connIDs, pers, connClose)
}
17 changes: 9 additions & 8 deletions conn_id_generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ var _ = Describe("Connection ID Generator", func() {
addedConnIDs []protocol.ConnectionID
retiredConnIDs []protocol.ConnectionID
removedConnIDs []protocol.ConnectionID
replacedWithClosed map[string]packetHandler
replacedWithClosed []protocol.ConnectionID
queuedFrames []wire.Frame
g *connIDGenerator
)
Expand All @@ -32,15 +32,17 @@ var _ = Describe("Connection ID Generator", func() {
retiredConnIDs = nil
removedConnIDs = nil
queuedFrames = nil
replacedWithClosed = make(map[string]packetHandler)
replacedWithClosed = nil
g = newConnIDGenerator(
initialConnID,
initialClientDestConnID,
func(c protocol.ConnectionID) { addedConnIDs = append(addedConnIDs, c) },
connIDToToken,
func(c protocol.ConnectionID) { removedConnIDs = append(removedConnIDs, c) },
func(c protocol.ConnectionID) { retiredConnIDs = append(retiredConnIDs, c) },
func(c protocol.ConnectionID, h packetHandler) { replacedWithClosed[string(c)] = h },
func(cs []protocol.ConnectionID, _ protocol.Perspective, _ []byte) {
replacedWithClosed = append(replacedWithClosed, cs...)
},
func(f wire.Frame) { queuedFrames = append(queuedFrames, f) },
protocol.VersionDraft29,
)
Expand Down Expand Up @@ -174,14 +176,13 @@ var _ = Describe("Connection ID Generator", func() {
It("replaces with a closed connection for all connection IDs", func() {
Expect(g.SetMaxActiveConnIDs(5)).To(Succeed())
Expect(queuedFrames).To(HaveLen(4))
sess := NewMockPacketHandler(mockCtrl)
g.ReplaceWithClosed(sess)
g.ReplaceWithClosed(protocol.PerspectiveClient, []byte("foobar"))
Expect(replacedWithClosed).To(HaveLen(6)) // initial conn ID, initial client dest conn id, and newly issued ones
Expect(replacedWithClosed).To(HaveKeyWithValue(string(initialClientDestConnID), sess))
Expect(replacedWithClosed).To(HaveKeyWithValue(string(initialConnID), sess))
Expect(replacedWithClosed).To(ContainElement(initialClientDestConnID))
Expect(replacedWithClosed).To(ContainElement(initialConnID))
for _, f := range queuedFrames {
nf := f.(*wire.NewConnectionIDFrame)
Expect(replacedWithClosed).To(HaveKeyWithValue(string(nf.ConnectionID), sess))
Expect(replacedWithClosed).To(ContainElement(nf.ConnectionID))
}
})
})
7 changes: 3 additions & 4 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ type connRunner interface {
GetStatelessResetToken(protocol.ConnectionID) protocol.StatelessResetToken
Retire(protocol.ConnectionID)
Remove(protocol.ConnectionID)
ReplaceWithClosed(protocol.ConnectionID, packetHandler)
ReplaceWithClosed([]protocol.ConnectionID, protocol.Perspective, []byte)
AddResetToken(protocol.StatelessResetToken, packetHandler)
RemoveResetToken(protocol.StatelessResetToken)
}
Expand Down Expand Up @@ -1521,7 +1521,7 @@ func (s *connection) handleCloseError(closeErr *closeError) {

// If this is a remote close we're done here
if closeErr.remote {
s.connIDGenerator.ReplaceWithClosed(newClosedRemoteConn(s.perspective))
s.connIDGenerator.ReplaceWithClosed(s.perspective, nil)
return
}
if closeErr.immediate {
Expand All @@ -1538,8 +1538,7 @@ func (s *connection) handleCloseError(closeErr *closeError) {
if err != nil {
s.logger.Debugf("Error sending CONNECTION_CLOSE: %s", err)
}
cs := newClosedLocalConn(s.conn, connClosePacket, s.perspective, s.logger)
s.connIDGenerator.ReplaceWithClosed(cs)
s.connIDGenerator.ReplaceWithClosed(s.perspective, connClosePacket)
}

func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) {
Expand Down
42 changes: 12 additions & 30 deletions connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,6 @@ func areConnsRunning() bool {
return strings.Contains(b.String(), "quic-go.(*connection).run")
}

func areClosedConnsRunning() bool {
var b bytes.Buffer
pprof.Lookup("goroutine").WriteTo(&b, 1)
return strings.Contains(b.String(), "quic-go.(*closedLocalConn).run")
}

var _ = Describe("Connection", func() {
var (
conn *connection
Expand Down Expand Up @@ -72,11 +66,11 @@ var _ = Describe("Connection", func() {
}

expectReplaceWithClosed := func() {
connRunner.EXPECT().ReplaceWithClosed(clientDestConnID, gomock.Any()).MaxTimes(1)
connRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) {
Expect(s).To(BeAssignableToTypeOf(&closedLocalConn{}))
s.shutdown()
Eventually(areClosedConnsRunning).Should(BeFalse())
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, _ protocol.Perspective, _ []byte) {
Expect(connIDs).To(ContainElement(srcConnID))
if len(connIDs) > 1 {
Expect(connIDs).To(ContainElement(clientDestConnID))
}
})
}

Expand Down Expand Up @@ -330,11 +324,8 @@ var _ = Describe("Connection", func() {
ErrorMessage: "foobar",
}
streamManager.EXPECT().CloseWithError(expectedErr)
connRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) {
Expect(s).To(BeAssignableToTypeOf(&closedRemoteConn{}))
})
connRunner.EXPECT().ReplaceWithClosed(clientDestConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) {
Expect(s).To(BeAssignableToTypeOf(&closedRemoteConn{}))
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, _ protocol.Perspective, _ []byte) {
Expect(connIDs).To(ConsistOf(clientDestConnID, srcConnID))
})
cryptoSetup.EXPECT().Close()
gomock.InOrder(
Expand All @@ -361,11 +352,8 @@ var _ = Describe("Connection", func() {
ErrorMessage: "foobar",
}
streamManager.EXPECT().CloseWithError(testErr)
connRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) {
Expect(s).To(BeAssignableToTypeOf(&closedRemoteConn{}))
})
connRunner.EXPECT().ReplaceWithClosed(clientDestConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) {
Expect(s).To(BeAssignableToTypeOf(&closedRemoteConn{}))
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, _ protocol.Perspective, _ []byte) {
Expect(connIDs).To(ConsistOf(clientDestConnID, srcConnID))
})
cryptoSetup.EXPECT().Close()
gomock.InOrder(
Expand Down Expand Up @@ -565,7 +553,7 @@ var _ = Describe("Connection", func() {
runConn()
cryptoSetup.EXPECT().Close()
streamManager.EXPECT().CloseWithError(gomock.Any())
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).AnyTimes()
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
buf := &bytes.Buffer{}
hdr := &wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: srcConnID},
Expand Down Expand Up @@ -2433,10 +2421,7 @@ var _ = Describe("Client Connection", func() {
}

expectReplaceWithClosed := func() {
connRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) {
s.shutdown()
Eventually(areClosedConnsRunning).Should(BeFalse())
})
connRunner.EXPECT().ReplaceWithClosed([]protocol.ConnectionID{srcConnID}, gomock.Any(), gomock.Any())
}

BeforeEach(func() {
Expand Down Expand Up @@ -2767,10 +2752,7 @@ var _ = Describe("Client Connection", func() {

expectClose := func(applicationClose bool) {
if !closed {
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) {
Expect(s).To(BeAssignableToTypeOf(&closedLocalConn{}))
s.shutdown()
})
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any())
if applicationClose {
packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil).MaxTimes(1)
} else {
Expand Down
Loading

0 comments on commit 509616c

Please sign in to comment.