diff --git a/cmd/devp2p/crawl.go b/cmd/devp2p/crawl.go
new file mode 100644
index 000000000000..92aaad72a372
--- /dev/null
+++ b/cmd/devp2p/crawl.go
@@ -0,0 +1,152 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of go-ethereum.
+//
+// go-ethereum is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// go-ethereum is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with go-ethereum. If not, see .
+
+package main
+
+import (
+ "time"
+
+ "github.com/ethereum/go-ethereum/log"
+ "github.com/ethereum/go-ethereum/p2p/discover"
+ "github.com/ethereum/go-ethereum/p2p/enode"
+)
+
+type crawler struct {
+ input nodeSet
+ output nodeSet
+ disc *discover.UDPv4
+ iters []enode.Iterator
+ inputIter enode.Iterator
+ ch chan *enode.Node
+ closed chan struct{}
+
+ // settings
+ revalidateInterval time.Duration
+}
+
+func newCrawler(input nodeSet, disc *discover.UDPv4, iters ...enode.Iterator) *crawler {
+ c := &crawler{
+ input: input,
+ output: make(nodeSet, len(input)),
+ disc: disc,
+ iters: iters,
+ inputIter: enode.IterNodes(input.nodes()),
+ ch: make(chan *enode.Node),
+ closed: make(chan struct{}),
+ }
+ c.iters = append(c.iters, c.inputIter)
+ // Copy input to output initially. Any nodes that fail validation
+ // will be dropped from output during the run.
+ for id, n := range input {
+ c.output[id] = n
+ }
+ return c
+}
+
+func (c *crawler) run(timeout time.Duration) nodeSet {
+ var (
+ timeoutTimer = time.NewTimer(timeout)
+ timeoutCh <-chan time.Time
+ doneCh = make(chan enode.Iterator, len(c.iters))
+ liveIters = len(c.iters)
+ )
+ for _, it := range c.iters {
+ go c.runIterator(doneCh, it)
+ }
+
+loop:
+ for {
+ select {
+ case n := <-c.ch:
+ c.updateNode(n)
+ case it := <-doneCh:
+ if it == c.inputIter {
+ // Enable timeout when we're done revalidating the input nodes.
+ log.Info("Revalidation of input set is done", "len", len(c.input))
+ if timeout > 0 {
+ timeoutCh = timeoutTimer.C
+ }
+ }
+ if liveIters--; liveIters == 0 {
+ break loop
+ }
+ case <-timeoutCh:
+ break loop
+ }
+ }
+
+ close(c.closed)
+ for _, it := range c.iters {
+ it.Close()
+ }
+ for ; liveIters > 0; liveIters-- {
+ <-doneCh
+ }
+ return c.output
+}
+
+func (c *crawler) runIterator(done chan<- enode.Iterator, it enode.Iterator) {
+ defer func() { done <- it }()
+ for it.Next() {
+ select {
+ case c.ch <- it.Node():
+ case <-c.closed:
+ return
+ }
+ }
+}
+
+func (c *crawler) updateNode(n *enode.Node) {
+ node, ok := c.output[n.ID()]
+
+ // Skip validation of recently-seen nodes.
+ if ok && time.Since(node.LastCheck) < c.revalidateInterval {
+ return
+ }
+
+ // Request the node record.
+ nn, err := c.disc.RequestENR(n)
+ node.LastCheck = truncNow()
+ if err != nil {
+ if node.Score == 0 {
+ // Node doesn't implement EIP-868.
+ log.Debug("Skipping node", "id", n.ID())
+ return
+ }
+ node.Score /= 2
+ } else {
+ node.N = nn
+ node.Seq = nn.Seq()
+ node.Score++
+ if node.FirstResponse.IsZero() {
+ node.FirstResponse = node.LastCheck
+ }
+ node.LastResponse = node.LastCheck
+ }
+
+ // Store/update node in output set.
+ if node.Score <= 0 {
+ log.Info("Removing node", "id", n.ID())
+ delete(c.output, n.ID())
+ } else {
+ log.Info("Updating node", "id", n.ID(), "seq", n.Seq(), "score", node.Score)
+ c.output[n.ID()] = node
+ }
+}
+
+func truncNow() time.Time {
+ return time.Now().UTC().Truncate(1 * time.Second)
+}
diff --git a/cmd/devp2p/discv4cmd.go b/cmd/devp2p/discv4cmd.go
index ab5b874029df..9525bec66817 100644
--- a/cmd/devp2p/discv4cmd.go
+++ b/cmd/devp2p/discv4cmd.go
@@ -39,6 +39,7 @@ var (
discv4RequestRecordCommand,
discv4ResolveCommand,
discv4ResolveJSONCommand,
+ discv4CrawlCommand,
},
}
discv4PingCommand = cli.Command{
@@ -67,12 +68,25 @@ var (
Flags: []cli.Flag{bootnodesFlag},
ArgsUsage: "",
}
+ discv4CrawlCommand = cli.Command{
+ Name: "crawl",
+ Usage: "Updates a nodes.json file with random nodes found in the DHT",
+ Action: discv4Crawl,
+ Flags: []cli.Flag{bootnodesFlag, crawlTimeoutFlag},
+ }
)
-var bootnodesFlag = cli.StringFlag{
- Name: "bootnodes",
- Usage: "Comma separated nodes used for bootstrapping",
-}
+var (
+ bootnodesFlag = cli.StringFlag{
+ Name: "bootnodes",
+ Usage: "Comma separated nodes used for bootstrapping",
+ }
+ crawlTimeoutFlag = cli.DurationFlag{
+ Name: "timeout",
+ Usage: "Time limit for the crawl.",
+ Value: 30 * time.Minute,
+ }
+)
func discv4Ping(ctx *cli.Context) error {
n := getNodeArg(ctx)
@@ -113,30 +127,48 @@ func discv4ResolveJSON(ctx *cli.Context) error {
if ctx.NArg() < 1 {
return fmt.Errorf("need nodes file as argument")
}
- disc := startV4(ctx)
- defer disc.Close()
- file := ctx.Args().Get(0)
-
- // Load existing nodes in file.
- var nodes []*enode.Node
- if common.FileExist(file) {
- nodes = loadNodesJSON(file).nodes()
+ nodesFile := ctx.Args().Get(0)
+ inputSet := make(nodeSet)
+ if common.FileExist(nodesFile) {
+ inputSet = loadNodesJSON(nodesFile)
}
- // Add nodes from command line arguments.
+
+ // Add extra nodes from command line arguments.
+ var nodeargs []*enode.Node
for i := 1; i < ctx.NArg(); i++ {
n, err := parseNode(ctx.Args().Get(i))
if err != nil {
exit(err)
}
- nodes = append(nodes, n)
+ nodeargs = append(nodeargs, n)
}
- result := make(nodeSet, len(nodes))
- for _, n := range nodes {
- n = disc.Resolve(n)
- result[n.ID()] = nodeJSON{Seq: n.Seq(), N: n}
+ // Run the crawler.
+ disc := startV4(ctx)
+ defer disc.Close()
+ c := newCrawler(inputSet, disc, enode.IterNodes(nodeargs))
+ c.revalidateInterval = 0
+ output := c.run(0)
+ writeNodesJSON(nodesFile, output)
+ return nil
+}
+
+func discv4Crawl(ctx *cli.Context) error {
+ if ctx.NArg() < 1 {
+ return fmt.Errorf("need nodes file as argument")
+ }
+ nodesFile := ctx.Args().First()
+ var inputSet nodeSet
+ if common.FileExist(nodesFile) {
+ inputSet = loadNodesJSON(nodesFile)
}
- writeNodesJSON(file, result)
+
+ disc := startV4(ctx)
+ defer disc.Close()
+ c := newCrawler(inputSet, disc, disc.RandomNodes())
+ c.revalidateInterval = 10 * time.Minute
+ output := c.run(ctx.Duration(crawlTimeoutFlag.Name))
+ writeNodesJSON(nodesFile, output)
return nil
}
diff --git a/cmd/devp2p/dnscmd.go b/cmd/devp2p/dnscmd.go
index 74d70d3aaaf9..eb15764b04e2 100644
--- a/cmd/devp2p/dnscmd.go
+++ b/cmd/devp2p/dnscmd.go
@@ -109,7 +109,8 @@ func dnsSync(ctx *cli.Context) error {
}
def := treeToDefinition(url, t)
def.Meta.LastModified = time.Now()
- writeTreeDefinition(outdir, def)
+ writeTreeMetadata(outdir, def)
+ writeTreeNodes(outdir, def)
return nil
}
@@ -151,7 +152,7 @@ func dnsSign(ctx *cli.Context) error {
def = treeToDefinition(url, t)
def.Meta.LastModified = time.Now()
- writeTreeDefinition(defdir, def)
+ writeTreeMetadata(defdir, def)
return nil
}
@@ -315,26 +316,28 @@ func ensureValidTreeSignature(t *dnsdisc.Tree, pubkey *ecdsa.PublicKey, sig stri
return nil
}
-// writeTreeDefinition writes a DNS node tree definition to the given directory.
-func writeTreeDefinition(directory string, def *dnsDefinition) {
+// writeTreeMetadata writes a DNS node tree metadata file to the given directory.
+func writeTreeMetadata(directory string, def *dnsDefinition) {
metaJSON, err := json.MarshalIndent(&def.Meta, "", jsonIndent)
if err != nil {
exit(err)
}
- // Convert nodes.
- nodes := make(nodeSet, len(def.Nodes))
- nodes.add(def.Nodes...)
- // Write.
if err := os.Mkdir(directory, 0744); err != nil && !os.IsExist(err) {
exit(err)
}
- metaFile, nodesFile := treeDefinitionFiles(directory)
- writeNodesJSON(nodesFile, nodes)
+ metaFile, _ := treeDefinitionFiles(directory)
if err := ioutil.WriteFile(metaFile, metaJSON, 0644); err != nil {
exit(err)
}
}
+func writeTreeNodes(directory string, def *dnsDefinition) {
+ ns := make(nodeSet, len(def.Nodes))
+ ns.add(def.Nodes...)
+ _, nodesFile := treeDefinitionFiles(directory)
+ writeNodesJSON(nodesFile, ns)
+}
+
func treeDefinitionFiles(directory string) (string, string) {
meta := filepath.Join(directory, "enrtree-info.json")
nodes := filepath.Join(directory, "nodes.json")
diff --git a/cmd/devp2p/main.go b/cmd/devp2p/main.go
index c88fe6f6123e..6faa65093737 100644
--- a/cmd/devp2p/main.go
+++ b/cmd/devp2p/main.go
@@ -60,6 +60,7 @@ func init() {
enrdumpCommand,
discv4Command,
dnsCommand,
+ nodesetCommand,
}
}
diff --git a/cmd/devp2p/nodeset.go b/cmd/devp2p/nodeset.go
index a4a05016e92e..2d86c3f65aba 100644
--- a/cmd/devp2p/nodeset.go
+++ b/cmd/devp2p/nodeset.go
@@ -21,7 +21,9 @@ import (
"encoding/json"
"fmt"
"io/ioutil"
+ "os"
"sort"
+ "time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/p2p/enode"
@@ -36,6 +38,15 @@ type nodeSet map[enode.ID]nodeJSON
type nodeJSON struct {
Seq uint64 `json:"seq"`
N *enode.Node `json:"record"`
+
+ // The score tracks how many liveness checks were performed. It is incremented by one
+ // every time the node passes a check, and halved every time it doesn't.
+ Score int `json:"score,omitempty"`
+ // These two track the time of last successful contact.
+ FirstResponse time.Time `json:"firstResponse,omitempty"`
+ LastResponse time.Time `json:"lastResponse,omitempty"`
+ // This one tracks the time of our last attempt to contact the node.
+ LastCheck time.Time `json:"lastCheck,omitempty"`
}
func loadNodesJSON(file string) nodeSet {
@@ -51,6 +62,10 @@ func writeNodesJSON(file string, nodes nodeSet) {
if err != nil {
exit(err)
}
+ if file == "-" {
+ os.Stdout.Write(nodesJSON)
+ return
+ }
if err := ioutil.WriteFile(file, nodesJSON, 0644); err != nil {
exit(err)
}
diff --git a/cmd/devp2p/nodesetcmd.go b/cmd/devp2p/nodesetcmd.go
new file mode 100644
index 000000000000..384a12da7c7e
--- /dev/null
+++ b/cmd/devp2p/nodesetcmd.go
@@ -0,0 +1,193 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of go-ethereum.
+//
+// go-ethereum is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// go-ethereum is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with go-ethereum. If not, see .
+
+package main
+
+import (
+ "fmt"
+ "net"
+ "time"
+
+ "github.com/ethereum/go-ethereum/core/forkid"
+ "github.com/ethereum/go-ethereum/p2p/enr"
+ "github.com/ethereum/go-ethereum/params"
+ "github.com/ethereum/go-ethereum/rlp"
+ "gopkg.in/urfave/cli.v1"
+)
+
+var (
+ nodesetCommand = cli.Command{
+ Name: "nodeset",
+ Usage: "Node set tools",
+ Subcommands: []cli.Command{
+ nodesetInfoCommand,
+ nodesetFilterCommand,
+ },
+ }
+ nodesetInfoCommand = cli.Command{
+ Name: "info",
+ Usage: "Shows statistics about a node set",
+ Action: nodesetInfo,
+ ArgsUsage: "",
+ }
+ nodesetFilterCommand = cli.Command{
+ Name: "filter",
+ Usage: "Filters a node set",
+ Action: nodesetFilter,
+ ArgsUsage: " filters..",
+
+ SkipFlagParsing: true,
+ }
+)
+
+func nodesetInfo(ctx *cli.Context) error {
+ if ctx.NArg() < 1 {
+ return fmt.Errorf("need nodes file as argument")
+ }
+
+ ns := loadNodesJSON(ctx.Args().First())
+ fmt.Printf("Set contains %d nodes.\n", len(ns))
+ return nil
+}
+
+func nodesetFilter(ctx *cli.Context) error {
+ if ctx.NArg() < 1 {
+ return fmt.Errorf("need nodes file as argument")
+ }
+ ns := loadNodesJSON(ctx.Args().First())
+ filter, err := andFilter(ctx.Args().Tail())
+ if err != nil {
+ return err
+ }
+
+ result := make(nodeSet)
+ for id, n := range ns {
+ if filter(n) {
+ result[id] = n
+ }
+ }
+ writeNodesJSON("-", result)
+ return nil
+}
+
+type nodeFilter func(nodeJSON) bool
+
+type nodeFilterC struct {
+ narg int
+ fn func([]string) (nodeFilter, error)
+}
+
+var filterFlags = map[string]nodeFilterC{
+ "-ip": {1, ipFilter},
+ "-min-age": {1, minAgeFilter},
+ "-eth-network": {1, ethFilter},
+ "-les-server": {0, lesFilter},
+}
+
+func parseFilters(args []string) ([]nodeFilter, error) {
+ var filters []nodeFilter
+ for len(args) > 0 {
+ fc, ok := filterFlags[args[0]]
+ if !ok {
+ return nil, fmt.Errorf("invalid filter %q", args[0])
+ }
+ if len(args) < fc.narg {
+ return nil, fmt.Errorf("filter %q wants %d arguments, have %d", args[0], fc.narg, len(args))
+ }
+ filter, err := fc.fn(args[1:])
+ if err != nil {
+ return nil, fmt.Errorf("%s: %v", args[0], err)
+ }
+ filters = append(filters, filter)
+ args = args[fc.narg+1:]
+ }
+ return filters, nil
+}
+
+func andFilter(args []string) (nodeFilter, error) {
+ checks, err := parseFilters(args)
+ if err != nil {
+ return nil, err
+ }
+ f := func(n nodeJSON) bool {
+ for _, filter := range checks {
+ if !filter(n) {
+ return false
+ }
+ }
+ return true
+ }
+ return f, nil
+}
+
+func ipFilter(args []string) (nodeFilter, error) {
+ _, cidr, err := net.ParseCIDR(args[0])
+ if err != nil {
+ return nil, err
+ }
+ f := func(n nodeJSON) bool { return cidr.Contains(n.N.IP()) }
+ return f, nil
+}
+
+func minAgeFilter(args []string) (nodeFilter, error) {
+ minage, err := time.ParseDuration(args[0])
+ if err != nil {
+ return nil, err
+ }
+ f := func(n nodeJSON) bool {
+ age := n.LastResponse.Sub(n.FirstResponse)
+ return age >= minage
+ }
+ return f, nil
+}
+
+func ethFilter(args []string) (nodeFilter, error) {
+ var filter func(forkid.ID) error
+ switch args[0] {
+ case "mainnet":
+ filter = forkid.NewStaticFilter(params.MainnetChainConfig, params.MainnetGenesisHash)
+ case "rinkeby":
+ filter = forkid.NewStaticFilter(params.RinkebyChainConfig, params.RinkebyGenesisHash)
+ case "goerli":
+ filter = forkid.NewStaticFilter(params.GoerliChainConfig, params.GoerliGenesisHash)
+ case "ropsten":
+ filter = forkid.NewStaticFilter(params.TestnetChainConfig, params.TestnetGenesisHash)
+ default:
+ return nil, fmt.Errorf("unknown network %q", args[0])
+ }
+
+ f := func(n nodeJSON) bool {
+ var eth struct {
+ ForkID forkid.ID
+ _ []rlp.RawValue `rlp:"tail"`
+ }
+ if n.N.Load(enr.WithEntry("eth", ð)) != nil {
+ return false
+ }
+ return filter(eth.ForkID) == nil
+ }
+ return f, nil
+}
+
+func lesFilter(args []string) (nodeFilter, error) {
+ f := func(n nodeJSON) bool {
+ var les struct {
+ _ []rlp.RawValue `rlp:"tail"`
+ }
+ return n.N.Load(enr.WithEntry("les", &les)) == nil
+ }
+ return f, nil
+}
diff --git a/core/forkid/forkid.go b/core/forkid/forkid.go
index 8c1700879a5e..3845742ad9bb 100644
--- a/core/forkid/forkid.go
+++ b/core/forkid/forkid.go
@@ -80,7 +80,7 @@ func newID(config *params.ChainConfig, genesis common.Hash, head uint64) ID {
return ID{Hash: checksumToBytes(hash), Next: next}
}
-// NewFilter creates an filter that returns if a fork ID should be rejected or not
+// NewFilter creates a filter that returns if a fork ID should be rejected or not
// based on the local chain's status.
func NewFilter(chain *core.BlockChain) func(id ID) error {
return newFilter(
@@ -92,6 +92,12 @@ func NewFilter(chain *core.BlockChain) func(id ID) error {
)
}
+// NewStaticFilter creates a filter at block zero.
+func NewStaticFilter(config *params.ChainConfig, genesis common.Hash) func(id ID) error {
+ head := func() uint64 { return 0 }
+ return newFilter(config, genesis, head)
+}
+
// newFilter is the internal version of NewFilter, taking closures as its arguments
// instead of a chain. The reason is to allow testing it without having to simulate
// an entire blockchain.
diff --git a/p2p/dial.go b/p2p/dial.go
index 8dee5063f1d5..68e06cce5874 100644
--- a/p2p/dial.go
+++ b/p2p/dial.go
@@ -33,12 +33,7 @@ const (
// private networks.
dialHistoryExpiration = inboundThrottleTime + 5*time.Second
- // Discovery lookups are throttled and can only run
- // once every few seconds.
- lookupInterval = 4 * time.Second
-
- // If no peers are found for this amount of time, the initial bootnodes are
- // attempted to be connected.
+ // If no peers are found for this amount of time, the initial bootnodes are dialed.
fallbackInterval = 20 * time.Second
// Endpoint resolution is throttled with bounded backoff.
@@ -52,6 +47,10 @@ type NodeDialer interface {
Dial(*enode.Node) (net.Conn, error)
}
+type nodeResolver interface {
+ Resolve(*enode.Node) *enode.Node
+}
+
// TCPDialer implements the NodeDialer interface by using a net.Dialer to
// create TCP connections to nodes in the network
type TCPDialer struct {
@@ -69,7 +68,6 @@ func (t TCPDialer) Dial(dest *enode.Node) (net.Conn, error) {
// of the main loop in Server.run.
type dialstate struct {
maxDynDials int
- ntab discoverTable
netrestrict *netutil.Netlist
self enode.ID
bootnodes []*enode.Node // default dials when there are no peers
@@ -79,55 +77,23 @@ type dialstate struct {
lookupRunning bool
dialing map[enode.ID]connFlag
lookupBuf []*enode.Node // current discovery lookup results
- randomNodes []*enode.Node // filled from Table
static map[enode.ID]*dialTask
hist expHeap
}
-type discoverTable interface {
- Close()
- Resolve(*enode.Node) *enode.Node
- LookupRandom() []*enode.Node
- ReadRandomNodes([]*enode.Node) int
-}
-
type task interface {
Do(*Server)
}
-// A dialTask is generated for each node that is dialed. Its
-// fields cannot be accessed while the task is running.
-type dialTask struct {
- flags connFlag
- dest *enode.Node
- lastResolved time.Time
- resolveDelay time.Duration
-}
-
-// discoverTask runs discovery table operations.
-// Only one discoverTask is active at any time.
-// discoverTask.Do performs a random lookup.
-type discoverTask struct {
- results []*enode.Node
-}
-
-// A waitExpireTask is generated if there are no other tasks
-// to keep the loop in Server.run ticking.
-type waitExpireTask struct {
- time.Duration
-}
-
-func newDialState(self enode.ID, ntab discoverTable, maxdyn int, cfg *Config) *dialstate {
+func newDialState(self enode.ID, maxdyn int, cfg *Config) *dialstate {
s := &dialstate{
maxDynDials: maxdyn,
- ntab: ntab,
self: self,
netrestrict: cfg.NetRestrict,
log: cfg.Logger,
static: make(map[enode.ID]*dialTask),
dialing: make(map[enode.ID]connFlag),
bootnodes: make([]*enode.Node, len(cfg.BootstrapNodes)),
- randomNodes: make([]*enode.Node, maxdyn/2),
}
copy(s.bootnodes, cfg.BootstrapNodes)
if s.log == nil {
@@ -151,10 +117,6 @@ func (s *dialstate) removeStatic(n *enode.Node) {
}
func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Time) []task {
- if s.start.IsZero() {
- s.start = now
- }
-
var newtasks []task
addDial := func(flag connFlag, n *enode.Node) bool {
if err := s.checkDial(n, peers); err != nil {
@@ -166,20 +128,9 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti
return true
}
- // Compute number of dynamic dials necessary at this point.
- needDynDials := s.maxDynDials
- for _, p := range peers {
- if p.rw.is(dynDialedConn) {
- needDynDials--
- }
- }
- for _, flag := range s.dialing {
- if flag&dynDialedConn != 0 {
- needDynDials--
- }
+ if s.start.IsZero() {
+ s.start = now
}
-
- // Expire the dial history on every invocation.
s.hist.expire(now)
// Create dials for static nodes if they are not connected.
@@ -194,6 +145,20 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti
newtasks = append(newtasks, t)
}
}
+
+ // Compute number of dynamic dials needed.
+ needDynDials := s.maxDynDials
+ for _, p := range peers {
+ if p.rw.is(dynDialedConn) {
+ needDynDials--
+ }
+ }
+ for _, flag := range s.dialing {
+ if flag&dynDialedConn != 0 {
+ needDynDials--
+ }
+ }
+
// If we don't have any peers whatsoever, try to dial a random bootnode. This
// scenario is useful for the testnet (and private networks) where the discovery
// table might be full of mostly bad peers, making it hard to find good ones.
@@ -201,24 +166,12 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti
bootnode := s.bootnodes[0]
s.bootnodes = append(s.bootnodes[:0], s.bootnodes[1:]...)
s.bootnodes = append(s.bootnodes, bootnode)
-
if addDial(dynDialedConn, bootnode) {
needDynDials--
}
}
- // Use random nodes from the table for half of the necessary
- // dynamic dials.
- randomCandidates := needDynDials / 2
- if randomCandidates > 0 {
- n := s.ntab.ReadRandomNodes(s.randomNodes)
- for i := 0; i < randomCandidates && i < n; i++ {
- if addDial(dynDialedConn, s.randomNodes[i]) {
- needDynDials--
- }
- }
- }
- // Create dynamic dials from random lookup results, removing tried
- // items from the result buffer.
+
+ // Create dynamic dials from discovery results.
i := 0
for ; i < len(s.lookupBuf) && needDynDials > 0; i++ {
if addDial(dynDialedConn, s.lookupBuf[i]) {
@@ -226,10 +179,11 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti
}
}
s.lookupBuf = s.lookupBuf[:copy(s.lookupBuf, s.lookupBuf[i:])]
+
// Launch a discovery lookup if more candidates are needed.
if len(s.lookupBuf) < needDynDials && !s.lookupRunning {
s.lookupRunning = true
- newtasks = append(newtasks, &discoverTask{})
+ newtasks = append(newtasks, &discoverTask{want: needDynDials - len(s.lookupBuf)})
}
// Launch a timer to wait for the next node to expire if all
@@ -279,6 +233,15 @@ func (s *dialstate) taskDone(t task, now time.Time) {
}
}
+// A dialTask is generated for each node that is dialed. Its
+// fields cannot be accessed while the task is running.
+type dialTask struct {
+ flags connFlag
+ dest *enode.Node
+ lastResolved time.Time
+ resolveDelay time.Duration
+}
+
func (t *dialTask) Do(srv *Server) {
if t.dest.Incomplete() {
if !t.resolve(srv) {
@@ -304,8 +267,8 @@ func (t *dialTask) Do(srv *Server) {
// discovery network with useless queries for nodes that don't exist.
// The backoff delay resets when the node is found.
func (t *dialTask) resolve(srv *Server) bool {
- if srv.ntab == nil {
- srv.log.Debug("Can't resolve node", "id", t.dest.ID, "err", "discovery is disabled")
+ if srv.staticNodeResolver == nil {
+ srv.log.Debug("Can't resolve node", "id", t.dest.ID(), "err", "discovery is disabled")
return false
}
if t.resolveDelay == 0 {
@@ -314,20 +277,20 @@ func (t *dialTask) resolve(srv *Server) bool {
if time.Since(t.lastResolved) < t.resolveDelay {
return false
}
- resolved := srv.ntab.Resolve(t.dest)
+ resolved := srv.staticNodeResolver.Resolve(t.dest)
t.lastResolved = time.Now()
if resolved == nil {
t.resolveDelay *= 2
if t.resolveDelay > maxResolveDelay {
t.resolveDelay = maxResolveDelay
}
- srv.log.Debug("Resolving node failed", "id", t.dest.ID, "newdelay", t.resolveDelay)
+ srv.log.Debug("Resolving node failed", "id", t.dest.ID(), "newdelay", t.resolveDelay)
return false
}
// The node was found.
t.resolveDelay = initialResolveDelay
t.dest = resolved
- srv.log.Debug("Resolved node", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()})
+ srv.log.Debug("Resolved node", "id", t.dest.ID(), "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()})
return true
}
@@ -350,26 +313,34 @@ func (t *dialTask) String() string {
return fmt.Sprintf("%v %x %v:%d", t.flags, id[:8], t.dest.IP(), t.dest.TCP())
}
+// discoverTask runs discovery table operations.
+// Only one discoverTask is active at any time.
+// discoverTask.Do performs a random lookup.
+type discoverTask struct {
+ want int
+ results []*enode.Node
+}
+
func (t *discoverTask) Do(srv *Server) {
- // newTasks generates a lookup task whenever dynamic dials are
- // necessary. Lookups need to take some time, otherwise the
- // event loop spins too fast.
- next := srv.lastLookup.Add(lookupInterval)
- if now := time.Now(); now.Before(next) {
- time.Sleep(next.Sub(now))
- }
- srv.lastLookup = time.Now()
- t.results = srv.ntab.LookupRandom()
+ t.results = enode.ReadNodes(srv.discmix, t.want)
}
func (t *discoverTask) String() string {
- s := "discovery lookup"
+ s := "discovery query"
if len(t.results) > 0 {
s += fmt.Sprintf(" (%d results)", len(t.results))
+ } else {
+ s += fmt.Sprintf(" (want %d)", t.want)
}
return s
}
+// A waitExpireTask is generated if there are no other tasks
+// to keep the loop in Server.run ticking.
+type waitExpireTask struct {
+ time.Duration
+}
+
func (t waitExpireTask) Do(*Server) {
time.Sleep(t.Duration)
}
diff --git a/p2p/dial_test.go b/p2p/dial_test.go
index de8fc4a6e3e6..6189ec4d0b85 100644
--- a/p2p/dial_test.go
+++ b/p2p/dial_test.go
@@ -73,7 +73,7 @@ func runDialTest(t *testing.T, test dialtest) {
t.Errorf("ERROR round %d: got %v\nwant %v\nstate: %v\nrunning: %v",
i, spew.Sdump(new), spew.Sdump(round.new), spew.Sdump(test.init), spew.Sdump(running))
}
- t.Logf("round %d new tasks: %s", i, strings.TrimSpace(spew.Sdump(new)))
+ t.Logf("round %d (running %d) new tasks: %s", i, running, strings.TrimSpace(spew.Sdump(new)))
// Time advances by 16 seconds on every round.
vtime = vtime.Add(16 * time.Second)
@@ -81,19 +81,11 @@ func runDialTest(t *testing.T, test dialtest) {
}
}
-type fakeTable []*enode.Node
-
-func (t fakeTable) Self() *enode.Node { return new(enode.Node) }
-func (t fakeTable) Close() {}
-func (t fakeTable) LookupRandom() []*enode.Node { return nil }
-func (t fakeTable) Resolve(*enode.Node) *enode.Node { return nil }
-func (t fakeTable) ReadRandomNodes(buf []*enode.Node) int { return copy(buf, t) }
-
// This test checks that dynamic dials are launched from discovery results.
func TestDialStateDynDial(t *testing.T) {
config := &Config{Logger: testlog.Logger(t, log.LvlTrace)}
runDialTest(t, dialtest{
- init: newDialState(enode.ID{}, fakeTable{}, 5, config),
+ init: newDialState(enode.ID{}, 5, config),
rounds: []round{
// A discovery query is launched.
{
@@ -102,7 +94,9 @@ func TestDialStateDynDial(t *testing.T) {
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
},
- new: []task{&discoverTask{}},
+ new: []task{
+ &discoverTask{want: 3},
+ },
},
// Dynamic dials are launched when it completes.
{
@@ -188,7 +182,7 @@ func TestDialStateDynDial(t *testing.T) {
},
new: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(7), nil)},
- &discoverTask{},
+ &discoverTask{want: 2},
},
},
// Peer 7 is connected, but there still aren't enough dynamic peers
@@ -218,7 +212,7 @@ func TestDialStateDynDial(t *testing.T) {
&discoverTask{},
},
new: []task{
- &discoverTask{},
+ &discoverTask{want: 2},
},
},
},
@@ -235,35 +229,37 @@ func TestDialStateDynDialBootnode(t *testing.T) {
},
Logger: testlog.Logger(t, log.LvlTrace),
}
- table := fakeTable{
- newNode(uintID(4), nil),
- newNode(uintID(5), nil),
- newNode(uintID(6), nil),
- newNode(uintID(7), nil),
- newNode(uintID(8), nil),
- }
runDialTest(t, dialtest{
- init: newDialState(enode.ID{}, table, 5, config),
+ init: newDialState(enode.ID{}, 5, config),
rounds: []round{
- // 2 dynamic dials attempted, bootnodes pending fallback interval
{
new: []task{
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
- &discoverTask{},
+ &discoverTask{want: 5},
},
},
- // No dials succeed, bootnodes still pending fallback interval
{
done: []task{
+ &discoverTask{
+ results: []*enode.Node{
+ newNode(uintID(4), nil),
+ newNode(uintID(5), nil),
+ },
+ },
+ },
+ new: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
+ &discoverTask{want: 3},
},
},
// No dials succeed, bootnodes still pending fallback interval
{},
- // No dials succeed, 2 dynamic dials attempted and 1 bootnode too as fallback interval was reached
+ // 1 bootnode attempted as fallback interval was reached
{
+ done: []task{
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
+ },
new: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)},
},
@@ -275,15 +271,12 @@ func TestDialStateDynDialBootnode(t *testing.T) {
},
new: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)},
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
},
},
// No dials succeed, 3rd bootnode is attempted
{
done: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)},
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
},
new: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)},
@@ -293,115 +286,19 @@ func TestDialStateDynDialBootnode(t *testing.T) {
{
done: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)},
- },
- new: []task{},
- },
- // Random dial succeeds, no more bootnodes are attempted
- {
- new: []task{
- &waitExpireTask{3 * time.Second},
- },
- peers: []*Peer{
- {rw: &conn{flags: dynDialedConn, node: newNode(uintID(4), nil)}},
- },
- done: []task{
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)},
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
- },
- },
- },
- })
-}
-
-func TestDialStateDynDialFromTable(t *testing.T) {
- // This table always returns the same random nodes
- // in the order given below.
- table := fakeTable{
- newNode(uintID(1), nil),
- newNode(uintID(2), nil),
- newNode(uintID(3), nil),
- newNode(uintID(4), nil),
- newNode(uintID(5), nil),
- newNode(uintID(6), nil),
- newNode(uintID(7), nil),
- newNode(uintID(8), nil),
- }
-
- runDialTest(t, dialtest{
- init: newDialState(enode.ID{}, table, 10, &Config{Logger: testlog.Logger(t, log.LvlTrace)}),
- rounds: []round{
- // 5 out of 8 of the nodes returned by ReadRandomNodes are dialed.
- {
- new: []task{
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)},
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)},
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)},
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
- &discoverTask{},
- },
- },
- // Dialing nodes 1,2 succeeds. Dials from the lookup are launched.
- {
- peers: []*Peer{
- {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
- {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
- },
- done: []task{
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)},
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)},
&discoverTask{results: []*enode.Node{
- newNode(uintID(10), nil),
- newNode(uintID(11), nil),
- newNode(uintID(12), nil),
+ newNode(uintID(6), nil),
}},
},
new: []task{
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(10), nil)},
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(11), nil)},
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(12), nil)},
- &discoverTask{},
- },
- },
- // Dialing nodes 3,4,5 fails. The dials from the lookup succeed.
- {
- peers: []*Peer{
- {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
- {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
- {rw: &conn{flags: dynDialedConn, node: newNode(uintID(10), nil)}},
- {rw: &conn{flags: dynDialedConn, node: newNode(uintID(11), nil)}},
- {rw: &conn{flags: dynDialedConn, node: newNode(uintID(12), nil)}},
- },
- done: []task{
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)},
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(10), nil)},
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(11), nil)},
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(12), nil)},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(6), nil)},
+ &discoverTask{want: 4},
},
},
- // Waiting for expiry. No waitExpireTask is launched because the
- // discovery query is still running.
+ // Random dial succeeds, no more bootnodes are attempted
{
peers: []*Peer{
- {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
- {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
- {rw: &conn{flags: dynDialedConn, node: newNode(uintID(10), nil)}},
- {rw: &conn{flags: dynDialedConn, node: newNode(uintID(11), nil)}},
- {rw: &conn{flags: dynDialedConn, node: newNode(uintID(12), nil)}},
- },
- },
- // Nodes 3,4 are not tried again because only the first two
- // returned random nodes (nodes 1,2) are tried and they're
- // already connected.
- {
- peers: []*Peer{
- {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
- {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
- {rw: &conn{flags: dynDialedConn, node: newNode(uintID(10), nil)}},
- {rw: &conn{flags: dynDialedConn, node: newNode(uintID(11), nil)}},
- {rw: &conn{flags: dynDialedConn, node: newNode(uintID(12), nil)}},
+ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(6), nil)}},
},
},
},
@@ -416,11 +313,11 @@ func newNode(id enode.ID, ip net.IP) *enode.Node {
return enode.SignNull(&r, id)
}
-// This test checks that candidates that do not match the netrestrict list are not dialed.
+// // This test checks that candidates that do not match the netrestrict list are not dialed.
func TestDialStateNetRestrict(t *testing.T) {
// This table always returns the same random nodes
// in the order given below.
- table := fakeTable{
+ nodes := []*enode.Node{
newNode(uintID(1), net.ParseIP("127.0.0.1")),
newNode(uintID(2), net.ParseIP("127.0.0.2")),
newNode(uintID(3), net.ParseIP("127.0.0.3")),
@@ -434,12 +331,23 @@ func TestDialStateNetRestrict(t *testing.T) {
restrict.Add("127.0.2.0/24")
runDialTest(t, dialtest{
- init: newDialState(enode.ID{}, table, 10, &Config{NetRestrict: restrict}),
+ init: newDialState(enode.ID{}, 10, &Config{NetRestrict: restrict}),
rounds: []round{
{
new: []task{
- &dialTask{flags: dynDialedConn, dest: table[4]},
- &discoverTask{},
+ &discoverTask{want: 10},
+ },
+ },
+ {
+ done: []task{
+ &discoverTask{results: nodes},
+ },
+ new: []task{
+ &dialTask{flags: dynDialedConn, dest: nodes[4]},
+ &dialTask{flags: dynDialedConn, dest: nodes[5]},
+ &dialTask{flags: dynDialedConn, dest: nodes[6]},
+ &dialTask{flags: dynDialedConn, dest: nodes[7]},
+ &discoverTask{want: 6},
},
},
},
@@ -459,7 +367,7 @@ func TestDialStateStaticDial(t *testing.T) {
Logger: testlog.Logger(t, log.LvlTrace),
}
runDialTest(t, dialtest{
- init: newDialState(enode.ID{}, fakeTable{}, 0, config),
+ init: newDialState(enode.ID{}, 0, config),
rounds: []round{
// Static dials are launched for the nodes that
// aren't yet connected.
@@ -544,7 +452,7 @@ func TestDialStateCache(t *testing.T) {
Logger: testlog.Logger(t, log.LvlTrace),
}
runDialTest(t, dialtest{
- init: newDialState(enode.ID{}, fakeTable{}, 0, config),
+ init: newDialState(enode.ID{}, 0, config),
rounds: []round{
// Static dials are launched for the nodes that
// aren't yet connected.
@@ -618,8 +526,8 @@ func TestDialResolve(t *testing.T) {
Dialer: TCPDialer{&net.Dialer{Deadline: time.Now().Add(-5 * time.Minute)}},
}
resolved := newNode(uintID(1), net.IP{127, 0, 55, 234})
- table := &resolveMock{answer: resolved}
- state := newDialState(enode.ID{}, table, 0, config)
+ resolver := &resolveMock{answer: resolved}
+ state := newDialState(enode.ID{}, 0, config)
// Check that the task is generated with an incomplete ID.
dest := newNode(uintID(1), nil)
@@ -630,10 +538,14 @@ func TestDialResolve(t *testing.T) {
}
// Now run the task, it should resolve the ID once.
- srv := &Server{ntab: table, log: config.Logger, Config: *config}
+ srv := &Server{
+ Config: *config,
+ log: config.Logger,
+ staticNodeResolver: resolver,
+ }
tasks[0].Do(srv)
- if !reflect.DeepEqual(table.resolveCalls, []*enode.Node{dest}) {
- t.Fatalf("wrong resolve calls, got %v", table.resolveCalls)
+ if !reflect.DeepEqual(resolver.calls, []*enode.Node{dest}) {
+ t.Fatalf("wrong resolve calls, got %v", resolver.calls)
}
// Report it as done to the dialer, which should update the static node record.
@@ -666,18 +578,13 @@ func uintID(i uint32) enode.ID {
return id
}
-// implements discoverTable for TestDialResolve
+// for TestDialResolve
type resolveMock struct {
- resolveCalls []*enode.Node
- answer *enode.Node
+ calls []*enode.Node
+ answer *enode.Node
}
func (t *resolveMock) Resolve(n *enode.Node) *enode.Node {
- t.resolveCalls = append(t.resolveCalls, n)
+ t.calls = append(t.calls, n)
return t.answer
}
-
-func (t *resolveMock) Self() *enode.Node { return new(enode.Node) }
-func (t *resolveMock) Close() {}
-func (t *resolveMock) LookupRandom() []*enode.Node { return nil }
-func (t *resolveMock) ReadRandomNodes(buf []*enode.Node) int { return 0 }
diff --git a/p2p/discover/common.go b/p2p/discover/common.go
index 3c080359fdf8..cef6a9fc4f9f 100644
--- a/p2p/discover/common.go
+++ b/p2p/discover/common.go
@@ -25,6 +25,7 @@ import (
"github.com/ethereum/go-ethereum/p2p/netutil"
)
+// UDPConn is a network connection on which discovery can operate.
type UDPConn interface {
ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error)
WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error)
@@ -32,7 +33,7 @@ type UDPConn interface {
LocalAddr() net.Addr
}
-// Config holds Table-related settings.
+// Config holds settings for the discovery listener.
type Config struct {
// These settings are required and configure the UDP listener:
PrivateKey *ecdsa.PrivateKey
@@ -50,7 +51,7 @@ func ListenUDP(c UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv4, error) {
}
// ReadPacket is a packet that couldn't be handled. Those packets are sent to the unhandled
-// channel if configured.
+// channel if configured. This is exported for internal use, do not use this type.
type ReadPacket struct {
Data []byte
Addr *net.UDPAddr
diff --git a/p2p/discover/lookup.go b/p2p/discover/lookup.go
new file mode 100644
index 000000000000..f988e0683808
--- /dev/null
+++ b/p2p/discover/lookup.go
@@ -0,0 +1,209 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package discover
+
+import (
+ "context"
+
+ "github.com/ethereum/go-ethereum/p2p/enode"
+)
+
+// lookup performs a network search for nodes close to the given target. It approaches the
+// target by querying nodes that are closer to it on each iteration. The given target does
+// not need to be an actual node identifier.
+type lookup struct {
+ tab *Table
+ queryfunc func(*node) ([]*node, error)
+ replyCh chan []*node
+ cancelCh <-chan struct{}
+ asked, seen map[enode.ID]bool
+ result nodesByDistance
+ replyBuffer []*node
+ queries int
+}
+
+type queryFunc func(*node) ([]*node, error)
+
+func newLookup(ctx context.Context, tab *Table, target enode.ID, q queryFunc) *lookup {
+ it := &lookup{
+ tab: tab,
+ queryfunc: q,
+ asked: make(map[enode.ID]bool),
+ seen: make(map[enode.ID]bool),
+ result: nodesByDistance{target: target},
+ replyCh: make(chan []*node, alpha),
+ cancelCh: ctx.Done(),
+ queries: -1,
+ }
+ // Don't query further if we hit ourself.
+ // Unlikely to happen often in practice.
+ it.asked[tab.self().ID()] = true
+ return it
+}
+
+// run runs the lookup to completion and returns the closest nodes found.
+func (it *lookup) run() []*enode.Node {
+ for it.advance() {
+ }
+ return unwrapNodes(it.result.entries)
+}
+
+// advance advances the lookup until any new nodes have been found.
+// It returns false when the lookup has ended.
+func (it *lookup) advance() bool {
+ for it.startQueries() {
+ select {
+ case nodes := <-it.replyCh:
+ it.replyBuffer = it.replyBuffer[:0]
+ for _, n := range nodes {
+ if n != nil && !it.seen[n.ID()] {
+ it.seen[n.ID()] = true
+ it.result.push(n, bucketSize)
+ it.replyBuffer = append(it.replyBuffer, n)
+ }
+ }
+ it.queries--
+ if len(it.replyBuffer) > 0 {
+ return true
+ }
+ case <-it.cancelCh:
+ it.shutdown()
+ }
+ }
+ return false
+}
+
+func (it *lookup) shutdown() {
+ for it.queries > 0 {
+ <-it.replyCh
+ it.queries--
+ }
+ it.queryfunc = nil
+ it.replyBuffer = nil
+}
+
+func (it *lookup) startQueries() bool {
+ if it.queryfunc == nil {
+ return false
+ }
+
+ // The first query returns nodes from the local table.
+ if it.queries == -1 {
+ it.tab.mutex.Lock()
+ closest := it.tab.closest(it.result.target, bucketSize, false)
+ it.tab.mutex.Unlock()
+ it.queries = 1
+ it.replyCh <- closest.entries
+ return true
+ }
+
+ // Ask the closest nodes that we haven't asked yet.
+ for i := 0; i < len(it.result.entries) && it.queries < alpha; i++ {
+ n := it.result.entries[i]
+ if !it.asked[n.ID()] {
+ it.asked[n.ID()] = true
+ it.queries++
+ go it.query(n, it.replyCh)
+ }
+ }
+ // The lookup ends when no more nodes can be asked.
+ return it.queries > 0
+}
+
+func (it *lookup) query(n *node, reply chan<- []*node) {
+ fails := it.tab.db.FindFails(n.ID(), n.IP())
+ r, err := it.queryfunc(n)
+ if err == errClosed {
+ // Avoid recording failures on shutdown.
+ reply <- nil
+ return
+ } else if len(r) == 0 {
+ fails++
+ it.tab.db.UpdateFindFails(n.ID(), n.IP(), fails)
+ it.tab.log.Trace("Findnode failed", "id", n.ID(), "failcount", fails, "err", err)
+ if fails >= maxFindnodeFailures {
+ it.tab.log.Trace("Too many findnode failures, dropping", "id", n.ID(), "failcount", fails)
+ it.tab.delete(n)
+ }
+ } else if fails > 0 {
+ // Reset failure counter because it counts _consecutive_ failures.
+ it.tab.db.UpdateFindFails(n.ID(), n.IP(), 0)
+ }
+
+ // Grab as many nodes as possible. Some of them might not be alive anymore, but we'll
+ // just remove those again during revalidation.
+ for _, n := range r {
+ it.tab.addSeenNode(n)
+ }
+ reply <- r
+}
+
+// lookupIterator performs lookup operations and iterates over all seen nodes.
+// When a lookup finishes, a new one is created through nextLookup.
+type lookupIterator struct {
+ buffer []*node
+ nextLookup lookupFunc
+ ctx context.Context
+ cancel func()
+ lookup *lookup
+}
+
+type lookupFunc func(ctx context.Context) *lookup
+
+func newLookupIterator(ctx context.Context, next lookupFunc) *lookupIterator {
+ ctx, cancel := context.WithCancel(ctx)
+ return &lookupIterator{ctx: ctx, cancel: cancel, nextLookup: next}
+}
+
+// Node returns the current node.
+func (it *lookupIterator) Node() *enode.Node {
+ if len(it.buffer) == 0 {
+ return nil
+ }
+ return unwrapNode(it.buffer[0])
+}
+
+// Next moves to the next node.
+func (it *lookupIterator) Next() bool {
+ // Consume next node in buffer.
+ if len(it.buffer) > 0 {
+ it.buffer = it.buffer[1:]
+ }
+ // Advance the lookup to refill the buffer.
+ for len(it.buffer) == 0 {
+ if it.ctx.Err() != nil {
+ it.lookup = nil
+ it.buffer = nil
+ return false
+ }
+ if it.lookup == nil {
+ it.lookup = it.nextLookup(it.ctx)
+ continue
+ }
+ if !it.lookup.advance() {
+ it.lookup = nil
+ continue
+ }
+ it.buffer = it.lookup.replyBuffer
+ }
+ return true
+}
+
+// Close ends the iterator.
+func (it *lookupIterator) Close() {
+ it.cancel()
+}
diff --git a/p2p/discover/table_util_test.go b/p2p/discover/table_util_test.go
index 2292055e160d..e35e48c5e69d 100644
--- a/p2p/discover/table_util_test.go
+++ b/p2p/discover/table_util_test.go
@@ -17,11 +17,14 @@
package discover
import (
+ "bytes"
"crypto/ecdsa"
"encoding/hex"
+ "errors"
"fmt"
"math/rand"
"net"
+ "reflect"
"sort"
"sync"
@@ -169,6 +172,28 @@ func hasDuplicates(slice []*node) bool {
return false
}
+func checkNodesEqual(got, want []*enode.Node) error {
+ if reflect.DeepEqual(got, want) {
+ return nil
+ }
+ output := new(bytes.Buffer)
+ fmt.Fprintf(output, "got %d nodes:\n", len(got))
+ for _, n := range got {
+ fmt.Fprintf(output, " %v %v\n", n.ID(), n)
+ }
+ fmt.Fprintf(output, "want %d:\n", len(want))
+ for _, n := range want {
+ fmt.Fprintf(output, " %v %v\n", n.ID(), n)
+ }
+ return errors.New(output.String())
+}
+
+func sortByID(nodes []*enode.Node) {
+ sort.Slice(nodes, func(i, j int) bool {
+ return string(nodes[i].ID().Bytes()) < string(nodes[j].ID().Bytes())
+ })
+}
+
func sortedByDistanceTo(distbase enode.ID, slice []*node) bool {
return sort.SliceIsSorted(slice, func(i, j int) bool {
return enode.DistCmp(distbase, slice[i].ID(), slice[j].ID()) < 0
diff --git a/p2p/discover/v4_udp_lookup_test.go b/p2p/discover/v4_lookup_test.go
similarity index 75%
rename from p2p/discover/v4_udp_lookup_test.go
rename to p2p/discover/v4_lookup_test.go
index bc1cdfb089ab..9b4042c5a276 100644
--- a/p2p/discover/v4_udp_lookup_test.go
+++ b/p2p/discover/v4_lookup_test.go
@@ -20,7 +20,6 @@ import (
"crypto/ecdsa"
"fmt"
"net"
- "reflect"
"sort"
"testing"
@@ -49,19 +48,7 @@ func TestUDPv4_Lookup(t *testing.T) {
}()
// Answer lookup packets.
- for done := false; !done; {
- done = test.waitPacketOut(func(p packetV4, to *net.UDPAddr, hash []byte) {
- n, key := lookupTestnet.nodeByAddr(to)
- switch p.(type) {
- case *pingV4:
- test.packetInFrom(nil, key, to, &pongV4{Expiration: futureExp, ReplyTok: hash})
- case *findnodeV4:
- dist := enode.LogDist(n.ID(), lookupTestnet.target.id())
- nodes := lookupTestnet.nodesAtDistance(dist - 1)
- test.packetInFrom(nil, key, to, &neighborsV4{Expiration: futureExp, Nodes: nodes})
- }
- })
- }
+ serveTestnet(test, lookupTestnet)
// Verify result nodes.
results := <-resultC
@@ -78,8 +65,94 @@ func TestUDPv4_Lookup(t *testing.T) {
if !sortedByDistanceTo(lookupTestnet.target.id(), wrapNodes(results)) {
t.Errorf("result set not sorted by distance to target")
}
- if !reflect.DeepEqual(results, lookupTestnet.closest(bucketSize)) {
- t.Errorf("results aren't the closest %d nodes", bucketSize)
+ if err := checkNodesEqual(results, lookupTestnet.closest(bucketSize)); err != nil {
+ t.Errorf("results aren't the closest %d nodes\n%v", bucketSize, err)
+ }
+}
+
+func TestUDPv4_LookupIterator(t *testing.T) {
+ t.Parallel()
+ test := newUDPTest(t)
+ defer test.close()
+
+ // Seed table with initial nodes.
+ bootnodes := make([]*node, len(lookupTestnet.dists[256]))
+ for i := range lookupTestnet.dists[256] {
+ bootnodes[i] = wrapNode(lookupTestnet.node(256, i))
+ }
+ fillTable(test.table, bootnodes)
+ go serveTestnet(test, lookupTestnet)
+
+ // Create the iterator and collect the nodes it yields.
+ iter := test.udp.RandomNodes()
+ seen := make(map[enode.ID]*enode.Node)
+ for limit := lookupTestnet.len(); iter.Next() && len(seen) < limit; {
+ seen[iter.Node().ID()] = iter.Node()
+ }
+ iter.Close()
+
+ // Check that all nodes in lookupTestnet were seen by the iterator.
+ results := make([]*enode.Node, 0, len(seen))
+ for _, n := range seen {
+ results = append(results, n)
+ }
+ sortByID(results)
+ want := lookupTestnet.nodes()
+ if err := checkNodesEqual(results, want); err != nil {
+ t.Fatal(err)
+ }
+}
+
+// TestUDPv4_LookupIteratorClose checks that lookupIterator ends when its Close
+// method is called.
+func TestUDPv4_LookupIteratorClose(t *testing.T) {
+ t.Parallel()
+ test := newUDPTest(t)
+ defer test.close()
+
+ // Seed table with initial nodes.
+ bootnodes := make([]*node, len(lookupTestnet.dists[256]))
+ for i := range lookupTestnet.dists[256] {
+ bootnodes[i] = wrapNode(lookupTestnet.node(256, i))
+ }
+ fillTable(test.table, bootnodes)
+ go serveTestnet(test, lookupTestnet)
+
+ it := test.udp.RandomNodes()
+ if ok := it.Next(); !ok || it.Node() == nil {
+ t.Fatalf("iterator didn't return any node")
+ }
+
+ it.Close()
+
+ ncalls := 0
+ for ; ncalls < 100 && it.Next(); ncalls++ {
+ if it.Node() == nil {
+ t.Error("iterator returned Node() == nil node after Next() == true")
+ }
+ }
+ t.Logf("iterator returned %d nodes after close", ncalls)
+ if it.Next() {
+ t.Errorf("Next() == true after close and %d more calls", ncalls)
+ }
+ if n := it.Node(); n != nil {
+ t.Errorf("iterator returned non-nil node after close and %d more calls", ncalls)
+ }
+}
+
+func serveTestnet(test *udpTest, testnet *preminedTestnet) {
+ for done := false; !done; {
+ done = test.waitPacketOut(func(p packetV4, to *net.UDPAddr, hash []byte) {
+ n, key := testnet.nodeByAddr(to)
+ switch p.(type) {
+ case *pingV4:
+ test.packetInFrom(nil, key, to, &pongV4{Expiration: futureExp, ReplyTok: hash})
+ case *findnodeV4:
+ dist := enode.LogDist(n.ID(), testnet.target.id())
+ nodes := testnet.nodesAtDistance(dist - 1)
+ test.packetInFrom(nil, key, to, &neighborsV4{Expiration: futureExp, Nodes: nodes})
+ }
+ })
}
}
@@ -148,6 +221,25 @@ type preminedTestnet struct {
dists [hashBits + 1][]*ecdsa.PrivateKey
}
+func (tn *preminedTestnet) len() int {
+ n := 0
+ for _, keys := range tn.dists {
+ n += len(keys)
+ }
+ return n
+}
+
+func (tn *preminedTestnet) nodes() []*enode.Node {
+ result := make([]*enode.Node, 0, tn.len())
+ for dist, keys := range tn.dists {
+ for index := range keys {
+ result = append(result, tn.node(dist, index))
+ }
+ }
+ sortByID(result)
+ return result
+}
+
func (tn *preminedTestnet) node(dist, index int) *enode.Node {
key := tn.dists[dist][index]
ip := net.IP{127, byte(dist >> 8), byte(dist), byte(index)}
diff --git a/p2p/discover/v4_udp.go b/p2p/discover/v4_udp.go
index a8f7101b0594..bfb66fcb1967 100644
--- a/p2p/discover/v4_udp.go
+++ b/p2p/discover/v4_udp.go
@@ -19,6 +19,7 @@ package discover
import (
"bytes"
"container/list"
+ "context"
"crypto/ecdsa"
crand "crypto/rand"
"errors"
@@ -207,7 +208,8 @@ type UDPv4 struct {
addReplyMatcher chan *replyMatcher
gotreply chan reply
- closing chan struct{}
+ closeCtx context.Context
+ cancelCloseCtx func()
}
// replyMatcher represents a pending reply.
@@ -256,20 +258,23 @@ type reply struct {
}
func ListenV4(c UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv4, error) {
+ closeCtx, cancel := context.WithCancel(context.Background())
t := &UDPv4{
conn: c,
priv: cfg.PrivateKey,
netrestrict: cfg.NetRestrict,
localNode: ln,
db: ln.Database(),
- closing: make(chan struct{}),
gotreply: make(chan reply),
addReplyMatcher: make(chan *replyMatcher),
+ closeCtx: closeCtx,
+ cancelCloseCtx: cancel,
log: cfg.Log,
}
if t.log == nil {
t.log = log.Root()
}
+
tab, err := newTable(t, ln.Database(), cfg.Bootnodes, t.log)
if err != nil {
return nil, err
@@ -291,126 +296,13 @@ func (t *UDPv4) Self() *enode.Node {
// Close shuts down the socket and aborts any running queries.
func (t *UDPv4) Close() {
t.closeOnce.Do(func() {
- close(t.closing)
+ t.cancelCloseCtx()
t.conn.Close()
t.wg.Wait()
t.tab.close()
})
}
-// ReadRandomNodes reads random nodes from the local table.
-func (t *UDPv4) ReadRandomNodes(buf []*enode.Node) int {
- return t.tab.ReadRandomNodes(buf)
-}
-
-// LookupRandom finds random nodes in the network.
-func (t *UDPv4) LookupRandom() []*enode.Node {
- if t.tab.len() == 0 {
- // All nodes were dropped, refresh. The very first query will hit this
- // case and run the bootstrapping logic.
- <-t.tab.refresh()
- }
- return t.lookupRandom()
-}
-
-func (t *UDPv4) LookupPubkey(key *ecdsa.PublicKey) []*enode.Node {
- if t.tab.len() == 0 {
- // All nodes were dropped, refresh. The very first query will hit this
- // case and run the bootstrapping logic.
- <-t.tab.refresh()
- }
- return unwrapNodes(t.lookup(encodePubkey(key)))
-}
-
-func (t *UDPv4) lookupRandom() []*enode.Node {
- var target encPubkey
- crand.Read(target[:])
- return unwrapNodes(t.lookup(target))
-}
-
-func (t *UDPv4) lookupSelf() []*enode.Node {
- return unwrapNodes(t.lookup(encodePubkey(&t.priv.PublicKey)))
-}
-
-// lookup performs a network search for nodes close to the given target. It approaches the
-// target by querying nodes that are closer to it on each iteration. The given target does
-// not need to be an actual node identifier.
-func (t *UDPv4) lookup(targetKey encPubkey) []*node {
- var (
- target = enode.ID(crypto.Keccak256Hash(targetKey[:]))
- asked = make(map[enode.ID]bool)
- seen = make(map[enode.ID]bool)
- reply = make(chan []*node, alpha)
- pendingQueries = 0
- result *nodesByDistance
- )
- // Don't query further if we hit ourself.
- // Unlikely to happen often in practice.
- asked[t.Self().ID()] = true
-
- // Generate the initial result set.
- t.tab.mutex.Lock()
- result = t.tab.closest(target, bucketSize, false)
- t.tab.mutex.Unlock()
-
- for {
- // ask the alpha closest nodes that we haven't asked yet
- for i := 0; i < len(result.entries) && pendingQueries < alpha; i++ {
- n := result.entries[i]
- if !asked[n.ID()] {
- asked[n.ID()] = true
- pendingQueries++
- go t.lookupWorker(n, targetKey, reply)
- }
- }
- if pendingQueries == 0 {
- // we have asked all closest nodes, stop the search
- break
- }
- select {
- case nodes := <-reply:
- for _, n := range nodes {
- if n != nil && !seen[n.ID()] {
- seen[n.ID()] = true
- result.push(n, bucketSize)
- }
- }
- case <-t.tab.closeReq:
- return nil // shutdown, no need to continue.
- }
- pendingQueries--
- }
- return result.entries
-}
-
-func (t *UDPv4) lookupWorker(n *node, targetKey encPubkey, reply chan<- []*node) {
- fails := t.db.FindFails(n.ID(), n.IP())
- r, err := t.findnode(n.ID(), n.addr(), targetKey)
- if err == errClosed {
- // Avoid recording failures on shutdown.
- reply <- nil
- return
- } else if len(r) == 0 {
- fails++
- t.db.UpdateFindFails(n.ID(), n.IP(), fails)
- t.log.Trace("Findnode failed", "id", n.ID(), "failcount", fails, "err", err)
- if fails >= maxFindnodeFailures {
- t.log.Trace("Too many findnode failures, dropping", "id", n.ID(), "failcount", fails)
- t.tab.delete(n)
- }
- } else if fails > 0 {
- // Reset failure counter because it counts _consecutive_ failures.
- t.db.UpdateFindFails(n.ID(), n.IP(), 0)
- }
-
- // Grab as many nodes as possible. Some of them might not be alive anymore, but we'll
- // just remove those again during revalidation.
- for _, n := range r {
- t.tab.addSeenNode(n)
- }
- reply <- r
-}
-
// Resolve searches for a specific node with the given ID and tries to get the most recent
// version of the node record for it. It returns n if the node could not be resolved.
func (t *UDPv4) Resolve(n *enode.Node) *enode.Node {
@@ -498,6 +390,45 @@ func (t *UDPv4) makePing(toaddr *net.UDPAddr) *pingV4 {
}
}
+// LookupPubkey finds the closest nodes to the given public key.
+func (t *UDPv4) LookupPubkey(key *ecdsa.PublicKey) []*enode.Node {
+ if t.tab.len() == 0 {
+ // All nodes were dropped, refresh. The very first query will hit this
+ // case and run the bootstrapping logic.
+ <-t.tab.refresh()
+ }
+ return t.newLookup(t.closeCtx, encodePubkey(key)).run()
+}
+
+// RandomNodes is an iterator yielding nodes from a random walk of the DHT.
+func (t *UDPv4) RandomNodes() enode.Iterator {
+ return newLookupIterator(t.closeCtx, t.newRandomLookup)
+}
+
+// lookupRandom implements transport.
+func (t *UDPv4) lookupRandom() []*enode.Node {
+ return t.newRandomLookup(t.closeCtx).run()
+}
+
+// lookupSelf implements transport.
+func (t *UDPv4) lookupSelf() []*enode.Node {
+ return t.newLookup(t.closeCtx, encodePubkey(&t.priv.PublicKey)).run()
+}
+
+func (t *UDPv4) newRandomLookup(ctx context.Context) *lookup {
+ var target encPubkey
+ crand.Read(target[:])
+ return t.newLookup(ctx, target)
+}
+
+func (t *UDPv4) newLookup(ctx context.Context, targetKey encPubkey) *lookup {
+ target := enode.ID(crypto.Keccak256Hash(targetKey[:]))
+ it := newLookup(ctx, t.tab, target, func(n *node) ([]*node, error) {
+ return t.findnode(n.ID(), n.addr(), targetKey)
+ })
+ return it
+}
+
// findnode sends a findnode request to the given node and waits until
// the node has sent up to k neighbors.
func (t *UDPv4) findnode(toid enode.ID, toaddr *net.UDPAddr, target encPubkey) ([]*node, error) {
@@ -575,7 +506,7 @@ func (t *UDPv4) pending(id enode.ID, ip net.IP, ptype byte, callback replyMatchF
select {
case t.addReplyMatcher <- p:
// loop will handle it
- case <-t.closing:
+ case <-t.closeCtx.Done():
ch <- errClosed
}
return p
@@ -589,7 +520,7 @@ func (t *UDPv4) handleReply(from enode.ID, fromIP net.IP, req packetV4) bool {
case t.gotreply <- reply{from, fromIP, req, matched}:
// loop will handle it
return <-matched
- case <-t.closing:
+ case <-t.closeCtx.Done():
return false
}
}
@@ -635,7 +566,7 @@ func (t *UDPv4) loop() {
resetTimeout()
select {
- case <-t.closing:
+ case <-t.closeCtx.Done():
for el := plist.Front(); el != nil; el = el.Next() {
el.Value.(*replyMatcher).errc <- errClosed
}
diff --git a/p2p/enode/iter.go b/p2p/enode/iter.go
new file mode 100644
index 000000000000..112b76d06a0e
--- /dev/null
+++ b/p2p/enode/iter.go
@@ -0,0 +1,286 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package enode
+
+import (
+ "sync"
+ "time"
+)
+
+// Iterator represents a sequence of nodes. The Next method moves to the next node in the
+// sequence. It returns false when the sequence has ended or the iterator is closed. Close
+// may be called concurrently with Next and Node, and interrupts Next if it is blocked.
+type Iterator interface {
+ Next() bool // moves to next node
+ Node() *Node // returns current node
+ Close() // ends the iterator
+}
+
+// ReadNodes reads at most n nodes from the given iterator. The return value contains no
+// duplicates and no nil values. To prevent looping indefinitely for small repeating node
+// sequences, this function calls Next at most n times.
+func ReadNodes(it Iterator, n int) []*Node {
+ seen := make(map[ID]*Node, n)
+ for i := 0; i < n && it.Next(); i++ {
+ // Remove duplicates, keeping the node with higher seq.
+ node := it.Node()
+ prevNode, ok := seen[node.ID()]
+ if ok && prevNode.Seq() > node.Seq() {
+ continue
+ }
+ seen[node.ID()] = node
+ }
+ result := make([]*Node, 0, len(seen))
+ for _, node := range seen {
+ result = append(result, node)
+ }
+ return result
+}
+
+// IterNodes makes an iterator which runs through the given nodes once.
+func IterNodes(nodes []*Node) Iterator {
+ return &sliceIter{nodes: nodes, index: -1}
+}
+
+// CycleNodes makes an iterator which cycles through the given nodes indefinitely.
+func CycleNodes(nodes []*Node) Iterator {
+ return &sliceIter{nodes: nodes, index: -1, cycle: true}
+}
+
+type sliceIter struct {
+ mu sync.Mutex
+ nodes []*Node
+ index int
+ cycle bool
+}
+
+func (it *sliceIter) Next() bool {
+ it.mu.Lock()
+ defer it.mu.Unlock()
+
+ if len(it.nodes) == 0 {
+ return false
+ }
+ it.index++
+ if it.index == len(it.nodes) {
+ if it.cycle {
+ it.index = 0
+ } else {
+ it.nodes = nil
+ return false
+ }
+ }
+ return true
+}
+
+func (it *sliceIter) Node() *Node {
+ if len(it.nodes) == 0 {
+ return nil
+ }
+ return it.nodes[it.index]
+}
+
+func (it *sliceIter) Close() {
+ it.mu.Lock()
+ defer it.mu.Unlock()
+
+ it.nodes = nil
+}
+
+// Filter wraps an iterator such that Next only returns nodes for which
+// the 'check' function returns true.
+func Filter(it Iterator, check func(*Node) bool) Iterator {
+ return &filterIter{it, check}
+}
+
+type filterIter struct {
+ Iterator
+ check func(*Node) bool
+}
+
+func (f *filterIter) Next() bool {
+ for f.Iterator.Next() {
+ if f.check(f.Node()) {
+ return true
+ }
+ }
+ return false
+}
+
+// FairMix aggregates multiple node iterators. The mixer itself is an iterator which ends
+// only when Close is called. Source iterators added via AddSource are removed from the
+// mix when they end.
+//
+// The distribution of nodes returned by Next is approximately fair, i.e. FairMix
+// attempts to draw from all sources equally often. However, if a certain source is slow
+// and doesn't return a node within the configured timeout, a node from any other source
+// will be returned.
+//
+// It's safe to call AddSource and Close concurrently with Next.
+type FairMix struct {
+ wg sync.WaitGroup
+ fromAny chan *Node
+ timeout time.Duration
+ cur *Node
+
+ mu sync.Mutex
+ closed chan struct{}
+ sources []*mixSource
+ last int
+}
+
+type mixSource struct {
+ it Iterator
+ next chan *Node
+ timeout time.Duration
+}
+
+// NewFairMix creates a mixer.
+//
+// The timeout specifies how long the mixer will wait for the next fairly-chosen source
+// before giving up and taking a node from any other source. A good way to set the timeout
+// is deciding how long you'd want to wait for a node on average. Passing a negative
+// timeout makes the mixer completely fair.
+func NewFairMix(timeout time.Duration) *FairMix {
+ m := &FairMix{
+ fromAny: make(chan *Node),
+ closed: make(chan struct{}),
+ timeout: timeout,
+ }
+ return m
+}
+
+// AddSource adds a source of nodes.
+func (m *FairMix) AddSource(it Iterator) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ if m.closed == nil {
+ return
+ }
+ m.wg.Add(1)
+ source := &mixSource{it, make(chan *Node), m.timeout}
+ m.sources = append(m.sources, source)
+ go m.runSource(m.closed, source)
+}
+
+// Close shuts down the mixer and all current sources.
+// Calling this is required to release resources associated with the mixer.
+func (m *FairMix) Close() {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ if m.closed == nil {
+ return
+ }
+ for _, s := range m.sources {
+ s.it.Close()
+ }
+ close(m.closed)
+ m.wg.Wait()
+ close(m.fromAny)
+ m.sources = nil
+ m.closed = nil
+}
+
+// Next returns a node from a random source.
+func (m *FairMix) Next() bool {
+ m.cur = nil
+
+ var timeout <-chan time.Time
+ if m.timeout >= 0 {
+ timer := time.NewTimer(m.timeout)
+ timeout = timer.C
+ defer timer.Stop()
+ }
+ for {
+ source := m.pickSource()
+ if source == nil {
+ return m.nextFromAny()
+ }
+ select {
+ case n, ok := <-source.next:
+ if ok {
+ m.cur = n
+ source.timeout = m.timeout
+ return true
+ }
+ // This source has ended.
+ m.deleteSource(source)
+ case <-timeout:
+ source.timeout /= 2
+ return m.nextFromAny()
+ }
+ }
+}
+
+// Node returns the current node.
+func (m *FairMix) Node() *Node {
+ return m.cur
+}
+
+// nextFromAny is used when there are no sources or when the 'fair' choice
+// doesn't turn up a node quickly enough.
+func (m *FairMix) nextFromAny() bool {
+ n, ok := <-m.fromAny
+ if ok {
+ m.cur = n
+ }
+ return ok
+}
+
+// pickSource chooses the next source to read from, cycling through them in order.
+func (m *FairMix) pickSource() *mixSource {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ if len(m.sources) == 0 {
+ return nil
+ }
+ m.last = (m.last + 1) % len(m.sources)
+ return m.sources[m.last]
+}
+
+// deleteSource deletes a source.
+func (m *FairMix) deleteSource(s *mixSource) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ for i := range m.sources {
+ if m.sources[i] == s {
+ copy(m.sources[i:], m.sources[i+1:])
+ m.sources[len(m.sources)-1] = nil
+ m.sources = m.sources[:len(m.sources)-1]
+ break
+ }
+ }
+}
+
+// runSource reads a single source in a loop.
+func (m *FairMix) runSource(closed chan struct{}, s *mixSource) {
+ defer m.wg.Done()
+ defer close(s.next)
+ for s.it.Next() {
+ n := s.it.Node()
+ select {
+ case s.next <- n:
+ case m.fromAny <- n:
+ case <-closed:
+ return
+ }
+ }
+}
diff --git a/p2p/enode/iter_test.go b/p2p/enode/iter_test.go
new file mode 100644
index 000000000000..6009661f3ce6
--- /dev/null
+++ b/p2p/enode/iter_test.go
@@ -0,0 +1,291 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package enode
+
+import (
+ "encoding/binary"
+ "runtime"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/ethereum/go-ethereum/p2p/enr"
+)
+
+func TestReadNodes(t *testing.T) {
+ nodes := ReadNodes(new(genIter), 10)
+ checkNodes(t, nodes, 10)
+}
+
+// This test checks that ReadNodes terminates when reading N nodes from an iterator
+// which returns less than N nodes in an endless cycle.
+func TestReadNodesCycle(t *testing.T) {
+ iter := &callCountIter{
+ Iterator: CycleNodes([]*Node{
+ testNode(0, 0),
+ testNode(1, 0),
+ testNode(2, 0),
+ }),
+ }
+ nodes := ReadNodes(iter, 10)
+ checkNodes(t, nodes, 3)
+ if iter.count != 10 {
+ t.Fatalf("%d calls to Next, want %d", iter.count, 100)
+ }
+}
+
+func TestFilterNodes(t *testing.T) {
+ nodes := make([]*Node, 100)
+ for i := range nodes {
+ nodes[i] = testNode(uint64(i), uint64(i))
+ }
+
+ it := Filter(IterNodes(nodes), func(n *Node) bool {
+ return n.Seq() >= 50
+ })
+ for i := 50; i < len(nodes); i++ {
+ if !it.Next() {
+ t.Fatal("Next returned false")
+ }
+ if it.Node() != nodes[i] {
+ t.Fatalf("iterator returned wrong node %v\nwant %v", it.Node(), nodes[i])
+ }
+ }
+ if it.Next() {
+ t.Fatal("Next returned true after underlying iterator has ended")
+ }
+}
+
+func checkNodes(t *testing.T, nodes []*Node, wantLen int) {
+ if len(nodes) != wantLen {
+ t.Errorf("slice has %d nodes, want %d", len(nodes), wantLen)
+ return
+ }
+ seen := make(map[ID]bool)
+ for i, e := range nodes {
+ if e == nil {
+ t.Errorf("nil node at index %d", i)
+ return
+ }
+ if seen[e.ID()] {
+ t.Errorf("slice has duplicate node %v", e.ID())
+ return
+ }
+ seen[e.ID()] = true
+ }
+}
+
+// This test checks fairness of FairMix in the happy case where all sources return nodes
+// within the context's deadline.
+func TestFairMix(t *testing.T) {
+ for i := 0; i < 500; i++ {
+ testMixerFairness(t)
+ }
+}
+
+func testMixerFairness(t *testing.T) {
+ mix := NewFairMix(1 * time.Second)
+ mix.AddSource(&genIter{index: 1})
+ mix.AddSource(&genIter{index: 2})
+ mix.AddSource(&genIter{index: 3})
+ defer mix.Close()
+
+ nodes := ReadNodes(mix, 500)
+ checkNodes(t, nodes, 500)
+
+ // Verify that the nodes slice contains an approximately equal number of nodes
+ // from each source.
+ d := idPrefixDistribution(nodes)
+ for _, count := range d {
+ if approxEqual(count, len(nodes)/3, 30) {
+ t.Fatalf("ID distribution is unfair: %v", d)
+ }
+ }
+}
+
+// This test checks that FairMix falls back to an alternative source when
+// the 'fair' choice doesn't return a node within the timeout.
+func TestFairMixNextFromAll(t *testing.T) {
+ mix := NewFairMix(1 * time.Millisecond)
+ mix.AddSource(&genIter{index: 1})
+ mix.AddSource(CycleNodes(nil))
+ defer mix.Close()
+
+ nodes := ReadNodes(mix, 500)
+ checkNodes(t, nodes, 500)
+
+ d := idPrefixDistribution(nodes)
+ if len(d) > 1 || d[1] != len(nodes) {
+ t.Fatalf("wrong ID distribution: %v", d)
+ }
+}
+
+// This test ensures FairMix works for Next with no sources.
+func TestFairMixEmpty(t *testing.T) {
+ var (
+ mix = NewFairMix(1 * time.Second)
+ testN = testNode(1, 1)
+ ch = make(chan *Node)
+ )
+ defer mix.Close()
+
+ go func() {
+ mix.Next()
+ ch <- mix.Node()
+ }()
+
+ mix.AddSource(CycleNodes([]*Node{testN}))
+ if n := <-ch; n != testN {
+ t.Errorf("got wrong node: %v", n)
+ }
+}
+
+// This test checks closing a source while Next runs.
+func TestFairMixRemoveSource(t *testing.T) {
+ mix := NewFairMix(1 * time.Second)
+ source := make(blockingIter)
+ mix.AddSource(source)
+
+ sig := make(chan *Node)
+ go func() {
+ <-sig
+ mix.Next()
+ sig <- mix.Node()
+ }()
+
+ sig <- nil
+ runtime.Gosched()
+ source.Close()
+
+ wantNode := testNode(0, 0)
+ mix.AddSource(CycleNodes([]*Node{wantNode}))
+ n := <-sig
+
+ if len(mix.sources) != 1 {
+ t.Fatalf("have %d sources, want one", len(mix.sources))
+ }
+ if n != wantNode {
+ t.Fatalf("mixer returned wrong node")
+ }
+}
+
+type blockingIter chan struct{}
+
+func (it blockingIter) Next() bool {
+ <-it
+ return false
+}
+
+func (it blockingIter) Node() *Node {
+ return nil
+}
+
+func (it blockingIter) Close() {
+ close(it)
+}
+
+func TestFairMixClose(t *testing.T) {
+ for i := 0; i < 20 && !t.Failed(); i++ {
+ testMixerClose(t)
+ }
+}
+
+func testMixerClose(t *testing.T) {
+ mix := NewFairMix(-1)
+ mix.AddSource(CycleNodes(nil))
+ mix.AddSource(CycleNodes(nil))
+
+ done := make(chan struct{})
+ go func() {
+ defer close(done)
+ if mix.Next() {
+ t.Error("Next returned true")
+ }
+ }()
+ // This call is supposed to make it more likely that NextNode is
+ // actually executing by the time we call Close.
+ runtime.Gosched()
+
+ mix.Close()
+ select {
+ case <-done:
+ case <-time.After(3 * time.Second):
+ t.Fatal("Next didn't unblock on Close")
+ }
+
+ mix.Close() // shouldn't crash
+}
+
+func idPrefixDistribution(nodes []*Node) map[uint32]int {
+ d := make(map[uint32]int)
+ for _, node := range nodes {
+ id := node.ID()
+ d[binary.BigEndian.Uint32(id[:4])]++
+ }
+ return d
+}
+
+func approxEqual(x, y, ε int) bool {
+ if y > x {
+ x, y = y, x
+ }
+ return x-y > ε
+}
+
+// genIter creates fake nodes with numbered IDs based on 'index' and 'gen'
+type genIter struct {
+ node *Node
+ index, gen uint32
+}
+
+func (s *genIter) Next() bool {
+ index := atomic.LoadUint32(&s.index)
+ if index == ^uint32(0) {
+ s.node = nil
+ return false
+ }
+ s.node = testNode(uint64(index)<<32|uint64(s.gen), 0)
+ s.gen++
+ return true
+}
+
+func (s *genIter) Node() *Node {
+ return s.node
+}
+
+func (s *genIter) Close() {
+ s.index = ^uint32(0)
+}
+
+func testNode(id, seq uint64) *Node {
+ var nodeID ID
+ binary.BigEndian.PutUint64(nodeID[:], id)
+ r := new(enr.Record)
+ r.SetSeq(seq)
+ return SignNull(r, nodeID)
+}
+
+// callCountIter counts calls to NextNode.
+type callCountIter struct {
+ Iterator
+ count int
+}
+
+func (it *callCountIter) Next() bool {
+ it.count++
+ return it.Iterator.Next()
+}
diff --git a/p2p/protocol.go b/p2p/protocol.go
index 9ce4c20203c7..fa23a087c281 100644
--- a/p2p/protocol.go
+++ b/p2p/protocol.go
@@ -54,6 +54,11 @@ type Protocol struct {
// but returns nil, it is assumed that the protocol handshake is still running.
PeerInfo func(id enode.ID) interface{}
+ // DialCandidates, if non-nil, is a way to tell Server about protocol-specific nodes
+ // that should be dialed. The server continuously reads nodes from the iterator and
+ // attempts to create connections to them.
+ DialCandidates enode.Iterator
+
// Attributes contains protocol specific information for the node record.
Attributes []enr.Entry
}
diff --git a/p2p/server.go b/p2p/server.go
index 692c9eb7d91d..246148741fd9 100644
--- a/p2p/server.go
+++ b/p2p/server.go
@@ -45,6 +45,11 @@ import (
const (
defaultDialTimeout = 15 * time.Second
+ // This is the fairness knob for the discovery mixer. When looking for peers, we'll
+ // wait this long for a single source of candidates before moving on and trying other
+ // sources.
+ discmixTimeout = 5 * time.Second
+
// Connectivity defaults.
maxActiveDialTasks = 16
defaultMaxPendingPeers = 50
@@ -167,16 +172,20 @@ type Server struct {
lock sync.Mutex // protects running
running bool
- nodedb *enode.DB
- localnode *enode.LocalNode
- ntab discoverTable
listener net.Listener
ourHandshake *protoHandshake
- DiscV5 *discv5.Network
loopWG sync.WaitGroup // loop, listenLoop
peerFeed event.Feed
log log.Logger
+ nodedb *enode.DB
+ localnode *enode.LocalNode
+ ntab *discover.UDPv4
+ DiscV5 *discv5.Network
+ discmix *enode.FairMix
+
+ staticNodeResolver nodeResolver
+
// Channels into the run loop.
quit chan struct{}
addstatic chan *enode.Node
@@ -470,7 +479,7 @@ func (srv *Server) Start() (err error) {
}
dynPeers := srv.maxDialedConns()
- dialer := newDialState(srv.localnode.ID(), srv.ntab, dynPeers, &srv.Config)
+ dialer := newDialState(srv.localnode.ID(), dynPeers, &srv.Config)
srv.loopWG.Add(1)
go srv.run(dialer)
return nil
@@ -521,6 +530,18 @@ func (srv *Server) setupLocalNode() error {
}
func (srv *Server) setupDiscovery() error {
+ srv.discmix = enode.NewFairMix(discmixTimeout)
+
+ // Add protocol-specific discovery sources.
+ added := make(map[string]bool)
+ for _, proto := range srv.Protocols {
+ if proto.DialCandidates != nil && !added[proto.Name] {
+ srv.discmix.AddSource(proto.DialCandidates)
+ added[proto.Name] = true
+ }
+ }
+
+ // Don't listen on UDP endpoint if DHT is disabled.
if srv.NoDiscovery && !srv.DiscoveryV5 {
return nil
}
@@ -562,7 +583,10 @@ func (srv *Server) setupDiscovery() error {
return err
}
srv.ntab = ntab
+ srv.discmix.AddSource(ntab.RandomNodes())
+ srv.staticNodeResolver = ntab
}
+
// Discovery V5
if srv.DiscoveryV5 {
var ntab *discv5.Network
@@ -620,6 +644,7 @@ func (srv *Server) run(dialstate dialer) {
srv.log.Info("Started P2P networking", "self", srv.localnode.Node().URLv4())
defer srv.loopWG.Done()
defer srv.nodedb.Close()
+ defer srv.discmix.Close()
var (
peers = make(map[enode.ID]*Peer)
diff --git a/p2p/server_test.go b/p2p/server_test.go
index e8bc627e1d30..383445c83388 100644
--- a/p2p/server_test.go
+++ b/p2p/server_test.go
@@ -233,8 +233,8 @@ func TestServerTaskScheduling(t *testing.T) {
Config: Config{MaxPeers: 10},
localnode: enode.NewLocalNode(db, newkey()),
nodedb: db,
+ discmix: enode.NewFairMix(0),
quit: make(chan struct{}),
- ntab: fakeTable{},
running: true,
log: log.New(),
}
@@ -282,9 +282,9 @@ func TestServerManyTasks(t *testing.T) {
quit: make(chan struct{}),
localnode: enode.NewLocalNode(db, newkey()),
nodedb: db,
- ntab: fakeTable{},
running: true,
log: log.New(),
+ discmix: enode.NewFairMix(0),
}
done = make(chan *testTask)
start, end = 0, 0