Skip to content

Commit

Permalink
refactor: clone default http transport for registry client (#840)
Browse files Browse the repository at this point in the history
Signed-off-by: Billy Zha <jinzha1@microsoft.com>
  • Loading branch information
qweeah authored Mar 3, 2023
1 parent d93544d commit 70fde82
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 89 deletions.
47 changes: 13 additions & 34 deletions cmd/oras/internal/option/remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (
"os"
"strconv"
"strings"
"time"

"github.com/spf13/pflag"
"oras.land/oras-go/v2/registry/remote"
Expand All @@ -49,7 +48,6 @@ type Remote struct {
Password string

resolveFlag []string
resolveDialContext func(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error)
applyDistributionSpec bool
distributionSpec distributionSpec
headerFlags []string
Expand Down Expand Up @@ -131,9 +129,9 @@ func (opts *Remote) readPassword() (err error) {
}

// parseResolve parses resolve flag.
func (opts *Remote) parseResolve() error {
func (opts *Remote) parseResolve(baseDial onet.DialFunc) (onet.DialFunc, error) {
if len(opts.resolveFlag) == 0 {
return nil
return baseDial, nil
}

formatError := func(param, message string) error {
Expand All @@ -144,32 +142,29 @@ func (opts *Remote) parseResolve() error {
parts := strings.SplitN(r, ":", 4)
length := len(parts)
if length < 3 {
return formatError(r, "expecting host:port:address[:address_port]")
return nil, formatError(r, "expecting host:port:address[:address_port]")
}
host := parts[0]
hostPort, err := strconv.Atoi(parts[1])
if err != nil {
return formatError(r, "expecting uint64 host port")
return nil, formatError(r, "expecting uint64 host port")
}
// ipv6 zone is not parsed
address := net.ParseIP(parts[2])
if address == nil {
return formatError(r, "invalid IP address")
return nil, formatError(r, "invalid IP address")
}
addressPort := hostPort
if length > 3 {
addressPort, err = strconv.Atoi(parts[3])
if err != nil {
return formatError(r, "expecting uint64 address port")
return nil, formatError(r, "expecting uint64 address port")
}
}
dialer.Add(host, hostPort, address, addressPort)
}
opts.resolveDialContext = func(base *net.Dialer) func(context.Context, string, string) (net.Conn, error) {
dialer.Dialer = base
return dialer.DialContext
}
return nil
dialer.BaseDialContext = baseDial
return dialer.DialContext, nil
}

// tlsConfig assembles the tls config.
Expand All @@ -193,29 +188,13 @@ func (opts *Remote) authClient(registry string, debug bool) (client *auth.Client
if err != nil {
return nil, err
}
if err := opts.parseResolve(); err != nil {
baseTransport := http.DefaultTransport.(*http.Transport).Clone()
baseTransport.TLSClientConfig = config
dialContext, err := opts.parseResolve(baseTransport.DialContext)
if err != nil {
return nil, err
}
resolveDialContext := opts.resolveDialContext
if resolveDialContext == nil {
resolveDialContext = func(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) {
return dialer.DialContext
}
}
// default value are derived from http.DefaultTransport
baseTransport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: resolveDialContext(&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}),
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
TLSClientConfig: config,
}
baseTransport.DialContext = dialContext
client = &auth.Client{
Client: &http.Client{
// http.RoundTripper with a retry using the DefaultPolicy
Expand Down
103 changes: 51 additions & 52 deletions cmd/oras/internal/option/remote_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,55 @@ func TestRemote_NewRepository(t *testing.T) {
}
}

func TestRemote_NewRepository_Retry(t *testing.T) {
caPath := filepath.Join(t.TempDir(), "oras-test.pem")
if err := os.WriteFile(caPath, pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: ts.Certificate().Raw}), 0644); err != nil {
t.Fatalf("unexpected error: %v", err)
}
retries, count := 3, 0
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
count++
if count < retries {
http.Error(w, "error", http.StatusTooManyRequests)
return
}
json.NewEncoder(w).Encode(testTagList)
}))
defer ts.Close()
opts := struct {
Remote
Common
}{
Remote{
CACertFilePath: caPath,
},
Common{},
}

uri, err := url.ParseRequestURI(ts.URL)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
repo, err := opts.NewRepository(uri.Host+"/"+testRepo, opts.Common)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

if err = repo.Tags(context.Background(), "", func(got []string) error {
want := []string{"tag"}
if len(got) != len(testTagList.Tags) || !reflect.DeepEqual(got, want) {
return fmt.Errorf("expect: %v, got: %v", testTagList.Tags, got)
}
return nil
}); err != nil {
t.Fatalf("unexpected error: %v", err)
}

if count != retries {
t.Errorf("expected %d retries, got %d", retries, count)
}
}

func TestRemote_isPlainHttp_localhost(t *testing.T) {
opts := Remote{PlainHTTP: false}
got := opts.isPlainHttp("localhost")
Expand Down Expand Up @@ -286,7 +335,7 @@ func TestRemote_parseResolve_err(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.opts.parseResolve(); err == nil {
if _, err := tt.opts.parseResolve(nil); err == nil {
t.Errorf("Expecting error in Remote.parseResolve()")
}
})
Expand All @@ -309,7 +358,7 @@ func TestRemote_parseResolve(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.opts.parseResolve(); err != nil {
if _, err := tt.opts.parseResolve(nil); err != nil {
t.Errorf("Remote.parseResolve() error = %v", err)
}
})
Expand Down Expand Up @@ -416,53 +465,3 @@ func TestRemote_parseCustomHeaders(t *testing.T) {
})
}
}

func TestRemote_NewRepository_Retry(t *testing.T) {
caPath := filepath.Join(t.TempDir(), "oras-test.pem")
if err := os.WriteFile(caPath, pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: ts.Certificate().Raw}), 0644); err != nil {
t.Fatalf("unexpected error: %v", err)
}
retries, count := 3, 0
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
count++
if count < retries {
http.Error(w, "error", http.StatusTooManyRequests)
return
}
json.NewEncoder(w).Encode(testTagList)
}))
defer ts.Close()

opts := struct {
Remote
Common
}{
Remote{
CACertFilePath: caPath,
},
Common{},
}

uri, err := url.ParseRequestURI(ts.URL)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
repo, err := opts.NewRepository(uri.Host+"/"+testRepo, opts.Common)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

if err = repo.Tags(context.Background(), "", func(got []string) error {
want := []string{"tag"}
if len(got) != len(testTagList.Tags) || !reflect.DeepEqual(got, want) {
return fmt.Errorf("expect: %v, got: %v", testTagList.Tags, got)
}
return nil
}); err != nil {
t.Fatalf("unexpected error: %v", err)
}

if count != retries {
t.Errorf("expected %d retries, got %d", retries, count)
}
}
9 changes: 6 additions & 3 deletions internal/net/net.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@ import (
"net"
)

// DialFunc is the function type for http.DialContext.
type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error)

// Dialer struct provides dialing function with predefined DNS resolves.
type Dialer struct {
*net.Dialer
resolve map[string]string
BaseDialContext DialFunc
resolve map[string]string
}

// Add adds an entry for DNS resolve.
Expand All @@ -41,5 +44,5 @@ func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Con
if resolved, ok := d.resolve[addr]; ok {
addr = resolved
}
return d.Dialer.DialContext(ctx, network, addr)
return d.BaseDialContext(ctx, network, addr)
}

0 comments on commit 70fde82

Please sign in to comment.