Skip to content

Commit

Permalink
chore: fix saml_internal_test
Browse files Browse the repository at this point in the history
  • Loading branch information
mistahj67 committed Oct 25, 2024
1 parent e190d7d commit 5171524
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 92 deletions.
2 changes: 1 addition & 1 deletion cmd/api/src/api/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ func (s authenticator) CreateSession(ctx context.Context, user model.User, authP
userSession.AuthProviderID = typedAuthProvider.ID
case model.OIDCProvider:
userSession.AuthProviderType = model.SessionAuthProviderOIDC
userSession.AuthProviderID = int32(typedAuthProvider.ID)
userSession.AuthProviderID = typedAuthProvider.ID
default:
return "", errors.New("invalid auth provider")
}
Expand Down
172 changes: 81 additions & 91 deletions cmd/api/src/api/saml/saml_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,9 @@ import (
"testing"
"time"

"github.com/crewjam/saml"
"github.com/specterops/bloodhound/headers"

"github.com/specterops/bloodhound/src/api"
apimocks "github.com/specterops/bloodhound/src/api/mocks"
"github.com/specterops/bloodhound/src/auth/bhsaml"
"github.com/specterops/bloodhound/src/config"
"github.com/specterops/bloodhound/src/ctx"
Expand All @@ -36,131 +35,122 @@ import (
"github.com/specterops/bloodhound/src/database/types/null"
"github.com/specterops/bloodhound/src/model"
"github.com/specterops/bloodhound/src/serde"

"github.com/crewjam/saml"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
)

func TestProviderResource_createSessionFromAssertion(t *testing.T) {
const (
badUsername = "bad"
goodUsername = "good"
goodJWT = "fake"
)

func TestAuth_CreateSSOSession(t *testing.T) {
var (
goodUser = model.User{
PrincipalName: goodUsername,
username = "harls"
user = model.User{
PrincipalName: username,
SAMLProvider: &model.SAMLProvider{
Serial: model.Serial{
ID: 1,
},
Serial: model.Serial{ID: 1},
},

SAMLProviderID: null.Int32From(1),
}

mockCtrl = gomock.NewController(t)
mockDB = dbmocks.NewMockDatabase(mockCtrl)
mockAuthenticator = apimocks.NewMockAuthenticator(mockCtrl)
resource = ProviderResource{
db: mockDB,
authenticator: mockAuthenticator,
serviceProvider: bhsaml.ServiceProvider{
testAuthenticator = api.NewAuthenticator(config.Configuration{}, mockDB, dbmocks.NewMockAuthContextInitializer(mockCtrl))

resource = NewProviderResource(
mockDB,
config.Configuration{RootURL: serde.MustParseURL("https://example.com")},
bhsaml.ServiceProvider{
Config: model.SAMLProvider{
Serial: model.Serial{
ID: 1,
},
Serial: model.Serial{ID: 1},
},
},
cfg: config.Configuration{
RootURL: serde.MustParseURL("https://example.com"),
},
writeAPIErrorResponse: func(request *http.Request, response http.ResponseWriter, statusCode int, message string) {
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(statusCode, message, request), response)
},
}
)

defer mockCtrl.Finish()

var (
expires = time.Now().UTC().Add(time.Hour)
response = httptest.NewRecorder()
expectedCookieContent = fmt.Sprintf("token=fake; Path=/; Expires=%s; Secure; SameSite=Strict", expires.Format(http.TimeFormat))
func(request *http.Request, response http.ResponseWriter, statusCode int, message string) {},
)

testAssertion = &saml.Assertion{
AttributeStatements: []saml.AttributeStatement{
{
Attributes: []saml.Attribute{
{
FriendlyName: "uid",
Name: bhsaml.XMLSOAPClaimsEmailAddress,
NameFormat: bhsaml.ObjectIDAttributeNameFormat,
Values: []saml.AttributeValue{
{
Type: bhsaml.XMLTypeString,
Value: goodUsername,
},
},
},
},
},
},
AttributeStatements: []saml.AttributeStatement{{
Attributes: []saml.Attribute{{
FriendlyName: "uid",
Name: bhsaml.XMLSOAPClaimsEmailAddress,
NameFormat: bhsaml.ObjectIDAttributeNameFormat,
Values: []saml.AttributeValue{{
Type: bhsaml.XMLTypeString,
Value: username,
}},
}},
}},
}
)
defer mockCtrl.Finish()

httpRequest, _ := http.NewRequestWithContext(context.WithValue(context.TODO(), ctx.ValueKey, &ctx.Context{Host: &resource.cfg.RootURL.URL}), http.MethodPost, "http://localhost", nil)

// Test happy path
mockDB.EXPECT().LookupUser(gomock.Any(), goodUsername).Return(goodUser, nil)
mockAuthenticator.EXPECT().CreateSession(gomock.Any(), goodUser, gomock.Any()).Return(goodJWT, nil)
t.Run("successfully create sso session", func(t *testing.T) {
var (
response = httptest.NewRecorder()
expires = time.Now().UTC()
expectedCookieContent = fmt.Sprintf("token=.*; Path=/; Expires=%s; Secure; SameSite=Strict", expires.Format(http.TimeFormat))
)

resource.createSessionFromAssertion(httpRequest, response, expires, testAssertion)
mockDB.EXPECT().LookupUser(gomock.Any(), username).Return(user, nil)
mockDB.EXPECT().CreateUserSession(gomock.Any(), gomock.Any()).Return(model.UserSession{}, nil)

require.Equal(t, expectedCookieContent, response.Header().Get(headers.SetCookie.String()))
require.Equal(t, "https://example.com/ui", response.Header().Get(headers.Location.String()))
require.Equal(t, http.StatusFound, response.Code)
principalName, err := resource.getSAMLUserPrincipalNameFromAssertion(testAssertion)
require.Nil(t, err)

// Change the assertion statement attribute to the bad username to assert we get a 403
testAssertion.AttributeStatements[0].Attributes[0].Values[0].Value = badUsername
testAuthenticator.CreateSSOSession(httpRequest, response, principalName, resource.serviceProvider.Config)

mockDB.EXPECT().LookupUser(gomock.Any(), badUsername).Return(model.User{}, database.ErrNotFound)
require.Regexp(t, expectedCookieContent, response.Header().Get(headers.SetCookie.String()))
require.Equal(t, "https://example.com/ui", response.Header().Get(headers.Location.String()))
require.Equal(t, http.StatusFound, response.Code)
})

response = httptest.NewRecorder()
t.Run("Forbidden 403 if user isn't in db", func(t *testing.T) {
response := httptest.NewRecorder()
mockDB.EXPECT().LookupUser(gomock.Any(), username).Return(model.User{}, database.ErrNotFound)

resource.createSessionFromAssertion(httpRequest, response, expires, testAssertion)
require.Equal(t, http.StatusForbidden, response.Code)
principalName, err := resource.getSAMLUserPrincipalNameFromAssertion(testAssertion)
require.Nil(t, err)

// Change the db return to a user that isn't associated with a SAML Provider
mockDB.EXPECT().LookupUser(gomock.Any(), badUsername).Return(model.User{}, nil)
testAuthenticator.CreateSSOSession(httpRequest, response, principalName, resource.serviceProvider.Config)

response = httptest.NewRecorder()
require.Equal(t, http.StatusForbidden, response.Code)
})

resource.createSessionFromAssertion(httpRequest, response, expires, testAssertion)
require.Equal(t, http.StatusForbidden, response.Code)
t.Run("Forbidden 403 if user isn't associated with a SAML Provider", func(t *testing.T) {
response := httptest.NewRecorder()
mockDB.EXPECT().LookupUser(gomock.Any(), username).Return(model.User{}, nil)

// Change the db return to a user that isn't associated with this SAML Provider
mockDB.EXPECT().LookupUser(gomock.Any(), badUsername).Return(model.User{
SAMLProviderID: null.Int32From(2),
SAMLProvider: &model.SAMLProvider{
Serial: model.Serial{
ID: 2,
principalName, err := resource.getSAMLUserPrincipalNameFromAssertion(testAssertion)
require.Nil(t, err)

testAuthenticator.CreateSSOSession(httpRequest, response, principalName, resource.serviceProvider.Config)

require.Equal(t, http.StatusForbidden, response.Code)
})

t.Run("Forbidden 403 if user isn't associated with specified SAML Provider", func(t *testing.T) {
response := httptest.NewRecorder()
mockDB.EXPECT().LookupUser(gomock.Any(), username).Return(model.User{
SAMLProviderID: null.Int32From(2),
SAMLProvider: &model.SAMLProvider{
Serial: model.Serial{
ID: 2,
},
},
},
}, nil)
}, nil)

principalName, err := resource.getSAMLUserPrincipalNameFromAssertion(testAssertion)
require.Nil(t, err)

response = httptest.NewRecorder()
testAuthenticator.CreateSSOSession(httpRequest, response, principalName, resource.serviceProvider.Config)

resource.createSessionFromAssertion(httpRequest, response, expires, testAssertion)
require.Equal(t, http.StatusForbidden, response.Code)
require.Equal(t, http.StatusForbidden, response.Code)

// Remove the assertion statement attribute for the username
testAssertion.AttributeStatements[0].Attributes[0].Values = nil
})

response = httptest.NewRecorder()
t.Run("Correctly fails with SAML assertion error if assertion is invalid", func(t *testing.T) {
testAssertion.AttributeStatements[0].Attributes[0].Values = nil

resource.createSessionFromAssertion(httpRequest, response, expires, testAssertion)
require.Equal(t, http.StatusBadRequest, response.Code)
_, err := resource.getSAMLUserPrincipalNameFromAssertion(testAssertion)
require.ErrorIs(t, err, ErrorSAMLAssertion)
})
}

0 comments on commit 5171524

Please sign in to comment.