diff --git a/http/config.go b/http/config.go index aff4164d..0232922e 100644 --- a/http/config.go +++ b/http/config.go @@ -22,11 +22,13 @@ type ServerConfig struct { // Headers is an optional map of headers that is written out. Headers map[string][]string - // HandledMethods set which methods will be handled for the HTTP - // requests. Other methods will return 405. This is different from CORS - // AllowedMethods (the API may handle GET and POST, but only allow GETs - // for CORS-enabled requests via AllowedMethods). - HandledMethods []string + // AllowGet indicates whether or not this server accepts GET requests. + // When unset, the server only accepts POST, HEAD, and OPTIONS. + // + // This is different from CORS AllowedMethods. The API may allow GET + // requests in general, but reject them in CORS. That will allow + // websites to include resources from the API but not _read_ them. + AllowGet bool // corsOpts is a set of options for CORS headers. corsOpts *cors.Options @@ -38,7 +40,6 @@ type ServerConfig struct { func NewServerConfig() *ServerConfig { cfg := new(ServerConfig) cfg.corsOpts = new(cors.Options) - cfg.HandledMethods = []string{http.MethodPost} return cfg } @@ -149,16 +150,3 @@ func allowReferer(r *http.Request, cfg *ServerConfig) bool { return false } - -// handleRequestMethod returns true if the request method is among -// HandledMethods. -func handleRequestMethod(r *http.Request, cfg *ServerConfig) bool { - // For very small slices as these, this should be faster than - // a map lookup. - for _, m := range cfg.HandledMethods { - if r.Method == m { - return true - } - } - return false -} diff --git a/http/errors_test.go b/http/errors_test.go index 3c004eee..a320b7d7 100644 --- a/http/errors_test.go +++ b/http/errors_test.go @@ -116,7 +116,7 @@ func TestErrors(t *testing.T) { mkTest := func(tc testcase) func(*testing.T) { return func(t *testing.T) { - _, srv := getTestServer(t, nil, nil) // handler_test:/^func getTestServer/ + _, srv := getTestServer(t, nil, false) // handler_test:/^func getTestServer/ c := NewClient(srv.URL) req, err := cmds.NewRequest(context.Background(), tc.path, tc.opts, nil, nil, cmdRoot) if err != nil { @@ -161,11 +161,11 @@ func TestErrors(t *testing.T) { func TestUnhandledMethod(t *testing.T) { tc := httpTestCase{ - Method: "GET", - HandledMethods: []string{"POST"}, - Code: http.StatusMethodNotAllowed, + Method: "GET", + AllowGet: false, + Code: http.StatusMethodNotAllowed, ResHeaders: map[string]string{ - "Allow": "POST", + "Allow": "POST, HEAD, OPTIONS", }, } tc.test(t) diff --git a/http/handler.go b/http/handler.go index a5ccb053..b88e4e63 100644 --- a/http/handler.go +++ b/http/handler.go @@ -97,10 +97,27 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // First of all, check if we are allowed to handle the request method // or we are configured not to. - if !handleRequestMethod(r, h.cfg) { - setAllowedHeaders(w, h.cfg.HandledMethods) + // + // Always allow OPTIONS, POST + switch r.Method { + case http.MethodOptions: + // If we get here, this is a normal (non-preflight) request. + // The CORS library handles all other requests. + + // Tell the user the allowed methods, and return. + setAllowedHeaders(w, h.cfg.AllowGet) + w.WriteHeader(http.StatusNoContent) + return + case http.MethodPost: + case http.MethodGet, http.MethodHead: + if h.cfg.AllowGet { + break + } + fallthrough + default: + setAllowedHeaders(w, h.cfg.AllowGet) http.Error(w, "405 - Method Not Allowed", http.StatusMethodNotAllowed) - log.Warnf("The IPFS API does not support %s requests. All requests must use %s", h.cfg.HandledMethods) + log.Warnf("The IPFS API does not support %s requests.", r.Method) return } @@ -139,6 +156,13 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + // set user's headers first. + for k, v := range h.cfg.Headers { + if !skipAPIHeader(k) { + w.Header()[k] = v + } + } + // Handle the timeout up front. var cancel func() if timeoutStr, ok := req.Options[cmds.TimeoutOpt]; ok { @@ -163,13 +187,6 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { defer done() } - // set user's headers first. - for k, v := range h.cfg.Headers { - if !skipAPIHeader(k) { - w.Header()[k] = v - } - } - h.root.Call(req, re, h.env) } @@ -180,8 +197,11 @@ func sanitizedErrStr(err error) string { return s } -func setAllowedHeaders(w http.ResponseWriter, methods []string) { - for _, m := range methods { - w.Header().Add("Allow", m) +func setAllowedHeaders(w http.ResponseWriter, allowGet bool) { + w.Header().Add("Allow", http.MethodHead) + w.Header().Add("Allow", http.MethodOptions) + w.Header().Add("Allow", http.MethodPost) + if allowGet { + w.Header().Add("Allow", http.MethodGet) } } diff --git a/http/handler_test.go b/http/handler_test.go index 9c268a81..ba8dd147 100644 --- a/http/handler_test.go +++ b/http/handler_test.go @@ -292,7 +292,7 @@ var ( } ) -func getTestServer(t *testing.T, origins []string, handledMethods []string) (cmds.Environment, *httptest.Server) { +func getTestServer(t *testing.T, origins []string, allowGet bool) (cmds.Environment, *httptest.Server) { if len(origins) == 0 { origins = defaultOrigins } @@ -306,12 +306,7 @@ func getTestServer(t *testing.T, origins []string, handledMethods []string) (cmd } srvCfg := originCfg(origins) - - if len(handledMethods) == 0 { - srvCfg.HandledMethods = []string{"GET", "POST"} - } else { - srvCfg.HandledMethods = handledMethods - } + srvCfg.AllowGet = allowGet return env, httptest.NewServer(NewHandler(env, cmdRoot, srvCfg)) } diff --git a/http/http_test.go b/http/http_test.go index 4792d533..66aa85fa 100644 --- a/http/http_test.go +++ b/http/http_test.go @@ -88,7 +88,7 @@ func TestHTTP(t *testing.T) { mkTest := func(tc testcase) func(*testing.T) { return func(t *testing.T) { - env, srv := getTestServer(t, nil, nil) // handler_test:/^func getTestServer/ + env, srv := getTestServer(t, nil, true) // handler_test:/^func getTestServer/ c := NewClient(srv.URL) req, err := cmds.NewRequest(context.Background(), tc.path, nil, nil, nil, cmdRoot) if err != nil { diff --git a/http/reforigin_test.go b/http/reforigin_test.go index 08958ccc..e7dde10c 100644 --- a/http/reforigin_test.go +++ b/http/reforigin_test.go @@ -4,15 +4,42 @@ import ( "fmt" "net/http" "net/url" + "strings" "testing" cmds "github.com/ipfs/go-ipfs-cmds" ) func assertHeaders(t *testing.T, resHeaders http.Header, reqHeaders map[string]string) { + t.Helper() + t.Logf("headers: %v", resHeaders) for name, value := range reqHeaders { - if resHeaders.Get(name) != value { - t.Errorf("Invalid header '%s', wanted '%s', got '%s'", name, value, resHeaders.Get(name)) + header := resHeaders[http.CanonicalHeaderKey(name)] + switch len(header) { + case 0: + if value != "" { + t.Errorf("expected a header for %s", name) + } + case 1: + if header[0] != value { + t.Errorf("Invalid header '%s', wanted '%s', got '%s'", name, value, header[0]) + } + default: + values := strings.Split(value, ",") + set := make(map[string]bool, len(values)) + for _, v := range values { + set[strings.Trim(v, " ")] = true + } + for _, got := range header { + if !set[got] { + t.Errorf("found unexpected value %s in header %s", got, name) + continue + } + delete(set, got) + } + for missing := range set { + t.Errorf("missing value %s in header %s", missing, name) + } } } } @@ -27,7 +54,7 @@ func originCfg(origins []string) *ServerConfig { cfg := NewServerConfig() cfg.SetAllowedOrigins(origins...) cfg.SetAllowedMethods("GET", "PUT", "POST") - cfg.HandledMethods = []string{"GET", "POST"} + cfg.AllowGet = true return cfg } @@ -39,18 +66,19 @@ var defaultOrigins = []string{ } type httpTestCase struct { - Method string - Path string - Code int - Origin string - Referer string - AllowOrigins []string - HandledMethods []string - ReqHeaders map[string]string - ResHeaders map[string]string + Method string + Path string + Code int + Origin string + Referer string + AllowOrigins []string + AllowGet bool + ReqHeaders map[string]string + ResHeaders map[string]string } func (tc *httpTestCase) test(t *testing.T) { + t.Helper() // defaults method := tc.Method if method == "" { @@ -85,7 +113,7 @@ func (tc *httpTestCase) test(t *testing.T) { } // server - _, server := getTestServer(t, tc.AllowOrigins, tc.HandledMethods) + _, server := getTestServer(t, tc.AllowOrigins, tc.AllowGet) if server == nil { return } @@ -114,6 +142,7 @@ func TestDisallowedOrigins(t *testing.T) { return httpTestCase{ Origin: origin, AllowOrigins: allowedOrigins, + AllowGet: true, ResHeaders: map[string]string{ ACAOrigin: "", ACAMethods: "", @@ -144,6 +173,7 @@ func TestAllowedOrigins(t *testing.T) { return httpTestCase{ Origin: origin, AllowOrigins: allowedOrigins, + AllowGet: true, ResHeaders: map[string]string{ ACAOrigin: origin, ACAMethods: "", @@ -171,6 +201,7 @@ func TestWildcardOrigin(t *testing.T) { gtc := func(origin string, allowedOrigins []string) httpTestCase { return httpTestCase{ Origin: origin, + AllowGet: true, AllowOrigins: allowedOrigins, ResHeaders: map[string]string{ ACAOrigin: "*", @@ -204,6 +235,7 @@ func TestDisallowedReferer(t *testing.T) { return httpTestCase{ Origin: "http://localhost", Referer: referer, + AllowGet: true, AllowOrigins: allowedOrigins, ResHeaders: map[string]string{ ACAOrigin: "http://localhost", @@ -232,6 +264,7 @@ func TestAllowedReferer(t *testing.T) { return httpTestCase{ Origin: "http://localhost", AllowOrigins: allowedOrigins, + AllowGet: true, ResHeaders: map[string]string{ ACAOrigin: "http://localhost", ACAMethods: "", @@ -260,6 +293,7 @@ func TestWildcardReferer(t *testing.T) { return httpTestCase{ Origin: origin, AllowOrigins: allowedOrigins, + AllowGet: true, ResHeaders: map[string]string{ ACAOrigin: "*", ACAMethods: "", @@ -338,6 +372,7 @@ func TestEncoding(t *testing.T) { return httpTestCase{ Method: "GET", Path: path, + AllowGet: true, Origin: "http://localhost", AllowOrigins: []string{"*"}, ReqHeaders: map[string]string{ diff --git a/http/responseemitter.go b/http/responseemitter.go index 27a16317..a405cec3 100644 --- a/http/responseemitter.go +++ b/http/responseemitter.go @@ -106,7 +106,7 @@ func (re *responseEmitter) Emit(value interface{}) error { var err error // return immediately if this is a head request - if re.method == "HEAD" { + if re.method == http.MethodHead { return nil }