From 6fa6a16559a1c1d0e90de4a3277a3db6e73cb9fc Mon Sep 17 00:00:00 2001 From: nic-chen <33000667+nic-chen@users.noreply.github.com> Date: Fri, 25 Dec 2020 07:56:57 +0800 Subject: [PATCH] fix: when enable or disable existing SSL, an error occurred (#1064) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: when enable or disable existing SSL, an error occurred * fix: keep the code uniform * fix: lint * fix: keep the code uniform * fix: license * chore: add comment to note that json.Marshal and json.Unmarshal may cause the precision loss Co-authored-by: 琚致远 Co-authored-by: YuanSheng Wang --- api/internal/handler/ssl/ssl.go | 10 +++- api/internal/utils/consts/api_error.go | 26 ++++++++- api/internal/utils/consts/api_error_test.go | 62 +++++++++++++++++++++ api/internal/utils/utils.go | 12 ++++ api/internal/utils/utils_test.go | 18 ++++++ api/test/e2e/ssl_test.go | 8 +++ 6 files changed, 130 insertions(+), 6 deletions(-) create mode 100644 api/internal/utils/consts/api_error_test.go diff --git a/api/internal/handler/ssl/ssl.go b/api/internal/handler/ssl/ssl.go index 19009954f8..36d8efa4f5 100644 --- a/api/internal/handler/ssl/ssl.go +++ b/api/internal/handler/ssl/ssl.go @@ -86,7 +86,11 @@ func (h *Handler) Get(c droplet.Context) (interface{}, error) { } //format respond - ssl := ret.(*entity.SSL) + ssl := &entity.SSL{} + err = utils.ObjectClone(ret, ssl) + if err != nil { + return handler.SpecCodeResponse(err), err + } ssl.Key = "" ssl.Keys = nil @@ -160,9 +164,9 @@ func (h *Handler) List(c droplet.Context) (interface{}, error) { //format respond var list []interface{} - var ssl *entity.SSL for _, item := range ret.Rows { - ssl = item.(*entity.SSL) + ssl := &entity.SSL{} + _ = utils.ObjectClone(item, ssl) ssl.Key = "" ssl.Keys = nil list = append(list, ssl) diff --git a/api/internal/utils/consts/api_error.go b/api/internal/utils/consts/api_error.go index 608ad0ef05..4e2a0c0299 100644 --- a/api/internal/utils/consts/api_error.go +++ b/api/internal/utils/consts/api_error.go @@ -17,8 +17,10 @@ package consts import ( - "github.com/gin-gonic/gin" "net/http" + "strings" + + "github.com/gin-gonic/gin" ) type WrapperHandle func(c *gin.Context) (interface{}, error) @@ -27,7 +29,21 @@ func ErrorWrapper(handle WrapperHandle) gin.HandlerFunc { return func(c *gin.Context) { data, err := handle(c) if err != nil { - apiError := err.(*ApiError) + apiError, ok := err.(*ApiError) + if !ok { + errMsg := err.Error() + if strings.Contains(errMsg, "required") || + strings.Contains(errMsg, "conflicted") || + strings.Contains(errMsg, "invalid") || + strings.Contains(errMsg, "missing") || + strings.Contains(errMsg, "validate failed") { + apiError = InvalidParam(errMsg) + } else if strings.Contains(errMsg, "not found") { + apiError = NotFound(errMsg) + } else { + apiError = SystemError(errMsg) + } + } c.JSON(apiError.Status, apiError) return } @@ -53,5 +69,9 @@ func InvalidParam(message string) *ApiError { } func SystemError(message string) *ApiError { - return &ApiError{500, 500, message} + return &ApiError{500, 10001, message} +} + +func NotFound(message string) *ApiError { + return &ApiError{404, 10002, message} } diff --git a/api/internal/utils/consts/api_error_test.go b/api/internal/utils/consts/api_error_test.go new file mode 100644 index 0000000000..9be241e4ea --- /dev/null +++ b/api/internal/utils/consts/api_error_test.go @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package consts + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" +) + +func performRequest(r http.Handler, method, path string) *httptest.ResponseRecorder { + req := httptest.NewRequest(method, path, nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + return w +} + +func TestRequestLogHandler(t *testing.T) { + r := gin.New() + r.GET("/", ErrorWrapper(func(c *gin.Context) (interface{}, error) { + return nil, nil + })) + r.GET("/notfound", ErrorWrapper(func(c *gin.Context) (interface{}, error) { + return nil, fmt.Errorf("data not found") + })) + r.GET("/invalid", ErrorWrapper(func(c *gin.Context) (interface{}, error) { + return nil, fmt.Errorf("schema validate failed") + })) + r.GET("/error", ErrorWrapper(func(c *gin.Context) (interface{}, error) { + return nil, fmt.Errorf("internal system error") + })) + + w := performRequest(r, "GET", "/") + assert.Equal(t, 200, w.Code) + + w = performRequest(r, "GET", "/notfound") + assert.Equal(t, 404, w.Code) + + w = performRequest(r, "GET", "/invalid") + assert.Equal(t, 400, w.Code) + + w = performRequest(r, "GET", "/error") + assert.Equal(t, 500, w.Code) +} diff --git a/api/internal/utils/utils.go b/api/internal/utils/utils.go index ad477925af..b8d192c70b 100644 --- a/api/internal/utils/utils.go +++ b/api/internal/utils/utils.go @@ -17,6 +17,7 @@ package utils import ( + "encoding/json" "errors" "fmt" "net" @@ -97,6 +98,17 @@ func InterfaceToString(val interface{}) string { return str } +// Note: json.Marshal and json.Unmarshal may cause the precision loss +func ObjectClone(origin, copy interface{}) error { + byt, err := json.Marshal(origin) + if err != nil { + return err + } + + err = json.Unmarshal(byt, copy) + return err +} + func GenLabelMap(label string) (map[string]string, error) { var err = errors.New("malformed label") mp := make(map[string]string) diff --git a/api/internal/utils/utils_test.go b/api/internal/utils/utils_test.go index 0d8427de61..8da295f83a 100644 --- a/api/internal/utils/utils_test.go +++ b/api/internal/utils/utils_test.go @@ -44,6 +44,24 @@ func TestSumIPs_with_nil(t *testing.T) { assert.Equal(t, uint16(0), total) } +func TestObjectClone(t *testing.T) { + type test struct { + Str string + Num int + } + + origin := &test{Str: "a", Num: 1} + copy := &test{} + err := ObjectClone(origin, copy) + assert.Nil(t, err) + assert.Equal(t, origin, copy) + + // change value of the copy, should not change value of origin + copy.Num = 2 + assert.NotEqual(t, copy.Num, origin.Num) + assert.Equal(t, 1, origin.Num) +} + func TestGenLabelMap(t *testing.T) { expectedErr := errors.New("malformed label") mp, err := GenLabelMap("l1") diff --git a/api/test/e2e/ssl_test.go b/api/test/e2e/ssl_test.go index add638074a..9dc3b80565 100644 --- a/api/test/e2e/ssl_test.go +++ b/api/test/e2e/ssl_test.go @@ -101,6 +101,14 @@ func TestSSL_Basic(t *testing.T) { Headers: map[string]string{"Authorization": token}, ExpectStatus: http.StatusOK, }, + { + Desc: "get the route just created to trigger removing `key`", + Object: ManagerApiExpect(t), + Method: http.MethodGet, + Path: "/apisix/admin/routes/r1", + Headers: map[string]string{"Authorization": token}, + ExpectStatus: http.StatusOK, + }, { Desc: "hit the route just created using HTTPS", Object: APISIXHTTPSExpect(t),