Skip to content

Commit

Permalink
Merge commit '616d3f35b52cf5081bc9d3f4abc2b9b78e818370' into 20220221…
Browse files Browse the repository at this point in the history
…_add-tenant-package
  • Loading branch information
simonswine committed Feb 21, 2022
2 parents 19921f8 + 616d3f3 commit 1ad3f0a
Show file tree
Hide file tree
Showing 4 changed files with 454 additions and 0 deletions.
158 changes: 158 additions & 0 deletions tenant/resolver.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
package tenant

import (
"context"
"errors"
"net/http"
"strings"

"github.com/weaveworks/common/user"
)

var defaultResolver Resolver = NewSingleResolver()

// WithDefaultResolver updates the resolver used for the package methods.
func WithDefaultResolver(r Resolver) {
defaultResolver = r
}

// TenantID returns exactly a single tenant ID from the context. It should be
// used when a certain endpoint should only support exactly a single
// tenant ID. It returns an error user.ErrNoOrgID if there is no tenant ID
// supplied or user.ErrTooManyOrgIDs if there are multiple tenant IDs present.
//
// ignore stutter warning
//nolint:golint
func TenantID(ctx context.Context) (string, error) {
return defaultResolver.TenantID(ctx)
}

// TenantIDs returns all tenant IDs from the context. It should return
// normalized list of ordered and distinct tenant IDs (as produced by
// NormalizeTenantIDs).
//
// ignore stutter warning
//nolint:golint
func TenantIDs(ctx context.Context) ([]string, error) {
return defaultResolver.TenantIDs(ctx)
}

type Resolver interface {
// TenantID returns exactly a single tenant ID from the context. It should be
// used when a certain endpoint should only support exactly a single
// tenant ID. It returns an error user.ErrNoOrgID if there is no tenant ID
// supplied or user.ErrTooManyOrgIDs if there are multiple tenant IDs present.
TenantID(context.Context) (string, error)

// TenantIDs returns all tenant IDs from the context. It should return
// normalized list of ordered and distinct tenant IDs (as produced by
// NormalizeTenantIDs).
TenantIDs(context.Context) ([]string, error)
}

// NewSingleResolver creates a tenant resolver, which restricts all requests to
// be using a single tenant only. This allows a wider set of characters to be
// used within the tenant ID and should not impose a breaking change.
func NewSingleResolver() *SingleResolver {
return &SingleResolver{}
}

type SingleResolver struct {
}

// containsUnsafePathSegments will return true if the string is a directory
// reference like `.` and `..` or if any path separator character like `/` and
// `\` can be found.
func containsUnsafePathSegments(id string) bool {
// handle the relative reference to current and parent path.
if id == "." || id == ".." {
return true
}

return strings.ContainsAny(id, "\\/")
}

var errInvalidTenantID = errors.New("invalid tenant ID")

func (t *SingleResolver) TenantID(ctx context.Context) (string, error) {
//lint:ignore faillint wrapper around upstream method
id, err := user.ExtractOrgID(ctx)
if err != nil {
return "", err
}

if containsUnsafePathSegments(id) {
return "", errInvalidTenantID
}

return id, nil
}

func (t *SingleResolver) TenantIDs(ctx context.Context) ([]string, error) {
orgID, err := t.TenantID(ctx)
if err != nil {
return nil, err
}
return []string{orgID}, err
}

type MultiResolver struct {
}

// NewMultiResolver creates a tenant resolver, which allows request to have
// multiple tenant ids submitted separated by a '|' character. This enforces
// further limits on the character set allowed within tenants as detailed here:
// https://cortexmetrics.io/docs/guides/limitations/#tenant-id-naming)
func NewMultiResolver() *MultiResolver {
return &MultiResolver{}
}

func (t *MultiResolver) TenantID(ctx context.Context) (string, error) {
orgIDs, err := t.TenantIDs(ctx)
if err != nil {
return "", err
}

if len(orgIDs) > 1 {
return "", user.ErrTooManyOrgIDs
}

return orgIDs[0], nil
}

func (t *MultiResolver) TenantIDs(ctx context.Context) ([]string, error) {
//lint:ignore faillint wrapper around upstream method
orgID, err := user.ExtractOrgID(ctx)
if err != nil {
return nil, err
}

orgIDs := strings.Split(orgID, tenantIDsLabelSeparator)
for _, orgID := range orgIDs {
if err := ValidTenantID(orgID); err != nil {
return nil, err
}
if containsUnsafePathSegments(orgID) {
return nil, errInvalidTenantID
}
}

return NormalizeTenantIDs(orgIDs), nil
}

