Skip to content

Commit

Permalink
refactor newHTTPClient func
Browse files Browse the repository at this point in the history
* extract common newHTTPClient func out to its own package
* add test for testing root CAs in the constructor.
* test certs are set to expired in 10 years

Signed-off-by: Rui Yang <ruiya@vmware.com>
  • Loading branch information
Rui Yang committed Oct 8, 2022
1 parent c2859da commit 5410c7a
Show file tree
Hide file tree
Showing 16 changed files with 328 additions and 160 deletions.
44 changes: 7 additions & 37 deletions connector/github/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,21 @@ package github

import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"os"
"regexp"
"strconv"
"strings"
"time"

"golang.org/x/oauth2"
"golang.org/x/oauth2/github"

"github.com/dexidp/dex/connector"
groups_pkg "github.com/dexidp/dex/pkg/groups"
"github.com/dexidp/dex/pkg/httpclient"
"github.com/dexidp/dex/pkg/log"
)

Expand Down Expand Up @@ -106,7 +102,7 @@ func (c *Config) Open(id string, logger log.Logger) (connector.Connector, error)
g.rootCA = c.RootCA

var err error
if g.httpClient, err = newHTTPClient(g.rootCA); err != nil {
if g.httpClient, err = httpclient.NewHTTPClient([]string{g.rootCA}, false); err != nil {
return nil, fmt.Errorf("failed to create HTTP client: %v", err)
}
}
Expand Down Expand Up @@ -208,34 +204,6 @@ func (e *oauth2Error) Error() string {
return e.error + ": " + e.errorDescription
}

// newHTTPClient returns a new HTTP client that trusts the custom declared rootCA cert.
func newHTTPClient(rootCA string) (*http.Client, error) {
tlsConfig := tls.Config{RootCAs: x509.NewCertPool()}
rootCABytes, err := os.ReadFile(rootCA)
if err != nil {
return nil, fmt.Errorf("failed to read root-ca: %v", err)
}
if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCABytes) {
return nil, fmt.Errorf("no certs found in root CA file %q", rootCA)
}

return &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tlsConfig,
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
},
}, nil
}

func (c *githubConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) {
q := r.URL.Query()
if errType := q.Get("error"); errType != "" {
Expand Down Expand Up @@ -356,9 +324,11 @@ func formatTeamName(org string, team string) string {

// groupsForOrgs enforces org and team constraints on user authorization
// Cases in which user is authorized:
// N orgs, no teams: user is member of at least 1 org
// N orgs, M teams per org: user is member of any team from at least 1 org
// N-1 orgs, M teams per org, 1 org with no teams: user is member of any team
//
// N orgs, no teams: user is member of at least 1 org
// N orgs, M teams per org: user is member of any team from at least 1 org
// N-1 orgs, M teams per org, 1 org with no teams: user is member of any team
//
// from at least 1 org, or member of org with no teams
func (c *githubConnector) groupsForOrgs(ctx context.Context, client *http.Client, userName string) ([]string, error) {
groups := make([]string, 0)
Expand Down
42 changes: 2 additions & 40 deletions connector/oauth/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,17 @@ package oauth

import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"os"
"strings"
"time"

"golang.org/x/oauth2"

"github.com/dexidp/dex/connector"
"github.com/dexidp/dex/pkg/httpclient"
"github.com/dexidp/dex/pkg/log"
)

Expand Down Expand Up @@ -112,48 +108,14 @@ func (c *Config) Open(id string, logger log.Logger) (connector.Connector, error)
emailVerifiedKey: emailVerifiedKey,
}

oauthConn.httpClient, err = newHTTPClient(c.RootCAs, c.InsecureSkipVerify)
oauthConn.httpClient, err = httpclient.NewHTTPClient(c.RootCAs, c.InsecureSkipVerify)
if err != nil {
return nil, err
}

return oauthConn, err
}

func newHTTPClient(rootCAs []string, insecureSkipVerify bool) (*http.Client, error) {
pool, err := x509.SystemCertPool()
if err != nil {
return nil, err
}

tlsConfig := tls.Config{RootCAs: pool, InsecureSkipVerify: insecureSkipVerify}
for _, rootCA := range rootCAs {
rootCABytes, err := os.ReadFile(rootCA)
if err != nil {
return nil, fmt.Errorf("failed to read root-ca: %v", err)
}
if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCABytes) {
return nil, fmt.Errorf("no certs found in root CA file %q", rootCA)
}
}

return &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tlsConfig,
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
},
}, nil
}

func (c *oauthConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) {
if c.redirectURI != callbackURL {
return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI)
Expand Down
41 changes: 2 additions & 39 deletions connector/oidc/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,19 @@ package oidc

import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"os"
"strings"
"time"

"github.com/coreos/go-oidc/v3/oidc"
"golang.org/x/oauth2"

"github.com/dexidp/dex/connector"
"github.com/dexidp/dex/pkg/httpclient"
"github.com/dexidp/dex/pkg/log"
)

