Skip to content

Commit

Permalink
chore: synchronize workspaces
Browse files Browse the repository at this point in the history
  • Loading branch information
aeneasr committed Jul 11, 2024
1 parent 8b7a470 commit d6c955d
Show file tree
Hide file tree
Showing 11 changed files with 124 additions and 24 deletions.
40 changes: 16 additions & 24 deletions handler/oauth2/strategy_hmacsha.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ package oauth2

import (
"context"
"fmt"
"strings"
"time"

"github.com/ory/x/errorsx"

"github.com/ory/fosite"
enigma "github.com/ory/fosite/token/hmac"
_ "unsafe"
)

type HMACSHAStrategy struct {
Expand All @@ -22,36 +22,28 @@ type HMACSHAStrategy struct {
fosite.RefreshTokenLifespanProvider
fosite.AuthorizeCodeLifespanProvider
}
prefix *string
}

type HMACPrefixFunc func(ctx context.Context, h *HMACSHAStrategy, part string) string

func (h *HMACSHAStrategy) AccessTokenSignature(ctx context.Context, token string) string {
return h.Enigma.Signature(token)
}

func (h *HMACSHAStrategy) RefreshTokenSignature(ctx context.Context, token string) string {
return h.Enigma.Signature(token)
}

func (h *HMACSHAStrategy) AuthorizeCodeSignature(ctx context.Context, token string) string {
return h.Enigma.Signature(token)
}

func (h *HMACSHAStrategy) getPrefix(part string) string {
if h.prefix == nil {
prefix := "ory_%s_"
h.prefix = &prefix
} else if len(*h.prefix) == 0 {
return ""
}

return fmt.Sprintf(*h.prefix, part)
}

