Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract functionality to detect if the CLI is running on DBR #1889

Merged
merged 3 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions bundle/config/mutator/configure_wsfs.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,12 @@ import (
"strings"

"github.com/databricks/cli/bundle"
"github.com/databricks/cli/libs/dbr"
"github.com/databricks/cli/libs/diag"
"github.com/databricks/cli/libs/env"
"github.com/databricks/cli/libs/filer"
"github.com/databricks/cli/libs/vfs"
)

const envDatabricksRuntimeVersion = "DATABRICKS_RUNTIME_VERSION"

type configureWSFS struct{}

func ConfigureWSFS() bundle.Mutator {
Expand All @@ -32,7 +30,7 @@ func (m *configureWSFS) Apply(ctx context.Context, b *bundle.Bundle) diag.Diagno
}

// The executable must be running on DBR.
if _, ok := env.Lookup(ctx, envDatabricksRuntimeVersion); !ok {
if !dbr.RunsOnRuntime(ctx) {
return nil
}

Expand Down
65 changes: 65 additions & 0 deletions bundle/config/mutator/configure_wsfs_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package mutator_test

import (
"context"
"runtime"
"testing"

"github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config/mutator"
"github.com/databricks/cli/libs/dbr"
"github.com/databricks/cli/libs/vfs"
"github.com/databricks/databricks-sdk-go/config"
"github.com/databricks/databricks-sdk-go/experimental/mocks"
"github.com/stretchr/testify/assert"
)

func mockBundleForConfigureWSFS(t *testing.T, syncRootPath string) *bundle.Bundle {
// The native path of the sync root on Windows will never match the /Workspace prefix,
// so the test case for nominal behavior will always fail.
if runtime.GOOS == "windows" {
t.Skip("this test is not applicable on Windows")
}

b := &bundle.Bundle{
SyncRoot: vfs.MustNew(syncRootPath),
}

w := mocks.NewMockWorkspaceClient(t)
w.WorkspaceClient.Config = &config.Config{}
b.SetWorkpaceClient(w.WorkspaceClient)

return b
}

func TestConfigureWSFS_SkipsIfNotWorkspacePrefix(t *testing.T) {
b := mockBundleForConfigureWSFS(t, "/foo")
originalSyncRoot := b.SyncRoot

ctx := context.Background()
diags := bundle.Apply(ctx, b, mutator.ConfigureWSFS())
assert.Empty(t, diags)
assert.Equal(t, originalSyncRoot, b.SyncRoot)
}

func TestConfigureWSFS_SkipsIfNotRunningOnRuntime(t *testing.T) {
b := mockBundleForConfigureWSFS(t, "/Workspace/foo")
originalSyncRoot := b.SyncRoot

ctx := context.Background()
ctx = dbr.MockRuntime(ctx, false)
diags := bundle.Apply(ctx, b, mutator.ConfigureWSFS())
assert.Empty(t, diags)
assert.Equal(t, originalSyncRoot, b.SyncRoot)
}

func TestConfigureWSFS_SwapSyncRoot(t *testing.T) {
b := mockBundleForConfigureWSFS(t, "/Workspace/foo")
originalSyncRoot := b.SyncRoot

ctx := context.Background()
ctx = dbr.MockRuntime(ctx, true)
diags := bundle.Apply(ctx, b, mutator.ConfigureWSFS())
assert.Empty(t, diags)
assert.NotEqual(t, originalSyncRoot, b.SyncRoot)
}
4 changes: 4 additions & 0 deletions cmd/root/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/databricks/cli/internal/build"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/dbr"
"github.com/databricks/cli/libs/log"
"github.com/spf13/cobra"
)
Expand Down Expand Up @@ -73,6 +74,9 @@ func New(ctx context.Context) *cobra.Command {
// get the context back
ctx = cmd.Context()

// Detect if the CLI is running on DBR and store this on the context.
ctx = dbr.DetectRuntime(ctx)

// Configure our user agent with the command that's about to be executed.
ctx = withCommandInUserAgent(ctx, cmd)
ctx = withCommandExecIdInUserAgent(ctx)
Expand Down
49 changes: 49 additions & 0 deletions libs/dbr/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package dbr

import "context"

// key is a package-local type to use for context keys.
//
// Using an unexported type for context keys prevents key collisions across
// packages since external packages cannot create values of this type.
type key int

const (
// dbrKey is the context key for the detection result.
// The value of 1 is arbitrary and can be any number.
// Other keys in the same package must have different values.
dbrKey = key(1)
)

// DetectRuntime detects whether or not the current
// process is running inside a Databricks Runtime environment.
// It return a new context with the detection result set.
func DetectRuntime(ctx context.Context) context.Context {
if v := ctx.Value(dbrKey); v != nil {
panic("dbr.DetectRuntime called twice on the same context")
}
return context.WithValue(ctx, dbrKey, detect(ctx))
}

// MockRuntime is a helper function to mock the detection result.
// It returns a new context with the detection result set.
func MockRuntime(ctx context.Context, b bool) context.Context {
if v := ctx.Value(dbrKey); v != nil {
panic("dbr.MockRuntime called twice on the same context")
}
return context.WithValue(ctx, dbrKey, b)
}

// RunsOnRuntime returns the detection result from the context.
// It expects a context returned by [DetectRuntime] or [MockRuntime].
//
// We store this value in a context to avoid having to use either
// a global variable, passing a boolean around everywhere, or
// performing the same detection multiple times.
func RunsOnRuntime(ctx context.Context) bool {
pietern marked this conversation as resolved.
Show resolved Hide resolved
v := ctx.Value(dbrKey)
if v == nil {
panic("dbr.RunsOnRuntime called without calling dbr.DetectRuntime first")
}
return v.(bool)
}
59 changes: 59 additions & 0 deletions libs/dbr/context_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package dbr

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
)

