From 7e93adce69fc03388034678d653af0f9e5c3d9e7 Mon Sep 17 00:00:00 2001 From: Bruce Riley Date: Fri, 7 Jun 2024 13:28:59 -0500 Subject: [PATCH 1/2] Node: Minor tweaks and spy improvement --- node/cmd/spy/spy.go | 33 +--------- node/pkg/accountant/submit_obs.go | 17 +---- node/pkg/common/channel_utils.go | 20 ++++++ node/pkg/common/guardianset.go | 11 +++- node/pkg/common/guardianset_test.go | 2 +- node/pkg/p2p/p2p.go | 98 ++++++++++++++++------------- node/pkg/processor/broadcast.go | 3 +- node/pkg/processor/cleanup.go | 8 +-- node/pkg/processor/observation.go | 8 +-- node/pkg/processor/processor.go | 2 +- 10 files changed, 95 insertions(+), 107 deletions(-) create mode 100644 node/pkg/common/channel_utils.go diff --git a/node/cmd/spy/spy.go b/node/cmd/spy/spy.go index ae616ced6c..e39a3d3f99 100644 --- a/node/cmd/spy/spy.go +++ b/node/cmd/spy/spy.go @@ -340,12 +340,6 @@ func runSpy(cmd *cobra.Command, args []string) { // Outbound gossip message queue sendC := make(chan []byte) - // Inbound observations - obsvC := make(chan *common.MsgWithTimeStamp[gossipv1.SignedObservation], 1024) - - // Inbound observation requests - obsvReqC := make(chan *gossipv1.ObservationRequest, 1024) - // Inbound signed VAAs signedInC := make(chan *gossipv1.SignedVAAWithQuorum, 1024) @@ -370,29 +364,6 @@ func runSpy(cmd *cobra.Command, args []string) { } } - // Ignore observations - go func() { - for { - select { - case <-rootCtx.Done(): - return - case <-obsvC: - } - } - }() - - // Ignore observation requests - // Note: without this, the whole program hangs on observation requests - go func() { - for { - select { - case <-rootCtx.Done(): - return - case <-obsvReqC: - } - } - }() - // Log signed VAAs go func() { for { @@ -422,8 +393,8 @@ func runSpy(cmd *cobra.Command, args []string) { components.Port = *p2pPort if err := supervisor.Run(ctx, "p2p", - p2p.Run(obsvC, - obsvReqC, + p2p.Run(nil, // Ignore incoming observations. + nil, // Ignore observation requests. nil, sendC, signedInC, diff --git a/node/pkg/accountant/submit_obs.go b/node/pkg/accountant/submit_obs.go index a21ba58466..7da1695170 100644 --- a/node/pkg/accountant/submit_obs.go +++ b/node/pkg/accountant/submit_obs.go @@ -67,7 +67,7 @@ func (acct *Accountant) handleBatch(ctx context.Context, subChan chan *common.Me ctx, cancel := context.WithTimeout(ctx, delayInMS) defer cancel() - msgs, err := readFromChannel[*common.MessagePublication](ctx, subChan, batchSize) + msgs, err := common.ReadFromChannelWithTimeout[*common.MessagePublication](ctx, subChan, batchSize) if err != nil && !errors.Is(err, context.DeadlineExceeded) { return fmt.Errorf("failed to read messages from channel for %s: %w", tag, err) } @@ -95,21 +95,6 @@ func (acct *Accountant) handleBatch(ctx context.Context, subChan chan *common.Me return nil } -// readFromChannel reads events from the channel until a timeout occurs or the batch is full, and returns them. -func readFromChannel[T any](ctx context.Context, ch <-chan T, count int) ([]T, error) { - out := make([]T, 0, count) - for len(out) < count { - select { - case <-ctx.Done(): - return out, ctx.Err() - case msg := <-ch: - out = append(out, msg) - } - } - - return out, nil -} - // removeCompleted drops any messages that are no longer in the pending transfer map. This is to handle the case where the contract reports // that a transfer is committed while it is in the channel. There is no point in submitting the observation once the transfer is committed. func (acct *Accountant) removeCompleted(msgs []*common.MessagePublication) []*common.MessagePublication { diff --git a/node/pkg/common/channel_utils.go b/node/pkg/common/channel_utils.go new file mode 100644 index 0000000000..60c3ff522c --- /dev/null +++ b/node/pkg/common/channel_utils.go @@ -0,0 +1,20 @@ +package common + +import ( + "context" +) + +// ReadFromChannelWithTimeout reads events from the channel until a timeout occurs or the max maxCount is reached. +func ReadFromChannelWithTimeout[T any](ctx context.Context, ch <-chan T, maxCount int) ([]T, error) { + out := make([]T, 0, maxCount) + for len(out) < maxCount { + select { + case <-ctx.Done(): + return out, ctx.Err() + case msg := <-ch: + out = append(out, msg) + } + } + + return out, nil +} diff --git a/node/pkg/common/guardianset.go b/node/pkg/common/guardianset.go index 77f14dd0c8..ea153a90d4 100644 --- a/node/pkg/common/guardianset.go +++ b/node/pkg/common/guardianset.go @@ -54,8 +54,8 @@ type GuardianSet struct { // On-chain set index Index uint32 - // Quorum value for this set of keys - Quorum int + // quorum value for this set of keys + quorum int // A map from address to index. Testing showed that, on average, a map is almost three times faster than a sequential search of the key slice. // Testing also showed that the map was twice as fast as using a sorted slice and `slices.BinarySearchFunc`. That being said, on a 4GHz CPU, @@ -63,6 +63,11 @@ type GuardianSet struct { keyMap map[common.Address]int } +// Quorum returns the current quorum value. +func (gs *GuardianSet) Quorum() int { + return gs.quorum +} + func NewGuardianSet(keys []common.Address, index uint32) *GuardianSet { keyMap := map[common.Address]int{} for idx, key := range keys { @@ -71,7 +76,7 @@ func NewGuardianSet(keys []common.Address, index uint32) *GuardianSet { return &GuardianSet{ Keys: keys, Index: index, - Quorum: vaa.CalculateQuorum(len(keys)), + quorum: vaa.CalculateQuorum(len(keys)), keyMap: keyMap, } } diff --git a/node/pkg/common/guardianset_test.go b/node/pkg/common/guardianset_test.go index de936630b0..0112c38eee 100644 --- a/node/pkg/common/guardianset_test.go +++ b/node/pkg/common/guardianset_test.go @@ -34,7 +34,7 @@ func TestNewGuardianSet(t *testing.T) { gs := NewGuardianSet(keys, 1) assert.True(t, reflect.DeepEqual(keys, gs.Keys)) assert.Equal(t, uint32(1), gs.Index) - assert.Equal(t, vaa.CalculateQuorum(len(keys)), gs.Quorum) + assert.Equal(t, vaa.CalculateQuorum(len(keys)), gs.Quorum()) } func TestKeyIndex(t *testing.T) { diff --git a/node/pkg/p2p/p2p.go b/node/pkg/p2p/p2p.go index 8afb4bddf4..e28d1b41eb 100644 --- a/node/pkg/p2p/p2p.go +++ b/node/pkg/p2p/p2p.go @@ -590,7 +590,9 @@ func Run( } // Send to local observation request queue (the loopback message is ignored) - obsvReqC <- msg + if obsvReqC != nil { + obsvReqC <- msg + } err = th.Publish(ctx, b) p2pMessagesSent.Inc() @@ -699,59 +701,65 @@ func Run( }() } case *gossipv1.GossipMessage_SignedObservation: - if err := common.PostMsgWithTimestamp[gossipv1.SignedObservation](m.SignedObservation, obsvC); err == nil { - p2pMessagesReceived.WithLabelValues("observation").Inc() - } else { - if components.WarnChannelOverflow { - logger.Warn("Ignoring SignedObservation because obsvC full", zap.String("hash", hex.EncodeToString(m.SignedObservation.Hash))) + if obsvC != nil { + if err := common.PostMsgWithTimestamp[gossipv1.SignedObservation](m.SignedObservation, obsvC); err == nil { + p2pMessagesReceived.WithLabelValues("observation").Inc() + } else { + if components.WarnChannelOverflow { + logger.Warn("Ignoring SignedObservation because obsvC full", zap.String("hash", hex.EncodeToString(m.SignedObservation.Hash))) + } + p2pReceiveChannelOverflow.WithLabelValues("observation").Inc() } - p2pReceiveChannelOverflow.WithLabelValues("observation").Inc() } case *gossipv1.GossipMessage_SignedVaaWithQuorum: - select { - case signedInC <- m.SignedVaaWithQuorum: - p2pMessagesReceived.WithLabelValues("signed_vaa_with_quorum").Inc() - default: - if components.WarnChannelOverflow { - // TODO do not log this in production - var hexStr string - if vaa, err := vaa.Unmarshal(m.SignedVaaWithQuorum.Vaa); err == nil { - hexStr = vaa.HexDigest() + if signedInC != nil { + select { + case signedInC <- m.SignedVaaWithQuorum: + p2pMessagesReceived.WithLabelValues("signed_vaa_with_quorum").Inc() + default: + if components.WarnChannelOverflow { + // TODO do not log this in production + var hexStr string + if vaa, err := vaa.Unmarshal(m.SignedVaaWithQuorum.Vaa); err == nil { + hexStr = vaa.HexDigest() + } + logger.Warn("Ignoring SignedVaaWithQuorum because signedInC full", zap.String("hash", hexStr)) } - logger.Warn("Ignoring SignedVaaWithQuorum because signedInC full", zap.String("hash", hexStr)) + p2pReceiveChannelOverflow.WithLabelValues("signed_vaa_with_quorum").Inc() } - p2pReceiveChannelOverflow.WithLabelValues("signed_vaa_with_quorum").Inc() } case *gossipv1.GossipMessage_SignedObservationRequest: - s := m.SignedObservationRequest - gs := gst.Get() - if gs == nil { - if logger.Level().Enabled(zapcore.DebugLevel) { - logger.Debug("dropping SignedObservationRequest - no guardian set", zap.Any("value", s), zap.String("from", envelope.GetFrom().String())) - } - break - } - r, err := processSignedObservationRequest(s, gs) - if err != nil { - p2pMessagesReceived.WithLabelValues("invalid_signed_observation_request").Inc() - if logger.Level().Enabled(zapcore.DebugLevel) { - logger.Debug("invalid signed observation request received", - zap.Error(err), - zap.Any("payload", msg.Message), - zap.Any("value", s), - zap.Binary("raw", envelope.Data), - zap.String("from", envelope.GetFrom().String())) - } - } else { - if logger.Level().Enabled(zapcore.DebugLevel) { - logger.Debug("valid signed observation request received", zap.Any("value", r), zap.String("from", envelope.GetFrom().String())) + if obsvReqC != nil { + s := m.SignedObservationRequest + gs := gst.Get() + if gs == nil { + if logger.Level().Enabled(zapcore.DebugLevel) { + logger.Debug("dropping SignedObservationRequest - no guardian set", zap.Any("value", s), zap.String("from", envelope.GetFrom().String())) + } + break } + r, err := processSignedObservationRequest(s, gs) + if err != nil { + p2pMessagesReceived.WithLabelValues("invalid_signed_observation_request").Inc() + if logger.Level().Enabled(zapcore.DebugLevel) { + logger.Debug("invalid signed observation request received", + zap.Error(err), + zap.Any("payload", msg.Message), + zap.Any("value", s), + zap.Binary("raw", envelope.Data), + zap.String("from", envelope.GetFrom().String())) + } + } else { + if logger.Level().Enabled(zapcore.DebugLevel) { + logger.Debug("valid signed observation request received", zap.Any("value", r), zap.String("from", envelope.GetFrom().String())) + } - select { - case obsvReqC <- r: - p2pMessagesReceived.WithLabelValues("signed_observation_request").Inc() - default: - p2pReceiveChannelOverflow.WithLabelValues("signed_observation_request").Inc() + select { + case obsvReqC <- r: + p2pMessagesReceived.WithLabelValues("signed_observation_request").Inc() + default: + p2pReceiveChannelOverflow.WithLabelValues("signed_observation_request").Inc() + } } } case *gossipv1.GossipMessage_SignedChainGovernorConfig: diff --git a/node/pkg/processor/broadcast.go b/node/pkg/processor/broadcast.go index 40ecfa665a..641332de33 100644 --- a/node/pkg/processor/broadcast.go +++ b/node/pkg/processor/broadcast.go @@ -8,7 +8,6 @@ import ( "github.com/prometheus/client_golang/prometheus/promauto" ethcommon "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/crypto" "google.golang.org/protobuf/proto" node_common "github.com/certusone/wormhole/node/pkg/common" @@ -43,7 +42,7 @@ func (p *Processor) broadcastSignature( ) { digest := o.SigningDigest() obsv := gossipv1.SignedObservation{ - Addr: crypto.PubkeyToAddress(p.gk.PublicKey).Bytes(), + Addr: p.ourAddr.Bytes(), Hash: digest.Bytes(), Signature: signature, TxHash: txhash, diff --git a/node/pkg/processor/cleanup.go b/node/pkg/processor/cleanup.go index 31120409dc..bcc5e6cd6d 100644 --- a/node/pkg/processor/cleanup.go +++ b/node/pkg/processor/cleanup.go @@ -115,7 +115,7 @@ func (p *Processor) handleCleanup(ctx context.Context) { } hasSigs := len(s.signatures) - quorum := hasSigs >= gs.Quorum + quorum := hasSigs >= gs.Quorum() var chain vaa.ChainID if s.ourObservation != nil { @@ -128,7 +128,7 @@ func (p *Processor) handleCleanup(ctx context.Context) { zap.String("digest", hash), zap.Duration("delta", delta), zap.Int("have_sigs", hasSigs), - zap.Int("required_sigs", gs.Quorum), + zap.Int("required_sigs", gs.Quorum()), zap.Bool("quorum", quorum), zap.Stringer("emitter_chain", chain), ) @@ -245,8 +245,8 @@ func (p *Processor) handleCleanup(ctx context.Context) { zap.String("digest", hash), zap.Duration("delta", delta), zap.Int("have_sigs", hasSigs), - zap.Int("required_sigs", p.gs.Quorum), - zap.Bool("quorum", hasSigs >= p.gs.Quorum), + zap.Int("required_sigs", p.gs.Quorum()), + zap.Bool("quorum", hasSigs >= p.gs.Quorum()), ) } delete(p.state.signatures, hash) diff --git a/node/pkg/processor/observation.go b/node/pkg/processor/observation.go index ca96083e34..e1545ce49c 100644 --- a/node/pkg/processor/observation.go +++ b/node/pkg/processor/observation.go @@ -228,7 +228,7 @@ func (p *Processor) handleObservation(ctx context.Context, obs *node_common.MsgW // Hence, if len(s.signatures) < quorum, then there is definitely no quorum and we can return early to save additional computation, // but if len(s.signatures) >= quorum, there is not necessarily quorum for the active guardian set. // We will later check for quorum again after assembling the VAA for a particular guardian set. - if len(s.signatures) < gs.Quorum { + if len(s.signatures) < gs.Quorum() { // no quorum yet, we're done here if p.logger.Level().Enabled(zapcore.DebugLevel) { p.logger.Debug("quorum not yet met", @@ -250,13 +250,13 @@ func (p *Processor) handleObservation(ctx context.Context, obs *node_common.MsgW zap.Any("set", gs.KeysAsHexStrings()), zap.Uint32("index", gs.Index), zap.Bools("aggregation", agg), - zap.Int("required_sigs", gs.Quorum), + zap.Int("required_sigs", gs.Quorum()), zap.Int("have_sigs", len(sigsVaaFormat)), - zap.Bool("quorum", len(sigsVaaFormat) >= gs.Quorum), + zap.Bool("quorum", len(sigsVaaFormat) >= gs.Quorum()), ) } - if len(sigsVaaFormat) >= gs.Quorum { + if len(sigsVaaFormat) >= gs.Quorum() { // we have reached quorum *with the active guardian set* s.ourObservation.HandleQuorum(sigsVaaFormat, hash, p) } else { diff --git a/node/pkg/processor/processor.go b/node/pkg/processor/processor.go index 8264ee1841..b1e88268ef 100644 --- a/node/pkg/processor/processor.go +++ b/node/pkg/processor/processor.go @@ -223,7 +223,7 @@ func (p *Processor) Run(ctx context.Context) error { p.logger.Info("guardian set updated", zap.Strings("set", p.gs.KeysAsHexStrings()), zap.Uint32("index", p.gs.Index), - zap.Int("quorum", p.gs.Quorum), + zap.Int("quorum", p.gs.Quorum()), ) p.gst.Set(p.gs) case k := <-p.msgC: From ed0137bfc64f73e70929034e615bc1cf1d3c2818 Mon Sep 17 00:00:00 2001 From: Bruce Riley Date: Sat, 8 Jun 2024 08:40:35 -0500 Subject: [PATCH 2/2] Add tests --- node/pkg/common/channel_utils_test.go | 80 +++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 node/pkg/common/channel_utils_test.go diff --git a/node/pkg/common/channel_utils_test.go b/node/pkg/common/channel_utils_test.go new file mode 100644 index 0000000000..cfa524755c --- /dev/null +++ b/node/pkg/common/channel_utils_test.go @@ -0,0 +1,80 @@ +package common + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const myDelay = time.Millisecond * 100 +const myMaxSize = 2 +const myQueueSize = myMaxSize * 10 + +func TestReadFromChannelWithTimeout_NoData(t *testing.T) { + ctx := context.Background() + myChan := make(chan int, myQueueSize) + + // No data should timeout. + timeout, cancel := context.WithTimeout(ctx, myDelay) + defer cancel() + observations, err := ReadFromChannelWithTimeout[int](timeout, myChan, myMaxSize) + assert.Equal(t, err, context.DeadlineExceeded) + assert.Equal(t, 0, len(observations)) +} + +func TestReadFromChannelWithTimeout_SomeData(t *testing.T) { + ctx := context.Background() + myChan := make(chan int, myQueueSize) + myChan <- 1 + + // Some data but not enough to fill a message should timeout and return the data. + timeout, cancel := context.WithTimeout(ctx, myDelay) + defer cancel() + observations, err := ReadFromChannelWithTimeout[int](timeout, myChan, myMaxSize) + assert.Equal(t, err, context.DeadlineExceeded) + require.Equal(t, 1, len(observations)) + assert.Equal(t, 1, observations[0]) +} + +func TestReadFromChannelWithTimeout_JustEnoughData(t *testing.T) { + ctx := context.Background() + myChan := make(chan int, myQueueSize) + myChan <- 1 + myChan <- 2 + + // Just enough data should return the data and no error. + timeout, cancel := context.WithTimeout(ctx, myDelay) + defer cancel() + observations, err := ReadFromChannelWithTimeout[int](timeout, myChan, myMaxSize) + assert.NoError(t, err) + require.Equal(t, 2, len(observations)) + assert.Equal(t, 1, observations[0]) + assert.Equal(t, 2, observations[1]) +} + +func TestReadFromChannelWithTimeout_TooMuchData(t *testing.T) { + ctx := context.Background() + myChan := make(chan int, myQueueSize) + myChan <- 1 + myChan <- 2 + myChan <- 3 + + // If there is more data than will fit, it should immediately return a full message, then timeout and return the remainder. + timeout, cancel := context.WithTimeout(ctx, myDelay) + defer cancel() + observations, err := ReadFromChannelWithTimeout[int](timeout, myChan, myMaxSize) + assert.NoError(t, err) + require.Equal(t, 2, len(observations)) + assert.Equal(t, 1, observations[0]) + assert.Equal(t, 2, observations[1]) + + timeout2, cancel2 := context.WithTimeout(ctx, myDelay) + defer cancel2() + observations, err = ReadFromChannelWithTimeout[int](timeout2, myChan, myMaxSize) + assert.Equal(t, err, context.DeadlineExceeded) + require.Equal(t, 1, len(observations)) + assert.Equal(t, 3, observations[0]) +}