// ExtractTenantIDFromHTTPRequest extracts a single TenantID through a given
// resolver directly from a HTTP request.
func ExtractTenantIDFromHTTPRequest(req *http.Request) (string, context.Context, error) {
//lint:ignore faillint wrapper around upstream method
_, ctx, err := user.ExtractOrgIDFromHTTPRequest(req)
if err != nil {
return "", nil, err
}

tenantID, err := defaultResolver.TenantID(ctx)
if err != nil {
return "", nil, err
}

return tenantID, ctx, nil
}
149 changes: 149 additions & 0 deletions tenant/resolver_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
package tenant

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
"github.com/weaveworks/common/user"
)

func strptr(s string) *string {
return &s
}

type resolverTestCase struct {
name string
headerValue *string
errTenantID error
errTenantIDs error
tenantID string
tenantIDs []string
}

func (tc *resolverTestCase) test(r Resolver) func(t *testing.T) {
return func(t *testing.T) {

ctx := context.Background()
if tc.headerValue != nil {
ctx = user.InjectOrgID(ctx, *tc.headerValue)
}

tenantID, err := r.TenantID(ctx)
if tc.errTenantID != nil {
assert.Equal(t, tc.errTenantID, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tc.tenantID, tenantID)
}

tenantIDs, err := r.TenantIDs(ctx)
if tc.errTenantIDs != nil {
assert.Equal(t, tc.errTenantIDs, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tc.tenantIDs, tenantIDs)
}
}
}

var commonResolverTestCases = []resolverTestCase{
{
name: "no-header",
errTenantID: user.ErrNoOrgID,
errTenantIDs: user.ErrNoOrgID,
},
{
name: "empty",
headerValue: strptr(""),
tenantIDs: []string{""},
},
{
name: "single-tenant",
headerValue: strptr("tenant-a"),
tenantID: "tenant-a",
tenantIDs: []string{"tenant-a"},
},
{
name: "parent-dir",
headerValue: strptr(".."),
errTenantID: errInvalidTenantID,
errTenantIDs: errInvalidTenantID,
},
{
name: "current-dir",
headerValue: strptr("."),
errTenantID: errInvalidTenantID,
errTenantIDs: errInvalidTenantID,
},
}

func TestSingleResolver(t *testing.T) {
r := NewSingleResolver()
for _, tc := range append(commonResolverTestCases, []resolverTestCase{
{
name: "multi-tenant",
headerValue: strptr("tenant-a|tenant-b"),
tenantID: "tenant-a|tenant-b",
tenantIDs: []string{"tenant-a|tenant-b"},
},
{
name: "containing-forward-slash",
headerValue: strptr("forward/slash"),
errTenantID: errInvalidTenantID,
errTenantIDs: errInvalidTenantID,
},
{
name: "containing-backward-slash",
headerValue: strptr(`backward\slash`),
errTenantID: errInvalidTenantID,
errTenantIDs: errInvalidTenantID,
},
}...) {
t.Run(tc.name, tc.test(r))
}
}

func TestMultiResolver(t *testing.T) {
r := NewMultiResolver()
for _, tc := range append(commonResolverTestCases, []resolverTestCase{
{
name: "multi-tenant",
headerValue: strptr("tenant-a|tenant-b"),
errTenantID: user.ErrTooManyOrgIDs,
tenantIDs: []string{"tenant-a", "tenant-b"},
},
{
name: "multi-tenant-wrong-order",
headerValue: strptr("tenant-b|tenant-a"),
errTenantID: user.ErrTooManyOrgIDs,
tenantIDs: []string{"tenant-a", "tenant-b"},
},
{
name: "multi-tenant-duplicate-order",
headerValue: strptr("tenant-b|tenant-b|tenant-a"),
errTenantID: user.ErrTooManyOrgIDs,
tenantIDs: []string{"tenant-a", "tenant-b"},
},
{
name: "multi-tenant-with-relative-path",
headerValue: strptr("tenant-a|tenant-b|.."),
errTenantID: errInvalidTenantID,
errTenantIDs: errInvalidTenantID,
},
{
name: "containing-forward-slash",
headerValue: strptr("forward/slash"),
errTenantID: &errTenantIDUnsupportedCharacter{pos: 7, tenantID: "forward/slash"},
errTenantIDs: &errTenantIDUnsupportedCharacter{pos: 7, tenantID: "forward/slash"},
},
{
name: "containing-backward-slash",
headerValue: strptr(`backward\slash`),
errTenantID: &errTenantIDUnsupportedCharacter{pos: 8, tenantID: "backward\\slash"},
errTenantIDs: &errTenantIDUnsupportedCharacter{pos: 8, tenantID: "backward\\slash"},
},
}...) {
t.Run(tc.name, tc.test(r))
}
}
Loading

0 comments on commit 1ad3f0a

Please sign in to comment.