From d5271ad7eaac6230edfa548520670856f36a51f1 Mon Sep 17 00:00:00 2001 From: Dan Salmon Date: Sat, 19 Aug 2023 22:12:56 -0500 Subject: [PATCH 1/5] move most of main to cmd --- cmd/s3scanner/args.go | 83 +++++++ cmd/s3scanner/s3scanner.go | 401 +++++++++++++++++++++++++++++++ main.go | 468 +------------------------------------ 3 files changed, 486 insertions(+), 466 deletions(-) create mode 100644 cmd/s3scanner/args.go create mode 100644 cmd/s3scanner/s3scanner.go diff --git a/cmd/s3scanner/args.go b/cmd/s3scanner/args.go new file mode 100644 index 0000000..2edbc07 --- /dev/null +++ b/cmd/s3scanner/args.go @@ -0,0 +1,83 @@ +package s3scanner + +import ( + "errors" + "fmt" + log "github.com/sirupsen/logrus" + "github.com/spf13/viper" + "os" +) + +type ArgCollection struct { + BucketFile string + BucketName string + DoEnumerate bool + Json bool + ProviderFlag string + Threads int + UseMq bool + Verbose bool + Version bool + WriteToDB bool +} + +func (args ArgCollection) Validate() error { + // Validate: only 1 input flag is provided + numInputFlags := 0 + if args.UseMq { + numInputFlags += 1 + } + if args.BucketName != "" { + numInputFlags += 1 + } + if args.BucketFile != "" { + numInputFlags += 1 + } + if numInputFlags != 1 { + return errors.New("exactly one of: -bucket, -bucket-file, -mq required") + } + + return nil +} + +/* +validateConfig checks that the config file contains all necessary keys according to the args specified +*/ +func validateConfig(args ArgCollection) error { + expectedKeys := []string{} + configFileRequired := false + if args.ProviderFlag == "custom" { + configFileRequired = true + expectedKeys = append(expectedKeys, []string{"providers.custom.insecure", "providers.custom.endpoint_format", "providers.custom.regions", "providers.custom.address_style"}...) + } + if args.WriteToDB { + configFileRequired = true + expectedKeys = append(expectedKeys, []string{"db.uri"}...) + } + if args.UseMq { + configFileRequired = true + expectedKeys = append(expectedKeys, []string{"mq.queue_name", "mq.uri"}...) + } + // User didn't give any arguments that require the config file + if !configFileRequired { + return nil + } + + // Try to find and read config file + if err := viper.ReadInConfig(); err != nil { + if _, ok := err.(viper.ConfigFileNotFoundError); ok { + log.Error("config file not found") + os.Exit(1) + } else { + panic(fmt.Errorf("fatal error config file: %w", err)) + } + } + + // Verify all expected keys are in the config file + for _, k := range expectedKeys { + if !viper.IsSet(k) { + return fmt.Errorf("config file missing key: %s", k) + } + } + return nil +} diff --git a/cmd/s3scanner/s3scanner.go b/cmd/s3scanner/s3scanner.go new file mode 100644 index 0000000..90d5bc2 --- /dev/null +++ b/cmd/s3scanner/s3scanner.go @@ -0,0 +1,401 @@ +package s3scanner + +import ( + "bytes" + "encoding/json" + "flag" + "fmt" + "github.com/dustin/go-humanize" + "github.com/sa7mon/s3scanner/bucket" + "github.com/sa7mon/s3scanner/db" + log2 "github.com/sa7mon/s3scanner/log" + "github.com/sa7mon/s3scanner/mq" + "github.com/sa7mon/s3scanner/provider" + log "github.com/sirupsen/logrus" + "github.com/spf13/viper" + "github.com/streadway/amqp" + "os" + "reflect" + "strings" + "sync" + "text/tabwriter" +) + +func failOnError(err error, msg string) { + if err != nil { + log.Fatalf("%s: %s", msg, err) + } +} + +func printResult(b *bucket.Bucket) { + if args.Json { + log.WithField("bucket", b).Info() + return + } + + if b.Exists == bucket.BucketNotExist { + log.Infof("not_exist | %s", b.Name) + return + } + + result := fmt.Sprintf("exists | %v | %v | %v", b.Name, b.Region, b.String()) + if b.ObjectsEnumerated { + result = fmt.Sprintf("%v | %v objects (%v)", result, len(b.Objects), humanize.Bytes(b.BucketSize)) + } + log.Info(result) +} + +func work(wg *sync.WaitGroup, buckets chan bucket.Bucket, provider provider.StorageProvider, enumerate bool, writeToDB bool) { + defer wg.Done() + for b1 := range buckets { + b, existsErr := provider.BucketExists(&b1) + if existsErr != nil { + log.Errorf("error | %s | %s", b.Name, existsErr.Error()) + continue + } + + if b.Exists == bucket.BucketNotExist { + printResult(b) + continue + } + + // Scan permissions + scanErr := provider.Scan(b, false) + if scanErr != nil { + log.WithFields(log.Fields{"bucket": b}).Error(scanErr) + } + + if enumerate && b.PermAllUsersRead == bucket.PermissionAllowed { + log.WithFields(log.Fields{"method": "main.work()", + "bucket_name": b.Name, "region": b.Region}).Debugf("enumerating objects...") + enumErr := provider.Enumerate(b) + if enumErr != nil { + log.Errorf("Error enumerating bucket '%s': %v\nEnumerated objects: %v", b.Name, enumErr, len(b.Objects)) + continue + } + } + printResult(b) + + if writeToDB { + dbErr := db.StoreBucket(b) + if dbErr != nil { + log.Error(dbErr) + } + } + } +} + +func mqwork(threadId int, wg *sync.WaitGroup, conn *amqp.Connection, provider provider.StorageProvider, queue string, threads int, + doEnumerate bool, writeToDB bool) { + _, once := os.LookupEnv("TEST_MQ") // If we're being tested, exit after one bucket is scanned + defer wg.Done() + + // Wrap the whole thing in a for (while) loop so if the mq server kills the channel, we start it up again + for { + ch, chErr := mq.Connect(conn, queue, threads, threadId) + if chErr != nil { + failOnError(chErr, "couldn't connect to message queue") + } + + msgs, consumeErr := ch.Consume(queue, fmt.Sprintf("%s_%v", queue, threadId), false, false, false, false, nil) + if consumeErr != nil { + log.Error(fmt.Errorf("failed to register a consumer: %w", consumeErr)) + return + } + + for j := range msgs { + bucketToScan := bucket.Bucket{} + + unmarshalErr := json.Unmarshal(j.Body, &bucketToScan) + if unmarshalErr != nil { + log.Error(unmarshalErr) + } + + if !bucket.IsValidS3BucketName(bucketToScan.Name) { + log.Info(fmt.Sprintf("invalid | %s", bucketToScan.Name)) + failOnError(j.Ack(false), "failed to ack") + continue + } + + b, existsErr := provider.BucketExists(&bucketToScan) + if existsErr != nil { + log.WithFields(log.Fields{"bucket": b.Name, "step": "checkExists"}).Error(existsErr) + failOnError(j.Reject(false), "failed to reject") + } + if b.Exists == bucket.BucketNotExist { + // ack the message and skip to the next + log.Infof("not_exist | %s", b.Name) + failOnError(j.Ack(false), "failed to ack") + continue + } + + scanErr := provider.Scan(b, false) + if scanErr != nil { + log.WithFields(log.Fields{"bucket": b}).Error(scanErr) + failOnError(j.Reject(false), "failed to reject") + continue + } + + if doEnumerate { + if b.PermAllUsersRead != bucket.PermissionAllowed { + printResult(&bucketToScan) + failOnError(j.Ack(false), "failed to ack") + if writeToDB { + dbErr := db.StoreBucket(&bucketToScan) + if dbErr != nil { + log.Error(dbErr) + } + } + continue + } + + log.WithFields(log.Fields{"method": "main.mqwork()", + "bucket_name": b.Name, "region": b.Region}).Debugf("enumerating objects...") + + enumErr := provider.Enumerate(b) + if enumErr != nil { + log.Errorf("Error enumerating bucket '%s': %v\nEnumerated objects: %v", b.Name, enumErr, len(b.Objects)) + failOnError(j.Reject(false), "failed to reject") + } + } + + printResult(&bucketToScan) + ackErr := j.Ack(false) + if ackErr != nil { + // Acknowledge mq message. May fail if we've taken too long and the server has closed the channel + // If it has, we break and start at the top of the outer for-loop again which re-establishes a new + // channel + log.WithFields(log.Fields{"bucket": b}).Error(ackErr) + break + } + + // Write to database + if writeToDB { + dbErr := db.StoreBucket(&bucketToScan) + if dbErr != nil { + log.Error(dbErr) + } + } + if once { + return + } + } + } +} + +type flagSetting struct { + indentLevel int + category int +} + +const ( + CategoryInput int = 0 + CategoryOutput int = 1 + CategoryOptions int = 2 + CategoryDebug int = 3 +) + +var configPaths = []string{".", "/etc/s3scanner/", "$HOME/.s3scanner/"} + +var args = ArgCollection{} + +func usage() { + bufferCategoryInput := new(bytes.Buffer) + bufferCategoryOutput := new(bytes.Buffer) + bufferCategoryOptions := new(bytes.Buffer) + bufferCategoryDebug := new(bytes.Buffer) + categoriesWriters := map[int]*tabwriter.Writer{ + CategoryInput: tabwriter.NewWriter(bufferCategoryInput, 0, 0, 2, ' ', 0), + CategoryOutput: tabwriter.NewWriter(bufferCategoryOutput, 0, 0, 2, ' ', 0), + CategoryOptions: tabwriter.NewWriter(bufferCategoryOptions, 0, 0, 2, ' ', 0), + CategoryDebug: tabwriter.NewWriter(bufferCategoryDebug, 0, 0, 2, ' ', 0), + } + flag.VisitAll(func(f *flag.Flag) { + setting, ok := flagSettings[f.Name] + if !ok { + log.Errorf("flag is missing category: %s", f.Name) + os.Exit(1) + } + writer := categoriesWriters[setting.category] + + fmt.Fprintf(writer, "%s -%s\t", strings.Repeat(" ", setting.indentLevel), f.Name) // Two spaces before -; see next two comments. + name, usage := flag.UnquoteUsage(f) + fmt.Fprintf(writer, " %s\t", name) + fmt.Fprint(writer, usage) + if !reflect.ValueOf(f.DefValue).IsZero() { + fmt.Fprintf(writer, " Default: %q", f.DefValue) + } + fmt.Fprint(writer, "\n") + }) + + // Output all the categories + categoriesWriters[CategoryInput].Flush() + categoriesWriters[CategoryOutput].Flush() + categoriesWriters[CategoryOptions].Flush() + categoriesWriters[CategoryDebug].Flush() + fmt.Fprint(flag.CommandLine.Output(), "INPUT: (1 required)\n", bufferCategoryInput.String()) + fmt.Fprint(flag.CommandLine.Output(), "\nOUTPUT:\n", bufferCategoryOutput.String()) + fmt.Fprint(flag.CommandLine.Output(), "\nOPTIONS:\n", bufferCategoryOptions.String()) + fmt.Fprint(flag.CommandLine.Output(), "\nDEBUG:\n", bufferCategoryDebug.String()) + + // Add config file description + quotedPaths := "" + for i, b := range configPaths { + if i != 0 { + quotedPaths += " " + } + quotedPaths += fmt.Sprintf("\"%s\"", b) + } + + fmt.Fprintf(flag.CommandLine.Output(), "\nIf config file is required these locations will be searched for config.yml: %s\n", + quotedPaths) +} + +var flagSettings = map[string]flagSetting{ + "provider": {category: CategoryOptions}, + "bucket": {category: CategoryInput}, + "bucket-file": {category: CategoryInput}, + "mq": {category: CategoryInput}, + "threads": {category: CategoryOptions}, + "verbose": {category: CategoryDebug}, + "version": {category: CategoryDebug}, + "db": {category: CategoryOutput}, + "json": {category: CategoryOutput}, + "enumerate": {category: CategoryOptions}, +} + +func Run(version string) { + // https://twin.sh/articles/39/go-concurrency-goroutines-worker-pools-and-throttling-made-simple + // https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/aws#AnonymousCredentials + + viper.SetConfigName("config") // name of config file (without extension) + viper.SetConfigType("yml") // REQUIRED if the config file does not have the extension in the name + for _, p := range configPaths { + viper.AddConfigPath(p) + } + + flag.StringVar(&args.ProviderFlag, "provider", "aws", fmt.Sprintf( + "Object storage provider: %s - custom requires config file.", + strings.Join(provider.AllProviders, ", "))) + flag.StringVar(&args.BucketName, "bucket", "", "Name of bucket to check.") + flag.StringVar(&args.BucketFile, "bucket-file", "", "File of bucket names to check.") + flag.BoolVar(&args.UseMq, "mq", false, "Connect to RabbitMQ to get buckets. Requires config file key \"mq\".") + + flag.BoolVar(&args.WriteToDB, "db", false, "Save results to a Postgres database. Requires config file key \"db.uri\".") + flag.BoolVar(&args.Json, "json", false, "Print logs to stdout in JSON format instead of human-readable.") + + flag.BoolVar(&args.DoEnumerate, "enumerate", false, "Enumerate bucket objects (can be time-consuming).") + flag.IntVar(&args.Threads, "threads", 4, "Number of threads to scan with.") + flag.BoolVar(&args.Verbose, "verbose", false, "Enable verbose logging.") + flag.BoolVar(&args.Version, "version", false, "Print version") + + flag.Usage = usage + flag.Parse() + + if args.Version { + fmt.Println(version) + os.Exit(0) + } + + argsErr := args.Validate() + if argsErr != nil { + log.Error(argsErr) + os.Exit(1) + } + + // Configure logging + log.SetLevel(log.InfoLevel) + if args.Verbose { + log.SetLevel(log.DebugLevel) + } + log.SetOutput(os.Stdout) + if args.Json { + log.SetFormatter(&log2.NestedJSONFormatter{}) + } else { + log.SetFormatter(&log.TextFormatter{DisableTimestamp: true}) + } + + var p provider.StorageProvider + var err error + configErr := validateConfig(args) + if configErr != nil { + log.Error(configErr) + os.Exit(1) + } + if args.ProviderFlag == "custom" { + if viper.IsSet("providers.custom") { + log.Debug("found custom provider") + p, err = provider.NewCustomProvider( + viper.GetString("providers.custom.address_style"), + viper.GetBool("providers.custom.insecure"), + viper.GetStringSlice("providers.custom.regions"), + viper.GetString("providers.custom.endpoint_format")) + if err != nil { + log.Error(err) + os.Exit(1) + } + } + } else { + p, err = provider.NewProvider(args.ProviderFlag) + if err != nil { + log.Error(err) + os.Exit(1) + } + } + + // Setup database connection + if args.WriteToDB { + dbConfig := viper.GetString("db.uri") + log.Debugf("using database URI from config: %s", dbConfig) + dbErr := db.Connect(dbConfig, true) + if dbErr != nil { + log.Error(dbErr) + os.Exit(1) + } + } + + var wg sync.WaitGroup + + if !args.UseMq { + buckets := make(chan bucket.Bucket) + + for i := 0; i < args.Threads; i++ { + wg.Add(1) + go work(&wg, buckets, p, args.DoEnumerate, args.WriteToDB) + } + + if args.BucketFile != "" { + err := bucket.ReadFromFile(args.BucketFile, buckets) + close(buckets) + if err != nil { + log.Error(err) + os.Exit(1) + } + } else if args.BucketName != "" { + if !bucket.IsValidS3BucketName(args.BucketName) { + log.Info(fmt.Sprintf("invalid | %s", args.BucketName)) + os.Exit(0) + } + c := bucket.NewBucket(strings.ToLower(args.BucketName)) + buckets <- c + close(buckets) + } + + wg.Wait() + os.Exit(0) + } + + // Setup mq connection and spin off consumers + mqUri := viper.GetString("mq.uri") + mqName := viper.GetString("mq.queue_name") + conn, err := amqp.Dial(mqUri) + failOnError(err, fmt.Sprintf("failed to connect to AMQP URI '%s'", mqUri)) + defer conn.Close() + + for i := 0; i < args.Threads; i++ { + wg.Add(1) + go mqwork(i, &wg, conn, p, mqName, args.Threads, args.DoEnumerate, args.WriteToDB) + } + log.Printf("Waiting for messages. To exit press CTRL+C") + wg.Wait() +} diff --git a/main.go b/main.go index 2139a3b..8f3cb1b 100644 --- a/main.go +++ b/main.go @@ -1,473 +1,9 @@ package main -import ( - "bytes" - "encoding/json" - "errors" - "flag" - "fmt" - "github.com/dustin/go-humanize" - "github.com/sa7mon/s3scanner/bucket" - "github.com/sa7mon/s3scanner/db" - log2 "github.com/sa7mon/s3scanner/log" - "github.com/sa7mon/s3scanner/mq" - "github.com/sa7mon/s3scanner/provider" - log "github.com/sirupsen/logrus" - "github.com/spf13/viper" - "github.com/streadway/amqp" - "os" - "reflect" - "strings" - "sync" - "text/tabwriter" -) - -func failOnError(err error, msg string) { - if err != nil { - log.Fatalf("%s: %s", msg, err) - } -} - -func printResult(b *bucket.Bucket) { - if args.json { - log.WithField("bucket", b).Info() - return - } - - if b.Exists == bucket.BucketNotExist { - log.Infof("not_exist | %s", b.Name) - return - } - - result := fmt.Sprintf("exists | %v | %v | %v", b.Name, b.Region, b.String()) - if b.ObjectsEnumerated { - result = fmt.Sprintf("%v | %v objects (%v)", result, len(b.Objects), humanize.Bytes(b.BucketSize)) - } - log.Info(result) -} - -func work(wg *sync.WaitGroup, buckets chan bucket.Bucket, provider provider.StorageProvider, enumerate bool, writeToDB bool) { - defer wg.Done() - for b1 := range buckets { - b, existsErr := provider.BucketExists(&b1) - if existsErr != nil { - log.Errorf("error | %s | %s", b.Name, existsErr.Error()) - continue - } - - if b.Exists == bucket.BucketNotExist { - printResult(b) - continue - } - - // Scan permissions - scanErr := provider.Scan(b, false) - if scanErr != nil { - log.WithFields(log.Fields{"bucket": b}).Error(scanErr) - } - - if enumerate && b.PermAllUsersRead == bucket.PermissionAllowed { - log.WithFields(log.Fields{"method": "main.work()", - "bucket_name": b.Name, "region": b.Region}).Debugf("enumerating objects...") - enumErr := provider.Enumerate(b) - if enumErr != nil { - log.Errorf("Error enumerating bucket '%s': %v\nEnumerated objects: %v", b.Name, enumErr, len(b.Objects)) - continue - } - } - printResult(b) - - if writeToDB { - dbErr := db.StoreBucket(b) - if dbErr != nil { - log.Error(dbErr) - } - } - } -} - -func mqwork(threadId int, wg *sync.WaitGroup, conn *amqp.Connection, provider provider.StorageProvider, queue string, threads int, - doEnumerate bool, writeToDB bool) { - _, once := os.LookupEnv("TEST_MQ") // If we're being tested, exit after one bucket is scanned - defer wg.Done() - - // Wrap the whole thing in a for (while) loop so if the mq server kills the channel, we start it up again - for { - ch, chErr := mq.Connect(conn, queue, threads, threadId) - if chErr != nil { - failOnError(chErr, "couldn't connect to message queue") - } - - msgs, consumeErr := ch.Consume(queue, fmt.Sprintf("%s_%v", queue, threadId), false, false, false, false, nil) - if consumeErr != nil { - log.Error(fmt.Errorf("failed to register a consumer: %w", consumeErr)) - return - } - - for j := range msgs { - bucketToScan := bucket.Bucket{} - - unmarshalErr := json.Unmarshal(j.Body, &bucketToScan) - if unmarshalErr != nil { - log.Error(unmarshalErr) - } - - if !bucket.IsValidS3BucketName(bucketToScan.Name) { - log.Info(fmt.Sprintf("invalid | %s", bucketToScan.Name)) - failOnError(j.Ack(false), "failed to ack") - continue - } - - b, existsErr := provider.BucketExists(&bucketToScan) - if existsErr != nil { - log.WithFields(log.Fields{"bucket": b.Name, "step": "checkExists"}).Error(existsErr) - failOnError(j.Reject(false), "failed to reject") - } - if b.Exists == bucket.BucketNotExist { - // ack the message and skip to the next - log.Infof("not_exist | %s", b.Name) - failOnError(j.Ack(false), "failed to ack") - continue - } - - scanErr := provider.Scan(b, false) - if scanErr != nil { - log.WithFields(log.Fields{"bucket": b}).Error(scanErr) - failOnError(j.Reject(false), "failed to reject") - continue - } - - if doEnumerate { - if b.PermAllUsersRead != bucket.PermissionAllowed { - printResult(&bucketToScan) - failOnError(j.Ack(false), "failed to ack") - if writeToDB { - dbErr := db.StoreBucket(&bucketToScan) - if dbErr != nil { - log.Error(dbErr) - } - } - continue - } - - log.WithFields(log.Fields{"method": "main.mqwork()", - "bucket_name": b.Name, "region": b.Region}).Debugf("enumerating objects...") - - enumErr := provider.Enumerate(b) - if enumErr != nil { - log.Errorf("Error enumerating bucket '%s': %v\nEnumerated objects: %v", b.Name, enumErr, len(b.Objects)) - failOnError(j.Reject(false), "failed to reject") - } - } - - printResult(&bucketToScan) - ackErr := j.Ack(false) - if ackErr != nil { - // Acknowledge mq message. May fail if we've taken too long and the server has closed the channel - // If it has, we break and start at the top of the outer for-loop again which re-establishes a new - // channel - log.WithFields(log.Fields{"bucket": b}).Error(ackErr) - break - } - - // Write to database - if writeToDB { - dbErr := db.StoreBucket(&bucketToScan) - if dbErr != nil { - log.Error(dbErr) - } - } - if once { - return - } - } - } -} - -type flagSetting struct { - indentLevel int - category int -} - -type argCollection struct { - bucketFile string - bucketName string - doEnumerate bool - json bool - providerFlag string - threads int - useMq bool - verbose bool - version bool - writeToDB bool -} - -func (args argCollection) Validate() error { - // Validate: only 1 input flag is provided - numInputFlags := 0 - if args.useMq { - numInputFlags += 1 - } - if args.bucketName != "" { - numInputFlags += 1 - } - if args.bucketFile != "" { - numInputFlags += 1 - } - if numInputFlags != 1 { - return errors.New("exactly one of: -bucket, -bucket-file, -mq required") - } - - return nil -} - -/* -validateConfig checks that the config file contains all necessary keys according to the args specified -*/ -func validateConfig(args argCollection) error { - expectedKeys := []string{} - configFileRequired := false - if args.providerFlag == "custom" { - configFileRequired = true - expectedKeys = append(expectedKeys, []string{"providers.custom.insecure", "providers.custom.endpoint_format", "providers.custom.regions", "providers.custom.address_style"}...) - } - if args.writeToDB { - configFileRequired = true - expectedKeys = append(expectedKeys, []string{"db.uri"}...) - } - if args.useMq { - configFileRequired = true - expectedKeys = append(expectedKeys, []string{"mq.queue_name", "mq.uri"}...) - } - // User didn't give any arguments that require the config file - if !configFileRequired { - return nil - } - - // Try to find and read config file - if err := viper.ReadInConfig(); err != nil { - if _, ok := err.(viper.ConfigFileNotFoundError); ok { - log.Error("config file not found") - os.Exit(1) - } else { - panic(fmt.Errorf("fatal error config file: %w", err)) - } - } - - // Verify all expected keys are in the config file - for _, k := range expectedKeys { - if !viper.IsSet(k) { - return fmt.Errorf("config file missing key: %s", k) - } - } - return nil -} - -const ( - CategoryInput int = 0 - CategoryOutput int = 1 - CategoryOptions int = 2 - CategoryDebug int = 3 -) - -var configPaths = []string{".", "/etc/s3scanner/", "$HOME/.s3scanner/"} +import "github.com/sa7mon/s3scanner/cmd/s3scanner" var version = "dev" -var args = argCollection{} func main() { - // https://twin.sh/articles/39/go-concurrency-goroutines-worker-pools-and-throttling-made-simple - // https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/aws#AnonymousCredentials - - viper.SetConfigName("config") // name of config file (without extension) - viper.SetConfigType("yml") // REQUIRED if the config file does not have the extension in the name - for _, p := range configPaths { - viper.AddConfigPath(p) - } - - flagSettings := make(map[string]flagSetting, 11) - flag.StringVar(&args.providerFlag, "provider", "aws", fmt.Sprintf( - "Object storage provider: %s - custom requires config file.", - strings.Join(provider.AllProviders, ", "))) - flagSettings["provider"] = flagSetting{category: CategoryOptions} - flag.StringVar(&args.bucketName, "bucket", "", "Name of bucket to check.") - flagSettings["bucket"] = flagSetting{category: CategoryInput} - flag.StringVar(&args.bucketFile, "bucket-file", "", "File of bucket names to check.") - flagSettings["bucket-file"] = flagSetting{category: CategoryInput} - flag.BoolVar(&args.useMq, "mq", false, "Connect to RabbitMQ to get buckets. Requires config file key \"mq\".") - flagSettings["mq"] = flagSetting{category: CategoryInput} - - flag.BoolVar(&args.writeToDB, "db", false, "Save results to a Postgres database. Requires config file key \"db.uri\".") - flagSettings["db"] = flagSetting{category: CategoryOutput} - flag.BoolVar(&args.json, "json", false, "Print logs to stdout in JSON format instead of human-readable.") - flagSettings["json"] = flagSetting{category: CategoryOutput} - - flag.BoolVar(&args.doEnumerate, "enumerate", false, "Enumerate bucket objects (can be time-consuming).") - flagSettings["enumerate"] = flagSetting{category: CategoryOptions} - flag.IntVar(&args.threads, "threads", 4, "Number of threads to scan with.") - flagSettings["threads"] = flagSetting{category: CategoryOptions} - flag.BoolVar(&args.verbose, "verbose", false, "Enable verbose logging.") - flagSettings["verbose"] = flagSetting{category: CategoryDebug} - flag.BoolVar(&args.version, "version", false, "Print version") - flagSettings["version"] = flagSetting{category: CategoryDebug} - - flag.Usage = func() { - bufferCategoryInput := new(bytes.Buffer) - bufferCategoryOutput := new(bytes.Buffer) - bufferCategoryOptions := new(bytes.Buffer) - bufferCategoryDebug := new(bytes.Buffer) - categoriesWriters := map[int]*tabwriter.Writer{ - CategoryInput: tabwriter.NewWriter(bufferCategoryInput, 0, 0, 2, ' ', 0), - CategoryOutput: tabwriter.NewWriter(bufferCategoryOutput, 0, 0, 2, ' ', 0), - CategoryOptions: tabwriter.NewWriter(bufferCategoryOptions, 0, 0, 2, ' ', 0), - CategoryDebug: tabwriter.NewWriter(bufferCategoryDebug, 0, 0, 2, ' ', 0), - } - flag.VisitAll(func(f *flag.Flag) { - setting, ok := flagSettings[f.Name] - if !ok { - log.Errorf("flag is missing category: %s", f.Name) - os.Exit(1) - } - writer := categoriesWriters[setting.category] - - fmt.Fprintf(writer, "%s -%s\t", strings.Repeat(" ", setting.indentLevel), f.Name) // Two spaces before -; see next two comments. - name, usage := flag.UnquoteUsage(f) - fmt.Fprintf(writer, " %s\t", name) - fmt.Fprint(writer, usage) - if !reflect.ValueOf(f.DefValue).IsZero() { - fmt.Fprintf(writer, " Default: %q", f.DefValue) - } - fmt.Fprint(writer, "\n") - }) - - // Output all the categories - categoriesWriters[CategoryInput].Flush() - categoriesWriters[CategoryOutput].Flush() - categoriesWriters[CategoryOptions].Flush() - categoriesWriters[CategoryDebug].Flush() - fmt.Fprint(flag.CommandLine.Output(), "INPUT: (1 required)\n", bufferCategoryInput.String()) - fmt.Fprint(flag.CommandLine.Output(), "\nOUTPUT:\n", bufferCategoryOutput.String()) - fmt.Fprint(flag.CommandLine.Output(), "\nOPTIONS:\n", bufferCategoryOptions.String()) - fmt.Fprint(flag.CommandLine.Output(), "\nDEBUG:\n", bufferCategoryDebug.String()) - - // Add config file description - quotedPaths := "" - for i, b := range configPaths { - if i != 0 { - quotedPaths += " " - } - quotedPaths += fmt.Sprintf("\"%s\"", b) - } - - fmt.Fprintf(flag.CommandLine.Output(), "\nIf config file is required these locations will be searched for config.yml: %s\n", - quotedPaths) - } - flag.Parse() - - if args.version { - fmt.Println(version) - os.Exit(0) - } - - argsErr := args.Validate() - if argsErr != nil { - log.Error(argsErr) - os.Exit(1) - } - - // Configure logging - log.SetLevel(log.InfoLevel) - if args.verbose { - log.SetLevel(log.DebugLevel) - } - log.SetOutput(os.Stdout) - if args.json { - log.SetFormatter(&log2.NestedJSONFormatter{}) - } else { - log.SetFormatter(&log.TextFormatter{DisableTimestamp: true}) - } - - var p provider.StorageProvider - var err error - configErr := validateConfig(args) - if configErr != nil { - log.Error(configErr) - os.Exit(1) - } - if args.providerFlag == "custom" { - if viper.IsSet("providers.custom") { - log.Debug("found custom provider") - p, err = provider.NewCustomProvider( - viper.GetString("providers.custom.address_style"), - viper.GetBool("providers.custom.insecure"), - viper.GetStringSlice("providers.custom.regions"), - viper.GetString("providers.custom.endpoint_format")) - if err != nil { - log.Error(err) - os.Exit(1) - } - } - } else { - p, err = provider.NewProvider(args.providerFlag) - if err != nil { - log.Error(err) - os.Exit(1) - } - } - - // Setup database connection - if args.writeToDB { - dbConfig := viper.GetString("db.uri") - log.Debugf("using database URI from config: %s", dbConfig) - dbErr := db.Connect(dbConfig, true) - if dbErr != nil { - log.Error(dbErr) - os.Exit(1) - } - } - - var wg sync.WaitGroup - - if !args.useMq { - buckets := make(chan bucket.Bucket) - - for i := 0; i < args.threads; i++ { - wg.Add(1) - go work(&wg, buckets, p, args.doEnumerate, args.writeToDB) - } - - if args.bucketFile != "" { - err := bucket.ReadFromFile(args.bucketFile, buckets) - close(buckets) - if err != nil { - log.Error(err) - os.Exit(1) - } - } else if args.bucketName != "" { - if !bucket.IsValidS3BucketName(args.bucketName) { - log.Info(fmt.Sprintf("invalid | %s", args.bucketName)) - os.Exit(0) - } - c := bucket.NewBucket(strings.ToLower(args.bucketName)) - buckets <- c - close(buckets) - } - - wg.Wait() - os.Exit(0) - } - - // Setup mq connection and spin off consumers - mqUri := viper.GetString("mq.uri") - mqName := viper.GetString("mq.queue_name") - conn, err := amqp.Dial(mqUri) - failOnError(err, fmt.Sprintf("failed to connect to AMQP URI '%s'", mqUri)) - defer conn.Close() - - for i := 0; i < args.threads; i++ { - wg.Add(1) - go mqwork(i, &wg, conn, p, mqName, args.threads, args.doEnumerate, args.writeToDB) - } - log.Printf("Waiting for messages. To exit press CTRL+C") - wg.Wait() + s3scanner.Run(version) } From d056b500e402eca45df76e207624e36476a1b3b5 Mon Sep 17 00:00:00 2001 From: Dan Salmon Date: Sat, 19 Aug 2023 22:36:52 -0500 Subject: [PATCH 2/5] move workers to own module --- cmd/s3scanner/s3scanner.go | 174 ++----------------------------------- worker/mq_worker.go | 118 +++++++++++++++++++++++++ worker/worker.go | 70 +++++++++++++++ 3 files changed, 194 insertions(+), 168 deletions(-) create mode 100644 worker/mq_worker.go create mode 100644 worker/worker.go diff --git a/cmd/s3scanner/s3scanner.go b/cmd/s3scanner/s3scanner.go index 90d5bc2..d1e5c9f 100644 --- a/cmd/s3scanner/s3scanner.go +++ b/cmd/s3scanner/s3scanner.go @@ -2,15 +2,13 @@ package s3scanner import ( "bytes" - "encoding/json" "flag" "fmt" - "github.com/dustin/go-humanize" "github.com/sa7mon/s3scanner/bucket" "github.com/sa7mon/s3scanner/db" log2 "github.com/sa7mon/s3scanner/log" - "github.com/sa7mon/s3scanner/mq" "github.com/sa7mon/s3scanner/provider" + "github.com/sa7mon/s3scanner/worker" log "github.com/sirupsen/logrus" "github.com/spf13/viper" "github.com/streadway/amqp" @@ -21,168 +19,6 @@ import ( "text/tabwriter" ) -func failOnError(err error, msg string) { - if err != nil { - log.Fatalf("%s: %s", msg, err) - } -} - -func printResult(b *bucket.Bucket) { - if args.Json { - log.WithField("bucket", b).Info() - return - } - - if b.Exists == bucket.BucketNotExist { - log.Infof("not_exist | %s", b.Name) - return - } - - result := fmt.Sprintf("exists | %v | %v | %v", b.Name, b.Region, b.String()) - if b.ObjectsEnumerated { - result = fmt.Sprintf("%v | %v objects (%v)", result, len(b.Objects), humanize.Bytes(b.BucketSize)) - } - log.Info(result) -} - -func work(wg *sync.WaitGroup, buckets chan bucket.Bucket, provider provider.StorageProvider, enumerate bool, writeToDB bool) { - defer wg.Done() - for b1 := range buckets { - b, existsErr := provider.BucketExists(&b1) - if existsErr != nil { - log.Errorf("error | %s | %s", b.Name, existsErr.Error()) - continue - } - - if b.Exists == bucket.BucketNotExist { - printResult(b) - continue - } - - // Scan permissions - scanErr := provider.Scan(b, false) - if scanErr != nil { - log.WithFields(log.Fields{"bucket": b}).Error(scanErr) - } - - if enumerate && b.PermAllUsersRead == bucket.PermissionAllowed { - log.WithFields(log.Fields{"method": "main.work()", - "bucket_name": b.Name, "region": b.Region}).Debugf("enumerating objects...") - enumErr := provider.Enumerate(b) - if enumErr != nil { - log.Errorf("Error enumerating bucket '%s': %v\nEnumerated objects: %v", b.Name, enumErr, len(b.Objects)) - continue - } - } - printResult(b) - - if writeToDB { - dbErr := db.StoreBucket(b) - if dbErr != nil { - log.Error(dbErr) - } - } - } -} - -func mqwork(threadId int, wg *sync.WaitGroup, conn *amqp.Connection, provider provider.StorageProvider, queue string, threads int, - doEnumerate bool, writeToDB bool) { - _, once := os.LookupEnv("TEST_MQ") // If we're being tested, exit after one bucket is scanned - defer wg.Done() - - // Wrap the whole thing in a for (while) loop so if the mq server kills the channel, we start it up again - for { - ch, chErr := mq.Connect(conn, queue, threads, threadId) - if chErr != nil { - failOnError(chErr, "couldn't connect to message queue") - } - - msgs, consumeErr := ch.Consume(queue, fmt.Sprintf("%s_%v", queue, threadId), false, false, false, false, nil) - if consumeErr != nil { - log.Error(fmt.Errorf("failed to register a consumer: %w", consumeErr)) - return - } - - for j := range msgs { - bucketToScan := bucket.Bucket{} - - unmarshalErr := json.Unmarshal(j.Body, &bucketToScan) - if unmarshalErr != nil { - log.Error(unmarshalErr) - } - - if !bucket.IsValidS3BucketName(bucketToScan.Name) { - log.Info(fmt.Sprintf("invalid | %s", bucketToScan.Name)) - failOnError(j.Ack(false), "failed to ack") - continue - } - - b, existsErr := provider.BucketExists(&bucketToScan) - if existsErr != nil { - log.WithFields(log.Fields{"bucket": b.Name, "step": "checkExists"}).Error(existsErr) - failOnError(j.Reject(false), "failed to reject") - } - if b.Exists == bucket.BucketNotExist { - // ack the message and skip to the next - log.Infof("not_exist | %s", b.Name) - failOnError(j.Ack(false), "failed to ack") - continue - } - - scanErr := provider.Scan(b, false) - if scanErr != nil { - log.WithFields(log.Fields{"bucket": b}).Error(scanErr) - failOnError(j.Reject(false), "failed to reject") - continue - } - - if doEnumerate { - if b.PermAllUsersRead != bucket.PermissionAllowed { - printResult(&bucketToScan) - failOnError(j.Ack(false), "failed to ack") - if writeToDB { - dbErr := db.StoreBucket(&bucketToScan) - if dbErr != nil { - log.Error(dbErr) - } - } - continue - } - - log.WithFields(log.Fields{"method": "main.mqwork()", - "bucket_name": b.Name, "region": b.Region}).Debugf("enumerating objects...") - - enumErr := provider.Enumerate(b) - if enumErr != nil { - log.Errorf("Error enumerating bucket '%s': %v\nEnumerated objects: %v", b.Name, enumErr, len(b.Objects)) - failOnError(j.Reject(false), "failed to reject") - } - } - - printResult(&bucketToScan) - ackErr := j.Ack(false) - if ackErr != nil { - // Acknowledge mq message. May fail if we've taken too long and the server has closed the channel - // If it has, we break and start at the top of the outer for-loop again which re-establishes a new - // channel - log.WithFields(log.Fields{"bucket": b}).Error(ackErr) - break - } - - // Write to database - if writeToDB { - dbErr := db.StoreBucket(&bucketToScan) - if dbErr != nil { - log.Error(dbErr) - } - } - if once { - return - } - } - } -} - type flagSetting struct { indentLevel int category int @@ -361,7 +197,7 @@ func Run(version string) { for i := 0; i < args.Threads; i++ { wg.Add(1) - go work(&wg, buckets, p, args.DoEnumerate, args.WriteToDB) + go worker.Work(&wg, buckets, p, args.DoEnumerate, args.WriteToDB, args.Json) } if args.BucketFile != "" { @@ -389,12 +225,14 @@ func Run(version string) { mqUri := viper.GetString("mq.uri") mqName := viper.GetString("mq.queue_name") conn, err := amqp.Dial(mqUri) - failOnError(err, fmt.Sprintf("failed to connect to AMQP URI '%s'", mqUri)) + if err != nil { + log.Fatalf("%s: %s", fmt.Sprintf("failed to connect to AMQP URI '%s'", mqUri), err) + } defer conn.Close() for i := 0; i < args.Threads; i++ { wg.Add(1) - go mqwork(i, &wg, conn, p, mqName, args.Threads, args.DoEnumerate, args.WriteToDB) + go worker.WorkMQ(i, &wg, conn, p, mqName, args.Threads, args.DoEnumerate, args.WriteToDB) } log.Printf("Waiting for messages. To exit press CTRL+C") wg.Wait() diff --git a/worker/mq_worker.go b/worker/mq_worker.go new file mode 100644 index 0000000..2b1c76b --- /dev/null +++ b/worker/mq_worker.go @@ -0,0 +1,118 @@ +package worker + +import ( + "encoding/json" + "fmt" + "github.com/sa7mon/s3scanner/bucket" + "github.com/sa7mon/s3scanner/db" + "github.com/sa7mon/s3scanner/mq" + "github.com/sa7mon/s3scanner/provider" + log "github.com/sirupsen/logrus" + "github.com/streadway/amqp" + "os" + "sync" +) + +func failOnError(err error, msg string) { + if err != nil { + log.Fatalf("%s: %s", msg, err) + } +} + +func WorkMQ(threadId int, wg *sync.WaitGroup, conn *amqp.Connection, provider provider.StorageProvider, queue string, + threads int, doEnumerate bool, writeToDB bool) { + _, once := os.LookupEnv("TEST_MQ") // If we're being tested, exit after one bucket is scanned + defer wg.Done() + + // Wrap the whole thing in a for (while) loop so if the mq server kills the channel, we start it up again + for { + ch, chErr := mq.Connect(conn, queue, threads, threadId) + if chErr != nil { + failOnError(chErr, "couldn't connect to message queue") + } + + msgs, consumeErr := ch.Consume(queue, fmt.Sprintf("%s_%v", queue, threadId), false, false, false, false, nil) + if consumeErr != nil { + log.Error(fmt.Errorf("failed to register a consumer: %w", consumeErr)) + return + } + + for j := range msgs { + bucketToScan := bucket.Bucket{} + + unmarshalErr := json.Unmarshal(j.Body, &bucketToScan) + if unmarshalErr != nil { + log.Error(unmarshalErr) + } + + if !bucket.IsValidS3BucketName(bucketToScan.Name) { + log.Info(fmt.Sprintf("invalid | %s", bucketToScan.Name)) + failOnError(j.Ack(false), "failed to ack") + continue + } + + b, existsErr := provider.BucketExists(&bucketToScan) + if existsErr != nil { + log.WithFields(log.Fields{"bucket": b.Name, "step": "checkExists"}).Error(existsErr) + failOnError(j.Reject(false), "failed to reject") + } + if b.Exists == bucket.BucketNotExist { + // ack the message and skip to the next + log.Infof("not_exist | %s", b.Name) + failOnError(j.Ack(false), "failed to ack") + continue + } + + scanErr := provider.Scan(b, false) + if scanErr != nil { + log.WithFields(log.Fields{"bucket": b}).Error(scanErr) + failOnError(j.Reject(false), "failed to reject") + continue + } + + if doEnumerate { + if b.PermAllUsersRead != bucket.PermissionAllowed { + printResult(&bucketToScan, false) + failOnError(j.Ack(false), "failed to ack") + if writeToDB { + dbErr := db.StoreBucket(&bucketToScan) + if dbErr != nil { + log.Error(dbErr) + } + } + continue + } + + log.WithFields(log.Fields{"method": "main.mqwork()", + "bucket_name": b.Name, "region": b.Region}).Debugf("enumerating objects...") + + enumErr := provider.Enumerate(b) + if enumErr != nil { + log.Errorf("Error enumerating bucket '%s': %v\nEnumerated objects: %v", b.Name, enumErr, len(b.Objects)) + failOnError(j.Reject(false), "failed to reject") + } + } + + printResult(&bucketToScan, false) + ackErr := j.Ack(false) + if ackErr != nil { + // Acknowledge mq message. May fail if we've taken too long and the server has closed the channel + // If it has, we break and start at the top of the outer for-loop again which re-establishes a new + // channel + log.WithFields(log.Fields{"bucket": b}).Error(ackErr) + break + } + + // Write to database + if writeToDB { + dbErr := db.StoreBucket(&bucketToScan) + if dbErr != nil { + log.Error(dbErr) + } + } + if once { + return + } + } + } +} diff --git a/worker/worker.go b/worker/worker.go new file mode 100644 index 0000000..4ae3f55 --- /dev/null +++ b/worker/worker.go @@ -0,0 +1,70 @@ +package worker + +import ( + "fmt" + "github.com/dustin/go-humanize" + "github.com/sa7mon/s3scanner/bucket" + "github.com/sa7mon/s3scanner/db" + "github.com/sa7mon/s3scanner/provider" + log "github.com/sirupsen/logrus" + "sync" +) + +func printResult(b *bucket.Bucket, json bool) { + if json { + log.WithField("bucket", b).Info() + return + } + + if b.Exists == bucket.BucketNotExist { + log.Infof("not_exist | %s", b.Name) + return + } + + result := fmt.Sprintf("exists | %v | %v | %v", b.Name, b.Region, b.String()) + if b.ObjectsEnumerated { + result = fmt.Sprintf("%v | %v objects (%v)", result, len(b.Objects), humanize.Bytes(b.BucketSize)) + } + log.Info(result) +} + +func Work(wg *sync.WaitGroup, buckets chan bucket.Bucket, provider provider.StorageProvider, doEnumerate bool, + writeToDB bool, json bool) { + defer wg.Done() + for b1 := range buckets { + b, existsErr := provider.BucketExists(&b1) + if existsErr != nil { + log.Errorf("error | %s | %s", b.Name, existsErr.Error()) + continue + } + + if b.Exists == bucket.BucketNotExist { + printResult(b, json) + continue + } + + // Scan permissions + scanErr := provider.Scan(b, false) + if scanErr != nil { + log.WithFields(log.Fields{"bucket": b}).Error(scanErr) + } + + if doEnumerate && b.PermAllUsersRead == bucket.PermissionAllowed { + log.WithFields(log.Fields{"method": "main.work()", + "bucket_name": b.Name, "region": b.Region}).Debugf("enumerating objects...") + enumErr := provider.Enumerate(b) + if enumErr != nil { + log.Errorf("Error enumerating bucket '%s': %v\nEnumerated objects: %v", b.Name, enumErr, len(b.Objects)) + continue + } + } + printResult(b, json) + + if writeToDB { + dbErr := db.StoreBucket(b) + if dbErr != nil { + log.Error(dbErr) + } + } + } +} From ad831f252716a1703dbdc40a707d0382c2b285e1 Mon Sep 17 00:00:00 2001 From: Dan Salmon Date: Sat, 19 Aug 2023 22:58:23 -0500 Subject: [PATCH 3/5] move tests --- Makefile | 5 +- .../s3scanner/s3scanner_test.go | 108 ++++-------------- worker/mq_worker.go | 20 ++-- worker/mq_worker_test.go | 55 +++++++++ worker/worker.go | 6 +- worker/worker_test.go | 24 ++++ 6 files changed, 119 insertions(+), 99 deletions(-) rename main_test.go => cmd/s3scanner/s3scanner_test.go (55%) create mode 100644 worker/mq_worker_test.go create mode 100644 worker/worker_test.go diff --git a/Makefile b/Makefile index ca119ba..5e7ac99 100644 --- a/Makefile +++ b/Makefile @@ -16,5 +16,8 @@ test: test-integration: TEST_DB=1 TEST_MQ=1 go test ./... +test-coverage: + TEST_DB=1 TEST_MQ=1 go test ./... -coverprofile cover.out && go tool cover -html=cover.out + upgrade: - go get -u ./... \ No newline at end of file + go get -u ./... diff --git a/main_test.go b/cmd/s3scanner/s3scanner_test.go similarity index 55% rename from main_test.go rename to cmd/s3scanner/s3scanner_test.go index 3a83f42..528e1da 100644 --- a/main_test.go +++ b/cmd/s3scanner/s3scanner_test.go @@ -1,71 +1,48 @@ -package main +package s3scanner import ( "bytes" - "encoding/json" "github.com/sa7mon/s3scanner/bucket" - "github.com/sa7mon/s3scanner/mq" - "github.com/sa7mon/s3scanner/provider" + "github.com/sa7mon/s3scanner/worker" log "github.com/sirupsen/logrus" "github.com/sirupsen/logrus/hooks/writer" - "github.com/streadway/amqp" "github.com/stretchr/testify/assert" - "os" - "sync" "testing" ) -func publishBucket(ch *amqp.Channel, b bucket.Bucket) { - bucketBytes, err := json.Marshal(b) - if err != nil { - failOnError(err, "Failed to marshal bucket msg") - } - - err = ch.Publish( - "", - "test", - false, - false, - amqp.Publishing{Body: bucketBytes, DeliveryMode: amqp.Transient}, - ) - if err != nil { - failOnError(err, "Failed to publish to channel") - } -} - func TestArgCollection_Validate(t *testing.T) { - goodInputs := []argCollection{ + goodInputs := []ArgCollection{ { - bucketName: "asdf", - bucketFile: "", - useMq: false, + BucketName: "asdf", + BucketFile: "", + UseMq: false, }, { - bucketName: "", - bucketFile: "buckets.txt", - useMq: false, + BucketName: "", + BucketFile: "buckets.txt", + UseMq: false, }, { - bucketName: "", - bucketFile: "", - useMq: true, + BucketName: "", + BucketFile: "", + UseMq: true, }, } - tooManyInputs := []argCollection{ + tooManyInputs := []ArgCollection{ { - bucketName: "asdf", - bucketFile: "asdf", - useMq: false, + BucketName: "asdf", + BucketFile: "asdf", + UseMq: false, }, { - bucketName: "adsf", - bucketFile: "", - useMq: true, + BucketName: "adsf", + BucketFile: "", + UseMq: true, }, { - bucketName: "", - bucketFile: "asdf.txt", - useMq: true, + BucketName: "", + BucketFile: "asdf.txt", + UseMq: true, }, } @@ -83,45 +60,6 @@ func TestArgCollection_Validate(t *testing.T) { } } -func TestWork(t *testing.T) { - b := bucket.NewBucket("s3scanner-bucketsize") - aws, err := provider.NewProviderAWS() - assert.Nil(t, err) - b2, exErr := aws.BucketExists(&b) - assert.Nil(t, exErr) - - wg := sync.WaitGroup{} - wg.Add(1) - c := make(chan bucket.Bucket, 1) - c <- *b2 - close(c) - work(&wg, c, aws, true, false) -} - -func TestMqWork(t *testing.T) { - _, testMQ := os.LookupEnv("TEST_MQ") - if !testMQ { - t.Skip("TEST_MQ not enabled") - } - - aws, err := provider.NewProviderAWS() - assert.Nil(t, err) - - wg := sync.WaitGroup{} - wg.Add(1) - - conn, err := amqp.Dial("amqp://guest:guest@localhost:5672") - assert.Nil(t, err) - - // Connect to queue and add a test bucket - ch, err := mq.Connect(conn, "test", 1, 0) - assert.Nil(t, err) - publishBucket(ch, bucket.Bucket{Name: "mqtest"}) - - mqwork(0, &wg, conn, aws, "test", 1, - false, false) -} - func TestLogs(t *testing.T) { var buf bytes.Buffer log.AddHook(&writer.Hook{ // Send logs with level higher than warning to stderr @@ -171,7 +109,7 @@ func TestLogs(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t2 *testing.T) { - printResult(&tt.b) + worker.PrintResult(&tt.b, false) assert.Contains(t2, buf.String(), tt.expected) }) } diff --git a/worker/mq_worker.go b/worker/mq_worker.go index 2b1c76b..2e097df 100644 --- a/worker/mq_worker.go +++ b/worker/mq_worker.go @@ -13,7 +13,7 @@ import ( "sync" ) -func failOnError(err error, msg string) { +func FailOnError(err error, msg string) { if err != nil { log.Fatalf("%s: %s", msg, err) } @@ -28,7 +28,7 @@ func WorkMQ(threadId int, wg *sync.WaitGroup, conn *amqp.Connection, provider pr for { ch, chErr := mq.Connect(conn, queue, threads, threadId) if chErr != nil { - failOnError(chErr, "couldn't connect to message queue") + FailOnError(chErr, "couldn't connect to message queue") } msgs, consumeErr := ch.Consume(queue, fmt.Sprintf("%s_%v", queue, threadId), false, false, false, false, nil) @@ -47,33 +47,33 @@ func WorkMQ(threadId int, wg *sync.WaitGroup, conn *amqp.Connection, provider pr if !bucket.IsValidS3BucketName(bucketToScan.Name) { log.Info(fmt.Sprintf("invalid | %s", bucketToScan.Name)) - failOnError(j.Ack(false), "failed to ack") + FailOnError(j.Ack(false), "failed to ack") continue } b, existsErr := provider.BucketExists(&bucketToScan) if existsErr != nil { log.WithFields(log.Fields{"bucket": b.Name, "step": "checkExists"}).Error(existsErr) - failOnError(j.Reject(false), "failed to reject") + FailOnError(j.Reject(false), "failed to reject") } if b.Exists == bucket.BucketNotExist { // ack the message and skip to the next log.Infof("not_exist | %s", b.Name) - failOnError(j.Ack(false), "failed to ack") + FailOnError(j.Ack(false), "failed to ack") continue } scanErr := provider.Scan(b, false) if scanErr != nil { log.WithFields(log.Fields{"bucket": b}).Error(scanErr) - failOnError(j.Reject(false), "failed to reject") + FailOnError(j.Reject(false), "failed to reject") continue } if doEnumerate { if b.PermAllUsersRead != bucket.PermissionAllowed { - printResult(&bucketToScan, false) - failOnError(j.Ack(false), "failed to ack") + PrintResult(&bucketToScan, false) + FailOnError(j.Ack(false), "failed to ack") if writeToDB { dbErr := db.StoreBucket(&bucketToScan) if dbErr != nil { @@ -89,11 +89,11 @@ func WorkMQ(threadId int, wg *sync.WaitGroup, conn *amqp.Connection, provider pr enumErr := provider.Enumerate(b) if enumErr != nil { log.Errorf("Error enumerating bucket '%s': %v\nEnumerated objects: %v", b.Name, enumErr, len(b.Objects)) - failOnError(j.Reject(false), "failed to reject") + FailOnError(j.Reject(false), "failed to reject") } } - printResult(&bucketToScan, false) + PrintResult(&bucketToScan, false) ackErr := j.Ack(false) if ackErr != nil { // Acknowledge mq message. May fail if we've taken too long and the server has closed the channel diff --git a/worker/mq_worker_test.go b/worker/mq_worker_test.go new file mode 100644 index 0000000..bc4e322 --- /dev/null +++ b/worker/mq_worker_test.go @@ -0,0 +1,55 @@ +package worker + +import ( + "encoding/json" + "github.com/sa7mon/s3scanner/bucket" + "github.com/sa7mon/s3scanner/mq" + "github.com/sa7mon/s3scanner/provider" + "github.com/streadway/amqp" + "github.com/stretchr/testify/assert" + "os" + "sync" + "testing" +) + +func publishBucket(ch *amqp.Channel, b bucket.Bucket) { + bucketBytes, err := json.Marshal(b) + if err != nil { + FailOnError(err, "Failed to marshal bucket msg") + } + + err = ch.Publish( + "", + "test", + false, + false, + amqp.Publishing{Body: bucketBytes, DeliveryMode: amqp.Transient}, + ) + if err != nil { + FailOnError(err, "Failed to publish to channel") + } +} + +func TestMqWork(t *testing.T) { + _, testMQ := os.LookupEnv("TEST_MQ") + if !testMQ { + t.Skip("TEST_MQ not enabled") + } + + aws, err := provider.NewProviderAWS() + assert.Nil(t, err) + + wg := sync.WaitGroup{} + wg.Add(1) + + conn, err := amqp.Dial("amqp://guest:guest@localhost:5672") + assert.Nil(t, err) + + // Connect to queue and add a test bucket + ch, err := mq.Connect(conn, "test", 1, 0) + assert.Nil(t, err) + publishBucket(ch, bucket.Bucket{Name: "mqtest"}) + + WorkMQ(0, &wg, conn, aws, "test", 1, + false, false) +} diff --git a/worker/worker.go b/worker/worker.go index 4ae3f55..97c477a 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -10,7 +10,7 @@ import ( "sync" ) -func printResult(b *bucket.Bucket, json bool) { +func PrintResult(b *bucket.Bucket, json bool) { if json { log.WithField("bucket", b).Info() return @@ -39,7 +39,7 @@ func Work(wg *sync.WaitGroup, buckets chan bucket.Bucket, provider provider.Stor } if b.Exists == bucket.BucketNotExist { - printResult(b, json) + PrintResult(b, json) continue } @@ -58,7 +58,7 @@ func Work(wg *sync.WaitGroup, buckets chan bucket.Bucket, provider provider.Stor continue } } - printResult(b, json) + PrintResult(b, json) if writeToDB { dbErr := db.StoreBucket(b) diff --git a/worker/worker_test.go b/worker/worker_test.go new file mode 100644 index 0000000..535cbee --- /dev/null +++ b/worker/worker_test.go @@ -0,0 +1,24 @@ +package worker + +import ( + "github.com/sa7mon/s3scanner/bucket" + "github.com/sa7mon/s3scanner/provider" + "github.com/stretchr/testify/assert" + "sync" + "testing" +) + +func TestWork(t *testing.T) { + b := bucket.NewBucket("s3scanner-bucketsize") + aws, err := provider.NewProviderAWS() + assert.Nil(t, err) + b2, exErr := aws.BucketExists(&b) + assert.Nil(t, exErr) + + wg := sync.WaitGroup{} + wg.Add(1) + c := make(chan bucket.Bucket, 1) + c <- *b2 + close(c) + Work(&wg, c, aws, true, false, false) +} From f4986f41c81b4d427cbd9c6fcd712e6163bf93e9 Mon Sep 17 00:00:00 2001 From: Dan Salmon Date: Sun, 20 Aug 2023 00:26:47 -0500 Subject: [PATCH 4/5] move and add tests --- cmd/s3scanner/args.go | 7 +- cmd/s3scanner/args_test.go | 87 ++++++++++++++++++++++++ cmd/s3scanner/s3scanner_test.go | 116 -------------------------------- worker/worker_test.go | 59 ++++++++++++++++ 4 files changed, 148 insertions(+), 121 deletions(-) create mode 100644 cmd/s3scanner/args_test.go diff --git a/cmd/s3scanner/args.go b/cmd/s3scanner/args.go index 2edbc07..9f2d75d 100644 --- a/cmd/s3scanner/args.go +++ b/cmd/s3scanner/args.go @@ -3,9 +3,7 @@ package s3scanner import ( "errors" "fmt" - log "github.com/sirupsen/logrus" "github.com/spf13/viper" - "os" ) type ArgCollection struct { @@ -66,10 +64,9 @@ func validateConfig(args ArgCollection) error { // Try to find and read config file if err := viper.ReadInConfig(); err != nil { if _, ok := err.(viper.ConfigFileNotFoundError); ok { - log.Error("config file not found") - os.Exit(1) + return errors.New("config file not found") } else { - panic(fmt.Errorf("fatal error config file: %w", err)) + return err } } diff --git a/cmd/s3scanner/args_test.go b/cmd/s3scanner/args_test.go new file mode 100644 index 0000000..102f9af --- /dev/null +++ b/cmd/s3scanner/args_test.go @@ -0,0 +1,87 @@ +package s3scanner + +import ( + "errors" + "github.com/spf13/viper" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestArgCollection_Validate(t *testing.T) { + goodInputs := []ArgCollection{ + { + BucketName: "asdf", + BucketFile: "", + UseMq: false, + }, + { + BucketName: "", + BucketFile: "buckets.txt", + UseMq: false, + }, + { + BucketName: "", + BucketFile: "", + UseMq: true, + }, + } + tooManyInputs := []ArgCollection{ + { + BucketName: "asdf", + BucketFile: "asdf", + UseMq: false, + }, + { + BucketName: "adsf", + BucketFile: "", + UseMq: true, + }, + { + BucketName: "", + BucketFile: "asdf.txt", + UseMq: true, + }, + } + + for _, v := range goodInputs { + err := v.Validate() + if err != nil { + t.Errorf("%v: %e", v, err) + } + } + for _, v := range tooManyInputs { + err := v.Validate() + if err == nil { + t.Errorf("expected error but did not find one: %v", v) + } + } +} + +func TestValidateConfig(t *testing.T) { + a := ArgCollection{ + DoEnumerate: false, + Json: false, + ProviderFlag: "custom", + UseMq: true, + WriteToDB: true, + } + viper.AddConfigPath("../../") + viper.SetConfigName("config") // name of config file (without extension) + viper.SetConfigType("yml") // REQUIRED if the config file does not have the extension in the name + err := validateConfig(a) + assert.Nil(t, err) +} + +func TestValidateConfig_NotFound(t *testing.T) { + a := ArgCollection{ + DoEnumerate: false, + Json: false, + ProviderFlag: "custom", + UseMq: true, + WriteToDB: true, + } + viper.SetConfigName("asdf") // won't be found + viper.SetConfigType("yml") + err := validateConfig(a) + assert.Equal(t, errors.New("config file not found"), err) +} diff --git a/cmd/s3scanner/s3scanner_test.go b/cmd/s3scanner/s3scanner_test.go index 528e1da..046ccfc 100644 --- a/cmd/s3scanner/s3scanner_test.go +++ b/cmd/s3scanner/s3scanner_test.go @@ -1,117 +1 @@ package s3scanner - -import ( - "bytes" - "github.com/sa7mon/s3scanner/bucket" - "github.com/sa7mon/s3scanner/worker" - log "github.com/sirupsen/logrus" - "github.com/sirupsen/logrus/hooks/writer" - "github.com/stretchr/testify/assert" - "testing" -) - -func TestArgCollection_Validate(t *testing.T) { - goodInputs := []ArgCollection{ - { - BucketName: "asdf", - BucketFile: "", - UseMq: false, - }, - { - BucketName: "", - BucketFile: "buckets.txt", - UseMq: false, - }, - { - BucketName: "", - BucketFile: "", - UseMq: true, - }, - } - tooManyInputs := []ArgCollection{ - { - BucketName: "asdf", - BucketFile: "asdf", - UseMq: false, - }, - { - BucketName: "adsf", - BucketFile: "", - UseMq: true, - }, - { - BucketName: "", - BucketFile: "asdf.txt", - UseMq: true, - }, - } - - for _, v := range goodInputs { - err := v.Validate() - if err != nil { - t.Errorf("%v: %e", v, err) - } - } - for _, v := range tooManyInputs { - err := v.Validate() - if err == nil { - t.Errorf("expected error but did not find one: %v", v) - } - } -} - -func TestLogs(t *testing.T) { - var buf bytes.Buffer - log.AddHook(&writer.Hook{ // Send logs with level higher than warning to stderr - Writer: &buf, - LogLevels: []log.Level{ - log.PanicLevel, - log.FatalLevel, - log.ErrorLevel, - log.WarnLevel, - log.InfoLevel, - }, - }) - - tests := []struct { - name string - b bucket.Bucket - enum bool - expected string - }{ - {name: "enumerated, public-read, empty", b: bucket.Bucket{ - Name: "test-logging", - Exists: bucket.BucketExists, - ObjectsEnumerated: true, - NumObjects: 0, - BucketSize: 0, - PermAllUsersRead: bucket.PermissionAllowed, - }, enum: true, expected: "exists | test-logging | | AuthUsers: [] | AllUsers: [READ] | 0 objects (0 B)"}, - {name: "enumerated, closed", b: bucket.Bucket{ - Name: "enumerated-closed", - Exists: bucket.BucketExists, - ObjectsEnumerated: true, - NumObjects: 0, - BucketSize: 0, - PermAllUsersRead: bucket.PermissionDenied, - }, enum: true, expected: "exists | enumerated-closed | | AuthUsers: [] | AllUsers: [] | 0 objects (0 B)"}, - {name: "closed", b: bucket.Bucket{ - Name: "no-enumerate-closed", - Exists: bucket.BucketExists, - ObjectsEnumerated: false, - PermAllUsersRead: bucket.PermissionDenied, - }, enum: true, expected: "exists | no-enumerate-closed | | AuthUsers: [] | AllUsers: []"}, - {name: "no-enum-not-exist", b: bucket.Bucket{ - Name: "no-enum-not-exist", - Exists: bucket.BucketNotExist, - }, enum: false, expected: "not_exist | no-enum-not-exist"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t2 *testing.T) { - worker.PrintResult(&tt.b, false) - assert.Contains(t2, buf.String(), tt.expected) - }) - } - -} diff --git a/worker/worker_test.go b/worker/worker_test.go index 535cbee..0698aa2 100644 --- a/worker/worker_test.go +++ b/worker/worker_test.go @@ -1,8 +1,11 @@ package worker import ( + "bytes" "github.com/sa7mon/s3scanner/bucket" "github.com/sa7mon/s3scanner/provider" + log "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus/hooks/writer" "github.com/stretchr/testify/assert" "sync" "testing" @@ -22,3 +25,59 @@ func TestWork(t *testing.T) { close(c) Work(&wg, c, aws, true, false, false) } + +func TestLogs(t *testing.T) { + var buf bytes.Buffer + log.AddHook(&writer.Hook{ // Send logs with level higher than warning to stderr + Writer: &buf, + LogLevels: []log.Level{ + log.PanicLevel, + log.FatalLevel, + log.ErrorLevel, + log.WarnLevel, + log.InfoLevel, + }, + }) + + tests := []struct { + name string + b bucket.Bucket + enum bool + expected string + }{ + {name: "enumerated, public-read, empty", b: bucket.Bucket{ + Name: "test-logging", + Exists: bucket.BucketExists, + ObjectsEnumerated: true, + NumObjects: 0, + BucketSize: 0, + PermAllUsersRead: bucket.PermissionAllowed, + }, enum: true, expected: "exists | test-logging | | AuthUsers: [] | AllUsers: [READ] | 0 objects (0 B)"}, + {name: "enumerated, closed", b: bucket.Bucket{ + Name: "enumerated-closed", + Exists: bucket.BucketExists, + ObjectsEnumerated: true, + NumObjects: 0, + BucketSize: 0, + PermAllUsersRead: bucket.PermissionDenied, + }, enum: true, expected: "exists | enumerated-closed | | AuthUsers: [] | AllUsers: [] | 0 objects (0 B)"}, + {name: "closed", b: bucket.Bucket{ + Name: "no-enumerate-closed", + Exists: bucket.BucketExists, + ObjectsEnumerated: false, + PermAllUsersRead: bucket.PermissionDenied, + }, enum: true, expected: "exists | no-enumerate-closed | | AuthUsers: [] | AllUsers: []"}, + {name: "no-enum-not-exist", b: bucket.Bucket{ + Name: "no-enum-not-exist", + Exists: bucket.BucketNotExist, + }, enum: false, expected: "not_exist | no-enum-not-exist"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t2 *testing.T) { + PrintResult(&tt.b, false) + assert.Contains(t2, buf.String(), tt.expected) + }) + } + +} From 096f9e1582550902caf16a4815543123e0e5169f Mon Sep 17 00:00:00 2001 From: Dan Salmon Date: Sun, 20 Aug 2023 00:36:46 -0500 Subject: [PATCH 5/5] reorder --- Makefile | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Makefile b/Makefile index 5e7ac99..e92f305 100644 --- a/Makefile +++ b/Makefile @@ -13,11 +13,11 @@ docker-image: test: go test ./... -test-integration: - TEST_DB=1 TEST_MQ=1 go test ./... - test-coverage: TEST_DB=1 TEST_MQ=1 go test ./... -coverprofile cover.out && go tool cover -html=cover.out +test-integration: + TEST_DB=1 TEST_MQ=1 go test ./... + upgrade: go get -u ./...