Skip to content

Commit

Permalink
Merge pull request #441 from uselagoon/db-groups
Browse files Browse the repository at this point in the history
Use Lagoon API DB to determine project group membership
  • Loading branch information
smlx authored May 10, 2024
2 parents e3ee154 + b66df8a commit f1df832
Show file tree
Hide file tree
Showing 16 changed files with 596 additions and 82 deletions.
61 changes: 61 additions & 0 deletions internal/cache/cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Package cache implements a generic, thread-safe, in-memory cache.
package cache

import (
"sync"
"time"
)

const (
defaultTTL = time.Minute
)

// Cache is a generic, thread-safe, in-memory cache that stores a value with a
// TTL, after which the cache expires.
type Cache[T any] struct {
data T
expiry time.Time
ttl time.Duration
mu sync.Mutex
}

// Option is a functional option argument to NewCache().
type Option[T any] func(*Cache[T])

// WithTTL sets the the Cache time-to-live to ttl.
func WithTTL[T any](ttl time.Duration) Option[T] {
return func(c *Cache[T]) {
c.ttl = ttl
}
}

// NewCache instantiates a Cache for type T with a default TTL of 1 minute.
func NewCache[T any](options ...Option[T]) *Cache[T] {
c := Cache[T]{
ttl: defaultTTL,
}
for _, option := range options {
option(&c)
}
return &c
}

// Set updates the value in the cache and sets the expiry to now+TTL.
func (c *Cache[T]) Set(value T) {
c.mu.Lock()
defer c.mu.Unlock()
c.data = value
c.expiry = time.Now().Add(c.ttl)
}

// Get retrieves the value from the cache. If cache has expired, the second
// return value will be false.
func (c *Cache[T]) Get() (T, bool) {
c.mu.Lock()
defer c.mu.Unlock()
if time.Now().After(c.expiry) {
var zero T
return zero, false
}
return c.data, true
}
69 changes: 69 additions & 0 deletions internal/cache/cache_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package cache_test

import (
"testing"
"time"

"github.com/alecthomas/assert/v2"
"github.com/uselagoon/ssh-portal/internal/cache"
)

func TestIntCache(t *testing.T) {
var testCases = map[string]struct {
input int
expect int
expired bool
}{
"not expired": {input: 11, expect: 11},
"expired": {input: 11, expired: true},
}
for name, tc := range testCases {
t.Run(name, func(tt *testing.T) {
c := cache.NewCache[int](cache.WithTTL[int](time.Second))
c.Set(tc.input)
if tc.expired {
time.Sleep(2 * time.Second)
_, ok := c.Get()
assert.False(tt, ok, name)
} else {
value, ok := c.Get()
assert.True(tt, ok, name)
assert.Equal(tt, tc.expect, value, name)
}
})
}
}

func TestMapCache(t *testing.T) {
var testCases = map[string]struct {
input map[string]string
expect map[string]string
expired bool
}{
"expired": {
input: map[string]string{"foo": "bar"},
expired: true,
},
"not expired": {
input: map[string]string{"foo": "bar"},
expect: map[string]string{"foo": "bar"},
},
}
for name, tc := range testCases {
t.Run(name, func(tt *testing.T) {
c := cache.NewCache[map[string]string](
cache.WithTTL[map[string]string](time.Second),
)
c.Set(tc.input)
if tc.expired {
time.Sleep(2 * time.Second)
_, ok := c.Get()
assert.False(tt, ok, name)
} else {
value, ok := c.Get()
assert.True(tt, ok, name)
assert.Equal(tt, tc.expect, value, name)
}
})
}
}
10 changes: 9 additions & 1 deletion internal/keycloak/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"time"

"github.com/MicahParks/keyfunc/v2"
"github.com/uselagoon/ssh-portal/internal/cache"
oidcClient "github.com/zitadel/oidc/v3/pkg/client"
"github.com/zitadel/oidc/v3/pkg/oidc"
"golang.org/x/time/rate"
Expand All @@ -21,23 +22,28 @@ const pkgName = "github.com/uselagoon/ssh-portal/internal/keycloak"

// Client is a keycloak client.
type Client struct {
baseURL *url.URL
clientID string
clientSecret string
jwks *keyfunc.JWKS
log *slog.Logger
oidcConfig *oidc.DiscoveryConfiguration
limiter *rate.Limiter

// groupNameGroupIDMap cache
groupCache *cache.Cache[map[string]string]
}

