diff --git a/backend.go b/backend.go index fa14be5..8f1f9f5 100644 --- a/backend.go +++ b/backend.go @@ -36,11 +36,11 @@ func Backend() *backend { }, Paths: framework.PathAppend( []*framework.Path{ - pathConfig(b), - pathConfigLdap(b), - pathLogin(b), - pathGroups(b), - pathGroupsList(b), + b.pathConfig(), + b.pathConfigLdap(), + b.pathLogin(), + b.pathGroups(), + b.pathGroupsList(), }, ), } diff --git a/path_config.go b/path_config.go index a6b73f3..22c8de3 100644 --- a/path_config.go +++ b/path_config.go @@ -3,7 +3,6 @@ package kerberos import ( "context" "encoding/base64" - "errors" "fmt" "github.com/hashicorp/vault/sdk/framework" @@ -16,7 +15,7 @@ type kerberosConfig struct { ServiceAccount string `json:"service_account"` } -func pathConfig(b *backend) *framework.Path { +func (b *backend) pathConfig() *framework.Path { return &framework.Path{ Pattern: "config$", Fields: map[string]*framework.FieldSchema{ @@ -29,10 +28,16 @@ func pathConfig(b *backend) *framework.Path { Description: `Service Account`, }, }, - Callbacks: map[logical.Operation]framework.OperationFunc{ - logical.UpdateOperation: b.pathConfigWrite, - logical.CreateOperation: b.pathConfigWrite, - logical.ReadOperation: b.pathConfigRead, + Operations: map[logical.Operation]framework.OperationHandler{ + logical.UpdateOperation: &framework.PathOperation{ + Callback: b.pathConfigWrite, + }, + logical.CreateOperation: &framework.PathOperation{ + Callback: b.pathConfigWrite, + }, + logical.ReadOperation: &framework.PathOperation{ + Callback: b.pathConfigRead, + }, }, HelpSynopsis: confHelpSynopsis, @@ -58,22 +63,21 @@ func (b *backend) pathConfigRead(ctx context.Context, req *logical.Request, data func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { serviceAccount := data.Get("service_account").(string) if serviceAccount == "" { - return nil, errors.New("data does not contain service_account") + return logical.ErrorResponse("data does not contain service_account"), logical.ErrInvalidRequest } kt := data.Get("keytab").(string) if kt == "" { - return nil, errors.New("data does not contain keytab") + return logical.ErrorResponse("data does not contain keytab"), logical.ErrInvalidRequest } // Check that the keytab is valid by parsing with krb5go binary, err := base64.StdEncoding.DecodeString(kt) if err != nil { - return nil, fmt.Errorf("could not base64 decode keytab: %v", err) + return logical.ErrorResponse(fmt.Sprintf("could not base64 decode keytab: %v", err)), logical.ErrInvalidRequest } - _, err = keytab.Parse(binary) - if err != nil { - return nil, fmt.Errorf("invalid keytab: %v", err) + if _, err = keytab.Parse(binary); err != nil { + return logical.ErrorResponse(fmt.Sprintf("invalid keytab: %v", err)), logical.ErrInvalidRequest } config := &kerberosConfig{ diff --git a/path_config_ldap.go b/path_config_ldap.go index 4aea322..9bd0e0c 100644 --- a/path_config_ldap.go +++ b/path_config_ldap.go @@ -2,6 +2,7 @@ package kerberos import ( "context" + "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/ldaputil" "github.com/hashicorp/vault/sdk/helper/tokenutil" @@ -10,14 +11,18 @@ import ( const ldapConfPath = "config/ldap" -func pathConfigLdap(b *backend) *framework.Path { +func (b *backend) pathConfigLdap() *framework.Path { p := &framework.Path{ Pattern: ldapConfPath, Fields: ldaputil.ConfigFields(), - Callbacks: map[logical.Operation]framework.OperationFunc{ - logical.ReadOperation: b.pathConfigLdapRead, - logical.UpdateOperation: b.pathConfigLdapWrite, + Operations: map[logical.Operation]framework.OperationHandler{ + logical.ReadOperation: &framework.PathOperation{ + Callback: b.pathConfigLdapRead, + }, + logical.UpdateOperation: &framework.PathOperation{ + Callback: b.pathConfigLdapWrite, + }, }, HelpSynopsis: pathConfigLdapHelpSyn, @@ -64,7 +69,7 @@ func (b *backend) pathConfigLdapRead(ctx context.Context, req *logical.Request, func (b *backend) pathConfigLdapWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { cfg, err := b.ConfigLdap(ctx, req) if err != nil { - return logical.ErrorResponse(err.Error()), nil + return nil, err } var prevLDAPCfg *ldaputil.ConfigEntry @@ -78,7 +83,7 @@ func (b *backend) pathConfigLdapWrite(ctx context.Context, req *logical.Request, newLdapCfg, err := ldaputil.NewConfigEntry(prevLDAPCfg, d) if err != nil { - return nil, err + return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest } cfg.ConfigEntry = newLdapCfg diff --git a/path_config_test.go b/path_config_test.go index a8e0eb7..482d83a 100644 --- a/path_config_test.go +++ b/path_config_test.go @@ -24,8 +24,7 @@ func getTestBackend(t *testing.T) (logical.Backend, logical.Storage) { } b := Backend() - err := b.Setup(context.Background(), config) - if err != nil { + if err := b.Setup(context.Background(), config); err != nil { t.Fatalf("unable to create backend: %v", err) } @@ -101,11 +100,14 @@ func testConfigWriteError(t *testing.T, b logical.Backend, storage logical.Stora Data: data, } - _, err := b.HandleRequest(context.Background(), req) + resp, err := b.HandleRequest(context.Background(), req) if err == nil { t.Fatal("expected error") } - if !strings.HasPrefix(err.Error(), e) { + if err.Error() != "invalid request" { + t.Fatal("expected invalid request") + } + if !strings.HasPrefix(resp.Error().Error(), e) { t.Fatalf("got unexpected error: %v, expected %v", err.Error(), e) } } diff --git a/path_groups.go b/path_groups.go index 9462146..95ce814 100644 --- a/path_groups.go +++ b/path_groups.go @@ -8,12 +8,14 @@ import ( "github.com/hashicorp/vault/sdk/logical" ) -func pathGroupsList(b *backend) *framework.Path { +func (b *backend) pathGroupsList() *framework.Path { return &framework.Path{ Pattern: "groups/?$", - Callbacks: map[logical.Operation]framework.OperationFunc{ - logical.ListOperation: b.pathGroupList, + Operations: map[logical.Operation]framework.OperationHandler{ + logical.ListOperation: &framework.PathOperation{ + Callback: b.pathGroupList, + }, }, HelpSynopsis: pathGroupHelpSyn, @@ -21,25 +23,31 @@ func pathGroupsList(b *backend) *framework.Path { } } -func pathGroups(b *backend) *framework.Path { +func (b *backend) pathGroups() *framework.Path { return &framework.Path{ Pattern: `groups/(?P.+)`, Fields: map[string]*framework.FieldSchema{ - "name": &framework.FieldSchema{ + "name": { Type: framework.TypeString, Description: "Name of the LDAP group.", }, - "policies": &framework.FieldSchema{ + "policies": { Type: framework.TypeCommaStringSlice, Description: "Comma-separated list of policies associated to the group.", }, }, - Callbacks: map[logical.Operation]framework.OperationFunc{ - logical.DeleteOperation: b.pathGroupDelete, - logical.ReadOperation: b.pathGroupRead, - logical.UpdateOperation: b.pathGroupWrite, + Operations: map[logical.Operation]framework.OperationHandler{ + logical.DeleteOperation: &framework.PathOperation{ + Callback: b.pathGroupDelete, + }, + logical.ReadOperation: &framework.PathOperation{ + Callback: b.pathGroupRead, + }, + logical.UpdateOperation: &framework.PathOperation{ + Callback: b.pathGroupWrite, + }, }, HelpSynopsis: pathGroupHelpSyn, diff --git a/path_login.go b/path_login.go index e9a68e7..bc07773 100644 --- a/path_login.go +++ b/path_login.go @@ -12,25 +12,28 @@ import ( "github.com/hashicorp/vault/sdk/helper/ldaputil" "github.com/hashicorp/vault/sdk/helper/strutil" "github.com/hashicorp/vault/sdk/logical" - "gopkg.in/jcmturner/gokrb5.v5/credentials" "gopkg.in/jcmturner/gokrb5.v5/gssapi" "gopkg.in/jcmturner/gokrb5.v5/keytab" "gopkg.in/jcmturner/gokrb5.v5/service" ) -func pathLogin(b *backend) *framework.Path { +func (b *backend) pathLogin() *framework.Path { return &framework.Path{ Pattern: "login$", Fields: map[string]*framework.FieldSchema{ - "authorization": &framework.FieldSchema{ + "authorization": { Type: framework.TypeString, Description: `SPNEGO Authorization header. Required.`, }, }, - Callbacks: map[logical.Operation]framework.OperationFunc{ - logical.ReadOperation: b.pathLoginGet, - logical.UpdateOperation: b.pathLogin, + Operations: map[logical.Operation]framework.OperationHandler{ + logical.ReadOperation: &framework.PathOperation{ + Callback: b.pathLoginGet, + }, + logical.UpdateOperation: &framework.PathOperation{ + Callback: b.pathLoginUpdate, + }, }, } } @@ -47,27 +50,33 @@ func parseKeytab(stringKeytab string) (*keytab.Keytab, error) { return &kt, nil } -func spnegoKrb5Authenticate(kt keytab.Keytab, sa string, authorization []byte, remoteAddr string) (bool, *credentials.Credentials, error) { +func spnegoKrb5Authenticate(kt keytab.Keytab, sa string, authorization []byte, remoteAddr string) (*credentials.Credentials, error) { var spnego gssapi.SPNEGO - err := spnego.Unmarshal(authorization) - if err != nil || !spnego.Init { - return false, nil, fmt.Errorf("SPNEGO negotiation token is not a NegTokenInit: %v", err) + if err := spnego.Unmarshal(authorization); err != nil || !spnego.Init { + return nil, fmt.Errorf("SPNEGO negotiation token is not a NegTokenInit: %v", err) } if !spnego.NegTokenInit.MechTypes[0].Equal(gssapi.MechTypeOIDKRB5) && !spnego.NegTokenInit.MechTypes[0].Equal(gssapi.MechTypeOIDMSLegacyKRB5) { - return false, nil, errors.New("SPNEGO OID of MechToken is not of type KRB5") + return nil, errors.New("SPNEGO OID of MechToken is not of type KRB5") } var mt gssapi.MechToken - err = mt.Unmarshal(spnego.NegTokenInit.MechToken) - if err != nil { - return false, nil, fmt.Errorf("SPNEGO error unmarshaling MechToken: %v", err) + if err := mt.Unmarshal(spnego.NegTokenInit.MechToken); err != nil { + return nil, fmt.Errorf("SPNEGO error unmarshaling MechToken: %v", err) } if !mt.IsAPReq() { - return false, nil, errors.New("MechToken does not contain an AP_REQ - KRB_AP_ERR_MSG_TYPE") + return nil, errors.New("MechToken does not contain an AP_REQ - KRB_AP_ERR_MSG_TYPE") } - ok, creds, err := service.ValidateAPREQ(mt.APReq, kt, sa, remoteAddr, false) - return ok, &creds, err + // The first return value here is a boolean reflecting whether the request is valid; + // however, this value is redundant because if the error is nil, the request is valid, + // but if it's populated, the request is invalid. Hence, it's ignored here because we + // only need to error if the error is populated, and that knowledge can be encapsulated + // here. + _, creds, err := service.ValidateAPREQ(mt.APReq, kt, sa, remoteAddr, false) + if err != nil { + return nil, err + } + return &creds, nil } func (b *backend) pathLoginGet(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { @@ -79,18 +88,18 @@ func (b *backend) pathLoginGet(ctx context.Context, req *logical.Request, d *fra }, logical.CodedError(401, "authentication required") } -func (b *backend) pathLogin(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { +func (b *backend) pathLoginUpdate(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { kerbCfg, err := b.config(ctx, req.Storage) if err != nil { return nil, err } if kerbCfg == nil { - return nil, errors.New("Could not load backend configuration") + return nil, errors.New("backend kerberos not configured") } kt, err := parseKeytab(kerbCfg.Keytab) if err != nil { - return nil, fmt.Errorf("Could not load keytab: %v", err) + return nil, err } ldapCfg, err := b.ConfigLdap(ctx, req) @@ -119,7 +128,7 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, d *framew ldapConnection, err := ldapClient.DialLDAP(ldapCfg.ConfigEntry) if err != nil { - return nil, fmt.Errorf("Could not connect to LDAP: %v", err) + return nil, fmt.Errorf("could not connect to LDAP: %v", err) } if ldapConnection == nil { return nil, errors.New("invalid connection returned from LDAP dial") @@ -142,16 +151,12 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, d *framew } authorization, err := base64.StdEncoding.DecodeString(s[1]) if err != nil { - return nil, fmt.Errorf("Could not base64 decode authorization: %v", err) + return nil, fmt.Errorf("could not base64 decode authorization: %v", err) } - ok, creds, err := spnegoKrb5Authenticate(*kt, kerbCfg.ServiceAccount, authorization, req.Connection.RemoteAddr) - if !ok { - if err != nil { - return nil, err - } else { - return logical.ErrorResponse("Kerberos authentication failed"), nil - } + creds, err := spnegoKrb5Authenticate(*kt, kerbCfg.ServiceAccount, authorization, req.Connection.RemoteAddr) + if err != nil { + return nil, err } if len(ldapCfg.BindPassword) > 0 { @@ -167,10 +172,7 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, d *framew if err != nil { return nil, err } - - if b.Logger().IsDebug() { - b.Logger().Debug("auth/ldap: User BindDN fetched", "username", creds.Username, "binddn", userBindDN) - } + b.Logger().Debug("auth/ldap: User BindDN fetched", "username", creds.Username, "binddn", userBindDN) userDN, err := ldapClient.GetUserDN(ldapCfg.ConfigEntry, ldapConnection, userBindDN) if err != nil { @@ -181,9 +183,7 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, d *framew if err != nil { return nil, err } - if b.Logger().IsDebug() { - b.Logger().Debug("auth/ldap: Groups fetched from server", "num_server_groups", len(ldapGroups), "server_groups", ldapGroups) - } + b.Logger().Debug("auth/ldap: Groups fetched from server", "num_server_groups", len(ldapGroups), "server_groups", ldapGroups) var allGroups []string // Merge local and LDAP groups @@ -193,9 +193,15 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, d *framew var policies []string for _, groupName := range allGroups { group, err := b.Group(ctx, req.Storage, groupName) - if err == nil && group != nil { - policies = append(policies, group.Policies...) + if err != nil { + b.Logger().Warn(fmt.Sprintf("unable to retrieve %s: %s", groupName, err.Error())) + continue + } + if group == nil { + b.Logger().Warn(fmt.Sprintf("unable to find %s, does not currently exist", groupName)) + continue } + policies = append(policies, group.Policies...) } // Policies from each group may overlap policies = strutil.RemoveDuplicates(policies, true)