Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tenant package #143

Merged
merged 4 commits into from
Mar 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions tenant/resolver.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
package tenant

import (
"context"
"errors"
"net/http"
"strings"

"github.com/weaveworks/common/user"
)

var defaultResolver Resolver = NewSingleResolver()

// WithDefaultResolver updates the resolver used for the package methods.
func WithDefaultResolver(r Resolver) {
defaultResolver = r
}

// TenantID returns exactly a single tenant ID from the context. It should be
// used when a certain endpoint should only support exactly a single
// tenant ID. It returns an error user.ErrNoOrgID if there is no tenant ID
// supplied or user.ErrTooManyOrgIDs if there are multiple tenant IDs present.
//
// ignore stutter warning
//nolint:revive
func TenantID(ctx context.Context) (string, error) {
return defaultResolver.TenantID(ctx)
}

// TenantIDs returns all tenant IDs from the context. It should return
// normalized list of ordered and distinct tenant IDs (as produced by
// NormalizeTenantIDs).
//
// ignore stutter warning
//nolint:revive
func TenantIDs(ctx context.Context) ([]string, error) {
return defaultResolver.TenantIDs(ctx)
}

type Resolver interface {
// TenantID returns exactly a single tenant ID from the context. It should be
// used when a certain endpoint should only support exactly a single
// tenant ID. It returns an error user.ErrNoOrgID if there is no tenant ID
// supplied or user.ErrTooManyOrgIDs if there are multiple tenant IDs present.
TenantID(context.Context) (string, error)

// TenantIDs returns all tenant IDs from the context. It should return
// normalized list of ordered and distinct tenant IDs (as produced by
// NormalizeTenantIDs).
TenantIDs(context.Context) ([]string, error)
}

// NewSingleResolver creates a tenant resolver, which restricts all requests to
// be using a single tenant only. This allows a wider set of characters to be
// used within the tenant ID and should not impose a breaking change.
func NewSingleResolver() *SingleResolver {
return &SingleResolver{}
}

type SingleResolver struct {
}

// containsUnsafePathSegments will return true if the string is a directory
// reference like `.` and `..` or if any path separator character like `/` and
// `\` can be found.
func containsUnsafePathSegments(id string) bool {
// handle the relative reference to current and parent path.
if id == "." || id == ".." {
return true
}

return strings.ContainsAny(id, "\\/")
}

var errInvalidTenantID = errors.New("invalid tenant ID")

func (t *SingleResolver) TenantID(ctx context.Context) (string, error) {
//lint:ignore faillint wrapper around upstream method
id, err := user.ExtractOrgID(ctx)
if err != nil {
return "", err
}

if containsUnsafePathSegments(id) {
return "", errInvalidTenantID
}

return id, nil
}

func (t *SingleResolver) TenantIDs(ctx context.Context) ([]string, error) {
orgID, err := t.TenantID(ctx)
if err != nil {
return nil, err
}
return []string{orgID}, err
}

type MultiResolver struct {
}

// NewMultiResolver creates a tenant resolver, which allows request to have
// multiple tenant ids submitted separated by a '|' character. This enforces
// further limits on the character set allowed within tenants as detailed here:
// https://cortexmetrics.io/docs/guides/limitations/#tenant-id-naming)
func NewMultiResolver() *MultiResolver {
return &MultiResolver{}
}

func (t *MultiResolver) TenantID(ctx context.Context) (string, error) {
orgIDs, err := t.TenantIDs(ctx)
if err != nil {
return "", err
}

if len(orgIDs) > 1 {
return "", user.ErrTooManyOrgIDs
}

return orgIDs[0], nil
}

func (t *MultiResolver) TenantIDs(ctx context.Context) ([]string, error) {
//lint:ignore faillint wrapper around upstream method
orgID, err := user.ExtractOrgID(ctx)
if err != nil {
return nil, err
}

orgIDs := strings.Split(orgID, tenantIDsLabelSeparator)
for _, orgID := range orgIDs {
if err := ValidTenantID(orgID); err != nil {
return nil, err
}
if containsUnsafePathSegments(orgID) {
return nil, errInvalidTenantID
}
}

return NormalizeTenantIDs(orgIDs), nil
}