// NewClient creates a new keycloak client for the lagoon realm.
func NewClient(ctx context.Context, log *slog.Logger, keycloakURL, clientID,
clientSecret string, rateLimit int) (*Client, error) {
// discover OIDC config
issuerURL, err := url.Parse(keycloakURL)
baseURL, err := url.Parse(keycloakURL)
if err != nil {
return nil, fmt.Errorf("couldn't parse keycloak base URL %s: %v",
keycloakURL, err)
}
issuerURL := *baseURL
issuerURL.Path = path.Join(issuerURL.Path, "auth/realms/lagoon")
oidcConfig, err := oidcClient.Discover(ctx, issuerURL.String(),
&http.Client{Timeout: 8 * time.Second})
Expand All @@ -50,11 +56,13 @@ func NewClient(ctx context.Context, log *slog.Logger, keycloakURL, clientID,
return nil, fmt.Errorf("couldn't get keycloak lagoon realm JWKS: %v", err)
}
return &Client{
baseURL: baseURL,
clientID: clientID,
clientSecret: clientSecret,
jwks: jwks,
log: log,
oidcConfig: oidcConfig,
limiter: rate.NewLimiter(rate.Limit(rateLimit), rateLimit),
groupCache: cache.NewCache[map[string]string](),
}, nil
}
82 changes: 82 additions & 0 deletions internal/keycloak/groups.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package keycloak

import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"path"

"golang.org/x/oauth2/clientcredentials"
)

// Group represents a Keycloak Group. It holds the fields required when getting
// a list of groups from keycloak.
type Group struct {
ID string `json:"id"`
Name string `json:"name"`
}

func (c *Client) httpClient(ctx context.Context) *http.Client {
cc := clientcredentials.Config{
ClientID: c.clientID,
ClientSecret: c.clientSecret,
TokenURL: c.oidcConfig.TokenEndpoint,
}
return cc.Client(ctx)
}

// rawGroups returns the raw JSON group representation from the Keycloak API.
func (c *Client) rawGroups(ctx context.Context) ([]byte, error) {
groupsURL := *c.baseURL
groupsURL.Path = path.Join(c.baseURL.Path,
"/auth/admin/realms/lagoon/groups")
req, err := http.NewRequestWithContext(ctx, "GET", groupsURL.String(), nil)
if err != nil {
return nil, fmt.Errorf("couldn't construct groups request: %v", err)
}
q := req.URL.Query()
q.Add("briefRepresentation", "true")
req.URL.RawQuery = q.Encode()
res, err := c.httpClient(ctx).Do(req)
if err != nil {
return nil, fmt.Errorf("couldn't get groups: %v", err)
}
defer res.Body.Close()
if res.StatusCode > 299 {
body, _ := io.ReadAll(res.Body)
return nil, fmt.Errorf("bad groups response: %d\n%s", res.StatusCode, body)
}
return io.ReadAll(res.Body)
}

// GroupNameGroupIDMap returns a map of Keycloak Group names to Group IDs.
func (c *Client) GroupNameGroupIDMap(
ctx context.Context,
) (map[string]string, error) {
// rate limit keycloak API access
if err := c.limiter.Wait(ctx); err != nil {
return nil, fmt.Errorf("couldn't wait for limiter: %v", err)
}
// prefer to use cached value
if groupNameGroupIDMap, ok := c.groupCache.Get(); ok {
return groupNameGroupIDMap, nil
}
// otherwise get data from keycloak
data, err := c.rawGroups(ctx)
if err != nil {
return nil, fmt.Errorf("couldn't get groups from Keycloak API: %v", err)
}
var groups []Group
if err := json.Unmarshal(data, &groups); err != nil {
return nil, fmt.Errorf("couldn't unmarshal Keycloak groups: %v", err)
}
groupNameGroupIDMap := map[string]string{}
for _, group := range groups {
groupNameGroupIDMap[group.Name] = group.ID
}
// update cache
c.groupCache.Set(groupNameGroupIDMap)
return groupNameGroupIDMap, nil
}
36 changes: 3 additions & 33 deletions internal/keycloak/jwt.go
Original file line number Diff line number Diff line change
@@ -1,47 +1,17 @@
package keycloak

import (
"encoding/json"
"fmt"

"github.com/golang-jwt/jwt/v5"
"golang.org/x/oauth2"
)

type groupProjectIDs map[string][]int

