From bbccfbf48933da858ca2f459342f2457be107029 Mon Sep 17 00:00:00 2001 From: Will Norris Date: Fri, 17 May 2024 15:24:02 -0700 Subject: [PATCH] allow passing context to GetCertificate Add a new GetCertificateWithContext to allow passing an existing context. Also export the default dialer method, so that callers that override TailscaledDialer can fall back to this if other methods fail. Signed-off-by: Will Norris --- tscert.go | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/tscert.go b/tscert.go index 225f354..52df122 100644 --- a/tscert.go +++ b/tscert.go @@ -35,10 +35,11 @@ var ( // TailscaledDialer is the DialContext func that connects to the local machine's // tailscaled or equivalent. - TailscaledDialer = defaultDialer + TailscaledDialer = DialLocalAPI ) -func defaultDialer(ctx context.Context, network, addr string) (net.Conn, error) { +// DialLocalAPI connects to the LocalAPI server of the tailscaled instance on the machine. +func DialLocalAPI(ctx context.Context, network, addr string) (net.Conn, error) { if addr != "local-tailscaled.sock:80" { return nil, fmt.Errorf("unexpected URL address %q", addr) } @@ -236,13 +237,21 @@ func CertPair(ctx context.Context, domain string) (certPEM, keyPEM []byte, err e // // It returns a cached certificate from disk if it's still valid. // -// It's the right signature to use as the value of -// tls.Config.GetCertificate. +// It's the right signature to use as the value of tls.Config.GetCertificate. func GetCertificate(hi *tls.ClientHelloInfo) (*tls.Certificate, error) { + return GetCertificateWithContext(context.Background(), hi) +} + +// GetCertificateWithContext fetches a TLS certificate for the TLS ClientHello in hi. +// +// It returns a cached certificate from disk if it's still valid. +// +// Use GetCertificate instead if a value for tls.Config.GetCertificate is needed. +func GetCertificateWithContext(ctx context.Context, hi *tls.ClientHelloInfo) (*tls.Certificate, error) { if hi == nil || hi.ServerName == "" { return nil, errors.New("no SNI ServerName") } - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + ctx, cancel := context.WithTimeout(ctx, time.Minute) defer cancel() name := hi.ServerName