Skip to content

Commit

Permalink
allow passing context to GetCertificate
Browse files Browse the repository at this point in the history
Add a new GetCertificateWithContext to allow passing an existing
context.  Also export the default dialer method, so that callers that
override TailscaledDialer can fall back to this if other methods fail.

Signed-off-by: Will Norris <will@tailscale.com>
  • Loading branch information
willnorris committed May 17, 2024
1 parent 28a91b6 commit bbccfbf
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions tscert.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@ var (

// TailscaledDialer is the DialContext func that connects to the local machine's
// tailscaled or equivalent.
TailscaledDialer = defaultDialer
TailscaledDialer = DialLocalAPI
)

func defaultDialer(ctx context.Context, network, addr string) (net.Conn, error) {
// DialLocalAPI connects to the LocalAPI server of the tailscaled instance on the machine.
func DialLocalAPI(ctx context.Context, network, addr string) (net.Conn, error) {
if addr != "local-tailscaled.sock:80" {
return nil, fmt.Errorf("unexpected URL address %q", addr)
}
Expand Down Expand Up @@ -236,13 +237,21 @@ func CertPair(ctx context.Context, domain string) (certPEM, keyPEM []byte, err e
//
// It returns a cached certificate from disk if it's still valid.
//
// It's the right signature to use as the value of
// tls.Config.GetCertificate.
// It's the right signature to use as the value of tls.Config.GetCertificate.
func GetCertificate(hi *tls.ClientHelloInfo) (*tls.Certificate, error) {
return GetCertificateWithContext(context.Background(), hi)
}

// GetCertificateWithContext fetches a TLS certificate for the TLS ClientHello in hi.
//
// It returns a cached certificate from disk if it's still valid.
//
// Use GetCertificate instead if a value for tls.Config.GetCertificate is needed.
func GetCertificateWithContext(ctx context.Context, hi *tls.ClientHelloInfo) (*tls.Certificate, error) {
if hi == nil || hi.ServerName == "" {
return nil, errors.New("no SNI ServerName")
}
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel()

name := hi.ServerName
Expand Down

0 comments on commit bbccfbf

Please sign in to comment.