From 22d49c0a0569f5834b31c35c96169c29b5ed5c0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mi=C5=82osz=20Szekiel?= <12242002+mszekiel@users.noreply.github.com> Date: Thu, 14 Sep 2023 10:06:40 +0200 Subject: [PATCH] feat: improve geolocation handling for middleware --- httpx/client_info.go | 22 ++++++++++++++++++++++ httpx/client_info_test.go | 32 ++++++++++++++++++++++++++++++++ otelx/semconv/context.go | 14 +++++--------- otelx/semconv/context_test.go | 17 +++++++++++++---- otelx/semconv/events.go | 28 ++++++++++++++++++++++------ 5 files changed, 94 insertions(+), 19 deletions(-) diff --git a/httpx/client_info.go b/httpx/client_info.go index 8d6c3a98..90a9866f 100644 --- a/httpx/client_info.go +++ b/httpx/client_info.go @@ -9,6 +9,12 @@ import ( "strings" ) +type GeoLocation struct { + City string + Region string + Country string +} + func GetClientIPAddressesWithoutInternalIPs(ipAddresses []string) (string, error) { var res string @@ -36,3 +42,19 @@ func ClientIP(r *http.Request) string { return r.RemoteAddr } } + +func ClientGeoLocation(r *http.Request) GeoLocation { + var clientGeoLocation GeoLocation + + if r.Header.Get("Cf-Ipcity") != "" { + clientGeoLocation.City = r.Header.Get("Cf-Ipcity") + } + if r.Header.Get("Cf-Region-Code") != "" { + clientGeoLocation.Region = r.Header.Get("Cf-Region-Code") + } + if r.Header.Get("Cf-Ipcountry") != "" { + clientGeoLocation.Country = r.Header.Get("Cf-Ipcountry") + } + + return clientGeoLocation +} diff --git a/httpx/client_info_test.go b/httpx/client_info_test.go index f3682f4d..d44d7ba6 100644 --- a/httpx/client_info_test.go +++ b/httpx/client_info_test.go @@ -58,3 +58,35 @@ func TestClientIP(t *testing.T) { assert.Equal(t, "1.0.0.4", ClientIP(req)) }) } + +func TestClientGeoLocation(t *testing.T) { + req := http.Request{ + Header: http.Header{}, + } + req.Header.Add("cf-ipcity", "Berlin") + req.Header.Add("cf-ipcountry", "Germany") + req.Header.Add("cf-region-code", "BE") + + t.Run("cf-ipcity", func(t *testing.T) { + req := req.Clone(context.Background()) + assert.Equal(t, "Berlin", ClientGeoLocation(req).City) + }) + + t.Run("cf-ipcountry", func(t *testing.T) { + req := req.Clone(context.Background()) + assert.Equal(t, "Germany", ClientGeoLocation(req).Country) + }) + + t.Run("cf-region-code", func(t *testing.T) { + req := req.Clone(context.Background()) + assert.Equal(t, "BE", ClientGeoLocation(req).Region) + }) + + t.Run("empty", func(t *testing.T) { + req := req.Clone(context.Background()) + req.Header.Del("cf-ipcity") + req.Header.Del("cf-ipcountry") + req.Header.Del("cf-region-code") + assert.Equal(t, GeoLocation{}, ClientGeoLocation(req)) + }) +} diff --git a/otelx/semconv/context.go b/otelx/semconv/context.go index eef390e6..f9ea4e06 100644 --- a/otelx/semconv/context.go +++ b/otelx/semconv/context.go @@ -36,16 +36,12 @@ func AttributesFromContext(ctx context.Context) []attribute.KeyValue { } func Middleware(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) { - var clientGeoLocation []string - if r.Header.Get("Cf-Ipcity") != "" { - clientGeoLocation = append(clientGeoLocation, r.Header.Get("Cf-Ipcity")) - } - if r.Header.Get("Cf-Ipcountry") != "" { - clientGeoLocation = append(clientGeoLocation, r.Header.Get("Cf-Ipcountry")) - } + ctx := ContextWithAttributes(r.Context(), - AttrClientIP(httpx.ClientIP(r)), - AttrGeoLocation(clientGeoLocation), + append( + AttrGeoLocation(httpx.ClientGeoLocation(r)), + AttrClientIP(httpx.ClientIP(r)), + )..., ) next(rw, r.WithContext(ctx)) diff --git a/otelx/semconv/context_test.go b/otelx/semconv/context_test.go index 979522b6..0484609f 100644 --- a/otelx/semconv/context_test.go +++ b/otelx/semconv/context_test.go @@ -10,6 +10,8 @@ import ( "github.com/gofrs/uuid" "github.com/stretchr/testify/assert" "go.opentelemetry.io/otel/attribute" + + "github.com/ory/x/httpx" ) func TestAttributesFromContext(t *testing.T) { @@ -21,13 +23,20 @@ func TestAttributesFromContext(t *testing.T) { assert.Len(t, AttributesFromContext(ctx), 1) uid1, uid2 := uuid.Must(uuid.NewV4()), uuid.Must(uuid.NewV4()) - ctx = ContextWithAttributes(ctx, AttrIdentityID(uid1), AttrClientIP("127.0.0.1"), AttrIdentityID(uid2), AttrGeoLocation([]string{"Berlin", "Germany"})) + location := httpx.GeoLocation{ + City: "Berlin", + Country: "Germany", + Region: "BE", + } + ctx = ContextWithAttributes(ctx, append(AttrGeoLocation(location), AttrIdentityID(uid1), AttrClientIP("127.0.0.1"), AttrIdentityID(uid2))...) attrs := AttributesFromContext(ctx) - assert.Len(t, attrs, 4, "should deduplicate") - assert.Equal(t, []attribute.KeyValue{ + assert.Len(t, attrs, 6, "should deduplicate") + assert.EqualValues(t, []attribute.KeyValue{ attribute.String(AttributeKeyNID.String(), nid.String()), + attribute.String(AttributeKeyGeoLocationCity.String(), "Berlin"), + attribute.String(AttributeKeyGeoLocationCountry.String(), "Germany"), + attribute.String(AttributeKeyGeoLocationRegion.String(), "BE"), attribute.String(AttributeKeyClientIP.String(), "127.0.0.1"), attribute.String(AttributeKeyIdentityID.String(), uid2.String()), - attribute.StringSlice(AttributeKeyGeoLocation.String(), []string{"Berlin", "Germany"}), }, attrs, "last duplicate attribute wins") } diff --git a/otelx/semconv/events.go b/otelx/semconv/events.go index 0e597164..215d37a7 100644 --- a/otelx/semconv/events.go +++ b/otelx/semconv/events.go @@ -7,6 +7,8 @@ package semconv import ( "github.com/gofrs/uuid" otelattr "go.opentelemetry.io/otel/attribute" + + "github.com/ory/x/httpx" ) type Event string @@ -22,10 +24,12 @@ func (a AttributeKey) String() string { } const ( - AttributeKeyIdentityID AttributeKey = "IdentityID" - AttributeKeyNID AttributeKey = "ProjectID" - AttributeKeyClientIP AttributeKey = "ClientIP" - AttributeKeyGeoLocation AttributeKey = "GeoLocation" + AttributeKeyIdentityID AttributeKey = "IdentityID" + AttributeKeyNID AttributeKey = "ProjectID" + AttributeKeyClientIP AttributeKey = "ClientIP" + AttributeKeyGeoLocationCity AttributeKey = "GeoLocationCity" + AttributeKeyGeoLocationRegion AttributeKey = "GeoLocationRegion" + AttributeKeyGeoLocationCountry AttributeKey = "GeoLocationCountry" ) func AttrIdentityID(val uuid.UUID) otelattr.KeyValue { @@ -40,6 +44,18 @@ func AttrClientIP(val string) otelattr.KeyValue { return otelattr.String(AttributeKeyClientIP.String(), val) } -func AttrGeoLocation(val []string) otelattr.KeyValue { - return otelattr.StringSlice(AttributeKeyGeoLocation.String(), val) +func AttrGeoLocation(val httpx.GeoLocation) []otelattr.KeyValue { + var geoLocationAttributes []otelattr.KeyValue + + if val.City != "" { + geoLocationAttributes = append(geoLocationAttributes, otelattr.String(AttributeKeyGeoLocationCity.String(), val.City)) + } + if val.Country != "" { + geoLocationAttributes = append(geoLocationAttributes, otelattr.String(AttributeKeyGeoLocationCountry.String(), val.Country)) + } + if val.Region != "" { + geoLocationAttributes = append(geoLocationAttributes, otelattr.String(AttributeKeyGeoLocationRegion.String(), val.Region)) + } + + return geoLocationAttributes }