Skip to content

Commit

Permalink
[v15] fix: Enforce device trust on OSS processes (#46946)
Browse files Browse the repository at this point in the history
* fix: Enforce device trust on OSS processes

* Fix TestSessionController_AcquireSessionContext
  • Loading branch information
codingllama committed Sep 27, 2024
1 parent 92d8b5d commit 9a298ef
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 31 deletions.
20 changes: 2 additions & 18 deletions lib/devicetrust/authz/authz.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,14 @@
package authz

import (
"sync"

"github.com/gravitational/trace"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/constants"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/devicetrust/config"
dtconfig "github.com/gravitational/teleport/lib/devicetrust/config"
"github.com/gravitational/teleport/lib/tlsca"
)

Expand Down Expand Up @@ -73,9 +71,7 @@ func VerifySSHUser(dt *types.DeviceTrust, cert *ssh.Certificate) error {
}

func verifyDeviceExtensions(dt *types.DeviceTrust, username string, verified bool) error {
mode := config.GetEffectiveMode(dt)
maybeLogModeMismatch(mode, dt)

mode := dtconfig.GetEnforcementMode(dt)
switch {
case mode != constants.DeviceTrustModeRequired:
return nil // OK, extensions not enforced.
Expand All @@ -88,15 +84,3 @@ func verifyDeviceExtensions(dt *types.DeviceTrust, username string, verified boo
return nil
}
}

var logModeOnce sync.Once

func maybeLogModeMismatch(effective string, dt *types.DeviceTrust) {
if dt == nil || dt.Mode == "" || effective == dt.Mode {
return
}

logModeOnce.Do(func() {
log.Warnf("Device Trust: mode %q requires Teleport Enterprise. Using effective mode %q.", dt.Mode, effective)
})
}
6 changes: 3 additions & 3 deletions lib/devicetrust/authz/authz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,13 @@ func runVerifyUserTest(t *testing.T, method string, verify func(dt *types.Device
assertErr: assertNoErr,
},
{
name: "OSS mode never enforced",
name: "OSS mode=required (Enterprise Auth)",
buildType: modules.BuildOSS,
dt: &types.DeviceTrust{
Mode: constants.DeviceTrustModeRequired, // Invalid for OSS, treated as "off".
Mode: constants.DeviceTrustModeRequired,
},
ext: userWithoutExtensions,
assertErr: assertNoErr,
assertErr: assertDeniedErr,
},
{
name: "Enterprise mode=off",
Expand Down
12 changes: 12 additions & 0 deletions lib/devicetrust/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,18 @@ func GetEffectiveMode(dt *types.DeviceTrust) string {
return dt.Mode
}

// GetEnforcementMode returns the configured device trust mode, disregarding the
// provenance of the binary if the mode is set.
// Used for device enforcement checks. Guarantees that OSS binaries paired with
// an Enterprise Auth will correctly enforce device trust.
func GetEnforcementMode(dt *types.DeviceTrust) string {
// If absent use the defaults from GetEffectiveMode.
if dt == nil || dt.Mode == "" {
return GetEffectiveMode(dt)
}
return dt.Mode
}

// ValidateConfigAgainstModules verifies the device trust configuration against
// the current modules.
// This method exists to provide feedback to users about invalid configurations,
Expand Down
58 changes: 58 additions & 0 deletions lib/devicetrust/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ import (
)

func TestValidateConfigAgainstModules(t *testing.T) {
// Don't t.Parallel, depends on modules.SetTestModules.

type testCase struct {
name string
buildType string
Expand Down Expand Up @@ -110,3 +112,59 @@ func TestValidateConfigAgainstModules(t *testing.T) {
})
}
}

func TestGetEnforcementMode(t *testing.T) {
// Don't t.Parallel, depends on modules.SetTestModules.

tests := []struct {
name string
buildType string
dt *types.DeviceTrust
want string
}{
{
name: "OSS default",
buildType: modules.BuildOSS,
want: constants.DeviceTrustModeOff,
},
{
name: "Enterprise default",
buildType: modules.BuildEnterprise,
want: constants.DeviceTrustModeOptional,
},
{
name: "dt.Mode empty",
buildType: modules.BuildEnterprise,
dt: &types.DeviceTrust{
Mode: "",
},
want: constants.DeviceTrustModeOptional,
},
{
name: "dt.Mode set",
buildType: modules.BuildEnterprise,
dt: &types.DeviceTrust{
Mode: constants.DeviceTrustModeRequired,
},
want: constants.DeviceTrustModeRequired,
},
{
name: "OSS node with Ent Auth",
buildType: modules.BuildOSS,
dt: &types.DeviceTrust{
Mode: constants.DeviceTrustModeRequired,
},
want: constants.DeviceTrustModeRequired,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
modules.SetTestModules(t, &modules.TestModules{
TestBuildType: test.buildType,
})

got := dtconfig.GetEnforcementMode(test.dt)
assert.Equal(t, test.want, got, "dtconfig.GetEnforcementMode mismatch")
})
}
}
19 changes: 9 additions & 10 deletions lib/srv/session_control_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,10 @@ func TestSessionController_AcquireSessionContext(t *testing.T) {
}
return idCtx
}
assertTrustedDeviceRequired := func(t *testing.T, _ context.Context, err error, _ *eventstest.MockRecorderEmitter) {
assert.ErrorContains(t, err, "device", "AcquireSessionContext returned an unexpected error")
assert.True(t, trace.IsAccessDenied(err), "AcquireSessionContext returned an error other than trace.AccessDeniedError: %T", err)
}

cases := []struct {
name string
Expand Down Expand Up @@ -451,22 +455,17 @@ func TestSessionController_AcquireSessionContext(t *testing.T) {
},
},
{
name: "device extensions not enforced for OSS",
cfg: cfgWithDeviceMode(constants.DeviceTrustModeRequired),
identity: minimalIdentity,
assertion: func(t *testing.T, _ context.Context, err error, _ *eventstest.MockRecorderEmitter) {
assert.NoError(t, err, "AcquireSessionContext returned an unexpected error")
},
name: "device extensions enforced for OSS",
cfg: cfgWithDeviceMode(constants.DeviceTrustModeRequired),
identity: minimalIdentity,
assertion: assertTrustedDeviceRequired,
},
{
name: "device extensions enforced for Enterprise",
buildType: modules.BuildEnterprise,
cfg: cfgWithDeviceMode(constants.DeviceTrustModeRequired),
identity: minimalIdentity,
assertion: func(t *testing.T, _ context.Context, err error, _ *eventstest.MockRecorderEmitter) {
assert.ErrorContains(t, err, "device", "AcquireSessionContext returned an unexpected error")
assert.True(t, trace.IsAccessDenied(err), "AcquireSessionContext returned an error other than trace.AccessDeniedError: %T", err)
},
assertion: assertTrustedDeviceRequired,
},
{
name: "device extensions valid for Enterprise",
Expand Down

0 comments on commit 9a298ef

Please sign in to comment.