diff --git a/core/blockchain.go b/core/blockchain.go index fbb3a51af9..84a229f629 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -94,8 +94,9 @@ const ( diffLayerFreezerBlockLimit = 864000 // The number of diff layers that should be kept in disk. diffLayerPruneRecheckInterval = 1 * time.Second // The interval to prune unverified diff layers maxDiffQueueDist = 2048 // Maximum allowed distance from the chain head to queue diffLayers - maxDiffLimit = 2048 // Maximum number of unique diff layers a peer may have delivered + maxDiffLimit = 2048 // Maximum number of unique diff layers a peer may have responded maxDiffForkDist = 11 // Maximum allowed backward distance from the chain head + maxDiffLimitForBroadcast = 128 // Maximum number of unique diff layers a peer may have broadcasted // BlockChainVersion ensures that an incompatible database forces a resync from scratch. // @@ -2534,6 +2535,34 @@ func (bc *BlockChain) removeDiffLayers(diffHash common.Hash) { } } +func (bc *BlockChain) RemoveDiffPeer(pid string) { + bc.diffMux.Lock() + defer bc.diffMux.Unlock() + if invaliDiffHashes := bc.diffPeersToDiffHashes[pid]; invaliDiffHashes != nil { + for invalidDiffHash := range invaliDiffHashes { + lastDiffHash := false + if peers, ok := bc.diffHashToPeers[invalidDiffHash]; ok { + delete(peers, pid) + if len(peers) == 0 { + lastDiffHash = true + delete(bc.diffHashToPeers, invalidDiffHash) + } + } + if lastDiffHash { + affectedBlockHash := bc.diffHashToBlockHash[invalidDiffHash] + if diffs, exist := bc.blockHashToDiffLayers[affectedBlockHash]; exist { + delete(diffs, invalidDiffHash) + if len(diffs) == 0 { + delete(bc.blockHashToDiffLayers, affectedBlockHash) + } + } + delete(bc.diffHashToBlockHash, invalidDiffHash) + } + } + delete(bc.diffPeersToDiffHashes, pid) + } +} + func (bc *BlockChain) untrustedDiffLayerPruneLoop() { recheck := time.Tick(diffLayerPruneRecheckInterval) bc.wg.Add(1) @@ -2595,7 +2624,7 @@ func (bc *BlockChain) pruneDiffLayer() { } // Process received diff layers -func (bc *BlockChain) HandleDiffLayer(diffLayer *types.DiffLayer, pid string) error { +func (bc *BlockChain) HandleDiffLayer(diffLayer *types.DiffLayer, pid string, fulfilled bool) error { // Basic check currentHeight := bc.CurrentBlock().NumberU64() if diffLayer.Number > currentHeight && diffLayer.Number-currentHeight > maxDiffQueueDist { @@ -2610,6 +2639,13 @@ func (bc *BlockChain) HandleDiffLayer(diffLayer *types.DiffLayer, pid string) er bc.diffMux.Lock() defer bc.diffMux.Unlock() + if !fulfilled { + if len(bc.diffPeersToDiffHashes[pid]) > maxDiffLimitForBroadcast { + log.Error("too many accumulated diffLayers", "pid", pid) + return nil + } + } + if len(bc.diffPeersToDiffHashes[pid]) > maxDiffLimit { log.Error("too many accumulated diffLayers", "pid", pid) return nil @@ -2618,12 +2654,14 @@ func (bc *BlockChain) HandleDiffLayer(diffLayer *types.DiffLayer, pid string) er if _, alreadyHas := bc.diffPeersToDiffHashes[pid][diffLayer.DiffHash]; alreadyHas { return nil } - } else { - bc.diffPeersToDiffHashes[pid] = make(map[common.Hash]struct{}) } + bc.diffPeersToDiffHashes[pid] = make(map[common.Hash]struct{}) bc.diffPeersToDiffHashes[pid][diffLayer.DiffHash] = struct{}{} if _, exist := bc.diffNumToBlockHashes[diffLayer.Number]; !exist { bc.diffNumToBlockHashes[diffLayer.Number] = make(map[common.Hash]struct{}) + } + if len(bc.diffNumToBlockHashes[diffLayer.Number]) > 4 { + } bc.diffNumToBlockHashes[diffLayer.Number][diffLayer.BlockHash] = struct{}{} diff --git a/core/blockchain_diff_test.go b/core/blockchain_diff_test.go index 7df2612b35..14c5426bf6 100644 --- a/core/blockchain_diff_test.go +++ b/core/blockchain_diff_test.go @@ -143,7 +143,7 @@ func TestProcessDiffLayer(t *testing.T) { if err != nil { t.Errorf("failed to decode rawdata %v", err) } - lightBackend.Chain().HandleDiffLayer(diff, "testpid") + lightBackend.Chain().HandleDiffLayer(diff, "testpid", true) _, err = lightBackend.chain.insertChain([]*types.Block{block}, true) if err != nil { t.Errorf("failed to insert block %v", err) @@ -158,7 +158,7 @@ func TestProcessDiffLayer(t *testing.T) { bz, _ := rlp.EncodeToBytes(&latestAccount) diff.Accounts[0].Blob = bz - lightBackend.Chain().HandleDiffLayer(diff, "testpid") + lightBackend.Chain().HandleDiffLayer(diff, "testpid", true) _, err := lightBackend.chain.insertChain([]*types.Block{nextBlock}, true) if err != nil { @@ -216,8 +216,8 @@ func TestPruneDiffLayer(t *testing.T) { header := fullBackend.chain.GetHeaderByNumber(num) rawDiff := fullBackend.chain.GetDiffLayerRLP(header.Hash()) diff, _ := rawDataToDiffLayer(rawDiff) - fullBackend.Chain().HandleDiffLayer(diff, "testpid1") - fullBackend.Chain().HandleDiffLayer(diff, "testpid2") + fullBackend.Chain().HandleDiffLayer(diff, "testpid1", true) + fullBackend.Chain().HandleDiffLayer(diff, "testpid2", true) } fullBackend.chain.pruneDiffLayer() diff --git a/eth/handler_diff.go b/eth/handler_diff.go index ea310c38c2..34453e7762 100644 --- a/eth/handler_diff.go +++ b/eth/handler_diff.go @@ -35,6 +35,7 @@ func (h *diffHandler) RunPeer(peer *diff.Peer, hand diff.Handler) error { if err := peer.Handshake(h.diffSync); err != nil { return err } + defer h.chain.RemoveDiffPeer(peer.ID()) return (*handler)(h).runDiffExtension(peer, hand) } @@ -55,26 +56,34 @@ func (h *diffHandler) Handle(peer *diff.Peer, packet diff.Packet) error { // data packet for the local node to consume. switch packet := packet.(type) { case *diff.DiffLayersPacket: - diffs, err := packet.Unpack() - if err != nil { - return err - } - for _, d := range diffs { - if d != nil { - if err := d.Validate(); err != nil { - return err - } - } - } - for _, diff := range diffs { - err := h.chain.HandleDiffLayer(diff, peer.ID()) - if err != nil { - return err - } - } + return h.handleDiffLayerPackage(packet, peer.ID(), false) + + case *diff.FullDiffLayersPacket: + return h.handleDiffLayerPackage(&packet.DiffLayersPacket, peer.ID(), true) default: return fmt.Errorf("unexpected diff packet type: %T", packet) } return nil } + +func (h *diffHandler) handleDiffLayerPackage(packet *diff.DiffLayersPacket, pid string, fulfilled bool) error { + diffs, err := packet.Unpack() + if err != nil { + return err + } + for _, d := range diffs { + if d != nil { + if err := d.Validate(); err != nil { + return err + } + } + } + for _, diff := range diffs { + err := h.chain.HandleDiffLayer(diff, pid, fulfilled) + if err != nil { + return err + } + } + return nil +} diff --git a/eth/protocols/diff/handler.go b/eth/protocols/diff/handler.go index cf6828b18b..d07035fe9b 100644 --- a/eth/protocols/diff/handler.go +++ b/eth/protocols/diff/handler.go @@ -20,6 +20,8 @@ const ( maxDiffLayerServe = 1024 ) +var requestTracker = NewTracker(time.Minute) + // Handler is a callback to invoke from an outside runner after the boilerplate // exchanges have passed. type Handler func(peer *Peer) error @@ -139,8 +141,11 @@ func handleMessage(backend Backend, peer *Peer) error { if err := msg.Decode(res); err != nil { return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) } - requestTracker.Fulfil(peer.id, peer.version, FullDiffLayerMsg, res.RequestId) - return backend.Handle(peer, &res.DiffLayersPacket) + if fulfilled := requestTracker.Fulfil(peer.id, peer.version, FullDiffLayerMsg, res.RequestId); fulfilled { + return backend.Handle(peer, res) + } else { + return fmt.Errorf("%w: %v", errUnexpectedMsg, msg.Code) + } default: return fmt.Errorf("%w: %v", errInvalidMsgCode, msg.Code) } diff --git a/eth/protocols/diff/protocol.go b/eth/protocols/diff/protocol.go index 650ba4f51e..02474632a5 100644 --- a/eth/protocols/diff/protocol.go +++ b/eth/protocols/diff/protocol.go @@ -58,6 +58,7 @@ var ( errMsgTooLarge = errors.New("message too long") errDecode = errors.New("invalid message") errInvalidMsgCode = errors.New("invalid message code") + errUnexpectedMsg = errors.New("unexpected message code") errBadRequest = errors.New("bad request") errNoCapMsg = errors.New("miss cap message during handshake") ) diff --git a/eth/protocols/diff/tracker.go b/eth/protocols/diff/tracker.go index 754c41258b..7ee49e6ce2 100644 --- a/eth/protocols/diff/tracker.go +++ b/eth/protocols/diff/tracker.go @@ -17,10 +17,145 @@ package diff import ( + "container/list" + "fmt" + "sync" "time" - "github.com/ethereum/go-ethereum/p2p/tracker" + "github.com/ethereum/go-ethereum/log" ) -// requestTracker is a singleton tracker for request times. -var requestTracker = tracker.New(ProtocolName, time.Minute) +const ( + // maxTrackedPackets is a huge number to act as a failsafe on the number of + // pending requests the node will track. It should never be hit unless an + // attacker figures out a way to spin requests. + maxTrackedPackets = 10000 +) + +// request tracks sent network requests which have not yet received a response. +type request struct { + peer string + version uint // Protocol version + + reqCode uint64 // Protocol message code of the request + resCode uint64 // Protocol message code of the expected response + + time time.Time // Timestamp when the request was made + expire *list.Element // Expiration marker to untrack it +} + +type Tracker struct { + timeout time.Duration // Global timeout after which to drop a tracked packet + + pending map[uint64]*request // Currently pending requests + expire *list.List // Linked list tracking the expiration order + wake *time.Timer // Timer tracking the expiration of the next item + + lock sync.Mutex // Lock protecting from concurrent updates +} + +func NewTracker(timeout time.Duration) *Tracker { + return &Tracker{ + timeout: timeout, + pending: make(map[uint64]*request), + expire: list.New(), + } +} + +// Track adds a network request to the tracker to wait for a response to arrive +// or until the request it cancelled or times out. +func (t *Tracker) Track(peer string, version uint, reqCode uint64, resCode uint64, id uint64) { + t.lock.Lock() + defer t.lock.Unlock() + + // If there's a duplicate request, we've just random-collided (or more probably, + // we have a bug), report it. We could also add a metric, but we're not really + // expecting ourselves to be buggy, so a noisy warning should be enough. + if _, ok := t.pending[id]; ok { + log.Error("Network request id collision", "version", version, "code", reqCode, "id", id) + return + } + // If we have too many pending requests, bail out instead of leaking memory + if pending := len(t.pending); pending >= maxTrackedPackets { + log.Error("Request tracker exceeded allowance", "pending", pending, "peer", peer, "version", version, "code", reqCode) + return + } + // Id doesn't exist yet, start tracking it + t.pending[id] = &request{ + peer: peer, + version: version, + reqCode: reqCode, + resCode: resCode, + time: time.Now(), + expire: t.expire.PushBack(id), + } + + // If we've just inserted the first item, start the expiration timer + if t.wake == nil { + t.wake = time.AfterFunc(t.timeout, t.clean) + } +} + +// clean is called automatically when a preset time passes without a response +// being dleivered for the first network request. +func (t *Tracker) clean() { + t.lock.Lock() + defer t.lock.Unlock() + + // Expire anything within a certain threshold (might be no items at all if + // we raced with the delivery) + for t.expire.Len() > 0 { + // Stop iterating if the next pending request is still alive + var ( + head = t.expire.Front() + id = head.Value.(uint64) + req = t.pending[id] + ) + if time.Since(req.time) < t.timeout+5*time.Millisecond { + break + } + // Nope, dead, drop it + t.expire.Remove(head) + delete(t.pending, id) + } + t.schedule() +} + +// schedule starts a timer to trigger on the expiration of the first network +// packet. +func (t *Tracker) schedule() { + if t.expire.Len() == 0 { + t.wake = nil + return + } + t.wake = time.AfterFunc(time.Until(t.pending[t.expire.Front().Value.(uint64)].time.Add(t.timeout)), t.clean) +} + +// Fulfil fills a pending request, if any is available. +func (t *Tracker) Fulfil(peer string, version uint, code uint64, id uint64) bool { + t.lock.Lock() + defer t.lock.Unlock() + + // If it's a non existing request, track as stale response + req, ok := t.pending[id] + if !ok { + return false + } + // If the response is funky, it might be some active attack + if req.peer != peer || req.version != version || req.resCode != code { + log.Warn("Network response id collision", + "have", fmt.Sprintf("%s:/%d:%d", peer, version, code), + "want", fmt.Sprintf("%s:/%d:%d", peer, req.version, req.resCode), + ) + return false + } + // Everything matches, mark the request serviced + t.expire.Remove(req.expire) + delete(t.pending, id) + if req.expire.Prev() == nil { + if t.wake.Stop() { + t.schedule() + } + } + return true +}