Expand Down Expand Up @@ -119,7 +116,7 @@ func knownBrokenAuthHeaderProvider(issuerURL string) bool {
// Open returns a connector which can be used to login users through an upstream
// OpenID Connect provider.
func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, err error) {
httpClient, err := newHTTPClient(c.RootCAs, c.InsecureSkipVerify)
httpClient, err := httpclient.NewHTTPClient(c.RootCAs, c.InsecureSkipVerify)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -188,40 +185,6 @@ func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, e
}, nil
}

func newHTTPClient(rootCAs []string, insecureSkipVerify bool) (*http.Client, error) {
pool, err := x509.SystemCertPool()
if err != nil {
return nil, err
}

tlsConfig := tls.Config{RootCAs: pool, InsecureSkipVerify: insecureSkipVerify}
for _, rootCA := range rootCAs {
rootCABytes, err := os.ReadFile(rootCA)
if err != nil {
return nil, fmt.Errorf("failed to read root-ca: %v", err)
}
if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCABytes) {
return nil, fmt.Errorf("no certs found in root CA file %q", rootCA)
}
}

return &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tlsConfig,
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
},
}, nil
}

var (
_ connector.CallbackConnector = (*oidcConnector)(nil)
_ connector.RefreshConnector = (*oidcConnector)(nil)
Expand Down
46 changes: 7 additions & 39 deletions connector/openshift/openshift.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,17 @@ package openshift

import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"os"
"strings"
"time"

"golang.org/x/oauth2"

"github.com/dexidp/dex/connector"
"github.com/dexidp/dex/pkg/groups"
"github.com/dexidp/dex/pkg/httpclient"
"github.com/dexidp/dex/pkg/log"
"github.com/dexidp/dex/storage/kubernetes/k8sapi"
)
Expand Down Expand Up @@ -67,7 +63,12 @@ type user struct {
// Open returns a connector which can be used to login users through an upstream
// OpenShift OAuth2 provider.
func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, err error) {
httpClient, err := newHTTPClient(c.InsecureCA, c.RootCA)
var rootCAs []string
if c.RootCA != "" {
rootCAs = append(rootCAs, c.RootCA)
}

httpClient, err := httpclient.NewHTTPClient(rootCAs, c.InsecureCA)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client: %w", err)
}
Expand Down Expand Up @@ -262,36 +263,3 @@ func validateAllowedGroups(userGroups, allowedGroups []string) bool {

return len(matchingGroups) != 0
}

// newHTTPClient returns a new HTTP client
func newHTTPClient(insecureCA bool, rootCA string) (*http.Client, error) {
tlsConfig := tls.Config{}
if insecureCA {
tlsConfig = tls.Config{InsecureSkipVerify: true}
} else if rootCA != "" {
tlsConfig = tls.Config{RootCAs: x509.NewCertPool()}
rootCABytes, err := os.ReadFile(rootCA)
if err != nil {
return nil, fmt.Errorf("failed to read root-ca: %w", err)
}
if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCABytes) {
return nil, fmt.Errorf("no certs found in root CA file %q", rootCA)
}
}

return &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tlsConfig,
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
},
}, nil
}
11 changes: 6 additions & 5 deletions connector/openshift/openshift_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"golang.org/x/oauth2"

"github.com/dexidp/dex/connector"
"github.com/dexidp/dex/pkg/httpclient"
"github.com/dexidp/dex/storage/kubernetes/k8sapi"
)

Expand Down Expand Up @@ -70,7 +71,7 @@ func TestGetUser(t *testing.T) {
_, err = http.NewRequest("GET", hostURL.String(), nil)
expectNil(t, err)

h, err := newHTTPClient(true, "")
h, err := httpclient.NewHTTPClient(nil, true)

expectNil(t, err)

Expand Down Expand Up @@ -128,7 +129,7 @@ func TestVerifyGroup(t *testing.T) {
_, err = http.NewRequest("GET", hostURL.String(), nil)
expectNil(t, err)

h, err := newHTTPClient(true, "")
h, err := httpclient.NewHTTPClient(nil, true)

expectNil(t, err)

Expand Down Expand Up @@ -164,7 +165,7 @@ func TestCallbackIdentity(t *testing.T) {
req, err := http.NewRequest("GET", hostURL.String(), nil)
expectNil(t, err)

h, err := newHTTPClient(true, "")
h, err := httpclient.NewHTTPClient(nil, true)

expectNil(t, err)

Expand Down Expand Up @@ -198,7 +199,7 @@ func TestRefreshIdentity(t *testing.T) {
})
defer s.Close()

h, err := newHTTPClient(true, "")
h, err := httpclient.NewHTTPClient(nil, true)
expectNil(t, err)

oc := openshiftConnector{apiURL: s.URL, httpClient: h, oauth2Config: &oauth2.Config{
Expand Down Expand Up @@ -237,7 +238,7 @@ func TestRefreshIdentityFailure(t *testing.T) {
})
defer s.Close()

h, err := newHTTPClient(true, "")
h, err := httpclient.NewHTTPClient(nil, true)
expectNil(t, err)

oc := openshiftConnector{apiURL: s.URL, httpClient: h, oauth2Config: &oauth2.Config{
Expand Down
Loading

0 comments on commit 5410c7a

Please sign in to comment.