diff --git a/hamt/hamt.go b/hamt/hamt.go index 74a8d2759..21cea4be6 100644 --- a/hamt/hamt.go +++ b/hamt/hamt.go @@ -24,6 +24,7 @@ import ( "context" "fmt" "os" + "sync" format "github.com/ipfs/go-unixfs" "github.com/ipfs/go-unixfs/internal" @@ -372,14 +373,11 @@ func (ds *Shard) EnumLinksAsync(ctx context.Context) <-chan format.LinkResult { go func() { defer close(linkResults) defer cancel() - getLinks := makeAsyncTrieGetLinks(ds.dserv, linkResults) - cset := cid.NewSet() - rootNode, err := ds.Node() - if err != nil { - emitResult(ctx, linkResults, format.LinkResult{Link: nil, Err: err}) - return - } - err = dag.Walk(ctx, getLinks, rootNode.Cid(), cset.Visit, dag.Concurrent()) + + err := parallelWalkDepth(ctx, ds, ds.dserv, func(formattedLink *ipld.Link) error { + emitResult(ctx, linkResults, format.LinkResult{Link: formattedLink, Err: nil}) + return nil + }) if err != nil { emitResult(ctx, linkResults, format.LinkResult{Link: nil, Err: err}) } @@ -387,6 +385,203 @@ func (ds *Shard) EnumLinksAsync(ctx context.Context) <-chan format.LinkResult { return linkResults } +type listCidShardUnion struct { + links []cid.Cid + shards []*Shard +} + +func (ds *Shard) walkLinks(processLinkValues func(formattedLink *ipld.Link) error) (*listCidShardUnion, error) { + res := &listCidShardUnion{} + + for idx, lnk := range ds.childer.links { + if nextShard := ds.childer.children[idx]; nextShard == nil { + lnkLinkType, err := ds.childLinkType(lnk) + if err != nil { + return nil, err + } + + switch lnkLinkType { + case shardValueLink: + sv, err := ds.makeShardValue(lnk) + if err != nil { + return nil, err + } + formattedLink := sv.val + formattedLink.Name = sv.key + + if err := processLinkValues(formattedLink); err != nil { + return nil, err + } + case shardLink: + res.links = append(res.links, lnk.Cid) + default: + return nil, fmt.Errorf("unsupported shard link type") + } + + } else { + if nextShard.val != nil { + formattedLink := &ipld.Link{ + Name: nextShard.key, + Size: nextShard.val.Size, + Cid: nextShard.val.Cid, + } + if err := processLinkValues(formattedLink); err != nil { + return nil, err + } + } else { + res.shards = append(res.shards, nextShard) + } + } + } + return res, nil +} + +func parallelWalkDepth(ctx context.Context, root *Shard, dserv ipld.DAGService, processShardValues func(formattedLink *ipld.Link) error) error { + const concurrency = 32 + visit := cid.NewSet().Visit + + type shardCidUnion struct { + cid cid.Cid + shard *Shard + } + + feed := make(chan *shardCidUnion) + out := make(chan *listCidShardUnion) + done := make(chan struct{}) + + var visitlk sync.Mutex + var wg sync.WaitGroup + + errChan := make(chan error) + fetchersCtx, cancel := context.WithCancel(ctx) + defer wg.Wait() + defer cancel() + for i := 0; i < concurrency; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for cdepth := range feed { + var shouldVisit bool + + if cdepth.shard != nil { + shouldVisit = true + } else { + visitlk.Lock() + shouldVisit = visit(cdepth.cid) + visitlk.Unlock() + } + + if shouldVisit { + var nextShard *Shard + if cdepth.shard != nil { + nextShard = cdepth.shard + } else { + nd, err := dserv.Get(ctx, cdepth.cid) + if err != nil { + if err != nil { + select { + case errChan <- err: + case <-fetchersCtx.Done(): + } + return + } + } + nextShard, err = NewHamtFromDag(dserv, nd) + if err != nil { + if err != nil { + if err != nil { + select { + case errChan <- err: + case <-fetchersCtx.Done(): + } + return + } + } + } + } + + nextLinks, err := nextShard.walkLinks(processShardValues) + if err != nil { + select { + case errChan <- err: + case <-fetchersCtx.Done(): + } + return + } + + select { + case out <- nextLinks: + case <-fetchersCtx.Done(): + return + } + } + select { + case done <- struct{}{}: + case <-fetchersCtx.Done(): + } + } + }() + } + defer close(feed) + + send := feed + var todoQueue []*shardCidUnion + var inProgress int + + next := &shardCidUnion{ + shard: root, + } + + for { + select { + case send <- next: + inProgress++ + if len(todoQueue) > 0 { + next = todoQueue[0] + todoQueue = todoQueue[1:] + } else { + next = nil + send = nil + } + case <-done: + inProgress-- + if inProgress == 0 && next == nil { + return nil + } + case linksDepth := <-out: + for _, c := range linksDepth.links { + cd := &shardCidUnion{ + cid: c, + } + + if next == nil { + next = cd + send = feed + } else { + todoQueue = append(todoQueue, cd) + } + } + for _, shard := range linksDepth.shards { + cd := &shardCidUnion{ + shard: shard, + } + + if next == nil { + next = cd + send = feed + } else { + todoQueue = append(todoQueue, cd) + } + } + case err := <-errChan: + return err + + case <-ctx.Done(): + return ctx.Err() + } + } +} + // makeAsyncTrieGetLinks builds a getLinks function that can be used with EnumerateChildrenAsync // to iterate a HAMT shard. It takes an IPLD Dag Service to fetch nodes, and a call back that will get called // on all links to leaf nodes in a HAMT tree, so they can be collected for an EnumLinks operation