Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added --restart flag for bundle run command #1191

Merged
merged 3 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions bundle/run/job.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/databricks/cli/libs/log"
"github.com/databricks/databricks-sdk-go/service/jobs"
"github.com/fatih/color"
"golang.org/x/sync/errgroup"
)

// Default timeout for waiting for a job run to complete.
Expand Down Expand Up @@ -275,3 +276,42 @@ func (r *jobRunner) convertPythonParams(opts *Options) error {

return nil
}

func (r *jobRunner) Cancel(ctx context.Context) error {
w := r.bundle.WorkspaceClient()
jobID, err := strconv.ParseInt(r.job.ID, 10, 64)
if err != nil {
return fmt.Errorf("job ID is not an integer: %s", r.job.ID)
}

runs, err := w.Jobs.ListRunsAll(ctx, jobs.ListRunsRequest{
ActiveOnly: true,
JobId: jobID,
})

if err != nil {
return err
}

if len(runs) == 0 {
return nil
}

errGroup, errCtx := errgroup.WithContext(ctx)
for _, run := range runs {
runId := run.RunId
errGroup.Go(func() error {
wait, err := w.Jobs.CancelRun(errCtx, jobs.CancelRun{
RunId: runId,
})
if err != nil {
return err
}
// Waits for the Terminated or Skipped state
_, err = wait.GetWithTimeout(jobRunTimeout)
andrewnester marked this conversation as resolved.
Show resolved Hide resolved
return err
})
}

return errGroup.Wait()
}
79 changes: 79 additions & 0 deletions bundle/run/job_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
package run

import (
"context"
"testing"
"time"

"github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/bundle/config/resources"
"github.com/databricks/databricks-sdk-go/experimental/mocks"
"github.com/databricks/databricks-sdk-go/service/jobs"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -47,3 +51,78 @@ func TestConvertPythonParams(t *testing.T) {
require.Contains(t, opts.Job.notebookParams, "__python_params")
require.Equal(t, opts.Job.notebookParams["__python_params"], `["param1","param2","param3"]`)
}

func TestJobRunnerCancel(t *testing.T) {
job := &resources.Job{
ID: "123",
}
b := &bundle.Bundle{
Config: config.Root{
Resources: config.Resources{
Jobs: map[string]*resources.Job{
"test_job": job,
},
},
},
}

runner := jobRunner{key: "test", bundle: b, job: job}

m := mocks.NewMockWorkspaceClient(t)
b.SetWorkpaceClient(m.WorkspaceClient)

jobApi := m.GetMockJobsAPI()
jobApi.EXPECT().ListRunsAll(mock.Anything, jobs.ListRunsRequest{
ActiveOnly: true,
JobId: 123,
}).Return([]jobs.BaseRun{
{RunId: 1},
{RunId: 2},
}, nil)

mockWait := &jobs.WaitGetRunJobTerminatedOrSkipped[struct{}]{
Poll: func(time time.Duration, f func(j *jobs.Run)) (*jobs.Run, error) {
return nil, nil
},
}
jobApi.EXPECT().CancelRun(mock.Anything, jobs.CancelRun{
RunId: 1,
}).Return(mockWait, nil)
jobApi.EXPECT().CancelRun(mock.Anything, jobs.CancelRun{
RunId: 2,
}).Return(mockWait, nil)

err := runner.Cancel(context.Background())
require.NoError(t, err)
}

func TestJobRunnerCancelWithNoActiveRuns(t *testing.T) {
job := &resources.Job{
ID: "123",
}
b := &bundle.Bundle{
Config: config.Root{
Resources: config.Resources{
Jobs: map[string]*resources.Job{
"test_job": job,
},
},
},
}

runner := jobRunner{key: "test", bundle: b, job: job}

m := mocks.NewMockWorkspaceClient(t)
b.SetWorkpaceClient(m.WorkspaceClient)

jobApi := m.GetMockJobsAPI()
jobApi.EXPECT().ListRunsAll(mock.Anything, jobs.ListRunsRequest{
ActiveOnly: true,
JobId: 123,
}).Return([]jobs.BaseRun{}, nil)

jobApi.AssertNotCalled(t, "CancelRun")

err := runner.Cancel(context.Background())
require.NoError(t, err)
}
15 changes: 15 additions & 0 deletions bundle/run/pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,18 @@ func (r *pipelineRunner) Run(ctx context.Context, opts *Options) (output.RunOutp
time.Sleep(time.Second)
}
}

func (r *pipelineRunner) Cancel(ctx context.Context) error {
w := r.bundle.WorkspaceClient()
wait, err := w.Pipelines.Stop(ctx, pipelines.StopRequest{
PipelineId: r.pipeline.ID,
})

if err != nil {
return err
}

// Waits for the Idle state of the pipeline
_, err = wait.GetWithTimeout(jobRunTimeout)
andrewnester marked this conversation as resolved.
Show resolved Hide resolved
return err
}
49 changes: 49 additions & 0 deletions bundle/run/pipeline_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package run

import (
"context"
"testing"
"time"

"github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/bundle/config/resources"
"github.com/databricks/databricks-sdk-go/experimental/mocks"
"github.com/databricks/databricks-sdk-go/service/pipelines"
"github.com/stretchr/testify/require"
)

func TestPipelineRunnerCancel(t *testing.T) {
pipeline := &resources.Pipeline{
ID: "123",
}

b := &bundle.Bundle{
Config: config.Root{
Resources: config.Resources{
Pipelines: map[string]*resources.Pipeline{
"test_pipeline": pipeline,
},
},
},
}

runner := pipelineRunner{key: "test", bundle: b, pipeline: pipeline}

m := mocks.NewMockWorkspaceClient(t)
b.SetWorkpaceClient(m.WorkspaceClient)

mockWait := &pipelines.WaitGetPipelineIdle[struct{}]{
Poll: func(time.Duration, func(*pipelines.GetPipelineResponse)) (*pipelines.GetPipelineResponse, error) {
return nil, nil
},
}

pipelineApi := m.GetMockPipelinesAPI()
pipelineApi.EXPECT().Stop(context.Background(), pipelines.StopRequest{
PipelineId: "123",
}).Return(mockWait, nil)

err := runner.Cancel(context.Background())
require.NoError(t, err)
}
3 changes: 3 additions & 0 deletions bundle/run/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ type Runner interface {

// Run the underlying worklow.
Run(ctx context.Context, opts *Options) (output.RunOutput, error)

// Cancel the underlying workflow.
Cancel(ctx context.Context) error
}

// Find locates a runner matching the specified argument.
Expand Down
11 changes: 11 additions & 0 deletions cmd/bundle/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ func newRunCommand() *cobra.Command {
runOptions.Define(cmd)

var noWait bool
var restart bool
cmd.Flags().BoolVar(&noWait, "no-wait", false, "Don't wait for the run to complete.")
cmd.Flags().BoolVar(&restart, "restart", false, "Restart the run if it is already running.")

cmd.RunE = func(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()
Expand Down Expand Up @@ -68,6 +70,15 @@ func newRunCommand() *cobra.Command {
}

runOptions.NoWait = noWait
if restart {
s := cmdio.Spinner(ctx)
s <- "Cancelling all runs"
err := runner.Cancel(ctx)
close(s)
if err != nil {
return err
}
}
output, err := runner.Run(ctx, &runOptions)
if err != nil {
return err
Expand Down
Loading