Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First draft of discovery #150

Merged
merged 3 commits into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions go/cmd/vtgateproxy/vtgateproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,12 @@ func main() {
servenv.ParseFlags("vtgateproxy")
servenv.Init()

vtgateproxy.Init()

servenv.OnRun(func() {
// Flags are parsed now. Parse the template using the actual flag value and overwrite the current template.
vtgateproxy.RegisterJsonDiscovery()
vtgateproxy.Init()
})

servenv.OnClose(func() {
})
servenv.RunDefault()
Expand Down
214 changes: 214 additions & 0 deletions go/vt/vtgateproxy/discovery.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
package vtgateproxy

import (
"encoding/json"
"flag"
"fmt"
"math/rand"
"os"
"strconv"
"time"

"google.golang.org/grpc/resolver"
)

var (
jsonDiscoveryConfig = flag.String("json_config", "", "json file describing the host list to use fot vitess://vtgate resolution")
)

// File based discovery for vtgate grpc endpoints
// This loads the list of hosts from json and watches for changes to the list of hosts. It will select N connection to maintain to backend vtgates.
// Connections will rebalance every 5 minutes
//
// Example json config - based on the slack hosts format
//
// [
// {
// "address": "10.4.56.194",
// "az_id": "use1-az1",
// "grpc": "15999",
// "type": "aux"
// },
//
// Naming scheme:
// vtgate://<type>?num_connections=<int>&az_id=<string>
//
// num_connections: Option number of hosts to open connections to for round-robin selection
// az_id: Filter to just hosts in this az (optional)
// type: Only select from hosts of this type (required)
//

type DiscoveryHost struct {
Address string
NebulaAddress string `json:"nebula_address"`
Grpc string
AZId string `json:"az_id"`
Type string
}

type JSONGateConfigDiscovery struct {
JsonPath string
}

func (b *JSONGateConfigDiscovery) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (resolver.Resolver, error) {
queryOpts := target.URL.Query()
queryParamCount := queryOpts.Get("num_connections")
queryAZID := queryOpts.Get("az_id")
num_connections := 0

gateType := target.URL.Host

if queryParamCount != "" {
num_connections, _ = strconv.Atoi(queryParamCount)
}

filters := resolveFilters{
gate_type: gateType,
}

if queryAZID != "" {
filters.az_id = queryAZID
}

r := &resolveJSONGateConfig{
target: target,
cc: cc,
jsonPath: b.JsonPath,
num_connections: num_connections,
filters: filters,
}
r.start()
return r, nil
}
func (*JSONGateConfigDiscovery) Scheme() string { return "vtgate" }

func RegisterJsonDiscovery() {
fmt.Printf("Registering: %v\n", *jsonDiscoveryConfig)
resolver.Register(&JSONGateConfigDiscovery{
JsonPath: *jsonDiscoveryConfig,
})
}

type resolveFilters struct {
gate_type string
az_id string
}

// exampleResolver is a
// Resolver(https://godoc.org/google.golang.org/grpc/resolver#Resolver).
type resolveJSONGateConfig struct {
target resolver.Target
cc resolver.ClientConn
jsonPath string
ticker *time.Ticker
rand *rand.Rand // safe for concurrent use.
num_connections int
filters resolveFilters
}

func (r *resolveJSONGateConfig) loadConfig() (*[]resolver.Address, error) {
config := []DiscoveryHost{}

data, err := os.ReadFile(r.jsonPath)
if err != nil {
return nil, err
}

err = json.Unmarshal(data, &config)
if err != nil {
fmt.Printf("parse err: %v\n", err)
return nil, err
}

fmt.Printf("%v\n", config)

addrs := []resolver.Address{}
for _, s := range config {
// Apply filters
if r.filters.gate_type != "" {
if r.filters.gate_type != s.Type {
fmt.Printf("Dropped non matching type: %v\n", s.Type)
continue
}
}

if r.filters.az_id != "" {
if r.filters.az_id != s.AZId {
fmt.Printf("Dropped non matching az: %v\n", s.AZId)
continue
}
}
// Add matching hosts to registration list
fmt.Printf("selected host for discovery: %v %v\n", fmt.Sprintf("%s:%s", s.NebulaAddress, s.Grpc), s)
addrs = append(addrs, resolver.Address{Addr: fmt.Sprintf("%s:%s", s.NebulaAddress, s.Grpc)})
}

// Shuffle to ensure every host has a different order to iterate through
r.rand.Shuffle(len(addrs), func(i, j int) {
addrs[i], addrs[j] = addrs[j], addrs[i]
})

// Slice off the first N hosts, optionally
if r.num_connections > 0 && r.num_connections <= len(addrs) {
addrs = addrs[0:r.num_connections]
}

fmt.Printf("Returning discovery: %v\n", addrs)

return &addrs, nil
}

func (r *resolveJSONGateConfig) start() {
fmt.Print("Starting discovery checker\n")
r.rand = rand.New(rand.NewSource(time.Now().UnixNano()))

// Immediately load the initial config
addrs, err := r.loadConfig()
if err == nil {
// if we parse ok, populate the local address store
r.cc.UpdateState(resolver.State{Addresses: *addrs})
}

// Start a config watcher
r.ticker = time.NewTicker(100 * time.Millisecond)
fileStat, err := os.Stat(r.jsonPath)
if err != nil {
return
}
lastLoaded := time.Now()

go func() {
for range r.ticker.C {
checkFileStat, err := os.Stat(r.jsonPath)
isUnchanged := checkFileStat.Size() == fileStat.Size() || checkFileStat.ModTime() == fileStat.ModTime()
isNotExpired := time.Since(lastLoaded) < 1*time.Minute
if isUnchanged && isNotExpired {
// no change
continue
}
lastLoaded = time.Now()

fileStat = checkFileStat
fmt.Printf("Detected config change\n")

addrs, err := r.loadConfig()
if err != nil {
// better luck next loop
// TODO: log this
fmt.Print("oh no\n")
continue
}

r.cc.UpdateState(resolver.State{Addresses: *addrs})
}
}()
}
func (r *resolveJSONGateConfig) ResolveNow(o resolver.ResolveNowOptions) {}
func (r *resolveJSONGateConfig) Close() {
r.ticker.Stop()
}

func init() {
// Register the example ResolverBuilder. This is usually done in a package's
// init() function.
}
8 changes: 5 additions & 3 deletions go/vt/vtgateproxy/vtgateproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ var (
defaultDDLStrategy = flag.String("ddl_strategy", string(schema.DDLStrategyDirect), "Set default strategy for DDL statements. Override with @@ddl_strategy session variable")
sysVarSetEnabled = flag.Bool("enable_system_settings", true, "This will enable the system settings to be changed per session at the database connection level")

vtGateProxy *VTGateProxy
vtGateProxy *VTGateProxy = &VTGateProxy{}
)

type VTGateProxy struct {
Expand All @@ -54,6 +54,10 @@ func (proxy *VTGateProxy) connect(ctx context.Context) error {
return append(opts, grpc.WithBlock()), nil
})

grpcclient.RegisterGRPCDialOptions(func(opts []grpc.DialOption) ([]grpc.DialOption, error) {
return append(opts, grpc.WithDefaultServiceConfig(`{"loadBalancingConfig": [{"round_robin":{}}]}`)), nil
})

conn, err := vtgateconn.DialProtocol(ctx, "grpc", *target)
if err != nil {
return err
Expand Down Expand Up @@ -104,8 +108,6 @@ func (proxy *VTGateProxy) StreamExecute(ctx context.Context, session *vtgateconn
}

func Init() error {
vtGateProxy = &VTGateProxy{}

// XXX maybe add connect timeout?
ctx, cancel := context.WithTimeout(context.Background(), *dialTimeout)
defer cancel()
Expand Down
Loading