Skip to content

Commit

Permalink
feat: add max connections flag (#63)
Browse files Browse the repository at this point in the history
  • Loading branch information
enocom authored Jul 8, 2022
1 parent e2b40c4 commit a5d8ee8
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 18 deletions.
3 changes: 3 additions & 0 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ without having to manage any client SSL certificates.`,
"Path to a service account key to use for authentication.")
cmd.PersistentFlags().BoolVarP(&c.conf.GcloudAuth, "gcloud-auth", "g", false,
"Use gcloud's user configuration to retrieve a token for authentication.")
cmd.PersistentFlags().Uint64Var(&c.conf.MaxConnections, "max-connections", 0,
`Limits the number of connections by refusing any additional connections.
When this flag is not set, there is no limit.`)
cmd.PersistentFlags().StringVar(&c.telemetryProject, "telemetry-project", "",
"Enable Cloud Monitoring and Cloud Trace integration with the provided project ID.")
cmd.PersistentFlags().BoolVar(&c.disableTraces, "disable-traces", false,
Expand Down
7 changes: 7 additions & 0 deletions cmd/root_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,13 @@ func TestNewCommandArguments(t *testing.T) {
}},
}),
},
{
desc: "using the max connections flag",
args: []string{"--max-connections", "1", "/projects/proj/locations/region/clusters/clust/instances/inst"},
want: withDefaults(&proxy.Config{
MaxConnections: 1,
}),
},
}

for _, tc := range tcs {
Expand Down
30 changes: 30 additions & 0 deletions internal/proxy/alignment_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package proxy

import (
"testing"
"unsafe"
)

func TestClientUsesSyncAtomicAlignment(t *testing.T) {
// The sync/atomic pkg has a bug that requires the developer to guarantee
// 64-bit alignment when using 64-bit functions on 32-bit systems.
c := &Client{}

if a := unsafe.Offsetof(c.connCount); a%64 != 0 {
t.Errorf("Client.connCount is not 64-bit aligned: want 0, got %v", a)
}
}
43 changes: 39 additions & 4 deletions internal/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"regexp"
"strings"
"sync"
"sync/atomic"
"time"

"cloud.google.com/go/alloydbconn"
Expand All @@ -43,7 +44,7 @@ type InstanceConnConfig struct {
// Port is the port on which to bind a listener for the instance.
Port int
// UnixSocket is the directory where a Unix socket will be created,
// connected to the Cloud SQL instance. If set, takes precedence over Addr
// connected to the AlloyDB instance. If set, takes precedence over Addr
// and Port.
UnixSocket string
}
Expand Down Expand Up @@ -79,6 +80,11 @@ type Config struct {
// configuration takes precedence over global configuration.
Instances []InstanceConnConfig

// MaxConnections are the maximum number of connections the Client may
// establish to the AlloyDB server side proxy before refusing additional
// connections. A zero-value indicates no limit.
MaxConnections uint64

// Dialer specifies the dialer to use when connecting to AlloyDB
// instances.
Dialer alloydb.Dialer
Expand Down Expand Up @@ -150,8 +156,17 @@ func UnixSocketDir(dir, inst string) (string, error) {
return filepath.Join(dir, shortName), nil
}

// Client represents the state of the current instantiation of the proxy.
// Client proxies connections from a local client to the remote server side
// proxy for multiple AlloyDB instances.
type Client struct {
// connCount tracks the number of all open connections from the Client to
// all AlloyDB instances.
connCount uint64

// maxConns is the maximum number of allowed connections tracked by
// connCount. If not set, there is no limit.
maxConns uint64

cmd *cobra.Command
dialer alloydb.Dialer

Expand Down Expand Up @@ -194,10 +209,17 @@ func NewClient(ctx context.Context, cmd *cobra.Command, conf *Config) (*Client,
mnts = append(mnts, m)
}

return &Client{mnts: mnts, cmd: cmd, dialer: d}, nil
c := &Client{
mnts: mnts,
cmd: cmd,
dialer: d,
maxConns: conf.MaxConnections,
}
return c, nil
}

// Serve listens on the mounted ports and beging proxying the connections to the instances.
// Serve starts proxying connections for all configured instances using the
// associated socket.
func (c *Client) Serve(ctx context.Context) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
Expand Down Expand Up @@ -274,6 +296,19 @@ func (c *Client) serveSocketMount(ctx context.Context, s *socketMount) error {
go func() {
c.cmd.Printf("[%s] accepted connection from %s\n", s.inst, cConn.RemoteAddr())

// A client has established a connection to the local socket. Before
// we initiate a connection to the AlloyDB backend, increment the
// connection counter. If the total number of connections exceeds
// the maximum, refuse to connect and close the client connection.
count := atomic.AddUint64(&c.connCount, 1)
defer atomic.AddUint64(&c.connCount, ^uint64(0))

if c.maxConns > 0 && count > c.maxConns {
c.cmd.Printf("max connections (%v) exceeded, refusing new connection\n", c.maxConns)
_ = cConn.Close()
return
}

// give a max of 30 seconds to connect to the instance
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
Expand Down
103 changes: 89 additions & 14 deletions internal/proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ package proxy_test
import (
"context"
"errors"
"io"
"io/ioutil"
"net"
"os"
"path/filepath"
"sync"
"testing"
"time"

Expand All @@ -29,21 +31,33 @@ import (
"github.com/spf13/cobra"
)

type fakeDialer struct{}

type testCase struct {
desc string
in *proxy.Config
wantTCPAddrs []string
wantUnixAddrs []string
}

func (fakeDialer) Dial(ctx context.Context, inst string, opts ...alloydbconn.DialOption) (net.Conn, error) {
conn, _ := net.Pipe()
return conn, nil
type fakeDialer struct {
mu sync.Mutex
dialCount int
}

func (f *fakeDialer) Dial(ctx context.Context, inst string, opts ...alloydbconn.DialOption) (net.Conn, error) {
f.mu.Lock()
defer f.mu.Unlock()
f.dialCount++
c1, _ := net.Pipe()
return c1, nil
}

func (f *fakeDialer) dialAttempts() int {
f.mu.Lock()
defer f.mu.Unlock()
return f.dialCount
}

func (fakeDialer) Close() error {
func (*fakeDialer) Close() error {
return nil
}

Expand Down Expand Up @@ -196,17 +210,14 @@ func TestClientInitialization(t *testing.T) {

for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
tc.in.Dialer = fakeDialer{}
tc.in.Dialer = &fakeDialer{}
c, err := proxy.NewClient(ctx, &cobra.Command{}, tc.in)
if err != nil {
t.Fatalf("want error = nil, got = %v", err)
}
defer c.Close()
for _, addr := range tc.wantTCPAddrs {
conn, err := net.Dial("tcp", addr)
if err != nil {
t.Fatalf("want error = nil, got = %v", err)
}
conn := tryTCPDial(t, addr)
err = conn.Close()
if err != nil {
t.Logf("failed to close connection: %v", err)
Expand All @@ -227,14 +238,78 @@ func TestClientInitialization(t *testing.T) {
}
}

func tryTCPDial(t *testing.T, addr string) net.Conn {
attempts := 10
var (
conn net.Conn
err error
)
for i := 0; i < attempts; i++ {
conn, err = net.Dial("tcp", addr)
if err != nil {
time.Sleep(100 * time.Millisecond)
continue
}
return conn
}

t.Fatalf("failed to dial in %v attempts: %v", attempts, err)
return nil
}

func TestClientLimitsMaxConnections(t *testing.T) {
d := &fakeDialer{}
in := &proxy.Config{
Addr: "127.0.0.1",
Port: 5000,
Instances: []proxy.InstanceConnConfig{
{Name: "proj:region:pg"},
},
MaxConnections: 1,
Dialer: d,
}
c, err := proxy.NewClient(context.Background(), &cobra.Command{}, in)
if err != nil {
t.Fatalf("proxy.NewClient error: %v", err)
}
defer c.Close()
go c.Serve(context.Background())

conn1, err1 := net.Dial("tcp", "127.0.0.1:5000")
if err1 != nil {
t.Fatalf("net.Dial error: %v", err1)
}
defer conn1.Close()

conn2, err2 := net.Dial("tcp", "127.0.0.1:5000")
if err2 != nil {
t.Fatalf("net.Dial error: %v", err1)
}
defer conn2.Close()

// try to read to check if the connection is closed
// wait only a second for the result (since nothing is writing to the
// socket)
conn2.SetReadDeadline(time.Now().Add(time.Second))
_, rErr := conn2.Read(make([]byte, 1))
if rErr != io.EOF {
t.Fatalf("conn.Read should return io.EOF, got = %v", rErr)
}

want := 1
if got := d.dialAttempts(); got != want {
t.Fatalf("dial attempts did not match expected, want = %v, got = %v", want, got)
}
}

func TestClientClosesCleanly(t *testing.T) {
in := &proxy.Config{
Addr: "127.0.0.1",
Port: 5000,
Instances: []proxy.InstanceConnConfig{
{Name: "proj:reg:inst"},
},
Dialer: fakeDialer{},
Dialer: &fakeDialer{},
}
c, err := proxy.NewClient(context.Background(), &cobra.Command{}, in)
if err != nil {
Expand All @@ -261,7 +336,7 @@ func TestClosesWithError(t *testing.T) {
Instances: []proxy.InstanceConnConfig{
{Name: "proj:reg:inst"},
},
Dialer: errorDialer{},
Dialer: &errorDialer{},
}
c, err := proxy.NewClient(context.Background(), &cobra.Command{}, in)
if err != nil {
Expand Down Expand Up @@ -315,7 +390,7 @@ func TestClientInitializationWorksRepeatedly(t *testing.T) {
Instances: []proxy.InstanceConnConfig{
{Name: "/projects/proj/locations/region/clusters/clust/instances/inst1"},
},
Dialer: fakeDialer{},
Dialer: &fakeDialer{},
}
c, err := proxy.NewClient(ctx, &cobra.Command{}, in)
if err != nil {
Expand Down

0 comments on commit a5d8ee8

Please sign in to comment.