diff --git a/internal/cache/cache.go b/internal/cache/cache.go new file mode 100644 index 00000000..13db1887 --- /dev/null +++ b/internal/cache/cache.go @@ -0,0 +1,61 @@ +// Package cache implements a generic, thread-safe, in-memory cache. +package cache + +import ( + "sync" + "time" +) + +const ( + defaultTTL = time.Minute +) + +// Cache is a generic, thread-safe, in-memory cache that stores a value with a +// TTL, after which the cache expires. +type Cache[T any] struct { + data T + expiry time.Time + ttl time.Duration + mu sync.Mutex +} + +// Option is a functional option argument to NewCache(). +type Option[T any] func(*Cache[T]) + +// WithTTL sets the the Cache time-to-live to ttl. +func WithTTL[T any](ttl time.Duration) Option[T] { + return func(c *Cache[T]) { + c.ttl = ttl + } +} + +// NewCache instantiates a Cache for type T with a default TTL of 1 minute. +func NewCache[T any](options ...Option[T]) *Cache[T] { + c := Cache[T]{ + ttl: defaultTTL, + } + for _, option := range options { + option(&c) + } + return &c +} + +// Set updates the value in the cache and sets the expiry to now+TTL. +func (c *Cache[T]) Set(value T) { + c.mu.Lock() + defer c.mu.Unlock() + c.data = value + c.expiry = time.Now().Add(c.ttl) +} + +// Get retrieves the value from the cache. If cache has expired, the second +// return value will be false. +func (c *Cache[T]) Get() (T, bool) { + c.mu.Lock() + defer c.mu.Unlock() + if time.Now().After(c.expiry) { + var zero T + return zero, false + } + return c.data, true +} diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go new file mode 100644 index 00000000..d614b0d7 --- /dev/null +++ b/internal/cache/cache_test.go @@ -0,0 +1,69 @@ +package cache_test + +import ( + "testing" + "time" + + "github.com/alecthomas/assert/v2" + "github.com/uselagoon/ssh-portal/internal/cache" +) + +func TestIntCache(t *testing.T) { + var testCases = map[string]struct { + input int + expect int + expired bool + }{ + "not expired": {input: 11, expect: 11}, + "expired": {input: 11, expired: true}, + } + for name, tc := range testCases { + t.Run(name, func(tt *testing.T) { + c := cache.NewCache[int](cache.WithTTL[int](time.Second)) + c.Set(tc.input) + if tc.expired { + time.Sleep(2 * time.Second) + _, ok := c.Get() + assert.False(tt, ok, name) + } else { + value, ok := c.Get() + assert.True(tt, ok, name) + assert.Equal(tt, tc.expect, value, name) + } + }) + } +} + +func TestMapCache(t *testing.T) { + var testCases = map[string]struct { + input map[string]string + expect map[string]string + expired bool + }{ + "expired": { + input: map[string]string{"foo": "bar"}, + expired: true, + }, + "not expired": { + input: map[string]string{"foo": "bar"}, + expect: map[string]string{"foo": "bar"}, + }, + } + for name, tc := range testCases { + t.Run(name, func(tt *testing.T) { + c := cache.NewCache[map[string]string]( + cache.WithTTL[map[string]string](time.Second), + ) + c.Set(tc.input) + if tc.expired { + time.Sleep(2 * time.Second) + _, ok := c.Get() + assert.False(tt, ok, name) + } else { + value, ok := c.Get() + assert.True(tt, ok, name) + assert.Equal(tt, tc.expect, value, name) + } + }) + } +} diff --git a/internal/keycloak/client.go b/internal/keycloak/client.go index 1d38f55f..197f4370 100644 --- a/internal/keycloak/client.go +++ b/internal/keycloak/client.go @@ -12,6 +12,7 @@ import ( "time" "github.com/MicahParks/keyfunc/v2" + "github.com/uselagoon/ssh-portal/internal/cache" oidcClient "github.com/zitadel/oidc/v3/pkg/client" "github.com/zitadel/oidc/v3/pkg/oidc" "golang.org/x/time/rate" @@ -21,23 +22,28 @@ const pkgName = "github.com/uselagoon/ssh-portal/internal/keycloak" // Client is a keycloak client. type Client struct { + baseURL *url.URL clientID string clientSecret string jwks *keyfunc.JWKS log *slog.Logger oidcConfig *oidc.DiscoveryConfiguration limiter *rate.Limiter + + // groupNameGroupIDMap cache + groupCache *cache.Cache[map[string]string] } // NewClient creates a new keycloak client for the lagoon realm. func NewClient(ctx context.Context, log *slog.Logger, keycloakURL, clientID, clientSecret string, rateLimit int) (*Client, error) { // discover OIDC config - issuerURL, err := url.Parse(keycloakURL) + baseURL, err := url.Parse(keycloakURL) if err != nil { return nil, fmt.Errorf("couldn't parse keycloak base URL %s: %v", keycloakURL, err) } + issuerURL := *baseURL issuerURL.Path = path.Join(issuerURL.Path, "auth/realms/lagoon") oidcConfig, err := oidcClient.Discover(ctx, issuerURL.String(), &http.Client{Timeout: 8 * time.Second}) @@ -50,11 +56,13 @@ func NewClient(ctx context.Context, log *slog.Logger, keycloakURL, clientID, return nil, fmt.Errorf("couldn't get keycloak lagoon realm JWKS: %v", err) } return &Client{ + baseURL: baseURL, clientID: clientID, clientSecret: clientSecret, jwks: jwks, log: log, oidcConfig: oidcConfig, limiter: rate.NewLimiter(rate.Limit(rateLimit), rateLimit), + groupCache: cache.NewCache[map[string]string](), }, nil } diff --git a/internal/keycloak/groups.go b/internal/keycloak/groups.go new file mode 100644 index 00000000..f68f35f7 --- /dev/null +++ b/internal/keycloak/groups.go @@ -0,0 +1,82 @@ +package keycloak + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "path" + + "golang.org/x/oauth2/clientcredentials" +) + +// Group represents a Keycloak Group. It holds the fields required when getting +// a list of groups from keycloak. +type Group struct { + ID string `json:"id"` + Name string `json:"name"` +} + +func (c *Client) httpClient(ctx context.Context) *http.Client { + cc := clientcredentials.Config{ + ClientID: c.clientID, + ClientSecret: c.clientSecret, + TokenURL: c.oidcConfig.TokenEndpoint, + } + return cc.Client(ctx) +} + +// rawGroups returns the raw JSON group representation from the Keycloak API. +func (c *Client) rawGroups(ctx context.Context) ([]byte, error) { + groupsURL := *c.baseURL + groupsURL.Path = path.Join(c.baseURL.Path, + "/auth/admin/realms/lagoon/groups") + req, err := http.NewRequestWithContext(ctx, "GET", groupsURL.String(), nil) + if err != nil { + return nil, fmt.Errorf("couldn't construct groups request: %v", err) + } + q := req.URL.Query() + q.Add("briefRepresentation", "true") + req.URL.RawQuery = q.Encode() + res, err := c.httpClient(ctx).Do(req) + if err != nil { + return nil, fmt.Errorf("couldn't get groups: %v", err) + } + defer res.Body.Close() + if res.StatusCode > 299 { + body, _ := io.ReadAll(res.Body) + return nil, fmt.Errorf("bad groups response: %d\n%s", res.StatusCode, body) + } + return io.ReadAll(res.Body) +} + +// GroupNameGroupIDMap returns a map of Keycloak Group names to Group IDs. +func (c *Client) GroupNameGroupIDMap( + ctx context.Context, +) (map[string]string, error) { + // rate limit keycloak API access + if err := c.limiter.Wait(ctx); err != nil { + return nil, fmt.Errorf("couldn't wait for limiter: %v", err) + } + // prefer to use cached value + if groupNameGroupIDMap, ok := c.groupCache.Get(); ok { + return groupNameGroupIDMap, nil + } + // otherwise get data from keycloak + data, err := c.rawGroups(ctx) + if err != nil { + return nil, fmt.Errorf("couldn't get groups from Keycloak API: %v", err) + } + var groups []Group + if err := json.Unmarshal(data, &groups); err != nil { + return nil, fmt.Errorf("couldn't unmarshal Keycloak groups: %v", err) + } + groupNameGroupIDMap := map[string]string{} + for _, group := range groups { + groupNameGroupIDMap[group.Name] = group.ID + } + // update cache + c.groupCache.Set(groupNameGroupIDMap) + return groupNameGroupIDMap, nil +} diff --git a/internal/keycloak/jwt.go b/internal/keycloak/jwt.go index 340f0d62..8dd46eec 100644 --- a/internal/keycloak/jwt.go +++ b/internal/keycloak/jwt.go @@ -1,47 +1,17 @@ package keycloak import ( - "encoding/json" "fmt" "github.com/golang-jwt/jwt/v5" "golang.org/x/oauth2" ) -type groupProjectIDs map[string][]int - -func (gpids *groupProjectIDs) UnmarshalJSON(data []byte) error { - // unmarshal the double-encoded group-pid attributes - var gpas []string - if err := json.Unmarshal(data, &gpas); err != nil { - return err - } - // convert the slice of encoded group-pid attributes into a slice of - // group-pid maps - var gpms []map[string][]int - for _, gpa := range gpas { - var gpm map[string][]int - if err := json.Unmarshal([]byte(gpa), &gpm); err != nil { - return err - } - gpms = append(gpms, gpm) - } - // flatten the slice of group-pid maps into a single map - *gpids = groupProjectIDs{} - for _, gpm := range gpms { - for k, v := range gpm { - (*gpids)[k] = v - } - } - return nil -} - // LagoonClaims contains the token claims used by Lagoon. type LagoonClaims struct { - RealmRoles []string `json:"realm_roles"` - UserGroups []string `json:"group_membership"` - GroupProjectIDs groupProjectIDs `json:"group_lagoon_project_ids"` - AuthorizedParty string `json:"azp"` + RealmRoles []string `json:"realm_roles"` + UserGroups []string `json:"group_membership"` + AuthorizedParty string `json:"azp"` jwt.RegisteredClaims clientID string `json:"-"` diff --git a/internal/keycloak/jwt_test.go b/internal/keycloak/jwt_test.go index af125d81..1b0c2870 100644 --- a/internal/keycloak/jwt_test.go +++ b/internal/keycloak/jwt_test.go @@ -29,13 +29,8 @@ func TestUnmarshalLagoonClaims(t *testing.T) { "{\"credentialtest-group1\":[1]}", "{\"ci-group\":[3,4,5,6,7,8,9,10,11,12,17,14,16,20,21,24,19,23,31]}"]}`), expect: &keycloak.LagoonClaims{ - RealmRoles: nil, - UserGroups: nil, - GroupProjectIDs: map[string][]int{ - "credentialtest-group1": {1}, - "ci-group": {3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 17, 14, 16, 20, 21, 24, - 19, 23, 31}, - }, + RealmRoles: nil, + UserGroups: nil, RegisteredClaims: jwt.RegisteredClaims{}, }, }, @@ -97,11 +92,6 @@ func TestUnmarshalLagoonClaims(t *testing.T) { UserGroups: []string{ "/ci-group/ci-group-owner", "/credentialtest-group1/credentialtest-group1-owner"}, - GroupProjectIDs: map[string][]int{ - "credentialtest-group1": {1}, - "ci-group": {3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 17, 14, 16, 20, 21, 24, - 19, 23, 31}, - }, AuthorizedParty: "service-api", RegisteredClaims: jwt.RegisteredClaims{ ID: "ba279e79-4f38-43ae-83e7-fe461aad59d1", diff --git a/internal/keycloak/userrolesandgroups.go b/internal/keycloak/userrolesandgroups.go index 68bc7f6d..d310f97f 100644 --- a/internal/keycloak/userrolesandgroups.go +++ b/internal/keycloak/userrolesandgroups.go @@ -12,16 +12,15 @@ import ( ) // UserRolesAndGroups queries Keycloak given the user UUID, and returns the -// user's realm roles, group memberships, and the project IDs associated with -// those groups. +// user's realm roles, and group memberships (by name, including subgroups). func (c *Client) UserRolesAndGroups(ctx context.Context, - userUUID *uuid.UUID) ([]string, []string, map[string][]int, error) { + userUUID *uuid.UUID) ([]string, []string, error) { // set up tracing ctx, span := otel.Tracer(pkgName).Start(ctx, "UserRolesAndGroups") defer span.End() // rate limit keycloak API access if err := c.limiter.Wait(ctx); err != nil { - return nil, nil, nil, fmt.Errorf("couldn't wait for limiter: %v", err) + return nil, nil, fmt.Errorf("couldn't wait for limiter: %v", err) } // get user token userConfig := oauth2.Config{ @@ -41,12 +40,12 @@ func (c *Client) UserRolesAndGroups(ctx context.Context, // https://www.keycloak.org/docs/latest/securing_apps/#_token-exchange oauth2.SetAuthURLParam("requested_subject", userUUID.String())) if err != nil { - return nil, nil, nil, fmt.Errorf("couldn't get user token: %v", err) + return nil, nil, fmt.Errorf("couldn't get user token: %v", err) } // parse and extract verified attributes claims, err := c.parseAccessToken(userToken, userUUID.String()) if err != nil { - return nil, nil, nil, fmt.Errorf("couldn't parse user access token: %v", err) + return nil, nil, fmt.Errorf("couldn't parse user access token: %v", err) } - return claims.RealmRoles, claims.UserGroups, claims.GroupProjectIDs, nil + return claims.RealmRoles, claims.UserGroups, nil } diff --git a/internal/lagoon/groupnameprojectids.go b/internal/lagoon/groupnameprojectids.go new file mode 100644 index 00000000..84de620a --- /dev/null +++ b/internal/lagoon/groupnameprojectids.go @@ -0,0 +1,75 @@ +// Package lagoon provides Lagoon-specific functionality which doesn't fit +// cleanly into the other service packages such as Keycloak or Lagoon DB. +package lagoon + +import ( + "context" + "fmt" + "strings" +) + +// DBService provides methods for querying the Lagoon API DB. +type DBService interface { + GroupIDProjectIDsMap(context.Context) (map[string][]int, error) +} + +// KeycloakService provides methods for querying the Keycloak API. +type KeycloakService interface { + GroupNameGroupIDMap(context.Context) (map[string]string, error) +} + +// given a nested user group name like "/foo-bar/foo-bar-owner", sanity check +// the format and return the top-level group name (between the separators). +func groupNameFromUserGroup(userGroup string) (string, error) { + parts := strings.Split(userGroup, `/`) + switch { + case len(parts) != 3: + return "", fmt.Errorf(`unknown user group format: %v`, userGroup) + case len(parts[0]) != 0: + return "", fmt.Errorf(`missing leading "/": %v`, userGroup) + case len(parts[1]) == 0: + return "", fmt.Errorf(`missing group name: %v`, userGroup) + case len(parts[2]) == 0: + return "", fmt.Errorf(`missing subgroup name: %v`, userGroup) + default: + return parts[1], nil + } +} + +// GroupNameProjectIDsMap generates a map of group names to project IDs for the +// groups the user is a member of. userGroups should be a slice of groups +// including subgroups in the format returned from +func GroupNameProjectIDsMap( + ctx context.Context, + ldb DBService, + k KeycloakService, + userGroups []string, +) (map[string][]int, error) { + // get the map of group names to group IDs + groupNameGroupIDMap, err := k.GroupNameGroupIDMap(ctx) + if err != nil { + return nil, fmt.Errorf("couldn't query keycloak groups: %v", err) + } + // get the group -> project memberships + groupIDProjectIDsMap, err := ldb.GroupIDProjectIDsMap(ctx) + if err != nil { + return nil, fmt.Errorf("couldn't query Lagoon DB group projects: %v", err) + } + groupNameProjectIDsMap := map[string][]int{} + for _, userGroup := range userGroups { + // carve out group name from the user group + groupName, err := groupNameFromUserGroup(userGroup) + if err != nil { + return nil, fmt.Errorf("couldn't parse user group: %v", err) + } + // for each user group, get the group ID + groupID, ok := groupNameGroupIDMap[groupName] + if !ok { + return nil, fmt.Errorf("couldn't get group ID for group: %v", groupName) + } + // use the group ID to find the group projects and map groupName to project + // IDs in the groupProjectIDs map + groupNameProjectIDsMap[groupName] = groupIDProjectIDsMap[groupID] + } + return groupNameProjectIDsMap, nil +} diff --git a/internal/lagoon/groupnameprojectids_test.go b/internal/lagoon/groupnameprojectids_test.go new file mode 100644 index 00000000..27a34b66 --- /dev/null +++ b/internal/lagoon/groupnameprojectids_test.go @@ -0,0 +1,89 @@ +package lagoon_test + +import ( + "context" + "testing" + + "github.com/alecthomas/assert/v2" + "github.com/uselagoon/ssh-portal/internal/lagoon" + "github.com/uselagoon/ssh-portal/internal/mock" + "go.uber.org/mock/gomock" +) + +var ( + groupNameGroupIDMap = map[string]string{ + "project-bs-demo": "89c2894b-5345-453d-839d-2c210fe9b18d", + "project-drupal-example": "948adf3d-f075-4659-925d-7d1d4a85f0ba", + "project-skip-test-project": "ea6bd1a8-a1e7-46cc-a62e-cca8dc27f5ed", + "project-test-drupal-example-simple": "0ce10b5d-72ca-40a5-a33f-056b8565521f", + "another-random-group": "7fd49076-5fc9-4b2f-9998-3a3eff731ec0", + } + groupIDProjectIDsMap = map[string][]int{ + "89c2894b-5345-453d-839d-2c210fe9b18d": {1, 23}, + "948adf3d-f075-4659-925d-7d1d4a85f0ba": {45}, + "ea6bd1a8-a1e7-46cc-a62e-cca8dc27f5ed": {6, 7, 8}, + "0ce10b5d-72ca-40a5-a33f-056b8565521f": {90}, + "7fd49076-5fc9-4b2f-9998-3a3eff731ec0": {2, 3}, + } +) + +func TestGroupNameProjectIDsMap(t *testing.T) { + var testCases = map[string]struct { + input []string + expect map[string][]int + expectError bool + }{ + "happy path": { + input: []string{ + "/project-bs-demo/project-as-demo-developer", + "/project-drupal-example/project-drupal-example-maintainer", + "/project-skip-test-project/project-skip-test-project-owner", + "/project-test-drupal-example-simple/project-test-drupal-example-simple-maintainer", + }, + expect: map[string][]int{ + "project-bs-demo": {1, 23}, + "project-drupal-example": {45}, + "project-skip-test-project": {6, 7, 8}, + "project-test-drupal-example-simple": {90}, + }, + }, + "invalid group name": { + input: []string{ + "/project-bs-demo/project-as-demo-developer", + "/project-drupal-example/project-drupal-example-maintainer", + "/project-skip-test-project/project-skip-test-project-owner", + "invalid-group/foo", + }, + expectError: true, + }, + "unknown group": { + input: []string{ + "/project-vandelay/project-as-demo-developer", + "/project-drupal-example/project-drupal-example-maintainer", + "/project-skip-test-project/project-skip-test-project-owner", + "/project-test-drupal-example-simple/project-test-drupal-example-simple-maintainer", + }, + expectError: true, + }, + } + for name, tc := range testCases { + t.Run(name, func(tt *testing.T) { + ctx := context.Background() + // set up mocks + ctrl := gomock.NewController(tt) + kcService := mock.NewMockKeycloakService(ctrl) + dbService := mock.NewMockDBService(ctrl) + // configure mocks + kcService.EXPECT().GroupNameGroupIDMap(ctx).Return(groupNameGroupIDMap, nil) + dbService.EXPECT().GroupIDProjectIDsMap(ctx).Return(groupIDProjectIDsMap, nil) + // test function + gnpids, err := lagoon.GroupNameProjectIDsMap(ctx, dbService, kcService, tc.input) + if tc.expectError { + assert.Error(tt, err, name) + } else { + assert.NoError(tt, err, name) + assert.Equal(tt, tc.expect, gnpids, name) + } + }) + } +} diff --git a/internal/lagoondb/client.go b/internal/lagoondb/client.go index 9880d17f..b4f6f3cd 100644 --- a/internal/lagoondb/client.go +++ b/internal/lagoondb/client.go @@ -35,6 +35,13 @@ type User struct { UUID *uuid.UUID `db:"uuid"` } +// groupProjectMapping maps Lagoon group ID to project ID. +// This type is only used for database unmarshalling. +type groupProjectMapping struct { + GroupID string `db:"group_id"` + ProjectID int `db:"project_id"` +} + // ErrNoResult is returned by client methods if there is no result. var ErrNoResult = errors.New("no rows in result set") @@ -55,7 +62,10 @@ func NewClient(ctx context.Context, dsn string) (*Client, error) { // EnvironmentByNamespaceName returns the Environment associated with the given // Namespace name (on Openshift this is the project name). -func (c *Client) EnvironmentByNamespaceName(ctx context.Context, name string) (*Environment, error) { +func (c *Client) EnvironmentByNamespaceName( + ctx context.Context, + name string, +) (*Environment, error) { // set up tracing ctx, span := otel.Tracer(pkgName).Start(ctx, "EnvironmentByNamespaceName") defer span.End() @@ -84,7 +94,10 @@ func (c *Client) EnvironmentByNamespaceName(ctx context.Context, name string) (* // UserBySSHFingerprint returns the User associated with the given // SSH fingerprint. -func (c *Client) UserBySSHFingerprint(ctx context.Context, fingerprint string) (*User, error) { +func (c *Client) UserBySSHFingerprint( + ctx context.Context, + fingerprint string, +) (*User, error) { // set up tracing ctx, span := otel.Tracer(pkgName).Start(ctx, "UserBySSHFingerprint") defer span.End() @@ -129,3 +142,28 @@ func (c *Client) SSHEndpointByEnvironmentID(ctx context.Context, } return ssh.Host, ssh.Port, nil } + +// GroupIDProjectIDsMap returns a map of Group (UU)IDs to Project IDs. +// This denotes Project Group membership in Lagoon. +func (c *Client) GroupIDProjectIDsMap( + ctx context.Context, +) (map[string][]int, error) { + var gpms []groupProjectMapping + err := c.db.SelectContext(ctx, &gpms, ` + SELECT group_id, project_id + FROM kc_group_projects`) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrNoResult + } + return nil, err + } + groupIDProjectIDsMap := map[string][]int{} + // no need to check for duplicates here since the table has: + // UNIQUE KEY `group_project` (`group_id`,`project_id`) + for _, gpm := range gpms { + groupIDProjectIDsMap[gpm.GroupID] = + append(groupIDProjectIDsMap[gpm.GroupID], gpm.ProjectID) + } + return groupIDProjectIDsMap, nil +} diff --git a/internal/mock/lagoon_groupnameprojectids.go b/internal/mock/lagoon_groupnameprojectids.go new file mode 100644 index 00000000..b39e0660 --- /dev/null +++ b/internal/mock/lagoon_groupnameprojectids.go @@ -0,0 +1,95 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ../lagoon/groupnameprojectids.go +// +// Generated by this command: +// +// mockgen -source=../lagoon/groupnameprojectids.go -package=mock -destination=lagoon_groupnameprojectids.go -write_generate_directive +// + +// Package mock is a generated GoMock package. +package mock + +import ( + context "context" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +//go:generate mockgen -source=../lagoon/groupnameprojectids.go -package=mock -destination=lagoon_groupnameprojectids.go -write_generate_directive + +// MockDBService is a mock of DBService interface. +type MockDBService struct { + ctrl *gomock.Controller + recorder *MockDBServiceMockRecorder +} + +// MockDBServiceMockRecorder is the mock recorder for MockDBService. +type MockDBServiceMockRecorder struct { + mock *MockDBService +} + +// NewMockDBService creates a new mock instance. +func NewMockDBService(ctrl *gomock.Controller) *MockDBService { + mock := &MockDBService{ctrl: ctrl} + mock.recorder = &MockDBServiceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockDBService) EXPECT() *MockDBServiceMockRecorder { + return m.recorder +} + +// GroupIDProjectIDsMap mocks base method. +func (m *MockDBService) GroupIDProjectIDsMap(arg0 context.Context) (map[string][]int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GroupIDProjectIDsMap", arg0) + ret0, _ := ret[0].(map[string][]int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GroupIDProjectIDsMap indicates an expected call of GroupIDProjectIDsMap. +func (mr *MockDBServiceMockRecorder) GroupIDProjectIDsMap(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GroupIDProjectIDsMap", reflect.TypeOf((*MockDBService)(nil).GroupIDProjectIDsMap), arg0) +} + +// MockKeycloakService is a mock of KeycloakService interface. +type MockKeycloakService struct { + ctrl *gomock.Controller + recorder *MockKeycloakServiceMockRecorder +} + +// MockKeycloakServiceMockRecorder is the mock recorder for MockKeycloakService. +type MockKeycloakServiceMockRecorder struct { + mock *MockKeycloakService +} + +// NewMockKeycloakService creates a new mock instance. +func NewMockKeycloakService(ctrl *gomock.Controller) *MockKeycloakService { + mock := &MockKeycloakService{ctrl: ctrl} + mock.recorder = &MockKeycloakServiceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockKeycloakService) EXPECT() *MockKeycloakServiceMockRecorder { + return m.recorder +} + +// GroupNameGroupIDMap mocks base method. +func (m *MockKeycloakService) GroupNameGroupIDMap(arg0 context.Context) (map[string]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GroupNameGroupIDMap", arg0) + ret0, _ := ret[0].(map[string]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GroupNameGroupIDMap indicates an expected call of GroupNameGroupIDMap. +func (mr *MockKeycloakServiceMockRecorder) GroupNameGroupIDMap(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GroupNameGroupIDMap", reflect.TypeOf((*MockKeycloakService)(nil).GroupNameGroupIDMap), arg0) +} diff --git a/internal/rbac/usercansshtoenvironment.go b/internal/rbac/usercansshtoenvironment.go index 691ebb46..3d087dd2 100644 --- a/internal/rbac/usercansshtoenvironment.go +++ b/internal/rbac/usercansshtoenvironment.go @@ -42,7 +42,7 @@ func (p *Permission) UserCanSSHToEnvironment( env *lagoondb.Environment, realmRoles, userGroups []string, - groupProjectIDs map[string][]int, + groupNameProjectIDsMap map[string][]int, ) bool { // set up tracing _, span := otel.Tracer(pkgName).Start(ctx, "UserCanSSHToEnvironment") @@ -71,7 +71,7 @@ func (p *Permission) UserCanSSHToEnvironment( } // check if the user is a member of a group containing the project and has // the required role - for group, pids := range groupProjectIDs { + for groupName, pids := range groupNameProjectIDsMap { for _, pid := range pids { if pid == env.ProjectID { // user is in the same group as project, check if they have the @@ -79,7 +79,7 @@ func (p *Permission) UserCanSSHToEnvironment( var validGroups []string for _, role := range validRoles { validGroups = append(validGroups, - fmt.Sprintf("/%s/%s-%s", group, group, role)) + fmt.Sprintf("/%s/%s-%s", groupName, groupName, role)) } for _, userGroup := range userGroups { for _, validGroup := range validGroups { diff --git a/internal/sshportalapi/server.go b/internal/sshportalapi/server.go index 9e0b339d..d3f60b89 100644 --- a/internal/sshportalapi/server.go +++ b/internal/sshportalapi/server.go @@ -10,6 +10,7 @@ import ( "github.com/google/uuid" "github.com/nats-io/nats.go" + "github.com/uselagoon/ssh-portal/internal/lagoon" "github.com/uselagoon/ssh-portal/internal/lagoondb" "github.com/uselagoon/ssh-portal/internal/rbac" ) @@ -21,13 +22,15 @@ const ( // LagoonDBService provides methods for querying the Lagoon API DB. type LagoonDBService interface { + lagoon.DBService EnvironmentByNamespaceName(context.Context, string) (*lagoondb.Environment, error) UserBySSHFingerprint(context.Context, string) (*lagoondb.User, error) } // KeycloakService provides methods for querying the Keycloak API. type KeycloakService interface { - UserRolesAndGroups(context.Context, *uuid.UUID) ([]string, []string, map[string][]int, error) + lagoon.KeycloakService + UserRolesAndGroups(context.Context, *uuid.UUID) ([]string, []string, error) } // ServeNATS sshportalapi NATS requests. diff --git a/internal/sshportalapi/sshportal.go b/internal/sshportalapi/sshportal.go index 3d9a2e25..7059b54d 100644 --- a/internal/sshportalapi/sshportal.go +++ b/internal/sshportalapi/sshportal.go @@ -8,6 +8,7 @@ import ( "github.com/nats-io/nats.go" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" + "github.com/uselagoon/ssh-portal/internal/lagoon" "github.com/uselagoon/ssh-portal/internal/lagoondb" "github.com/uselagoon/ssh-portal/internal/rbac" "go.opentelemetry.io/otel" @@ -45,11 +46,16 @@ var ( }) ) -func sshportal(ctx context.Context, log *slog.Logger, c *nats.EncodedConn, - p *rbac.Permission, l LagoonDBService, k KeycloakService) nats.Handler { +func sshportal( + ctx context.Context, + log *slog.Logger, + c *nats.EncodedConn, + p *rbac.Permission, + l LagoonDBService, + k KeycloakService, +) nats.Handler { return func(_, replySubject string, query *SSHAccessQuery) { var realmRoles, userGroups []string - var groupProjectIDs map[string][]int // set up tracing and update metrics ctx, span := otel.Tracer(pkgName).Start(ctx, SubjectSSHAccessQuery) defer span.End() @@ -101,23 +107,30 @@ func sshportal(ctx context.Context, log *slog.Logger, c *nats.EncodedConn, return } // get the user roles and groups - realmRoles, userGroups, groupProjectIDs, err = - k.UserRolesAndGroups(ctx, user.UUID) + realmRoles, userGroups, err = k.UserRolesAndGroups(ctx, user.UUID) if err != nil { - log.Error("couldn't query user roles and groups", + log.Error("couldn't query keycloak user roles and groups", slog.String("userUUID", user.UUID.String()), slog.Any("error", err)) return } + // generate the group name to project IDs map + groupNameProjectIDsMap, err := + lagoon.GroupNameProjectIDsMap(ctx, l, k, userGroups) + if err != nil { + log.Error("couldn't generate group name to project IDs map", + slog.Any("error", err)) + return + } log.Debug("keycloak user attributes", slog.Any("realmRoles", realmRoles), slog.Any("userGroups", userGroups), - slog.Any("groupProjectIDs", groupProjectIDs), + slog.Any("groupNameProjectIDsMap", groupNameProjectIDsMap), slog.String("userUUID", user.UUID.String()), ) // check permission - ok := p.UserCanSSHToEnvironment(ctx, env, realmRoles, userGroups, - groupProjectIDs) + ok := p.UserCanSSHToEnvironment( + ctx, env, realmRoles, userGroups, groupNameProjectIDsMap) var logMsg string if ok { logMsg = "SSH access authorized" diff --git a/internal/sshtoken/serve.go b/internal/sshtoken/serve.go index 34c435b9..03328bcd 100644 --- a/internal/sshtoken/serve.go +++ b/internal/sshtoken/serve.go @@ -11,6 +11,7 @@ import ( "github.com/gliderlabs/ssh" "github.com/uselagoon/ssh-portal/internal/keycloak" + "github.com/uselagoon/ssh-portal/internal/lagoon" "github.com/uselagoon/ssh-portal/internal/lagoondb" "github.com/uselagoon/ssh-portal/internal/rbac" ) @@ -20,18 +21,26 @@ const shutdownTimeout = 8 * time.Second // LagoonDBService provides methods for querying the Lagoon API DB. type LagoonDBService interface { + lagoon.DBService EnvironmentByNamespaceName(context.Context, string) (*lagoondb.Environment, error) UserBySSHFingerprint(context.Context, string) (*lagoondb.User, error) SSHEndpointByEnvironmentID(context.Context, int) (string, string, error) } // Serve contains the main ssh session logic -func Serve(ctx context.Context, log *slog.Logger, l net.Listener, - p *rbac.Permission, ldb *lagoondb.Client, - keycloakToken, keycloakPermission *keycloak.Client, - hostKeys [][]byte) error { +func Serve( + ctx context.Context, + log *slog.Logger, + l net.Listener, + p *rbac.Permission, + ldb *lagoondb.Client, + keycloakToken, + keycloakPermission *keycloak.Client, + hostKeys [][]byte, +) error { srv := ssh.Server{ - Handler: sessionHandler(log, p, keycloakToken, keycloakPermission, ldb), + Handler: sessionHandler( + log, p, keycloakToken, keycloakPermission, ldb), PublicKeyHandler: pubKeyAuth(log, ldb), } for _, hk := range hostKeys { diff --git a/internal/sshtoken/sessionhandler.go b/internal/sshtoken/sessionhandler.go index 227ddedf..fc018b37 100644 --- a/internal/sshtoken/sessionhandler.go +++ b/internal/sshtoken/sessionhandler.go @@ -10,6 +10,7 @@ import ( "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" + "github.com/uselagoon/ssh-portal/internal/lagoon" "github.com/uselagoon/ssh-portal/internal/lagoondb" "github.com/uselagoon/ssh-portal/internal/rbac" ) @@ -24,8 +25,8 @@ type KeycloakTokenService interface { // KeycloakUserInfoService provides methods for querying the Keycloak API for // permission information contained in service-api user tokens. type KeycloakUserInfoService interface { - UserRolesAndGroups(context.Context, *uuid.UUID) ([]string, []string, - map[string][]int, error) + lagoon.KeycloakService + UserRolesAndGroups(context.Context, *uuid.UUID) ([]string, []string, error) } var ( @@ -124,12 +125,17 @@ func tokenSession(s ssh.Session, log *slog.Logger, // the user has access to, returns an error message to the user with the SSH // endpoint to use for ssh shell access. If the user doesn't have access to the // environment a generic error message is returned. -func redirectSession(s ssh.Session, log *slog.Logger, - p *rbac.Permission, keycloakUserInfo KeycloakUserInfoService, - ldb LagoonDBService, uid *uuid.UUID) { +func redirectSession( + s ssh.Session, + log *slog.Logger, + p *rbac.Permission, + keycloakUserInfo KeycloakUserInfoService, + ldb LagoonDBService, + uid *uuid.UUID, +) { ctx := s.Context() // get the user roles and groups - realmRoles, userGroups, groupProjectIDs, err := + realmRoles, userGroups, err := keycloakUserInfo.UserRolesAndGroups(s.Context(), uid) if err != nil { log.Error("couldn't query user roles and groups", @@ -170,15 +176,22 @@ func redirectSession(s ssh.Session, log *slog.Logger, slog.String("namespaceName", s.User()), slog.String("projectName", env.ProjectName), ) + groupNameProjectIDsMap, err := + lagoon.GroupNameProjectIDsMap(ctx, ldb, keycloakUserInfo, userGroups) + if err != nil { + log.Error("couldn't generate group name to project IDs map", + slog.Any("error", err)) + return + } // check permission ok := p.UserCanSSHToEnvironment(s.Context(), env, realmRoles, - userGroups, groupProjectIDs) + userGroups, groupNameProjectIDsMap) if !ok { log.Info("user cannot SSH to environment") log.Debug("user permissions", slog.Any("realmRoles", realmRoles), slog.Any("userGroups", userGroups), - slog.Any("groupProjectIDs", groupProjectIDs)) + slog.Any("groupProjectIDs", groupNameProjectIDsMap)) _, err = fmt.Fprintf(s.Stderr(), "This SSH server does not provide shell access. SID: %s\r\n", ctx.SessionID()) @@ -192,7 +205,7 @@ func redirectSession(s ssh.Session, log *slog.Logger, log.Debug("user permissions", slog.Any("realmRoles", realmRoles), slog.Any("userGroups", userGroups), - slog.Any("groupProjectIDs", groupProjectIDs)) + slog.Any("groupProjectIDs", groupNameProjectIDsMap)) sshHost, sshPort, err := ldb.SSHEndpointByEnvironmentID(s.Context(), env.ID) if err != nil { if errors.Is(err, lagoondb.ErrNoResult) {