diff --git a/Makefile b/Makefile index 5505c548d..e65975371 100644 --- a/Makefile +++ b/Makefile @@ -68,7 +68,7 @@ build-all: GOOS=windows go build -o receptor.exe ./cmd/receptor-cl && \ GOOS=darwin go build -o receptor.app ./cmd/receptor-cl && \ go build example/*.go && \ - go build -o receptor --tags no_controlsvc,no_backends,no_services,no_tls_config,no_workceptor,no_cert_auth ./cmd/receptor-cl && \ + go build -o receptor --tags no_backends,no_services,no_tls_config,no_workceptor,no_cert_auth ./cmd/receptor-cl && \ go build -o receptor ./cmd/receptor-cl DIST := receptor_$(shell echo '$(VERSION)' | sed 's/^v//')_$(GOOS)_$(GOARCH) diff --git a/pkg/controlsvc/connect.go b/pkg/controlsvc/connect.go index 13f599c1e..db3a9912f 100644 --- a/pkg/controlsvc/connect.go +++ b/pkg/controlsvc/connect.go @@ -83,7 +83,7 @@ func (c *connectCommand) ControlFunc(_ context.Context, nc NetceptorForControlCo if err != nil { return nil, err } - err = cfo.BridgeConn("Connecting\n", rc, "connected service", nc.GetLogger()) + err = cfo.BridgeConn("Connecting\n", rc, "connected service", nc.GetLogger(), &Util{}) if err != nil { return nil, err } diff --git a/pkg/controlsvc/controlsvc.go b/pkg/controlsvc/controlsvc.go index 18e591bbc..90bac58c4 100644 --- a/pkg/controlsvc/controlsvc.go +++ b/pkg/controlsvc/controlsvc.go @@ -73,24 +73,20 @@ type Tlser interface { NewListener(inner net.Listener, config *tls.Config) net.Listener } -type Tls struct{} +type TLS struct{} -func (t *Tls) NewListener(inner net.Listener, config *tls.Config) net.Listener { +func (t *TLS) NewListener(inner net.Listener, config *tls.Config) net.Listener { return tls.NewListener(inner, config) } // SockControl implements the ControlFuncOperations interface that is passed back to control functions. type SockControl struct { - conn net.Conn - utils Utiler - io Copier + conn net.Conn } -func NewSockControl(conn net.Conn, utils Utiler, copier Copier) *SockControl { +func NewSockControl(conn net.Conn) *SockControl { return &SockControl{ - conn: conn, - utils: utils, - io: copier, + conn: conn, } } @@ -98,7 +94,7 @@ func (s *SockControl) RemoteAddr() net.Addr { return s.conn.RemoteAddr() } -// WriteMessage attempts to write a message to a connection +// WriteMessage attempts to write a message to a connection. func (s *SockControl) WriteMessage(message string) error { if message != "" { _, err := s.conn.Write([]byte(message)) @@ -106,25 +102,26 @@ func (s *SockControl) WriteMessage(message string) error { return err } } + return nil } // BridgeConn bridges the socket to another socket. -func (s *SockControl) BridgeConn(message string, bc io.ReadWriteCloser, bcName string, logger *logger.ReceptorLogger) error { +func (s *SockControl) BridgeConn(message string, bc io.ReadWriteCloser, bcName string, logger *logger.ReceptorLogger, utils Utiler) error { if err := s.WriteMessage(message); err != nil { return err } - s.utils.BridgeConns(s.conn, "control service", bc, bcName, logger) + utils.BridgeConns(s.conn, "control service", bc, bcName, logger) return nil } // ReadFromConn copies from the socket to an io.Writer, until EOF. -func (s *SockControl) ReadFromConn(message string, out io.Writer) error { +func (s *SockControl) ReadFromConn(message string, out io.Writer, io Copier) error { if err := s.WriteMessage(message); err != nil { return err } - if _, err := s.io.Copy(out, s.conn); err != nil { + if _, err := io.Copy(out, s.conn); err != nil { return err } @@ -155,10 +152,9 @@ type Server struct { nc NetceptorForControlsvc controlFuncLock sync.RWMutex controlTypes map[string]ControlCommandType - // new stuff - serverUtils Utiler - serverNet Neter - serverTls Tlser + serverUtils Utiler + serverNet Neter + serverTLS Tlser } // New returns a new instance of a control service. @@ -169,7 +165,7 @@ func New(stdServices bool, nc NetceptorForControlsvc) *Server { controlTypes: make(map[string]ControlCommandType), serverUtils: &Util{}, serverNet: &Net{}, - serverTls: &Tls{}, + serverTLS: &TLS{}, } if stdServices { s.controlTypes["ping"] = &pingCommandType{} @@ -190,8 +186,8 @@ func (s *Server) SetServerNet(n Neter) { s.serverNet = n } -func (s *Server) SetServerTls(t Tlser) { - s.serverTls = t +func (s *Server) SetServerTLS(t Tlser) { + s.serverTLS = t } // MainInstance is the global instance of the control service instantiated by the command-line main() function. @@ -217,11 +213,13 @@ func errorNormal(nc NetceptorForControlsvc, logMessage string, err error) bool { if !strings.HasSuffix(err.Error(), normalCloseError) { nc.GetLogger().Error("%s: %s\n", logMessage, err) } + return true } func writeToConnWithLog(conn net.Conn, nc NetceptorForControlsvc, writeMessage string, logMessage string) bool { _, err := conn.Write([]byte(writeMessage)) + return errorNormal(nc, logMessage, err) } @@ -318,7 +316,7 @@ func (s *Server) RunControlSession(conn net.Conn) { } s.controlFuncLock.RUnlock() if ct != nil { - cfo := NewSockControl(conn, &Util{}, &SocketConnIO{}) + cfo := NewSockControl(conn) var cfr map[string]interface{} var cc ControlCommand @@ -435,7 +433,7 @@ func (s *Server) RunControlSvc(ctx context.Context, service string, tlscfg *tls. return fmt.Errorf("error listening on TCP socket: %s", err) } if tcptls != nil { - tli = s.serverTls.NewListener(tli, tcptls) + tli = s.serverTLS.NewListener(tli, tcptls) } } else { tli = nil @@ -469,12 +467,10 @@ func (s *Server) RunControlSvc(ctx context.Context, service string, tlscfg *tls. } }() for _, listener := range []net.Listener{uli, tli, li} { - if reflect.ValueOf(listener).IsNil() { + if listener == nil || reflect.ValueOf(listener).IsNil() { continue } - if listener != nil { - go s.ConnectionListener(ctx, listener) - } + go s.ConnectionListener(ctx, listener) } return nil diff --git a/pkg/controlsvc/controlsvc_test.go b/pkg/controlsvc/controlsvc_test.go index 81df8b787..c6fe49750 100644 --- a/pkg/controlsvc/controlsvc_test.go +++ b/pkg/controlsvc/controlsvc_test.go @@ -19,7 +19,7 @@ const ( writeToConnError = "write to conn write message err" ) -func printExpectedError(t *testing.T, err error) { +func printErrorMessage(t *testing.T, err error) { t.Errorf("expected error %s", err) } @@ -34,16 +34,13 @@ func TestConnectionListener(t *testing.T) { expectedError bool expectedCalls func(context.CancelFunc) }{ - { - name: "return from context error", - expectedError: true, - }, { name: "error accepting connection", expectedError: false, expectedCalls: func(ctxCancel context.CancelFunc) { mockListener.EXPECT().Accept().DoAndReturn(func() (net.Conn, error) { ctxCancel() + return nil, errors.New("terminated") }) mockNetceptor.EXPECT().GetLogger().Return(logger) @@ -56,19 +53,11 @@ func TestConnectionListener(t *testing.T) { ctx, ctxCancel := context.WithCancel(context.Background()) defer ctxCancel() - if testCase.expectedCalls != nil { - testCase.expectedCalls(ctxCancel) - } + testCase.expectedCalls(ctxCancel) s := controlsvc.New(false, mockNetceptor) - - if testCase.expectedError { - ctxCancel() - } - s.ConnectionListener(ctx, mockListener) }) } - } func TestSetupConnection(t *testing.T) { @@ -79,12 +68,10 @@ func TestSetupConnection(t *testing.T) { setupConnectionTestCases := []struct { name string - expectedError bool expectedCalls func() }{ { - name: "log error - setting timeout", - expectedError: true, + name: "log error - setting timeout", expectedCalls: func() { mockConn.EXPECT().SetDeadline(gomock.Any()).Return(errors.New("terminated")) mockNetceptor.EXPECT().GetLogger().Return(logger) @@ -92,8 +79,7 @@ func TestSetupConnection(t *testing.T) { }, }, { - name: "log error - tls handshake", - expectedError: true, + name: "log error - tls handshake", expectedCalls: func() { mockConn.EXPECT().SetDeadline(gomock.Any()).Return(nil) mockNetceptor.EXPECT().GetLogger().Return(logger) @@ -163,6 +149,9 @@ func TestRunControlSvc(t *testing.T) { { name: "no listeners error", expectedError: "no listeners specified", + expectedCalls: func() { + // empty func for testing + }, listeners: map[string]string{ "service": "", "unixSocket": "", @@ -173,16 +162,14 @@ func TestRunControlSvc(t *testing.T) { for _, testCase := range runControlSvcTestCases { t.Run(testCase.name, func(t *testing.T) { - if testCase.expectedCalls != nil { - testCase.expectedCalls() - } + testCase.expectedCalls() s := controlsvc.New(false, mockNetceptor) s.SetServerUtils(mockUnix) s.SetServerNet(mockNet) err := s.RunControlSvc(context.Background(), testCase.listeners["service"], &tls.Config{}, testCase.listeners["unixSocket"], os.FileMode(0o600), testCase.listeners["tcpListen"], &tls.Config{}) - if err == nil || err.Error() != testCase.expectedError { + if err.Error() != testCase.expectedError { t.Errorf("expected error %s, got %v", testCase.expectedError, err) } }) @@ -194,9 +181,7 @@ func TestSockControlRemoteAddr(t *testing.T) { mockCon := mock_controlsvc.NewMockConn(ctrl) mockAddr := mock_controlsvc.NewMockAddr(ctrl) - mockUtil := mock_controlsvc.NewMockUtiler(ctrl) - mockCopier := mock_controlsvc.NewMockCopier(ctrl) - sockControl := controlsvc.NewSockControl(mockCon, mockUtil, mockCopier) + sockControl := controlsvc.NewSockControl(mockCon) localhost := "127.0.0.1" @@ -212,41 +197,51 @@ func TestSockControlRemoteAddr(t *testing.T) { func TestSockControlWriteMessage(t *testing.T) { ctrl := gomock.NewController(t) mockCon := mock_controlsvc.NewMockConn(ctrl) - mockUtil := mock_controlsvc.NewMockUtiler(ctrl) - mockCopier := mock_controlsvc.NewMockCopier(ctrl) - - sockControl := controlsvc.NewSockControl(mockCon, mockUtil, mockCopier) + sockControl := controlsvc.NewSockControl(mockCon) writeMessageTestCases := []struct { name string message string + expectedError bool expectedCalls func() }{ { - name: "without message", - message: "", + name: "pass without message", + message: "", + expectedError: false, + expectedCalls: func() { + // empty func for testing + }, }, { - name: "with message", - message: "message", + name: "fail with message", + message: "message", + expectedError: true, expectedCalls: func() { mockCon.EXPECT().Write(gomock.Any()).Return(0, errors.New("cannot write message")) }, }, + { + name: "pass with message", + message: "message", + expectedError: false, + expectedCalls: func() { + mockCon.EXPECT().Write(gomock.Any()).Return(0, nil) + }, + }, } for _, testCase := range writeMessageTestCases { t.Run(testCase.name, func(t *testing.T) { - if testCase.expectedCalls != nil { - testCase.expectedCalls() - } + testCase.expectedCalls() err := sockControl.WriteMessage(testCase.message) - if testCase.message == "" && err != nil { - t.Errorf("should be nil") + if !testCase.expectedError && err != nil { + t.Errorf("write message ran unsuccessfully %s", err) } - if testCase.message != "" && err.Error() != "cannot write message" { - t.Errorf("%s %s", testCase.name, err) + + if testCase.expectedError && err.Error() != "cannot write message" { + printErrorMessage(t, err) } }) } @@ -256,10 +251,9 @@ func TestSockControlBridgeConn(t *testing.T) { ctrl := gomock.NewController(t) mockCon := mock_controlsvc.NewMockConn(ctrl) mockUtil := mock_controlsvc.NewMockUtiler(ctrl) - mockCopier := mock_controlsvc.NewMockCopier(ctrl) logger := logger.NewReceptorLogger("") - sockControl := controlsvc.NewSockControl(mockCon, mockUtil, mockCopier) + sockControl := controlsvc.NewSockControl(mockCon) bridgeConnTestCases := []struct { name string @@ -285,13 +279,13 @@ func TestSockControlBridgeConn(t *testing.T) { for _, testCase := range bridgeConnTestCases { t.Run(testCase.name, func(t *testing.T) { testCase.expectedCalls() - err := sockControl.BridgeConn(testCase.message, mockCon, "test", logger) + err := sockControl.BridgeConn(testCase.message, mockCon, "test", logger, mockUtil) if testCase.message == "" && err != nil { - t.Errorf("should be nil") + t.Errorf("bridge conn ran unsuccessfully") } if testCase.message != "" && err.Error() != "terminated" { - t.Errorf("stuff %v", err) + t.Errorf("write message error for bridge conn %v", err) } }) } @@ -300,10 +294,9 @@ func TestSockControlBridgeConn(t *testing.T) { func TestSockControlReadFromConn(t *testing.T) { ctrl := gomock.NewController(t) mockCon := mock_controlsvc.NewMockConn(ctrl) - mockUtil := mock_controlsvc.NewMockUtiler(ctrl) mockCopier := mock_controlsvc.NewMockCopier(ctrl) - sockControl := controlsvc.NewSockControl(mockCon, mockUtil, mockCopier) + sockControl := controlsvc.NewSockControl(mockCon) bridgeConnTestCases := []struct { name string @@ -344,16 +337,14 @@ func TestSockControlReadFromConn(t *testing.T) { for _, testCase := range bridgeConnTestCases { t.Run(testCase.name, func(t *testing.T) { testCase.expectedCalls() - err := sockControl.ReadFromConn(testCase.message, mockCon) + err := sockControl.ReadFromConn(testCase.message, mockCon, mockCopier) - if testCase.expectedError { - if err == nil && err.Error() != testCase.errorMessage { - printExpectedError(t, err) - } - } else { - if err != nil { - printExpectedError(t, err) - } + if testCase.expectedError && err.Error() != testCase.errorMessage { + printErrorMessage(t, err) + } + + if !testCase.expectedError && err != nil { + printErrorMessage(t, err) } }) } @@ -362,10 +353,7 @@ func TestSockControlReadFromConn(t *testing.T) { func TestSockControlWriteToConn(t *testing.T) { ctrl := gomock.NewController(t) mockCon := mock_controlsvc.NewMockConn(ctrl) - mockUtil := mock_controlsvc.NewMockUtiler(ctrl) - mockCopier := mock_controlsvc.NewMockCopier(ctrl) - - sockControl := controlsvc.NewSockControl(mockCon, mockUtil, mockCopier) + sockControl := controlsvc.NewSockControl(mockCon) bridgeConnTestCases := []struct { name string @@ -415,14 +403,12 @@ func TestSockControlWriteToConn(t *testing.T) { err := sockControl.WriteToConn(testCase.message, c) - if testCase.expectedError { - if err == nil && err.Error() != testCase.errorMessage { - printExpectedError(t, err) - } - } else { - if err != nil { - printExpectedError(t, err) - } + if testCase.expectedError && err.Error() != testCase.errorMessage { + printErrorMessage(t, err) + } + + if !testCase.expectedError && err != nil { + printErrorMessage(t, err) } }) } @@ -431,10 +417,7 @@ func TestSockControlWriteToConn(t *testing.T) { func TestSockControlClose(t *testing.T) { ctrl := gomock.NewController(t) mockCon := mock_controlsvc.NewMockConn(ctrl) - mockUtil := mock_controlsvc.NewMockUtiler(ctrl) - mockCopier := mock_controlsvc.NewMockCopier(ctrl) - - sockControl := controlsvc.NewSockControl(mockCon, mockUtil, mockCopier) + sockControl := controlsvc.NewSockControl(mockCon) errorMessage := "cannot close connection" @@ -442,7 +425,7 @@ func TestSockControlClose(t *testing.T) { err := sockControl.Close() if err == nil && err.Error() != errorMessage { - printExpectedError(t, err) + printErrorMessage(t, err) } } @@ -451,17 +434,15 @@ func TestAddControlFunc(t *testing.T) { mockCtrlCmd := mock_controlsvc.NewMockControlCommandType(ctrl) mockNetceptor := mock_controlsvc.NewMockNetceptorForControlsvc(ctrl) controlFuncTestsCases := []struct { - name string - input string - expectedError bool - errorMessage string - testCase func(msg string, err error) + name string + input string + errorMessage string + testCase func(msg string, err error) }{ { - name: "ping command", - input: "ping", - expectedError: true, - errorMessage: "control function named ping already exists", + name: "ping command", + input: "ping", + errorMessage: "control function named ping already exists", testCase: func(msg string, err error) { if msg != err.Error() { t.Errorf("expected error: %s, received: %s", msg, err) @@ -469,9 +450,8 @@ func TestAddControlFunc(t *testing.T) { }, }, { - name: "obliterate command", - input: "obliterate", - expectedError: false, + name: "obliterate command", + input: "obliterate", testCase: func(msg string, err error) { if err != nil { t.Errorf("error should be nil. received %s", err) @@ -505,17 +485,12 @@ func TestRunControlSession(t *testing.T) { runControlSessionTestCases := []struct { name string - message string - input chan []byte expectedCalls func() - expectedError bool - errorMessage string }{ { name: "logger warning - could not close connection", expectedCalls: func() { mockCon.EXPECT().Write(gomock.Any()).Return(0, nil) - // meh mockCon.EXPECT().Read(make([]byte, 1)).Return(0, io.EOF) mockCon.EXPECT().Close().Return(errors.New("test")) mockNetceptor.EXPECT().GetLogger().Return(logger) @@ -527,7 +502,6 @@ func TestRunControlSession(t *testing.T) { mockCon.EXPECT().Write(gomock.Any()).Return(0, errors.New("test")) mockCon.EXPECT().Close() }, - errorMessage: "Could not write in control service: test", }, { name: "logger debug - control service closed", @@ -559,24 +533,18 @@ func TestRunControlSession(t *testing.T) { func TestRunControlSessionTwo(t *testing.T) { ctrl := gomock.NewController(t) - mockCon := mock_controlsvc.NewMockConn(ctrl) mockNetceptor := mock_controlsvc.NewMockNetceptorForControlsvc(ctrl) logger := logger.NewReceptorLogger("") runControlSessionTestCases := []struct { name string - message string - input chan []byte expectedCalls func() - expectedError bool - errorMessage string commandByte []byte }{ { name: "command must be a string", expectedCalls: func() { mockNetceptor.EXPECT().NodeID() - mockCon.EXPECT().Write(gomock.Any()).Return(0, nil).AnyTimes() // don't know why mockNetceptor.EXPECT().GetLogger().Return(logger).Times(4) }, commandByte: []byte("{\"command\": 0}"), @@ -585,7 +553,6 @@ func TestRunControlSessionTwo(t *testing.T) { name: "JSON did not contain a command", expectedCalls: func() { mockNetceptor.EXPECT().NodeID() - mockCon.EXPECT().Write(gomock.Any()).Return(0, nil).AnyTimes() mockNetceptor.EXPECT().GetLogger().Return(logger).Times(4) }, commandByte: []byte("{}"), @@ -594,7 +561,6 @@ func TestRunControlSessionTwo(t *testing.T) { name: "command must be a string", expectedCalls: func() { mockNetceptor.EXPECT().NodeID() - mockCon.EXPECT().Write(gomock.Any()).Return(0, nil).AnyTimes() // don't know why mockNetceptor.EXPECT().GetLogger().Return(logger).Times(4) }, commandByte: []byte("{\"command\": \"echo\"}"), @@ -603,7 +569,6 @@ func TestRunControlSessionTwo(t *testing.T) { name: "tokens", expectedCalls: func() { mockNetceptor.EXPECT().NodeID() - mockCon.EXPECT().Write(gomock.Any()).Return(0, nil).AnyTimes() // don't know why mockNetceptor.EXPECT().GetLogger().Return(logger).Times(4) }, commandByte: []byte("a b"), @@ -612,7 +577,6 @@ func TestRunControlSessionTwo(t *testing.T) { name: "control types - reload", expectedCalls: func() { mockNetceptor.EXPECT().NodeID() - mockCon.EXPECT().Write(gomock.Any()).Return(0, nil).AnyTimes() // don't know why mockNetceptor.EXPECT().GetLogger().Return(logger).Times(6) }, commandByte: []byte("{\"command\": \"reload\"}"), @@ -621,7 +585,6 @@ func TestRunControlSessionTwo(t *testing.T) { name: "control types - no ping target", expectedCalls: func() { mockNetceptor.EXPECT().NodeID() - mockCon.EXPECT().Write(gomock.Any()).Return(0, nil).AnyTimes() // don't know why mockNetceptor.EXPECT().GetLogger().Return(logger).Times(5) }, commandByte: []byte("{\"command\": \"ping\"}"), @@ -637,7 +600,6 @@ func TestRunControlSessionTwo(t *testing.T) { go func() { pipeA.Write(testCase.commandByte) pipeA.Close() - }() go func() { io.ReadAll(pipeA) diff --git a/pkg/controlsvc/interfaces.go b/pkg/controlsvc/interfaces.go index 80539a39e..4f4c1bbde 100644 --- a/pkg/controlsvc/interfaces.go +++ b/pkg/controlsvc/interfaces.go @@ -36,8 +36,8 @@ type ControlCommand interface { // ControlFuncOperations provides callbacks for control services to take actions. type ControlFuncOperations interface { - BridgeConn(message string, bc io.ReadWriteCloser, bcName string, logger *logger.ReceptorLogger) error - ReadFromConn(message string, out io.Writer) error + BridgeConn(message string, bc io.ReadWriteCloser, bcName string, logger *logger.ReceptorLogger, utils Utiler) error + ReadFromConn(message string, out io.Writer, io Copier) error WriteToConn(message string, in chan []byte) error Close() error RemoteAddr() net.Addr diff --git a/pkg/workceptor/controlsvc.go b/pkg/workceptor/controlsvc.go index bfcf5b705..452d4804f 100644 --- a/pkg/workceptor/controlsvc.go +++ b/pkg/workceptor/controlsvc.go @@ -297,7 +297,7 @@ func (c *workceptorCommand) ControlFunc(ctx context.Context, nc controlsvc.Netce return nil, err } worker.UpdateBasicStatus(WorkStatePending, "Waiting for Input Data", 0) - err = cfo.ReadFromConn(fmt.Sprintf("Work unit created with ID %s. Send stdin data and EOF.\n", worker.ID()), stdin) + err = cfo.ReadFromConn(fmt.Sprintf("Work unit created with ID %s. Send stdin data and EOF.\n", worker.ID()), stdin, &controlsvc.SocketConnIO{}) if err != nil { worker.UpdateBasicStatus(WorkStateFailed, fmt.Sprintf("Error reading input data: %s", err), 0)