Skip to content

Commit

Permalink
feat: add iac progress bar
Browse files Browse the repository at this point in the history
  • Loading branch information
sundowndev committed Apr 27, 2021
1 parent 7a52ef3 commit 03833f4
Show file tree
Hide file tree
Showing 13 changed files with 121 additions and 30 deletions.
24 changes: 17 additions & 7 deletions pkg/cmd/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,23 @@ func scanRun(opts *pkg.ScanOptions) error {
providerLibrary := terraform.NewProviderLibrary()
supplierLibrary := resource.NewSupplierLibrary()

progress := globaloutput.NewProgress()
scanProgress := globaloutput.NewProgress()
scanProgress.SetOptions(globaloutput.ProgressOptions{
LoadingText: "Scanning resources",
FinishedText: "Scanned resources",
ShowCount: true,
})

iacProgress := globaloutput.NewProgress()
iacProgress.SetOptions(globaloutput.ProgressOptions{
LoadingText: "Scanning states",
FinishedText: "Scanned states",
ShowCount: true,
})

resourceSchemaRepository := resource.NewSchemaRepository()

err := remote.Activate(opts.To, alerter, providerLibrary, supplierLibrary, progress, resourceSchemaRepository)
err := remote.Activate(opts.To, alerter, providerLibrary, supplierLibrary, scanProgress, resourceSchemaRepository)
if err != nil {
return err
}
Expand All @@ -165,28 +177,26 @@ func scanRun(opts *pkg.ScanOptions) error {

scanner := pkg.NewScanner(supplierLibrary.Suppliers(), alerter, resourceSchemaRepository)

iacSupplier, err := supplier.GetIACSupplier(opts.From, providerLibrary, opts.BackendOptions, resourceSchemaRepository)
iacSupplier, err := supplier.GetIACSupplier(opts.From, providerLibrary, opts.BackendOptions, resourceSchemaRepository, iacProgress)
if err != nil {
return err
}

resFactory := terraform.NewTerraformResourceFactory(providerLibrary)

ctl := pkg.NewDriftCTL(scanner, iacSupplier, alerter, resFactory, opts, resourceSchemaRepository)
ctl := pkg.NewDriftCTL(scanner, iacSupplier, alerter, resFactory, opts, resourceSchemaRepository, scanProgress)

go func() {
<-c
logrus.Warn("Detected interrupt, cleanup ...")
ctl.Stop()
}()

progress.Start()
analysis, err := ctl.Run()
progress.Stop()

if err != nil {
return err
}
scanProgress.Stop()

err = selectedOutput.Write(analysis)
if err != nil {
Expand Down
7 changes: 6 additions & 1 deletion pkg/driftctl.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package pkg
import (
"fmt"

globaloutput "github.com/cloudskiff/driftctl/pkg/output"
"github.com/jmespath/go-jmespath"
"github.com/sirupsen/logrus"

Expand Down Expand Up @@ -36,9 +37,10 @@ type DriftCTL struct {
filter *jmespath.JMESPath
resourceFactory resource.ResourceFactory
strictMode bool
progress globaloutput.Progress
}

func NewDriftCTL(remoteSupplier resource.Supplier, iacSupplier resource.Supplier, alerter *alerter.Alerter, resFactory resource.ResourceFactory, opts *ScanOptions, resourceSchemaRepository resource.SchemaRepositoryInterface) *DriftCTL {
func NewDriftCTL(remoteSupplier resource.Supplier, iacSupplier resource.Supplier, alerter *alerter.Alerter, resFactory resource.ResourceFactory, opts *ScanOptions, resourceSchemaRepository resource.SchemaRepositoryInterface, progress globaloutput.Progress) *DriftCTL {
return &DriftCTL{
remoteSupplier,
iacSupplier,
Expand All @@ -47,6 +49,7 @@ func NewDriftCTL(remoteSupplier resource.Supplier, iacSupplier resource.Supplier
opts.Filter,
resFactory,
opts.StrictMode,
progress,
}
}

Expand Down Expand Up @@ -136,6 +139,8 @@ func (d DriftCTL) scan() (remoteResources []resource.Resource, resourcesFromStat
return nil, nil, err
}

d.progress.Start()

logrus.Info("Start scanning cloud provider")
remoteResources, err = d.remoteSupplier.Resources()
if err != nil {
Expand Down
6 changes: 5 additions & 1 deletion pkg/driftctl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"testing"

awssdk "github.com/aws/aws-sdk-go/aws"
"github.com/cloudskiff/driftctl/pkg/output"
"github.com/jmespath/go-jmespath"
"github.com/r3labs/diff/v2"
"github.com/stretchr/testify/mock"
Expand Down Expand Up @@ -82,9 +83,12 @@ func runTest(t *testing.T, cases TestCases) {
c.mocks(resourceFactory)
}

progress := &output.MockProgress{}
progress.On("Start").Return()

driftctl := pkg.NewDriftCTL(remoteSupplier, stateSupplier, testAlerter, resourceFactory, &pkg.ScanOptions{
Filter: filter,
}, repo)
}, repo, progress)

analysis, err := driftctl.Run()

Expand Down
7 changes: 5 additions & 2 deletions pkg/iac/supplier/supplier.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"

"github.com/cloudskiff/driftctl/pkg/iac/terraform/state/backend"
"github.com/cloudskiff/driftctl/pkg/output"
"github.com/cloudskiff/driftctl/pkg/terraform"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
Expand All @@ -28,7 +29,9 @@ func IsSupplierSupported(supplierKey string) bool {
return false
}

func GetIACSupplier(configs []config.SupplierConfig, library *terraform.ProviderLibrary, backendOpts *backend.Options, resourceSchemaRepository resource.SchemaRepositoryInterface) (resource.Supplier, error) {
func GetIACSupplier(configs []config.SupplierConfig, library *terraform.ProviderLibrary, backendOpts *backend.Options, resourceSchemaRepository resource.SchemaRepositoryInterface, progress output.Progress) (resource.Supplier, error) {
progress.Start()

chainSupplier := resource.NewChainSupplier()
for _, config := range configs {
if !IsSupplierSupported(config.Key) {
Expand All @@ -39,7 +42,7 @@ func GetIACSupplier(configs []config.SupplierConfig, library *terraform.Provider
var err error
switch config.Key {
case state.TerraformStateReaderSupplier:
supplier, err = state.NewReader(config, library, backendOpts, resourceSchemaRepository)
supplier, err = state.NewReader(config, library, backendOpts, resourceSchemaRepository, progress)
default:
return nil, errors.Errorf("Unsupported supplier '%s'", config.Key)
}
Expand Down
14 changes: 10 additions & 4 deletions pkg/iac/supplier/supplier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ import (

"github.com/cloudskiff/driftctl/pkg/iac/config"
"github.com/cloudskiff/driftctl/pkg/iac/terraform/state/backend"
"github.com/cloudskiff/driftctl/pkg/output"
"github.com/cloudskiff/driftctl/pkg/terraform"
"github.com/cloudskiff/driftctl/test/resource"
"github.com/stretchr/testify/assert"
)

func TestGetIACSupplier(t *testing.T) {
Expand Down Expand Up @@ -83,11 +85,15 @@ func TestGetIACSupplier(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
progress := &output.MockProgress{}
progress.On("Start").Return().Times(1)

repo := resource.InitFakeSchemaRepository("aws", "3.19.0")
_, err := GetIACSupplier(tt.args.config, terraform.NewProviderLibrary(), tt.args.options, repo)
if tt.wantErr != nil && err.Error() != tt.wantErr.Error() {
t.Errorf("GetIACSupplier() error = %v, wantErr %v", err, tt.wantErr)
return
_, err := GetIACSupplier(tt.args.config, terraform.NewProviderLibrary(), tt.args.options, repo, progress)
if tt.wantErr != nil {
assert.EqualError(t, err, tt.wantErr.Error())
} else {
assert.NoError(t, err)
}
})
}
Expand Down
1 change: 0 additions & 1 deletion pkg/iac/terraform/state/backend/s3_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ type S3Backend struct {
}

func NewS3Reader(path string) (*S3Backend, error) {

backend := S3Backend{}
bucketPath := strings.Split(path, "/")
if len(bucketPath) < 2 {
Expand Down
8 changes: 6 additions & 2 deletions pkg/iac/terraform/state/terraform_state_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/cloudskiff/driftctl/pkg/iac/config"
"github.com/cloudskiff/driftctl/pkg/iac/terraform/state/backend"
"github.com/cloudskiff/driftctl/pkg/iac/terraform/state/enumerator"
"github.com/cloudskiff/driftctl/pkg/output"
"github.com/cloudskiff/driftctl/pkg/remote/deserializer"
"github.com/cloudskiff/driftctl/pkg/resource"

Expand All @@ -30,15 +31,16 @@ type TerraformStateReader struct {
deserializers []deserializer.CTYDeserializer
backendOptions *backend.Options
resourceSchemaRepository resource.SchemaRepositoryInterface
progress output.Progress
}

func (r *TerraformStateReader) initReader() error {
r.enumerator = enumerator.GetEnumerator(r.config)
return nil
}

func NewReader(config config.SupplierConfig, library *terraform.ProviderLibrary, backendOpts *backend.Options, resourceSchemaRepository resource.SchemaRepositoryInterface) (*TerraformStateReader, error) {
reader := TerraformStateReader{library: library, config: config, deserializers: iac.Deserializers(), backendOptions: backendOpts, resourceSchemaRepository: resourceSchemaRepository}
func NewReader(config config.SupplierConfig, library *terraform.ProviderLibrary, backendOpts *backend.Options, resourceSchemaRepository resource.SchemaRepositoryInterface, progress output.Progress) (*TerraformStateReader, error) {
reader := TerraformStateReader{library: library, config: config, deserializers: iac.Deserializers(), backendOptions: backendOpts, resourceSchemaRepository: resourceSchemaRepository, progress: progress}
err := reader.initReader()
if err != nil {
return nil, err
Expand Down Expand Up @@ -213,6 +215,7 @@ func (r *TerraformStateReader) decode(values map[string][]cty.Value) ([]resource
}

func (r *TerraformStateReader) Resources() ([]resource.Resource, error) {
defer r.progress.Stop()

if r.enumerator == nil {
return r.retrieveForState(r.config.Path)
Expand All @@ -227,6 +230,7 @@ func (r *TerraformStateReader) retrieveForState(path string) ([]resource.Resourc
"path": r.config.Path,
"backend": r.config.Backend,
}).Debug("Reading resources from state")
r.progress.Inc()
values, err := r.retrieve()
if err != nil {
return nil, err
Expand Down
14 changes: 10 additions & 4 deletions pkg/iac/terraform/state/terraform_state_reader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,16 @@ func TestTerraformStateReader_AWS_Resources(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
progress := &output.MockProgress{}
progress.On("Inc").Return().Times(1)
progress.On("Stop").Return().Times(1)

shouldUpdate := tt.dirName == *goldenfile.Update

var realProvider *aws.AWSTerraformProvider

if shouldUpdate {
var err error
progress := &output.MockProgress{}
progress.On("Inc").Return()
realProvider, err = aws.NewAWSTerraformProvider(progress)
if err != nil {
t.Fatal(err)
Expand All @@ -125,6 +127,7 @@ func TestTerraformStateReader_AWS_Resources(t *testing.T) {
library: library,
deserializers: iac.Deserializers(),
resourceSchemaRepository: repo,
progress: progress,
}

got, err := r.Resources()
Expand Down Expand Up @@ -174,14 +177,16 @@ func TestTerraformStateReader_Github_Resources(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
progress := &output.MockProgress{}
progress.On("Inc").Return().Times(1)
progress.On("Stop").Return().Times(1)

shouldUpdate := tt.dirName == *goldenfile.Update

var realProvider *github.GithubTerraformProvider

if shouldUpdate {
var err error
progress := &output.MockProgress{}
progress.On("Inc").Return()
realProvider, err = github.NewGithubTerraformProvider(progress)
if err != nil {
t.Fatal(err)
Expand All @@ -202,6 +207,7 @@ func TestTerraformStateReader_Github_Resources(t *testing.T) {
},
library: library,
deserializers: iac.Deserializers(),
progress: progress,
}

got, err := r.Resources()
Expand Down
5 changes: 5 additions & 0 deletions pkg/output/mock_Progress.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions pkg/output/printer.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ func Printf(format string, args ...interface{}) {
globalPrinter.Printf(format, args...)
}

func Flush() {
globalPrinter.Printf(" \r")
}

type Printer interface {
Printf(format string, args ...interface{})
}
Expand Down
Loading

0 comments on commit 03833f4

Please sign in to comment.