Skip to content

Commit

Permalink
Done with the refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
fredjeck committed Apr 9, 2024
1 parent f2a7b67 commit c2432e6
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 271 deletions.
80 changes: 43 additions & 37 deletions config/authorization.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,48 +12,50 @@ import (
"gopkg.in/yaml.v3"
)

type HttpMethod string
// HTTPMethod is a wrapper aroun http.Method
type HTTPMethod string

var (
HttpMethodGet HttpMethod = "GET"
HttpMethodHead HttpMethod = "HEAD"
HttpMethodPost HttpMethod = "POST"
HttpMethodPut HttpMethod = "PUT"
HttpMethodDelete HttpMethod = "DELETE"
HttpMethodConnect HttpMethod = "CONNECT"
HttpMethodOptions HttpMethod = "OPTIONS"
HttpMethodTrace HttpMethod = "TRACE"
HttpMethodPatch HttpMethod = "PATCH"
HttpMethodAll HttpMethod = "ALL"
HttpMethodUnknown HttpMethod = "UNKNOWN"
HTTPMethodGet HTTPMethod = "GET" // HTTPMethodGet HTTP GET
HTTPMethodHead HTTPMethod = "HEAD" // HTTPMethodHead HTTP HEAD
HTTPMethodPost HTTPMethod = "POST" // HTTPMethodPost HTTP POST
HTTPMethodPut HTTPMethod = "PUT" // HTTPMethodPut HTTP PUT
HTTPMethodDelete HTTPMethod = "DELETE" // HTTPMethodDelete HTTP DELETE
HTTPMethodConnect HTTPMethod = "CONNECT" // HTTPMethodConnect HTTP CONNECT
HTTPMethodOptions HTTPMethod = "OPTIONS" // HTTPMethodOptions HTTP OPTIONS
HTTPMethodTrace HTTPMethod = "TRACE" // HTTPMethodTrace HTTP TRACE
HTTPMethodPatch HTTPMethod = "PATCH" // HTTPMethodPatch HTTP PATCH
HTTPMethodAll HTTPMethod = "ALL" // HTTPMethodAll HTTP All
HTTPMethodUnknown HTTPMethod = "UNKNOWN" // HTTPMethodUnknown when parsing fails
)

