Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix request parsing #38

Merged
merged 5 commits into from
Jun 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 92 additions & 10 deletions gemax/request.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
package gemax

import (
"bytes"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"io/fs"
"net/url"
"strings"

"golang.org/x/exp/slices"
)

var requestSuffix = []byte("\n")

// MaxRequestSize is the maximum incoming request size in bytes.
const MaxRequestSize = 1026
const MaxRequestSize = int64(1024 + len("\r\n"))

// IncomingRequest describes a server side request object.
type IncomingRequest interface {
Expand All @@ -21,31 +29,72 @@ var (
errDotPath = errors.New("dots in path are not permitted")
)

var ErrBadRequest = errors.New("bad request")

// ParseIncomingRequest constructs an IncomingRequest from bytestream
// and additional parameters (remote address for now).
func ParseIncomingRequest(re io.Reader, remoteAddr string) (IncomingRequest, error) {
var reader = io.LimitReader(re, MaxRequestSize)
var u string
var _, errReadRequest = fmt.Fscanf(reader, "%s\r\n", &u)
if errReadRequest != nil {
return nil, fmt.Errorf("bad request: %w", errReadRequest)
var certs []*x509.Certificate
if tlsConn, ok := re.(*tls.Conn); ok {
certs = slices.Clone(tlsConn.ConnectionState().PeerCertificates)
}

re = io.LimitReader(re, MaxRequestSize)

line, errLine := readUntil(re, '\n')
if errLine != nil {
return nil, errLine
}

if !bytes.HasSuffix(line, requestSuffix) {
return nil, ErrBadRequest
}
var parsed, errParse = url.ParseRequestURI(u)

line = bytes.TrimRight(line, "\r\n")

parsed, errParse := url.ParseRequestURI(string(line))
if errParse != nil {
return nil, fmt.Errorf("bad request: %w", errParse)
return nil, fmt.Errorf("%w: %w", ErrBadRequest, errParse)
}

if parsed.Scheme == "" {
return nil, fmt.Errorf("%w: missing scheme", ErrBadRequest)
}
if strings.Contains(parsed.Path, "/..") {
return nil, errDotPath

if !isValidPath(parsed.Path) {
return nil, fmt.Errorf("%w: %w", ErrBadRequest, errDotPath)
}

if parsed.Path == "" {
parsed.Path = "/"
}

return &incomingRequest{
url: parsed,
remoteAddr: remoteAddr,
certs: certs,
}, nil
}

func isValidPath(path string) bool {

path = strings.TrimPrefix(path, "/")
path = strings.TrimSuffix(path, "/")

switch path {
case ".", "..":
return false
case "":
return true
}

return fs.ValidPath(path)
}

type incomingRequest struct {
url *url.URL
remoteAddr string
certs []*x509.Certificate
}

func (req *incomingRequest) URL() *url.URL {
Expand All @@ -55,3 +104,36 @@ func (req *incomingRequest) URL() *url.URL {
func (req *incomingRequest) RemoteAddr() string {
return req.remoteAddr
}

// - found delimiter -> return data[:delimIndex+1], err
// - found EOF -> return data, err
// - found error -> return data, err
func readUntil(re io.Reader, delim byte) ([]byte, error) {
b := make([]byte, 0, MaxRequestSize/4)
var errRead error
for {
if len(b) == cap(b) {
// Add more capacity (let append pick how much).
b = append(b, 0)[:len(b)]
}
n, err := re.Read(b[len(b):cap(b)])
b = b[:len(b)+n]

delimIndex := bytes.IndexByte(b, delim)
if delimIndex >= 0 {
b = b[:delimIndex+1]
}

if errors.Is(err, io.EOF) && delimIndex < 0 {
// EOF, but no delimiter found.
err = errors.Join(ErrBadRequest, io.ErrUnexpectedEOF)
}

if delimIndex >= 0 || err != nil {
errRead = err
break
}
}

return b, errRead
}
83 changes: 83 additions & 0 deletions gemax/request_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package gemax_test

import (
"strings"
"testing"

"github.com/ninedraft/gemax/gemax"
)

func TestParseIncomingRequest(t *testing.T) {
t.Parallel()
t.Log("parsing incoming request line")

const remoteAddr = "remote-addr"
type expect struct {
err bool
url string
}

tc := func(name, input string, expected expect) {
t.Run(name, func(t *testing.T) {
t.Parallel()

re := strings.NewReader(input)

parsed, err := gemax.ParseIncomingRequest(re, remoteAddr)

if (err != nil) != expected.err {
t.Errorf("error = %v, want error = %v", err, expected.err)
}

if parsed == nil && err == nil {
t.Error("parsed = nil, want not nil")
return
}

if parsed != nil {
assertEq(t, parsed.RemoteAddr(), remoteAddr, "remote addr")
assertEq(t, parsed.URL().String(), expected.url, "url")
}
})
}

tc("valid",
"gemini://example.com\r\n", expect{
url: "gemini://example.com/",
})
tc("valid no \\r",
"gemini://example.com\n", expect{
url: "gemini://example.com/",
})
tc("valid with path",
"gemini://example.com/path\r\n", expect{
url: "gemini://example.com/path",
})
tc("valid with path and query",
"gemini://example.com/path?query=value\r\n", expect{
url: "gemini://example.com/path?query=value",
})
tc("valid http",
"http://example.com\r\n", expect{
url: "http://example.com/",
})

tc("too long",
"http://example.com/"+strings.Repeat("a", 2048)+"\r\n",
expect{err: true})
tc("empty",
"", expect{err: true})
tc("no new \\r\\n",
"gemini://example.com", expect{err: true})
tc("no \\n",
"gemini://example.com\r", expect{err: true})
}

func assertEq[E comparable](t *testing.T, got, want E, format string, args ...any) {
t.Helper()

if got != want {
t.Errorf("got %v, want %v", got, want)
t.Errorf(format, args...)
}
}
13 changes: 3 additions & 10 deletions gemax/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package gemax
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/url"
Expand Down Expand Up @@ -146,15 +145,9 @@ func (server *Server) handle(ctx context.Context, conn net.Conn) {
}
}()
var req, errParseReq = ParseIncomingRequest(conn, conn.RemoteAddr().String())
var code = status.Success
switch {
case errors.Is(errParseReq, errDotPath):
code = status.PermanentFailure
case errParseReq != nil:
code = status.BadRequest
}
if errParseReq != nil {
server.logf("WARN: bad request: remote_ip=%s, code=%s: %v", conn.RemoteAddr(), code, errParseReq)
const code = status.BadRequest
server.logf("WARN: bad request: remote_addr=%s, code=%s: %v", conn.RemoteAddr(), code, errParseReq)
rw.WriteStatus(code, status.Text(code))
return
}
Expand Down Expand Up @@ -217,7 +210,7 @@ func (server *Server) validHost(u *url.URL) bool {

func (server *Server) buildHosts() {
if server.hosts == nil {
server.hosts = map[string]struct{}{}
server.hosts = make(map[string]struct{}, len(server.Hosts))
}
for _, host := range server.Hosts {
server.hosts[host] = struct{}{}
Expand Down
Loading