func (gpids *groupProjectIDs) UnmarshalJSON(data []byte) error {
// unmarshal the double-encoded group-pid attributes
var gpas []string
if err := json.Unmarshal(data, &gpas); err != nil {
return err
}
// convert the slice of encoded group-pid attributes into a slice of
// group-pid maps
var gpms []map[string][]int
for _, gpa := range gpas {
var gpm map[string][]int
if err := json.Unmarshal([]byte(gpa), &gpm); err != nil {
return err
}
gpms = append(gpms, gpm)
}
// flatten the slice of group-pid maps into a single map
*gpids = groupProjectIDs{}
for _, gpm := range gpms {
for k, v := range gpm {
(*gpids)[k] = v
}
}
return nil
}

// LagoonClaims contains the token claims used by Lagoon.
type LagoonClaims struct {
RealmRoles []string `json:"realm_roles"`
UserGroups []string `json:"group_membership"`
GroupProjectIDs groupProjectIDs `json:"group_lagoon_project_ids"`
AuthorizedParty string `json:"azp"`
RealmRoles []string `json:"realm_roles"`
UserGroups []string `json:"group_membership"`
AuthorizedParty string `json:"azp"`
jwt.RegisteredClaims

clientID string `json:"-"`
Expand Down
14 changes: 2 additions & 12 deletions internal/keycloak/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,8 @@ func TestUnmarshalLagoonClaims(t *testing.T) {
"{\"credentialtest-group1\":[1]}",
"{\"ci-group\":[3,4,5,6,7,8,9,10,11,12,17,14,16,20,21,24,19,23,31]}"]}`),
expect: &keycloak.LagoonClaims{
RealmRoles: nil,
UserGroups: nil,
GroupProjectIDs: map[string][]int{
"credentialtest-group1": {1},
"ci-group": {3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 17, 14, 16, 20, 21, 24,
19, 23, 31},
},
RealmRoles: nil,
UserGroups: nil,
RegisteredClaims: jwt.RegisteredClaims{},
},
},
Expand Down Expand Up @@ -97,11 +92,6 @@ func TestUnmarshalLagoonClaims(t *testing.T) {
UserGroups: []string{
"/ci-group/ci-group-owner",
"/credentialtest-group1/credentialtest-group1-owner"},
GroupProjectIDs: map[string][]int{
"credentialtest-group1": {1},
"ci-group": {3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 17, 14, 16, 20, 21, 24,
19, 23, 31},
},
AuthorizedParty: "service-api",
RegisteredClaims: jwt.RegisteredClaims{
ID: "ba279e79-4f38-43ae-83e7-fe461aad59d1",
Expand Down
13 changes: 6 additions & 7 deletions internal/keycloak/userrolesandgroups.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,15 @@ import (
)

// UserRolesAndGroups queries Keycloak given the user UUID, and returns the
// user's realm roles, group memberships, and the project IDs associated with
// those groups.
// user's realm roles, and group memberships (by name, including subgroups).
func (c *Client) UserRolesAndGroups(ctx context.Context,
userUUID *uuid.UUID) ([]string, []string, map[string][]int, error) {
userUUID *uuid.UUID) ([]string, []string, error) {
// set up tracing
ctx, span := otel.Tracer(pkgName).Start(ctx, "UserRolesAndGroups")
defer span.End()
// rate limit keycloak API access
if err := c.limiter.Wait(ctx); err != nil {
return nil, nil, nil, fmt.Errorf("couldn't wait for limiter: %v", err)
return nil, nil, fmt.Errorf("couldn't wait for limiter: %v", err)
}
// get user token
userConfig := oauth2.Config{
Expand All @@ -41,12 +40,12 @@ func (c *Client) UserRolesAndGroups(ctx context.Context,
// https://www.keycloak.org/docs/latest/securing_apps/#_token-exchange
oauth2.SetAuthURLParam("requested_subject", userUUID.String()))
if err != nil {
return nil, nil, nil, fmt.Errorf("couldn't get user token: %v", err)
return nil, nil, fmt.Errorf("couldn't get user token: %v", err)
}
// parse and extract verified attributes
claims, err := c.parseAccessToken(userToken, userUUID.String())
if err != nil {
return nil, nil, nil, fmt.Errorf("couldn't parse user access token: %v", err)
return nil, nil, fmt.Errorf("couldn't parse user access token: %v", err)
}
return claims.RealmRoles, claims.UserGroups, claims.GroupProjectIDs, nil
return claims.RealmRoles, claims.UserGroups, nil
}
Loading

0 comments on commit f1df832

Please sign in to comment.