Skip to content

Commit

Permalink
feat: add max-sigterm-delay flag (#64)
Browse files Browse the repository at this point in the history
  • Loading branch information
enocom authored Jul 11, 2022
1 parent 0fd062c commit 2c9864d
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 39 deletions.
12 changes: 9 additions & 3 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,12 @@ without having to manage any client SSL certificates.`,
cmd.PersistentFlags().Uint64Var(&c.conf.MaxConnections, "max-connections", 0,
`Limits the number of connections by refusing any additional connections.
When this flag is not set, there is no limit.`)
cmd.PersistentFlags().DurationVar(&c.conf.WaitOnClose, "max-sigterm-delay", 0,
`Maximum amount of time to wait after for any open connections
to close after receiving a TERM signal. The proxy will shut
down when the number of open connections reaches 0 or when
the maximum time has passed. Defaults to 0s.`)

cmd.PersistentFlags().StringVar(&c.telemetryProject, "telemetry-project", "",
"Enable Cloud Monitoring and Cloud Trace integration with the provided project ID.")
cmd.PersistentFlags().BoolVar(&c.disableTraces, "disable-traces", false,
Expand Down Expand Up @@ -389,7 +395,7 @@ func runSignalWrapper(cmd *Command) error {
cmd.Println("The proxy has started successfully and is ready for new connections!")
defer func() {
if cErr := p.Close(); cErr != nil {
cmd.PrintErrf("error during shutdown: %v\n", cErr)
cmd.PrintErrf("The proxy failed to close cleanly: %v\n", cErr)
}
}()

Expand All @@ -400,9 +406,9 @@ func runSignalWrapper(cmd *Command) error {
err := <-shutdownCh
switch {
case errors.Is(err, errSigInt):
cmd.PrintErrln("SIGINT signal received. Shuting down...")
cmd.PrintErrln("SIGINT signal received. Shutting down...")
case errors.Is(err, errSigTerm):
cmd.PrintErrln("SIGTERM signal received. Shuting down...")
cmd.PrintErrln("SIGTERM signal received. Shutting down...")
default:
cmd.PrintErrf("The proxy has encountered a terminal error: %v\n", err)
}
Expand Down
7 changes: 7 additions & 0 deletions cmd/root_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,13 @@ func TestNewCommandArguments(t *testing.T) {
MaxConnections: 1,
}),
},
{
desc: "using wait after signterm flag",
args: []string{"--max-sigterm-delay", "10s", "/projects/proj/locations/region/clusters/clust/instances/inst"},
want: withDefaults(&proxy.Config{
WaitOnClose: 10 * time.Second,
}),
},
}

for _, tc := range tcs {
Expand Down
42 changes: 38 additions & 4 deletions internal/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ type Config struct {
// connections. A zero-value indicates no limit.
MaxConnections uint64

// WaitOnClose sets the duration to wait for connections to close before
// shutting down. Not setting this field means to close immediately
// regardless of any open connections.
WaitOnClose time.Duration

// Dialer specifies the dialer to use when connecting to AlloyDB
// instances.
Dialer alloydb.Dialer
Expand Down Expand Up @@ -172,6 +177,10 @@ type Client struct {

// mnts is a list of all mounted sockets for this client
mnts []*socketMount

// waitOnClose is the maximum duration to wait for open connections to close
// when shutting down.
waitOnClose time.Duration
}

// NewClient completes the initial setup required to get the proxy to a "steady" state.
Expand Down Expand Up @@ -210,10 +219,11 @@ func NewClient(ctx context.Context, cmd *cobra.Command, conf *Config) (*Client,
}

c := &Client{
mnts: mnts,
cmd: cmd,
dialer: d,
maxConns: conf.MaxConnections,
mnts: mnts,
cmd: cmd,
dialer: d,
maxConns: conf.MaxConnections,
waitOnClose: conf.WaitOnClose,
}
return c, nil
}
Expand Down Expand Up @@ -262,16 +272,40 @@ func (m MultiErr) Error() string {

func (c *Client) Close() error {
var mErr MultiErr
// First, close all open socket listeners to prevent additional connections.
for _, m := range c.mnts {
err := m.Close()
if err != nil {
mErr = append(mErr, err)
}
}
// Next, close the dialer to prevent any additional refreshes.
cErr := c.dialer.Close()
if cErr != nil {
mErr = append(mErr, cErr)
}
if c.waitOnClose == 0 {
if len(mErr) > 0 {
return mErr
}
return nil
}
timeout := time.After(c.waitOnClose)
tick := time.Tick(100 * time.Millisecond)
for {
select {
case <-tick:
if atomic.LoadUint64(&c.connCount) > 0 {
continue
}
case <-timeout:
}
break
}
open := atomic.LoadUint64(&c.connCount)
if open > 0 {
mErr = append(mErr, fmt.Errorf("%d connection(s) still open after waiting %v", open, c.waitOnClose))
}
if len(mErr) > 0 {
return mErr
}
Expand Down
118 changes: 86 additions & 32 deletions internal/proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ type errorDialer struct {
fakeDialer
}

func (errorDialer) Close() error {
func (*errorDialer) Close() error {
return errors.New("errorDialer returns error on Close")
}

Expand Down Expand Up @@ -143,15 +143,15 @@ func TestClientInitialization(t *testing.T) {
desc: "with incrementing automatic port selection",
in: &proxy.Config{
Addr: "127.0.0.1",
Port: 5432, // default port
Port: 6000,
Instances: []proxy.InstanceConnConfig{
{Name: inst1},
{Name: inst2},
},
},
wantTCPAddrs: []string{
"127.0.0.1:5432",
"127.0.0.1:5433",
"127.0.0.1:6000",
"127.0.0.1:6001",
},
},
{
Expand Down Expand Up @@ -238,25 +238,6 @@ func TestClientInitialization(t *testing.T) {
}
}

func tryTCPDial(t *testing.T, addr string) net.Conn {
attempts := 10
var (
conn net.Conn
err error
)
for i := 0; i < attempts; i++ {
conn, err = net.Dial("tcp", addr)
if err != nil {
time.Sleep(100 * time.Millisecond)
continue
}
return conn
}

t.Fatalf("failed to dial in %v attempts: %v", attempts, err)
return nil
}

func TestClientLimitsMaxConnections(t *testing.T) {
d := &fakeDialer{}
in := &proxy.Config{
Expand Down Expand Up @@ -291,17 +272,92 @@ func TestClientLimitsMaxConnections(t *testing.T) {
// wait only a second for the result (since nothing is writing to the
// socket)
conn2.SetReadDeadline(time.Now().Add(time.Second))
_, rErr := conn2.Read(make([]byte, 1))
if rErr != io.EOF {
t.Fatalf("conn.Read should return io.EOF, got = %v", rErr)

wantEOF := func(t *testing.T, c net.Conn) {
var got error
for i := 0; i < 10; i++ {
_, got = c.Read(make([]byte, 1))
if got == io.EOF {
return
}
time.Sleep(100 * time.Millisecond)
}
t.Fatalf("conn.Read should return io.EOF, got = %v", got)
}

wantEOF(t, conn2)

want := 1
if got := d.dialAttempts(); got != want {
t.Fatalf("dial attempts did not match expected, want = %v, got = %v", want, got)
}
}

func tryTCPDial(t *testing.T, addr string) net.Conn {
attempts := 10
var (
conn net.Conn
err error
)
for i := 0; i < attempts; i++ {
conn, err = net.Dial("tcp", addr)
if err != nil {
time.Sleep(100 * time.Millisecond)
continue
}
return conn
}

t.Fatalf("failed to dial in %v attempts: %v", attempts, err)
return nil
}

func TestClientCloseWaitsForActiveConnections(t *testing.T) {
in := &proxy.Config{
Addr: "127.0.0.1",
Port: 5000,
Instances: []proxy.InstanceConnConfig{
{Name: "proj:region:pg"},
},
Dialer: &fakeDialer{},
}
c, err := proxy.NewClient(context.Background(), &cobra.Command{}, in)
if err != nil {
t.Fatalf("proxy.NewClient error: %v", err)
}
go c.Serve(context.Background())

conn := tryTCPDial(t, "127.0.0.1:5000")
_ = conn.Close()

if err := c.Close(); err != nil {
t.Fatalf("c.Close error: %v", err)
}

in.WaitOnClose = time.Second
in.Port = 5001
c, err = proxy.NewClient(context.Background(), &cobra.Command{}, in)
if err != nil {
t.Fatalf("proxy.NewClient error: %v", err)
}
go c.Serve(context.Background())

var open []net.Conn
for i := 0; i < 5; i++ {
conn = tryTCPDial(t, "127.0.0.1:5001")
open = append(open, conn)
}
defer func() {
for _, o := range open {
o.Close()
}
}()

if err := c.Close(); err == nil {
t.Fatal("c.Close should error, got = nil")
}
}

func TestClientClosesCleanly(t *testing.T) {
in := &proxy.Config{
Addr: "127.0.0.1",
Expand All @@ -316,12 +372,8 @@ func TestClientClosesCleanly(t *testing.T) {
t.Fatalf("proxy.NewClient error want = nil, got = %v", err)
}
go c.Serve(context.Background())
time.Sleep(time.Second) // allow the socket to start listening

conn, dErr := net.Dial("tcp", "127.0.0.1:5000")
if dErr != nil {
t.Fatalf("net.Dial error = %v", dErr)
}
conn := tryTCPDial(t, "127.0.0.1:5000")
_ = conn.Close()

if err := c.Close(); err != nil {
Expand All @@ -343,7 +395,9 @@ func TestClosesWithError(t *testing.T) {
t.Fatalf("proxy.NewClient error want = nil, got = %v", err)
}
go c.Serve(context.Background())
time.Sleep(time.Second) // allow the socket to start listening

conn := tryTCPDial(t, "127.0.0.1:5000")
defer conn.Close()

if err = c.Close(); err == nil {
t.Fatal("c.Close() should error, got nil")
Expand Down

0 comments on commit 2c9864d

Please sign in to comment.