diff --git a/cmd/integration-test/http.go b/cmd/integration-test/http.go index b4f957e692..5b218aa53b 100644 --- a/cmd/integration-test/http.go +++ b/cmd/integration-test/http.go @@ -23,6 +23,7 @@ import ( logutil "github.com/projectdiscovery/utils/log" sliceutil "github.com/projectdiscovery/utils/slice" stringsutil "github.com/projectdiscovery/utils/strings" + unitutils "github.com/projectdiscovery/utils/unit" ) var httpTestcases = []TestCaseInfo{ @@ -509,7 +510,7 @@ func (h *httpPostMultipartBody) Execute(filePath string) error { var routerErr error router.POST("/", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - if err := r.ParseMultipartForm(1 * 1024); err != nil { + if err := r.ParseMultipartForm(unitutils.Mega); err != nil { routerErr = err return } diff --git a/cmd/nuclei/main.go b/cmd/nuclei/main.go index 40984074df..5aea16337c 100644 --- a/cmd/nuclei/main.go +++ b/cmd/nuclei/main.go @@ -38,6 +38,7 @@ import ( "github.com/projectdiscovery/nuclei/v3/pkg/utils/monitor" errorutil "github.com/projectdiscovery/utils/errors" fileutil "github.com/projectdiscovery/utils/file" + unitutils "github.com/projectdiscovery/utils/unit" updateutils "github.com/projectdiscovery/utils/update" ) @@ -304,7 +305,7 @@ on extensive configurability, massive extensibility and ease of use.`) flagSet.StringVarP(&options.AttackType, "attack-type", "at", "", "type of payload combinations to perform (batteringram,pitchfork,clusterbomb)"), flagSet.StringVarP(&options.SourceIP, "source-ip", "sip", "", "source ip address to use for network scan"), flagSet.IntVarP(&options.ResponseReadSize, "response-size-read", "rsr", 0, "max response size to read in bytes"), - flagSet.IntVarP(&options.ResponseSaveSize, "response-size-save", "rss", 1*1024*1024, "max response size to read in bytes"), + flagSet.IntVarP(&options.ResponseSaveSize, "response-size-save", "rss", unitutils.Mega, "max response size to read in bytes"), flagSet.DurationVarP(&options.ResponseReadTimeout, "response-read-timeout", "rrt", time.Duration(5*time.Second), "response read timeout in seconds"), flagSet.CallbackVar(resetCallback, "reset", "reset removes all nuclei configuration and data files (including nuclei-templates)"), flagSet.BoolVarP(&options.TlsImpersonate, "tls-impersonate", "tlsi", false, "enable experimental client hello (ja3) tls randomization"), diff --git a/internal/pdcp/writer.go b/internal/pdcp/writer.go index a76d4ea31a..6e05410355 100644 --- a/internal/pdcp/writer.go +++ b/internal/pdcp/writer.go @@ -19,6 +19,7 @@ import ( "github.com/projectdiscovery/retryablehttp-go" pdcpauth "github.com/projectdiscovery/utils/auth/pdcp" errorutil "github.com/projectdiscovery/utils/errors" + unitutils "github.com/projectdiscovery/utils/unit" updateutils "github.com/projectdiscovery/utils/update" urlutil "github.com/projectdiscovery/utils/url" ) @@ -26,8 +27,8 @@ import ( const ( uploadEndpoint = "/v1/scans/import" appendEndpoint = "/v1/scans/%s/import" - flushTimer = time.Duration(1) * time.Minute - MaxChunkSize = 1024 * 1024 * 4 // 4 MB + flushTimer = time.Minute + MaxChunkSize = 4 * unitutils.Mega // 4 MB xidRe = `^[a-z0-9]{20}$` ) diff --git a/lib/multi.go b/lib/multi.go index a2149ddcd2..44c13ddfe6 100644 --- a/lib/multi.go +++ b/lib/multi.go @@ -128,6 +128,7 @@ func (e *ThreadSafeNucleiEngine) ExecuteNucleiWithOptsCtx(ctx context.Context, t return err } } + // create ephemeral nuclei objects/instances/types using base nuclei engine unsafeOpts, err := createEphemeralObjects(ctx, e.eng, tmpEngine.opts) if err != nil { diff --git a/lib/sdk.go b/lib/sdk.go index 2e23aa49cd..60f5255bcd 100644 --- a/lib/sdk.go +++ b/lib/sdk.go @@ -18,7 +18,7 @@ import ( "github.com/projectdiscovery/nuclei/v3/pkg/protocols" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/hosterrorscache" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/interactsh" - "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate" + "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolinit" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/headless/engine" "github.com/projectdiscovery/nuclei/v3/pkg/reporting" "github.com/projectdiscovery/nuclei/v3/pkg/templates" @@ -183,8 +183,7 @@ func (e *NucleiEngine) SignTemplate(tmplSigner *signer.TemplateSigner, data []by return buff.Bytes(), err } -// Close all resources used by nuclei engine -func (e *NucleiEngine) Close() { +func (e *NucleiEngine) closeInternal() { if e.interactshClient != nil { e.interactshClient.Close() } @@ -206,8 +205,6 @@ func (e *NucleiEngine) Close() { if e.rateLimiter != nil { e.rateLimiter.Stop() } - // close global shared resources - protocolstate.Close() if e.inputProvider != nil { e.inputProvider.Close() } @@ -219,6 +216,12 @@ func (e *NucleiEngine) Close() { } } +// Close all resources used by nuclei engine +func (e *NucleiEngine) Close() { + e.closeInternal() + protocolinit.Close() +} + // ExecuteCallbackWithCtx executes templates on targets and calls callback on each result(only if results are found) // enable matcher-status option if you expect this callback to be called for all results regardless if it matched or not func (e *NucleiEngine) ExecuteCallbackWithCtx(ctx context.Context, callback ...func(event *output.ResultEvent)) error { diff --git a/lib/sdk_private.go b/lib/sdk_private.go index ae61add221..3475200bd0 100644 --- a/lib/sdk_private.go +++ b/lib/sdk_private.go @@ -25,6 +25,7 @@ import ( "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/hosterrorscache" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/interactsh" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolinit" + "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/http/httpclientpool" "github.com/projectdiscovery/nuclei/v3/pkg/reporting" "github.com/projectdiscovery/nuclei/v3/pkg/templates" @@ -34,7 +35,7 @@ import ( "github.com/projectdiscovery/ratelimit" ) -var sharedInit sync.Once = sync.Once{} +var sharedInit *sync.Once // applyRequiredDefaults to options func (e *NucleiEngine) applyRequiredDefaults(ctx context.Context) { @@ -117,6 +118,10 @@ func (e *NucleiEngine) init(ctx context.Context) error { e.parser = templates.NewParser() + if sharedInit == nil || protocolstate.ShouldInit() { + sharedInit = &sync.Once{} + } + sharedInit.Do(func() { _ = protocolinit.Init(e.opts) }) diff --git a/pkg/catalog/loader/loader.go b/pkg/catalog/loader/loader.go index c719f71a5a..e1d8371e35 100644 --- a/pkg/catalog/loader/loader.go +++ b/pkg/catalog/loader/loader.go @@ -7,6 +7,7 @@ import ( "os" "sort" "strings" + "sync" "github.com/logrusorgru/aurora" "github.com/pkg/errors" @@ -23,6 +24,7 @@ import ( "github.com/projectdiscovery/nuclei/v3/pkg/workflows" "github.com/projectdiscovery/retryablehttp-go" errorutil "github.com/projectdiscovery/utils/errors" + sliceutil "github.com/projectdiscovery/utils/slice" stringsutil "github.com/projectdiscovery/utils/strings" urlutil "github.com/projectdiscovery/utils/url" ) @@ -425,10 +427,10 @@ func (store *Store) LoadTemplatesWithTags(templatesList, tags []string) []*templ store.logErroredTemplates(errs) templatePathMap := store.pathFilter.Match(includedTemplates) - loadedTemplates := make([]*templates.Template, 0, len(templatePathMap)) + loadedTemplates := sliceutil.NewSyncSlice[*templates.Template]() loadTemplate := func(tmpl *templates.Template) { - loadedTemplates = append(loadedTemplates, tmpl) + loadedTemplates.Append(tmpl) // increment signed/unsigned counters if tmpl.Verified { if tmpl.TemplateVerifier == "" { @@ -441,80 +443,89 @@ func (store *Store) LoadTemplatesWithTags(templatesList, tags []string) []*templ } } + var wgLoadTemplates sync.WaitGroup + for templatePath := range templatePathMap { - loaded, err := store.config.ExecutorOptions.Parser.LoadTemplate(templatePath, store.tagFilter, tags, store.config.Catalog) - if loaded || store.pathFilter.MatchIncluded(templatePath) { - parsed, err := templates.Parse(templatePath, store.preprocessor, store.config.ExecutorOptions) - if err != nil { - // exclude templates not compatible with offline matching from total runtime warning stats - if !errors.Is(err, templates.ErrIncompatibleWithOfflineMatching) { - stats.Increment(templates.RuntimeWarningsStats) - } - gologger.Warning().Msgf("Could not parse template %s: %s\n", templatePath, err) - } else if parsed != nil { - if !parsed.Verified && store.config.ExecutorOptions.Options.DisableUnsignedTemplates { - // skip unverified templates when prompted to - stats.Increment(templates.SkippedUnsignedStats) - continue - } - // if template has request signature like aws then only signed and verified templates are allowed - if parsed.UsesRequestSignature() && !parsed.Verified { - stats.Increment(templates.SkippedRequestSignatureStats) - continue - } - // DAST only templates - if store.config.ExecutorOptions.Options.DAST { - // check if the template is a DAST template - if parsed.IsFuzzing() { - loadTemplate(parsed) + wgLoadTemplates.Add(1) + go func(templatePath string) { + defer wgLoadTemplates.Done() + + loaded, err := store.config.ExecutorOptions.Parser.LoadTemplate(templatePath, store.tagFilter, tags, store.config.Catalog) + if loaded || store.pathFilter.MatchIncluded(templatePath) { + parsed, err := templates.Parse(templatePath, store.preprocessor, store.config.ExecutorOptions) + if err != nil { + // exclude templates not compatible with offline matching from total runtime warning stats + if !errors.Is(err, templates.ErrIncompatibleWithOfflineMatching) { + stats.Increment(templates.RuntimeWarningsStats) } - } else if len(parsed.RequestsHeadless) > 0 && !store.config.ExecutorOptions.Options.Headless { - // donot include headless template in final list if headless flag is not set - stats.Increment(templates.ExcludedHeadlessTmplStats) - if config.DefaultConfig.LogAllEvents { - gologger.Print().Msgf("[%v] Headless flag is required for headless template '%s'.\n", aurora.Yellow("WRN").String(), templatePath) + gologger.Warning().Msgf("Could not parse template %s: %s\n", templatePath, err) + } else if parsed != nil { + if !parsed.Verified && store.config.ExecutorOptions.Options.DisableUnsignedTemplates { + // skip unverified templates when prompted to + stats.Increment(templates.SkippedUnsignedStats) + return } - } else if len(parsed.RequestsCode) > 0 && !store.config.ExecutorOptions.Options.EnableCodeTemplates { - // donot include 'Code' protocol custom template in final list if code flag is not set - stats.Increment(templates.ExcludedCodeTmplStats) - if config.DefaultConfig.LogAllEvents { - gologger.Print().Msgf("[%v] Code flag is required for code protocol template '%s'.\n", aurora.Yellow("WRN").String(), templatePath) + // if template has request signature like aws then only signed and verified templates are allowed + if parsed.UsesRequestSignature() && !parsed.Verified { + stats.Increment(templates.SkippedRequestSignatureStats) + return } - } else if len(parsed.RequestsCode) > 0 && !parsed.Verified && len(parsed.Workflows) == 0 { - // donot include unverified 'Code' protocol custom template in final list - stats.Increment(templates.SkippedCodeTmplTamperedStats) - // these will be skipped so increment skip counter - stats.Increment(templates.SkippedUnsignedStats) - if config.DefaultConfig.LogAllEvents { - gologger.Print().Msgf("[%v] Tampered/Unsigned template at %v.\n", aurora.Yellow("WRN").String(), templatePath) - } - } else if parsed.IsFuzzing() && !store.config.ExecutorOptions.Options.DAST { - stats.Increment(templates.ExludedDastTmplStats) - if config.DefaultConfig.LogAllEvents { - gologger.Print().Msgf("[%v] -dast flag is required for DAST template '%s'.\n", aurora.Yellow("WRN").String(), templatePath) + // DAST only templates + if store.config.ExecutorOptions.Options.DAST { + // check if the template is a DAST template + if parsed.IsFuzzing() { + loadTemplate(parsed) + } + } else if len(parsed.RequestsHeadless) > 0 && !store.config.ExecutorOptions.Options.Headless { + // donot include headless template in final list if headless flag is not set + stats.Increment(templates.ExcludedHeadlessTmplStats) + if config.DefaultConfig.LogAllEvents { + gologger.Print().Msgf("[%v] Headless flag is required for headless template '%s'.\n", aurora.Yellow("WRN").String(), templatePath) + } + } else if len(parsed.RequestsCode) > 0 && !store.config.ExecutorOptions.Options.EnableCodeTemplates { + // donot include 'Code' protocol custom template in final list if code flag is not set + stats.Increment(templates.ExcludedCodeTmplStats) + if config.DefaultConfig.LogAllEvents { + gologger.Print().Msgf("[%v] Code flag is required for code protocol template '%s'.\n", aurora.Yellow("WRN").String(), templatePath) + } + } else if len(parsed.RequestsCode) > 0 && !parsed.Verified && len(parsed.Workflows) == 0 { + // donot include unverified 'Code' protocol custom template in final list + stats.Increment(templates.SkippedCodeTmplTamperedStats) + // these will be skipped so increment skip counter + stats.Increment(templates.SkippedUnsignedStats) + if config.DefaultConfig.LogAllEvents { + gologger.Print().Msgf("[%v] Tampered/Unsigned template at %v.\n", aurora.Yellow("WRN").String(), templatePath) + } + } else if parsed.IsFuzzing() && !store.config.ExecutorOptions.Options.DAST { + stats.Increment(templates.ExludedDastTmplStats) + if config.DefaultConfig.LogAllEvents { + gologger.Print().Msgf("[%v] -dast flag is required for DAST template '%s'.\n", aurora.Yellow("WRN").String(), templatePath) + } + } else { + loadTemplate(parsed) } - } else { - loadTemplate(parsed) } } - } - if err != nil { - if strings.Contains(err.Error(), templates.ErrExcluded.Error()) { - stats.Increment(templates.TemplatesExcludedStats) - if config.DefaultConfig.LogAllEvents { - gologger.Print().Msgf("[%v] %v\n", aurora.Yellow("WRN").String(), err.Error()) + if err != nil { + if strings.Contains(err.Error(), templates.ErrExcluded.Error()) { + stats.Increment(templates.TemplatesExcludedStats) + if config.DefaultConfig.LogAllEvents { + gologger.Print().Msgf("[%v] %v\n", aurora.Yellow("WRN").String(), err.Error()) + } + return } - continue + gologger.Warning().Msg(err.Error()) } - gologger.Warning().Msg(err.Error()) - } + }(templatePath) } - sort.SliceStable(loadedTemplates, func(i, j int) bool { - return loadedTemplates[i].Path < loadedTemplates[j].Path + wgLoadTemplates.Wait() + + sort.SliceStable(loadedTemplates.Slice, func(i, j int) bool { + return loadedTemplates.Slice[i].Path < loadedTemplates.Slice[j].Path }) - return loadedTemplates + return loadedTemplates.Slice } // IsHTTPBasedProtocolUsed returns true if http/headless protocol is being used for diff --git a/pkg/output/output.go b/pkg/output/output.go index 1ee710ff3b..449b73a2d1 100644 --- a/pkg/output/output.go +++ b/pkg/output/output.go @@ -33,6 +33,7 @@ import ( "github.com/projectdiscovery/utils/errkit" fileutil "github.com/projectdiscovery/utils/file" osutils "github.com/projectdiscovery/utils/os" + unitutils "github.com/projectdiscovery/utils/unit" urlutil "github.com/projectdiscovery/utils/url" ) @@ -449,7 +450,7 @@ func (w *StandardWriter) WriteFailure(wrappedEvent *InternalWrappedEvent) error return w.Write(data) } -var maxTemplateFileSizeForEncoding = 1024 * 1024 +var maxTemplateFileSizeForEncoding = unitutils.Mega func (w *StandardWriter) encodeTemplate(templatePath string) string { data, err := os.ReadFile(templatePath) diff --git a/pkg/protocols/common/automaticscan/automaticscan.go b/pkg/protocols/common/automaticscan/automaticscan.go index 2a7988cb04..a5e51c177a 100644 --- a/pkg/protocols/common/automaticscan/automaticscan.go +++ b/pkg/protocols/common/automaticscan/automaticscan.go @@ -32,13 +32,14 @@ import ( sliceutil "github.com/projectdiscovery/utils/slice" stringsutil "github.com/projectdiscovery/utils/strings" syncutil "github.com/projectdiscovery/utils/sync" + unitutils "github.com/projectdiscovery/utils/unit" wappalyzer "github.com/projectdiscovery/wappalyzergo" "gopkg.in/yaml.v2" ) const ( mappingFilename = "wappalyzer-mapping.yml" - maxDefaultBody = 4 * 1024 * 1024 // 4MB + maxDefaultBody = 4 * unitutils.Mega ) // Options contains configuration options for automatic scan service diff --git a/pkg/protocols/common/protocolinit/init.go b/pkg/protocols/common/protocolinit/init.go index 59f7dd40b0..c8268337f5 100644 --- a/pkg/protocols/common/protocolinit/init.go +++ b/pkg/protocols/common/protocolinit/init.go @@ -38,6 +38,5 @@ func Init(options *types.Options) error { } func Close() { - protocolstate.Dialer.Close() - protocolstate.Dialer = nil + protocolstate.Close() } diff --git a/pkg/protocols/common/protocolstate/state.go b/pkg/protocols/common/protocolstate/state.go index 89675e7696..cf6f5234ec 100644 --- a/pkg/protocols/common/protocolstate/state.go +++ b/pkg/protocols/common/protocolstate/state.go @@ -22,6 +22,10 @@ var ( Dialer *fastdialer.Dialer ) +func ShouldInit() bool { + return Dialer == nil +} + // Init creates the Dialer instance based on user configuration func Init(options *types.Options) error { if Dialer != nil { @@ -212,5 +216,6 @@ func Close() { Dialer.Close() Dialer = nil } + Dialer = nil StopActiveMemGuardian() } diff --git a/pkg/protocols/http/httpclientpool/clientpool.go b/pkg/protocols/http/httpclientpool/clientpool.go index 4200e25b78..b4ca5d40cf 100644 --- a/pkg/protocols/http/httpclientpool/clientpool.go +++ b/pkg/protocols/http/httpclientpool/clientpool.go @@ -16,7 +16,6 @@ import ( "golang.org/x/net/proxy" "golang.org/x/net/publicsuffix" - "github.com/projectdiscovery/fastdialer/fastdialer" "github.com/projectdiscovery/fastdialer/fastdialer/ja3/impersonate" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/utils" @@ -28,9 +27,6 @@ import ( ) var ( - // Dialer is a copy of the fastdialer from protocolstate - Dialer *fastdialer.Dialer - rawHttpClient *rawhttp.Client forceMaxRedirects int normalClient *retryablehttp.Client @@ -154,8 +150,8 @@ func GetRawHTTP(options *types.Options) *rawhttp.Client { rawHttpOptions.Proxy = types.ProxyURL } else if types.ProxySocksURL != "" { rawHttpOptions.Proxy = types.ProxySocksURL - } else if Dialer != nil { - rawHttpOptions.FastDialer = Dialer + } else if protocolstate.Dialer != nil { + rawHttpOptions.FastDialer = protocolstate.Dialer } rawHttpOptions.Timeout = GetHttpTimeout(options) rawHttpClient = rawhttp.NewClient(rawHttpOptions) @@ -175,10 +171,6 @@ func Get(options *types.Options, configuration *Configuration) (*retryablehttp.C func wrappedGet(options *types.Options, configuration *Configuration) (*retryablehttp.Client, error) { var err error - if Dialer == nil { - Dialer = protocolstate.Dialer - } - hash := configuration.Hash() if client, ok := clientPool.Get(hash); ok { return client, nil @@ -254,15 +246,15 @@ func wrappedGet(options *types.Options, configuration *Configuration) (*retryabl transport := &http.Transport{ ForceAttemptHTTP2: options.ForceAttemptHTTP2, - DialContext: Dialer.Dial, + DialContext: protocolstate.Dialer.Dial, DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { if options.TlsImpersonate { - return Dialer.DialTLSWithConfigImpersonate(ctx, network, addr, tlsConfig, impersonate.Random, nil) + return protocolstate.Dialer.DialTLSWithConfigImpersonate(ctx, network, addr, tlsConfig, impersonate.Random, nil) } if options.HasClientCertificates() || options.ForceAttemptHTTP2 { - return Dialer.DialTLSWithConfig(ctx, network, addr, tlsConfig) + return protocolstate.Dialer.DialTLSWithConfig(ctx, network, addr, tlsConfig) } - return Dialer.DialTLS(ctx, network, addr) + return protocolstate.Dialer.DialTLS(ctx, network, addr) }, MaxIdleConns: maxIdleConns, MaxIdleConnsPerHost: maxIdleConnsPerHost, diff --git a/pkg/protocols/http/request.go b/pkg/protocols/http/request.go index 7b6c6cee24..7e5e069348 100644 --- a/pkg/protocols/http/request.go +++ b/pkg/protocols/http/request.go @@ -44,6 +44,7 @@ import ( "github.com/projectdiscovery/utils/reader" sliceutil "github.com/projectdiscovery/utils/slice" stringsutil "github.com/projectdiscovery/utils/strings" + unitutils "github.com/projectdiscovery/utils/unit" urlutil "github.com/projectdiscovery/utils/url" ) @@ -55,7 +56,7 @@ const ( ) var ( - MaxBodyRead = int64(10 * 1024 * 1024) // 10MB + MaxBodyRead = 10 * unitutils.Mega // ErrMissingVars is error occured when variables are missing ErrMissingVars = errkit.New("stop execution due to unresolved variables").SetKind(nucleierr.ErrTemplateLogic).Build() // ErrHttpEngineRequestDeadline is error occured when request deadline set by http request engine is exceeded @@ -597,7 +598,7 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, dynamicVa return requestErr } -const drainReqSize = int64(8 * 1024) +const drainReqSize = int64(8 * unitutils.Kilo) // executeRequest executes the actual generated request and returns error if occurred func (request *Request) executeRequest(input *contextargs.Context, generatedRequest *generatedRequest, previousEvent output.InternalEvent, hasInteractMatchers bool, processEvent protocols.OutputEventCallback, requestCount int) (err error) { @@ -842,7 +843,7 @@ func (request *Request) executeRequest(input *contextargs.Context, generatedRequ if input.MetaInput.CustomIP != "" { outputEvent["ip"] = input.MetaInput.CustomIP } else { - outputEvent["ip"] = httpclientpool.Dialer.GetDialedIP(hostname) + outputEvent["ip"] = protocolstate.Dialer.GetDialedIP(hostname) } if len(generatedRequest.interactshURLs) > 0 { @@ -875,15 +876,15 @@ func (request *Request) executeRequest(input *contextargs.Context, generatedRequ // define max body read limit maxBodylimit := MaxBodyRead // 10MB if request.MaxSize > 0 { - maxBodylimit = int64(request.MaxSize) + maxBodylimit = request.MaxSize } if request.options.Options.ResponseReadSize != 0 { - maxBodylimit = int64(request.options.Options.ResponseReadSize) + maxBodylimit = request.options.Options.ResponseReadSize } // respChain is http response chain that reads response body // efficiently by reusing buffers and does all decoding and optimizations - respChain := httpUtils.NewResponseChain(resp, maxBodylimit) + respChain := httpUtils.NewResponseChain(resp, int64(maxBodylimit)) defer respChain.Close() // reuse buffers // we only intend to log/save the final redirected response @@ -938,7 +939,7 @@ func (request *Request) executeRequest(input *contextargs.Context, generatedRequ if input.MetaInput.CustomIP != "" { outputEvent["ip"] = input.MetaInput.CustomIP } else { - outputEvent["ip"] = httpclientpool.Dialer.GetDialedIP(hostname) + outputEvent["ip"] = protocolstate.Dialer.GetDialedIP(hostname) } if request.options.Interactsh != nil { request.options.Interactsh.MakePlaceholders(generatedRequest.interactshURLs, outputEvent) diff --git a/pkg/protocols/offlinehttp/request.go b/pkg/protocols/offlinehttp/request.go index 4a440c167f..e913e02d8f 100644 --- a/pkg/protocols/offlinehttp/request.go +++ b/pkg/protocols/offlinehttp/request.go @@ -17,11 +17,12 @@ import ( templateTypes "github.com/projectdiscovery/nuclei/v3/pkg/templates/types" "github.com/projectdiscovery/utils/conversion" syncutil "github.com/projectdiscovery/utils/sync" + unitutils "github.com/projectdiscovery/utils/unit" ) var _ protocols.Request = &Request{} -const maxSize = 5 * 1024 * 1024 +const maxSize = 5 * unitutils.Mega // Type returns the type of the protocol request func (request *Request) Type() templateTypes.ProtocolType { diff --git a/pkg/protocols/protocols.go b/pkg/protocols/protocols.go index 68881fec5c..af7d2b4766 100644 --- a/pkg/protocols/protocols.go +++ b/pkg/protocols/protocols.go @@ -34,10 +34,11 @@ import ( "github.com/projectdiscovery/nuclei/v3/pkg/scan" templateTypes "github.com/projectdiscovery/nuclei/v3/pkg/templates/types" "github.com/projectdiscovery/nuclei/v3/pkg/types" + unitutils "github.com/projectdiscovery/utils/unit" ) var ( - MaxTemplateFileSizeForEncoding = 1024 * 1024 + MaxTemplateFileSizeForEncoding = unitutils.Mega ) // Executer is an interface implemented any protocol based request executer. diff --git a/pkg/reporting/format/format_utils.go b/pkg/reporting/format/format_utils.go index 62a3d75fe8..d5ec0adfb1 100644 --- a/pkg/reporting/format/format_utils.go +++ b/pkg/reporting/format/format_utils.go @@ -12,6 +12,7 @@ import ( "github.com/projectdiscovery/nuclei/v3/pkg/reporting/exporters/markdown/util" "github.com/projectdiscovery/nuclei/v3/pkg/types" "github.com/projectdiscovery/nuclei/v3/pkg/utils" + unitutils "github.com/projectdiscovery/utils/unit" ) // Summary returns a formatted built one line summary of the event @@ -71,7 +72,7 @@ func CreateReportDescription(event *output.ResultEvent, formatter ResultFormatte if event.Response != "" { var responseString string // If the response is larger than 5 kb, truncate it before writing. - maxKbSize := 5 * 1024 + maxKbSize := 5 * unitutils.Kilo if len(event.Response) > maxKbSize { responseString = event.Response[:maxKbSize] responseString += ".... Truncated ...." diff --git a/pkg/testutils/testutils.go b/pkg/testutils/testutils.go index 4ce40ace96..6e6f94d9eb 100644 --- a/pkg/testutils/testutils.go +++ b/pkg/testutils/testutils.go @@ -24,6 +24,7 @@ import ( protocolUtils "github.com/projectdiscovery/nuclei/v3/pkg/protocols/utils" "github.com/projectdiscovery/nuclei/v3/pkg/types" "github.com/projectdiscovery/nuclei/v3/pkg/utils" + unitutils "github.com/projectdiscovery/utils/unit" ) // Init initializes the protocols and their configurations @@ -203,7 +204,7 @@ func (m *MockOutputWriter) WriteFailure(wrappedEvent *output.InternalWrappedEven return m.Write(data) } -var maxTemplateFileSizeForEncoding = 1024 * 1024 +var maxTemplateFileSizeForEncoding = unitutils.Mega func (w *MockOutputWriter) encodeTemplate(templatePath string) string { data, err := os.ReadFile(templatePath) diff --git a/pkg/types/types.go b/pkg/types/types.go index 28f98f9198..f49da194cc 100644 --- a/pkg/types/types.go +++ b/pkg/types/types.go @@ -15,6 +15,7 @@ import ( errorutil "github.com/projectdiscovery/utils/errors" fileutil "github.com/projectdiscovery/utils/file" folderutil "github.com/projectdiscovery/utils/folder" + unitutils "github.com/projectdiscovery/utils/unit" ) var ( @@ -441,8 +442,8 @@ func DefaultOptions() *Options { Timeout: 5, Retries: 1, MaxHostError: 30, - ResponseReadSize: 10 * 1024 * 1024, - ResponseSaveSize: 1024 * 1024, + ResponseReadSize: 10 * unitutils.Mega, + ResponseSaveSize: unitutils.Mega, ResponseReadTimeout: 5 * time.Second, } } diff --git a/pkg/utils/monitor/monitor.go b/pkg/utils/monitor/monitor.go index 6441fdf20d..cefb9506b8 100644 --- a/pkg/utils/monitor/monitor.go +++ b/pkg/utils/monitor/monitor.go @@ -16,6 +16,7 @@ import ( "github.com/DataDog/gostackparse" "github.com/projectdiscovery/gologger" permissionutil "github.com/projectdiscovery/utils/permission" + unitutils "github.com/projectdiscovery/utils/unit" "github.com/rs/xid" ) @@ -118,7 +119,7 @@ func (s *Agent) monitorWorker(cancel context.CancelFunc) { // getStack returns full stack trace of the program var getStack = func(all bool) []byte { - for i := 1024 * 1024; ; i *= 2 { + for i := unitutils.Mega; ; i *= 2 { buf := make([]byte, i) if n := runtime.Stack(buf, all); n < i { return buf[:n-1] diff --git a/pkg/utils/stats/stats.go b/pkg/utils/stats/stats.go index 590608d1e2..f1edf2ea33 100644 --- a/pkg/utils/stats/stats.go +++ b/pkg/utils/stats/stats.go @@ -2,23 +2,22 @@ package stats import ( "fmt" - "sync" "sync/atomic" "github.com/logrusorgru/aurora" "github.com/projectdiscovery/gologger" + mapsutil "github.com/projectdiscovery/utils/maps" ) // Storage is a storage for storing statistics information // about the nuclei engine displaying it at user-defined intervals. type Storage struct { - data map[string]*storageDataItem - mutex *sync.RWMutex + data *mapsutil.SyncLockMap[string, *storageDataItem] } type storageDataItem struct { description string - value int64 + value atomic.Int64 } var Default *Storage @@ -59,38 +58,32 @@ func GetValue(name string) int64 { // New creates a new storage object func New() *Storage { - return &Storage{data: make(map[string]*storageDataItem), mutex: &sync.RWMutex{}} + data := mapsutil.NewSyncLockMap[string, *storageDataItem]() + return &Storage{data: data} } // NewEntry creates a new entry in the storage object func (s *Storage) NewEntry(name, description string) { - s.mutex.Lock() - s.data[name] = &storageDataItem{description: description, value: 0} - s.mutex.Unlock() + _ = s.data.Set(name, &storageDataItem{description: description, value: atomic.Int64{}}) } // Increment increments the value for a name string func (s *Storage) Increment(name string) { - s.mutex.RLock() - data, ok := s.data[name] - s.mutex.RUnlock() + data, ok := s.data.Get(name) if !ok { return } - - atomic.AddInt64(&data.value, 1) + data.value.Add(1) } // Display displays the stats for a name func (s *Storage) Display(name string) { - s.mutex.RLock() - data, ok := s.data[name] - s.mutex.RUnlock() + data, ok := s.data.Get(name) if !ok { return } - dataValue := atomic.LoadInt64(&data.value) + dataValue := data.value.Load() if dataValue == 0 { return // don't show for nil stats } @@ -98,14 +91,12 @@ func (s *Storage) Display(name string) { } func (s *Storage) DisplayAsWarning(name string) { - s.mutex.RLock() - data, ok := s.data[name] - s.mutex.RUnlock() + data, ok := s.data.Get(name) if !ok { return } - dataValue := atomic.LoadInt64(&data.value) + dataValue := data.value.Load() if dataValue == 0 { return // don't show for nil stats } @@ -115,14 +106,12 @@ func (s *Storage) DisplayAsWarning(name string) { // ForceDisplayWarning forces the display of a warning // regardless of current verbosity level func (s *Storage) ForceDisplayWarning(name string) { - s.mutex.RLock() - data, ok := s.data[name] - s.mutex.RUnlock() + data, ok := s.data.Get(name) if !ok { return } - dataValue := atomic.LoadInt64(&data.value) + dataValue := data.value.Load() if dataValue == 0 { return // don't show for nil stats } @@ -131,13 +120,10 @@ func (s *Storage) ForceDisplayWarning(name string) { // GetValue returns the value for a set variable func (s *Storage) GetValue(name string) int64 { - s.mutex.RLock() - data, ok := s.data[name] - s.mutex.RUnlock() + data, ok := s.data.Get(name) if !ok { return 0 } - dataValue := atomic.LoadInt64(&data.value) - return dataValue + return data.value.Load() }