diff --git a/main.go b/main.go index c1f7a62f..299051c6 100644 --- a/main.go +++ b/main.go @@ -49,7 +49,6 @@ func main() { runnerConfigOptions := job.NewConfigOptionsWithFlags() jobsGlobalConfig := job.NewGlobalConfigWithFlags() otaConfig := ota.NewConfigWithFlags() - countryCheckerConfig := utils.NewCountryCheckerConfigWithFlags() updaterMode, destinationPath := config.NewUpdaterOptionsWithFlags() prometheusOn, prometheusListenAddress := metrics.NewOptionsWithFlags() pprof := flag.String("pprof", utils.GetEnvStringDefault("GO_PPROF_ENDPOINT", ""), "enable pprof") @@ -100,8 +99,7 @@ func main() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - metrics.InitOrFail(ctx, logger, *prometheusOn, *prometheusListenAddress, jobsGlobalConfig.ClientID, - utils.CheckCountryOrFail(ctx, logger, countryCheckerConfig, jobsGlobalConfig.GetProxyParams(logger, nil))) + metrics.InitOrFail(ctx, logger, *prometheusOn, *prometheusListenAddress, jobsGlobalConfig.ClientID, "") job.NewRunner(runnerConfigOptions, jobsGlobalConfig, newReporter(*logFormat, *lessStats, logger)).Run(ctx, logger) } diff --git a/src/job/base.go b/src/job/base.go index 38a83e70..125c7c38 100644 --- a/src/job/base.go +++ b/src/job/base.go @@ -26,7 +26,12 @@ package job import ( "context" "flag" + "io" "math/rand" + "net/http" + "net/url" + "os" + "path/filepath" "time" "github.com/google/uuid" @@ -43,9 +48,11 @@ type GlobalConfig struct { ClientID string UserID string - ProxyURLs string - LocalAddr string - Interface string + proxyURLs string + proxylist string + defaultProxyProto string + localAddr string + iface string SkipEncrypted bool EnablePrimitiveJobs bool ScaleFactor float64 @@ -62,11 +69,13 @@ func NewGlobalConfigWithFlags() *GlobalConfig { flag.StringVar(&res.UserID, "user-id", utils.GetEnvStringDefault("USER_ID", ""), "user id for optional metrics") - flag.StringVar(&res.ProxyURLs, "proxy", utils.GetEnvStringDefault("SYSTEM_PROXY", ""), + flag.StringVar(&res.proxyURLs, "proxy", utils.GetEnvStringDefault("SYSTEM_PROXY", ""), "system proxy to set by default (can be a comma-separated list or a template)") - flag.StringVar(&res.LocalAddr, "local-address", utils.GetEnvStringDefault("LOCAL_ADDRESS", ""), + flag.StringVar(&res.proxylist, "proxylist", "", "file or url to read a list of proxies from") + flag.StringVar(&res.defaultProxyProto, "default-proxy-proto", "socks5", "protocol to fallback to if proxy contains only address") + flag.StringVar(&res.localAddr, "local-address", utils.GetEnvStringDefault("LOCAL_ADDRESS", ""), "specify ip address of local interface to use") - flag.StringVar(&res.Interface, "interface", utils.GetEnvStringDefault("NETWORK_INTERFACE", ""), + flag.StringVar(&res.iface, "interface", utils.GetEnvStringDefault("NETWORK_INTERFACE", ""), "specify which interface to bind to for attacks (ignored on windows)") flag.BoolVar(&res.SkipEncrypted, "skip-encrypted", utils.GetEnvBoolDefault("SKIP_ENCRYPTED", false), "set to true if you want to only run plaintext jobs from the config for security considerations") @@ -91,16 +100,61 @@ func NewGlobalConfigWithFlags() *GlobalConfig { func (g GlobalConfig) GetProxyParams(logger *zap.Logger, data any) utils.ProxyParams { return utils.ProxyParams{ - URLs: templates.ParseAndExecute(logger, g.ProxyURLs, data), - LocalAddr: templates.ParseAndExecute(logger, g.LocalAddr, data), - Interface: templates.ParseAndExecute(logger, g.Interface, data), + URLs: templates.ParseAndExecute(logger, g.proxyURLs, data), + DefaultProto: g.defaultProxyProto, + LocalAddr: templates.ParseAndExecute(logger, g.localAddr, data), + Interface: templates.ParseAndExecute(logger, g.iface, data), } } +func (g *GlobalConfig) initProxylist(ctx context.Context) error { + if g.proxyURLs != "" || g.proxylist == "" { + return nil + } + proxylist, err := readProxylist(ctx, g.proxylist) + if err != nil { + return err + } + g.proxyURLs = string(proxylist) + return nil +} + +func readProxylist(ctx context.Context, path string) ([]byte, error) { + proxylistURL, err := url.ParseRequestURI(path) + // absolute paths can be interpreted as a URL with no schema, need to check for that explicitly + if err != nil || filepath.IsAbs(path) { + res, err := os.ReadFile(path) + if err != nil { + return nil, err + } + + return res, nil + } + + const requestTimeout = 20 * time.Second + + ctx, cancel := context.WithTimeout(ctx, requestTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, proxylistURL.String(), nil) + if err != nil { + return nil, err + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + return io.ReadAll(resp.Body) +} + // Job comment for linter type Job = func(ctx context.Context, args config.Args, globalConfig *GlobalConfig, a *metrics.Accumulator, logger *zap.Logger) (data any, err error) // Get job by type name +// //nolint:cyclop // The string map alternative is orders of magnitude slower func Get(t string) Job { switch t { diff --git a/src/job/config/defaultconfig.go b/src/job/config/defaultconfig.go index 92d25fad..acf57be0 100644 --- a/src/job/config/defaultconfig.go +++ b/src/job/config/defaultconfig.go @@ -5,6 +5,7 @@ import ( ) // DefaultConfig is the config embedded into the app that it will use if not able to fetch any other config +// //nolint:lll // Makes no sense splitting this into multiple lines var DefaultConfig = `` diff --git a/src/job/runner.go b/src/job/runner.go index b69ea08b..406e7d75 100644 --- a/src/job/runner.go +++ b/src/job/runner.go @@ -84,6 +84,10 @@ func NewRunner(cfgOptions *ConfigOptions, globalJobsCfg *GlobalConfig, reporter // Run the runner and block until Stop() is called func (r *Runner) Run(ctx context.Context, logger *zap.Logger) { + if err := r.globalJobsCfg.initProxylist(ctx); err != nil { + logger.Warn("failed to init proxylist", zap.Error(err)) + } + ctx = context.WithValue(ctx, templates.ContextKey("goos"), runtime.GOOS) ctx = context.WithValue(ctx, templates.ContextKey("goarch"), runtime.GOARCH) ctx = context.WithValue(ctx, templates.ContextKey("version"), ota.Version) diff --git a/src/utils/countrychecker.go b/src/utils/countrychecker.go deleted file mode 100644 index 0f3050c0..00000000 --- a/src/utils/countrychecker.go +++ /dev/null @@ -1,142 +0,0 @@ -package utils - -import ( - "context" - "encoding/json" - "flag" - "fmt" - "net" - "strings" - "time" - - "github.com/valyala/fasthttp" - "go.uber.org/zap" -) - -type CountryCheckerConfig struct { - countryBlackListCSV string - strict bool - interval time.Duration - maxRetries int -} - -// NewGlobalConfigWithFlags returns a GlobalConfig initialized with command line flags. -func NewCountryCheckerConfigWithFlags() *CountryCheckerConfig { - const maxFetchRetries = 3 - - var res CountryCheckerConfig - - flag.StringVar(&res.countryBlackListCSV, "country-list", GetEnvStringDefault("COUNTRY_LIST", "Ukraine"), "comma-separated list of countries") - flag.BoolVar(&res.strict, "strict-country-check", GetEnvBoolDefault("STRICT_COUNTRY_CHECK", false), - "enable strict country check; will also exit if IP can't be determined") - flag.IntVar(&res.maxRetries, "country-check-retries", GetEnvIntDefault("COUNTRY_CHECK_RETRIES", maxFetchRetries), - "how much retries should be made when checking the country") - flag.DurationVar(&res.interval, "country-check-interval", GetEnvDurationDefault("COUNTRY_CHECK_INTERVAL", 0), - "run country check in background with a regular interval") - - return &res -} - -// CheckCountryOrFail checks the country of client origin by IP and exits the program if it is in the blacklist. -func CheckCountryOrFail(ctx context.Context, logger *zap.Logger, cfg *CountryCheckerConfig, proxyParams ProxyParams) string { - if cfg.interval != 0 { - go func() { - ticker := time.NewTicker(cfg.interval) - - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - _ = ckeckCountryOnce(ctx, logger, cfg, proxyParams) - } - } - }() - } - - return ckeckCountryOnce(ctx, logger, cfg, proxyParams) -} - -func ckeckCountryOnce(ctx context.Context, logger *zap.Logger, cfg *CountryCheckerConfig, proxyParams ProxyParams) string { - country, ip, err := getCountry(ctx, logger, proxyParams, cfg.maxRetries) - if err != nil { - if cfg.strict { - logger.Fatal("country strict check failed", zap.Error(err)) - } - - return "" - } - - logger.Info("location info", zap.String("country", country), zap.String("ip", ip)) - - if strings.Contains(cfg.countryBlackListCSV, country) { - logger.Warn("you might need to enable VPN.") - - if cfg.strict { - logger.Fatal("country strict check failed", zap.String("country", country)) - } - } - - return country -} - -func getCountry(ctx context.Context, logger *zap.Logger, proxyParams ProxyParams, maxFetchRetries int) (country, ip string, err error) { - counter := Counter{Count: maxFetchRetries} - backoffController := BackoffController{BackoffConfig: DefaultBackoffConfig()} - - for counter.Next() { - logger.Info("checking IP address,", zap.Int("iter", counter.iter)) - - if country, ip, err = fetchLocationInfo(ctx, proxyParams); err != nil { - logger.Warn("error fetching location info", zap.Error(err)) - Sleep(ctx, backoffController.Increment().GetTimeout()) - } else { - return - } - } - - return "", "", fmt.Errorf("couldn't get location info in %d tries", maxFetchRetries) -} - -func fetchLocationInfo(ctx context.Context, proxyParams ProxyParams) (country, ip string, err error) { - const ( - ipCheckerURI = "https://api.myip.com/" - requestTimeout = 3 * time.Second - ) - - proxyFunc := GetProxyFunc(ctx, proxyParams, "http") - - client := &fasthttp.Client{ - MaxConnDuration: requestTimeout, - ReadTimeout: requestTimeout, - WriteTimeout: requestTimeout, - MaxIdleConnDuration: requestTimeout, - Dial: func(addr string) (net.Conn, error) { - return proxyFunc("tcp", addr) - }, - } - - req, resp := fasthttp.AcquireRequest(), fasthttp.AcquireResponse() - defer func() { - fasthttp.ReleaseRequest(req) - fasthttp.ReleaseResponse(resp) - }() - - req.SetRequestURI(ipCheckerURI) - req.Header.SetMethod(fasthttp.MethodGet) - - if err := client.Do(req, resp); err != nil { - return "", "", err - } - - ipInfo := struct { - Country string `json:"country"` - IP string `json:"ip"` - }{} - - if err := json.Unmarshal(resp.Body(), &ipInfo); err != nil { - return "", "", err - } - - return ipInfo.Country, ipInfo.IP, nil -} diff --git a/src/utils/proxy.go b/src/utils/proxy.go index 9d758633..9f84d355 100644 --- a/src/utils/proxy.go +++ b/src/utils/proxy.go @@ -17,10 +17,11 @@ import ( type ProxyFunc func(network, addr string) (net.Conn, error) type ProxyParams struct { - URLs string - LocalAddr string - Interface string - Timeout time.Duration + URLs string + DefaultProto string + LocalAddr string + Interface string + Timeout time.Duration } // this won't work for udp payloads but if people use proxies they might not want to have their ip exposed @@ -31,13 +32,18 @@ func GetProxyFunc(ctx context.Context, params ProxyParams, protocol string) Prox return proxy.FromEnvironmentUsing(direct).Dial } - proxies := strings.Split(params.URLs, ",") + proxies := strings.Fields(strings.ReplaceAll(params.URLs, ",", " ")) // We need to dial new proxy on each call return func(network, addr string) (net.Conn, error) { - u, err := url.Parse(proxies[rand.Intn(len(proxies))]) //nolint:gosec // Cryptographically secure random not required + selected := proxies[rand.Intn(len(proxies))] //nolint:gosec // Cryptographically secure random not required + u, err := url.Parse(selected) if err != nil { - return nil, fmt.Errorf("error building proxy %v: %w", u.String(), err) + selected = params.DefaultProto + "://" + selected + u, err = url.Parse(selected) + if err != nil { + return nil, fmt.Errorf("error building proxy %v: %w", selected, err) + } } switch u.Scheme {