diff --git a/pkg/smokescreen/config.go b/pkg/smokescreen/config.go index 63897b13..274d670f 100644 --- a/pkg/smokescreen/config.go +++ b/pkg/smokescreen/config.go @@ -85,8 +85,13 @@ type Config struct { ProxyDialTimeout func(ctx context.Context, network, address string, timeout time.Duration) (net.Conn, error) // Custom handler to allow clients to modify reject responses + // Deprecated: RejectResponseHandler is deprecated.Please use RejectResponseHandlerWithCtx instead. RejectResponseHandler func(*http.Response) + // Custom handler to allow clients to modify reject responses + // In case RejectResponseHandler is set, this cannot be used. + RejectResponseHandlerWithCtx func(*SmokescreenContext, *http.Response) + // Custom handler to allow clients to modify successful CONNECT responses AcceptResponseHandler func(*SmokescreenContext, *http.Response) error @@ -418,6 +423,13 @@ func (config *Config) SetupTls(certFile, keyFile string, clientCAFiles []string) return nil } +func (config *Config) Validate() error { + if config.RejectResponseHandler != nil && config.RejectResponseHandlerWithCtx != nil { + return errors.New("RejectResponseHandler and RejectResponseHandlerWithCtx cannot be used together") + } + return nil +} + func (config *Config) populateClientCaMap(pemCerts []byte) (ok bool) { for len(pemCerts) > 0 { diff --git a/pkg/smokescreen/smokescreen.go b/pkg/smokescreen/smokescreen.go index 5dc95135..14295123 100644 --- a/pkg/smokescreen/smokescreen.go +++ b/pkg/smokescreen/smokescreen.go @@ -404,6 +404,9 @@ func rejectResponse(pctx *goproxy.ProxyCtx, err error) *http.Response { if sctx.cfg.RejectResponseHandler != nil { sctx.cfg.RejectResponseHandler(resp) } + if sctx.cfg.RejectResponseHandlerWithCtx != nil { + sctx.cfg.RejectResponseHandlerWithCtx(sctx, resp) + } return resp } @@ -733,9 +736,14 @@ func findListener(ip string, defaultPort uint16) (net.Listener, error) { func StartWithConfig(config *Config, quit <-chan interface{}) { config.Log.Println("starting") + var err error + + if err = config.Validate(); err != nil { + config.Log.Fatal("invalid config", err) + } + proxy := BuildProxy(config) listener := config.Listener - var err error if listener == nil { listener, err = findListener(config.Ip, config.Port) diff --git a/pkg/smokescreen/smokescreen_test.go b/pkg/smokescreen/smokescreen_test.go index 0f7b0120..89718145 100644 --- a/pkg/smokescreen/smokescreen_test.go +++ b/pkg/smokescreen/smokescreen_test.go @@ -1067,6 +1067,47 @@ func TestRejectResponseHandler(t *testing.T) { }) } +func TestRejectResponseHandlerWithCtx(t *testing.T) { + r := require.New(t) + testHeader := "TestRejectResponseHandlerWithCtxHeader" + t.Run("Testing custom reject response handler", func(t *testing.T) { + cfg, err := testConfig("test-local-srv") + + // set a custom RejectResponseHandler that will set a header on every reject response + cfg.RejectResponseHandlerWithCtx = func(_ *SmokescreenContext, resp *http.Response) { + resp.Header.Set(testHeader, "This header is added by the RejectResponseHandlerWithCtx") + } + r.NoError(err) + + proxySrv := proxyServer(cfg) + r.NoError(err) + defer proxySrv.Close() + + // Create a http.Client that uses our proxy + client, err := proxyClient(proxySrv.URL) + r.NoError(err) + + // Send a request that should be blocked + resp, err := client.Get("http://127.0.0.1") + r.NoError(err) + + // The RejectResponseHandlerWithCtx should set our custom header + h := resp.Header.Get(testHeader) + if h == "" { + t.Errorf("Expecting header %s to be set by RejectResponseHandler", testHeader) + } + // Send a request that should be allowed + resp, err = client.Get("http://example.com") + r.NoError(err) + + // The header set by our custom reject response handler should not be set + h = resp.Header.Get(testHeader) + if h != "" { + t.Errorf("Expecting header %s to not be set by RejectResponseHandler", testHeader) + } + }) +} + // Test that Smokescreen calls the custom accept response handler (if defined in the Config struct) // after every accepted request func TestAcceptResponseHandler(t *testing.T) { @@ -1494,6 +1535,38 @@ func TestMitm(t *testing.T) { }) } +func TestConfigValidate(t *testing.T) { + t.Run("Test invalid config", func(t *testing.T) { + conf := NewConfig() + conf.ConnectTimeout = 10 * time.Second + conf.ExitTimeout = 10 * time.Second + conf.AdditionalErrorMessageOnDeny = "Proxy denied" + conf.RejectResponseHandlerWithCtx = func(smokescreenContext *SmokescreenContext, response *http.Response) { + fmt.Println("RejectResponseHandlerWithCtx") + } + conf.RejectResponseHandler = func(response *http.Response) { + fmt.Println("RejectResponseHandler") + } + err := conf.Validate() + require.Error(t, err) + + }) + + t.Run("Test valid config", func(t *testing.T) { + conf := NewConfig() + conf.ConnectTimeout = 10 * time.Second + conf.ExitTimeout = 10 * time.Second + conf.AdditionalErrorMessageOnDeny = "Proxy denied" + + conf.RejectResponseHandler = func(response *http.Response) { + fmt.Println("RejectResponseHandler") + } + err := conf.Validate() + require.NoError(t, err) + + }) +} + func findCanonicalProxyDecision(logs []*logrus.Entry) *logrus.Entry { for _, entry := range logs { if entry.Message == CanonicalProxyDecision {