Skip to content

Commit

Permalink
Merge branch 'master' into fix_mq_test
Browse files Browse the repository at this point in the history
  • Loading branch information
ti-chi-bot authored May 18, 2022
2 parents a0439d6 + 359af18 commit 2c94029
Show file tree
Hide file tree
Showing 12 changed files with 364 additions and 16 deletions.
39 changes: 32 additions & 7 deletions dm/dm/master/openapi_view.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@
package master

import (
"crypto/tls"
"encoding/json"
"fmt"
"net/http"
"net/http/httputil"

"github.com/pingcap/failpoint"

ginmiddleware "github.com/deepmap/oapi-codegen/pkg/gin-middleware"
"github.com/gin-gonic/gin"
Expand All @@ -39,9 +43,8 @@ const (
docJSONBasePath = "/api/v1/dm.json"
)

// redirectRequestToLeaderMW a middleware auto redirect request to leader.
// because the leader has some data in memory, only the leader can process the request.
func (s *Server) redirectRequestToLeaderMW() gin.HandlerFunc {
// reverseRequestToLeaderMW reverses request to leader.
func (s *Server) reverseRequestToLeaderMW(tlsCfg *tls.Config) gin.HandlerFunc {
return func(c *gin.Context) {
ctx2 := c.Request.Context()
isLeader, _ := s.isLeaderAndNeedForward(ctx2)
Expand All @@ -54,14 +57,36 @@ func (s *Server) redirectRequestToLeaderMW() gin.HandlerFunc {
_ = c.AbortWithError(http.StatusBadRequest, err)
return
}
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("http://%s%s", leaderOpenAPIAddr, c.Request.RequestURI))
c.AbortWithStatus(http.StatusTemporaryRedirect)

failpoint.Inject("MockNotSetTls", func() {
tlsCfg = nil
})
// simpleProxy just reverses to leader host
simpleProxy := httputil.ReverseProxy{
Director: func(req *http.Request) {
if tlsCfg != nil {
req.URL.Scheme = "https"
} else {
req.URL.Scheme = "http"
}
req.URL.Host = leaderOpenAPIAddr
req.Host = leaderOpenAPIAddr
},
}
if tlsCfg != nil {
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.TLSClientConfig = tlsCfg
simpleProxy.Transport = transport
}
log.L().Info("reverse request to leader", zap.String("Request URL", c.Request.URL.String()), zap.String("leader", leaderOpenAPIAddr), zap.Bool("hasTLS", tlsCfg != nil))
simpleProxy.ServeHTTP(c.Writer, c.Request)
c.Abort()
}
}
}

