From 56c79b265d93ec2c423ac5820db389b6a6f22c49 Mon Sep 17 00:00:00 2001 From: Matt Fellows Date: Sat, 14 Jul 2018 19:30:51 +1000 Subject: [PATCH] fix(tests): remove race condition in service manager. Fixes #89 --- client/service_manager.go | 37 +++++++++++++++++++++----- client/service_test.go | 56 ++++++++++++++++++++++++++------------- 2 files changed, 69 insertions(+), 24 deletions(-) diff --git a/client/service_manager.go b/client/service_manager.go index 606c39e29..9d0bf2ae6 100644 --- a/client/service_manager.go +++ b/client/service_manager.go @@ -5,13 +5,14 @@ import ( "log" "os" "os/exec" + "sync" "time" ) // ServiceManager is the default implementation of the Service interface. type ServiceManager struct { Cmd string - processes map[int]*exec.Cmd + processMap processMap Args []string Env []string commandCompleteChan chan *exec.Cmd @@ -24,7 +25,7 @@ func (s *ServiceManager) Setup() { s.commandCreatedChan = make(chan *exec.Cmd) s.commandCompleteChan = make(chan *exec.Cmd) - s.processes = make(map[int]*exec.Cmd) + s.processMap = processMap{processes: make(map[int]*exec.Cmd)} // Listen for service create/kill go s.addServiceMonitor() @@ -38,7 +39,7 @@ func (s *ServiceManager) addServiceMonitor() { select { case p := <-s.commandCreatedChan: if p != nil && p.Process != nil { - s.processes[p.Process.Pid] = p + s.processMap.Set(p.Process.Pid, p) } } } @@ -53,7 +54,7 @@ func (s *ServiceManager) removeServiceMonitor() { case p = <-s.commandCompleteChan: if p != nil && p.Process != nil { p.Process.Signal(os.Interrupt) - delete(s.processes, p.Process.Pid) + s.processMap.Delete(p.Process.Pid) } } } @@ -62,7 +63,7 @@ func (s *ServiceManager) removeServiceMonitor() { // Stop a Service and returns the exit status. func (s *ServiceManager) Stop(pid int) (bool, error) { log.Println("[DEBUG] stopping service with pid", pid) - cmd := s.processes[pid] + cmd := s.processMap.Get(pid) // Remove service from registry go func() { @@ -96,7 +97,7 @@ func (s *ServiceManager) Stop(pid int) (bool, error) { // List all Service PIDs. func (s *ServiceManager) List() map[int]*exec.Cmd { log.Println("[DEBUG] listing services") - return s.processes + return s.processMap.processes } // Command executes the command @@ -151,3 +152,27 @@ func (s *ServiceManager) Start() *exec.Cmd { return cmd } + +type processMap struct { + sync.RWMutex + processes map[int]*exec.Cmd +} + +func (pm *processMap) Get(k int) *exec.Cmd { + pm.RLock() + defer pm.RUnlock() + v, _ := pm.processes[k] + return v +} + +func (pm *processMap) Set(k int, v *exec.Cmd) { + pm.Lock() + defer pm.Unlock() + pm.processes[k] = v +} + +func (pm *processMap) Delete(k int) { + pm.Lock() + defer pm.Unlock() + delete(pm.processes, k) +} diff --git a/client/service_test.go b/client/service_test.go index 34e4c264a..5b9be9ae6 100644 --- a/client/service_test.go +++ b/client/service_test.go @@ -48,22 +48,25 @@ func TestServiceManager_removeServiceMonitor(t *testing.T) { mgr := createServiceManager() cmd := fakeExecCommand("", true, "") cmd.Start() - mgr.processes = map[int]*exec.Cmd{ + mgr.processMap.processes = map[int]*exec.Cmd{ cmd.Process.Pid: cmd, } mgr.commandCompleteChan <- cmd var timeout = time.After(channelTimeout) for { + mgr.processMap.Lock() + defer mgr.processMap.Unlock() + select { case <-time.After(10 * time.Millisecond): - if len(mgr.processes) == 0 { + if len(mgr.processMap.processes) == 0 { return } case <-timeout: - if len(mgr.processes) != 0 { + if len(mgr.processMap.processes) != 0 { t.Fatalf(`Expected 1 command to be removed from the queue. Have %d - Timed out after 500millis`, len(mgr.processes)) + Timed out after 500millis`, len(mgr.processMap.processes)) } } } @@ -77,15 +80,20 @@ func TestServiceManager_addServiceMonitor(t *testing.T) { var timeout = time.After(channelTimeout) for { + select { case <-time.After(10 * time.Millisecond): - if len(mgr.processes) == 1 { + mgr.processMap.Lock() + defer mgr.processMap.Unlock() + if len(mgr.processMap.processes) == 1 { return } case <-timeout: - if len(mgr.processes) != 1 { + mgr.processMap.Lock() + defer mgr.processMap.Unlock() + if len(mgr.processMap.processes) != 1 { t.Fatalf(`Expected 1 command to be added to the queue, but got: %d. - Timed out after 500millis`, len(mgr.processes)) + Timed out after 500millis`, len(mgr.processMap.processes)) } return } @@ -99,16 +107,18 @@ func TestServiceManager_addServiceMonitorWithDeadJob(t *testing.T) { var timeout = time.After(channelTimeout) for { + select { case <-time.After(10 * time.Millisecond): - if len(mgr.processes) != 0 { + + if len(mgr.processMap.processes) != 0 { t.Fatalf(`Expected 0 command to be added to the queue, but got: %d. - Timed out after 5 attempts`, len(mgr.processes)) + Timed out after 5 attempts`, len(mgr.processMap.processes)) } case <-timeout: - if len(mgr.processes) != 0 { + if len(mgr.processMap.processes) != 0 { t.Fatalf(`Expected 0 command to be added to the queue, but got: %d. - Timed out after 50millis`, len(mgr.processes)) + Timed out after 50millis`, len(mgr.processMap.processes)) } return } @@ -119,20 +129,23 @@ func TestServiceManager_Stop(t *testing.T) { mgr := createServiceManager() cmd := fakeExecCommand("", true, "") cmd.Start() - mgr.processes = map[int]*exec.Cmd{ + mgr.processMap.processes = map[int]*exec.Cmd{ cmd.Process.Pid: cmd, } mgr.Stop(cmd.Process.Pid) var timeout = time.After(channelTimeout) for { + mgr.processMap.Lock() + defer mgr.processMap.Unlock() + select { case <-time.After(10 * time.Millisecond): - if len(mgr.processes) == 0 { + if len(mgr.processMap.processes) == 0 { return } case <-timeout: - if len(mgr.processes) != 0 { + if len(mgr.processMap.processes) != 0 { t.Fatalf(`Expected 1 command to be removed from the queue. Timed out after 500millis`) } @@ -148,7 +161,9 @@ func TestServiceManager_List(t *testing.T) { processes := map[int]*exec.Cmd{ cmd.Process.Pid: cmd, } - mgr.processes = processes + mgr.processMap.Lock() + mgr.processMap.processes = processes + mgr.processMap.Unlock() if !reflect.DeepEqual(processes, mgr.List()) { t.Fatalf("Expected mgr.List() to equal processes") @@ -161,15 +176,20 @@ func TestServiceManager_Start(t *testing.T) { var timeout = time.After(channelTimeout) for { + select { case <-time.After(10 * time.Millisecond): - if len(mgr.processes) == 1 { + mgr.processMap.Lock() + if len(mgr.processMap.processes) == 1 { + mgr.processMap.Unlock() return } case <-timeout: - if len(mgr.processes) != 1 { + mgr.processMap.Lock() + defer mgr.processMap.Unlock() + if len(mgr.processMap.processes) != 1 { t.Fatalf(`Expected 1 command to be added to the queue, but got: %d. - Timed out after 500millis`, len(mgr.processes)) + Timed out after 500millis`, len(mgr.processMap.processes)) } return }