Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v14] Added support to select database roles from tsh #36528

Merged
merged 3 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,707 changes: 882 additions & 825 deletions api/client/proto/authservice.pb.go

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions api/proto/teleport/legacy/client/proto/authservice.proto
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,8 @@ message RouteToDatabase {
string Username = 3 [(gogoproto.jsontag) = "username,omitempty"];
// Database is an optional database name to embed.
string Database = 4 [(gogoproto.jsontag) = "database,omitempty"];
// Roles is an optional list of database roles to embed.
repeated string Roles = 5 [(gogoproto.jsontag) = "roles,omitempty"];
}

// RouteToWindowsDesktop combines parameters for windows desktop routing information.
Expand Down
2 changes: 2 additions & 0 deletions api/proto/teleport/legacy/types/events/events.proto
Original file line number Diff line number Diff line change
Expand Up @@ -3696,6 +3696,8 @@ message RouteToDatabase {
string Username = 3 [(gogoproto.jsontag) = "username,omitempty"];
// Database is an optional database name to embed.
string Database = 4 [(gogoproto.jsontag) = "database,omitempty"];
// Roles is an optional list of database roles to embed.
repeated string Roles = 5 [(gogoproto.jsontag) = "roles,omitempty"];
}

// AccessRequestResourceSearch is emitted when a user searches for resources as
Expand Down
682 changes: 366 additions & 316 deletions api/types/events/events.pb.go

Large diffs are not rendered by default.

18 changes: 18 additions & 0 deletions api/utils/slices.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package utils

import (
"slices"
"strings"
)

Expand Down Expand Up @@ -81,3 +82,20 @@ func DeduplicateAny[T any](in []T, compare func(T, T) bool) []T {
}
return out
}

// ContainSameUniqueElements returns true if the input slices contain the same
// unique elements. Ordering and duplicates are ignored.
func ContainSameUniqueElements[S ~[]E, E comparable](s1, s2 S) bool {
s1Dedup := Deduplicate(s1)
s2Dedup := Deduplicate(s2)

if len(s1Dedup) != len(s2Dedup) {
return false
}
for i := range s1Dedup {
if !slices.Contains(s2Dedup, s1Dedup[i]) {
return false
}
}
return true
}
52 changes: 52 additions & 0 deletions api/utils/slices_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,55 @@ func TestDeduplicateAny(t *testing.T) {
})
}
}

