diff --git a/app/resolver.go b/app/resolver.go index 3da9adc8d9..036d68310f 100644 --- a/app/resolver.go +++ b/app/resolver.go @@ -6,6 +6,11 @@ import ( "time" ) +var ( + tick = time.Tick + lookupIP = net.LookupIP +) + // Resolver periodically tries to resolve the IP addresses for a given // set of hostnames. type Resolver struct { @@ -25,22 +30,21 @@ type peer struct { // resolve to multiple IPs; it will repeatedly call // add with the same IP, expecting the target to dedupe. func NewResolver(peers []string, add func(string)) Resolver { - resolver := Resolver{ + r := Resolver{ quit: make(chan struct{}), add: add, peers: prepareNames(peers), } - - go resolver.loop() - return resolver + go r.loop() + return r } -func prepareNames(peers []string) []peer { +func prepareNames(strs []string) []peer { var results []peer - for _, p := range peers { - hostname, port, err := net.SplitHostPort(p) + for _, s := range strs { + hostname, port, err := net.SplitHostPort(s) if err != nil { - log.Printf("invalid address %s: %v", p, err) + log.Printf("invalid address %s: %v", s, err) continue } results = append(results, peer{hostname, port}) @@ -50,10 +54,10 @@ func prepareNames(peers []string) []peer { func (r Resolver) loop() { r.resolveHosts() + t := tick(time.Minute) for { - tick := time.Tick(1 * time.Minute) select { - case <-tick: + case <-t: r.resolveHosts() case <-r.quit: return @@ -63,7 +67,7 @@ func (r Resolver) loop() { func (r Resolver) resolveHosts() { for _, peer := range r.peers { - addrs, err := net.LookupIP(peer.hostname) + addrs, err := lookupIP(peer.hostname) if err != nil { log.Printf("lookup %s: %v", peer.hostname, err) continue @@ -74,13 +78,12 @@ func (r Resolver) resolveHosts() { if addr.To4() == nil { continue } - r.add(net.JoinHostPort(addr.String(), peer.port)) } } } -// Stop this resolver. +// Stop this Resolver. func (r Resolver) Stop() { - r.quit <- struct{}{} + close(r.quit) } diff --git a/app/resolver_test.go b/app/resolver_test.go new file mode 100644 index 0000000000..5c41d7d488 --- /dev/null +++ b/app/resolver_test.go @@ -0,0 +1,71 @@ +package main + +import ( + "net" + "runtime" + "testing" + "time" +) + +func TestResolver(t *testing.T) { + oldTick := tick + defer func() { tick = oldTick }() + c := make(chan time.Time) + tick = func(_ time.Duration) <-chan time.Time { return c } + + oldLookupIP := lookupIP + defer func() { lookupIP = oldLookupIP }() + ips := []net.IP{} + lookupIP = func(host string) ([]net.IP, error) { return ips, nil } + + port := ":80" + adds := make(chan string) + add := func(s string) { adds <- s } + r := NewResolver([]string{"symbolic.name" + port}, add) + + c <- time.Now() // trigger initial resolve, with no endpoints + select { + case <-time.After(time.Millisecond): + case s := <-adds: + t.Errorf("got unexpected add: %q", s) + } + + assertAdd := func(want string) { + select { + case have := <-adds: + if want != have { + _, _, line, _ := runtime.Caller(1) + t.Errorf("line %d: want %q, have %q", line, want, have) + } + case <-time.After(time.Millisecond): + t.Fatal("didn't get add in time") + } + } + + ip1 := "1.2.3.4" + ips = makeIPs(ip1) + c <- time.Now() // trigger a resolve + assertAdd(ip1 + port) // we want 1 add + + ip2 := "10.10.10.10" + ips = makeIPs(ip1, ip2) + c <- time.Now() // trigger another resolve, this time with 2 adds + assertAdd(ip1 + port) // first add + assertAdd(ip2 + port) // second add + + done := make(chan struct{}) + go func() { r.Stop(); close(done) }() + select { + case <-done: + case <-time.After(time.Millisecond): + t.Errorf("didn't Stop in time") + } +} + +func makeIPs(addrs ...string) []net.IP { + var ips []net.IP + for _, addr := range addrs { + ips = append(ips, net.ParseIP(addr)) + } + return ips +}