Skip to content

Commit

Permalink
parallel container startup with deferred values
Browse files Browse the repository at this point in the history
This commit tries to be as unobtrusive as possible, attaching new behavior to existing types where possible rather than building out new infrastructure.

constructorNode returns a deferred value when called. On the first call, it asks paramList to start building an arg slice, which may also be deferred.

Once the arg slice is resolved, constructorNode schedules its constructor function to be called. Once it's called, it resolves its own deferral.

Multiple paramSingles can observe the same constructorNode before it's ready. If there's an error, they may all see the same error, which is a change in behavior.

There are two schedulers: synchronous and parallel. The synchronous scheduler returns things in the same order as before. The parallel may not (and the tests that rely on shuffle order will fail). The scheduler needs to be flushed after deferred values are created. The synchronous scheduler does nothing on when flushing, but the parallel scheduler runs a pool of goroutines to resolve constructors.

Calls to dig functions always happen on the same goroutine as Scope.Invoke(). Calls to constructor functions can happen on pooled goroutines.

The choice of scheduler is up to the Scope. Whether constructor functions are safe to call in parallel seems most logically to be a property of the scope, and the scope is passed down the constructor/param call chain.
  • Loading branch information
xandris committed Jan 27, 2022
1 parent f478a90 commit cce342d
Show file tree
Hide file tree
Showing 12 changed files with 609 additions and 91 deletions.
84 changes: 57 additions & 27 deletions constructor.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,18 @@ type constructorNode struct {
// id uniquely identifies the constructor that produces a node.
id dot.CtorID

// Whether this node is already building its paramList and calling the constructor
calling bool

// Whether the constructor owned by this node was already called.
called bool

// Type information about constructor parameters.
paramList paramList

// The result of calling the constructor
deferred deferred

// Type information about constructor results.
resultList resultList

Expand Down Expand Up @@ -121,42 +127,66 @@ func (n *constructorNode) String() string {
return fmt.Sprintf("deps: %v, ctor: %v", n.paramList, n.ctype)
}

// Call calls this constructor if it hasn't already been called and
// injects any values produced by it into the provided container.
func (n *constructorNode) Call(c containerStore) error {
if n.called {
return nil
// Call calls this constructor if it hasn't already been called and injects any values produced by it into the container
// passed to newConstructorNode.
//
// If constructorNode has a unresolved deferred already in the process of building, it will return that one. If it has
// already been successfully called, it will return an already-resolved deferred. Together these mean it will try the
// call again if it failed last time.
//
// On failure, the returned pointer is not guaranteed to stay in a failed state; another call will reset it back to its
// zero value; don't store the returned pointer. (It will still call each observer only once.)
func (n *constructorNode) Call(c containerStore) *deferred {
if n.calling || n.called {
return &n.deferred
}

n.calling = true
n.deferred = deferred{}

if err := shallowCheckDependencies(c, n.paramList); err != nil {
return errMissingDependencies{
n.deferred.resolve(errMissingDependencies{
Func: n.location,
Reason: err,
}
})
}

args, err := n.paramList.BuildList(c)
if err != nil {
return errArgumentsFailed{
Func: n.location,
Reason: err,
var args []reflect.Value
d := n.paramList.BuildList(c, &args)

d.observe(func(err error) {
if err != nil {
n.calling = false
n.deferred.resolve(errArgumentsFailed{
Func: n.location,
Reason: err,
})
return
}
}

receiver := newStagingContainerWriter()
results := c.invoker()(reflect.ValueOf(n.ctor), args)
if err := n.resultList.ExtractList(receiver, results); err != nil {
return errConstructorFailed{Func: n.location, Reason: err}
}

// Commit the result to the original container that this constructor
// was supplied to. The provided constructor is only used for a view of
// the rest of the graph to instantiate the dependencies of this
// container.
receiver.Commit(n.s)
n.called = true

return nil
var results []reflect.Value

c.scheduler().schedule(func() {
results = c.invoker()(reflect.ValueOf(n.ctor), args)
}).observe(func(_ error) {
n.calling = false
receiver := newStagingContainerWriter()
if err := n.resultList.ExtractList(receiver, results); err != nil {
n.deferred.resolve(errConstructorFailed{Func: n.location, Reason: err})
return
}

// Commit the result to the original container that this constructor
// was supplied to. The provided container is only used for a view of
// the rest of the graph to instantiate the dependencies of this
// container.
receiver.Commit(n.s)
n.called = true
n.deferred.resolve(nil)
})
})

return &n.deferred
}

// stagingContainerWriter is a containerWriter that records the changes that
Expand Down
8 changes: 6 additions & 2 deletions constructor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,11 @@ func TestNodeAlreadyCalled(t *testing.T) {
require.False(t, n.called, "node must not have been called")

c := New()
require.NoError(t, n.Call(c.scope), "invoke failed")
d := n.Call(c.scope)
c.scope.sched.flush()
require.NoError(t, d.err, "invoke failed")
require.True(t, n.called, "node must be called")
require.NoError(t, n.Call(c.scope), "calling again should be okay")
d = n.Call(c.scope)
c.scope.sched.flush()
require.NoError(t, d.err, "calling again should be okay")
}
26 changes: 26 additions & 0 deletions container.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ type containerStore interface {

// Returns invokerFn function to use when calling arguments.
invoker() invokerFn

// Returns the scheduler to use for this scope.
scheduler() scheduler
}

// New constructs a Container.
Expand Down Expand Up @@ -208,6 +211,29 @@ func dryInvoker(fn reflect.Value, _ []reflect.Value) []reflect.Value {
return results
}

type maxConcurrencyOption int

// MaxConcurrency run constructors in this container with a fixed pool of executor
// goroutines. max is the number of goroutines to start.
func MaxConcurrency(max int) Option {
return maxConcurrencyOption(max)
}

func (m maxConcurrencyOption) applyOption(container *Container) {
container.scope.sched = &parallelScheduler{concurrency: int(m)}
}

type unboundedConcurrency struct{}

// UnboundedConcurrency run constructors in this container as concurrently as possible.
// Go's resource limits like GOMAXPROCS will inherently limit how much can happen in
// parallel.
var UnboundedConcurrency Option = unboundedConcurrency{}

func (u unboundedConcurrency) applyOption(container *Container) {
container.scope.sched = &unboundedScheduler{}
}

// String representation of the entire Container
func (c *Container) String() string {
return c.scope.String()
Expand Down
104 changes: 104 additions & 0 deletions deferred.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package dig

type observer func(error)

// A deferred is an observable future result that may fail. Its zero value is unresolved and has no observers. It can
// be resolved once, at which point every observer will be called.
type deferred struct {
observers []observer
settled bool
err error
}

// alreadyResolved is a deferred that has already been resolved with a nil error.
var alreadyResolved = deferred{settled: true}

// failedDeferred returns a deferred that is resolved with the given error.
func failedDeferred(err error) *deferred {
return &deferred{settled: true, err: err}
}

// observe registers an observer to receive a callback when this deferred is resolved. It will be called at most one
// time. If this deferred is already resolved, the observer is called immediately, before observe returns.
func (d *deferred) observe(obs observer) {
if d.settled {
obs(d.err)
return
}

d.observers = append(d.observers, obs)
}

// resolve sets the status of this deferred and notifies all observers if it's not already resolved.
func (d *deferred) resolve(err error) {
if d.settled {
return
}

d.settled = true
d.err = err
for _, obs := range d.observers {
obs(err)
}
d.observers = nil
}

// then returns a new deferred that is either resolved with the same error as this deferred, or any error returned from
// the supplied function. The supplied function is only called if this deferred is resolved without error.
func (d *deferred) then(res func() error) *deferred {
d2 := new(deferred)
d.observe(func(err error) {
if err != nil {
d2.resolve(err)
return
}
d2.resolve(res())
})
return d2
}

// catch maps any error from this deferred using the supplied function. The supplied function is only called if this
// deferred is resolved with an error. If the supplied function returns a nil error, the new deferred will resolve
// successfully.
func (d *deferred) catch(rej func(error) error) *deferred {
d2 := new(deferred)
d.observe(func(err error) {
if err != nil {
err = rej(err)
}
d2.resolve(err)
})
return d2
}

// whenAll returns a new deferred that resolves when all the supplied deferreds resolve. It resolves with the first
// error reported by any deferred, or nil if they all succeed.
func whenAll(others ...*deferred) *deferred {
if len(others) == 0 {
return &alreadyResolved
}

d := new(deferred)
count := len(others)

onResolved := func(err error) {
if d.settled {
return
}

if err != nil {
d.resolve(err)
}

count--
if count == 0 {
d.resolve(nil)
}
}

for _, other := range others {
other.observe(onResolved)
}

return d
}
89 changes: 89 additions & 0 deletions dig_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"math/rand"
"os"
"reflect"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -3566,3 +3567,91 @@ func TestEndToEndSuccessWithAliases(t *testing.T) {
})

}

func TestConcurrency(t *testing.T) {
// Ensures providers will run at the same time
t.Run("TestMaxConcurrency", func(t *testing.T) {
t.Parallel()

type (
A int
B int
C int
)

var (
timer = time.NewTimer(10 * time.Second)
max int32 = 3
done = make(chan struct{})
running int32 = 0
waitForUs = func() error {
if atomic.AddInt32(&running, 1) == max {
close(done)
}
select {
case <-timer.C:
return errors.New("timeout expired")
case <-done:
return nil
}
}
c = digtest.New(t, dig.MaxConcurrency(int(max)))
)

c.RequireProvide(func() (A, error) { return 0, waitForUs() })
c.RequireProvide(func() (B, error) { return 1, waitForUs() })
c.RequireProvide(func() (C, error) { return 2, waitForUs() })

c.RequireInvoke(func(a A, b B, c C) {
require.Equal(t, a, A(0))
require.Equal(t, b, B(1))
require.Equal(t, c, C(2))
require.Equal(t, running, int32(3))
})
})

t.Run("TestUnboundConcurrency", func(t *testing.T) {
t.Parallel()

var (
timer = time.NewTimer(10 * time.Second)
max int32 = 20
done = make(chan struct{})
running int32 = 0
waitForUs = func() error {
if atomic.AddInt32(&running, 1) >= max {
close(done)
}
select {
case <-timer.C:
return errors.New("timeout expired")
case <-done:
return nil
}
}
c = digtest.New(t, dig.UnboundedConcurrency)
expected []int
)

for i := 0; i < int(max); i++ {
i := i
expected = append(expected, i)
type out struct {
dig.Out

Value int `group:"a"`
}
c.RequireProvide(func() (out, error) { return out{Value: i}, waitForUs() })
}

type in struct {
dig.In

Values []int `group:"a"`
}

c.RequireInvoke(func(i in) {
require.ElementsMatch(t, expected, i.Values)
})
})
}
Loading

0 comments on commit cce342d

Please sign in to comment.