// InitOpenAPIHandles init openapi handlers.
func (s *Server) InitOpenAPIHandles() error {
func (s *Server) InitOpenAPIHandles(tlsCfg *tls.Config) error {
swagger, err := openapi.GetSwagger()
if err != nil {
return err
Expand All @@ -73,7 +98,7 @@ func (s *Server) InitOpenAPIHandles() error {
// middlewares
r.Use(gin.Recovery())
r.Use(openapi.ZapLogger(log.L().WithFields(zap.String("component", "openapi")).Logger))
r.Use(s.redirectRequestToLeaderMW())
r.Use(s.reverseRequestToLeaderMW(tlsCfg))
r.Use(terrorHTTPErrorHandler())
// use validation middleware to check all requests against the OpenAPI schema.
r.Use(ginmiddleware.OapiRequestValidator(swagger))
Expand Down
162 changes: 158 additions & 4 deletions dm/dm/master/openapi_view_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@
package master

import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"testing"
"time"

Expand Down Expand Up @@ -292,7 +296,7 @@ func (s *OpenAPIViewSuite) TestClusterAPI() {
cancel1()
}

func (s *OpenAPIViewSuite) TestRedirectRequestToLeader() {
func (s *OpenAPIViewSuite) TestReverseRequestToLeader() {
ctx1, cancel1 := context.WithCancel(context.Background())
s1 := setupTestServer(ctx1, s.T())
defer func() {
Expand Down Expand Up @@ -334,9 +338,159 @@ func (s *OpenAPIViewSuite) TestRedirectRequestToLeader() {
s.Len(resultListSource.Data, 0)
s.Equal(0, resultListSource.Total)

// list source not from leader will get a redirect
result = testutil.NewRequest().Get(baseURL).GoWithHTTPHandler(s.T(), s2.openapiHandles)
s.Equal(http.StatusTemporaryRedirect, result.Code())
// list source from non-leader will get result too
result, err := HTTPTestWithTestResponseRecorder(testutil.NewRequest().Get(baseURL), s2.openapiHandles)
s.NoError(err)
s.Equal(http.StatusOK, result.Code())
var resultListSource2 openapi.GetSourceListResponse
s.NoError(result.UnmarshalBodyToObject(&resultListSource2))
s.Len(resultListSource2.Data, 0)
s.Equal(0, resultListSource2.Total)
}

func (s *OpenAPIViewSuite) TestReverseRequestToHttpsLeader() {
pwd, err := os.Getwd()
require.NoError(s.T(), err)
caPath := pwd + "/tls_for_test/ca.pem"
certPath := pwd + "/tls_for_test/dm.pem"
keyPath := pwd + "/tls_for_test/dm.key"

// master1
masterAddr1 := tempurl.Alloc()[len("http://"):]
peerAddr1 := tempurl.Alloc()[len("http://"):]
cfg1 := NewConfig()
require.NoError(s.T(), cfg1.Parse([]string{
"--name=dm-master-tls-1",
fmt.Sprintf("--data-dir=%s", s.T().TempDir()),
fmt.Sprintf("--master-addr=https://%s", masterAddr1),
fmt.Sprintf("--advertise-addr=https://%s", masterAddr1),
fmt.Sprintf("--peer-urls=https://%s", peerAddr1),
fmt.Sprintf("--advertise-peer-urls=https://%s", peerAddr1),
fmt.Sprintf("--initial-cluster=dm-master-tls-1=https://%s", peerAddr1),
"--ssl-ca=" + caPath,
"--ssl-cert=" + certPath,
"--ssl-key=" + keyPath,
}))
cfg1.OpenAPI = true
s1 := NewServer(cfg1)
ctx1, cancel1 := context.WithCancel(context.Background())
require.NoError(s.T(), s1.Start(ctx1))
defer func() {
cancel1()
s1.Close()
}()
// wait the first one become the leader
require.True(s.T(), utils.WaitSomething(30, 100*time.Millisecond, func() bool {
return s1.election.IsLeader() && s1.scheduler.Started()
}))

// master2
masterAddr2 := tempurl.Alloc()[len("http://"):]
peerAddr2 := tempurl.Alloc()[len("http://"):]
cfg2 := NewConfig()
require.NoError(s.T(), cfg2.Parse([]string{
"--name=dm-master-tls-2",
fmt.Sprintf("--data-dir=%s", s.T().TempDir()),
fmt.Sprintf("--master-addr=https://%s", masterAddr2),
fmt.Sprintf("--advertise-addr=https://%s", masterAddr2),
fmt.Sprintf("--peer-urls=https://%s", peerAddr2),
fmt.Sprintf("--advertise-peer-urls=https://%s", peerAddr2),
"--ssl-ca=" + caPath,
"--ssl-cert=" + certPath,
"--ssl-key=" + keyPath,
}))
cfg2.OpenAPI = true
cfg2.Join = s1.cfg.MasterAddr // join to an existing cluster
s2 := NewServer(cfg2)
ctx2, cancel2 := context.WithCancel(context.Background())
require.NoError(s.T(), s2.Start(ctx2))
defer func() {
cancel2()
s2.Close()
}()
// wait the second master ready
require.False(s.T(), utils.WaitSomething(30, 100*time.Millisecond, func() bool {
return s2.election.IsLeader()
}))

baseURL := "/api/v1/sources"
// list source from leader
result := testutil.NewRequest().Get(baseURL).GoWithHTTPHandler(s.T(), s1.openapiHandles)
s.Equal(http.StatusOK, result.Code())
var resultListSource openapi.GetSourceListResponse
s.NoError(result.UnmarshalBodyToObject(&resultListSource))
s.Len(resultListSource.Data, 0)
s.Equal(0, resultListSource.Total)

// with tls, list source not from leader will get result too
result, err = HTTPTestWithTestResponseRecorder(testutil.NewRequest().Get(baseURL), s2.openapiHandles)
s.NoError(err)
s.Equal(http.StatusOK, result.Code())
var resultListSource2 openapi.GetSourceListResponse
s.NoError(result.UnmarshalBodyToObject(&resultListSource2))
s.Len(resultListSource2.Data, 0)
s.Equal(0, resultListSource2.Total)

// without tls, list source not from leader will be 502
s.NoError(failpoint.Enable("github.com/pingcap/tiflow/dm/dm/master/MockNotSetTls", `return()`))
result, err = HTTPTestWithTestResponseRecorder(testutil.NewRequest().Get(baseURL), s2.openapiHandles)
s.NoError(err)
s.Equal(http.StatusBadGateway, result.Code())
s.NoError(failpoint.Disable("github.com/pingcap/tiflow/dm/dm/master/MockNotSetTls"))
}

// httptest.ResponseRecorder is not http.CloseNotifier, will panic when test reverse proxy.
// We need to implement the interface ourselves.
// ref: https://github.com/gin-gonic/gin/blob/ce20f107f5dc498ec7489d7739541a25dcd48463/context_test.go#L1747-L1765
type TestResponseRecorder struct {
*httptest.ResponseRecorder
closeChannel chan bool
}

func (r *TestResponseRecorder) CloseNotify() <-chan bool {
return r.closeChannel
}

func (r *TestResponseRecorder) closeClient() {
r.closeChannel <- true
}

func CreateTestResponseRecorder() *TestResponseRecorder {
return &TestResponseRecorder{
httptest.NewRecorder(),
make(chan bool, 1),
}
}

func HTTPTestWithTestResponseRecorder(r *testutil.RequestBuilder, handler http.Handler) (*testutil.CompletedRequest, error) {
if r == nil {
return nil, nil
}
if r.Error != nil {
return nil, r.Error
}
var bodyReader io.Reader
if r.Body != nil {
bodyReader = bytes.NewReader(r.Body)
}

req := httptest.NewRequest(r.Method, r.Path, bodyReader)
for h, v := range r.Headers {
req.Header.Add(h, v)
}
if host, ok := r.Headers["Host"]; ok {
req.Host = host
}
for _, c := range r.Cookies {
req.AddCookie(c)
}

rec := CreateTestResponseRecorder()
handler.ServeHTTP(rec, req)

return &testutil.CompletedRequest{
Recorder: rec.ResponseRecorder,
}, nil
}

func (s *OpenAPIViewSuite) TestOpenAPIWillNotStartInDefaultConfig() {
Expand Down
7 changes: 6 additions & 1 deletion dm/dm/master/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,12 @@ func (s *Server) Start(ctx context.Context) (err error) {
"/debug/": getDebugHandler(),
}
if s.cfg.OpenAPI {
if initOpenAPIErr := s.InitOpenAPIHandles(); initOpenAPIErr != nil {
// tls3 is used to openapi reverse proxy
tls3, err1 := toolutils.NewTLS(s.cfg.SSLCA, s.cfg.SSLCert, s.cfg.SSLKey, s.cfg.AdvertiseAddr, s.cfg.CertAllowedCN)
if err1 != nil {
return terror.ErrMasterTLSConfigNotValid.Delegate(err1)
}
if initOpenAPIErr := s.InitOpenAPIHandles(tls3.TLSConfig()); initOpenAPIErr != nil {
return terror.ErrOpenAPICommonError.Delegate(initOpenAPIErr)
}
userHandles["/api/v1/"] = s.openapiHandles
Expand Down
43 changes: 40 additions & 3 deletions dm/tests/openapi/client/openapi_source_check
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python
import sys
import requests
import ssl

SOURCE1_NAME = "mysql-01"
SOURCE2_NAME = "mysql-02"
Expand All @@ -11,6 +12,10 @@ WORKER2_NAME = "worker2"
API_ENDPOINT = "http://127.0.0.1:8261/api/v1/sources"
API_ENDPOINT_NOT_LEADER = "http://127.0.0.1:8361/api/v1/sources"

API_ENDPOINT_HTTPS = "https://127.0.0.1:8261/api/v1/sources"
API_ENDPOINT_NOT_LEADER_HTTPS = "https://127.0.0.1:8361/api/v1/sources"



def create_source_failed():
resp = requests.post(url=API_ENDPOINT)
Expand Down Expand Up @@ -53,6 +58,23 @@ def create_source2_success():
print("create_source1_success resp=", resp.json())
assert resp.status_code == 201

def create_source_success_https(ssl_ca, ssl_cert, ssl_key):
req = {
"source": {
"case_sensitive": False,
"enable": True,
"enable_gtid": False,
"host": "127.0.0.1",
"password": "123456",
"port": 3306,
"source_name": SOURCE1_NAME,
"user": "root",
}
}
resp = requests.post(url=API_ENDPOINT_HTTPS, json=req, verify=ssl_ca, cert=(ssl_cert, ssl_key))
print("create_source_success_https resp=", resp.json())
assert resp.status_code == 201

def update_source1_without_password_success():
req = {
"source": {
Expand All @@ -76,6 +98,12 @@ def list_source_success(source_count):
print("list_source_by_openapi_success resp=", data)
assert data["total"] == int(source_count)

def list_source_success_https(source_count, ssl_ca, ssl_cert, ssl_key):
resp = requests.get(url=API_ENDPOINT_HTTPS, verify=ssl_ca, cert=(ssl_cert, ssl_key))
assert resp.status_code == 200
data = resp.json()
print("list_source_success_https resp=", data)
assert data["total"] == int(source_count)

def list_source_with_status_success(source_count, status_count):
resp = requests.get(url=API_ENDPOINT + "?with_status=true")
Expand All @@ -87,13 +115,19 @@ def list_source_with_status_success(source_count, status_count):
assert len(data["data"][i]["status_list"]) == int(status_count)


def list_source_with_redirect(source_count):
def list_source_with_reverse(source_count):
resp = requests.get(url=API_ENDPOINT_NOT_LEADER)
assert resp.status_code == 200
data = resp.json()
print("list_source_by_openapi_redirect resp=", data)
print("list_source_with_reverse resp=", data)
assert data["total"] == int(source_count)

def list_source_with_reverse_https(source_count, ssl_ca, ssl_cert, ssl_key):
resp = requests.get(url=API_ENDPOINT_NOT_LEADER_HTTPS, verify=ssl_ca, cert=(ssl_cert, ssl_key))
assert resp.status_code == 200
data = resp.json()
print("list_source_with_reverse_https resp=", data)
assert data["total"] == int(source_count)

def delete_source_success(source_name):
resp = requests.delete(url=API_ENDPOINT + "/" + source_name)
Expand Down Expand Up @@ -235,9 +269,12 @@ if __name__ == "__main__":
"create_source_failed": create_source_failed,
"create_source1_success": create_source1_success,
"create_source2_success": create_source2_success,
"create_source_success_https": create_source_success_https,
"update_source1_without_password_success": update_source1_without_password_success,
"list_source_success": list_source_success,
"list_source_with_redirect": list_source_with_redirect,
"list_source_success_https": list_source_success_https,
"list_source_with_reverse_https": list_source_with_reverse_https,
"list_source_with_reverse": list_source_with_reverse,
"list_source_with_status_success": list_source_with_status_success,
"delete_source_failed": delete_source_failed,
"delete_source_success": delete_source_success,
Expand Down
Loading

0 comments on commit 2c94029

Please sign in to comment.