diff --git a/x-pack/filebeat/input/etw/input.go b/x-pack/filebeat/input/etw/input.go index 050fcf6ddf9b..b5b331b3c923 100644 --- a/x-pack/filebeat/input/etw/input.go +++ b/x-pack/filebeat/input/etw/input.go @@ -22,6 +22,7 @@ import ( "github.com/elastic/elastic-agent-libs/logp" "github.com/elastic/elastic-agent-libs/mapstr" + "golang.org/x/sync/errgroup" "golang.org/x/sys/windows" ) @@ -66,6 +67,7 @@ type etwInput struct { log *logp.Logger config config etwSession *etw.Session + publisher stateless.Publisher operator sessionOperator } @@ -105,10 +107,13 @@ func (e *etwInput) Run(ctx input.Context, publisher stateless.Publisher) error { if err != nil { return fmt.Errorf("error initializing ETW session: %w", err) } + e.etwSession.Callback = e.consumeEvent + e.publisher = publisher // Set up logger with session information e.log = ctx.Logger.With("session", e.etwSession.Name) e.log.Info("Starting " + inputName + " input") + defer e.log.Info(inputName + " input stopped") // Handle realtime session creation or attachment if e.etwSession.Realtime { @@ -125,71 +130,31 @@ func (e *etwInput) Run(ctx input.Context, publisher stateless.Publisher) error { if err != nil { return fmt.Errorf("realtime session could not be created: %w", err) } - e.log.Debug("created session") + e.log.Debug("created new session") } } - // Defer the cleanup and closing of resources - var wg sync.WaitGroup - var once sync.Once - // Create an error channel to communicate errors from the goroutine - errChan := make(chan error, 1) + stopConsumer := sync.OnceFunc(e.Close) + defer stopConsumer() - defer func() { - once.Do(e.Close) - e.log.Info(inputName + " input stopped") + // Stop the consumer upon input cancellation (shutdown). + go func() { + <-ctx.Cancelation.Done() + stopConsumer() }() - // eventReceivedCallback processes each ETW event - eventReceivedCallback := func(record *etw.EventRecord) uintptr { - if record == nil { - e.log.Error("received null event record") - return 1 - } - - e.log.Debugf("received event %d with length %d", record.EventHeader.EventDescriptor.Id, record.UserDataLength) - - data, err := etw.GetEventProperties(record) - if err != nil { - e.log.Errorw("failed to read event properties", "error", err) - return 1 - } - - evt := buildEvent(data, record.EventHeader, e.etwSession, e.config) - publisher.Publish(evt) - - return 0 - } - - // Set the callback function for the ETW session - e.etwSession.Callback = eventReceivedCallback - // Start a goroutine to consume ETW events - wg.Add(1) - go func() { - defer wg.Done() - e.log.Debug("starting to listen ETW events") + g := new(errgroup.Group) + g.Go(func() error { + e.log.Debug("starting ETW consumer") + defer e.log.Debug("stopped ETW consumer") if err = e.operator.startConsumer(e.etwSession); err != nil { - errChan <- fmt.Errorf("failed to start consumer: %w", err) // Send error to channel - return + return fmt.Errorf("failed running ETW consumer: %w", err) } - e.log.Debug("stopped to read ETW events from session") - errChan <- nil - }() + return nil + }) - // We ensure resources are closed when receiving a cancellation signal - go func() { - <-ctx.Cancelation.Done() - once.Do(e.Close) - }() - - wg.Wait() // Ensure all goroutines have finished before closing - close(errChan) - if err, ok := <-errChan; ok && err != nil { - return err - } - - return nil + return g.Wait() } var ( @@ -271,6 +236,26 @@ func convertFileTimeToGoTime(fileTime64 uint64) time.Time { return time.Unix(0, fileTime.Nanoseconds()).UTC() } +func (e *etwInput) consumeEvent(record *etw.EventRecord) uintptr { + if record == nil { + e.log.Error("received null event record") + return 1 + } + + e.log.Debugf("received event with ID %d and user-data length %d", record.EventHeader.EventDescriptor.Id, record.UserDataLength) + + data, err := etw.GetEventProperties(record) + if err != nil { + e.log.Errorw("failed to read event properties", "error", err) + return 1 + } + + evt := buildEvent(data, record.EventHeader, e.etwSession, e.config) + e.publisher.Publish(evt) + + return 0 +} + // Close stops the ETW session and logs the outcome. func (e *etwInput) Close() { if err := e.operator.stopSession(e.etwSession); err != nil { diff --git a/x-pack/filebeat/input/etw/input_test.go b/x-pack/filebeat/input/etw/input_test.go index af1fa36d4bd5..fd2673278d37 100644 --- a/x-pack/filebeat/input/etw/input_test.go +++ b/x-pack/filebeat/input/etw/input_test.go @@ -107,7 +107,8 @@ func Test_RunEtwInput_AttachToExistingSessionError(t *testing.T) { mockSession := &etw.Session{ Name: "MySession", Realtime: true, - NewSession: false} + NewSession: false, + } return mockSession, nil } // Setup the mock behavior for AttachToExistingSession @@ -146,7 +147,8 @@ func Test_RunEtwInput_CreateRealtimeSessionError(t *testing.T) { mockSession := &etw.Session{ Name: "MySession", Realtime: true, - NewSession: true} + NewSession: true, + } return mockSession, nil } // Setup the mock behavior for AttachToExistingSession @@ -189,7 +191,8 @@ func Test_RunEtwInput_StartConsumerError(t *testing.T) { mockSession := &etw.Session{ Name: "MySession", Realtime: true, - NewSession: true} + NewSession: true, + } return mockSession, nil } // Setup the mock behavior for AttachToExistingSession @@ -232,7 +235,7 @@ func Test_RunEtwInput_StartConsumerError(t *testing.T) { // Run test err := etwInput.Run(inputCtx, nil) - assert.EqualError(t, err, "failed to start consumer: mock error") + assert.EqualError(t, err, "failed running ETW consumer: mock error") } func Test_RunEtwInput_Success(t *testing.T) { @@ -244,7 +247,8 @@ func Test_RunEtwInput_Success(t *testing.T) { mockSession := &etw.Session{ Name: "MySession", Realtime: true, - NewSession: true} + NewSession: true, + } return mockSession, nil } // Setup the mock behavior for AttachToExistingSession @@ -471,7 +475,6 @@ func Test_buildEvent(t *testing.T) { assert.Equal(t, tt.expected["event.severity"], mapEv["event.severity"]) assert.Equal(t, tt.expected["log.file.path"], mapEv["log.file.path"]) assert.Equal(t, tt.expected["log.level"], mapEv["log.level"]) - }) } } @@ -495,7 +498,7 @@ func Test_convertFileTimeToGoTime(t *testing.T) { { name: "TestActualDate", fileTime: 133515900000000000, // February 05, 2024, 7:00:00 AM - want: time.Date(2024, 02, 05, 7, 0, 0, 0, time.UTC), + want: time.Date(2024, 0o2, 0o5, 7, 0, 0, 0, time.UTC), }, } diff --git a/x-pack/libbeat/reader/etw/session.go b/x-pack/libbeat/reader/etw/session.go index 3216ff3af050..9d78d279de2d 100644 --- a/x-pack/libbeat/reader/etw/session.go +++ b/x-pack/libbeat/reader/etw/session.go @@ -229,10 +229,8 @@ func (s *Session) StartConsumer() error { // Open an ETW trace processing handle for consuming events // from an ETW real-time trace session or an ETW log file. s.traceHandler, err = s.openTrace(&elf) - switch { case err == nil: - // Handle specific errors for trace opening. case errors.Is(err, ERROR_BAD_PATHNAME): return fmt.Errorf("invalid log source when opening trace: %w", err) @@ -241,10 +239,10 @@ func (s *Session) StartConsumer() error { default: return fmt.Errorf("failed to open trace: %w", err) } + // Process the trace. This function blocks until processing ends. if err := s.processTrace(&s.traceHandler, 1, nil, nil); err != nil { return fmt.Errorf("failed to process trace: %w", err) } - return nil }