Skip to content

Commit

Permalink
audit background contexts (#14869)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmank88 authored Nov 6, 2024
1 parent c042b62 commit 9e899bb
Show file tree
Hide file tree
Showing 59 changed files with 519 additions and 625 deletions.
3 changes: 2 additions & 1 deletion core/capabilities/ccip/launcher/integration_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package launcher

import (
"context"
"testing"
"time"

Expand Down Expand Up @@ -115,7 +116,7 @@ type oracleCreatorPrints struct {
t *testing.T
}

func (o *oracleCreatorPrints) Create(_ uint32, config cctypes.OCR3ConfigWithMeta) (cctypes.CCIPOracle, error) {
func (o *oracleCreatorPrints) Create(ctx context.Context, _ uint32, config cctypes.OCR3ConfigWithMeta) (cctypes.CCIPOracle, error) {
pluginType := cctypes.PluginType(config.Config.PluginType)
o.t.Logf("Creating plugin oracle (pluginType: %s) with config %+v\n", pluginType, config)
return &oraclePrints{pluginType: pluginType, config: config, t: o.t}, nil
Expand Down
41 changes: 25 additions & 16 deletions core/capabilities/ccip/launcher/launcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ type launcher struct {
myP2PID ragep2ptypes.PeerID
lggr logger.Logger
homeChainReader ccipreader.HomeChain
stopChan chan struct{}
stopChan services.StopChan
// latestState is the latest capability registry state received from the syncer.
latestState registrysyncer.LocalRegistry
// regState is the latest capability registry state that we have successfully processed.
Expand Down Expand Up @@ -140,12 +140,16 @@ func (l *launcher) Start(context.Context) error {
func (l *launcher) monitor() {
defer l.wg.Done()
ticker := time.NewTicker(l.tickInterval)

ctx, cancel := l.stopChan.NewCtx()
defer cancel()

for {
select {
case <-l.stopChan:
case <-ctx.Done():
return
case <-ticker.C:
if err := l.tick(); err != nil {
if err := l.tick(ctx); err != nil {
l.lggr.Errorw("Failed to tick", "err", err)
}
}
Expand All @@ -154,7 +158,7 @@ func (l *launcher) monitor() {

// tick gets the latest registry state and processes the diff between the current and latest state.
// This may lead to starting or stopping OCR instances.
func (l *launcher) tick() error {
func (l *launcher) tick(ctx context.Context) error {
// Ensure that the home chain reader is healthy.
// For new jobs it may be possible that the home chain reader is not yet ready
// so we won't be able to fetch configs and start any OCR instances.
Expand All @@ -171,7 +175,7 @@ func (l *launcher) tick() error {
return fmt.Errorf("failed to diff capability registry states: %w", err)
}

err = l.processDiff(diffRes)
err = l.processDiff(ctx, diffRes)
if err != nil {
return fmt.Errorf("failed to process diff: %w", err)
}
Expand All @@ -183,17 +187,17 @@ func (l *launcher) tick() error {
// for any added OCR instances, it will launch them.
// for any removed OCR instances, it will shut them down.
// for any updated OCR instances, it will restart them with the new configuration.
func (l *launcher) processDiff(diff diffResult) error {
func (l *launcher) processDiff(ctx context.Context, diff diffResult) error {
err := l.processRemoved(diff.removed)
err = multierr.Append(err, l.processAdded(diff.added))
err = multierr.Append(err, l.processUpdate(diff.updated))
err = multierr.Append(err, l.processAdded(ctx, diff.added))
err = multierr.Append(err, l.processUpdate(ctx, diff.updated))

return err
}

// processUpdate will manage when configurations of an existing don are updated
// If new oracles are needed, they are created and started. Old ones will be shut down
func (l *launcher) processUpdate(updated map[registrysyncer.DonID]registrysyncer.DON) error {
func (l *launcher) processUpdate(ctx context.Context, updated map[registrysyncer.DonID]registrysyncer.DON) error {
l.lock.Lock()
defer l.lock.Unlock()

Expand All @@ -203,12 +207,13 @@ func (l *launcher) processUpdate(updated map[registrysyncer.DonID]registrysyncer
return fmt.Errorf("invariant violation: expected to find CCIP DON %d in the map of running deployments", don.ID)
}

latestConfigs, err := getConfigsForDon(l.homeChainReader, don)
latestConfigs, err := getConfigsForDon(ctx, l.homeChainReader, don)
if err != nil {
return err
}

newPlugins, err := updateDON(
ctx,
l.lggr,
l.myP2PID,
prevPlugins,
Expand All @@ -233,16 +238,17 @@ func (l *launcher) processUpdate(updated map[registrysyncer.DonID]registrysyncer

// processAdded is for when a new don is created. We know that all oracles
// must be created and started
func (l *launcher) processAdded(added map[registrysyncer.DonID]registrysyncer.DON) error {
func (l *launcher) processAdded(ctx context.Context, added map[registrysyncer.DonID]registrysyncer.DON) error {
l.lock.Lock()
defer l.lock.Unlock()

for donID, don := range added {
configs, err := getConfigsForDon(l.homeChainReader, don)
configs, err := getConfigsForDon(ctx, l.homeChainReader, don)
if err != nil {
return fmt.Errorf("failed to get current configs for don %d: %w", donID, err)
}
newPlugins, err := createDON(
ctx,
l.lggr,
l.myP2PID,
don,
Expand Down Expand Up @@ -300,6 +306,7 @@ func (l *launcher) processRemoved(removed map[registrysyncer.DonID]registrysynce
}

func updateDON(
ctx context.Context,
lggr logger.Logger,
p2pID ragep2ptypes.PeerID,
prevPlugins pluginRegistry,
Expand All @@ -318,7 +325,7 @@ func updateDON(
for _, c := range latestConfigs {
digest := c.ConfigDigest
if _, ok := prevPlugins[digest]; !ok {
oracle, err := oracleCreator.Create(don.ID, cctypes.OCR3ConfigWithMeta(c))
oracle, err := oracleCreator.Create(ctx, don.ID, cctypes.OCR3ConfigWithMeta(c))
if err != nil {
return nil, fmt.Errorf("failed to create CCIP oracle: %w for digest %x", err, digest)
}
Expand All @@ -335,6 +342,7 @@ func updateDON(
// createDON is a pure function that handles the case where a new DON is added to the capability registry.
// It returns up to 4 plugins that are later started.
func createDON(
ctx context.Context,
lggr logger.Logger,
p2pID ragep2ptypes.PeerID,
don registrysyncer.DON,
Expand All @@ -352,7 +360,7 @@ func createDON(
return nil, fmt.Errorf("digest does not match type %w", err)
}

oracle, err := oracleCreator.Create(don.ID, cctypes.OCR3ConfigWithMeta(config))
oracle, err := oracleCreator.Create(ctx, don.ID, cctypes.OCR3ConfigWithMeta(config))
if err != nil {
return nil, fmt.Errorf("failed to create CCIP oracle: %w for digest %x", err, digest)
}
Expand All @@ -363,16 +371,17 @@ func createDON(
}

func getConfigsForDon(
ctx context.Context,
homeChainReader ccipreader.HomeChain,
don registrysyncer.DON) ([]ccipreader.OCR3ConfigWithMeta, error) {
// this should be a retryable error.
commitOCRConfigs, err := homeChainReader.GetOCRConfigs(context.Background(), don.ID, uint8(cctypes.PluginTypeCCIPCommit))
commitOCRConfigs, err := homeChainReader.GetOCRConfigs(ctx, don.ID, uint8(cctypes.PluginTypeCCIPCommit))
if err != nil {
return nil, fmt.Errorf("failed to fetch OCR configs for CCIP commit plugin (don id: %d) from home chain config contract: %w",
don.ID, err)
}

execOCRConfigs, err := homeChainReader.GetOCRConfigs(context.Background(), don.ID, uint8(cctypes.PluginTypeCCIPExec))
execOCRConfigs, err := homeChainReader.GetOCRConfigs(ctx, don.ID, uint8(cctypes.PluginTypeCCIPExec))
if err != nil {
return nil, fmt.Errorf("failed to fetch OCR configs for CCIP exec plugin (don id: %d) from home chain config contract: %w",
don.ID, err)
Expand Down
43 changes: 23 additions & 20 deletions core/capabilities/ccip/launcher/launcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

cctypes "github.com/smartcontractkit/chainlink/v2/core/capabilities/ccip/types"
"github.com/smartcontractkit/chainlink/v2/core/capabilities/ccip/types/mocks"
"github.com/smartcontractkit/chainlink/v2/core/internal/testutils"

ragep2ptypes "github.com/smartcontractkit/libocr/ragep2p/types"
"github.com/stretchr/testify/mock"
Expand Down Expand Up @@ -113,7 +114,7 @@ func Test_createDON(t *testing.T) {
},
}, nil)
oracleCreator.EXPECT().Type().Return(cctypes.OracleTypeBootstrap).Once()
oracleCreator.EXPECT().Create(mock.Anything, mock.Anything).Return(mocks.NewCCIPOracle(t), nil).Twice()
oracleCreator.EXPECT().Create(mock.Anything, mock.Anything, mock.Anything).Return(mocks.NewCCIPOracle(t), nil).Twice()
},
false,
},
Expand Down Expand Up @@ -153,11 +154,11 @@ func Test_createDON(t *testing.T) {
},
}, nil)

oracleCreator.EXPECT().Create(mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool {
oracleCreator.EXPECT().Create(mock.Anything, mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool {
return cfg.Config.PluginType == uint8(cctypes.PluginTypeCCIPCommit)
})).
Return(mocks.NewCCIPOracle(t), nil)
oracleCreator.EXPECT().Create(mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool {
oracleCreator.EXPECT().Create(mock.Anything, mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool {
return cfg.Config.PluginType == uint8(cctypes.PluginTypeCCIPExec)
})).
Return(mocks.NewCCIPOracle(t), nil)
Expand Down Expand Up @@ -212,11 +213,11 @@ func Test_createDON(t *testing.T) {
},
}, nil)

oracleCreator.EXPECT().Create(mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool {
oracleCreator.EXPECT().Create(mock.Anything, mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool {
return cfg.Config.PluginType == uint8(cctypes.PluginTypeCCIPCommit)
})).
Return(mocks.NewCCIPOracle(t), nil).Twice()
oracleCreator.EXPECT().Create(mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool {
oracleCreator.EXPECT().Create(mock.Anything, mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool {
return cfg.Config.PluginType == uint8(cctypes.PluginTypeCCIPExec)
})).
Return(mocks.NewCCIPOracle(t), nil).Twice()
Expand All @@ -229,10 +230,11 @@ func Test_createDON(t *testing.T) {
if tt.expect != nil {
tt.expect(t, tt.args, tt.args.oracleCreator, tt.args.homeChainReader)
}
ctx := testutils.Context(t)

latestConfigs, err := getConfigsForDon(tt.args.homeChainReader, tt.args.don)
latestConfigs, err := getConfigsForDon(ctx, tt.args.homeChainReader, tt.args.don)
require.NoError(t, err)
_, err = createDON(tt.args.lggr, tt.args.p2pID, tt.args.don, tt.args.oracleCreator, latestConfigs)
_, err = createDON(ctx, tt.args.lggr, tt.args.p2pID, tt.args.don, tt.args.oracleCreator, latestConfigs)
if tt.wantErr {
require.Error(t, err)
} else {
Expand Down Expand Up @@ -304,11 +306,11 @@ func Test_updateDON(t *testing.T) {
ConfigDigest: utils.RandomBytes32(),
},
}, nil)
oracleCreator.EXPECT().Create(mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool {
oracleCreator.EXPECT().Create(mock.Anything, mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool {
return cfg.Config.PluginType == uint8(cctypes.PluginTypeCCIPCommit)
})).
Return(mocks.NewCCIPOracle(t), nil)
oracleCreator.EXPECT().Create(mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool {
oracleCreator.EXPECT().Create(mock.Anything, mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool {
return cfg.Config.PluginType == uint8(cctypes.PluginTypeCCIPExec)
})).
Return(mocks.NewCCIPOracle(t), nil)
Expand Down Expand Up @@ -405,11 +407,11 @@ func Test_updateDON(t *testing.T) {
ConfigDigest: utils.RandomBytes32(),
},
}, nil)
oracleCreator.EXPECT().Create(mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool {
oracleCreator.EXPECT().Create(mock.Anything, mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool {
return cfg.Config.PluginType == uint8(cctypes.PluginTypeCCIPCommit)
})).
Return(mocks.NewCCIPOracle(t), nil).Once()
oracleCreator.EXPECT().Create(mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool {
oracleCreator.EXPECT().Create(mock.Anything, mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool {
return cfg.Config.PluginType == uint8(cctypes.PluginTypeCCIPExec)
})).
Return(mocks.NewCCIPOracle(t), nil).Once()
Expand Down Expand Up @@ -472,11 +474,11 @@ func Test_updateDON(t *testing.T) {
ConfigDigest: utils.RandomBytes32(),
},
}, nil)
oracleCreator.EXPECT().Create(mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool {
oracleCreator.EXPECT().Create(mock.Anything, mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool {
return cfg.Config.PluginType == uint8(cctypes.PluginTypeCCIPCommit)
})).
Return(mocks.NewCCIPOracle(t), nil).Twice()
oracleCreator.EXPECT().Create(mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool {
oracleCreator.EXPECT().Create(mock.Anything, mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool {
return cfg.Config.PluginType == uint8(cctypes.PluginTypeCCIPExec)
})).
Return(mocks.NewCCIPOracle(t), nil).Twice()
Expand All @@ -489,10 +491,11 @@ func Test_updateDON(t *testing.T) {
if tt.expect != nil {
tt.expect(t, tt.args, tt.args.oracleCreator, tt.args.homeChainReader)
}
ctx := testutils.Context(t)

latestConfigs, err := getConfigsForDon(tt.args.homeChainReader, tt.args.don)
latestConfigs, err := getConfigsForDon(ctx, tt.args.homeChainReader, tt.args.don)
require.NoError(t, err)
newPlugins, err := updateDON(tt.args.lggr, tt.args.p2pID, tt.args.prevPlugins, tt.args.don, tt.args.oracleCreator, latestConfigs)
newPlugins, err := updateDON(ctx, tt.args.lggr, tt.args.p2pID, tt.args.prevPlugins, tt.args.don, tt.args.oracleCreator, latestConfigs)
if (err != nil) != tt.wantErr {
t.Errorf("updateDON() error = %v, wantErr %v", err, tt.wantErr)
return
Expand Down Expand Up @@ -602,11 +605,11 @@ func Test_launcher_processDiff(t *testing.T) {
commitOracle.On("Start").Return(nil)
execOracle := mocks.NewCCIPOracle(t)
execOracle.On("Start").Return(nil)
m.EXPECT().Create(mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool {
m.EXPECT().Create(mock.Anything, mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool {
return cfg.Config.PluginType == uint8(cctypes.PluginTypeCCIPCommit)
})).
Return(commitOracle, nil)
m.EXPECT().Create(mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool {
m.EXPECT().Create(mock.Anything, mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool {
return cfg.Config.PluginType == uint8(cctypes.PluginTypeCCIPExec)
})).
Return(execOracle, nil)
Expand Down Expand Up @@ -679,11 +682,11 @@ func Test_launcher_processDiff(t *testing.T) {
commitOracle.On("Start").Return(nil)
execOracle := mocks.NewCCIPOracle(t)
execOracle.On("Start").Return(nil)
m.EXPECT().Create(mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool {
m.EXPECT().Create(mock.Anything, mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool {
return cfg.Config.PluginType == uint8(cctypes.PluginTypeCCIPCommit)
})).
Return(commitOracle, nil)
m.EXPECT().Create(mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool {
m.EXPECT().Create(mock.Anything, mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool {
return cfg.Config.PluginType == uint8(cctypes.PluginTypeCCIPExec)
})).
Return(execOracle, nil)
Expand Down Expand Up @@ -733,7 +736,7 @@ func Test_launcher_processDiff(t *testing.T) {
homeChainReader: tt.fields.homeChainReader,
oracleCreator: tt.fields.oracleCreator,
}
err := l.processDiff(tt.args.diff)
err := l.processDiff(testutils.Context(t), tt.args.diff)
if tt.wantErr {
require.Error(t, err)
} else {
Expand Down
3 changes: 1 addition & 2 deletions core/capabilities/ccip/oraclecreator/bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ func (i *bootstrapOracleCreator) Type() cctypes.OracleType {
}

// Create implements types.OracleCreator.
func (i *bootstrapOracleCreator) Create(_ uint32, config cctypes.OCR3ConfigWithMeta) (cctypes.CCIPOracle, error) {
func (i *bootstrapOracleCreator) Create(ctx context.Context, _ uint32, config cctypes.OCR3ConfigWithMeta) (cctypes.CCIPOracle, error) {
// Assuming that the chain selector is referring to an evm chain for now.
// TODO: add an api that returns chain family.
// NOTE: this doesn't really matter for the bootstrap node, it doesn't do anything on-chain.
Expand All @@ -158,7 +158,6 @@ func (i *bootstrapOracleCreator) Create(_ uint32, config cctypes.OCR3ConfigWithM
oraclePeerIDs = append(oraclePeerIDs, n.P2pID)
}

ctx := context.Background()
rmnHomeReader, err := i.getRmnHomeReader(ctx, config)
if err != nil {
return nil, fmt.Errorf("failed to get RMNHome reader: %w", err)
Expand Down
Loading

0 comments on commit 9e899bb

Please sign in to comment.