diff --git a/balancer/endpointsharding/endpointsharding.go b/balancer/endpointsharding/endpointsharding.go new file mode 100644 index 000000000000..c3d061e47b06 --- /dev/null +++ b/balancer/endpointsharding/endpointsharding.go @@ -0,0 +1,293 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package endpointsharding implements a load balancing policy that manages +// homogenous child policies each owning a single endpoint. +// +// # Experimental +// +// Notice: This package is EXPERIMENTAL and may be changed or removed in a +// later release. +package endpointsharding + +import ( + "encoding/json" + "errors" + "fmt" + "sync" + "sync/atomic" + + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/base" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/internal/balancer/gracefulswitch" + "google.golang.org/grpc/internal/grpcrand" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/serviceconfig" +) + +// ChildState is the balancer state of a child along with the endpoint which +// identifies the child balancer. +type ChildState struct { + Endpoint resolver.Endpoint + State balancer.State +} + +// NewBalancer returns a load balancing policy that manages homogenous child +// policies each owning a single endpoint. +func NewBalancer(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { + es := &endpointSharding{ + cc: cc, + bOpts: opts, + } + es.children.Store(resolver.NewEndpointMap()) + return es +} + +// endpointSharding is a balancer that wraps child balancers. It creates a child +// balancer with child config for every unique Endpoint received. It updates the +// child states on any update from parent or child. +type endpointSharding struct { + cc balancer.ClientConn + bOpts balancer.BuildOptions + + children atomic.Pointer[resolver.EndpointMap] + + // inhibitChildUpdates is set during UpdateClientConnState/ResolverError + // calls (calls to children will each produce an update, only want one + // update). + inhibitChildUpdates atomic.Bool + + mu sync.Mutex // Sync updateState callouts and childState recent state updates +} + +// UpdateClientConnState creates a child for new endpoints and deletes children +// for endpoints that are no longer present. It also updates all the children, +// and sends a single synchronous update of the childrens' aggregated state at +// the end of the UpdateClientConnState operation. If any endpoint has no +// addresses, returns error without forwarding any updates. Otherwise returns +// first error found from a child, but fully processes the new update. +func (es *endpointSharding) UpdateClientConnState(state balancer.ClientConnState) error { + if len(state.ResolverState.Endpoints) == 0 { + return errors.New("endpoints list is empty") + } + // Check/return early if any endpoints have no addresses. + // TODO: make this configurable if needed. + for i, endpoint := range state.ResolverState.Endpoints { + if len(endpoint.Addresses) == 0 { + return fmt.Errorf("endpoint %d has empty addresses", i) + } + } + + es.inhibitChildUpdates.Store(true) + defer func() { + es.inhibitChildUpdates.Store(false) + es.updateState() + }() + var ret error + + children := es.children.Load() + newChildren := resolver.NewEndpointMap() + + // Update/Create new children. + for _, endpoint := range state.ResolverState.Endpoints { + if _, ok := newChildren.Get(endpoint); ok { + // Endpoint child was already created, continue to avoid duplicate + // update. + continue + } + var bal *balancerWrapper + if child, ok := children.Get(endpoint); ok { + bal = child.(*balancerWrapper) + } else { + bal = &balancerWrapper{ + childState: ChildState{Endpoint: endpoint}, + ClientConn: es.cc, + es: es, + } + bal.Balancer = gracefulswitch.NewBalancer(bal, es.bOpts) + } + newChildren.Set(endpoint, bal) + if err := bal.UpdateClientConnState(balancer.ClientConnState{ + BalancerConfig: state.BalancerConfig, + ResolverState: resolver.State{ + Endpoints: []resolver.Endpoint{endpoint}, + Attributes: state.ResolverState.Attributes, + }, + }); err != nil && ret == nil { + // Return first error found, and always commit full processing of + // updating children. If desired to process more specific errors + // across all endpoints, caller should make these specific + // validations, this is a current limitation for simplicities sake. + ret = err + } + } + // Delete old children that are no longer present. + for _, e := range children.Keys() { + child, _ := children.Get(e) + bal := child.(balancer.Balancer) + if _, ok := newChildren.Get(e); !ok { + bal.Close() + } + } + es.children.Store(newChildren) + return ret +} + +// ResolverError forwards the resolver error to all of the endpointSharding's +// children and sends a single synchronous update of the childStates at the end +// of the ResolverError operation. +func (es *endpointSharding) ResolverError(err error) { + es.inhibitChildUpdates.Store(true) + defer func() { + es.inhibitChildUpdates.Store(false) + es.updateState() + }() + children := es.children.Load() + for _, child := range children.Values() { + bal := child.(balancer.Balancer) + bal.ResolverError(err) + } +} + +func (es *endpointSharding) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) { + // UpdateSubConnState is deprecated. +} + +func (es *endpointSharding) Close() { + children := es.children.Load() + for _, child := range children.Values() { + bal := child.(balancer.Balancer) + bal.Close() + } +} + +// updateState updates this component's state. It sends the aggregated state, +// and a picker with round robin behavior with all the child states present if +// needed. +func (es *endpointSharding) updateState() { + if es.inhibitChildUpdates.Load() { + return + } + var readyPickers, connectingPickers, idlePickers, transientFailurePickers []balancer.Picker + + es.mu.Lock() + defer es.mu.Unlock() + + children := es.children.Load() + childStates := make([]ChildState, 0, children.Len()) + + for _, child := range children.Values() { + bw := child.(*balancerWrapper) + childState := bw.childState + childStates = append(childStates, childState) + childPicker := childState.State.Picker + switch childState.State.ConnectivityState { + case connectivity.Ready: + readyPickers = append(readyPickers, childPicker) + case connectivity.Connecting: + connectingPickers = append(connectingPickers, childPicker) + case connectivity.Idle: + idlePickers = append(idlePickers, childPicker) + case connectivity.TransientFailure: + transientFailurePickers = append(transientFailurePickers, childPicker) + // connectivity.Shutdown shouldn't appear. + } + } + + // Construct the round robin picker based off the aggregated state. Whatever + // the aggregated state, use the pickers present that are currently in that + // state only. + var aggState connectivity.State + var pickers []balancer.Picker + if len(readyPickers) >= 1 { + aggState = connectivity.Ready + pickers = readyPickers + } else if len(connectingPickers) >= 1 { + aggState = connectivity.Connecting + pickers = connectingPickers + } else if len(idlePickers) >= 1 { + aggState = connectivity.Idle + pickers = idlePickers + } else if len(transientFailurePickers) >= 1 { + aggState = connectivity.TransientFailure + pickers = transientFailurePickers + } else { + aggState = connectivity.TransientFailure + pickers = []balancer.Picker{base.NewErrPicker(errors.New("no children to pick from"))} + } // No children (resolver error before valid update). + p := &pickerWithChildStates{ + pickers: pickers, + childStates: childStates, + next: uint32(grpcrand.Intn(len(pickers))), + } + es.cc.UpdateState(balancer.State{ + ConnectivityState: aggState, + Picker: p, + }) +} + +// pickerWithChildStates delegates to the pickers it holds in a round robin +// fashion. It also contains the childStates of all the endpointSharding's +// children. +type pickerWithChildStates struct { + pickers []balancer.Picker + childStates []ChildState + next uint32 +} + +func (p *pickerWithChildStates) Pick(info balancer.PickInfo) (balancer.PickResult, error) { + nextIndex := atomic.AddUint32(&p.next, 1) + picker := p.pickers[nextIndex%uint32(len(p.pickers))] + return picker.Pick(info) +} + +// ChildStatesFromPicker returns the state of all the children managed by the +// endpoint sharding balancer that created this picker. +func ChildStatesFromPicker(picker balancer.Picker) []ChildState { + p, ok := picker.(*pickerWithChildStates) + if !ok { + return nil + } + return p.childStates +} + +// balancerWrapper is a wrapper of a balancer. It ID's a child balancer by +// endpoint, and persists recent child balancer state. +type balancerWrapper struct { + balancer.Balancer // Simply forward balancer.Balancer operations. + balancer.ClientConn // embed to intercept UpdateState, doesn't deal with SubConns + + es *endpointSharding + + childState ChildState +} + +func (bw *balancerWrapper) UpdateState(state balancer.State) { + bw.es.mu.Lock() + bw.childState.State = state + bw.es.mu.Unlock() + bw.es.updateState() +} + +func ParseConfig(cfg json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { + return gracefulswitch.ParseConfig(cfg) +} + +// PickFirstConfig is a pick first config without shuffling enabled. +const PickFirstConfig = "[{\"pick_first\": {}}]" diff --git a/balancer/endpointsharding/endpointsharding_test.go b/balancer/endpointsharding/endpointsharding_test.go new file mode 100644 index 000000000000..6b23063b5d9c --- /dev/null +++ b/balancer/endpointsharding/endpointsharding_test.go @@ -0,0 +1,159 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package endpointsharding + +import ( + "context" + "encoding/json" + "fmt" + "log" + "testing" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/internal" + "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/internal/stubserver" + "google.golang.org/grpc/internal/testutils/roundrobin" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/resolver/manual" + "google.golang.org/grpc/serviceconfig" + + testgrpc "google.golang.org/grpc/interop/grpc_testing" +) + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +var gracefulSwitchPickFirst serviceconfig.LoadBalancingConfig + +var logger = grpclog.Component("endpoint-sharding-test") + +func init() { + var err error + gracefulSwitchPickFirst, err = ParseConfig(json.RawMessage(PickFirstConfig)) + if err != nil { + logger.Fatal(err) + } + balancer.Register(fakePetioleBuilder{}) +} + +const fakePetioleName = "fake_petiole" + +type fakePetioleBuilder struct{} + +func (fakePetioleBuilder) Name() string { + return fakePetioleName +} + +func (fakePetioleBuilder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { + fp := &fakePetiole{ + ClientConn: cc, + bOpts: opts, + } + fp.Balancer = NewBalancer(fp, opts) + return fp +} + +func (fakePetioleBuilder) ParseConfig(s json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { + return nil, nil +} + +// fakePetiole is a load balancer that wraps the endpointShardingBalancer, and +// forwards ClientConnUpdates with a child config of graceful switch that wraps +// pick first. It also intercepts UpdateState to make sure it can access the +// child state maintained by EndpointSharding. +type fakePetiole struct { + balancer.Balancer + balancer.ClientConn + bOpts balancer.BuildOptions +} + +func (fp *fakePetiole) UpdateClientConnState(state balancer.ClientConnState) error { + if el := state.ResolverState.Endpoints; len(el) != 2 { + return fmt.Errorf("UpdateClientConnState wants two endpoints, got: %v", el) + } + + return fp.Balancer.UpdateClientConnState(balancer.ClientConnState{ + BalancerConfig: gracefulSwitchPickFirst, + ResolverState: state.ResolverState, + }) +} + +func (fp *fakePetiole) UpdateState(state balancer.State) { + childStates := ChildStatesFromPicker(state.Picker) + // Both child states should be present in the child picker. States and + // picker change over the lifecycle of test, but there should always be two. + if len(childStates) != 2 { + logger.Fatal(fmt.Errorf("length of child states received: %v, want 2", len(childStates))) + } + + fp.ClientConn.UpdateState(state) +} + +// TestEndpointShardingBasic tests the basic functionality of the endpoint +// sharding balancer. It specifies a petiole policy that is essentially a +// wrapper around the endpoint sharder. Two backends are started, with each +// backend's address specified in an endpoint. The petiole does not have a +// special picker, so it should fallback to the default behavior, which is to +// round_robin amongst the endpoint children that are in the aggregated state. +// It also verifies the petiole has access to the raw child state in case it +// wants to implement a custom picker. +func (s) TestEndpointShardingBasic(t *testing.T) { + backend1 := stubserver.StartTestService(t, nil) + defer backend1.Stop() + backend2 := stubserver.StartTestService(t, nil) + defer backend2.Stop() + + mr := manual.NewBuilderWithScheme("e2e-test") + defer mr.Close() + + json := `{"loadBalancingConfig": [{"fake_petiole":{}}]}` + sc := internal.ParseServiceConfig.(func(string) *serviceconfig.ParseResult)(json) + mr.InitialState(resolver.State{ + Endpoints: []resolver.Endpoint{ + {Addresses: []resolver.Address{{Addr: backend1.Address}}}, + {Addresses: []resolver.Address{{Addr: backend2.Address}}}, + }, + ServiceConfig: sc, + }) + + cc, err := grpc.Dial(mr.Scheme()+":///", grpc.WithResolvers(mr), grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + log.Fatalf("Failed to dial: %v", err) + } + defer cc.Close() + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + client := testgrpc.NewTestServiceClient(cc) + // Assert a round robin distribution between the two spun up backends. This + // requires a poll and eventual consistency as both endpoint children do not + // start in state READY. + if err = roundrobin.CheckRoundRobinRPCs(ctx, client, []resolver.Address{{Addr: backend1.Address}, {Addr: backend2.Address}}); err != nil { + t.Fatalf("error in expected round robin: %v", err) + } +} diff --git a/examples/examples_test.sh b/examples/examples_test.sh index 250840205abb..5e95120498c7 100755 --- a/examples/examples_test.sh +++ b/examples/examples_test.sh @@ -55,6 +55,7 @@ EXAMPLES=( "features/authz" "features/cancellation" "features/compression" + "features/customloadbalancer" "features/deadline" "features/encryption/TLS" "features/error_details" @@ -109,6 +110,7 @@ declare -A EXPECTED_SERVER_OUTPUT=( ["features/authz"]="unary echoing message \"hello world\"" ["features/cancellation"]="server: error receiving from stream: rpc error: code = Canceled desc = context canceled" ["features/compression"]="UnaryEcho called with message \"compress\"" + ["features/customloadbalancer"]="serving on localhost:50051" ["features/deadline"]="" ["features/encryption/TLS"]="" ["features/error_details"]="" @@ -132,6 +134,7 @@ declare -A EXPECTED_CLIENT_OUTPUT=( ["features/authz"]="UnaryEcho: hello world" ["features/cancellation"]="cancelling context" ["features/compression"]="UnaryEcho call returned \"compress\", " + ["features/customloadbalancer"]="Successful multiple iterations of 1:2 ratio" ["features/deadline"]="wanted = DeadlineExceeded, got = DeadlineExceeded" ["features/encryption/TLS"]="UnaryEcho: hello world" ["features/error_details"]="Greeting: Hello world" diff --git a/examples/features/customloadbalancer/README.md b/examples/features/customloadbalancer/README.md new file mode 100644 index 000000000000..db39bf20b014 --- /dev/null +++ b/examples/features/customloadbalancer/README.md @@ -0,0 +1,52 @@ +# Custom Load Balancer + +This example shows how to deploy a custom load balancer in a `ClientConn`. + +## Try it + +``` +go run server/main.go +``` + +``` +go run client/main.go +``` + +## Explanation + +Two echo servers are serving on "localhost:20000" and "localhost:20001". They +will include their serving address in the response. So the server on +"localhost:20001" will reply to the RPC with `this is +examples/customloadbalancing (from localhost:20001)`. + +A client is created, to connect to both of these servers (they get both server +addresses from the name resolver in two separate endpoints). The client is +configured with the load balancer specified in the service config, which in this +case is custom_round_robin. + +### custom_round_robin + +The client is configured to use `custom_round_robin`. `custom_round_robin` +creates a pick first child for every endpoint it receives. It waits until both +pick first children become ready, then defers to the first pick first child's +picker, choosing the connection to localhost:20000, except every chooseSecond +times, where it defers to second pick first child's picker, choosing the +connection to localhost:20001 (or vice versa). + +`custom_round_robin` is written as a delegating policy wrapping `pick_first` +load balancers, one for every endpoint received. This is the intended way a user +written custom lb should be specified, as pick first will contain a lot of +useful functionality, such as Sticky Transient Failure, Happy Eyeballs, and +Health Checking. + +``` +this is examples/customloadbalancing (from localhost:50050) +this is examples/customloadbalancing (from localhost:50050) +this is examples/customloadbalancing (from localhost:50051) +this is examples/customloadbalancing (from localhost:50050) +this is examples/customloadbalancing (from localhost:50050) +this is examples/customloadbalancing (from localhost:50051) +this is examples/customloadbalancing (from localhost:50050) +this is examples/customloadbalancing (from localhost:50050) +this is examples/customloadbalancing (from localhost:50051) +``` diff --git a/examples/features/customloadbalancer/client/customroundrobin/customroundrobin.go b/examples/features/customloadbalancer/client/customroundrobin/customroundrobin.go new file mode 100644 index 000000000000..60e10743a131 --- /dev/null +++ b/examples/features/customloadbalancer/client/customroundrobin/customroundrobin.go @@ -0,0 +1,157 @@ +/* + * + * Copyright 2023 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package customroundrobin + +import ( + "encoding/json" + "fmt" + "sync/atomic" + + _ "google.golang.org/grpc" // to register pick_first + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/endpointsharding" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/serviceconfig" +) + +var gracefulSwitchPickFirst serviceconfig.LoadBalancingConfig + +func init() { + balancer.Register(customRoundRobinBuilder{}) + var err error + gracefulSwitchPickFirst, err = endpointsharding.ParseConfig(json.RawMessage(endpointsharding.PickFirstConfig)) + if err != nil { + logger.Fatal(err) + } +} + +const customRRName = "custom_round_robin" + +type customRRConfig struct { + serviceconfig.LoadBalancingConfig `json:"-"` + + // ChooseSecond represents how often pick iterations choose the second + // SubConn in the list. Defaults to 3. If 0 never choose the second SubConn. + ChooseSecond uint32 `json:"chooseSecond,omitempty"` +} + +type customRoundRobinBuilder struct{} + +func (customRoundRobinBuilder) ParseConfig(s json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { + lbConfig := &customRRConfig{ + ChooseSecond: 3, + } + + if err := json.Unmarshal(s, lbConfig); err != nil { + return nil, fmt.Errorf("custom-round-robin: unable to unmarshal customRRConfig: %v", err) + } + return lbConfig, nil +} + +func (customRoundRobinBuilder) Name() string { + return customRRName +} + +func (customRoundRobinBuilder) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer { + crr := &customRoundRobin{ + ClientConn: cc, + bOpts: bOpts, + } + crr.Balancer = endpointsharding.NewBalancer(crr, bOpts) + return crr +} + +var logger = grpclog.Component("example") + +type customRoundRobin struct { + // All state and operations on this balancer are either initialized at build + // time and read only after, or are only accessed as part of its + // balancer.Balancer API (UpdateState from children only comes in from + // balancer.Balancer calls as well, and children are called one at a time), + // in which calls are guaranteed to come synchronously. Thus, no extra + // synchronization is required in this balancer. + balancer.Balancer + balancer.ClientConn + bOpts balancer.BuildOptions + + cfg atomic.Pointer[customRRConfig] +} + +func (crr *customRoundRobin) UpdateClientConnState(state balancer.ClientConnState) error { + crrCfg, ok := state.BalancerConfig.(*customRRConfig) + if !ok { + return balancer.ErrBadResolverState + } + if el := state.ResolverState.Endpoints; len(el) != 2 { + return fmt.Errorf("UpdateClientConnState wants two endpoints, got: %v", el) + } + crr.cfg.Store(crrCfg) + // A call to UpdateClientConnState should always produce a new Picker. That + // is guaranteed to happen since the aggregator will always call + // UpdateChildState in its UpdateClientConnState. + return crr.Balancer.UpdateClientConnState(balancer.ClientConnState{ + BalancerConfig: gracefulSwitchPickFirst, + ResolverState: state.ResolverState, + }) +} + +func (crr *customRoundRobin) UpdateState(state balancer.State) { + if state.ConnectivityState == connectivity.Ready { + childStates := endpointsharding.ChildStatesFromPicker(state.Picker) + var readyPickers []balancer.Picker + for _, childState := range childStates { + if childState.State.ConnectivityState == connectivity.Ready { + readyPickers = append(readyPickers, childState.State.Picker) + } + } + // If both children are ready, pick using the custom round robin + // algorithm. + if len(readyPickers) == 2 { + picker := &customRoundRobinPicker{ + pickers: readyPickers, + chooseSecond: crr.cfg.Load().ChooseSecond, + next: 0, + } + crr.ClientConn.UpdateState(balancer.State{ + ConnectivityState: connectivity.Ready, + Picker: picker, + }) + return + } + } + // Delegate to default behavior/picker from below. + crr.ClientConn.UpdateState(state) +} + +type customRoundRobinPicker struct { + pickers []balancer.Picker + chooseSecond uint32 + next uint32 +} + +func (crrp *customRoundRobinPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { + next := atomic.AddUint32(&crrp.next, 1) + index := 0 + if next != 0 && next%crrp.chooseSecond == 0 { + index = 1 + } + childPicker := crrp.pickers[index%len(crrp.pickers)] + return childPicker.Pick(info) +} diff --git a/examples/features/customloadbalancer/client/main.go b/examples/features/customloadbalancer/client/main.go new file mode 100644 index 000000000000..921717d1fe17 --- /dev/null +++ b/examples/features/customloadbalancer/client/main.go @@ -0,0 +1,154 @@ +/* + * + * Copyright 2023 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package main + +import ( + "context" + "fmt" + "log" + "strings" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + _ "google.golang.org/grpc/examples/features/customloadbalancer/client/customroundrobin" // To register custom_round_robin. + pb "google.golang.org/grpc/examples/features/proto/echo" + "google.golang.org/grpc/internal" + "google.golang.org/grpc/peer" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/resolver/manual" + "google.golang.org/grpc/serviceconfig" +) + +var ( + addr1 = "localhost:50050" + addr2 = "localhost:50051" +) + +func main() { + mr := manual.NewBuilderWithScheme("example") + defer mr.Close() + + // You can also plug in your own custom lb policy, which needs to be + // configurable. This n is configurable. Try changing n and see how the + // behavior changes. + json := `{"loadBalancingConfig": [{"custom_round_robin":{"chooseSecond": 3}}]}` + sc := internal.ParseServiceConfig.(func(string) *serviceconfig.ParseResult)(json) + mr.InitialState(resolver.State{ + Endpoints: []resolver.Endpoint{ + {Addresses: []resolver.Address{{Addr: addr1}}}, + {Addresses: []resolver.Address{{Addr: addr2}}}, + }, + ServiceConfig: sc, + }) + + cc, err := grpc.Dial(mr.Scheme()+":///", grpc.WithResolvers(mr), grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + log.Fatalf("Failed to dial: %v", err) + } + defer cc.Close() + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + ec := pb.NewEchoClient(cc) + if err := waitForDistribution(ctx, ec); err != nil { + log.Fatalf(err.Error()) + } + fmt.Println("Successful multiple iterations of 1:2 ratio") +} + +// waitForDistribution makes RPC's on the echo client until 3 RPC's follow the +// same 1:2 address ratio for the peer. Returns an error if fails to do so +// before context timeout. +func waitForDistribution(ctx context.Context, ec pb.EchoClient) error { + for { + results := make(map[string]uint32) + InnerLoop: + for { + if ctx.Err() != nil { + return fmt.Errorf("timeout waiting for 1:2 distribution between addresses %v and %v", addr1, addr2) + } + + for i := 0; i < 3; i++ { + res := make(map[string]uint32) + for j := 0; j < 3; j++ { + var peer peer.Peer + r, err := ec.UnaryEcho(ctx, &pb.EchoRequest{Message: "this is examples/customloadbalancing"}, grpc.Peer(&peer)) + if err != nil { + return fmt.Errorf("UnaryEcho failed: %v", err) + } + fmt.Println(r) + peerAddr := peer.Addr.String() + if !strings.HasSuffix(peerAddr, "50050") && !strings.HasSuffix(peerAddr, "50051") { + return fmt.Errorf("peer address was not one of %v or %v, got: %v", addr1, addr2, peerAddr) + } + res[peerAddr]++ + time.Sleep(time.Millisecond) + } + // Make sure the addresses come in a 1:2 ratio for this + // iteration. + var seen1, seen2 bool + for addr, count := range res { + if count != 1 && count != 2 { + break InnerLoop + } + if count == 1 { + if seen1 { + break InnerLoop + } + seen1 = true + } + if count == 2 { + if seen2 { + break InnerLoop + } + seen2 = true + } + results[addr] = results[addr] + count + } + if !seen1 || !seen2 { + break InnerLoop + } + } + // Make sure iteration is 3 and 6 for addresses seen. This makes + // sure the distribution is the same 1:2 ratio for each iteration. + var seen3, seen6 bool + for _, count := range results { + if count != 3 && count != 6 { + break InnerLoop + } + if count == 3 { + if seen3 { + break InnerLoop + } + seen3 = true + } + if count == 6 { + if seen6 { + break InnerLoop + } + seen6 = true + } + return nil + } + if !seen3 || !seen6 { + break InnerLoop + } + } + } +} diff --git a/examples/features/customloadbalancer/server/main.go b/examples/features/customloadbalancer/server/main.go new file mode 100644 index 000000000000..ec5f2337c007 --- /dev/null +++ b/examples/features/customloadbalancer/server/main.go @@ -0,0 +1,66 @@ +/* + * + * Copyright 2023 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package main + +import ( + "context" + "fmt" + "log" + "net" + "sync" + + "google.golang.org/grpc" + pb "google.golang.org/grpc/examples/features/proto/echo" +) + +var ( + addrs = []string{"localhost:50050", "localhost:50051"} +) + +type echoServer struct { + pb.UnimplementedEchoServer + addr string +} + +func (s *echoServer) UnaryEcho(ctx context.Context, req *pb.EchoRequest) (*pb.EchoResponse, error) { + return &pb.EchoResponse{Message: fmt.Sprintf("%s (from %s)", req.Message, s.addr)}, nil +} + +func main() { + var wg sync.WaitGroup + for _, addr := range addrs { + lis, err := net.Listen("tcp", addr) + if err != nil { + log.Fatalf("failed to listen: %v", err) + } + s := grpc.NewServer() + pb.RegisterEchoServer(s, &echoServer{ + addr: addr, + }) + log.Printf("serving on %s\n", addr) + wg.Add(1) + go func() { + defer wg.Done() + if err := s.Serve(lis); err != nil { + log.Fatalf("failed to serve: %v", err) + } + }() + } + wg.Wait() +} diff --git a/internal/balancer/gracefulswitch/config.go b/internal/balancer/gracefulswitch/config.go index 6bf7f87396f6..13821a926606 100644 --- a/internal/balancer/gracefulswitch/config.go +++ b/internal/balancer/gracefulswitch/config.go @@ -75,7 +75,6 @@ func ParseConfig(cfg json.RawMessage) (serviceconfig.LoadBalancingConfig, error) if err != nil { return nil, fmt.Errorf("error parsing config for policy %q: %v", name, err) } - return &lbConfig{childBuilder: builder, childConfig: cfg}, nil } diff --git a/internal/balancer/gracefulswitch/gracefulswitch.go b/internal/balancer/gracefulswitch/gracefulswitch.go index 45d5e50ea9b1..73bb4c4ee9a3 100644 --- a/internal/balancer/gracefulswitch/gracefulswitch.go +++ b/internal/balancer/gracefulswitch/gracefulswitch.go @@ -169,7 +169,6 @@ func (gsb *Balancer) latestBalancer() *balancerWrapper { func (gsb *Balancer) UpdateClientConnState(state balancer.ClientConnState) error { // The resolver data is only relevant to the most recent LB Policy. balToUpdate := gsb.latestBalancer() - gsbCfg, ok := state.BalancerConfig.(*lbConfig) if ok { // Switch to the child in the config unless it is already active. diff --git a/pickfirst.go b/pickfirst.go index e3ea42ba962b..8853626614e8 100644 --- a/pickfirst.go +++ b/pickfirst.go @@ -54,7 +54,7 @@ type pfConfig struct { serviceconfig.LoadBalancingConfig `json:"-"` // If set to true, instructs the LB policy to shuffle the order of the list - // of addresses received from the name resolver before attempting to + // of endpoints received from the name resolver before attempting to // connect to them. ShuffleAddressList bool `json:"shuffleAddressList"` } @@ -94,8 +94,7 @@ func (b *pickfirstBalancer) ResolverError(err error) { } func (b *pickfirstBalancer) UpdateClientConnState(state balancer.ClientConnState) error { - addrs := state.ResolverState.Addresses - if len(addrs) == 0 { + if len(state.ResolverState.Addresses) == 0 && len(state.ResolverState.Endpoints) == 0 { // The resolver reported an empty address list. Treat it like an error by // calling b.ResolverError. if b.subConn != nil { @@ -107,22 +106,49 @@ func (b *pickfirstBalancer) UpdateClientConnState(state balancer.ClientConnState b.ResolverError(errors.New("produced zero addresses")) return balancer.ErrBadResolverState } - // We don't have to guard this block with the env var because ParseConfig // already does so. cfg, ok := state.BalancerConfig.(pfConfig) if state.BalancerConfig != nil && !ok { return fmt.Errorf("pickfirst: received illegal BalancerConfig (type %T): %v", state.BalancerConfig, state.BalancerConfig) } - if cfg.ShuffleAddressList { - addrs = append([]resolver.Address{}, addrs...) - grpcrand.Shuffle(len(addrs), func(i, j int) { addrs[i], addrs[j] = addrs[j], addrs[i] }) - } if b.logger.V(2) { b.logger.Infof("Received new config %s, resolver state %s", pretty.ToJSON(cfg), pretty.ToJSON(state.ResolverState)) } + var addrs []resolver.Address + if endpoints := state.ResolverState.Endpoints; len(endpoints) != 0 { + // Perform the optional shuffling described in gRFC A62. The shuffling will + // change the order of endpoints but not touch the order of the addresses + // within each endpoint. - A61 + if cfg.ShuffleAddressList { + endpoints = append([]resolver.Endpoint{}, endpoints...) + grpcrand.Shuffle(len(endpoints), func(i, j int) { endpoints[i], endpoints[j] = endpoints[j], endpoints[i] }) + } + + // "Flatten the list by concatenating the ordered list of addresses for each + // of the endpoints, in order." - A61 + for _, endpoint := range endpoints { + // "In the flattened list, interleave addresses from the two address + // families, as per RFC-8304 section 4." - A61 + // TODO: support the above language. + addrs = append(addrs, endpoint.Addresses...) + } + } else { + // Endpoints not set, process addresses until we migrate resolver + // emissions fully to Endpoints. The top channel does wrap emitted + // addresses with endpoints, however some balancers such as weighted + // target do not forwarrd the corresponding correct endpoints down/split + // endpoints properly. Once all balancers correctly forward endpoints + // down, can delete this else conditional. + addrs = state.ResolverState.Addresses + if cfg.ShuffleAddressList { + addrs = append([]resolver.Address{}, addrs...) + grpcrand.Shuffle(len(addrs), func(i, j int) { addrs[i], addrs[j] = addrs[j], addrs[i] }) + } + } + if b.subConn != nil { b.cc.UpdateAddresses(b.subConn, addrs) return nil diff --git a/test/pickfirst_test.go b/test/pickfirst_test.go index 171d40d29d7a..52f6c531d146 100644 --- a/test/pickfirst_test.go +++ b/test/pickfirst_test.go @@ -397,7 +397,10 @@ func (s) TestPickFirst_ShuffleAddressList(t *testing.T) { // Push an update with both addresses and shuffling disabled. We should // connect to backend 0. - r.UpdateState(resolver.State{Addresses: []resolver.Address{addrs[0], addrs[1]}}) + r.UpdateState(resolver.State{Endpoints: []resolver.Endpoint{ + {Addresses: []resolver.Address{addrs[0]}}, + {Addresses: []resolver.Address{addrs[1]}}, + }}) if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil { t.Fatal(err) } @@ -406,7 +409,10 @@ func (s) TestPickFirst_ShuffleAddressList(t *testing.T) { // but the channel should still be connected to backend 0. shufState := resolver.State{ ServiceConfig: parseServiceConfig(t, r, serviceConfig), - Addresses: []resolver.Address{addrs[0], addrs[1]}, + Endpoints: []resolver.Endpoint{ + {Addresses: []resolver.Address{addrs[0]}}, + {Addresses: []resolver.Address{addrs[1]}}, + }, } r.UpdateState(shufState) if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil {