diff --git a/errors.go b/errors.go index 9c32a49..21f4e0d 100644 --- a/errors.go +++ b/errors.go @@ -14,6 +14,8 @@ package fsm +import "context" + // InvalidEventError is returned by FSM.Event() when the event cannot be called // in the current state. type InvalidEventError struct { @@ -80,8 +82,11 @@ func (e CanceledError) Error() string { // AsyncError is returned by FSM.Event() when a callback have initiated an // asynchronous state transition. +// +// Ctx indicates if the transition is done. type AsyncError struct { Err error + Ctx context.Context } func (e AsyncError) Error() string { diff --git a/fsm.go b/fsm.go index 7220d95..6c598cd 100644 --- a/fsm.go +++ b/fsm.go @@ -25,6 +25,7 @@ package fsm import ( + "context" "strings" "sync" ) @@ -32,6 +33,7 @@ import ( // transitioner is an interface for the FSM's transition function. type transitioner interface { transition(*FSM) error + cancelTransition(*FSM) error } // FSM is the state machine that holds the current state. @@ -62,6 +64,9 @@ type FSM struct { metadata map[string]interface{} metadataMu sync.RWMutex + + // cancelCtx cancels the context associated to asynchronous state transition. + cancelCtx context.CancelFunc } // EventDesc represents an event when initializing the FSM. @@ -135,6 +140,7 @@ func NewFSM(initial string, events []EventDesc, callbacks map[string]Callback) * transitions: make(map[eKey]string), callbacks: make(map[cKey]Callback), metadata: make(map[string]interface{}), + cancelCtx: nil, } // Build transition map and store sets of all events and states. @@ -355,11 +361,23 @@ func (f *FSM) Transition() error { return f.doTransition() } +// CancelTransition wraps transitioner.cancelTransition. +func (f *FSM) CancelTransition() error { + f.eventMu.Lock() + defer f.eventMu.Unlock() + return f.cancelTransition() +} + // doTransition wraps transitioner.transition. func (f *FSM) doTransition() error { return f.transitionerObj.transition(f) } +// cancelTransition wraps transitioner.cancelTransition. +func (f *FSM) cancelTransition() error { + return f.transitionerObj.cancelTransition(f) +} + // transitionerStruct is the default implementation of the transitioner // interface. Other implementations can be swapped in for testing. type transitionerStruct struct{} @@ -374,6 +392,24 @@ func (t transitionerStruct) transition(f *FSM) error { } f.transition() f.transition = nil + if f.cancelCtx != nil { + f.cancelCtx() + f.cancelCtx = nil + } + return nil +} + +// CancelTransition cancels an asynchrounous state change. +// +// The callback for leave_ must prviously have called Async on its +// event to have initiated an asynchronous state transition. +func (t transitionerStruct) cancelTransition(f *FSM) error { + if f.transition == nil { + return NotInTransitionError{} + } + f.cancelCtx() + f.cancelCtx = nil + f.transition = nil return nil } @@ -403,7 +439,9 @@ func (f *FSM) leaveStateCallbacks(e *Event) error { if e.canceled { return CanceledError{e.Err} } else if e.async { - return AsyncError{e.Err} + ctx, cancel := context.WithCancel(context.Background()) + f.cancelCtx = cancel + return AsyncError{e.Err, ctx} } } if fn, ok := f.callbacks[cKey{"", callbackLeaveState}]; ok { @@ -411,7 +449,9 @@ func (f *FSM) leaveStateCallbacks(e *Event) error { if e.canceled { return CanceledError{e.Err} } else if e.async { - return AsyncError{e.Err} + ctx, cancel := context.WithCancel(context.Background()) + f.cancelCtx = cancel + return AsyncError{e.Err, ctx} } } return nil diff --git a/fsm_test.go b/fsm_test.go index 851054d..b3965f9 100644 --- a/fsm_test.go +++ b/fsm_test.go @@ -29,6 +29,10 @@ func (t fakeTransitionerObj) transition(f *FSM) error { return &InternalError{} } +func (t fakeTransitionerObj) cancelTransition(f *FSM) error { + return nil +} + func TestSameState(t *testing.T) { fsm := NewFSM( "start", @@ -53,7 +57,7 @@ func TestSetState(t *testing.T) { ) fsm.SetState("start") if fsm.Current() != "start" { - t.Error("expected state to be 'walking'") + t.Error("expected state to be 'start'") } err := fsm.Event("walk") if err != nil { @@ -471,6 +475,140 @@ func TestAsyncTransitionNotInProgress(t *testing.T) { } } +func TestCancelAsyncTransitionGenericState(t *testing.T) { + fsm := NewFSM( + "start", + Events{ + {Name: "run", Src: []string{"start"}, Dst: "end"}, + }, + Callbacks{ + "leave_state": func(e *Event) { + e.Async() + }, + }, + ) + fsm.Event("run") + if fsm.Current() != "start" { + t.Error("expected state to be 'start'") + } + err := fsm.Event("run") + if e, ok := err.(InTransitionError); !ok && e.Event != "run" { + t.Error("expected 'InTransitionError' with correct state") + } + fsm.CancelTransition() + err = fsm.Event("run") + if _, ok := err.(AsyncError); !ok { + t.Error("expected 'AsyncError'") + } +} + +func TestCancelAsyncTransitionSpecificState(t *testing.T) { + fsm := NewFSM( + "start", + Events{ + {Name: "run", Src: []string{"start"}, Dst: "end"}, + }, + Callbacks{ + "leave_start": func(e *Event) { + e.Async() + }, + }, + ) + fsm.Event("run") + if fsm.Current() != "start" { + t.Error("expected state to be 'start'") + } + err := fsm.Event("run") + if e, ok := err.(InTransitionError); !ok && e.Event != "run" { + t.Error("expected 'InTransitionError' with correct state") + } + fsm.CancelTransition() + err = fsm.Event("run") + if _, ok := err.(AsyncError); !ok { + t.Error("expected 'AsyncError'") + } +} + +func TestContextWhenCancelAsyncTransition(t *testing.T) { + fsm := NewFSM( + "start", + Events{ + {Name: "run", Src: []string{"start"}, Dst: "end"}, + }, + Callbacks{ + "leave_start": func(e *Event) { + e.Async() + }, + }, + ) + err := fsm.Event("run") + asyncErr, ok := err.(AsyncError) + if !ok || asyncErr.Ctx == nil { + t.Error("expected context in 'AsyncError'") + } + defer asyncErr.Ctx.Done() + + fsm.CancelTransition() + timer := time.NewTimer(time.Millisecond * 100) + defer timer.Stop() + + select { + case <-timer.C: + t.Error("expected context has been done") + case _, ok := <-asyncErr.Ctx.Done(): + if ok { + t.Error("expected context has been done") + } + } +} + +func TestContextWhenFinishAsyncTransition(t *testing.T) { + fsm := NewFSM( + "start", + Events{ + {Name: "run", Src: []string{"start"}, Dst: "end"}, + }, + Callbacks{ + "leave_start": func(e *Event) { + e.Async() + }, + }, + ) + err := fsm.Event("run") + asyncErr, ok := err.(AsyncError) + if !ok || asyncErr.Ctx == nil { + t.Error("expected context in 'AsyncError'") + } + defer asyncErr.Ctx.Done() + + fsm.Transition() + timer := time.NewTimer(time.Millisecond * 100) + defer timer.Stop() + + select { + case <-timer.C: + t.Error("expected context has been done") + case _, ok := <-asyncErr.Ctx.Done(): + if ok { + t.Error("expected context has been done") + } + } +} + +func TestCancelAsyncTransitionNotInProgress(t *testing.T) { + fsm := NewFSM( + "start", + Events{ + {Name: "run", Src: []string{"start"}, Dst: "end"}, + }, + Callbacks{}, + ) + err := fsm.CancelTransition() + if _, ok := err.(NotInTransitionError); !ok { + t.Error("expected 'NotInTransitionError'") + } +} + func TestCallbackNoError(t *testing.T) { fsm := NewFSM( "start",