Skip to content

Commit

Permalink
move acs_url input validation to rpc create and update methods (#46847)
Browse files Browse the repository at this point in the history
  • Loading branch information
flyinghermit authored Sep 25, 2024
1 parent 4c41eb3 commit b4b9ad9
Show file tree
Hide file tree
Showing 6 changed files with 254 additions and 217 deletions.
18 changes: 18 additions & 0 deletions lib/auth/auth_with_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -6686,6 +6686,16 @@ func (a *ServerWithRoles) CreateSAMLIdPServiceProvider(ctx context.Context, sp t
log.WithError(trace.NewAggregate(emitErr, err)).Warn("Failed to emit SAML IdP service provider created event.")
}

if err := services.ValidateAssertionConsumerServicesEndpoint(sp.GetACSURL()); err != nil {
return trace.Wrap(err)
}

if sp.GetEntityDescriptor() != "" {
if err := services.ValidateAndFilterEntityDescriptor(sp, services.SAMLACSInputStrictFilter); err != nil {
return trace.Wrap(err)
}
}

return trace.Wrap(err)
}

Expand Down Expand Up @@ -6717,6 +6727,14 @@ func (a *ServerWithRoles) UpdateSAMLIdPServiceProvider(ctx context.Context, sp t
log.WithError(trace.NewAggregate(emitErr, err)).Warn("Failed to emit SAML IdP service provider updated event.")
}

if err := services.ValidateAssertionConsumerServicesEndpoint(sp.GetACSURL()); err != nil {
return trace.Wrap(err)
}

if err := services.ValidateAndFilterEntityDescriptor(sp, services.SAMLACSInputStrictFilter); err != nil {
return trace.Wrap(err)
}

return trace.Wrap(err)
}

Expand Down
162 changes: 162 additions & 0 deletions lib/auth/auth_with_roles_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5701,6 +5701,168 @@ func TestUpdateSAMLIdPServiceProvider(t *testing.T) {
}
}

func TestCreateSAMLIdPServiceProviderInvalidInputs(t *testing.T) {
ctx := context.Background()
srv := newTestTLSServer(t)
user, _ := createSAMLIdPTestUsers(t, srv.Auth())
client, err := srv.NewClient(TestUser(user))
require.NoError(t, err)

tests := []struct {
name string
entityDescriptor string
entityID string
acsURL string
errAssertion require.ErrorAssertionFunc
}{
{
name: "missing url scheme in acs input",
entityID: "sp",
acsURL: "sp",
errAssertion: func(t require.TestingT, err error, i ...interface{}) {
require.ErrorContains(t, err, "invalid scheme")
},
},
{
name: "missing url scheme for acs in ed",
entityDescriptor: services.NewSAMLTestSPMetadata("sp", "sp"),
errAssertion: func(t require.TestingT, err error, i ...interface{}) {
require.ErrorContains(t, err, "invalid url scheme")
},
},
{
name: "http url scheme in acs",
entityID: "sp",
acsURL: "http://sp",
errAssertion: func(t require.TestingT, err error, i ...interface{}) {
require.ErrorContains(t, err, "invalid scheme")
},
},
{
name: "http url scheme for acs in ed",
entityDescriptor: services.NewSAMLTestSPMetadata("sp", "http://sp"),
errAssertion: func(t require.TestingT, err error, i ...interface{}) {
require.ErrorContains(t, err, "unsupported ACS bindings")
},
},
{
name: "unsupported scheme in acs",
entityID: "sp",
acsURL: "gopher://sp",
errAssertion: func(t require.TestingT, err error, i ...interface{}) {
require.ErrorContains(t, err, "invalid scheme")
},
},
{
name: "unsupported scheme for acs in ed",
entityDescriptor: services.NewSAMLTestSPMetadata("sp", "gopher://sp"),
errAssertion: func(t require.TestingT, err error, i ...interface{}) {
require.ErrorContains(t, err, "invalid url scheme")
},
},
{
name: "invalid character in acs",
entityID: "sp",
acsURL: "https://sp>",
errAssertion: func(t require.TestingT, err error, i ...interface{}) {
require.ErrorContains(t, err, "unsupported character")
},
},
{
name: "invalid character in acs in ed",
entityDescriptor: services.NewSAMLTestSPMetadata("sp", "https://sp>"),
errAssertion: func(t require.TestingT, err error, i ...interface{}) {
require.ErrorContains(t, err, "unsupported ACS bindings")
},
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
sp, err := types.NewSAMLIdPServiceProvider(types.Metadata{
Name: "test",
}, types.SAMLIdPServiceProviderSpecV1{
EntityDescriptor: test.entityDescriptor,
EntityID: test.entityID,
ACSURL: test.acsURL,
})
require.NoError(t, err)

err = client.CreateSAMLIdPServiceProvider(ctx, sp)
test.errAssertion(t, err)
})
}
}

