From 7944f8adbaa8964eaf0993a005dfce085bd345c7 Mon Sep 17 00:00:00 2001 From: Anand Swaminathan Date: Wed, 2 Dec 2020 00:03:17 -0800 Subject: [PATCH] Handle nodeexec not found for Task (#208) --- .../pkg/controller/nodes/task/handler.go | 6 +- .../pkg/controller/nodes/task/handler_test.go | 142 ++++++++++++++++++ .../pkg/controller/nodes/task/transformer.go | 30 ++-- 3 files changed, 166 insertions(+), 12 deletions(-) diff --git a/flytepropeller/pkg/controller/nodes/task/handler.go b/flytepropeller/pkg/controller/nodes/task/handler.go index c7c2512b9f..67cc6aeb08 100644 --- a/flytepropeller/pkg/controller/nodes/task/handler.go +++ b/flytepropeller/pkg/controller/nodes/task/handler.go @@ -667,9 +667,13 @@ func (t Handler) Abort(ctx context.Context, nCtx handler.NodeExecutionContext, r } taskExecID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID() evRecorder := nCtx.EventsRecorder() + nodeExecutionID, err := getParentNodeExecIDForTask(&taskExecID, nCtx.ExecutionContext()) + if err != nil { + return err + } if err := evRecorder.RecordTaskEvent(ctx, &event.TaskExecutionEvent{ TaskId: taskExecID.TaskId, - ParentNodeExecutionId: taskExecID.NodeExecutionId, + ParentNodeExecutionId: nodeExecutionID, RetryAttempt: nCtx.CurrentAttempt(), Phase: core.TaskExecution_ABORTED, OccurredAt: ptypes.TimestampNow(), diff --git a/flytepropeller/pkg/controller/nodes/task/handler_test.go b/flytepropeller/pkg/controller/nodes/task/handler_test.go index f46a5bf078..e257d65a17 100644 --- a/flytepropeller/pkg/controller/nodes/task/handler_test.go +++ b/flytepropeller/pkg/controller/nodes/task/handler_test.go @@ -1349,6 +1349,148 @@ func Test_task_Abort(t *testing.T) { } } +func Test_task_Abort_v1(t *testing.T) { + createNodeCtx := func(ev *fakeBufferedTaskEventRecorder) *nodeMocks.NodeExecutionContext { + wfExecID := &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", + } + + nodeID := "n1" + + nm := &nodeMocks.NodeExecutionMetadata{} + nm.OnGetAnnotations().Return(map[string]string{}) + nm.OnGetNodeExecutionID().Return(&core.NodeExecutionIdentifier{ + NodeId: nodeID, + ExecutionId: wfExecID, + }) + nm.OnGetK8sServiceAccount().Return("service-account") + nm.OnGetLabels().Return(map[string]string{}) + nm.OnGetNamespace().Return("namespace") + nm.OnGetOwnerID().Return(types.NamespacedName{Namespace: "namespace", Name: "name"}) + nm.OnGetOwnerReference().Return(v12.OwnerReference{ + Kind: "sample", + Name: "name", + }) + + taskID := &core.Identifier{} + tr := &nodeMocks.TaskReader{} + tr.OnGetTaskID().Return(taskID) + tr.OnGetTaskType().Return("x") + + ns := &flyteMocks.ExecutableNodeStatus{} + ns.OnGetDataDir().Return(storage.DataReference("data-dir")) + ns.OnGetOutputDir().Return(storage.DataReference("output-dir")) + + res := &v1.ResourceRequirements{} + n := &flyteMocks.ExecutableNode{} + ma := 5 + n.OnGetRetryStrategy().Return(&v1alpha1.RetryStrategy{MinAttempts: &ma}) + n.OnGetResources().Return(res) + + ir := &ioMocks.InputReader{} + nCtx := &nodeMocks.NodeExecutionContext{} + nCtx.OnNodeExecutionMetadata().Return(nm) + nCtx.OnNode().Return(n) + nCtx.OnInputReader().Return(ir) + ds, err := storage.NewDataStore( + &storage.Config{ + Type: storage.TypeMemory, + }, + promutils.NewTestScope(), + ) + assert.NoError(t, err) + nCtx.OnDataStore().Return(ds) + nCtx.OnCurrentAttempt().Return(uint32(1)) + nCtx.OnTaskReader().Return(tr) + nCtx.OnMaxDatasetSizeBytes().Return(int64(1)) + nCtx.OnNodeStatus().Return(ns) + nCtx.OnNodeID().Return("n1") + nCtx.OnEnqueueOwnerFunc().Return(nil) + nCtx.OnEventsRecorder().Return(ev) + + executionContext := &mocks.ExecutionContext{} + executionContext.OnGetExecutionConfig().Return(v1alpha1.ExecutionConfig{}) + executionContext.OnGetParentInfo().Return(nil) + executionContext.OnGetEventVersion().Return(v1alpha1.EventVersion1) + nCtx.OnExecutionContext().Return(executionContext) + + nCtx.OnRawOutputPrefix().Return("s3://sandbox/") + nCtx.OnOutputShardSelector().Return(ioutils.NewConstantShardSelector([]string{"x"})) + + st := bytes.NewBuffer([]byte{}) + a := 45 + type test struct { + A int + } + cod := codex.GobStateCodec{} + assert.NoError(t, cod.Encode(test{A: a}, st)) + nr := &nodeMocks.NodeStateReader{} + nr.OnGetTaskNodeState().Return(handler.TaskNodeState{ + PluginState: st.Bytes(), + }) + nCtx.OnNodeStateReader().Return(nr) + return nCtx + } + + noopRm := CreateNoopResourceManager(context.TODO(), promutils.NewTestScope()) + + type fields struct { + defaultPluginCallback func() pluginCore.Plugin + } + type args struct { + ev *fakeBufferedTaskEventRecorder + } + tests := []struct { + name string + fields fields + args args + wantErr bool + abortCalled bool + }{ + {"no-plugin", fields{defaultPluginCallback: func() pluginCore.Plugin { + return nil + }}, args{nil}, true, false}, + + {"abort-fails", fields{defaultPluginCallback: func() pluginCore.Plugin { + p := &pluginCoreMocks.Plugin{} + p.On("GetID").Return("id") + p.On("Abort", mock.Anything, mock.Anything).Return(fmt.Errorf("error")) + return p + }}, args{nil}, true, true}, + {"abort-success", fields{defaultPluginCallback: func() pluginCore.Plugin { + p := &pluginCoreMocks.Plugin{} + p.On("GetID").Return("id") + p.On("Abort", mock.Anything, mock.Anything).Return(nil) + return p + }}, args{ev: &fakeBufferedTaskEventRecorder{}}, false, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := tt.fields.defaultPluginCallback() + tk := Handler{ + defaultPlugin: m, + resourceManager: noopRm, + } + nCtx := createNodeCtx(tt.args.ev) + if err := tk.Abort(context.TODO(), nCtx, "reason"); (err != nil) != tt.wantErr { + t.Errorf("Handler.Abort() error = %v, wantErr %v", err, tt.wantErr) + } + c := 0 + if tt.abortCalled { + c = 1 + if !tt.wantErr { + assert.Len(t, tt.args.ev.evs, 1) + } + } + if m != nil { + m.(*pluginCoreMocks.Plugin).AssertNumberOfCalls(t, "Abort", c) + } + }) + } +} + func Test_task_Finalize(t *testing.T) { wfExecID := &core.WorkflowExecutionIdentifier{ diff --git a/flytepropeller/pkg/controller/nodes/task/transformer.go b/flytepropeller/pkg/controller/nodes/task/transformer.go index f638da6c37..5c16d1aa35 100644 --- a/flytepropeller/pkg/controller/nodes/task/transformer.go +++ b/flytepropeller/pkg/controller/nodes/task/transformer.go @@ -53,6 +53,22 @@ func trimErrorMessage(original string, maxLength int) string { return original[0:maxLength/2] + original[len(original)-maxLength/2:] } +func getParentNodeExecIDForTask(taskExecID *core.TaskExecutionIdentifier, execContext executors.ExecutionContext) (*core.NodeExecutionIdentifier, error) { + nodeExecutionID := &core.NodeExecutionIdentifier{ + ExecutionId: taskExecID.NodeExecutionId.ExecutionId, + } + if execContext.GetEventVersion() != v1alpha1.EventVersion0 { + currentNodeUniqueID, err := common.GenerateUniqueID(execContext.GetParentInfo(), taskExecID.NodeExecutionId.NodeId) + if err != nil { + return nil, err + } + nodeExecutionID.NodeId = currentNodeUniqueID + } else { + nodeExecutionID.NodeId = taskExecID.NodeExecutionId.NodeId + } + return nodeExecutionID, nil +} + func ToTaskExecutionEvent(taskExecID *core.TaskExecutionIdentifier, in io.InputFilePaths, out io.OutputFilePaths, info pluginCore.PhaseInfo, nodeExecutionMetadata handler.NodeExecutionMetadata, execContext executors.ExecutionContext) (*event.TaskExecutionEvent, error) { // Transitions to a new phase @@ -66,17 +82,9 @@ func ToTaskExecutionEvent(taskExecID *core.TaskExecutionIdentifier, in io.InputF } } - nodeExecutionID := &core.NodeExecutionIdentifier{ - ExecutionId: taskExecID.NodeExecutionId.ExecutionId, - } - if execContext.GetEventVersion() != v1alpha1.EventVersion0 { - currentNodeUniqueID, err := common.GenerateUniqueID(execContext.GetParentInfo(), taskExecID.NodeExecutionId.NodeId) - if err != nil { - return nil, err - } - nodeExecutionID.NodeId = currentNodeUniqueID - } else { - nodeExecutionID.NodeId = taskExecID.NodeExecutionId.NodeId + nodeExecutionID, err := getParentNodeExecIDForTask(taskExecID, execContext) + if err != nil { + return nil, err } tev := &event.TaskExecutionEvent{ TaskId: taskExecID.TaskId,