diff --git a/connector/oidc/oidc.go b/connector/oidc/oidc.go index e345dca0b2..cf055f832a 100644 --- a/connector/oidc/oidc.go +++ b/connector/oidc/oidc.go @@ -3,9 +3,13 @@ package oidc import ( "context" + "crypto/tls" + "crypto/x509" "encoding/json" "errors" "fmt" + "io/ioutil" + "net" "net/http" "net/url" "strings" @@ -34,7 +38,14 @@ type Config struct { Scopes []string `json:"scopes"` // defaults to "profile" and "email" - // Override the value of email_verified to true in the returned claims + // Optional list of whitelisted domains when using Google + // If this field is nonempty, only users from a listed domain will be allowed to log in + HostedDomains []string `json:"hostedDomains"` + + // Certificates for SSL validation + RootCAs []string `json:"rootCAs"` + + // Override the value of email_verifed to true in the returned claims InsecureSkipEmailVerified bool `json:"insecureSkipEmailVerified"` // InsecureEnableGroups enables groups claims. This is disabled by default until https://github.com/dexidp/dex/issues/1065 is resolved @@ -45,6 +56,9 @@ type Config struct { // processing requests from this Client, with the values appearing in order of preference. AcrValues []string `json:"acrValues"` + // Disable certificate verification + InsecureSkipVerify bool `json:"insecureSkipVerify"` + // GetUserInfo uses the userinfo endpoint to get additional claims for // the token. This is especially useful where upstreams return "thin" // id tokens @@ -105,7 +119,13 @@ func knownBrokenAuthHeaderProvider(issuerURL string) bool { // Open returns a connector which can be used to login users through an upstream // OpenID Connect provider. func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, err error) { + httpClient, err := newHTTPClient(c.RootCAs, c.InsecureSkipVerify) + if err != nil { + return nil, err + } + ctx, cancel := context.WithCancel(context.Background()) + ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) provider, err := oidc.NewProvider(ctx, c.Issuer) if err != nil { @@ -152,6 +172,8 @@ func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, e ), logger: logger, cancel: cancel, + httpClient: httpClient, + hostedDomains: c.HostedDomains, insecureSkipEmailVerified: c.InsecureSkipEmailVerified, insecureEnableGroups: c.InsecureEnableGroups, acrValues: c.AcrValues, @@ -166,6 +188,40 @@ func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, e }, nil } +func newHTTPClient(rootCAs []string, insecureSkipVerify bool) (*http.Client, error) { + pool, err := x509.SystemCertPool() + if err != nil { + return nil, err + } + + tlsConfig := tls.Config{RootCAs: pool, InsecureSkipVerify: insecureSkipVerify} + for _, rootCA := range rootCAs { + rootCABytes, err := ioutil.ReadFile(rootCA) + if err != nil { + return nil, fmt.Errorf("failed to read root-ca: %v", err) + } + if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCABytes) { + return nil, fmt.Errorf("no certs found in root CA file %q", rootCA) + } + } + + return &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tlsConfig, + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + DualStack: true, + }).DialContext, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + }, + }, nil +} + var ( _ connector.CallbackConnector = (*oidcConnector)(nil) _ connector.RefreshConnector = (*oidcConnector)(nil) @@ -178,6 +234,8 @@ type oidcConnector struct { verifier *oidc.IDTokenVerifier cancel context.CancelFunc logger log.Logger + httpClient *http.Client + hostedDomains []string insecureSkipEmailVerified bool insecureEnableGroups bool acrValues []string @@ -238,7 +296,10 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide if errType := q.Get("error"); errType != "" { return identity, &oauth2Error{errType, q.Get("error_description")} } - token, err := c.oauth2Config.Exchange(r.Context(), q.Get("code")) + + ctx := context.WithValue(r.Context(), oauth2.HTTPClient, c.httpClient) + + token, err := c.oauth2Config.Exchange(ctx, q.Get("code")) if err != nil { return identity, fmt.Errorf("oidc: failed to get token: %v", err) }