From e8af4dadd930f35890a782a2b93aa6b4292c6854 Mon Sep 17 00:00:00 2001 From: Eno Compton Date: Fri, 2 Sep 2022 20:02:27 -0600 Subject: [PATCH 1/2] feat: add support for README in FUSE mode (#1312) --- cmd/root.go | 28 ++++++- cmd/root_linux_test.go | 73 ++++++++++++++++++ cmd/root_test.go | 19 ++++- cmd/root_windows_test.go | 36 +++++++++ go.mod | 1 + go.sum | 5 ++ internal/proxy/fuse.go | 79 +++++++++++++++++++ internal/proxy/fuse_darwin.go | 41 ++++++++++ internal/proxy/fuse_linux.go | 33 ++++++++ internal/proxy/fuse_linux_test.go | 43 +++++++++++ internal/proxy/fuse_test.go | 122 ++++++++++++++++++++++++++++++ internal/proxy/fuse_windows.go | 24 ++++++ internal/proxy/proxy.go | 81 ++++++++++++++++++-- 13 files changed, 570 insertions(+), 15 deletions(-) create mode 100644 cmd/root_linux_test.go create mode 100644 cmd/root_windows_test.go create mode 100644 internal/proxy/fuse.go create mode 100644 internal/proxy/fuse_darwin.go create mode 100644 internal/proxy/fuse_linux.go create mode 100644 internal/proxy/fuse_linux_test.go create mode 100644 internal/proxy/fuse_test.go create mode 100644 internal/proxy/fuse_windows.go diff --git a/cmd/root.go b/cmd/root.go index 88c4eec4f..30a34b25c 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -24,6 +24,7 @@ import ( "net/url" "os" "os/signal" + "path/filepath" "strconv" "strings" "syscall" @@ -241,7 +242,15 @@ func NewCommand(opts ...Option) *Command { cmd.PersistentFlags().StringVar(&c.conf.APIEndpointURL, "sqladmin-api-endpoint", "", "API endpoint for all Cloud SQL Admin API requests. (default: https://sqladmin.googleapis.com)") cmd.PersistentFlags().StringVar(&c.conf.QuotaProject, "quota-project", "", - `Specifies the project for Cloud SQL Admin API quota tracking. Must have "serviceusage.service.use" IAM permission.`) + `Specifies the project to use for Cloud SQL Admin API quota tracking. +The IAM principal must have the "serviceusage.services.use" permission +for the given project. See https://cloud.google.com/service-usage/docs/overview and +https://cloud.google.com/storage/docs/requester-pays`) + cmd.PersistentFlags().StringVar(&c.conf.FUSEDir, "fuse", "", + "Mount a directory at the path using FUSE to access Cloud SQL instances.") + cmd.PersistentFlags().StringVar(&c.conf.FUSETempDir, "fuse-tmp-dir", + filepath.Join(os.TempDir(), "csql-tmp"), + "Temp dir for Unix sockets created with FUSE") // Global and per instance flags cmd.PersistentFlags().StringVarP(&c.conf.Addr, "address", "a", "127.0.0.1", @@ -259,11 +268,24 @@ func NewCommand(opts ...Option) *Command { } func parseConfig(cmd *Command, conf *proxy.Config, args []string) error { - // If no instance connection names were provided, error. - if len(args) == 0 { + // If no instance connection names were provided AND FUSE isn't enabled, + // error. + if len(args) == 0 && conf.FUSEDir == "" { return newBadCommandError("missing instance_connection_name (e.g., project:region:instance)") } + if conf.FUSEDir != "" { + if err := proxy.SupportsFUSE(); err != nil { + return newBadCommandError( + fmt.Sprintf("--fuse is not supported: %v", err), + ) + } + } + + if len(args) == 0 && conf.FUSEDir == "" && conf.FUSETempDir != "" { + return newBadCommandError("cannot specify --fuse-tmp-dir without --fuse") + } + userHasSet := func(f string) bool { return cmd.PersistentFlags().Lookup(f).Changed } diff --git a/cmd/root_linux_test.go b/cmd/root_linux_test.go new file mode 100644 index 000000000..0a2e8e499 --- /dev/null +++ b/cmd/root_linux_test.go @@ -0,0 +1,73 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cmd + +import ( + "os" + "path/filepath" + "testing" + + "github.com/spf13/cobra" +) + +func TestNewCommandArgumentsOnLinux(t *testing.T) { + defaultTmp := filepath.Join(os.TempDir(), "csql-tmp") + tcs := []struct { + desc string + args []string + wantDir string + wantTempDir string + }{ + { + desc: "using the fuse flag", + args: []string{"--fuse", "/cloudsql"}, + wantDir: "/cloudsql", + wantTempDir: defaultTmp, + }, + { + desc: "using the fuse temporary directory flag", + args: []string{"--fuse", "/cloudsql", "--fuse-tmp-dir", "/mycooldir"}, + wantDir: "/cloudsql", + wantTempDir: "/mycooldir", + }, + } + + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + c := NewCommand() + // Keep the test output quiet + c.SilenceUsage = true + c.SilenceErrors = true + // Disable execute behavior + c.RunE = func(*cobra.Command, []string) error { + return nil + } + c.SetArgs(tc.args) + + err := c.Execute() + if err != nil { + t.Fatalf("want error = nil, got = %v", err) + } + + if got, want := c.conf.FUSEDir, tc.wantDir; got != want { + t.Fatalf("FUSEDir: want = %v, got = %v", want, got) + } + + if got, want := c.conf.FUSETempDir, tc.wantTempDir; got != want { + t.Fatalf("FUSEDir: want = %v, got = %v", want, got) + } + }) + } +} diff --git a/cmd/root_test.go b/cmd/root_test.go index 5c9d1e3b6..bd96e730f 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -19,6 +19,8 @@ import ( "errors" "net" "net/http" + "os" + "path/filepath" "sync" "testing" "time" @@ -41,11 +43,16 @@ func TestNewCommandArguments(t *testing.T) { if c.Addr == "" { c.Addr = "127.0.0.1" } - if c.Instances == nil { - c.Instances = []proxy.InstanceConnConfig{{}} + if c.FUSEDir == "" { + if c.Instances == nil { + c.Instances = []proxy.InstanceConnConfig{{}} + } + if i := &c.Instances[0]; i.Name == "" { + i.Name = "proj:region:inst" + } } - if i := &c.Instances[0]; i.Name == "" { - i.Name = "proj:region:inst" + if c.FUSETempDir == "" { + c.FUSETempDir = filepath.Join(os.TempDir(), "csql-tmp") } return c } @@ -520,6 +527,10 @@ func TestNewCommandWithErrors(t *testing.T) { desc: "using an invalid url for sqladmin-api-endpoint", args: []string{"--sqladmin-api-endpoint", "https://user:abc{DEf1=ghi@example.com:5432/db?sslmode=require", "proj:region:inst"}, }, + { + desc: "using fuse-tmp-dir without fuse", + args: []string{"--fuse-tmp-dir", "/mydir"}, + }, } for _, tc := range tcs { diff --git a/cmd/root_windows_test.go b/cmd/root_windows_test.go new file mode 100644 index 000000000..78b17674d --- /dev/null +++ b/cmd/root_windows_test.go @@ -0,0 +1,36 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cmd + +import ( + "testing" + + "github.com/spf13/cobra" +) + +func TestWindowsDoesNotSupportFUSE(t *testing.T) { + c := NewCommand() + // Keep the test output quiet + c.SilenceUsage = true + c.SilenceErrors = true + // Disable execute behavior + c.RunE = func(*cobra.Command, []string) error { return nil } + c.SetArgs([]string{"--fuse"}) + + err := c.Execute() + if err == nil { + t.Fatal("want error != nil, got = nil") + } +} diff --git a/go.mod b/go.mod index 1613aca6c..3cb2669de 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/denisenkom/go-mssqldb v0.12.2 github.com/go-sql-driver/mysql v1.6.0 github.com/google/go-cmp v0.5.8 + github.com/hanwen/go-fuse/v2 v2.1.0 github.com/jackc/pgx/v4 v4.17.0 github.com/spf13/cobra v1.5.0 go.opencensus.io v0.23.0 diff --git a/go.sum b/go.sum index f364c103c..53f0fbca1 100644 --- a/go.sum +++ b/go.sum @@ -643,6 +643,9 @@ github.com/grpc-ecosystem/grpc-gateway v1.9.0/go.mod h1:vNeuVxBJEsws4ogUvrchl83t github.com/grpc-ecosystem/grpc-gateway v1.9.5/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/grpc-ecosystem/grpc-gateway/v2 v2.7.0/go.mod h1:hgWBS7lorOAVIJEQMi4ZsPv9hVvWI6+ch50m39Pf2Ks= +github.com/hanwen/go-fuse v1.0.0/go.mod h1:unqXarDXqzAk0rt98O2tVndEPIpUgLD9+rwFisZH3Ok= +github.com/hanwen/go-fuse/v2 v2.1.0 h1:+32ffteETaLYClUj0a3aHjZ1hOPxxaNEHiZiujuDaek= +github.com/hanwen/go-fuse/v2 v2.1.0/go.mod h1:oRyA5eK+pvJyv5otpO/DgccS8y/RvYMaO00GgRLGryc= github.com/hashicorp/consul/api v1.1.0/go.mod h1:VmuI/Lkw1nC05EYQWNKwWGbkg+FbDBtguAZLlVdkD9Q= github.com/hashicorp/consul/api v1.12.0/go.mod h1:6pVBMo0ebnYdt2S3H87XhekM/HHrUoTD2XXb/VrZVy0= github.com/hashicorp/consul/sdk v0.1.1/go.mod h1:VKf9jXwCTEY1QZP2MOLRhb5i/I/ssyNV1vwHyQBF0x8= @@ -798,6 +801,8 @@ github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v0.0.0-20170820004349-d65d576e9348/go.mod h1:B69LEHPfb2qLo0BaaOLcbitczOKLWTsrBG9LczfCD4k= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= diff --git a/internal/proxy/fuse.go b/internal/proxy/fuse.go new file mode 100644 index 000000000..025b1011a --- /dev/null +++ b/internal/proxy/fuse.go @@ -0,0 +1,79 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "context" + "syscall" + + "github.com/hanwen/go-fuse/v2/fs" + "github.com/hanwen/go-fuse/v2/fuse" + "github.com/hanwen/go-fuse/v2/fuse/nodefs" +) + +// readme represents a static read-only text file. +type readme struct { + fs.Inode +} + +const readmeText = ` +When programs attempt to open files in this directory, a remote connection to +the Cloud SQL instance of the same name will be established. + +For example, when you run one of the followg commands, the proxy will initiate a +connection to the corresponding Cloud SQL instance, given you have the correct +IAM permissions. + + mysql -u root -S "/somedir/project:region:instance" + + # or + + psql "host=/somedir/project:region:instance dbname=mydb user=myuser" + +For MySQL, the proxy will create a socket with the instance connection name +(e.g., project:region:instance) in this directory. For Postgres, the proxy will +create a directory with the instance connection name, and create a socket inside +that directory with the special Postgres name: .s.PGSQL.5432. + +Listing the contents of this directory will show all instances with active +connections. +` + +// Getattr implements fs.NodeGetattrer and indicates that this file is a regular +// file. +func (*readme) Getattr(ctx context.Context, f fs.FileHandle, out *fuse.AttrOut) syscall.Errno { + *out = fuse.AttrOut{Attr: fuse.Attr{ + Mode: 0444 | syscall.S_IFREG, + Size: uint64(len(readmeText)), + }} + return fs.OK +} + +// Read implements fs.NodeReader and supports incremental reads. +func (*readme) Read(ctx context.Context, f fs.FileHandle, dest []byte, off int64) (fuse.ReadResult, syscall.Errno) { + end := int(off) + len(dest) + if end > len(readmeText) { + end = len(readmeText) + } + return fuse.ReadResultData([]byte(readmeText[off:end])), fs.OK +} + +// Open implements fs.NodeOpener and supports opening the README as a read-only +// file. +func (*readme) Open(ctx context.Context, mode uint32) (fs.FileHandle, uint32, syscall.Errno) { + df := nodefs.NewDataFile([]byte(readmeText)) + rf := nodefs.NewReadOnlyFile(df) + return rf, 0, fs.OK +} diff --git a/internal/proxy/fuse_darwin.go b/internal/proxy/fuse_darwin.go new file mode 100644 index 000000000..ceb5db269 --- /dev/null +++ b/internal/proxy/fuse_darwin.go @@ -0,0 +1,41 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "errors" + "os" +) + +const ( + macfusePath = "/Library/Filesystems/macfuse.fs/Contents/Resources/mount_macfuse" + osxfusePath = "/Library/Filesystems/osxfuse.fs/Contents/Resources/mount_osxfuse" +) + +// SupportsFUSE checks if macfuse or osxfuse are installed on the host by +// looking for both in their known installation location. +func SupportsFUSE() error { + // This code follows the same strategy as hanwen/go-fuse. + // See https://github.com/hanwen/go-fuse/blob/0f728ba15b38579efefc3dc47821882ca18ffea7/fuse/mount_darwin.go#L121-L124. + + // check for macfuse first (newer version of osxfuse) + if _, err := os.Stat(macfusePath); err != nil { + // if that fails, check for osxfuse next + if _, err := os.Stat(osxfusePath); err != nil { + return errors.New("failed to find osxfuse or macfuse: verify FUSE installation and try again (see https://osxfuse.github.io).") + } + } + return nil +} diff --git a/internal/proxy/fuse_linux.go b/internal/proxy/fuse_linux.go new file mode 100644 index 000000000..264e2acf2 --- /dev/null +++ b/internal/proxy/fuse_linux.go @@ -0,0 +1,33 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "errors" + "os/exec" +) + +// SupportsFUSE checks if the fusermount binary is present in the PATH or a well +// known location. +func SupportsFUSE() error { + // This code follows the same strategy found in hanwen/go-fuse. + // See https://github.com/hanwen/go-fuse/blob/0f728ba15b38579efefc3dc47821882ca18ffea7/fuse/mount_linux.go#L184-L198. + if _, err := exec.LookPath("fusermount"); err != nil { + if _, err := exec.LookPath("/bin/fusermount"); err != nil { + return errors.New("fusermount binary not found in PATH or /bin") + } + } + return nil +} diff --git a/internal/proxy/fuse_linux_test.go b/internal/proxy/fuse_linux_test.go new file mode 100644 index 000000000..b8ad06ea2 --- /dev/null +++ b/internal/proxy/fuse_linux_test.go @@ -0,0 +1,43 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy_test + +import ( + "os" + "testing" + + "github.com/GoogleCloudPlatform/cloud-sql-proxy/v2/internal/proxy" +) + +func TestFUSESupport(t *testing.T) { + if testing.Short() { + t.Skip("skipping fuse tests in short mode.") + } + + removePath := func() func() { + original := os.Getenv("PATH") + os.Unsetenv("PATH") + return func() { os.Setenv("PATH", original) } + } + if err := proxy.SupportsFUSE(); err != nil { + t.Fatalf("expected FUSE to be support (PATH set): %v", err) + } + cleanup := removePath() + defer cleanup() + + if err := proxy.SupportsFUSE(); err != nil { + t.Fatalf("expected FUSE to be supported (PATH unset): %v", err) + } +} diff --git a/internal/proxy/fuse_test.go b/internal/proxy/fuse_test.go new file mode 100644 index 000000000..67b801108 --- /dev/null +++ b/internal/proxy/fuse_test.go @@ -0,0 +1,122 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !windows && !darwin +// +build !windows,!darwin + +package proxy_test + +import ( + "context" + "io/ioutil" + "os" + "path/filepath" + "testing" + "time" + + "github.com/GoogleCloudPlatform/cloud-sql-proxy/v2/internal/log" + "github.com/GoogleCloudPlatform/cloud-sql-proxy/v2/internal/proxy" +) + +func randTmpDir(t interface { + Fatalf(format string, args ...interface{}) +}) string { + name, err := ioutil.TempDir("", "*") + if err != nil { + t.Fatalf("failed to create tmp dir: %v", err) + } + return name +} + +// tryFunc executes the provided function up to maxCount times, sleeping 100ms +// between attempts. +func tryFunc(f func() error, maxCount int) error { + var errCount int + for { + err := f() + if err == nil { + return nil + } + errCount++ + if errCount == maxCount { + return err + } + time.Sleep(100 * time.Millisecond) + } +} + +func TestREADME(t *testing.T) { + if testing.Short() { + t.Skip("skipping fuse tests in short mode.") + } + ctx := context.Background() + + dir := randTmpDir(t) + conf := &proxy.Config{ + FUSEDir: dir, + FUSETempDir: randTmpDir(t), + } + logger := log.NewStdLogger(os.Stdout, os.Stdout) + d := &fakeDialer{} + c, err := proxy.NewClient(ctx, d, logger, conf) + if err != nil { + t.Fatalf("want error = nil, got = %v", err) + } + + ready := make(chan struct{}) + go c.Serve(ctx, func() { close(ready) }) + select { + case <-ready: + case <-time.After(time.Minute): + t.Fatal("proxy.Client failed to start serving") + } + + fi, err := os.Stat(dir) + if err != nil { + t.Fatalf("os.Stat: %v", err) + } + if !fi.IsDir() { + t.Fatalf("fuse mount mode: want = dir, got = %v", fi.Mode()) + } + + entries, err := os.ReadDir(dir) + if err != nil { + t.Fatalf("os.ReadDir: %v", err) + } + if len(entries) != 1 { + t.Fatalf("dir entries: want = 1, got = %v", len(entries)) + } + e := entries[0] + if want, got := "README", e.Name(); want != got { + t.Fatalf("want = %v, got = %v", want, got) + } + + data, err := ioutil.ReadFile(filepath.Join(dir, "README")) + if err != nil { + t.Fatal(err) + } + if len(data) == 0 { + t.Fatalf("expected README data, got no data (dir = %v)", dir) + } + + if cErr := c.Close(); cErr != nil { + t.Fatalf("c.Close(): %v", cErr) + } + + // verify that c.Close unmounts the FUSE server + _, err = ioutil.ReadFile(filepath.Join(dir, "README")) + if err == nil { + t.Fatal("expected ioutil.Readfile to fail, but it succeeded") + } +} diff --git a/internal/proxy/fuse_windows.go b/internal/proxy/fuse_windows.go new file mode 100644 index 000000000..6e5289cf4 --- /dev/null +++ b/internal/proxy/fuse_windows.go @@ -0,0 +1,24 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "errors" +) + +// SupportsFUSE is false on Windows. +func SupportsFUSE() error { + return errors.New("fuse is not supported on Windows") +} diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index d4e4fb2cf..42bae9bd1 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -23,11 +23,14 @@ import ( "strings" "sync" "sync/atomic" + "syscall" "time" "cloud.google.com/go/cloudsqlconn" "github.com/GoogleCloudPlatform/cloud-sql-proxy/v2/cloudsql" "github.com/GoogleCloudPlatform/cloud-sql-proxy/v2/internal/gcloud" + "github.com/hanwen/go-fuse/v2/fs" + "github.com/hanwen/go-fuse/v2/fuse" "golang.org/x/oauth2" ) @@ -83,6 +86,15 @@ type Config struct { // connected to any Instances. If set, takes precedence over Addr and Port. UnixSocket string + // FUSEDir enables a file system in user space at the provided path that + // connects to the requested instance only when a client requests it. + FUSEDir string + + // FUSETempDir sets the temporary directory where the FUSE mount will place + // Unix domain sockets connected to Cloud SQL instances. The temp directory + // is not accessed directly. + FUSETempDir string + // IAMAuthN enables automatic IAM DB Authentication for all instances. // Postgres-only. IAMAuthN bool @@ -243,9 +255,18 @@ type Client struct { waitOnClose time.Duration logger cloudsql.Logger + + // fuseDir specifies the directory where a FUSE server is mounted. The value + // is empty if FUSE is not enabled. + fuseDir string + fuseServer *fuse.Server + + // Inode adds support for FUSE operations. + fs.Inode } -// NewClient completes the initial setup required to get the proxy to a "steady" state. +// NewClient completes the initial setup required to get the proxy to a "steady" +// state. func NewClient(ctx context.Context, d cloudsql.Dialer, l cloudsql.Logger, conf *Config) (*Client, error) { // Check if the caller has configured a dialer. // Otherwise, initialize a new one. @@ -260,6 +281,18 @@ func NewClient(ctx context.Context, d cloudsql.Dialer, l cloudsql.Logger, conf * } } + c := &Client{ + logger: l, + dialer: d, + maxConns: conf.MaxConnections, + waitOnClose: conf.WaitOnClose, + } + + if conf.FUSEDir != "" { + c.fuseDir = conf.FUSEDir + return c, nil + } + for _, inst := range conf.Instances { // Initiate refresh operation and warm the cache. go func(name string) { d.EngineVersion(ctx, name) }(inst.Name) @@ -287,16 +320,29 @@ func NewClient(ctx context.Context, d cloudsql.Dialer, l cloudsql.Logger, conf * l.Infof("[%s] Listening on %s", inst.Name, m.Addr()) mnts = append(mnts, m) } - c := &Client{ - mnts: mnts, - logger: l, - dialer: d, - maxConns: conf.MaxConnections, - waitOnClose: conf.WaitOnClose, - } + c.mnts = mnts return c, nil } +func (c *Client) Readdir(ctx context.Context) (fs.DirStream, syscall.Errno) { + entries := []fuse.DirEntry{ + {Name: "README", Mode: 0555 | fuse.S_IFREG}, + } + return fs.NewListDirStream(entries), fs.OK +} + +// Lookup implements the fs.NodeLookuper interface and returns an index node +// (inode) for a symlink that points to a Unix domain socket. The Unix domain +// socket is connected to the requested Cloud SQL instance. Lookup returns a +// symlink (instead of the socket itself) so that multiple callers all use the +// same Unix socket. +func (c *Client) Lookup(ctx context.Context, instance string, out *fuse.EntryOut) (*fs.Inode, syscall.Errno) { + if instance == "README" { + return c.NewInode(ctx, &readme{}, fs.StableAttr{}), fs.OK + } + return nil, syscall.ENOENT +} + // CheckConnections dials each registered instance and reports any errors that // may have occurred. func (c *Client) CheckConnections(ctx context.Context) error { @@ -347,6 +393,20 @@ func (c *Client) ConnCount() (uint64, uint64) { func (c *Client) Serve(ctx context.Context, notify func()) error { ctx, cancel := context.WithCancel(ctx) defer cancel() + + if c.fuseDir != "" { + srv, err := fs.Mount(c.fuseDir, c, &fs.Options{ + MountOptions: fuse.MountOptions{AllowOther: true}, + }) + if err != nil { + return fmt.Errorf("FUSE mount failed: %q: %v", c.fuseDir, err) + } + c.fuseServer = srv + notify() + <-ctx.Done() + return ctx.Err() + } + exitCh := make(chan error) for _, m := range c.mnts { go func(mnt *socketMount) { @@ -388,6 +448,11 @@ func (m MultiErr) Error() string { // Close triggers the proxyClient to shutdown. func (c *Client) Close() error { var mErr MultiErr + if c.fuseServer != nil { + if err := c.fuseServer.Unmount(); err != nil { + mErr = append(mErr, err) + } + } // First, close all open socket listeners to prevent additional connections. for _, m := range c.mnts { err := m.Close() From f84448841113b5a9801b863d67f627cb9de48b9a Mon Sep 17 00:00:00 2001 From: Eno Compton Date: Wed, 7 Sep 2022 10:57:43 -0600 Subject: [PATCH 2/2] feat: add support for FUSE connections (#1373) This commit also ensures that closing the proxy.Client blocks until all listeners are closed. --- internal/proxy/fuse.go | 16 ++- internal/proxy/fuse_test.go | 266 ++++++++++++++++++++++++++++++----- internal/proxy/proxy.go | 164 +++++++++++++++++++-- internal/proxy/proxy_test.go | 57 +++++--- 4 files changed, 432 insertions(+), 71 deletions(-) diff --git a/internal/proxy/fuse.go b/internal/proxy/fuse.go index 025b1011a..f03125b6f 100644 --- a/internal/proxy/fuse.go +++ b/internal/proxy/fuse.go @@ -23,14 +23,26 @@ import ( "github.com/hanwen/go-fuse/v2/fuse/nodefs" ) +// symlink implements a symbolic link, returning the underlying path when +// Readlink is called. +type symlink struct { + fs.Inode + path string +} + +// Readlink implements fs.NodeReadlinker and returns the symlink's path. +func (s *symlink) Readlink(ctx context.Context) ([]byte, syscall.Errno) { + return []byte(s.path), fs.OK +} + // readme represents a static read-only text file. type readme struct { fs.Inode } const readmeText = ` -When programs attempt to open files in this directory, a remote connection to -the Cloud SQL instance of the same name will be established. +When applications attempt to open files in this directory, a remote connection +to the Cloud SQL instance of the same name will be established. For example, when you run one of the followg commands, the proxy will initiate a connection to the corresponding Cloud SQL instance, given you have the correct diff --git a/internal/proxy/fuse_test.go b/internal/proxy/fuse_test.go index 67b801108..35e1d851e 100644 --- a/internal/proxy/fuse_test.go +++ b/internal/proxy/fuse_test.go @@ -20,13 +20,15 @@ package proxy_test import ( "context" "io/ioutil" + "net" "os" "path/filepath" "testing" "time" - "github.com/GoogleCloudPlatform/cloud-sql-proxy/v2/internal/log" + "github.com/GoogleCloudPlatform/cloud-sql-proxy/v2/cloudsql" "github.com/GoogleCloudPlatform/cloud-sql-proxy/v2/internal/proxy" + "github.com/hanwen/go-fuse/v2/fs" ) func randTmpDir(t interface { @@ -39,48 +41,37 @@ func randTmpDir(t interface { return name } -// tryFunc executes the provided function up to maxCount times, sleeping 100ms -// between attempts. -func tryFunc(f func() error, maxCount int) error { - var errCount int - for { - err := f() - if err == nil { - return nil - } - errCount++ - if errCount == maxCount { - return err +// newTestClient is a convenience function for testing that creates a +// proxy.Client and starts it. The returned cleanup function is also a +// convenience. Callers may choose to ignore it and manually close the client. +func newTestClient(t *testing.T, d cloudsql.Dialer, fuseDir, fuseTempDir string) (*proxy.Client, func()) { + conf := &proxy.Config{FUSEDir: fuseDir, FUSETempDir: fuseTempDir} + c, err := proxy.NewClient(context.Background(), d, testLogger, conf) + if err != nil { + t.Fatalf("want error = nil, got = %v", err) + } + + ready := make(chan struct{}) + go c.Serve(context.Background(), func() { close(ready) }) + select { + case <-ready: + case <-time.Tick(5 * time.Second): + t.Fatal("failed to Serve") + } + return c, func() { + if cErr := c.Close(); cErr != nil { + t.Logf("failed to close client: %v", cErr) } - time.Sleep(100 * time.Millisecond) } } -func TestREADME(t *testing.T) { +func TestFUSEREADME(t *testing.T) { if testing.Short() { t.Skip("skipping fuse tests in short mode.") } - ctx := context.Background() - dir := randTmpDir(t) - conf := &proxy.Config{ - FUSEDir: dir, - FUSETempDir: randTmpDir(t), - } - logger := log.NewStdLogger(os.Stdout, os.Stdout) d := &fakeDialer{} - c, err := proxy.NewClient(ctx, d, logger, conf) - if err != nil { - t.Fatalf("want error = nil, got = %v", err) - } - - ready := make(chan struct{}) - go c.Serve(ctx, func() { close(ready) }) - select { - case <-ready: - case <-time.After(time.Minute): - t.Fatal("proxy.Client failed to start serving") - } + _, cleanup := newTestClient(t, d, dir, randTmpDir(t)) fi, err := os.Stat(dir) if err != nil { @@ -110,13 +101,212 @@ func TestREADME(t *testing.T) { t.Fatalf("expected README data, got no data (dir = %v)", dir) } - if cErr := c.Close(); cErr != nil { - t.Fatalf("c.Close(): %v", cErr) - } + cleanup() // close the client - // verify that c.Close unmounts the FUSE server + // verify that the FUSE server is no longer mounted _, err = ioutil.ReadFile(filepath.Join(dir, "README")) if err == nil { t.Fatal("expected ioutil.Readfile to fail, but it succeeded") } } + +func tryDialUnix(t *testing.T, addr string) net.Conn { + var ( + conn net.Conn + dialErr error + ) + for i := 0; i < 10; i++ { + conn, dialErr = net.Dial("unix", addr) + if conn != nil { + break + } + time.Sleep(100 * time.Millisecond) + } + if dialErr != nil { + t.Fatalf("net.Dial(): %v", dialErr) + } + return conn +} + +func TestFUSEDialInstance(t *testing.T) { + fuseDir := randTmpDir(t) + fuseTempDir := randTmpDir(t) + tcs := []struct { + desc string + wantInstance string + socketPath string + fuseTempDir string + }{ + { + desc: "mysql connections create a Unix socket", + wantInstance: "proj:region:mysql", + socketPath: filepath.Join(fuseDir, "proj:region:mysql"), + fuseTempDir: fuseTempDir, + }, + { + desc: "postgres connections create a directory with a special file", + wantInstance: "proj:region:pg", + socketPath: filepath.Join(fuseDir, "proj:region:pg", ".s.PGSQL.5432"), + fuseTempDir: fuseTempDir, + }, + { + desc: "connecting creates intermediate temp directories", + wantInstance: "proj:region:mysql", + socketPath: filepath.Join(fuseDir, "proj:region:mysql"), + fuseTempDir: filepath.Join(fuseTempDir, "doesntexist"), + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + d := &fakeDialer{} + _, cleanup := newTestClient(t, d, fuseDir, tc.fuseTempDir) + defer cleanup() + + conn := tryDialUnix(t, tc.socketPath) + defer conn.Close() + + var got []string + for i := 0; i < 10; i++ { + got = d.dialedInstances() + if len(got) == 1 { + break + } + time.Sleep(100 * time.Millisecond) + } + if len(got) != 1 { + t.Fatalf("dialed instances len: want = 1, got = %v", got) + } + if want, inst := tc.wantInstance, got[0]; want != inst { + t.Fatalf("instance: want = %v, got = %v", want, inst) + } + + }) + } +} + +func TestFUSEReadDir(t *testing.T) { + fuseDir := randTmpDir(t) + _, cleanup := newTestClient(t, &fakeDialer{}, fuseDir, randTmpDir(t)) + defer cleanup() + + // Initiate a connection so the FUSE server will list it in the dir entries. + conn := tryDialUnix(t, filepath.Join(fuseDir, "proj:reg:mysql")) + defer conn.Close() + + entries, err := os.ReadDir(fuseDir) + if err != nil { + t.Fatalf("os.ReadDir(): %v", err) + } + // len should be README plus the proj:reg:mysql socket + if got, want := len(entries), 2; got != want { + t.Fatalf("want = %v, got = %v", want, got) + } + var names []string + for _, e := range entries { + names = append(names, e.Name()) + } + if names[0] != "README" || names[1] != "proj:reg:mysql" { + t.Fatalf("want = %v, got = %v", []string{"README", "proj:reg:mysql"}, names) + } +} + +func TestFUSEErrors(t *testing.T) { + ctx := context.Background() + d := &fakeDialer{} + c, _ := newTestClient(t, d, randTmpDir(t), randTmpDir(t)) + + // Simulate FUSE file access by invoking Lookup directly to control + // how the socket cache is populated. + _, err := c.Lookup(ctx, "proj:reg:mysql", nil) + if err != fs.OK { + t.Fatalf("proxy.Client.Lookup(): %v", err) + } + + // Close the client to close all open sockets. + if err := c.Close(); err != nil { + t.Fatalf("c.Close(): %v", err) + } + + // Simulate another FUSE file access to directly populated the socket cache. + _, err = c.Lookup(ctx, "proj:reg:mysql", nil) + if err != fs.OK { + t.Fatalf("proxy.Client.Lookup(): %v", err) + } + + // Verify the dialer was called twice, to prove the previous cache entry was + // removed when the socket was closed. + var attempts int + wantAttempts := 2 + for i := 0; i < 10; i++ { + attempts = d.engineVersionAttempts() + if attempts == wantAttempts { + return + } + time.Sleep(100 * time.Millisecond) + } + t.Fatalf("engine version attempts: want = %v, got = %v", wantAttempts, attempts) +} + +func TestFUSEWithBadInstanceName(t *testing.T) { + fuseDir := randTmpDir(t) + d := &fakeDialer{} + _, cleanup := newTestClient(t, d, fuseDir, randTmpDir(t)) + defer cleanup() + + _, dialErr := net.Dial("unix", filepath.Join(fuseDir, "notvalid")) + if dialErr == nil { + t.Fatalf("net.Dial() should fail") + } + + if got := d.engineVersionAttempts(); got > 0 { + t.Fatalf("engine version calls: want = 0, got = %v", got) + } +} + +func TestFUSECheckConnections(t *testing.T) { + fuseDir := randTmpDir(t) + d := &fakeDialer{} + c, cleanup := newTestClient(t, d, fuseDir, randTmpDir(t)) + defer cleanup() + + // first establish a connection to "register" it with the proxy + conn := tryDialUnix(t, filepath.Join(fuseDir, "proj:reg:mysql")) + defer conn.Close() + + if err := c.CheckConnections(context.Background()); err != nil { + t.Fatalf("c.CheckConnections(): %v", err) + } + + // verify the dialer was invoked twice, once for connect, once for check + // connection + var attempts int + wantAttempts := 2 + for i := 0; i < 10; i++ { + attempts = d.dialAttempts() + if attempts == wantAttempts { + return + } + time.Sleep(100 * time.Millisecond) + } + t.Fatalf("dial attempts: want = %v, got = %v", wantAttempts, attempts) +} + +func TestFUSEClose(t *testing.T) { + fuseDir := randTmpDir(t) + d := &fakeDialer{} + c, _ := newTestClient(t, d, fuseDir, randTmpDir(t)) + + // first establish a connection to "register" it with the proxy + conn := tryDialUnix(t, filepath.Join(fuseDir, "proj:reg:mysql")) + defer conn.Close() + + // Close the proxy which should close all listeners + if err := c.Close(); err != nil { + t.Fatalf("c.Close(): %v", err) + } + + _, err := net.Dial("unix", filepath.Join(fuseDir, "proj:reg:mysql")) + if err == nil { + t.Fatal("net.Dial() should fail") + } +} diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 42bae9bd1..7683f1685 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -20,6 +20,8 @@ import ( "io" "net" "os" + "path/filepath" + "regexp" "strings" "sync" "sync/atomic" @@ -34,6 +36,44 @@ import ( "golang.org/x/oauth2" ) +var ( + // Instance connection name is the format :: + // Additionally, we have to support legacy "domain-scoped" projects (e.g. "google.com:PROJECT") + connNameRegex = regexp.MustCompile("([^:]+(:[^:]+)?):([^:]+):([^:]+)") +) + +// connName represents the "instance connection name", in the format +// "project:region:name". Use the "parseConnName" method to initialize this +// struct. +type connName struct { + project string + region string + name string +} + +func (c *connName) String() string { + return fmt.Sprintf("%s:%s:%s", c.project, c.region, c.name) +} + +// parseConnName initializes a new connName struct. +func parseConnName(cn string) (connName, error) { + b := []byte(cn) + m := connNameRegex.FindSubmatch(b) + if m == nil { + return connName{}, fmt.Errorf( + "invalid instance connection name, want = PROJECT:REGION:INSTANCE, got = %v", + cn, + ) + } + + c := connName{ + project: string(m[1]), + region: string(m[3]), + name: string(m[4]), + } + return c, nil +} + // InstanceConnConfig holds the configuration for an individual instance // connection. type InstanceConnConfig struct { @@ -234,6 +274,11 @@ func (c *portConfig) nextDBPort(version string) int { } } +type socketSymlink struct { + socket *socketMount + symlink *symlink +} + // Client proxies connections from a local client to the remote server side // proxy for multiple Cloud SQL instances. type Client struct { @@ -257,9 +302,18 @@ type Client struct { logger cloudsql.Logger // fuseDir specifies the directory where a FUSE server is mounted. The value - // is empty if FUSE is not enabled. - fuseDir string - fuseServer *fuse.Server + // is empty if FUSE is not enabled. The directory holds symlinks to Unix + // domain sockets in the fuseTmpDir. + fuseDir string + fuseTempDir string + // fuseMu protects access to fuseSockets. + fuseMu sync.Mutex + // fuseSockets is a map of instance connection name to socketMount and + // symlink. + fuseSockets map[string]socketSymlink + fuseServerMu sync.Mutex + fuseServer *fuse.Server + fuseWg sync.WaitGroup // Inode adds support for FUSE operations. fs.Inode @@ -289,7 +343,12 @@ func NewClient(ctx context.Context, d cloudsql.Dialer, l cloudsql.Logger, conf * } if conf.FUSEDir != "" { + if err := os.MkdirAll(conf.FUSETempDir, 0777); err != nil { + return nil, err + } c.fuseDir = conf.FUSEDir + c.fuseTempDir = conf.FUSETempDir + c.fuseSockets = map[string]socketSymlink{} return c, nil } @@ -324,10 +383,24 @@ func NewClient(ctx context.Context, d cloudsql.Dialer, l cloudsql.Logger, conf * return c, nil } +// Readdir returns a list of all active Unix sockets in addition to the README. func (c *Client) Readdir(ctx context.Context) (fs.DirStream, syscall.Errno) { entries := []fuse.DirEntry{ {Name: "README", Mode: 0555 | fuse.S_IFREG}, } + var active []string + c.fuseMu.Lock() + for k := range c.fuseSockets { + active = append(active, k) + } + c.fuseMu.Unlock() + + for _, a := range active { + entries = append(entries, fuse.DirEntry{ + Name: a, + Mode: 0777 | syscall.S_IFSOCK, + }) + } return fs.NewListDirStream(entries), fs.OK } @@ -340,7 +413,54 @@ func (c *Client) Lookup(ctx context.Context, instance string, out *fuse.EntryOut if instance == "README" { return c.NewInode(ctx, &readme{}, fs.StableAttr{}), fs.OK } - return nil, syscall.ENOENT + + if _, err := parseConnName(instance); err != nil { + return nil, syscall.ENOENT + } + + c.fuseMu.Lock() + defer c.fuseMu.Unlock() + if l, ok := c.fuseSockets[instance]; ok { + return l.symlink.EmbeddedInode(), fs.OK + } + + version, err := c.dialer.EngineVersion(ctx, instance) + if err != nil { + c.logger.Errorf("could not resolve version for %q: %v", instance, err) + return nil, syscall.ENOENT + } + + s, err := newSocketMount( + ctx, &Config{UnixSocket: c.fuseTempDir}, + nil, InstanceConnConfig{Name: instance}, version, + ) + if err != nil { + c.logger.Errorf("could not create socket for %q: %v", instance, err) + return nil, syscall.ENOENT + } + + c.fuseWg.Add(1) + go func() { + defer c.fuseWg.Done() + sErr := c.serveSocketMount(ctx, s) + if sErr != nil { + c.fuseMu.Lock() + delete(c.fuseSockets, instance) + c.fuseMu.Unlock() + } + }() + + // Return a symlink that points to the actual Unix socket within the + // temporary directory. For Postgres, return a symlink that points to the + // directory which holds the ".s.PGSQL.5432" Unix socket. + sl := &symlink{path: filepath.Join(c.fuseTempDir, instance)} + c.fuseSockets[instance] = socketSymlink{ + socket: s, + symlink: sl, + } + return c.NewInode(ctx, sl, fs.StableAttr{ + Mode: 0777 | fuse.S_IFLNK}, + ), fs.OK } // CheckConnections dials each registered instance and reports any errors that @@ -349,8 +469,17 @@ func (c *Client) CheckConnections(ctx context.Context) error { var ( wg sync.WaitGroup errCh = make(chan error, len(c.mnts)) + mnts = c.mnts ) - for _, mnt := range c.mnts { + if c.fuseDir != "" { + mnts = []*socketMount{} + c.fuseMu.Lock() + for _, m := range c.fuseSockets { + mnts = append(mnts, m.socket) + } + c.fuseMu.Unlock() + } + for _, mnt := range mnts { wg.Add(1) go func(m *socketMount) { defer wg.Done() @@ -401,7 +530,9 @@ func (c *Client) Serve(ctx context.Context, notify func()) error { if err != nil { return fmt.Errorf("FUSE mount failed: %q: %v", c.fuseDir, err) } + c.fuseServerMu.Lock() c.fuseServer = srv + c.fuseServerMu.Unlock() notify() <-ctx.Done() return ctx.Err() @@ -447,19 +578,35 @@ func (m MultiErr) Error() string { // Close triggers the proxyClient to shutdown. func (c *Client) Close() error { + mnts := c.mnts + + c.fuseServerMu.Lock() + hasFuseServer := c.fuseServer != nil + c.fuseServerMu.Unlock() + var mErr MultiErr - if c.fuseServer != nil { + if hasFuseServer { if err := c.fuseServer.Unmount(); err != nil { mErr = append(mErr, err) } + mnts = []*socketMount{} + c.fuseMu.Lock() + for _, m := range c.fuseSockets { + mnts = append(mnts, m.socket) + } + c.fuseMu.Unlock() } + // First, close all open socket listeners to prevent additional connections. - for _, m := range c.mnts { + for _, m := range mnts { err := m.Close() if err != nil { mErr = append(mErr, err) } } + if hasFuseServer { + c.fuseWg.Wait() + } // Next, close the dialer to prevent any additional refreshes. cErr := c.dialer.Close() if cErr != nil { @@ -541,9 +688,10 @@ func (c *Client) serveSocketMount(ctx context.Context, s *socketMount) error { // socketMount is a tcp/unix socket that listens for a Cloud SQL instance. type socketMount struct { + fs.Inode inst string - dialOpts []cloudsqlconn.DialOption listener net.Listener + dialOpts []cloudsqlconn.DialOption } func newSocketMount(ctx context.Context, conf *Config, pc *portConfig, inst InstanceConnConfig, version string) (*socketMount, error) { diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index cfe25f591..f7fe6b78b 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -32,9 +32,13 @@ import ( "github.com/GoogleCloudPlatform/cloud-sql-proxy/v2/internal/proxy" ) +var testLogger = log.NewStdLogger(os.Stdout, os.Stdout) + type fakeDialer struct { - mu sync.Mutex - dialCount int + mu sync.Mutex + dialCount int + engineVersionCount int + instances []string } func (*fakeDialer) Close() error { @@ -47,15 +51,31 @@ func (f *fakeDialer) dialAttempts() int { return f.dialCount } +func (f *fakeDialer) engineVersionAttempts() int { + f.mu.Lock() + defer f.mu.Unlock() + return f.engineVersionCount +} + +func (f *fakeDialer) dialedInstances() []string { + f.mu.Lock() + defer f.mu.Unlock() + return append([]string{}, f.instances...) +} + func (f *fakeDialer) Dial(ctx context.Context, inst string, opts ...cloudsqlconn.DialOption) (net.Conn, error) { f.mu.Lock() defer f.mu.Unlock() f.dialCount++ + f.instances = append(f.instances, inst) c1, _ := net.Pipe() return c1, nil } -func (*fakeDialer) EngineVersion(_ context.Context, inst string) (string, error) { +func (f *fakeDialer) EngineVersion(_ context.Context, inst string) (string, error) { + f.mu.Lock() + defer f.mu.Unlock() + f.engineVersionCount++ switch { case strings.Contains(inst, "pg"): return "POSTGRES_14", nil @@ -242,8 +262,7 @@ func TestClientInitialization(t *testing.T) { for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { - logger := log.NewStdLogger(os.Stdout, os.Stdout) - c, err := proxy.NewClient(ctx, &fakeDialer{}, logger, tc.in) + c, err := proxy.NewClient(ctx, &fakeDialer{}, testLogger, tc.in) if err != nil { t.Fatalf("want error = nil, got = %v", err) } @@ -287,8 +306,7 @@ func TestClientLimitsMaxConnections(t *testing.T) { }, MaxConnections: 1, } - logger := log.NewStdLogger(os.Stdout, os.Stdout) - c, err := proxy.NewClient(context.Background(), d, logger, in) + c, err := proxy.NewClient(context.Background(), d, testLogger, in) if err != nil { t.Fatalf("proxy.NewClient error: %v", err) } @@ -350,7 +368,6 @@ func tryTCPDial(t *testing.T, addr string) net.Conn { } func TestClientCloseWaitsForActiveConnections(t *testing.T) { - logger := log.NewStdLogger(os.Stdout, os.Stdout) in := &proxy.Config{ Addr: "127.0.0.1", Port: 5000, @@ -359,7 +376,7 @@ func TestClientCloseWaitsForActiveConnections(t *testing.T) { }, WaitOnClose: 5 * time.Second, } - c, err := proxy.NewClient(context.Background(), &fakeDialer{}, logger, in) + c, err := proxy.NewClient(context.Background(), &fakeDialer{}, testLogger, in) if err != nil { t.Fatalf("proxy.NewClient error: %v", err) } @@ -389,8 +406,7 @@ func TestClientClosesCleanly(t *testing.T) { {Name: "proj:reg:inst"}, }, } - logger := log.NewStdLogger(os.Stdout, os.Stdout) - c, err := proxy.NewClient(context.Background(), &fakeDialer{}, logger, in) + c, err := proxy.NewClient(context.Background(), &fakeDialer{}, testLogger, in) if err != nil { t.Fatalf("proxy.NewClient error want = nil, got = %v", err) } @@ -412,8 +428,7 @@ func TestClosesWithError(t *testing.T) { {Name: "proj:reg:inst"}, }, } - logger := log.NewStdLogger(os.Stdout, os.Stdout) - c, err := proxy.NewClient(context.Background(), &errorDialer{}, logger, in) + c, err := proxy.NewClient(context.Background(), &errorDialer{}, testLogger, in) if err != nil { t.Fatalf("proxy.NewClient error want = nil, got = %v", err) } @@ -469,14 +484,13 @@ func TestClientInitializationWorksRepeatedly(t *testing.T) { }, } - logger := log.NewStdLogger(os.Stdout, os.Stdout) - c, err := proxy.NewClient(ctx, &fakeDialer{}, logger, in) + c, err := proxy.NewClient(ctx, &fakeDialer{}, testLogger, in) if err != nil { t.Fatalf("want error = nil, got = %v", err) } c.Close() - c, err = proxy.NewClient(ctx, &fakeDialer{}, logger, in) + c, err = proxy.NewClient(ctx, &fakeDialer{}, testLogger, in) if err != nil { t.Fatalf("want error = nil, got = %v", err) } @@ -490,8 +504,7 @@ func TestClientNotifiesCallerOnServe(t *testing.T) { {Name: "proj:region:pg"}, }, } - logger := log.NewStdLogger(os.Stdout, os.Stdout) - c, err := proxy.NewClient(ctx, &fakeDialer{}, logger, in) + c, err := proxy.NewClient(ctx, &fakeDialer{}, testLogger, in) if err != nil { t.Fatalf("want error = nil, got = %v", err) } @@ -515,7 +528,6 @@ func TestClientNotifiesCallerOnServe(t *testing.T) { } func TestClientConnCount(t *testing.T) { - logger := log.NewStdLogger(os.Stdout, os.Stdout) in := &proxy.Config{ Addr: "127.0.0.1", Port: 5000, @@ -525,7 +537,7 @@ func TestClientConnCount(t *testing.T) { MaxConnections: 10, } - c, err := proxy.NewClient(context.Background(), &fakeDialer{}, logger, in) + c, err := proxy.NewClient(context.Background(), &fakeDialer{}, testLogger, in) if err != nil { t.Fatalf("proxy.NewClient error: %v", err) } @@ -558,7 +570,6 @@ func TestClientConnCount(t *testing.T) { } func TestCheckConnections(t *testing.T) { - logger := log.NewStdLogger(os.Stdout, os.Stdout) in := &proxy.Config{ Addr: "127.0.0.1", Port: 5000, @@ -567,7 +578,7 @@ func TestCheckConnections(t *testing.T) { }, } d := &fakeDialer{} - c, err := proxy.NewClient(context.Background(), d, logger, in) + c, err := proxy.NewClient(context.Background(), d, testLogger, in) if err != nil { t.Fatalf("proxy.NewClient error: %v", err) } @@ -591,7 +602,7 @@ func TestCheckConnections(t *testing.T) { }, } ed := &errorDialer{} - c, err = proxy.NewClient(context.Background(), ed, logger, in) + c, err = proxy.NewClient(context.Background(), ed, testLogger, in) if err != nil { t.Fatalf("proxy.NewClient error: %v", err) }