Skip to content

Commit

Permalink
[v14] Added support to select database roles from tsh (#36528)
Browse files Browse the repository at this point in the history
* Added support to select database roles from `tsh`. (#35867)

* Support selecting db roles.

* add active roles in tsh db ls

* add select db role ut

* add a warning when connecting to older version

* review comments

* chnage slices import

* --db-roles to --db-role

* fix ut

* fix build and ut
  • Loading branch information
greedy52 authored Jan 11, 2024
1 parent c1fe64a commit 4827c34
Show file tree
Hide file tree
Showing 18 changed files with 2,291 additions and 1,665 deletions.
1,707 changes: 882 additions & 825 deletions api/client/proto/authservice.pb.go

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions api/proto/teleport/legacy/client/proto/authservice.proto
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,8 @@ message RouteToDatabase {
string Username = 3 [(gogoproto.jsontag) = "username,omitempty"];
// Database is an optional database name to embed.
string Database = 4 [(gogoproto.jsontag) = "database,omitempty"];
// Roles is an optional list of database roles to embed.
repeated string Roles = 5 [(gogoproto.jsontag) = "roles,omitempty"];
}

// RouteToWindowsDesktop combines parameters for windows desktop routing information.
Expand Down
2 changes: 2 additions & 0 deletions api/proto/teleport/legacy/types/events/events.proto
Original file line number Diff line number Diff line change
Expand Up @@ -3718,6 +3718,8 @@ message RouteToDatabase {
string Username = 3 [(gogoproto.jsontag) = "username,omitempty"];
// Database is an optional database name to embed.
string Database = 4 [(gogoproto.jsontag) = "database,omitempty"];
// Roles is an optional list of database roles to embed.
repeated string Roles = 5 [(gogoproto.jsontag) = "roles,omitempty"];
}

// AccessRequestResourceSearch is emitted when a user searches for resources as
Expand Down
1,573 changes: 811 additions & 762 deletions api/types/events/events.pb.go

Large diffs are not rendered by default.

19 changes: 19 additions & 0 deletions api/utils/slices.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ package utils

import (
"strings"

"golang.org/x/exp/slices"
)

// JoinStrings returns a string that is all the elements in the slice `T[]` joined by `sep`
Expand Down Expand Up @@ -81,3 +83,20 @@ func DeduplicateAny[T any](in []T, compare func(T, T) bool) []T {
}
return out
}

// ContainSameUniqueElements returns true if the input slices contain the same
// unique elements. Ordering and duplicates are ignored.
func ContainSameUniqueElements[S ~[]E, E comparable](s1, s2 S) bool {
s1Dedup := Deduplicate(s1)
s2Dedup := Deduplicate(s2)

if len(s1Dedup) != len(s2Dedup) {
return false
}
for i := range s1Dedup {
if !slices.Contains(s2Dedup, s1Dedup[i]) {
return false
}
}
return true
}
52 changes: 52 additions & 0 deletions api/utils/slices_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,55 @@ func TestDeduplicateAny(t *testing.T) {
})
}
}