func TestContext_DetectRuntimePanics(t *testing.T) {
ctx := context.Background()

// Run detection.
ctx = DetectRuntime(ctx)

// Expect a panic if the detection is run twice.
assert.Panics(t, func() {
ctx = DetectRuntime(ctx)
})
}

func TestContext_MockRuntimePanics(t *testing.T) {
ctx := context.Background()

// Run detection.
ctx = MockRuntime(ctx, true)

// Expect a panic if the mock function is run twice.
assert.Panics(t, func() {
MockRuntime(ctx, true)
})
}

func TestContext_RunsOnRuntimePanics(t *testing.T) {
ctx := context.Background()

// Expect a panic if the detection is not run.
assert.Panics(t, func() {
RunsOnRuntime(ctx)
})
}

func TestContext_RunsOnRuntime(t *testing.T) {
ctx := context.Background()

// Run detection.
ctx = DetectRuntime(ctx)

// Expect no panic because detection has run.
assert.NotPanics(t, func() {
RunsOnRuntime(ctx)
})
}

func TestContext_RunsOnRuntimeWithMock(t *testing.T) {
ctx := context.Background()
assert.True(t, RunsOnRuntime(MockRuntime(ctx, true)))
assert.False(t, RunsOnRuntime(MockRuntime(ctx, false)))
}
35 changes: 35 additions & 0 deletions libs/dbr/detect.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package dbr

import (
"context"
"os"
"runtime"

"github.com/databricks/cli/libs/env"
)

// Dereference [os.Stat] to allow mocking in tests.
var statFunc = os.Stat