func ParseHttpMethod(method string) HttpMethod {
// ParseHTTPMethod translates the provided string to an HTTPMethod and makes sure the method is supported
func ParseHTTPMethod(method string) HTTPMethod {
switch strings.ToUpper(strings.TrimSpace(method)) {
case "GET":
return HttpMethodGet
return HTTPMethodGet
case "HEAD":
return HttpMethodHead
return HTTPMethodHead
case "POST":
return HttpMethodPost
return HTTPMethodPost
case "PUT":
return HttpMethodPut
return HTTPMethodPut
case "DELETE":
return HttpMethodDelete
return HTTPMethodDelete
case "CONNECT":
return HttpMethodConnect
return HTTPMethodConnect
case "OPTIONS":
return HttpMethodOptions
return HTTPMethodOptions
case "TRACE":
return HttpMethodTrace
return HTTPMethodTrace
case "PATCH":
return HttpMethodPatch
return HTTPMethodPatch
case "ALL":
return HttpMethodAll
return HTTPMethodAll
case "UNKNOWN":
return HttpMethodUnknown
return HTTPMethodUnknown
default:
return HttpMethodUnknown
return HTTPMethodUnknown
}
}

Expand All @@ -66,12 +68,13 @@ const (
type Authorization struct {
ClientID string
Allow bool
Endpoints map[HttpMethod][]*regexp.Regexp
Endpoints map[HTTPMethod][]*regexp.Regexp
}

// NewAuthorization creates a new authorization
func NewAuthorization() *Authorization {
return &Authorization{
Endpoints: make(map[HttpMethod][]*regexp.Regexp),
Endpoints: make(map[HTTPMethod][]*regexp.Regexp),
}
}

Expand All @@ -82,6 +85,7 @@ var (
ErrInvalidMode = errors.New("mode is mandatory and should either be 'allow' or 'reject'")
)

// NewAuthorizationFromYaml Geneates a new authorization configration from the provided yaml content
func NewAuthorizationFromYaml(contents []byte) (*Authorization, error) {
auth := NewAuthorization()

Expand Down Expand Up @@ -113,7 +117,7 @@ func NewAuthorizationFromYaml(contents []byte) (*Authorization, error) {
for _, v := range paths.([]interface{}) {
switch v.(type) {
case string:
if err := auth.AppendPath(v.(string), ""); err != nil {
if err := auth.ConfigurePath(v.(string), ""); err != nil {
slog.Warn("incompatible path detected", slog.Any("error", err))
continue
}
Expand All @@ -132,7 +136,7 @@ func NewAuthorizationFromYaml(contents []byte) (*Authorization, error) {
methods = m.(string)
}

if err := auth.AppendPath(path, methods); err != nil {
if err := auth.ConfigurePath(path, methods); err != nil {
slog.Warn("incompatible path detected", slog.Any("error", err))
continue
}
Expand All @@ -155,8 +159,8 @@ func NewAuthorizationFromYaml(contents []byte) (*Authorization, error) {
return auth, nil
}

// IsPathAuthorized returns true if the provided path access should be granted
func (auth *Authorization) IsAllowed(path string, method HttpMethod) bool {
// IsAllowed returns true if the provided path access should be granted
func (auth *Authorization) IsAllowed(path string, method HTTPMethod) bool {

endpoints, ok := auth.Endpoints[method]
if ok {
Expand All @@ -167,7 +171,7 @@ func (auth *Authorization) IsAllowed(path string, method HttpMethod) bool {
}
}

endpoints, ok = auth.Endpoints[HttpMethodAll]
endpoints, ok = auth.Endpoints[HTTPMethodAll]
if !ok {
return !auth.Allow
}
Expand All @@ -180,17 +184,18 @@ func (auth *Authorization) IsAllowed(path string, method HttpMethod) bool {
return !auth.Allow
}

func (auth *Authorization) AppendPath(path string, methods string) error {
supportedMethods := make([]HttpMethod, 0)
// ConfigurePath configures the provided path for the given methods
func (auth *Authorization) ConfigurePath(path string, methods string) error {
supportedMethods := make([]HTTPMethod, 0)
lowercased := strings.ToLower(methods)

if len(methods) == 0 || strings.Contains(lowercased, "all") {
// If the user specifies all, we avoid injecting other method types
supportedMethods = append(supportedMethods, HttpMethodAll)
supportedMethods = append(supportedMethods, HTTPMethodAll)
} else {
for _, m := range strings.Split(lowercased, ",") {
method := ParseHttpMethod(m)
if method == HttpMethodUnknown {
method := ParseHTTPMethod(m)
if method == HTTPMethodUnknown {
slog.Warn(fmt.Sprintf("http method '%s' is not a supported method and will be ignored for clientID '%s'", method, auth.ClientID))
continue
}
Expand All @@ -213,6 +218,7 @@ func (auth *Authorization) AppendPath(path string, methods string) error {
return nil
}

// LoadAllAuthorizations loads all the client authorization yaml files from the provided directory
func LoadAllAuthorizations(dir string) (map[string]*Authorization, error) {

fileInfo, err := os.Stat(dir)
Expand Down
60 changes: 30 additions & 30 deletions config/authorization_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,59 +9,59 @@ import (
func TestAppendSinglePath(t *testing.T) {
auth := NewAuthorization()

auth.AppendPath("/Pokemon", "")
auth.ConfigurePath("/Pokemon", "")

endpoints, ok := auth.Endpoints[HttpMethodAll]
endpoints, ok := auth.Endpoints[HTTPMethodAll]
assert.True(t, ok)
assert.Len(t, endpoints, 1)
}

func TestAppendMultiplePaths(t *testing.T) {
auth := NewAuthorization()

auth.AppendPath("/Pokemon", "")
auth.AppendPath("/Pokemon/Ditto", "")
auth.AppendPath("/Pokemon/Pikachu", "")
auth.ConfigurePath("/Pokemon", "")
auth.ConfigurePath("/Pokemon/Ditto", "")
auth.ConfigurePath("/Pokemon/Pikachu", "")

endpoints, ok := auth.Endpoints[HttpMethodAll]
endpoints, ok := auth.Endpoints[HTTPMethodAll]
assert.True(t, ok)
assert.Len(t, endpoints, 3)
}

func TestAppendMultipleMethods(t *testing.T) {
auth := NewAuthorization()

auth.AppendPath("/Pokemon", "get, post")
auth.AppendPath("/Pokemon/Ditto", "get, post, options")
auth.AppendPath("/Pokemon/Pikachu", "post")
auth.ConfigurePath("/Pokemon", "get, post")
auth.ConfigurePath("/Pokemon/Ditto", "get, post, options")
auth.ConfigurePath("/Pokemon/Pikachu", "post")

endpoints, ok := auth.Endpoints[HttpMethodGet]
endpoints, ok := auth.Endpoints[HTTPMethodGet]
assert.True(t, ok)
assert.Len(t, endpoints, 2)

endpoints, ok = auth.Endpoints[HttpMethodPost]
endpoints, ok = auth.Endpoints[HTTPMethodPost]
assert.True(t, ok)
assert.Len(t, endpoints, 3)

endpoints, ok = auth.Endpoints[HttpMethodOptions]
endpoints, ok = auth.Endpoints[HTTPMethodOptions]
assert.True(t, ok)
assert.Len(t, endpoints, 1)
}

func TestAppendInvalidMethods(t *testing.T) {
auth := NewAuthorization()

auth.AppendPath("/Pokemon", "notknown, notvalid")
auth.AppendPath("/Pokemon/Ditto", "woopsie")
auth.ConfigurePath("/Pokemon", "notknown, notvalid")
auth.ConfigurePath("/Pokemon/Ditto", "woopsie")

assert.Len(t, auth.Endpoints, 0)
}

func TestAppendInvalidEnpoints(t *testing.T) {
auth := NewAuthorization()

auth.AppendPath("[\\]", "get")
auth.AppendPath("[ab", "put")
auth.ConfigurePath("[\\]", "get")
auth.ConfigurePath("[ab", "put")

assert.Len(t, auth.Endpoints, 0)
}
Expand All @@ -86,19 +86,19 @@ paths:
auth, err := NewAuthorizationFromYaml([]byte(yml))
assert.NoError(t, err)

endpoints, ok := auth.Endpoints[HttpMethodGet]
endpoints, ok := auth.Endpoints[HTTPMethodGet]
assert.True(t, ok)
assert.Len(t, endpoints, 1)

endpoints, ok = auth.Endpoints[HttpMethodPost]
endpoints, ok = auth.Endpoints[HTTPMethodPost]
assert.True(t, ok)
assert.Len(t, endpoints, 3)

endpoints, ok = auth.Endpoints[HttpMethodDelete]
endpoints, ok = auth.Endpoints[HTTPMethodDelete]
assert.True(t, ok)
assert.Len(t, endpoints, 1)

endpoints, ok = auth.Endpoints[HttpMethodAll]
endpoints, ok = auth.Endpoints[HTTPMethodAll]
assert.True(t, ok)
assert.Len(t, endpoints, 2)

Expand Down Expand Up @@ -216,10 +216,10 @@ paths:

auth, err := NewAuthorizationFromYaml([]byte(yml))
assert.NoError(t, err)
assert.True(t, auth.IsAllowed("/api/pokemon/ditto", HttpMethodGet))
assert.False(t, auth.IsAllowed("/api/encounter", HttpMethodGet))
assert.True(t, auth.IsAllowed("/api/encounter", HttpMethodPut))
assert.True(t, auth.IsAllowed("/api/pokemon/pikachu", HttpMethodPut))
assert.True(t, auth.IsAllowed("/api/pokemon/ditto", HTTPMethodGet))
assert.False(t, auth.IsAllowed("/api/encounter", HTTPMethodGet))
assert.True(t, auth.IsAllowed("/api/encounter", HTTPMethodPut))
assert.True(t, auth.IsAllowed("/api/pokemon/pikachu", HTTPMethodPut))
}

func TestPathIsDisallowed(t *testing.T) {
Expand All @@ -234,13 +234,13 @@ paths:

auth, err := NewAuthorizationFromYaml([]byte(yml))
assert.NoError(t, err)
assert.True(t, auth.IsAllowed("/api/pokemon/ditto", HttpMethodGet))
assert.False(t, auth.IsAllowed("/api/encounter", HttpMethodPut))
assert.True(t, auth.IsAllowed("/api/encounter", HttpMethodGet))
assert.False(t, auth.IsAllowed("/api/pokemon/pikachu", HttpMethodPut))
assert.True(t, auth.IsAllowed("/api/pokemon/ditto", HTTPMethodGet))
assert.False(t, auth.IsAllowed("/api/encounter", HTTPMethodPut))
assert.True(t, auth.IsAllowed("/api/encounter", HTTPMethodGet))
assert.False(t, auth.IsAllowed("/api/pokemon/pikachu", HTTPMethodPut))
}

func TestHttpMethodParsing(t *testing.T) {
func TestHTTPMethodParsing(t *testing.T) {
yml := `
clientID: client
mode: allow
Expand All @@ -254,7 +254,7 @@ paths:
assert.Len(t, auth.Endpoints, 9)
}

func TestHttpMethodOptimization(t *testing.T) {
func TestHTTPMethodOptimization(t *testing.T) {
yml := `
clientID: client
mode: allow
Expand Down
7 changes: 5 additions & 2 deletions logging/logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ const (
KeyAllow = "request.allow" // KeyAllow is the logging key for the request outcome
KeyClientID = "request.client.id" // KeyClientID is the logging key for the header identifier value
KeyProtocol = "request.protocol" // KeyProtocol is the logging key for the GRPC protocol version
KeyReason = "reason" // KeyReason is the logging key for the deny reason
)

// Setup configures the logging environment
Expand All @@ -44,6 +45,7 @@ type Context struct {
RequestContext interface{}
}

// AuthV3LoggingContext creates a logging context from an AuthV3 CheckRequest
func AuthV3LoggingContext(request *authv3.CheckRequest) *Context {
httpAttrs := request.GetAttributes().GetRequest().GetHttp()
return &Context{
Expand All @@ -57,10 +59,11 @@ func AuthV3LoggingContext(request *authv3.CheckRequest) *Context {
}
}

// AuthV2LoggingContext creates a logging context from an AuthV2 CheckRequest
func AuthV2LoggingContext(request *authv2.CheckRequest) *Context {
httpAttrs := request.GetAttributes().GetRequest().GetHttp()
return &Context{
Protocol: "V3",
Protocol: "V2",
Host: httpAttrs.Host,
Path: httpAttrs.Path,
Method: httpAttrs.Method,
Expand All @@ -76,11 +79,11 @@ func LogRequest(allow bool, reason string, context *Context) {
msg := fmt.Sprintf("%s %s %s for '%s'", context.Method, context.Path, outcome, context.ClientID)
if !allow {
outcome = "denied"
msg = fmt.Sprintf("%s %s %s for '%s' (reason: %s)", context.Method, context.Path, outcome, context.ClientID, reason)
}

slog.Info(msg,
slog.Bool(KeyAllow, allow),
slog.String(KeyReason, reason),
slog.String(KeyHost, context.Host),
slog.String(KeyPath, context.Path),
slog.String(KeyMethod, context.Method),
Expand Down
6 changes: 4 additions & 2 deletions server/grpcv2.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func (s *GRPCAuthzServerV2) deny(request *authv2.CheckRequest, reason string) *a
// Check implements gRPC v2 check request.
func (s *GRPCAuthzServerV2) Check(_ context.Context, request *authv2.CheckRequest) (*authv2.CheckResponse, error) {
attrs := request.GetAttributes()
method := config.HttpMethod(attrs.Request.Http.Method)
method := config.HTTPMethod(attrs.Request.Http.Method)
// Determine whether to allow or deny the request.
clientID, headerExists := attrs.GetRequest().GetHttp().GetHeaders()[s.AuthzHeader]

Expand All @@ -94,7 +94,9 @@ func (s *GRPCAuthzServerV2) Check(_ context.Context, request *authv2.CheckReques
reason = fmt.Sprintf("missing authz configuration header %s", s.AuthzHeader)
}

logging.LogRequest(allowed, reason, logging.AuthV2LoggingContext(request))
ctx := logging.AuthV2LoggingContext(request)
ctx.ClientID = clientID
logging.LogRequest(allowed, reason, ctx)
if allowed {
return s.allow(request), nil
}
Expand Down
6 changes: 4 additions & 2 deletions server/grpcv3.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func (s *GRPCAuthzServerV3) deny(request *authv3.CheckRequest, reason string) *a
// Check implements gRPC v3 check request.
func (s *GRPCAuthzServerV3) Check(_ context.Context, request *authv3.CheckRequest) (*authv3.CheckResponse, error) {
attrs := request.GetAttributes()
method := config.HttpMethod(attrs.Request.Http.Method)
method := config.HTTPMethod(attrs.Request.Http.Method)
// Determine whether to allow or deny the request.
clientID, headerExists := attrs.GetRequest().GetHttp().GetHeaders()[s.AuthzHeader]

Expand All @@ -96,7 +96,9 @@ func (s *GRPCAuthzServerV3) Check(_ context.Context, request *authv3.CheckReques
reason = fmt.Sprintf("missing authz configuration header %s", s.AuthzHeader)
}

logging.LogRequest(allowed, reason, logging.AuthV3LoggingContext(request))
ctx := logging.AuthV3LoggingContext(request)
ctx.ClientID = clientID
logging.LogRequest(allowed, reason, ctx)
if allowed {
return s.allow(request), nil
}
Expand Down
Loading

0 comments on commit c2432e6

Please sign in to comment.