diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 5fd743ea3..9d2dd0c9f 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -34,9 +34,12 @@ jobs: PDCP_API_KEY: "${{ secrets.PDCP_API_KEY }}" - - name: Running example + - name: Testing Example - Simple run: go run . - working-directory: examples/ + working-directory: examples/simple/ + - name: Testing Example - Speed Control + run: go run . + working-directory: examples/speed_control/ - name: Integration Tests Linux, macOS if: runner.os == 'Linux' || runner.os == 'macOS' diff --git a/examples/example.go b/examples/simple/main.go similarity index 90% rename from examples/example.go rename to examples/simple/main.go index ae2bb4ed7..5b14cb04c 100644 --- a/examples/example.go +++ b/examples/simple/main.go @@ -16,7 +16,6 @@ func main() { options := runner.Options{ Methods: "GET", InputTargetHost: goflags.StringSlice{"scanme.sh", "projectdiscovery.io", "localhost"}, - //InputFile: "./targetDomains.txt", // path to file containing the target domains list OnResult: func(r runner.Result) { // handle error if r.Err != nil { diff --git a/examples/speed_control/main.go b/examples/speed_control/main.go new file mode 100644 index 000000000..5e34f6a71 --- /dev/null +++ b/examples/speed_control/main.go @@ -0,0 +1,99 @@ +package main + +import ( + "bytes" + "encoding/json" + "fmt" + "log" + "net/http" + "time" + + "github.com/projectdiscovery/goflags" + "github.com/projectdiscovery/gologger" + "github.com/projectdiscovery/gologger/levels" + "github.com/projectdiscovery/httpx/runner" +) + +func main() { + gologger.DefaultLogger.SetMaxLevel(levels.LevelVerbose) // increase the verbosity (optional) + + // generate urls + var urls []string + for i := 0; i < 100; i++ { + urls = append(urls, fmt.Sprintf("https://scanme.sh/a=%d", i)) + } + + apiEndpoint := "127.0.0.1:31234" + + options := runner.Options{ + Methods: "GET", + InputTargetHost: goflags.StringSlice(urls), + Threads: 1, + HttpApiEndpoint: apiEndpoint, + OnResult: func(r runner.Result) { + // handle error + if r.Err != nil { + fmt.Printf("[Err] %s: %s\n", r.Input, r.Err) + return + } + fmt.Printf("%s %s %d\n", r.Input, r.Host, r.StatusCode) + }, + } + + // after 3 seconds increase the speed to 50 + time.AfterFunc(3*time.Second, func() { + client := &http.Client{} + + concurrencySettings := runner.Concurrency{Threads: 50} + requestBody, err := json.Marshal(concurrencySettings) + if err != nil { + log.Fatalf("Error creating request body: %v", err) + } + + req, err := http.NewRequest("PUT", fmt.Sprintf("http://%s/api/concurrency", apiEndpoint), bytes.NewBuffer(requestBody)) + if err != nil { + log.Fatalf("Error creating PUT request: %v", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := client.Do(req) + if err != nil { + log.Fatalf("Error sending PUT request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + log.Printf("Failed to update threads, status code: %d", resp.StatusCode) + } else { + log.Println("Threads updated to 50 successfully") + } + }) + + if err := options.ValidateOptions(); err != nil { + log.Fatal(err) + } + + httpxRunner, err := runner.New(&options) + if err != nil { + log.Fatal(err) + } + defer httpxRunner.Close() + + httpxRunner.RunEnumeration() + + // check the threads + req, err := http.Get(fmt.Sprintf("http://%s/api/concurrency", apiEndpoint)) + if err != nil { + log.Fatalf("Error creating GET request: %v", err) + } + var concurrencySettings runner.Concurrency + if err := json.NewDecoder(req.Body).Decode(&concurrencySettings); err != nil { + log.Fatalf("Error decoding response body: %v", err) + } + + if concurrencySettings.Threads == 50 { + log.Println("Threads are set to 50") + } else { + log.Fatalf("Fatal error: Threads are not set to 50, current value: %d", concurrencySettings.Threads) + } +} diff --git a/go.mod b/go.mod index abba67c0f..e96dece78 100644 --- a/go.mod +++ b/go.mod @@ -39,7 +39,6 @@ require ( github.com/projectdiscovery/useragent v0.0.54 github.com/projectdiscovery/utils v0.1.1 github.com/projectdiscovery/wappalyzergo v0.1.4 - github.com/remeh/sizedwaitgroup v1.0.0 github.com/rs/xid v1.5.0 github.com/spaolacci/murmur3 v1.1.0 github.com/stretchr/testify v1.9.0 diff --git a/go.sum b/go.sum index acc8c6aab..f70763486 100644 --- a/go.sum +++ b/go.sum @@ -62,6 +62,8 @@ github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5/go.mod h1:qssHWj6 github.com/dsnet/golib v0.0.0-20171103203638-1ea166775780/go.mod h1:Lj+Z9rebOhdfkVLjJ8T6VcRQv3SXugXy999NBtR9aFY= github.com/fatih/color v1.15.0 h1:kOqh6YHBtK8aywxGerMG2Eq3H6Qgoqeo13Bk2Mv/nBs= github.com/fatih/color v1.15.0/go.mod h1:0h5ZqXfHYED7Bhv2ZJamyIOUej9KtShiJESRwBDUSsw= +github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw= +github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= @@ -268,8 +270,6 @@ github.com/quic-go/quic-go v0.42.0 h1:uSfdap0eveIl8KXnipv9K7nlwZ5IqLlYOpJ58u5utp github.com/quic-go/quic-go v0.42.0/go.mod h1:132kz4kL3F9vxhW3CtQJLDVwcFe5wdWeJXXijhsO57M= github.com/refraction-networking/utls v1.5.4 h1:9k6EO2b8TaOGsQ7Pl7p9w6PUhx18/ZCeT0WNTZ7Uw4o= github.com/refraction-networking/utls v1.5.4/go.mod h1:SPuDbBmgLGp8s+HLNc83FuavwZCFoMmExj+ltUHiHUw= -github.com/remeh/sizedwaitgroup v1.0.0 h1:VNGGFwNo/R5+MJBf6yrsr110p0m4/OX4S3DCy7Kyl5E= -github.com/remeh/sizedwaitgroup v1.0.0/go.mod h1:3j2R4OIe/SeS6YDhICBy22RWjJC5eNCJ1V+9+NVNYlo= github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis= diff --git a/runner/apiendpoint.go b/runner/apiendpoint.go new file mode 100644 index 000000000..5faa4b93f --- /dev/null +++ b/runner/apiendpoint.go @@ -0,0 +1,73 @@ +// TODO: move this to internal package +package runner + +import ( + "encoding/json" + "net/http" +) + +type Concurrency struct { + Threads int `json:"threads"` +} + +// Server represents the HTTP server that handles the concurrency settings endpoints. +type Server struct { + addr string + config *Options +} + +// New creates a new instance of Server. +func NewServer(addr string, config *Options) *Server { + return &Server{ + addr: addr, + config: config, + } +} + +// Start initializes the server and its routes, then starts listening on the specified address. +func (s *Server) Start() error { + http.HandleFunc("/api/concurrency", s.handleConcurrency) + if err := http.ListenAndServe(s.addr, nil); err != nil { + return err + } + return nil +} + +// handleConcurrency routes the request based on its method to the appropriate handler. +func (s *Server) handleConcurrency(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + s.getSettings(w, r) + case http.MethodPut: + s.updateSettings(w, r) + default: + http.Error(w, "Unsupported HTTP method", http.StatusMethodNotAllowed) + } +} + +// GetSettings handles GET requests and returns the current concurrency settings +func (s *Server) getSettings(w http.ResponseWriter, _ *http.Request) { + concurrencySettings := Concurrency{ + Threads: s.config.Threads, + } + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(concurrencySettings); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } +} + +// UpdateSettings handles PUT requests to update the concurrency settings +func (s *Server) updateSettings(w http.ResponseWriter, r *http.Request) { + var newSettings Concurrency + if err := json.NewDecoder(r.Body).Decode(&newSettings); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + if newSettings.Threads > 0 { + s.config.Threads = newSettings.Threads + } + + w.WriteHeader(http.StatusOK) +} diff --git a/runner/options.go b/runner/options.go index 2321c7888..fb5550ce5 100644 --- a/runner/options.go +++ b/runner/options.go @@ -34,6 +34,7 @@ import ( const ( two = 2 + defaultThreads = 50 DefaultResumeFile = "resume.cfg" DefaultOutputDirectory = "output" ) @@ -295,6 +296,7 @@ type Options struct { UseInstalledChrome bool TlsImpersonate bool DisableStdin bool + HttpApiEndpoint string NoScreenshotBytes bool NoHeadlessBody bool ScreenshotTimeout int @@ -385,7 +387,7 @@ func ParseOptions() *Options { ) flagSet.CreateGroup("rate-limit", "Rate-Limit", - flagSet.IntVarP(&options.Threads, "threads", "t", 50, "number of threads to use"), + flagSet.IntVarP(&options.Threads, "threads", "t", defaultThreads, "number of threads to use"), flagSet.IntVarP(&options.RateLimit, "rate-limit", "rl", 150, "maximum requests to send per second"), flagSet.IntVarP(&options.RateLimitMinute, "rate-limit-minute", "rlm", 0, "maximum number of requests to send per minute"), ) @@ -451,6 +453,7 @@ func ParseOptions() *Options { flagSet.BoolVar(&options.NoDecode, "no-decode", false, "avoid decoding body"), flagSet.BoolVarP(&options.TlsImpersonate, "tls-impersonate", "tlsi", false, "enable experimental client hello (ja3) tls randomization"), flagSet.BoolVar(&options.DisableStdin, "no-stdin", false, "Disable Stdin processing"), + flagSet.StringVarP(&options.HttpApiEndpoint, "http-api-endpoint", "hae", "", "experimental http api endpoint"), ) flagSet.CreateGroup("debug", "Debug", @@ -678,6 +681,11 @@ func (options *Options) ValidateOptions() error { return fmt.Errorf("invalid protocol: %s", options.Protocol) } + if options.Threads == 0 { + gologger.Info().Msgf("Threads automatically set to %d", defaultThreads) + options.Threads = defaultThreads + } + return nil } diff --git a/runner/runner.go b/runner/runner.go index e267015a5..3a03ff4bf 100644 --- a/runner/runner.go +++ b/runner/runner.go @@ -22,6 +22,7 @@ import ( "sort" "strconv" "strings" + "sync" "time" "golang.org/x/exp/maps" @@ -52,7 +53,6 @@ import ( urlutil "github.com/projectdiscovery/utils/url" "github.com/projectdiscovery/ratelimit" - "github.com/remeh/sizedwaitgroup" // automatic fd max increase if running as root _ "github.com/projectdiscovery/fdmax/autofdmax" @@ -68,6 +68,7 @@ import ( fileutil "github.com/projectdiscovery/utils/file" pdhttputil "github.com/projectdiscovery/utils/http" iputil "github.com/projectdiscovery/utils/ip" + syncutil "github.com/projectdiscovery/utils/sync" wappalyzer "github.com/projectdiscovery/wappalyzergo" ) @@ -85,6 +86,7 @@ type Runner struct { browser *Browser errorPageClassifier *errorpageclassifier.ErrorPageClassifier pHashClusters []pHashCluster + httpApiEndpoint *Server } // picked based on try-fail but it seems to close to one it's used https://www.hackerfactor.com/blog/index.php?/archives/432-Looks-Like-It.html#c1992 @@ -364,6 +366,17 @@ func New(options *Options) (*Runner, error) { runner.errorPageClassifier = errorpageclassifier.New() + if options.HttpApiEndpoint != "" { + apiServer := NewServer(options.HttpApiEndpoint, options) + gologger.Info().Msgf("Listening api endpoint on: %s", options.HttpApiEndpoint) + runner.httpApiEndpoint = apiServer + go func() { + if err := apiServer.Start(); err != nil { + gologger.Error().Msgf("Failed to start API server: %s", err) + } + }() + } + return runner, nil } @@ -680,12 +693,12 @@ func (r *Runner) RunEnumeration() { } // output routine - wgoutput := sizedwaitgroup.New(2) - wgoutput.Add() + var wgoutput sync.WaitGroup output := make(chan Result) nextStep := make(chan Result) + wgoutput.Add(1) go func(output chan Result, nextSteps ...chan Result) { defer wgoutput.Done() @@ -1065,7 +1078,7 @@ func (r *Runner) RunEnumeration() { // HTML Summary // - needs output of previous routine // - separate goroutine due to incapability of go templates to render from file - wgoutput.Add() + wgoutput.Add(1) go func(output chan Result) { defer wgoutput.Done() @@ -1109,7 +1122,7 @@ func (r *Runner) RunEnumeration() { } }(nextStep) - wg := sizedwaitgroup.New(r.options.Threads) + wg, _ := syncutil.New(syncutil.WithSize(r.options.Threads)) processItem := func(k string) error { if r.options.resumeCfg != nil { @@ -1132,10 +1145,10 @@ func (r *Runner) RunEnumeration() { for _, p := range r.options.requestURIs { scanopts := r.scanopts.Clone() scanopts.RequestURI = p - r.process(k, &wg, r.hp, protocol, scanopts, output) + r.process(k, wg, r.hp, protocol, scanopts, output) } } else { - r.process(k, &wg, r.hp, protocol, &r.scanopts, output) + r.process(k, wg, r.hp, protocol, &r.scanopts, output) } return nil @@ -1224,11 +1237,18 @@ func (r *Runner) GetScanOpts() ScanOptions { return r.scanopts } -func (r *Runner) Process(t string, wg *sizedwaitgroup.SizedWaitGroup, protocol string, scanopts *ScanOptions, output chan Result) { +func (r *Runner) Process(t string, wg *syncutil.AdaptiveWaitGroup, protocol string, scanopts *ScanOptions, output chan Result) { r.process(t, wg, r.hp, protocol, scanopts, output) } -func (r *Runner) process(t string, wg *sizedwaitgroup.SizedWaitGroup, hp *httpx.HTTPX, protocol string, scanopts *ScanOptions, output chan Result) { +func (r *Runner) process(t string, wg *syncutil.AdaptiveWaitGroup, hp *httpx.HTTPX, protocol string, scanopts *ScanOptions, output chan Result) { + // attempts to set the workpool size to the number of threads + if r.options.Threads > 0 && wg.Size != r.options.Threads { + if err := wg.Resize(context.Background(), r.options.Threads); err != nil { + gologger.Error().Msgf("Could not resize workpool: %s\n", err) + } + } + protocols := []string{protocol} if scanopts.NoFallback || protocol == httpx.HTTPandHTTPS { protocols = []string{httpx.HTTPS, httpx.HTTP}