Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(header forwarding): add HeaderForwarder and plumbing #507

Merged
merged 9 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ GO_MOD_PACKAGES=./types/...
GO_FOLDERS=$(shell echo ${GO_PACKAGES} | sed -e "s/\.\///g" | sed -e "s/\/\.\.\.//g")
GO_MOD_FOLDERS=$(shell echo ${GO_MOD_PACKAGES} | sed -e "s/\.\///g" | sed -e "s/\/\.\.\.//g")
TEST_SCRIPT=go test ${GO_PACKAGES}
LINT_SETTINGS=golint,misspell,gocyclo,gocritic,whitespace,goconst,gocognit,bodyclose,unconvert,lll,unparam
LINT_SETTINGS=misspell,gocyclo,gocritic,whitespace,goconst,gocognit,bodyclose,unconvert,lll,unparam

build:
go build ./...

deps:
go get ./...

Expand Down
2 changes: 1 addition & 1 deletion asserter/asserter.go
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ func NewGenericRosettaClient(
ignoreRosettaSpecValidation: true,
}

//init default operation statuses for generic rosetta client
// init default operation statuses for generic rosetta client
InitOperationStatus(asserter)

return asserter, nil
Expand Down
2 changes: 1 addition & 1 deletion constructor/worker/worker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1848,7 +1848,7 @@ func TestHTTPRequestWorker(t *testing.T) {

w.Header().Set("Content-Type", test.contentType)
w.WriteHeader(test.statusCode)
fmt.Fprintf(w, test.response)
fmt.Fprint(w, test.response)
}))

