-
Notifications
You must be signed in to change notification settings - Fork 69
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge commit '616d3f35b52cf5081bc9d3f4abc2b9b78e818370' into 20220221…
…_add-tenant-package
- Loading branch information
Showing
4 changed files
with
454 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
package tenant | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"net/http" | ||
"strings" | ||
|
||
"github.com/weaveworks/common/user" | ||
) | ||
|
||
var defaultResolver Resolver = NewSingleResolver() | ||
|
||
// WithDefaultResolver updates the resolver used for the package methods. | ||
func WithDefaultResolver(r Resolver) { | ||
defaultResolver = r | ||
} | ||
|
||
// TenantID returns exactly a single tenant ID from the context. It should be | ||
// used when a certain endpoint should only support exactly a single | ||
// tenant ID. It returns an error user.ErrNoOrgID if there is no tenant ID | ||
// supplied or user.ErrTooManyOrgIDs if there are multiple tenant IDs present. | ||
// | ||
// ignore stutter warning | ||
//nolint:golint | ||
func TenantID(ctx context.Context) (string, error) { | ||
return defaultResolver.TenantID(ctx) | ||
} | ||
|
||
// TenantIDs returns all tenant IDs from the context. It should return | ||
// normalized list of ordered and distinct tenant IDs (as produced by | ||
// NormalizeTenantIDs). | ||
// | ||
// ignore stutter warning | ||
//nolint:golint | ||
func TenantIDs(ctx context.Context) ([]string, error) { | ||
return defaultResolver.TenantIDs(ctx) | ||
} | ||
|
||
type Resolver interface { | ||
// TenantID returns exactly a single tenant ID from the context. It should be | ||
// used when a certain endpoint should only support exactly a single | ||
// tenant ID. It returns an error user.ErrNoOrgID if there is no tenant ID | ||
// supplied or user.ErrTooManyOrgIDs if there are multiple tenant IDs present. | ||
TenantID(context.Context) (string, error) | ||
|
||
// TenantIDs returns all tenant IDs from the context. It should return | ||
// normalized list of ordered and distinct tenant IDs (as produced by | ||
// NormalizeTenantIDs). | ||
TenantIDs(context.Context) ([]string, error) | ||
} | ||
|
||
// NewSingleResolver creates a tenant resolver, which restricts all requests to | ||
// be using a single tenant only. This allows a wider set of characters to be | ||
// used within the tenant ID and should not impose a breaking change. | ||
func NewSingleResolver() *SingleResolver { | ||
return &SingleResolver{} | ||
} | ||
|
||
type SingleResolver struct { | ||
} | ||
|
||
// containsUnsafePathSegments will return true if the string is a directory | ||
// reference like `.` and `..` or if any path separator character like `/` and | ||
// `\` can be found. | ||
func containsUnsafePathSegments(id string) bool { | ||
// handle the relative reference to current and parent path. | ||
if id == "." || id == ".." { | ||
return true | ||
} | ||
|
||
return strings.ContainsAny(id, "\\/") | ||
} | ||
|
||
var errInvalidTenantID = errors.New("invalid tenant ID") | ||
|
||
func (t *SingleResolver) TenantID(ctx context.Context) (string, error) { | ||
//lint:ignore faillint wrapper around upstream method | ||
id, err := user.ExtractOrgID(ctx) | ||
if err != nil { | ||
return "", err | ||
} | ||
|
||
if containsUnsafePathSegments(id) { | ||
return "", errInvalidTenantID | ||
} | ||
|
||
return id, nil | ||
} | ||
|
||
func (t *SingleResolver) TenantIDs(ctx context.Context) ([]string, error) { | ||
orgID, err := t.TenantID(ctx) | ||
if err != nil { | ||
return nil, err | ||
} | ||
return []string{orgID}, err | ||
} | ||
|
||
type MultiResolver struct { | ||
} | ||
|
||
// NewMultiResolver creates a tenant resolver, which allows request to have | ||
// multiple tenant ids submitted separated by a '|' character. This enforces | ||
// further limits on the character set allowed within tenants as detailed here: | ||
// https://cortexmetrics.io/docs/guides/limitations/#tenant-id-naming) | ||
func NewMultiResolver() *MultiResolver { | ||
return &MultiResolver{} | ||
} | ||
|
||
func (t *MultiResolver) TenantID(ctx context.Context) (string, error) { | ||
orgIDs, err := t.TenantIDs(ctx) | ||
if err != nil { | ||
return "", err | ||
} | ||
|
||
if len(orgIDs) > 1 { | ||
return "", user.ErrTooManyOrgIDs | ||
} | ||
|
||
return orgIDs[0], nil | ||
} | ||
|
||
func (t *MultiResolver) TenantIDs(ctx context.Context) ([]string, error) { | ||
//lint:ignore faillint wrapper around upstream method | ||
orgID, err := user.ExtractOrgID(ctx) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
orgIDs := strings.Split(orgID, tenantIDsLabelSeparator) | ||
for _, orgID := range orgIDs { | ||
if err := ValidTenantID(orgID); err != nil { | ||
return nil, err | ||
} | ||
if containsUnsafePathSegments(orgID) { | ||
return nil, errInvalidTenantID | ||
} | ||
} | ||
|
||
return NormalizeTenantIDs(orgIDs), nil | ||
} | ||
|
||
// ExtractTenantIDFromHTTPRequest extracts a single TenantID through a given | ||
// resolver directly from a HTTP request. | ||
func ExtractTenantIDFromHTTPRequest(req *http.Request) (string, context.Context, error) { | ||
//lint:ignore faillint wrapper around upstream method | ||
_, ctx, err := user.ExtractOrgIDFromHTTPRequest(req) | ||
if err != nil { | ||
return "", nil, err | ||
} | ||
|
||
tenantID, err := defaultResolver.TenantID(ctx) | ||
if err != nil { | ||
return "", nil, err | ||
} | ||
|
||
return tenantID, ctx, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
package tenant | ||
|
||
import ( | ||
"context" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/assert" | ||
"github.com/weaveworks/common/user" | ||
) | ||
|
||
func strptr(s string) *string { | ||
return &s | ||
} | ||
|
||
type resolverTestCase struct { | ||
name string | ||
headerValue *string | ||
errTenantID error | ||
errTenantIDs error | ||
tenantID string | ||
tenantIDs []string | ||
} | ||
|
||
func (tc *resolverTestCase) test(r Resolver) func(t *testing.T) { | ||
return func(t *testing.T) { | ||
|
||
ctx := context.Background() | ||
if tc.headerValue != nil { | ||
ctx = user.InjectOrgID(ctx, *tc.headerValue) | ||
} | ||
|
||
tenantID, err := r.TenantID(ctx) | ||
if tc.errTenantID != nil { | ||
assert.Equal(t, tc.errTenantID, err) | ||
} else { | ||
assert.NoError(t, err) | ||
assert.Equal(t, tc.tenantID, tenantID) | ||
} | ||
|
||
tenantIDs, err := r.TenantIDs(ctx) | ||
if tc.errTenantIDs != nil { | ||
assert.Equal(t, tc.errTenantIDs, err) | ||
} else { | ||
assert.NoError(t, err) | ||
assert.Equal(t, tc.tenantIDs, tenantIDs) | ||
} | ||
} | ||
} | ||
|
||
var commonResolverTestCases = []resolverTestCase{ | ||
{ | ||
name: "no-header", | ||
errTenantID: user.ErrNoOrgID, | ||
errTenantIDs: user.ErrNoOrgID, | ||
}, | ||
{ | ||
name: "empty", | ||
headerValue: strptr(""), | ||
tenantIDs: []string{""}, | ||
}, | ||
{ | ||
name: "single-tenant", | ||
headerValue: strptr("tenant-a"), | ||
tenantID: "tenant-a", | ||
tenantIDs: []string{"tenant-a"}, | ||
}, | ||
{ | ||
name: "parent-dir", | ||
headerValue: strptr(".."), | ||
errTenantID: errInvalidTenantID, | ||
errTenantIDs: errInvalidTenantID, | ||
}, | ||
{ | ||
name: "current-dir", | ||
headerValue: strptr("."), | ||
errTenantID: errInvalidTenantID, | ||
errTenantIDs: errInvalidTenantID, | ||
}, | ||
} | ||
|
||
func TestSingleResolver(t *testing.T) { | ||
r := NewSingleResolver() | ||
for _, tc := range append(commonResolverTestCases, []resolverTestCase{ | ||
{ | ||
name: "multi-tenant", | ||
headerValue: strptr("tenant-a|tenant-b"), | ||
tenantID: "tenant-a|tenant-b", | ||
tenantIDs: []string{"tenant-a|tenant-b"}, | ||
}, | ||
{ | ||
name: "containing-forward-slash", | ||
headerValue: strptr("forward/slash"), | ||
errTenantID: errInvalidTenantID, | ||
errTenantIDs: errInvalidTenantID, | ||
}, | ||
{ | ||
name: "containing-backward-slash", | ||
headerValue: strptr(`backward\slash`), | ||
errTenantID: errInvalidTenantID, | ||
errTenantIDs: errInvalidTenantID, | ||
}, | ||
}...) { | ||
t.Run(tc.name, tc.test(r)) | ||
} | ||
} | ||
|
||
func TestMultiResolver(t *testing.T) { | ||
r := NewMultiResolver() | ||
for _, tc := range append(commonResolverTestCases, []resolverTestCase{ | ||
{ | ||
name: "multi-tenant", | ||
headerValue: strptr("tenant-a|tenant-b"), | ||
errTenantID: user.ErrTooManyOrgIDs, | ||
tenantIDs: []string{"tenant-a", "tenant-b"}, | ||
}, | ||
{ | ||
name: "multi-tenant-wrong-order", | ||
headerValue: strptr("tenant-b|tenant-a"), | ||
errTenantID: user.ErrTooManyOrgIDs, | ||
tenantIDs: []string{"tenant-a", "tenant-b"}, | ||
}, | ||
{ | ||
name: "multi-tenant-duplicate-order", | ||
headerValue: strptr("tenant-b|tenant-b|tenant-a"), | ||
errTenantID: user.ErrTooManyOrgIDs, | ||
tenantIDs: []string{"tenant-a", "tenant-b"}, | ||
}, | ||
{ | ||
name: "multi-tenant-with-relative-path", | ||
headerValue: strptr("tenant-a|tenant-b|.."), | ||
errTenantID: errInvalidTenantID, | ||
errTenantIDs: errInvalidTenantID, | ||
}, | ||
{ | ||
name: "containing-forward-slash", | ||
headerValue: strptr("forward/slash"), | ||
errTenantID: &errTenantIDUnsupportedCharacter{pos: 7, tenantID: "forward/slash"}, | ||
errTenantIDs: &errTenantIDUnsupportedCharacter{pos: 7, tenantID: "forward/slash"}, | ||
}, | ||
{ | ||
name: "containing-backward-slash", | ||
headerValue: strptr(`backward\slash`), | ||
errTenantID: &errTenantIDUnsupportedCharacter{pos: 8, tenantID: "backward\\slash"}, | ||
errTenantIDs: &errTenantIDUnsupportedCharacter{pos: 8, tenantID: "backward\\slash"}, | ||
}, | ||
}...) { | ||
t.Run(tc.name, tc.test(r)) | ||
} | ||
} |
Oops, something went wrong.