diff --git a/command/agent/agent_endpoint.go b/command/agent/agent_endpoint.go index 5e348941a404..9c02611c3a9a 100644 --- a/command/agent/agent_endpoint.go +++ b/command/agent/agent_endpoint.go @@ -452,6 +452,14 @@ func (s *HTTPServer) AgentRegisterService(resp http.ResponseWriter, req *http.Re return nil, nil } + // Check the service address here and in the catalog RPC endpoint + // since service registration isn't sychronous. + if args.Address == "0.0.0.0" { + resp.WriteHeader(400) + fmt.Fprintf(resp, "Invalid service address") + return nil, nil + } + // Get the node service. ns := args.NodeService() diff --git a/command/agent/agent_endpoint_test.go b/command/agent/agent_endpoint_test.go index e348271e93cc..dc40b32d56a2 100644 --- a/command/agent/agent_endpoint_test.go +++ b/command/agent/agent_endpoint_test.go @@ -1497,6 +1497,33 @@ func TestAgent_RegisterService_ACLDeny(t *testing.T) { } } +func TestAgent_RegisterService_InvalidAddress(t *testing.T) { + dir, srv := makeHTTPServer(t) + defer os.RemoveAll(dir) + defer srv.Shutdown() + defer srv.agent.Shutdown() + + req, err := http.NewRequest("GET", "/v1/agent/service/register?token=abc123", nil) + if err != nil { + t.Fatalf("err: %v", err) + } + args := &ServiceDefinition{ + Name: "test", + Address: "0.0.0.0", + Port: 8000, + } + req.Body = encodeReq(args) + + resp := httptest.NewRecorder() + _, err = srv.AgentRegisterService(resp, req) + if got, want := resp.Code, 400; got != want { + t.Fatalf("got code %d want %d", got, want) + } + if got, want := resp.Body.String(), "Invalid service address"; got != want { + t.Fatalf("got body %q want %q", got, want) + } +} + func TestAgent_DeregisterService(t *testing.T) { dir, srv := makeHTTPServer(t) defer os.RemoveAll(dir) diff --git a/command/agent/catalog_endpoint_test.go b/command/agent/catalog_endpoint_test.go index 535dc8b1497c..6d04a94a132a 100644 --- a/command/agent/catalog_endpoint_test.go +++ b/command/agent/catalog_endpoint_test.go @@ -55,6 +55,36 @@ func TestCatalogRegister(t *testing.T) { } } +func TestCatalogRegister_Service_InvalidAddress(t *testing.T) { + dir, srv := makeHTTPServer(t) + defer os.RemoveAll(dir) + defer srv.Shutdown() + defer srv.agent.Shutdown() + + testrpc.WaitForLeader(t, srv.agent.RPC, "dc1") + + // Register node + req, err := http.NewRequest("GET", "/v1/catalog/register", nil) + if err != nil { + t.Fatalf("err: %v", err) + } + args := &structs.RegisterRequest{ + Node: "foo", + Address: "127.0.0.1", + Service: &structs.NodeService{ + Service: "test", + Address: "0.0.0.0", + Port: 8080, + }, + } + req.Body = encodeReq(args) + + _, err = srv.CatalogRegister(nil, req) + if err == nil || err.Error() != "Invalid service address" { + t.Fatalf("err: %v", err) + } +} + func TestCatalogDeregister(t *testing.T) { dir, srv := makeHTTPServer(t) defer os.RemoveAll(dir) diff --git a/consul/catalog_endpoint.go b/consul/catalog_endpoint.go index b329060ce0cf..4e5424537283 100644 --- a/consul/catalog_endpoint.go +++ b/consul/catalog_endpoint.go @@ -52,6 +52,12 @@ func (c *Catalog) Register(args *structs.RegisterRequest, reply *struct{}) error return fmt.Errorf("Must provide service name with ID") } + // Check the service address here and in the agent endpoint + // since service registration isn't sychronous. + if args.Service.Address == "0.0.0.0" { + return fmt.Errorf("Invalid service address") + } + // Apply the ACL policy if any. The 'consul' service is excluded // since it is managed automatically internally (that behavior // is going away after version 0.8). We check this same policy diff --git a/consul/catalog_endpoint_test.go b/consul/catalog_endpoint_test.go index f27207f7ed25..9737ce2fe245 100644 --- a/consul/catalog_endpoint_test.go +++ b/consul/catalog_endpoint_test.go @@ -46,6 +46,31 @@ func TestCatalog_Register(t *testing.T) { } } +func TestCatalog_RegisterService_InvalidAddress(t *testing.T) { + dir1, s1 := testServer(t) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + codec := rpcClient(t, s1) + defer codec.Close() + + arg := structs.RegisterRequest{ + Datacenter: "dc1", + Node: "foo", + Address: "127.0.0.1", + Service: &structs.NodeService{ + Service: "db", + Address: "0.0.0.0", + Port: 8000, + }, + } + var out struct{} + + err := msgpackrpc.CallWithCodec(codec, "Catalog.Register", &arg, &out) + if err == nil || err.Error() != "Invalid service address" { + t.Fatalf("got error %v want 'Invalid service address'", err) + } +} + func TestCatalog_Register_NodeID(t *testing.T) { dir1, s1 := testServer(t) defer os.RemoveAll(dir1)