From b0a93479520730d576e94e11716d4f9b603c6a31 Mon Sep 17 00:00:00 2001 From: Aaron Craelius Date: Sat, 29 May 2021 17:25:42 -0400 Subject: [PATCH] WIP --- core/abci/app_base.go | 67 +++++++ core/abci/config/config.go | 5 + core/abci/handler.go | 30 ++++ core/abci/handler_base.go | 63 +++++++ core/abci/header/header.go | 88 +++++++++ core/abci/middleware.go | 3 + core/app_config/app_config.go | 96 ++++++++-- core/container/container.go | 294 +++++++++++++++++++------------ core/container/container_test.go | 275 +++++++++++++++++++---------- core/store/kv/kv.go | 20 +++ core/store/sc/sc.go | 18 ++ core/store/ss/ss.go | 18 ++ core/tx/module/module.go | 23 +-- x/authn/app/middleware.go | 37 ++-- x/authn/module/module.go | 26 +-- 15 files changed, 792 insertions(+), 271 deletions(-) create mode 100644 core/abci/app_base.go create mode 100644 core/abci/config/config.go create mode 100644 core/abci/handler.go create mode 100644 core/abci/handler_base.go create mode 100644 core/abci/header/header.go create mode 100644 core/abci/middleware.go create mode 100644 core/store/kv/kv.go create mode 100644 core/store/sc/sc.go create mode 100644 core/store/ss/ss.go diff --git a/core/abci/app_base.go b/core/abci/app_base.go new file mode 100644 index 000000000000..815f0b23ede0 --- /dev/null +++ b/core/abci/app_base.go @@ -0,0 +1,67 @@ +package abci + +import ( + "context" + + "github.com/tendermint/tendermint/abci/types" +) + +type AppBase struct { + handler Handler + checkCtx context.Context + deliverCtx context.Context +} + +var _ types.Application = AppBase{} + +func (a AppBase) Info(info types.RequestInfo) types.ResponseInfo { + return a.handler.Info(a.checkCtx, info) +} + +func (a AppBase) SetOption(option types.RequestSetOption) types.ResponseSetOption { + return a.handler.SetOption(a.checkCtx, option) +} + +func (a AppBase) Query(query types.RequestQuery) types.ResponseQuery { + return a.handler.Query(a.checkCtx, query) +} + +func (a AppBase) CheckTx(tx types.RequestCheckTx) types.ResponseCheckTx { + return a.handler.CheckTx(a.checkCtx, tx) +} + +func (a AppBase) InitChain(chain types.RequestInitChain) types.ResponseInitChain { + return a.handler.InitChain(a.deliverCtx, chain) +} + +func (a AppBase) BeginBlock(block types.RequestBeginBlock) types.ResponseBeginBlock { + return a.handler.BeginBlock(a.deliverCtx, block) +} + +func (a AppBase) DeliverTx(tx types.RequestDeliverTx) types.ResponseDeliverTx { + return a.handler.DeliverTx(a.deliverCtx, tx) +} + +func (a AppBase) EndBlock(block types.RequestEndBlock) types.ResponseEndBlock { + return a.handler.EndBlock(a.deliverCtx, block) +} + +func (a AppBase) Commit() types.ResponseCommit { + return a.handler.Commit(a.deliverCtx) +} + +func (a AppBase) ListSnapshots(snapshots types.RequestListSnapshots) types.ResponseListSnapshots { + return a.handler.ListSnapshots(a.checkCtx, snapshots) +} + +func (a AppBase) OfferSnapshot(snapshot types.RequestOfferSnapshot) types.ResponseOfferSnapshot { + return a.handler.OfferSnapshot(a.checkCtx, snapshot) +} + +func (a AppBase) LoadSnapshotChunk(chunk types.RequestLoadSnapshotChunk) types.ResponseLoadSnapshotChunk { + return a.handler.LoadSnapshotChunk(a.checkCtx, chunk) +} + +func (a AppBase) ApplySnapshotChunk(chunk types.RequestApplySnapshotChunk) types.ResponseApplySnapshotChunk { + return a.handler.ApplySnapshotChunk(a.checkCtx, chunk) +} diff --git a/core/abci/config/config.go b/core/abci/config/config.go new file mode 100644 index 000000000000..fcb83e34f52b --- /dev/null +++ b/core/abci/config/config.go @@ -0,0 +1,5 @@ +package config + +type Config struct { + Middleware []interface{} +} diff --git a/core/abci/handler.go b/core/abci/handler.go new file mode 100644 index 000000000000..b827047e3c0c --- /dev/null +++ b/core/abci/handler.go @@ -0,0 +1,30 @@ +package abci + +import ( + "context" + + types "github.com/tendermint/tendermint/abci/types" +) + +type Handler interface { + // Info/Query Connection + Info(ctx context.Context, req types.RequestInfo) types.ResponseInfo // Return application info + SetOption(ctx context.Context, req types.RequestSetOption) types.ResponseSetOption // Set application option + Query(ctx context.Context, req types.RequestQuery) types.ResponseQuery // Query for state + + // Mempool Connection + CheckTx(ctx context.Context, req types.RequestCheckTx) types.ResponseCheckTx // Validate a tx for the mempool + + // Consensus Connection + InitChain(ctx context.Context, req types.RequestInitChain) types.ResponseInitChain // Initialize blockchain w validators/other info from TendermintCore + BeginBlock(ctx context.Context, req types.RequestBeginBlock) types.ResponseBeginBlock // Signals the beginning of a block + DeliverTx(ctx context.Context, req types.RequestDeliverTx) types.ResponseDeliverTx // Deliver a tx for full processing + EndBlock(ctx context.Context, req types.RequestEndBlock) types.ResponseEndBlock // Signals the end of a block, returns changes to the validator set + Commit(context.Context) types.ResponseCommit // Commit the state and return the application Merkle root hash + + // State Sync Connection + ListSnapshots(ctx context.Context, req types.RequestListSnapshots) types.ResponseListSnapshots // List available snapshots + OfferSnapshot(ctx context.Context, req types.RequestOfferSnapshot) types.ResponseOfferSnapshot // Offer a snapshot to the application + LoadSnapshotChunk(ctx context.Context, req types.RequestLoadSnapshotChunk) types.ResponseLoadSnapshotChunk // Load a snapshot chunk + ApplySnapshotChunk(ctx context.Context, req types.RequestApplySnapshotChunk) types.ResponseApplySnapshotChunk // Apply a shapshot chunk +} diff --git a/core/abci/handler_base.go b/core/abci/handler_base.go new file mode 100644 index 000000000000..800d07158ffd --- /dev/null +++ b/core/abci/handler_base.go @@ -0,0 +1,63 @@ +package abci + +import ( + "context" + + "github.com/tendermint/tendermint/abci/types" +) + +type HandlerBase struct{} + +func (h HandlerBase) Info(context.Context, types.RequestInfo) types.ResponseInfo { + return types.ResponseInfo{} +} + +func (h HandlerBase) SetOption(context.Context, types.RequestSetOption) types.ResponseSetOption { + return types.ResponseSetOption{} +} + +func (h HandlerBase) Query(context.Context, types.RequestQuery) types.ResponseQuery { + return types.ResponseQuery{} +} + +func (h HandlerBase) CheckTx(context.Context, types.RequestCheckTx) types.ResponseCheckTx { + return types.ResponseCheckTx{} +} + +func (h HandlerBase) InitChain(context.Context, types.RequestInitChain) types.ResponseInitChain { + return types.ResponseInitChain{} +} + +func (h HandlerBase) BeginBlock(context.Context, types.RequestBeginBlock) types.ResponseBeginBlock { + return types.ResponseBeginBlock{} +} + +func (h HandlerBase) DeliverTx(context.Context, types.RequestDeliverTx) types.ResponseDeliverTx { + return types.ResponseDeliverTx{} +} + +func (h HandlerBase) EndBlock(context.Context, types.RequestEndBlock) types.ResponseEndBlock { + return types.ResponseEndBlock{} +} + +func (h HandlerBase) Commit(context.Context) types.ResponseCommit { + return types.ResponseCommit{} +} + +func (h HandlerBase) ListSnapshots(context.Context, types.RequestListSnapshots) types.ResponseListSnapshots { + return types.ResponseListSnapshots{} +} + +func (h HandlerBase) OfferSnapshot(context.Context, types.RequestOfferSnapshot) types.ResponseOfferSnapshot { + return types.ResponseOfferSnapshot{} +} + +func (h HandlerBase) LoadSnapshotChunk(context.Context, types.RequestLoadSnapshotChunk) types.ResponseLoadSnapshotChunk { + return types.ResponseLoadSnapshotChunk{} +} + +func (h HandlerBase) ApplySnapshotChunk(context.Context, types.RequestApplySnapshotChunk) types.ResponseApplySnapshotChunk { + return types.ResponseApplySnapshotChunk{} +} + +var _ Handler = HandlerBase{} diff --git a/core/abci/header/header.go b/core/abci/header/header.go new file mode 100644 index 000000000000..54169619a09b --- /dev/null +++ b/core/abci/header/header.go @@ -0,0 +1,88 @@ +package header + +import ( + "context" + "fmt" + + tmproto "github.com/tendermint/tendermint/proto/tendermint/types" + + sdk "github.com/cosmos/cosmos-sdk/types" + + "github.com/tendermint/tendermint/abci/types" + + "github.com/cosmos/cosmos-sdk/core/abci" +) + +var Middleware abci.Middleware = func(handler abci.Handler) abci.Handler { + return &middleware{ + Handler: handler, + } +} + +type middleware struct { + abci.Handler + initialHeight int64 + lastBlockHeight int64 +} + +func (m *middleware) InitChain(ctx context.Context, req types.RequestInitChain) types.ResponseInitChain { + // On a new chain, we consider the init chain block height as 0, even though + // req.InitialHeight is 1 by default. + initHeader := tmproto.Header{ChainID: req.ChainId, Time: req.Time} + + // If req.InitialHeight is > 1, then we set the initial version in the + // stores. + if req.InitialHeight > 1 { + m.initialHeight = req.InitialHeight + initHeader = tmproto.Header{ChainID: req.ChainId, Height: req.InitialHeight, Time: req.Time} + } + + sdkCtx := sdk.UnwrapSDKContext(ctx) + sdkCtx = sdkCtx.WithBlockHeader(initHeader) + ctx = context.WithValue(ctx, sdk.SdkContextKey, sdkCtx) + + return m.Handler.InitChain(ctx, req) +} + +func (m *middleware) BeginBlock(ctx context.Context, req types.RequestBeginBlock) types.ResponseBeginBlock { + if err := m.validateHeight(req); err != nil { + panic(err) + } + + m.lastBlockHeight = req.Header.Height + + sdkCtx := sdk.UnwrapSDKContext(ctx) + sdkCtx = sdkCtx. + WithBlockHeader(req.Header). + WithBlockHeight(req.Header.Height) + ctx = context.WithValue(ctx, sdk.SdkContextKey, sdkCtx) + + return m.Handler.BeginBlock(ctx, req) +} + +func (m middleware) validateHeight(req types.RequestBeginBlock) error { + if req.Header.Height < 1 { + return fmt.Errorf("invalid height: %d", req.Header.Height) + } + + // expectedHeight holds the expected height to validate. + var expectedHeight int64 + if m.lastBlockHeight == 0 && m.initialHeight > 1 { + // In this case, we're validating the first block of the chain (no + // previous commit). The height we're expecting is the initial height. + expectedHeight = m.initialHeight + } else { + // This case can means two things: + // - either there was already a previous commit in the store, in which + // case we increment the version from there, + // - or there was no previous commit, and initial version was not set, + // in which case we start at version 1. + expectedHeight = m.lastBlockHeight + 1 + } + + if req.Header.Height != expectedHeight { + return fmt.Errorf("invalid height: %d; expected: %d", req.Header.Height, expectedHeight) + } + + return nil +} diff --git a/core/abci/middleware.go b/core/abci/middleware.go new file mode 100644 index 000000000000..f82f0f4bf86d --- /dev/null +++ b/core/abci/middleware.go @@ -0,0 +1,3 @@ +package abci + +type Middleware func(Handler) Handler diff --git a/core/app_config/app_config.go b/core/app_config/app_config.go index 7b655571b1a5..0c6c13ad08f8 100644 --- a/core/app_config/app_config.go +++ b/core/app_config/app_config.go @@ -1,24 +1,28 @@ package app_config import ( + "context" "fmt" "reflect" - container2 "github.com/cosmos/cosmos-sdk/core/container" + "github.com/tendermint/tendermint/abci/types" + + "github.com/cosmos/cosmos-sdk/core/abci" + + "github.com/cosmos/cosmos-sdk/core/container" "github.com/gogo/protobuf/proto" - "github.com/tendermint/tendermint/abci/types" codectypes "github.com/cosmos/cosmos-sdk/codec/types" "github.com/cosmos/cosmos-sdk/core/module" "github.com/cosmos/cosmos-sdk/core/module/app" ) -func Compose(config AppConfig, moduleRegistry *module.Registry) (types.Application, error) { +func Compose(config AppConfig, moduleRegistry *module.Registry) (abci.Handler, error) { interfaceRegistry := codectypes.NewInterfaceRegistry() - container := container2.NewContainer() + cont := container.NewContainer() modSet := &moduleSet{ - container: container, + container: cont, modMap: map[string]app.Handler{}, configMap: map[string]*ModuleConfig{}, } @@ -64,11 +68,65 @@ func Compose(config AppConfig, moduleRegistry *module.Registry) (types.Applicati } type moduleSet struct { - container *container2.Container + container *container.Container modMap map[string]app.Handler configMap map[string]*ModuleConfig } +func (ms *moduleSet) Info(ctx context.Context, req types.RequestInfo) types.ResponseInfo { + panic("implement me") +} + +func (ms *moduleSet) SetOption(ctx context.Context, req types.RequestSetOption) types.ResponseSetOption { + panic("implement me") +} + +func (ms *moduleSet) Query(ctx context.Context, req types.RequestQuery) types.ResponseQuery { + panic("implement me") +} + +func (ms *moduleSet) CheckTx(ctx context.Context, req types.RequestCheckTx) types.ResponseCheckTx { + panic("implement me") +} + +func (ms *moduleSet) InitChain(ctx context.Context, req types.RequestInitChain) types.ResponseInitChain { + panic("implement me") +} + +func (ms *moduleSet) BeginBlock(ctx context.Context, req types.RequestBeginBlock) types.ResponseBeginBlock { + panic("implement me") +} + +func (ms *moduleSet) DeliverTx(ctx context.Context, req types.RequestDeliverTx) types.ResponseDeliverTx { + panic("implement me") +} + +func (ms *moduleSet) EndBlock(ctx context.Context, req types.RequestEndBlock) types.ResponseEndBlock { + panic("implement me") +} + +func (ms *moduleSet) Commit(ctx context.Context) types.ResponseCommit { + panic("implement me") +} + +func (ms *moduleSet) ListSnapshots(ctx context.Context, req types.RequestListSnapshots) types.ResponseListSnapshots { + panic("implement me") +} + +func (ms *moduleSet) OfferSnapshot(ctx context.Context, req types.RequestOfferSnapshot) types.ResponseOfferSnapshot { + panic("implement me") +} + +func (ms *moduleSet) LoadSnapshotChunk(ctx context.Context, req types.RequestLoadSnapshotChunk) types.ResponseLoadSnapshotChunk { + panic("implement me") +} + +func (ms *moduleSet) ApplySnapshotChunk(ctx context.Context, req types.RequestApplySnapshotChunk) types.ResponseApplySnapshotChunk { + panic("implement me") +} + +var _ abci.Handler = &moduleSet{} + func (ms *moduleSet) addModule(interfaceRegistry codectypes.InterfaceRegistry, registry *module.Registry, config *ModuleConfig) error { ms.configMap[config.Name] = config @@ -97,16 +155,18 @@ func (ms *moduleSet) addModule(interfaceRegistry codectypes.InterfaceRegistry, r ctrTyp := ctrVal.Type() numIn := ctrTyp.NumIn() - var needs []container2.Key + var needs []container.Input for i := 1; i < numIn; i++ { argTy := ctrTyp.In(i) - needs = append(needs, container2.Key{ - Type: argTy, + needs = append(needs, container.Input{ + Key: container.Key{ + Type: argTy, + }, }) } numOut := ctrTyp.NumIn() - var provides []container2.Key + var provides []container.Output for i := 1; i < numOut; i++ { argTy := ctrTyp.Out(i) @@ -115,13 +175,17 @@ func (ms *moduleSet) addModule(interfaceRegistry codectypes.InterfaceRegistry, r continue } - provides = append(provides, container2.Key{ - Type: argTy, - }) + provides = append(provides, + container.Output{ + Key: container.Key{ + Type: argTy, + }, + }, + ) } - return ms.container.RegisterProvider(container2.Provider{ - Constructor: func(deps []reflect.Value) ([]reflect.Value, error) { + return ms.container.RegisterProvider(container.Provider{ + Constructor: func(deps []reflect.Value, _ container.Scope) ([]reflect.Value, error) { args := []reflect.Value{reflect.ValueOf(msg)} args = append(args, deps...) res := ctrVal.Call(args) @@ -149,7 +213,7 @@ func (ms *moduleSet) addModule(interfaceRegistry codectypes.InterfaceRegistry, r }, Needs: needs, Provides: provides, - Scope: config.Name, + Scope: container.Scope(config.Name), }) } diff --git a/core/container/container.go b/core/container/container.go index ef1d32840eee..69bdc0f5fb1c 100644 --- a/core/container/container.go +++ b/core/container/container.go @@ -32,11 +32,11 @@ func NewContainer() *Container { } type Input struct { - Key Key + Key Optional bool } -type SecureOutput struct { +type Output struct { Key SecurityChecker SecurityChecker } @@ -48,7 +48,7 @@ type Key struct { type Scope string type node struct { - *Provider + Provider called bool values []reflect.Value err error @@ -60,20 +60,22 @@ type node struct { // be restricted to certain scopes based on SecurityCheckers. type Provider struct { // Constructor provides the dependencies - Constructor func(deps []reflect.Value) ([]reflect.Value, error) + Constructor func(deps []reflect.Value, scope Scope) ([]reflect.Value, error) // Needs are the keys for dependencies the constructor needs Needs []Input // Needs are the keys for dependencies the constructor provides - Provides []SecureOutput + Provides []Output // Scope is the scope within which the constructor runs Scope Scope + + IsScopeProvider bool } type scopeNode struct { - *ScopeProvider + Provider calledForScope map[Scope]bool valuesForScope map[Scope][]reflect.Value errsForScope map[Scope]error @@ -108,46 +110,75 @@ type secureValue struct { type SecurityChecker func(scope Scope) error -func (c *Container) RegisterProvider(provider *Provider) error { - n := &node{ - Provider: provider, - called: false, - } +func (c *Container) RegisterProvider(provider Provider) error { + if !provider.IsScopeProvider { + n := &node{ + Provider: provider, + called: false, + } - c.nodes = append(c.nodes, n) + c.nodes = append(c.nodes, n) - for _, key := range provider.Provides { - if c.providers[key.Key] != nil { - return fmt.Errorf("TODO") - } + for _, key := range provider.Provides { + if c.providers[key.Key] != nil { + return fmt.Errorf("TODO") + } - c.providers[key.Key] = n - } + if c.scopeProviders[key.Key] != nil { + return fmt.Errorf("TODO") + } - return nil -} + c.providers[key.Key] = n + } + } else { + n := &scopeNode{ + Provider: provider, + calledForScope: map[Scope]bool{}, + valuesForScope: map[Scope][]reflect.Value{}, + errsForScope: map[Scope]error{}, + } -func (c *Container) RegisterScopeProvider(provider *ScopeProvider) error { - n := &scopeNode{ - ScopeProvider: provider, - calledForScope: map[Scope]bool{}, - valuesForScope: map[Scope][]reflect.Value{}, - errsForScope: map[Scope]error{}, - } + c.scopeNodes = append(c.scopeNodes, n) - c.scopeNodes = append(c.scopeNodes, n) + for _, key := range provider.Provides { + if c.providers[key.Key] != nil { + return fmt.Errorf("TODO") + } - for _, key := range provider.Provides { - if c.scopeProviders[key] != nil { - return fmt.Errorf("TODO") + if c.scopeProviders[key.Key] != nil { + return fmt.Errorf("TODO") + } + + c.scopeProviders[key.Key] = n } - c.scopeProviders[key] = n + return nil } return nil } +//func (c *Container) RegisterScopeProvider(provider *ScopeProvider) error { +// n := &scopeNode{ +// ScopeProvider: provider, +// calledForScope: map[Scope]bool{}, +// valuesForScope: map[Scope][]reflect.Value{}, +// errsForScope: map[Scope]error{}, +// } +// +// c.scopeNodes = append(c.scopeNodes, n) +// +// for _, key := range provider.Provides { +// if c.scopeProviders[key] != nil { +// return fmt.Errorf("TODO") +// } +// +// c.scopeProviders[key] = n +// } +// +// return nil +//} + func (c *Container) resolve(scope Scope, input Input, stack map[interface{}]bool) (reflect.Value, error) { if scope != "" { if val, ok := c.scopedValues[scope][input.Key]; ok { @@ -181,7 +212,7 @@ func (c *Container) resolve(scope Scope, input Input, stack map[interface{}]bool deps = append(deps, res) } - res, err := provider.Constructor(scope, deps) + res, err := provider.Constructor(deps, scope) provider.calledForScope[scope] = true if err != nil { provider.errsForScope[scope] = err @@ -192,14 +223,14 @@ func (c *Container) resolve(scope Scope, input Input, stack map[interface{}]bool for i, val := range res { p := provider.Provides[i] - if _, ok := c.scopedValues[scope][p]; ok { + if _, ok := c.scopedValues[scope][p.Key]; ok { return reflect.Value{}, fmt.Errorf("value provided twice") } if c.scopedValues[scope] == nil { c.scopedValues[scope] = map[Key]reflect.Value{} } - c.scopedValues[scope][p] = val + c.scopedValues[scope][p.Key] = val } val, ok := c.scopedValues[scope][input.Key] @@ -241,6 +272,10 @@ func (c *Container) resolve(scope Scope, input Input, stack map[interface{}]bool return val, err } + if input.Optional { + return reflect.Zero(input.Type), nil + } + return reflect.Value{}, fmt.Errorf("no provider") } @@ -262,7 +297,7 @@ func (c *Container) execNode(provider *node, stack map[interface{}]bool) error { deps = append(deps, res) } - res, err := provider.Constructor(deps) + res, err := provider.Constructor(deps, "") provider.called = true if err != nil { provider.err = err @@ -328,7 +363,9 @@ func (StructArgs) isStructArgs() {} type isStructArgs interface{ isStructArgs() } -var isStructArgsTyp = reflect.TypeOf((*isStructArgs)(nil)) +var structArgsType = reflect.TypeOf(StructArgs{}) + +var isStructArgsTyp = reflect.TypeOf((*isStructArgs)(nil)).Elem() var scopeTyp = reflect.TypeOf(Scope("")) @@ -350,28 +387,37 @@ func TypeToInput(typ reflect.Type) ([]Input, InMarshaler, error) { for i := 0; i < nFields; i++ { field := typ.Field(i) - fieldInputs, m, err := TypeToInput(field.Type) - if err != nil { - return nil, nil, err - } + if field.Type == structArgsType { + marshalers = append(marshalers, inFieldMarshaler{ + n: 0, + inMarshaler: func(values []reflect.Value) reflect.Value { + return reflect.ValueOf(StructArgs{}) + }, + }) + } else { + fieldInputs, m, err := TypeToInput(field.Type) + if err != nil { + return nil, nil, err + } - optionalTag, ok := field.Tag.Lookup("optional") - if ok { - if len(fieldInputs) == 1 { - if optionalTag != "true" { - return nil, nil, fmt.Errorf("true is the only valid value for the optional tag, got %s", optionalTag) + optionalTag, ok := field.Tag.Lookup("optional") + if ok { + if len(fieldInputs) == 1 { + if optionalTag != "true" { + return nil, nil, fmt.Errorf("true is the only valid value for the optional tag, got %s", optionalTag) + } + fieldInputs[0].Optional = true + } else if len(fieldInputs) > 1 { + return nil, nil, fmt.Errorf("optional tag cannot be applied to nested StructArgs") } - fieldInputs[0].Optional = true - } else if len(fieldInputs) > 1 { - return nil, nil, fmt.Errorf("optional tag cannot be applied to nested StructArgs") } - } - res = append(res, fieldInputs...) - marshalers = append(marshalers, inFieldMarshaler{ - n: len(fieldInputs), - inMarshaler: m, - }) + res = append(res, fieldInputs...) + marshalers = append(marshalers, inFieldMarshaler{ + n: len(fieldInputs), + inMarshaler: m, + }) + } } return res, structMarshaler(typ, marshalers), nil @@ -388,10 +434,10 @@ func TypeToInput(typ reflect.Type) ([]Input, InMarshaler, error) { } } -func TypeToOutput(typ reflect.Type, securityContext func(scope Scope, tag string) error) ([]SecureOutput, OutMarshaler, error) { +func TypeToOutput(typ reflect.Type, securityContext func(scope Scope, tag string) error) ([]Output, OutMarshaler, error) { if typ.AssignableTo(isStructArgsTyp) && typ.Kind() == reflect.Struct { nFields := typ.NumField() - var res []SecureOutput + var res []Output var marshalers []OutMarshaler for i := 0; i < nFields; i++ { @@ -429,7 +475,7 @@ func TypeToOutput(typ reflect.Type, securityContext func(scope Scope, tag string } else if typ == scopeTyp { return nil, nil, fmt.Errorf("can't convert type %T to %T", Scope(""), Input{}) } else { - return []SecureOutput{{ + return []Output{{ Key: Key{ Type: typ, }, @@ -441,7 +487,7 @@ func TypeToOutput(typ reflect.Type, securityContext func(scope Scope, tag string func structMarshaler(typ reflect.Type, marshalers []inFieldMarshaler) func([]reflect.Value) reflect.Value { return func(values []reflect.Value) reflect.Value { - structInst := reflect.New(typ) + structInst := reflect.Zero(typ) for i, m := range marshalers { val := m.inMarshaler(values[:m.n]) @@ -458,90 +504,108 @@ func (c *Container) Provide(constructor interface{}) error { } func (c *Container) ProvideWithScope(constructor interface{}, scope Scope) error { - p, sp, err := ConstructorToProvider(constructor, scope, c.securityContext) + p, err := ConstructorToProvider(constructor, scope, c.securityContext) if err != nil { return err } - if p != nil { - return c.RegisterProvider(p) - } - - if sp != nil { - return c.RegisterScopeProvider(sp) - } - - return fmt.Errorf("unexpected case") + return c.RegisterProvider(p) } -func ConstructorToProvider(constructor interface{}, scope Scope, securityContext func(scope Scope, tag string) error) (*Provider, *ScopeProvider, error) { +func ConstructorToProvider(constructor interface{}, scope Scope, securityContext func(scope Scope, tag string) error) (Provider, error) { ctrTyp := reflect.TypeOf(constructor) if ctrTyp.Kind() != reflect.Func { - return nil, nil, fmt.Errorf("expected function got %T", constructor) + return Provider{}, fmt.Errorf("expected function got %T", constructor) } numIn := ctrTyp.NumIn() - numOut := ctrTyp.NumIn() + numOut := ctrTyp.NumOut() var scopeProvider bool + i := 0 if numIn >= 1 { if in0 := ctrTyp.In(0); in0 == scopeTyp { scopeProvider = true + i = 1 } } - if !scopeProvider { - var inputs []Input - var inMarshalers []inFieldMarshaler - for i := 0; i < numIn; i++ { - in, inMarshaler, err := TypeToInput(ctrTyp.In(i)) - if err != nil { - return nil, nil, err - } - inputs = append(inputs, in...) - inMarshalers = append(inMarshalers, inFieldMarshaler{ - n: len(in), - inMarshaler: inMarshaler, - }) + var inputs []Input + var inMarshalers []inFieldMarshaler + for ; i < numIn; i++ { + in, inMarshaler, err := TypeToInput(ctrTyp.In(i)) + if err != nil { + return Provider{}, err } + inputs = append(inputs, in...) + inMarshalers = append(inMarshalers, inFieldMarshaler{ + n: len(in), + inMarshaler: inMarshaler, + }) + } - var outputs []SecureOutput - var outMarshalers []OutMarshaler - for i := 0; i < numOut; i++ { - out, outMarshaler, err := TypeToOutput(ctrTyp.Out(i), securityContext) - if err != nil { - return nil, nil, err - } - outputs = append(outputs, out...) - outMarshalers = append(outMarshalers, outMarshaler) + var outputs []Output + var outMarshalers []OutMarshaler + for i := 0; i < numOut; i++ { + out, outMarshaler, err := TypeToOutput(ctrTyp.Out(i), securityContext) + if err != nil { + return Provider{}, err } + outputs = append(outputs, out...) + outMarshalers = append(outMarshalers, outMarshaler) + } - ctrVal := reflect.ValueOf(constructor) - provideCtr := func(deps []reflect.Value) ([]reflect.Value, error) { - inVals := make([]reflect.Value, numIn) - for i := 0; i < numIn; i++ { - m := inMarshalers[i] - inVals[i] = m.inMarshaler(deps[m.n:]) - deps = deps[:m.n] - } + ctrVal := reflect.ValueOf(constructor) + provideCtr := func(deps []reflect.Value, scope Scope) ([]reflect.Value, error) { + var inVals []reflect.Value - outVals := ctrVal.Call(inVals) + if scopeProvider { + inVals = append(inVals, reflect.ValueOf(scope)) + } - var provides []reflect.Value - for i := 0; i < numOut; i++ { - provides = append(provides, outMarshalers[i](outVals[i])...) - } + nInMarshalers := len(inMarshalers) + for i = 0; i < nInMarshalers; i++ { + m := inMarshalers[i] + inVals = append(inVals, m.inMarshaler(deps[:m.n])) + deps = deps[m.n:] + } + + outVals := ctrVal.Call(inVals) - return outVals, nil + var provides []reflect.Value + for i := 0; i < numOut; i++ { + provides = append(provides, outMarshalers[i](outVals[i])...) } - return &Provider{ - Constructor: provideCtr, - Needs: inputs, - Provides: outputs, - Scope: scope, - }, nil, nil - } else { + return outVals, nil + } + + return Provider{ + Constructor: provideCtr, + Needs: inputs, + Provides: outputs, + Scope: scope, + IsScopeProvider: scopeProvider, + }, nil +} +func (c *Container) Invoke(fn interface{}) error { + fnTyp := reflect.TypeOf(fn) + if fnTyp.Kind() != reflect.Func { + return fmt.Errorf("expected function got %T", fn) } + + numIn := fnTyp.NumIn() + in := make([]reflect.Value, numIn) + for i := 0; i < numIn; i++ { + val, err := c.Resolve("", Key{Type: fnTyp.In(i)}) + if err != nil { + return err + } + in[i] = val + } + + _ = reflect.ValueOf(fn).Call(in) + + return nil } diff --git a/core/container/container_test.go b/core/container/container_test.go index 5300a759d74e..0fda7f763d14 100644 --- a/core/container/container_test.go +++ b/core/container/container_test.go @@ -1,126 +1,215 @@ package container import ( - "reflect" "testing" "github.com/stretchr/testify/require" ) -type storeKey struct { +type ssKey struct { name string + db db +} + +type scKey struct { + name string + db db +} + +type kvStoreKey struct { + ssKey + scKey } type keeperA struct { - key storeKey + key kvStoreKey } type keeperB struct { - key storeKey + key kvStoreKey a keeperA } +type db struct{} + +func dbProvider() db { + return db{} +} + +func ssKeyProvider(scope Scope, db db) ssKey { + return ssKey{db: db, name: string(scope)} +} + +func scKeyProvider(scope Scope, db db) scKey { + return scKey{db: db, name: string(scope)} +} + +type kvStoreKeyInput struct { + StructArgs + SSKey ssKey + SCKey scKey +} + +func kvStoreKeyProvider(scope Scope, input kvStoreKeyInput) kvStoreKey { + return kvStoreKey{input.SSKey, input.SCKey} +} + +func keeperAProvider(key kvStoreKey) keeperA { + return keeperA{key: key} +} + +func keeperBProvider(key kvStoreKey, a keeperA) keeperB { + return keeperB{key, a} +} + func TestContainer(t *testing.T) { c := NewContainer() - require.NoError(t, c.RegisterProvider(Provider{ - Constructor: func(deps []reflect.Value) ([]reflect.Value, error) { - return []reflect.Value{reflect.ValueOf(keeperA{deps[0].Interface().(storeKey)})}, nil - }, - Needs: []Key{ - { - Type: reflect.TypeOf(storeKey{}), - }, - }, - Provides: []Key{ - { - Type: reflect.TypeOf((*keeperA)(nil)), - }, - }, - Scope: "a", - })) - require.NoError(t, c.RegisterProvider(Provider{ - Constructor: func(deps []reflect.Value) ([]reflect.Value, error) { - return []reflect.Value{reflect.ValueOf(keeperB{ - key: deps[0].Interface().(storeKey), - a: deps[1].Interface().(keeperA), - })}, nil - }, - Needs: []Input{ - { - Key: Key{ - Type: reflect.TypeOf(storeKey{}), + require.NoError(t, c.Provide(dbProvider)) + require.NoError(t, c.Provide(ssKeyProvider)) + require.NoError(t, c.Provide(scKeyProvider)) + require.NoError(t, c.Provide(kvStoreKeyProvider)) + require.NoError(t, c.ProvideWithScope(keeperAProvider, "a")) + require.NoError(t, c.ProvideWithScope(keeperBProvider, "b")) + require.NoError(t, c.Invoke(func(b keeperB) { + require.Equal(t, keeperB{ + key: kvStoreKey{ + ssKey: ssKey{ + name: "b", + db: db{}, }, - }, - { - Key: Key{ - Type: reflect.TypeOf((*keeperA)(nil)), + scKey: scKey{ + name: "b", + db: db{}, }, }, - }, - Provides: []SecureOutput{ - { - Key: Key{ - Type: reflect.TypeOf((*keeperB)(nil)), + a: keeperA{ + key: kvStoreKey{ + ssKey: ssKey{ + name: "a", + db: db{}, + }, + scKey: scKey{ + name: "a", + db: db{}, + }, }, }, - }, - Scope: "b", + }, b) })) - require.NoError(t, c.RegisterScopeProvider( - ScopeProvider{ - Constructor: func(scope Scope, deps []reflect.Value) ([]reflect.Value, error) { - return []reflect.Value{reflect.ValueOf(storeKey{name: scope})}, nil - }, - Needs: nil, - Provides: []Key{ - { - Type: reflect.TypeOf(storeKey{}), - }, - }, - }, - )) - - res, err := c.Resolve("b", Key{Type: reflect.TypeOf((*keeperB)(nil))}) - require.NoError(t, err) - b := res.Interface().(keeperB) - t.Logf("%+v", b) - require.Equal(t, "b", b.key.name) - require.Equal(t, "a", b.a.key.name) } func TestCycle(t *testing.T) { c := NewContainer() - require.NoError(t, c.RegisterProvider(Provider{ - Constructor: func(deps []reflect.Value) ([]reflect.Value, error) { - return nil, nil - }, - Needs: []Key{ - { - Type: reflect.TypeOf((*keeperB)(nil)), - }, - }, - Provides: []Key{ - { - Type: reflect.TypeOf((*keeperA)(nil)), - }, - }, + require.NoError(t, c.Provide(func(a keeperA) keeperB { + return keeperB{} })) - require.NoError(t, c.RegisterProvider(Provider{ - Constructor: func(deps []reflect.Value) ([]reflect.Value, error) { - return nil, nil - }, - Needs: []Key{ - { - Type: reflect.TypeOf((*keeperA)(nil)), - }, - }, - Provides: []Key{ - { - Type: reflect.TypeOf((*keeperB)(nil)), - }, - }, + require.NoError(t, c.Provide(func(a keeperB) keeperA { + return keeperA{} })) - - _, err := c.Resolve("b", Key{Type: reflect.TypeOf((*keeperB)(nil))}) - require.EqualError(t, err, "fatal: cycle detected") + require.EqualError(t, c.Invoke(func(a keeperA) {}), "fatal: cycle detected") } + +//func TestContainer(t *testing.T) { +// c := NewContainer() +// require.NoError(t, c.RegisterProvider(Provider{ +// Constructor: func(deps []reflect.Value) ([]reflect.Value, error) { +// return []reflect.Value{reflect.ValueOf(keeperA{deps[0].Interface().(storeKey)})}, nil +// }, +// Needs: []Key{ +// { +// Type: reflect.TypeOf(storeKey{}), +// }, +// }, +// Provides: []Key{ +// { +// Type: reflect.TypeOf((*keeperA)(nil)), +// }, +// }, +// Scope: "a", +// })) +// require.NoError(t, c.RegisterProvider(Provider{ +// Constructor: func(deps []reflect.Value) ([]reflect.Value, error) { +// return []reflect.Value{reflect.ValueOf(keeperB{ +// key: deps[0].Interface().(storeKey), +// a: deps[1].Interface().(keeperA), +// })}, nil +// }, +// Needs: []Input{ +// { +// Key: Key{ +// Type: reflect.TypeOf(storeKey{}), +// }, +// }, +// { +// Key: Key{ +// Type: reflect.TypeOf((*keeperA)(nil)), +// }, +// }, +// }, +// Provides: []Output{ +// { +// Key: Key{ +// Type: reflect.TypeOf((*keeperB)(nil)), +// }, +// }, +// }, +// Scope: "b", +// })) +// require.NoError(t, c.RegisterScopeProvider( +// ScopeProvider{ +// Constructor: func(scope Scope, deps []reflect.Value) ([]reflect.Value, error) { +// return []reflect.Value{reflect.ValueOf(storeKey{name: scope})}, nil +// }, +// Needs: nil, +// Provides: []Key{ +// { +// Type: reflect.TypeOf(storeKey{}), +// }, +// }, +// }, +// )) +// +// res, err := c.Resolve("b", Key{Type: reflect.TypeOf((*keeperB)(nil))}) +// require.NoError(t, err) +// b := res.Interface().(keeperB) +// t.Logf("%+v", b) +// require.Equal(t, "b", b.key.name) +// require.Equal(t, "a", b.a.key.name) +//} +// +//func TestCycle(t *testing.T) { +// c := NewContainer() +// require.NoError(t, c.RegisterProvider(Provider{ +// Constructor: func(deps []reflect.Value) ([]reflect.Value, error) { +// return nil, nil +// }, +// Needs: []Key{ +// { +// Type: reflect.TypeOf((*keeperB)(nil)), +// }, +// }, +// Provides: []Key{ +// { +// Type: reflect.TypeOf((*keeperA)(nil)), +// }, +// }, +// })) +// require.NoError(t, c.RegisterProvider(Provider{ +// Constructor: func(deps []reflect.Value) ([]reflect.Value, error) { +// return nil, nil +// }, +// Needs: []Key{ +// { +// Type: reflect.TypeOf((*keeperA)(nil)), +// }, +// }, +// Provides: []Key{ +// { +// Type: reflect.TypeOf((*keeperB)(nil)), +// }, +// }, +// })) +// +// _, err := c.Resolve("b", Key{Type: reflect.TypeOf((*keeperB)(nil))}) +// require.EqualError(t, err, "fatal: cycle detected") +//} diff --git a/core/store/kv/kv.go b/core/store/kv/kv.go new file mode 100644 index 000000000000..64d96d49e0a6 --- /dev/null +++ b/core/store/kv/kv.go @@ -0,0 +1,20 @@ +package kv + +import ( + "context" + + "github.com/cosmos/cosmos-sdk/core/container" + "github.com/cosmos/cosmos-sdk/core/store" + "github.com/cosmos/cosmos-sdk/core/store/sc" + "github.com/cosmos/cosmos-sdk/core/store/ss" +) + +type StoreKey struct{} + +func StoreKeyProvider(scope container.Scope, ssKey ss.StoreKey, scKey sc.StoreKey) StoreKey { + panic("TODO") +} + +func (StoreKey) Open(context.Context) store.KVStore { + panic("TODO") +} diff --git a/core/store/sc/sc.go b/core/store/sc/sc.go new file mode 100644 index 000000000000..628367f65a1b --- /dev/null +++ b/core/store/sc/sc.go @@ -0,0 +1,18 @@ +package sc + +import ( + "context" + + "github.com/cosmos/cosmos-sdk/core/container" + "github.com/cosmos/cosmos-sdk/core/store" +) + +type StoreKey struct{} + +func StoreKeyProvider(scope container.Scope) StoreKey { + panic("TODO") +} + +func (StoreKey) Open(context.Context) store.BasicKVStore { + panic("TODO") +} diff --git a/core/store/ss/ss.go b/core/store/ss/ss.go new file mode 100644 index 000000000000..278ebb7aa2c7 --- /dev/null +++ b/core/store/ss/ss.go @@ -0,0 +1,18 @@ +package ss + +import ( + "context" + + "github.com/cosmos/cosmos-sdk/core/container" + "github.com/cosmos/cosmos-sdk/core/store" +) + +type StoreKey struct{} + +func StoreKeyProvider(scope container.Scope) StoreKey { + panic("TODO") +} + +func (StoreKey) Open(context.Context) store.KVStore { + panic("TODO") +} diff --git a/core/tx/module/module.go b/core/tx/module/module.go index 02481439ec3e..d54e64394767 100644 --- a/core/tx/module/module.go +++ b/core/tx/module/module.go @@ -8,17 +8,18 @@ import ( tx2 "github.com/cosmos/cosmos-sdk/types/tx" + abci "github.com/tendermint/tendermint/abci/types" + "google.golang.org/grpc" + "github.com/cosmos/cosmos-sdk/codec" "github.com/cosmos/cosmos-sdk/codec/types" "github.com/cosmos/cosmos-sdk/core/module/app" "github.com/cosmos/cosmos-sdk/core/tx" - abci "github.com/tendermint/tendermint/abci/types" - "google.golang.org/grpc" ) -func init() { - app.RegisterAppModule(appModule{}) -} +//func init() { +// app.RegisterAppModule(appModule{}) +//} type appModule struct { Config *tx.Module @@ -49,7 +50,7 @@ func (a appModule) RegisterQueryServices(registrar grpc.ServiceRegistrar) {} func (a appModule) TxHandler(params app_config.TxHandlerParams) app_config.TxHandler { return txHandler{ - Module: a.Module, + Module: a.Config, msgRouter: params.MsgRouter, } } @@ -64,10 +65,10 @@ type MiddlewareRegistrar interface { type MiddlewareFactory func(config interface{}) Middleware -type Middleware interface { - OnCheckTx(ctx context.Context, tx tx2.Tx, req abci.RequestCheckTx, next TxHandler) (abci.ResponseCheckTx, error) - OnDeliverTx(ctx context.Context, tx tx2.Tx, req abci.RequestDeliverTx, next TxHandler) (abci.ResponseDeliverTx, error) -} +//type Middleware interface { +// OnCheckTx(ctx context.Context, tx tx2.Tx, req abci.RequestCheckTx, next TxHandler) (abci.ResponseCheckTx, error) +// OnDeliverTx(ctx context.Context, tx tx2.Tx, req abci.RequestDeliverTx, next TxHandler) (abci.ResponseDeliverTx, error) +//} type TxHandler interface { CheckTx(ctx context.Context, tx tx2.Tx, req abci.RequestCheckTx) (abci.ResponseCheckTx, error) @@ -86,3 +87,5 @@ func (t txHandler) CheckTx(ctx context.Context, req abci.RequestCheckTx) (abci.R func (t txHandler) DeliverTx(ctx context.Context, req abci.RequestDeliverTx) (abci.ResponseDeliverTx, error) { panic("implement me") } + +type Middleware func(TxHandler) TxHandler diff --git a/x/authn/app/middleware.go b/x/authn/app/middleware.go index f358ec69e884..9acedae8d2a7 100644 --- a/x/authn/app/middleware.go +++ b/x/authn/app/middleware.go @@ -3,48 +3,45 @@ package app import ( "context" - "github.com/cosmos/cosmos-sdk/core/app_config" - abci "github.com/tendermint/tendermint/abci/types" - "github.com/cosmos/cosmos-sdk/core/module/app" + "github.com/cosmos/cosmos-sdk/core/tx/module" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" "github.com/cosmos/cosmos-sdk/types/tx" "github.com/cosmos/cosmos-sdk/x/authn" ) type validateMemoMiddlewareHandler struct { + module.TxHandler *authn.ValidateMemoMiddleware } -func (v validateMemoMiddlewareHandler) validate(tx tx.Tx) error { - memoLength := len(tx.Body.Memo) - if uint64(memoLength) > v.MaxMemoCharacters { - return sdkerrors.Wrapf(sdkerrors.ErrMemoTooLarge, - "maximum number of characters is %d but received %d characters", - v.MaxMemoCharacters, memoLength, - ) - } - - return nil -} - -func (v validateMemoMiddlewareHandler) OnCheckTx(ctx context.Context, tx tx.Tx, req abci.RequestCheckTx, next app_config.TxHandler) (abci.ResponseCheckTx, error) { +func (v validateMemoMiddlewareHandler) CheckTx(ctx context.Context, tx tx.Tx, req abci.RequestCheckTx) (abci.ResponseCheckTx, error) { err := v.validate(tx) if err != nil { return abci.ResponseCheckTx{}, err } - return next.CheckTx(ctx, tx, req) + return v.TxHandler.CheckTx(ctx, tx, req) } -func (v validateMemoMiddlewareHandler) OnDeliverTx(ctx context.Context, tx tx.Tx, req abci.RequestDeliverTx, next app_config.TxHandler) (abci.ResponseDeliverTx, error) { +func (v validateMemoMiddlewareHandler) DeliverTx(ctx context.Context, tx tx.Tx, req abci.RequestDeliverTx) (abci.ResponseDeliverTx, error) { err := v.validate(tx) if err != nil { return abci.ResponseDeliverTx{}, err } - return next.DeliverTx(ctx, tx, req) + return v.TxHandler.DeliverTx(ctx, tx, req) } -var _ app.TxMiddleware = validateMemoMiddlewareHandler{} +func (v validateMemoMiddlewareHandler) validate(tx tx.Tx) error { + memoLength := len(tx.Body.Memo) + if uint64(memoLength) > v.MaxMemoCharacters { + return sdkerrors.Wrapf(sdkerrors.ErrMemoTooLarge, + "maximum number of characters is %d but received %d characters", + v.MaxMemoCharacters, memoLength, + ) + } + + return nil +} diff --git a/x/authn/module/module.go b/x/authn/module/module.go index f5ee367e0fc5..46343a74b874 100644 --- a/x/authn/module/module.go +++ b/x/authn/module/module.go @@ -1,19 +1,11 @@ package module -import ( - codec2 "github.com/cosmos/cosmos-sdk/codec" - "github.com/cosmos/cosmos-sdk/core/codec" - "github.com/cosmos/cosmos-sdk/core/module/app" - "github.com/cosmos/cosmos-sdk/core/store" - "github.com/cosmos/cosmos-sdk/x/authn" -) - -var _ codec.TypeProvider = Module{} - -func (m Module) RegisterTypes(registry codec.TypeRegistry) { - authn.RegisterTypes(registry) -} - -func (m Module) NewAppHandler(cdc codec2.Codec, storeKey store.KVStoreKey) app.Handler { - -} +//var _ codec.TypeProvider = Module{} +// +//func (m Module) RegisterTypes(registry codec.TypeRegistry) { +// authn.RegisterTypes(registry) +//} +// +//func (m Module) NewAppHandler(cdc codec2.Codec, storeKey store.KVStoreKey) app.Handler { +// +//}