diff --git a/cmd/envbuilder/main.go b/cmd/envbuilder/main.go index 410e0897..e8dc2201 100644 --- a/cmd/envbuilder/main.go +++ b/cmd/envbuilder/main.go @@ -37,6 +37,15 @@ func envbuilderCmd() serpent.Command { Options: o.CLI(), Handler: func(inv *serpent.Invocation) error { o.SetDefaults() + var preExecs []func() + preExec := func() { + for _, fn := range preExecs { + fn() + } + preExecs = nil + } + defer preExec() // Ensure cleanup in case of error. + o.Logger = log.New(os.Stderr, o.Verbose) if o.CoderAgentURL != "" { if o.CoderAgentToken == "" { @@ -49,7 +58,9 @@ func envbuilderCmd() serpent.Command { coderLog, closeLogs, err := log.Coder(inv.Context(), u, o.CoderAgentToken) if err == nil { o.Logger = log.Wrap(o.Logger, coderLog) - defer closeLogs() + preExecs = append(preExecs, func() { + closeLogs() + }) // This adds the envbuilder subsystem. // If telemetry is enabled in a Coder deployment, // this will be reported and help us understand @@ -78,7 +89,7 @@ func envbuilderCmd() serpent.Command { return nil } - err := envbuilder.Run(inv.Context(), o) + err := envbuilder.Run(inv.Context(), o, preExec) if err != nil { o.Logger(log.LevelError, "error: %s", err) } diff --git a/envbuilder.go b/envbuilder.go index 683f6a54..94998165 100644 --- a/envbuilder.go +++ b/envbuilder.go @@ -84,7 +84,9 @@ type execArgsInfo struct { // Logger is the logf to use for all operations. // Filesystem is the filesystem to use for all operations. // Defaults to the host filesystem. -func Run(ctx context.Context, opts options.Options) error { +// preExec are any functions that should be called before exec'ing the init +// command. This is useful for ensuring that defers get run. +func Run(ctx context.Context, opts options.Options, preExec ...func()) error { var args execArgsInfo // Run in a separate function to ensure all defers run before we // setuid or exec. @@ -103,6 +105,9 @@ func Run(ctx context.Context, opts options.Options) error { } opts.Logger(log.LevelInfo, "=== Running the init command %s %+v as the %q user...", opts.InitCommand, args.InitArgs, args.UserInfo.user.Username) + for _, fn := range preExec { + fn() + } err = syscall.Exec(args.InitCommand, append([]string{args.InitCommand}, args.InitArgs...), args.Environ) if err != nil { diff --git a/go.mod b/go.mod index b3fa7843..9fa1d696 100644 --- a/go.mod +++ b/go.mod @@ -25,6 +25,7 @@ require ( github.com/gliderlabs/ssh v0.3.7 github.com/go-git/go-billy/v5 v5.5.0 github.com/go-git/go-git/v5 v5.12.0 + github.com/google/go-cmp v0.6.0 github.com/google/go-containerregistry v0.20.1 github.com/google/uuid v1.6.0 github.com/hashicorp/go-multierror v1.1.1 @@ -149,7 +150,6 @@ require ( github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/google/btree v1.1.2 // indirect - github.com/google/go-cmp v0.6.0 // indirect github.com/google/nftables v0.2.0 // indirect github.com/google/pprof v0.0.0-20230817174616-7a8ec2ada47b // indirect github.com/gorilla/handlers v1.5.1 // indirect diff --git a/integration/integration_test.go b/integration/integration_test.go index 66dfe846..cfe1de49 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -23,6 +23,8 @@ import ( "testing" "time" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/envbuilder" "github.com/coder/envbuilder/devcontainer/features" "github.com/coder/envbuilder/internal/magicdir" @@ -58,6 +60,71 @@ const ( testImageUbuntu = "localhost:5000/envbuilder-test-ubuntu:latest" ) +func TestLogs(t *testing.T) { + t.Parallel() + + token := uuid.NewString() + logsDone := make(chan struct{}) + + logHandler := func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/v2/buildinfo": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"version": "v2.8.9"}`)) + return + case "/api/v2/workspaceagents/me/logs": + w.WriteHeader(http.StatusOK) + tokHdr := r.Header.Get(codersdk.SessionTokenHeader) + assert.Equal(t, token, tokHdr) + var req agentsdk.PatchLogs + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + for _, log := range req.Logs { + t.Logf("got log: %+v", log) + if strings.Contains(log.Output, "Running the init command") { + close(logsDone) + return + } + } + return + default: + t.Errorf("unexpected request to %s", r.URL.Path) + w.WriteHeader(http.StatusNotFound) + return + } + } + logSrv := httptest.NewServer(http.HandlerFunc(logHandler)) + defer logSrv.Close() + + // Ensures that a Git repository with a devcontainer.json is cloned and built. + srv := gittest.CreateGitServer(t, gittest.Options{ + Files: map[string]string{ + "devcontainer.json": `{ + "build": { + "dockerfile": "Dockerfile" + }, + }`, + "Dockerfile": fmt.Sprintf(`FROM %s`, testImageUbuntu), + }, + }) + _, err := runEnvbuilder(t, runOpts{env: []string{ + envbuilderEnv("GIT_URL", srv.URL), + "CODER_AGENT_URL=" + logSrv.URL, + "CODER_AGENT_TOKEN=" + token, + }}) + require.NoError(t, err) + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + select { + case <-ctx.Done(): + t.Fatal("timed out waiting for logs") + case <-logsDone: + } +} + func TestInitScriptInitCommand(t *testing.T) { t.Parallel() diff --git a/log/coder.go b/log/coder.go index d8b4fe0d..d31092d5 100644 --- a/log/coder.go +++ b/log/coder.go @@ -6,6 +6,7 @@ import ( "fmt" "net/url" "os" + "sync" "time" "cdr.dev/slog" @@ -27,13 +28,14 @@ var ( minAgentAPIV2 = "v2.9" ) -// Coder establishes a connection to the Coder instance located at -// coderURL and authenticates using token. It then establishes a -// dRPC connection to the Agent API and begins sending logs. -// If the version of Coder does not support the Agent API, it will -// fall back to using the PatchLogs endpoint. -// The returned function is used to block until all logs are sent. -func Coder(ctx context.Context, coderURL *url.URL, token string) (Func, func(), error) { +// Coder establishes a connection to the Coder instance located at coderURL and +// authenticates using token. It then establishes a dRPC connection to the Agent +// API and begins sending logs. If the version of Coder does not support the +// Agent API, it will fall back to using the PatchLogs endpoint. The closer is +// used to close the logger and to wait at most logSendGracePeriod for logs to +// be sent. Cancelling the context will close the logs immediately without +// waiting for logs to be sent. +func Coder(ctx context.Context, coderURL *url.URL, token string) (logger Func, closer func(), err error) { // To troubleshoot issues, we need some way of logging. metaLogger := slog.Make(sloghuman.Sink(os.Stderr)) defer metaLogger.Sync() @@ -44,9 +46,19 @@ func Coder(ctx context.Context, coderURL *url.URL, token string) (Func, func(), } if semver.Compare(semver.MajorMinor(bi.Version), minAgentAPIV2) < 0 { metaLogger.Warn(ctx, "Detected Coder version incompatible with AgentAPI v2, falling back to deprecated API", slog.F("coder_version", bi.Version)) - sendLogs, flushLogs := sendLogsV1(ctx, client, metaLogger.Named("send_logs_v1")) - return sendLogs, flushLogs, nil + logger, closer = sendLogsV1(ctx, client, metaLogger.Named("send_logs_v1")) + return logger, closer, nil } + + // Create a new context so we can ensure the connection is torn down. + ctx, cancel := context.WithCancel(ctx) + defer func() { + if err != nil { + cancel() + } + }() + // Note that ctx passed to initRPC will be inherited by the + // underlying connection, nothing we can do about that here. dac, err := initRPC(ctx, client, metaLogger.Named("init_rpc")) if err != nil { // Logged externally @@ -54,8 +66,19 @@ func Coder(ctx context.Context, coderURL *url.URL, token string) (Func, func(), } ls := agentsdk.NewLogSender(metaLogger.Named("coder_log_sender")) metaLogger.Warn(ctx, "Sending logs via AgentAPI v2", slog.F("coder_version", bi.Version)) - sendLogs, doneFunc := sendLogsV2(ctx, dac, ls, metaLogger.Named("send_logs_v2")) - return sendLogs, doneFunc, nil + logger, loggerCloser := sendLogsV2(ctx, dac, ls, metaLogger.Named("send_logs_v2")) + var closeOnce sync.Once + closer = func() { + loggerCloser() + + closeOnce.Do(func() { + // Typically cancel would be after Close, but we want to be + // sure there's nothing that might block on Close. + cancel() + _ = dac.DRPCConn().Close() + }) + } + return logger, closer, nil } type coderLogSender interface { @@ -74,7 +97,7 @@ func initClient(coderURL *url.URL, token string) *agentsdk.Client { func initRPC(ctx context.Context, client *agentsdk.Client, l slog.Logger) (proto.DRPCAgentClient20, error) { var c proto.DRPCAgentClient20 var err error - retryCtx, retryCancel := context.WithTimeout(context.Background(), rpcConnectTimeout) + retryCtx, retryCancel := context.WithTimeout(ctx, rpcConnectTimeout) defer retryCancel() attempts := 0 for r := retry.New(100*time.Millisecond, time.Second); r.Wait(retryCtx); { @@ -95,65 +118,67 @@ func initRPC(ctx context.Context, client *agentsdk.Client, l slog.Logger) (proto // sendLogsV1 uses the PatchLogs endpoint to send logs. // This is deprecated, but required for backward compatibility with older versions of Coder. -func sendLogsV1(ctx context.Context, client *agentsdk.Client, l slog.Logger) (Func, func()) { +func sendLogsV1(ctx context.Context, client *agentsdk.Client, l slog.Logger) (logger Func, closer func()) { // nolint: staticcheck // required for backwards compatibility - sendLogs, flushLogs := agentsdk.LogsSender(agentsdk.ExternalLogSourceID, client.PatchLogs, slog.Logger{}) + sendLog, flushAndClose := agentsdk.LogsSender(agentsdk.ExternalLogSourceID, client.PatchLogs, slog.Logger{}) + var mu sync.Mutex return func(lvl Level, msg string, args ...any) { log := agentsdk.Log{ CreatedAt: time.Now(), Output: fmt.Sprintf(msg, args...), Level: codersdk.LogLevel(lvl), } - if err := sendLogs(ctx, log); err != nil { + mu.Lock() + defer mu.Unlock() + if err := sendLog(ctx, log); err != nil { l.Warn(ctx, "failed to send logs to Coder", slog.Error(err)) } }, func() { - if err := flushLogs(ctx); err != nil { + ctx, cancel := context.WithTimeout(ctx, logSendGracePeriod) + defer cancel() + if err := flushAndClose(ctx); err != nil { l.Warn(ctx, "failed to flush logs", slog.Error(err)) } } } // sendLogsV2 uses the v2 agent API to send logs. Only compatibile with coder versions >= 2.9. -func sendLogsV2(ctx context.Context, dest agentsdk.LogDest, ls coderLogSender, l slog.Logger) (Func, func()) { +func sendLogsV2(ctx context.Context, dest agentsdk.LogDest, ls coderLogSender, l slog.Logger) (logger Func, closer func()) { + sendCtx, sendCancel := context.WithCancel(ctx) done := make(chan struct{}) uid := uuid.New() go func() { defer close(done) - if err := ls.SendLoop(ctx, dest); err != nil { + if err := ls.SendLoop(sendCtx, dest); err != nil { if !errors.Is(err, context.Canceled) { l.Warn(ctx, "failed to send logs to Coder", slog.Error(err)) } } - - // Wait for up to 10 seconds for logs to finish sending. - sendCtx, sendCancel := context.WithTimeout(context.Background(), logSendGracePeriod) - defer sendCancel() - // Try once more to send any pending logs - if err := ls.SendLoop(sendCtx, dest); err != nil { - if !errors.Is(err, context.DeadlineExceeded) { - l.Warn(ctx, "failed to send remaining logs to Coder", slog.Error(err)) - } - } - ls.Flush(uid) - if err := ls.WaitUntilEmpty(sendCtx); err != nil { - if !errors.Is(err, context.DeadlineExceeded) { - l.Warn(ctx, "log sender did not empty", slog.Error(err)) - } - } }() - logFunc := func(l Level, msg string, args ...any) { - ls.Enqueue(uid, agentsdk.Log{ - CreatedAt: time.Now(), - Output: fmt.Sprintf(msg, args...), - Level: codersdk.LogLevel(l), - }) - } + var closeOnce sync.Once + return func(l Level, msg string, args ...any) { + ls.Enqueue(uid, agentsdk.Log{ + CreatedAt: time.Now(), + Output: fmt.Sprintf(msg, args...), + Level: codersdk.LogLevel(l), + }) + }, func() { + closeOnce.Do(func() { + // Trigger a flush and wait for logs to be sent. + ls.Flush(uid) + ctx, cancel := context.WithTimeout(ctx, logSendGracePeriod) + defer cancel() + err := ls.WaitUntilEmpty(ctx) + if err != nil { + l.Warn(ctx, "log sender did not empty", slog.Error(err)) + } - doneFunc := func() { - <-done - } + // Stop the send loop. + sendCancel() + }) - return logFunc, doneFunc + // Wait for the send loop to finish. + <-done + } } diff --git a/log/coder_internal_test.go b/log/coder_internal_test.go index 4895150e..8b8bb632 100644 --- a/log/coder_internal_test.go +++ b/log/coder_internal_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "math/rand" "net/http" "net/http/httptest" "net/url" @@ -38,10 +39,8 @@ func TestCoder(t *testing.T) { defer closeOnce.Do(func() { close(gotLogs) }) tokHdr := r.Header.Get(codersdk.SessionTokenHeader) assert.Equal(t, token, tokHdr) - var req agentsdk.PatchLogs - err := json.NewDecoder(r.Body).Decode(&req) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) + req, ok := decodeV1Logs(t, w, r) + if !ok { return } if assert.Len(t, req.Logs, 1) { @@ -54,15 +53,44 @@ func TestCoder(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - u, err := url.Parse(srv.URL) - require.NoError(t, err) - log, closeLog, err := Coder(ctx, u, token) - require.NoError(t, err) - defer closeLog() - log(LevelInfo, "hello %s", "world") + + logger, _ := newCoderLogger(ctx, t, srv.URL, token) + logger(LevelInfo, "hello %s", "world") <-gotLogs }) + t.Run("V1/Close", func(t *testing.T) { + t.Parallel() + + var got []agentsdk.Log + handler := func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/v2/buildinfo" { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"version": "v2.8.9"}`)) + return + } + req, ok := decodeV1Logs(t, w, r) + if !ok { + return + } + got = append(got, req.Logs...) + } + srv := httptest.NewServer(http.HandlerFunc(handler)) + defer srv.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + logger, closer := newCoderLogger(ctx, t, srv.URL, uuid.NewString()) + logger(LevelInfo, "1") + logger(LevelInfo, "2") + closer() + logger(LevelInfo, "3") + require.Len(t, got, 2) + assert.Equal(t, "1", got[0].Output) + assert.Equal(t, "2", got[1].Output) + }) + t.Run("V1/ErrUnauthorized", func(t *testing.T) { t.Parallel() @@ -140,42 +168,31 @@ func TestCoder(t *testing.T) { require.Len(t, ld.logs, 10) }) - // In this test, we just stand up an endpoint that does not - // do dRPC. We'll try to connect, fail to websocket upgrade - // and eventually give up. - t.Run("V2/Err", func(t *testing.T) { + // In this test, we just fake out the DRPC server. + t.Run("V2/Close", func(t *testing.T) { t.Parallel() - token := uuid.NewString() - handlerDone := make(chan struct{}) - var closeOnce sync.Once - handler := func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/api/v2/buildinfo" { - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"version": "v2.9.0"}`)) - return - } - defer closeOnce.Do(func() { close(handlerDone) }) - w.WriteHeader(http.StatusOK) - } - srv := httptest.NewServer(http.HandlerFunc(handler)) - defer srv.Close() - - ctx, cancel := context.WithTimeout(context.Background(), time.Second) + ctx, cancel := context.WithCancel(context.Background()) defer cancel() - u, err := url.Parse(srv.URL) - require.NoError(t, err) - _, _, err = Coder(ctx, u, token) - require.ErrorContains(t, err, "failed to WebSocket dial") - require.ErrorIs(t, err, context.DeadlineExceeded) - <-handlerDone + + ld := &fakeLogDest{t: t} + ls := agentsdk.NewLogSender(slogtest.Make(t, nil)) + logger, closer := sendLogsV2(ctx, ld, ls, slogtest.Make(t, nil)) + defer closer() + + logger(LevelInfo, "1") + logger(LevelInfo, "2") + closer() + logger(LevelInfo, "3") + + require.Len(t, ld.logs, 2) }) // In this test, we validate that a 401 error on the initial connect // results in a retry. When envbuilder initially attempts to connect // using the Coder agent token, the workspace build may not yet have // completed. - t.Run("V2Retry", func(t *testing.T) { + t.Run("V2/Retry", func(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() @@ -221,6 +238,99 @@ func TestCoder(t *testing.T) { }) } +//nolint:paralleltest // We need to replace a global timeout. +func TestCoderRPCTimeout(t *testing.T) { + // This timeout is picked with the current subtests in mind, it + // should not be changed without good reason. + testReplaceTimeout(t, &rpcConnectTimeout, 500*time.Millisecond) + + // In this test, we just stand up an endpoint that does not + // do dRPC. We'll try to connect, fail to websocket upgrade + // and eventually give up after rpcConnectTimeout. + t.Run("V2/Err", func(t *testing.T) { + t.Parallel() + + token := uuid.NewString() + handlerDone := make(chan struct{}) + handlerWait := make(chan struct{}) + var closeOnce sync.Once + handler := func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/v2/buildinfo" { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"version": "v2.9.0"}`)) + return + } + defer closeOnce.Do(func() { close(handlerDone) }) + <-handlerWait + w.WriteHeader(http.StatusOK) + } + srv := httptest.NewServer(http.HandlerFunc(handler)) + defer srv.Close() + + ctx, cancel := context.WithTimeout(context.Background(), rpcConnectTimeout/2) + defer cancel() + u, err := url.Parse(srv.URL) + require.NoError(t, err) + _, _, err = Coder(ctx, u, token) + require.ErrorContains(t, err, "failed to WebSocket dial") + require.ErrorIs(t, err, context.DeadlineExceeded) + close(handlerWait) + <-handlerDone + }) + + t.Run("V2/Timeout", func(t *testing.T) { + t.Parallel() + + token := uuid.NewString() + handlerDone := make(chan struct{}) + handlerWait := make(chan struct{}) + var closeOnce sync.Once + handler := func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/v2/buildinfo" { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"version": "v2.9.0"}`)) + return + } + defer closeOnce.Do(func() { close(handlerDone) }) + <-handlerWait + w.WriteHeader(http.StatusOK) + } + srv := httptest.NewServer(http.HandlerFunc(handler)) + defer srv.Close() + + ctx, cancel := context.WithTimeout(context.Background(), rpcConnectTimeout*2) + defer cancel() + u, err := url.Parse(srv.URL) + require.NoError(t, err) + _, _, err = Coder(ctx, u, token) + require.ErrorContains(t, err, "failed to WebSocket dial") + require.ErrorIs(t, err, context.DeadlineExceeded) + close(handlerWait) + <-handlerDone + }) +} + +func decodeV1Logs(t *testing.T, w http.ResponseWriter, r *http.Request) (agentsdk.PatchLogs, bool) { + t.Helper() + var req agentsdk.PatchLogs + err := json.NewDecoder(r.Body).Decode(&req) + if !assert.NoError(t, err) { + http.Error(w, err.Error(), http.StatusBadRequest) + return req, false + } + return req, true +} + +func newCoderLogger(ctx context.Context, t *testing.T, us string, token string) (Func, func()) { + t.Helper() + u, err := url.Parse(us) + require.NoError(t, err) + logger, closer, err := Coder(ctx, u, token) + require.NoError(t, err) + t.Cleanup(closer) + return logger, closer +} + type fakeLogDest struct { t testing.TB logs []*proto.Log @@ -231,3 +341,27 @@ func (d *fakeLogDest) BatchCreateLogs(ctx context.Context, request *proto.BatchC d.logs = append(d.logs, request.Logs...) return &proto.BatchCreateLogsResponse{}, nil } + +func testReplaceTimeout(t *testing.T, v *time.Duration, d time.Duration) { + t.Helper() + if isParallel(t) { + t.Fatal("cannot replace timeout in parallel test") + } + old := *v + *v = d + t.Cleanup(func() { *v = old }) +} + +func isParallel(t *testing.T) (ret bool) { + t.Helper() + // This is a hack to determine if the test is running in parallel + // via property of t.Setenv. + defer func() { + if r := recover(); r != nil { + ret = true + } + }() + // Random variable name to avoid collisions. + t.Setenv(fmt.Sprintf("__TEST_CHECK_IS_PARALLEL_%d", rand.Int()), "1") + return false +}