Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

providers: group caching mechanism for okta provider #184

Merged
merged 2 commits into from
May 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ require (
github.com/sirupsen/logrus v1.3.0
golang.org/x/net v0.0.0-20190311183353-d8887717615a // indirect
golang.org/x/oauth2 v0.0.0-20190130055435-99b60b757ec1
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4
google.golang.org/api v0.1.0
gopkg.in/yaml.v2 v2.2.2
)
17 changes: 8 additions & 9 deletions internal/auth/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ import (
// CookieHTTPOnly - bool - set httponly cookie flag
// RequestTimeout - duration - overall request timeout
// AuthCodeSecret - string - the seed string for secure auth codes (optionally base64 encoded)
// GroupCacheProviderTTL - time.Duration - cache TTL for the group-cache provider used for on-demand group caching
// GroupsCacheRefreshTTL - time.Duratoin - cache TTL for the groups fillcache mechanism used to preemptively fill group caches
// PassHostHeader - bool - pass the request Host Header to upstream (default true)
// SkipProviderButton - bool - if true, will skip sign-in-page to directly reach the next step: oauth/start
// PassUserHeaders - bool (default true) - pass X-Forwarded-User and X-Forwarded-Email information to upstream
Expand Down Expand Up @@ -89,6 +91,7 @@ type Options struct {

AuthCodeSecret string `envconfig:"AUTH_CODE_SECRET"`

GroupCacheProviderTTL time.Duration `envconfig:"GROUP_CACHE_PROVIDER_TTL" default:"10m"`
GroupsCacheRefreshTTL time.Duration `envconfig:"GROUPS_CACHE_REFRESH_TTL" default:"10m"`
SessionLifetimeTTL time.Duration `envconfig:"SESSION_LIFETIME_TTL" default:"720h"`

Expand Down Expand Up @@ -301,7 +304,10 @@ func newProvider(o *Options) (providers.Provider, error) {
if err != nil {
return nil, err
}
singleFlightProvider = providers.NewSingleFlightProvider(oktaProvider)
tags := []string{"provider:okta"}

groupsCache := providers.NewGroupCache(oktaProvider, o.GroupCacheProviderTTL, oktaProvider.StatsdClient, tags)
singleFlightProvider = providers.NewSingleFlightProvider(groupsCache)
default:
return nil, fmt.Errorf("unimplemented provider: %q", o.Provider)
}
Expand Down Expand Up @@ -334,14 +340,7 @@ func AssignStatsdClient(opts *Options) func(*Authenticator) error {
"statsd client is running")

proxy.StatsdClient = StatsdClient
switch v := proxy.provider.(type) {
case *providers.GoogleProvider:
v.SetStatsdClient(StatsdClient)
case *providers.SingleFlightProvider:
v.AssignStatsdClient(StatsdClient)
default:
logger.Info("provider does not have statsd client")
}
proxy.provider.SetStatsdClient(StatsdClient)
return nil
}
}
125 changes: 125 additions & 0 deletions internal/auth/providers/group_cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
package providers

import (
"sort"
"strings"
"time"

"github.com/buzzfeed/sso/internal/pkg/groups"
"github.com/buzzfeed/sso/internal/pkg/sessions"
"github.com/datadog/datadog-go/statsd"
)

var (
// This is a compile-time check to make sure our types correctly implement the interface:
// https://medium.com/@matryer/golang-tip-compile-time-checks-to-ensure-your-type-satisfies-an-interface-c167afed3aae
_ Provider = &GroupCache{}
)

type Cache interface {
Get(key groups.CacheKey) (groups.CacheEntry, bool)
Set(key groups.CacheKey, val groups.CacheEntry)
Purge(key groups.CacheKey)
}

// GroupCache is designed to act as a provider while wrapping subsequent provider's functions,
// while also offering a caching mechanism (specifically used for group caching at the moment).
type GroupCache struct {
statsdClient *statsd.Client
provider Provider
cache Cache
}

