diff --git a/go.mod b/go.mod index 2281a15..16c65c6 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/marten-seemann/webtransport-go go 1.17 require ( - github.com/lucas-clemente/quic-go v0.27.1-0.20220520111257-8185d1b4e072 + github.com/lucas-clemente/quic-go v0.27.1-0.20220526175250-9d5de12933f2 github.com/stretchr/testify v1.7.1 ) diff --git a/go.sum b/go.sum index 19ca405..f349cc4 100644 --- a/go.sum +++ b/go.sum @@ -75,8 +75,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.3/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/lucas-clemente/quic-go v0.27.1-0.20220520111257-8185d1b4e072 h1:s2A290bfKQCyFrBjpWti5klxhHgWOLfyV8LX3PfeVBM= -github.com/lucas-clemente/quic-go v0.27.1-0.20220520111257-8185d1b4e072/go.mod h1:AzgQoPda7N+3IqMMMkywBKggIFo2KT6pfnlrQ2QieeI= +github.com/lucas-clemente/quic-go v0.27.1-0.20220526175250-9d5de12933f2 h1:qUX8EnM3HZ03HK3NFZx7pd6zCuUsQjA+VW8gO5vtbEQ= +github.com/lucas-clemente/quic-go v0.27.1-0.20220526175250-9d5de12933f2/go.mod h1:AzgQoPda7N+3IqMMMkywBKggIFo2KT6pfnlrQ2QieeI= github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/marten-seemann/qpack v0.2.1 h1:jvTsT/HpCn2UZJdP+UUB53FfUUgeOyG5K1ns0OJOGVs= diff --git a/stream.go b/stream.go index 33b9e93..3a011da 100644 --- a/stream.go +++ b/stream.go @@ -46,12 +46,20 @@ func newSendStream(str quic.SendStream, hdr []byte) SendStream { return &sendStream{str: str, streamHdr: hdr} } +func (s *sendStream) maybeSendStreamHeader() error { + if len(s.streamHdr) == 0 { + return nil + } + if _, err := s.str.Write(s.streamHdr); err != nil { + return err + } + s.streamHdr = nil + return nil +} + func (s *sendStream) Write(b []byte) (int, error) { - if len(s.streamHdr) > 0 { - if _, err := s.str.Write(s.streamHdr); err != nil { - return 0, err - } - s.streamHdr = nil + if err := s.maybeSendStreamHeader(); err != nil { + return 0, err } n, err := s.str.Write(b) return n, maybeConvertStreamError(err) @@ -62,6 +70,9 @@ func (s *sendStream) CancelWrite(e ErrorCode) { } func (s *sendStream) Close() error { + if err := s.maybeSendStreamHeader(); err != nil { + return err + } return maybeConvertStreamError(s.str.Close()) } diff --git a/webtransport_test.go b/webtransport_test.go index 5e973e5..3aa7e9e 100644 --- a/webtransport_test.go +++ b/webtransport_test.go @@ -109,7 +109,7 @@ func getRandomData(l int) []byte { return data } -func TestBidirectionalStreams(t *testing.T) { +func TestBidirectionalStreamsDataTransfer(t *testing.T) { t.Run("client-initiated", func(t *testing.T) { conn, closeServer := establishConn(t, newEchoHandler(t)) defer closeServer() @@ -130,6 +130,89 @@ func TestBidirectionalStreams(t *testing.T) { }) } +func TestBidirectionalStreamsImmediateClose(t *testing.T) { + t.Run("bidirectional streams", func(t *testing.T) { + t.Run("client-initiated", func(t *testing.T) { + conn, closeServer := establishConn(t, func(c *webtransport.Conn) { + str, err := c.AcceptStream(context.Background()) + require.NoError(t, err) + n, err := str.Read([]byte{0}) + require.Zero(t, n) + require.ErrorIs(t, err, io.EOF) + require.NoError(t, str.Close()) + }) + defer closeServer() + + str, err := conn.OpenStream() + require.NoError(t, err) + require.NoError(t, str.Close()) + n, err := str.Read([]byte{0}) + require.Zero(t, n) + require.ErrorIs(t, err, io.EOF) + }) + + t.Run("server-initiated", func(t *testing.T) { + done := make(chan struct{}) + conn, closeServer := establishConn(t, func(c *webtransport.Conn) { + defer close(done) + str, err := c.OpenStream() + require.NoError(t, err) + require.NoError(t, str.Close()) + n, err := str.Read([]byte{0}) + require.Zero(t, n) + require.ErrorIs(t, err, io.EOF) + }) + defer closeServer() + + str, err := conn.AcceptStream(context.Background()) + require.NoError(t, err) + n, err := str.Read([]byte{0}) + require.Zero(t, n) + require.ErrorIs(t, err, io.EOF) + require.NoError(t, str.Close()) + <-done + }) + }) + + t.Run("unidirectional", func(t *testing.T) { + t.Run("client-initiated", func(t *testing.T) { + done := make(chan struct{}) + conn, closeServer := establishConn(t, func(c *webtransport.Conn) { + defer close(done) + str, err := c.AcceptUniStream(context.Background()) + require.NoError(t, err) + n, err := str.Read([]byte{0}) + require.Zero(t, n) + require.ErrorIs(t, err, io.EOF) + }) + defer closeServer() + + str, err := conn.OpenUniStream() + require.NoError(t, err) + require.NoError(t, str.Close()) + <-done + }) + + t.Run("server-initiated", func(t *testing.T) { + done := make(chan struct{}) + conn, closeServer := establishConn(t, func(c *webtransport.Conn) { + defer close(done) + str, err := c.OpenUniStream() + require.NoError(t, err) + require.NoError(t, str.Close()) + }) + defer closeServer() + + str, err := conn.AcceptUniStream(context.Background()) + require.NoError(t, err) + n, err := str.Read([]byte{0}) + require.Zero(t, n) + require.ErrorIs(t, err, io.EOF) + <-done + }) + }) +} + func TestUnidirectionalStreams(t *testing.T) { conn, closeServer := establishConn(t, func(conn *webtransport.Conn) { // Accept a unidirectional stream, read all of its contents,