diff --git a/pkg/httpclient/client.go b/pkg/httpclient/client.go index 64aefa5..11bbf26 100644 --- a/pkg/httpclient/client.go +++ b/pkg/httpclient/client.go @@ -17,47 +17,51 @@ type Header struct { func MakeRequest(c *fasthttp.Client, url string, maxRetries uint, timeout uint, headers ...Header) ([]byte, error) { var ( - req *fasthttp.Request - resp *fasthttp.Response + req *fasthttp.Request + respBody []byte + err error ) retries := int(maxRetries) for i := retries; i >= 0; i-- { req = fasthttp.AcquireRequest() - defer fasthttp.ReleaseRequest(req) req.Header.SetMethod(fasthttp.MethodGet) for _, header := range headers { - req.Header.Set(header.Key, header.Value) + if header.Key != "" { + req.Header.Set(header.Key, header.Value) + } } req.Header.Set(fasthttp.HeaderUserAgent, getUserAgent()) + req.Header.Set("Accept", "*/*") req.SetRequestURI(url) - - resp = fasthttp.AcquireResponse() - defer fasthttp.ReleaseResponse(resp) - - if err := c.DoTimeout(req, resp, time.Second*time.Duration(timeout)); err != nil { - fasthttp.ReleaseRequest(req) - if retries == 0 { - return nil, err - } - } - - if resp.Body() == nil { - if retries == 0 { - return nil, ErrNilResponse - } - } - // url responded with 503, so try again - if resp.StatusCode() == 503 { - continue + respBody, err = doReq(c, req, timeout) + if err == nil { + goto done } - - goto done } done: + if err != nil { + return nil, err + } + return respBody, nil +} + +// doReq handles http requests +func doReq(c *fasthttp.Client, req *fasthttp.Request, timeout uint) ([]byte, error) { + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseResponse(resp) + defer fasthttp.ReleaseRequest(req) + if err := c.DoTimeout(req, resp, time.Second*time.Duration(timeout)); err != nil { + return nil, err + } if resp.StatusCode() != 200 { return nil, ErrNon200Response } + + if resp.Body() == nil { + return nil, ErrNilResponse + } + return resp.Body(), nil } diff --git a/pkg/providers/urlscan/urlscan.go b/pkg/providers/urlscan/urlscan.go index 0a102cc..46b58a4 100644 --- a/pkg/providers/urlscan/urlscan.go +++ b/pkg/providers/urlscan/urlscan.go @@ -8,6 +8,7 @@ import ( "github.com/lc/gau/v2/pkg/httpclient" "github.com/lc/gau/v2/pkg/providers" "github.com/sirupsen/logrus" + "strings" ) const ( @@ -55,7 +56,6 @@ paginate: if err != nil { return fmt.Errorf("failed to fetch urlscan: %s", err) } - var result apiResponse decoder := jsoniter.NewDecoder(bytes.NewReader(resp)) decoder.UseNumber() @@ -72,7 +72,7 @@ paginate: total := len(result.Results) for i, res := range result.Results { - if res.Page.Domain == domain { + if res.Page.Domain == domain || (c.config.IncludeSubdomains && strings.HasSuffix(res.Page.Domain, domain)) { results <- res.Page.URL }