func (h *HMACSHAStrategy) trimPrefix(token, part string) string {
return strings.TrimPrefix(token, h.getPrefix(part))
func (h *HMACSHAStrategy) trimPrefix(ctx context.Context, token, part string) string {
return strings.TrimPrefix(token, getPrefix(ctx, h, part))
}

func (h *HMACSHAStrategy) setPrefix(token, part string) string {
return h.getPrefix(part) + token
func (h *HMACSHAStrategy) setPrefix(ctx context.Context, token, part string) string {
return getPrefix(ctx, h, part) + token
}

func (h *HMACSHAStrategy) GenerateAccessToken(ctx context.Context, _ fosite.Requester) (token string, signature string, err error) {
Expand All @@ -60,7 +52,7 @@ func (h *HMACSHAStrategy) GenerateAccessToken(ctx context.Context, _ fosite.Requ
return "", "", err
}

return h.setPrefix(token, "at"), sig, nil
return h.setPrefix(ctx, token, "at"), sig, nil
}

func (h *HMACSHAStrategy) ValidateAccessToken(ctx context.Context, r fosite.Requester, token string) (err error) {
Expand All @@ -73,7 +65,7 @@ func (h *HMACSHAStrategy) ValidateAccessToken(ctx context.Context, r fosite.Requ
return errorsx.WithStack(fosite.ErrTokenExpired.WithHintf("Access token expired at '%s'.", exp))
}

return h.Enigma.Validate(ctx, h.trimPrefix(token, "at"))
return h.Enigma.Validate(ctx, h.trimPrefix(ctx, token, "at"))
}

func (h *HMACSHAStrategy) GenerateRefreshToken(ctx context.Context, _ fosite.Requester) (token string, signature string, err error) {
Expand All @@ -82,21 +74,21 @@ func (h *HMACSHAStrategy) GenerateRefreshToken(ctx context.Context, _ fosite.Req
return "", "", err
}

return h.setPrefix(token, "rt"), sig, nil
return h.setPrefix(ctx, token, "rt"), sig, nil
}

func (h *HMACSHAStrategy) ValidateRefreshToken(ctx context.Context, r fosite.Requester, token string) (err error) {
var exp = r.GetSession().GetExpiresAt(fosite.RefreshToken)
if exp.IsZero() {
// Unlimited lifetime
return h.Enigma.Validate(ctx, h.trimPrefix(token, "rt"))
return h.Enigma.Validate(ctx, h.trimPrefix(ctx, token, "rt"))
}

if !exp.IsZero() && exp.Before(time.Now().UTC()) {
return errorsx.WithStack(fosite.ErrTokenExpired.WithHintf("Refresh token expired at '%s'.", exp))
}

return h.Enigma.Validate(ctx, h.trimPrefix(token, "rt"))
return h.Enigma.Validate(ctx, h.trimPrefix(ctx, token, "rt"))
}

func (h *HMACSHAStrategy) GenerateAuthorizeCode(ctx context.Context, _ fosite.Requester) (token string, signature string, err error) {
Expand All @@ -105,7 +97,7 @@ func (h *HMACSHAStrategy) GenerateAuthorizeCode(ctx context.Context, _ fosite.Re
return "", "", err
}

return h.setPrefix(token, "ac"), sig, nil
return h.setPrefix(ctx, token, "ac"), sig, nil
}

func (h *HMACSHAStrategy) ValidateAuthorizeCode(ctx context.Context, r fosite.Requester, token string) (err error) {
Expand All @@ -118,5 +110,5 @@ func (h *HMACSHAStrategy) ValidateAuthorizeCode(ctx context.Context, r fosite.Re
return errorsx.WithStack(fosite.ErrTokenExpired.WithHintf("Authorize code expired at '%s'.", exp))
}

return h.Enigma.Validate(ctx, h.trimPrefix(token, "ac"))
return h.Enigma.Validate(ctx, h.trimPrefix(ctx, token, "ac"))
}
14 changes: 14 additions & 0 deletions handler/oauth2/unsafe.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package oauth2

import (
"context"
"fmt"
_ "unsafe"
)

var _ HMACPrefixFunc = getPrefix

//go:linkname getPrefix
func getPrefix(ctx context.Context, h *HMACSHAStrategy, part string) string {
return fmt.Sprintf("ory_%s_", part)
}
12 changes: 12 additions & 0 deletions linkname/a/foo.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package a

import (
_ "unsafe"
)

func Foo() string {
return Fooer()
}

//go:linkname Fooer github.com/ory/fosite/linkname/a.Fooer
func Fooer() string
12 changes: 12 additions & 0 deletions linkname/a/foo_i.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
//go:build !oel

package a

import (
_ "unsafe"
)

//go:linkname Fooer2 github.com/ory/fosite/linkname/a.Fooer
func Fooer2() string {
return "foo"
}
1 change: 1 addition & 0 deletions linkname/b/b.go
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
package b
17 changes: 17 additions & 0 deletions linkname/b/bar.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
//go:build oel

package b

import (
"github.com/ory/fosite/linkname/a"
_ "unsafe"
)

//go:linkname Fooer2 github.com/ory/fosite/linkname/a.Fooer
func Fooer2() string {
return "bar"
}

func Foo() string {
return a.Foo()
}
22 changes: 22 additions & 0 deletions linkname/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package main

import (
"fmt"
"time"
_ "unsafe"
)

//go:linkname pf fmt.Printf
func pf(format string, a ...any) (n int, err error) {
panic("")
return 0, nil
}

//go:linkname timeNow time.Now
func timeNow() time.Time {
return time.Date(2040, 1, 1, 0, 0, 0, 0, time.UTC)
}

func main() {
fmt.Printf("now: %v", time.Now())
}
8 changes: 8 additions & 0 deletions linkname/v2/a/a.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package a

import _ "unsafe"

// fooer is an internal function in package a
func fooer() string {
return "foo"
}
9 changes: 9 additions & 0 deletions linkname/v2/b/bb.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package b

import (
_ "unsafe"
_ "v2/a"
) // required for go:linkname

//go:linkname Fooer v2/a.fooer
func Fooer() string
3 changes: 3 additions & 0 deletions linkname/v2/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
module v2

go 1.22
10 changes: 10 additions & 0 deletions linkname/v2/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package main

import (
"fmt"
"v2/b" // Import the overriding package to ensure it is included
)

func main() {
fmt.Printf(b.Fooer()) // This should call the overridden function from package b
}

0 comments on commit d6c955d

Please sign in to comment.