diff --git a/connection.go b/connection.go index d906bb0..eaeaa2d 100644 --- a/connection.go +++ b/connection.go @@ -23,6 +23,7 @@ import ( "net" "net/http" "sync" + "sync/atomic" "time" "github.com/moby/spdystream/spdy" @@ -40,18 +41,38 @@ const ( QUEUE_SIZE = 50 ) +// atomicBool uses load/store operations on an int32 to simulate an atomic boolean. +type atomicBool struct { + v int32 +} + +// set sets the int32 to the given boolean. +func (a *atomicBool) set(value bool) { + if value { + atomic.StoreInt32(&a.v, 1) + return + } + atomic.StoreInt32(&a.v, 0) +} + +// get returns true if the int32 == 1 +func (a *atomicBool) get() bool { + return atomic.LoadInt32(&a.v) == 1 +} + type StreamHandler func(stream *Stream) type AuthHandler func(header http.Header, slot uint8, parent uint32) bool type idleAwareFramer struct { - f *spdy.Framer - conn *Connection - writeLock sync.Mutex - resetChan chan struct{} - setTimeoutLock sync.Mutex - setTimeoutChan chan time.Duration - timeout time.Duration + f *spdy.Framer + conn *Connection + writeLock sync.Mutex + resetChan chan struct{} + setTimeoutLock sync.Mutex + setTimeoutChan chan time.Duration + timeout time.Duration + ignorePingFrames *atomicBool } func newIdleAwareFramer(framer *spdy.Framer) *idleAwareFramer { @@ -60,7 +81,8 @@ func newIdleAwareFramer(framer *spdy.Framer) *idleAwareFramer { resetChan: make(chan struct{}, 2), // setTimeoutChan needs to be buffered to avoid deadlocks when calling setIdleTimeout at about // the same time the connection is being closed - setTimeoutChan: make(chan time.Duration, 1), + setTimeoutChan: make(chan time.Duration, 1), + ignorePingFrames: &atomicBool{0}, } return iaf } @@ -158,6 +180,13 @@ func (i *idleAwareFramer) WriteFrame(frame spdy.Frame) error { return err } + if i.ignorePingFrames.get() { + _, ok := frame.(*spdy.PingFrame) + if ok { + return nil + } + } + i.resetChan <- struct{}{} return nil @@ -169,16 +198,24 @@ func (i *idleAwareFramer) ReadFrame() (spdy.Frame, error) { return nil, err } + if i.ignorePingFrames.get() { + _, ok := frame.(*spdy.PingFrame) + if ok { + return frame, nil + } + } + // resetChan should never be closed since it is only closed // when the connection has closed its closeChan. This closure // only occurs after all Reads have finished // TODO (dmcgowan): refactor relationship into connection i.resetChan <- struct{}{} - return frame, nil } -func (i *idleAwareFramer) setIdleTimeout(timeout time.Duration) { +func (i *idleAwareFramer) setIdleTimeout(timeout time.Duration, ignorePingFrames bool) { + i.ignorePingFrames.set(ignorePingFrames) + i.setTimeoutLock.Lock() defer i.setTimeoutLock.Unlock() @@ -834,7 +871,13 @@ func (s *Connection) SetCloseTimeout(timeout time.Duration) { // SetIdleTimeout sets the amount of time the connection may sit idle before // it is forcefully terminated. func (s *Connection) SetIdleTimeout(timeout time.Duration) { - s.framer.setIdleTimeout(timeout) + s.framer.setIdleTimeout(timeout, false) +} + +// SetUserIdleTimeout sets the amount of time the connection may sit idle, +// not taking into account SPDY Ping frames, before it is forcefully terminated +func (s *Connection) SetUserIdleTimeout(timeout time.Duration) { + s.framer.setIdleTimeout(timeout, true) } func (s *Connection) sendHeaders(headers http.Header, stream *Stream, fin bool) error { diff --git a/spdy_test.go b/spdy_test.go index 312a950..582fb5a 100644 --- a/spdy_test.go +++ b/spdy_test.go @@ -50,7 +50,9 @@ func TestSpdyStreams(t *testing.T) { } go spdyConn.Serve(NoOpStreamHandler) + authMu.Lock() authenticated = true + authMu.Unlock() stream, streamErr := spdyConn.CreateStream(http.Header{}, nil, false) if streamErr != nil { t.Fatalf("Error creating stream: %s", streamErr) @@ -144,7 +146,9 @@ func TestSpdyStreams(t *testing.T) { t.Fatalf("Error reseting stream: %s", streamResetErr) } + authMu.Lock() authenticated = false + authMu.Unlock() badStream, badStreamErr := spdyConn.CreateStream(http.Header{}, nil, false) if badStreamErr != nil { t.Fatalf("Error creating stream: %s", badStreamErr) @@ -225,7 +229,9 @@ func TestHalfClose(t *testing.T) { } go spdyConn.Serve(NoOpStreamHandler) + authMu.Lock() authenticated = true + authMu.Unlock() stream, streamErr := spdyConn.CreateStream(http.Header{}, nil, false) if streamErr != nil { t.Fatalf("Error creating stream: %s", streamErr) @@ -311,7 +317,9 @@ func TestUnexpectedRemoteConnectionClosed(t *testing.T) { } go spdyConn.Serve(NoOpStreamHandler) + authMu.Lock() authenticated = true + authMu.Unlock() stream, streamErr := spdyConn.CreateStream(http.Header{}, nil, false) if streamErr != nil { t.Fatalf("Error creating stream: %s", streamErr) @@ -427,7 +435,9 @@ func TestIdleShutdownRace(t *testing.T) { } go spdyConn.Serve(NoOpStreamHandler) + authMu.Lock() authenticated = true + authMu.Unlock() stream, err := spdyConn.CreateStream(http.Header{}, nil, false) if err != nil { t.Fatalf("Error creating stream: %v", err) @@ -544,6 +554,34 @@ func TestIdleNoData(t *testing.T) { wg.Wait() } +func TestUserIdleNoData(t *testing.T) { + var wg sync.WaitGroup + server, listen, serverErr := runServer(&wg) + if serverErr != nil { + t.Fatalf("Error initializing server: %s", serverErr) + } + + conn, dialErr := net.Dial("tcp", listen) + if dialErr != nil { + t.Fatalf("Error dialing server: %s", dialErr) + } + + spdyConn, spdyErr := NewConnection(conn, false) + if spdyErr != nil { + t.Fatalf("Error creating spdy connection: %s", spdyErr) + } + go spdyConn.Serve(NoOpStreamHandler) + + spdyConn.SetUserIdleTimeout(10 * time.Millisecond) + <-spdyConn.CloseChan() + + closeErr := server.Close() + if closeErr != nil { + t.Fatalf("Error shutting down server: %s", closeErr) + } + wg.Wait() +} + func TestIdleWithData(t *testing.T) { var wg sync.WaitGroup server, listen, serverErr := runServer(&wg) @@ -564,7 +602,9 @@ func TestIdleWithData(t *testing.T) { spdyConn.SetIdleTimeout(25 * time.Millisecond) + authMu.Lock() authenticated = true + authMu.Unlock() stream, err := spdyConn.CreateStream(http.Header{}, nil, false) if err != nil { t.Fatalf("Error creating stream: %v", err) @@ -606,6 +646,144 @@ Loop: wg.Wait() } +func TestIdleWithPing(t *testing.T) { + var wg sync.WaitGroup + server, listen, serverErr := runServer(&wg) + if serverErr != nil { + t.Fatalf("Error initializing server: %s", serverErr) + } + + conn, dialErr := net.Dial("tcp", listen) + if dialErr != nil { + t.Fatalf("Error dialing server: %s", dialErr) + } + + spdyConn, spdyErr := NewConnection(conn, false) + if spdyErr != nil { + t.Fatalf("Error creating spdy connection: %s", spdyErr) + } + go spdyConn.Serve(NoOpStreamHandler) + + spdyConn.SetIdleTimeout(25 * time.Millisecond) + + authMu.Lock() + authenticated = true + authMu.Unlock() + _, err := spdyConn.CreateStream(http.Header{}, nil, false) + if err != nil { + t.Fatalf("Error creating stream: %v", err) + } + + writeCh := make(chan struct{}) + + go func() { + for i := 0; i < 10; i++ { + pingTime, pingErr := spdyConn.Ping() + if pingErr != nil { + t.Errorf("Error pinging server: %s", pingErr) + } + + if pingTime == time.Duration(0) { + t.Errorf("Expecting non-zero ping time") + } + time.Sleep(10 * time.Millisecond) + } + close(writeCh) + }() + + writesFinished := false + +Loop: + for { + select { + case <-writeCh: + writesFinished = true + case <-spdyConn.CloseChan(): + if !writesFinished { + t.Fatal("Connection closed before all writes finished") + } + break Loop + } + } + + closeErr := server.Close() + if closeErr != nil { + t.Fatalf("Error shutting down server: %s", closeErr) + } + wg.Wait() +} + +func TestUserIdleWithPing(t *testing.T) { + var wg sync.WaitGroup + server, listen, serverErr := runServer(&wg) + if serverErr != nil { + t.Fatalf("Error initializing server: %s", serverErr) + } + + conn, dialErr := net.Dial("tcp", listen) + if dialErr != nil { + t.Fatalf("Error dialing server: %s", dialErr) + } + + spdyConn, spdyErr := NewConnection(conn, false) + if spdyErr != nil { + t.Fatalf("Error creating spdy connection: %s", spdyErr) + } + go spdyConn.Serve(NoOpStreamHandler) + + spdyConn.SetUserIdleTimeout(25 * time.Millisecond) + + authMu.Lock() + authenticated = true + authMu.Unlock() + _, err := spdyConn.CreateStream(http.Header{}, nil, false) + if err != nil { + t.Fatalf("Error creating stream: %v", err) + } + + writeCh := make(chan struct{}) + + go func() { + for i := 0; i < 10; i++ { + select { + case <-spdyConn.CloseChan(): + default: + pingTime, pingErr := spdyConn.Ping() + if pingErr != nil { + t.Errorf("Error pinging server: %s", pingErr) + } + + if pingTime == time.Duration(0) { + t.Errorf("Expecting non-zero ping time") + } + time.Sleep(10 * time.Millisecond) + } + } + close(writeCh) + }() + + writesFinished := false + +Loop: + for { + select { + case <-writeCh: + writesFinished = true + case <-spdyConn.CloseChan(): + if writesFinished { + t.Fatal("Connection closed after all writes finished") + } + break Loop + } + } + + closeErr := server.Close() + if closeErr != nil { + t.Fatalf("Error shutting down server: %s", closeErr) + } + wg.Wait() +} + func TestIdleRace(t *testing.T) { var wg sync.WaitGroup server, listen, serverErr := runServer(&wg) @@ -626,8 +804,9 @@ func TestIdleRace(t *testing.T) { spdyConn.SetIdleTimeout(10 * time.Millisecond) + authMu.Lock() authenticated = true - + authMu.Unlock() for i := 0; i < 10; i++ { _, err := spdyConn.CreateStream(http.Header{}, nil, false) if err != nil { @@ -711,7 +890,9 @@ func TestStreamReset(t *testing.T) { } go spdyConn.Serve(NoOpStreamHandler) + authMu.Lock() authenticated = true + authMu.Unlock() stream, streamErr := spdyConn.CreateStream(http.Header{}, nil, false) if streamErr != nil { t.Fatalf("Error creating stream: %s", streamErr) @@ -759,7 +940,9 @@ func TestStreamResetWithDataRemaining(t *testing.T) { } go spdyConn.Serve(NoOpStreamHandler) + authMu.Lock() authenticated = true + authMu.Unlock() stream, streamErr := spdyConn.CreateStream(http.Header{}, nil, false) if streamErr != nil { t.Fatalf("Error creating stream: %s", streamErr) @@ -1154,12 +1337,16 @@ func TestStreamReadUnblocksAfterCloseThenReset(t *testing.T) { } var authenticated bool +var authMu sync.Mutex func authStreamHandler(stream *Stream) { + authMu.Lock() if !authenticated { + authMu.Unlock() stream.Refuse() return } + authMu.Unlock() MirrorStreamHandler(stream) }