defer ts.Close()
Expand Down
2 changes: 2 additions & 0 deletions examples/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,14 @@ func NewBlockchainRouter(
networkAPIController := server.NewNetworkAPIController(
networkAPIService,
asserter,
nil,
)

blockAPIService := services.NewBlockAPIService(network)
blockAPIController := server.NewBlockAPIController(
blockAPIService,
asserter,
nil,
)

return server.NewRouter(networkAPIController, blockAPIController)
Expand Down
2 changes: 1 addition & 1 deletion fetcher/block.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ func (f *Fetcher) UnsafeBlock(
}

// Exit early if no need to fetch txs
if blockResponse.OtherTransactions == nil || len(blockResponse.OtherTransactions) == 0 {
if len(blockResponse.OtherTransactions) == 0 {
return blockResponse.Block, nil
}

Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ require (
github.com/dgraph-io/badger/v2 v2.2007.4
github.com/ethereum/go-ethereum v1.10.21
github.com/fatih/color v1.13.0
github.com/google/uuid v1.6.0
github.com/gorilla/mux v1.8.0
github.com/lucasjones/reggen v0.0.0-20180717132126-cdb49ff09d77
github.com/neilotoole/errgroup v0.1.6
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEW
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
Expand Down
43 changes: 43 additions & 0 deletions headerforwarder/context_headers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright 2024 Coinbase, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package headerforwarder

import (
"context"
"net/http"

"github.com/google/uuid"
)

type contextKey string

const requestIDKey = contextKey("request_id")

func ContextWithRosettaID(ctx context.Context) context.Context {
return context.WithValue(ctx, requestIDKey, uuid.NewString())
}

func RosettaIDFromContext(ctx context.Context) string {
return ctx.Value(requestIDKey).(string)
}

func RosettaIDFromRequest(r *http.Request) string {
switch value := r.Context().Value(requestIDKey).(type) {
case string:
return value
default:
return ""
}
}
142 changes: 142 additions & 0 deletions headerforwarder/forwarder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
// Copyright 2024 Coinbase, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package headerforwarder

import (
"net/http"
)

// HeaderExtractingTransport is a utility to help a rosetta server forward headers to and from
// native node requests. It implements several interfaces to achieve that:
// - http.RoundTripper: this can be used to create an http Client that will automatically save headers
// if necessary
// - func(http.Handler) http.Handler: this can be used to wrap an http.Handler to set headers
// on the response
//
// the headers can be requested later.
//
// TODO: this should expire entries after a certain amount of time
type HeaderForwarder struct {
requestHeaders map[string]http.Header
interestingHeaders []string
actualTransport http.RoundTripper
}

func NewHeaderForwarder(interestingHeaders []string, transport http.RoundTripper) *HeaderForwarder {
return &HeaderForwarder{
requestHeaders: make(map[string]http.Header),
interestingHeaders: interestingHeaders,
actualTransport: transport,
}
}

// RoundTrip implements http.RoundTripper and will be used to construct an http Client which
// saves the native node response headers if necessary.
func (hf *HeaderForwarder) RoundTrip(req *http.Request) (*http.Response, error) {
resp, err := hf.actualTransport.RoundTrip(req)

if err == nil && hf.shouldRememberHeaders(req, resp) {
hf.rememberHeaders(req, resp)
}

return resp, err
}

// shouldRememberHeaders is called to determine if response headers should be remembered for a
// given request. Response headers will only be remembered if the request does not contain all of
// the interesting headers and the response contains at least one of the interesting headers.
//
// It should be noted that the request and response here are for a request to the native node,
// not a request to the Rosetta server.
func (hf *HeaderForwarder) shouldRememberHeaders(req *http.Request, resp *http.Response) bool {
requestHasAllHeaders := true
responseHasSomeHeaders := false

for _, interestingHeader := range hf.interestingHeaders {
_, requestHasHeader := req.Header[http.CanonicalHeaderKey(interestingHeader)]
_, responseHasHeader := resp.Header[http.CanonicalHeaderKey(interestingHeader)]

if !requestHasHeader {
requestHasAllHeaders = false
}

if responseHasHeader {
responseHasSomeHeaders = true
}
}

// only remember headers if the request does not contain all of the interesting headers and the
// response contains at least one
return !requestHasAllHeaders && responseHasSomeHeaders
}

// rememberHeaders is called to save the native node response headers. The request object
// here is a native node request (constructed by go-ethereum for geth-based rosetta implementations).
// The response object is a native node response.
func (hf *HeaderForwarder) rememberHeaders(req *http.Request, resp *http.Response) {
ctx := req.Context()
// rosettaRequestID := services.osettaIdFromContext(ctx)
potterbm-cb marked this conversation as resolved.
Show resolved Hide resolved
rosettaRequestID := RosettaIDFromContext(ctx)

// Only remember interesting headers
headersToRemember := make(http.Header)
for _, interestingHeader := range hf.interestingHeaders {
headersToRemember.Set(interestingHeader, resp.Header.Get(interestingHeader))
}

hf.requestHeaders[rosettaRequestID] = headersToRemember
}

// GetResponseHeaders returns any native node response headers that were recorded for a request ID.
func (hf *HeaderForwarder) getResponseHeaders(rosettaRequestID string) (http.Header, bool) {
headers, ok := hf.requestHeaders[rosettaRequestID]

// Delete the headers from the map after they are retrieved
// This is safe to call even if the key doesn't exist
delete(hf.requestHeaders, rosettaRequestID)

return headers, ok
}

// HeaderForwarderHandler will allow the next handler to serve the request, and then checks
// if there are any native node response headers recorded for the request. If there are, it will set
// those headers on the response
func (hf *HeaderForwarder) HeaderForwarderHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// add a unique ID to the request context, and make a new request for it
requestWithID := hf.WithRequestID(r)

// Serve the request
// NOTE: ResponseWriter::WriteHeader() WILL be called here, so we can't set headers after this happens
// We include a wrapper around the response writer that allows us to set headers just before
// WriteHeader is called
wrappedResponseWriter := NewResponseWriter(
w,
RosettaIDFromRequest(requestWithID),
hf.getResponseHeaders,
)
next.ServeHTTP(wrappedResponseWriter, requestWithID)
})
}

// WithRequestID adds a unique ID to the request context. A new request is returned that contains the
// new context
func (hf *HeaderForwarder) WithRequestID(req *http.Request) *http.Request {
ctx := req.Context()
ctxWithID := ContextWithRosettaID(ctx)
requestWithID := req.WithContext(ctxWithID)

return requestWithID
}
68 changes: 68 additions & 0 deletions headerforwarder/response_writer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright 2024 Coinbase, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package headerforwarder

