Skip to content

Commit

Permalink
code maintenance
Browse files Browse the repository at this point in the history
  • Loading branch information
tyrannosaurus-becks committed Aug 6, 2019
1 parent 499a627 commit 1c744d6
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 75 deletions.
10 changes: 5 additions & 5 deletions backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ func Backend() *backend {
},
Paths: framework.PathAppend(
[]*framework.Path{
pathConfig(b),
pathConfigLdap(b),
pathLogin(b),
pathGroups(b),
pathGroupsList(b),
b.pathConfig(),
b.pathConfigLdap(),
b.pathLogin(),
b.pathGroups(),
b.pathGroupsList(),
},
),
}
Expand Down
28 changes: 16 additions & 12 deletions path_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package kerberos
import (
"context"
"encoding/base64"
"errors"
"fmt"

"github.com/hashicorp/vault/sdk/framework"
Expand All @@ -16,7 +15,7 @@ type kerberosConfig struct {
ServiceAccount string `json:"service_account"`
}

func pathConfig(b *backend) *framework.Path {
func (b *backend) pathConfig() *framework.Path {
return &framework.Path{
Pattern: "config$",
Fields: map[string]*framework.FieldSchema{
Expand All @@ -29,10 +28,16 @@ func pathConfig(b *backend) *framework.Path {
Description: `Service Account`,
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.UpdateOperation: b.pathConfigWrite,
logical.CreateOperation: b.pathConfigWrite,
logical.ReadOperation: b.pathConfigRead,
Operations: map[logical.Operation]framework.OperationHandler{
logical.UpdateOperation: &framework.PathOperation{
Callback: b.pathConfigWrite,
},
logical.CreateOperation: &framework.PathOperation{
Callback: b.pathConfigWrite,
},
logical.ReadOperation: &framework.PathOperation{
Callback: b.pathConfigRead,
},
},

HelpSynopsis: confHelpSynopsis,
Expand All @@ -58,22 +63,21 @@ func (b *backend) pathConfigRead(ctx context.Context, req *logical.Request, data
func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
serviceAccount := data.Get("service_account").(string)
if serviceAccount == "" {
return nil, errors.New("data does not contain service_account")
return logical.ErrorResponse("data does not contain service_account"), logical.ErrInvalidRequest
}

kt := data.Get("keytab").(string)
if kt == "" {
return nil, errors.New("data does not contain keytab")
return logical.ErrorResponse("data does not contain keytab"), logical.ErrInvalidRequest
}

// Check that the keytab is valid by parsing with krb5go
binary, err := base64.StdEncoding.DecodeString(kt)
if err != nil {
return nil, fmt.Errorf("could not base64 decode keytab: %v", err)
return logical.ErrorResponse(fmt.Sprintf("could not base64 decode keytab: %v", err)), logical.ErrInvalidRequest
}
_, err = keytab.Parse(binary)
if err != nil {
return nil, fmt.Errorf("invalid keytab: %v", err)
if _, err = keytab.Parse(binary); err != nil {
return logical.ErrorResponse(fmt.Sprintf("invalid keytab: %v", err)), logical.ErrInvalidRequest
}

config := &kerberosConfig{
Expand Down
17 changes: 11 additions & 6 deletions path_config_ldap.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package kerberos

import (
"context"

"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/ldaputil"
"github.com/hashicorp/vault/sdk/helper/tokenutil"
Expand All @@ -10,14 +11,18 @@ import (

const ldapConfPath = "config/ldap"

func pathConfigLdap(b *backend) *framework.Path {
func (b *backend) pathConfigLdap() *framework.Path {
p := &framework.Path{
Pattern: ldapConfPath,
Fields: ldaputil.ConfigFields(),

Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathConfigLdapRead,
logical.UpdateOperation: b.pathConfigLdapWrite,
Operations: map[logical.Operation]framework.OperationHandler{
logical.ReadOperation: &framework.PathOperation{
Callback: b.pathConfigLdapRead,
},
logical.UpdateOperation: &framework.PathOperation{
Callback: b.pathConfigLdapWrite,
},
},

HelpSynopsis: pathConfigLdapHelpSyn,
Expand Down Expand Up @@ -64,7 +69,7 @@ func (b *backend) pathConfigLdapRead(ctx context.Context, req *logical.Request,
func (b *backend) pathConfigLdapWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
cfg, err := b.ConfigLdap(ctx, req)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
return nil, err
}

var prevLDAPCfg *ldaputil.ConfigEntry
Expand All @@ -78,7 +83,7 @@ func (b *backend) pathConfigLdapWrite(ctx context.Context, req *logical.Request,

newLdapCfg, err := ldaputil.NewConfigEntry(prevLDAPCfg, d)
if err != nil {
return nil, err
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
}
cfg.ConfigEntry = newLdapCfg

Expand Down
10 changes: 6 additions & 4 deletions path_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ func getTestBackend(t *testing.T) (logical.Backend, logical.Storage) {
}

b := Backend()
err := b.Setup(context.Background(), config)
if err != nil {
if err := b.Setup(context.Background(), config); err != nil {
t.Fatalf("unable to create backend: %v", err)
}

Expand Down Expand Up @@ -101,11 +100,14 @@ func testConfigWriteError(t *testing.T, b logical.Backend, storage logical.Stora
Data: data,
}

_, err := b.HandleRequest(context.Background(), req)
resp, err := b.HandleRequest(context.Background(), req)
if err == nil {
t.Fatal("expected error")
}
if !strings.HasPrefix(err.Error(), e) {
if err.Error() != "invalid request" {
t.Fatal("expected invalid request")
}
if !strings.HasPrefix(resp.Error().Error(), e) {
t.Fatalf("got unexpected error: %v, expected %v", err.Error(), e)
}
}
Expand Down
28 changes: 18 additions & 10 deletions path_groups.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,38 +8,46 @@ import (
"github.com/hashicorp/vault/sdk/logical"
)

func pathGroupsList(b *backend) *framework.Path {
func (b *backend) pathGroupsList() *framework.Path {
return &framework.Path{
Pattern: "groups/?$",

Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ListOperation: b.pathGroupList,
Operations: map[logical.Operation]framework.OperationHandler{
logical.ListOperation: &framework.PathOperation{
Callback: b.pathGroupList,
},
},

HelpSynopsis: pathGroupHelpSyn,
HelpDescription: pathGroupHelpDesc,
}
}

func pathGroups(b *backend) *framework.Path {
func (b *backend) pathGroups() *framework.Path {
return &framework.Path{
Pattern: `groups/(?P<name>.+)`,
Fields: map[string]*framework.FieldSchema{
"name": &framework.FieldSchema{
"name": {
Type: framework.TypeString,
Description: "Name of the LDAP group.",
},

"policies": &framework.FieldSchema{
"policies": {
Type: framework.TypeCommaStringSlice,
Description: "Comma-separated list of policies associated to the group.",
},
},

Callbacks: map[logical.Operation]framework.OperationFunc{
logical.DeleteOperation: b.pathGroupDelete,
logical.ReadOperation: b.pathGroupRead,
logical.UpdateOperation: b.pathGroupWrite,
Operations: map[logical.Operation]framework.OperationHandler{
logical.DeleteOperation: &framework.PathOperation{
Callback: b.pathGroupDelete,
},
logical.ReadOperation: &framework.PathOperation{
Callback: b.pathGroupRead,
},
logical.UpdateOperation: &framework.PathOperation{
Callback: b.pathGroupWrite,
},
},

HelpSynopsis: pathGroupHelpSyn,
Expand Down
82 changes: 44 additions & 38 deletions path_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,28 @@ import (
"github.com/hashicorp/vault/sdk/helper/ldaputil"
"github.com/hashicorp/vault/sdk/helper/strutil"
"github.com/hashicorp/vault/sdk/logical"

"gopkg.in/jcmturner/gokrb5.v5/credentials"
"gopkg.in/jcmturner/gokrb5.v5/gssapi"
"gopkg.in/jcmturner/gokrb5.v5/keytab"
"gopkg.in/jcmturner/gokrb5.v5/service"
)

func pathLogin(b *backend) *framework.Path {
func (b *backend) pathLogin() *framework.Path {
return &framework.Path{
Pattern: "login$",
Fields: map[string]*framework.FieldSchema{
"authorization": &framework.FieldSchema{
"authorization": {
Type: framework.TypeString,
Description: `SPNEGO Authorization header. Required.`,
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathLoginGet,
logical.UpdateOperation: b.pathLogin,
Operations: map[logical.Operation]framework.OperationHandler{
logical.ReadOperation: &framework.PathOperation{
Callback: b.pathLoginGet,
},
logical.UpdateOperation: &framework.PathOperation{
Callback: b.pathLoginUpdate,
},
},
}
}
Expand All @@ -47,27 +50,33 @@ func parseKeytab(stringKeytab string) (*keytab.Keytab, error) {
return &kt, nil
}

func spnegoKrb5Authenticate(kt keytab.Keytab, sa string, authorization []byte, remoteAddr string) (bool, *credentials.Credentials, error) {
func spnegoKrb5Authenticate(kt keytab.Keytab, sa string, authorization []byte, remoteAddr string) (*credentials.Credentials, error) {
var spnego gssapi.SPNEGO
err := spnego.Unmarshal(authorization)
if err != nil || !spnego.Init {
return false, nil, fmt.Errorf("SPNEGO negotiation token is not a NegTokenInit: %v", err)
if err := spnego.Unmarshal(authorization); err != nil || !spnego.Init {
return nil, fmt.Errorf("SPNEGO negotiation token is not a NegTokenInit: %v", err)
}
if !spnego.NegTokenInit.MechTypes[0].Equal(gssapi.MechTypeOIDKRB5) && !spnego.NegTokenInit.MechTypes[0].Equal(gssapi.MechTypeOIDMSLegacyKRB5) {
return false, nil, errors.New("SPNEGO OID of MechToken is not of type KRB5")
return nil, errors.New("SPNEGO OID of MechToken is not of type KRB5")
}

var mt gssapi.MechToken
err = mt.Unmarshal(spnego.NegTokenInit.MechToken)
if err != nil {
return false, nil, fmt.Errorf("SPNEGO error unmarshaling MechToken: %v", err)
if err := mt.Unmarshal(spnego.NegTokenInit.MechToken); err != nil {
return nil, fmt.Errorf("SPNEGO error unmarshaling MechToken: %v", err)
}
if !mt.IsAPReq() {
return false, nil, errors.New("MechToken does not contain an AP_REQ - KRB_AP_ERR_MSG_TYPE")
return nil, errors.New("MechToken does not contain an AP_REQ - KRB_AP_ERR_MSG_TYPE")
}

ok, creds, err := service.ValidateAPREQ(mt.APReq, kt, sa, remoteAddr, false)
return ok, &creds, err
// The first return value here is a boolean reflecting whether the request is valid;
// however, this value is redundant because if the error is nil, the request is valid,
// but if it's populated, the request is invalid. Hence, it's ignored here because we
// only need to error if the error is populated, and that knowledge can be encapsulated
// here.
_, creds, err := service.ValidateAPREQ(mt.APReq, kt, sa, remoteAddr, false)
if err != nil {
return nil, err
}
return &creds, nil
}

func (b *backend) pathLoginGet(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
Expand All @@ -79,18 +88,18 @@ func (b *backend) pathLoginGet(ctx context.Context, req *logical.Request, d *fra
}, logical.CodedError(401, "authentication required")
}

func (b *backend) pathLogin(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
func (b *backend) pathLoginUpdate(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
kerbCfg, err := b.config(ctx, req.Storage)
if err != nil {
return nil, err
}
if kerbCfg == nil {
return nil, errors.New("Could not load backend configuration")
return nil, errors.New("backend kerberos not configured")
}

kt, err := parseKeytab(kerbCfg.Keytab)
if err != nil {
return nil, fmt.Errorf("Could not load keytab: %v", err)
return nil, err
}

ldapCfg, err := b.ConfigLdap(ctx, req)
Expand Down Expand Up @@ -119,7 +128,7 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, d *framew

ldapConnection, err := ldapClient.DialLDAP(ldapCfg.ConfigEntry)
if err != nil {
return nil, fmt.Errorf("Could not connect to LDAP: %v", err)
return nil, fmt.Errorf("could not connect to LDAP: %v", err)
}
if ldapConnection == nil {
return nil, errors.New("invalid connection returned from LDAP dial")
Expand All @@ -142,16 +151,12 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, d *framew
}
authorization, err := base64.StdEncoding.DecodeString(s[1])
if err != nil {
return nil, fmt.Errorf("Could not base64 decode authorization: %v", err)
return nil, fmt.Errorf("could not base64 decode authorization: %v", err)
}

ok, creds, err := spnegoKrb5Authenticate(*kt, kerbCfg.ServiceAccount, authorization, req.Connection.RemoteAddr)
if !ok {
if err != nil {
return nil, err
} else {
return logical.ErrorResponse("Kerberos authentication failed"), nil
}
creds, err := spnegoKrb5Authenticate(*kt, kerbCfg.ServiceAccount, authorization, req.Connection.RemoteAddr)
if err != nil {
return nil, err
}

if len(ldapCfg.BindPassword) > 0 {
Expand All @@ -167,10 +172,7 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, d *framew
if err != nil {
return nil, err
}

if b.Logger().IsDebug() {
b.Logger().Debug("auth/ldap: User BindDN fetched", "username", creds.Username, "binddn", userBindDN)
}
b.Logger().Debug("auth/ldap: User BindDN fetched", "username", creds.Username, "binddn", userBindDN)

userDN, err := ldapClient.GetUserDN(ldapCfg.ConfigEntry, ldapConnection, userBindDN)
if err != nil {
Expand All @@ -181,9 +183,7 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, d *framew
if err != nil {
return nil, err
}
if b.Logger().IsDebug() {
b.Logger().Debug("auth/ldap: Groups fetched from server", "num_server_groups", len(ldapGroups), "server_groups", ldapGroups)
}
b.Logger().Debug("auth/ldap: Groups fetched from server", "num_server_groups", len(ldapGroups), "server_groups", ldapGroups)

var allGroups []string
// Merge local and LDAP groups
Expand All @@ -193,9 +193,15 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, d *framew
var policies []string
for _, groupName := range allGroups {
group, err := b.Group(ctx, req.Storage, groupName)
if err == nil && group != nil {
policies = append(policies, group.Policies...)
if err != nil {
b.Logger().Warn(fmt.Sprintf("unable to retrieve %s: %s", groupName, err.Error()))
continue
}
if group == nil {
b.Logger().Warn(fmt.Sprintf("unable to find %s, does not currently exist", groupName))
continue
}
policies = append(policies, group.Policies...)
}
// Policies from each group may overlap
policies = strutil.RemoveDuplicates(policies, true)
Expand Down

0 comments on commit 1c744d6

Please sign in to comment.