Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change HandledMethods to AllowGet and cleanup method handling #191

Merged
merged 1 commit into from
Apr 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 7 additions & 19 deletions http/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@ type ServerConfig struct {
// Headers is an optional map of headers that is written out.
Headers map[string][]string

// HandledMethods set which methods will be handled for the HTTP
// requests. Other methods will return 405. This is different from CORS
// AllowedMethods (the API may handle GET and POST, but only allow GETs
// for CORS-enabled requests via AllowedMethods).
HandledMethods []string
// AllowGet indicates whether or not this server accepts GET requests.
// When unset, the server only accepts POST, HEAD, and OPTIONS.
//
// This is different from CORS AllowedMethods. The API may allow GET
// requests in general, but reject them in CORS. That will allow
// websites to include resources from the API but not _read_ them.
AllowGet bool

// corsOpts is a set of options for CORS headers.
corsOpts *cors.Options
Expand All @@ -38,7 +40,6 @@ type ServerConfig struct {
func NewServerConfig() *ServerConfig {
cfg := new(ServerConfig)
cfg.corsOpts = new(cors.Options)
cfg.HandledMethods = []string{http.MethodPost}
return cfg
}

Expand Down Expand Up @@ -149,16 +150,3 @@ func allowReferer(r *http.Request, cfg *ServerConfig) bool {

return false
}

// handleRequestMethod returns true if the request method is among
// HandledMethods.
func handleRequestMethod(r *http.Request, cfg *ServerConfig) bool {
// For very small slices as these, this should be faster than
// a map lookup.
for _, m := range cfg.HandledMethods {
if r.Method == m {
return true
}
}
return false
}
10 changes: 5 additions & 5 deletions http/errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func TestErrors(t *testing.T) {

mkTest := func(tc testcase) func(*testing.T) {
return func(t *testing.T) {
_, srv := getTestServer(t, nil, nil) // handler_test:/^func getTestServer/
_, srv := getTestServer(t, nil, false) // handler_test:/^func getTestServer/
c := NewClient(srv.URL)
req, err := cmds.NewRequest(context.Background(), tc.path, tc.opts, nil, nil, cmdRoot)
if err != nil {
Expand Down Expand Up @@ -161,11 +161,11 @@ func TestErrors(t *testing.T) {

func TestUnhandledMethod(t *testing.T) {
tc := httpTestCase{
Method: "GET",
HandledMethods: []string{"POST"},
Code: http.StatusMethodNotAllowed,
Method: "GET",
AllowGet: false,
Code: http.StatusMethodNotAllowed,
ResHeaders: map[string]string{
"Allow": "POST",
"Allow": "POST, HEAD, OPTIONS",
},
}
tc.test(t)
Expand Down
46 changes: 33 additions & 13 deletions http/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,27 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {

// First of all, check if we are allowed to handle the request method
// or we are configured not to.
if !handleRequestMethod(r, h.cfg) {
setAllowedHeaders(w, h.cfg.HandledMethods)
//
// Always allow OPTIONS, POST
switch r.Method {
case http.MethodOptions:
// If we get here, this is a normal (non-preflight) request.
// The CORS library handles all other requests.

// Tell the user the allowed methods, and return.
setAllowedHeaders(w, h.cfg.AllowGet)
w.WriteHeader(http.StatusNoContent)
return
case http.MethodPost:
case http.MethodGet, http.MethodHead:
if h.cfg.AllowGet {
break
}
fallthrough
default:
setAllowedHeaders(w, h.cfg.AllowGet)
http.Error(w, "405 - Method Not Allowed", http.StatusMethodNotAllowed)
log.Warnf("The IPFS API does not support %s requests. All requests must use %s", h.cfg.HandledMethods)
log.Warnf("The IPFS API does not support %s requests.", r.Method)
return
}

Expand Down Expand Up @@ -139,6 +156,13 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}

// set user's headers first.
for k, v := range h.cfg.Headers {
if !skipAPIHeader(k) {
w.Header()[k] = v
}
}

// Handle the timeout up front.
var cancel func()
if timeoutStr, ok := req.Options[cmds.TimeoutOpt]; ok {
Expand All @@ -163,13 +187,6 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
defer done()
}

// set user's headers first.
for k, v := range h.cfg.Headers {
if !skipAPIHeader(k) {
w.Header()[k] = v
}
}

h.root.Call(req, re, h.env)
}

Expand All @@ -180,8 +197,11 @@ func sanitizedErrStr(err error) string {
return s
}

func setAllowedHeaders(w http.ResponseWriter, methods []string) {
for _, m := range methods {
w.Header().Add("Allow", m)
func setAllowedHeaders(w http.ResponseWriter, allowGet bool) {
w.Header().Add("Allow", http.MethodHead)
w.Header().Add("Allow", http.MethodOptions)
w.Header().Add("Allow", http.MethodPost)
if allowGet {
w.Header().Add("Allow", http.MethodGet)
}
}
9 changes: 2 additions & 7 deletions http/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ var (
}
)

func getTestServer(t *testing.T, origins []string, handledMethods []string) (cmds.Environment, *httptest.Server) {
func getTestServer(t *testing.T, origins []string, allowGet bool) (cmds.Environment, *httptest.Server) {
if len(origins) == 0 {
origins = defaultOrigins
}
Expand All @@ -306,12 +306,7 @@ func getTestServer(t *testing.T, origins []string, handledMethods []string) (cmd
}

srvCfg := originCfg(origins)

if len(handledMethods) == 0 {
srvCfg.HandledMethods = []string{"GET", "POST"}
} else {
srvCfg.HandledMethods = handledMethods
}
srvCfg.AllowGet = allowGet

return env, httptest.NewServer(NewHandler(env, cmdRoot, srvCfg))
}
Expand Down
2 changes: 1 addition & 1 deletion http/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func TestHTTP(t *testing.T) {

mkTest := func(tc testcase) func(*testing.T) {
return func(t *testing.T) {
env, srv := getTestServer(t, nil, nil) // handler_test:/^func getTestServer/
env, srv := getTestServer(t, nil, true) // handler_test:/^func getTestServer/
c := NewClient(srv.URL)
req, err := cmds.NewRequest(context.Background(), tc.path, nil, nil, nil, cmdRoot)
if err != nil {
Expand Down
61 changes: 48 additions & 13 deletions http/reforigin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,42 @@ import (
"fmt"
"net/http"
"net/url"
"strings"
"testing"

cmds "github.com/ipfs/go-ipfs-cmds"
)

func assertHeaders(t *testing.T, resHeaders http.Header, reqHeaders map[string]string) {
t.Helper()
t.Logf("headers: %v", resHeaders)
for name, value := range reqHeaders {
if resHeaders.Get(name) != value {
t.Errorf("Invalid header '%s', wanted '%s', got '%s'", name, value, resHeaders.Get(name))
header := resHeaders[http.CanonicalHeaderKey(name)]
switch len(header) {
case 0:
if value != "" {
t.Errorf("expected a header for %s", name)
}
case 1:
if header[0] != value {
t.Errorf("Invalid header '%s', wanted '%s', got '%s'", name, value, header[0])
}
default:
values := strings.Split(value, ",")
set := make(map[string]bool, len(values))
for _, v := range values {
set[strings.Trim(v, " ")] = true
}
for _, got := range header {
if !set[got] {
t.Errorf("found unexpected value %s in header %s", got, name)
continue
}
delete(set, got)
}
for missing := range set {
t.Errorf("missing value %s in header %s", missing, name)
}
}
}
}
Expand All @@ -27,7 +54,7 @@ func originCfg(origins []string) *ServerConfig {
cfg := NewServerConfig()
cfg.SetAllowedOrigins(origins...)
cfg.SetAllowedMethods("GET", "PUT", "POST")
cfg.HandledMethods = []string{"GET", "POST"}
cfg.AllowGet = true
return cfg
}

Expand All @@ -39,18 +66,19 @@ var defaultOrigins = []string{
}

type httpTestCase struct {
Method string
Path string
Code int
Origin string
Referer string
AllowOrigins []string
HandledMethods []string
ReqHeaders map[string]string
ResHeaders map[string]string
Method string
Path string
Code int
Origin string
Referer string
AllowOrigins []string
AllowGet bool
ReqHeaders map[string]string
ResHeaders map[string]string
}

func (tc *httpTestCase) test(t *testing.T) {
t.Helper()
// defaults
method := tc.Method
if method == "" {
Expand Down Expand Up @@ -85,7 +113,7 @@ func (tc *httpTestCase) test(t *testing.T) {
}

// server
_, server := getTestServer(t, tc.AllowOrigins, tc.HandledMethods)
_, server := getTestServer(t, tc.AllowOrigins, tc.AllowGet)
if server == nil {
return
}
Expand Down Expand Up @@ -114,6 +142,7 @@ func TestDisallowedOrigins(t *testing.T) {
return httpTestCase{
Origin: origin,
AllowOrigins: allowedOrigins,
AllowGet: true,
ResHeaders: map[string]string{
ACAOrigin: "",
ACAMethods: "",
Expand Down Expand Up @@ -144,6 +173,7 @@ func TestAllowedOrigins(t *testing.T) {
return httpTestCase{
Origin: origin,
AllowOrigins: allowedOrigins,
AllowGet: true,
ResHeaders: map[string]string{
ACAOrigin: origin,
ACAMethods: "",
Expand Down Expand Up @@ -171,6 +201,7 @@ func TestWildcardOrigin(t *testing.T) {
gtc := func(origin string, allowedOrigins []string) httpTestCase {
return httpTestCase{
Origin: origin,
AllowGet: true,
AllowOrigins: allowedOrigins,
ResHeaders: map[string]string{
ACAOrigin: "*",
Expand Down Expand Up @@ -204,6 +235,7 @@ func TestDisallowedReferer(t *testing.T) {
return httpTestCase{
Origin: "http://localhost",
Referer: referer,
AllowGet: true,
AllowOrigins: allowedOrigins,
ResHeaders: map[string]string{
ACAOrigin: "http://localhost",
Expand Down Expand Up @@ -232,6 +264,7 @@ func TestAllowedReferer(t *testing.T) {
return httpTestCase{
Origin: "http://localhost",
AllowOrigins: allowedOrigins,
AllowGet: true,
ResHeaders: map[string]string{
ACAOrigin: "http://localhost",
ACAMethods: "",
Expand Down Expand Up @@ -260,6 +293,7 @@ func TestWildcardReferer(t *testing.T) {
return httpTestCase{
Origin: origin,
AllowOrigins: allowedOrigins,
AllowGet: true,
ResHeaders: map[string]string{
ACAOrigin: "*",
ACAMethods: "",
Expand Down Expand Up @@ -338,6 +372,7 @@ func TestEncoding(t *testing.T) {
return httpTestCase{
Method: "GET",
Path: path,
AllowGet: true,
Origin: "http://localhost",
AllowOrigins: []string{"*"},
ReqHeaders: map[string]string{
Expand Down
2 changes: 1 addition & 1 deletion http/responseemitter.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func (re *responseEmitter) Emit(value interface{}) error {
var err error

// return immediately if this is a head request
if re.method == "HEAD" {
if re.method == http.MethodHead {
return nil
}

Expand Down