diff --git a/app_user_test.go b/app_user_test.go index 3623d08..d060727 100644 --- a/app_user_test.go +++ b/app_user_test.go @@ -366,8 +366,7 @@ func TestCreateUser(t *testing.T) { _, err = psg.CreateUser(createUserBody) require.NotNil(t, err) - expectedErrorText := "email: cannot be blank; phone: cannot be blank." - badRequestAsserts(t, err, expectedErrorText) + assert.Equal(t, "At least one of args.Email or args.Phone is required.", err.Error()) }) t.Run("Error: unauthorized", func(t *testing.T) { diff --git a/auth.go b/auth.go index 4fa1def..69674cd 100644 --- a/auth.go +++ b/auth.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/golang-jwt/jwt" + gojwt "github.com/golang-jwt/jwt" "github.com/lestrrat-go/jwx/v2/jwk" ) @@ -38,8 +39,8 @@ func newAuth(appID string, client *ClientWithResponses) (*auth, error) { } // CreateMagicLink creates a Magic Link for your app. -func (a *auth) CreateMagicLink(createMagicLinkBody CreateMagicLinkBody) (*MagicLink, error) { - res, err := a.client.CreateMagicLinkWithResponse(context.Background(), a.appID, createMagicLinkBody) +func (a *auth) CreateMagicLink(args CreateMagicLinkBody) (*MagicLink, error) { + res, err := a.client.CreateMagicLinkWithResponse(context.Background(), a.appID, args) if err != nil { return nil, err } @@ -52,13 +53,17 @@ func (a *auth) CreateMagicLink(createMagicLinkBody CreateMagicLinkBody) (*MagicL } // ValidateJWT validates the JWT and returns the user ID. -func (a *auth) ValidateJWT(authToken string) (string, error) { - parsedToken, err := jwt.Parse(authToken, a.getPublicKey) +func (a *auth) ValidateJWT(jwt string) (string, error) { + if jwt == "" { + return "", errors.New("jwt is required.") + } + + parsedToken, err := gojwt.Parse(jwt, a.getPublicKey) if err != nil { return "", err } - claims, ok := parsedToken.Claims.(jwt.MapClaims) + claims, ok := parsedToken.Claims.(gojwt.MapClaims) if !ok { return "", errors.New("failed to extract claims from JWT") } diff --git a/authentication.go b/authentication.go index d1638cd..b32d03a 100644 --- a/authentication.go +++ b/authentication.go @@ -55,6 +55,10 @@ func (a *App) AuthenticateRequestWithCookie(r *http.Request) (string, error) { // // Deprecated: use `Passage.Auth.ValidateJWT` instead. func (a *App) ValidateAuthToken(authToken string) (string, bool) { + if authToken == "" { + return "", false + } + parsedToken, err := jwt.Parse(authToken, a.Auth.getPublicKey) if err != nil { return "", false diff --git a/user.go b/user.go index 05f1afa..8d79cd2 100644 --- a/user.go +++ b/user.go @@ -2,6 +2,7 @@ package passage import ( "context" + "errors" "net/http" "strings" ) @@ -24,6 +25,10 @@ func newUser(appID string, client *ClientWithResponses) *user { // Get retrieves a user's object using their user ID. func (u *user) Get(userID string) (*PassageUser, error) { + if userID == "" { + return nil, errors.New("userID is required.") + } + res, err := u.client.GetUserWithResponse(context.Background(), u.appID, userID) if err != nil { return nil, err @@ -38,6 +43,10 @@ func (u *user) Get(userID string) (*PassageUser, error) { // GetByIdentifier retrieves a user's object using their user identifier. func (u *user) GetByIdentifier(identifier string) (*PassageUser, error) { + if identifier == "" { + return nil, errors.New("identifier is required.") + } + limit := 1 lowerIdentifier := strings.ToLower(identifier) res, err := u.client.ListPaginatedUsersWithResponse( @@ -72,6 +81,10 @@ func (u *user) GetByIdentifier(identifier string) (*PassageUser, error) { // Activate activates a user using their user ID. func (u *user) Activate(userID string) (*PassageUser, error) { + if userID == "" { + return nil, errors.New("userID is required.") + } + res, err := u.client.ActivateUserWithResponse(context.Background(), u.appID, userID) if err != nil { return nil, err @@ -86,6 +99,10 @@ func (u *user) Activate(userID string) (*PassageUser, error) { // Deactivate deactivates a user using their user ID. func (u *user) Deactivate(userID string) (*PassageUser, error) { + if userID == "" { + return nil, errors.New("userID is required.") + } + res, err := u.client.DeactivateUserWithResponse(context.Background(), u.appID, userID) if err != nil { return nil, err @@ -100,6 +117,10 @@ func (u *user) Deactivate(userID string) (*PassageUser, error) { // Update updates a user. func (u *user) Update(userID string, options UpdateUserOptions) (*PassageUser, error) { + if userID == "" { + return nil, errors.New("userID is required.") + } + res, err := u.client.UpdateUserWithResponse(context.Background(), u.appID, userID, options) if err != nil { return nil, err @@ -114,6 +135,10 @@ func (u *user) Update(userID string, options UpdateUserOptions) (*PassageUser, e // Create creates a user. func (u *user) Create(args CreateUserArgs) (*PassageUser, error) { + if args.Email == "" && args.Phone == "" { + return nil, errors.New("At least one of args.Email or args.Phone is required.") + } + res, err := u.client.CreateUserWithResponse(context.Background(), u.appID, args) if err != nil { return nil, err @@ -128,6 +153,10 @@ func (u *user) Create(args CreateUserArgs) (*PassageUser, error) { // Delete deletes a user using their user ID. func (u *user) Delete(userID string) error { + if userID == "" { + return errors.New("userID is required.") + } + res, err := u.client.DeleteUserWithResponse(context.Background(), u.appID, userID) if err != nil { return err @@ -142,6 +171,10 @@ func (u *user) Delete(userID string) error { // ListDevices retrieves a user's webauthn devices using their user ID. func (u *user) ListDevices(userID string) ([]WebAuthnDevices, error) { + if userID == "" { + return nil, errors.New("userID is required.") + } + res, err := u.client.ListUserDevicesWithResponse(context.Background(), u.appID, userID) if err != nil { return nil, err @@ -156,6 +189,14 @@ func (u *user) ListDevices(userID string) ([]WebAuthnDevices, error) { // RevokeDevice revokes user's webauthn device using their user ID and the device ID. func (u *user) RevokeDevice(userID string, deviceID string) error { + if userID == "" { + return errors.New("userID is required.") + } + + if deviceID == "" { + return errors.New("deviceID is required.") + } + res, err := u.client.DeleteUserDevicesWithResponse(context.Background(), u.appID, userID, deviceID) if err != nil { return err @@ -170,6 +211,10 @@ func (u *user) RevokeDevice(userID string, deviceID string) error { // RevokeRefreshTokens revokes all of a user's Refresh Tokens using their User ID. func (u *user) RevokeRefreshTokens(userID string) error { + if userID == "" { + return errors.New("userID is required.") + } + res, err := u.client.RevokeUserRefreshTokensWithResponse(context.Background(), u.appID, userID) if err != nil { return err diff --git a/user_test.go b/user_test.go index 33a59da..6636266 100644 --- a/user_test.go +++ b/user_test.go @@ -367,8 +367,7 @@ func TestCreate(t *testing.T) { _, err = psg.User.Create(createUserBody) require.NotNil(t, err) - expectedMessage := "email: cannot be blank; phone: cannot be blank." - passageBadRequestAsserts(t, err, expectedMessage) + assert.Equal(t, "At least one of args.Email or args.Phone is required.", err.Error()) }) t.Run("Error: unauthorized", func(t *testing.T) {