Skip to content

Commit

Permalink
Improved the search to be concurrent
Browse files Browse the repository at this point in the history
- Add support for querying multiple sources concurrently
- Return the combined results from all sources
- If any errors occur, return the first error and discard the results

[source/multi.go]
- Add `sync.WaitGroup` for waiting on all source queries to finish
- Add `errs` and `resChan` channels for collecting errors and results from sources
- Add goroutines for querying sources, collecting errors, and collecting results
- If any errors occurred, return the first error and discard the results
- Return the combined results from all sources

Signed-off-by: naveensrinivasan <172697+naveensrinivasan@users.noreply.github.com>
  • Loading branch information
naveensrinivasan authored and jkjell committed Nov 16, 2023
1 parent 40c7ed5 commit 002d897
Showing 1 changed file with 53 additions and 7 deletions.
60 changes: 53 additions & 7 deletions source/multi.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@

package source

import "context"
import (
"context"
"sync"
)

type MultiSource struct {
sources []Sourcer
Expand All @@ -24,16 +27,59 @@ func NewMultiSource(sources ...Sourcer) *MultiSource {
return &MultiSource{sources}
}

// Search concurrently queries all sources and returns the combined results.
func (s *MultiSource) Search(ctx context.Context, collectionName string, subjectDigests, attestations []string) ([]CollectionEnvelope, error) {
results := make([]CollectionEnvelope, 0)
for _, source := range s.sources {
res, err := source.Search(ctx, collectionName, subjectDigests, attestations)
if err != nil {
return results, err
results := []CollectionEnvelope{}
errors := []error{}

errs := make(chan error) // Channel for collecting errors from each source
resChan := make(chan []CollectionEnvelope) // Channel for collecting results from each source

errdone := make(chan bool) // Signal channel indicating when error collection is done
readerDone := make(chan bool) // Signal channel indicating when result collection is done

// Goroutine for collecting results from the result channel
go func() {
for item := range resChan {
results = append(results, item...)
}
readerDone <- true
}()

results = append(results, res...)
// Goroutine for collecting errors from the error channel
go func() {
for err := range errs {
errors = append(errors, err)
}
errdone <- true
}()

var wg sync.WaitGroup // WaitGroup for waiting on all source queries to finish
for _, source := range s.sources {
source := source
wg.Add(1)
// Goroutine for querying a source and collecting the results or error
go func(src Sourcer) {
defer wg.Done()
res, err := src.Search(ctx, collectionName, subjectDigests, attestations)
if err != nil {
errs <- err
} else {
resChan <- res
}
}(source)
}
wg.Wait() // Wait for all source queries to finish
close(resChan) // Close the result channel
close(errs) // Close the error channel

<-errdone // Wait for error collection to finish
<-readerDone // Wait for result collection to finish

// If any errors occurred, return the first error and discard the results
if len(errors) > 0 {
return nil, errors[0]
}
// Return the combined results from all sources
return results, nil
}

0 comments on commit 002d897

Please sign in to comment.