diff --git a/server/listener.go b/server/listener.go index cee5fa3771..50aea3c62e 100644 --- a/server/listener.go +++ b/server/listener.go @@ -53,7 +53,7 @@ type Listener struct { // For unix socket connection, 'unixSocketPath' takes a path for the unix socket file. // If 'unixSocketPath' is empty, no need to create the second listener. func NewListener(protocol, address string, unixSocketPath string) (*Listener, error) { - netl, err := net.Listen(protocol, address) + netl, err := newNetListener(protocol, address) if err != nil { return nil, err } diff --git a/server/net_listener_unix.go b/server/net_listener_unix.go new file mode 100644 index 0000000000..e35bf6688f --- /dev/null +++ b/server/net_listener_unix.go @@ -0,0 +1,39 @@ +//go:build !windows + +package server + +import ( + "context" + "net" + "syscall" + + "golang.org/x/sys/unix" +) + +// Very rarely in our CI, the server fails to bind to the port with the error: "port already in use." +// This is odd because the server already confirms that the port is not in use before connecting. +// Using the SO_REUSEADDR and SO_REUSEPORT options prevents this spurious failure. +// This is safe to do because we have already checked that the +func newNetListener(protocol, address string) (net.Listener, error) { + lc := net.ListenConfig{ + Control: func(network, address string, c syscall.RawConn) error { + var socketErr error + err := c.Control(func(fd uintptr) { + err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEADDR, 1) + if err != nil { + socketErr = err + } + + err = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1) + if err != nil { + socketErr = err + } + }) + if err != nil { + return err + } + return socketErr + }, + } + return lc.Listen(context.Background(), protocol, address) +} diff --git a/server/net_listener_windows.go b/server/net_listener_windows.go new file mode 100644 index 0000000000..71cc70b4fd --- /dev/null +++ b/server/net_listener_windows.go @@ -0,0 +1,7 @@ +package server + +import "net" + +func newNetListener(protocol, address string) (net.Listener, error) { + return net.Listen(protocol, address) +} diff --git a/server/server.go b/server/server.go index 5c4bd7c422..d3cf5b8975 100644 --- a/server/server.go +++ b/server/server.go @@ -16,6 +16,8 @@ package server import ( "errors" + "fmt" + "net" "time" "github.com/dolthub/vitess/go/mysql" @@ -97,6 +99,16 @@ func NewValidatingServer( return newServerFromHandler(cfg, e, sm, handler) } +func portInUse(hostPort string) bool { + timeout := time.Second + conn, _ := net.DialTimeout("tcp", hostPort, timeout) + if conn != nil { + defer conn.Close() + return true + } + return false +} + func newServerFromHandler(cfg Config, e *sqle.Engine, sm *SessionManager, handler mysql.Handler) (*Server, error) { if cfg.ConnReadTimeout < 0 { cfg.ConnReadTimeout = 0 @@ -109,6 +121,11 @@ func newServerFromHandler(cfg Config, e *sqle.Engine, sm *SessionManager, handle } var unixSocketInUse error + + if portInUse(cfg.Address) { + unixSocketInUse = fmt.Errorf("Port %s already in use.", cfg.Address) + } + l, err := NewListener(cfg.Protocol, cfg.Address, cfg.Socket) if err != nil { if errors.Is(err, UnixSocketInUseError) {