Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Task context aware #19

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion golang/pkg/rnr/job.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func taskDiff(path []string, old *pb.Task, new *pb.Task) []string {
}

func (j *Job) Poll(ctx context.Context) {
j.root.Poll()
j.root.Poll(ctx)

newProto := j.root.Proto(nil)
// Calculate diff and post state changes
Expand Down
6 changes: 3 additions & 3 deletions golang/pkg/rnr/job_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ import (
var _ rnr.Task = &task{}

type task struct {
pollFn func()
pollFn func(context.Context)
}

func (t *task) Poll() { t.pollFn() }
func (t *task) Poll(ctx context.Context) { t.pollFn(ctx) }
func (t *task) SetState(pb.TaskState) {}
func (t *task) Proto(func(*pb.Task)) *pb.Task { return nil }
func (*task) GetChild(string) rnr.Task { return nil }
Expand All @@ -25,7 +25,7 @@ func TestJob(t *testing.T) {
var pollCount int

task := &task{
pollFn: func() {
pollFn: func(context.Context) {
pollCount++
t.Logf("poll %02d", pollCount)
},
Expand Down
6 changes: 4 additions & 2 deletions golang/pkg/rnr/task_nested.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package rnr

import (
"context"
"fmt"
"sync"

Expand All @@ -9,6 +10,7 @@ import (
)

// Nested Task
var _ Task = &NestedTask{}

type NestedTaskCallback func(*NestedTask, *[]Task)

Expand Down Expand Up @@ -55,7 +57,7 @@ func (nt *NestedTask) Add(task Task) error {
return nil
}

func (nt *NestedTask) Poll() {
func (nt *NestedTask) Poll(ctx context.Context) {

if taskSchedState(nt.Proto(nil)) != RUNNING {
return
Expand Down Expand Up @@ -99,7 +101,7 @@ func (nt *NestedTask) Poll() {
// Poll a task iff it's running or it has its state changed recently
if state == RUNNING || pb.State != nt.oldState[child] {
nt.oldState[child] = pb.State
child.Poll()
child.Poll(ctx)
}

}
Expand Down
30 changes: 18 additions & 12 deletions golang/pkg/rnr/task_nested_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package rnr

import (
"context"
"fmt"
"testing"

Expand Down Expand Up @@ -64,6 +65,7 @@ func TestNestedTask_GetChild(t *testing.T) {
}

func TestNestedTask_FailFirst(t *testing.T) {
ctx := context.Background()
nt := NewNestedTask("nested task test", NestedTaskOptions{Parallelism: 1, CompleteAll: false})
ct1 := newMockFailingTask("child 1")
ct2 := newMockTask("child 2")
Expand All @@ -74,11 +76,12 @@ func TestNestedTask_FailFirst(t *testing.T) {
nt.Add(ct2)
nt.SetState(pb.TaskState_RUNNING)

nt.Poll()
nt.Poll(ctx)
compareTaskStates(t, tasks, []pb.TaskState{pb.TaskState_FAILED, pb.TaskState_PENDING, pb.TaskState_FAILED})
}

func TestNestedTask_CompleteAllFail(t *testing.T) {
ctx := context.Background()
nt := NewNestedTask("nested task test", NestedTaskOptions{Parallelism: 1, CompleteAll: true})
ct1 := newMockFailingTask("child 1")
ct2 := newMockTask("child 2")
Expand All @@ -89,14 +92,15 @@ func TestNestedTask_CompleteAllFail(t *testing.T) {
nt.Add(ct2)
nt.SetState(pb.TaskState_RUNNING)

nt.Poll()
nt.Poll(ctx)
compareTaskStates(t, tasks, []pb.TaskState{pb.TaskState_FAILED, pb.TaskState_PENDING, pb.TaskState_RUNNING})

nt.Poll()
nt.Poll(ctx)
compareTaskStates(t, tasks, []pb.TaskState{pb.TaskState_FAILED, pb.TaskState_SUCCESS, pb.TaskState_FAILED})
}

func TestNestedTask_CompleteAllSuccess(t *testing.T) {
ctx := context.Background()
ct1 := newMockTask("child 1")
ct2 := newMockTask("child 2")
nt := NewNestedTask("nested task test", NestedTaskOptions{Parallelism: 1, CompleteAll: true})
Expand All @@ -107,14 +111,15 @@ func TestNestedTask_CompleteAllSuccess(t *testing.T) {
nt.Add(ct2)
nt.SetState(pb.TaskState_RUNNING)

nt.Poll()
nt.Poll(ctx)
compareTaskStates(t, tasks, []pb.TaskState{pb.TaskState_SUCCESS, pb.TaskState_PENDING, pb.TaskState_RUNNING})

nt.Poll()
nt.Poll(ctx)
compareTaskStates(t, tasks, []pb.TaskState{pb.TaskState_SUCCESS, pb.TaskState_SUCCESS, pb.TaskState_SUCCESS})
}

func TestNestedTask_CallbackInvoked(t *testing.T) {
ctx := context.Background()
childrenAdded := 0
nt := NewNestedTask("nested task test", NestedTaskOptions{
Parallelism: 1,
Expand All @@ -127,13 +132,13 @@ func TestNestedTask_CallbackInvoked(t *testing.T) {
})

nt.SetState(pb.TaskState_PENDING)
nt.Poll()
nt.Poll(ctx)
if childrenAdded != 0 {
t.Errorf("callback shouldn't be invoked for non-RUNNING nested task!")
}

nt.SetState(pb.TaskState_RUNNING)
nt.Poll()
nt.Poll(ctx)
if childrenAdded != 1 {
t.Errorf("nested task callback was not invoked for running task!")
}
Expand All @@ -142,14 +147,15 @@ func TestNestedTask_CallbackInvoked(t *testing.T) {
}

nt.SetState(pb.TaskState_RUNNING)
nt.Poll()
nt.Poll(ctx)
fmt.Println(nt.children[1].Proto(nil).Name)
if len(nt.children) != 2 {
t.Errorf("expected 2 children, got %d", len(nt.children))
}
}

func TestNextedTask_PollAfterStateChange(t *testing.T) {
ctx := context.Background()
ct := newMockTask("child 1")
ct.finalState = pb.TaskState_RUNNING // This will stay in RUNNING state unless state is changed externally
nt := NewNestedTask("nested task test", NestedTaskOptions{Parallelism: 1, CompleteAll: true})
Expand All @@ -159,14 +165,14 @@ func TestNextedTask_PollAfterStateChange(t *testing.T) {

// Initially, a task in in PENDING state. Poll() from its parent should transfer it to RUNNING.
oldPollCount := ct.pollCount
nt.Poll()
nt.Poll(ctx)
if oldPollCount+1 != ct.pollCount {
t.Errorf("task was not polled when transitioning from PENDING to RUNNING state")
}

// A task in RUNNING state should get Poll()-ed each time.
oldPollCount = ct.pollCount
nt.Poll()
nt.Poll(ctx)
if oldPollCount+1 != ct.pollCount {
t.Errorf("task was not polled when in RUNNING state")
}
Expand All @@ -175,7 +181,7 @@ func TestNextedTask_PollAfterStateChange(t *testing.T) {
nt.SetState(pb.TaskState_RUNNING)
ct.SetState(pb.TaskState_SUCCESS)
oldPollCount = ct.pollCount
nt.Poll()
nt.Poll(ctx)
if oldPollCount+1 != ct.pollCount {
t.Errorf("task was not polled when transitioning from RUNNING to SUCCESS state")
}
Expand All @@ -184,7 +190,7 @@ func TestNextedTask_PollAfterStateChange(t *testing.T) {
nt.SetState(pb.TaskState_RUNNING)
ct.SetState(pb.TaskState_SKIPPED)
oldPollCount = ct.pollCount
nt.Poll()
nt.Poll(ctx)
if oldPollCount+1 != ct.pollCount {
t.Errorf("task was not polled when transitioning from SUCCESS to SKIPPED state")
}
Expand Down
31 changes: 29 additions & 2 deletions golang/pkg/rnr/task_shell.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package rnr

import (
"context"
"fmt"
"os/exec"
"sync"

Expand All @@ -9,6 +11,7 @@ import (
)

// Shell Task
var _ Task = &ShellTask{}

type ShellTask struct {
pbMutex sync.Mutex
Expand All @@ -34,16 +37,40 @@ func NewShellTask(name, cmd string, args ...string) *ShellTask {
return ret
}

func (ct *ShellTask) Poll() {
func (ct *ShellTask) Poll(ctx context.Context) {
if ct.cmd.Process == nil {
// Not yet started, let's launch it first
go func() { ct.err <- ct.cmd.Run() }()
if err := ct.cmd.Start(); err != nil {
ct.pb.Message = fmt.Sprintf("failed to start: %v", err)
ct.pb.State = pb.TaskState_FAILED
return
}

go func() { ct.err <- ct.cmd.Wait() }()

ct.pb.State = pb.TaskState_RUNNING
ct.pb.Message = "Started"
return
}

if ct.pb.State != pb.TaskState_RUNNING {
return
}

select {
default:
// still running
case <-ctx.Done():
if ct.cmd.ProcessState != nil {
// process was already killed/finished
return
}

if err := ct.cmd.Process.Kill(); err != nil {
ct.pb.Message = fmt.Sprintf("cannot kill process: %v", err)
ct.pb.State = pb.TaskState_FAILED
}

case err := <-ct.err:
ct.pb.Message = "Exited"
// The process has finished
Expand Down
80 changes: 79 additions & 1 deletion golang/pkg/rnr/task_shell_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,84 @@
package rnr

import "testing"
import (
"context"
"flag"
"fmt"
"os"
"strconv"
"testing"
"time"

"github.com/mplzik/rnr/golang/pkg/pb"
)

func TestMain(m *testing.M) {
var (
runTests = flag.Bool("run-tests", true, "whether to run the tests")
exitCode = flag.Int("exit-code", 0, "exit code when not running the tests")
sleep = flag.Duration("sleep", 0*time.Second, "how long to sleep to simulate activity")
)

flag.Parse()

if *runTests {
os.Exit(m.Run())
}

t := time.NewTicker(10 * time.Millisecond)
a := time.After(*sleep)

var done bool

for !done {
select {
case <-a:
fmt.Println("DONE AFTER", *sleep)
done = true
case <-t.C:
fmt.Println("TICK")
}
}

os.Exit(*exitCode)
}

func TestShellTask(t *testing.T) {
run := func(ctx context.Context, exitCode int, expectedState pb.TaskState, sleep time.Duration) {
t.Helper()

task := NewShellTask("foo", os.Args[0], "-run-tests=false", "-exit-code", strconv.Itoa(exitCode), "-sleep", sleep.String())

for {
state := task.Proto(nil).State
if state == pb.TaskState_SUCCESS || state == pb.TaskState_FAILED {
break
}
task.Poll(ctx)
time.Sleep(50 * time.Millisecond)
}

if state := task.Proto(nil).State; state != expectedState {
t.Fatalf("expecting %v state, got %v", expectedState, state)
}

pb := task.Proto(nil)
t.Logf("task in state %v with message %q", pb.State, pb.Message)
}

ctx := context.Background()

run(ctx, 0, pb.TaskState_SUCCESS, 0)
run(ctx, 1, pb.TaskState_FAILED, 0)

ctx2, cancel := context.WithCancel(ctx)
cancel() // manually cancel the context
run(ctx2, 0, pb.TaskState_FAILED, 5*time.Second)

ctx3, cancel := context.WithCancel(ctx)
cancel() // manually cancel the context
run(ctx3, 1, pb.TaskState_FAILED, 5*time.Second)
}

func TestShellTask_GetChild(t *testing.T) {
c := NewShellTask("shell task test", "").GetChild("foo")
Expand Down
7 changes: 5 additions & 2 deletions golang/pkg/rnr/task_simple_callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
type CallbackFunc func(context.Context, *CallbackTask) (bool, error)

// CallbackTask
var _ Task = &CallbackTask{}

// CallbackTask implements a task with synchronously called callback.
// It returns a boolean indicating whether to transition into a final state and an error in case an error has happened. These values are used to best-effort-update the task's protobuf. If (false, nil) is supplied, the task state will be left untouched
Expand All @@ -34,13 +35,15 @@ func NewCallbackTask(name string, callback CallbackFunc) *CallbackTask {
}

// Poll synchronously calls the callback
func (ct *CallbackTask) Poll() {
func (ct *CallbackTask) Poll(ctx context.Context) {
if (taskSchedState(&ct.pb) != RUNNING) && (ct.oldState == ct.pb.GetState()) {
return
}

// TODO should we call the callback if the context was done?

ct.oldState = ct.pb.GetState()
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
ctx, cancel := context.WithTimeout(ctx, 1*time.Second)
defer cancel()

ret, err := ct.callback(ctx, ct)
Expand Down
Loading