diff --git a/clientconn.go b/clientconn.go index 14ce9c76aa37..8aeac74aebf3 100644 --- a/clientconn.go +++ b/clientconn.go @@ -243,7 +243,7 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * // Only try to parse target when resolver builder is not already set. cc.parsedTarget = parseTarget(cc.target) grpclog.Infof("parsed scheme: %q", cc.parsedTarget.Scheme) - cc.dopts.resolverBuilder = resolver.Get(cc.parsedTarget.Scheme) + cc.dopts.resolverBuilder = cc.getResolver(cc.parsedTarget.Scheme) if cc.dopts.resolverBuilder == nil { // If resolver builder is still nil, the parsed target's scheme is // not registered. Fallback to default resolver and set Endpoint to @@ -253,7 +253,7 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * Scheme: resolver.GetDefaultScheme(), Endpoint: target, } - cc.dopts.resolverBuilder = resolver.Get(cc.parsedTarget.Scheme) + cc.dopts.resolverBuilder = cc.getResolver(cc.parsedTarget.Scheme) } } else { cc.parsedTarget = resolver.Target{Endpoint: target} @@ -1542,3 +1542,12 @@ func (c *channelzChannel) ChannelzMetric() *channelz.ChannelInternalMetric { // Deprecated: This error is never returned by grpc and should not be // referenced by users. var ErrClientConnTimeout = errors.New("grpc: timed out when dialing") + +func (cc *ClientConn) getResolver(scheme string) resolver.Builder { + for _, rb := range cc.dopts.resolvers { + if cc.parsedTarget.Scheme == rb.Scheme() { + return rb + } + } + return resolver.Get(cc.parsedTarget.Scheme) +} diff --git a/clientconn_state_transition_test.go b/clientconn_state_transition_test.go index 1309e8941352..0e9b1e752156 100644 --- a/clientconn_state_transition_test.go +++ b/clientconn_state_transition_test.go @@ -320,7 +320,7 @@ func (s) TestStateTransitions_TriesAllAddrsBeforeTransientFailure(t *testing.T) {Addr: lis1.Addr().String()}, {Addr: lis2.Addr().String()}, }}) - client, err := DialContext(ctx, "this-gets-overwritten", WithInsecure(), WithBalancerName(stateRecordingBalancerName), withResolverBuilder(rb)) + client, err := DialContext(ctx, "whatever:///this-gets-overwritten", WithInsecure(), WithBalancerName(stateRecordingBalancerName), WithResolvers(rb)) if err != nil { t.Fatal(err) } @@ -414,7 +414,7 @@ func (s) TestStateTransitions_MultipleAddrsEntersReady(t *testing.T) { {Addr: lis1.Addr().String()}, {Addr: lis2.Addr().String()}, }}) - client, err := DialContext(ctx, "this-gets-overwritten", WithInsecure(), WithBalancerName(stateRecordingBalancerName), withResolverBuilder(rb)) + client, err := DialContext(ctx, "whatever:///this-gets-overwritten", WithInsecure(), WithBalancerName(stateRecordingBalancerName), WithResolvers(rb)) if err != nil { t.Fatal(err) } diff --git a/clientconn_test.go b/clientconn_test.go index 02d15abbc4f2..23a9c67bce07 100644 --- a/clientconn_test.go +++ b/clientconn_test.go @@ -539,10 +539,10 @@ func (s) TestDial_OneBackoffPerRetryGroup(t *testing.T) { {Addr: lis1.Addr().String()}, {Addr: lis2.Addr().String()}, }}) - client, err := DialContext(ctx, "this-gets-overwritten", + client, err := DialContext(ctx, "whatever:///this-gets-overwritten", WithInsecure(), WithBalancerName(stateRecordingBalancerName), - withResolverBuilder(rb), + WithResolvers(rb), withMinConnectDeadline(getMinConnectTimeout)) if err != nil { t.Fatal(err) @@ -1085,9 +1085,9 @@ func (s) TestUpdateAddresses_RetryFromFirstAddr(t *testing.T) { rb := manual.NewBuilderWithScheme("whatever") rb.InitialState(resolver.State{Addresses: addrsList}) - client, err := Dial("this-gets-overwritten", + client, err := Dial("whatever:///this-gets-overwritten", WithInsecure(), - withResolverBuilder(rb), + WithResolvers(rb), withBackoff(noBackoff{}), WithBalancerName(stateRecordingBalancerName), withMinConnectDeadline(func() time.Duration { return time.Hour })) diff --git a/dialoptions.go b/dialoptions.go index 9af3eef7ab34..2fb70e30f278 100644 --- a/dialoptions.go +++ b/dialoptions.go @@ -73,6 +73,7 @@ type dialOptions struct { // resolver.ResolveNow(). The user will have no need to configure this, but // we need to be able to configure this in tests. resolveNowBackoff func(int) time.Duration + resolvers []resolver.Builder } // DialOption configures how we set up the connection. @@ -589,3 +590,15 @@ func withResolveNowBackoff(f func(int) time.Duration) DialOption { o.resolveNowBackoff = f }) } + +// WithResolvers allows a list of resolver implementations to be registered +// locally with the ClientConn without needing to be globally registered via +// resolver.Register. They will be matched against the scheme used for the +// current Dial only, and will take precedence over the global registry. +// +// This API is EXPERIMENTAL. +func WithResolvers(rs ...resolver.Builder) DialOption { + return newFuncDialOption(func(o *dialOptions) { + o.resolvers = rs + }) +}