diff --git a/saga/inMemorySagaLog.go b/saga/inMemorySagaLog.go index 49912510f..f4c2f0c34 100644 --- a/saga/inMemorySagaLog.go +++ b/saga/inMemorySagaLog.go @@ -27,41 +27,41 @@ func MakeInMemorySaga() Saga { } func (log *inMemorySagaLog) LogMessage(msg sagaMessage) error { - fmt.Println(fmt.Sprintf("Saga %s: %s %s", msg.sagaId, msg.msgType.String(), msg.taskId)) + log.mutex.Lock() + defer log.mutex.Unlock() + + fmt.Println(fmt.Sprintf("Saga %s: %s %s", msg.sagaId, msg.msgType.String(), msg.taskId)) sagaId := msg.sagaId - var err error - log.mutex.Lock() msgs, ok := log.sagas[sagaId] if !ok { return errors.New(fmt.Sprintf("Saga: %s is not Started yet.", msg.sagaId)) } log.sagas[sagaId] = append(msgs, msg) - log.mutex.Unlock() - return err + return nil } func (log *inMemorySagaLog) StartSaga(sagaId string, job []byte) error { log.mutex.Lock() + defer log.mutex.Unlock() fmt.Println(fmt.Sprintf("Start Saga %s", sagaId)) startMsg := MakeStartSagaMessage(sagaId, job) log.sagas[sagaId] = []sagaMessage{startMsg} - log.mutex.Unlock() - return nil } func (log *inMemorySagaLog) GetMessages(sagaId string) ([]sagaMessage, error) { log.mutex.RLock() + defer log.mutex.RUnlock() + msgs, ok := log.sagas[sagaId] - log.mutex.RUnlock() if ok { return msgs, nil @@ -69,3 +69,19 @@ func (log *inMemorySagaLog) GetMessages(sagaId string) ([]sagaMessage, error) { return nil, nil } } + +/* + * Returns all Sagas Started since this InMemory Saga was created + */ +func (log *inMemorySagaLog) GetActiveSagas() ([]string, error) { + log.mutex.RLock() + defer log.mutex.RUnlock() + + keys := make([]string, 0, len(log.sagas)) + + for key, _ := range log.sagas { + keys = append(keys, key) + } + + return keys, nil +} diff --git a/saga/saga.go b/saga/saga.go index 32da9000a..29e4d0e0b 100644 --- a/saga/saga.go +++ b/saga/saga.go @@ -1,22 +1,5 @@ package saga -type SagaRecoveryType int - -/* - * Saga Recovery Types define how to interpret SagaState in RecoveryMode. - * - * ForwardRecovery: all tasks in the saga must be executed at least once. - * tasks MUST BE idempotent - * - * RollbackRecovery: if Saga is Aborted or in unsafe state, compensating - * tasks for all started tasks need to be executed. - * compensating tasks MUST BE idempotent. - */ -const ( - BackwardRecovery SagaRecoveryType = iota - ForwardRecovery -) - /* * Saga Object which provides all Saga Functionality * Implementations of SagaLog should provide a factory method @@ -56,26 +39,6 @@ func (s Saga) StartSaga(sagaId string, job []byte) (*SagaState, error) { return state, nil } -/* - * logs the specified message durably to the SagaLog & updates internal state if its a valid state transition - */ -func (s Saga) logMessage(state *SagaState, msg sagaMessage) (*SagaState, error) { - - //verify that the applied message results in a valid state - newState, err := updateSagaState(state, msg) - if err != nil { - return nil, err - } - - //try durably storing the message - err = s.log.LogMessage(msg) - if err != nil { - return nil, err - } - - return newState, nil -} - /* * Log an End Saga Message to the log, returns updated SagaState * Returns the resulting SagaState or an error if it fails @@ -148,3 +111,48 @@ func (s Saga) StartCompensatingTask(state *SagaState, taskId string, data []byte func (s Saga) EndCompensatingTask(state *SagaState, taskId string, results []byte) (*SagaState, error) { return s.logMessage(state, MakeEndCompTaskMessage(state.sagaId, taskId, results)) } + +/* + * Should be called at Saga Creation time. + * Returns a Slice of In Progress SagaIds + */ +func (s Saga) Startup() ([]string, error) { + + ids, err := s.log.GetActiveSagas() + if err != nil { + return nil, err + } + + return ids, nil +} + +/* + * Recovers SagaState by reading all logged messages from the log. + * Utilizes the specified recoveryType to determine if Saga needs to be + * Aborted or can proceed safely. + * + * Returns the current SagaState + */ +func (s Saga) RecoverSagaState(sagaId string, recoveryType SagaRecoveryType) (*SagaState, error) { + return recoverState(sagaId, s, recoveryType) +} + +/* + * logs the specified message durably to the SagaLog & updates internal state if its a valid state transition + */ +func (s Saga) logMessage(state *SagaState, msg sagaMessage) (*SagaState, error) { + + //verify that the applied message results in a valid state + newState, err := updateSagaState(state, msg) + if err != nil { + return nil, err + } + + //try durably storing the message + err = s.log.LogMessage(msg) + if err != nil { + return nil, err + } + + return newState, nil +} diff --git a/saga/sagaRecovery.go b/saga/sagaRecovery.go new file mode 100644 index 000000000..62bf22a56 --- /dev/null +++ b/saga/sagaRecovery.go @@ -0,0 +1,101 @@ +package saga + +import ( + "fmt" +) + +type SagaRecoveryType int + +/* + * Saga Recovery Types define how to interpret SagaState in RecoveryMode. + * + * ForwardRecovery: all tasks in the saga must be executed at least once. + * tasks MUST BE idempotent + * + * RollbackRecovery: if Saga is Aborted or in unsafe state, compensating + * tasks for all started tasks need to be executed. + * compensating tasks MUST BE idempotent. + */ +const ( + RollbackRecovery SagaRecoveryType = iota + ForwardRecovery +) + +/* + * Recovers SagaState from SagaLog messages + */ +func recoverState(sagaId string, saga Saga, recoveryType SagaRecoveryType) (*SagaState, error) { + + // Get Logged Messages For this Saga from the Log. + msgs, err := saga.log.GetMessages(sagaId) + if err != nil { + return nil, err + } + + if msgs == nil || len(msgs) == 0 { + return nil, nil + } + + // Reconstruct Saga State from Logged Messages + startMsg := msgs[0] + if startMsg.msgType != StartSaga { + return nil, fmt.Errorf("InvalidMessages: first message must be StartSaga") + } + + state, err := makeSagaState(sagaId, startMsg.data) + if err != nil { + return nil, err + } + + for _, msg := range msgs { + state, err = updateSagaState(state, msg) + if err != nil { + return nil, err + } + } + + // Check if we can safely proceed forward based on recovery method + // RollbackRecovery must check if in a SafeState, + // ForwardRecovery can always make progress + switch recoveryType { + + case RollbackRecovery: + + // if Saga is not in a safe state we must abort the saga + // And compensating tasks should start + if !isSagaInSafeState(state) { + state, err = saga.AbortSaga(state) + if err != nil { + return nil, err + } + } + + case ForwardRecovery: + + } + + return state, nil +} + +/* + * Returns true if saga is in a safe state, i.e. execution can pick up where + * it left off. This is only used in RollbackRecovery + * + * A Saga is in a Safe State if all StartedTasks also have EndTask Messages + * A Saga is also in a Safe State if the Saga has been aborted and compensating + * actions have started to be applied. + */ +func isSagaInSafeState(state *SagaState) bool { + + if state.IsSagaAborted() { + return true + } + + for taskId, _ := range state.taskState { + if state.IsTaskStarted(taskId) && !state.IsTaskCompleted(taskId) { + return false + } + } + + return true +} diff --git a/saga/sagaRecovery_test.go b/saga/sagaRecovery_test.go new file mode 100644 index 000000000..485f8aa43 --- /dev/null +++ b/saga/sagaRecovery_test.go @@ -0,0 +1,245 @@ +package saga + +import ( + "bytes" + "errors" + "github.com/golang/mock/gomock" + "testing" +) + +func TestRecoverState_GetMessagesReturnsError(t *testing.T) { + + sagaId := "sagaId" + + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + sagaLogMock := NewMockSagaLog(mockCtrl) + sagaLogMock.EXPECT().GetMessages(sagaId).Return(nil, errors.New("test error")) + saga := MakeSaga(sagaLogMock) + + state, err := recoverState(sagaId, saga, RollbackRecovery) + + if err == nil { + t.Error("Expected GetMessages return error to cause recoverState to return an error") + } + + if state != nil { + t.Error("Expected returned SagaState to be nil when Error occurs") + } +} + +func TestRecoverState_GetMessagesReturnsEmptyList(t *testing.T) { + sagaId := "sagaId" + + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + sagaLogMock := NewMockSagaLog(mockCtrl) + sagaLogMock.EXPECT().GetMessages(sagaId).Return(nil, nil) + saga := MakeSaga(sagaLogMock) + + state, err := recoverState(sagaId, saga, RollbackRecovery) + + if state != nil { + t.Error("Expected returned SagaState to be nil when no messages in SagaLog") + } + + if err != nil { + t.Error("Expect Error to be nil when no messages returned from SagaLog") + } +} + +func TestRecoverState_MissingStartMessage(t *testing.T) { + sagaId := "sagaId" + + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + msgs := []sagaMessage{ + MakeEndSagaMessage(sagaId), + MakeStartSagaMessage(sagaId, nil), + } + + sagaLogMock := NewMockSagaLog(mockCtrl) + sagaLogMock.EXPECT().GetMessages(sagaId).Return(msgs, nil) + saga := MakeSaga(sagaLogMock) + + state, err := recoverState(sagaId, saga, RollbackRecovery) + + if err == nil { + t.Error("Expected error when StartSaga is not first message") + } + + if state != nil { + t.Error("Expect sagaState to be nil when error occurs") + } +} + +func TestRecoverState_UpdateSagaStateFails(t *testing.T) { + sagaId := "sagaId" + + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + msgs := []sagaMessage{ + MakeStartSagaMessage(sagaId, nil), + MakeEndTaskMessage(sagaId, "task1", nil), + MakeEndSagaMessage(sagaId), + } + + sagaLogMock := NewMockSagaLog(mockCtrl) + sagaLogMock.EXPECT().GetMessages(sagaId).Return(msgs, nil) + saga := MakeSaga(sagaLogMock) + + state, err := recoverState(sagaId, saga, RollbackRecovery) + + if err == nil { + t.Error("Expected error when StartSaga is not first message") + } + + if state != nil { + t.Error("Expect sagaState to be nil when error occurs") + } +} + +func TestRecoverState_SuccessfulForwardRecovery(t *testing.T) { + sagaId := "sagaId" + taskId := "taskId" + + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + msgs := []sagaMessage{ + MakeStartSagaMessage(sagaId, []byte{4, 5, 6}), + MakeStartTaskMessage(sagaId, taskId, []byte{1, 2, 3}), + } + + sagaLogMock := NewMockSagaLog(mockCtrl) + sagaLogMock.EXPECT().GetMessages(sagaId).Return(msgs, nil) + saga := MakeSaga(sagaLogMock) + + state, err := recoverState(sagaId, saga, ForwardRecovery) + + if err != nil { + t.Errorf("Expected error to be nil %s", err) + } + if state == nil { + t.Error("Expected state to reflect supplied messages") + } + + if state.SagaId() != sagaId { + t.Error("Expected SagaState to have same SagaId") + } + + if !bytes.Equal(state.Job(), []byte{4, 5, 6}) { + t.Error("Expected SagaState Job to match StartMessage data") + } + + if !state.IsTaskStarted(taskId) { + t.Error("Expected SagaState to have task started") + } + + if !bytes.Equal(state.GetStartTaskData(taskId), []byte{1, 2, 3}) { + t.Error("Expected SagaState to have data associatd with starttask") + } + + if state.IsTaskCompleted(taskId) { + t.Error("Expected SagaState to have task not completed") + } + + if state.GetEndTaskData(taskId) != nil { + t.Error("Expected no data associated with end task") + } + + if state.IsCompTaskStarted(taskId) { + t.Error("Expected SagaState to have comptask not started") + } + + if state.GetStartCompTaskData(taskId) != nil { + t.Error("Expected no data associated with start comp task") + } + + if state.IsCompTaskCompleted(taskId) { + t.Error("Expected SagaState to have comptask not completed") + } + + if state.GetEndCompTaskData(taskId) != nil { + t.Error("Expected no data associated with end comp task") + } + + if state.IsSagaCompleted() { + t.Error("Expected SagaState to not be completed") + } + + if state.IsSagaAborted() { + t.Error("Expected SagaState to not be aborted") + } +} + +func TestRecoverState_SuccessfulRollbackRecovery(t *testing.T) { + sagaId := "sagaId" + taskId := "taskId" + + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + msgs := []sagaMessage{ + MakeStartSagaMessage(sagaId, []byte{4, 5, 6}), + MakeStartTaskMessage(sagaId, taskId, []byte{1, 2, 3}), + } + + sagaLogMock := NewMockSagaLog(mockCtrl) + sagaLogMock.EXPECT().GetMessages(sagaId).Return(msgs, nil) + sagaLogMock.EXPECT().LogMessage(MakeAbortSagaMessage(sagaId)) + saga := MakeSaga(sagaLogMock) + + state, err := recoverState(sagaId, saga, RollbackRecovery) + + if err != nil { + t.Errorf("Expected error to be nil %s", err) + } + if state == nil { + t.Error("Expected state to reflect supplied messages") + } + + if !state.IsSagaAborted() { + t.Error("Expected Saga to be Aborted, not in Safe State") + } +} + +func TestSafeState_AbortedSaga(t *testing.T) { + state := initializeSagaState() + state.sagaAborted = true + state.taskState["task1"] = TaskStarted | CompTaskStarted + + safeState := isSagaInSafeState(state) + + if !safeState { + t.Error("Expected Aborted Saga to be in Safe State") + } +} + +func TestSafeState_MissingEndTask(t *testing.T) { + state := initializeSagaState() + state.taskState["task1"] = TaskStarted + state.taskState["task2"] = TaskStarted | TaskCompleted + + safeState := isSagaInSafeState(state) + + if safeState { + t.Error("Expected Saga to be in unsafe state") + } +} + +func TestSafeState_Safe(t *testing.T) { + state := initializeSagaState() + state.taskState["task1"] = TaskStarted | TaskCompleted + state.taskState["task2"] = TaskStarted | TaskCompleted + + safeState := isSagaInSafeState(state) + + if !safeState { + t.Error("Expected Saga to be in safe state") + } +} diff --git a/saga/saga_test.go b/saga/saga_test.go index fa403ece8..118cb9485 100644 --- a/saga/saga_test.go +++ b/saga/saga_test.go @@ -18,17 +18,15 @@ func TestStartSaga(t *testing.T) { sagaLogMock := NewMockSagaLog(mockCtrl) sagaLogMock.EXPECT().StartSaga(id, job) - s := Saga{ - log: sagaLogMock, - } - + s := MakeSaga(sagaLogMock) state, err := s.StartSaga(id, job) + if err != nil { - t.Error(fmt.Sprintf("Expected StartSaga to not return an error")) + t.Error("Expected StartSaga to not return an error") } if state.SagaId() != id { - t.Error(fmt.Sprintf("Expected state.SagaId to equal 'testSaga'")) + t.Error("Expected state.SagaId to equal 'testSaga'") } } @@ -42,16 +40,14 @@ func TestStartSagaLogError(t *testing.T) { sagaLogMock := NewMockSagaLog(mockCtrl) sagaLogMock.EXPECT().StartSaga(id, job).Return(errors.New("Failed to Log StartSaga")) - s := Saga{ - log: sagaLogMock, - } + s := MakeSaga(sagaLogMock) state, err := s.StartSaga(id, job) if err == nil { - t.Error(fmt.Sprintf("Expected StartSaga to return error if SagaLog fails to log request")) + t.Error("Expected StartSaga to return error if SagaLog fails to log request") } if state != nil { - t.Error(fmt.Sprintf("Expected returned state to be nil when error occurs")) + t.Error("Expected returned state to be nil when error occurs") } } @@ -65,15 +61,12 @@ func TestEndSaga(t *testing.T) { sagaLogMock.EXPECT().StartSaga("testSaga", nil) sagaLogMock.EXPECT().LogMessage(entry) - s := Saga{ - log: sagaLogMock, - } - + s := MakeSaga(sagaLogMock) state, err := s.StartSaga("testSaga", nil) state, err = s.EndSaga(state) if err != nil { - t.Error(fmt.Sprintf("Expected EndSaga to not return an error")) + t.Error("Expected EndSaga to not return an error") } } @@ -87,14 +80,12 @@ func TestEndSagaLogError(t *testing.T) { sagaLogMock.EXPECT().StartSaga("testSaga", nil) sagaLogMock.EXPECT().LogMessage(entry).Return(errors.New("Failed to Log EndSaga Message")) - s := Saga{ - log: sagaLogMock, - } + s := MakeSaga(sagaLogMock) state, err := s.StartSaga("testSaga", nil) state, err = s.EndSaga(state) if err == nil { - t.Error(fmt.Sprintf("Expected EndSaga to not return an error when write to SagaLog Fails")) + t.Error("Expected EndSaga to not return an error when write to SagaLog Fails") } } @@ -108,14 +99,12 @@ func TestAbortSaga(t *testing.T) { sagaLogMock.EXPECT().StartSaga("testSaga", nil) sagaLogMock.EXPECT().LogMessage(entry) - s := Saga{ - log: sagaLogMock, - } + s := MakeSaga(sagaLogMock) state, err := s.StartSaga("testSaga", nil) state, err = s.AbortSaga(state) if err != nil { - t.Error(fmt.Sprintf("Expected AbortSaga to not return an error")) + t.Error("Expected AbortSaga to not return an error") } } @@ -129,14 +118,13 @@ func TestAbortSagaLogError(t *testing.T) { sagaLogMock.EXPECT().StartSaga("testSaga", nil) sagaLogMock.EXPECT().LogMessage(entry).Return(errors.New("Failed to Log AbortSaga Message")) - s := Saga{ - log: sagaLogMock, - } + s := MakeSaga(sagaLogMock) + state, err := s.StartSaga("testSaga", nil) state, err = s.AbortSaga(state) if err == nil { - t.Error(fmt.Sprintf("Expected AbortSaga to not return an error when write to SagaLog Fails")) + t.Error("Expected AbortSaga to not return an error when write to SagaLog Fails") } } @@ -150,14 +138,12 @@ func TestStartTask(t *testing.T) { sagaLogMock.EXPECT().StartSaga("testSaga", nil) sagaLogMock.EXPECT().LogMessage(entry) - s := Saga{ - log: sagaLogMock, - } + s := MakeSaga(sagaLogMock) state, err := s.StartSaga("testSaga", nil) state, err = s.StartTask(state, "task1", nil) if err != nil { - t.Error(fmt.Sprintf("Expected StartTask to not return an error")) + t.Error("Expected StartTask to not return an error") } } @@ -171,14 +157,13 @@ func TestStartTaskLogError(t *testing.T) { sagaLogMock.EXPECT().StartSaga("testSaga", nil) sagaLogMock.EXPECT().LogMessage(entry).Return(errors.New("Failed to Log StartTask Message")) - s := Saga{ - log: sagaLogMock, - } + s := MakeSaga(sagaLogMock) + state, err := s.StartSaga("testSaga", nil) state, err = s.StartTask(state, "task1", nil) if err == nil { - t.Error(fmt.Sprintf("Expected StartTask to not return an error when write to SagaLog Fails")) + t.Error("Expected StartTask to not return an error when write to SagaLog Fails") } } @@ -193,15 +178,14 @@ func TestEndTask(t *testing.T) { sagaLogMock.EXPECT().LogMessage(MakeStartTaskMessage("testSaga", "task1", nil)) sagaLogMock.EXPECT().LogMessage(entry) - s := Saga{ - log: sagaLogMock, - } + s := MakeSaga(sagaLogMock) + state, err := s.StartSaga("testSaga", nil) state, err = s.StartTask(state, "task1", nil) state, err = s.EndTask(state, "task1", nil) if err != nil { - t.Error(fmt.Sprintf("Expected EndTask to not return an error")) + t.Error("Expected EndTask to not return an error") } } @@ -216,15 +200,14 @@ func TestEndTaskLogError(t *testing.T) { sagaLogMock.EXPECT().LogMessage(MakeStartTaskMessage("testSaga", "task1", nil)) sagaLogMock.EXPECT().LogMessage(entry).Return(errors.New("Failed to Log EndTask Message")) - s := Saga{ - log: sagaLogMock, - } + s := MakeSaga(sagaLogMock) + state, err := s.StartSaga("testSaga", nil) state, err = s.StartTask(state, "task1", nil) state, err = s.EndTask(state, "task1", nil) if err == nil { - t.Error(fmt.Sprintf("Expected EndTask to not return an error when write to SagaLog Fails")) + t.Error("Expected EndTask to not return an error when write to SagaLog Fails") } } @@ -240,16 +223,15 @@ func TestStartCompTask(t *testing.T) { sagaLogMock.EXPECT().LogMessage(MakeAbortSagaMessage("testSaga")) sagaLogMock.EXPECT().LogMessage(entry) - s := Saga{ - log: sagaLogMock, - } + s := MakeSaga(sagaLogMock) + state, err := s.StartSaga("testSaga", nil) state, err = s.StartTask(state, "task1", nil) state, err = s.AbortSaga(state) state, err = s.StartCompensatingTask(state, "task1", nil) if err != nil { - t.Error(fmt.Sprintf("Expected StartCompensatingTask to not return an error")) + t.Error("Expected StartCompensatingTask to not return an error") } } @@ -265,16 +247,15 @@ func TestStartCompTaskLogError(t *testing.T) { sagaLogMock.EXPECT().LogMessage(MakeAbortSagaMessage("testSaga")) sagaLogMock.EXPECT().LogMessage(entry).Return(errors.New("Failed to Log StartCompTask Message")) - s := Saga{ - log: sagaLogMock, - } + s := MakeSaga(sagaLogMock) + state, err := s.StartSaga("testSaga", nil) state, err = s.StartTask(state, "task1", nil) state, err = s.AbortSaga(state) state, err = s.StartCompensatingTask(state, "task1", nil) if err == nil { - t.Error(fmt.Sprintf("Expected StartCompTask to not return an error when write to SagaLog Fails")) + t.Error("Expected StartCompTask to not return an error when write to SagaLog Fails") } } @@ -291,9 +272,8 @@ func TestEndCompTask(t *testing.T) { sagaLogMock.EXPECT().LogMessage(MakeStartCompTaskMessage("testSaga", "task1", nil)) sagaLogMock.EXPECT().LogMessage(entry) - s := Saga{ - log: sagaLogMock, - } + s := MakeSaga(sagaLogMock) + state, err := s.StartSaga("testSaga", nil) state, err = s.StartTask(state, "task1", nil) state, err = s.AbortSaga(state) @@ -301,7 +281,7 @@ func TestEndCompTask(t *testing.T) { state, err = s.EndCompensatingTask(state, "task1", nil) if err != nil { - t.Error(fmt.Sprintf("Expected EndCompensatingTask to not return an error")) + t.Error("Expected EndCompensatingTask to not return an error") } } @@ -318,9 +298,8 @@ func TestEndCompTaskLogError(t *testing.T) { sagaLogMock.EXPECT().LogMessage(MakeStartCompTaskMessage("testSaga", "task1", nil)) sagaLogMock.EXPECT().LogMessage(entry).Return(errors.New("Failed to Log EndCompTask Message")) - s := Saga{ - log: sagaLogMock, - } + s := MakeSaga(sagaLogMock) + state, err := s.StartSaga("testSaga", nil) state, err = s.StartTask(state, "task1", nil) state, err = s.AbortSaga(state) @@ -328,6 +307,107 @@ func TestEndCompTaskLogError(t *testing.T) { state, err = s.EndCompensatingTask(state, "task1", nil) if err == nil { - t.Error(fmt.Sprintf("Expected EndCompTask to not return an error when write to SagaLog Fails")) + t.Error("Expected EndCompTask to not return an error when write to SagaLog Fails") + } +} + +func TestStartup_ReturnsError(t *testing.T) { + + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + sagaLogMock := NewMockSagaLog(mockCtrl) + sagaLogMock.EXPECT().GetActiveSagas().Return(nil, errors.New("test error")) + + s := MakeSaga(sagaLogMock) + ids, err := s.Startup() + + if err == nil { + t.Error("Expected error to not be nil") + } + if ids != nil { + t.Error("ids should be null when error is returned") + } +} + +func TestStartup_ReturnsIds(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + sagaLogMock := NewMockSagaLog(mockCtrl) + sagaLogMock.EXPECT().GetActiveSagas().Return([]string{"saga1", "saga2", "saga3"}, nil) + + s := MakeSaga(sagaLogMock) + ids, err := s.Startup() + + if err != nil { + t.Error(fmt.Sprintf("unexpected error returned %s", err)) + } + if ids == nil { + t.Error("expected is to be returned") + } + + expectedIds := make(map[string]bool) + expectedIds["saga1"] = true + expectedIds["saga2"] = true + expectedIds["saga3"] = true + + for _, id := range ids { + if !expectedIds[id] { + t.Error(fmt.Sprintf("unexpectedId returend %s", id)) + } + } +} + +func TestRecoverSagaState(t *testing.T) { + + sagaId := "saga1" + taskId := "task1" + + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + msgs := []sagaMessage{ + MakeStartSagaMessage(sagaId, nil), + MakeStartTaskMessage(sagaId, taskId, nil), + MakeEndTaskMessage(sagaId, taskId, nil), + MakeEndSagaMessage(sagaId), + } + + sagaLogMock := NewMockSagaLog(mockCtrl) + sagaLogMock.EXPECT().GetMessages(sagaId).Return(msgs, nil) + + s := MakeSaga(sagaLogMock) + state, err := s.RecoverSagaState(sagaId, ForwardRecovery) + + if err != nil { + t.Error(fmt.Sprintf("unexpected error returned %s", err)) + } + if state == nil { + t.Error("expected returned state to not be nil") + } + + if !state.IsSagaCompleted() { + t.Error("expected returned saga state to be completed saga") + } +} + +func TestRecoverSagaState_ReturnsError(t *testing.T) { + sagaId := "saga1" + + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + sagaLogMock := NewMockSagaLog(mockCtrl) + sagaLogMock.EXPECT().GetMessages(sagaId).Return(nil, errors.New("test error")) + + s := MakeSaga(sagaLogMock) + state, err := s.RecoverSagaState(sagaId, RollbackRecovery) + + if err == nil { + t.Error("expeceted error to not be nil") + } + + if state != nil { + t.Error("expected returned state to be nil when error occurs") } } diff --git a/saga/sagalog.go b/saga/sagalog.go index 9f649f578..0cab7ae1e 100644 --- a/saga/sagalog.go +++ b/saga/sagalog.go @@ -24,4 +24,12 @@ type SagaLog interface { * specified saga. */ GetMessages(sagaId string) ([]sagaMessage, error) + + /* + * Returns a list of all in progress sagaIds. + * This MUST include all not completed sagaIds. + * It may also included completed sagas + * Returns an error if it fails. + */ + GetActiveSagas() ([]string, error) } diff --git a/saga/sagalog_mock.go b/saga/sagalog_mock.go index 8045d42ea..319eabcb9 100644 --- a/saga/sagalog_mock.go +++ b/saga/sagalog_mock.go @@ -58,3 +58,14 @@ func (_m *MockSagaLog) GetMessages(sagaId string) ([]sagaMessage, error) { func (_mr *_MockSagaLogRecorder) GetMessages(arg0 interface{}) *gomock.Call { return _mr.mock.ctrl.RecordCall(_mr.mock, "GetMessages", arg0) } + +func (_m *MockSagaLog) GetActiveSagas() ([]string, error) { + ret := _m.ctrl.Call(_m, "GetActiveSagas") + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +func (_mr *_MockSagaLogRecorder) GetActiveSagas() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "GetActiveSagas") +} diff --git a/sched/demo/main.go b/sched/demo/main.go index 803e8ef66..bf084b9cf 100644 --- a/sched/demo/main.go +++ b/sched/demo/main.go @@ -2,13 +2,15 @@ package main import ( "fmt" - "sync" msg "github.com/scootdev/scoot/messages" - saga "github.com/scootdev/scoot/saga" + s "github.com/scootdev/scoot/saga" ci "github.com/scootdev/scoot/sched/clusterimplementations" cm "github.com/scootdev/scoot/sched/clustermembership" distributor "github.com/scootdev/scoot/sched/distributor" + + "os" + "sync" ) /* demo code */ @@ -20,6 +22,7 @@ func main() { workCh := make(chan msg.Job) distributor := &distributor.RoundRobin{} + saga := s.MakeInMemorySaga() var wg sync.WaitGroup wg.Add(2) @@ -30,21 +33,50 @@ func main() { }() go func() { - scheduleWork(workCh, cluster, distributor) + scheduleWork(workCh, cluster, distributor, saga) wg.Done() }() wg.Wait() + + ids, err := saga.Startup() + + // we are using an in memory saga here if we can't get the active sagas something is + // very wrong just exit the program. + if err != nil { + fmt.Println("ERROR getting active sagas ", err) + os.Exit(2) + } + + completedSagas := 0 + + for _, sagaId := range ids { + + sagaState, err := saga.RecoverSagaState(sagaId, s.ForwardRecovery) + if err != nil { + // For now just print error in actual scheduler we'd want to retry multiple times, + // before putting it on a deadletter queue + fmt.Println(fmt.Sprintf("ERROR recovering saga state for %s: %s", sagaId, err)) + } + + // all Sagas are expected to be completed + if !sagaState.IsSagaCompleted() { + fmt.Println(fmt.Sprintf("Expected all Sagas to be Completed %s is not", sagaId)) + } else { + completedSagas++ + } + } + + fmt.Println("Jobs Completed:", completedSagas) } func scheduleWork( workCh <-chan msg.Job, cluster cm.Cluster, - distributor distributor.Distributor) { + distributor distributor.Distributor, + saga s.Saga) { var wg sync.WaitGroup - saga := saga.MakeInMemorySaga() - for work := range workCh { node := distributor.DistributeWork(work, cluster)