Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM/Openapi: use reverse proxy instead of redirect #5390

Merged
merged 15 commits into from
May 18, 2022
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 32 additions & 7 deletions dm/dm/master/openapi_view.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@
package master

import (
"crypto/tls"
"encoding/json"
"fmt"
"net/http"
"net/http/httputil"

"github.com/pingcap/failpoint"

ginmiddleware "github.com/deepmap/oapi-codegen/pkg/gin-middleware"
"github.com/gin-gonic/gin"
Expand All @@ -39,9 +43,8 @@ const (
docJSONBasePath = "/api/v1/dm.json"
)

// redirectRequestToLeaderMW a middleware auto redirect request to leader.
// because the leader has some data in memory, only the leader can process the request.
func (s *Server) redirectRequestToLeaderMW() gin.HandlerFunc {
// reverseRequestToLeaderMW reverses request to leader.
func (s *Server) reverseRequestToLeaderMW(tlsCfg *tls.Config) gin.HandlerFunc {
return func(c *gin.Context) {
ctx2 := c.Request.Context()
isLeader, _ := s.isLeaderAndNeedForward(ctx2)
Expand All @@ -54,14 +57,36 @@ func (s *Server) redirectRequestToLeaderMW() gin.HandlerFunc {
_ = c.AbortWithError(http.StatusBadRequest, err)
return
}
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("http://%s%s", leaderOpenAPIAddr, c.Request.RequestURI))
c.AbortWithStatus(http.StatusTemporaryRedirect)

failpoint.Inject("MockNotSetTls", func() {
tlsCfg = nil
})
// simpleProxy just reverses to leader host
simpleProxy := httputil.ReverseProxy{
Director: func(req *http.Request) {
if tlsCfg != nil {
req.URL.Scheme = "https"
} else {
req.URL.Scheme = "http"
}
req.URL.Host = leaderOpenAPIAddr
req.Host = leaderOpenAPIAddr
},
}
if tlsCfg != nil {
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.TLSClientConfig = tlsCfg
simpleProxy.Transport = transport
}
log.L().Info("reverse request to leader", zap.String("Request URL", c.Request.URL.String()), zap.String("leader", leaderOpenAPIAddr), zap.Bool("hasTLS", tlsCfg != nil))
simpleProxy.ServeHTTP(c.Writer, c.Request)
c.Abort()
}
}
}

