diff --git a/cmd/http.go b/cmd/http.go index 8a65839c4..e1d4e4947 100644 --- a/cmd/http.go +++ b/cmd/http.go @@ -16,7 +16,7 @@ var httpCmd = &cobra.Command{ plugin.Serve(&plugin.ServeConfig{ HandshakeConfig: dkplugin.Handshake, Plugins: map[string]plugin.Plugin{ - "executor": &dkplugin.ExecutorPlugin{Executor: &http.HTTP{}}, + "executor": &dkplugin.ExecutorPlugin{Executor: http.New()}, }, // A non-nil value here enables gRPC serving for this plugin... diff --git a/plugin/http/http.go b/plugin/http/http.go index f82343073..3e1c9072a 100644 --- a/plugin/http/http.go +++ b/plugin/http/http.go @@ -7,6 +7,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "io/ioutil" "log" "net/http" @@ -31,6 +32,14 @@ const ( // HTTP process http request type HTTP struct { + clientPool map[string]http.Client +} + +// New +func New() *HTTP { + return &HTTP{ + clientPool: make(map[string]http.Client), + } } // Execute Process method of the plugin @@ -96,18 +105,31 @@ func (s *HTTP) ExecuteImpl(args *types.ExecuteRequest) ([]byte, error) { log.Printf("request %#v\n\n", req) } - client, warns := createClient(args.Config) - for _, warn := range warns { - output.Write([]byte(fmt.Sprintf("Warning: %s.\n", warn.Error()))) + // get client from pool + var ( + client http.Client + ok bool + ) + + cc := args.Config["timeout"] + args.Config["tlsRootCAsFile"] + args.Config["tlsCertificateFile"] + args.Config["tlsCertificateKeyFile"] + + if client, ok = s.clientPool[cc]; !ok { + var warns []error + client, warns = createClient(args.Config) + for _, warn := range warns { + _, _ = output.Write([]byte(fmt.Sprintf("Warning: %s.\n", warn.Error()))) + } + s.clientPool[cc] = client } + // do request resp, err := client.Do(req) if err != nil { return output.Bytes(), err } defer resp.Body.Close() - out, err := ioutil.ReadAll(resp.Body) + out, err := io.ReadAll(resp.Body) if err != nil { return output.Bytes(), err } diff --git a/plugin/http/http_test.go b/plugin/http/http_test.go index f8eb68cec..e8ee6eba8 100644 --- a/plugin/http/http_test.go +++ b/plugin/http/http_test.go @@ -3,7 +3,7 @@ package http import ( "bytes" "fmt" - "io/ioutil" + "io" "log" "net/http" "net/http/httptest" @@ -43,13 +43,13 @@ func newTestServer() *httptest.Server { // Echo POST body back to request case "/echo": if r.Method == http.MethodPost { - in, err := ioutil.ReadAll(r.Body) + in, err := io.ReadAll(r.Body) if err != nil { w.WriteHeader(500) return } r.Body.Close() - w.Write(in) + _, _ = w.Write(in) w.WriteHeader(200) return } @@ -81,7 +81,7 @@ func TestExecute(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - http := &HTTP{} + http := New() pa := &types.ExecuteRequest{ JobName: tt.name, Config: tt.config, @@ -113,7 +113,7 @@ func TestNoVerifyPeer(t *testing.T) { "tlsNoVerifyPeer": "true", }, } - http := &HTTP{} + http := New() output, _ := http.Execute(pa, nil) fmt.Println(string(output.Output)) fmt.Println(output.Error) @@ -133,7 +133,7 @@ func TestClientSSLCert(t *testing.T) { "tlsCertificateKeyFile": "testdata/badssl.com-client-key-decrypted.pem", }, } - http := &HTTP{} + http := New() output, _ := http.Execute(pa, nil) fmt.Println(string(output.Output)) fmt.Println(output.Error) @@ -152,7 +152,7 @@ func TestRootCA(t *testing.T) { "tlsRootCAsFile": "testdata/badssl-ca-untrusted-root.crt", }, } - http := &HTTP{} + http := New() output, _ := http.Execute(pa, nil) fmt.Println(string(output.Output)) fmt.Println(output.Error)