// detect returns true if the current process is running on a Databricks Runtime.
// Its return value is meant to be cached in the context.
func detect(ctx context.Context) bool {
// Databricks Runtime implies Linux.
// Return early on other operating systems.
if runtime.GOOS != "linux" {
return false
}

// Databricks Runtime always has the DATABRICKS_RUNTIME_VERSION environment variable set.
if value, ok := env.Lookup(ctx, "DATABRICKS_RUNTIME_VERSION"); !ok || value == "" {
return false
}

// Expect to see a "/databricks" directory.
if fi, err := statFunc("/databricks"); err != nil || !fi.IsDir() {
return false
}

// All checks passed.
return true
}
83 changes: 83 additions & 0 deletions libs/dbr/detect_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package dbr

import (
"context"
"io/fs"
"runtime"
"testing"

"github.com/databricks/cli/libs/env"
"github.com/databricks/cli/libs/fakefs"
"github.com/stretchr/testify/assert"
)

func requireLinux(t *testing.T) {
if runtime.GOOS != "linux" {
t.Skipf("skipping test on %s", runtime.GOOS)
}
}

func configureStatFunc(t *testing.T, fi fs.FileInfo, err error) {
originalFunc := statFunc
statFunc = func(name string) (fs.FileInfo, error) {
assert.Equal(t, "/databricks", name)
return fi, err
}

t.Cleanup(func() {
statFunc = originalFunc
})
}

func TestDetect_NotLinux(t *testing.T) {
if runtime.GOOS == "linux" {
t.Skip("skipping test on Linux OS")
}

ctx := context.Background()
assert.False(t, detect(ctx))
}

func TestDetect_Env(t *testing.T) {
requireLinux(t)

// Configure other checks to pass.
configureStatFunc(t, fakefs.FileInfo{FakeDir: true}, nil)

t.Run("empty", func(t *testing.T) {
ctx := env.Set(context.Background(), "DATABRICKS_RUNTIME_VERSION", "")
assert.False(t, detect(ctx))
})

t.Run("non-empty cluster", func(t *testing.T) {
ctx := env.Set(context.Background(), "DATABRICKS_RUNTIME_VERSION", "15.4")
assert.True(t, detect(ctx))
})

t.Run("non-empty serverless", func(t *testing.T) {
ctx := env.Set(context.Background(), "DATABRICKS_RUNTIME_VERSION", "client.1.13")
assert.True(t, detect(ctx))
})
}

func TestDetect_Stat(t *testing.T) {
requireLinux(t)

// Configure other checks to pass.
ctx := env.Set(context.Background(), "DATABRICKS_RUNTIME_VERSION", "non-empty")

t.Run("error", func(t *testing.T) {
configureStatFunc(t, nil, fs.ErrNotExist)
assert.False(t, detect(ctx))
})

t.Run("not a directory", func(t *testing.T) {
configureStatFunc(t, fakefs.FileInfo{}, nil)
assert.False(t, detect(ctx))
})

t.Run("directory", func(t *testing.T) {
configureStatFunc(t, fakefs.FileInfo{FakeDir: true}, nil)
assert.True(t, detect(ctx))
})
}
55 changes: 55 additions & 0 deletions libs/fakefs/fakefs.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package fakefs

import (
"io/fs"
"time"
)

// DirEntry is a fake implementation of [fs.DirEntry].
type DirEntry struct {
FileInfo
}

func (entry DirEntry) Type() fs.FileMode {
typ := fs.ModePerm
if entry.FakeDir {
typ |= fs.ModeDir
}
return typ
}

func (entry DirEntry) Info() (fs.FileInfo, error) {
return entry.FileInfo, nil
}

// FileInfo is a fake implementation of [fs.FileInfo].
type FileInfo struct {
FakeName string
FakeSize int64
FakeDir bool
FakeMode fs.FileMode
}

func (info FileInfo) Name() string {
return info.FakeName
}

func (info FileInfo) Size() int64 {
return info.FakeSize
}

func (info FileInfo) Mode() fs.FileMode {
return info.FakeMode
}

func (info FileInfo) ModTime() time.Time {
return time.Now()
}

func (info FileInfo) IsDir() bool {
return info.FakeDir
}

func (info FileInfo) Sys() any {
return nil
}
Loading