diff --git a/.travis.yml b/.travis.yml index f912ec0..51bd86c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,8 +8,10 @@ go: env: global: - - GOTFLAGS="-race" - BUILD_DEPTYPE=gomod + matrix: + - GOTFLAGS="-race" + - GOTFLAGS="-count 5" # disable travis install diff --git a/bench_test.go b/bench_test.go index 34302de..cf33cff 100644 --- a/bench_test.go +++ b/bench_test.go @@ -48,13 +48,14 @@ func BenchmarkAccept(b *testing.B) { func BenchmarkSendRecv(b *testing.B) { client, server := testClientServer() defer client.Close() - defer server.Close() sendBuf := make([]byte, 512) recvBuf := make([]byte, 512) doneCh := make(chan struct{}) go func() { + defer close(doneCh) + defer server.Close() stream, err := server.AcceptStream() if err != nil { return @@ -62,10 +63,10 @@ func BenchmarkSendRecv(b *testing.B) { defer stream.Close() for i := 0; i < b.N; i++ { if _, err := io.ReadFull(stream, recvBuf); err != nil { - b.Fatalf("err: %v", err) + b.Errorf("err: %v", err) + return } } - close(doneCh) }() stream, err := client.Open() @@ -95,6 +96,8 @@ func BenchmarkSendRecvLarge(b *testing.B) { recvDone := make(chan struct{}) go func() { + defer close(recvDone) + defer server.Close() stream, err := server.AcceptStream() if err != nil { return @@ -103,11 +106,11 @@ func BenchmarkSendRecvLarge(b *testing.B) { for i := 0; i < b.N; i++ { for j := 0; j < sendSize/recvSize; j++ { if _, err := io.ReadFull(stream, recvBuf); err != nil { - b.Fatalf("err: %v", err) + b.Errorf("err: %v", err) + return } } } - close(recvDone) }() stream, err := client.Open() diff --git a/session.go b/session.go index 42caeb3..de1ce17 100644 --- a/session.go +++ b/session.go @@ -87,15 +87,11 @@ type Session struct { // keepaliveTimer is a periodic timer for keepalive messages. It's nil // when keepalives are disabled. - keepaliveLock sync.Mutex - keepaliveTimer *time.Timer + keepaliveLock sync.Mutex + keepaliveTimer *time.Timer + keepaliveActive bool } -const ( - stageInitial uint32 = iota - stageFinal -) - // newSession is used to construct a new session func newSession(config *Config, conn net.Conn, client bool, readBuf int) *Session { var reader io.Reader = conn @@ -327,23 +323,27 @@ func (s *Session) startKeepalive() { defer s.keepaliveLock.Unlock() s.keepaliveTimer = time.AfterFunc(s.config.KeepAliveInterval, func() { s.keepaliveLock.Lock() - - if s.keepaliveTimer == nil { + if s.keepaliveTimer == nil || s.keepaliveActive { + // keepalives have been stopped or a keepalive is active. s.keepaliveLock.Unlock() - // keepalives have been stopped. return } + s.keepaliveActive = true + s.keepaliveLock.Unlock() + _, err := s.Ping() + + s.keepaliveLock.Lock() + s.keepaliveActive = false + if s.keepaliveTimer != nil { + s.keepaliveTimer.Reset(s.config.KeepAliveInterval) + } + s.keepaliveLock.Unlock() + if err != nil { - // Make sure to unlock before exiting so we don't - // deadlock trying to shutdown keepalives. - s.keepaliveLock.Unlock() s.logger.Printf("[ERR] yamux: keepalive failed: %v", err) s.exitErr(ErrKeepAliveTimeout) - return } - s.keepaliveTimer.Reset(s.config.KeepAliveInterval) - s.keepaliveLock.Unlock() }) } @@ -353,7 +353,24 @@ func (s *Session) stopKeepalive() { defer s.keepaliveLock.Unlock() if s.keepaliveTimer != nil { s.keepaliveTimer.Stop() + s.keepaliveTimer = nil + } +} + +func (s *Session) extendKeepalive() { + s.keepaliveLock.Lock() + if s.keepaliveTimer != nil && !s.keepaliveActive { + // Don't stop the timer and drain the channel. This is an + // AfterFunc, not a normal timer, and any attempts to drain the + // channel will block forever. + // + // Go will stop the timer for us internally anyways. The docs + // say one must stop the timer before calling reset but that's + // to ensure that the timer doesn't end up firing immediately + // after calling Reset. + s.keepaliveTimer.Reset(s.config.KeepAliveInterval) } + s.keepaliveLock.Unlock() } // send sends the header and body. @@ -512,9 +529,7 @@ func (s *Session) recvLoop() error { // There's no reason to keepalive if we're active. Worse, if the // peer is busy sending us stuff, the pong might get stuck // behind a bunch of data. - if s.keepaliveTimer != nil { - s.keepaliveTimer.Reset(s.config.KeepAliveInterval) - } + s.extendKeepalive() // Verify the version if hdr.Version() != protoVersion { diff --git a/session_norace_test.go b/session_norace_test.go new file mode 100644 index 0000000..c011a4a --- /dev/null +++ b/session_norace_test.go @@ -0,0 +1,163 @@ +//+build !race + +package yamux + +import ( + "bytes" + "io" + "io/ioutil" + "sync" + "testing" + "time" +) + +func TestSession_PingOfDeath(t *testing.T) { + client, server := testClientServerConfig(testConfNoKeepAlive()) + defer client.Close() + defer server.Close() + + count := 10000 + + var wg sync.WaitGroup + begin := make(chan struct{}) + for i := 0; i < count; i++ { + wg.Add(2) + go func() { + defer wg.Done() + <-begin + if _, err := server.Ping(); err != nil { + t.Error(err) + } + }() + go func() { + defer wg.Done() + <-begin + if _, err := client.Ping(); err != nil { + t.Error(err) + } + }() + } + close(begin) + wg.Wait() +} + +func TestSendData_VeryLarge(t *testing.T) { + client, server := testClientServer() + defer client.Close() + defer server.Close() + + var n int64 = 1 * 1024 * 1024 * 1024 + var workers int = 16 + + wg := &sync.WaitGroup{} + wg.Add(workers * 2) + + for i := 0; i < workers; i++ { + go func() { + defer wg.Done() + stream, err := server.AcceptStream() + if err != nil { + t.Errorf("err: %v", err) + return + } + defer stream.Close() + + buf := make([]byte, 4) + _, err = io.ReadFull(stream, buf) + if err != nil { + t.Errorf("err: %v", err) + return + } + if !bytes.Equal(buf, []byte{0, 1, 2, 3}) { + t.Errorf("bad header") + return + } + + recv, err := io.Copy(ioutil.Discard, stream) + if err != nil { + t.Errorf("err: %v", err) + return + } + if recv != n { + t.Errorf("bad: %v", recv) + return + } + }() + } + for i := 0; i < workers; i++ { + go func() { + defer wg.Done() + stream, err := client.Open() + if err != nil { + t.Errorf("err: %v", err) + return + } + defer stream.Close() + + _, err = stream.Write([]byte{0, 1, 2, 3}) + if err != nil { + t.Errorf("err: %v", err) + return + } + + unlimited := &UnlimitedReader{} + sent, err := io.Copy(stream, io.LimitReader(unlimited, n)) + if err != nil { + t.Errorf("err: %v", err) + return + } + if sent != n { + t.Errorf("bad: %v", sent) + return + } + }() + } + + doneCh := make(chan struct{}) + go func() { + wg.Wait() + close(doneCh) + }() + select { + case <-doneCh: + case <-time.After(20 * time.Second): + server.Close() + client.Close() + wg.Wait() + t.Fatal("timeout") + } +} + +func TestLargeWindow(t *testing.T) { + conf := DefaultConfig() + conf.MaxStreamWindowSize *= 2 + + client, server := testClientServerConfig(conf) + defer client.Close() + defer server.Close() + + stream, err := client.Open() + if err != nil { + t.Fatalf("err: %v", err) + } + defer stream.Close() + + stream2, err := server.Accept() + if err != nil { + t.Fatalf("err: %v", err) + } + defer stream2.Close() + + err = stream.SetWriteDeadline(time.Now().Add(10 * time.Millisecond)) + if err != nil { + t.Fatal(err) + } + buf := make([]byte, conf.MaxStreamWindowSize) + n, err := stream.Write(buf) + if err != nil { + t.Fatalf("err: %v", err) + } + if n != len(buf) { + t.Fatalf("short write: %d", n) + } +} diff --git a/session_test.go b/session_test.go index b3ab237..aaa75b5 100644 --- a/session_test.go +++ b/session_test.go @@ -5,7 +5,6 @@ import ( "fmt" "io" "io/ioutil" - "log" "net" "reflect" "runtime" @@ -19,19 +18,19 @@ import ( type logCapture struct{ bytes.Buffer } func (l *logCapture) logs() []string { - return strings.Split(strings.TrimSpace(l.String()), "\n") + lines := strings.Split(strings.TrimSpace(l.String()), "\n") + for i, line := range lines { + // trim leading date. + split := strings.SplitN(line, " ", 3) + lines[i] = split[2] + } + return lines } func (l *logCapture) match(expect []string) bool { return reflect.DeepEqual(l.logs(), expect) } -func captureLogs(s *Session) *logCapture { - buf := new(logCapture) - s.logger = log.New(buf, "", 0) - return buf -} - type pipeConn struct { net.Conn writeDeadline pipeDeadline @@ -127,12 +126,12 @@ func TestClientClient(t *testing.T) { defer client1.Close() defer client2.Close() - client1.OpenStream() + _, _ = client1.OpenStream() _, err := client2.AcceptStream() if err == nil { t.Fatalf("should have failed to open a stream with two clients") } - client2.OpenStream() + _, _ = client2.OpenStream() _, err = client1.AcceptStream() if err == nil { t.Fatalf("should have failed to open a stream with two clients") @@ -150,12 +149,12 @@ func TestServerServer(t *testing.T) { defer server1.Close() defer server2.Close() - server1.OpenStream() + _, _ = server1.OpenStream() _, err := server2.AcceptStream() if err == nil { t.Fatalf("should have failed to open a stream with two servers") } - server2.OpenStream() + _, _ = server2.OpenStream() _, err = server1.AcceptStream() if err == nil { t.Fatalf("should have failed to open a stream with two servers") @@ -176,7 +175,7 @@ func TestStreamAfterShutdown(t *testing.T) { s, err := client.OpenStream() if err == nil { cb(s) - s.Reset() + _ = s.Reset() } client.Close() }() @@ -200,14 +199,14 @@ func TestStreamAfterShutdown(t *testing.T) { // test write for i := 0; i < 100; i++ { do(func(s *Stream) { - s.Write([]byte{10}) + _, _ = s.Write([]byte{10}) }) } // test read for i := 0; i < 100; i++ { do(func(s *Stream) { - s.Read([]byte{10}) + _, _ = s.Read([]byte{10}) }) } } @@ -263,7 +262,8 @@ func TestCloseBeforeAck(t *testing.T) { defer close(done) s, err := client.OpenStream() if err != nil { - t.Fatal(err) + t.Error(err) + return } s.Close() }() @@ -271,6 +271,9 @@ func TestCloseBeforeAck(t *testing.T) { select { case <-done: case <-time.After(time.Second * 5): + client.Close() + server.Close() + <-done t.Fatal("timed out trying to open stream") } } @@ -294,13 +297,16 @@ func TestAccept(t *testing.T) { defer wg.Done() stream, err := server.AcceptStream() if err != nil { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) + return } if id := stream.StreamID(); id != 1 { - t.Fatalf("bad: %v", id) + t.Errorf("bad: %v", id) + return } if err := stream.Close(); err != nil { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) + return } }() @@ -308,13 +314,16 @@ func TestAccept(t *testing.T) { defer wg.Done() stream, err := client.AcceptStream() if err != nil { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) + return } if id := stream.StreamID(); id != 2 { - t.Fatalf("bad: %v", id) + t.Errorf("bad: %v", id) + return } if err := stream.Close(); err != nil { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) + return } }() @@ -322,13 +331,16 @@ func TestAccept(t *testing.T) { defer wg.Done() stream, err := server.OpenStream() if err != nil { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) + return } if id := stream.StreamID(); id != 2 { - t.Fatalf("bad: %v", id) + t.Errorf("bad: %v", id) + return } if err := stream.Close(); err != nil { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) + return } }() @@ -336,13 +348,16 @@ func TestAccept(t *testing.T) { defer wg.Done() stream, err := client.OpenStream() if err != nil { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) + return } if id := stream.StreamID(); id != 1 { - t.Fatalf("bad: %v", id) + t.Errorf("bad: %v", id) + return } if err := stream.Close(); err != nil { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) + return } }() @@ -355,7 +370,10 @@ func TestAccept(t *testing.T) { select { case <-doneCh: case <-time.After(time.Second): - panic("timeout") + server.Close() + client.Close() + wg.Wait() + t.Fatal("timeout") } } @@ -386,29 +404,35 @@ func TestSendData_Small(t *testing.T) { defer wg.Done() stream, err := server.AcceptStream() if err != nil { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) + return } if server.NumStreams() != 1 { - t.Fatalf("bad") + t.Errorf("bad") + return } buf := make([]byte, 4) for i := 0; i < 1000; i++ { n, err := stream.Read(buf) if err != nil { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) + return } if n != 4 { - t.Fatalf("short read: %d", n) + t.Errorf("short read: %d", n) + return } if string(buf) != "test" { - t.Fatalf("bad: %s", buf) + t.Errorf("bad: %s", buf) + return } } if err := stream.Close(); err != nil { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) + return } n, err := stream.Read([]byte{0}) if n != 0 || err != io.EOF { @@ -420,25 +444,30 @@ func TestSendData_Small(t *testing.T) { defer wg.Done() stream, err := client.Open() if err != nil { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) + return } if client.NumStreams() != 1 { - t.Fatalf("bad") + t.Errorf("bad") + return } for i := 0; i < 1000; i++ { n, err := stream.Write([]byte("test")) if err != nil { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) + return } if n != 4 { - t.Fatalf("short write %d", n) + t.Errorf("short write %d", n) + return } } if err := stream.Close(); err != nil { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) + return } n, err := stream.Read([]byte{0}) if n != 0 || err != io.EOF { @@ -454,7 +483,10 @@ func TestSendData_Small(t *testing.T) { select { case <-doneCh: case <-time.After(time.Second): - panic("timeout") + client.Close() + server.Close() + wg.Wait() + t.Fatal("timeout") } if client.NumStreams() != 0 { @@ -487,28 +519,34 @@ func TestSendData_Large(t *testing.T) { defer wg.Done() stream, err := server.AcceptStream() if err != nil { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) + return } + var sz int buf := make([]byte, recvSize) for i := 0; i < sendSize/recvSize; i++ { n, err := io.ReadFull(stream, buf) if err != nil { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) + return } if n != recvSize { - t.Fatalf("short read: %d", n) + t.Errorf("short read: %d", n) + return } sz += n for idx := range buf { if buf[idx] != byte(idx%256) { - t.Fatalf("bad: %v %v %v", i, idx, buf[idx]) + t.Errorf("bad: %v %v %v", i, idx, buf[idx]) + return } } } if err := stream.Close(); err != nil { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) + return } t.Logf("cap=%d, n=%d\n", stream.recvBuf.Cap(), sz) @@ -518,19 +556,23 @@ func TestSendData_Large(t *testing.T) { defer wg.Done() stream, err := client.Open() if err != nil { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) + return } n, err := stream.Write(data) if err != nil { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) + return } if n != len(data) { - t.Fatalf("short write %d", n) + t.Errorf("short write %d", n) + return } if err := stream.Close(); err != nil { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) + return } }() @@ -541,13 +583,20 @@ func TestSendData_Large(t *testing.T) { }() select { case <-doneCh: - case <-time.After(5 * time.Second): - panic("timeout") + case <-time.After(20 * time.Second): + client.Close() + server.Close() + wg.Wait() + t.Fatal("timeout") } } func TestGoAway(t *testing.T) { - client, server := testClientServer() + // This test is noisy. + conf := testConf() + conf.LogOutput = ioutil.Discard + + client, server := testClientServerConfig(conf) defer client.Close() defer server.Close() @@ -580,7 +629,8 @@ func TestManyStreams(t *testing.T) { defer wg.Done() stream, err := server.AcceptStream() if err != nil { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) + return } defer stream.Close() @@ -591,10 +641,12 @@ func TestManyStreams(t *testing.T) { return } if err != nil { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) + return } if n == 0 { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) + return } } } @@ -602,7 +654,8 @@ func TestManyStreams(t *testing.T) { defer wg.Done() stream, err := client.Open() if err != nil { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) + return } defer stream.Close() @@ -610,10 +663,12 @@ func TestManyStreams(t *testing.T) { for i := 0; i < 1000; i++ { n, err := stream.Write([]byte(msg)) if err != nil { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) + return } if n != len(msg) { - t.Fatalf("short write %d", n) + t.Errorf("short write %d", n) + return } } } @@ -641,7 +696,8 @@ func TestManyStreams_PingPong(t *testing.T) { defer wg.Done() stream, err := server.AcceptStream() if err != nil { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) + return } defer stream.Close() @@ -653,13 +709,16 @@ func TestManyStreams_PingPong(t *testing.T) { return } if err != nil { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) + return } if n != 4 { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) + return } if !bytes.Equal(buf, ping) { - t.Fatalf("bad: %s", buf) + t.Errorf("bad: %s", buf) + return } // Shrink the internal buffer! @@ -668,10 +727,12 @@ func TestManyStreams_PingPong(t *testing.T) { // Write out the 'pong' n, err = stream.Write(pong) if err != nil { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) + return } if n != 4 { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) + return } } } @@ -679,7 +740,8 @@ func TestManyStreams_PingPong(t *testing.T) { defer wg.Done() stream, err := client.OpenStream() if err != nil { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) + return } defer stream.Close() @@ -688,22 +750,27 @@ func TestManyStreams_PingPong(t *testing.T) { // Send the 'ping' n, err := stream.Write(ping) if err != nil { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) + return } if n != 4 { - t.Fatalf("short write %d", n) + t.Errorf("short write %d", n) + return } // Read the 'pong' n, err = io.ReadFull(stream, buf) if err != nil { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) + return } if n != 4 { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) + return } if !bytes.Equal(buf, pong) { - t.Fatalf("bad: %s", buf) + t.Errorf("bad: %s", buf) + return } // Shrink the buffer @@ -905,11 +972,12 @@ func TestKeepAlive_Timeout(t *testing.T) { client, _ := Client(conn1, clientConf) defer client.Close() - server, _ := Server(conn2, testConf()) - defer server.Close() + serverLogs := new(logCapture) + serverConf := testConf() + serverConf.LogOutput = serverLogs - _ = captureLogs(client) // Client logs aren't part of the test - serverLogs := captureLogs(server) + server, _ := Server(conn2, serverConf) + defer server.Close() errCh := make(chan error, 1) go func() { @@ -939,37 +1007,6 @@ func TestKeepAlive_Timeout(t *testing.T) { } } -func TestLargeWindow(t *testing.T) { - conf := DefaultConfig() - conf.MaxStreamWindowSize *= 2 - - client, server := testClientServerConfig(conf) - defer client.Close() - defer server.Close() - - stream, err := client.Open() - if err != nil { - t.Fatalf("err: %v", err) - } - defer stream.Close() - - stream2, err := server.Accept() - if err != nil { - t.Fatalf("err: %v", err) - } - defer stream2.Close() - - stream.SetWriteDeadline(time.Now().Add(10 * time.Millisecond)) - buf := make([]byte, conf.MaxStreamWindowSize) - n, err := stream.Write(buf) - if err != nil { - t.Fatalf("err: %v", err) - } - if n != len(buf) { - t.Fatalf("short write: %d", n) - } -} - type UnlimitedReader struct{} func (u *UnlimitedReader) Read(p []byte) (int, error) { @@ -977,92 +1014,20 @@ func (u *UnlimitedReader) Read(p []byte) (int, error) { return len(p), nil } -func TestSendData_VeryLarge(t *testing.T) { - client, server := testClientServer() - defer client.Close() - defer server.Close() - - var n int64 = 1 * 1024 * 1024 * 1024 - var workers int = 16 - - wg := &sync.WaitGroup{} - wg.Add(workers * 2) - - for i := 0; i < workers; i++ { - go func() { - defer wg.Done() - stream, err := server.AcceptStream() - if err != nil { - t.Fatalf("err: %v", err) - } - defer stream.Close() - - buf := make([]byte, 4) - _, err = io.ReadFull(stream, buf) - if err != nil { - t.Fatalf("err: %v", err) - } - if !bytes.Equal(buf, []byte{0, 1, 2, 3}) { - t.Fatalf("bad header") - } - - recv, err := io.Copy(ioutil.Discard, stream) - if err != nil { - t.Fatalf("err: %v", err) - } - if recv != n { - t.Fatalf("bad: %v", recv) - } - }() - } - for i := 0; i < workers; i++ { - go func() { - defer wg.Done() - stream, err := client.Open() - if err != nil { - t.Fatalf("err: %v", err) - } - defer stream.Close() - - _, err = stream.Write([]byte{0, 1, 2, 3}) - if err != nil { - t.Fatalf("err: %v", err) - } - - unlimited := &UnlimitedReader{} - sent, err := io.Copy(stream, io.LimitReader(unlimited, n)) - if err != nil { - t.Fatalf("err: %v", err) - } - if sent != n { - t.Fatalf("bad: %v", sent) - } - }() - } - - doneCh := make(chan struct{}) - go func() { - wg.Wait() - close(doneCh) - }() - select { - case <-doneCh: - case <-time.After(20 * time.Second): - panic("timeout") - } -} - func TestBacklogExceeded_Accept(t *testing.T) { client, server := testClientServer() defer client.Close() - defer server.Close() max := 5 * client.config.AcceptBacklog + done := make(chan struct{}) go func() { + defer close(done) + defer server.Close() for i := 0; i < max; i++ { stream, err := server.Accept() if err != nil { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) + return } defer stream.Close() } @@ -1072,14 +1037,15 @@ func TestBacklogExceeded_Accept(t *testing.T) { for i := 0; i < max; i++ { stream, err := client.Open() if err != nil { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) } defer stream.Close() if _, err := stream.Write([]byte("foo")); err != nil { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) } } + <-done } func TestSession_WindowUpdateWriteDuringRead(t *testing.T) { @@ -1164,25 +1130,30 @@ func TestSession_PartialReadWindowUpdate(t *testing.T) { var err error wr, err = server.AcceptStream() if err != nil { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) + return } defer wr.Close() sendWindow := atomic.LoadUint32(&wr.sendWindow) if sendWindow != client.config.MaxStreamWindowSize { - t.Fatalf("sendWindow: exp=%d, got=%d", client.config.MaxStreamWindowSize, sendWindow) + t.Errorf("sendWindow: exp=%d, got=%d", client.config.MaxStreamWindowSize, sendWindow) + return } n, err := wr.Write(make([]byte, flood)) if err != nil { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) + return } if int64(n) != flood { - t.Fatalf("short write: %d", n) + t.Errorf("short write: %d", n) + return } sendWindow = atomic.LoadUint32(&wr.sendWindow) if sendWindow != 0 { - t.Fatalf("sendWindow: exp=%d, got=%d", 0, sendWindow) + t.Errorf("sendWindow: exp=%d, got=%d", 0, sendWindow) + return } }() @@ -1199,12 +1170,21 @@ func TestSession_PartialReadWindowUpdate(t *testing.T) { t.Fatal(err) } - time.Sleep(1 * time.Millisecond) + var ( + exp = uint32(flood / 2) + sendWindow uint32 + ) - sendWindow := atomic.LoadUint32(&wr.sendWindow) - if exp := uint32(flood / 2); sendWindow != exp { - t.Errorf("sendWindow: exp=%d, got=%d", exp, sendWindow) + // This test is racy. Wait a short period, then longer and longer. At + // most ~1s. + for i := 1; i < 15; i++ { + time.Sleep(time.Duration(i*i) * time.Millisecond) + sendWindow = atomic.LoadUint32(&wr.sendWindow) + if sendWindow == exp { + return + } } + t.Errorf("sendWindow: exp=%d, got=%d", exp, sendWindow) } func TestSession_sendMsg_Timeout(t *testing.T) { @@ -1228,34 +1208,6 @@ func TestSession_sendMsg_Timeout(t *testing.T) { } } -func TestSession_PingOfDeath(t *testing.T) { - client, server := testClientServerConfig(testConfNoKeepAlive()) - defer client.Close() - defer server.Close() - - var wg sync.WaitGroup - begin := make(chan struct{}) - for i := 0; i < 10000; i++ { - wg.Add(2) - go func() { - defer wg.Done() - <-begin - if _, err := server.Ping(); err != nil { - t.Error(err) - } - }() - go func() { - defer wg.Done() - <-begin - if _, err := client.Ping(); err != nil { - t.Error(err) - } - }() - } - close(begin) - wg.Wait() -} - func TestSession_ConnectionWriteTimeout(t *testing.T) { client, server := testClientServerConfig(testConfNoKeepAlive()) defer client.Close() @@ -1271,7 +1223,8 @@ func TestSession_ConnectionWriteTimeout(t *testing.T) { stream, err := server.AcceptStream() if err != nil { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) + return } <-sync @@ -1287,7 +1240,8 @@ func TestSession_ConnectionWriteTimeout(t *testing.T) { stream, err := client.OpenStream() if err != nil { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) + return } defer stream.Close() @@ -1304,7 +1258,8 @@ func TestSession_ConnectionWriteTimeout(t *testing.T) { } else if err == ErrConnectionWriteTimeout { break } else { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) + return } } }() @@ -1338,7 +1293,10 @@ func TestStreamResetWrite(t *testing.T) { t.Fatalf("err: %v", err) } - stream.Reset() + err = stream.Reset() + if err != nil { + t.Fatal(err) + } <-wait } @@ -1417,7 +1375,10 @@ func TestStreamResetRead(t *testing.T) { }() time.Sleep(1 * time.Second) - stream.Reset() + err = stream.Reset() + if err != nil { + t.Fatal(err) + } wc.Wait() } @@ -1453,7 +1414,11 @@ func TestLotsOfWritesWithStreamDeadline(t *testing.T) { <-waitCh // stream2 should've received no messages, as they all expired in the buffer. - stream2.SetReadDeadline(time.Now().Add(200 * time.Millisecond)) + err = stream2.SetReadDeadline(time.Now().Add(200 * time.Millisecond)) + if err != nil { + t.Error(err) + return + } if b, err := ioutil.ReadAll(stream2); len(b) != 0 || err != ErrTimeout { t.Errorf("writes from the client should've expired; got: %v, bytes: %v", err, b) return @@ -1471,7 +1436,7 @@ func TestLotsOfWritesWithStreamDeadline(t *testing.T) { if err != nil { t.Fatal(err) } - defer stream2.Reset() + defer stream2.Reset() //nolint // wait for the server to accept the streams. <-waitCh @@ -1481,13 +1446,20 @@ func TestLotsOfWritesWithStreamDeadline(t *testing.T) { // Send a clogging write on stream1. go func() { - stream1.SetWriteDeadline(time.Now().Add(5 * time.Second)) - stream1.Write([]byte{100}) + err := stream1.SetWriteDeadline(time.Now().Add(5 * time.Second)) + if err != nil { + t.Error(err) + return + } + _, _ = stream1.Write([]byte{100}) }() // Keep writing till we fill the buffer and timeout. var wg sync.WaitGroup - stream2.SetWriteDeadline(time.Now().Add(100 * time.Millisecond)) + err = stream2.SetWriteDeadline(time.Now().Add(100 * time.Millisecond)) + if err != nil { + t.Fatal(err) + } for { _, err := stream2.Write([]byte("foobar")) if err == nil { @@ -1531,7 +1503,8 @@ func TestReadDeadlineInterrupt(t *testing.T) { defer close(done) buf := make([]byte, 4) if _, err := stream.Read(buf); err != ErrTimeout { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) + return } }()