// InitOpenAPIHandles init openapi handlers.
func (s *Server) InitOpenAPIHandles() error {
func (s *Server) InitOpenAPIHandles(tlsCfg *tls.Config) error {
swagger, err := openapi.GetSwagger()
if err != nil {
return err
Expand All @@ -73,7 +98,7 @@ func (s *Server) InitOpenAPIHandles() error {
// middlewares
r.Use(gin.Recovery())
r.Use(openapi.ZapLogger(log.L().WithFields(zap.String("component", "openapi")).Logger))
r.Use(s.redirectRequestToLeaderMW())
r.Use(s.reverseRequestToLeaderMW(tlsCfg))
r.Use(terrorHTTPErrorHandler())
// use validation middleware to check all requests against the OpenAPI schema.
r.Use(ginmiddleware.OapiRequestValidator(swagger))
Expand Down
162 changes: 158 additions & 4 deletions dm/dm/master/openapi_view_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@
package master

import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"testing"
"time"

Expand Down Expand Up @@ -292,7 +296,7 @@ func (s *OpenAPIViewSuite) TestClusterAPI() {
cancel1()
}

func (s *OpenAPIViewSuite) TestRedirectRequestToLeader() {
func (s *OpenAPIViewSuite) TestReverseRequestToLeader() {
ctx1, cancel1 := context.WithCancel(context.Background())
s1 := setupTestServer(ctx1, s.T())
defer func() {
Expand Down Expand Up @@ -334,9 +338,159 @@ func (s *OpenAPIViewSuite) TestRedirectRequestToLeader() {
s.Len(resultListSource.Data, 0)
s.Equal(0, resultListSource.Total)

// list source not from leader will get a redirect
result = testutil.NewRequest().Get(baseURL).GoWithHTTPHandler(s.T(), s2.openapiHandles)
s.Equal(http.StatusTemporaryRedirect, result.Code())
// list source from non-leader will get result too
result, err := HTTPTestWithTestResponseRecorder(testutil.NewRequest().Get(baseURL), s2.openapiHandles)
s.NoError(err)
s.Equal(http.StatusOK, result.Code())
var resultListSource2 openapi.GetSourceListResponse
s.NoError(result.UnmarshalBodyToObject(&resultListSource2))
s.Len(resultListSource2.Data, 0)
s.Equal(0, resultListSource2.Total)
}

func (s *OpenAPIViewSuite) TestReverseRequestToHttpsLeader() {
pwd, err := os.Getwd()
require.NoError(s.T(), err)
caPath := pwd + "/tls_for_test/ca.pem"
certPath := pwd + "/tls_for_test/dm.pem"
keyPath := pwd + "/tls_for_test/dm.key"

// master1
masterAddr1 := tempurl.Alloc()[len("http://"):]
peerAddr1 := tempurl.Alloc()[len("http://"):]
cfg1 := NewConfig()
require.NoError(s.T(), cfg1.Parse([]string{
"--name=dm-master-tls-1",
fmt.Sprintf("--data-dir=%s", s.T().TempDir()),
fmt.Sprintf("--master-addr=https://%s", masterAddr1),
fmt.Sprintf("--advertise-addr=https://%s", masterAddr1),
fmt.Sprintf("--peer-urls=https://%s", peerAddr1),
fmt.Sprintf("--advertise-peer-urls=https://%s", peerAddr1),
fmt.Sprintf("--initial-cluster=dm-master-tls-1=https://%s", peerAddr1),
"--ssl-ca=" + caPath,
"--ssl-cert=" + certPath,
"--ssl-key=" + keyPath,
}))
cfg1.OpenAPI = true
s1 := NewServer(cfg1)
ctx1, cancel1 := context.WithCancel(context.Background())
require.NoError(s.T(), s1.Start(ctx1))
defer func() {
cancel1()
s1.Close()
}()
// wait the first one become the leader
require.True(s.T(), utils.WaitSomething(30, 100*time.Millisecond, func() bool {
return s1.election.IsLeader() && s1.scheduler.Started()
}))

// master2
masterAddr2 := tempurl.Alloc()[len("http://"):]
peerAddr2 := tempurl.Alloc()[len("http://"):]
cfg2 := NewConfig()
require.NoError(s.T(), cfg2.Parse([]string{
"--name=dm-master-tls-2",
fmt.Sprintf("--data-dir=%s", s.T().TempDir()),
fmt.Sprintf("--master-addr=https://%s", masterAddr2),
fmt.Sprintf("--advertise-addr=https://%s", masterAddr2),
fmt.Sprintf("--peer-urls=https://%s", peerAddr2),
fmt.Sprintf("--advertise-peer-urls=https://%s", peerAddr2),
"--ssl-ca=" + caPath,
"--ssl-cert=" + certPath,
"--ssl-key=" + keyPath,
}))
cfg2.OpenAPI = true
cfg2.Join = s1.cfg.MasterAddr // join to an existing cluster
s2 := NewServer(cfg2)
ctx2, cancel2 := context.WithCancel(context.Background())
require.NoError(s.T(), s2.Start(ctx2))
defer func() {
cancel2()
s2.Close()
}()
// wait the second master ready
require.False(s.T(), utils.WaitSomething(30, 100*time.Millisecond, func() bool {
return s2.election.IsLeader()
}))

baseURL := "/api/v1/sources"
// list source from leader
result := testutil.NewRequest().Get(baseURL).GoWithHTTPHandler(s.T(), s1.openapiHandles)
s.Equal(http.StatusOK, result.Code())
var resultListSource openapi.GetSourceListResponse
s.NoError(result.UnmarshalBodyToObject(&resultListSource))
s.Len(resultListSource.Data, 0)
s.Equal(0, resultListSource.Total)

// with tls, list source not from leader will get result too
result, err = HTTPTestWithTestResponseRecorder(testutil.NewRequest().Get(baseURL), s2.openapiHandles)
s.NoError(err)
s.Equal(http.StatusOK, result.Code())
var resultListSource2 openapi.GetSourceListResponse
s.NoError(result.UnmarshalBodyToObject(&resultListSource2))
s.Len(resultListSource2.Data, 0)
s.Equal(0, resultListSource2.Total)

// without tls, list source not from leader will be 502
s.NoError(failpoint.Enable("github.com/pingcap/tiflow/dm/dm/master/MockNotSetTls", `return()`))
result, err = HTTPTestWithTestResponseRecorder(testutil.NewRequest().Get(baseURL), s2.openapiHandles)
s.NoError(err)
s.Equal(http.StatusBadGateway, result.Code())
s.NoError(failpoint.Disable("github.com/pingcap/tiflow/dm/dm/master/MockNotSetTls"))
}

// httptest.ResponseRecorder is not http.CloseNotifier, will panic when test reverse proxy.
// We need to implement the interface ourselves.
// ref: https://github.com/gin-gonic/gin/blob/ce20f107f5dc498ec7489d7739541a25dcd48463/context_test.go#L1747-L1765
type TestResponseRecorder struct {
*httptest.ResponseRecorder
closeChannel chan bool
}

func (r *TestResponseRecorder) CloseNotify() <-chan bool {
return r.closeChannel
}

func (r *TestResponseRecorder) closeClient() {
r.closeChannel <- true
}

func CreateTestResponseRecorder() *TestResponseRecorder {
return &TestResponseRecorder{
httptest.NewRecorder(),
make(chan bool, 1),
}
}

func HTTPTestWithTestResponseRecorder(r *testutil.RequestBuilder, handler http.Handler) (*testutil.CompletedRequest, error) {
if r == nil {
return nil, nil
}
if r.Error != nil {
return nil, r.Error
}
var bodyReader io.Reader
if r.Body != nil {
bodyReader = bytes.NewReader(r.Body)
}

req := httptest.NewRequest(r.Method, r.Path, bodyReader)
for h, v := range r.Headers {
req.Header.Add(h, v)
}
if host, ok := r.Headers["Host"]; ok {
req.Host = host
}
for _, c := range r.Cookies {
req.AddCookie(c)
}

rec := CreateTestResponseRecorder()
handler.ServeHTTP(rec, req)

return &testutil.CompletedRequest{
Recorder: rec.ResponseRecorder,
}, nil
}

func (s *OpenAPIViewSuite) TestOpenAPIWillNotStartInDefaultConfig() {
Expand Down
7 changes: 6 additions & 1 deletion dm/dm/master/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,12 @@ func (s *Server) Start(ctx context.Context) (err error) {
"/debug/": getDebugHandler(),
}
if s.cfg.OpenAPI {
if initOpenAPIErr := s.InitOpenAPIHandles(); initOpenAPIErr != nil {
// tls3 is used to openapi reverse proxy
tls3, err1 := toolutils.NewTLS(s.cfg.SSLCA, s.cfg.SSLCert, s.cfg.SSLKey, s.cfg.AdvertiseAddr, s.cfg.CertAllowedCN)
if err1 != nil {
return terror.ErrMasterTLSConfigNotValid.Delegate(err1)
}
if initOpenAPIErr := s.InitOpenAPIHandles(tls3.TLSConfig()); initOpenAPIErr != nil {
return terror.ErrOpenAPICommonError.Delegate(initOpenAPIErr)
}
userHandles["/api/v1/"] = s.openapiHandles
Expand Down
43 changes: 40 additions & 3 deletions dm/tests/openapi/client/openapi_source_check
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python
import sys
import requests
import ssl

SOURCE1_NAME = "mysql-01"
SOURCE2_NAME = "mysql-02"
Expand All @@ -11,6 +12,10 @@ WORKER2_NAME = "worker2"
API_ENDPOINT = "http://127.0.0.1:8261/api/v1/sources"
API_ENDPOINT_NOT_LEADER = "http://127.0.0.1:8361/api/v1/sources"

API_ENDPOINT_HTTPS = "https://127.0.0.1:8261/api/v1/sources"
API_ENDPOINT_NOT_LEADER_HTTPS = "https://127.0.0.1:8361/api/v1/sources"



def create_source_failed():
resp = requests.post(url=API_ENDPOINT)
Expand Down Expand Up @@ -53,6 +58,23 @@ def create_source2_success():
print("create_source1_success resp=", resp.json())
assert resp.status_code == 201

def create_source_success_https(ssl_ca, ssl_cert, ssl_key):
req = {
"source": {
"case_sensitive": False,
"enable": True,
"enable_gtid": False,
"host": "127.0.0.1",
"password": "123456",
"port": 3306,
"source_name": SOURCE1_NAME,
"user": "root",
}
}
resp = requests.post(url=API_ENDPOINT_HTTPS, json=req, verify=ssl_ca, cert=(ssl_cert, ssl_key))
print("create_source_success_https resp=", resp.json())
assert resp.status_code == 201

def update_source1_without_password_success():
req = {
"source": {
Expand All @@ -76,6 +98,12 @@ def list_source_success(source_count):
print("list_source_by_openapi_success resp=", data)
assert data["total"] == int(source_count)

def list_source_success_https(source_count, ssl_ca, ssl_cert, ssl_key):
resp = requests.get(url=API_ENDPOINT_HTTPS, verify=ssl_ca, cert=(ssl_cert, ssl_key))
assert resp.status_code == 200
data = resp.json()
print("list_source_success_https resp=", data)
assert data["total"] == int(source_count)

def list_source_with_status_success(source_count, status_count):
resp = requests.get(url=API_ENDPOINT + "?with_status=true")
Expand All @@ -87,13 +115,19 @@ def list_source_with_status_success(source_count, status_count):
assert len(data["data"][i]["status_list"]) == int(status_count)


def list_source_with_redirect(source_count):
def list_source_with_reverse(source_count):
resp = requests.get(url=API_ENDPOINT_NOT_LEADER)
assert resp.status_code == 200
data = resp.json()
print("list_source_by_openapi_redirect resp=", data)
print("list_source_with_reverse resp=", data)
assert data["total"] == int(source_count)

def list_source_with_reverse_https(source_count, ssl_ca, ssl_cert, ssl_key):
resp = requests.get(url=API_ENDPOINT_NOT_LEADER_HTTPS, verify=ssl_ca, cert=(ssl_cert, ssl_key))
assert resp.status_code == 200
data = resp.json()
print("list_source_with_reverse_https resp=", data)
assert data["total"] == int(source_count)

def delete_source_success(source_name):
resp = requests.delete(url=API_ENDPOINT + "/" + source_name)
Expand Down Expand Up @@ -235,9 +269,12 @@ if __name__ == "__main__":
"create_source_failed": create_source_failed,
"create_source1_success": create_source1_success,
"create_source2_success": create_source2_success,
"create_source_success_https": create_source_success_https,
"update_source1_without_password_success": update_source1_without_password_success,
"list_source_success": list_source_success,
"list_source_with_redirect": list_source_with_redirect,
"list_source_success_https": list_source_success_https,
"list_source_with_reverse_https": list_source_with_reverse_https,
"list_source_with_reverse": list_source_with_reverse,
"list_source_with_status_success": list_source_with_status_success,
"delete_source_failed": delete_source_failed,
"delete_source_success": delete_source_success,
Expand Down
Loading