Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add parameter guards #115

Merged
merged 3 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions app_user_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,8 +366,7 @@ func TestCreateUser(t *testing.T) {
_, err = psg.CreateUser(createUserBody)

require.NotNil(t, err)
expectedErrorText := "email: cannot be blank; phone: cannot be blank."
badRequestAsserts(t, err, expectedErrorText)
assert.Equal(t, "At least one of args.Email or args.Phone is required.", err.Error())
})

t.Run("Error: unauthorized", func(t *testing.T) {
Expand Down
15 changes: 10 additions & 5 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"

"github.com/golang-jwt/jwt"
gojwt "github.com/golang-jwt/jwt"
"github.com/lestrrat-go/jwx/v2/jwk"
)

Expand Down Expand Up @@ -38,8 +39,8 @@ func newAuth(appID string, client *ClientWithResponses) (*auth, error) {
}

// CreateMagicLink creates a Magic Link for your app.
func (a *auth) CreateMagicLink(createMagicLinkBody CreateMagicLinkBody) (*MagicLink, error) {
res, err := a.client.CreateMagicLinkWithResponse(context.Background(), a.appID, createMagicLinkBody)
func (a *auth) CreateMagicLink(args CreateMagicLinkBody) (*MagicLink, error) {
res, err := a.client.CreateMagicLinkWithResponse(context.Background(), a.appID, args)
if err != nil {
return nil, err
}
Expand All @@ -52,13 +53,17 @@ func (a *auth) CreateMagicLink(createMagicLinkBody CreateMagicLinkBody) (*MagicL
}

// ValidateJWT validates the JWT and returns the user ID.
func (a *auth) ValidateJWT(authToken string) (string, error) {
parsedToken, err := jwt.Parse(authToken, a.getPublicKey)
func (a *auth) ValidateJWT(jwt string) (string, error) {
if jwt == "" {
return "", errors.New("jwt is required.")
}

parsedToken, err := gojwt.Parse(jwt, a.getPublicKey)
if err != nil {
return "", err
}

claims, ok := parsedToken.Claims.(jwt.MapClaims)
claims, ok := parsedToken.Claims.(gojwt.MapClaims)
if !ok {
return "", errors.New("failed to extract claims from JWT")
}
Expand Down
4 changes: 4 additions & 0 deletions authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ func (a *App) AuthenticateRequestWithCookie(r *http.Request) (string, error) {
//
// Deprecated: use `Passage.Auth.ValidateJWT` instead.
func (a *App) ValidateAuthToken(authToken string) (string, bool) {
if authToken == "" {
return "", false
}

parsedToken, err := jwt.Parse(authToken, a.Auth.getPublicKey)
if err != nil {
return "", false
Expand Down
45 changes: 45 additions & 0 deletions user.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package passage

import (
"context"
"errors"
"net/http"
"strings"
)
Expand All @@ -24,6 +25,10 @@ func newUser(appID string, client *ClientWithResponses) *user {

// Get retrieves a user's object using their user ID.
func (u *user) Get(userID string) (*PassageUser, error) {
if userID == "" {
return nil, errors.New("userID is required.")
}

res, err := u.client.GetUserWithResponse(context.Background(), u.appID, userID)
if err != nil {
return nil, err
Expand All @@ -38,6 +43,10 @@ func (u *user) Get(userID string) (*PassageUser, error) {

// GetByIdentifier retrieves a user's object using their user identifier.
func (u *user) GetByIdentifier(identifier string) (*PassageUser, error) {
if identifier == "" {
return nil, errors.New("identifier is required.")
}

limit := 1
lowerIdentifier := strings.ToLower(identifier)
res, err := u.client.ListPaginatedUsersWithResponse(
Expand Down Expand Up @@ -72,6 +81,10 @@ func (u *user) GetByIdentifier(identifier string) (*PassageUser, error) {

// Activate activates a user using their user ID.
func (u *user) Activate(userID string) (*PassageUser, error) {
if userID == "" {
return nil, errors.New("userID is required.")
}

res, err := u.client.ActivateUserWithResponse(context.Background(), u.appID, userID)
if err != nil {
return nil, err
Expand All @@ -86,6 +99,10 @@ func (u *user) Activate(userID string) (*PassageUser, error) {

// Deactivate deactivates a user using their user ID.
func (u *user) Deactivate(userID string) (*PassageUser, error) {
if userID == "" {
return nil, errors.New("userID is required.")
}

res, err := u.client.DeactivateUserWithResponse(context.Background(), u.appID, userID)
if err != nil {
return nil, err
Expand All @@ -100,6 +117,10 @@ func (u *user) Deactivate(userID string) (*PassageUser, error) {

// Update updates a user.
func (u *user) Update(userID string, options UpdateUserOptions) (*PassageUser, error) {
if userID == "" {
return nil, errors.New("userID is required.")
}

res, err := u.client.UpdateUserWithResponse(context.Background(), u.appID, userID, options)
if err != nil {
return nil, err
Expand All @@ -114,6 +135,10 @@ func (u *user) Update(userID string, options UpdateUserOptions) (*PassageUser, e

// Create creates a user.
func (u *user) Create(args CreateUserArgs) (*PassageUser, error) {
if args.Email == "" && args.Phone == "" {
return nil, errors.New("At least one of args.Email or args.Phone is required.")
}

res, err := u.client.CreateUserWithResponse(context.Background(), u.appID, args)
if err != nil {
return nil, err
Expand All @@ -128,6 +153,10 @@ func (u *user) Create(args CreateUserArgs) (*PassageUser, error) {

// Delete deletes a user using their user ID.
func (u *user) Delete(userID string) error {
if userID == "" {
return errors.New("userID is required.")
}

res, err := u.client.DeleteUserWithResponse(context.Background(), u.appID, userID)
if err != nil {
return err
Expand All @@ -142,6 +171,10 @@ func (u *user) Delete(userID string) error {

// ListDevices retrieves a user's webauthn devices using their user ID.
func (u *user) ListDevices(userID string) ([]WebAuthnDevices, error) {
if userID == "" {
return nil, errors.New("userID is required.")
}

res, err := u.client.ListUserDevicesWithResponse(context.Background(), u.appID, userID)
if err != nil {
return nil, err
Expand All @@ -156,6 +189,14 @@ func (u *user) ListDevices(userID string) ([]WebAuthnDevices, error) {

// RevokeDevice revokes user's webauthn device using their user ID and the device ID.
func (u *user) RevokeDevice(userID string, deviceID string) error {
if userID == "" {
return errors.New("userID is required.")
}

if deviceID == "" {
return errors.New("deviceID is required.")
}

res, err := u.client.DeleteUserDevicesWithResponse(context.Background(), u.appID, userID, deviceID)
if err != nil {
return err
Expand All @@ -170,6 +211,10 @@ func (u *user) RevokeDevice(userID string, deviceID string) error {

// RevokeRefreshTokens revokes all of a user's Refresh Tokens using their User ID.
func (u *user) RevokeRefreshTokens(userID string) error {
if userID == "" {
return errors.New("userID is required.")
}

res, err := u.client.RevokeUserRefreshTokensWithResponse(context.Background(), u.appID, userID)
if err != nil {
return err
Expand Down
3 changes: 1 addition & 2 deletions user_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -367,8 +367,7 @@ func TestCreate(t *testing.T) {
_, err = psg.User.Create(createUserBody)

require.NotNil(t, err)
expectedMessage := "email: cannot be blank; phone: cannot be blank."
passageBadRequestAsserts(t, err, expectedMessage)
assert.Equal(t, "At least one of args.Email or args.Phone is required.", err.Error())
})

t.Run("Error: unauthorized", func(t *testing.T) {
Expand Down
Loading