diff --git a/api/api.go b/api/api.go index 9506b4e1f4f0..5617293e4498 100644 --- a/api/api.go +++ b/api/api.go @@ -5,9 +5,12 @@ import ( "encoding/json" "fmt" "io" + "net" "net/http" "net/url" + "os" "strconv" + "strings" "time" ) @@ -111,11 +114,17 @@ type Config struct { // DefaultConfig returns a default configuration for the client func DefaultConfig() *Config { - return &Config{ + config := &Config{ Address: "127.0.0.1:8500", Scheme: "http", HttpClient: http.DefaultClient, } + + if len(os.Getenv("CONSUL_HTTP_ADDR")) > 0 { + config.Address = os.Getenv("CONSUL_HTTP_ADDR") + } + + return config } // Client provides a client to the Consul API @@ -128,7 +137,11 @@ func NewClient(config *Config) (*Client, error) { // bootstrap the config defConfig := DefaultConfig() - if len(config.Address) == 0 { + switch { + case len(config.Address) != 0: + case len(os.Getenv("CONSUL_HTTP_ADDR")) > 0: + config.Address = os.Getenv("CONSUL_HTTP_ADDR") + default: config.Address = defConfig.Address } @@ -140,6 +153,16 @@ func NewClient(config *Config) (*Client, error) { config.HttpClient = defConfig.HttpClient } + if strings.HasPrefix(config.Address, "unix://") { + shortStr := strings.TrimPrefix(config.Address, "unix://") + t := &http.Transport{} + t.Dial = func(_, _ string) (net.Conn, error) { + return net.Dial("unix", shortStr) + } + config.HttpClient.Transport = t + config.Address = shortStr + } + client := &Client{ config: *config, } @@ -206,9 +229,6 @@ func (r *request) toHTTP() (*http.Request, error) { // Encode the query parameters r.url.RawQuery = r.params.Encode() - // Get the url sring - urlRaw := r.url.String() - // Check if we should encode the body if r.body == nil && r.obj != nil { if b, err := encodeBody(r.obj); err != nil { @@ -219,14 +239,21 @@ func (r *request) toHTTP() (*http.Request, error) { } // Create the HTTP request - req, err := http.NewRequest(r.method, urlRaw, r.body) + req, err := http.NewRequest(r.method, r.url.RequestURI(), r.body) + if err != nil { + return nil, err + } + + req.URL.Host = r.url.Host + req.URL.Scheme = r.url.Scheme + req.Host = r.url.Host // Setup auth - if err == nil && r.config.HttpAuth != nil { + if r.config.HttpAuth != nil { req.SetBasicAuth(r.config.HttpAuth.Username, r.config.HttpAuth.Password) } - return req, err + return req, nil } // newRequest is used to create a new request diff --git a/api/api_test.go b/api/api_test.go index 4b697ba58fd0..488fcb1ee0ff 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -2,6 +2,7 @@ package api import ( crand "crypto/rand" + "encoding/json" "fmt" "io/ioutil" "net/http" @@ -13,27 +14,50 @@ import ( "github.com/hashicorp/consul/testutil" ) -var consulConfig = `{ - "ports": { - "dns": 19000, - "http": 18800, - "rpc": 18600, - "serf_lan": 18200, - "serf_wan": 18400, - "server": 18000 - }, - "data_dir": "%s", - "bootstrap": true, - "log_level": "debug", - "server": true -}` - type testServer struct { pid int dataDir string configFile string } +type testPortConfig struct { + DNS int `json:"dns,omitempty"` + HTTP int `json:"http,omitempty"` + RPC int `json:"rpc,omitempty"` + SerfLan int `json:"serf_lan,omitempty"` + SerfWan int `json:"serf_wan,omitempty"` + Server int `json:"server,omitempty"` +} + +type testAddressConfig struct { + HTTP string `json:"http,omitempty"` +} + +type testServerConfig struct { + Bootstrap bool `json:"bootstrap,omitempty"` + Server bool `json:"server,omitempty"` + DataDir string `json:"data_dir,omitempty"` + LogLevel string `json:"log_level,omitempty"` + Addresses *testAddressConfig `json:"addresses,omitempty"` + Ports testPortConfig `json:"ports,omitempty"` +} + +func defaultConfig() *testServerConfig { + return &testServerConfig{ + Bootstrap: true, + Server: true, + LogLevel: "debug", + Ports: testPortConfig{ + DNS: 19000, + HTTP: 18800, + RPC: 18600, + SerfLan: 18200, + SerfWan: 18400, + Server: 18000, + }, + } +} + func (s *testServer) stop() { defer os.RemoveAll(s.dataDir) defer os.RemoveAll(s.configFile) @@ -45,6 +69,10 @@ func (s *testServer) stop() { } func newTestServer(t *testing.T) *testServer { + return newTestServerWithConfig(t, func(c *testServerConfig) {}) +} + +func newTestServerWithConfig(t *testing.T, cb func(c *testServerConfig)) *testServer { if path, err := exec.LookPath("consul"); err != nil || path == "" { t.Log("consul not found on $PATH, skipping") t.SkipNow() @@ -66,8 +94,18 @@ func newTestServer(t *testing.T) *testServer { if err != nil { t.Fatalf("err: %s", err) } - configContent := fmt.Sprintf(consulConfig, dataDir) - if _, err := configFile.WriteString(configContent); err != nil { + + consulConfig := defaultConfig() + consulConfig.DataDir = dataDir + + cb(consulConfig) + + configContent, err := json.Marshal(consulConfig) + if err != nil { + t.Fatalf("err: %s", err) + } + + if _, err := configFile.Write(configContent); err != nil { t.Fatalf("err: %s", err) } configFile.Close() @@ -80,10 +118,32 @@ func newTestServer(t *testing.T) *testServer { t.Fatalf("err: %s", err) } + return &testServer{ + pid: cmd.Process.Pid, + dataDir: dataDir, + configFile: configFile.Name(), + } +} + +func makeClient(t *testing.T) (*Client, *testServer) { + return makeClientWithConfig(t, func(c *Config) { + c.Address = "127.0.0.1:18800" + }, func(c *testServerConfig) {}) +} + +func makeClientWithConfig(t *testing.T, clientConfig func(c *Config), serverConfig func(c *testServerConfig)) (*Client, *testServer) { + server := newTestServerWithConfig(t, serverConfig) + conf := DefaultConfig() + clientConfig(conf) + client, err := NewClient(conf) + if err != nil { + t.Fatalf("err: %v", err) + } + // Allow the server some time to start, and verify we have a leader. - client := new(http.Client) testutil.WaitForResult(func() (bool, error) { - resp, err := client.Get("http://127.0.0.1:18800/v1/catalog/nodes") + req := client.newRequest("GET", "/v1/catalog/nodes") + _, resp, err := client.doRequest(req) if err != nil { return false, err } @@ -102,21 +162,6 @@ func newTestServer(t *testing.T) *testServer { t.Fatalf("err: %s", err) }) - return &testServer{ - pid: cmd.Process.Pid, - dataDir: dataDir, - configFile: configFile.Name(), - } -} - -func makeClient(t *testing.T) (*Client, *testServer) { - server := newTestServer(t) - conf := DefaultConfig() - conf.Address = "127.0.0.1:18800" - client, err := NewClient(conf) - if err != nil { - t.Fatalf("err: %v", err) - } return client, server } @@ -205,7 +250,7 @@ func TestRequestToHTTP(t *testing.T) { if req.Method != "DELETE" { t.Fatalf("bad: %v", req) } - if req.URL.String() != "http://127.0.0.1:18800/v1/kv/foo?dc=foo" { + if req.URL.RequestURI() != "/v1/kv/foo?dc=foo" { t.Fatalf("bad: %v", req) } } diff --git a/api/status_test.go b/api/status_test.go index 096b13da090b..5e7acd274060 100644 --- a/api/status_test.go +++ b/api/status_test.go @@ -1,10 +1,13 @@ package api import ( + "io/ioutil" + "os/user" + "runtime" "testing" ) -func TestStatusLeader(t *testing.T) { +func TestStatusLeaderTCP(t *testing.T) { c, s := makeClient(t) defer s.stop() @@ -19,6 +22,48 @@ func TestStatusLeader(t *testing.T) { } } +func TestStatusLeaderUnix(t *testing.T) { + if runtime.GOOS == "windows" { + t.SkipNow() + } + + tempdir, err := ioutil.TempDir("", "consul-test-") + if err != nil { + t.Fatal("Could not create a working directory") + } + + socket := "unix://" + tempdir + "/unix-http-test.sock" + + clientConfig := func(c *Config) { + c.Address = socket + } + + serverConfig := func(c *testServerConfig) { + user, err := user.Current() + if err != nil { + t.Fatal("Could not get current user") + } + + if c.Addresses == nil { + c.Addresses = &testAddressConfig{} + } + c.Addresses.HTTP = socket + ";" + user.Uid + ";" + user.Gid + ";640" + } + + c, s := makeClientWithConfig(t, clientConfig, serverConfig) + defer s.stop() + + status := c.Status() + + leader, err := status.Leader() + if err != nil { + t.Fatalf("err: %v", err) + } + if leader == "" { + t.Fatalf("Expected leader") + } +} + func TestStatusPeers(t *testing.T) { c, s := makeClient(t) defer s.stop() diff --git a/command/agent/agent_test.go b/command/agent/agent_test.go index 0add36702c70..91cb5c1ef51f 100644 --- a/command/agent/agent_test.go +++ b/command/agent/agent_test.go @@ -7,8 +7,10 @@ import ( "io" "io/ioutil" "os" + "os/user" "path/filepath" "reflect" + "runtime" "sync/atomic" "testing" "time" @@ -123,7 +125,7 @@ func TestAgentStartStop(t *testing.T) { } } -func TestAgent_RPCPing(t *testing.T) { +func TestAgent_RPCPingTCP(t *testing.T) { dir, agent := makeAgent(t, nextConfig()) defer os.RemoveAll(dir) defer agent.Shutdown() @@ -134,6 +136,35 @@ func TestAgent_RPCPing(t *testing.T) { } } +func TestAgent_RPCPingUnix(t *testing.T) { + if runtime.GOOS == "windows" { + t.SkipNow() + } + + nextConf := nextConfig() + + tempdir, err := ioutil.TempDir("", "consul-test-") + if err != nil { + t.Fatal("Could not create a working directory") + } + + user, err := user.Current() + if err != nil { + t.Fatal("Could not get current user") + } + + nextConf.Addresses.RPC = "unix://" + tempdir + "/unix-rpc-test.sock;" + user.Uid + ";" + user.Gid + ";640" + + dir, agent := makeAgent(t, nextConf) + defer os.RemoveAll(dir) + defer agent.Shutdown() + + var out struct{} + if err := agent.RPC("Status.Ping", struct{}{}, &out); err != nil { + t.Fatalf("err: %v", err) + } +} + func TestAgent_AddService(t *testing.T) { dir, agent := makeAgent(t, nextConfig()) defer os.RemoveAll(dir) diff --git a/command/agent/command.go b/command/agent/command.go index 82e111caa882..b9a82e19aeb7 100644 --- a/command/agent/command.go +++ b/command/agent/command.go @@ -295,13 +295,26 @@ func (c *Command) setupAgent(config *Config, logOutput io.Writer, logWriter *log return err } - rpcListener, err := net.Listen("tcp", rpcAddr.String()) + if _, ok := rpcAddr.(*net.UnixAddr); ok { + // Remove the socket if it exists, or we'll get a bind error + _ = os.Remove(rpcAddr.String()) + } + + rpcListener, err := net.Listen(rpcAddr.Network(), rpcAddr.String()) if err != nil { agent.Shutdown() c.Ui.Error(fmt.Sprintf("Error starting RPC listener: %s", err)) return err } + if _, ok := rpcAddr.(*net.UnixAddr); ok { + if err := adjustUnixSocketPermissions(config.Addresses.RPC); err != nil { + agent.Shutdown() + c.Ui.Error(fmt.Sprintf("Error adjusting Unix socket permissions: %s", err)) + return err + } + } + // Start the IPC layer c.Ui.Output("Starting Consul agent RPC...") c.rpcServer = NewAgentRPC(agent, rpcListener, logOutput, logWriter) @@ -319,6 +332,7 @@ func (c *Command) setupAgent(config *Config, logOutput io.Writer, logWriter *log if config.Ports.DNS > 0 { dnsAddr, err := config.ClientListener(config.Addresses.DNS, config.Ports.DNS) if err != nil { + agent.Shutdown() c.Ui.Error(fmt.Sprintf("Invalid DNS bind address: %s", err)) return err } @@ -575,7 +589,7 @@ func (c *Command) Run(args []string) int { } // Get the new client http listener addr - httpAddr, err := config.ClientListenerAddr(config.Addresses.HTTP, config.Ports.HTTP) + httpAddr, err := config.ClientListener(config.Addresses.HTTP, config.Ports.HTTP) if err != nil { c.Ui.Error(fmt.Sprintf("Failed to determine HTTP address: %v", err)) } @@ -585,7 +599,7 @@ func (c *Command) Run(args []string) int { go func(wp *watch.WatchPlan) { wp.Handler = makeWatchHandler(logOutput, wp.Exempt["handler"]) wp.LogOutput = c.logOutput - if err := wp.Run(httpAddr); err != nil { + if err := wp.Run(httpAddr.String()); err != nil { c.Ui.Error(fmt.Sprintf("Error running watch: %v", err)) } }(wp) @@ -744,7 +758,7 @@ func (c *Command) handleReload(config *Config) *Config { } // Get the new client listener addr - httpAddr, err := newConf.ClientListenerAddr(config.Addresses.HTTP, config.Ports.HTTP) + httpAddr, err := newConf.ClientListener(config.Addresses.HTTP, config.Ports.HTTP) if err != nil { c.Ui.Error(fmt.Sprintf("Failed to determine HTTP address: %v", err)) } @@ -759,7 +773,7 @@ func (c *Command) handleReload(config *Config) *Config { go func(wp *watch.WatchPlan) { wp.Handler = makeWatchHandler(c.logOutput, wp.Exempt["handler"]) wp.LogOutput = c.logOutput - if err := wp.Run(httpAddr); err != nil { + if err := wp.Run(httpAddr.String()); err != nil { c.Ui.Error(fmt.Sprintf("Error running watch: %v", err)) } }(wp) diff --git a/command/agent/config.go b/command/agent/config.go index 3e14134921b2..74ed4c3ad965 100644 --- a/command/agent/config.go +++ b/command/agent/config.go @@ -7,8 +7,11 @@ import ( "io" "net" "os" + "os/user" "path/filepath" + "regexp" "sort" + "strconv" "strings" "time" @@ -345,6 +348,91 @@ type Config struct { WatchPlans []*watch.WatchPlan `mapstructure:"-" json:"-"` } +// UnixSocket contains the parameters for a Unix socket interface +type UnixSocket struct { + // Path to the socket on-disk + Path string + + // uid of the owner of the socket + Uid int + + // gid of the group of the socket + Gid int + + // Permissions for the socket file + Permissions os.FileMode +} + +func populateUnixSocket(addr string) (*UnixSocket, error) { + if !strings.HasPrefix(addr, "unix://") { + return nil, fmt.Errorf("Failed to parse Unix address, format is unix://[path];[user];[group];[mode]: %v", addr) + } + + splitAddr := strings.Split(strings.TrimPrefix(addr, "unix://"), ";") + if len(splitAddr) != 4 { + return nil, fmt.Errorf("Failed to parse Unix address, format is unix://[path];[user];[group];[mode]: %v", addr) + } + + ret := &UnixSocket{Path: splitAddr[0]} + + var userVal *user.User + var err error + + regex := regexp.MustCompile("[\\d]+") + if regex.MatchString(splitAddr[1]) { + userVal, err = user.LookupId(splitAddr[1]) + } else { + userVal, err = user.Lookup(splitAddr[1]) + } + if err != nil { + return nil, fmt.Errorf("Invalid user given for Unix socket ownership: %v", splitAddr[1]) + } + + if uid64, err := strconv.ParseInt(userVal.Uid, 10, 32); err != nil { + return nil, fmt.Errorf("Failed to parse given user ID of %v into integer", userVal.Uid) + } else { + ret.Uid = int(uid64) + } + + // Go doesn't currently have a way to look up gid from group name, + // so require a numeric gid; see + // https://codereview.appspot.com/101310044 + if gid64, err := strconv.ParseInt(splitAddr[2], 10, 32); err != nil { + return nil, fmt.Errorf("Socket group must be given as numeric gid. Failed to parse given group ID of %v into integer", splitAddr[2]) + } else { + ret.Gid = int(gid64) + } + + if mode, err := strconv.ParseUint(splitAddr[3], 8, 32); err != nil { + return nil, fmt.Errorf("Failed to parse given mode of %v into integer", splitAddr[3]) + } else { + if mode > 0777 { + return nil, fmt.Errorf("Given mode is invalid; must be an octal number between 0 and 777") + } else { + ret.Permissions = os.FileMode(mode) + } + } + + return ret, nil +} + +func adjustUnixSocketPermissions(addr string) error { + sock, err := populateUnixSocket(addr) + if err != nil { + return err + } + + if err = os.Chown(sock.Path, sock.Uid, sock.Gid); err != nil { + return fmt.Errorf("Error attempting to change socket permissions to userid %v and groupid %v: %v", sock.Uid, sock.Gid, err) + } + + if err = os.Chmod(sock.Path, sock.Permissions); err != nil { + return fmt.Errorf("Error attempting to change socket permissions to mode %v: %v", sock.Permissions, err) + } + + return nil +} + type dirEnts []os.FileInfo // DefaultConfig is used to return a sane default configuration @@ -389,31 +477,39 @@ func (c *Config) EncryptBytes() ([]byte, error) { // ClientListener is used to format a listener for a // port on a ClientAddr -func (c *Config) ClientListener(override string, port int) (*net.TCPAddr, error) { +func (c *Config) ClientListener(override string, port int) (net.Addr, error) { var addr string if override != "" { addr = override } else { addr = c.ClientAddr } - ip := net.ParseIP(addr) - if ip == nil { - return nil, fmt.Errorf("Failed to parse IP: %v", addr) - } - return &net.TCPAddr{IP: ip, Port: port}, nil -} -// ClientListenerAddr is used to format an address for a -// port on a ClientAddr, handling the zero IP. -func (c *Config) ClientListenerAddr(override string, port int) (string, error) { - addr, err := c.ClientListener(override, port) - if err != nil { - return "", err - } - if addr.IP.IsUnspecified() { - addr.IP = net.ParseIP("127.0.0.1") + switch { + case strings.HasPrefix(addr, "unix://"): + sock, err := populateUnixSocket(addr) + if err != nil { + return nil, err + } + + return &net.UnixAddr{Name: sock.Path, Net: "unix"}, nil + + default: + ip := net.ParseIP(addr) + if ip == nil { + return nil, fmt.Errorf("Failed to parse IP: %v", addr) + } + + if ip.IsUnspecified() { + ip = net.ParseIP("127.0.0.1") + } + + if ip == nil { + return nil, fmt.Errorf("Failed to parse IP 127.0.0.1") + } + + return &net.TCPAddr{IP: ip, Port: port}, nil } - return addr.String(), nil } // DecodeConfig reads the configuration from the given reader in JSON diff --git a/command/agent/config_test.go b/command/agent/config_test.go index 292e1b2721bb..f10c5b72422e 100644 --- a/command/agent/config_test.go +++ b/command/agent/config_test.go @@ -4,9 +4,12 @@ import ( "bytes" "encoding/base64" "io/ioutil" + "net" "os" + "os/user" "path/filepath" "reflect" + "runtime" "strings" "testing" "time" @@ -1068,3 +1071,109 @@ func TestReadConfigPaths_dir(t *testing.T) { t.Fatalf("bad: %#v", config) } } + +func TestUnixSockets(t *testing.T) { + if runtime.GOOS == "windows" { + t.SkipNow() + } + + usr, err := user.Current() + if err != nil { + t.Fatal("Could not get current user: ", err) + } + + tempdir, err := ioutil.TempDir("", "consul-test-") + if err != nil { + t.Fatal("Could not create a working directory: ", err) + } + + type SocketTestData struct { + Path string + Uid string + Gid string + Mode string + } + + testUnixSocketPopulation := func(s SocketTestData) (*UnixSocket, error) { + return populateUnixSocket("unix://" + s.Path + ";" + s.Uid + ";" + s.Gid + ";" + s.Mode) + } + + testUnixSocketPermissions := func(s SocketTestData) error { + return adjustUnixSocketPermissions("unix://" + s.Path + ";" + s.Uid + ";" + s.Gid + ";" + s.Mode) + } + + _, err = populateUnixSocket("tcp://abc123") + if err == nil { + t.Fatal("Should have rejected invalid scheme") + } + + _, err = populateUnixSocket("unix://x;y;z") + if err == nil { + t.Fatal("Should have rejected invalid number of parameters in Unix socket definition") + } + + std := SocketTestData{ + Path: tempdir + "/unix-config-test.sock", + Uid: usr.Uid, + Gid: usr.Gid, + Mode: "640", + } + + std.Uid = "orasdfdsnfoinweroiu" + _, err = testUnixSocketPopulation(std) + if err == nil { + t.Fatal("Did not error on invalid username") + } + + std.Uid = usr.Username + std.Gid = "foinfphawepofhewof" + _, err = testUnixSocketPopulation(std) + if err == nil { + t.Fatal("Did not error on invalid group (a name, must be gid)") + } + + std.Gid = usr.Gid + std.Mode = "999" + _, err = testUnixSocketPopulation(std) + if err == nil { + t.Fatal("Did not error on invalid socket mode") + } + + std.Uid = usr.Username + std.Mode = "640" + _, err = testUnixSocketPopulation(std) + if err != nil { + t.Fatal("Unix socket test failed (using username): ", err) + } + + std.Uid = usr.Uid + sock, err := testUnixSocketPopulation(std) + if err != nil { + t.Fatal("Unix socket test failed (using uid): ", err) + } + + addr := &net.UnixAddr{Name: sock.Path, Net: "unix"} + _, err = net.Listen(addr.Network(), addr.String()) + if err != nil { + t.Fatal("Error creating socket for futher tests: ", err) + } + + std.Uid = "-999999" + err = testUnixSocketPermissions(std) + if err == nil { + t.Fatal("Did not error on invalid uid") + } + + std.Uid = usr.Uid + std.Gid = "-999999" + err = testUnixSocketPermissions(std) + if err == nil { + t.Fatal("Did not error on invalid uid") + } + + std.Gid = usr.Gid + err = testUnixSocketPermissions(std) + if err != nil { + t.Fatal("Adjusting socket permissions failed: ", err) + } +} diff --git a/command/agent/http.go b/command/agent/http.go index 3fe2feec51e5..d480de816f52 100644 --- a/command/agent/http.go +++ b/command/agent/http.go @@ -9,6 +9,7 @@ import ( "net" "net/http" "net/http/pprof" + "os" "strconv" "strings" "time" @@ -34,7 +35,7 @@ type HTTPServer struct { func NewHTTPServers(agent *Agent, config *Config, logOutput io.Writer) ([]*HTTPServer, error) { var tlsConfig *tls.Config var list net.Listener - var httpAddr *net.TCPAddr + var httpAddr net.Addr var err error var servers []*HTTPServer @@ -58,12 +59,29 @@ func NewHTTPServers(agent *Agent, config *Config, logOutput io.Writer) ([]*HTTPS return nil, err } - ln, err := net.Listen("tcp", httpAddr.String()) + if _, ok := httpAddr.(*net.UnixAddr); ok { + // Remove the socket if it exists, or we'll get a bind error + _ = os.Remove(httpAddr.String()) + } + + ln, err := net.Listen(httpAddr.Network(), httpAddr.String()) if err != nil { - return nil, err + return nil, fmt.Errorf("Failed to get Listen on %s: %v", httpAddr.String(), err) } - list = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, tlsConfig) + switch httpAddr.(type) { + case *net.UnixAddr: + if err := adjustUnixSocketPermissions(config.Addresses.HTTPS); err != nil { + return nil, err + } + list = tls.NewListener(ln, tlsConfig) + + case *net.TCPAddr: + list = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, tlsConfig) + + default: + return nil, fmt.Errorf("Error determining address type when attempting to get Listen on %s: %v", httpAddr.String(), err) + } // Create the mux mux := http.NewServeMux() @@ -90,13 +108,29 @@ func NewHTTPServers(agent *Agent, config *Config, logOutput io.Writer) ([]*HTTPS return nil, fmt.Errorf("Failed to get ClientListener address:port: %v", err) } - // Create non-TLS listener - ln, err := net.Listen("tcp", httpAddr.String()) + if _, ok := httpAddr.(*net.UnixAddr); ok { + // Remove the socket if it exists, or we'll get a bind error + _ = os.Remove(httpAddr.String()) + } + + ln, err := net.Listen(httpAddr.Network(), httpAddr.String()) if err != nil { return nil, fmt.Errorf("Failed to get Listen on %s: %v", httpAddr.String(), err) } - list = tcpKeepAliveListener{ln.(*net.TCPListener)} + switch httpAddr.(type) { + case *net.UnixAddr: + if err := adjustUnixSocketPermissions(config.Addresses.HTTP); err != nil { + return nil, err + } + list = ln + + case *net.TCPAddr: + list = tcpKeepAliveListener{ln.(*net.TCPListener)} + + default: + return nil, fmt.Errorf("Error determining address type when attempting to get Listen on %s: %v", httpAddr.String(), err) + } // Create the mux mux := http.NewServeMux() diff --git a/command/agent/rpc_client.go b/command/agent/rpc_client.go index 7ba1907b248a..1490674f146f 100644 --- a/command/agent/rpc_client.go +++ b/command/agent/rpc_client.go @@ -7,6 +7,8 @@ import ( "github.com/hashicorp/logutils" "log" "net" + "os" + "strings" "sync" "sync/atomic" ) @@ -34,7 +36,7 @@ type seqHandler interface { type RPCClient struct { seq uint64 - conn *net.TCPConn + conn net.Conn reader *bufio.Reader writer *bufio.Writer dec *codec.Decoder @@ -79,8 +81,23 @@ func (c *RPCClient) send(header *requestHeader, obj interface{}) error { // NewRPCClient is used to create a new RPC client given the address. // This will properly dial, handshake, and start listening func NewRPCClient(addr string) (*RPCClient, error) { + sanedAddr := os.Getenv("CONSUL_RPC_ADDR") + if len(sanedAddr) == 0 { + sanedAddr = addr + } + + mode := "tcp" + + if strings.HasPrefix(sanedAddr, "unix://") { + sanedAddr = strings.TrimPrefix(sanedAddr, "unix://") + } + + if strings.HasPrefix(sanedAddr, "/") { + mode = "unix" + } + // Try to dial to agent - conn, err := net.Dial("tcp", addr) + conn, err := net.Dial(mode, sanedAddr) if err != nil { return nil, err } @@ -88,7 +105,7 @@ func NewRPCClient(addr string) (*RPCClient, error) { // Create the client client := &RPCClient{ seq: 0, - conn: conn.(*net.TCPConn), + conn: conn, reader: bufio.NewReader(conn), writer: bufio.NewWriter(conn), dispatch: make(map[uint64]seqHandler), diff --git a/command/agent/rpc_client_test.go b/command/agent/rpc_client_test.go index 3bf03d6dc6eb..2d8dfc9c0841 100644 --- a/command/agent/rpc_client_test.go +++ b/command/agent/rpc_client_test.go @@ -6,8 +6,11 @@ import ( "github.com/hashicorp/consul/testutil" "github.com/hashicorp/serf/serf" "io" + "io/ioutil" "net" "os" + "os/user" + "runtime" "strings" "testing" "time" @@ -34,17 +37,22 @@ func testRPCClient(t *testing.T) *rpcParts { } func testRPCClientWithConfig(t *testing.T, cb func(c *Config)) *rpcParts { - l, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("err: %s", err) - } - lw := NewLogWriter(512) mult := io.MultiWriter(os.Stderr, lw) conf := nextConfig() cb(conf) + rpcAddr, err := conf.ClientListener(conf.Addresses.RPC, conf.Ports.RPC) + if err != nil { + t.Fatalf("err: %s", err) + } + + l, err := net.Listen(rpcAddr.Network(), rpcAddr.String()) + if err != nil { + t.Fatalf("err: %s", err) + } + dir, agent := makeAgentLog(t, conf, mult) rpc := NewAgentRPC(agent, l, mult, lw) @@ -208,6 +216,41 @@ func TestRPCClientStats(t *testing.T) { } } +func TestRPCClientStatsUnix(t *testing.T) { + if runtime.GOOS == "windows" { + t.SkipNow() + } + + tempdir, err := ioutil.TempDir("", "consul-test-") + if err != nil { + t.Fatal("Could not create a working directory: ", err) + } + + user, err := user.Current() + if err != nil { + t.Fatal("Could not get current user: ", err) + } + + cb := func(c *Config) { + c.Addresses.RPC = "unix://" + tempdir + "/unix-rpc-test.sock;" + user.Uid + ";" + user.Gid + ";640" + } + + p1 := testRPCClientWithConfig(t, cb) + + stats, err := p1.client.Stats() + if err != nil { + t.Fatalf("err: %s", err) + } + + if _, ok := stats["agent"]; !ok { + t.Fatalf("bad: %#v", stats) + } + + if _, ok := stats["consul"]; !ok { + t.Fatalf("bad: %#v", stats) + } +} + func TestRPCClientLeave(t *testing.T) { p1 := testRPCClient(t) defer p1.Close() diff --git a/command/rpc.go b/command/rpc.go index f70fb4f23c5d..f0c9e5b1fabf 100644 --- a/command/rpc.go +++ b/command/rpc.go @@ -8,8 +8,8 @@ import ( "github.com/hashicorp/consul/command/agent" ) -// RPCAddrEnvName defines the environment variable name, which can set -// a default RPC address in case there is no -rpc-addr specified. +// RPCAddrEnvName defines an environment variable name which sets +// an RPC address if there is no -rpc-addr specified. const RPCAddrEnvName = "CONSUL_RPC_ADDR" // RPCAddrFlag returns a pointer to a string that will be populated @@ -43,7 +43,12 @@ func HTTPClient(addr string) (*consulapi.Client, error) { // HTTPClientDC returns a new Consul HTTP client with the given address and datacenter func HTTPClientDC(addr, dc string) (*consulapi.Client, error) { conf := consulapi.DefaultConfig() - conf.Address = addr + switch { + case len(os.Getenv("CONSUL_HTTP_ADDR")) > 0: + conf.Address = os.Getenv("CONSUL_HTTP_ADDR") + default: + conf.Address = addr + } conf.Datacenter = dc return consulapi.NewClient(conf) } diff --git a/website/source/docs/agent/options.html.markdown b/website/source/docs/agent/options.html.markdown index 5144342a005f..6aa6b61d872e 100644 --- a/website/source/docs/agent/options.html.markdown +++ b/website/source/docs/agent/options.html.markdown @@ -239,8 +239,17 @@ definitions support being updated during a reload. However, because the caches are not actively invalidated, ACL policy may be stale up to the TTL value. -* `addresses` - This is a nested object that allows setting the bind address - for the following keys: +* `addresses` - This is a nested object that allows setting bind addresses. For `rpc` + and `http`, a Unix socket can be specified in the following form: + unix://[/path/to/socket];[username|uid];[gid];[mode]. The socket will be created + in the specified location with the given username or uid, gid, and mode. The + user Consul is running as must have appropriate permissions to change the socket + ownership to the given uid or gid. When running Consul agent commands against + Unix socket interfaces, use the `-rpc-addr` or `-http-addr` arguments to specify + the path to the socket, e.g. "unix://path/to/socket". You can also place the desired + values in `CONSUL_RPC_ADDR` and `CONSUL_HTTP_ADDR` environment variables. For TCP + addresses, these should be in the form ip:port. + The following keys are valid: * `dns` - The DNS server. Defaults to `client_addr` * `http` - The HTTP API. Defaults to `client_addr` * `rpc` - The RPC endpoint. Defaults to `client_addr`