Skip to content

Commit

Permalink
Merge pull request #34 from deweppro/develop
Browse files Browse the repository at this point in the history
fix jwt - use keys pool
  • Loading branch information
markus621 authored Apr 20, 2023
2 parents a6c4e3f + 7e51885 commit f7aefc7
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 102 deletions.
162 changes: 82 additions & 80 deletions plugins/auth/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,40 +26,46 @@ const (
type (
// ConfigJWT jwt config model
ConfigJWT struct {
JWT ConfigJWTItem `yaml:"jwt"`
JWT []ConfigJWTItem `yaml:"jwt"`
}

ConfigJWTItem struct {
Key string `yaml:"key"`
TTL time.Duration `yaml:"ttl"`
Algorithm string `yaml:"alg"`
Cookie string `yaml:"cookie"`
ID string `yaml:"id"`
Key string `yaml:"key"`
Algorithm string `yaml:"alg"`
}
)

func (v *ConfigJWT) Default() {
if len(v.JWT.Key) == 0 {
v.JWT = ConfigJWTItem{
Key: random.String(32),
TTL: time.Hour * 24,
Algorithm: JWTAlgHS256,
Cookie: "jwt",
if len(v.JWT) == 0 {
for i := 0; i < 5; i++ {
v.JWT = append(v.JWT, ConfigJWTItem{
ID: random.String(6),
Key: random.String(32),
Algorithm: JWTAlgHS256,
})
}
}
}

func (v *ConfigJWT) Validate() error {
if len(v.JWT.Key) < 32 {
return fmt.Errorf("jwt key less than 32 characters")
if len(v.JWT) == 0 {
return fmt.Errorf("jwt config is empty")
}
if len(v.JWT.Cookie) == 0 {
return fmt.Errorf("jwt cookie name is empty")
}
switch v.JWT.Algorithm {
case JWTAlgHS256, JWTAlgHS384, JWTAlgHS512:
default:
return fmt.Errorf("jwt algorithm not supported")
for _, vv := range v.JWT {
if len(vv.ID) == 0 {
return fmt.Errorf("jwt key id is empty")
}
if len(vv.Key) < 32 {
return fmt.Errorf("jwt key less than 32 characters")
}
switch vv.Algorithm {
case JWTAlgHS256, JWTAlgHS384, JWTAlgHS512:
default:
return fmt.Errorf("jwt algorithm not supported")
}
}

return nil
}

Expand All @@ -75,57 +81,85 @@ func WithJWT() plugins.Plugin {

//easyjson:json
type JWTHeader struct {
Kid string `json:"kid"`
Alg string `json:"alg"`
IssuedAt int64 `json:"iat"`
ExpiresAt int64 `json:"eat"`
}

type (
JWT interface {
Sign(payload interface{}) (string, error)
Extend(token string) (string, error)
Sign(payload interface{}, ttl time.Duration) (string, error)
Verify(token string, payload interface{}) (*JWTHeader, error)
CookieName() string
}

_jwt struct {
pool map[string]*_jwtPoolItem
}

_jwtPoolItem struct {
conf ConfigJWTItem
hash func() hash.Hash
key []byte
}
)

func newJWT(c ConfigJWTItem) (JWT, error) {
var h func() hash.Hash
func newJWT(conf []ConfigJWTItem) (JWT, error) {
obj := &_jwt{pool: make(map[string]*_jwtPoolItem)}

for _, c := range conf {
var h func() hash.Hash
switch c.Algorithm {
case JWTAlgHS256:
h = sha256.New
case JWTAlgHS384:
h = sha512.New384
case JWTAlgHS512:
h = sha512.New
default:
return nil, fmt.Errorf("jwt algorithm not supported")
}
obj.pool[c.ID] = &_jwtPoolItem{conf: c, hash: h, key: []byte(c.Key)}
}

return obj, nil
}

switch c.Algorithm {
case JWTAlgHS256:
h = sha256.New
case JWTAlgHS384:
h = sha512.New384
case JWTAlgHS512:
h = sha512.New
default:
return nil, fmt.Errorf("jwt algorithm not supported")
func (v *_jwt) randPool() (*_jwtPoolItem, error) {
for _, p := range v.pool {
return p, nil
}
return nil, fmt.Errorf("jwt pool is empty")
}

return &_jwt{conf: c, hash: h, key: []byte(c.Key)}, nil
func (v *_jwt) getPool(id string) (*_jwtPoolItem, error) {
p, ok := v.pool[id]
if ok {
return p, nil
}
return nil, fmt.Errorf("jwt pool not found")
}

func (v *_jwt) calcHash(data []byte) ([]byte, error) {
mac := hmac.New(v.hash, v.key)
func (v *_jwt) calcHash(hash func() hash.Hash, key []byte, data []byte) ([]byte, error) {
mac := hmac.New(hash, key)
if _, err := mac.Write(data); err != nil {
return nil, err
}
result := mac.Sum(nil)
return result, nil
}

func (v *_jwt) Sign(payload interface{}) (string, error) {
func (v *_jwt) Sign(payload interface{}, ttl time.Duration) (string, error) {
pool, err := v.randPool()
if err != nil {
return "", err
}

h, err := (&JWTHeader{
Alg: v.conf.Algorithm,
Kid: pool.conf.ID,
Alg: pool.conf.Algorithm,
IssuedAt: time.Now().Unix(),
ExpiresAt: time.Now().Add(v.conf.TTL).Unix(),
ExpiresAt: time.Now().Add(ttl).Unix(),
}).MarshalJSON()
if err != nil {
return "", err
Expand All @@ -138,7 +172,7 @@ func (v *_jwt) Sign(payload interface{}) (string, error) {
}
result += "." + base64.StdEncoding.EncodeToString(p)

s, err := v.calcHash([]byte(result))
s, err := v.calcHash(pool.hash, pool.key, []byte(result))
if err != nil {
return "", err
}
Expand All @@ -147,39 +181,6 @@ func (v *_jwt) Sign(payload interface{}) (string, error) {
return result, nil
}

func (v *_jwt) Extend(token string) (string, error) {
data := strings.Split(token, ".")
if len(data) != 3 {
return "", fmt.Errorf("invalid jwt format")
}

h, err := base64.StdEncoding.DecodeString(data[0])
if err != nil {
return "", err
}
header := &JWTHeader{}
if err = header.UnmarshalJSON(h); err != nil {
return "", err
}

header.ExpiresAt = time.Now().Add(v.conf.TTL).Unix()

h, err = header.MarshalJSON()
if err != nil {
return "", err
}
data[0] = base64.StdEncoding.EncodeToString(h)

sig, err := v.calcHash([]byte(data[0] + "." + data[1]))
if err != nil {
return "", err
}

data[2] = base64.StdEncoding.EncodeToString(sig)

return strings.Join(data, "."), nil
}

func (v *_jwt) Verify(token string, payload interface{}) (*JWTHeader, error) {
data := strings.Split(token, ".")
if len(data) != 3 {
Expand All @@ -195,7 +196,12 @@ func (v *_jwt) Verify(token string, payload interface{}) (*JWTHeader, error) {
return nil, err
}

if header.Alg != v.conf.Algorithm {
pool, err := v.getPool(header.Kid)
if err != nil {
return nil, err
}

if header.Alg != pool.conf.Algorithm {
return nil, fmt.Errorf("invalid jwt algorithm")
}
if header.ExpiresAt < time.Now().Unix() {
Expand All @@ -206,7 +212,7 @@ func (v *_jwt) Verify(token string, payload interface{}) (*JWTHeader, error) {
if err != nil {
return nil, err
}
actual, err := v.calcHash([]byte(data[0] + "." + data[1]))
actual, err := v.calcHash(pool.hash, pool.key, []byte(data[0]+"."+data[1]))
if err != nil {
return nil, err
}
Expand All @@ -225,7 +231,3 @@ func (v *_jwt) Verify(token string, payload interface{}) (*JWTHeader, error) {

return header, nil
}

func (v *_jwt) CookieName() string {
return v.conf.Cookie
}
9 changes: 8 additions & 1 deletion plugins/auth/jwt_easyjson.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 2 additions & 7 deletions plugins/auth/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,16 @@ func TestUnit_newJWT(t *testing.T) {
require.NoError(t, err)

payload1 := demoJwtPayload{ID: 159}
token, err := j.Sign(&payload1)
token, err := j.Sign(&payload1, time.Hour)
require.NoError(t, err)

payload2 := demoJwtPayload{}
head1, err := j.Verify(token, &payload2)
require.NoError(t, err)

require.Equal(t, payload1, payload2)
<-time.After(time.Second)

token, err = j.Extend(token)
require.NoError(t, err)

head2, err := j.Verify(token, &payload2)
require.NoError(t, err)

require.NotEqual(t, head1, head2)
require.Equal(t, head1, head2)
}
Loading

0 comments on commit f7aefc7

Please sign in to comment.