diff --git a/internal/app/machined/pkg/controllers/secrets/api.go b/internal/app/machined/pkg/controllers/secrets/api.go index c6e745cb47..c4c679e056 100644 --- a/internal/app/machined/pkg/controllers/secrets/api.go +++ b/internal/app/machined/pkg/controllers/secrets/api.go @@ -374,7 +374,7 @@ func (ctrl *APIController) generateJoin(ctx context.Context, r controller.Runtim var ca []byte - ca, serverCert.Crt, err = remoteGen.Identity(serverCSR) + ca, serverCert.Crt, err = remoteGen.IdentityContext(ctx, serverCSR) if err != nil { return fmt.Errorf("failed to sign API server CSR: %w", err) } @@ -387,7 +387,7 @@ func (ctrl *APIController) generateJoin(ctx context.Context, r controller.Runtim return fmt.Errorf("failed to generate API client CSR: %w", err) } - _, clientCert.Crt, err = remoteGen.Identity(clientCSR) + _, clientCert.Crt, err = remoteGen.IdentityContext(ctx, clientCSR) if err != nil { return fmt.Errorf("failed to sign API client CSR: %w", err) } diff --git a/pkg/grpc/gen/remote.go b/pkg/grpc/gen/remote.go index 9e35026aee..4c405cc296 100644 --- a/pkg/grpc/gen/remote.go +++ b/pkg/grpc/gen/remote.go @@ -7,12 +7,11 @@ package gen import ( "context" "fmt" - "log" "strings" - "sync" "time" "github.com/talos-systems/crypto/x509" + "github.com/talos-systems/go-retry/retry" "google.golang.org/grpc" "github.com/talos-systems/talos/pkg/grpc/middleware/auth/basic" @@ -29,11 +28,6 @@ func init() { // RemoteGenerator represents the OS identity generator. type RemoteGenerator struct { - done chan struct{} - creds basic.Credentials - - // connMu protects conn & client - connMu sync.Mutex conn *grpc.ClientConn client securityapi.SecurityServiceClient } @@ -44,45 +38,51 @@ func NewRemoteGenerator(token string, endpoints []string) (g *RemoteGenerator, e return nil, fmt.Errorf("at least one root of trust endpoint is required") } - g = &RemoteGenerator{ - done: make(chan struct{}), - creds: basic.NewTokenCredentials(token), - } - - if err = g.SetEndpoints(endpoints); err != nil { - return nil, err - } - - return g, nil -} + g = &RemoteGenerator{} -// SetEndpoints updates the list of endpoints to talk to. -func (g *RemoteGenerator) SetEndpoints(endpoints []string) error { - conn, err := basic.NewConnection(fmt.Sprintf("%s:///%s", trustdResolverScheme, strings.Join(endpoints, ",")), g.creds) + conn, err := basic.NewConnection(fmt.Sprintf("%s:///%s", trustdResolverScheme, strings.Join(endpoints, ",")), basic.NewTokenCredentials(token)) if err != nil { - return err - } - - g.connMu.Lock() - defer g.connMu.Unlock() - - if g.conn != nil { - g.conn.Close() //nolint:errcheck + return nil, err } g.conn = conn g.client = securityapi.NewSecurityServiceClient(g.conn) - return nil + return g, nil } // Identity creates an identity certificate via the security API. func (g *RemoteGenerator) Identity(csr *x509.CertificateSigningRequest) (ca, crt []byte, err error) { + return g.IdentityContext(context.Background(), csr) +} + +// IdentityContext creates an identity certificate via the security API. +func (g *RemoteGenerator) IdentityContext(ctx context.Context, csr *x509.CertificateSigningRequest) (ca, crt []byte, err error) { req := &securityapi.CertificateRequest{ Csr: csr.X509CertificateRequestPEM, } - ca, crt, err = g.poll(req) + ctx, cancel := context.WithTimeout(ctx, 5*time.Minute) + defer cancel() + + err = retry.Exponential(5*time.Minute, + retry.WithAttemptTimeout(30*time.Second), + retry.WithUnits(5*time.Second), + retry.WithJitter(100*time.Millisecond), + ).RetryWithContext(ctx, func(ctx context.Context) error { + var resp *securityapi.CertificateResponse + + resp, err = g.client.Certificate(ctx, req) + if err != nil { + return retry.ExpectedError(err) + } + + ca = resp.Ca + crt = resp.Crt + + return nil + }) + if err != nil { return nil, nil, err } @@ -92,49 +92,5 @@ func (g *RemoteGenerator) Identity(csr *x509.CertificateSigningRequest) (ca, crt // Close closes the gRPC client connection. func (g *RemoteGenerator) Close() error { - g.done <- struct{}{} - return g.conn.Close() } - -func (g *RemoteGenerator) certificate(in *securityapi.CertificateRequest) (resp *securityapi.CertificateResponse, err error) { - g.connMu.Lock() - defer g.connMu.Unlock() - - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - return g.client.Certificate(ctx, in) -} - -func (g *RemoteGenerator) poll(in *securityapi.CertificateRequest) (ca, crt []byte, err error) { - // TODO: rewrite with retry package - timeout := time.NewTimer(time.Minute * 5) - defer timeout.Stop() - - tick := time.NewTicker(time.Second * 5) - defer tick.Stop() - - for { - select { - case <-timeout.C: - return nil, nil, fmt.Errorf("timeout waiting for certificate") - case <-tick.C: - var resp *securityapi.CertificateResponse - - resp, err = g.certificate(in) - if err != nil { - log.Println(err) - - continue - } - - ca = resp.Ca - crt = resp.Crt - - return ca, crt, nil - case <-g.done: - return nil, nil, nil - } - } -}