diff --git a/cmd/devp2p/crawl.go b/cmd/devp2p/crawl.go new file mode 100644 index 000000000000..92aaad72a372 --- /dev/null +++ b/cmd/devp2p/crawl.go @@ -0,0 +1,152 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of go-ethereum. +// +// go-ethereum is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// go-ethereum is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with go-ethereum. If not, see . + +package main + +import ( + "time" + + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/p2p/enode" +) + +type crawler struct { + input nodeSet + output nodeSet + disc *discover.UDPv4 + iters []enode.Iterator + inputIter enode.Iterator + ch chan *enode.Node + closed chan struct{} + + // settings + revalidateInterval time.Duration +} + +func newCrawler(input nodeSet, disc *discover.UDPv4, iters ...enode.Iterator) *crawler { + c := &crawler{ + input: input, + output: make(nodeSet, len(input)), + disc: disc, + iters: iters, + inputIter: enode.IterNodes(input.nodes()), + ch: make(chan *enode.Node), + closed: make(chan struct{}), + } + c.iters = append(c.iters, c.inputIter) + // Copy input to output initially. Any nodes that fail validation + // will be dropped from output during the run. + for id, n := range input { + c.output[id] = n + } + return c +} + +func (c *crawler) run(timeout time.Duration) nodeSet { + var ( + timeoutTimer = time.NewTimer(timeout) + timeoutCh <-chan time.Time + doneCh = make(chan enode.Iterator, len(c.iters)) + liveIters = len(c.iters) + ) + for _, it := range c.iters { + go c.runIterator(doneCh, it) + } + +loop: + for { + select { + case n := <-c.ch: + c.updateNode(n) + case it := <-doneCh: + if it == c.inputIter { + // Enable timeout when we're done revalidating the input nodes. + log.Info("Revalidation of input set is done", "len", len(c.input)) + if timeout > 0 { + timeoutCh = timeoutTimer.C + } + } + if liveIters--; liveIters == 0 { + break loop + } + case <-timeoutCh: + break loop + } + } + + close(c.closed) + for _, it := range c.iters { + it.Close() + } + for ; liveIters > 0; liveIters-- { + <-doneCh + } + return c.output +} + +func (c *crawler) runIterator(done chan<- enode.Iterator, it enode.Iterator) { + defer func() { done <- it }() + for it.Next() { + select { + case c.ch <- it.Node(): + case <-c.closed: + return + } + } +} + +func (c *crawler) updateNode(n *enode.Node) { + node, ok := c.output[n.ID()] + + // Skip validation of recently-seen nodes. + if ok && time.Since(node.LastCheck) < c.revalidateInterval { + return + } + + // Request the node record. + nn, err := c.disc.RequestENR(n) + node.LastCheck = truncNow() + if err != nil { + if node.Score == 0 { + // Node doesn't implement EIP-868. + log.Debug("Skipping node", "id", n.ID()) + return + } + node.Score /= 2 + } else { + node.N = nn + node.Seq = nn.Seq() + node.Score++ + if node.FirstResponse.IsZero() { + node.FirstResponse = node.LastCheck + } + node.LastResponse = node.LastCheck + } + + // Store/update node in output set. + if node.Score <= 0 { + log.Info("Removing node", "id", n.ID()) + delete(c.output, n.ID()) + } else { + log.Info("Updating node", "id", n.ID(), "seq", n.Seq(), "score", node.Score) + c.output[n.ID()] = node + } +} + +func truncNow() time.Time { + return time.Now().UTC().Truncate(1 * time.Second) +} diff --git a/cmd/devp2p/discv4cmd.go b/cmd/devp2p/discv4cmd.go index ab5b874029df..9525bec66817 100644 --- a/cmd/devp2p/discv4cmd.go +++ b/cmd/devp2p/discv4cmd.go @@ -39,6 +39,7 @@ var ( discv4RequestRecordCommand, discv4ResolveCommand, discv4ResolveJSONCommand, + discv4CrawlCommand, }, } discv4PingCommand = cli.Command{ @@ -67,12 +68,25 @@ var ( Flags: []cli.Flag{bootnodesFlag}, ArgsUsage: "", } + discv4CrawlCommand = cli.Command{ + Name: "crawl", + Usage: "Updates a nodes.json file with random nodes found in the DHT", + Action: discv4Crawl, + Flags: []cli.Flag{bootnodesFlag, crawlTimeoutFlag}, + } ) -var bootnodesFlag = cli.StringFlag{ - Name: "bootnodes", - Usage: "Comma separated nodes used for bootstrapping", -} +var ( + bootnodesFlag = cli.StringFlag{ + Name: "bootnodes", + Usage: "Comma separated nodes used for bootstrapping", + } + crawlTimeoutFlag = cli.DurationFlag{ + Name: "timeout", + Usage: "Time limit for the crawl.", + Value: 30 * time.Minute, + } +) func discv4Ping(ctx *cli.Context) error { n := getNodeArg(ctx) @@ -113,30 +127,48 @@ func discv4ResolveJSON(ctx *cli.Context) error { if ctx.NArg() < 1 { return fmt.Errorf("need nodes file as argument") } - disc := startV4(ctx) - defer disc.Close() - file := ctx.Args().Get(0) - - // Load existing nodes in file. - var nodes []*enode.Node - if common.FileExist(file) { - nodes = loadNodesJSON(file).nodes() + nodesFile := ctx.Args().Get(0) + inputSet := make(nodeSet) + if common.FileExist(nodesFile) { + inputSet = loadNodesJSON(nodesFile) } - // Add nodes from command line arguments. + + // Add extra nodes from command line arguments. + var nodeargs []*enode.Node for i := 1; i < ctx.NArg(); i++ { n, err := parseNode(ctx.Args().Get(i)) if err != nil { exit(err) } - nodes = append(nodes, n) + nodeargs = append(nodeargs, n) } - result := make(nodeSet, len(nodes)) - for _, n := range nodes { - n = disc.Resolve(n) - result[n.ID()] = nodeJSON{Seq: n.Seq(), N: n} + // Run the crawler. + disc := startV4(ctx) + defer disc.Close() + c := newCrawler(inputSet, disc, enode.IterNodes(nodeargs)) + c.revalidateInterval = 0 + output := c.run(0) + writeNodesJSON(nodesFile, output) + return nil +} + +func discv4Crawl(ctx *cli.Context) error { + if ctx.NArg() < 1 { + return fmt.Errorf("need nodes file as argument") + } + nodesFile := ctx.Args().First() + var inputSet nodeSet + if common.FileExist(nodesFile) { + inputSet = loadNodesJSON(nodesFile) } - writeNodesJSON(file, result) + + disc := startV4(ctx) + defer disc.Close() + c := newCrawler(inputSet, disc, disc.RandomNodes()) + c.revalidateInterval = 10 * time.Minute + output := c.run(ctx.Duration(crawlTimeoutFlag.Name)) + writeNodesJSON(nodesFile, output) return nil } diff --git a/cmd/devp2p/dnscmd.go b/cmd/devp2p/dnscmd.go index 74d70d3aaaf9..eb15764b04e2 100644 --- a/cmd/devp2p/dnscmd.go +++ b/cmd/devp2p/dnscmd.go @@ -109,7 +109,8 @@ func dnsSync(ctx *cli.Context) error { } def := treeToDefinition(url, t) def.Meta.LastModified = time.Now() - writeTreeDefinition(outdir, def) + writeTreeMetadata(outdir, def) + writeTreeNodes(outdir, def) return nil } @@ -151,7 +152,7 @@ func dnsSign(ctx *cli.Context) error { def = treeToDefinition(url, t) def.Meta.LastModified = time.Now() - writeTreeDefinition(defdir, def) + writeTreeMetadata(defdir, def) return nil } @@ -315,26 +316,28 @@ func ensureValidTreeSignature(t *dnsdisc.Tree, pubkey *ecdsa.PublicKey, sig stri return nil } -// writeTreeDefinition writes a DNS node tree definition to the given directory. -func writeTreeDefinition(directory string, def *dnsDefinition) { +// writeTreeMetadata writes a DNS node tree metadata file to the given directory. +func writeTreeMetadata(directory string, def *dnsDefinition) { metaJSON, err := json.MarshalIndent(&def.Meta, "", jsonIndent) if err != nil { exit(err) } - // Convert nodes. - nodes := make(nodeSet, len(def.Nodes)) - nodes.add(def.Nodes...) - // Write. if err := os.Mkdir(directory, 0744); err != nil && !os.IsExist(err) { exit(err) } - metaFile, nodesFile := treeDefinitionFiles(directory) - writeNodesJSON(nodesFile, nodes) + metaFile, _ := treeDefinitionFiles(directory) if err := ioutil.WriteFile(metaFile, metaJSON, 0644); err != nil { exit(err) } } +func writeTreeNodes(directory string, def *dnsDefinition) { + ns := make(nodeSet, len(def.Nodes)) + ns.add(def.Nodes...) + _, nodesFile := treeDefinitionFiles(directory) + writeNodesJSON(nodesFile, ns) +} + func treeDefinitionFiles(directory string) (string, string) { meta := filepath.Join(directory, "enrtree-info.json") nodes := filepath.Join(directory, "nodes.json") diff --git a/cmd/devp2p/main.go b/cmd/devp2p/main.go index c88fe6f6123e..6faa65093737 100644 --- a/cmd/devp2p/main.go +++ b/cmd/devp2p/main.go @@ -60,6 +60,7 @@ func init() { enrdumpCommand, discv4Command, dnsCommand, + nodesetCommand, } } diff --git a/cmd/devp2p/nodeset.go b/cmd/devp2p/nodeset.go index a4a05016e92e..2d86c3f65aba 100644 --- a/cmd/devp2p/nodeset.go +++ b/cmd/devp2p/nodeset.go @@ -21,7 +21,9 @@ import ( "encoding/json" "fmt" "io/ioutil" + "os" "sort" + "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/p2p/enode" @@ -36,6 +38,15 @@ type nodeSet map[enode.ID]nodeJSON type nodeJSON struct { Seq uint64 `json:"seq"` N *enode.Node `json:"record"` + + // The score tracks how many liveness checks were performed. It is incremented by one + // every time the node passes a check, and halved every time it doesn't. + Score int `json:"score,omitempty"` + // These two track the time of last successful contact. + FirstResponse time.Time `json:"firstResponse,omitempty"` + LastResponse time.Time `json:"lastResponse,omitempty"` + // This one tracks the time of our last attempt to contact the node. + LastCheck time.Time `json:"lastCheck,omitempty"` } func loadNodesJSON(file string) nodeSet { @@ -51,6 +62,10 @@ func writeNodesJSON(file string, nodes nodeSet) { if err != nil { exit(err) } + if file == "-" { + os.Stdout.Write(nodesJSON) + return + } if err := ioutil.WriteFile(file, nodesJSON, 0644); err != nil { exit(err) } diff --git a/cmd/devp2p/nodesetcmd.go b/cmd/devp2p/nodesetcmd.go new file mode 100644 index 000000000000..384a12da7c7e --- /dev/null +++ b/cmd/devp2p/nodesetcmd.go @@ -0,0 +1,193 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of go-ethereum. +// +// go-ethereum is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// go-ethereum is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with go-ethereum. If not, see . + +package main + +import ( + "fmt" + "net" + "time" + + "github.com/ethereum/go-ethereum/core/forkid" + "github.com/ethereum/go-ethereum/p2p/enr" + "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/rlp" + "gopkg.in/urfave/cli.v1" +) + +var ( + nodesetCommand = cli.Command{ + Name: "nodeset", + Usage: "Node set tools", + Subcommands: []cli.Command{ + nodesetInfoCommand, + nodesetFilterCommand, + }, + } + nodesetInfoCommand = cli.Command{ + Name: "info", + Usage: "Shows statistics about a node set", + Action: nodesetInfo, + ArgsUsage: "", + } + nodesetFilterCommand = cli.Command{ + Name: "filter", + Usage: "Filters a node set", + Action: nodesetFilter, + ArgsUsage: " filters..", + + SkipFlagParsing: true, + } +) + +func nodesetInfo(ctx *cli.Context) error { + if ctx.NArg() < 1 { + return fmt.Errorf("need nodes file as argument") + } + + ns := loadNodesJSON(ctx.Args().First()) + fmt.Printf("Set contains %d nodes.\n", len(ns)) + return nil +} + +func nodesetFilter(ctx *cli.Context) error { + if ctx.NArg() < 1 { + return fmt.Errorf("need nodes file as argument") + } + ns := loadNodesJSON(ctx.Args().First()) + filter, err := andFilter(ctx.Args().Tail()) + if err != nil { + return err + } + + result := make(nodeSet) + for id, n := range ns { + if filter(n) { + result[id] = n + } + } + writeNodesJSON("-", result) + return nil +} + +type nodeFilter func(nodeJSON) bool + +type nodeFilterC struct { + narg int + fn func([]string) (nodeFilter, error) +} + +var filterFlags = map[string]nodeFilterC{ + "-ip": {1, ipFilter}, + "-min-age": {1, minAgeFilter}, + "-eth-network": {1, ethFilter}, + "-les-server": {0, lesFilter}, +} + +func parseFilters(args []string) ([]nodeFilter, error) { + var filters []nodeFilter + for len(args) > 0 { + fc, ok := filterFlags[args[0]] + if !ok { + return nil, fmt.Errorf("invalid filter %q", args[0]) + } + if len(args) < fc.narg { + return nil, fmt.Errorf("filter %q wants %d arguments, have %d", args[0], fc.narg, len(args)) + } + filter, err := fc.fn(args[1:]) + if err != nil { + return nil, fmt.Errorf("%s: %v", args[0], err) + } + filters = append(filters, filter) + args = args[fc.narg+1:] + } + return filters, nil +} + +func andFilter(args []string) (nodeFilter, error) { + checks, err := parseFilters(args) + if err != nil { + return nil, err + } + f := func(n nodeJSON) bool { + for _, filter := range checks { + if !filter(n) { + return false + } + } + return true + } + return f, nil +} + +func ipFilter(args []string) (nodeFilter, error) { + _, cidr, err := net.ParseCIDR(args[0]) + if err != nil { + return nil, err + } + f := func(n nodeJSON) bool { return cidr.Contains(n.N.IP()) } + return f, nil +} + +func minAgeFilter(args []string) (nodeFilter, error) { + minage, err := time.ParseDuration(args[0]) + if err != nil { + return nil, err + } + f := func(n nodeJSON) bool { + age := n.LastResponse.Sub(n.FirstResponse) + return age >= minage + } + return f, nil +} + +func ethFilter(args []string) (nodeFilter, error) { + var filter func(forkid.ID) error + switch args[0] { + case "mainnet": + filter = forkid.NewStaticFilter(params.MainnetChainConfig, params.MainnetGenesisHash) + case "rinkeby": + filter = forkid.NewStaticFilter(params.RinkebyChainConfig, params.RinkebyGenesisHash) + case "goerli": + filter = forkid.NewStaticFilter(params.GoerliChainConfig, params.GoerliGenesisHash) + case "ropsten": + filter = forkid.NewStaticFilter(params.TestnetChainConfig, params.TestnetGenesisHash) + default: + return nil, fmt.Errorf("unknown network %q", args[0]) + } + + f := func(n nodeJSON) bool { + var eth struct { + ForkID forkid.ID + _ []rlp.RawValue `rlp:"tail"` + } + if n.N.Load(enr.WithEntry("eth", ð)) != nil { + return false + } + return filter(eth.ForkID) == nil + } + return f, nil +} + +func lesFilter(args []string) (nodeFilter, error) { + f := func(n nodeJSON) bool { + var les struct { + _ []rlp.RawValue `rlp:"tail"` + } + return n.N.Load(enr.WithEntry("les", &les)) == nil + } + return f, nil +} diff --git a/core/forkid/forkid.go b/core/forkid/forkid.go index 8c1700879a5e..3845742ad9bb 100644 --- a/core/forkid/forkid.go +++ b/core/forkid/forkid.go @@ -80,7 +80,7 @@ func newID(config *params.ChainConfig, genesis common.Hash, head uint64) ID { return ID{Hash: checksumToBytes(hash), Next: next} } -// NewFilter creates an filter that returns if a fork ID should be rejected or not +// NewFilter creates a filter that returns if a fork ID should be rejected or not // based on the local chain's status. func NewFilter(chain *core.BlockChain) func(id ID) error { return newFilter( @@ -92,6 +92,12 @@ func NewFilter(chain *core.BlockChain) func(id ID) error { ) } +// NewStaticFilter creates a filter at block zero. +func NewStaticFilter(config *params.ChainConfig, genesis common.Hash) func(id ID) error { + head := func() uint64 { return 0 } + return newFilter(config, genesis, head) +} + // newFilter is the internal version of NewFilter, taking closures as its arguments // instead of a chain. The reason is to allow testing it without having to simulate // an entire blockchain. diff --git a/p2p/dial.go b/p2p/dial.go index 8dee5063f1d5..68e06cce5874 100644 --- a/p2p/dial.go +++ b/p2p/dial.go @@ -33,12 +33,7 @@ const ( // private networks. dialHistoryExpiration = inboundThrottleTime + 5*time.Second - // Discovery lookups are throttled and can only run - // once every few seconds. - lookupInterval = 4 * time.Second - - // If no peers are found for this amount of time, the initial bootnodes are - // attempted to be connected. + // If no peers are found for this amount of time, the initial bootnodes are dialed. fallbackInterval = 20 * time.Second // Endpoint resolution is throttled with bounded backoff. @@ -52,6 +47,10 @@ type NodeDialer interface { Dial(*enode.Node) (net.Conn, error) } +type nodeResolver interface { + Resolve(*enode.Node) *enode.Node +} + // TCPDialer implements the NodeDialer interface by using a net.Dialer to // create TCP connections to nodes in the network type TCPDialer struct { @@ -69,7 +68,6 @@ func (t TCPDialer) Dial(dest *enode.Node) (net.Conn, error) { // of the main loop in Server.run. type dialstate struct { maxDynDials int - ntab discoverTable netrestrict *netutil.Netlist self enode.ID bootnodes []*enode.Node // default dials when there are no peers @@ -79,55 +77,23 @@ type dialstate struct { lookupRunning bool dialing map[enode.ID]connFlag lookupBuf []*enode.Node // current discovery lookup results - randomNodes []*enode.Node // filled from Table static map[enode.ID]*dialTask hist expHeap } -type discoverTable interface { - Close() - Resolve(*enode.Node) *enode.Node - LookupRandom() []*enode.Node - ReadRandomNodes([]*enode.Node) int -} - type task interface { Do(*Server) } -// A dialTask is generated for each node that is dialed. Its -// fields cannot be accessed while the task is running. -type dialTask struct { - flags connFlag - dest *enode.Node - lastResolved time.Time - resolveDelay time.Duration -} - -// discoverTask runs discovery table operations. -// Only one discoverTask is active at any time. -// discoverTask.Do performs a random lookup. -type discoverTask struct { - results []*enode.Node -} - -// A waitExpireTask is generated if there are no other tasks -// to keep the loop in Server.run ticking. -type waitExpireTask struct { - time.Duration -} - -func newDialState(self enode.ID, ntab discoverTable, maxdyn int, cfg *Config) *dialstate { +func newDialState(self enode.ID, maxdyn int, cfg *Config) *dialstate { s := &dialstate{ maxDynDials: maxdyn, - ntab: ntab, self: self, netrestrict: cfg.NetRestrict, log: cfg.Logger, static: make(map[enode.ID]*dialTask), dialing: make(map[enode.ID]connFlag), bootnodes: make([]*enode.Node, len(cfg.BootstrapNodes)), - randomNodes: make([]*enode.Node, maxdyn/2), } copy(s.bootnodes, cfg.BootstrapNodes) if s.log == nil { @@ -151,10 +117,6 @@ func (s *dialstate) removeStatic(n *enode.Node) { } func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Time) []task { - if s.start.IsZero() { - s.start = now - } - var newtasks []task addDial := func(flag connFlag, n *enode.Node) bool { if err := s.checkDial(n, peers); err != nil { @@ -166,20 +128,9 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti return true } - // Compute number of dynamic dials necessary at this point. - needDynDials := s.maxDynDials - for _, p := range peers { - if p.rw.is(dynDialedConn) { - needDynDials-- - } - } - for _, flag := range s.dialing { - if flag&dynDialedConn != 0 { - needDynDials-- - } + if s.start.IsZero() { + s.start = now } - - // Expire the dial history on every invocation. s.hist.expire(now) // Create dials for static nodes if they are not connected. @@ -194,6 +145,20 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti newtasks = append(newtasks, t) } } + + // Compute number of dynamic dials needed. + needDynDials := s.maxDynDials + for _, p := range peers { + if p.rw.is(dynDialedConn) { + needDynDials-- + } + } + for _, flag := range s.dialing { + if flag&dynDialedConn != 0 { + needDynDials-- + } + } + // If we don't have any peers whatsoever, try to dial a random bootnode. This // scenario is useful for the testnet (and private networks) where the discovery // table might be full of mostly bad peers, making it hard to find good ones. @@ -201,24 +166,12 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti bootnode := s.bootnodes[0] s.bootnodes = append(s.bootnodes[:0], s.bootnodes[1:]...) s.bootnodes = append(s.bootnodes, bootnode) - if addDial(dynDialedConn, bootnode) { needDynDials-- } } - // Use random nodes from the table for half of the necessary - // dynamic dials. - randomCandidates := needDynDials / 2 - if randomCandidates > 0 { - n := s.ntab.ReadRandomNodes(s.randomNodes) - for i := 0; i < randomCandidates && i < n; i++ { - if addDial(dynDialedConn, s.randomNodes[i]) { - needDynDials-- - } - } - } - // Create dynamic dials from random lookup results, removing tried - // items from the result buffer. + + // Create dynamic dials from discovery results. i := 0 for ; i < len(s.lookupBuf) && needDynDials > 0; i++ { if addDial(dynDialedConn, s.lookupBuf[i]) { @@ -226,10 +179,11 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti } } s.lookupBuf = s.lookupBuf[:copy(s.lookupBuf, s.lookupBuf[i:])] + // Launch a discovery lookup if more candidates are needed. if len(s.lookupBuf) < needDynDials && !s.lookupRunning { s.lookupRunning = true - newtasks = append(newtasks, &discoverTask{}) + newtasks = append(newtasks, &discoverTask{want: needDynDials - len(s.lookupBuf)}) } // Launch a timer to wait for the next node to expire if all @@ -279,6 +233,15 @@ func (s *dialstate) taskDone(t task, now time.Time) { } } +// A dialTask is generated for each node that is dialed. Its +// fields cannot be accessed while the task is running. +type dialTask struct { + flags connFlag + dest *enode.Node + lastResolved time.Time + resolveDelay time.Duration +} + func (t *dialTask) Do(srv *Server) { if t.dest.Incomplete() { if !t.resolve(srv) { @@ -304,8 +267,8 @@ func (t *dialTask) Do(srv *Server) { // discovery network with useless queries for nodes that don't exist. // The backoff delay resets when the node is found. func (t *dialTask) resolve(srv *Server) bool { - if srv.ntab == nil { - srv.log.Debug("Can't resolve node", "id", t.dest.ID, "err", "discovery is disabled") + if srv.staticNodeResolver == nil { + srv.log.Debug("Can't resolve node", "id", t.dest.ID(), "err", "discovery is disabled") return false } if t.resolveDelay == 0 { @@ -314,20 +277,20 @@ func (t *dialTask) resolve(srv *Server) bool { if time.Since(t.lastResolved) < t.resolveDelay { return false } - resolved := srv.ntab.Resolve(t.dest) + resolved := srv.staticNodeResolver.Resolve(t.dest) t.lastResolved = time.Now() if resolved == nil { t.resolveDelay *= 2 if t.resolveDelay > maxResolveDelay { t.resolveDelay = maxResolveDelay } - srv.log.Debug("Resolving node failed", "id", t.dest.ID, "newdelay", t.resolveDelay) + srv.log.Debug("Resolving node failed", "id", t.dest.ID(), "newdelay", t.resolveDelay) return false } // The node was found. t.resolveDelay = initialResolveDelay t.dest = resolved - srv.log.Debug("Resolved node", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()}) + srv.log.Debug("Resolved node", "id", t.dest.ID(), "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()}) return true } @@ -350,26 +313,34 @@ func (t *dialTask) String() string { return fmt.Sprintf("%v %x %v:%d", t.flags, id[:8], t.dest.IP(), t.dest.TCP()) } +// discoverTask runs discovery table operations. +// Only one discoverTask is active at any time. +// discoverTask.Do performs a random lookup. +type discoverTask struct { + want int + results []*enode.Node +} + func (t *discoverTask) Do(srv *Server) { - // newTasks generates a lookup task whenever dynamic dials are - // necessary. Lookups need to take some time, otherwise the - // event loop spins too fast. - next := srv.lastLookup.Add(lookupInterval) - if now := time.Now(); now.Before(next) { - time.Sleep(next.Sub(now)) - } - srv.lastLookup = time.Now() - t.results = srv.ntab.LookupRandom() + t.results = enode.ReadNodes(srv.discmix, t.want) } func (t *discoverTask) String() string { - s := "discovery lookup" + s := "discovery query" if len(t.results) > 0 { s += fmt.Sprintf(" (%d results)", len(t.results)) + } else { + s += fmt.Sprintf(" (want %d)", t.want) } return s } +// A waitExpireTask is generated if there are no other tasks +// to keep the loop in Server.run ticking. +type waitExpireTask struct { + time.Duration +} + func (t waitExpireTask) Do(*Server) { time.Sleep(t.Duration) } diff --git a/p2p/dial_test.go b/p2p/dial_test.go index de8fc4a6e3e6..6189ec4d0b85 100644 --- a/p2p/dial_test.go +++ b/p2p/dial_test.go @@ -73,7 +73,7 @@ func runDialTest(t *testing.T, test dialtest) { t.Errorf("ERROR round %d: got %v\nwant %v\nstate: %v\nrunning: %v", i, spew.Sdump(new), spew.Sdump(round.new), spew.Sdump(test.init), spew.Sdump(running)) } - t.Logf("round %d new tasks: %s", i, strings.TrimSpace(spew.Sdump(new))) + t.Logf("round %d (running %d) new tasks: %s", i, running, strings.TrimSpace(spew.Sdump(new))) // Time advances by 16 seconds on every round. vtime = vtime.Add(16 * time.Second) @@ -81,19 +81,11 @@ func runDialTest(t *testing.T, test dialtest) { } } -type fakeTable []*enode.Node - -func (t fakeTable) Self() *enode.Node { return new(enode.Node) } -func (t fakeTable) Close() {} -func (t fakeTable) LookupRandom() []*enode.Node { return nil } -func (t fakeTable) Resolve(*enode.Node) *enode.Node { return nil } -func (t fakeTable) ReadRandomNodes(buf []*enode.Node) int { return copy(buf, t) } - // This test checks that dynamic dials are launched from discovery results. func TestDialStateDynDial(t *testing.T) { config := &Config{Logger: testlog.Logger(t, log.LvlTrace)} runDialTest(t, dialtest{ - init: newDialState(enode.ID{}, fakeTable{}, 5, config), + init: newDialState(enode.ID{}, 5, config), rounds: []round{ // A discovery query is launched. { @@ -102,7 +94,9 @@ func TestDialStateDynDial(t *testing.T) { {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, }, - new: []task{&discoverTask{}}, + new: []task{ + &discoverTask{want: 3}, + }, }, // Dynamic dials are launched when it completes. { @@ -188,7 +182,7 @@ func TestDialStateDynDial(t *testing.T) { }, new: []task{ &dialTask{flags: dynDialedConn, dest: newNode(uintID(7), nil)}, - &discoverTask{}, + &discoverTask{want: 2}, }, }, // Peer 7 is connected, but there still aren't enough dynamic peers @@ -218,7 +212,7 @@ func TestDialStateDynDial(t *testing.T) { &discoverTask{}, }, new: []task{ - &discoverTask{}, + &discoverTask{want: 2}, }, }, }, @@ -235,35 +229,37 @@ func TestDialStateDynDialBootnode(t *testing.T) { }, Logger: testlog.Logger(t, log.LvlTrace), } - table := fakeTable{ - newNode(uintID(4), nil), - newNode(uintID(5), nil), - newNode(uintID(6), nil), - newNode(uintID(7), nil), - newNode(uintID(8), nil), - } runDialTest(t, dialtest{ - init: newDialState(enode.ID{}, table, 5, config), + init: newDialState(enode.ID{}, 5, config), rounds: []round{ - // 2 dynamic dials attempted, bootnodes pending fallback interval { new: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, - &discoverTask{}, + &discoverTask{want: 5}, }, }, - // No dials succeed, bootnodes still pending fallback interval { done: []task{ + &discoverTask{ + results: []*enode.Node{ + newNode(uintID(4), nil), + newNode(uintID(5), nil), + }, + }, + }, + new: []task{ &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, + &discoverTask{want: 3}, }, }, // No dials succeed, bootnodes still pending fallback interval {}, - // No dials succeed, 2 dynamic dials attempted and 1 bootnode too as fallback interval was reached + // 1 bootnode attempted as fallback interval was reached { + done: []task{ + &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, + }, new: []task{ &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)}, }, @@ -275,15 +271,12 @@ func TestDialStateDynDialBootnode(t *testing.T) { }, new: []task{ &dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, }, }, // No dials succeed, 3rd bootnode is attempted { done: []task{ &dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, }, new: []task{ &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)}, @@ -293,115 +286,19 @@ func TestDialStateDynDialBootnode(t *testing.T) { { done: []task{ &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)}, - }, - new: []task{}, - }, - // Random dial succeeds, no more bootnodes are attempted - { - new: []task{ - &waitExpireTask{3 * time.Second}, - }, - peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(4), nil)}}, - }, - done: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, - }, - }, - }, - }) -} - -func TestDialStateDynDialFromTable(t *testing.T) { - // This table always returns the same random nodes - // in the order given below. - table := fakeTable{ - newNode(uintID(1), nil), - newNode(uintID(2), nil), - newNode(uintID(3), nil), - newNode(uintID(4), nil), - newNode(uintID(5), nil), - newNode(uintID(6), nil), - newNode(uintID(7), nil), - newNode(uintID(8), nil), - } - - runDialTest(t, dialtest{ - init: newDialState(enode.ID{}, table, 10, &Config{Logger: testlog.Logger(t, log.LvlTrace)}), - rounds: []round{ - // 5 out of 8 of the nodes returned by ReadRandomNodes are dialed. - { - new: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, - &discoverTask{}, - }, - }, - // Dialing nodes 1,2 succeeds. Dials from the lookup are launched. - { - peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, - }, - done: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)}, &discoverTask{results: []*enode.Node{ - newNode(uintID(10), nil), - newNode(uintID(11), nil), - newNode(uintID(12), nil), + newNode(uintID(6), nil), }}, }, new: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(10), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(11), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(12), nil)}, - &discoverTask{}, - }, - }, - // Dialing nodes 3,4,5 fails. The dials from the lookup succeed. - { - peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(10), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(11), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(12), nil)}}, - }, - done: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(10), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(11), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(12), nil)}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(6), nil)}, + &discoverTask{want: 4}, }, }, - // Waiting for expiry. No waitExpireTask is launched because the - // discovery query is still running. + // Random dial succeeds, no more bootnodes are attempted { peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(10), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(11), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(12), nil)}}, - }, - }, - // Nodes 3,4 are not tried again because only the first two - // returned random nodes (nodes 1,2) are tried and they're - // already connected. - { - peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(10), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(11), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(12), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(6), nil)}}, }, }, }, @@ -416,11 +313,11 @@ func newNode(id enode.ID, ip net.IP) *enode.Node { return enode.SignNull(&r, id) } -// This test checks that candidates that do not match the netrestrict list are not dialed. +// // This test checks that candidates that do not match the netrestrict list are not dialed. func TestDialStateNetRestrict(t *testing.T) { // This table always returns the same random nodes // in the order given below. - table := fakeTable{ + nodes := []*enode.Node{ newNode(uintID(1), net.ParseIP("127.0.0.1")), newNode(uintID(2), net.ParseIP("127.0.0.2")), newNode(uintID(3), net.ParseIP("127.0.0.3")), @@ -434,12 +331,23 @@ func TestDialStateNetRestrict(t *testing.T) { restrict.Add("127.0.2.0/24") runDialTest(t, dialtest{ - init: newDialState(enode.ID{}, table, 10, &Config{NetRestrict: restrict}), + init: newDialState(enode.ID{}, 10, &Config{NetRestrict: restrict}), rounds: []round{ { new: []task{ - &dialTask{flags: dynDialedConn, dest: table[4]}, - &discoverTask{}, + &discoverTask{want: 10}, + }, + }, + { + done: []task{ + &discoverTask{results: nodes}, + }, + new: []task{ + &dialTask{flags: dynDialedConn, dest: nodes[4]}, + &dialTask{flags: dynDialedConn, dest: nodes[5]}, + &dialTask{flags: dynDialedConn, dest: nodes[6]}, + &dialTask{flags: dynDialedConn, dest: nodes[7]}, + &discoverTask{want: 6}, }, }, }, @@ -459,7 +367,7 @@ func TestDialStateStaticDial(t *testing.T) { Logger: testlog.Logger(t, log.LvlTrace), } runDialTest(t, dialtest{ - init: newDialState(enode.ID{}, fakeTable{}, 0, config), + init: newDialState(enode.ID{}, 0, config), rounds: []round{ // Static dials are launched for the nodes that // aren't yet connected. @@ -544,7 +452,7 @@ func TestDialStateCache(t *testing.T) { Logger: testlog.Logger(t, log.LvlTrace), } runDialTest(t, dialtest{ - init: newDialState(enode.ID{}, fakeTable{}, 0, config), + init: newDialState(enode.ID{}, 0, config), rounds: []round{ // Static dials are launched for the nodes that // aren't yet connected. @@ -618,8 +526,8 @@ func TestDialResolve(t *testing.T) { Dialer: TCPDialer{&net.Dialer{Deadline: time.Now().Add(-5 * time.Minute)}}, } resolved := newNode(uintID(1), net.IP{127, 0, 55, 234}) - table := &resolveMock{answer: resolved} - state := newDialState(enode.ID{}, table, 0, config) + resolver := &resolveMock{answer: resolved} + state := newDialState(enode.ID{}, 0, config) // Check that the task is generated with an incomplete ID. dest := newNode(uintID(1), nil) @@ -630,10 +538,14 @@ func TestDialResolve(t *testing.T) { } // Now run the task, it should resolve the ID once. - srv := &Server{ntab: table, log: config.Logger, Config: *config} + srv := &Server{ + Config: *config, + log: config.Logger, + staticNodeResolver: resolver, + } tasks[0].Do(srv) - if !reflect.DeepEqual(table.resolveCalls, []*enode.Node{dest}) { - t.Fatalf("wrong resolve calls, got %v", table.resolveCalls) + if !reflect.DeepEqual(resolver.calls, []*enode.Node{dest}) { + t.Fatalf("wrong resolve calls, got %v", resolver.calls) } // Report it as done to the dialer, which should update the static node record. @@ -666,18 +578,13 @@ func uintID(i uint32) enode.ID { return id } -// implements discoverTable for TestDialResolve +// for TestDialResolve type resolveMock struct { - resolveCalls []*enode.Node - answer *enode.Node + calls []*enode.Node + answer *enode.Node } func (t *resolveMock) Resolve(n *enode.Node) *enode.Node { - t.resolveCalls = append(t.resolveCalls, n) + t.calls = append(t.calls, n) return t.answer } - -func (t *resolveMock) Self() *enode.Node { return new(enode.Node) } -func (t *resolveMock) Close() {} -func (t *resolveMock) LookupRandom() []*enode.Node { return nil } -func (t *resolveMock) ReadRandomNodes(buf []*enode.Node) int { return 0 } diff --git a/p2p/discover/common.go b/p2p/discover/common.go index 3c080359fdf8..cef6a9fc4f9f 100644 --- a/p2p/discover/common.go +++ b/p2p/discover/common.go @@ -25,6 +25,7 @@ import ( "github.com/ethereum/go-ethereum/p2p/netutil" ) +// UDPConn is a network connection on which discovery can operate. type UDPConn interface { ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error) @@ -32,7 +33,7 @@ type UDPConn interface { LocalAddr() net.Addr } -// Config holds Table-related settings. +// Config holds settings for the discovery listener. type Config struct { // These settings are required and configure the UDP listener: PrivateKey *ecdsa.PrivateKey @@ -50,7 +51,7 @@ func ListenUDP(c UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv4, error) { } // ReadPacket is a packet that couldn't be handled. Those packets are sent to the unhandled -// channel if configured. +// channel if configured. This is exported for internal use, do not use this type. type ReadPacket struct { Data []byte Addr *net.UDPAddr diff --git a/p2p/discover/lookup.go b/p2p/discover/lookup.go new file mode 100644 index 000000000000..f988e0683808 --- /dev/null +++ b/p2p/discover/lookup.go @@ -0,0 +1,209 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package discover + +import ( + "context" + + "github.com/ethereum/go-ethereum/p2p/enode" +) + +// lookup performs a network search for nodes close to the given target. It approaches the +// target by querying nodes that are closer to it on each iteration. The given target does +// not need to be an actual node identifier. +type lookup struct { + tab *Table + queryfunc func(*node) ([]*node, error) + replyCh chan []*node + cancelCh <-chan struct{} + asked, seen map[enode.ID]bool + result nodesByDistance + replyBuffer []*node + queries int +} + +type queryFunc func(*node) ([]*node, error) + +func newLookup(ctx context.Context, tab *Table, target enode.ID, q queryFunc) *lookup { + it := &lookup{ + tab: tab, + queryfunc: q, + asked: make(map[enode.ID]bool), + seen: make(map[enode.ID]bool), + result: nodesByDistance{target: target}, + replyCh: make(chan []*node, alpha), + cancelCh: ctx.Done(), + queries: -1, + } + // Don't query further if we hit ourself. + // Unlikely to happen often in practice. + it.asked[tab.self().ID()] = true + return it +} + +// run runs the lookup to completion and returns the closest nodes found. +func (it *lookup) run() []*enode.Node { + for it.advance() { + } + return unwrapNodes(it.result.entries) +} + +// advance advances the lookup until any new nodes have been found. +// It returns false when the lookup has ended. +func (it *lookup) advance() bool { + for it.startQueries() { + select { + case nodes := <-it.replyCh: + it.replyBuffer = it.replyBuffer[:0] + for _, n := range nodes { + if n != nil && !it.seen[n.ID()] { + it.seen[n.ID()] = true + it.result.push(n, bucketSize) + it.replyBuffer = append(it.replyBuffer, n) + } + } + it.queries-- + if len(it.replyBuffer) > 0 { + return true + } + case <-it.cancelCh: + it.shutdown() + } + } + return false +} + +func (it *lookup) shutdown() { + for it.queries > 0 { + <-it.replyCh + it.queries-- + } + it.queryfunc = nil + it.replyBuffer = nil +} + +func (it *lookup) startQueries() bool { + if it.queryfunc == nil { + return false + } + + // The first query returns nodes from the local table. + if it.queries == -1 { + it.tab.mutex.Lock() + closest := it.tab.closest(it.result.target, bucketSize, false) + it.tab.mutex.Unlock() + it.queries = 1 + it.replyCh <- closest.entries + return true + } + + // Ask the closest nodes that we haven't asked yet. + for i := 0; i < len(it.result.entries) && it.queries < alpha; i++ { + n := it.result.entries[i] + if !it.asked[n.ID()] { + it.asked[n.ID()] = true + it.queries++ + go it.query(n, it.replyCh) + } + } + // The lookup ends when no more nodes can be asked. + return it.queries > 0 +} + +func (it *lookup) query(n *node, reply chan<- []*node) { + fails := it.tab.db.FindFails(n.ID(), n.IP()) + r, err := it.queryfunc(n) + if err == errClosed { + // Avoid recording failures on shutdown. + reply <- nil + return + } else if len(r) == 0 { + fails++ + it.tab.db.UpdateFindFails(n.ID(), n.IP(), fails) + it.tab.log.Trace("Findnode failed", "id", n.ID(), "failcount", fails, "err", err) + if fails >= maxFindnodeFailures { + it.tab.log.Trace("Too many findnode failures, dropping", "id", n.ID(), "failcount", fails) + it.tab.delete(n) + } + } else if fails > 0 { + // Reset failure counter because it counts _consecutive_ failures. + it.tab.db.UpdateFindFails(n.ID(), n.IP(), 0) + } + + // Grab as many nodes as possible. Some of them might not be alive anymore, but we'll + // just remove those again during revalidation. + for _, n := range r { + it.tab.addSeenNode(n) + } + reply <- r +} + +// lookupIterator performs lookup operations and iterates over all seen nodes. +// When a lookup finishes, a new one is created through nextLookup. +type lookupIterator struct { + buffer []*node + nextLookup lookupFunc + ctx context.Context + cancel func() + lookup *lookup +} + +type lookupFunc func(ctx context.Context) *lookup + +func newLookupIterator(ctx context.Context, next lookupFunc) *lookupIterator { + ctx, cancel := context.WithCancel(ctx) + return &lookupIterator{ctx: ctx, cancel: cancel, nextLookup: next} +} + +// Node returns the current node. +func (it *lookupIterator) Node() *enode.Node { + if len(it.buffer) == 0 { + return nil + } + return unwrapNode(it.buffer[0]) +} + +// Next moves to the next node. +func (it *lookupIterator) Next() bool { + // Consume next node in buffer. + if len(it.buffer) > 0 { + it.buffer = it.buffer[1:] + } + // Advance the lookup to refill the buffer. + for len(it.buffer) == 0 { + if it.ctx.Err() != nil { + it.lookup = nil + it.buffer = nil + return false + } + if it.lookup == nil { + it.lookup = it.nextLookup(it.ctx) + continue + } + if !it.lookup.advance() { + it.lookup = nil + continue + } + it.buffer = it.lookup.replyBuffer + } + return true +} + +// Close ends the iterator. +func (it *lookupIterator) Close() { + it.cancel() +} diff --git a/p2p/discover/table_util_test.go b/p2p/discover/table_util_test.go index 2292055e160d..e35e48c5e69d 100644 --- a/p2p/discover/table_util_test.go +++ b/p2p/discover/table_util_test.go @@ -17,11 +17,14 @@ package discover import ( + "bytes" "crypto/ecdsa" "encoding/hex" + "errors" "fmt" "math/rand" "net" + "reflect" "sort" "sync" @@ -169,6 +172,28 @@ func hasDuplicates(slice []*node) bool { return false } +func checkNodesEqual(got, want []*enode.Node) error { + if reflect.DeepEqual(got, want) { + return nil + } + output := new(bytes.Buffer) + fmt.Fprintf(output, "got %d nodes:\n", len(got)) + for _, n := range got { + fmt.Fprintf(output, " %v %v\n", n.ID(), n) + } + fmt.Fprintf(output, "want %d:\n", len(want)) + for _, n := range want { + fmt.Fprintf(output, " %v %v\n", n.ID(), n) + } + return errors.New(output.String()) +} + +func sortByID(nodes []*enode.Node) { + sort.Slice(nodes, func(i, j int) bool { + return string(nodes[i].ID().Bytes()) < string(nodes[j].ID().Bytes()) + }) +} + func sortedByDistanceTo(distbase enode.ID, slice []*node) bool { return sort.SliceIsSorted(slice, func(i, j int) bool { return enode.DistCmp(distbase, slice[i].ID(), slice[j].ID()) < 0 diff --git a/p2p/discover/v4_udp_lookup_test.go b/p2p/discover/v4_lookup_test.go similarity index 75% rename from p2p/discover/v4_udp_lookup_test.go rename to p2p/discover/v4_lookup_test.go index bc1cdfb089ab..9b4042c5a276 100644 --- a/p2p/discover/v4_udp_lookup_test.go +++ b/p2p/discover/v4_lookup_test.go @@ -20,7 +20,6 @@ import ( "crypto/ecdsa" "fmt" "net" - "reflect" "sort" "testing" @@ -49,19 +48,7 @@ func TestUDPv4_Lookup(t *testing.T) { }() // Answer lookup packets. - for done := false; !done; { - done = test.waitPacketOut(func(p packetV4, to *net.UDPAddr, hash []byte) { - n, key := lookupTestnet.nodeByAddr(to) - switch p.(type) { - case *pingV4: - test.packetInFrom(nil, key, to, &pongV4{Expiration: futureExp, ReplyTok: hash}) - case *findnodeV4: - dist := enode.LogDist(n.ID(), lookupTestnet.target.id()) - nodes := lookupTestnet.nodesAtDistance(dist - 1) - test.packetInFrom(nil, key, to, &neighborsV4{Expiration: futureExp, Nodes: nodes}) - } - }) - } + serveTestnet(test, lookupTestnet) // Verify result nodes. results := <-resultC @@ -78,8 +65,94 @@ func TestUDPv4_Lookup(t *testing.T) { if !sortedByDistanceTo(lookupTestnet.target.id(), wrapNodes(results)) { t.Errorf("result set not sorted by distance to target") } - if !reflect.DeepEqual(results, lookupTestnet.closest(bucketSize)) { - t.Errorf("results aren't the closest %d nodes", bucketSize) + if err := checkNodesEqual(results, lookupTestnet.closest(bucketSize)); err != nil { + t.Errorf("results aren't the closest %d nodes\n%v", bucketSize, err) + } +} + +func TestUDPv4_LookupIterator(t *testing.T) { + t.Parallel() + test := newUDPTest(t) + defer test.close() + + // Seed table with initial nodes. + bootnodes := make([]*node, len(lookupTestnet.dists[256])) + for i := range lookupTestnet.dists[256] { + bootnodes[i] = wrapNode(lookupTestnet.node(256, i)) + } + fillTable(test.table, bootnodes) + go serveTestnet(test, lookupTestnet) + + // Create the iterator and collect the nodes it yields. + iter := test.udp.RandomNodes() + seen := make(map[enode.ID]*enode.Node) + for limit := lookupTestnet.len(); iter.Next() && len(seen) < limit; { + seen[iter.Node().ID()] = iter.Node() + } + iter.Close() + + // Check that all nodes in lookupTestnet were seen by the iterator. + results := make([]*enode.Node, 0, len(seen)) + for _, n := range seen { + results = append(results, n) + } + sortByID(results) + want := lookupTestnet.nodes() + if err := checkNodesEqual(results, want); err != nil { + t.Fatal(err) + } +} + +// TestUDPv4_LookupIteratorClose checks that lookupIterator ends when its Close +// method is called. +func TestUDPv4_LookupIteratorClose(t *testing.T) { + t.Parallel() + test := newUDPTest(t) + defer test.close() + + // Seed table with initial nodes. + bootnodes := make([]*node, len(lookupTestnet.dists[256])) + for i := range lookupTestnet.dists[256] { + bootnodes[i] = wrapNode(lookupTestnet.node(256, i)) + } + fillTable(test.table, bootnodes) + go serveTestnet(test, lookupTestnet) + + it := test.udp.RandomNodes() + if ok := it.Next(); !ok || it.Node() == nil { + t.Fatalf("iterator didn't return any node") + } + + it.Close() + + ncalls := 0 + for ; ncalls < 100 && it.Next(); ncalls++ { + if it.Node() == nil { + t.Error("iterator returned Node() == nil node after Next() == true") + } + } + t.Logf("iterator returned %d nodes after close", ncalls) + if it.Next() { + t.Errorf("Next() == true after close and %d more calls", ncalls) + } + if n := it.Node(); n != nil { + t.Errorf("iterator returned non-nil node after close and %d more calls", ncalls) + } +} + +func serveTestnet(test *udpTest, testnet *preminedTestnet) { + for done := false; !done; { + done = test.waitPacketOut(func(p packetV4, to *net.UDPAddr, hash []byte) { + n, key := testnet.nodeByAddr(to) + switch p.(type) { + case *pingV4: + test.packetInFrom(nil, key, to, &pongV4{Expiration: futureExp, ReplyTok: hash}) + case *findnodeV4: + dist := enode.LogDist(n.ID(), testnet.target.id()) + nodes := testnet.nodesAtDistance(dist - 1) + test.packetInFrom(nil, key, to, &neighborsV4{Expiration: futureExp, Nodes: nodes}) + } + }) } } @@ -148,6 +221,25 @@ type preminedTestnet struct { dists [hashBits + 1][]*ecdsa.PrivateKey } +func (tn *preminedTestnet) len() int { + n := 0 + for _, keys := range tn.dists { + n += len(keys) + } + return n +} + +func (tn *preminedTestnet) nodes() []*enode.Node { + result := make([]*enode.Node, 0, tn.len()) + for dist, keys := range tn.dists { + for index := range keys { + result = append(result, tn.node(dist, index)) + } + } + sortByID(result) + return result +} + func (tn *preminedTestnet) node(dist, index int) *enode.Node { key := tn.dists[dist][index] ip := net.IP{127, byte(dist >> 8), byte(dist), byte(index)} diff --git a/p2p/discover/v4_udp.go b/p2p/discover/v4_udp.go index a8f7101b0594..bfb66fcb1967 100644 --- a/p2p/discover/v4_udp.go +++ b/p2p/discover/v4_udp.go @@ -19,6 +19,7 @@ package discover import ( "bytes" "container/list" + "context" "crypto/ecdsa" crand "crypto/rand" "errors" @@ -207,7 +208,8 @@ type UDPv4 struct { addReplyMatcher chan *replyMatcher gotreply chan reply - closing chan struct{} + closeCtx context.Context + cancelCloseCtx func() } // replyMatcher represents a pending reply. @@ -256,20 +258,23 @@ type reply struct { } func ListenV4(c UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv4, error) { + closeCtx, cancel := context.WithCancel(context.Background()) t := &UDPv4{ conn: c, priv: cfg.PrivateKey, netrestrict: cfg.NetRestrict, localNode: ln, db: ln.Database(), - closing: make(chan struct{}), gotreply: make(chan reply), addReplyMatcher: make(chan *replyMatcher), + closeCtx: closeCtx, + cancelCloseCtx: cancel, log: cfg.Log, } if t.log == nil { t.log = log.Root() } + tab, err := newTable(t, ln.Database(), cfg.Bootnodes, t.log) if err != nil { return nil, err @@ -291,126 +296,13 @@ func (t *UDPv4) Self() *enode.Node { // Close shuts down the socket and aborts any running queries. func (t *UDPv4) Close() { t.closeOnce.Do(func() { - close(t.closing) + t.cancelCloseCtx() t.conn.Close() t.wg.Wait() t.tab.close() }) } -// ReadRandomNodes reads random nodes from the local table. -func (t *UDPv4) ReadRandomNodes(buf []*enode.Node) int { - return t.tab.ReadRandomNodes(buf) -} - -// LookupRandom finds random nodes in the network. -func (t *UDPv4) LookupRandom() []*enode.Node { - if t.tab.len() == 0 { - // All nodes were dropped, refresh. The very first query will hit this - // case and run the bootstrapping logic. - <-t.tab.refresh() - } - return t.lookupRandom() -} - -func (t *UDPv4) LookupPubkey(key *ecdsa.PublicKey) []*enode.Node { - if t.tab.len() == 0 { - // All nodes were dropped, refresh. The very first query will hit this - // case and run the bootstrapping logic. - <-t.tab.refresh() - } - return unwrapNodes(t.lookup(encodePubkey(key))) -} - -func (t *UDPv4) lookupRandom() []*enode.Node { - var target encPubkey - crand.Read(target[:]) - return unwrapNodes(t.lookup(target)) -} - -func (t *UDPv4) lookupSelf() []*enode.Node { - return unwrapNodes(t.lookup(encodePubkey(&t.priv.PublicKey))) -} - -// lookup performs a network search for nodes close to the given target. It approaches the -// target by querying nodes that are closer to it on each iteration. The given target does -// not need to be an actual node identifier. -func (t *UDPv4) lookup(targetKey encPubkey) []*node { - var ( - target = enode.ID(crypto.Keccak256Hash(targetKey[:])) - asked = make(map[enode.ID]bool) - seen = make(map[enode.ID]bool) - reply = make(chan []*node, alpha) - pendingQueries = 0 - result *nodesByDistance - ) - // Don't query further if we hit ourself. - // Unlikely to happen often in practice. - asked[t.Self().ID()] = true - - // Generate the initial result set. - t.tab.mutex.Lock() - result = t.tab.closest(target, bucketSize, false) - t.tab.mutex.Unlock() - - for { - // ask the alpha closest nodes that we haven't asked yet - for i := 0; i < len(result.entries) && pendingQueries < alpha; i++ { - n := result.entries[i] - if !asked[n.ID()] { - asked[n.ID()] = true - pendingQueries++ - go t.lookupWorker(n, targetKey, reply) - } - } - if pendingQueries == 0 { - // we have asked all closest nodes, stop the search - break - } - select { - case nodes := <-reply: - for _, n := range nodes { - if n != nil && !seen[n.ID()] { - seen[n.ID()] = true - result.push(n, bucketSize) - } - } - case <-t.tab.closeReq: - return nil // shutdown, no need to continue. - } - pendingQueries-- - } - return result.entries -} - -func (t *UDPv4) lookupWorker(n *node, targetKey encPubkey, reply chan<- []*node) { - fails := t.db.FindFails(n.ID(), n.IP()) - r, err := t.findnode(n.ID(), n.addr(), targetKey) - if err == errClosed { - // Avoid recording failures on shutdown. - reply <- nil - return - } else if len(r) == 0 { - fails++ - t.db.UpdateFindFails(n.ID(), n.IP(), fails) - t.log.Trace("Findnode failed", "id", n.ID(), "failcount", fails, "err", err) - if fails >= maxFindnodeFailures { - t.log.Trace("Too many findnode failures, dropping", "id", n.ID(), "failcount", fails) - t.tab.delete(n) - } - } else if fails > 0 { - // Reset failure counter because it counts _consecutive_ failures. - t.db.UpdateFindFails(n.ID(), n.IP(), 0) - } - - // Grab as many nodes as possible. Some of them might not be alive anymore, but we'll - // just remove those again during revalidation. - for _, n := range r { - t.tab.addSeenNode(n) - } - reply <- r -} - // Resolve searches for a specific node with the given ID and tries to get the most recent // version of the node record for it. It returns n if the node could not be resolved. func (t *UDPv4) Resolve(n *enode.Node) *enode.Node { @@ -498,6 +390,45 @@ func (t *UDPv4) makePing(toaddr *net.UDPAddr) *pingV4 { } } +// LookupPubkey finds the closest nodes to the given public key. +func (t *UDPv4) LookupPubkey(key *ecdsa.PublicKey) []*enode.Node { + if t.tab.len() == 0 { + // All nodes were dropped, refresh. The very first query will hit this + // case and run the bootstrapping logic. + <-t.tab.refresh() + } + return t.newLookup(t.closeCtx, encodePubkey(key)).run() +} + +// RandomNodes is an iterator yielding nodes from a random walk of the DHT. +func (t *UDPv4) RandomNodes() enode.Iterator { + return newLookupIterator(t.closeCtx, t.newRandomLookup) +} + +// lookupRandom implements transport. +func (t *UDPv4) lookupRandom() []*enode.Node { + return t.newRandomLookup(t.closeCtx).run() +} + +// lookupSelf implements transport. +func (t *UDPv4) lookupSelf() []*enode.Node { + return t.newLookup(t.closeCtx, encodePubkey(&t.priv.PublicKey)).run() +} + +func (t *UDPv4) newRandomLookup(ctx context.Context) *lookup { + var target encPubkey + crand.Read(target[:]) + return t.newLookup(ctx, target) +} + +func (t *UDPv4) newLookup(ctx context.Context, targetKey encPubkey) *lookup { + target := enode.ID(crypto.Keccak256Hash(targetKey[:])) + it := newLookup(ctx, t.tab, target, func(n *node) ([]*node, error) { + return t.findnode(n.ID(), n.addr(), targetKey) + }) + return it +} + // findnode sends a findnode request to the given node and waits until // the node has sent up to k neighbors. func (t *UDPv4) findnode(toid enode.ID, toaddr *net.UDPAddr, target encPubkey) ([]*node, error) { @@ -575,7 +506,7 @@ func (t *UDPv4) pending(id enode.ID, ip net.IP, ptype byte, callback replyMatchF select { case t.addReplyMatcher <- p: // loop will handle it - case <-t.closing: + case <-t.closeCtx.Done(): ch <- errClosed } return p @@ -589,7 +520,7 @@ func (t *UDPv4) handleReply(from enode.ID, fromIP net.IP, req packetV4) bool { case t.gotreply <- reply{from, fromIP, req, matched}: // loop will handle it return <-matched - case <-t.closing: + case <-t.closeCtx.Done(): return false } } @@ -635,7 +566,7 @@ func (t *UDPv4) loop() { resetTimeout() select { - case <-t.closing: + case <-t.closeCtx.Done(): for el := plist.Front(); el != nil; el = el.Next() { el.Value.(*replyMatcher).errc <- errClosed } diff --git a/p2p/enode/iter.go b/p2p/enode/iter.go new file mode 100644 index 000000000000..112b76d06a0e --- /dev/null +++ b/p2p/enode/iter.go @@ -0,0 +1,286 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package enode + +import ( + "sync" + "time" +) + +// Iterator represents a sequence of nodes. The Next method moves to the next node in the +// sequence. It returns false when the sequence has ended or the iterator is closed. Close +// may be called concurrently with Next and Node, and interrupts Next if it is blocked. +type Iterator interface { + Next() bool // moves to next node + Node() *Node // returns current node + Close() // ends the iterator +} + +// ReadNodes reads at most n nodes from the given iterator. The return value contains no +// duplicates and no nil values. To prevent looping indefinitely for small repeating node +// sequences, this function calls Next at most n times. +func ReadNodes(it Iterator, n int) []*Node { + seen := make(map[ID]*Node, n) + for i := 0; i < n && it.Next(); i++ { + // Remove duplicates, keeping the node with higher seq. + node := it.Node() + prevNode, ok := seen[node.ID()] + if ok && prevNode.Seq() > node.Seq() { + continue + } + seen[node.ID()] = node + } + result := make([]*Node, 0, len(seen)) + for _, node := range seen { + result = append(result, node) + } + return result +} + +// IterNodes makes an iterator which runs through the given nodes once. +func IterNodes(nodes []*Node) Iterator { + return &sliceIter{nodes: nodes, index: -1} +} + +// CycleNodes makes an iterator which cycles through the given nodes indefinitely. +func CycleNodes(nodes []*Node) Iterator { + return &sliceIter{nodes: nodes, index: -1, cycle: true} +} + +type sliceIter struct { + mu sync.Mutex + nodes []*Node + index int + cycle bool +} + +func (it *sliceIter) Next() bool { + it.mu.Lock() + defer it.mu.Unlock() + + if len(it.nodes) == 0 { + return false + } + it.index++ + if it.index == len(it.nodes) { + if it.cycle { + it.index = 0 + } else { + it.nodes = nil + return false + } + } + return true +} + +func (it *sliceIter) Node() *Node { + if len(it.nodes) == 0 { + return nil + } + return it.nodes[it.index] +} + +func (it *sliceIter) Close() { + it.mu.Lock() + defer it.mu.Unlock() + + it.nodes = nil +} + +// Filter wraps an iterator such that Next only returns nodes for which +// the 'check' function returns true. +func Filter(it Iterator, check func(*Node) bool) Iterator { + return &filterIter{it, check} +} + +type filterIter struct { + Iterator + check func(*Node) bool +} + +func (f *filterIter) Next() bool { + for f.Iterator.Next() { + if f.check(f.Node()) { + return true + } + } + return false +} + +// FairMix aggregates multiple node iterators. The mixer itself is an iterator which ends +// only when Close is called. Source iterators added via AddSource are removed from the +// mix when they end. +// +// The distribution of nodes returned by Next is approximately fair, i.e. FairMix +// attempts to draw from all sources equally often. However, if a certain source is slow +// and doesn't return a node within the configured timeout, a node from any other source +// will be returned. +// +// It's safe to call AddSource and Close concurrently with Next. +type FairMix struct { + wg sync.WaitGroup + fromAny chan *Node + timeout time.Duration + cur *Node + + mu sync.Mutex + closed chan struct{} + sources []*mixSource + last int +} + +type mixSource struct { + it Iterator + next chan *Node + timeout time.Duration +} + +// NewFairMix creates a mixer. +// +// The timeout specifies how long the mixer will wait for the next fairly-chosen source +// before giving up and taking a node from any other source. A good way to set the timeout +// is deciding how long you'd want to wait for a node on average. Passing a negative +// timeout makes the mixer completely fair. +func NewFairMix(timeout time.Duration) *FairMix { + m := &FairMix{ + fromAny: make(chan *Node), + closed: make(chan struct{}), + timeout: timeout, + } + return m +} + +// AddSource adds a source of nodes. +func (m *FairMix) AddSource(it Iterator) { + m.mu.Lock() + defer m.mu.Unlock() + + if m.closed == nil { + return + } + m.wg.Add(1) + source := &mixSource{it, make(chan *Node), m.timeout} + m.sources = append(m.sources, source) + go m.runSource(m.closed, source) +} + +// Close shuts down the mixer and all current sources. +// Calling this is required to release resources associated with the mixer. +func (m *FairMix) Close() { + m.mu.Lock() + defer m.mu.Unlock() + + if m.closed == nil { + return + } + for _, s := range m.sources { + s.it.Close() + } + close(m.closed) + m.wg.Wait() + close(m.fromAny) + m.sources = nil + m.closed = nil +} + +// Next returns a node from a random source. +func (m *FairMix) Next() bool { + m.cur = nil + + var timeout <-chan time.Time + if m.timeout >= 0 { + timer := time.NewTimer(m.timeout) + timeout = timer.C + defer timer.Stop() + } + for { + source := m.pickSource() + if source == nil { + return m.nextFromAny() + } + select { + case n, ok := <-source.next: + if ok { + m.cur = n + source.timeout = m.timeout + return true + } + // This source has ended. + m.deleteSource(source) + case <-timeout: + source.timeout /= 2 + return m.nextFromAny() + } + } +} + +// Node returns the current node. +func (m *FairMix) Node() *Node { + return m.cur +} + +// nextFromAny is used when there are no sources or when the 'fair' choice +// doesn't turn up a node quickly enough. +func (m *FairMix) nextFromAny() bool { + n, ok := <-m.fromAny + if ok { + m.cur = n + } + return ok +} + +// pickSource chooses the next source to read from, cycling through them in order. +func (m *FairMix) pickSource() *mixSource { + m.mu.Lock() + defer m.mu.Unlock() + + if len(m.sources) == 0 { + return nil + } + m.last = (m.last + 1) % len(m.sources) + return m.sources[m.last] +} + +// deleteSource deletes a source. +func (m *FairMix) deleteSource(s *mixSource) { + m.mu.Lock() + defer m.mu.Unlock() + + for i := range m.sources { + if m.sources[i] == s { + copy(m.sources[i:], m.sources[i+1:]) + m.sources[len(m.sources)-1] = nil + m.sources = m.sources[:len(m.sources)-1] + break + } + } +} + +// runSource reads a single source in a loop. +func (m *FairMix) runSource(closed chan struct{}, s *mixSource) { + defer m.wg.Done() + defer close(s.next) + for s.it.Next() { + n := s.it.Node() + select { + case s.next <- n: + case m.fromAny <- n: + case <-closed: + return + } + } +} diff --git a/p2p/enode/iter_test.go b/p2p/enode/iter_test.go new file mode 100644 index 000000000000..6009661f3ce6 --- /dev/null +++ b/p2p/enode/iter_test.go @@ -0,0 +1,291 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package enode + +import ( + "encoding/binary" + "runtime" + "sync/atomic" + "testing" + "time" + + "github.com/ethereum/go-ethereum/p2p/enr" +) + +func TestReadNodes(t *testing.T) { + nodes := ReadNodes(new(genIter), 10) + checkNodes(t, nodes, 10) +} + +// This test checks that ReadNodes terminates when reading N nodes from an iterator +// which returns less than N nodes in an endless cycle. +func TestReadNodesCycle(t *testing.T) { + iter := &callCountIter{ + Iterator: CycleNodes([]*Node{ + testNode(0, 0), + testNode(1, 0), + testNode(2, 0), + }), + } + nodes := ReadNodes(iter, 10) + checkNodes(t, nodes, 3) + if iter.count != 10 { + t.Fatalf("%d calls to Next, want %d", iter.count, 100) + } +} + +func TestFilterNodes(t *testing.T) { + nodes := make([]*Node, 100) + for i := range nodes { + nodes[i] = testNode(uint64(i), uint64(i)) + } + + it := Filter(IterNodes(nodes), func(n *Node) bool { + return n.Seq() >= 50 + }) + for i := 50; i < len(nodes); i++ { + if !it.Next() { + t.Fatal("Next returned false") + } + if it.Node() != nodes[i] { + t.Fatalf("iterator returned wrong node %v\nwant %v", it.Node(), nodes[i]) + } + } + if it.Next() { + t.Fatal("Next returned true after underlying iterator has ended") + } +} + +func checkNodes(t *testing.T, nodes []*Node, wantLen int) { + if len(nodes) != wantLen { + t.Errorf("slice has %d nodes, want %d", len(nodes), wantLen) + return + } + seen := make(map[ID]bool) + for i, e := range nodes { + if e == nil { + t.Errorf("nil node at index %d", i) + return + } + if seen[e.ID()] { + t.Errorf("slice has duplicate node %v", e.ID()) + return + } + seen[e.ID()] = true + } +} + +// This test checks fairness of FairMix in the happy case where all sources return nodes +// within the context's deadline. +func TestFairMix(t *testing.T) { + for i := 0; i < 500; i++ { + testMixerFairness(t) + } +} + +func testMixerFairness(t *testing.T) { + mix := NewFairMix(1 * time.Second) + mix.AddSource(&genIter{index: 1}) + mix.AddSource(&genIter{index: 2}) + mix.AddSource(&genIter{index: 3}) + defer mix.Close() + + nodes := ReadNodes(mix, 500) + checkNodes(t, nodes, 500) + + // Verify that the nodes slice contains an approximately equal number of nodes + // from each source. + d := idPrefixDistribution(nodes) + for _, count := range d { + if approxEqual(count, len(nodes)/3, 30) { + t.Fatalf("ID distribution is unfair: %v", d) + } + } +} + +// This test checks that FairMix falls back to an alternative source when +// the 'fair' choice doesn't return a node within the timeout. +func TestFairMixNextFromAll(t *testing.T) { + mix := NewFairMix(1 * time.Millisecond) + mix.AddSource(&genIter{index: 1}) + mix.AddSource(CycleNodes(nil)) + defer mix.Close() + + nodes := ReadNodes(mix, 500) + checkNodes(t, nodes, 500) + + d := idPrefixDistribution(nodes) + if len(d) > 1 || d[1] != len(nodes) { + t.Fatalf("wrong ID distribution: %v", d) + } +} + +// This test ensures FairMix works for Next with no sources. +func TestFairMixEmpty(t *testing.T) { + var ( + mix = NewFairMix(1 * time.Second) + testN = testNode(1, 1) + ch = make(chan *Node) + ) + defer mix.Close() + + go func() { + mix.Next() + ch <- mix.Node() + }() + + mix.AddSource(CycleNodes([]*Node{testN})) + if n := <-ch; n != testN { + t.Errorf("got wrong node: %v", n) + } +} + +// This test checks closing a source while Next runs. +func TestFairMixRemoveSource(t *testing.T) { + mix := NewFairMix(1 * time.Second) + source := make(blockingIter) + mix.AddSource(source) + + sig := make(chan *Node) + go func() { + <-sig + mix.Next() + sig <- mix.Node() + }() + + sig <- nil + runtime.Gosched() + source.Close() + + wantNode := testNode(0, 0) + mix.AddSource(CycleNodes([]*Node{wantNode})) + n := <-sig + + if len(mix.sources) != 1 { + t.Fatalf("have %d sources, want one", len(mix.sources)) + } + if n != wantNode { + t.Fatalf("mixer returned wrong node") + } +} + +type blockingIter chan struct{} + +func (it blockingIter) Next() bool { + <-it + return false +} + +func (it blockingIter) Node() *Node { + return nil +} + +func (it blockingIter) Close() { + close(it) +} + +func TestFairMixClose(t *testing.T) { + for i := 0; i < 20 && !t.Failed(); i++ { + testMixerClose(t) + } +} + +func testMixerClose(t *testing.T) { + mix := NewFairMix(-1) + mix.AddSource(CycleNodes(nil)) + mix.AddSource(CycleNodes(nil)) + + done := make(chan struct{}) + go func() { + defer close(done) + if mix.Next() { + t.Error("Next returned true") + } + }() + // This call is supposed to make it more likely that NextNode is + // actually executing by the time we call Close. + runtime.Gosched() + + mix.Close() + select { + case <-done: + case <-time.After(3 * time.Second): + t.Fatal("Next didn't unblock on Close") + } + + mix.Close() // shouldn't crash +} + +func idPrefixDistribution(nodes []*Node) map[uint32]int { + d := make(map[uint32]int) + for _, node := range nodes { + id := node.ID() + d[binary.BigEndian.Uint32(id[:4])]++ + } + return d +} + +func approxEqual(x, y, ε int) bool { + if y > x { + x, y = y, x + } + return x-y > ε +} + +// genIter creates fake nodes with numbered IDs based on 'index' and 'gen' +type genIter struct { + node *Node + index, gen uint32 +} + +func (s *genIter) Next() bool { + index := atomic.LoadUint32(&s.index) + if index == ^uint32(0) { + s.node = nil + return false + } + s.node = testNode(uint64(index)<<32|uint64(s.gen), 0) + s.gen++ + return true +} + +func (s *genIter) Node() *Node { + return s.node +} + +func (s *genIter) Close() { + s.index = ^uint32(0) +} + +func testNode(id, seq uint64) *Node { + var nodeID ID + binary.BigEndian.PutUint64(nodeID[:], id) + r := new(enr.Record) + r.SetSeq(seq) + return SignNull(r, nodeID) +} + +// callCountIter counts calls to NextNode. +type callCountIter struct { + Iterator + count int +} + +func (it *callCountIter) Next() bool { + it.count++ + return it.Iterator.Next() +} diff --git a/p2p/protocol.go b/p2p/protocol.go index 9ce4c20203c7..fa23a087c281 100644 --- a/p2p/protocol.go +++ b/p2p/protocol.go @@ -54,6 +54,11 @@ type Protocol struct { // but returns nil, it is assumed that the protocol handshake is still running. PeerInfo func(id enode.ID) interface{} + // DialCandidates, if non-nil, is a way to tell Server about protocol-specific nodes + // that should be dialed. The server continuously reads nodes from the iterator and + // attempts to create connections to them. + DialCandidates enode.Iterator + // Attributes contains protocol specific information for the node record. Attributes []enr.Entry } diff --git a/p2p/server.go b/p2p/server.go index 692c9eb7d91d..246148741fd9 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -45,6 +45,11 @@ import ( const ( defaultDialTimeout = 15 * time.Second + // This is the fairness knob for the discovery mixer. When looking for peers, we'll + // wait this long for a single source of candidates before moving on and trying other + // sources. + discmixTimeout = 5 * time.Second + // Connectivity defaults. maxActiveDialTasks = 16 defaultMaxPendingPeers = 50 @@ -167,16 +172,20 @@ type Server struct { lock sync.Mutex // protects running running bool - nodedb *enode.DB - localnode *enode.LocalNode - ntab discoverTable listener net.Listener ourHandshake *protoHandshake - DiscV5 *discv5.Network loopWG sync.WaitGroup // loop, listenLoop peerFeed event.Feed log log.Logger + nodedb *enode.DB + localnode *enode.LocalNode + ntab *discover.UDPv4 + DiscV5 *discv5.Network + discmix *enode.FairMix + + staticNodeResolver nodeResolver + // Channels into the run loop. quit chan struct{} addstatic chan *enode.Node @@ -470,7 +479,7 @@ func (srv *Server) Start() (err error) { } dynPeers := srv.maxDialedConns() - dialer := newDialState(srv.localnode.ID(), srv.ntab, dynPeers, &srv.Config) + dialer := newDialState(srv.localnode.ID(), dynPeers, &srv.Config) srv.loopWG.Add(1) go srv.run(dialer) return nil @@ -521,6 +530,18 @@ func (srv *Server) setupLocalNode() error { } func (srv *Server) setupDiscovery() error { + srv.discmix = enode.NewFairMix(discmixTimeout) + + // Add protocol-specific discovery sources. + added := make(map[string]bool) + for _, proto := range srv.Protocols { + if proto.DialCandidates != nil && !added[proto.Name] { + srv.discmix.AddSource(proto.DialCandidates) + added[proto.Name] = true + } + } + + // Don't listen on UDP endpoint if DHT is disabled. if srv.NoDiscovery && !srv.DiscoveryV5 { return nil } @@ -562,7 +583,10 @@ func (srv *Server) setupDiscovery() error { return err } srv.ntab = ntab + srv.discmix.AddSource(ntab.RandomNodes()) + srv.staticNodeResolver = ntab } + // Discovery V5 if srv.DiscoveryV5 { var ntab *discv5.Network @@ -620,6 +644,7 @@ func (srv *Server) run(dialstate dialer) { srv.log.Info("Started P2P networking", "self", srv.localnode.Node().URLv4()) defer srv.loopWG.Done() defer srv.nodedb.Close() + defer srv.discmix.Close() var ( peers = make(map[enode.ID]*Peer) diff --git a/p2p/server_test.go b/p2p/server_test.go index e8bc627e1d30..383445c83388 100644 --- a/p2p/server_test.go +++ b/p2p/server_test.go @@ -233,8 +233,8 @@ func TestServerTaskScheduling(t *testing.T) { Config: Config{MaxPeers: 10}, localnode: enode.NewLocalNode(db, newkey()), nodedb: db, + discmix: enode.NewFairMix(0), quit: make(chan struct{}), - ntab: fakeTable{}, running: true, log: log.New(), } @@ -282,9 +282,9 @@ func TestServerManyTasks(t *testing.T) { quit: make(chan struct{}), localnode: enode.NewLocalNode(db, newkey()), nodedb: db, - ntab: fakeTable{}, running: true, log: log.New(), + discmix: enode.NewFairMix(0), } done = make(chan *testTask) start, end = 0, 0