Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decode path-encoded URL components #2332

Merged
merged 4 commits into from
Sep 10, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 50 additions & 2 deletions runtime/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,31 @@ import (
"google.golang.org/protobuf/proto"
)

// UnescapingMode defines the behavior of grpc-gateway for URL escaping.
v3n marked this conversation as resolved.
Show resolved Hide resolved
type UnescapingMode int

const (
// UnescapingModeLegacy is the default V2 behavior, which escapes the entire
// path string before doing any routing.
UnescapingModeLegacy UnescapingMode = iota

// EscapingTypeExceptReserved unescapes all path parameters except RFC 6570
// reserved characters.
UnescapingModeAllExceptReserved

// EscapingTypeExceptSlash unescapes URL path parameters except path
// seperators, which will be left as "%2F".
UnescapingModeAllExceptSlash

// URL path parameters will be fully decoded.
UnescapingModeAllCharacters

// UnescapingModeDefault is the default escaping type.
// TODO(v3): default this to UnescapingModeAllExceptReserved per grpc-httpjson-transcoding's
// reference implementation
UnescapingModeDefault = UnescapingModeLegacy
)

// A HandlerFunc handles a specific pair of path pattern and HTTP method.
type HandlerFunc func(w http.ResponseWriter, r *http.Request, pathParams map[string]string)

Expand All @@ -31,6 +56,7 @@ type ServeMux struct {
streamErrorHandler StreamErrorHandlerFunc
routingErrorHandler RoutingErrorHandlerFunc
disablePathLengthFallback bool
unescapingMode UnescapingMode
}

// ServeMuxOption is an option that can be given to a ServeMux on construction.
Expand All @@ -48,6 +74,14 @@ func WithForwardResponseOption(forwardResponseOption func(context.Context, http.
}
}

// WithEscapingType sets the escaping type. See the definitions of UnescapingMode
// for more information.
func WithUnescapingMode(mode UnescapingMode) ServeMuxOption {
return func(serveMux *ServeMux) {
serveMux.unescapingMode = mode
}
}

// SetQueryParameterParser sets the query parameter parser, used to populate message from query parameters.
// Configuring this will mean the generated OpenAPI output is no longer correct, and it should be
// done with careful consideration.
Expand Down Expand Up @@ -153,6 +187,7 @@ func NewServeMux(opts ...ServeMuxOption) *ServeMux {
errorHandler: DefaultHTTPErrorHandler,
streamErrorHandler: DefaultStreamErrorHandler,
routingErrorHandler: DefaultRoutingErrorHandler,
unescapingMode: UnescapingModeDefault,
}

for _, opt := range opts {
Expand Down Expand Up @@ -204,6 +239,11 @@ func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}

// TODO(v3): remove UnescapingModeLegacy
if s.unescapingMode != UnescapingModeLegacy && r.URL.RawPath != "" {
path = r.URL.RawPath
}

components := strings.Split(path[1:], "/")

if override := r.Header.Get("X-HTTP-Method-Override"); override != "" && s.isPathLengthFallback(r) {
Expand Down Expand Up @@ -244,8 +284,12 @@ func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
components[l-1], verb = lastComponent[:idx], lastComponent[idx+1:]
}

