From 70fde821540fb11e2d6caa5a84e1f48f7f082203 Mon Sep 17 00:00:00 2001 From: Billy Zha Date: Fri, 3 Mar 2023 14:55:51 +0800 Subject: [PATCH] refactor: clone default http transport for registry client (#840) Signed-off-by: Billy Zha --- cmd/oras/internal/option/remote.go | 47 +++-------- cmd/oras/internal/option/remote_test.go | 103 ++++++++++++------------ internal/net/net.go | 9 ++- 3 files changed, 70 insertions(+), 89 deletions(-) diff --git a/cmd/oras/internal/option/remote.go b/cmd/oras/internal/option/remote.go index 99e5d23c9..05fbe58d2 100644 --- a/cmd/oras/internal/option/remote.go +++ b/cmd/oras/internal/option/remote.go @@ -25,7 +25,6 @@ import ( "os" "strconv" "strings" - "time" "github.com/spf13/pflag" "oras.land/oras-go/v2/registry/remote" @@ -49,7 +48,6 @@ type Remote struct { Password string resolveFlag []string - resolveDialContext func(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) applyDistributionSpec bool distributionSpec distributionSpec headerFlags []string @@ -131,9 +129,9 @@ func (opts *Remote) readPassword() (err error) { } // parseResolve parses resolve flag. -func (opts *Remote) parseResolve() error { +func (opts *Remote) parseResolve(baseDial onet.DialFunc) (onet.DialFunc, error) { if len(opts.resolveFlag) == 0 { - return nil + return baseDial, nil } formatError := func(param, message string) error { @@ -144,32 +142,29 @@ func (opts *Remote) parseResolve() error { parts := strings.SplitN(r, ":", 4) length := len(parts) if length < 3 { - return formatError(r, "expecting host:port:address[:address_port]") + return nil, formatError(r, "expecting host:port:address[:address_port]") } host := parts[0] hostPort, err := strconv.Atoi(parts[1]) if err != nil { - return formatError(r, "expecting uint64 host port") + return nil, formatError(r, "expecting uint64 host port") } // ipv6 zone is not parsed address := net.ParseIP(parts[2]) if address == nil { - return formatError(r, "invalid IP address") + return nil, formatError(r, "invalid IP address") } addressPort := hostPort if length > 3 { addressPort, err = strconv.Atoi(parts[3]) if err != nil { - return formatError(r, "expecting uint64 address port") + return nil, formatError(r, "expecting uint64 address port") } } dialer.Add(host, hostPort, address, addressPort) } - opts.resolveDialContext = func(base *net.Dialer) func(context.Context, string, string) (net.Conn, error) { - dialer.Dialer = base - return dialer.DialContext - } - return nil + dialer.BaseDialContext = baseDial + return dialer.DialContext, nil } // tlsConfig assembles the tls config. @@ -193,29 +188,13 @@ func (opts *Remote) authClient(registry string, debug bool) (client *auth.Client if err != nil { return nil, err } - if err := opts.parseResolve(); err != nil { + baseTransport := http.DefaultTransport.(*http.Transport).Clone() + baseTransport.TLSClientConfig = config + dialContext, err := opts.parseResolve(baseTransport.DialContext) + if err != nil { return nil, err } - resolveDialContext := opts.resolveDialContext - if resolveDialContext == nil { - resolveDialContext = func(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) { - return dialer.DialContext - } - } - // default value are derived from http.DefaultTransport - baseTransport := &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: resolveDialContext(&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }), - ForceAttemptHTTP2: true, - MaxIdleConns: 100, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - TLSClientConfig: config, - } + baseTransport.DialContext = dialContext client = &auth.Client{ Client: &http.Client{ // http.RoundTripper with a retry using the DefaultPolicy diff --git a/cmd/oras/internal/option/remote_test.go b/cmd/oras/internal/option/remote_test.go index 099c70c04..532669203 100644 --- a/cmd/oras/internal/option/remote_test.go +++ b/cmd/oras/internal/option/remote_test.go @@ -231,6 +231,55 @@ func TestRemote_NewRepository(t *testing.T) { } } +func TestRemote_NewRepository_Retry(t *testing.T) { + caPath := filepath.Join(t.TempDir(), "oras-test.pem") + if err := os.WriteFile(caPath, pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: ts.Certificate().Raw}), 0644); err != nil { + t.Fatalf("unexpected error: %v", err) + } + retries, count := 3, 0 + ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + count++ + if count < retries { + http.Error(w, "error", http.StatusTooManyRequests) + return + } + json.NewEncoder(w).Encode(testTagList) + })) + defer ts.Close() + opts := struct { + Remote + Common + }{ + Remote{ + CACertFilePath: caPath, + }, + Common{}, + } + + uri, err := url.ParseRequestURI(ts.URL) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + repo, err := opts.NewRepository(uri.Host+"/"+testRepo, opts.Common) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if err = repo.Tags(context.Background(), "", func(got []string) error { + want := []string{"tag"} + if len(got) != len(testTagList.Tags) || !reflect.DeepEqual(got, want) { + return fmt.Errorf("expect: %v, got: %v", testTagList.Tags, got) + } + return nil + }); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if count != retries { + t.Errorf("expected %d retries, got %d", retries, count) + } +} + func TestRemote_isPlainHttp_localhost(t *testing.T) { opts := Remote{PlainHTTP: false} got := opts.isPlainHttp("localhost") @@ -286,7 +335,7 @@ func TestRemote_parseResolve_err(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := tt.opts.parseResolve(); err == nil { + if _, err := tt.opts.parseResolve(nil); err == nil { t.Errorf("Expecting error in Remote.parseResolve()") } }) @@ -309,7 +358,7 @@ func TestRemote_parseResolve(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := tt.opts.parseResolve(); err != nil { + if _, err := tt.opts.parseResolve(nil); err != nil { t.Errorf("Remote.parseResolve() error = %v", err) } }) @@ -416,53 +465,3 @@ func TestRemote_parseCustomHeaders(t *testing.T) { }) } } - -func TestRemote_NewRepository_Retry(t *testing.T) { - caPath := filepath.Join(t.TempDir(), "oras-test.pem") - if err := os.WriteFile(caPath, pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: ts.Certificate().Raw}), 0644); err != nil { - t.Fatalf("unexpected error: %v", err) - } - retries, count := 3, 0 - ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - count++ - if count < retries { - http.Error(w, "error", http.StatusTooManyRequests) - return - } - json.NewEncoder(w).Encode(testTagList) - })) - defer ts.Close() - - opts := struct { - Remote - Common - }{ - Remote{ - CACertFilePath: caPath, - }, - Common{}, - } - - uri, err := url.ParseRequestURI(ts.URL) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - repo, err := opts.NewRepository(uri.Host+"/"+testRepo, opts.Common) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if err = repo.Tags(context.Background(), "", func(got []string) error { - want := []string{"tag"} - if len(got) != len(testTagList.Tags) || !reflect.DeepEqual(got, want) { - return fmt.Errorf("expect: %v, got: %v", testTagList.Tags, got) - } - return nil - }); err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if count != retries { - t.Errorf("expected %d retries, got %d", retries, count) - } -} diff --git a/internal/net/net.go b/internal/net/net.go index ad42d2f13..f080f99d4 100644 --- a/internal/net/net.go +++ b/internal/net/net.go @@ -21,10 +21,13 @@ import ( "net" ) +// DialFunc is the function type for http.DialContext. +type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) + // Dialer struct provides dialing function with predefined DNS resolves. type Dialer struct { - *net.Dialer - resolve map[string]string + BaseDialContext DialFunc + resolve map[string]string } // Add adds an entry for DNS resolve. @@ -41,5 +44,5 @@ func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Con if resolved, ok := d.resolve[addr]; ok { addr = resolved } - return d.Dialer.DialContext(ctx, network, addr) + return d.BaseDialContext(ctx, network, addr) }