Skip to content

Commit

Permalink
Cancel taskrun using entrypoint binary
Browse files Browse the repository at this point in the history
The cancellation of taskruns is now done through the entrypoint binary
through a new flag called 'cancel_file'. This removes the need for
deleting the pods to cancel a taskrun, allowing examination of the logs
on the pods from cancelled taskruns. Part of work on issue tektoncd#3238

Signed-off-by: Arash Deshmeh <adeshmeh@ca.ibm.com>
  • Loading branch information
adshmh committed Feb 24, 2022
1 parent 9c529d3 commit e12f80f
Show file tree
Hide file tree
Showing 10 changed files with 830 additions and 51 deletions.
5 changes: 4 additions & 1 deletion cmd/entrypoint/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package main

import (
"context"
"encoding/json"
"flag"
"fmt"
Expand Down Expand Up @@ -49,6 +50,7 @@ var (
onError = flag.String("on_error", "", "Set to \"continue\" to ignore an error and continue when a container terminates with a non-zero exit code."+
" Set to \"stopAndFail\" to declare a failure with a step error and stop executing the rest of the steps.")
stepMetadataDir = flag.String("step_metadata_dir", "", "If specified, create directory to store the step metadata e.g. /tekton/steps/<step-name>/")
cancelFile = flag.String("cancel_file", "", "Path indicating task should be cancelled")
)

const (
Expand All @@ -58,7 +60,7 @@ const (

func checkForBreakpointOnFailure(e entrypoint.Entrypointer, breakpointExitPostFile string) {
if e.BreakpointOnFailure {
if waitErr := e.Waiter.Wait(breakpointExitPostFile, false, false); waitErr != nil {
if waitErr := e.Waiter.Wait(context.Background(), breakpointExitPostFile, false, false); waitErr != nil {
log.Println("error occurred while waiting for " + breakpointExitPostFile + " : " + waitErr.Error())
}
// get exitcode from .breakpointexit
Expand Down Expand Up @@ -136,6 +138,7 @@ func main() {
BreakpointOnFailure: *breakpointOnFailure,
OnError: *onError,
StepMetadataDir: *stepMetadataDir,
CancelFile: *cancelFile,
}

// Copy any creds injected by the controller into the $HOME directory of the current
Expand Down
6 changes: 5 additions & 1 deletion cmd/entrypoint/waiter.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"context"
"fmt"
"os"
"time"
Expand Down Expand Up @@ -31,11 +32,14 @@ func (rw *realWaiter) setWaitPollingInterval(pollingInterval time.Duration) *rea
//
// If a file of the same name with a ".err" extension exists then this Wait
// will end with a skipError.
func (rw *realWaiter) Wait(file string, expectContent bool, breakpointOnFailure bool) error {
func (rw *realWaiter) Wait(ctx context.Context, file string, expectContent bool, breakpointOnFailure bool) error {
if file == "" {
return nil
}
for ; ; time.Sleep(rw.waitPollingInterval) {
if ctx.Err() != nil {
return nil
}
if info, err := os.Stat(file); err == nil {
if !expectContent || info.Size() > 0 {
return nil
Expand Down
38 changes: 31 additions & 7 deletions cmd/entrypoint/waiter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@ limitations under the License.
package main

import (
"context"
"io/ioutil"
"os"
"strings"
"testing"
"time"
)

const testWaitPollingInterval = 10 * time.Millisecond
const testWaitPollingInterval = 15 * time.Millisecond

func TestRealWaiterWaitMissingFile(t *testing.T) {
// Create a temp file and then immediately delete it to get
Expand All @@ -38,7 +39,7 @@ func TestRealWaiterWaitMissingFile(t *testing.T) {
rw := realWaiter{}
doneCh := make(chan struct{})
go func() {
err := rw.setWaitPollingInterval(testWaitPollingInterval).Wait(tmp.Name(), false, false)
err := rw.setWaitPollingInterval(testWaitPollingInterval).Wait(context.Background(), tmp.Name(), false, false)
if err != nil {
t.Errorf("error waiting on tmp file %q", tmp.Name())
}
Expand All @@ -61,7 +62,7 @@ func TestRealWaiterWaitWithFile(t *testing.T) {
rw := realWaiter{}
doneCh := make(chan struct{})
go func() {
err := rw.setWaitPollingInterval(testWaitPollingInterval).Wait(tmp.Name(), false, false)
err := rw.setWaitPollingInterval(testWaitPollingInterval).Wait(context.Background(), tmp.Name(), false, false)
if err != nil {
t.Errorf("error waiting on tmp file %q", tmp.Name())
}
Expand All @@ -84,7 +85,7 @@ func TestRealWaiterWaitMissingContent(t *testing.T) {
rw := realWaiter{}
doneCh := make(chan struct{})
go func() {
err := rw.setWaitPollingInterval(testWaitPollingInterval).Wait(tmp.Name(), true, false)
err := rw.setWaitPollingInterval(testWaitPollingInterval).Wait(context.Background(), tmp.Name(), true, false)
if err != nil {
t.Errorf("error waiting on tmp file %q", tmp.Name())
}
Expand All @@ -107,7 +108,7 @@ func TestRealWaiterWaitWithContent(t *testing.T) {
rw := realWaiter{}
doneCh := make(chan struct{})
go func() {
err := rw.setWaitPollingInterval(testWaitPollingInterval).Wait(tmp.Name(), true, false)
err := rw.setWaitPollingInterval(testWaitPollingInterval).Wait(context.Background(), tmp.Name(), true, false)
if err != nil {
t.Errorf("error waiting on tmp file %q", tmp.Name())
}
Expand Down Expand Up @@ -135,7 +136,7 @@ func TestRealWaiterWaitWithErrorWaitfile(t *testing.T) {
doneCh := make(chan struct{})
go func() {
// error of type skipError is returned after encountering a error waitfile
err := rw.setWaitPollingInterval(testWaitPollingInterval).Wait(tmpFileName, false, false)
err := rw.setWaitPollingInterval(testWaitPollingInterval).Wait(context.Background(), tmpFileName, false, false)
if err == nil {
t.Errorf("expected skipError upon encounter error waitfile")
}
Expand Down Expand Up @@ -165,7 +166,7 @@ func TestRealWaiterWaitWithBreakpointOnFailure(t *testing.T) {
doneCh := make(chan struct{})
go func() {
// When breakpoint on failure is enabled skipError shouldn't be returned for a error waitfile
err := rw.setWaitPollingInterval(testWaitPollingInterval).Wait(tmpFileName, false, true)
err := rw.setWaitPollingInterval(testWaitPollingInterval).Wait(context.Background(), tmpFileName, false, true)
if err != nil {
t.Errorf("error waiting on tmp file %q", tmp.Name())
}
Expand All @@ -178,3 +179,26 @@ func TestRealWaiterWaitWithBreakpointOnFailure(t *testing.T) {
t.Errorf("expected Wait() to have detected a non-zero file size by now")
}
}

func TestRealWaiterWaitWithCancel(t *testing.T) {
rw := realWaiter{}
doneCh := make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
err := rw.setWaitPollingInterval(testWaitPollingInterval).Wait(ctx, "does_not_exist", true, false)
if err != nil {
t.Error("error waiting")
}
close(doneCh)
}()
time.Sleep(testWaitPollingInterval)
cancel()

select {
case <-doneCh:
// Success
case <-time.After(2 * testWaitPollingInterval):
t.Errorf("expected Wait() to have exited by now")
}
}
41 changes: 37 additions & 4 deletions pkg/entrypoint/entrypointer.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,15 @@ type Entrypointer struct {
OnError string
// StepMetadataDir is the directory for a step where the step related metadata can be stored
StepMetadataDir string

// CancelFile is the file that causes the task to be cancelled.
CancelFile string
}

// Waiter encapsulates waiting for files to exist.
type Waiter interface {
// Wait blocks until the specified file exists.
Wait(file string, expectContent bool, breakpointOnFailure bool) error
Wait(ctx context.Context, file string, expectContent bool, breakpointOnFailure bool) error
}

// Runner encapsulates running commands.
Expand Down Expand Up @@ -114,7 +117,7 @@ func (e Entrypointer) Go() error {
}()

for _, f := range e.WaitFiles {
if err := e.Waiter.Wait(f, e.WaitFileContent, e.BreakpointOnFailure); err != nil {
if err := e.Waiter.Wait(context.Background(), f, e.WaitFileContent, e.BreakpointOnFailure); err != nil {
// An error happened while waiting, so we bail
// *but* we write postfile to make next steps bail too.
// In case of breakpoint on failure do not write post file.
Expand Down Expand Up @@ -146,15 +149,45 @@ func (e Entrypointer) Go() error {
var cancel context.CancelFunc
if e.Timeout != nil && *e.Timeout != time.Duration(0) {
ctx, cancel = context.WithTimeout(ctx, *e.Timeout)
defer cancel()
} else {
ctx, cancel = context.WithCancel(ctx)
}
defer cancel()

runChan := make(chan error)
errChan := make(chan error, 1)
go func() {
errChan <- e.Runner.Run(ctx, e.Command...)
close(runChan)
cancel()
}()

var shouldCancel bool
if e.CancelFile != "" {
if err := e.Waiter.Wait(ctx, e.CancelFile, true, e.BreakpointOnFailure); err != nil {
return err
}
if ctx.Err() == nil {
shouldCancel = true
}
} else {
<-runChan
}
err = e.Runner.Run(ctx, e.Command...)
err = <-errChan
if err == context.DeadlineExceeded {
output = append(output, v1beta1.PipelineResourceResult{
Key: "Reason",
Value: "TimeoutExceeded",
ResultType: v1beta1.InternalTektonResultType,
})
} else if shouldCancel {
// Waiter has found the cancel file: Cancel the run
cancel()
output = append(output, v1beta1.PipelineResourceResult{
Key: "Reason",
Value: "Cancelled",
ResultType: v1beta1.InternalTektonResultType,
})
}
}

Expand Down
76 changes: 70 additions & 6 deletions pkg/entrypoint/entrypointer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ func TestEntrypointer(t *testing.T) {
desc, entrypoint, postFile, stepDir, stepDirLink string
waitFiles, args []string
breakpointOnFailure bool
cancelFile string
noCancel bool
shouldCancel bool
waiter Waiter
runner Runner
}{{
desc: "do nothing",
}, {
Expand Down Expand Up @@ -160,6 +165,18 @@ func TestEntrypointer(t *testing.T) {
}, {
desc: "breakpointOnFailure to wait or not to wait ",
breakpointOnFailure: true,
}, {
desc: "Runner completes if not cancelled",
cancelFile: ".",
waiter: &contextWaiter{duration: 30 * time.Millisecond},
runner: &fakeLongRunner{duration: 10 * time.Millisecond},
noCancel: true,
}, {
desc: "Runner can be cancelled",
cancelFile: ".",
waiter: &contextWaiter{duration: 10 * time.Millisecond},
runner: &fakeLongRunner{duration: 30 * time.Millisecond},
shouldCancel: true,
}} {
t.Run(c.desc, func(t *testing.T) {
fw, fr, fpw := &fakeWaiter{}, &fakeRunner{}, &fakePostWriter{}
Expand All @@ -171,12 +188,25 @@ func TestEntrypointer(t *testing.T) {
terminationPath = terminationFile.Name()
defer os.Remove(terminationFile.Name())
}
var waiter Waiter
if c.waiter != nil {
waiter = c.waiter
} else {
waiter = fw
}
var runner Runner
if c.runner != nil {
runner = c.runner
} else {
runner = fr
}
err := Entrypointer{
Command: append([]string{c.entrypoint}, c.args...),
WaitFiles: c.waitFiles,
PostFile: c.postFile,
Waiter: fw,
Runner: fr,
CancelFile: c.cancelFile,
Waiter: waiter,
Runner: runner,
PostWriter: fpw,
TerminationPath: terminationPath,
Timeout: &timeout,
Expand All @@ -199,7 +229,7 @@ func TestEntrypointer(t *testing.T) {
}

wantArgs := append([]string{c.entrypoint}, c.args...)
if len(wantArgs) != 0 {
if c.entrypoint != "" || len(c.args) > 0 {
if fr.args == nil {
t.Error("Wanted command to be run, got nil")
} else if !reflect.DeepEqual(*fr.args, wantArgs) {
Expand Down Expand Up @@ -229,15 +259,23 @@ func TestEntrypointer(t *testing.T) {
var entries []v1alpha1.PipelineResourceResult
if err := json.Unmarshal(fileContents, &entries); err == nil {
var found = false
var cancelled = false
for _, result := range entries {
if result.Key == "StartedAt" {
found = true
break
} else if result.Key == "Reason" && result.Value == "Cancelled" {
cancelled = true
}
}
if !found {
t.Error("Didn't find the startedAt entry")
}
if c.noCancel && cancelled {
t.Error("Didn't expect cancel but run was cancelled")
}
if c.shouldCancel && !cancelled {
t.Error("Expected run to be cancelled but it wasn't")
}
}
} else if !os.IsNotExist(err) {
t.Error("Wanted termination file written, got nil")
Expand Down Expand Up @@ -349,9 +387,24 @@ func TestEntrypointer_OnError(t *testing.T) {
}
}

type contextWaiter struct {
duration time.Duration
}

func (c contextWaiter) Wait(ctx context.Context, _ string, _ bool, _ bool) error {
for {
select {
case <-time.After(c.duration):
return nil
case <-ctx.Done():
return nil
}
}
}

type fakeWaiter struct{ waited []string }

func (f *fakeWaiter) Wait(file string, _ bool, _ bool) error {
func (f *fakeWaiter) Wait(ctx context.Context, file string, _ bool, _ bool) error {
f.waited = append(f.waited, file)
return nil
}
Expand Down Expand Up @@ -380,7 +433,7 @@ func (f *fakePostWriter) Write(file, content string) {

type fakeErrorWaiter struct{ waited *string }

func (f *fakeErrorWaiter) Wait(file string, expectContent bool, breakpointOnFailure bool) error {
func (f *fakeErrorWaiter) Wait(ctx context.Context, file string, expectContent bool, breakpointOnFailure bool) error {
f.waited = &file
return errors.New("waiter failed")
}
Expand Down Expand Up @@ -418,3 +471,14 @@ func (f *fakeExitErrorRunner) Run(ctx context.Context, args ...string) error {
f.args = &args
return exec.Command("ls", "/bogus/path").Run()
}

type fakeLongRunner struct{ duration time.Duration }

func (f *fakeLongRunner) Run(ctx context.Context, _ ...string) error {
select {
case <-time.After(f.duration):
return nil
case <-ctx.Done():
return nil
}
}
Loading

0 comments on commit e12f80f

Please sign in to comment.