From 45d63867576c68c511f356c3e8c8631abebd0d2a Mon Sep 17 00:00:00 2001 From: merlin Date: Sat, 17 Jun 2023 13:50:55 +0300 Subject: [PATCH 1/5] fix request parsing --- gemax/request.go | 65 +++++++++++++++++++++++++++------ gemax/request_test.go | 83 +++++++++++++++++++++++++++++++++++++++++++ go.mod | 7 ++-- go.sum | 10 ++++-- 4 files changed, 151 insertions(+), 14 deletions(-) create mode 100644 gemax/request_test.go diff --git a/gemax/request.go b/gemax/request.go index 32a9071..3a50e2b 100644 --- a/gemax/request.go +++ b/gemax/request.go @@ -1,15 +1,23 @@ package gemax import ( + "bytes" + "crypto/tls" + "crypto/x509" "errors" "fmt" "io" + "io/fs" "net/url" "strings" + + "golang.org/x/exp/slices" ) +var requestSuffix = []byte("\n") + // MaxRequestSize is the maximum incoming request size in bytes. -const MaxRequestSize = 1026 +const MaxRequestSize = int64(1024 + len("\r\n")) // IncomingRequest describes a server side request object. type IncomingRequest interface { @@ -21,31 +29,68 @@ var ( errDotPath = errors.New("dots in path are not permitted") ) +var ErrBadRequest = errors.New("bad request") + // ParseIncomingRequest constructs an IncomingRequest from bytestream // and additional parameters (remote address for now). func ParseIncomingRequest(re io.Reader, remoteAddr string) (IncomingRequest, error) { - var reader = io.LimitReader(re, MaxRequestSize) - var u string - var _, errReadRequest = fmt.Fscanf(reader, "%s\r\n", &u) - if errReadRequest != nil { - return nil, fmt.Errorf("bad request: %w", errReadRequest) + var certs []*x509.Certificate + if tlsConn, ok := re.(*tls.Conn); ok { + certs = slices.Clone(tlsConn.ConnectionState().PeerCertificates) + } + + re = io.LimitReader(re, MaxRequestSize) + + line, errLine := io.ReadAll(re) + if errLine != nil { + return nil, fmt.Errorf("%w: %w", ErrBadRequest, errLine) } - var parsed, errParse = url.ParseRequestURI(u) + + if !bytes.HasSuffix(line, requestSuffix) { + return nil, ErrBadRequest + } + + line = bytes.TrimRight(line, "\r\n") + + parsed, errParse := url.ParseRequestURI(string(line)) if errParse != nil { - return nil, fmt.Errorf("bad request: %w", errParse) + return nil, fmt.Errorf("%w: %w", ErrBadRequest, errParse) } - if strings.Contains(parsed.Path, "/..") { - return nil, errDotPath + + if !isValidPath(parsed.Path) { + return nil, fmt.Errorf("%w: %w", ErrBadRequest, errDotPath) } + + if parsed.Path == "" { + parsed.Path = "/" + } + return &incomingRequest{ url: parsed, remoteAddr: remoteAddr, + certs: certs, }, nil } +func isValidPath(path string) bool { + if path == "." { + return false + } + + path = strings.TrimPrefix(path, "/") + path = strings.TrimSuffix(path, "/") + + if path == "" { + return true + } + + return fs.ValidPath(path) +} + type incomingRequest struct { url *url.URL remoteAddr string + certs []*x509.Certificate } func (req *incomingRequest) URL() *url.URL { diff --git a/gemax/request_test.go b/gemax/request_test.go new file mode 100644 index 0000000..f656921 --- /dev/null +++ b/gemax/request_test.go @@ -0,0 +1,83 @@ +package gemax_test + +import ( + "strings" + "testing" + + "github.com/ninedraft/gemax/gemax" +) + +func TestParseIncomingRequest(t *testing.T) { + t.Parallel() + t.Log("parsing incoming request line") + + const remoteAddr = "remote-addr" + type expect struct { + err bool + url string + } + + tc := func(name, input string, expected expect) { + t.Run(name, func(t *testing.T) { + t.Parallel() + + re := strings.NewReader(input) + + parsed, err := gemax.ParseIncomingRequest(re, remoteAddr) + + if (err != nil) != expected.err { + t.Errorf("error = %v, want error = %v", err, expected.err) + } + + if parsed == nil && err == nil { + t.Error("parsed = nil, want not nil") + return + } + + if parsed != nil { + assertEq(t, parsed.RemoteAddr(), remoteAddr, "remote addr") + assertEq(t, parsed.URL().String(), expected.url, "url") + } + }) + } + + tc("valid", + "gemini://example.com\r\n", expect{ + url: "gemini://example.com/", + }) + tc("valid no \\r", + "gemini://example.com\n", expect{ + url: "gemini://example.com/", + }) + tc("valid with path", + "gemini://example.com/path\r\n", expect{ + url: "gemini://example.com/path", + }) + tc("valid with path and query", + "gemini://example.com/path?query=value\r\n", expect{ + url: "gemini://example.com/path?query=value", + }) + tc("valid http", + "http://example.com\r\n", expect{ + url: "http://example.com/", + }) + + tc("too long", + "http://example.com/"+strings.Repeat("a", 2048)+"\r\n", + expect{err: true}) + tc("empty", + "", expect{err: true}) + tc("no new \\r\\n", + "gemini://example.com", expect{err: true}) + tc("no \\n", + "gemini://example.com\r", expect{err: true}) +} + +func assertEq[E comparable](t *testing.T, got, want E, format string, args ...any) { + t.Helper() + + if got != want { + t.Errorf("got %v, want %v", got, want) + t.Errorf(format, args...) + } +} diff --git a/go.mod b/go.mod index 95faf56..451fb9c 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,8 @@ module github.com/ninedraft/gemax -go 1.19 +go 1.20 -require golang.org/x/net v0.2.0 +require ( + golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 + golang.org/x/net v0.11.0 +) diff --git a/go.sum b/go.sum index 2195436..3a65245 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,8 @@ -golang.org/x/net v0.2.0 h1:sZfSu1wtKLGlWI4ZZayP0ck9Y73K1ynO6gqzTdBVdPU= -golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= +golang.org/x/exp v0.0.0-20230224173230-c95f2b4c22f2 h1:Jvc7gsqn21cJHCmAWx0LiimpP18LZmUxkT5Mp7EZ1mI= +golang.org/x/exp v0.0.0-20230224173230-c95f2b4c22f2/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= +golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERsqu1GIbi967PQMq3Ivc= +golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= +golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g= +golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/net v0.11.0 h1:Gi2tvZIJyBtO9SDr1q9h5hEQCp/4L2RQ+ar0qjx2oNU= +golang.org/x/net v0.11.0/go.mod h1:2L/ixqYpgIVXmeoSA/4Lu7BzTG4KIyPIryS4IsOd1oQ= From 1a8101a4c5275fd56a939726a7f364e37f112a50 Mon Sep 17 00:00:00 2001 From: merlin Date: Sat, 17 Jun 2023 14:10:54 +0300 Subject: [PATCH 2/5] vendor tailscale memnet --- go.mod | 2 + vend/tailscale.com/net/memnet/LICENSE | 28 ++ vend/tailscale.com/net/memnet/VERSION.txt | 1 + vend/tailscale.com/net/memnet/conn.go | 110 ++++++++ vend/tailscale.com/net/memnet/conn_test.go | 21 ++ vend/tailscale.com/net/memnet/go.mod | 5 + vend/tailscale.com/net/memnet/go.sum | 2 + vend/tailscale.com/net/memnet/listener.go | 100 +++++++ .../tailscale.com/net/memnet/listener_test.go | 33 +++ vend/tailscale.com/net/memnet/memnet.go | 8 + vend/tailscale.com/net/memnet/pipe.go | 244 ++++++++++++++++++ vend/tailscale.com/net/memnet/pipe_test.go | 117 +++++++++ 12 files changed, 671 insertions(+) create mode 100644 vend/tailscale.com/net/memnet/LICENSE create mode 100644 vend/tailscale.com/net/memnet/VERSION.txt create mode 100644 vend/tailscale.com/net/memnet/conn.go create mode 100644 vend/tailscale.com/net/memnet/conn_test.go create mode 100644 vend/tailscale.com/net/memnet/go.mod create mode 100644 vend/tailscale.com/net/memnet/go.sum create mode 100644 vend/tailscale.com/net/memnet/listener.go create mode 100644 vend/tailscale.com/net/memnet/listener_test.go create mode 100644 vend/tailscale.com/net/memnet/memnet.go create mode 100644 vend/tailscale.com/net/memnet/pipe.go create mode 100644 vend/tailscale.com/net/memnet/pipe_test.go diff --git a/go.mod b/go.mod index 451fb9c..1d1c341 100644 --- a/go.mod +++ b/go.mod @@ -2,6 +2,8 @@ module github.com/ninedraft/gemax go 1.20 +replace tailscale.com/net/memnet => ./vend/tailscale.com/net/memnet + require ( golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 golang.org/x/net v0.11.0 diff --git a/vend/tailscale.com/net/memnet/LICENSE b/vend/tailscale.com/net/memnet/LICENSE new file mode 100644 index 0000000..394db19 --- /dev/null +++ b/vend/tailscale.com/net/memnet/LICENSE @@ -0,0 +1,28 @@ +BSD 3-Clause License + +Copyright (c) 2020 Tailscale Inc & AUTHORS. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vend/tailscale.com/net/memnet/VERSION.txt b/vend/tailscale.com/net/memnet/VERSION.txt new file mode 100644 index 0000000..b978278 --- /dev/null +++ b/vend/tailscale.com/net/memnet/VERSION.txt @@ -0,0 +1 @@ +1.43.0 diff --git a/vend/tailscale.com/net/memnet/conn.go b/vend/tailscale.com/net/memnet/conn.go new file mode 100644 index 0000000..fb7776e --- /dev/null +++ b/vend/tailscale.com/net/memnet/conn.go @@ -0,0 +1,110 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package memnet + +import ( + "net" + "net/netip" + "time" +) + +// Conn is a net.Conn that can additionally have its reads and writes blocked and unblocked. +type Conn interface { + net.Conn + + // SetReadBlock blocks or unblocks the Read method of this Conn. + // It reports an error if the existing value matches the new value, + // or if the Conn has been Closed. + SetReadBlock(bool) error + + // SetWriteBlock blocks or unblocks the Write method of this Conn. + // It reports an error if the existing value matches the new value, + // or if the Conn has been Closed. + SetWriteBlock(bool) error +} + +// NewConn creates a pair of Conns that are wired together by pipes. +func NewConn(name string, maxBuf int) (Conn, Conn) { + r := NewPipe(name+"|0", maxBuf) + w := NewPipe(name+"|1", maxBuf) + + return &connHalf{r: r, w: w}, &connHalf{r: w, w: r} +} + +// NewTCPConn creates a pair of Conns that are wired together by pipes. +func NewTCPConn(src, dst netip.AddrPort, maxBuf int) (local Conn, remote Conn) { + r := NewPipe(src.String(), maxBuf) + w := NewPipe(dst.String(), maxBuf) + + lAddr := net.TCPAddrFromAddrPort(src) + rAddr := net.TCPAddrFromAddrPort(dst) + + return &connHalf{r: r, w: w, remote: rAddr, local: lAddr}, &connHalf{r: w, w: r, remote: lAddr, local: rAddr} +} + +type connAddr string + +func (a connAddr) Network() string { return "mem" } +func (a connAddr) String() string { return string(a) } + +type connHalf struct { + local, remote net.Addr + r, w *Pipe +} + +func (c *connHalf) LocalAddr() net.Addr { + if c.local != nil { + return c.local + } + return connAddr(c.r.name) +} + +func (c *connHalf) RemoteAddr() net.Addr { + if c.remote != nil { + return c.remote + } + return connAddr(c.w.name) +} + +func (c *connHalf) Read(b []byte) (n int, err error) { + return c.r.Read(b) +} +func (c *connHalf) Write(b []byte) (n int, err error) { + return c.w.Write(b) +} + +func (c *connHalf) Close() error { + if err := c.w.Close(); err != nil { + return err + } + return c.r.Close() +} + +func (c *connHalf) SetDeadline(t time.Time) error { + err1 := c.SetReadDeadline(t) + err2 := c.SetWriteDeadline(t) + if err1 != nil { + return err1 + } + return err2 +} +func (c *connHalf) SetReadDeadline(t time.Time) error { + return c.r.SetReadDeadline(t) +} +func (c *connHalf) SetWriteDeadline(t time.Time) error { + return c.w.SetWriteDeadline(t) +} + +func (c *connHalf) SetReadBlock(b bool) error { + if b { + return c.r.Block() + } + return c.r.Unblock() +} +func (c *connHalf) SetWriteBlock(b bool) error { + if b { + return c.w.Block() + } + return c.w.Unblock() +} diff --git a/vend/tailscale.com/net/memnet/conn_test.go b/vend/tailscale.com/net/memnet/conn_test.go new file mode 100644 index 0000000..743ce52 --- /dev/null +++ b/vend/tailscale.com/net/memnet/conn_test.go @@ -0,0 +1,21 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package memnet + +import ( + "net" + "testing" + + "golang.org/x/net/nettest" +) + +func TestConn(t *testing.T) { + nettest.TestConn(t, func() (c1 net.Conn, c2 net.Conn, stop func(), err error) { + c1, c2 = NewConn("test", bufferSize) + return c1, c2, func() { + c1.Close() + c2.Close() + }, nil + }) +} diff --git a/vend/tailscale.com/net/memnet/go.mod b/vend/tailscale.com/net/memnet/go.mod new file mode 100644 index 0000000..264fb44 --- /dev/null +++ b/vend/tailscale.com/net/memnet/go.mod @@ -0,0 +1,5 @@ +module tailscale.com + +go 1.20 + +require golang.org/x/net v0.10.0 diff --git a/vend/tailscale.com/net/memnet/go.sum b/vend/tailscale.com/net/memnet/go.sum new file mode 100644 index 0000000..ac972c8 --- /dev/null +++ b/vend/tailscale.com/net/memnet/go.sum @@ -0,0 +1,2 @@ +golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= diff --git a/vend/tailscale.com/net/memnet/listener.go b/vend/tailscale.com/net/memnet/listener.go new file mode 100644 index 0000000..d84a2e4 --- /dev/null +++ b/vend/tailscale.com/net/memnet/listener.go @@ -0,0 +1,100 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package memnet + +import ( + "context" + "net" + "strings" + "sync" +) + +const ( + bufferSize = 256 * 1024 +) + +// Listener is a net.Listener using NewConn to create pairs of network +// connections connected in memory using a buffered pipe. It also provides a +// Dial method to establish new connections. +type Listener struct { + addr connAddr + ch chan Conn + closeOnce sync.Once + closed chan struct{} + + // NewConn, if non-nil, is called to create a new pair of connections + // when dialing. If nil, NewConn is used. + NewConn func(network, addr string, maxBuf int) (Conn, Conn) +} + +// Listen returns a new Listener for the provided address. +func Listen(addr string) *Listener { + return &Listener{ + addr: connAddr(addr), + ch: make(chan Conn), + closed: make(chan struct{}), + } +} + +// Addr implements net.Listener.Addr. +func (l *Listener) Addr() net.Addr { + return l.addr +} + +// Close closes the pipe listener. +func (l *Listener) Close() error { + l.closeOnce.Do(func() { + close(l.closed) + }) + return nil +} + +// Accept blocks until a new connection is available or the listener is closed. +func (l *Listener) Accept() (net.Conn, error) { + select { + case c := <-l.ch: + return c, nil + case <-l.closed: + return nil, net.ErrClosed + } +} + +// Dial connects to the listener using the provided context. +// The provided Context must be non-nil. If the context expires before the +// connection is complete, an error is returned. Once successfully connected +// any expiration of the context will not affect the connection. +func (l *Listener) Dial(ctx context.Context, network, addr string) (_ net.Conn, err error) { + if !strings.HasSuffix(network, "tcp") { + return nil, net.UnknownNetworkError(network) + } + if connAddr(addr) != l.addr { + return nil, &net.AddrError{ + Err: "invalid address", + Addr: addr, + } + } + + newConn := l.NewConn + if newConn == nil { + newConn = func(network, addr string, maxBuf int) (Conn, Conn) { + return NewConn(addr, maxBuf) + } + } + c, s := newConn(network, addr, bufferSize) + defer func() { + if err != nil { + c.Close() + s.Close() + } + }() + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-l.closed: + return nil, net.ErrClosed + case l.ch <- s: + return c, nil + } +} diff --git a/vend/tailscale.com/net/memnet/listener_test.go b/vend/tailscale.com/net/memnet/listener_test.go new file mode 100644 index 0000000..73b6784 --- /dev/null +++ b/vend/tailscale.com/net/memnet/listener_test.go @@ -0,0 +1,33 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package memnet + +import ( + "context" + "testing" +) + +func TestListener(t *testing.T) { + l := Listen("srv.local") + defer l.Close() + go func() { + c, err := l.Accept() + if err != nil { + t.Error(err) + return + } + defer c.Close() + }() + + if c, err := l.Dial(context.Background(), "tcp", "invalid"); err == nil { + c.Close() + t.Fatalf("dial to invalid address succeeded") + } + c, err := l.Dial(context.Background(), "tcp", "srv.local") + if err != nil { + t.Fatalf("dial failed: %v", err) + return + } + c.Close() +} diff --git a/vend/tailscale.com/net/memnet/memnet.go b/vend/tailscale.com/net/memnet/memnet.go new file mode 100644 index 0000000..c8799bc --- /dev/null +++ b/vend/tailscale.com/net/memnet/memnet.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package memnet implements an in-memory network implementation. +// It is useful for dialing and listening on in-memory addresses +// in tests and other situations where you don't want to use the +// network. +package memnet diff --git a/vend/tailscale.com/net/memnet/pipe.go b/vend/tailscale.com/net/memnet/pipe.go new file mode 100644 index 0000000..4716350 --- /dev/null +++ b/vend/tailscale.com/net/memnet/pipe.go @@ -0,0 +1,244 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package memnet + +import ( + "bytes" + "context" + "fmt" + "io" + "log" + "net" + "os" + "sync" + "time" +) + +const debugPipe = false + +// Pipe implements an in-memory FIFO with timeouts. +type Pipe struct { + name string + maxBuf int + mu sync.Mutex + cnd *sync.Cond + + blocked bool + closed bool + buf bytes.Buffer + readTimeout time.Time + writeTimeout time.Time + cancelReadTimer func() + cancelWriteTimer func() +} + +// NewPipe creates a Pipe with a buffer size fixed at maxBuf. +func NewPipe(name string, maxBuf int) *Pipe { + p := &Pipe{ + name: name, + maxBuf: maxBuf, + } + p.cnd = sync.NewCond(&p.mu) + return p +} + +// readOrBlock attempts to read from the buffer, if the buffer is empty and +// the connection hasn't been closed it will block until there is a change. +func (p *Pipe) readOrBlock(b []byte) (int, error) { + p.mu.Lock() + defer p.mu.Unlock() + if !p.readTimeout.IsZero() && !time.Now().Before(p.readTimeout) { + return 0, os.ErrDeadlineExceeded + } + if p.blocked { + p.cnd.Wait() + return 0, nil + } + + n, err := p.buf.Read(b) + // err will either be nil or io.EOF. + if err == io.EOF { + if p.closed { + return n, err + } + // Wait for something to change. + p.cnd.Wait() + } + return n, nil +} + +// Read implements io.Reader. +// Once the buffer is drained (i.e. after Close), subsequent calls will +// return io.EOF. +func (p *Pipe) Read(b []byte) (n int, err error) { + if debugPipe { + orig := b + defer func() { + log.Printf("Pipe(%q).Read(%q) n=%d, err=%v", p.name, string(orig[:n]), n, err) + }() + } + for n == 0 { + n2, err := p.readOrBlock(b) + if err != nil { + return n2, err + } + n += n2 + } + p.cnd.Signal() + return n, nil +} + +// writeOrBlock attempts to write to the buffer, if the buffer is full it will +// block until there is a change. +func (p *Pipe) writeOrBlock(b []byte) (int, error) { + p.mu.Lock() + defer p.mu.Unlock() + if p.closed { + return 0, net.ErrClosed + } + if !p.writeTimeout.IsZero() && !time.Now().Before(p.writeTimeout) { + return 0, os.ErrDeadlineExceeded + } + if p.blocked { + p.cnd.Wait() + return 0, nil + } + + // Optimistically we want to write the entire slice. + n := len(b) + if limit := p.maxBuf - p.buf.Len(); limit < n { + // However, we don't have enough capacity to write everything. + n = limit + } + if n == 0 { + // Wait for something to change. + p.cnd.Wait() + return 0, nil + } + + p.buf.Write(b[:n]) + p.cnd.Signal() + return n, nil +} + +// Write implements io.Writer. +func (p *Pipe) Write(b []byte) (n int, err error) { + if debugPipe { + orig := b + defer func() { + log.Printf("Pipe(%q).Write(%q) n=%d, err=%v", p.name, string(orig), n, err) + }() + } + for len(b) > 0 { + n2, err := p.writeOrBlock(b) + if err != nil { + return n + n2, err + } + n += n2 + b = b[n2:] + } + return n, nil +} + +// Close closes the pipe. +func (p *Pipe) Close() error { + p.mu.Lock() + defer p.mu.Unlock() + p.closed = true + p.blocked = false + if p.cancelWriteTimer != nil { + p.cancelWriteTimer() + p.cancelWriteTimer = nil + } + if p.cancelReadTimer != nil { + p.cancelReadTimer() + p.cancelReadTimer = nil + } + p.cnd.Broadcast() + + return nil +} + +func (p *Pipe) deadlineTimer(t time.Time) func() { + if t.IsZero() { + return nil + } + if t.Before(time.Now()) { + p.cnd.Broadcast() + return nil + } + ctx, cancel := context.WithDeadline(context.Background(), t) + go func() { + <-ctx.Done() + if ctx.Err() == context.DeadlineExceeded { + p.cnd.Broadcast() + } + }() + return cancel +} + +// SetReadDeadline sets the deadline for future Read calls. +func (p *Pipe) SetReadDeadline(t time.Time) error { + p.mu.Lock() + defer p.mu.Unlock() + p.readTimeout = t + // If we already have a deadline, cancel it and create a new one. + if p.cancelReadTimer != nil { + p.cancelReadTimer() + p.cancelReadTimer = nil + } + p.cancelReadTimer = p.deadlineTimer(t) + return nil +} + +// SetWriteDeadline sets the deadline for future Write calls. +func (p *Pipe) SetWriteDeadline(t time.Time) error { + p.mu.Lock() + defer p.mu.Unlock() + p.writeTimeout = t + // If we already have a deadline, cancel it and create a new one. + if p.cancelWriteTimer != nil { + p.cancelWriteTimer() + p.cancelWriteTimer = nil + } + p.cancelWriteTimer = p.deadlineTimer(t) + return nil +} + +// Block will cause all calls to Read and Write to block until they either +// timeout, are unblocked or the pipe is closed. +func (p *Pipe) Block() error { + p.mu.Lock() + defer p.mu.Unlock() + closed := p.closed + blocked := p.blocked + p.blocked = true + + if closed { + return fmt.Errorf("memnet.Pipe(%q).Block: closed", p.name) + } + if blocked { + return fmt.Errorf("memnet.Pipe(%q).Block: already blocked", p.name) + } + p.cnd.Broadcast() + return nil +} + +// Unblock will cause all blocked Read/Write calls to continue execution. +func (p *Pipe) Unblock() error { + p.mu.Lock() + defer p.mu.Unlock() + closed := p.closed + blocked := p.blocked + p.blocked = false + + if closed { + return fmt.Errorf("memnet.Pipe(%q).Block: closed", p.name) + } + if !blocked { + return fmt.Errorf("memnet.Pipe(%q).Block: already unblocked", p.name) + } + p.cnd.Broadcast() + return nil +} diff --git a/vend/tailscale.com/net/memnet/pipe_test.go b/vend/tailscale.com/net/memnet/pipe_test.go new file mode 100644 index 0000000..a86d653 --- /dev/null +++ b/vend/tailscale.com/net/memnet/pipe_test.go @@ -0,0 +1,117 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package memnet + +import ( + "errors" + "fmt" + "os" + "testing" + "time" +) + +func TestPipeHello(t *testing.T) { + p := NewPipe("p1", 1<<16) + msg := "Hello, World!" + if n, err := p.Write([]byte(msg)); err != nil { + t.Fatal(err) + } else if n != len(msg) { + t.Errorf("p.Write(%q) n=%d, want %d", msg, n, len(msg)) + } + b := make([]byte, len(msg)) + if n, err := p.Read(b); err != nil { + t.Fatal(err) + } else if n != len(b) { + t.Errorf("p.Read(%q) n=%d, want %d", string(b[:n]), n, len(b)) + } + if got := string(b); got != msg { + t.Errorf("p.Read: %q, want %q", got, msg) + } +} + +func TestPipeTimeout(t *testing.T) { + t.Run("write", func(t *testing.T) { + p := NewPipe("p1", 1<<16) + p.SetWriteDeadline(time.Now().Add(-1 * time.Second)) + n, err := p.Write([]byte{'h'}) + if !errors.Is(err, os.ErrDeadlineExceeded) { + t.Errorf("missing write timeout got err: %v", err) + } + if n != 0 { + t.Errorf("n=%d on timeout", n) + } + }) + t.Run("read", func(t *testing.T) { + p := NewPipe("p1", 1<<16) + p.Write([]byte{'h'}) + + p.SetReadDeadline(time.Now().Add(-1 * time.Second)) + b := make([]byte, 1) + n, err := p.Read(b) + if !errors.Is(err, os.ErrDeadlineExceeded) { + t.Errorf("missing read timeout got err: %v", err) + } + if n != 0 { + t.Errorf("n=%d on timeout", n) + } + }) + t.Run("block-write", func(t *testing.T) { + p := NewPipe("p1", 1<<16) + p.SetWriteDeadline(time.Now().Add(10 * time.Millisecond)) + if err := p.Block(); err != nil { + t.Fatal(err) + } + if _, err := p.Write([]byte{'h'}); !errors.Is(err, os.ErrDeadlineExceeded) { + t.Fatalf("want write timeout got: %v", err) + } + }) + t.Run("block-read", func(t *testing.T) { + p := NewPipe("p1", 1<<16) + p.Write([]byte{'h', 'i'}) + p.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) + b := make([]byte, 1) + if err := p.Block(); err != nil { + t.Fatal(err) + } + if _, err := p.Read(b); !errors.Is(err, os.ErrDeadlineExceeded) { + t.Fatalf("want read timeout got: %v", err) + } + }) +} + +func TestLimit(t *testing.T) { + p := NewPipe("p1", 1) + errCh := make(chan error) + go func() { + n, err := p.Write([]byte{'a', 'b', 'c'}) + if err != nil { + errCh <- err + } else if n != 3 { + errCh <- fmt.Errorf("p.Write n=%d, want 3", n) + } else { + errCh <- nil + } + }() + b := make([]byte, 3) + + if n, err := p.Read(b); err != nil { + t.Fatal(err) + } else if n != 1 { + t.Errorf("Read(%q): n=%d want 1", string(b), n) + } + if n, err := p.Read(b); err != nil { + t.Fatal(err) + } else if n != 1 { + t.Errorf("Read(%q): n=%d want 1", string(b), n) + } + if n, err := p.Read(b); err != nil { + t.Fatal(err) + } else if n != 1 { + t.Errorf("Read(%q): n=%d want 1", string(b), n) + } + + if err := <-errCh; err != nil { + t.Error(err) + } +} From d49eb3f4cd31b23e4b18bf76905b677137f0d822 Mon Sep 17 00:00:00 2001 From: merlin Date: Sat, 17 Jun 2023 15:13:54 +0300 Subject: [PATCH 3/5] fix request parsing --- gemax/request.go | 45 +++++++++++-- gemax/server.go | 2 +- gemax/server_test.go | 157 ++++++++++--------------------------------- go.mod | 1 + 4 files changed, 78 insertions(+), 127 deletions(-) diff --git a/gemax/request.go b/gemax/request.go index 3a50e2b..d7b3c94 100644 --- a/gemax/request.go +++ b/gemax/request.go @@ -41,9 +41,9 @@ func ParseIncomingRequest(re io.Reader, remoteAddr string) (IncomingRequest, err re = io.LimitReader(re, MaxRequestSize) - line, errLine := io.ReadAll(re) + line, errLine := readUntil(re, '\n') if errLine != nil { - return nil, fmt.Errorf("%w: %w", ErrBadRequest, errLine) + return nil, errLine } if !bytes.HasSuffix(line, requestSuffix) { @@ -73,14 +73,14 @@ func ParseIncomingRequest(re io.Reader, remoteAddr string) (IncomingRequest, err } func isValidPath(path string) bool { - if path == "." { - return false - } path = strings.TrimPrefix(path, "/") path = strings.TrimSuffix(path, "/") - if path == "" { + switch path { + case ".", "..": + return false + case "": return true } @@ -100,3 +100,36 @@ func (req *incomingRequest) URL() *url.URL { func (req *incomingRequest) RemoteAddr() string { return req.remoteAddr } + +// - found delimiter -> return data[:delimIndex+1], err +// - found EOF -> return data, err +// - found error -> return data, err +func readUntil(re io.Reader, delim byte) ([]byte, error) { + b := make([]byte, 0, MaxRequestSize/4) + var errRead error + for { + if len(b) == cap(b) { + // Add more capacity (let append pick how much). + b = append(b, 0)[:len(b)] + } + n, err := re.Read(b[len(b):cap(b)]) + b = b[:len(b)+n] + + delimIndex := bytes.IndexByte(b, delim) + if delimIndex >= 0 { + b = b[:delimIndex+1] + } + + if errors.Is(err, io.EOF) && delimIndex < 0 { + // EOF, but no delimiter found. + err = errors.Join(ErrBadRequest, io.ErrUnexpectedEOF) + } + + if delimIndex >= 0 || err != nil { + errRead = err + break + } + } + + return b, errRead +} diff --git a/gemax/server.go b/gemax/server.go index 5ad8d3b..9f3339c 100644 --- a/gemax/server.go +++ b/gemax/server.go @@ -154,7 +154,7 @@ func (server *Server) handle(ctx context.Context, conn net.Conn) { code = status.BadRequest } if errParseReq != nil { - server.logf("WARN: bad request: remote_ip=%s, code=%s: %v", conn.RemoteAddr(), code, errParseReq) + server.logf("WARN: bad request: remote_addr=%s, code=%s: %v", conn.RemoteAddr(), code, errParseReq) rw.WriteStatus(code, status.Text(code)) return } diff --git a/gemax/server_test.go b/gemax/server_test.go index a354ea4..47f6ee9 100644 --- a/gemax/server_test.go +++ b/gemax/server_test.go @@ -16,6 +16,8 @@ import ( "github.com/ninedraft/gemax/gemax" "github.com/ninedraft/gemax/gemax/internal/testaddr" "github.com/ninedraft/gemax/gemax/status" + + "tailscale.com/net/memnet" ) func TestServerSuccess(test *testing.T) { @@ -31,9 +33,9 @@ func TestServerSuccess(test *testing.T) { } }) - var resp = listener.next(test.Name(), strings.NewReader("gemini://example.com/path")) + var resp = dialAndWrite(test, ctx, listener, "gemini://example.com/path\r\n") - expectResponse(test, resp, "20 text/gemini\r\ngemini://example.com/path") + expectResponse(test, strings.NewReader(resp), "20 text/gemini\r\ngemini://example.com/path") } func TestServerBadRequest(test *testing.T) { @@ -48,9 +50,9 @@ func TestServerBadRequest(test *testing.T) { } }) - var resp = listener.next(test.Name(), strings.NewReader("invalid URL")) + var resp = dialAndWrite(test, ctx, listener, "invalid URL\r\n") - expectResponse(test, resp, "59 "+status.Text(status.BadRequest)+"\r\n") + expectResponse(test, strings.NewReader(resp), "59 "+status.Text(status.BadRequest)+"\r\n") } func TestServerInvalidHost(test *testing.T) { @@ -67,9 +69,9 @@ func TestServerInvalidHost(test *testing.T) { } }) - var resp = listener.next(test.Name(), strings.NewReader("gemini://another.com/path")) + var resp = dialAndWrite(test, ctx, listener, "gemini://another.com/path\r\n") - expectResponse(test, resp, "50 host not found\r\n") + expectResponse(test, strings.NewReader(resp), "50 host not found\r\n") } func TestServerCancelListen(test *testing.T) { @@ -234,9 +236,9 @@ func TestURLDotEscape(test *testing.T) { } }) - var resp = listener.next(test.Name(), strings.NewReader("gemini://example.com/../../\r\n")) + var resp = dialAndWrite(test, ctx, listener, "gemini://example.com/./\r\n") - expectResponse(test, resp, "50 50 PERMANENT FAILURE\r\n") + expectResponse(test, strings.NewReader(resp), "50 50 PERMANENT FAILURE\r\n") } // emulates michael-lazar/gemini-diagnostics localhost 9999 --checks='PageNotFound' @@ -257,15 +259,15 @@ func TestPageNotFound(test *testing.T) { } }) - var resp = listener.next(test.Name(), strings.NewReader("gemini://example.com/notexist\r\n")) + var resp = dialAndWrite(test, ctx, listener, "gemini://example.com/notexist\r\n") - expectResponse(test, resp, "51 gemini://example.com/notexist is not found\r\n") + expectResponse(test, strings.NewReader(resp), "51 gemini://example.com/notexist is not found\r\n") }) test.Run("custom", func(test *testing.T) { test.Log("meta must not interfere with response body") var listener, server = setupServer(test, - func(_ context.Context, rw gemax.ResponseWriter, req gemax.IncomingRequest) { + func(_ context.Context, rw gemax.ResponseWriter, _ gemax.IncomingRequest) { rw.WriteStatus(status.NotFound, "page is not found\r\ndotdot") }) server.Hosts = []string{"example.com"} @@ -279,25 +281,25 @@ func TestPageNotFound(test *testing.T) { } }) - var resp = listener.next(test.Name(), strings.NewReader("gemini://example.com/notexist\r\n")) + var resp = dialAndWrite(test, ctx, listener, "gemini://example.com/notexist\r\n") - expectResponse(test, resp, "51 page is not found\tdotdot\r\n") + expectResponse(test, strings.NewReader(resp), "51 page is not found\tdotdot\r\n") }) } -func setupServer(t *testing.T, handler gemax.Handler) (*fakeListener, *gemax.Server) { +func setupServer(t *testing.T, handler gemax.Handler) (*memnet.Listener, *gemax.Server) { t.Helper() var server = &gemax.Server{ Logf: t.Logf, Handler: handler, } - var listener = newListener(t.Name()) + var listener = memnet.Listen(t.Name()) return listener, server } -func setupEchoServer(t *testing.T) (*fakeListener, *gemax.Server) { +func setupEchoServer(t *testing.T) (*memnet.Listener, *gemax.Server) { t.Helper() - return setupServer(t, func(ctx context.Context, rw gemax.ResponseWriter, req gemax.IncomingRequest) { + return setupServer(t, func(_ context.Context, rw gemax.ResponseWriter, req gemax.IncomingRequest) { _, _ = rw.Write([]byte(req.URL().String())) }) } @@ -313,79 +315,6 @@ func expectResponse(t *testing.T, got io.Reader, want string) { } } -type fakeListener struct { - conns chan *fakeConn - addr string -} - -func newListener(addr string) *fakeListener { - return &fakeListener{ - addr: addr, - conns: make(chan *fakeConn), - } -} - -func (listener *fakeListener) next(addr string, data io.Reader) io.Reader { - var pipe = newPipe() - listener.conns <- &fakeConn{ - addr: addr, - localAddr: addr, - Reader: data, - WriteCloser: pipe, - } - return pipe -} - -func (listener *fakeListener) Close() error { - close(listener.conns) - return nil -} - -func (listener *fakeListener) Accept() (net.Conn, error) { - var conn, ok = <-listener.conns - if !ok { - return nil, fmt.Errorf("listener closed: %w", io.EOF) - } - return conn, nil -} - -func (listener *fakeListener) Addr() net.Addr { - return fakeAddr(listener.addr) -} - -type fakeConn struct { - addr string - localAddr string - io.Reader - io.WriteCloser -} - -func (conn *fakeConn) RemoteAddr() net.Addr { - return fakeAddr(conn.addr) -} - -func (conn *fakeConn) LocalAddr() net.Addr { - return fakeAddr(conn.localAddr) -} - -func (conn *fakeConn) SetDeadline(t time.Time) error { - return nil -} - -func (conn *fakeConn) SetReadDeadline(t time.Time) error { - return nil -} - -func (conn *fakeConn) SetWriteDeadline(t time.Time) error { - return nil -} - -type fakeAddr string - -func (fakeAddr) Network() string { return "fake network" } - -func (addr fakeAddr) String() string { return string(addr) } - func runTask(t *testing.T, task func()) { var done = make(chan struct{}) go func() { @@ -397,42 +326,30 @@ func runTask(t *testing.T, task func()) { }) } -type chPipe struct { - closed bool - ch chan byte -} +func dialAndWrite(t *testing.T, ctx context.Context, dialer *memnet.Listener, format string, args ...any) string { + t.Helper() -func newPipe() *chPipe { - return &chPipe{ - ch: make(chan byte), + t.Log("dialing in-memory network") + conn, errDial := dialer.Dial(ctx, "tcp", t.Name()) + if errDial != nil { + panic("dialin in-memory network: " + errDial.Error()) } -} -func (p *chPipe) Read(dst []byte) (int, error) { - for i := range dst { - var b, ok = <-p.ch - if !ok { - return i, io.EOF - } - dst[i] = b - } - return len(dst), nil -} + defer func() { _ = conn.Close() }() -func (p *chPipe) Write(data []byte) (int, error) { - for _, b := range data { - p.ch <- b + t.Log("writing to in-memory network") + _, errWrite := fmt.Fprintf(conn, format, args...) + if errWrite != nil { + panic("writing to in-memory network: " + errWrite.Error()) } - return len(data), nil -} -var errAlreadyClosed = errors.New("already closed") + var resp = &strings.Builder{} -func (p *chPipe) Close() error { - if p.closed { - return errAlreadyClosed + t.Log("reading from in-memory network") + _, errRead := io.Copy(resp, conn) + if errRead != nil { + panic("reading from in-memory network: " + errRead.Error()) } - close(p.ch) - p.closed = true - return nil + + return resp.String() } diff --git a/go.mod b/go.mod index 1d1c341..386b4f1 100644 --- a/go.mod +++ b/go.mod @@ -7,4 +7,5 @@ replace tailscale.com/net/memnet => ./vend/tailscale.com/net/memnet require ( golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 golang.org/x/net v0.11.0 + tailscale.com/net/memnet v1.42.0 ) From ccfc15915470e6b8924225fc4a11b5a92b4b03e0 Mon Sep 17 00:00:00 2001 From: merlin Date: Sat, 17 Jun 2023 15:36:29 +0300 Subject: [PATCH 4/5] handle no scheme as bad request --- gemax/request.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/gemax/request.go b/gemax/request.go index d7b3c94..0c4ed4e 100644 --- a/gemax/request.go +++ b/gemax/request.go @@ -57,6 +57,10 @@ func ParseIncomingRequest(re io.Reader, remoteAddr string) (IncomingRequest, err return nil, fmt.Errorf("%w: %w", ErrBadRequest, errParse) } + if parsed.Scheme == "" { + return nil, fmt.Errorf("%w: missing scheme", ErrBadRequest) + } + if !isValidPath(parsed.Path) { return nil, fmt.Errorf("%w: %w", ErrBadRequest, errDotPath) } From f00ceca3a1d3de3e7cfbbc37c30f2f3b5865f0d9 Mon Sep 17 00:00:00 2001 From: merlin Date: Sat, 17 Jun 2023 15:36:47 +0300 Subject: [PATCH 5/5] always return bad request on parsing error --- gemax/server.go | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/gemax/server.go b/gemax/server.go index 9f3339c..3a98a0a 100644 --- a/gemax/server.go +++ b/gemax/server.go @@ -3,7 +3,6 @@ package gemax import ( "context" "crypto/tls" - "errors" "fmt" "net" "net/url" @@ -146,14 +145,8 @@ func (server *Server) handle(ctx context.Context, conn net.Conn) { } }() var req, errParseReq = ParseIncomingRequest(conn, conn.RemoteAddr().String()) - var code = status.Success - switch { - case errors.Is(errParseReq, errDotPath): - code = status.PermanentFailure - case errParseReq != nil: - code = status.BadRequest - } if errParseReq != nil { + const code = status.BadRequest server.logf("WARN: bad request: remote_addr=%s, code=%s: %v", conn.RemoteAddr(), code, errParseReq) rw.WriteStatus(code, status.Text(code)) return @@ -217,7 +210,7 @@ func (server *Server) validHost(u *url.URL) bool { func (server *Server) buildHosts() { if server.hosts == nil { - server.hosts = map[string]struct{}{} + server.hosts = make(map[string]struct{}, len(server.Hosts)) } for _, host := range server.Hosts { server.hosts[host] = struct{}{}