// NewGroupCache returns a new GroupCache (which includes a LocalCache from the groups package)
func NewGroupCache(provider Provider, ttl time.Duration, statsdClient *statsd.Client, tags []string) *GroupCache {
return &GroupCache{
statsdClient: statsdClient,
provider: provider,
cache: groups.NewLocalCache(ttl, statsdClient, tags),
}
}

// SetStatsdClient calls the provider's SetStatsdClient function.
func (p *GroupCache) SetStatsdClient(statsdClient *statsd.Client) {
p.statsdClient = statsdClient
p.provider.SetStatsdClient(statsdClient)
}

// Data returns the provider Data
func (p *GroupCache) Data() *ProviderData {
return p.provider.Data()
}

// Redeem wraps the provider's Redeem function
func (p *GroupCache) Redeem(redirectURL, code string) (*sessions.SessionState, error) {
return p.provider.Redeem(redirectURL, code)
}

// ValidateSessionState wraps the provider's ValidateSessionState function.
func (p *GroupCache) ValidateSessionState(s *sessions.SessionState) bool {
return p.provider.ValidateSessionState(s)
}

// GetSignInURL wraps the provider's GetSignInURL function.
func (p *GroupCache) GetSignInURL(redirectURI, finalRedirect string) string {
return p.provider.GetSignInURL(redirectURI, finalRedirect)
}

// RefreshSessionIfNeeded wraps the provider's RefreshSessionIfNeeded function.
func (p *GroupCache) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) {
return p.provider.RefreshSessionIfNeeded(s)
}

// ValidateGroupMembership wraps the provider's ValidateGroupMembership around calls to check local cache for group membership information.
func (p *GroupCache) ValidateGroupMembership(email string, allowedGroups []string, accessToken string) ([]string, error) {
// Create a cache key and check to see if it's in the cache. If not, call the provider's
// ValidateGroupMembership function and cache the result.
sort.Strings(allowedGroups)
key := groups.CacheKey{
Email: email,
AllowedGroups: strings.Join(allowedGroups, ","),
}

val, ok := p.cache.Get(key)
if ok {
p.statsdClient.Incr("provider.groupcache",
[]string{
"action:ValidateGroupMembership",
"cache:hit",
}, 1.0)
return val.ValidGroups, nil
}

// The key isn't in the cache, so pass the call on to the subsequent provider
p.statsdClient.Incr("provider.groupcache",
[]string{
"action:ValidateGroupMembership",
"cache:miss",
}, 1.0)

validGroups, err := p.provider.ValidateGroupMembership(email, allowedGroups, accessToken)
if err != nil {
return nil, err
}

entry := groups.CacheEntry{
ValidGroups: validGroups,
}
p.cache.Set(key, entry)
return validGroups, nil
}

// Revoke wraps the provider's Revoke function.
func (p *GroupCache) Revoke(s *sessions.SessionState) error {
return p.provider.Revoke(s)
}

// RefreshAccessToken wraps the provider's RefreshAccessToken function.
func (p *GroupCache) RefreshAccessToken(refreshToken string) (string, time.Duration, error) {
return p.provider.RefreshAccessToken(refreshToken)
}

// Stop calls the providers stop function.
func (p *GroupCache) Stop() {
p.provider.Stop()
}
164 changes: 164 additions & 0 deletions internal/auth/providers/group_cache_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
package providers

import (
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"testing"
"time"

"github.com/buzzfeed/sso/internal/pkg/groups"
"github.com/buzzfeed/sso/internal/pkg/testutil"
"github.com/datadog/datadog-go/statsd"
)

func newTestProviderServer(body []byte, code int) (*url.URL, *httptest.Server) {
s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(code)
rw.Write(body)
}))
u, _ := url.Parse(s.URL)
return u, s
}