func TestUpdateSAMLIdPServiceProviderInvalidInputs(t *testing.T) {
ctx := context.Background()
srv := newTestTLSServer(t)
user, _ := createSAMLIdPTestUsers(t, srv.Auth())
client, err := srv.NewClient(TestUser(user))
require.NoError(t, err)

sp, err := types.NewSAMLIdPServiceProvider(types.Metadata{
Name: "sp",
}, types.SAMLIdPServiceProviderSpecV1{
EntityDescriptor: services.NewSAMLTestSPMetadata("https://sp", "https://sp"),
})
require.NoError(t, err)

err = client.CreateSAMLIdPServiceProvider(ctx, sp)
require.NoError(t, err)

tests := []struct {
name string
entityDescriptor string
entityID string
acsURL string
errAssertion require.ErrorAssertionFunc
}{
{
name: "missing url scheme for acs in ed",
entityDescriptor: services.NewSAMLTestSPMetadata("https://sp", "sp"),
errAssertion: func(t require.TestingT, err error, i ...interface{}) {
require.ErrorContains(t, err, "invalid url scheme")
},
},
{
name: "http url scheme for acs in ed",
entityDescriptor: services.NewSAMLTestSPMetadata("https://sp", "http://sp"),
errAssertion: func(t require.TestingT, err error, i ...interface{}) {
require.ErrorContains(t, err, "unsupported ACS bindings")
},
},
{
name: "unsupported scheme for acs in ed",
entityDescriptor: services.NewSAMLTestSPMetadata("https://sp", "gopher://sp"),
errAssertion: func(t require.TestingT, err error, i ...interface{}) {
require.ErrorContains(t, err, "invalid url scheme")
},
},
{
name: "invalid character in acs in ed",
entityDescriptor: services.NewSAMLTestSPMetadata("https://sp", "https://sp>"),
errAssertion: func(t require.TestingT, err error, i ...interface{}) {
require.ErrorContains(t, err, "unsupported ACS bindings")
},
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
sp, err := types.NewSAMLIdPServiceProvider(types.Metadata{
Name: "sp",
}, types.SAMLIdPServiceProviderSpecV1{
EntityDescriptor: test.entityDescriptor,
})
require.NoError(t, err)

err = client.UpdateSAMLIdPServiceProvider(ctx, sp)
test.errAssertion(t, err)
})
}
}

func TestDeleteSAMLIdPServiceProvider(t *testing.T) {
ctx := context.Background()
srv := newTestTLSServer(t)
Expand Down
36 changes: 12 additions & 24 deletions lib/services/local/saml_idp_service_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,10 @@ func (s *SAMLIdPServiceProviderService) GetSAMLIdPServiceProvider(ctx context.Co
// CreateSAMLIdPServiceProvider creates a new SAML IdP service provider resource.
func (s *SAMLIdPServiceProviderService) CreateSAMLIdPServiceProvider(ctx context.Context, sp types.SAMLIdPServiceProvider) error {
if err := services.ValidateAssertionConsumerServicesEndpoint(sp.GetACSURL()); err != nil {
return trace.Wrap(err)
// logging instead of returning an error cause we do not want to break cache writes on a cluster
// that already has a service provider with unsupported characters/scheme in the acs_url.
s.log.Warn(err)
}

if sp.GetEntityDescriptor() == "" {
// fetchAndSetEntityDescriptor is expected to return error if it fails
// to fetch a valid entity descriptor.
Expand All @@ -126,7 +127,9 @@ func (s *SAMLIdPServiceProviderService) CreateSAMLIdPServiceProvider(ctx context
}
}

if err := s.validateEntityDescriptor(sp); err != nil {
// we only verify if the entity ID field in the spec matches with the entity descriptor.
// filtering is done only for logging purpose.
if err := services.ValidateAndFilterEntityDescriptor(sp, services.SAMLACSInputPermissiveFilter); err != nil {
return trace.Wrap(err)
}

Expand Down Expand Up @@ -157,10 +160,14 @@ func (s *SAMLIdPServiceProviderService) CreateSAMLIdPServiceProvider(ctx context
// UpdateSAMLIdPServiceProvider updates an existing SAML IdP service provider resource.
func (s *SAMLIdPServiceProviderService) UpdateSAMLIdPServiceProvider(ctx context.Context, sp types.SAMLIdPServiceProvider) error {
if err := services.ValidateAssertionConsumerServicesEndpoint(sp.GetACSURL()); err != nil {
return trace.Wrap(err)
// logging instead of returning an error cause we do not want to break cache writes on a cluster
// that already has a service provider with unsupported characters/scheme in the acs_url.
s.log.Warn(err)
}

if err := s.validateEntityDescriptor(sp); err != nil {
// we only verify if the entity ID field in the spec matches with the entity descriptor.
// filtering is done only for logging purpose.
if err := services.ValidateAndFilterEntityDescriptor(sp, services.SAMLACSInputPermissiveFilter); err != nil {
return trace.Wrap(err)
}

Expand Down Expand Up @@ -327,25 +334,6 @@ func (s *SAMLIdPServiceProviderService) embedAttributeMapping(sp types.SAMLIdPSe
return nil
}

// validateEntityDescriptor validates entity descriptor XML, entity ID and logs unsupported ACS bindings.
func (s *SAMLIdPServiceProviderService) validateEntityDescriptor(sp types.SAMLIdPServiceProvider) error {
ed, err := samlsp.ParseMetadata([]byte(sp.GetEntityDescriptor()))
if err != nil {
return trace.BadParameter("invalid entity descriptor for SAML IdP Service Provider %q: %v", sp.GetEntityID(), err)
}

if ed.EntityID != sp.GetEntityID() {
return trace.BadParameter("entity ID parsed from the entity descriptor does not match the entity ID in the SAML IdP service provider object")
}

// ensure any filtering related issues get logged
if err := services.FilterSAMLEntityDescriptor(ed, false /* quiet */); err != nil {
return trace.BadParameter("Entity descriptor for SAML IdP Service Provider %q contains unsupported ACS bindings: %v", sp.GetEntityID(), err)
}

return nil
}

// GetTeleportSPSSODescriptor returns Teleport embedded SPSSODescriptor and its index from a
// list of SPSSODescriptors. The correct SPSSODescriptor is determined by searching for
// AttributeConsumingService element with ServiceNames named teleport_saml_idp_service.
Expand Down
Loading

0 comments on commit b4b9ad9

Please sign in to comment.