func TestContainSameUniqueElements(t *testing.T) {
tests := []struct {
name string
s1 []string
s2 []string
check require.BoolAssertionFunc
}{
{
name: "empty",
s1: nil,
s2: []string{},
check: require.True,
},
{
name: "same",
s1: []string{"a", "b", "c"},
s2: []string{"a", "b", "c"},
check: require.True,
},
{
name: "same with different order",
s1: []string{"b", "c", "a"},
s2: []string{"a", "b", "c"},
check: require.True,
},
{
name: "same with duplicates",
s1: []string{"a", "a", "b", "c"},
s2: []string{"c", "c", "a", "b", "c", "c"},
check: require.True,
},
{
name: "different",
s1: []string{"a", "b"},
s2: []string{"a", "b", "c"},
check: require.False,
},
{
name: "different (same length)",
s1: []string{"d", "a", "b"},
s2: []string{"a", "b", "c"},
check: require.False,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
test.check(t, ContainSameUniqueElements(test.s1, test.s2))
})
}
}
36 changes: 23 additions & 13 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -1760,6 +1760,9 @@ type certRequest struct {
// dbName is the optional database name which, if provided, will be used
// as a default database.
dbName string
// dbRoles is the optional list of database roles which, if provided, will
// be used instead of all database roles granted for the target database.
dbRoles []string
// mfaVerified is the UUID of an MFA device when this certRequest was
// created immediately after an MFA check.
mfaVerified string
Expand Down Expand Up @@ -2069,6 +2072,7 @@ func (a *Server) GenerateDatabaseTestCert(req DatabaseTestCertRequest) ([]byte,
dbProtocol: req.RouteToDatabase.Protocol,
dbUser: req.RouteToDatabase.Username,
dbName: req.RouteToDatabase.Database,
dbRoles: req.RouteToDatabase.Roles,
})
if err != nil {
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -2620,6 +2624,7 @@ func generateCert(a *Server, req certRequest, caType types.CertAuthType) (*proto
Protocol: req.dbProtocol,
Username: req.dbUser,
Database: req.dbName,
Roles: req.dbRoles,
},
DatabaseNames: dbNames,
DatabaseUsers: dbUsers,
Expand Down Expand Up @@ -5470,21 +5475,26 @@ func (a *Server) isMFARequired(ctx context.Context, checker services.AccessCheck
return nil, trace.Wrap(notFoundErr)
}

autoCreate, _, err := checker.CheckDatabaseRoles(db)
if err != nil {
autoCreate, err := checker.DatabaseAutoUserMode(db)
switch {
case errors.Is(err, services.ErrSessionMFARequired):
noMFAAccessErr = err
case err != nil:
return nil, trace.Wrap(err)
default:
dbRoleMatchers := role.GetDatabaseRoleMatchers(role.RoleMatchersConfig{
Database: db,
DatabaseUser: t.Database.Username,
DatabaseName: t.Database.GetDatabase(),
AutoCreateUser: autoCreate.IsEnabled(),
})
noMFAAccessErr = checker.CheckAccess(
db,
services.AccessState{},
dbRoleMatchers...,
)
}
dbRoleMatchers := role.GetDatabaseRoleMatchers(role.RoleMatchersConfig{
Database: db,
DatabaseUser: t.Database.Username,
DatabaseName: t.Database.GetDatabase(),
AutoCreateUser: autoCreate.IsEnabled(),
})
noMFAAccessErr = checker.CheckAccess(
db,
services.AccessState{},
dbRoleMatchers...,
)

case *proto.IsMFARequiredRequest_WindowsDesktop:
desktops, err := a.GetWindowsDesktops(ctx, types.WindowsDesktopFilter{Name: t.WindowsDesktop.GetWindowsDesktop()})
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions lib/auth/auth_with_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -3171,6 +3171,7 @@ func (a *ServerWithRoles) generateUserCerts(ctx context.Context, req proto.UserC
dbProtocol: req.RouteToDatabase.Protocol,
dbUser: req.RouteToDatabase.Username,
dbName: req.RouteToDatabase.Database,
dbRoles: req.RouteToDatabase.Roles,
appName: req.RouteToApp.Name,
appSessionID: req.RouteToApp.SessionID,
appPublicAddr: req.RouteToApp.PublicAddr,
Expand Down
48 changes: 43 additions & 5 deletions lib/services/access_checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,14 @@ type AccessChecker interface {
// is allowed to use.
CheckDatabaseNamesAndUsers(ttl time.Duration, overrideTTL bool) (names []string, users []string, err error)

// CheckDatabaseRoles returns whether a user should be auto-created in the
// database and a list of database roles to assign.
CheckDatabaseRoles(types.Database) (mode types.CreateDatabaseUserMode, roles []string, err error)
// DatabaseAutoUserMode returns whether a user should be auto-created in
// the database.
DatabaseAutoUserMode(types.Database) (types.CreateDatabaseUserMode, error)

// CheckDatabaseRoles returns a list of database roles to assign, when
// auto-user provisioning is enabled. If no user-requested roles, all
// allowed roles are returned.
CheckDatabaseRoles(database types.Database, userRequestedRoles []string) (roles []string, err error)

// CheckImpersonate checks whether current user is allowed to impersonate
// users and roles
Expand Down Expand Up @@ -515,9 +520,42 @@ func (a *accessChecker) Traits() wrappers.Traits {
return a.info.Traits
}

// DatabaseAutoUserMode returns whether a user should be auto-created in
// the database.
func (a *accessChecker) DatabaseAutoUserMode(database types.Database) (types.CreateDatabaseUserMode, error) {
mode, _, err := a.checkDatabaseRoles(database)
return mode, trace.Wrap(err)
}

// CheckDatabaseRoles returns whether a user should be auto-created in the
// database and a list of database roles to assign.
func (a *accessChecker) CheckDatabaseRoles(database types.Database) (mode types.CreateDatabaseUserMode, roles []string, err error) {
func (a *accessChecker) CheckDatabaseRoles(database types.Database, userRequestedRoles []string) ([]string, error) {
mode, allowedRoles, err := a.checkDatabaseRoles(database)
if err != nil {
return nil, trace.Wrap(err)
}

switch {
case !mode.IsEnabled():
return []string{}, nil

// If user requested a list of roles, make sure all requested roles are
// allowed.
case len(userRequestedRoles) > 0:
for _, requestedRole := range userRequestedRoles {
if !slices.Contains(allowedRoles, requestedRole) {
return nil, trace.AccessDenied("access to database role %q denied", requestedRole)
}
}
return userRequestedRoles, nil

// If user does not provide any roles, use all allowed roles from roleset.
default:
return allowedRoles, nil
}
}

func (a *accessChecker) checkDatabaseRoles(database types.Database) (types.CreateDatabaseUserMode, []string, error) {
// First, collect roles from this roleset that have create database user mode set.
var autoCreateRoles RoleSet
for _, role := range a.RoleSet {
Expand Down Expand Up @@ -570,7 +608,7 @@ func (a *accessChecker) EnumerateDatabaseUsers(database types.Database, extraUse
// When auto-user provisioning is enabled, only Teleport username is allowed.
if database.SupportsAutoUsers() && database.GetAdminUser().Name != "" {
result := NewEnumerationResult()
autoUser, _, err := a.CheckDatabaseRoles(database)
autoUser, err := a.DatabaseAutoUserMode(database)
if err != nil {
return result, trace.Wrap(err)
} else if autoUser.IsEnabled() {
Expand Down
61 changes: 56 additions & 5 deletions lib/services/role_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4354,10 +4354,27 @@ func TestCheckDatabaseRoles(t *testing.T) {
},
}

// roleD has a bad label expression.
roleD := &types.RoleV6{
Metadata: types.Metadata{Name: "roleD", Namespace: apidefaults.Namespace},
Spec: types.RoleSpecV6{
Options: types.RoleOptions{
CreateDatabaseUser: types.NewBoolOption(true),
},
Allow: types.RoleConditions{
DatabaseLabelsExpression: `a bad expression`,
DatabaseRoles: []string{"reader"},
},
},
}

tests := []struct {
name string
roleSet RoleSet
inDatabaseLabels map[string]string
inRequestedRoles []string
outModeError bool
outRolesError bool
outCreateUser bool
outRoles []string
}{
Expand All @@ -4366,7 +4383,7 @@ func TestCheckDatabaseRoles(t *testing.T) {
roleSet: RoleSet{roleA},
inDatabaseLabels: map[string]string{"app": "metrics"},
outCreateUser: false,
outRoles: []string(nil),
outRoles: []string{},
},
{
name: "database doesn't match",
Expand Down Expand Up @@ -4396,6 +4413,29 @@ func TestCheckDatabaseRoles(t *testing.T) {
outCreateUser: true,
outRoles: []string{"reader"},
},
{
name: "connect to metrics database, requested writer role",
roleSet: RoleSet{roleA, roleB, roleC},
inDatabaseLabels: map[string]string{"app": "metrics"},
inRequestedRoles: []string{"writer"},
outCreateUser: true,
outRoles: []string{"writer"},
},
{
name: "requested role denied",
roleSet: RoleSet{roleA, roleB, roleC},
inDatabaseLabels: map[string]string{"app": "metrics", "env": "prod"},
inRequestedRoles: []string{"writer"},
outCreateUser: true,
outRolesError: true,
},
{
name: "check fails",
roleSet: RoleSet{roleD},
inDatabaseLabels: map[string]string{"app": "metrics"},
outModeError: true,
outRolesError: true,
},
}

for _, test := range tests {
Expand All @@ -4410,10 +4450,21 @@ func TestCheckDatabaseRoles(t *testing.T) {
})
require.NoError(t, err)

create, roles, err := accessChecker.CheckDatabaseRoles(database)
require.NoError(t, err)
require.Equal(t, test.outCreateUser, create.IsEnabled())
require.Equal(t, test.outRoles, roles)
create, err := accessChecker.DatabaseAutoUserMode(database)
if test.outModeError {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, test.outCreateUser, create.IsEnabled())
}

roles, err := accessChecker.CheckDatabaseRoles(database, test.inRequestedRoles)
if test.outRolesError {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, test.outRoles, roles)
}
})
}
}
Expand Down
12 changes: 10 additions & 2 deletions lib/srv/db/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,15 @@ func (s *Server) handleConnection(ctx context.Context, clientConn net.Conn) erro
s.log.Debug("LoginIP is not set (Proxy Service has to be updated). Rate limiting is disabled.")
}

// Update database roles. It needs to be done here after engine is
// dispatched so the engine can propagate the error message to the client.
if sessionCtx.AutoCreateUserMode.IsEnabled() {
sessionCtx.DatabaseRoles, err = sessionCtx.Checker.CheckDatabaseRoles(sessionCtx.Database, sessionCtx.Identity.RouteToDatabase.Roles)
if err != nil {
return trace.Wrap(err)
}
}

err = engine.HandleConnection(ctx, sessionCtx)
if err != nil {
connectionDiagnosticID := sessionCtx.Identity.ConnectionDiagnosticID
Expand Down Expand Up @@ -1121,7 +1130,7 @@ func (s *Server) authorize(ctx context.Context) (*common.Session, error) {
return nil, trace.Wrap(err)
}

autoCreate, databaseRoles, err := authContext.Checker.CheckDatabaseRoles(database)
autoCreate, err := authContext.Checker.DatabaseAutoUserMode(database)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -1139,7 +1148,6 @@ func (s *Server) authorize(ctx context.Context) (*common.Session, error) {
AutoCreateUserMode: autoCreate,
DatabaseUser: identity.RouteToDatabase.Username,
DatabaseName: identity.RouteToDatabase.Database,
DatabaseRoles: databaseRoles,
AuthContext: authContext,
Checker: authContext.Checker,
StartupParameters: make(map[string]string),
Expand Down
Loading
Loading