diff --git a/api/api.go b/api/api.go index 5617293e4498..0ff9a22c2348 100644 --- a/api/api.go +++ b/api/api.go @@ -120,8 +120,8 @@ func DefaultConfig() *Config { HttpClient: http.DefaultClient, } - if len(os.Getenv("CONSUL_HTTP_ADDR")) > 0 { - config.Address = os.Getenv("CONSUL_HTTP_ADDR") + if addr := os.Getenv("CONSUL_HTTP_ADDR"); addr != "" { + config.Address = addr } return config @@ -137,11 +137,7 @@ func NewClient(config *Config) (*Client, error) { // bootstrap the config defConfig := DefaultConfig() - switch { - case len(config.Address) != 0: - case len(os.Getenv("CONSUL_HTTP_ADDR")) > 0: - config.Address = os.Getenv("CONSUL_HTTP_ADDR") - default: + if len(config.Address) == 0 { config.Address = defConfig.Address } @@ -153,14 +149,15 @@ func NewClient(config *Config) (*Client, error) { config.HttpClient = defConfig.HttpClient } - if strings.HasPrefix(config.Address, "unix://") { - shortStr := strings.TrimPrefix(config.Address, "unix://") - t := &http.Transport{} - t.Dial = func(_, _ string) (net.Conn, error) { - return net.Dial("unix", shortStr) + if parts := strings.SplitN(config.Address, "unix://", 2); len(parts) == 2 { + config.HttpClient = &http.Client{ + Transport: &http.Transport{ + Dial: func(_, _ string) (net.Conn, error) { + return net.Dial("unix", parts[1]) + }, + }, } - config.HttpClient.Transport = t - config.Address = shortStr + config.Address = parts[1] } client := &Client{ diff --git a/api/api_test.go b/api/api_test.go index 488fcb1ee0ff..dc25f8b5a46f 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -8,6 +8,8 @@ import ( "net/http" "os" "os/exec" + "path/filepath" + "runtime" "testing" "time" @@ -42,6 +44,10 @@ type testServerConfig struct { Ports testPortConfig `json:"ports,omitempty"` } +// Callback functions for modifying config +type configCallback func(c *Config) +type serverConfigCallback func(c *testServerConfig) + func defaultConfig() *testServerConfig { return &testServerConfig{ Bootstrap: true, @@ -72,7 +78,7 @@ func newTestServer(t *testing.T) *testServer { return newTestServerWithConfig(t, func(c *testServerConfig) {}) } -func newTestServerWithConfig(t *testing.T, cb func(c *testServerConfig)) *testServer { +func newTestServerWithConfig(t *testing.T, cb serverConfigCallback) *testServer { if path, err := exec.LookPath("consul"); err != nil || path == "" { t.Log("consul not found on $PATH, skipping") t.SkipNow() @@ -131,15 +137,20 @@ func makeClient(t *testing.T) (*Client, *testServer) { }, func(c *testServerConfig) {}) } -func makeClientWithConfig(t *testing.T, clientConfig func(c *Config), serverConfig func(c *testServerConfig)) (*Client, *testServer) { - server := newTestServerWithConfig(t, serverConfig) +func makeClientWithConfig(t *testing.T, cb1 configCallback, cb2 serverConfigCallback) (*Client, *testServer) { + // Make client config conf := DefaultConfig() - clientConfig(conf) + cb1(conf) + + // Create client client, err := NewClient(conf) if err != nil { t.Fatalf("err: %v", err) } + // Create server + server := newTestServerWithConfig(t, cb2) + // Allow the server some time to start, and verify we have a leader. testutil.WaitForResult(func() (bool, error) { req := client.newRequest("GET", "/v1/catalog/nodes") @@ -278,3 +289,35 @@ func TestParseQueryMeta(t *testing.T) { t.Fatalf("Bad: %v", qm) } } + +func TestAPI_UnixSocket(t *testing.T) { + if runtime.GOOS == "windows" { + t.SkipNow() + } + + tempDir, err := ioutil.TempDir("", "consul") + if err != nil { + t.Fatalf("err: %s", err) + } + defer os.RemoveAll(tempDir) + socket := filepath.Join(tempDir, "test.sock") + + c, s := makeClientWithConfig(t, func(c *Config) { + c.Address = "unix://" + socket + }, func(c *testServerConfig) { + c.Addresses = &testAddressConfig{ + HTTP: "unix://" + socket, + } + }) + defer s.stop() + + agent := c.Agent() + + info, err := agent.Self() + if err != nil { + t.Fatalf("err: %s", err) + } + if info["Config"]["NodeName"] == "" { + t.Fatalf("bad: %v", info) + } +} diff --git a/api/status_test.go b/api/status_test.go index 5e7acd274060..096b13da090b 100644 --- a/api/status_test.go +++ b/api/status_test.go @@ -1,13 +1,10 @@ package api import ( - "io/ioutil" - "os/user" - "runtime" "testing" ) -func TestStatusLeaderTCP(t *testing.T) { +func TestStatusLeader(t *testing.T) { c, s := makeClient(t) defer s.stop() @@ -22,48 +19,6 @@ func TestStatusLeaderTCP(t *testing.T) { } } -func TestStatusLeaderUnix(t *testing.T) { - if runtime.GOOS == "windows" { - t.SkipNow() - } - - tempdir, err := ioutil.TempDir("", "consul-test-") - if err != nil { - t.Fatal("Could not create a working directory") - } - - socket := "unix://" + tempdir + "/unix-http-test.sock" - - clientConfig := func(c *Config) { - c.Address = socket - } - - serverConfig := func(c *testServerConfig) { - user, err := user.Current() - if err != nil { - t.Fatal("Could not get current user") - } - - if c.Addresses == nil { - c.Addresses = &testAddressConfig{} - } - c.Addresses.HTTP = socket + ";" + user.Uid + ";" + user.Gid + ";640" - } - - c, s := makeClientWithConfig(t, clientConfig, serverConfig) - defer s.stop() - - status := c.Status() - - leader, err := status.Leader() - if err != nil { - t.Fatalf("err: %v", err) - } - if leader == "" { - t.Fatalf("Expected leader") - } -} - func TestStatusPeers(t *testing.T) { c, s := makeClient(t) defer s.stop() diff --git a/command/agent/agent.go b/command/agent/agent.go index 778d071354a3..d5ef05af4c35 100644 --- a/command/agent/agent.go +++ b/command/agent/agent.go @@ -22,6 +22,13 @@ const ( // Path to save local agent checks checksDir = "checks" + + // errSocketFileExists is the human-friendly error message displayed when + // trying to bind a socket to an existing file. + errSocketFileExists = "A file exists at the requested socket path %q. " + + "If Consul was not shut down properly, the socket file may " + + "be left behind. If the path looks correct, remove the file " + + "and try again." ) /* diff --git a/command/agent/agent_test.go b/command/agent/agent_test.go index 91cb5c1ef51f..0add36702c70 100644 --- a/command/agent/agent_test.go +++ b/command/agent/agent_test.go @@ -7,10 +7,8 @@ import ( "io" "io/ioutil" "os" - "os/user" "path/filepath" "reflect" - "runtime" "sync/atomic" "testing" "time" @@ -125,7 +123,7 @@ func TestAgentStartStop(t *testing.T) { } } -func TestAgent_RPCPingTCP(t *testing.T) { +func TestAgent_RPCPing(t *testing.T) { dir, agent := makeAgent(t, nextConfig()) defer os.RemoveAll(dir) defer agent.Shutdown() @@ -136,35 +134,6 @@ func TestAgent_RPCPingTCP(t *testing.T) { } } -func TestAgent_RPCPingUnix(t *testing.T) { - if runtime.GOOS == "windows" { - t.SkipNow() - } - - nextConf := nextConfig() - - tempdir, err := ioutil.TempDir("", "consul-test-") - if err != nil { - t.Fatal("Could not create a working directory") - } - - user, err := user.Current() - if err != nil { - t.Fatal("Could not get current user") - } - - nextConf.Addresses.RPC = "unix://" + tempdir + "/unix-rpc-test.sock;" + user.Uid + ";" + user.Gid + ";640" - - dir, agent := makeAgent(t, nextConf) - defer os.RemoveAll(dir) - defer agent.Shutdown() - - var out struct{} - if err := agent.RPC("Status.Ping", struct{}{}, &out); err != nil { - t.Fatalf("err: %v", err) - } -} - func TestAgent_AddService(t *testing.T) { dir, agent := makeAgent(t, nextConfig()) defer os.RemoveAll(dir) diff --git a/command/agent/command.go b/command/agent/command.go index b9a82e19aeb7..17eb5c577930 100644 --- a/command/agent/command.go +++ b/command/agent/command.go @@ -295,9 +295,12 @@ func (c *Command) setupAgent(config *Config, logOutput io.Writer, logWriter *log return err } - if _, ok := rpcAddr.(*net.UnixAddr); ok { - // Remove the socket if it exists, or we'll get a bind error - _ = os.Remove(rpcAddr.String()) + // Error if we are trying to bind a domain socket to an existing path + if path, ok := unixSocketAddr(config.Addresses.RPC); ok { + if _, err := os.Stat(path); err == nil || !os.IsNotExist(err) { + c.Ui.Output(fmt.Sprintf(errSocketFileExists, path)) + return fmt.Errorf(errSocketFileExists, path) + } } rpcListener, err := net.Listen(rpcAddr.Network(), rpcAddr.String()) @@ -307,14 +310,6 @@ func (c *Command) setupAgent(config *Config, logOutput io.Writer, logWriter *log return err } - if _, ok := rpcAddr.(*net.UnixAddr); ok { - if err := adjustUnixSocketPermissions(config.Addresses.RPC); err != nil { - agent.Shutdown() - c.Ui.Error(fmt.Sprintf("Error adjusting Unix socket permissions: %s", err)) - return err - } - } - // Start the IPC layer c.Ui.Output("Starting Consul agent RPC...") c.rpcServer = NewAgentRPC(agent, rpcListener, logOutput, logWriter) diff --git a/command/agent/config.go b/command/agent/config.go index 74ed4c3ad965..92b9e64d3794 100644 --- a/command/agent/config.go +++ b/command/agent/config.go @@ -7,11 +7,8 @@ import ( "io" "net" "os" - "os/user" "path/filepath" - "regexp" "sort" - "strconv" "strings" "time" @@ -348,89 +345,13 @@ type Config struct { WatchPlans []*watch.WatchPlan `mapstructure:"-" json:"-"` } -// UnixSocket contains the parameters for a Unix socket interface -type UnixSocket struct { - // Path to the socket on-disk - Path string - - // uid of the owner of the socket - Uid int - - // gid of the group of the socket - Gid int - - // Permissions for the socket file - Permissions os.FileMode -} - -func populateUnixSocket(addr string) (*UnixSocket, error) { +// unixSocketAddr tests if a given address describes a domain socket, +// and returns the relevant path part of the string if it is. +func unixSocketAddr(addr string) (string, bool) { if !strings.HasPrefix(addr, "unix://") { - return nil, fmt.Errorf("Failed to parse Unix address, format is unix://[path];[user];[group];[mode]: %v", addr) - } - - splitAddr := strings.Split(strings.TrimPrefix(addr, "unix://"), ";") - if len(splitAddr) != 4 { - return nil, fmt.Errorf("Failed to parse Unix address, format is unix://[path];[user];[group];[mode]: %v", addr) - } - - ret := &UnixSocket{Path: splitAddr[0]} - - var userVal *user.User - var err error - - regex := regexp.MustCompile("[\\d]+") - if regex.MatchString(splitAddr[1]) { - userVal, err = user.LookupId(splitAddr[1]) - } else { - userVal, err = user.Lookup(splitAddr[1]) - } - if err != nil { - return nil, fmt.Errorf("Invalid user given for Unix socket ownership: %v", splitAddr[1]) - } - - if uid64, err := strconv.ParseInt(userVal.Uid, 10, 32); err != nil { - return nil, fmt.Errorf("Failed to parse given user ID of %v into integer", userVal.Uid) - } else { - ret.Uid = int(uid64) - } - - // Go doesn't currently have a way to look up gid from group name, - // so require a numeric gid; see - // https://codereview.appspot.com/101310044 - if gid64, err := strconv.ParseInt(splitAddr[2], 10, 32); err != nil { - return nil, fmt.Errorf("Socket group must be given as numeric gid. Failed to parse given group ID of %v into integer", splitAddr[2]) - } else { - ret.Gid = int(gid64) - } - - if mode, err := strconv.ParseUint(splitAddr[3], 8, 32); err != nil { - return nil, fmt.Errorf("Failed to parse given mode of %v into integer", splitAddr[3]) - } else { - if mode > 0777 { - return nil, fmt.Errorf("Given mode is invalid; must be an octal number between 0 and 777") - } else { - ret.Permissions = os.FileMode(mode) - } - } - - return ret, nil -} - -func adjustUnixSocketPermissions(addr string) error { - sock, err := populateUnixSocket(addr) - if err != nil { - return err + return "", false } - - if err = os.Chown(sock.Path, sock.Uid, sock.Gid); err != nil { - return fmt.Errorf("Error attempting to change socket permissions to userid %v and groupid %v: %v", sock.Uid, sock.Gid, err) - } - - if err = os.Chmod(sock.Path, sock.Permissions); err != nil { - return fmt.Errorf("Error attempting to change socket permissions to mode %v: %v", sock.Permissions, err) - } - - return nil + return strings.TrimPrefix(addr, "unix://"), true } type dirEnts []os.FileInfo @@ -485,31 +406,14 @@ func (c *Config) ClientListener(override string, port int) (net.Addr, error) { addr = c.ClientAddr } - switch { - case strings.HasPrefix(addr, "unix://"): - sock, err := populateUnixSocket(addr) - if err != nil { - return nil, err - } - - return &net.UnixAddr{Name: sock.Path, Net: "unix"}, nil - - default: - ip := net.ParseIP(addr) - if ip == nil { - return nil, fmt.Errorf("Failed to parse IP: %v", addr) - } - - if ip.IsUnspecified() { - ip = net.ParseIP("127.0.0.1") - } - - if ip == nil { - return nil, fmt.Errorf("Failed to parse IP 127.0.0.1") - } - - return &net.TCPAddr{IP: ip, Port: port}, nil + if path, ok := unixSocketAddr(addr); ok { + return &net.UnixAddr{Name: path, Net: "unix"}, nil + } + ip := net.ParseIP(addr) + if ip == nil { + return nil, fmt.Errorf("Failed to parse IP: %v", addr) } + return &net.TCPAddr{IP: ip, Port: port}, nil } // DecodeConfig reads the configuration from the given reader in JSON diff --git a/command/agent/config_test.go b/command/agent/config_test.go index f10c5b72422e..fa7bf6f274f6 100644 --- a/command/agent/config_test.go +++ b/command/agent/config_test.go @@ -4,12 +4,9 @@ import ( "bytes" "encoding/base64" "io/ioutil" - "net" "os" - "os/user" "path/filepath" "reflect" - "runtime" "strings" "testing" "time" @@ -1073,107 +1070,13 @@ func TestReadConfigPaths_dir(t *testing.T) { } func TestUnixSockets(t *testing.T) { - if runtime.GOOS == "windows" { - t.SkipNow() + path1, ok := unixSocketAddr("unix:///path/to/socket") + if !ok || path1 != "/path/to/socket" { + t.Fatalf("bad: %v %v", ok, path1) } - usr, err := user.Current() - if err != nil { - t.Fatal("Could not get current user: ", err) - } - - tempdir, err := ioutil.TempDir("", "consul-test-") - if err != nil { - t.Fatal("Could not create a working directory: ", err) - } - - type SocketTestData struct { - Path string - Uid string - Gid string - Mode string - } - - testUnixSocketPopulation := func(s SocketTestData) (*UnixSocket, error) { - return populateUnixSocket("unix://" + s.Path + ";" + s.Uid + ";" + s.Gid + ";" + s.Mode) - } - - testUnixSocketPermissions := func(s SocketTestData) error { - return adjustUnixSocketPermissions("unix://" + s.Path + ";" + s.Uid + ";" + s.Gid + ";" + s.Mode) - } - - _, err = populateUnixSocket("tcp://abc123") - if err == nil { - t.Fatal("Should have rejected invalid scheme") - } - - _, err = populateUnixSocket("unix://x;y;z") - if err == nil { - t.Fatal("Should have rejected invalid number of parameters in Unix socket definition") - } - - std := SocketTestData{ - Path: tempdir + "/unix-config-test.sock", - Uid: usr.Uid, - Gid: usr.Gid, - Mode: "640", - } - - std.Uid = "orasdfdsnfoinweroiu" - _, err = testUnixSocketPopulation(std) - if err == nil { - t.Fatal("Did not error on invalid username") - } - - std.Uid = usr.Username - std.Gid = "foinfphawepofhewof" - _, err = testUnixSocketPopulation(std) - if err == nil { - t.Fatal("Did not error on invalid group (a name, must be gid)") - } - - std.Gid = usr.Gid - std.Mode = "999" - _, err = testUnixSocketPopulation(std) - if err == nil { - t.Fatal("Did not error on invalid socket mode") - } - - std.Uid = usr.Username - std.Mode = "640" - _, err = testUnixSocketPopulation(std) - if err != nil { - t.Fatal("Unix socket test failed (using username): ", err) - } - - std.Uid = usr.Uid - sock, err := testUnixSocketPopulation(std) - if err != nil { - t.Fatal("Unix socket test failed (using uid): ", err) - } - - addr := &net.UnixAddr{Name: sock.Path, Net: "unix"} - _, err = net.Listen(addr.Network(), addr.String()) - if err != nil { - t.Fatal("Error creating socket for futher tests: ", err) - } - - std.Uid = "-999999" - err = testUnixSocketPermissions(std) - if err == nil { - t.Fatal("Did not error on invalid uid") - } - - std.Uid = usr.Uid - std.Gid = "-999999" - err = testUnixSocketPermissions(std) - if err == nil { - t.Fatal("Did not error on invalid uid") - } - - std.Gid = usr.Gid - err = testUnixSocketPermissions(std) - if err != nil { - t.Fatal("Adjusting socket permissions failed: ", err) + path2, ok := unixSocketAddr("notunix://blah") + if ok || path2 != "" { + t.Fatalf("bad: %v %v", ok, path2) } } diff --git a/command/agent/http.go b/command/agent/http.go index d480de816f52..708fe5b17413 100644 --- a/command/agent/http.go +++ b/command/agent/http.go @@ -59,28 +59,15 @@ func NewHTTPServers(agent *Agent, config *Config, logOutput io.Writer) ([]*HTTPS return nil, err } - if _, ok := httpAddr.(*net.UnixAddr); ok { - // Remove the socket if it exists, or we'll get a bind error - _ = os.Remove(httpAddr.String()) - } - ln, err := net.Listen(httpAddr.Network(), httpAddr.String()) if err != nil { return nil, fmt.Errorf("Failed to get Listen on %s: %v", httpAddr.String(), err) } - switch httpAddr.(type) { - case *net.UnixAddr: - if err := adjustUnixSocketPermissions(config.Addresses.HTTPS); err != nil { - return nil, err - } + if _, ok := unixSocketAddr(config.Addresses.HTTPS); ok { list = tls.NewListener(ln, tlsConfig) - - case *net.TCPAddr: + } else { list = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, tlsConfig) - - default: - return nil, fmt.Errorf("Error determining address type when attempting to get Listen on %s: %v", httpAddr.String(), err) } // Create the mux @@ -108,9 +95,11 @@ func NewHTTPServers(agent *Agent, config *Config, logOutput io.Writer) ([]*HTTPS return nil, fmt.Errorf("Failed to get ClientListener address:port: %v", err) } - if _, ok := httpAddr.(*net.UnixAddr); ok { - // Remove the socket if it exists, or we'll get a bind error - _ = os.Remove(httpAddr.String()) + // Error if we are trying to bind a domain socket to an existing path + if path, ok := unixSocketAddr(config.Addresses.HTTP); ok { + if _, err := os.Stat(path); err == nil || !os.IsNotExist(err) { + return nil, fmt.Errorf(errSocketFileExists, path) + } } ln, err := net.Listen(httpAddr.Network(), httpAddr.String()) @@ -118,18 +107,10 @@ func NewHTTPServers(agent *Agent, config *Config, logOutput io.Writer) ([]*HTTPS return nil, fmt.Errorf("Failed to get Listen on %s: %v", httpAddr.String(), err) } - switch httpAddr.(type) { - case *net.UnixAddr: - if err := adjustUnixSocketPermissions(config.Addresses.HTTP); err != nil { - return nil, err - } + if _, ok := unixSocketAddr(config.Addresses.HTTP); ok { list = ln - - case *net.TCPAddr: + } else { list = tcpKeepAliveListener{ln.(*net.TCPListener)} - - default: - return nil, fmt.Errorf("Error determining address type when attempting to get Listen on %s: %v", httpAddr.String(), err) } // Create the mux diff --git a/command/agent/http_test.go b/command/agent/http_test.go index eca844f1d375..4677d5ba537c 100644 --- a/command/agent/http_test.go +++ b/command/agent/http_test.go @@ -6,10 +6,12 @@ import ( "fmt" "io" "io/ioutil" + "net" "net/http" "net/http/httptest" "os" "path/filepath" + "runtime" "strconv" "testing" "time" @@ -19,7 +21,15 @@ import ( ) func makeHTTPServer(t *testing.T) (string, *HTTPServer) { + return makeHTTPServerWithConfig(t, nil) +} + +func makeHTTPServerWithConfig(t *testing.T, cb func(c *Config)) (string, *HTTPServer) { conf := nextConfig() + if cb != nil { + cb(conf) + } + dir, agent := makeAgent(t, conf) uiDir := filepath.Join(dir, "ui") if err := os.Mkdir(uiDir, 755); err != nil { @@ -43,6 +53,93 @@ func encodeReq(obj interface{}) io.ReadCloser { return ioutil.NopCloser(buf) } +func TestHTTPServer_UnixSocket(t *testing.T) { + if runtime.GOOS == "windows" { + t.SkipNow() + } + + tempDir, err := ioutil.TempDir("", "consul") + if err != nil { + t.Fatalf("err: %s", err) + } + defer os.RemoveAll(tempDir) + socket := filepath.Join(tempDir, "test.sock") + + dir, srv := makeHTTPServerWithConfig(t, func(c *Config) { + c.Addresses.HTTP = "unix://" + socket + }) + defer os.RemoveAll(dir) + defer srv.Shutdown() + defer srv.agent.Shutdown() + + // Ensure the socket was created + if _, err := os.Stat(socket); err != nil { + t.Fatalf("err: %s", err) + } + + // Ensure we can get a response from the socket. + path, _ := unixSocketAddr(srv.agent.config.Addresses.HTTP) + client := &http.Client{ + Transport: &http.Transport{ + Dial: func(_, _ string) (net.Conn, error) { + return net.Dial("unix", path) + }, + }, + } + + // This URL doesn't look like it makes sense, but the scheme (http://) and + // the host (127.0.0.1) are required by the HTTP client library. In reality + // this will just use the custom dialer and talk to the socket. + resp, err := client.Get("http://127.0.0.1/v1/agent/self") + if err != nil { + t.Fatalf("err: %s", err) + } + defer resp.Body.Close() + + if body, err := ioutil.ReadAll(resp.Body); err != nil || len(body) == 0 { + t.Fatalf("bad: %s %v", body, err) + } +} + +func TestHTTPServer_UnixSocket_FileExists(t *testing.T) { + if runtime.GOOS == "windows" { + t.SkipNow() + } + + tempDir, err := ioutil.TempDir("", "consul") + if err != nil { + t.Fatalf("err: %s", err) + } + defer os.RemoveAll(tempDir) + socket := filepath.Join(tempDir, "test.sock") + + // Create a regular file at the socket path + if err := ioutil.WriteFile(socket, []byte("hello world"), 0644); err != nil { + t.Fatalf("err: %s", err) + } + fi, err := os.Stat(socket) + if err != nil { + t.Fatalf("err: %s", err) + } + if !fi.Mode().IsRegular() { + t.Fatalf("not a regular file: %s", socket) + } + + conf := nextConfig() + conf.Addresses.HTTP = "unix://" + socket + + dir, agent := makeAgent(t, conf) + defer os.RemoveAll(dir) + + // Try to start the server with the same path anyways. + if servers, err := NewHTTPServers(agent, conf, agent.logOutput); err == nil { + for _, server := range servers { + server.Shutdown() + } + t.Fatalf("expected socket binding error") + } +} + func TestSetIndex(t *testing.T) { resp := httptest.NewRecorder() setIndex(resp, 1000) diff --git a/command/agent/rpc_client.go b/command/agent/rpc_client.go index 1490674f146f..cbc9689cfb82 100644 --- a/command/agent/rpc_client.go +++ b/command/agent/rpc_client.go @@ -81,24 +81,19 @@ func (c *RPCClient) send(header *requestHeader, obj interface{}) error { // NewRPCClient is used to create a new RPC client given the address. // This will properly dial, handshake, and start listening func NewRPCClient(addr string) (*RPCClient, error) { - sanedAddr := os.Getenv("CONSUL_RPC_ADDR") - if len(sanedAddr) == 0 { - sanedAddr = addr - } - - mode := "tcp" + var conn net.Conn + var err error - if strings.HasPrefix(sanedAddr, "unix://") { - sanedAddr = strings.TrimPrefix(sanedAddr, "unix://") + if envAddr := os.Getenv("CONSUL_RPC_ADDR"); envAddr != "" { + addr = envAddr } - if strings.HasPrefix(sanedAddr, "/") { + // Try to dial to agent + mode := "tcp" + if strings.HasPrefix(addr, "/") { mode = "unix" } - - // Try to dial to agent - conn, err := net.Dial(mode, sanedAddr) - if err != nil { + if conn, err = net.Dial(mode, addr); err != nil { return nil, err } diff --git a/command/agent/rpc_client_test.go b/command/agent/rpc_client_test.go index 2d8dfc9c0841..48d833564956 100644 --- a/command/agent/rpc_client_test.go +++ b/command/agent/rpc_client_test.go @@ -9,7 +9,7 @@ import ( "io/ioutil" "net" "os" - "os/user" + "path/filepath" "runtime" "strings" "testing" @@ -69,6 +69,38 @@ func testRPCClientWithConfig(t *testing.T, cb func(c *Config)) *rpcParts { } } +func TestRPCClient_UnixSocket(t *testing.T) { + if runtime.GOOS == "windows" { + t.SkipNow() + } + + tempDir, err := ioutil.TempDir("", "consul") + if err != nil { + t.Fatalf("err: %s", err) + } + defer os.RemoveAll(tempDir) + socket := filepath.Join(tempDir, "test.sock") + + p1 := testRPCClientWithConfig(t, func(c *Config) { + c.Addresses.RPC = "unix://" + socket + }) + defer p1.Close() + + // Ensure the socket was created + if _, err := os.Stat(socket); err != nil { + t.Fatalf("err: %s", err) + } + + // Ensure we can talk with the socket + mem, err := p1.client.LANMembers() + if err != nil { + t.Fatalf("err: %s", err) + } + if len(mem) != 1 { + t.Fatalf("bad: %#v", mem) + } +} + func TestRPCClientForceLeave(t *testing.T) { p1 := testRPCClient(t) p2 := testRPCClient(t) @@ -216,41 +248,6 @@ func TestRPCClientStats(t *testing.T) { } } -func TestRPCClientStatsUnix(t *testing.T) { - if runtime.GOOS == "windows" { - t.SkipNow() - } - - tempdir, err := ioutil.TempDir("", "consul-test-") - if err != nil { - t.Fatal("Could not create a working directory: ", err) - } - - user, err := user.Current() - if err != nil { - t.Fatal("Could not get current user: ", err) - } - - cb := func(c *Config) { - c.Addresses.RPC = "unix://" + tempdir + "/unix-rpc-test.sock;" + user.Uid + ";" + user.Gid + ";640" - } - - p1 := testRPCClientWithConfig(t, cb) - - stats, err := p1.client.Stats() - if err != nil { - t.Fatalf("err: %s", err) - } - - if _, ok := stats["agent"]; !ok { - t.Fatalf("bad: %#v", stats) - } - - if _, ok := stats["consul"]; !ok { - t.Fatalf("bad: %#v", stats) - } -} - func TestRPCClientLeave(t *testing.T) { p1 := testRPCClient(t) defer p1.Close() diff --git a/command/rpc.go b/command/rpc.go index f0c9e5b1fabf..84d41e822ef5 100644 --- a/command/rpc.go +++ b/command/rpc.go @@ -43,12 +43,10 @@ func HTTPClient(addr string) (*consulapi.Client, error) { // HTTPClientDC returns a new Consul HTTP client with the given address and datacenter func HTTPClientDC(addr, dc string) (*consulapi.Client, error) { conf := consulapi.DefaultConfig() - switch { - case len(os.Getenv("CONSUL_HTTP_ADDR")) > 0: - conf.Address = os.Getenv("CONSUL_HTTP_ADDR") - default: - conf.Address = addr + if envAddr := os.Getenv("CONSUL_HTTP_ADDR"); envAddr != "" { + addr = envAddr } + conf.Address = addr conf.Datacenter = dc return consulapi.NewClient(conf) } diff --git a/website/source/docs/agent/options.html.markdown b/website/source/docs/agent/options.html.markdown index 6aa6b61d872e..2ecc5e2db7e5 100644 --- a/website/source/docs/agent/options.html.markdown +++ b/website/source/docs/agent/options.html.markdown @@ -239,20 +239,23 @@ definitions support being updated during a reload. However, because the caches are not actively invalidated, ACL policy may be stale up to the TTL value. -* `addresses` - This is a nested object that allows setting bind addresses. For `rpc` - and `http`, a Unix socket can be specified in the following form: - unix://[/path/to/socket];[username|uid];[gid];[mode]. The socket will be created - in the specified location with the given username or uid, gid, and mode. The - user Consul is running as must have appropriate permissions to change the socket - ownership to the given uid or gid. When running Consul agent commands against - Unix socket interfaces, use the `-rpc-addr` or `-http-addr` arguments to specify - the path to the socket, e.g. "unix://path/to/socket". You can also place the desired - values in `CONSUL_RPC_ADDR` and `CONSUL_HTTP_ADDR` environment variables. For TCP - addresses, these should be in the form ip:port. - The following keys are valid: - * `dns` - The DNS server. Defaults to `client_addr` - * `http` - The HTTP API. Defaults to `client_addr` - * `rpc` - The RPC endpoint. Defaults to `client_addr` +* `addresses` - This is a nested object that allows setting bind addresses. +

+ Both `rpc` and `http` support binding to Unix domain sockets. A socket can be + specified in the form `unix:///path/to/socket`. A new domain socket will be + created at the given path. If the specified file path already exists, Consul + will refuse to start and return an error. For information on how to secure + socket file permissions, refer to the manual page for your operating system. +

+ When running Consul agent commands against Unix socket interfaces, use the + `-rpc-addr` or `-http-addr` arguments to specify the path to the socket. You + can also place the desired values in `CONSUL_RPC_ADDR` and `CONSUL_HTTP_ADDR` + environment variables. For TCP addresses, these should be in the form ip:port. +

+ The following keys are valid: + * `dns` - The DNS server. Defaults to `client_addr` + * `http` - The HTTP API. Defaults to `client_addr` + * `rpc` - The RPC endpoint. Defaults to `client_addr` * `advertise_addr` - Equivalent to the `-advertise` command-line flag.