func TestContainSameUniqueElements(t *testing.T) {
tests := []struct {
name string
s1 []string
s2 []string
check require.BoolAssertionFunc
}{
{
name: "empty",
s1: nil,
s2: []string{},
check: require.True,
},
{
name: "same",
s1: []string{"a", "b", "c"},
s2: []string{"a", "b", "c"},
check: require.True,
},
{
name: "same with different order",
s1: []string{"b", "c", "a"},
s2: []string{"a", "b", "c"},
check: require.True,
},
{
name: "same with duplicates",
s1: []string{"a", "a", "b", "c"},
s2: []string{"c", "c", "a", "b", "c", "c"},
check: require.True,
},
{
name: "different",
s1: []string{"a", "b"},
s2: []string{"a", "b", "c"},
check: require.False,
},
{
name: "different (same length)",
s1: []string{"d", "a", "b"},
s2: []string{"a", "b", "c"},
check: require.False,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
test.check(t, ContainSameUniqueElements(test.s1, test.s2))
})
}
}
36 changes: 23 additions & 13 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -1760,6 +1760,9 @@ type certRequest struct {
// dbName is the optional database name which, if provided, will be used
// as a default database.
dbName string
// dbRoles is the optional list of database roles which, if provided, will
// be used instead of all database roles granted for the target database.
dbRoles []string
// mfaVerified is the UUID of an MFA device when this certRequest was
// created immediately after an MFA check.
mfaVerified string
Expand Down Expand Up @@ -2071,6 +2074,7 @@ func (a *Server) GenerateDatabaseTestCert(req DatabaseTestCertRequest) ([]byte,
dbProtocol: req.RouteToDatabase.Protocol,
dbUser: req.RouteToDatabase.Username,
dbName: req.RouteToDatabase.Database,
dbRoles: req.RouteToDatabase.Roles,
})
if err != nil {
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -2623,6 +2627,7 @@ func generateCert(a *Server, req certRequest, caType types.CertAuthType) (*proto
Protocol: req.dbProtocol,
Username: req.dbUser,
Database: req.dbName,
Roles: req.dbRoles,
},
DatabaseNames: dbNames,
DatabaseUsers: dbUsers,
Expand Down Expand Up @@ -5474,21 +5479,26 @@ func (a *Server) isMFARequired(ctx context.Context, checker services.AccessCheck
return nil, trace.Wrap(notFoundErr)
}

autoCreate, _, err := checker.CheckDatabaseRoles(db)
if err != nil {
autoCreate, err := checker.DatabaseAutoUserMode(db)
switch {
case errors.Is(err, services.ErrSessionMFARequired):
noMFAAccessErr = err
case err != nil:
return nil, trace.Wrap(err)
default:
dbRoleMatchers := role.GetDatabaseRoleMatchers(role.RoleMatchersConfig{
Database: db,
DatabaseUser: t.Database.Username,
DatabaseName: t.Database.GetDatabase(),
AutoCreateUser: autoCreate.IsEnabled(),
})
noMFAAccessErr = checker.CheckAccess(
db,
services.AccessState{},
dbRoleMatchers...,
)
}
dbRoleMatchers := role.GetDatabaseRoleMatchers(role.RoleMatchersConfig{
Database: db,
DatabaseUser: t.Database.Username,
DatabaseName: t.Database.GetDatabase(),
AutoCreateUser: autoCreate.IsEnabled(),
})
noMFAAccessErr = checker.CheckAccess(
db,
services.AccessState{},
dbRoleMatchers...,
)

case *proto.IsMFARequiredRequest_WindowsDesktop:
desktops, err := a.GetWindowsDesktops(ctx, types.WindowsDesktopFilter{Name: t.WindowsDesktop.GetWindowsDesktop()})
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions lib/auth/auth_with_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -3182,6 +3182,7 @@ func (a *ServerWithRoles) generateUserCerts(ctx context.Context, req proto.UserC
dbProtocol: req.RouteToDatabase.Protocol,
dbUser: req.RouteToDatabase.Username,
dbName: req.RouteToDatabase.Database,
dbRoles: req.RouteToDatabase.Roles,
appName: req.RouteToApp.Name,
appSessionID: req.RouteToApp.SessionID,
appPublicAddr: req.RouteToApp.PublicAddr,
Expand Down
48 changes: 43 additions & 5 deletions lib/services/access_checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,14 @@ type AccessChecker interface {
// is allowed to use.
CheckDatabaseNamesAndUsers(ttl time.Duration, overrideTTL bool) (names []string, users []string, err error)

// CheckDatabaseRoles returns whether a user should be auto-created in the
// database and a list of database roles to assign.
CheckDatabaseRoles(types.Database) (mode types.CreateDatabaseUserMode, roles []string, err error)
// DatabaseAutoUserMode returns whether a user should be auto-created in
// the database.
DatabaseAutoUserMode(types.Database) (types.CreateDatabaseUserMode, error)

// CheckDatabaseRoles returns a list of database roles to assign, when
// auto-user provisioning is enabled. If no user-requested roles, all
// allowed roles are returned.
CheckDatabaseRoles(database types.Database, userRequestedRoles []string) (roles []string, err error)

// CheckImpersonate checks whether current user is allowed to impersonate
// users and roles
Expand Down Expand Up @@ -515,9 +520,42 @@ func (a *accessChecker) Traits() wrappers.Traits {
return a.info.Traits
}

// DatabaseAutoUserMode returns whether a user should be auto-created in
// the database.
func (a *accessChecker) DatabaseAutoUserMode(database types.Database) (types.CreateDatabaseUserMode, error) {
mode, _, err := a.checkDatabaseRoles(database)
return mode, trace.Wrap(err)
}

// CheckDatabaseRoles returns whether a user should be auto-created in the
// database and a list of database roles to assign.
func (a *accessChecker) CheckDatabaseRoles(database types.Database) (mode types.CreateDatabaseUserMode, roles []string, err error) {
func (a *accessChecker) CheckDatabaseRoles(database types.Database, userRequestedRoles []string) ([]string, error) {
mode, allowedRoles, err := a.checkDatabaseRoles(database)
if err != nil {
return nil, trace.Wrap(err)
}

switch {
case !mode.IsEnabled():
return []string{}, nil

// If user requested a list of roles, make sure all requested roles are
// allowed.
case len(userRequestedRoles) > 0:
for _, requestedRole := range userRequestedRoles {
if !slices.Contains(allowedRoles, requestedRole) {
return nil, trace.AccessDenied("access to database role %q denied", requestedRole)
}
}
return userRequestedRoles, nil

// If user does not provide any roles, use all allowed roles from roleset.
default:
return allowedRoles, nil
}
}

func (a *accessChecker) checkDatabaseRoles(database types.Database) (types.CreateDatabaseUserMode, []string, error) {
// First, collect roles from this roleset that have create database user mode set.
var autoCreateRoles RoleSet
for _, role := range a.RoleSet {
Expand Down Expand Up @@ -570,7 +608,7 @@ func (a *accessChecker) EnumerateDatabaseUsers(database types.Database, extraUse
// When auto-user provisioning is enabled, only Teleport username is allowed.
if database.SupportsAutoUsers() && database.GetAdminUser().Name != "" {
result := NewEnumerationResult()
autoUser, _, err := a.CheckDatabaseRoles(database)
autoUser, err := a.DatabaseAutoUserMode(database)
if err != nil {
return result, trace.Wrap(err)
} else if autoUser.IsEnabled() {
Expand Down
61 changes: 56 additions & 5 deletions lib/services/role_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4354,10 +4354,27 @@ func TestCheckDatabaseRoles(t *testing.T) {
},
}

// roleD has a bad label expression.
roleD := &types.RoleV6{
Metadata: types.Metadata{Name: "roleD", Namespace: apidefaults.Namespace},
Spec: types.RoleSpecV6{
Options: types.RoleOptions{
CreateDatabaseUser: types.NewBoolOption(true),
},
Allow: types.RoleConditions{
DatabaseLabelsExpression: `a bad expression`,
DatabaseRoles: []string{"reader"},
},
},
}

tests := []struct {
name string
roleSet RoleSet
inDatabaseLabels map[string]string
inRequestedRoles []string
outModeError bool
outRolesError bool
outCreateUser bool
outRoles []string
}{
Expand All @@ -4366,7 +4383,7 @@ func TestCheckDatabaseRoles(t *testing.T) {
roleSet: RoleSet{roleA},
inDatabaseLabels: map[string]string{"app": "metrics"},
outCreateUser: false,
outRoles: []string(nil),
outRoles: []string{},
},
{
name: "database doesn't match",
Expand Down Expand Up @@ -4396,6 +4413,29 @@ func TestCheckDatabaseRoles(t *testing.T) {
outCreateUser: true,
outRoles: []string{"reader"},
},
{
name: "connect to metrics database, requested writer role",
roleSet: RoleSet{roleA, roleB, roleC},
inDatabaseLabels: map[string]string{"app": "metrics"},
inRequestedRoles: []string{"writer"},
outCreateUser: true,
outRoles: []string{"writer"},
},
{
name: "requested role denied",
roleSet: RoleSet{roleA, roleB, roleC},
inDatabaseLabels: map[string]string{"app": "metrics", "env": "prod"},
inRequestedRoles: []string{"writer"},
outCreateUser: true,
outRolesError: true,
},
{
name: "check fails",
roleSet: RoleSet{roleD},
inDatabaseLabels: map[string]string{"app": "metrics"},
outModeError: true,
outRolesError: true,
},
}

for _, test := range tests {
Expand All @@ -4410,10 +4450,21 @@ func TestCheckDatabaseRoles(t *testing.T) {
})
require.NoError(t, err)

create, roles, err := accessChecker.CheckDatabaseRoles(database)
require.NoError(t, err)
require.Equal(t, test.outCreateUser, create.IsEnabled())
require.Equal(t, test.outRoles, roles)
create, err := accessChecker.DatabaseAutoUserMode(database)
if test.outModeError {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, test.outCreateUser, create.IsEnabled())
}

roles, err := accessChecker.CheckDatabaseRoles(database, test.inRequestedRoles)
if test.outRolesError {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, test.outRoles, roles)
}
})
}
}
Expand Down
12 changes: 10 additions & 2 deletions lib/srv/db/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,15 @@ func (s *Server) handleConnection(ctx context.Context, clientConn net.Conn) erro
s.log.Debug("LoginIP is not set (Proxy Service has to be updated). Rate limiting is disabled.")
}

// Update database roles. It needs to be done here after engine is
// dispatched so the engine can propagate the error message to the client.
if sessionCtx.AutoCreateUserMode.IsEnabled() {
sessionCtx.DatabaseRoles, err = sessionCtx.Checker.CheckDatabaseRoles(sessionCtx.Database, sessionCtx.Identity.RouteToDatabase.Roles)
if err != nil {
return trace.Wrap(err)
}
}

err = engine.HandleConnection(ctx, sessionCtx)
if err != nil {
connectionDiagnosticID := sessionCtx.Identity.ConnectionDiagnosticID
Expand Down Expand Up @@ -1121,7 +1130,7 @@ func (s *Server) authorize(ctx context.Context) (*common.Session, error) {
return nil, trace.Wrap(err)
}

autoCreate, databaseRoles, err := authContext.Checker.CheckDatabaseRoles(database)
autoCreate, err := authContext.Checker.DatabaseAutoUserMode(database)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -1139,7 +1148,6 @@ func (s *Server) authorize(ctx context.Context) (*common.Session, error) {
AutoCreateUserMode: autoCreate,
DatabaseUser: identity.RouteToDatabase.Username,
DatabaseName: identity.RouteToDatabase.Database,
DatabaseRoles: databaseRoles,
AuthContext: authContext,
Checker: authContext.Checker,
StartupParameters: make(map[string]string),
Expand Down
Loading

0 comments on commit 4827c34

Please sign in to comment.