// We define an Okta provider because the cache being tested here is currently
// only used by the Okta provider.
func newTestProvider(providerData *ProviderData, t *testing.T) *OktaProvider {
if providerData == nil {
providerData = &ProviderData{
ProviderName: "",
ClientID: "",
ClientSecret: "",
SignInURL: &url.URL{},
RedeemURL: &url.URL{},
RevokeURL: &url.URL{},
ProfileURL: &url.URL{},
ValidateURL: &url.URL{},
Scope: ""}
}
provider, err := NewOktaProvider(providerData, "test.okta.com", "default")
if err != nil {
t.Fatalf("new okta provider returns unexpected error: %q", err)
}
return provider
}

func TestCachedGroupsAreNotUsed(t *testing.T) {

type serverResp struct {
Groups []string `json:"groups"`
}

//set up the test server
provider := newTestProvider(nil, t)
resp := serverResp{
Groups: []string{"allowedGroup1", "allowedGroup2"},
}
body, err := json.Marshal(resp)
testutil.Equal(t, nil, err)
var server *httptest.Server
provider.ProfileURL, server = newTestProviderServer(body, http.StatusOK)
defer server.Close()

// set up the cache
ttl := time.Millisecond * 10
statsdClient, _ := statsd.New("127.0.0.1:8125")
tags := []string{"tags:test"}
GroupsCache := NewGroupCache(provider, ttl, statsdClient, tags)

// The below cached `MatchedGroups` should not be returned because the list of
// allowed groups we pass in are different to the cached `AllowedGroups`. It should instead
// make a call to the Provider (our test server).
cacheKey := groups.CacheKey{
Email: "email@test.com",
AllowedGroups: "allowedGroup1",
}
cacheData := groups.CacheEntry{
ValidGroups: []string{"allowedGroup1", "allowedGroup2", "allowedGroup3"},
}
GroupsCache.cache.Set(cacheKey, cacheData)

// If the groups stored in the `serverResp` struct are returned, it means the
// cache was skipped because the allowedGroups that we pass in are different to
// those stored in the cache. This is the outcome we expect in this test.
email := "email@test.com"
actualAllowedGroups := []string{"allowedGroup2", "allowedGroup1"}
accessToken := "123456"
actualMatchedGroups, err := GroupsCache.ValidateGroupMembership(email, actualAllowedGroups, accessToken)
if err != nil {
t.Fatalf("unexpected error caused while validating group membership: %q", err)
}
if !reflect.DeepEqual(actualMatchedGroups, actualAllowedGroups) {
t.Logf("expected groups to match: %q", actualAllowedGroups)
t.Logf(" actual groups returned: %q", actualMatchedGroups)
if reflect.DeepEqual(actualMatchedGroups, cacheData.ValidGroups) {
t.Fatalf("It looks like the groups in the cache were returned. In this case, the cache should have been skipped")
}
t.Fatalf("Unexpected groups returned.")
}

// We want to test that the groups returned are *now* cached, so we change the resp
// that the test server will send. If it matches this new response, we know the cache
// was skipped, which we do not expect to happen.
resp = serverResp{
Groups: []string{"allowedGroup1", "allowedGroup2"},
}
body, err = json.Marshal(resp)
testutil.Equal(t, nil, err)
provider.ProfileURL, server = newTestProviderServer(body, http.StatusOK)
defer server.Close()

actualMatchedGroups, err = GroupsCache.ValidateGroupMembership(email, actualAllowedGroups, accessToken)
if err != nil {
t.Fatalf("unexpected error caused while validating group membership: %q", err)
}
if !reflect.DeepEqual(actualMatchedGroups, actualAllowedGroups) {
t.Logf("expected groups to match: %q", actualAllowedGroups)
t.Logf(" actual groups returned: %q", actualMatchedGroups)
if reflect.DeepEqual(actualMatchedGroups, resp.Groups) {
t.Fatalf("It looks like the cache was skipped, and the provider was called. In this case, the cache should have been used")
}
}

}

