diff --git a/pkg/protocols/network/request.go b/pkg/protocols/network/request.go index e66524760e..c896b6551b 100644 --- a/pkg/protocols/network/request.go +++ b/pkg/protocols/network/request.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "net/url" + "os" "strings" "time" @@ -265,7 +266,7 @@ func (request *Request) executeRequestWithPayloads(variables map[string]interfac } if input.Read > 0 { - buffer, err := reader.ConnReadNWithTimeout(conn, int64(input.Read), DefaultReadTimeout) + buffer, err := ConnReadNWithTimeout(conn, int64(input.Read), DefaultReadTimeout) if err != nil { return errorutil.NewWithErr(err).Msgf("could not read response from connection") } @@ -315,7 +316,7 @@ func (request *Request) executeRequestWithPayloads(variables map[string]interfac bufferSize = -1 } - final, err := reader.ConnReadNWithTimeout(conn, int64(bufferSize), DefaultReadTimeout) + final, err := ConnReadNWithTimeout(conn, int64(bufferSize), DefaultReadTimeout) if err != nil { request.options.Output.Request(request.options.TemplatePath, address, request.Type().String(), err) return errors.Wrap(err, "could not read from server") @@ -412,3 +413,27 @@ func getAddress(toTest string) (string, error) { } return toTest, nil } + +func ConnReadNWithTimeout(conn net.Conn, n int64, timeout time.Duration) ([]byte, error) { + if timeout == 0 { + timeout = DefaultReadTimeout + } + if n == -1 { + // if n is -1 then read all available data from connection + return reader.ConnReadNWithTimeout(conn, -1, timeout) + } else if n == 0 { + n = 4096 // default buffer size + } + b := make([]byte, n) + _ = conn.SetDeadline(time.Now().Add(timeout)) + count, err := conn.Read(b) + _ = conn.SetDeadline(time.Time{}) + if err != nil && os.IsTimeout(err) && count > 0 { + // in case of timeout with some value read, return the value + return b[:count], nil + } + if err != nil { + return nil, err + } + return b[:count], nil +}