From 3cab285f11b6cafced19dd42756dca821a89dda7 Mon Sep 17 00:00:00 2001 From: Jihui Nie <56172920+jihuin@users.noreply.github.com> Date: Wed, 3 Aug 2022 19:23:18 -0700 Subject: [PATCH] fix: Allow registering multiple functions with one server for local testing. (#143) * Allow registering multiple functions with one server for local testing. * Allow registering multiple functions with one server for local testing. * Allow registering multiple functions with one server for local testing. * Let RegisterXXXFunctionContext call registry.Default().RegisterXXX * fix some nits --- funcframework/framework.go | 121 +++++++++++---------- funcframework/framework_test.go | 166 ++++++++++++++++++++++------- internal/registry/registry.go | 64 +++++++++-- internal/registry/registry_test.go | 151 +++++++++++++++++++++----- 4 files changed, 382 insertions(+), 120 deletions(-) diff --git a/funcframework/framework.go b/funcframework/framework.go index ea483f6..2959365 100644 --- a/funcframework/framework.go +++ b/funcframework/framework.go @@ -39,10 +39,6 @@ const ( fnErrorMessageStderrTmpl = "Function error: %v" ) -var ( - handler http.Handler -) - // recoverPanic recovers from a panic in a consistent manner. panicSrc should // describe what was happening when the panic was encountered, for example // "user function execution". w is an http.ResponseWriter to write a generic @@ -86,72 +82,92 @@ func RegisterEventFunction(path string, fn interface{}) { // RegisterHTTPFunctionContext registers fn as an HTTP function. func RegisterHTTPFunctionContext(ctx context.Context, path string, fn func(http.ResponseWriter, *http.Request)) error { - server, err := wrapHTTPFunction(path, fn) - if err == nil { - handler = server - } - return err + funcName := fmt.Sprintf("function_at_path_%q", path) + return registry.Default().RegisterHTTP(funcName, fn, registry.WithPath(path)) } // RegisterEventFunctionContext registers fn as an event function. The function must have two arguments, a // context.Context and a struct type depending on the event, and return an error. If fn has the // wrong signature, RegisterEventFunction returns an error. func RegisterEventFunctionContext(ctx context.Context, path string, fn interface{}) error { - server, err := wrapEventFunction(path, fn) - if err == nil { - handler = server - } - return err + funcName := fmt.Sprintf("function_at_path_%q", path) + return registry.Default().RegisterEvent(funcName, fn, registry.WithPath(path)) } // RegisterCloudEventFunctionContext registers fn as an cloudevent function. func RegisterCloudEventFunctionContext(ctx context.Context, path string, fn func(context.Context, cloudevents.Event) error) error { - server, err := wrapCloudEventFunction(ctx, path, fn) - if err == nil { - handler = server - } - return err + funcName := fmt.Sprintf("function_at_path_%q", path) + return registry.Default().RegisterCloudEvent(funcName, fn, registry.WithPath(path)) } // Start serves an HTTP server with registered function(s). func Start(port string) error { - // If FUNCTION_TARGET, try to start with that registered function - // If not set, assume non-declarative functions. - target := os.Getenv("FUNCTION_TARGET") + server, err := initServer() + if err != nil { + return err + } + return http.ListenAndServe(":"+port, server) +} - // Check if we have a function resource set, and if so, log progress. - if os.Getenv("K_SERVICE") == "" { - fmt.Printf("Serving function: %s\n", target) +func initServer() (*http.ServeMux, error) { + server := http.NewServeMux() + + // If FUNCTION_TARGET is set, only serve this target function at path "/". + // If not set, serve all functions at the registered paths. + if target := os.Getenv("FUNCTION_TARGET"); len(target) > 0 { + fn, ok := registry.Default().GetRegisteredFunction(target) + if !ok { + return nil, fmt.Errorf("no matching function found with name: %q", target) + } + h, err := wrapFunction(fn) + if err != nil { + return nil, fmt.Errorf("failed to serve function %q: %v", target, err) + } + server.Handle("/", h) + return server, nil } - // Check if there's a registered function, and use if possible - if fn, ok := registry.Default().GetRegisteredFunction(target); ok { - ctx := context.Background() - if fn.HTTPFn != nil { - server, err := wrapHTTPFunction("/", fn.HTTPFn) - if err != nil { - return fmt.Errorf("unexpected error in registerHTTPFunction: %v", err) - } - handler = server - } else if fn.CloudEventFn != nil { - server, err := wrapCloudEventFunction(ctx, "/", fn.CloudEventFn) - if err != nil { - return fmt.Errorf("unexpected error in registerCloudEventFunction: %v", err) - } - handler = server + fns := registry.Default().GetAllFunctions() + for funcName, fn := range fns { + h, err := wrapFunction(fn) + if err != nil { + return nil, fmt.Errorf("failed to serve function %q: %v", funcName, err) } + server.Handle(fn.Path, h) } + return server, nil +} - if handler == nil { - return fmt.Errorf("no matching function found with name: %q", target) +func wrapFunction(fn registry.RegisteredFunction) (http.Handler, error) { + // Check if we have a function resource set, and if so, log progress. + if os.Getenv("K_SERVICE") == "" { + fmt.Printf("Serving function %s\n", fn.Name) } - return http.ListenAndServe(":"+port, handler) + if fn.HTTPFn != nil { + handler, err := wrapHTTPFunction(fn.HTTPFn) + if err != nil { + return nil, fmt.Errorf("unexpected error in wrapHTTPFunction: %v", err) + } + return handler, nil + } else if fn.CloudEventFn != nil { + handler, err := wrapCloudEventFunction(context.Background(), fn.CloudEventFn) + if err != nil { + return nil, fmt.Errorf("unexpected error in wrapCloudEventFunction: %v", err) + } + return handler, nil + } else if fn.EventFn != nil { + handler, err := wrapEventFunction(fn.EventFn) + if err != nil { + return nil, fmt.Errorf("unexpected error in wrapEventFunction: %v", err) + } + return handler, nil + } + return nil, fmt.Errorf("missing function entry in %v", fn) } -func wrapHTTPFunction(path string, fn func(http.ResponseWriter, *http.Request)) (http.Handler, error) { - h := http.NewServeMux() - h.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) { +func wrapHTTPFunction(fn func(http.ResponseWriter, *http.Request)) (http.Handler, error) { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // TODO(b/111823046): Remove following once Cloud Functions does not need flushing the logs anymore. if os.Getenv("K_SERVICE") != "" { // Force flush of logs after every function trigger when running on GCF. @@ -160,17 +176,15 @@ func wrapHTTPFunction(path string, fn func(http.ResponseWriter, *http.Request)) } defer recoverPanic(w, "user function execution") fn(w, r) - }) - return h, nil + }), nil } -func wrapEventFunction(path string, fn interface{}) (http.Handler, error) { - h := http.NewServeMux() +func wrapEventFunction(fn interface{}) (http.Handler, error) { err := validateEventFunction(fn) if err != nil { return nil, err } - h.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if os.Getenv("K_SERVICE") != "" { // Force flush of logs after every function trigger when running on GCF. defer fmt.Println() @@ -184,11 +198,10 @@ func wrapEventFunction(path string, fn interface{}) (http.Handler, error) { } handleEventFunction(w, r, fn) - }) - return h, nil + }), nil } -func wrapCloudEventFunction(ctx context.Context, path string, fn func(context.Context, cloudevents.Event) error) (http.Handler, error) { +func wrapCloudEventFunction(ctx context.Context, fn func(context.Context, cloudevents.Event) error) (http.Handler, error) { p, err := cloudevents.NewHTTP() if err != nil { return nil, fmt.Errorf("failed to create protocol: %v", err) diff --git a/funcframework/framework_test.go b/funcframework/framework_test.go index 243d950..4072b3e 100644 --- a/funcframework/framework_test.go +++ b/funcframework/framework_test.go @@ -32,15 +32,17 @@ import ( "github.com/google/go-cmp/cmp" ) -func TestHTTPFunction(t *testing.T) { +func TestRegisterHTTPFunctionContext(t *testing.T) { tests := []struct { name string + path string fn func(w http.ResponseWriter, r *http.Request) wantStatus int // defaults to http.StatusOK wantResp string }{ { name: "helloworld", + path: "/TestRegisterHTTPFunctionContext_helloworld", fn: func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "Hello World!") }, @@ -48,6 +50,7 @@ func TestHTTPFunction(t *testing.T) { }, { name: "panic in function", + path: "/TestRegisterHTTPFunctionContext_panic", fn: func(w http.ResponseWriter, r *http.Request) { panic("intentional panic for test") }, @@ -58,16 +61,18 @@ func TestHTTPFunction(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - h, err := wrapHTTPFunction("/", tc.fn) - defer func() { handler = nil }() - if err != nil { - t.Fatalf("registerHTTPFunction(): %v", err) + if err := RegisterHTTPFunctionContext(context.Background(), tc.path, tc.fn); err != nil { + t.Fatalf("RegisterHTTPFunctionContext(): %v", err) } - srv := httptest.NewServer(h) + server, err := initServer() + if err != nil { + t.Fatalf("initServer(): %v", err) + } + srv := httptest.NewServer(server) defer srv.Close() - resp, err := http.Get(srv.URL) + resp, err := http.Get(srv.URL + tc.path) if err != nil { t.Fatalf("http.Get: %v", err) } @@ -76,7 +81,7 @@ func TestHTTPFunction(t *testing.T) { tc.wantStatus = http.StatusOK } if resp.StatusCode != tc.wantStatus { - t.Errorf("TestHTTPFunction status code: got %d, want: %d", resp.StatusCode, tc.wantStatus) + t.Errorf("unexpected status code: got %d, want: %d", resp.StatusCode, tc.wantStatus) } defer resp.Body.Close() @@ -101,9 +106,10 @@ type eventData struct { Data string `json:"data"` } -func TestEventFunction(t *testing.T) { +func TestRegisterEventFunctionContext(t *testing.T) { var tests = []struct { name string + path string body []byte fn interface{} status int @@ -114,6 +120,7 @@ func TestEventFunction(t *testing.T) { }{ { name: "valid function", + path: "/TestRegisterEventFunctionContext_valid", body: []byte(`{"id": 12345,"name": "custom"}`), fn: func(c context.Context, s customStruct) error { if s.ID != 12345 { @@ -129,6 +136,7 @@ func TestEventFunction(t *testing.T) { }, { name: "incorrect type", + path: "/TestRegisterEventFunctionContext_incorrect", body: []byte(`{"id": 12345,"name": 123}`), fn: func(c context.Context, s customStruct) error { return nil @@ -138,6 +146,7 @@ func TestEventFunction(t *testing.T) { }, { name: "erroring function", + path: "/TestRegisterEventFunctionContext_erroring", body: []byte(`{"id": 12345,"name": "custom"}`), fn: func(c context.Context, s customStruct) error { return fmt.Errorf("TestEventFunction(erroring function): this error should fire") @@ -149,6 +158,7 @@ func TestEventFunction(t *testing.T) { }, { name: "panicking function", + path: "/TestRegisterEventFunctionContext_panicking", body: []byte(`{"id": 12345,"name": "custom"}`), fn: func(c context.Context, s customStruct) error { panic("intential panic for test") @@ -159,6 +169,7 @@ func TestEventFunction(t *testing.T) { }, { name: "pubsub event", + path: "/TestRegisterEventFunctionContext_pubsub1", body: []byte(`{ "context": { "eventId": "1234567", @@ -187,6 +198,7 @@ func TestEventFunction(t *testing.T) { }, { name: "pubsub legacy event", + path: "/TestRegisterEventFunctionContext_pubsub2", body: []byte(`{ "eventId": "1234567", "timestamp": "2019-11-04T23:01:10.112Z", @@ -213,6 +225,7 @@ func TestEventFunction(t *testing.T) { }, { name: "cloudevent", + path: "/TestRegisterEventFunctionContext_cloudevent", body: []byte(`{ "data": { "bucket": "some-bucket", @@ -283,9 +296,8 @@ func TestEventFunction(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - h, err := wrapEventFunction("/", tc.fn) - if err != nil { - t.Fatalf("registerEventFunction(): %v", err) + if err := RegisterEventFunctionContext(context.Background(), tc.path, tc.fn); err != nil { + t.Fatalf("RegisterEventFunctionContext(): %v", err) } // Capture stderr for the duration of the test case. This includes @@ -295,10 +307,14 @@ func TestEventFunction(t *testing.T) { os.Stderr = w defer func() { os.Stderr = origStderrPipe }() - srv := httptest.NewServer(h) + server, err := initServer() + if err != nil { + t.Fatalf("initServer(): %v", err) + } + srv := httptest.NewServer(server) defer srv.Close() - req, err := http.NewRequest("POST", srv.URL, bytes.NewBuffer(tc.body)) + req, err := http.NewRequest("POST", srv.URL+tc.path, bytes.NewBuffer(tc.body)) if err != nil { t.Fatalf("error creating HTTP request for test: %v", err) } @@ -350,7 +366,7 @@ func TestEventFunction(t *testing.T) { } } -func TestCloudEventFunction(t *testing.T) { +func TestRegisterCloudEventFunctionContext(t *testing.T) { cloudeventsJSON := []byte(`{ "specversion" : "1.0", "type" : "com.github.pull.create", @@ -370,6 +386,7 @@ func TestCloudEventFunction(t *testing.T) { var tests = []struct { name string + path string body []byte fn func(context.Context, cloudevents.Event) error status int @@ -380,6 +397,7 @@ func TestCloudEventFunction(t *testing.T) { }{ { name: "binary cloudevent", + path: "/TestRegisterCloudEventFunctionContext_binary", body: []byte(""), fn: func(ctx context.Context, e cloudevents.Event) error { if e.String() != testCE.String() { @@ -402,6 +420,7 @@ func TestCloudEventFunction(t *testing.T) { }, { name: "structured cloudevent", + path: "/TestRegisterCloudEventFunctionContext_structured", body: cloudeventsJSON, fn: func(ctx context.Context, e cloudevents.Event) error { if e.String() != testCE.String() { @@ -417,6 +436,7 @@ func TestCloudEventFunction(t *testing.T) { }, { name: "background event", + path: "/TestRegisterCloudEventFunctionContext_background", body: []byte(`{ "context": { "eventId": "aaaaaa-1111-bbbb-2222-cccccccccccc", @@ -492,6 +512,7 @@ func TestCloudEventFunction(t *testing.T) { }, { name: "panic returns 500", + path: "/TestRegisterCloudEventFunctionContext_panic", body: cloudeventsJSON, fn: func(ctx context.Context, e cloudevents.Event) error { panic("intentional panic for test") @@ -503,6 +524,7 @@ func TestCloudEventFunction(t *testing.T) { }, { name: "error returns 500", + path: "/TestRegisterCloudEventFunctionContext_error", body: cloudeventsJSON, fn: func(ctx context.Context, e cloudevents.Event) error { return fmt.Errorf("error for test") @@ -518,11 +540,8 @@ func TestCloudEventFunction(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - ctx := context.Background() - h, err := wrapCloudEventFunction(ctx, "/", tc.fn) - defer func() { handler = nil }() - if err != nil { - t.Fatalf("registerCloudEventFunction(): %v", err) + if err := RegisterCloudEventFunctionContext(context.Background(), tc.path, tc.fn); err != nil { + t.Fatalf("RegisterCloudEventFunctionContext(): %v", err) } // Capture stderr for the duration of the test case. This includes @@ -532,10 +551,14 @@ func TestCloudEventFunction(t *testing.T) { os.Stderr = w defer func() { os.Stderr = origStderrPipe }() - srv := httptest.NewServer(h) + server, err := initServer() + if err != nil { + t.Fatalf("initServer(): %v", err) + } + srv := httptest.NewServer(server) defer srv.Close() - req, err := http.NewRequest("POST", srv.URL, bytes.NewBuffer(tc.body)) + req, err := http.NewRequest("POST", srv.URL+tc.path, bytes.NewBuffer(tc.body)) if err != nil { t.Fatalf("error creating HTTP request for test: %v", err) } @@ -586,49 +609,67 @@ func TestCloudEventFunction(t *testing.T) { func TestDeclarativeFunctionHTTP(t *testing.T) { funcName := "httpfunc" + funcResp := "Hello World!" os.Setenv("FUNCTION_TARGET", funcName) + defer os.Unsetenv("FUNCTION_TARGET") + // Verify RegisterHTTPFunctionContext and functions.HTTP don't conflict. if err := RegisterHTTPFunctionContext(context.Background(), "/", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "Hello World!") }); err != nil { - t.Fatalf("registerHTTPFunction(): %v", err) + t.Fatalf("RegisterHTTPFunctionContext(): %v", err) } - defer func() { handler = nil }() - // register functions + defer registry.Default().DeleteRegisteredFunction("function_at_path_\"/\"") + // Register functions. functions.HTTP(funcName, func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, "Hello World!") + fmt.Fprint(w, funcResp) }) - if _, ok := registry.Default().GetRegisteredFunction(funcName); !ok { - t.Fatalf("could not get registered function: %s", funcName) + t.Fatalf("could not get registered function: %q", funcName) } - srv := httptest.NewServer(handler) + server, err := initServer() + if err != nil { + t.Fatalf("initServer(): %v", err) + } + srv := httptest.NewServer(server) defer srv.Close() - if _, err := http.Get(srv.URL); err != nil { - t.Fatalf("could not make HTTP GET request to function: %s", err) + resp, err := http.Get(srv.URL) + if err != nil { + t.Fatalf("could not make HTTP GET request to function: %q", err) + } + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("ioutil.ReadAll: %v", err) + } + if got := strings.TrimSpace(string(body)); got != funcResp { + t.Errorf("unexpected http response: got %q; want: %q", got, funcResp) } } func TestDeclarativeFunctionCloudEvent(t *testing.T) { funcName := "cloudeventfunc" os.Setenv("FUNCTION_TARGET", funcName) + defer os.Unsetenv("FUNCTION_TARGET") + // Verify RegisterCloudEventFunctionContext and functions.CloudEvent don't conflict. if err := RegisterCloudEventFunctionContext(context.Background(), "/", dummyCloudEvent); err != nil { t.Fatalf("registerHTTPFunction(): %v", err) } - + defer registry.Default().DeleteRegisteredFunction("function_at_path_\"/\"") // register functions functions.CloudEvent(funcName, dummyCloudEvent) - - //cleanup global var - defer func() { handler = nil }() if _, ok := registry.Default().GetRegisteredFunction(funcName); !ok { t.Fatalf("could not get registered function: %s", funcName) } - srv := httptest.NewServer(handler) + server, err := initServer() + if err != nil { + t.Fatalf("initServer(): %v", err) + } + srv := httptest.NewServer(server) defer srv.Close() if _, err := http.Get(srv.URL); err != nil { @@ -639,6 +680,7 @@ func TestDeclarativeFunctionCloudEvent(t *testing.T) { func TestFunctionsNotRegisteredError(t *testing.T) { funcName := "HelloWorld" os.Setenv("FUNCTION_TARGET", funcName) + defer os.Unsetenv("FUNCTION_TARGET") wantErr := fmt.Sprintf("no matching function found with name: %q", funcName) @@ -650,3 +692,55 @@ func TestFunctionsNotRegisteredError(t *testing.T) { func dummyCloudEvent(ctx context.Context, e cloudevents.Event) error { return nil } + +func TestServeMultipleFunctions(t *testing.T) { + fns := []struct { + name string + fn func(w http.ResponseWriter, r *http.Request) + wantResp string + }{ + { + name: "fn1", + fn: func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "Hello Foo!") + }, + wantResp: "Hello Foo!", + }, + { + name: "fn2", + fn: func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "Hello Bar!") + }, + wantResp: "Hello Bar!", + }, + } + + // Register functions. + for _, f := range fns { + functions.HTTP(f.name, f.fn) + if _, ok := registry.Default().GetRegisteredFunction(f.name); !ok { + t.Fatalf("could not get registered function: %s", f.name) + } + } + + server, err := initServer() + if err != nil { + t.Fatalf("initServer(): %v", err) + } + srv := httptest.NewServer(server) + defer srv.Close() + + for _, f := range fns { + resp, err := http.Get(srv.URL + "/" + f.name) + if err != nil { + t.Fatalf("could not make HTTP GET request to function: %s", err) + } + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("ioutil.ReadAll: %v", err) + } + if got := strings.TrimSpace(string(body)); got != f.wantResp { + t.Errorf("unexpected http response: got %q; want: %q", got, f.wantResp) + } + } +} diff --git a/internal/registry/registry.go b/internal/registry/registry.go index a33f863..7b5ba94 100644 --- a/internal/registry/registry.go +++ b/internal/registry/registry.go @@ -12,8 +12,19 @@ import ( // registered with the registry. type RegisteredFunction struct { Name string // The name of the function + Path string // The serving path of the function CloudEventFn func(context.Context, cloudevents.Event) error // Optional: The user's CloudEvent function HTTPFn func(http.ResponseWriter, *http.Request) // Optional: The user's HTTP function + EventFn interface{} // Optional: The user's Event function +} + +// Option is an option used when registering a function. +type Option func(*RegisteredFunction) + +func WithPath(path string) Option { + return func(fn *RegisteredFunction) { + fn.Path = path + } } // Registry is a registry of functions. @@ -35,28 +46,59 @@ func New() *Registry { } // RegisterHTTP a HTTP function with a given name -func (r *Registry) RegisterHTTP(name string, fn func(http.ResponseWriter, *http.Request)) error { +func (r *Registry) RegisterHTTP(name string, fn func(http.ResponseWriter, *http.Request), options ...Option) error { if _, ok := r.functions[name]; ok { - return fmt.Errorf("function name already registered: %s", name) + return fmt.Errorf("function name already registered: %q", name) } - r.functions[name] = RegisteredFunction{ + function := RegisteredFunction{ Name: name, + Path: "/" + name, CloudEventFn: nil, HTTPFn: fn, + EventFn: nil, + } + for _, o := range options { + o(&function) } + r.functions[name] = function return nil } // RegistryCloudEvent a CloudEvent function with a given name -func (r *Registry) RegisterCloudEvent(name string, fn func(context.Context, cloudevents.Event) error) error { +func (r *Registry) RegisterCloudEvent(name string, fn func(context.Context, cloudevents.Event) error, options ...Option) error { if _, ok := r.functions[name]; ok { - return fmt.Errorf("function name already registered: %s", name) + return fmt.Errorf("function name already registered: %q", name) } - r.functions[name] = RegisteredFunction{ + function := RegisteredFunction{ Name: name, + Path: "/" + name, CloudEventFn: fn, HTTPFn: nil, + EventFn: nil, + } + for _, o := range options { + o(&function) + } + r.functions[name] = function + return nil +} + +// RegistryCloudEvent a Event function with a given name +func (r *Registry) RegisterEvent(name string, fn interface{}, options ...Option) error { + if _, ok := r.functions[name]; ok { + return fmt.Errorf("function name already registered: %q", name) + } + function := RegisteredFunction{ + Name: name, + Path: "/" + name, + CloudEventFn: nil, + HTTPFn: nil, + EventFn: fn, } + for _, o := range options { + o(&function) + } + r.functions[name] = function return nil } @@ -65,3 +107,13 @@ func (r *Registry) GetRegisteredFunction(name string) (RegisteredFunction, bool) fn, ok := r.functions[name] return fn, ok } + +// GetAllFunctions returns all the registered functions. +func (r *Registry) GetAllFunctions() map[string]RegisteredFunction { + return r.functions +} + +// DeleteRegisteredFunction deletes a registered function. +func (r *Registry) DeleteRegisteredFunction(name string) { + delete(r.functions, name) +} diff --git a/internal/registry/registry_test.go b/internal/registry/registry_test.go index edb2730..a903a52 100644 --- a/internal/registry/registry_test.go +++ b/internal/registry/registry_test.go @@ -4,7 +4,7 @@ // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, @@ -23,32 +23,137 @@ import ( ) func TestRegisterHTTP(t *testing.T) { - registry := New() - registry.RegisterHTTP("httpfn", func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, "Hello World!") - }) - - fn, ok := registry.GetRegisteredFunction("httpfn") - if !ok { - t.Fatalf("Expected function to be registered") + testCases := []struct { + name string + option Option + wantName string + wantPath string + }{ + { + name: "hello", + wantName: "hello", + wantPath: "/hello", + }, + { + name: "withPath", + option: WithPath("/world"), + wantName: "withPath", + wantPath: "/world", + }, } - if fn.Name != "httpfn" { - t.Errorf("Expected function name to be 'httpfn', got %s", fn.Name) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + registry := New() + + httpfn := func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "Hello World!") } + if tc.option != nil { + registry.RegisterHTTP(tc.name, httpfn, tc.option) + } else { + registry.RegisterHTTP(tc.name, httpfn) + } + + fn, ok := registry.GetRegisteredFunction(tc.name) + if !ok { + t.Fatalf("Expected function to be registered") + } + if fn.Name != tc.wantName { + t.Errorf("Expected function name to be %s, got %s", tc.wantName, fn.Name) + } + if fn.Path != tc.wantPath { + t.Errorf("Expected function path to be %s, got %s", tc.wantPath, fn.Path) + } + }) } } -func TestRegisterCE(t *testing.T) { - registry := New() - registry.RegisterCloudEvent("cefn", func(context.Context, cloudevents.Event) error { - return nil - }) +func TestRegisterCloudEvent(t *testing.T) { + testCases := []struct { + name string + option Option + wantName string + wantPath string + }{ + { + name: "hello", + wantName: "hello", + wantPath: "/hello", + }, + { + name: "withPath", + option: WithPath("/world"), + wantName: "withPath", + wantPath: "/world", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + registry := New() - fn, ok := registry.GetRegisteredFunction("cefn") - if !ok { - t.Fatalf("Expected function to be registered") + cefn := func(context.Context, cloudevents.Event) error { return nil} + if tc.option != nil { + registry.RegisterCloudEvent(tc.name, cefn, tc.option) + } else { + registry.RegisterCloudEvent(tc.name, cefn) + } + + fn, ok := registry.GetRegisteredFunction(tc.name) + if !ok { + t.Fatalf("Expected function to be registered") + } + if fn.Name != tc.wantName { + t.Errorf("Expected function name to be %s, got %s", tc.wantName, fn.Name) + } + if fn.Path != tc.wantPath { + t.Errorf("Expected function path to be %s, got %s", tc.wantPath, fn.Path) + } + }) } - if fn.Name != "cefn" { - t.Errorf("Expected function name to be 'cefn', got %s", fn.Name) +} + +func TestRegisterEvent(t *testing.T) { + testCases := []struct { + name string + option Option + wantName string + wantPath string + }{ + { + name: "hello", + wantName: "hello", + wantPath: "/hello", + }, + { + name: "withPath", + option: WithPath("/world"), + wantName: "withPath", + wantPath: "/world", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + registry := New() + + eventfn := func() {} + if tc.option != nil { + registry.RegisterEvent(tc.name, eventfn, tc.option) + } else { + registry.RegisterEvent(tc.name, eventfn) + } + + fn, ok := registry.GetRegisteredFunction(tc.name) + if !ok { + t.Fatalf("Expected function to be registered") + } + if fn.Name != tc.wantName { + t.Errorf("Expected function name to be %s, got %s", tc.wantName, fn.Name) + } + if fn.Path != tc.wantPath { + t.Errorf("Expected function path to be %s, got %s", tc.wantPath, fn.Path) + } + }) } } @@ -59,9 +164,7 @@ func TestRegisterMultipleFunctions(t *testing.T) { }); err != nil { t.Error("Expected \"multifn1\" function to be registered") } - if err := registry.RegisterHTTP("multifn2", func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, "Hello World 2!") - }); err != nil { + if err := registry.RegisterEvent("multifn2", func() {}); err != nil { t.Error("Expected \"multifn2\" function to be registered") } if err := registry.RegisterCloudEvent("multifn3", func(context.Context, cloudevents.Event) error {