// ExtractTenantIDFromHTTPRequest extracts a single TenantID through a given
// resolver directly from a HTTP request.
func ExtractTenantIDFromHTTPRequest(req *http.Request) (string, context.Context, error) {
//lint:ignore faillint wrapper around upstream method
_, ctx, err := user.ExtractOrgIDFromHTTPRequest(req)
if err != nil {
return "", nil, err
}

tenantID, err := defaultResolver.TenantID(ctx)
if err != nil {
return "", nil, err
}

return tenantID, ctx, nil
}
149 changes: 149 additions & 0 deletions tenant/resolver_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
package tenant

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
"github.com/weaveworks/common/user"
)

func strptr(s string) *string {
return &s
}

type resolverTestCase struct {
name string
headerValue *string
errTenantID error
errTenantIDs error
tenantID string
tenantIDs []string
}

func (tc *resolverTestCase) test(r Resolver) func(t *testing.T) {
return func(t *testing.T) {

ctx := context.Background()
if tc.headerValue != nil {
ctx = user.InjectOrgID(ctx, *tc.headerValue)
}

tenantID, err := r.TenantID(ctx)
if tc.errTenantID != nil {
assert.Equal(t, tc.errTenantID, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tc.tenantID, tenantID)
}

tenantIDs, err := r.TenantIDs(ctx)
if tc.errTenantIDs != nil {
assert.Equal(t, tc.errTenantIDs, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tc.tenantIDs, tenantIDs)
}
}
}

var commonResolverTestCases = []resolverTestCase{
{
name: "no-header",
errTenantID: user.ErrNoOrgID,
errTenantIDs: user.ErrNoOrgID,
},
{
name: "empty",
headerValue: strptr(""),
tenantIDs: []string{""},
},
{
name: "single-tenant",
headerValue: strptr("tenant-a"),
tenantID: "tenant-a",
tenantIDs: []string{"tenant-a"},
},
{
name: "parent-dir",
headerValue: strptr(".."),
errTenantID: errInvalidTenantID,
errTenantIDs: errInvalidTenantID,
},
{
name: "current-dir",
headerValue: strptr("."),
errTenantID: errInvalidTenantID,
errTenantIDs: errInvalidTenantID,
},
}

func TestSingleResolver(t *testing.T) {
r := NewSingleResolver()
for _, tc := range append(commonResolverTestCases, []resolverTestCase{
{
name: "multi-tenant",
headerValue: strptr("tenant-a|tenant-b"),
tenantID: "tenant-a|tenant-b",
tenantIDs: []string{"tenant-a|tenant-b"},
},
{
name: "containing-forward-slash",
headerValue: strptr("forward/slash"),
errTenantID: errInvalidTenantID,
errTenantIDs: errInvalidTenantID,
},
{
name: "containing-backward-slash",
headerValue: strptr(`backward\slash`),
errTenantID: errInvalidTenantID,
errTenantIDs: errInvalidTenantID,
},
}...) {
t.Run(tc.name, tc.test(r))
}
}

func TestMultiResolver(t *testing.T) {
r := NewMultiResolver()
for _, tc := range append(commonResolverTestCases, []resolverTestCase{
{
name: "multi-tenant",
headerValue: strptr("tenant-a|tenant-b"),
errTenantID: user.ErrTooManyOrgIDs,
tenantIDs: []string{"tenant-a", "tenant-b"},
},
{
name: "multi-tenant-wrong-order",
headerValue: strptr("tenant-b|tenant-a"),
errTenantID: user.ErrTooManyOrgIDs,
tenantIDs: []string{"tenant-a", "tenant-b"},
},
{
name: "multi-tenant-duplicate-order",
headerValue: strptr("tenant-b|tenant-b|tenant-a"),
errTenantID: user.ErrTooManyOrgIDs,
tenantIDs: []string{"tenant-a", "tenant-b"},
},
{
name: "multi-tenant-with-relative-path",
headerValue: strptr("tenant-a|tenant-b|.."),
errTenantID: errInvalidTenantID,
errTenantIDs: errInvalidTenantID,
},
{
name: "containing-forward-slash",
headerValue: strptr("forward/slash"),
errTenantID: &errTenantIDUnsupportedCharacter{pos: 7, tenantID: "forward/slash"},
errTenantIDs: &errTenantIDUnsupportedCharacter{pos: 7, tenantID: "forward/slash"},
},
{
name: "containing-backward-slash",
headerValue: strptr(`backward\slash`),
errTenantID: &errTenantIDUnsupportedCharacter{pos: 8, tenantID: "backward\\slash"},
errTenantIDs: &errTenantIDUnsupportedCharacter{pos: 8, tenantID: "backward\\slash"},
},
}...) {
t.Run(tc.name, tc.test(r))
}
}
Loading