Skip to content

Commit

Permalink
routing: avoid modifying AssumeChannelValid in unit tests
Browse files Browse the repository at this point in the history
This produces a race condition when reading AssumeChannelValid from a
different goroutine. Instead we isolate the test cases and initial
AssumeChannelValid properly.
  • Loading branch information
cfromknecht committed Feb 18, 2021
1 parent f7c5236 commit 250bc85
Showing 1 changed file with 39 additions and 18 deletions.
57 changes: 39 additions & 18 deletions routing/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ func (c *testCtx) RestartRouter() error {
func createTestCtxFromGraphInstance(startingHeight uint32, graphInstance *testGraphInstance) (
*testCtx, func(), error) {

return createTestCtxFromGraphInstanceAssumeValid(
startingHeight, graphInstance, false,
)
}

func createTestCtxFromGraphInstanceAssumeValid(startingHeight uint32,
graphInstance *testGraphInstance, assumeValid bool) (*testCtx, func(), error) {

// We'll initialize an instance of the channel router with mock
// versions of the chain and channel notifier. As we don't need to test
// any p2p functionality, the peer send and switch send messages won't
Expand Down Expand Up @@ -126,8 +134,9 @@ func createTestCtxFromGraphInstance(startingHeight uint32, graphInstance *testGr
next := atomic.AddUint64(&uniquePaymentID, 1)
return next, nil
},
PathFindingConfig: pathFindingConfig,
Clock: clock.NewTestClock(time.Unix(1, 0)),
PathFindingConfig: pathFindingConfig,
Clock: clock.NewTestClock(time.Unix(1, 0)),
AssumeChannelValid: assumeValid,
})
if err != nil {
return nil, nil, fmt.Errorf("unable to create router %v", err)
Expand Down Expand Up @@ -2034,6 +2043,15 @@ func TestPruneChannelGraphStaleEdges(t *testing.T) {
func TestPruneChannelGraphDoubleDisabled(t *testing.T) {
t.Parallel()

t.Run("no_assumechannelvalid", func(t *testing.T) {
testPruneChannelGraphDoubleDisabled(t, false)
})
t.Run("assumechannelvalid", func(t *testing.T) {
testPruneChannelGraphDoubleDisabled(t, true)
})
}

func testPruneChannelGraphDoubleDisabled(t *testing.T, assumeValid bool) {
// We'll create the following test graph so that only the last channel
// is pruned. We'll use a fresh timestamp to ensure they're not pruned
// according to that heuristic.
Expand Down Expand Up @@ -2125,34 +2143,37 @@ func TestPruneChannelGraphDoubleDisabled(t *testing.T) {
defer testGraph.cleanUp()

const startingHeight = 100
ctx, cleanUp, err := createTestCtxFromGraphInstance(
startingHeight, testGraph,
ctx, cleanUp, err := createTestCtxFromGraphInstanceAssumeValid(
startingHeight, testGraph, assumeValid,
)
if err != nil {
t.Fatalf("unable to create test context: %v", err)
}
defer cleanUp()

// All the channels should exist within the graph before pruning them.
assertChannelsPruned(t, ctx.graph, testChannels)

// If we attempt to prune them without AssumeChannelValid being set,
// none should be pruned.
if err := ctx.router.pruneZombieChans(); err != nil {
t.Fatalf("unable to prune zombie channels: %v", err)
// All the channels should exist within the graph before pruning them
// when not using AssumeChannelValid, otherwise we should have pruned
// the last channel on startup.
if !assumeValid {
assertChannelsPruned(t, ctx.graph, testChannels)
} else {
prunedChannel := testChannels[len(testChannels)-1].ChannelID
assertChannelsPruned(t, ctx.graph, testChannels, prunedChannel)
}

assertChannelsPruned(t, ctx.graph, testChannels)

// Now that AssumeChannelValid is set, we'll prune the graph again and
// the last channel should be the only one pruned.
ctx.router.cfg.AssumeChannelValid = true
if err := ctx.router.pruneZombieChans(); err != nil {
t.Fatalf("unable to prune zombie channels: %v", err)
}

prunedChannel := testChannels[len(testChannels)-1].ChannelID
assertChannelsPruned(t, ctx.graph, testChannels, prunedChannel)
// If we attempted to prune them without AssumeChannelValid being set,
// none should be pruned. Otherwise the last channel should still be
// pruned.
if !assumeValid {
assertChannelsPruned(t, ctx.graph, testChannels)
} else {
prunedChannel := testChannels[len(testChannels)-1].ChannelID
assertChannelsPruned(t, ctx.graph, testChannels, prunedChannel)
}
}

// TestFindPathFeeWeighting tests that the findPath method will properly prefer
Expand Down

0 comments on commit 250bc85

Please sign in to comment.