import (
"net/http"
)

// ResponseWriter is a wrapper around a http.ResponseWriter that allows us to set headers
// just before the WriteHeader function is called. These headers will be extracted from native node
// responses, and set on the rosetta response.
type ResponseWriter struct {
writer http.ResponseWriter
RosettaRequestID string
GetAdditionalHeaders func(string) (http.Header, bool)
}

func NewResponseWriter(
writer http.ResponseWriter,
rosettaRequestID string,
getAdditionalHeaders func(string) (http.Header, bool),
) *ResponseWriter {
return &ResponseWriter{
writer: writer,
RosettaRequestID: rosettaRequestID,
GetAdditionalHeaders: getAdditionalHeaders,
}
}

// Header passes through to the underlying ResponseWriter instance
func (hfrw *ResponseWriter) Header() http.Header {
return hfrw.writer.Header()
}

// Write passes through to the underlying ResponseWriter instance
func (hfrw *ResponseWriter) Write(b []byte) (int, error) {
return hfrw.writer.Write(b)
}

// WriteHeader will add any final extracted headers, and then pass through to the underlying ResponseWriter instance
func (hfrw *ResponseWriter) WriteHeader(statusCode int) {
hfrw.AddExtractedHeaders()
hfrw.writer.WriteHeader(statusCode)
}

func (hfrw *ResponseWriter) AddExtractedHeaders() {
headers, hasAdditionalHeaders := hfrw.GetAdditionalHeaders(hfrw.RosettaRequestID)

if hasAdditionalHeaders {
for key, values := range headers {
for _, value := range values {
hfrw.writer.Header().Add(key, value)
}
}
}
}
26 changes: 20 additions & 6 deletions server/api_account.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package server

import (
"context"
"encoding/json"
"net/http"
"strings"
Expand All @@ -28,18 +29,21 @@ import (
// A AccountAPIController binds http requests to an api service and writes the service results to
// the http response
type AccountAPIController struct {
service AccountAPIServicer
asserter *asserter.Asserter
service AccountAPIServicer
asserter *asserter.Asserter
contextFromRequest func(*http.Request) context.Context
}

// NewAccountAPIController creates a default api controller
func NewAccountAPIController(
s AccountAPIServicer,
asserter *asserter.Asserter,
contextFromRequest func(*http.Request) context.Context,
) Router {
return &AccountAPIController{
service: s,
asserter: asserter,
service: s,
asserter: asserter,
contextFromRequest: contextFromRequest,
}
}

Expand All @@ -61,6 +65,16 @@ func (c *AccountAPIController) Routes() Routes {
}
}

func (c *AccountAPIController) ContextFromRequest(r *http.Request) context.Context {
ctx := r.Context()

if c.contextFromRequest != nil {
ctx = c.contextFromRequest(r)
}

return ctx
}

// AccountBalance - Get an Account's Balance
func (c *AccountAPIController) AccountBalance(w http.ResponseWriter, r *http.Request) {
accountBalanceRequest := &types.AccountBalanceRequest{}
Expand All @@ -81,7 +95,7 @@ func (c *AccountAPIController) AccountBalance(w http.ResponseWriter, r *http.Req
return
}

result, serviceErr := c.service.AccountBalance(r.Context(), accountBalanceRequest)
result, serviceErr := c.service.AccountBalance(c.ContextFromRequest(r), accountBalanceRequest)
if serviceErr != nil {
EncodeJSONResponse(serviceErr, http.StatusInternalServerError, w)

Expand Down Expand Up @@ -111,7 +125,7 @@ func (c *AccountAPIController) AccountCoins(w http.ResponseWriter, r *http.Reque
return
}

result, serviceErr := c.service.AccountCoins(r.Context(), accountCoinsRequest)
result, serviceErr := c.service.AccountCoins(c.ContextFromRequest(r), accountCoinsRequest)
if serviceErr != nil {
EncodeJSONResponse(serviceErr, http.StatusInternalServerError, w)

Expand Down
Loading
Loading