Skip to content

Commit

Permalink
Merge pull request #232 from stripe/saurabhbhatia/add-reject-handler
Browse files Browse the repository at this point in the history
Add Support for Reject Handler with SmokescreenContext
  • Loading branch information
saurabhbhatia-stripe authored Oct 8, 2024
2 parents dab4bde + 04ce070 commit f6f8191
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 1 deletion.
12 changes: 12 additions & 0 deletions pkg/smokescreen/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,13 @@ type Config struct {
ProxyDialTimeout func(ctx context.Context, network, address string, timeout time.Duration) (net.Conn, error)

// Custom handler to allow clients to modify reject responses
// Deprecated: RejectResponseHandler is deprecated.Please use RejectResponseHandlerWithCtx instead.
RejectResponseHandler func(*http.Response)

// Custom handler to allow clients to modify reject responses
// In case RejectResponseHandler is set, this cannot be used.
RejectResponseHandlerWithCtx func(*SmokescreenContext, *http.Response)

// Custom handler to allow clients to modify successful CONNECT responses
AcceptResponseHandler func(*SmokescreenContext, *http.Response) error

Expand Down Expand Up @@ -418,6 +423,13 @@ func (config *Config) SetupTls(certFile, keyFile string, clientCAFiles []string)
return nil
}

func (config *Config) Validate() error {
if config.RejectResponseHandler != nil && config.RejectResponseHandlerWithCtx != nil {
return errors.New("RejectResponseHandler and RejectResponseHandlerWithCtx cannot be used together")
}
return nil
}

func (config *Config) populateClientCaMap(pemCerts []byte) (ok bool) {

for len(pemCerts) > 0 {
Expand Down
10 changes: 9 additions & 1 deletion pkg/smokescreen/smokescreen.go
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,9 @@ func rejectResponse(pctx *goproxy.ProxyCtx, err error) *http.Response {
if sctx.cfg.RejectResponseHandler != nil {
sctx.cfg.RejectResponseHandler(resp)
}
if sctx.cfg.RejectResponseHandlerWithCtx != nil {
sctx.cfg.RejectResponseHandlerWithCtx(sctx, resp)
}
return resp
}

Expand Down Expand Up @@ -733,9 +736,14 @@ func findListener(ip string, defaultPort uint16) (net.Listener, error) {

func StartWithConfig(config *Config, quit <-chan interface{}) {
config.Log.Println("starting")
var err error

if err = config.Validate(); err != nil {
config.Log.Fatal("invalid config", err)
}

proxy := BuildProxy(config)
listener := config.Listener
var err error

if listener == nil {
listener, err = findListener(config.Ip, config.Port)
Expand Down
73 changes: 73 additions & 0 deletions pkg/smokescreen/smokescreen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,47 @@ func TestRejectResponseHandler(t *testing.T) {
})
}

func TestRejectResponseHandlerWithCtx(t *testing.T) {
r := require.New(t)
testHeader := "TestRejectResponseHandlerWithCtxHeader"
t.Run("Testing custom reject response handler", func(t *testing.T) {
cfg, err := testConfig("test-local-srv")

// set a custom RejectResponseHandler that will set a header on every reject response
cfg.RejectResponseHandlerWithCtx = func(_ *SmokescreenContext, resp *http.Response) {
resp.Header.Set(testHeader, "This header is added by the RejectResponseHandlerWithCtx")
}
r.NoError(err)

proxySrv := proxyServer(cfg)
r.NoError(err)
defer proxySrv.Close()

// Create a http.Client that uses our proxy
client, err := proxyClient(proxySrv.URL)
r.NoError(err)

// Send a request that should be blocked
resp, err := client.Get("http://127.0.0.1")
r.NoError(err)

// The RejectResponseHandlerWithCtx should set our custom header
h := resp.Header.Get(testHeader)
if h == "" {
t.Errorf("Expecting header %s to be set by RejectResponseHandler", testHeader)
}
// Send a request that should be allowed
resp, err = client.Get("http://example.com")
r.NoError(err)

// The header set by our custom reject response handler should not be set
h = resp.Header.Get(testHeader)
if h != "" {
t.Errorf("Expecting header %s to not be set by RejectResponseHandler", testHeader)
}
})
}

// Test that Smokescreen calls the custom accept response handler (if defined in the Config struct)
// after every accepted request
func TestAcceptResponseHandler(t *testing.T) {
Expand Down Expand Up @@ -1494,6 +1535,38 @@ func TestMitm(t *testing.T) {
})
}

func TestConfigValidate(t *testing.T) {
t.Run("Test invalid config", func(t *testing.T) {
conf := NewConfig()
conf.ConnectTimeout = 10 * time.Second
conf.ExitTimeout = 10 * time.Second
conf.AdditionalErrorMessageOnDeny = "Proxy denied"
conf.RejectResponseHandlerWithCtx = func(smokescreenContext *SmokescreenContext, response *http.Response) {
fmt.Println("RejectResponseHandlerWithCtx")
}
conf.RejectResponseHandler = func(response *http.Response) {
fmt.Println("RejectResponseHandler")
}
err := conf.Validate()
require.Error(t, err)

})

t.Run("Test valid config", func(t *testing.T) {
conf := NewConfig()
conf.ConnectTimeout = 10 * time.Second
conf.ExitTimeout = 10 * time.Second
conf.AdditionalErrorMessageOnDeny = "Proxy denied"

conf.RejectResponseHandler = func(response *http.Response) {
fmt.Println("RejectResponseHandler")
}
err := conf.Validate()
require.NoError(t, err)

})
}

func findCanonicalProxyDecision(logs []*logrus.Entry) *logrus.Entry {
for _, entry := range logs {
if entry.Message == CanonicalProxyDecision {
Expand Down

0 comments on commit f6f8191

Please sign in to comment.