diff --git a/hamt/hamt.go b/hamt/hamt.go index 5d7221002..7dac3b18e 100644 --- a/hamt/hamt.go +++ b/hamt/hamt.go @@ -445,45 +445,52 @@ func parallelWalkDepth(ctx context.Context, root *Shard, dserv ipld.DAGService, visitSet := cid.NewSet() visit := visitSet.Visit - type shardCidUnion struct { - cid cid.Cid - shard *Shard - } - // Setup synchronization grp, errGrpCtx := errgroup.WithContext(ctx) // Input and output queues for workers. - feed := make(chan *shardCidUnion) + feed := make(chan *listCidShardUnion) out := make(chan *listCidShardUnion) done := make(chan struct{}) for i := 0; i < concurrency; i++ { grp.Go(func() error { for shardOrCID := range feed { - var shouldVisit bool + for _, nextShard := range shardOrCID.shards { + nextLinks, err := nextShard.walkLinks(processShardValues) + if err != nil { + return err + } + + select { + case out <- nextLinks: + case <-errGrpCtx.Done(): + return nil + } + } + + var linksToVisit []cid.Cid + for _, nextLink := range shardOrCID.links { + var shouldVisit bool - if shardOrCID.shard != nil { - shouldVisit = true - } else { visitlk.Lock() - shouldVisit = visit(shardOrCID.cid) + shouldVisit = visit(nextLink) visitlk.Unlock() + + if shouldVisit { + linksToVisit = append(linksToVisit, nextLink) + } } - if shouldVisit { - var nextShard *Shard - if shardOrCID.shard != nil { - nextShard = shardOrCID.shard - } else { - nd, err := dserv.Get(ctx, shardOrCID.cid) - if err != nil { - return err - } - nextShard, err = NewHamtFromDag(dserv, nd) - if err != nil { - return err - } + chNodes := dserv.GetMany(errGrpCtx, linksToVisit) + for optNode := range chNodes { + if optNode.Err != nil { + return optNode.Err + } + + nextShard, err := NewHamtFromDag(dserv, optNode.Node) + if err != nil { + return err } nextLinks, err := nextShard.walkLinks(processShardValues) @@ -497,6 +504,7 @@ func parallelWalkDepth(ctx context.Context, root *Shard, dserv ipld.DAGService, return nil } } + select { case done <- struct{}{}: case <-errGrpCtx.Done(): @@ -507,11 +515,11 @@ func parallelWalkDepth(ctx context.Context, root *Shard, dserv ipld.DAGService, } send := feed - var todoQueue []*shardCidUnion + var todoQueue []*listCidShardUnion var inProgress int - next := &shardCidUnion{ - shard: root, + next := &listCidShardUnion{ + shards: []*Shard{root}, } dispatcherLoop: @@ -532,29 +540,11 @@ dispatcherLoop: break dispatcherLoop } case nextNodes := <-out: - for _, c := range nextNodes.links { - shardOrCid := &shardCidUnion{ - cid: c, - } - - if next == nil { - next = shardOrCid - send = feed - } else { - todoQueue = append(todoQueue, shardOrCid) - } - } - for _, shard := range nextNodes.shards { - shardOrCid := &shardCidUnion{ - shard: shard, - } - - if next == nil { - next = shardOrCid - send = feed - } else { - todoQueue = append(todoQueue, shardOrCid) - } + if next == nil { + next = nextNodes + send = feed + } else { + todoQueue = append(todoQueue, nextNodes) } case <-errGrpCtx.Done(): break dispatcherLoop