Skip to content

Commit

Permalink
Cancel taskrun using entrypoint binary
Browse files Browse the repository at this point in the history
The cancellation of taskruns is now done through the entrypoint binary
through a new flag called 'cancel_file'. This removes the need for
deleting the pods to cancel a taskrun, allowing examination of the logs
on the pods from cancelled taskruns. Part of work on issue tektoncd#3238

Signed-off-by: Arash Deshmeh <adeshmeh@ca.ibm.com>
  • Loading branch information
adshmh authored and vdemeester committed Aug 31, 2022
1 parent 05f28f2 commit 676970b
Show file tree
Hide file tree
Showing 10 changed files with 900 additions and 65 deletions.
8 changes: 6 additions & 2 deletions cmd/entrypoint/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package main

import (
"context"
"encoding/json"
"flag"
"fmt"
Expand Down Expand Up @@ -51,6 +52,7 @@ var (
onError = flag.String("on_error", "", "Set to \"continue\" to ignore an error and continue when a container terminates with a non-zero exit code."+
" Set to \"stopAndFail\" to declare a failure with a step error and stop executing the rest of the steps.")
stepMetadataDir = flag.String("step_metadata_dir", "", "If specified, create directory to store the step metadata e.g. /tekton/steps/<step-name>/")
cancelFile = flag.String("cancel_file", "", "Path indicating task should be cancelled")
)

const (
Expand All @@ -60,7 +62,7 @@ const (

func checkForBreakpointOnFailure(e entrypoint.Entrypointer, breakpointExitPostFile string) {
if e.BreakpointOnFailure {
if waitErr := e.Waiter.Wait(breakpointExitPostFile, false, false); waitErr != nil {
if waitErr := e.Waiter.Wait(context.Background(), breakpointExitPostFile, false, false); waitErr != nil {
log.Println("error occurred while waiting for " + breakpointExitPostFile + " : " + waitErr.Error())
}
// get exitcode from .breakpointexit
Expand Down Expand Up @@ -148,6 +150,7 @@ func main() {
BreakpointOnFailure: *breakpointOnFailure,
OnError: *onError,
StepMetadataDir: *stepMetadataDir,
CancelFile: *cancelFile,
}

// Copy any creds injected by the controller into the $HOME directory of the current
Expand All @@ -156,7 +159,8 @@ func main() {
log.Printf("non-fatal error copying credentials: %q", err)
}

if err := e.Go(); err != nil {
ctx := context.Background()
if err := e.Go(ctx); err != nil {
breakpointExitPostFile := e.PostFile + breakpointExitSuffix
switch t := err.(type) {
case skipError:
Expand Down
11 changes: 9 additions & 2 deletions cmd/entrypoint/waiter.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"context"
"fmt"
"os"
"time"
Expand Down Expand Up @@ -31,11 +32,17 @@ func (rw *realWaiter) setWaitPollingInterval(pollingInterval time.Duration) *rea
//
// If a file of the same name with a ".err" extension exists then this Wait
// will end with a skipError.
func (rw *realWaiter) Wait(file string, expectContent bool, breakpointOnFailure bool) error {
func (rw *realWaiter) Wait(ctx context.Context, file string, expectContent bool, breakpointOnFailure bool) error {
if file == "" {
return nil
}
for ; ; time.Sleep(rw.waitPollingInterval) {
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(rw.waitPollingInterval):
}

if info, err := os.Stat(file); err == nil {
if !expectContent || info.Size() > 0 {
return nil
Expand Down
68 changes: 61 additions & 7 deletions cmd/entrypoint/waiter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@ limitations under the License.
package main

import (
"context"
"errors"
"io/ioutil"
"os"
"strings"
"testing"
"time"
)

const testWaitPollingInterval = 10 * time.Millisecond
const testWaitPollingInterval = 25 * time.Millisecond

func TestRealWaiterWaitMissingFile(t *testing.T) {
// Create a temp file and then immediately delete it to get
Expand All @@ -37,8 +39,9 @@ func TestRealWaiterWaitMissingFile(t *testing.T) {
os.Remove(tmp.Name())
rw := realWaiter{}
doneCh := make(chan struct{})
ctx := context.Background()
go func() {
err := rw.setWaitPollingInterval(testWaitPollingInterval).Wait(tmp.Name(), false, false)
err := rw.setWaitPollingInterval(testWaitPollingInterval).Wait(ctx, tmp.Name(), false, false)
if err != nil {
t.Errorf("error waiting on tmp file %q", tmp.Name())
}
Expand All @@ -65,8 +68,9 @@ func TestRealWaiterWaitWithFile(t *testing.T) {
defer os.Remove(tmp.Name())
rw := realWaiter{}
doneCh := make(chan struct{})
ctx := context.Background()
go func() {
err := rw.setWaitPollingInterval(testWaitPollingInterval).Wait(tmp.Name(), false, false)
err := rw.setWaitPollingInterval(testWaitPollingInterval).Wait(ctx, tmp.Name(), false, false)
if err != nil {
t.Errorf("error waiting on tmp file %q", tmp.Name())
}
Expand All @@ -89,8 +93,9 @@ func TestRealWaiterWaitMissingContent(t *testing.T) {
defer os.Remove(tmp.Name())
rw := realWaiter{}
doneCh := make(chan struct{})
ctx := context.Background()
go func() {
err := rw.setWaitPollingInterval(testWaitPollingInterval).Wait(tmp.Name(), true, false)
err := rw.setWaitPollingInterval(testWaitPollingInterval).Wait(ctx, tmp.Name(), true, false)
if err != nil {
t.Errorf("error waiting on tmp file %q", tmp.Name())
}
Expand All @@ -116,8 +121,9 @@ func TestRealWaiterWaitWithContent(t *testing.T) {
defer os.Remove(tmp.Name())
rw := realWaiter{}
doneCh := make(chan struct{})
ctx := context.Background()
go func() {
err := rw.setWaitPollingInterval(testWaitPollingInterval).Wait(tmp.Name(), true, false)
err := rw.setWaitPollingInterval(testWaitPollingInterval).Wait(ctx, tmp.Name(), true, false)
if err != nil {
t.Errorf("error waiting on tmp file %q", tmp.Name())
}
Expand All @@ -144,9 +150,10 @@ func TestRealWaiterWaitWithErrorWaitfile(t *testing.T) {
defer os.Remove(tmp.Name())
rw := realWaiter{}
doneCh := make(chan struct{})
ctx := context.Background()
go func() {
// error of type skipError is returned after encountering a error waitfile
err := rw.setWaitPollingInterval(testWaitPollingInterval).Wait(tmpFileName, false, false)
err := rw.setWaitPollingInterval(testWaitPollingInterval).Wait(ctx, tmpFileName, false, false)
if err == nil {
t.Errorf("expected skipError upon encounter error waitfile")
}
Expand Down Expand Up @@ -175,9 +182,10 @@ func TestRealWaiterWaitWithBreakpointOnFailure(t *testing.T) {
defer os.Remove(tmp.Name())
rw := realWaiter{}
doneCh := make(chan struct{})
ctx := context.Background()
go func() {
// When breakpoint on failure is enabled skipError shouldn't be returned for a error waitfile
err := rw.setWaitPollingInterval(testWaitPollingInterval).Wait(tmpFileName, false, true)
err := rw.setWaitPollingInterval(testWaitPollingInterval).Wait(ctx, tmpFileName, false, true)
if err != nil {
t.Errorf("error waiting on tmp file %q", tmp.Name())
}
Expand All @@ -191,3 +199,49 @@ func TestRealWaiterWaitWithBreakpointOnFailure(t *testing.T) {
t.Errorf("expected Wait() to have detected a non-zero file size by now")
}
}

func TestRealWaiterWaitWithCancel(t *testing.T) {
rw := realWaiter{}
doneCh := make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
err := rw.setWaitPollingInterval(testWaitPollingInterval).Wait(ctx, "does_not_exist", true, false)
if err != nil && !errors.Is(err, context.Canceled) {
t.Error("error waiting")
}
close(doneCh)
}()
time.Sleep(testWaitPollingInterval)
cancel()

select {
case <-doneCh:
// Success
case <-time.After(2 * testWaitPollingInterval):
t.Errorf("expected Wait() to have exited by now")
}
}

func TestRealWaiterWaitWithDeadline(t *testing.T) {
rw := realWaiter{}
doneCh := make(chan struct{})
ctx, cancel := context.WithTimeout(context.Background(), testWaitPollingInterval)
defer cancel()
go func() {
err := rw.setWaitPollingInterval(testWaitPollingInterval).Wait(ctx, "does_not_exist", true, false)
if err != nil && !errors.Is(err, context.DeadlineExceeded) {
t.Error("error waiting")
}
close(doneCh)
}()
time.Sleep(2 * testWaitPollingInterval)
cancel()

select {
case <-doneCh:
// Success
case <-time.After(2 * testWaitPollingInterval):
t.Errorf("expected Wait() to have exited by now")
}
}
42 changes: 36 additions & 6 deletions pkg/entrypoint/entrypointer.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,15 @@ type Entrypointer struct {
OnError string
// StepMetadataDir is the directory for a step where the step related metadata can be stored
StepMetadataDir string

// CancelFile is the file that causes the task to be cancelled.
CancelFile string
}

// Waiter encapsulates waiting for files to exist.
type Waiter interface {
// Wait blocks until the specified file exists.
Wait(file string, expectContent bool, breakpointOnFailure bool) error
Wait(ctx context.Context, file string, expectContent bool, breakpointOnFailure bool) error
}

// Runner encapsulates running commands.
Expand All @@ -101,7 +104,7 @@ type PostWriter interface {

// Go optionally waits for a file, runs the command, and writes a
// post file.
func (e Entrypointer) Go() error {
func (e Entrypointer) Go(ctx context.Context) error {
prod, _ := zap.NewProduction()
logger := prod.Sugar()

Expand All @@ -114,7 +117,7 @@ func (e Entrypointer) Go() error {
}()

for _, f := range e.WaitFiles {
if err := e.Waiter.Wait(f, e.WaitFileContent, e.BreakpointOnFailure); err != nil {
if err := e.Waiter.Wait(ctx, f, e.WaitFileContent, e.BreakpointOnFailure); err != nil {
// An error happened while waiting, so we bail
// *but* we write postfile to make next steps bail too.
// In case of breakpoint on failure do not write post file.
Expand Down Expand Up @@ -142,19 +145,46 @@ func (e Entrypointer) Go() error {
}

if err == nil {
ctx := context.Background()
var cancel context.CancelFunc
if e.Timeout != nil && *e.Timeout != time.Duration(0) {
ctx, cancel = context.WithTimeout(ctx, *e.Timeout)
defer cancel()
} else {
ctx, cancel = context.WithCancel(ctx)
}
defer cancel()

errChan := make(chan error, 1)
go func() {
errChan <- e.Runner.Run(ctx, e.Command...)
cancel()
}()

var cancelled bool
if e.CancelFile != "" {
if err := e.Waiter.Wait(ctx, e.CancelFile, true, e.BreakpointOnFailure); err != nil {
return err
}
if ctx.Err() == nil {
cancel()
cancelled = true
}
}
err = e.Runner.Run(ctx, e.Command...)
err = <-errChan

if err == context.DeadlineExceeded {
output = append(output, v1beta1.PipelineResourceResult{
Key: "Reason",
Value: "TimeoutExceeded",
ResultType: v1beta1.InternalTektonResultType,
})
} else if cancelled {
// Waiter has found the cancel file: Cancel the run
cancel()
output = append(output, v1beta1.PipelineResourceResult{
Key: "Reason",
Value: "Cancelled",
ResultType: v1beta1.InternalTektonResultType,
})
}
}

Expand Down
Loading

0 comments on commit 676970b

Please sign in to comment.