Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Sub graphs and Decorate #240

Closed
wants to merge 11 commits into from
88 changes: 75 additions & 13 deletions dig.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,15 @@ type Container struct {

// Defer acyclic check on provide until Invoke.
deferAcyclicVerification bool

// Name of the container
srikrsna marked this conversation as resolved.
Show resolved Hide resolved
name string

// Sub graphs of the container
srikrsna marked this conversation as resolved.
Show resolved Hide resolved
children []*Container

// Parent is the container that spawned this
srikrsna marked this conversation as resolved.
Show resolved Hide resolved
parent *Container
}

// containerWriter provides write access to the Container's underlying data
Expand Down Expand Up @@ -332,15 +341,25 @@ func setRand(r *rand.Rand) Option {
}

func (c *Container) knownTypes() []reflect.Type {
srikrsna marked this conversation as resolved.
Show resolved Hide resolved
typeSet := make(map[reflect.Type]struct{}, len(c.providers))
for k := range c.providers {
typeSet[k.t] = struct{}{}
getKnowTypes := func(c *Container) []reflect.Type {
typeSet := make(map[reflect.Type]struct{}, len(c.providers))
for k := range c.providers {
typeSet[k.t] = struct{}{}
}

types := make([]reflect.Type, 0, len(typeSet))
for t := range typeSet {
types = append(types, t)
}

return types
}

types := make([]reflect.Type, 0, len(typeSet))
for t := range typeSet {
types = append(types, t)
types := make([]reflect.Type, 0)
for _, c := range append(c.children, c) {
types = append(types, getKnowTypes(c)...)
srikrsna marked this conversation as resolved.
Show resolved Hide resolved
}

sort.Sort(byTypeName(types))
return types
}
Expand All @@ -366,11 +385,23 @@ func (c *Container) submitGroupedValue(name string, t reflect.Type, v reflect.Va
}

func (c *Container) getValueProviders(name string, t reflect.Type) []provider {
return c.getProviders(key{name: name, t: t})
providers := c.getProviders(key{name: name, t: t})

for _, c := range c.children {
providers = append(providers, c.getValueProviders(name, t)...)
}

return providers
}

func (c *Container) getGroupProviders(name string, t reflect.Type) []provider {
return c.getProviders(key{group: name, t: t})
providers := c.getProviders(key{group: name, t: t})

for _, c := range c.children {
providers = append(providers, c.getGroupProviders(name, t)...)
}

return providers
}

func (c *Container) getProviders(k key) []provider {
Expand All @@ -382,6 +413,14 @@ func (c *Container) getProviders(k key) []provider {
return providers
}

func (c *Container) getRoot() *Container {
if c.parent == nil {
return c
}

return c.parent.getRoot()
}

// Provide teaches the container how to build values of one or more types and
// expresses their dependencies.
//
Expand Down Expand Up @@ -433,6 +472,7 @@ func (c *Container) Provide(constructor interface{}, opts ...ProvideOption) erro
// The function may return an error to indicate failure. The error will be
// returned to the caller as-is.
func (c *Container) Invoke(function interface{}, opts ...InvokeOption) error {
cp := c.getRoot() // run invoke on root to get access to all the graphs
ftype := reflect.TypeOf(function)
if ftype == nil {
return errors.New("can't invoke an untyped nil")
Expand All @@ -446,20 +486,20 @@ func (c *Container) Invoke(function interface{}, opts ...InvokeOption) error {
return err
}

if err := shallowCheckDependencies(c, pl); err != nil {
if err := shallowCheckDependencies(cp, pl); err != nil {
return errMissingDependencies{
Func: digreflect.InspectFunc(function),
Reason: err,
}
}

if !c.isVerifiedAcyclic {
if err := c.verifyAcyclic(); err != nil {
if !cp.isVerifiedAcyclic {
if err := cp.verifyAcyclic(); err != nil {
return err
}
}

args, err := pl.BuildList(c)
args, err := pl.BuildList(cp)
if err != nil {
return errArgumentsFailed{
Func: digreflect.InspectFunc(function),
Expand All @@ -479,6 +519,27 @@ func (c *Container) Invoke(function interface{}, opts ...InvokeOption) error {
return nil
}

// Child returns a named child of this container. The child container has
// full access to the parent's types, and any types provided to the child
// will be made available to the parent.
//
// The name of the child is for observability purposes only. As such, it
// does not have to be unique across different children of the container.
func (c *Container) Child(name string) *Container {
child := &Container{
providers: make(map[key][]*node),
values: make(map[key]reflect.Value),
groups: make(map[key][]reflect.Value),
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
srikrsna marked this conversation as resolved.
Show resolved Hide resolved
name: name,
parent: c,
}

c.children = append(c.children, child)

return child
}

func (c *Container) verifyAcyclic() error {
visited := make(map[key]struct{})
for _, n := range c.nodes {
Expand Down Expand Up @@ -643,7 +704,7 @@ func (cv connectionVisitor) checkKey(k key, path string) error {
"cannot provide %v from %v: already provided by %v",
k, path, conflict)
}
if ps := cv.c.providers[k]; len(ps) > 0 {
if ps := cv.c.getRoot().getValueProviders(k.name, k.t); len(ps) > 0 {
srikrsna marked this conversation as resolved.
Show resolved Hide resolved
cons := make([]string, len(ps))
for i, p := range ps {
cons[i] = fmt.Sprint(p.Location())
Expand All @@ -653,6 +714,7 @@ func (cv connectionVisitor) checkKey(k key, path string) error {
"cannot provide %v from %v: already provided by %v",
k, path, strings.Join(cons, "; "))
}

return nil
}

Expand Down
Loading