Skip to content

Commit

Permalink
Export SmokescreenContext type (#200)
Browse files Browse the repository at this point in the history
* export SmokescreenContext type

* also export AclDecision

* ResolvedAddr too

* consistent caps

* Update pkg/smokescreen/smokescreen.go

Co-authored-by: jjiang-stripe <55402658+jjiang-stripe@users.noreply.github.com>

* export Decision

---------

Co-authored-by: jjiang-stripe <55402658+jjiang-stripe@users.noreply.github.com>
  • Loading branch information
cmoresco-stripe and jjiang-stripe authored Aug 4, 2023
1 parent 48069eb commit c86310d
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 44 deletions.
2 changes: 1 addition & 1 deletion pkg/smokescreen/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ type Config struct {
RejectResponseHandler func(*http.Response)

// Custom handler to allow clients to modify successful CONNECT responses
AcceptResponseHandler func(*smokescreenContext, *http.Response) error
AcceptResponseHandler func(*SmokescreenContext, *http.Response) error

// UnsafeAllowPrivateRanges inverts the default behavior, telling smokescreen to allow private IP
// ranges by default (exempting loopback and unicast ranges)
Expand Down
84 changes: 42 additions & 42 deletions pkg/smokescreen/smokescreen.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,17 @@ const (

type ipType int

type aclDecision struct {
type ACLDecision struct {
reason, role, project, outboundHost string
resolvedAddr *net.TCPAddr
ResolvedAddr *net.TCPAddr
allow bool
enforceWouldDeny bool
}

type smokescreenContext struct {
type SmokescreenContext struct {
cfg *Config
start time.Time
decision *aclDecision
Decision *ACLDecision
proxyType string
logger *logrus.Entry
requestedHost string
Expand Down Expand Up @@ -246,17 +246,17 @@ func dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
return nil, fmt.Errorf("dialContext missing required *goproxy.ProxyCtx")
}

sctx, ok := pctx.UserData.(*smokescreenContext)
sctx, ok := pctx.UserData.(*SmokescreenContext)
if !ok {
return nil, fmt.Errorf("dialContext missing required *smokescreenContext")
return nil, fmt.Errorf("dialContext missing required *SmokescreenContext")
}
d := sctx.decision
d := sctx.Decision

// If an address hasn't been resolved, does not match the original outboundHost,
// or is not tcp we must re-resolve it before establishing the connection.
if d.resolvedAddr == nil || d.outboundHost != addr || network != "tcp" {
if d.ResolvedAddr == nil || d.outboundHost != addr || network != "tcp" {
var err error
d.resolvedAddr, d.reason, err = safeResolve(sctx.cfg, network, addr)
d.ResolvedAddr, d.reason, err = safeResolve(sctx.cfg, network, addr)
if err != nil {
if _, ok := err.(denyError); ok {
sctx.cfg.Log.WithFields(
Expand All @@ -279,9 +279,9 @@ func dialContext(ctx context.Context, network, addr string) (net.Conn, error) {

start := time.Now()
if sctx.cfg.ProxyDialTimeout == nil {
conn, err = net.DialTimeout(network, d.resolvedAddr.String(), sctx.cfg.ConnectTimeout)
conn, err = net.DialTimeout(network, d.ResolvedAddr.String(), sctx.cfg.ConnectTimeout)
} else {
conn, err = sctx.cfg.ProxyDialTimeout(ctx, network, d.resolvedAddr.String(), sctx.cfg.ConnectTimeout)
conn, err = sctx.cfg.ProxyDialTimeout(ctx, network, d.ResolvedAddr.String(), sctx.cfg.ConnectTimeout)
}
connTime := time.Since(start)

Expand Down Expand Up @@ -332,7 +332,7 @@ func dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
// HTTPErrorHandler allows returning a custom error response when smokescreen
// fails to connect to the proxy target.
func HTTPErrorHandler(w io.WriteCloser, pctx *goproxy.ProxyCtx, err error) {
sctx := pctx.UserData.(*smokescreenContext)
sctx := pctx.UserData.(*SmokescreenContext)
resp := rejectResponse(pctx, err)

if err := resp.Write(w); err != nil {
Expand All @@ -345,7 +345,7 @@ func HTTPErrorHandler(w io.WriteCloser, pctx *goproxy.ProxyCtx, err error) {
}

func rejectResponse(pctx *goproxy.ProxyCtx, err error) *http.Response {
sctx := pctx.UserData.(*smokescreenContext)
sctx := pctx.UserData.(*SmokescreenContext)

var msg, status string
var code int
Expand Down Expand Up @@ -411,7 +411,7 @@ func configureTransport(tr *http.Transport, cfg *Config) {
}
}

func newContext(cfg *Config, proxyType string, req *http.Request) *smokescreenContext {
func newContext(cfg *Config, proxyType string, req *http.Request) *SmokescreenContext {
start := time.Now()

logger := cfg.Log.WithFields(logrus.Fields{
Expand All @@ -423,7 +423,7 @@ func newContext(cfg *Config, proxyType string, req *http.Request) *smokescreenCo
LogFieldTraceID: req.Header.Get(traceHeader),
})

return &smokescreenContext{
return &SmokescreenContext{
cfg: cfg,
logger: logger,
proxyType: proxyType,
Expand Down Expand Up @@ -462,7 +462,7 @@ func BuildProxy(config *Config) *goproxy.ProxyHttpServer {
// proxy requests we are able to specify the request during the call to OnResponse().
sctx := newContext(config, httpProxy, req)

// Attach smokescreenContext to goproxy.ProxyCtx
// Attach SmokescreenContext to goproxy.ProxyCtx
pctx.UserData = sctx

// Delete Smokescreen specific headers before goproxy forwards the request
Expand All @@ -482,16 +482,16 @@ func BuildProxy(config *Config) *goproxy.ProxyHttpServer {
}

sctx.logger.WithField("url", req.RequestURI).Debug("received HTTP proxy request")
sctx.decision, sctx.lookupTime, pctx.Error = checkIfRequestShouldBeProxied(config, req, destination)
sctx.Decision, sctx.lookupTime, pctx.Error = checkIfRequestShouldBeProxied(config, req, destination)

// Returning any kind of response in this handler is goproxy's way of short circuiting
// the request. The original request will never be sent, and goproxy will invoke our
// response filter attached via the OnResponse() handler.
if pctx.Error != nil {
return req, rejectResponse(pctx, pctx.Error)
}
if !sctx.decision.allow {
return req, rejectResponse(pctx, denyError{errors.New(sctx.decision.reason)})
if !sctx.Decision.allow {
return req, rejectResponse(pctx, denyError{errors.New(sctx.Decision.reason)})
}

// Call the custom request handler if it exists
Expand Down Expand Up @@ -539,9 +539,9 @@ func BuildProxy(config *Config) *goproxy.ProxyHttpServer {
// function will be called again with the previously returned response, which will
// simply trigger the logHTTP function and return.
proxy.OnResponse().DoFunc(func(resp *http.Response, pctx *goproxy.ProxyCtx) *http.Response {
sctx := pctx.UserData.(*smokescreenContext)
sctx := pctx.UserData.(*SmokescreenContext)

if resp != nil && pctx.Error == nil && sctx.decision.allow {
if resp != nil && pctx.Error == nil && sctx.Decision.allow {
if resp.Header.Get(errorHeader) != "" {
resp.Header.Del(errorHeader)
}
Expand All @@ -564,9 +564,9 @@ func BuildProxy(config *Config) *goproxy.ProxyHttpServer {
// The goproxy OnResponse() function above is only called for non-https responses.
if config.AcceptResponseHandler != nil {
proxy.ConnectRespHandler = func(pctx *goproxy.ProxyCtx, resp *http.Response) error {
sctx, ok := pctx.UserData.(*smokescreenContext)
sctx, ok := pctx.UserData.(*SmokescreenContext)
if !ok {
return fmt.Errorf("goproxy ProxyContext missing required UserData *smokescreenContext")
return fmt.Errorf("goproxy ProxyContext missing required UserData *SmokescreenContext")
}
return config.AcceptResponseHandler(sctx, resp)
}
Expand All @@ -576,7 +576,7 @@ func BuildProxy(config *Config) *goproxy.ProxyHttpServer {
}

func logProxy(config *Config, pctx *goproxy.ProxyCtx) {
sctx := pctx.UserData.(*smokescreenContext)
sctx := pctx.UserData.(*SmokescreenContext)

fields := logrus.Fields{}

Expand All @@ -589,8 +589,8 @@ func logProxy(config *Config, pctx *goproxy.ProxyCtx) {
}
}

decision := sctx.decision
if sctx.decision != nil {
decision := sctx.Decision
if sctx.Decision != nil {
fields[LogFieldRole] = decision.role
fields[LogFieldProject] = decision.project
}
Expand All @@ -609,7 +609,7 @@ func logProxy(config *Config, pctx *goproxy.ProxyCtx) {
fields[LogFieldContentLength] = pctx.Resp.ContentLength
}

if sctx.decision != nil {
if sctx.Decision != nil {
fields[LogFieldDecisionReason] = decision.reason
fields[LogFieldEnforceWouldDeny] = decision.enforceWouldDeny
fields[LogFieldAllow] = decision.allow
Expand All @@ -633,7 +633,7 @@ func logProxy(config *Config, pctx *goproxy.ProxyCtx) {
}

func handleConnect(config *Config, pctx *goproxy.ProxyCtx) (string, error) {
sctx := pctx.UserData.(*smokescreenContext)
sctx := pctx.UserData.(*SmokescreenContext)

// Check if requesting role is allowed to talk to remote
destination, err := hostport.New(pctx.Req.Host, false)
Expand All @@ -644,13 +644,13 @@ func handleConnect(config *Config, pctx *goproxy.ProxyCtx) (string, error) {

// checkIfRequestShouldBeProxied can return an error if either the resolved address is disallowed,
// or if there is a DNS resolution failure.
sctx.decision, sctx.lookupTime, pctx.Error = checkIfRequestShouldBeProxied(config, pctx.Req, destination)
sctx.Decision, sctx.lookupTime, pctx.Error = checkIfRequestShouldBeProxied(config, pctx.Req, destination)
if pctx.Error != nil {
// DNS resolution failure
return "", pctx.Error
}
if !sctx.decision.allow {
return "", denyError{errors.New(sctx.decision.reason)}
if !sctx.Decision.allow {
return "", denyError{errors.New(sctx.Decision.reason)}
}

// Call the custom request handler if it exists
Expand Down Expand Up @@ -881,7 +881,7 @@ func getRole(config *Config, req *http.Request) (string, error) {
}
}

func checkIfRequestShouldBeProxied(config *Config, req *http.Request, destination hostport.HostPort) (*aclDecision, time.Duration, error) {
func checkIfRequestShouldBeProxied(config *Config, req *http.Request, destination hostport.HostPort) (*ACLDecision, time.Duration, error) {
decision := checkACLsForRequest(config, req, destination)

var lookupTime time.Duration
Expand All @@ -898,15 +898,15 @@ func checkIfRequestShouldBeProxied(config *Config, req *http.Request, destinatio
decision.allow = false
decision.enforceWouldDeny = true
} else {
decision.resolvedAddr = resolved
decision.ResolvedAddr = resolved
}
}

return decision, lookupTime, nil
}

func checkACLsForRequest(config *Config, req *http.Request, destination hostport.HostPort) *aclDecision {
decision := &aclDecision{
func checkACLsForRequest(config *Config, req *http.Request, destination hostport.HostPort) *ACLDecision {
decision := &ACLDecision{
outboundHost: destination.String(),
}

Expand All @@ -932,9 +932,9 @@ func checkACLsForRequest(config *Config, req *http.Request, destination hostport
return decision
}

aclDecision, err := config.EgressACL.Decide(role, destination.Host)
decision.project = aclDecision.Project
decision.reason = aclDecision.Reason
ACLDecision, err := config.EgressACL.Decide(role, destination.Host)
decision.project = ACLDecision.Project
decision.reason = ACLDecision.Reason
if err != nil {
config.Log.WithFields(logrus.Fields{
"error": err,
Expand All @@ -947,11 +947,11 @@ func checkACLsForRequest(config *Config, req *http.Request, destination hostport

tags := map[string]string{
"role": decision.role,
"def_rule": fmt.Sprintf("%t", aclDecision.Default),
"project": aclDecision.Project,
"def_rule": fmt.Sprintf("%t", ACLDecision.Default),
"project": ACLDecision.Project,
}

switch aclDecision.Result {
switch ACLDecision.Result {
case acl.Deny:
decision.enforceWouldDeny = true
config.MetricsClient.IncrWithTags("acl.deny", tags, 1)
Expand All @@ -970,7 +970,7 @@ func checkACLsForRequest(config *Config, req *http.Request, destination hostport
config.Log.WithFields(logrus.Fields{
"role": role,
"destination": destination.Host,
"action": aclDecision.Result.String(),
"action": ACLDecision.Result.String(),
}).Warn("Unknown ACL action")
decision.reason = "Internal error"
config.MetricsClient.IncrWithTags("acl.unknown_error", tags, 1)
Expand Down
2 changes: 1 addition & 1 deletion pkg/smokescreen/smokescreen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1073,7 +1073,7 @@ func TestAcceptResponseHandler(t *testing.T) {
cfg, err := testConfig("test-local-srv")

// set a custom AcceptResponseHandler that will set a header on every reject response
cfg.AcceptResponseHandler = func(_ *smokescreenContext, resp *http.Response) error {
cfg.AcceptResponseHandler = func(_ *SmokescreenContext, resp *http.Response) error {
resp.Header.Set(testHeader, "This header is added by the AcceptResponseHandler")
return nil
}
Expand Down

0 comments on commit c86310d

Please sign in to comment.