pathParams, err := h.pat.Match(components, verb)
pathParams, err := h.pat.MatchAndEscape(components, verb, s.unescapingMode)
if err != nil {
if err == ErrMalformedSequence {
johanbrandhorst marked this conversation as resolved.
Show resolved Hide resolved
_, outboundMarshaler := MarshalerForRequest(s, r)
s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusBadRequest)
}
continue
}
h.h(w, r, pathParams)
Expand All @@ -259,8 +303,12 @@ func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
continue
}
for _, h := range handlers {
pathParams, err := h.pat.Match(components, verb)
pathParams, err := h.pat.MatchAndEscape(components, verb, s.unescapingMode)
if err != nil {
if err == ErrMalformedSequence {
johanbrandhorst marked this conversation as resolved.
Show resolved Hide resolved
_, outboundMarshaler := MarshalerForRequest(s, r)
s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusBadRequest)
}
continue
}
// X-HTTP-Method-Override is optional. Always allow fallback to POST.
Expand Down
66 changes: 65 additions & 1 deletion runtime/mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ func TestMuxServeHTTP(t *testing.T) {
respContent string

disablePathLengthFallback bool
unescapingMode runtime.UnescapingMode
}{
{
patterns: nil,
Expand Down Expand Up @@ -330,11 +331,74 @@ func TestMuxServeHTTP(t *testing.T) {
respStatus: http.StatusOK,
respContent: "POST /foo/{id=*}:verb:subverb",
},
{
patterns: []stubPattern{
{
method: "GET",
ops: []int{int(utilities.OpLitPush), 0, int(utilities.OpPush), 1, int(utilities.OpCapture), 1, int(utilities.OpLitPush), 2},
pool: []string{"foo", "id", "bar"},
},
},
reqMethod: "POST",
reqPath: "/foo/404%2fwith%2Fspace/bar",
headers: map[string]string{
"Content-Type": "application/json",
},
respStatus: http.StatusNotFound,
unescapingMode: runtime.UnescapingModeLegacy,
},
{
patterns: []stubPattern{
{
method: "GET",
ops: []int{
int(utilities.OpLitPush), 0,
int(utilities.OpPush), 0,
int(utilities.OpConcatN), 1,
int(utilities.OpCapture), 1,
int(utilities.OpLitPush), 2},
pool: []string{"foo", "id", "bar"},
},
},
reqMethod: "GET",
reqPath: "/foo/success%2fwith%2Fspace/bar",
headers: map[string]string{
"Content-Type": "application/json",
},
respStatus: http.StatusOK,
unescapingMode: runtime.UnescapingModeAllExceptReserved,
respContent: "GET /foo/{id=*}/bar",
},
{
patterns: []stubPattern{
{
method: "GET",
ops: []int{
int(utilities.OpLitPush), 0,
int(utilities.OpPushM), 0,
int(utilities.OpConcatN), 1,
int(utilities.OpCapture), 1,
},
pool: []string{"foo", "id", "bar"},
},
},
reqMethod: "GET",
reqPath: "/foo/success%2fwith%2Fspace",
headers: map[string]string{
"Content-Type": "application/json",
},
respStatus: http.StatusOK,
unescapingMode: runtime.UnescapingModeAllExceptReserved,
respContent: "GET /foo/{id=**}",
},
} {
t.Run(strconv.Itoa(i), func(t *testing.T) {
var opts []runtime.ServeMuxOption
opts = append(opts, runtime.WithUnescapingMode(spec.unescapingMode))
if spec.disablePathLengthFallback {
opts = append(opts, runtime.WithDisablePathLengthFallback())
opts = append(opts,
runtime.WithDisablePathLengthFallback(),
)
}
mux := runtime.NewServeMux(opts...)
for _, p := range spec.patterns {
Expand Down
151 changes: 146 additions & 5 deletions runtime/pattern.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ var (
ErrNotMatch = errors.New("not match to the path pattern")
// ErrInvalidPattern indicates that the given definition of Pattern is not valid.
ErrInvalidPattern = errors.New("invalid pattern")
// ErrMalformedSequence indicates that an escape sequence was malformed.
ErrMalformedSequence = errors.New("malformed escape sequence")
)

type op struct {
Expand Down Expand Up @@ -140,10 +142,11 @@ func MustPattern(p Pattern, err error) Pattern {
return p
}

// Match examines components if it matches to the Pattern.
// If it matches, the function returns a mapping from field paths to their captured values.
// If otherwise, the function returns an error.
func (p Pattern) Match(components []string, verb string) (map[string]string, error) {
// MatchAndEscape examines components if it matches to the Pattern. If it matches,
// the function returns a mapping from field paths to their captured values while
// applying the provided unescaping mode, returning an error if the URL encoding
// is malformed. Otherwise, the function returns an error.
v3n marked this conversation as resolved.
Show resolved Hide resolved
func (p Pattern) MatchAndEscape(components []string, verb string, unescapingMode UnescapingMode) (map[string]string, error) {
if p.verb != verb {
if p.verb != "" {
return nil, ErrNotMatch
Expand All @@ -161,6 +164,8 @@ func (p Pattern) Match(components []string, verb string) (map[string]string, err
captured := make([]string, len(p.vars))
l := len(components)
for _, op := range p.ops {
var err error

switch op.code {
case utilities.OpNop:
continue
Expand All @@ -173,6 +178,10 @@ func (p Pattern) Match(components []string, verb string) (map[string]string, err
if lit := p.pool[op.operand]; c != lit {
return nil, ErrNotMatch
}
} else if op.code == utilities.OpPush {
if c, err = unescape(c, unescapingMode, false); err != nil {
return nil, ErrMalformedSequence
}
}
stack = append(stack, c)
pos++
Expand All @@ -182,7 +191,11 @@ func (p Pattern) Match(components []string, verb string) (map[string]string, err
return nil, ErrNotMatch
}
end -= p.tailLen
stack = append(stack, strings.Join(components[pos:end], "/"))
c := strings.Join(components[pos:end], "/")
if c, err = unescape(c, unescapingMode, true); err != nil {
return nil, ErrMalformedSequence
}
stack = append(stack, c)
pos = end
case utilities.OpConcatN:
n := op.operand
Expand All @@ -204,6 +217,15 @@ func (p Pattern) Match(components []string, verb string) (map[string]string, err
return bindings, nil
}

// Match examines components if it matches to the Pattern.
// If it matches, the function returns a mapping from field paths to their captured values.
// If otherwise, the function returns an error.
//
// Deprecated: Use MatchAndEscape.
v3n marked this conversation as resolved.
Show resolved Hide resolved
func (p Pattern) Match(components []string, verb string) (map[string]string, error) {
return p.MatchAndEscape(components, verb, UnescapingModeDefault)
}

// Verb returns the verb part of the Pattern.
func (p Pattern) Verb() string { return p.verb }

Expand Down Expand Up @@ -234,3 +256,122 @@ func (p Pattern) String() string {
}
return "/" + segs
}

/*
* The following code is adopted and modified from Go's standard library
* and carries the attached license.
*
* Copyright 2009 The Go Authors. All rights reserved.
* Use of this source code is governed by a BSD-style
* license that can be found in the LICENSE file.
*/

// ishex returns whether or not the given byte is a valid hex character
func ishex(c byte) bool {
switch {
case '0' <= c && c <= '9':
return true
case 'a' <= c && c <= 'f':
return true
case 'A' <= c && c <= 'F':
return true
}
return false
}


func isRFC6570Reserved(c byte) bool {
switch c {
case '!', '#', '$', '&', '\'', '(', ')', '*',
'+', ',', '/', ':', ';', '=', '?', '@', '[', ']':
return true
default:
return false
}
}

// unhex converts a hex point to the bit representation
func unhex(c byte) byte {
switch {
case '0' <= c && c <= '9':
return c - '0'
case 'a' <= c && c <= 'f':
return c - 'a' + 10
case 'A' <= c && c <= 'F':
return c - 'A' + 10
}
return 0
}

// shouldUnescapeWithMode returns true if the character is escapable with the
// given mode
func shouldUnescapeWithMode(c byte, mode UnescapingMode) bool {
switch mode {
case UnescapingModeAllExceptReserved:
if isRFC6570Reserved(c) {
return false
}
case UnescapingModeAllExceptSlash:
if c == '/' {
return false
}
case UnescapingModeAllCharacters:
return true
}
return true
}

// unescape unescapes a path string using the provided mode
func unescape(s string, mode UnescapingMode, multisegment bool) (string, error) {
// TODO(v3): remove UnescapingModeLegacy
if mode == UnescapingModeLegacy {
return s, nil
}

if !multisegment {
mode = UnescapingModeAllCharacters
}

// Count %, check that they're well-formed.
n := 0
for i := 0; i < len(s); {
if s[i] == '%' {
n++
if i+2 >= len(s) || !ishex(s[i+1]) || !ishex(s[i+2]) {
s = s[i:]
if len(s) > 3 {
s = s[:3]
}

return "", ErrMalformedSequence
}
i += 3
} else {
i++
}
}

if n == 0 {
return s, nil
}

var t strings.Builder
t.Grow(len(s))
for i := 0; i < len(s); i++ {
switch s[i] {
case '%':
c := unhex(s[i+1])<<4 | unhex(s[i+2])
if shouldUnescapeWithMode(c, mode) {
t.WriteByte(c)
i += 2
continue
}
fallthrough
default:
t.WriteByte(s[i])
}
}

return t.String(), nil
}

Loading