Skip to content

Commit

Permalink
fix: roundtripper was ignoring WithRoundTripper option
Browse files Browse the repository at this point in the history
  • Loading branch information
clambin committed Dec 11, 2022
1 parent ae51baa commit b695946
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 32 deletions.
24 changes: 12 additions & 12 deletions httpclient/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@ import (
"github.com/prometheus/client_golang/prometheus"
)

// RoundTripperMetrics contains Prometheus metrics to capture during API calls. Each metric is expected to have two labels:
// roundTripperMetrics contains Prometheus metrics to capture during API calls. Each metric is expected to have two labels:
// the first will contain the application issuing the request. The second will contain the Path of the request.
type RoundTripperMetrics struct {
type roundTripperMetrics struct {
latency *prometheus.SummaryVec // measures latency of an API call
errors *prometheus.CounterVec // measures any errors returned by an API call
cache *prometheus.CounterVec // measures number of times the cache has been consulted
hits *prometheus.CounterVec // measures the number of times the cache was used
}

func newMetrics(namespace, subsystem, application string) *RoundTripperMetrics {
return &RoundTripperMetrics{
func newMetrics(namespace, subsystem, application string) *roundTripperMetrics {
return &roundTripperMetrics{
latency: prometheus.NewSummaryVec(prometheus.SummaryOpts{
Name: prometheus.BuildFQName(namespace, subsystem, "api_latency"),
Help: "latency of Reporter API calls",
Expand All @@ -38,25 +38,25 @@ func newMetrics(namespace, subsystem, application string) *RoundTripperMetrics {
}
}

var _ prometheus.Collector = &RoundTripperMetrics{}
var _ prometheus.Collector = &roundTripperMetrics{}

// Describe implements the prometheus.Collector interface so clients can register RoundTripperMetrics as a whole
func (pm *RoundTripperMetrics) Describe(ch chan<- *prometheus.Desc) {
// Describe implements the prometheus.Collector interface so clients can register roundTripperMetrics as a whole
func (pm *roundTripperMetrics) Describe(ch chan<- *prometheus.Desc) {
pm.latency.Describe(ch)
pm.errors.Describe(ch)
pm.cache.Describe(ch)
pm.hits.Describe(ch)
}

// Collect implements the prometheus.Collector interface so clients can register RoundTripperMetrics as a whole
func (pm *RoundTripperMetrics) Collect(ch chan<- prometheus.Metric) {
// Collect implements the prometheus.Collector interface so clients can register roundTripperMetrics as a whole
func (pm *roundTripperMetrics) Collect(ch chan<- prometheus.Metric) {
pm.latency.Collect(ch)
pm.errors.Collect(ch)
pm.cache.Collect(ch)
pm.hits.Collect(ch)
}

func (pm *RoundTripperMetrics) reportErrors(err error, labelValues ...string) {
func (pm *roundTripperMetrics) reportErrors(err error, labelValues ...string) {
if pm == nil {
return
}
Expand All @@ -68,14 +68,14 @@ func (pm *RoundTripperMetrics) reportErrors(err error, labelValues ...string) {
pm.errors.WithLabelValues(labelValues...).Add(value)
}

func (pm *RoundTripperMetrics) makeLatencyTimer(labelValues ...string) (timer *prometheus.Timer) {
func (pm *roundTripperMetrics) makeLatencyTimer(labelValues ...string) (timer *prometheus.Timer) {
if pm != nil {
timer = prometheus.NewTimer(pm.latency.WithLabelValues(labelValues...))
}
return
}

func (pm *RoundTripperMetrics) reportCache(hit bool, labelValues ...string) {
func (pm *roundTripperMetrics) reportCache(hit bool, labelValues ...string) {
if pm == nil {
return
}
Expand Down
6 changes: 3 additions & 3 deletions httpclient/roundtripper.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
// It implements the http.RoundTripper interface.
type RoundTripper struct {
roundTripper http.RoundTripper
metrics *RoundTripperMetrics
metrics *roundTripperMetrics
cache *Cache
}

Expand Down Expand Up @@ -41,7 +41,7 @@ func (r *RoundTripper) RoundTrip(request *http.Request) (response *http.Response
path := request.URL.Path
timer := r.metrics.makeLatencyTimer(path, request.Method)

response, err = http.DefaultTransport.RoundTrip(request)
response, err = r.roundTripper.RoundTrip(request)

if timer != nil {
timer.ObserveDuration()
Expand Down Expand Up @@ -102,7 +102,7 @@ func (o WithCache) apply(r *RoundTripper) {
r.cache = NewCache(o.Table, o.DefaultExpiry, o.CleanupInterval)
}

// WithRoundTripper assigns a final RoundTripper to the chain. Use this to control the final HTTP exchange's behaviour
// WithRoundTripper assigns a final RoundTripper of the chain. Use this to control the final HTTP exchange's behaviour
// (e.g. using a proxy to make the HTTP call).
type WithRoundTripper struct {
RoundTripper http.RoundTripper
Expand Down
35 changes: 18 additions & 17 deletions httpclient/roundtripper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,34 +70,25 @@ func TestRoundTripper_Collect(t *testing.T) {
r := httpclient.NewRoundTripper(
httpclient.WithCache{},
httpclient.WithRoundTripperMetrics{Application: "foo"},
httpclient.WithRoundTripper{RoundTripper: &stubbedRoundTripper{}},
)
registry := prometheus.NewRegistry()
registry.MustRegister(r)

c := &http.Client{Transport: r}
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.URL.Path != "/" {
http.Error(w, "invalid path", http.StatusNotFound)
return
}
_, _ = w.Write([]byte("Hello"))
http.Error(w, "this is not the server you're looking for", http.StatusNotFound)
}))
defer s.Close()

for i := 0; i < 2; i++ {
req, _ := http.NewRequest(http.MethodGet, s.URL+"/", nil)
resp, err := c.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, "Hello", string(body))
_ = resp.Body.Close()
}

req, _ := http.NewRequest(http.MethodGet, s.URL+"/invalid", nil)
req, _ := http.NewRequest(http.MethodGet, s.URL+"/", nil)
resp, err := c.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)

req, _ = http.NewRequest(http.MethodGet, s.URL+"/invalid", nil)
resp, err = c.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusNotFound, resp.StatusCode)

metrics, err := registry.Gather()
Expand All @@ -117,3 +108,13 @@ func TestRoundTripper_Collect(t *testing.T) {
}
}
}

type stubbedRoundTripper struct{}

func (r *stubbedRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
statusCode := http.StatusNotFound
if req.URL.Path == "/" {
statusCode = http.StatusOK
}
return &http.Response{StatusCode: statusCode}, nil
}

0 comments on commit b695946

Please sign in to comment.