diff --git a/changelog/fragments/1674071325-Fix-issue-where-its-possible-for-a-component-to-receive-a-unit-without-a-config.yaml b/changelog/fragments/1674071325-Fix-issue-where-its-possible-for-a-component-to-receive-a-unit-without-a-config.yaml new file mode 100644 index 00000000000..508224eff17 --- /dev/null +++ b/changelog/fragments/1674071325-Fix-issue-where-its-possible-for-a-component-to-receive-a-unit-without-a-config.yaml @@ -0,0 +1,31 @@ +# Kind can be one of: +# - breaking-change: a change to previously-documented behavior +# - deprecation: functionality that is being removed in a later release +# - bug-fix: fixes a problem in a previous version +# - enhancement: extends functionality but does not break or fix existing behavior +# - feature: new functionality +# - known-issue: problems that we are aware of in a given version +# - security: impacts on the security of a product or a user’s deployment. +# - upgrade: important information for someone upgrading from a prior version +# - other: does not fit into any of the other categories +kind: feature + +# Change summary; a 80ish characters long description of the change. +summary: Fix issue where its possible for a component to receive a unit without a config + +# Long description; in case the summary is not enough to describe the change +# this field accommodate a description without length limits. +#description: + +# Affected component; a word indicating the component this changeset affects. +component: + +# PR number; optional; the PR number that added the changeset. +# If not present is automatically filled by the tooling finding the PR where this changelog fragment has been added. +# NOTE: the tooling supports backports, so it's able to fill the original PR number instead of the backport PR number. +# Please provide it if you are adding a fragment for a different PR. +pr: 2138 + +# Issue number; optional; the GitHub issue related to this changeset (either closes or is part of). +# If not present is automatically filled by the tooling with the issue linked to the PR number. +issue: 2086 diff --git a/internal/pkg/agent/control/v1/proto/control_v1.pb.go b/internal/pkg/agent/control/v1/proto/control_v1.pb.go index fd3902a4a6e..a99fd51b1ab 100644 --- a/internal/pkg/agent/control/v1/proto/control_v1.pb.go +++ b/internal/pkg/agent/control/v1/proto/control_v1.pb.go @@ -16,10 +16,11 @@ package proto import ( - protoreflect "google.golang.org/protobuf/reflect/protoreflect" - protoimpl "google.golang.org/protobuf/runtime/protoimpl" reflect "reflect" sync "sync" + + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" ) const ( diff --git a/internal/pkg/agent/control/v1/proto/control_v1_grpc.pb.go b/internal/pkg/agent/control/v1/proto/control_v1_grpc.pb.go index 6264d2d81bf..43e62f56985 100644 --- a/internal/pkg/agent/control/v1/proto/control_v1_grpc.pb.go +++ b/internal/pkg/agent/control/v1/proto/control_v1_grpc.pb.go @@ -1,3 +1,7 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: // - protoc-gen-go-grpc v1.2.0 @@ -8,6 +12,7 @@ package proto import ( context "context" + grpc "google.golang.org/grpc" codes "google.golang.org/grpc/codes" status "google.golang.org/grpc/status" diff --git a/internal/pkg/agent/control/v2/cproto/control_v2.pb.go b/internal/pkg/agent/control/v2/cproto/control_v2.pb.go index 3e998191173..b203219b5fa 100644 --- a/internal/pkg/agent/control/v2/cproto/control_v2.pb.go +++ b/internal/pkg/agent/control/v2/cproto/control_v2.pb.go @@ -11,11 +11,12 @@ package cproto import ( + reflect "reflect" + sync "sync" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" timestamppb "google.golang.org/protobuf/types/known/timestamppb" - reflect "reflect" - sync "sync" ) const ( diff --git a/internal/pkg/agent/control/v2/cproto/control_v2_grpc.pb.go b/internal/pkg/agent/control/v2/cproto/control_v2_grpc.pb.go index f7c377c84eb..86ad29e7f8a 100644 --- a/internal/pkg/agent/control/v2/cproto/control_v2_grpc.pb.go +++ b/internal/pkg/agent/control/v2/cproto/control_v2_grpc.pb.go @@ -1,3 +1,7 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: // - protoc-gen-go-grpc v1.2.0 @@ -8,6 +12,7 @@ package cproto import ( context "context" + grpc "google.golang.org/grpc" codes "google.golang.org/grpc/codes" status "google.golang.org/grpc/status" diff --git a/pkg/component/fake/component/main.go b/pkg/component/fake/component/main.go index 2f10148357f..a64484f4783 100644 --- a/pkg/component/fake/component/main.go +++ b/pkg/component/fake/component/main.go @@ -12,6 +12,7 @@ import ( "io" "os" "os/signal" + "strconv" "syscall" "time" @@ -46,7 +47,7 @@ func main() { } func run() error { - logger := zerolog.New(os.Stderr).With().Timestamp().Logger() + logger := zerolog.New(os.Stderr).Level(zerolog.TraceLevel).With().Timestamp().Logger() ver := client.VersionInfo{ Name: fake, Version: "1.0", @@ -347,7 +348,8 @@ type fakeInput struct { state client.UnitState stateMsg string - canceller context.CancelFunc + canceller context.CancelFunc + killerCanceller context.CancelFunc } func newFakeInput(logger zerolog.Logger, logLevel client.UnitLogLevel, manager *stateManager, unit *client.Unit, cfg *proto.UnitExpectedConfig) (*fakeInput, error) { @@ -399,7 +401,7 @@ func newFakeInput(logger zerolog.Logger, logLevel client.UnitLogLevel, manager * } }() i.canceller = cancel - + i.parseConfig(cfg) return i, nil } @@ -429,6 +431,7 @@ func (f *fakeInput) Update(u *client.Unit) error { return fmt.Errorf("unit type changed with the same unit ID: %s", config.Type) } + f.parseConfig(config) state, stateMsg, err := getStateFromConfig(config) if err != nil { return fmt.Errorf("unit config parsing error: %w", err) @@ -440,6 +443,65 @@ func (f *fakeInput) Update(u *client.Unit) error { return nil } +func (f *fakeInput) parseConfig(config *proto.UnitExpectedConfig) { + // handle a case for killing the component when the pid of the component + // matches the current running PID + cfg := config.Source.AsMap() + killPIDRaw, kill := cfg["kill"] + if kill { + f.maybeKill(killPIDRaw) + } + + // handle a case where random killing of the component is enabled + _, killOnInterval := cfg["kill_on_interval"] + f.logger.Trace().Bool("kill_on_interval", killOnInterval).Msg("kill_on_interval config set value") + if killOnInterval { + f.logger.Info().Msg("starting interval killer") + f.runKiller() + } else { + f.logger.Info().Msg("stopping interval killer") + f.stopKiller() + } +} + +func (f *fakeInput) maybeKill(pidRaw interface{}) { + if killPID, ok := pidRaw.(string); ok { + if pid, err := strconv.Atoi(killPID); err == nil { + if pid == os.Getpid() { + f.logger.Warn().Msg("killing from config pid") + os.Exit(1) + } + } + } +} + +func (f *fakeInput) runKiller() { + if f.killerCanceller != nil { + // already running + return + } + ctx, canceller := context.WithCancel(context.Background()) + f.killerCanceller = canceller + go func() { + t := time.NewTimer(500 * time.Millisecond) + defer t.Stop() + select { + case <-ctx.Done(): + return + case <-t.C: + f.logger.Warn().Msg("killer performing kill") + os.Exit(1) + } + }() +} + +func (f *fakeInput) stopKiller() { + if f.killerCanceller != nil { + f.killerCanceller() + f.killerCanceller = nil + } +} + type stateSetterAction struct { input *fakeInput } diff --git a/pkg/component/fake/shipper/listener.go b/pkg/component/fake/shipper/listener.go index ce4d6a99a41..a888021d158 100644 --- a/pkg/component/fake/shipper/listener.go +++ b/pkg/component/fake/shipper/listener.go @@ -10,6 +10,7 @@ import ( "fmt" "net" "os" + "path/filepath" "strings" ) @@ -21,6 +22,13 @@ func createListener(path string) (net.Listener, error) { if _, err := os.Stat(path); !os.IsNotExist(err) { os.Remove(path) } + dir := filepath.Dir(path) + if _, err := os.Stat(dir); os.IsNotExist(err) { + err := os.MkdirAll(dir, 0750) + if err != nil { + return nil, err + } + } lis, err := net.Listen("unix", path) if err != nil { return nil, err diff --git a/pkg/component/runtime/command.go b/pkg/component/runtime/command.go index 94fc4428b74..778ce38331e 100644 --- a/pkg/component/runtime/command.go +++ b/pkg/component/runtime/command.go @@ -151,7 +151,7 @@ func (c *CommandRuntime) Run(ctx context.Context, comm Communicator) error { sendExpected := c.state.syncExpected(&newComp) changed := c.state.syncUnits(&newComp) if sendExpected || c.state.unsettled() { - comm.CheckinExpected(c.state.toCheckinExpected()) + comm.CheckinExpected(c.state.toCheckinExpected(), nil) } if changed { c.sendObserved() @@ -177,7 +177,7 @@ func (c *CommandRuntime) Run(ctx context.Context, comm Communicator) error { sendExpected = true } if sendExpected { - comm.CheckinExpected(c.state.toCheckinExpected()) + comm.CheckinExpected(c.state.toCheckinExpected(), checkin) } if changed { c.sendObserved() @@ -331,10 +331,6 @@ func (c *CommandRuntime) start(comm Communicator) error { c.lastCheckin = time.Time{} c.missedCheckins = 0 - // Ensure there is no pending checkin expected message buffered to avoid sending the new process - // the expected state of the previous process: https://github.com/elastic/beats/issues/34137 - comm.ClearPendingCheckinExpected() - proc, err := process.Start(path, process.WithArgs(args), process.WithEnv(env), diff --git a/pkg/component/runtime/conn_info_server_test.go b/pkg/component/runtime/conn_info_server_test.go index 43b7937eb99..8d445e0a5db 100644 --- a/pkg/component/runtime/conn_info_server_test.go +++ b/pkg/component/runtime/conn_info_server_test.go @@ -57,7 +57,7 @@ func (c *mockCommunicator) WriteConnInfo(w io.Writer, services ...client.Service return nil } -func (c *mockCommunicator) CheckinExpected(expected *proto.CheckinExpected) { +func (c *mockCommunicator) CheckinExpected(expected *proto.CheckinExpected, observed *proto.CheckinObserved) { } func (c *mockCommunicator) ClearPendingCheckinExpected() { diff --git a/pkg/component/runtime/manager_test.go b/pkg/component/runtime/manager_test.go index 2e97fb22121..7d204731aaf 100644 --- a/pkg/component/runtime/manager_test.go +++ b/pkg/component/runtime/manager_test.go @@ -11,6 +11,7 @@ import ( "fmt" "os" "path/filepath" + "regexp" "runtime" "testing" "time" @@ -63,7 +64,7 @@ func TestManager_SimpleComponentErr(t *testing.T) { defer cancel() ai, _ := info.NewAgentInfo(true) - m, err := NewManager(newErrorLogger(t), "localhost:0", ai, apmtest.DiscardTracer, newTestMonitoringMgr(), configuration.DefaultGRPCConfig()) + m, err := NewManager(newDebugLogger(t), "localhost:0", ai, apmtest.DiscardTracer, newTestMonitoringMgr(), configuration.DefaultGRPCConfig()) require.NoError(t, err) errCh := make(chan error) go func() { @@ -165,7 +166,7 @@ func TestManager_FakeInput_StartStop(t *testing.T) { defer cancel() ai, _ := info.NewAgentInfo(true) - m, err := NewManager(newErrorLogger(t), "localhost:0", ai, apmtest.DiscardTracer, newTestMonitoringMgr(), configuration.DefaultGRPCConfig()) + m, err := NewManager(newDebugLogger(t), "localhost:0", ai, apmtest.DiscardTracer, newTestMonitoringMgr(), configuration.DefaultGRPCConfig()) require.NoError(t, err) errCh := make(chan error) go func() { @@ -193,8 +194,9 @@ func TestManager_FakeInput_StartStop(t *testing.T) { }, Units: []component.Unit{ { - ID: "fake-input", - Type: client.UnitTypeInput, + ID: "fake-input", + Type: client.UnitTypeInput, + LogLevel: client.UnitLogLevelTrace, Config: component.MustExpectedConfig(map[string]interface{}{ "type": "fake", "state": int(client.UnitStateHealthy), @@ -290,7 +292,7 @@ func TestManager_FakeInput_BadUnitToGood(t *testing.T) { defer cancel() ai, _ := info.NewAgentInfo(true) - m, err := NewManager(newErrorLogger(t), "localhost:0", ai, apmtest.DiscardTracer, newTestMonitoringMgr(), configuration.DefaultGRPCConfig()) + m, err := NewManager(newDebugLogger(t), "localhost:0", ai, apmtest.DiscardTracer, newTestMonitoringMgr(), configuration.DefaultGRPCConfig()) require.NoError(t, err) errCh := make(chan error) go func() { @@ -318,8 +320,9 @@ func TestManager_FakeInput_BadUnitToGood(t *testing.T) { }, Units: []component.Unit{ { - ID: "fake-input", - Type: client.UnitTypeInput, + ID: "fake-input", + Type: client.UnitTypeInput, + LogLevel: client.UnitLogLevelTrace, Config: component.MustExpectedConfig(map[string]interface{}{ "type": "fake", "state": int(client.UnitStateHealthy), @@ -360,8 +363,9 @@ func TestManager_FakeInput_BadUnitToGood(t *testing.T) { updatedComp.Units = make([]component.Unit, len(comp.Units)) copy(updatedComp.Units, comp.Units) updatedComp.Units[1] = component.Unit{ - ID: "bad-input", - Type: client.UnitTypeInput, + ID: "bad-input", + Type: client.UnitTypeInput, + LogLevel: client.UnitLogLevelTrace, Config: component.MustExpectedConfig(map[string]interface{}{ "type": "fake", "state": int(client.UnitStateHealthy), @@ -461,7 +465,7 @@ func TestManager_FakeInput_GoodUnitToBad(t *testing.T) { defer cancel() ai, _ := info.NewAgentInfo(true) - m, err := NewManager(newErrorLogger(t), "localhost:0", ai, apmtest.DiscardTracer, newTestMonitoringMgr(), configuration.DefaultGRPCConfig()) + m, err := NewManager(newDebugLogger(t), "localhost:0", ai, apmtest.DiscardTracer, newTestMonitoringMgr(), configuration.DefaultGRPCConfig()) require.NoError(t, err) errCh := make(chan error) go func() { @@ -489,8 +493,9 @@ func TestManager_FakeInput_GoodUnitToBad(t *testing.T) { }, Units: []component.Unit{ { - ID: "fake-input", - Type: client.UnitTypeInput, + ID: "fake-input", + Type: client.UnitTypeInput, + LogLevel: client.UnitLogLevelTrace, Config: component.MustExpectedConfig(map[string]interface{}{ "type": "fake", "state": int(client.UnitStateHealthy), @@ -498,8 +503,9 @@ func TestManager_FakeInput_GoodUnitToBad(t *testing.T) { }), }, { - ID: "good-input", - Type: client.UnitTypeInput, + ID: "good-input", + Type: client.UnitTypeInput, + LogLevel: client.UnitLogLevelTrace, Config: component.MustExpectedConfig(map[string]interface{}{ "type": "fake", "state": int(client.UnitStateHealthy), @@ -616,7 +622,7 @@ func TestManager_FakeInput_Configure(t *testing.T) { defer cancel() ai, _ := info.NewAgentInfo(true) - m, err := NewManager(newErrorLogger(t), "localhost:0", ai, apmtest.DiscardTracer, newTestMonitoringMgr(), configuration.DefaultGRPCConfig()) + m, err := NewManager(newDebugLogger(t), "localhost:0", ai, apmtest.DiscardTracer, newTestMonitoringMgr(), configuration.DefaultGRPCConfig()) require.NoError(t, err) errCh := make(chan error) go func() { @@ -644,8 +650,9 @@ func TestManager_FakeInput_Configure(t *testing.T) { }, Units: []component.Unit{ { - ID: "fake-input", - Type: client.UnitTypeInput, + ID: "fake-input", + Type: client.UnitTypeInput, + LogLevel: client.UnitLogLevelTrace, Config: component.MustExpectedConfig(map[string]interface{}{ "type": "fake", "state": int(client.UnitStateHealthy), @@ -742,7 +749,7 @@ func TestManager_FakeInput_RemoveUnit(t *testing.T) { defer cancel() ai, _ := info.NewAgentInfo(true) - m, err := NewManager(newErrorLogger(t), "localhost:0", ai, apmtest.DiscardTracer, newTestMonitoringMgr(), configuration.DefaultGRPCConfig()) + m, err := NewManager(newDebugLogger(t), "localhost:0", ai, apmtest.DiscardTracer, newTestMonitoringMgr(), configuration.DefaultGRPCConfig()) require.NoError(t, err) errCh := make(chan error) go func() { @@ -770,8 +777,9 @@ func TestManager_FakeInput_RemoveUnit(t *testing.T) { }, Units: []component.Unit{ { - ID: "fake-input-0", - Type: client.UnitTypeInput, + ID: "fake-input-0", + Type: client.UnitTypeInput, + LogLevel: client.UnitLogLevelTrace, Config: component.MustExpectedConfig(map[string]interface{}{ "type": "fake", "state": int(client.UnitStateHealthy), @@ -779,8 +787,9 @@ func TestManager_FakeInput_RemoveUnit(t *testing.T) { }), }, { - ID: "fake-input-1", - Type: client.UnitTypeInput, + ID: "fake-input-1", + Type: client.UnitTypeInput, + LogLevel: client.UnitLogLevelTrace, Config: component.MustExpectedConfig(map[string]interface{}{ "type": "fake", "state": int(client.UnitStateHealthy), @@ -900,7 +909,7 @@ func TestManager_FakeInput_ActionState(t *testing.T) { defer cancel() ai, _ := info.NewAgentInfo(true) - m, err := NewManager(newErrorLogger(t), "localhost:0", ai, apmtest.DiscardTracer, newTestMonitoringMgr(), configuration.DefaultGRPCConfig()) + m, err := NewManager(newDebugLogger(t), "localhost:0", ai, apmtest.DiscardTracer, newTestMonitoringMgr(), configuration.DefaultGRPCConfig()) require.NoError(t, err) errCh := make(chan error) go func() { @@ -928,8 +937,9 @@ func TestManager_FakeInput_ActionState(t *testing.T) { }, Units: []component.Unit{ { - ID: "fake-input", - Type: client.UnitTypeInput, + ID: "fake-input", + Type: client.UnitTypeInput, + LogLevel: client.UnitLogLevelTrace, Config: component.MustExpectedConfig(map[string]interface{}{ "type": "fake", "state": int(client.UnitStateHealthy), @@ -1030,7 +1040,7 @@ func TestManager_FakeInput_Restarts(t *testing.T) { defer cancel() ai, _ := info.NewAgentInfo(true) - m, err := NewManager(newErrorLogger(t), "localhost:0", ai, apmtest.DiscardTracer, newTestMonitoringMgr(), configuration.DefaultGRPCConfig()) + m, err := NewManager(newDebugLogger(t), "localhost:0", ai, apmtest.DiscardTracer, newTestMonitoringMgr(), configuration.DefaultGRPCConfig()) require.NoError(t, err) errCh := make(chan error) go func() { @@ -1058,8 +1068,9 @@ func TestManager_FakeInput_Restarts(t *testing.T) { }, Units: []component.Unit{ { - ID: "fake-input", - Type: client.UnitTypeInput, + ID: "fake-input", + Type: client.UnitTypeInput, + LogLevel: client.UnitLogLevelTrace, Config: component.MustExpectedConfig(map[string]interface{}{ "type": "fake", "state": int(client.UnitStateHealthy), @@ -1097,6 +1108,8 @@ func TestManager_FakeInput_Restarts(t *testing.T) { // force the input to exit and it should be restarted if !killed { killed = true + + t.Log("triggering kill through action") actionCtx, actionCancel := context.WithTimeout(context.Background(), 500*time.Millisecond) _, err := m.PerformAction(actionCtx, comp, comp.Units[0], "kill", nil) actionCancel() @@ -1162,6 +1175,304 @@ LOOP: require.NoError(t, err) } +func TestManager_FakeInput_Restarts_ConfigKill(t *testing.T) { + testPaths(t) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ai, _ := info.NewAgentInfo(true) + m, err := NewManager(newDebugLogger(t), "localhost:0", ai, apmtest.DiscardTracer, newTestMonitoringMgr(), configuration.DefaultGRPCConfig()) + require.NoError(t, err) + errCh := make(chan error) + go func() { + err := m.Run(ctx) + if errors.Is(err, context.Canceled) { + err = nil + } + errCh <- err + }() + + waitCtx, waitCancel := context.WithTimeout(ctx, 1*time.Second) + defer waitCancel() + if err := m.WaitForReady(waitCtx); err != nil { + require.NoError(t, err) + } + + // adjust input spec to allow restart + cmdSpec := *fakeInputSpec.Command + cmdSpec.RestartMonitoringPeriod = 1 * time.Second + cmdSpec.MaxRestartsPerPeriod = 10 + inputSpec := fakeInputSpec + inputSpec.Command = &cmdSpec + + binaryPath := testBinary(t, "component") + comp := component.Component{ + ID: "fake-default", + InputSpec: &component.InputRuntimeSpec{ + InputType: "fake", + BinaryName: "", + BinaryPath: binaryPath, + Spec: inputSpec, + }, + Units: []component.Unit{ + { + ID: "fake-input", + Type: client.UnitTypeInput, + LogLevel: client.UnitLogLevelTrace, + Config: component.MustExpectedConfig(map[string]interface{}{ + "type": "fake", + "state": int(client.UnitStateHealthy), + "message": "Fake Healthy", + }), + }, + }, + } + + subCtx, subCancel := context.WithCancel(context.Background()) + defer subCancel() + subErrCh := make(chan error) + go func() { + killed := false + + sub := m.Subscribe(subCtx, "fake-default") + for { + select { + case <-subCtx.Done(): + return + case state := <-sub.Ch(): + t.Logf("component state changed: %+v", state) + if state.State == client.UnitStateFailed { + if !killed { + subErrCh <- fmt.Errorf("component failed: %s", state.Message) + } + } else { + unit, ok := state.Units[ComponentUnitKey{UnitType: client.UnitTypeInput, UnitID: "fake-input"}] + if ok { + if unit.State == client.UnitStateFailed { + if !killed { + subErrCh <- fmt.Errorf("unit failed: %s", unit.Message) + } + } else if unit.State == client.UnitStateHealthy { + // force the input to exit and it should be restarted + if !killed { + killed = true + + r := regexp.MustCompile(`pid \'(?P\d+)\'`) + rp := r.FindStringSubmatch(state.Message) + t.Logf("triggering kill through config on pid %s", rp) + comp.Units[0].Config = component.MustExpectedConfig(map[string]interface{}{ + "type": "fake", + "state": int(client.UnitStateHealthy), + "message": "Fake Healthy", + "kill": rp[1], + }) + err := m.Update([]component.Component{comp}) + if err != nil { + subErrCh <- err + } + } else { + // got back to healthy after kill + subErrCh <- nil + } + } else if unit.State == client.UnitStateStarting || unit.State == client.UnitStateStopped { + // acceptable + } else { + // unknown state that should not have occurred + subErrCh <- fmt.Errorf("unit reported unexpected state: %v", unit.State) + } + } else { + subErrCh <- errors.New("unit missing: fake-input") + } + } + } + } + }() + + defer drainErrChan(errCh) + defer drainErrChan(subErrCh) + + startTimer := time.NewTimer(100 * time.Millisecond) + defer startTimer.Stop() + select { + case <-startTimer.C: + err = m.Update([]component.Component{comp}) + require.NoError(t, err) + case err := <-errCh: + t.Fatalf("failed early: %s", err) + } + + endTimer := time.NewTimer(1 * time.Minute) + defer endTimer.Stop() +LOOP: + for { + select { + case <-endTimer.C: + t.Fatalf("timed out after 1 minute") + case err := <-errCh: + require.NoError(t, err) + case err := <-subErrCh: + require.NoError(t, err) + break LOOP + } + } + + subCancel() + cancel() + + err = <-errCh + require.NoError(t, err) +} + +func TestManager_FakeInput_KeepsRestarting(t *testing.T) { + testPaths(t) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ai, _ := info.NewAgentInfo(true) + m, err := NewManager(newDebugLogger(t), "localhost:0", ai, apmtest.DiscardTracer, newTestMonitoringMgr(), configuration.DefaultGRPCConfig()) + require.NoError(t, err) + errCh := make(chan error) + go func() { + err := m.Run(ctx) + if errors.Is(err, context.Canceled) { + err = nil + } + errCh <- err + }() + + waitCtx, waitCancel := context.WithTimeout(ctx, 1*time.Second) + defer waitCancel() + if err := m.WaitForReady(waitCtx); err != nil { + require.NoError(t, err) + } + + // adjust input spec to allow restart + cmdSpec := *fakeInputSpec.Command + cmdSpec.RestartMonitoringPeriod = 1 * time.Second + cmdSpec.MaxRestartsPerPeriod = 10 + inputSpec := fakeInputSpec + inputSpec.Command = &cmdSpec + + binaryPath := testBinary(t, "component") + comp := component.Component{ + ID: "fake-default", + InputSpec: &component.InputRuntimeSpec{ + InputType: "fake", + BinaryName: "", + BinaryPath: binaryPath, + Spec: inputSpec, + }, + Units: []component.Unit{ + { + ID: "fake-input", + Type: client.UnitTypeInput, + LogLevel: client.UnitLogLevelTrace, + Config: component.MustExpectedConfig(map[string]interface{}{ + "type": "fake", + "state": int(client.UnitStateHealthy), + "message": "Fake Healthy", + "kill_on_interval": true, + }), + }, + }, + } + + subCtx, subCancel := context.WithCancel(context.Background()) + defer subCancel() + subErrCh := make(chan error) + go func() { + lastStoppedCount := 0 + stoppedCount := 0 + + sub := m.Subscribe(subCtx, "fake-default") + for { + select { + case <-subCtx.Done(): + return + case state := <-sub.Ch(): + t.Logf("component state changed: %+v", state) + if state.State == client.UnitStateFailed { + // should not go failed because we allow restart per period + subErrCh <- fmt.Errorf("component failed: %s", state.Message) + } else { + unit, ok := state.Units[ComponentUnitKey{UnitType: client.UnitTypeInput, UnitID: "fake-input"}] + if ok { + if unit.State == client.UnitStateFailed { + // unit should not be failed because we allow restart per period + subErrCh <- fmt.Errorf("unit failed: %s", unit.Message) + } else if unit.State == client.UnitStateHealthy { + if lastStoppedCount != stoppedCount { + lastStoppedCount = stoppedCount + + // send new config on each healthy report + comp.Units[0].Config = component.MustExpectedConfig(map[string]interface{}{ + "type": "fake", + "state": int(client.UnitStateHealthy), + "message": fmt.Sprintf("Fake Healthy %d", lastStoppedCount), + "kill_on_interval": true, + }) + err := m.Update([]component.Component{comp}) + if err != nil { + subErrCh <- err + } + } + if stoppedCount >= 3 { + // got stopped 3 times and got back to healthy + subErrCh <- nil + } + } else if unit.State == client.UnitStateStarting { + // acceptable + } else if unit.State == client.UnitStateStopped { + stoppedCount += 1 + } else { + // unknown state that should not have occurred + subErrCh <- fmt.Errorf("unit reported unexpected state: %v", unit.State) + } + } else { + subErrCh <- errors.New("unit missing: fake-input") + } + } + } + } + }() + + defer drainErrChan(errCh) + defer drainErrChan(subErrCh) + + startTimer := time.NewTimer(100 * time.Millisecond) + defer startTimer.Stop() + select { + case <-startTimer.C: + err = m.Update([]component.Component{comp}) + require.NoError(t, err) + case err := <-errCh: + t.Fatalf("failed early: %s", err) + } + + endTimer := time.NewTimer(1 * time.Minute) + defer endTimer.Stop() +LOOP: + for { + select { + case <-endTimer.C: + t.Fatalf("timed out after 1 minute") + case err := <-errCh: + require.NoError(t, err) + case err := <-subErrCh: + require.NoError(t, err) + break LOOP + } + } + + subCancel() + cancel() + + err = <-errCh + require.NoError(t, err) +} + func TestManager_FakeInput_RestartsOnMissedCheckins(t *testing.T) { testPaths(t) @@ -1169,7 +1480,7 @@ func TestManager_FakeInput_RestartsOnMissedCheckins(t *testing.T) { defer cancel() ai, _ := info.NewAgentInfo(true) - m, err := NewManager(newErrorLogger(t), "localhost:0", ai, apmtest.DiscardTracer, newTestMonitoringMgr(), configuration.DefaultGRPCConfig()) + m, err := NewManager(newDebugLogger(t), "localhost:0", ai, apmtest.DiscardTracer, newTestMonitoringMgr(), configuration.DefaultGRPCConfig()) require.NoError(t, err) errCh := make(chan error) go func() { @@ -1290,7 +1601,7 @@ func TestManager_FakeInput_InvalidAction(t *testing.T) { defer cancel() ai, _ := info.NewAgentInfo(true) - m, err := NewManager(newErrorLogger(t), "localhost:0", ai, apmtest.DiscardTracer, newTestMonitoringMgr(), configuration.DefaultGRPCConfig()) + m, err := NewManager(newDebugLogger(t), "localhost:0", ai, apmtest.DiscardTracer, newTestMonitoringMgr(), configuration.DefaultGRPCConfig()) require.NoError(t, err) errCh := make(chan error) go func() { @@ -1414,7 +1725,7 @@ func TestManager_FakeInput_MultiComponent(t *testing.T) { defer cancel() ai, _ := info.NewAgentInfo(true) - m, err := NewManager(newErrorLogger(t), "localhost:0", ai, apmtest.DiscardTracer, newTestMonitoringMgr(), configuration.DefaultGRPCConfig()) + m, err := NewManager(newDebugLogger(t), "localhost:0", ai, apmtest.DiscardTracer, newTestMonitoringMgr(), configuration.DefaultGRPCConfig()) require.NoError(t, err) errCh := make(chan error) go func() { @@ -1626,7 +1937,7 @@ func TestManager_FakeInput_LogLevel(t *testing.T) { defer cancel() ai, _ := info.NewAgentInfo(true) - m, err := NewManager(newErrorLogger(t), "localhost:0", ai, apmtest.DiscardTracer, newTestMonitoringMgr(), configuration.DefaultGRPCConfig()) + m, err := NewManager(newDebugLogger(t), "localhost:0", ai, apmtest.DiscardTracer, newTestMonitoringMgr(), configuration.DefaultGRPCConfig()) require.NoError(t, err) errCh := make(chan error) go func() { @@ -1778,7 +2089,7 @@ func TestManager_FakeShipper(t *testing.T) { defer cancel() ai, _ := info.NewAgentInfo(true) - m, err := NewManager(newErrorLogger(t), "localhost:0", ai, apmtest.DiscardTracer, newTestMonitoringMgr(), configuration.DefaultGRPCConfig()) + m, err := NewManager(newDebugLogger(t), "localhost:0", ai, apmtest.DiscardTracer, newTestMonitoringMgr(), configuration.DefaultGRPCConfig()) require.NoError(t, err) errCh := make(chan error) go func() { @@ -2033,11 +2344,12 @@ LOOP: require.NoError(t, err) } -func newErrorLogger(t *testing.T) *logger.Logger { +func newDebugLogger(t *testing.T) *logger.Logger { t.Helper() loggerCfg := logger.DefaultLoggingConfig() - loggerCfg.Level = logp.ErrorLevel + loggerCfg.Level = logp.DebugLevel + loggerCfg.ToStderr = true log, err := logger.NewFromConfig("", loggerCfg, false) require.NoError(t, err) diff --git a/pkg/component/runtime/runtime_comm.go b/pkg/component/runtime/runtime_comm.go index 57f1032db5b..386f67f6500 100644 --- a/pkg/component/runtime/runtime_comm.go +++ b/pkg/component/runtime/runtime_comm.go @@ -33,9 +33,11 @@ type Communicator interface { // to the provided services. WriteConnInfo(w io.Writer, services ...client.Service) error // CheckinExpected sends the expected state to the component. - CheckinExpected(expected *proto.CheckinExpected) - // ClearPendingCheckinExpected clears eny pending checkin expected messages. - ClearPendingCheckinExpected() + // + // observed is the observed message received from the component and what was used to compute the provided + // expected message. In the case that `CheckinExpected` is being called from a configuration change resulting + // in a previously observed message not being present then `nil` should be passed in for observed. + CheckinExpected(expected *proto.CheckinExpected, observed *proto.CheckinObserved) // CheckinObserved receives the observed state from the component. CheckinObserved() <-chan *proto.CheckinObserved } @@ -54,11 +56,13 @@ type runtimeComm struct { checkinDone chan bool checkinLock sync.RWMutex - checkinExpectedLock sync.Mutex - checkinExpected chan *proto.CheckinExpected - + checkinExpected chan *proto.CheckinExpected checkinObserved chan *proto.CheckinObserved + initCheckinObserved *proto.CheckinObserved + initCheckinExpectedCh chan *proto.CheckinExpected + initCheckinObservedMx sync.Mutex + actionsConn bool actionsDone chan bool actionsLock sync.RWMutex @@ -88,7 +92,7 @@ func newRuntimeComm(logger *logger.Logger, listenAddr string, ca *authority.Cert token: token.String(), cert: pair, checkinConn: true, - checkinExpected: make(chan *proto.CheckinExpected, 1), // size of 1 channel to keep the latest expected checkin state + checkinExpected: make(chan *proto.CheckinExpected, 10), // size of 10 gives a buffer for expected, only last is used checkinObserved: make(chan *proto.CheckinObserved), actionsConn: true, actionsRequest: make(chan *proto.ActionRequest), @@ -131,7 +135,7 @@ func (c *runtimeComm) WriteConnInfo(w io.Writer, services ...client.Service) err return nil } -func (c *runtimeComm) CheckinExpected(expected *proto.CheckinExpected) { +func (c *runtimeComm) CheckinExpected(expected *proto.CheckinExpected, observed *proto.CheckinObserved) { if c.agentInfo != nil && c.agentInfo.AgentID() != "" { expected.AgentInfo = &proto.CheckinAgentInfo{ Id: c.agentInfo.AgentID(), @@ -142,29 +146,48 @@ func (c *runtimeComm) CheckinExpected(expected *proto.CheckinExpected) { expected.AgentInfo = nil } - // Lock to avoid race if this function is called from the different go routines - c.checkinExpectedLock.Lock() - - // Empty the channel - c.ClearPendingCheckinExpected() + // we need to determine if the communicator is currently in the initial observed message path + // in the case that it is we send the expected state over a different channel + c.initCheckinObservedMx.Lock() + initObserved := c.initCheckinObserved + expectedCh := c.initCheckinExpectedCh + if initObserved != nil { + // the next call to `CheckinExpected` must be from the initial `CheckinObserved` message + if observed != initObserved { + // not the initial observed message; we don't send it + c.initCheckinObservedMx.Unlock() + return + } + // it is the expected from the initial observed message + // clear the initial state + c.initCheckinObserved = nil + c.initCheckinExpectedCh = nil + c.initCheckinObservedMx.Unlock() + expectedCh <- expected + return + } + c.initCheckinObservedMx.Unlock() - // Put the new expected state in + // not in the initial observed message path; send it over the standard channel c.checkinExpected <- expected - - c.checkinExpectedLock.Unlock() -} - -func (c *runtimeComm) ClearPendingCheckinExpected() { - select { - case <-c.checkinExpected: - default: - } } func (c *runtimeComm) CheckinObserved() <-chan *proto.CheckinObserved { return c.checkinObserved } +// latestCheckinExpected ensures that the latest expected checkin is used +func (c *runtimeComm) latestCheckinExpected(exp *proto.CheckinExpected) *proto.CheckinExpected { + latest := exp + for { + select { + case latest = <-c.checkinExpected: + default: + return latest + } + } +} + func (c *runtimeComm) checkin(server proto.ElasticAgent_CheckinV2Server, init *proto.CheckinObserved) error { c.checkinLock.Lock() if c.checkinDone != nil { @@ -190,12 +213,30 @@ func (c *runtimeComm) checkin(server proto.ElasticAgent_CheckinV2Server, init *p c.checkinLock.Unlock() }() + initExp := make(chan *proto.CheckinExpected) recvDone := make(chan bool) sendDone := make(chan bool) go func() { defer func() { close(sendDone) }() + + // initial startup waits for the first expected message from the dedicated initExp channel + select { + case <-checkinDone: + return + case <-recvDone: + return + case expected := <-initExp: + err := server.Send(expected) + if err != nil { + if reportableErr(err) { + c.logger.Debugf("check-in stream failed to send initial expected state: %s", err) + } + return + } + } + for { var expected *proto.CheckinExpected select { @@ -204,6 +245,7 @@ func (c *runtimeComm) checkin(server proto.ElasticAgent_CheckinV2Server, init *p case <-recvDone: return case expected = <-c.checkinExpected: + expected = c.latestCheckinExpected(expected) } err := server.Send(expected) @@ -216,6 +258,18 @@ func (c *runtimeComm) checkin(server proto.ElasticAgent_CheckinV2Server, init *p } }() + // at this point the client is connected, and it has sent it's first initial checkin + // the initial expected message must come before the sender goroutine will send any other + // expected messages. `CheckinExpected` method will also drop any expected messages that do not + // match the observed message to ensure that the expected that we receive is from the initial + // observed state. + c.initCheckinObservedMx.Lock() + c.initCheckinObserved = init + c.initCheckinExpectedCh = initExp + c.latestCheckinExpected(nil) // clears all queued expected messages + c.initCheckinObservedMx.Unlock() + + // send the initial message (manager then calls `CheckinExpected` method with the result) c.checkinObserved <- init go func() { diff --git a/pkg/component/runtime/service.go b/pkg/component/runtime/service.go index a032d9abd06..c480f3dbe72 100644 --- a/pkg/component/runtime/service.go +++ b/pkg/component/runtime/service.go @@ -114,7 +114,6 @@ func (s *ServiceRuntime) Run(ctx context.Context, comm Communicator) (err error) // Initial state on start lastCheckin = time.Time{} missedCheckins = 0 - comm.ClearPendingCheckinExpected() checkinTimer.Stop() cisStop() @@ -206,7 +205,7 @@ func (s *ServiceRuntime) stop(ctx context.Context, comm Communicator, lastChecki if checkedIn { s.log.Debugf("send stopping state to %s service", name) s.state.forceExpectedState(client.UnitStateStopping) - comm.CheckinExpected(s.state.toCheckinExpected()) + comm.CheckinExpected(s.state.toCheckinExpected(), nil) } else { s.log.Debugf("%s service had never checked in, proceed to uninstall", name) } @@ -251,7 +250,7 @@ func (s *ServiceRuntime) processNewComp(newComp component.Component, comm Commun sendExpected := s.state.syncExpected(&newComp) changed := s.state.syncUnits(&newComp) if sendExpected || s.state.unsettled() { - comm.CheckinExpected(s.state.toCheckinExpected()) + comm.CheckinExpected(s.state.toCheckinExpected(), nil) } if changed { s.sendObserved() @@ -288,7 +287,7 @@ func (s *ServiceRuntime) processCheckin(checkin *proto.CheckinObserved, comm Com sendExpected = true } if sendExpected { - comm.CheckinExpected(s.state.toCheckinExpected()) + comm.CheckinExpected(s.state.toCheckinExpected(), checkin) } if changed { s.sendObserved()