diff --git a/internal/xds/bootstrap/bootstrap.go b/internal/xds/bootstrap/bootstrap.go index 8317859e1e95..c725bc1eac97 100644 --- a/internal/xds/bootstrap/bootstrap.go +++ b/internal/xds/bootstrap/bootstrap.go @@ -22,9 +22,11 @@ package bootstrap import ( "bytes" + "context" "encoding/json" "fmt" "maps" + "net" "net/url" "os" "slices" @@ -179,6 +181,7 @@ type ServerConfig struct { // credentials and store it here for easy access. selectedCreds ChannelCreds credsDialOption grpc.DialOption + dialerOption grpc.DialOption cleanups []func() } @@ -223,6 +226,16 @@ func (sc *ServerConfig) CredsDialOption() grpc.DialOption { return sc.credsDialOption } +// DialerOption returns the Dialer function that specifies how to dial the xDS +// server determined by the first supported credentials from the configuration, +// as a dial option. +// +// TODO(https://github.com/grpc/grpc-go/issues/7661): change ServerConfig type +// to have a single method that returns all configured dial options. +func (sc *ServerConfig) DialerOption() grpc.DialOption { + return sc.dialerOption +} + // Cleanups returns a collection of functions to be called when the xDS client // for this server is closed. Allows cleaning up resources created specifically // for this server. @@ -275,6 +288,12 @@ func (sc *ServerConfig) MarshalJSON() ([]byte, error) { return json.Marshal(server) } +// dialer captures the Dialer method specified via the credentials bundle. +type dialer interface { + // Dialer specifies how to dial the xDS server. + Dialer(context.Context, string) (net.Conn, error) +} + // UnmarshalJSON takes the json data (a server) and unmarshals it to the struct. func (sc *ServerConfig) UnmarshalJSON(data []byte) error { server := serverConfigJSON{} @@ -298,6 +317,9 @@ func (sc *ServerConfig) UnmarshalJSON(data []byte) error { } sc.selectedCreds = cc sc.credsDialOption = grpc.WithCredentialsBundle(bundle) + if d, ok := bundle.(dialer); ok { + sc.dialerOption = grpc.WithContextDialer(d.Dialer) + } sc.cleanups = append(sc.cleanups, cancel) break } diff --git a/test/xds/xds_client_custom_dialer_test.go b/test/xds/xds_client_custom_dialer_test.go new file mode 100644 index 000000000000..7d6832ddec5f --- /dev/null +++ b/test/xds/xds_client_custom_dialer_test.go @@ -0,0 +1,154 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package xds_test + +import ( + "context" + "encoding/json" + "fmt" + "net" + "testing" + + "github.com/google/uuid" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/internal" + "google.golang.org/grpc/internal/stubserver" + "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/internal/testutils/xds/e2e" + internalbootstrap "google.golang.org/grpc/internal/xds/bootstrap" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/xds/bootstrap" + + testgrpc "google.golang.org/grpc/interop/grpc_testing" + testpb "google.golang.org/grpc/interop/grpc_testing" +) + +const testDialerCredsBuilderName = "test_dialer_creds" + +// testDialerCredsBuilder implements the `Credentials` interface defined in +// package `xds/bootstrap` and encapsulates an insecure credential with a +// custom Dialer that specifies how to dial the xDS server. +type testDialerCredsBuilder struct { + dialerCalled chan struct{} +} + +func (t *testDialerCredsBuilder) Build(config json.RawMessage) (credentials.Bundle, func(), error) { + cfg := &struct { + MgmtServerAddress string `json:"mgmt_server_address"` + }{} + if err := json.Unmarshal(config, &cfg); err != nil { + return nil, func() {}, fmt.Errorf("failed to unmarshal config: %v", err) + } + return &testDialerCredsBundle{insecure.NewBundle(), t.dialerCalled, cfg.MgmtServerAddress}, func() {}, nil +} + +func (t *testDialerCredsBuilder) Name() string { + return testDialerCredsBuilderName +} + +// testDialerCredsBundle implements the `Bundle` interface defined in package +// `credentials` and encapsulates an insecure credential with a custom Dialer +// that specifies how to dial the xDS server. +type testDialerCredsBundle struct { + credentials.Bundle + dialerCalled chan struct{} + mgmtServerAddress string +} + +// Dialer specifies how to dial the xDS management server. +func (t *testDialerCredsBundle) Dialer(context.Context, string) (net.Conn, error) { + close(t.dialerCalled) + // Create a pass-through connection (no-op) to the xDS management server. + return net.Dial("tcp", t.mgmtServerAddress) +} + +func (s) TestClientCustomDialerFromCredentialsBundle(t *testing.T) { + // Create and register the credentials bundle builder. + credsBuilder := &testDialerCredsBuilder{dialerCalled: make(chan struct{})} + bootstrap.RegisterCredentials(credsBuilder) + + // Start an xDS management server. + mgmtServer := e2e.StartManagementServer(t, e2e.ManagementServerOptions{}) + + // Create bootstrap configuration pointing to the above management server. + nodeID := uuid.New().String() + bc, err := internalbootstrap.NewContentsForTesting(internalbootstrap.ConfigOptionsForTesting{ + Servers: []byte(fmt.Sprintf(`[{ + "server_uri": %q, + "channel_creds": [{ + "type": %q, + "config": {"mgmt_server_address": %q} + }] + }]`, mgmtServer.Address, testDialerCredsBuilderName, mgmtServer.Address)), + Node: []byte(fmt.Sprintf(`{"id": "%s"}`, nodeID)), + }) + if err != nil { + t.Fatalf("Failed to create bootstrap configuration: %v", err) + } + + // Create an xDS resolver with the above bootstrap configuration. + var resolverBuilder resolver.Builder + if newResolver := internal.NewXDSResolverWithConfigForTesting; newResolver != nil { + resolverBuilder, err = newResolver.(func([]byte) (resolver.Builder, error))(bc) + if err != nil { + t.Fatalf("Failed to create xDS resolver for testing: %v", err) + } + } + + // Spin up a test backend. + server := stubserver.StartTestService(t, nil) + defer server.Stop() + + // Configure client side xDS resources on the management server. + const serviceName = "my-service-client-side-xds" + resources := e2e.DefaultClientResources(e2e.ResourceParams{ + DialTarget: serviceName, + NodeID: nodeID, + Host: "localhost", + Port: testutils.ParsePort(t, server.Address), + SecLevel: e2e.SecurityLevelNone, + }) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if err := mgmtServer.Update(ctx, resources); err != nil { + t.Fatal(err) + } + + // Create a ClientConn and make a successful RPC. The insecure transport credentials passed into + // the gRPC.NewClient is the credentials for the data plane communication with the test backend. + cc, err := grpc.NewClient(fmt.Sprintf("xds:///%s", serviceName), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithResolvers(resolverBuilder)) + if err != nil { + t.Fatalf("failed to dial local test server: %v", err) + } + defer cc.Close() + + client := testgrpc.NewTestServiceClient(cc) + if _, err := client.EmptyCall(ctx, &testpb.Empty{}); err != nil { + t.Fatalf("EmptyCall() failed: %v", err) + } + + // Verify that the custom dialer was called. + select { + case <-ctx.Done(): + t.Fatalf("Timeout when waiting for custom dialer to be called") + case <-credsBuilder.dialerCalled: + } +} diff --git a/xds/internal/xdsclient/transport/transport.go b/xds/internal/xdsclient/transport/transport.go index 0bc0d386802d..134a9519f19f 100644 --- a/xds/internal/xdsclient/transport/transport.go +++ b/xds/internal/xdsclient/transport/transport.go @@ -202,6 +202,9 @@ func New(opts Options) (*Transport, error) { Timeout: 20 * time.Second, }), } + if dialerOpts := opts.ServerCfg.DialerOption(); dialerOpts != nil { + dopts = append(dopts, dialerOpts) + } grpcNewClient := transportinternal.GRPCNewClient.(func(string, ...grpc.DialOption) (*grpc.ClientConn, error)) cc, err := grpcNewClient(opts.ServerCfg.ServerURI(), dopts...) if err != nil { diff --git a/xds/internal/xdsclient/transport/transport_test.go b/xds/internal/xdsclient/transport/transport_test.go index 24aad924bd92..6c2c1f2835e2 100644 --- a/xds/internal/xdsclient/transport/transport_test.go +++ b/xds/internal/xdsclient/transport/transport_test.go @@ -18,10 +18,16 @@ package transport_test import ( + "context" + "encoding/json" + "net" "testing" "google.golang.org/grpc" - "google.golang.org/grpc/internal/xds/bootstrap" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" + internalbootstrap "google.golang.org/grpc/internal/xds/bootstrap" + "google.golang.org/grpc/xds/bootstrap" "google.golang.org/grpc/xds/internal/xdsclient/transport" "google.golang.org/grpc/xds/internal/xdsclient/transport/internal" @@ -39,7 +45,7 @@ func (s) TestNewWithGRPCDial(t *testing.T) { internal.GRPCNewClient = customDialer defer func() { internal.GRPCNewClient = oldDial }() - serverCfg, err := bootstrap.ServerConfigForTesting(bootstrap.ServerConfigTestingOptions{URI: "server-address"}) + serverCfg, err := internalbootstrap.ServerConfigForTesting(internalbootstrap.ServerConfigTestingOptions{URI: "server-address"}) if err != nil { t.Fatalf("Failed to create server config for testing: %v", err) } @@ -82,3 +88,80 @@ func (s) TestNewWithGRPCDial(t *testing.T) { t.Fatalf("transport.New(%+v) custom dialer called = true, want false", opts) } } + +const testDialerCredsBuilderName = "test_dialer_creds" + +// testDialerCredsBuilder implements the `Credentials` interface defined in +// package `xds/bootstrap` and encapsulates an insecure credential with a +// custom Dialer that specifies how to dial the xDS server. +type testDialerCredsBuilder struct{} + +func (t *testDialerCredsBuilder) Build(json.RawMessage) (credentials.Bundle, func(), error) { + return &testDialerCredsBundle{insecure.NewBundle()}, func() {}, nil +} + +func (t *testDialerCredsBuilder) Name() string { + return testDialerCredsBuilderName +} + +// testDialerCredsBundle implements the `Bundle` interface defined in package +// `credentials` and encapsulates an insecure credential with a custom Dialer +// that specifies how to dial the xDS server. +type testDialerCredsBundle struct { + credentials.Bundle +} + +func (t *testDialerCredsBundle) Dialer(context.Context, string) (net.Conn, error) { + return nil, nil +} + +func (s) TestNewWithDialerFromCredentialsBundle(t *testing.T) { + // Override grpc.NewClient with a custom one. + doptsLen := 0 + customGRPCNewClient := func(target string, opts ...grpc.DialOption) (*grpc.ClientConn, error) { + doptsLen = len(opts) + return grpc.NewClient(target, opts...) + } + oldGRPCNewClient := internal.GRPCNewClient + internal.GRPCNewClient = customGRPCNewClient + defer func() { internal.GRPCNewClient = oldGRPCNewClient }() + + bootstrap.RegisterCredentials(&testDialerCredsBuilder{}) + serverCfg, err := internalbootstrap.ServerConfigForTesting(internalbootstrap.ServerConfigTestingOptions{ + URI: "trafficdirector.googleapis.com:443", + ChannelCreds: []internalbootstrap.ChannelCreds{{Type: testDialerCredsBuilderName}}, + }) + if err != nil { + t.Fatalf("Failed to create server config for testing: %v", err) + } + if serverCfg.DialerOption() == nil { + t.Fatalf("Dialer for xDS transport in server config for testing is nil, want non-nil") + } + // Create a new transport. + opts := transport.Options{ + ServerCfg: serverCfg, + NodeProto: &v3corepb.Node{}, + OnRecvHandler: func(update transport.ResourceUpdate, onDone func()) error { + onDone() + return nil + }, + OnErrorHandler: func(error) {}, + OnSendHandler: func(*transport.ResourceSendInfo) {}, + } + c, err := transport.New(opts) + defer func() { + if c != nil { + c.Close() + } + }() + if err != nil { + t.Fatalf("transport.New(%v) failed: %v", opts, err) + } + // Verify there are three dial options passed to the custom grpc.NewClient. + // The first is opts.ServerCfg.CredsDialOption(), the second is + // grpc.WithKeepaliveParams(), and the third is opts.ServerCfg.DialerOption() + // from the credentials bundle. + if doptsLen != 3 { + t.Fatalf("transport.New(%v) custom grpc.NewClient called with %d dial options, want 3", opts, doptsLen) + } +}