diff --git a/config/authorization.go b/config/authorization.go index 8e313d6..470c3fc 100644 --- a/config/authorization.go +++ b/config/authorization.go @@ -12,48 +12,50 @@ import ( "gopkg.in/yaml.v3" ) -type HttpMethod string +// HTTPMethod is a wrapper aroun http.Method +type HTTPMethod string var ( - HttpMethodGet HttpMethod = "GET" - HttpMethodHead HttpMethod = "HEAD" - HttpMethodPost HttpMethod = "POST" - HttpMethodPut HttpMethod = "PUT" - HttpMethodDelete HttpMethod = "DELETE" - HttpMethodConnect HttpMethod = "CONNECT" - HttpMethodOptions HttpMethod = "OPTIONS" - HttpMethodTrace HttpMethod = "TRACE" - HttpMethodPatch HttpMethod = "PATCH" - HttpMethodAll HttpMethod = "ALL" - HttpMethodUnknown HttpMethod = "UNKNOWN" + HTTPMethodGet HTTPMethod = "GET" // HTTPMethodGet HTTP GET + HTTPMethodHead HTTPMethod = "HEAD" // HTTPMethodHead HTTP HEAD + HTTPMethodPost HTTPMethod = "POST" // HTTPMethodPost HTTP POST + HTTPMethodPut HTTPMethod = "PUT" // HTTPMethodPut HTTP PUT + HTTPMethodDelete HTTPMethod = "DELETE" // HTTPMethodDelete HTTP DELETE + HTTPMethodConnect HTTPMethod = "CONNECT" // HTTPMethodConnect HTTP CONNECT + HTTPMethodOptions HTTPMethod = "OPTIONS" // HTTPMethodOptions HTTP OPTIONS + HTTPMethodTrace HTTPMethod = "TRACE" // HTTPMethodTrace HTTP TRACE + HTTPMethodPatch HTTPMethod = "PATCH" // HTTPMethodPatch HTTP PATCH + HTTPMethodAll HTTPMethod = "ALL" // HTTPMethodAll HTTP All + HTTPMethodUnknown HTTPMethod = "UNKNOWN" // HTTPMethodUnknown when parsing fails ) -func ParseHttpMethod(method string) HttpMethod { +// ParseHTTPMethod translates the provided string to an HTTPMethod and makes sure the method is supported +func ParseHTTPMethod(method string) HTTPMethod { switch strings.ToUpper(strings.TrimSpace(method)) { case "GET": - return HttpMethodGet + return HTTPMethodGet case "HEAD": - return HttpMethodHead + return HTTPMethodHead case "POST": - return HttpMethodPost + return HTTPMethodPost case "PUT": - return HttpMethodPut + return HTTPMethodPut case "DELETE": - return HttpMethodDelete + return HTTPMethodDelete case "CONNECT": - return HttpMethodConnect + return HTTPMethodConnect case "OPTIONS": - return HttpMethodOptions + return HTTPMethodOptions case "TRACE": - return HttpMethodTrace + return HTTPMethodTrace case "PATCH": - return HttpMethodPatch + return HTTPMethodPatch case "ALL": - return HttpMethodAll + return HTTPMethodAll case "UNKNOWN": - return HttpMethodUnknown + return HTTPMethodUnknown default: - return HttpMethodUnknown + return HTTPMethodUnknown } } @@ -66,12 +68,13 @@ const ( type Authorization struct { ClientID string Allow bool - Endpoints map[HttpMethod][]*regexp.Regexp + Endpoints map[HTTPMethod][]*regexp.Regexp } +// NewAuthorization creates a new authorization func NewAuthorization() *Authorization { return &Authorization{ - Endpoints: make(map[HttpMethod][]*regexp.Regexp), + Endpoints: make(map[HTTPMethod][]*regexp.Regexp), } } @@ -82,6 +85,7 @@ var ( ErrInvalidMode = errors.New("mode is mandatory and should either be 'allow' or 'reject'") ) +// NewAuthorizationFromYaml Geneates a new authorization configration from the provided yaml content func NewAuthorizationFromYaml(contents []byte) (*Authorization, error) { auth := NewAuthorization() @@ -113,7 +117,7 @@ func NewAuthorizationFromYaml(contents []byte) (*Authorization, error) { for _, v := range paths.([]interface{}) { switch v.(type) { case string: - if err := auth.AppendPath(v.(string), ""); err != nil { + if err := auth.ConfigurePath(v.(string), ""); err != nil { slog.Warn("incompatible path detected", slog.Any("error", err)) continue } @@ -132,7 +136,7 @@ func NewAuthorizationFromYaml(contents []byte) (*Authorization, error) { methods = m.(string) } - if err := auth.AppendPath(path, methods); err != nil { + if err := auth.ConfigurePath(path, methods); err != nil { slog.Warn("incompatible path detected", slog.Any("error", err)) continue } @@ -155,8 +159,8 @@ func NewAuthorizationFromYaml(contents []byte) (*Authorization, error) { return auth, nil } -// IsPathAuthorized returns true if the provided path access should be granted -func (auth *Authorization) IsAllowed(path string, method HttpMethod) bool { +// IsAllowed returns true if the provided path access should be granted +func (auth *Authorization) IsAllowed(path string, method HTTPMethod) bool { endpoints, ok := auth.Endpoints[method] if ok { @@ -167,7 +171,7 @@ func (auth *Authorization) IsAllowed(path string, method HttpMethod) bool { } } - endpoints, ok = auth.Endpoints[HttpMethodAll] + endpoints, ok = auth.Endpoints[HTTPMethodAll] if !ok { return !auth.Allow } @@ -180,17 +184,18 @@ func (auth *Authorization) IsAllowed(path string, method HttpMethod) bool { return !auth.Allow } -func (auth *Authorization) AppendPath(path string, methods string) error { - supportedMethods := make([]HttpMethod, 0) +// ConfigurePath configures the provided path for the given methods +func (auth *Authorization) ConfigurePath(path string, methods string) error { + supportedMethods := make([]HTTPMethod, 0) lowercased := strings.ToLower(methods) if len(methods) == 0 || strings.Contains(lowercased, "all") { // If the user specifies all, we avoid injecting other method types - supportedMethods = append(supportedMethods, HttpMethodAll) + supportedMethods = append(supportedMethods, HTTPMethodAll) } else { for _, m := range strings.Split(lowercased, ",") { - method := ParseHttpMethod(m) - if method == HttpMethodUnknown { + method := ParseHTTPMethod(m) + if method == HTTPMethodUnknown { slog.Warn(fmt.Sprintf("http method '%s' is not a supported method and will be ignored for clientID '%s'", method, auth.ClientID)) continue } @@ -213,6 +218,7 @@ func (auth *Authorization) AppendPath(path string, methods string) error { return nil } +// LoadAllAuthorizations loads all the client authorization yaml files from the provided directory func LoadAllAuthorizations(dir string) (map[string]*Authorization, error) { fileInfo, err := os.Stat(dir) diff --git a/config/authorization_test.go b/config/authorization_test.go index b1ca730..6e5bd57 100644 --- a/config/authorization_test.go +++ b/config/authorization_test.go @@ -9,9 +9,9 @@ import ( func TestAppendSinglePath(t *testing.T) { auth := NewAuthorization() - auth.AppendPath("/Pokemon", "") + auth.ConfigurePath("/Pokemon", "") - endpoints, ok := auth.Endpoints[HttpMethodAll] + endpoints, ok := auth.Endpoints[HTTPMethodAll] assert.True(t, ok) assert.Len(t, endpoints, 1) } @@ -19,11 +19,11 @@ func TestAppendSinglePath(t *testing.T) { func TestAppendMultiplePaths(t *testing.T) { auth := NewAuthorization() - auth.AppendPath("/Pokemon", "") - auth.AppendPath("/Pokemon/Ditto", "") - auth.AppendPath("/Pokemon/Pikachu", "") + auth.ConfigurePath("/Pokemon", "") + auth.ConfigurePath("/Pokemon/Ditto", "") + auth.ConfigurePath("/Pokemon/Pikachu", "") - endpoints, ok := auth.Endpoints[HttpMethodAll] + endpoints, ok := auth.Endpoints[HTTPMethodAll] assert.True(t, ok) assert.Len(t, endpoints, 3) } @@ -31,19 +31,19 @@ func TestAppendMultiplePaths(t *testing.T) { func TestAppendMultipleMethods(t *testing.T) { auth := NewAuthorization() - auth.AppendPath("/Pokemon", "get, post") - auth.AppendPath("/Pokemon/Ditto", "get, post, options") - auth.AppendPath("/Pokemon/Pikachu", "post") + auth.ConfigurePath("/Pokemon", "get, post") + auth.ConfigurePath("/Pokemon/Ditto", "get, post, options") + auth.ConfigurePath("/Pokemon/Pikachu", "post") - endpoints, ok := auth.Endpoints[HttpMethodGet] + endpoints, ok := auth.Endpoints[HTTPMethodGet] assert.True(t, ok) assert.Len(t, endpoints, 2) - endpoints, ok = auth.Endpoints[HttpMethodPost] + endpoints, ok = auth.Endpoints[HTTPMethodPost] assert.True(t, ok) assert.Len(t, endpoints, 3) - endpoints, ok = auth.Endpoints[HttpMethodOptions] + endpoints, ok = auth.Endpoints[HTTPMethodOptions] assert.True(t, ok) assert.Len(t, endpoints, 1) } @@ -51,8 +51,8 @@ func TestAppendMultipleMethods(t *testing.T) { func TestAppendInvalidMethods(t *testing.T) { auth := NewAuthorization() - auth.AppendPath("/Pokemon", "notknown, notvalid") - auth.AppendPath("/Pokemon/Ditto", "woopsie") + auth.ConfigurePath("/Pokemon", "notknown, notvalid") + auth.ConfigurePath("/Pokemon/Ditto", "woopsie") assert.Len(t, auth.Endpoints, 0) } @@ -60,8 +60,8 @@ func TestAppendInvalidMethods(t *testing.T) { func TestAppendInvalidEnpoints(t *testing.T) { auth := NewAuthorization() - auth.AppendPath("[\\]", "get") - auth.AppendPath("[ab", "put") + auth.ConfigurePath("[\\]", "get") + auth.ConfigurePath("[ab", "put") assert.Len(t, auth.Endpoints, 0) } @@ -86,19 +86,19 @@ paths: auth, err := NewAuthorizationFromYaml([]byte(yml)) assert.NoError(t, err) - endpoints, ok := auth.Endpoints[HttpMethodGet] + endpoints, ok := auth.Endpoints[HTTPMethodGet] assert.True(t, ok) assert.Len(t, endpoints, 1) - endpoints, ok = auth.Endpoints[HttpMethodPost] + endpoints, ok = auth.Endpoints[HTTPMethodPost] assert.True(t, ok) assert.Len(t, endpoints, 3) - endpoints, ok = auth.Endpoints[HttpMethodDelete] + endpoints, ok = auth.Endpoints[HTTPMethodDelete] assert.True(t, ok) assert.Len(t, endpoints, 1) - endpoints, ok = auth.Endpoints[HttpMethodAll] + endpoints, ok = auth.Endpoints[HTTPMethodAll] assert.True(t, ok) assert.Len(t, endpoints, 2) @@ -216,10 +216,10 @@ paths: auth, err := NewAuthorizationFromYaml([]byte(yml)) assert.NoError(t, err) - assert.True(t, auth.IsAllowed("/api/pokemon/ditto", HttpMethodGet)) - assert.False(t, auth.IsAllowed("/api/encounter", HttpMethodGet)) - assert.True(t, auth.IsAllowed("/api/encounter", HttpMethodPut)) - assert.True(t, auth.IsAllowed("/api/pokemon/pikachu", HttpMethodPut)) + assert.True(t, auth.IsAllowed("/api/pokemon/ditto", HTTPMethodGet)) + assert.False(t, auth.IsAllowed("/api/encounter", HTTPMethodGet)) + assert.True(t, auth.IsAllowed("/api/encounter", HTTPMethodPut)) + assert.True(t, auth.IsAllowed("/api/pokemon/pikachu", HTTPMethodPut)) } func TestPathIsDisallowed(t *testing.T) { @@ -234,13 +234,13 @@ paths: auth, err := NewAuthorizationFromYaml([]byte(yml)) assert.NoError(t, err) - assert.True(t, auth.IsAllowed("/api/pokemon/ditto", HttpMethodGet)) - assert.False(t, auth.IsAllowed("/api/encounter", HttpMethodPut)) - assert.True(t, auth.IsAllowed("/api/encounter", HttpMethodGet)) - assert.False(t, auth.IsAllowed("/api/pokemon/pikachu", HttpMethodPut)) + assert.True(t, auth.IsAllowed("/api/pokemon/ditto", HTTPMethodGet)) + assert.False(t, auth.IsAllowed("/api/encounter", HTTPMethodPut)) + assert.True(t, auth.IsAllowed("/api/encounter", HTTPMethodGet)) + assert.False(t, auth.IsAllowed("/api/pokemon/pikachu", HTTPMethodPut)) } -func TestHttpMethodParsing(t *testing.T) { +func TestHTTPMethodParsing(t *testing.T) { yml := ` clientID: client mode: allow @@ -254,7 +254,7 @@ paths: assert.Len(t, auth.Endpoints, 9) } -func TestHttpMethodOptimization(t *testing.T) { +func TestHTTPMethodOptimization(t *testing.T) { yml := ` clientID: client mode: allow diff --git a/logging/logging.go b/logging/logging.go index 5fb3375..2cfa867 100644 --- a/logging/logging.go +++ b/logging/logging.go @@ -20,6 +20,7 @@ const ( KeyAllow = "request.allow" // KeyAllow is the logging key for the request outcome KeyClientID = "request.client.id" // KeyClientID is the logging key for the header identifier value KeyProtocol = "request.protocol" // KeyProtocol is the logging key for the GRPC protocol version + KeyReason = "reason" // KeyReason is the logging key for the deny reason ) // Setup configures the logging environment @@ -44,6 +45,7 @@ type Context struct { RequestContext interface{} } +// AuthV3LoggingContext creates a logging context from an AuthV3 CheckRequest func AuthV3LoggingContext(request *authv3.CheckRequest) *Context { httpAttrs := request.GetAttributes().GetRequest().GetHttp() return &Context{ @@ -57,10 +59,11 @@ func AuthV3LoggingContext(request *authv3.CheckRequest) *Context { } } +// AuthV2LoggingContext creates a logging context from an AuthV2 CheckRequest func AuthV2LoggingContext(request *authv2.CheckRequest) *Context { httpAttrs := request.GetAttributes().GetRequest().GetHttp() return &Context{ - Protocol: "V3", + Protocol: "V2", Host: httpAttrs.Host, Path: httpAttrs.Path, Method: httpAttrs.Method, @@ -76,11 +79,11 @@ func LogRequest(allow bool, reason string, context *Context) { msg := fmt.Sprintf("%s %s %s for '%s'", context.Method, context.Path, outcome, context.ClientID) if !allow { outcome = "denied" - msg = fmt.Sprintf("%s %s %s for '%s' (reason: %s)", context.Method, context.Path, outcome, context.ClientID, reason) } slog.Info(msg, slog.Bool(KeyAllow, allow), + slog.String(KeyReason, reason), slog.String(KeyHost, context.Host), slog.String(KeyPath, context.Path), slog.String(KeyMethod, context.Method), diff --git a/server/grpcv2.go b/server/grpcv2.go index a5317f6..e0ca418 100644 --- a/server/grpcv2.go +++ b/server/grpcv2.go @@ -72,7 +72,7 @@ func (s *GRPCAuthzServerV2) deny(request *authv2.CheckRequest, reason string) *a // Check implements gRPC v2 check request. func (s *GRPCAuthzServerV2) Check(_ context.Context, request *authv2.CheckRequest) (*authv2.CheckResponse, error) { attrs := request.GetAttributes() - method := config.HttpMethod(attrs.Request.Http.Method) + method := config.HTTPMethod(attrs.Request.Http.Method) // Determine whether to allow or deny the request. clientID, headerExists := attrs.GetRequest().GetHttp().GetHeaders()[s.AuthzHeader] @@ -94,7 +94,9 @@ func (s *GRPCAuthzServerV2) Check(_ context.Context, request *authv2.CheckReques reason = fmt.Sprintf("missing authz configuration header %s", s.AuthzHeader) } - logging.LogRequest(allowed, reason, logging.AuthV2LoggingContext(request)) + ctx := logging.AuthV2LoggingContext(request) + ctx.ClientID = clientID + logging.LogRequest(allowed, reason, ctx) if allowed { return s.allow(request), nil } diff --git a/server/grpcv3.go b/server/grpcv3.go index 14ac842..6bc2c4c 100644 --- a/server/grpcv3.go +++ b/server/grpcv3.go @@ -74,7 +74,7 @@ func (s *GRPCAuthzServerV3) deny(request *authv3.CheckRequest, reason string) *a // Check implements gRPC v3 check request. func (s *GRPCAuthzServerV3) Check(_ context.Context, request *authv3.CheckRequest) (*authv3.CheckResponse, error) { attrs := request.GetAttributes() - method := config.HttpMethod(attrs.Request.Http.Method) + method := config.HTTPMethod(attrs.Request.Http.Method) // Determine whether to allow or deny the request. clientID, headerExists := attrs.GetRequest().GetHttp().GetHeaders()[s.AuthzHeader] @@ -96,7 +96,9 @@ func (s *GRPCAuthzServerV3) Check(_ context.Context, request *authv3.CheckReques reason = fmt.Sprintf("missing authz configuration header %s", s.AuthzHeader) } - logging.LogRequest(allowed, reason, logging.AuthV3LoggingContext(request)) + ctx := logging.AuthV3LoggingContext(request) + ctx.ClientID = clientID + logging.LogRequest(allowed, reason, ctx) if allowed { return s.allow(request), nil } diff --git a/server/http.go b/server/http.go index 273d405..8369abb 100644 --- a/server/http.go +++ b/server/http.go @@ -12,6 +12,8 @@ import ( ) // HTTPAuthzServer implements an Envoy custom HTTP authorization filter +// Please note that HTTP authorization server has been disabled - old code can be found in tag 0.0.1-alpha +// HTTP Server is only kept for health check purposes type HTTPAuthzServer struct { httpServer *http.Server configuration *config.Configuration @@ -47,7 +49,6 @@ func (srv *HTTPAuthzServer) Start(wg *sync.WaitGroup, healthFunc func() (bool, s mux := http.NewServeMux() mux.HandleFunc("/healtz", handleHealth(healthFunc)) - mux.HandleFunc("/", handleCheck(srv.configuration.HTTPAuthZHeader, srv.configuration.Authorizations)) srv.httpServer = &http.Server{Handler: mux} srv.ready <- true // for testing @@ -79,39 +80,3 @@ func handleHealth(healthFunc func() (bool, string)) func(w http.ResponseWriter, response.Write([]byte(desc)) } } - -// Handles authorization requests -func handleCheck(authzHeader string, authorization map[string]*config.Authorization) func(w http.ResponseWriter, r *http.Request) { - return func(response http.ResponseWriter, request *http.Request) { - // body, err := io.ReadAll(request.Body) - // if err != nil { - // log.Printf("[HTTP] read body failed: %v", err) - // } - // l := fmt.Sprintf("%s %s%s, headers: %v, body: [%s]\n", request.Method, request.Host, request.URL, request.Header, truncate(string(body))) - - auth, ok := authorization[request.Header.Get(authzHeader)] - if !ok || !auth.IsAllowed(request.URL.String(), config.HttpMethod(request.Method)) { - deny(response) - - } else { - allow(response) - } - } -} - -func allow(response http.ResponseWriter) { - //log.Printf("[HTTP][allowed]: %s", l) - response.Header().Set(resultHeader, resultAllowed) - // response.Header().Set(overrideHeader, request.Header.Get(overrideHeader)) - // response.Header().Set(receivedHeader, l) - response.WriteHeader(http.StatusOK) -} - -func deny(response http.ResponseWriter) { - //log.Printf("[HTTP][allowed]: %s", l) - response.Header().Set(resultHeader, resultDenied) - // response.Header().Set(overrideHeader, request.Header.Get(overrideHeader)) - // response.Header().Set(receivedHeader, l) - response.WriteHeader(http.StatusForbidden) - //_, _ = response.Write([]byte(denyBody)) -} diff --git a/server/server_test.go b/server/server_test.go index a5b076d..0671f99 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -39,13 +39,22 @@ paths: methods: GET, PUT ` -var testCases = []struct { +const clientB = ` +clientID: clientB +mode: deny +paths: + - path: /pokemon/.*? +` + +type testCase struct { name string url string method string clientID string want int -}{ +} + +var testCases = []testCase{ { name: "Allow GET", url: "/pokemon/pikachu", @@ -53,130 +62,59 @@ var testCases = []struct { method: http.MethodGet, want: int(codes.OK), }, + { + name: "Allow PUT", + url: "/pokemon/tortank", + clientID: "clientA", + method: http.MethodPut, + want: int(codes.OK), + }, { name: "Deny DELETE", - url: "/pokemon/pikachu", + url: "/pokemon/ditto", clientID: "clientA", method: http.MethodDelete, want: int(codes.PermissionDenied), }, { - name: "Deny client", - url: "/pokemon/pikachu", - clientID: "clientB", - method: http.MethodGet, + name: "Deny URL", + url: "/berries", + clientID: "clientA", + method: http.MethodDelete, want: int(codes.PermissionDenied), }, -} - -var cases = []struct { - name string - isGRPCV3 bool - isGRPCV2 bool - header string - want int -}{ - { - name: "HTTP-allow", - header: "allow", - want: http.StatusOK, - }, - { - name: "HTTP-deny", - header: "deny", - want: http.StatusForbidden, - }, - { - name: "GRPCv3-allow", - isGRPCV3: true, - header: "allow", - want: int(codes.OK), - }, { - name: "GRPCv3-deny", - isGRPCV3: true, - header: "deny", + name: "Deny URL", + url: "/pokemon/pikachu", + clientID: "clientB", + method: http.MethodGet, want: int(codes.PermissionDenied), }, { - name: "GRPCv2-allow", - isGRPCV2: true, - header: "allow", + name: "Allow URL", + url: "/encounters", + clientID: "clientB", + method: http.MethodGet, want: int(codes.OK), }, { - name: "GRPCv2-deny", - isGRPCV2: true, - header: "deny", + name: "Deny Client", + url: "/gyms", + clientID: "clientC", + method: http.MethodGet, want: int(codes.PermissionDenied), }, } -func grpcV3Request(grpcV3Client authv3.AuthorizationClient, header string) (*authv3.CheckResponse, error) { - return grpcV3Client.Check(context.Background(), &authv3.CheckRequest{ - Attributes: &authv3.AttributeContext{ - Request: &authv3.AttributeContext_Request{ - Http: &authv3.AttributeContext_HttpRequest{ - Host: "localhost", - Path: "/check", - Headers: map[string]string{checkHeader: header}, - }, - }, - }, - }) -} - -func grpcV3PathRequest(grpcV3Client authv3.AuthorizationClient, clientID string, path string, method string) (*authv3.CheckResponse, error) { - return grpcV3Client.Check(context.Background(), &authv3.CheckRequest{ - Attributes: &authv3.AttributeContext{ - Request: &authv3.AttributeContext_Request{ - Http: &authv3.AttributeContext_HttpRequest{ - Host: "localhost", - Path: path, - Method: method, - Headers: map[string]string{checkHeader: clientID}, - }, - }, - }, - }) -} - -func grpcV2PathRequest(grpcV2Client authv2.AuthorizationClient, clientID string, path string, method string) (*authv2.CheckResponse, error) { - return grpcV2Client.Check(context.Background(), &authv2.CheckRequest{ - Attributes: &authv2.AttributeContext{ - Request: &authv2.AttributeContext_Request{ - Http: &authv2.AttributeContext_HttpRequest{ - Host: "localhost", - Path: path, - Method: method, - Headers: map[string]string{checkHeader: clientID}, - }, - }, - }, - }) -} - -func grpcV2Request(grpcV2Client authv2.AuthorizationClient, header string) (*authv2.CheckResponse, error) { - return grpcV2Client.Check(context.Background(), &authv2.CheckRequest{ - Attributes: &authv2.AttributeContext{ - Request: &authv2.AttributeContext_Request{ - Http: &authv2.AttributeContext_HttpRequest{ - Host: "localhost", - Path: "/check", - Headers: map[string]string{checkHeader: header}, - }, - }, - }, - }) -} - func TestExtAuthz(t *testing.T) { logging.Setup() authz := make(map[string]*config.Authorization) - ca, err := config.NewAuthorizationFromYaml([]byte(clientA)) - authz["clientA"] = ca + client, err := config.NewAuthorizationFromYaml([]byte(clientA)) + authz[client.ClientID] = client + client, err = config.NewAuthorizationFromYaml([]byte(clientB)) + authz[client.ClientID] = client server := NewJarlAuthzServer(&config.Configuration{ HTTPListenOn: "localhost:0", @@ -187,12 +125,8 @@ func TestExtAuthz(t *testing.T) { // Start the test server on random port. go server.Start() - // Prepare the HTTP request. + // Wait until HTTP Server is ready _ = <-server.httpServer.ready - httpReq, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://localhost:%d/check", server.httpServer.port), nil) - if err != nil { - t.Fatalf(err.Error()) - } // Prepare the gRPC request. _ = <-server.grpcServer.ready @@ -204,68 +138,62 @@ func TestExtAuthz(t *testing.T) { grpcV3Client := authv3.NewAuthorizationClient(conn) grpcV2Client := authv2.NewAuthorizationClient(conn) - runExtendedTestCases(t, grpcV2Client, grpcV3Client, httpReq) + runTestCases(t, grpcV2Client, grpcV3Client) } -func runExtendedTestCases(t *testing.T, grpcV2Client authv2.AuthorizationClient, grpcV3Client authv3.AuthorizationClient, httpReq *http.Request) { - var got int +func runTestCases(t *testing.T, grpcV2Client authv2.AuthorizationClient, grpcV3Client authv3.AuthorizationClient) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - resp, err := grpcV3PathRequest(grpcV3Client, tc.clientID, tc.url, tc.method) - if err != nil { - t.Errorf(err.Error()) - } else { - got = int(resp.Status.Code) - } - if got != tc.want { - t.Errorf("'%s' want %d but got %d", tc.name, tc.want, got) - } - - respv2, err := grpcV2PathRequest(grpcV2Client, tc.clientID, tc.url, tc.method) - if err != nil { - t.Errorf(err.Error()) - } else { - got = int(respv2.Status.Code) - } - if got != tc.want { - t.Errorf("'%s' want %d but got %d", tc.name, tc.want, got) - } + runGrpcV2Request(t, tc, grpcV2Client) + runGrpcV3Request(t, tc, grpcV3Client) }) } } -func runTestCases(t *testing.T, grpcV2Client authv2.AuthorizationClient, grpcV3Client authv3.AuthorizationClient, httpReq *http.Request) { - httpClient := &http.Client{} - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - var got int - if tc.isGRPCV3 { - resp, err := grpcV3Request(grpcV3Client, tc.header) - if err != nil { - t.Errorf(err.Error()) - } else { - got = int(resp.Status.Code) - } - } else if tc.isGRPCV2 { - resp, err := grpcV2Request(grpcV2Client, tc.header) - if err != nil { - t.Errorf(err.Error()) - } else { - got = int(resp.Status.Code) - } - } else { - httpReq.Header.Set(checkHeader, tc.header) - resp, err := httpClient.Do(httpReq) - if err != nil { - t.Errorf(err.Error()) - } else { - got = resp.StatusCode - resp.Body.Close() - } - } - if got != tc.want { - t.Errorf("want %d but got %d", tc.want, got) - } - }) +func runGrpcV3Request(t *testing.T, tc testCase, grpcV3Client authv3.AuthorizationClient) { + resp, err := grpcV3Client.Check(context.Background(), &authv3.CheckRequest{ + Attributes: &authv3.AttributeContext{ + Request: &authv3.AttributeContext_Request{ + Http: &authv3.AttributeContext_HttpRequest{ + Host: "localhost", + Path: tc.url, + Method: tc.method, + Headers: map[string]string{checkHeader: tc.clientID}, + }, + }, + }, + }) + + if err != nil { + t.Errorf(err.Error()) + } + + if int(resp.Status.Code) != tc.want { + t.Errorf("'%s' want %d but got %d", tc.name, tc.want, int(resp.Status.Code)) + } + return +} + +func runGrpcV2Request(t *testing.T, tc testCase, grpcV2Client authv2.AuthorizationClient) { + resp, err := grpcV2Client.Check(context.Background(), &authv2.CheckRequest{ + Attributes: &authv2.AttributeContext{ + Request: &authv2.AttributeContext_Request{ + Http: &authv2.AttributeContext_HttpRequest{ + Host: "localhost", + Path: tc.url, + Method: tc.method, + Headers: map[string]string{checkHeader: tc.clientID}, + }, + }, + }, + }) + + if err != nil { + t.Errorf(err.Error()) + } + + if int(resp.Status.Code) != tc.want { + t.Errorf("'%s' want %d but got %d", tc.name, tc.want, int(resp.Status.Code)) } + return }