Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add reporeader interface and an example extractor for python #215

Merged
merged 8 commits into from
Jun 16, 2023
Merged
14 changes: 13 additions & 1 deletion cmd/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"golang.org/x/exp/maps"
"gopkg.in/yaml.v3"

"github.com/Azure/draft/pkg/reporeader"
"github.com/Azure/draft/pkg/reporeader/readers"
"github.com/manifoldco/promptui"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
Expand Down Expand Up @@ -56,6 +58,7 @@ type createCmd struct {

templateWriter templatewriter.TemplateWriter
templateVariableRecorder config.TemplateVariableRecorder
repoReader reporeader.RepoReader
}

func newCreateCmd() *cobra.Command {
Expand Down Expand Up @@ -130,6 +133,7 @@ func (cc *createCmd) run() error {
} else {
cc.templateWriter = &writers.LocalFSWriter{}
}
cc.repoReader = &readers.LocalFSReader{}
davidgamero marked this conversation as resolved.
Show resolved Hide resolved

detectedLangDraftConfig, languageName, err := cc.detectLanguage()
if err != nil {
Expand Down Expand Up @@ -253,8 +257,16 @@ func (cc *createCmd) generateDockerfile(langConfig *config.DraftConfig, lowerLan
return errors.New("supported languages were loaded incorrectly")
}

// Extract language-specific defaults from repo
extractedDefaults, err := cc.supportedLangs.ExtractDefaults(lowerLang, cc.repoReader)
if err != nil {
return err
}
for _, d := range extractedDefaults {
langConfig.VariableDefaults = append(langConfig.VariableDefaults, d)
}

var inputs map[string]string
var err error
if cc.createConfig.LanguageVariables == nil {
inputs, err = prompts.RunPromptsFromConfigWithSkips(langConfig, maps.Keys(flagVariablesMap))
if err != nil {
Expand Down
3 changes: 1 addition & 2 deletions pkg/filematches/filematches.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package filematches

import (
"errors"
"io/ioutil"
"log"
"os"
"path/filepath"
Expand Down Expand Up @@ -44,7 +43,7 @@ func (f *FileMatches) walkFunc(path string, info os.FileInfo, err error) error {

// TODO: maybe generalize this function in the future
func isValidK8sFile(filePath string) bool {
fileContents, err := ioutil.ReadFile(filePath)
fileContents, err := os.ReadFile(filePath)
if err != nil {
log.Fatal(err)
}
Expand Down
32 changes: 32 additions & 0 deletions pkg/languages/defaults/python.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package defaults

import (
"fmt"

"github.com/Azure/draft/pkg/reporeader"
)

type PythonExtractor struct {
}

// ReadDefaults reads the default values for the language from the repo files
func (p PythonExtractor) ReadDefaults(r reporeader.RepoReader) (map[string]string, error) {
extractedValues := make(map[string]string)
files, err := r.FindFiles(".", []string{"*.py"}, 0)
if err != nil {
return nil, fmt.Errorf("error finding python files: %v", err)
}
if len(files) > 0 {
extractedValues["ENTRYPOINT"] = files[0]
davidgamero marked this conversation as resolved.
Show resolved Hide resolved
}

return extractedValues, nil
}

func (p PythonExtractor) MatchesLanguage(lowerlang string) bool {
return lowerlang == "python"
}

func (p PythonExtractor) GetName() string { return "python" }

var _ reporeader.VariableExtractor = &PythonExtractor{}
117 changes: 117 additions & 0 deletions pkg/languages/defaults/python_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package defaults

import (
"reflect"
"testing"

"github.com/Azure/draft/pkg/reporeader"
)

func TestPythonExtractor_MatchesLanguage(t *testing.T) {
type args struct {
lowerlang string
}
tests := []struct {
name string
args args
want bool
}{
{
name: "lowercase python",
args: args{
lowerlang: "python",
},
want: true,
},
{
name: "shouldn't match go",
args: args{
lowerlang: "go",
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := PythonExtractor{}
if got := p.MatchesLanguage(tt.args.lowerlang); got != tt.want {
t.Errorf("MatchesLanguage() = %v, want %v", got, tt.want)
}
})
}
}

func TestPythonExtractor_ReadDefaults(t *testing.T) {
type args struct {
r reporeader.RepoReader
}
tests := []struct {
name string
args args
want map[string]string
wantErr bool
}{
{
name: "extract first python file as entrypoint",
args: args{
r: reporeader.TestRepoReader{
Files: map[string][]byte{
"foo.py": []byte("print('hello world')"),
"bar.py": []byte("print('hello world')"),
},
},
},
want: map[string]string{
"ENTRYPOINT": "foo.py",
},
wantErr: false,
}, {
name: "no extraction if no python files",
args: args{
r: reporeader.TestRepoReader{
Files: map[string][]byte{
"foo.notpy": []byte("print('hello world')"),
"bar": []byte("print('hello world')"),
},
},
},
want: map[string]string{},
wantErr: false,
},
{
name: "empty extraction with no files",
args: args{
r: reporeader.TestRepoReader{
Files: map[string][]byte{},
},
},
want: map[string]string{},
wantErr: false,
},
{
name: "ignore files below depth root depth",
args: args{
r: reporeader.TestRepoReader{
Files: map[string][]byte{
"dir/foo.py": []byte("print('hello world')"),
},
},
},
want: map[string]string{},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := PythonExtractor{}
got, err := p.ReadDefaults(tt.args.r)
if (err != nil) != tt.wantErr {
t.Errorf("ReadDefaults() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("ReadDefaults() got = %v, want %v", got, tt.want)
}
})
}
}
41 changes: 40 additions & 1 deletion pkg/languages/languages.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@ import (
"io/fs"
"path"

log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"gopkg.in/yaml.v3"

"github.com/Azure/draft/pkg/languages/defaults"
"github.com/Azure/draft/pkg/reporeader"
log "github.com/sirupsen/logrus"

"github.com/Azure/draft/pkg/config"
"github.com/Azure/draft/pkg/embedutils"
"github.com/Azure/draft/pkg/osutil"
Expand Down Expand Up @@ -113,3 +116,39 @@ func CreateLanguagesFromEmbedFS(dockerfileTemplates embed.FS, dest string) *Lang

return l
}

func (l *Languages) ExtractDefaults(lowerLang string, r reporeader.RepoReader) ([]config.BuilderVarDefault, error) {
extractors := []reporeader.VariableExtractor{
&defaults.PythonExtractor{},
}
extractedValues := make(map[string]string)
var extractedDefaults []config.BuilderVarDefault
if r == nil {
log.Debugf("no repo reader provided, returning empty list of defaults")
return extractedDefaults, nil
}
for _, extractor := range extractors {
if extractor.MatchesLanguage(lowerLang) {
newDefaults, err := extractor.ReadDefaults(r)
if err != nil {
return nil, fmt.Errorf("error reading defaults for language %s: %v", lowerLang, err)
}
for k, v := range newDefaults {
if _, ok := extractedValues[k]; ok {
log.Debugf("duplicate default %s for language %s with extractor %s", k, lowerLang, extractor.GetName())
}
extractedValues[k] = v
log.Debugf("extracted default %s=%s with extractor:%s", k, v, extractor.GetName())
}
}
}

for k, v := range extractedValues {
extractedDefaults = append(extractedDefaults, config.BuilderVarDefault{
Name: k,
Value: v,
})
}

return extractedDefaults, nil
}
70 changes: 70 additions & 0 deletions pkg/reporeader/readers/localfsreader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package readers

import (
"fmt"
"io/fs"
"os"
"path/filepath"
"strings"

"github.com/Azure/draft/pkg/reporeader"
)

type LocalFSReader struct {
}

type LocalFileFinder struct {
Patterns []string
FoundFiles []string
MaxDepth int
}

func (l *LocalFileFinder) walkFunc(path string, info os.DirEntry, err error) error {
if err != nil {
return err
}

// Skip directories that are too deep
if info.IsDir() && strings.Count(path, string(os.PathSeparator)) > l.MaxDepth {
fmt.Println("skip", path)
return fs.SkipDir
}

if info.IsDir() {
return nil
}

for _, pattern := range l.Patterns {
if matched, err := filepath.Match(pattern, filepath.Base(path)); err != nil {
return err
} else if matched {
l.FoundFiles = append(l.FoundFiles, path)
}
}
return nil
}

func (r *LocalFSReader) FindFiles(path string, patterns []string, maxDepth int) ([]string, error) {
l := LocalFileFinder{
Patterns: patterns,
MaxDepth: maxDepth,
}
err := filepath.WalkDir(path, l.walkFunc)
if err != nil {
return nil, err
}
return l.FoundFiles, nil
}

var _ reporeader.RepoReader = &LocalFSReader{}

func (r *LocalFSReader) Exists(path string) bool {
if _, err := os.Stat(path); !os.IsNotExist(err) {
return true
}
return false
}

func (r *LocalFSReader) ReadFile(path string) ([]byte, error) {
return os.ReadFile(path)
}
62 changes: 62 additions & 0 deletions pkg/reporeader/reporeader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package reporeader

import (
"path/filepath"
"strings"
)

type RepoReader interface {
Exists(path string) bool
ReadFile(path string) ([]byte, error)
// FindFiles returns a list of files that match the given patterns searching up to
// maxDepth nested sub-directories. maxDepth of 0 limits files to the root dir.
FindFiles(path string, patterns []string, maxDepth int) ([]string, error)
}

// VariableExtractor is an interface that can be implemented for extracting variables from a repo's files
type VariableExtractor interface {
ReadDefaults(r RepoReader) (map[string]string, error)
MatchesLanguage(lowerlang string) bool
GetName() string
}

// TestRepoReader is a RepoReader that can be used for testing, and takes a list of relative file paths with their contents
type TestRepoReader struct {
Files map[string][]byte
}

func (r TestRepoReader) Exists(path string) bool {
if r.Files != nil {
_, ok := r.Files[path]
return ok
}
return false
}

func (r TestRepoReader) ReadFile(path string) ([]byte, error) {
if r.Files != nil {
return r.Files[path], nil
}
return nil, nil
}

func (r TestRepoReader) FindFiles(path string, patterns []string, maxDepth int) ([]string, error) {
var files []string
if r.Files == nil {
return files, nil
}
for k := range r.Files {
for _, pattern := range patterns {
if matched, err := filepath.Match(pattern, filepath.Base(k)); err != nil {
return nil, err
} else if matched {
splitPath := strings.Split(k, string(filepath.Separator))
fileDepth := len(splitPath) - 1
if fileDepth <= maxDepth {
files = append(files, k)
}
}
}
}
return files, nil
}