diff --git a/auth/range_perm_cache.go b/auth/range_perm_cache.go index a6989aaa2ee..691b65ba38e 100644 --- a/auth/range_perm_cache.go +++ b/auth/range_perm_cache.go @@ -38,18 +38,16 @@ func getMergedPerms(tx backend.BatchTx, userName string) *unifiedRangePermission for _, perm := range role.KeyPermission { var ivl adt.Interval - var rangeEnd string + var rangeEnd []byte - if len(perm.RangeEnd) == 1 && perm.RangeEnd[0] == 0 { - rangeEnd = "" - } else { - rangeEnd = string(perm.RangeEnd) + if len(perm.RangeEnd) != 1 || perm.RangeEnd[0] != 0 { + rangeEnd = perm.RangeEnd } if len(perm.RangeEnd) != 0 { - ivl = adt.NewStringAffineInterval(string(perm.Key), string(rangeEnd)) + ivl = adt.NewBytesAffineInterval(perm.Key, rangeEnd) } else { - ivl = adt.NewStringAffinePoint(string(perm.Key)) + ivl = adt.NewBytesAffinePoint(perm.Key) } switch perm.PermType { @@ -72,12 +70,12 @@ func getMergedPerms(tx backend.BatchTx, userName string) *unifiedRangePermission } } -func checkKeyInterval(cachedPerms *unifiedRangePermissions, key, rangeEnd string, permtyp authpb.Permission_Type) bool { - if len(rangeEnd) == 1 && rangeEnd[0] == '\x00' { - rangeEnd = "" +func checkKeyInterval(cachedPerms *unifiedRangePermissions, key, rangeEnd []byte, permtyp authpb.Permission_Type) bool { + if len(rangeEnd) == 1 && rangeEnd[0] == 0 { + rangeEnd = nil } - ivl := adt.NewStringAffineInterval(key, rangeEnd) + ivl := adt.NewBytesAffineInterval(key, rangeEnd) switch permtyp { case authpb.READ: return cachedPerms.readPerms.Contains(ivl) @@ -89,8 +87,8 @@ func checkKeyInterval(cachedPerms *unifiedRangePermissions, key, rangeEnd string return false } -func checkKeyPoint(cachedPerms *unifiedRangePermissions, key string, permtyp authpb.Permission_Type) bool { - pt := adt.NewStringAffinePoint(key) +func checkKeyPoint(cachedPerms *unifiedRangePermissions, key []byte, permtyp authpb.Permission_Type) bool { + pt := adt.NewBytesAffinePoint(key) switch permtyp { case authpb.READ: return cachedPerms.readPerms.Intersects(pt) @@ -102,7 +100,7 @@ func checkKeyPoint(cachedPerms *unifiedRangePermissions, key string, permtyp aut return false } -func (as *authStore) isRangeOpPermitted(tx backend.BatchTx, userName string, key, rangeEnd string, permtyp authpb.Permission_Type) bool { +func (as *authStore) isRangeOpPermitted(tx backend.BatchTx, userName string, key, rangeEnd []byte, permtyp authpb.Permission_Type) bool { // assumption: tx is Lock()ed _, ok := as.rangePermCache[userName] if !ok { diff --git a/auth/range_perm_cache_test.go b/auth/range_perm_cache_test.go index 8629e73a48a..fd8df6a9e0c 100644 --- a/auth/range_perm_cache_test.go +++ b/auth/range_perm_cache_test.go @@ -24,23 +24,23 @@ import ( func TestRangePermission(t *testing.T) { tests := []struct { perms []adt.Interval - begin string - end string + begin []byte + end []byte want bool }{ { - []adt.Interval{adt.NewStringAffineInterval("a", "c"), adt.NewStringAffineInterval("x", "z")}, - "a", "z", + []adt.Interval{adt.NewBytesAffineInterval([]byte("a"), []byte("c")), adt.NewBytesAffineInterval([]byte("x"), []byte("z"))}, + []byte("a"), []byte("z"), false, }, { - []adt.Interval{adt.NewStringAffineInterval("a", "f"), adt.NewStringAffineInterval("c", "d"), adt.NewStringAffineInterval("f", "z")}, - "a", "z", + []adt.Interval{adt.NewBytesAffineInterval([]byte("a"), []byte("f")), adt.NewBytesAffineInterval([]byte("c"), []byte("d")), adt.NewBytesAffineInterval([]byte("f"), []byte("z"))}, + []byte("a"), []byte("z"), true, }, { - []adt.Interval{adt.NewStringAffineInterval("a", "d"), adt.NewStringAffineInterval("a", "b"), adt.NewStringAffineInterval("c", "f")}, - "a", "f", + []adt.Interval{adt.NewBytesAffineInterval([]byte("a"), []byte("d")), adt.NewBytesAffineInterval([]byte("a"), []byte("b")), adt.NewBytesAffineInterval([]byte("c"), []byte("f"))}, + []byte("a"), []byte("f"), true, }, } diff --git a/auth/store.go b/auth/store.go index 4cec8e3f458..33a388a7964 100644 --- a/auth/store.go +++ b/auth/store.go @@ -749,7 +749,7 @@ func (as *authStore) isOpPermitted(userName string, revision uint64, key, rangeE return nil } - if as.isRangeOpPermitted(tx, userName, string(key), string(rangeEnd), permTyp) { + if as.isRangeOpPermitted(tx, userName, key, rangeEnd, permTyp) { return nil } diff --git a/pkg/adt/interval_tree.go b/pkg/adt/interval_tree.go index 9c5afb3f098..9769771ea4f 100644 --- a/pkg/adt/interval_tree.go +++ b/pkg/adt/interval_tree.go @@ -15,6 +15,7 @@ package adt import ( + "bytes" "math" ) @@ -558,3 +559,32 @@ func (v Int64Comparable) Compare(c Comparable) int { } return 0 } + +// BytesAffineComparable treats empty byte arrays as > all other byte arrays +type BytesAffineComparable []byte + +func (b BytesAffineComparable) Compare(c Comparable) int { + bc := c.(BytesAffineComparable) + + if len(b) == 0 { + if len(bc) == 0 { + return 0 + } + return 1 + } + if len(bc) == 0 { + return -1 + } + + return bytes.Compare(b, bc) +} + +func NewBytesAffineInterval(begin, end []byte) Interval { + return Interval{BytesAffineComparable(begin), BytesAffineComparable(end)} +} +func NewBytesAffinePoint(b []byte) Interval { + be := make([]byte, len(b)+1) + copy(be, b) + be[len(b)] = 0 + return NewBytesAffineInterval(b, be) +}