Skip to content

Commit

Permalink
statemachine: Implements conditional fault in task execution
Browse files Browse the repository at this point in the history
  • Loading branch information
joelrebel committed May 19, 2023
1 parent d8b7317 commit 2df6da3
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 9 deletions.
5 changes: 5 additions & 0 deletions internal/statemachine/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package statemachine

import (
sw "github.com/filanov/stateswitch"
"github.com/metal-toolbox/flasher/internal/model"
)

// MockTaskHandler implements the TaskTransitioner interface
Expand Down Expand Up @@ -38,3 +39,7 @@ func (h *MockTaskHandler) TaskSuccessful(_ sw.StateSwitch, _ sw.TransitionArgs)
func (h *MockTaskHandler) PublishStatus(_ sw.StateSwitch, _ sw.TransitionArgs) error {
return nil
}

func (h *MockTaskHandler) ConditionalFault(_ *HandlerContext, _ *model.Task, _ sw.TransitionType) error {
return nil
}
41 changes: 41 additions & 0 deletions internal/statemachine/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package statemachine
import (
"context"
"fmt"
"time"

sw "github.com/filanov/stateswitch"
"github.com/metal-toolbox/flasher/internal/model"
Expand All @@ -28,6 +29,7 @@ var (
ErrInvalidTransitionHandler = errors.New("expected a valid transitionHandler{} type")
ErrInvalidtaskHandlerContext = errors.New("expected a HandlerContext{} type")
ErrTaskTransition = errors.New("error in task transition")
errConditionFault = errors.New("condition induced fault")
)

// Publisher defines methods to publish task information.
Expand Down Expand Up @@ -293,13 +295,52 @@ func (m *TaskStateMachine) TransitionSuccess(task *model.Task, tctx *HandlerCont
return m.sm.Run(TransitionTypeTaskSuccess, task, tctx)
}

// ConditionalFault is invoked before each transition to induce a fault if specified.
func (m *TaskStateMachine) ConditionalFault(handlerCtx *HandlerContext, task *model.Task, transitionType sw.TransitionType) error {
if task.Fault == nil {
return nil
}

if task.Fault.Panic {
panic("condition induced panic..")
}

if task.Fault.FailAt == string(transitionType) {
return errors.Wrap(errConditionFault, string(transitionType))
}

if task.Fault.ExecuteWithDelay > 0 {
handlerCtx.Logger.WithField("delay", task.Fault.ExecuteWithDelay.Seconds()).Warn("condition induced delayed execution..")
time.Sleep(task.Fault.ExecuteWithDelay)

// reset delay duration
task.Fault.ExecuteWithDelay = 0
}

return nil
}

// Run executes the transitions in the expected order while handling any failures.
func (m *TaskStateMachine) Run(task *model.Task, tctx *HandlerContext) error {
var err error

var finalTransition sw.TransitionType

for _, transitionType := range m.transitions {
// conditionally fault
if errFault := m.ConditionalFault(tctx, task, transitionType); errFault != nil {
// include error in task
task.Status = errFault.Error()

// run transition failed handler
if txErr := m.TransitionFailed(task, tctx); txErr != nil {
err = errors.Wrap(errFault, string(TransitionTypeActionFailed)+": "+txErr.Error())
}

return err
}

// execute transition
err = m.sm.Run(transitionType, task, tctx)
if err != nil {
err = errors.Wrap(err, string(transitionType))
Expand Down
103 changes: 94 additions & 9 deletions internal/statemachine/task_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@ package statemachine
import (
"net"
"testing"
"time"

sw "github.com/filanov/stateswitch"
"github.com/google/uuid"
cptypes "github.com/metal-toolbox/conditionorc/pkg/types"
"github.com/metal-toolbox/flasher/internal/model"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -58,10 +61,10 @@ var (
}
)

func newTaskFixture(t *testing.T, state string) *model.Task {
func newTaskFixture(t *testing.T, state string, fault *cptypes.Fault) *model.Task {
t.Helper()

task := &model.Task{}
task := &model.Task{Fault: fault}
if err := task.SetState(sw.State(state)); err != nil {
t.Fatal(err)
}
Expand All @@ -81,7 +84,7 @@ func Test_NewTaskStateMachine(t *testing.T) {
}{
{
"new task statemachine is created",
newTaskFixture(t, string(model.StatePending)),
newTaskFixture(t, string(model.StatePending), nil),
},
}

Expand Down Expand Up @@ -110,47 +113,47 @@ func Test_Transitions(t *testing.T) {
}{
{
"Pending to Active",
newTaskFixture(t, string(model.StatePending)),
newTaskFixture(t, string(model.StatePending), nil),
[]sw.TransitionType{TransitionTypeActive},
string(model.StateActive),
false,
"",
},
{
"Active to Success",
newTaskFixture(t, string(model.StateActive)),
newTaskFixture(t, string(model.StateActive), nil),
[]sw.TransitionType{TransitionTypeRun},
string(model.StateSucceeded),
false,
"",
},
{
"Queued to Success - run all transitions",
newTaskFixture(t, string(model.StatePending)),
newTaskFixture(t, string(model.StatePending), nil),
[]sw.TransitionType{}, // with this not defined, the statemachine defaults to the configured transitions.
string(model.StateSucceeded),
false,
"",
},
{
"Queued to Failed",
newTaskFixture(t, string(model.StateActive)),
newTaskFixture(t, string(model.StateActive), nil),
[]sw.TransitionType{TransitionTypeTaskFail},
string(model.StateFailed),
true,
"",
},
{
"Active to Failed",
newTaskFixture(t, string(model.StatePending)),
newTaskFixture(t, string(model.StatePending), nil),
[]sw.TransitionType{TransitionTypeTaskFail},
string(model.StateFailed),
true,
"",
},
{
"Success to Active fails - invalid transition",
newTaskFixture(t, string(model.StatePending)),
newTaskFixture(t, string(model.StatePending), nil),
[]sw.TransitionType{TransitionTypeTaskSuccess},
string(model.StateFailed),
true,
Expand Down Expand Up @@ -192,3 +195,85 @@ func Test_Transitions(t *testing.T) {
})
}
}

func Test_ConditionalFaultWithTransitions(t *testing.T) {
tests := []struct {
name string
task *model.Task
runTransition []sw.TransitionType
expectedState string
expectError bool
expectPanic bool
expectDelay bool
}{
{
"condition induced error",
newTaskFixture(t, string(model.StateActive), &cptypes.Fault{FailAt: "plan"}),
[]sw.TransitionType{TransitionTypePlan},
string(model.StateFailed),
true,
false,
false,
},
{
"condition induced panic",
newTaskFixture(t, string(model.StateActive), &cptypes.Fault{Panic: true}),
[]sw.TransitionType{TransitionTypePlan},
string(model.StateFailed),
true,
true,
false,
},
{
"condition induced delay",
newTaskFixture(t, string(model.StateActive), &cptypes.Fault{ExecuteWithDelay: 30 * time.Millisecond}),
[]sw.TransitionType{TransitionTypePlan},
string(model.StateActive),
false,
false,
true,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// init task handler context
tctx := &HandlerContext{Task: tc.task, Logger: logrus.NewEntry(logrus.New())}
handler := &MockTaskHandler{}

// init new state machine
m, err := NewTaskStateMachine(handler)
if err != nil {
t.Fatal(err)
}

// set transition to perform based on test case
if len(tc.runTransition) > 0 {
m.SetTransitionOrder(tc.runTransition)
}

if tc.expectPanic {
assert.Panics(t, func() {
_ = m.Run(tc.task, tctx)
})

} else {
start := time.Now()

// run transition
err = m.Run(tc.task, tctx)
if err != nil {
if !tc.expectError {
t.Fatal(err)
}
}

if tc.expectDelay {
assert.GreaterOrEqual(t, time.Since(start), tc.task.Fault.ExecuteWithDelay)
}

assert.Equal(t, tc.expectedState, string(tc.task.State()))
}
})
}
}

0 comments on commit 2df6da3

Please sign in to comment.