// maybe we can skip this test, as the above test technically tests for the same thing, but just in a slightly more obfuscated way.
func TestCachedGroupsAreUsed(t *testing.T) {
provider := newTestProvider(nil, t)

// set up the cache
ttl := time.Millisecond * 10
statsdClient, _ := statsd.New("127.0.0.1:8125")
tags := []string{"tags:test"}
GroupsCache := NewGroupCache(provider, ttl, statsdClient, tags)

// In this case, the below `MatchedGroups` should be returned because the list of
// allowed groups are pass in match them.
cacheKey := groups.CacheKey{
Email: "email@test.com",
AllowedGroups: "allowedGroup1,allowedGroup2,allowedGroup3",
}
cacheData := groups.CacheEntry{
ValidGroups: []string{"allowedGroup1", "allowedGroup2", "allowedGroup3"},
}
GroupsCache.cache.Set(cacheKey, cacheData)

// We haven't set up a test server in this case because we pass in a list of allowed groups
// that match the allowed groups in the cache, so the cached matched groups should be used
// If the cache is skipped, an error will probably be returned when it tries to call the
// provider endpoint.
email := "email@test.com"
actualAllowedGroups := []string{"allowedGroup1", "allowedGroup2", "allowedGroup3"}
accessToken := "123456"
actualMatchedGroups, err := GroupsCache.ValidateGroupMembership(email, actualAllowedGroups, accessToken)
if err != nil {
t.Fatalf("unexpected error caused while validating group membership: %q", err)
}
if !reflect.DeepEqual(actualMatchedGroups, actualAllowedGroups) {
t.Logf("expected groups to match: %q", actualAllowedGroups)
t.Logf(" actual groups returned: %q", actualMatchedGroups)
t.Fatalf("unexpected groups returned")
}
}
6 changes: 5 additions & 1 deletion internal/auth/providers/okta.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ func NewOktaProvider(p *ProviderData, OrgURL, providerServerID string) (*OktaPro
return oktaProvider, nil
}

// Sets the providers StatsdClient
func (p *OktaProvider) SetStatsdClient(statsdClient *statsd.Client) {
p.StatsdClient = statsdClient
}

// ValidateSessionState attempts to validate the session state's access token.
func (p *OktaProvider) ValidateSessionState(s *sessions.SessionState) bool {
if s.AccessToken == "" {
Expand Down Expand Up @@ -323,7 +328,6 @@ func (p *OktaProvider) ValidateGroupMembership(email string, allowedGroups []str
if len(allowedGroups) == 0 {
return []string{}, nil
}

userinfo, err := p.GetUserProfile(accessToken)
if err != nil {
return nil, err
Expand Down
6 changes: 6 additions & 0 deletions internal/auth/providers/provider_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,14 @@ import (

log "github.com/buzzfeed/sso/internal/pkg/logging"
"github.com/buzzfeed/sso/internal/pkg/sessions"
"github.com/datadog/datadog-go/statsd"
)

// SetStatsdClient fulfills the Provider interface
func (p *ProviderData) SetStatsdClient(*statsd.Client) {
return
}

// Redeem takes in a redirect url and code and calls the redeem url endpoint, returning a session state if a valid
// access token is redeemed.
func (p *ProviderData) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) {
Expand Down
2 changes: 2 additions & 0 deletions internal/auth/providers/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"time"

"github.com/buzzfeed/sso/internal/pkg/sessions"
"github.com/datadog/datadog-go/statsd"
)

var (
Expand Down Expand Up @@ -36,6 +37,7 @@ const (

// Provider is an interface exposing functions necessary to authenticate with a given provider.
type Provider interface {
SetStatsdClient(*statsd.Client)
Data() *ProviderData
Redeem(string, string) (*sessions.SessionState, error)
ValidateSessionState(*sessions.SessionState) bool
Expand Down
Loading