diff --git a/ast/term.go b/ast/term.go index 4664bc5dac..ce8ee4853d 100644 --- a/ast/term.go +++ b/ast/term.go @@ -1293,6 +1293,11 @@ func (arr *Array) Elem(i int) *Term { return arr.elems[i] } +// Set sets the element i of arr. +func (arr *Array) Set(i int, v *Term) { + arr.set(i, v) +} + // rehash updates the cached hash of arr. func (arr *Array) rehash() { arr.hash = 0 @@ -1306,6 +1311,7 @@ func (arr *Array) set(i int, v *Term) { arr.ground = arr.ground && v.IsGround() arr.elems[i] = v arr.hashs[i] = v.Value.Hash() + arr.rehash() } // Slice returns a slice of arr starting from i index to j. -1 @@ -2560,6 +2566,8 @@ func (obj *object) insert(k, v *Term) { } curr.value = v + + obj.rehash() return } } @@ -2584,6 +2592,19 @@ func (obj *object) insert(k, v *Term) { } } +func (obj *object) rehash() { + // obj.keys is considered truth, from which obj.hash and obj.elems are recalculated. + + obj.hash = 0 + obj.elems = make(map[int]*objectElem, len(obj.keys)) + + for _, elem := range obj.keys { + hash := elem.key.Hash() + obj.hash += hash + elem.value.Hash() + obj.elems[hash] = elem + } +} + func filterObject(o Value, filter Value) (Value, error) { if filter.Compare(Null{}) == 0 { return o, nil diff --git a/bundle/store.go b/bundle/store.go index 45bcf6e559..9a49f025e8 100644 --- a/bundle/store.go +++ b/bundle/store.go @@ -59,9 +59,25 @@ func metadataPath(name string) storage.Path { return append(BundlesBasePath, name, "manifest", "metadata") } +func read(ctx context.Context, store storage.Store, txn storage.Transaction, path storage.Path) (interface{}, error) { + value, err := store.Read(ctx, txn, path) + if err != nil { + return nil, err + } + + if astValue, ok := value.(ast.Value); ok { + value, err = ast.JSON(astValue) + if err != nil { + return nil, err + } + } + + return value, nil +} + // ReadBundleNamesFromStore will return a list of bundle names which have had their metadata stored. func ReadBundleNamesFromStore(ctx context.Context, store storage.Store, txn storage.Transaction) ([]string, error) { - value, err := store.Read(ctx, txn, BundlesBasePath) + value, err := read(ctx, store, txn, BundlesBasePath) if err != nil { return nil, err } @@ -153,7 +169,7 @@ func eraseWasmModulesFromStore(ctx context.Context, store storage.Store, txn sto // ReadWasmMetadataFromStore will read Wasm module resolver metadata from the store. func ReadWasmMetadataFromStore(ctx context.Context, store storage.Store, txn storage.Transaction, name string) ([]WasmResolver, error) { path := wasmEntrypointsPath(name) - value, err := store.Read(ctx, txn, path) + value, err := read(ctx, store, txn, path) if err != nil { return nil, err } @@ -176,7 +192,7 @@ func ReadWasmMetadataFromStore(ctx context.Context, store storage.Store, txn sto // ReadWasmModulesFromStore will write Wasm module resolver metadata from the store. func ReadWasmModulesFromStore(ctx context.Context, store storage.Store, txn storage.Transaction, name string) (map[string][]byte, error) { path := wasmModulePath(name) - value, err := store.Read(ctx, txn, path) + value, err := read(ctx, store, txn, path) if err != nil { return nil, err } @@ -205,7 +221,7 @@ func ReadWasmModulesFromStore(ctx context.Context, store storage.Store, txn stor // If the bundle is not activated, this function will return // storage NotFound error. func ReadBundleRootsFromStore(ctx context.Context, store storage.Store, txn storage.Transaction, name string) ([]string, error) { - value, err := store.Read(ctx, txn, rootsPath(name)) + value, err := read(ctx, store, txn, rootsPath(name)) if err != nil { return nil, err } @@ -235,7 +251,7 @@ func ReadBundleRevisionFromStore(ctx context.Context, store storage.Store, txn s } func readRevisionFromStore(ctx context.Context, store storage.Store, txn storage.Transaction, path storage.Path) (string, error) { - value, err := store.Read(ctx, txn, path) + value, err := read(ctx, store, txn, path) if err != nil { return "", err } @@ -256,7 +272,7 @@ func ReadBundleMetadataFromStore(ctx context.Context, store storage.Store, txn s } func readMetadataFromStore(ctx context.Context, store storage.Store, txn storage.Transaction, path storage.Path) (map[string]interface{}, error) { - value, err := store.Read(ctx, txn, path) + value, err := read(ctx, store, txn, path) if err != nil { return nil, suppressNotFound(err) } @@ -277,7 +293,7 @@ func ReadBundleEtagFromStore(ctx context.Context, store storage.Store, txn stora } func readEtagFromStore(ctx context.Context, store storage.Store, txn storage.Transaction, path storage.Path) (string, error) { - value, err := store.Read(ctx, txn, path) + value, err := read(ctx, store, txn, path) if err != nil { return "", err } @@ -544,14 +560,7 @@ func activateDeltaBundles(opts *ActivateOpts, bundles map[string]*Bundle) error return err } - bs, err := json.Marshal(value) - if err != nil { - return fmt.Errorf("corrupt manifest data: %w", err) - } - - var manifest Manifest - - err = util.UnmarshalJSON(bs, &manifest) + manifest, err := valueToManifest(value) if err != nil { return fmt.Errorf("corrupt manifest data: %w", err) } @@ -585,6 +594,30 @@ func activateDeltaBundles(opts *ActivateOpts, bundles map[string]*Bundle) error return nil } +func valueToManifest(v interface{}) (Manifest, error) { + if astV, ok := v.(ast.Value); ok { + var err error + v, err = ast.JSON(astV) + if err != nil { + return Manifest{}, err + } + } + + var manifest Manifest + + bs, err := json.Marshal(v) + if err != nil { + return Manifest{}, err + } + + err = util.UnmarshalJSON(bs, &manifest) + if err != nil { + return Manifest{}, err + } + + return manifest, nil +} + // erase bundles by name and roots. This will clear all policies and data at its roots and remove its // manifest from storage. func eraseBundles(ctx context.Context, store storage.Store, txn storage.Transaction, parserOpts ast.ParserOptions, names map[string]struct{}, roots map[string]struct{}) (map[string]*ast.Module, error) { diff --git a/bundle/store_test.go b/bundle/store_test.go index 7d6bae7d81..2d859f0aae 100644 --- a/bundle/store_test.go +++ b/bundle/store_test.go @@ -22,11 +22,12 @@ import ( "github.com/open-policy-agent/opa/storage" "github.com/open-policy-agent/opa/storage/disk" - inmem "github.com/open-policy-agent/opa/storage/inmem/test" + "github.com/open-policy-agent/opa/storage/inmem" + inmemtst "github.com/open-policy-agent/opa/storage/inmem/test" ) func TestManifestStoreLifecycleSingleBundle(t *testing.T) { - store := inmem.New() + store := inmemtst.New() ctx := context.Background() tb := Manifest{ Revision: "abc123", @@ -40,7 +41,7 @@ func TestManifestStoreLifecycleSingleBundle(t *testing.T) { } func TestManifestStoreLifecycleMultiBundle(t *testing.T) { - store := inmem.New() + store := inmemtst.New() ctx := context.Background() bundles := map[string]Manifest{ @@ -63,7 +64,7 @@ func TestManifestStoreLifecycleMultiBundle(t *testing.T) { } func TestLegacyManifestStoreLifecycle(t *testing.T) { - store := inmem.New() + store := inmemtst.New() ctx := context.Background() tb := Manifest{ Revision: "abc123", @@ -101,7 +102,7 @@ func TestLegacyManifestStoreLifecycle(t *testing.T) { } func TestMixedManifestStoreLifecycle(t *testing.T) { - store := inmem.New() + store := inmemtst.New() ctx := context.Background() bundles := map[string]Manifest{ "bundle1": { @@ -3483,101 +3484,117 @@ func TestDeltaBundleLazyModeWithDefaultRules(t *testing.T) { } func TestBundleLifecycle(t *testing.T) { - ctx := context.Background() - mockStore := mock.New() + tests := []struct { + note string + readAst bool + }{ + { + note: "read raw", + readAst: false, + }, + { + note: "read ast", + readAst: true, + }, + } - compiler := ast.NewCompiler() - m := metrics.New() + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + ctx := context.Background() + mockStore := mock.New(inmem.OptReturnASTValuesOnRead(tc.readAst)) - extraMods := map[string]*ast.Module{ - "mod1": ast.MustParseModule("package x\np = true"), - } + compiler := ast.NewCompiler() + m := metrics.New() - const mod2 = "package a\np = true" - mod3 := "package b\np = true" + extraMods := map[string]*ast.Module{ + "mod1": ast.MustParseModule("package x\np = true"), + } - bundles := map[string]*Bundle{ - "bundle1": { - Manifest: Manifest{ - Roots: &[]string{"a"}, - }, - Data: map[string]interface{}{ - "a": map[string]interface{}{ - "b": "foo", - }, - }, - Modules: []ModuleFile{ - { - Path: "a/policy.rego", - Raw: []byte(mod2), - Parsed: ast.MustParseModule(mod2), - }, - }, - Etag: "foo"}, - "bundle2": { - Manifest: Manifest{ - Roots: &[]string{"b", "c"}, - }, - Data: nil, - Modules: []ModuleFile{ - { - Path: "b/policy.rego", - Raw: []byte(mod3), - Parsed: ast.MustParseModule(mod3), + const mod2 = "package a\np = true" + mod3 := "package b\np = true" + + bundles := map[string]*Bundle{ + "bundle1": { + Manifest: Manifest{ + Roots: &[]string{"a"}, + }, + Data: map[string]interface{}{ + "a": map[string]interface{}{ + "b": "foo", + }, + }, + Modules: []ModuleFile{ + { + Path: "a/policy.rego", + Raw: []byte(mod2), + Parsed: ast.MustParseModule(mod2), + }, + }, + Etag: "foo"}, + "bundle2": { + Manifest: Manifest{ + Roots: &[]string{"b", "c"}, + }, + Data: nil, + Modules: []ModuleFile{ + { + Path: "b/policy.rego", + Raw: []byte(mod3), + Parsed: ast.MustParseModule(mod3), + }, + }, }, - }, - }, - } + } - txn := storage.NewTransactionOrDie(ctx, mockStore, storage.WriteParams) + txn := storage.NewTransactionOrDie(ctx, mockStore, storage.WriteParams) - err := Activate(&ActivateOpts{ - Ctx: ctx, - Store: mockStore, - Txn: txn, - Compiler: compiler, - Metrics: m, - Bundles: bundles, - ExtraModules: extraMods, - }) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } + err := Activate(&ActivateOpts{ + Ctx: ctx, + Store: mockStore, + Txn: txn, + Compiler: compiler, + Metrics: m, + Bundles: bundles, + ExtraModules: extraMods, + }) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } - err = mockStore.Commit(ctx, txn) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } + err = mockStore.Commit(ctx, txn) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } - // Ensure the bundle was activated - txn = storage.NewTransactionOrDie(ctx, mockStore) - names, err := ReadBundleNamesFromStore(ctx, mockStore, txn) - if err != nil { - t.Fatal(err) - } + // Ensure the bundle was activated + txn = storage.NewTransactionOrDie(ctx, mockStore) + names, err := ReadBundleNamesFromStore(ctx, mockStore, txn) + if err != nil { + t.Fatal(err) + } - if len(names) != len(bundles) { - t.Fatalf("expected %d bundles in store, found %d", len(bundles), len(names)) - } - for _, name := range names { - if _, ok := bundles[name]; !ok { - t.Fatalf("unexpected bundle name found in store: %s", name) - } - } + if len(names) != len(bundles) { + t.Fatalf("expected %d bundles in store, found %d", len(bundles), len(names)) + } + for _, name := range names { + if _, ok := bundles[name]; !ok { + t.Fatalf("unexpected bundle name found in store: %s", name) + } + } - for bundleName, bundle := range bundles { - for modName := range bundle.ParsedModules(bundleName) { - if _, ok := compiler.Modules[modName]; !ok { - t.Fatalf("expected module %s from bundle %s to have been compiled", modName, bundleName) + for bundleName, bundle := range bundles { + for modName := range bundle.ParsedModules(bundleName) { + if _, ok := compiler.Modules[modName]; !ok { + t.Fatalf("expected module %s from bundle %s to have been compiled", modName, bundleName) + } + } } - } - } - actual, err := mockStore.Read(ctx, txn, storage.MustParsePath("/")) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - expectedRaw := ` + actual, err := mockStore.Read(ctx, txn, storage.MustParsePath("/")) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + expectedRaw := ` { "a": { "b": "foo" @@ -3602,270 +3619,282 @@ func TestBundleLifecycle(t *testing.T) { } } ` - expected := loadExpectedSortedResult(expectedRaw) - if !reflect.DeepEqual(expected, actual) { - t.Errorf("expected %v, got %v", expectedRaw, string(util.MustMarshalJSON(actual))) - } + assertEqual(t, tc.readAst, expectedRaw, actual) - // Ensure that the extra module was included - if _, ok := compiler.Modules["mod1"]; !ok { - t.Fatalf("expected extra module to be compiled") - } + // Ensure that the extra module was included + if _, ok := compiler.Modules["mod1"]; !ok { + t.Fatalf("expected extra module to be compiled") + } - // Stop the "read" transaction - mockStore.Abort(ctx, txn) + // Stop the "read" transaction + mockStore.Abort(ctx, txn) - txn = storage.NewTransactionOrDie(ctx, mockStore, storage.WriteParams) + txn = storage.NewTransactionOrDie(ctx, mockStore, storage.WriteParams) - err = Deactivate(&DeactivateOpts{ - Ctx: ctx, - Store: mockStore, - Txn: txn, - BundleNames: map[string]struct{}{"bundle1": {}, "bundle2": {}}, - }) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } + err = Deactivate(&DeactivateOpts{ + Ctx: ctx, + Store: mockStore, + Txn: txn, + BundleNames: map[string]struct{}{"bundle1": {}, "bundle2": {}}, + }) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } - err = mockStore.Commit(ctx, txn) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } + err = mockStore.Commit(ctx, txn) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } - // Expect the store to have been cleared out after deactivating the bundles - txn = storage.NewTransactionOrDie(ctx, mockStore) - names, err = ReadBundleNamesFromStore(ctx, mockStore, txn) - if err != nil { - t.Fatal(err) - } + // Expect the store to have been cleared out after deactivating the bundles + txn = storage.NewTransactionOrDie(ctx, mockStore) + names, err = ReadBundleNamesFromStore(ctx, mockStore, txn) + if err != nil { + t.Fatal(err) + } - if len(names) != 0 { - t.Fatalf("expected 0 bundles in store, found %d", len(names)) - } + if len(names) != 0 { + t.Fatalf("expected 0 bundles in store, found %d", len(names)) + } - actual, err = mockStore.Read(ctx, txn, storage.MustParsePath("/")) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - expectedRaw = `{"system": {"bundles": {}}}` - expected = loadExpectedSortedResult(expectedRaw) - if !reflect.DeepEqual(expected, actual) { - t.Errorf("expected %v, got %v", expectedRaw, string(util.MustMarshalJSON(actual))) - } + actual, err = mockStore.Read(ctx, txn, storage.MustParsePath("/")) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + expectedRaw = `{"system": {"bundles": {}}}` + assertEqual(t, tc.readAst, expectedRaw, actual) - mockStore.AssertValid(t) + mockStore.AssertValid(t) + }) + } } func TestDeltaBundleLifecycle(t *testing.T) { - ctx := context.Background() - mockStore := mock.New() + tests := []struct { + note string + readAst bool + }{ + { + note: "read raw", + readAst: false, + }, + { + note: "read ast", + readAst: true, + }, + } - compiler := ast.NewCompiler() - m := metrics.New() + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + ctx := context.Background() + mockStore := mock.New(inmem.OptReturnASTValuesOnRead(tc.readAst)) - mod1 := "package a\np = true" - mod2 := "package b\np = true" + compiler := ast.NewCompiler() + m := metrics.New() - bundles := map[string]*Bundle{ - "bundle1": { - Manifest: Manifest{ - Roots: &[]string{"a"}, - }, - Data: map[string]interface{}{ - "a": map[string]interface{}{ - "b": "foo", - "e": map[string]interface{}{ - "f": "bar", + mod1 := "package a\np = true" + mod2 := "package b\np = true" + + bundles := map[string]*Bundle{ + "bundle1": { + Manifest: Manifest{ + Roots: &[]string{"a"}, }, - "x": []map[string]string{{"name": "john"}, {"name": "jane"}}, - }, - }, - Modules: []ModuleFile{ - { - Path: "a/policy.rego", - Raw: []byte(mod1), - Parsed: ast.MustParseModule(mod1), + Data: map[string]interface{}{ + "a": map[string]interface{}{ + "b": "foo", + "e": map[string]interface{}{ + "f": "bar", + }, + "x": []map[string]string{{"name": "john"}, {"name": "jane"}}, + }, + }, + Modules: []ModuleFile{ + { + Path: "a/policy.rego", + Raw: []byte(mod1), + Parsed: ast.MustParseModule(mod1), + }, + }, + Etag: "foo", }, - }, - Etag: "foo", - }, - "bundle2": { - Manifest: Manifest{ - Roots: &[]string{"b", "c"}, - }, - Data: nil, - Modules: []ModuleFile{ - { - Path: "b/policy.rego", - Raw: []byte(mod2), - Parsed: ast.MustParseModule(mod2), + "bundle2": { + Manifest: Manifest{ + Roots: &[]string{"b", "c"}, + }, + Data: nil, + Modules: []ModuleFile{ + { + Path: "b/policy.rego", + Raw: []byte(mod2), + Parsed: ast.MustParseModule(mod2), + }, + }, }, - }, - }, - } + } - txn := storage.NewTransactionOrDie(ctx, mockStore, storage.WriteParams) + txn := storage.NewTransactionOrDie(ctx, mockStore, storage.WriteParams) - err := Activate(&ActivateOpts{ - Ctx: ctx, - Store: mockStore, - Txn: txn, - Compiler: compiler, - Metrics: m, - Bundles: bundles, - }) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } + err := Activate(&ActivateOpts{ + Ctx: ctx, + Store: mockStore, + Txn: txn, + Compiler: compiler, + Metrics: m, + Bundles: bundles, + }) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } - err = mockStore.Commit(ctx, txn) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } + err = mockStore.Commit(ctx, txn) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } - // Ensure the snapshot bundles were activated - txn = storage.NewTransactionOrDie(ctx, mockStore) - names, err := ReadBundleNamesFromStore(ctx, mockStore, txn) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } + // Ensure the snapshot bundles were activated + txn = storage.NewTransactionOrDie(ctx, mockStore) + names, err := ReadBundleNamesFromStore(ctx, mockStore, txn) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } - if len(names) != len(bundles) { - t.Fatalf("expected %d bundles in store, found %d", len(bundles), len(names)) - } - for _, name := range names { - if _, ok := bundles[name]; !ok { - t.Fatalf("unexpected bundle name found in store: %s", name) - } - } + if len(names) != len(bundles) { + t.Fatalf("expected %d bundles in store, found %d", len(bundles), len(names)) + } + for _, name := range names { + if _, ok := bundles[name]; !ok { + t.Fatalf("unexpected bundle name found in store: %s", name) + } + } - for bundleName, bundle := range bundles { - for modName := range bundle.ParsedModules(bundleName) { - if _, ok := compiler.Modules[modName]; !ok { - t.Fatalf("expected module %s from bundle %s to have been compiled", modName, bundleName) + for bundleName, bundle := range bundles { + for modName := range bundle.ParsedModules(bundleName) { + if _, ok := compiler.Modules[modName]; !ok { + t.Fatalf("expected module %s from bundle %s to have been compiled", modName, bundleName) + } + } } - } - } - // Stop the "read" transaction - mockStore.Abort(ctx, txn) + // Stop the "read" transaction + mockStore.Abort(ctx, txn) - // create a delta bundle and activate it + // create a delta bundle and activate it - // add a new object member - p1 := PatchOperation{ - Op: "upsert", - Path: "/a/c/d", - Value: []string{"foo", "bar"}, - } + // add a new object member + p1 := PatchOperation{ + Op: "upsert", + Path: "/a/c/d", + Value: []string{"foo", "bar"}, + } - // append value to array - p2 := PatchOperation{ - Op: "upsert", - Path: "/a/c/d/-", - Value: "baz", - } + // append value to array + p2 := PatchOperation{ + Op: "upsert", + Path: "/a/c/d/-", + Value: "baz", + } - // insert value in array - p3 := PatchOperation{ - Op: "upsert", - Path: "/a/x/1", - Value: map[string]string{"name": "alice"}, - } + // insert value in array + p3 := PatchOperation{ + Op: "upsert", + Path: "/a/x/1", + Value: map[string]string{"name": "alice"}, + } - // replace a value - p4 := PatchOperation{ - Op: "replace", - Path: "a/b", - Value: "bar", - } + // replace a value + p4 := PatchOperation{ + Op: "replace", + Path: "a/b", + Value: "bar", + } - // remove a value - p5 := PatchOperation{ - Op: "remove", - Path: "a/e", - } + // remove a value + p5 := PatchOperation{ + Op: "remove", + Path: "a/e", + } - // add a new object with an escaped character in the path - p6 := PatchOperation{ - Op: "upsert", - Path: "a/y/~0z", - Value: []int{1, 2, 3}, - } + // add a new object with an escaped character in the path + p6 := PatchOperation{ + Op: "upsert", + Path: "a/y/~0z", + Value: []int{1, 2, 3}, + } - // add a new object root - p7 := PatchOperation{ - Op: "upsert", - Path: "/c/d", - Value: []string{"foo", "bar"}, - } + // add a new object root + p7 := PatchOperation{ + Op: "upsert", + Path: "/c/d", + Value: []string{"foo", "bar"}, + } - deltaBundles := map[string]*Bundle{ - "bundle1": { - Manifest: Manifest{ - Revision: "delta-1", - Roots: &[]string{"a"}, - }, - Patch: Patch{Data: []PatchOperation{p1, p2, p3, p4, p5, p6}}, - Etag: "bar", - }, - "bundle2": { - Manifest: Manifest{ - Revision: "delta-2", - Roots: &[]string{"b", "c"}, - }, - Patch: Patch{Data: []PatchOperation{p7}}, - Etag: "baz", - }, - "bundle3": { - Manifest: Manifest{ - Roots: &[]string{"d"}, - }, - Data: map[string]interface{}{ - "d": map[string]interface{}{ - "e": "foo", + deltaBundles := map[string]*Bundle{ + "bundle1": { + Manifest: Manifest{ + Revision: "delta-1", + Roots: &[]string{"a"}, + }, + Patch: Patch{Data: []PatchOperation{p1, p2, p3, p4, p5, p6}}, + Etag: "bar", }, - }, - }, - } + "bundle2": { + Manifest: Manifest{ + Revision: "delta-2", + Roots: &[]string{"b", "c"}, + }, + Patch: Patch{Data: []PatchOperation{p7}}, + Etag: "baz", + }, + "bundle3": { + Manifest: Manifest{ + Roots: &[]string{"d"}, + }, + Data: map[string]interface{}{ + "d": map[string]interface{}{ + "e": "foo", + }, + }, + }, + } - txn = storage.NewTransactionOrDie(ctx, mockStore, storage.WriteParams) + txn = storage.NewTransactionOrDie(ctx, mockStore, storage.WriteParams) - err = Activate(&ActivateOpts{ - Ctx: ctx, - Store: mockStore, - Txn: txn, - Compiler: compiler, - Metrics: m, - Bundles: deltaBundles, - }) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } + err = Activate(&ActivateOpts{ + Ctx: ctx, + Store: mockStore, + Txn: txn, + Compiler: compiler, + Metrics: m, + Bundles: deltaBundles, + }) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } - err = mockStore.Commit(ctx, txn) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } + err = mockStore.Commit(ctx, txn) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } - // check the modules from the snapshot bundles are on the compiler - for bundleName, bundle := range bundles { - for modName := range bundle.ParsedModules(bundleName) { - if _, ok := compiler.Modules[modName]; !ok { - t.Fatalf("expected module %s from bundle %s to have been compiled", modName, bundleName) + // check the modules from the snapshot bundles are on the compiler + for bundleName, bundle := range bundles { + for modName := range bundle.ParsedModules(bundleName) { + if _, ok := compiler.Modules[modName]; !ok { + t.Fatalf("expected module %s from bundle %s to have been compiled", modName, bundleName) + } + } } - } - } - // Ensure the patches were applied - txn = storage.NewTransactionOrDie(ctx, mockStore) + // Ensure the patches were applied + txn = storage.NewTransactionOrDie(ctx, mockStore) - actual, err := mockStore.Read(ctx, txn, storage.MustParsePath("/")) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } + actual, err := mockStore.Read(ctx, txn, storage.MustParsePath("/")) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } - expectedRaw := ` + expectedRaw := ` { "a": { "b": "bar", @@ -3904,91 +3933,105 @@ func TestDeltaBundleLifecycle(t *testing.T) { } }` - expected := loadExpectedSortedResult(expectedRaw) - if !reflect.DeepEqual(expected, actual) { - t.Errorf("expected %v, got %v", expectedRaw, string(util.MustMarshalJSON(actual))) - } + assertEqual(t, tc.readAst, expectedRaw, actual) - // Stop the "read" transaction - mockStore.Abort(ctx, txn) + // Stop the "read" transaction + mockStore.Abort(ctx, txn) - mockStore.AssertValid(t) + mockStore.AssertValid(t) + }) + } } func TestDeltaBundleActivate(t *testing.T) { - - ctx := context.Background() - mockStore := mock.New() - - compiler := ast.NewCompiler() - m := metrics.New() - - // create a delta bundle - p1 := PatchOperation{ - Op: "upsert", - Path: "/a/c/d", - Value: []string{"foo", "bar"}, - } - - deltaBundles := map[string]*Bundle{ - "bundle1": { - Manifest: Manifest{ - Revision: "delta", - Roots: &[]string{"a"}, - }, - Patch: Patch{Data: []PatchOperation{p1}}, - Etag: "foo", + tests := []struct { + note string + readAst bool + }{ + { + note: "read raw", + readAst: false, + }, + { + note: "read ast", + readAst: true, }, } - txn := storage.NewTransactionOrDie(ctx, mockStore, storage.WriteParams) + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + ctx := context.Background() + mockStore := mock.New(inmem.OptReturnASTValuesOnRead(tc.readAst)) - err := Activate(&ActivateOpts{ - Ctx: ctx, - Store: mockStore, - Txn: txn, - Compiler: compiler, - Metrics: m, - Bundles: deltaBundles, - }) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } + compiler := ast.NewCompiler() + m := metrics.New() - err = mockStore.Commit(ctx, txn) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } + // create a delta bundle + p1 := PatchOperation{ + Op: "upsert", + Path: "/a/c/d", + Value: []string{"foo", "bar"}, + } - // Ensure the delta bundle was activated - txn = storage.NewTransactionOrDie(ctx, mockStore) - names, err := ReadBundleNamesFromStore(ctx, mockStore, txn) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } + deltaBundles := map[string]*Bundle{ + "bundle1": { + Manifest: Manifest{ + Revision: "delta", + Roots: &[]string{"a"}, + }, + Patch: Patch{Data: []PatchOperation{p1}}, + Etag: "foo", + }, + } - if len(names) != len(deltaBundles) { - t.Fatalf("expected %d bundles in store, found %d", len(deltaBundles), len(names)) - } + txn := storage.NewTransactionOrDie(ctx, mockStore, storage.WriteParams) - for _, name := range names { - if _, ok := deltaBundles[name]; !ok { - t.Fatalf("unexpected bundle name found in store: %s", name) - } - } + err := Activate(&ActivateOpts{ + Ctx: ctx, + Store: mockStore, + Txn: txn, + Compiler: compiler, + Metrics: m, + Bundles: deltaBundles, + }) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } - // Stop the "read" transaction - mockStore.Abort(ctx, txn) + err = mockStore.Commit(ctx, txn) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } - // Ensure the patches were applied - txn = storage.NewTransactionOrDie(ctx, mockStore) + // Ensure the delta bundle was activated + txn = storage.NewTransactionOrDie(ctx, mockStore) + names, err := ReadBundleNamesFromStore(ctx, mockStore, txn) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } - actual, err := mockStore.Read(ctx, txn, storage.MustParsePath("/")) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } + if len(names) != len(deltaBundles) { + t.Fatalf("expected %d bundles in store, found %d", len(deltaBundles), len(names)) + } - expectedRaw := ` + for _, name := range names { + if _, ok := deltaBundles[name]; !ok { + t.Fatalf("unexpected bundle name found in store: %s", name) + } + } + + // Stop the "read" transaction + mockStore.Abort(ctx, txn) + + // Ensure the patches were applied + txn = storage.NewTransactionOrDie(ctx, mockStore) + + actual, err := mockStore.Read(ctx, txn, storage.MustParsePath("/")) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + expectedRaw := ` { "a": { "c": { @@ -4008,15 +4051,30 @@ func TestDeltaBundleActivate(t *testing.T) { } } ` - expected := loadExpectedSortedResult(expectedRaw) - if !reflect.DeepEqual(expected, actual) { - t.Errorf("expected %v, got %v", expectedRaw, string(util.MustMarshalJSON(actual))) + assertEqual(t, tc.readAst, expectedRaw, actual) + + // Stop the "read" transaction + mockStore.Abort(ctx, txn) + + mockStore.AssertValid(t) + }) } +} - // Stop the "read" transaction - mockStore.Abort(ctx, txn) +func assertEqual(t *testing.T, expectAst bool, expected string, actual interface{}) { + t.Helper() - mockStore.AssertValid(t) + if expectAst { + exp := ast.MustParseTerm(expected) + if ast.Compare(exp, actual) != 0 { + t.Errorf("expected:\n\n%v\n\ngot:\n\n%v", expected, actual) + } + } else { + exp := loadExpectedSortedResult(expected) + if !reflect.DeepEqual(exp, actual) { + t.Errorf("expected:\n\n%v\n\ngot:\n\n%v", expected, actual) + } + } } func TestDeltaBundleBadManifest(t *testing.T) { @@ -4122,6 +4180,20 @@ func TestDeltaBundleBadManifest(t *testing.T) { } func TestEraseData(t *testing.T) { + storeReadModes := []struct { + note string + readAst bool + }{ + { + note: "read raw", + readAst: false, + }, + { + note: "read ast", + readAst: true, + }, + } + ctx := context.Background() cases := []struct { note string @@ -4182,37 +4254,38 @@ func TestEraseData(t *testing.T) { }, } - for _, tc := range cases { - t.Run(tc.note, func(t *testing.T) { - mockStore := mock.NewWithData(tc.initialData) - txn := storage.NewTransactionOrDie(ctx, mockStore, storage.WriteParams) + for _, rm := range storeReadModes { + t.Run(rm.note, func(t *testing.T) { + for _, tc := range cases { + t.Run(tc.note, func(t *testing.T) { + mockStore := mock.NewWithData(tc.initialData, inmem.OptReturnASTValuesOnRead(rm.readAst)) + txn := storage.NewTransactionOrDie(ctx, mockStore, storage.WriteParams) - roots := map[string]struct{}{} - for _, root := range tc.roots { - roots[root] = struct{}{} - } + roots := map[string]struct{}{} + for _, root := range tc.roots { + roots[root] = struct{}{} + } - err := eraseData(ctx, mockStore, txn, roots) - if !tc.expectErr && err != nil { - t.Fatalf("unepected error: %s", err) - } else if tc.expectErr && err == nil { - t.Fatalf("expected error, got: %s", err) - } + err := eraseData(ctx, mockStore, txn, roots) + if !tc.expectErr && err != nil { + t.Fatalf("unepected error: %s", err) + } else if tc.expectErr && err == nil { + t.Fatalf("expected error, got: %s", err) + } - err = mockStore.Commit(ctx, txn) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - mockStore.AssertValid(t) + err = mockStore.Commit(ctx, txn) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + mockStore.AssertValid(t) - txn = storage.NewTransactionOrDie(ctx, mockStore) - actual, err := mockStore.Read(ctx, txn, storage.MustParsePath("/")) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - expected := loadExpectedSortedResult(tc.expected) - if !reflect.DeepEqual(expected, actual) { - t.Errorf("expected %v, got %v", tc.expected, actual) + txn = storage.NewTransactionOrDie(ctx, mockStore) + actual, err := mockStore.Read(ctx, txn, storage.MustParsePath("/")) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + assertEqual(t, rm.readAst, tc.expected, actual) + }) } }) } @@ -4344,6 +4417,20 @@ func TestErasePolicies(t *testing.T) { } func TestWriteData(t *testing.T) { + storeReadModes := []struct { + note string + readAst bool + }{ + { + note: "read raw", + readAst: false, + }, + { + note: "read ast", + readAst: true, + }, + } + ctx := context.Background() cases := []struct { note string @@ -4428,32 +4515,33 @@ func TestWriteData(t *testing.T) { }, } - for _, tc := range cases { - t.Run(tc.note, func(t *testing.T) { - mockStore := mock.NewWithData(tc.existingData) - txn := storage.NewTransactionOrDie(ctx, mockStore, storage.WriteParams) + for _, rm := range storeReadModes { + t.Run(rm.note, func(t *testing.T) { + for _, tc := range cases { + t.Run(tc.note, func(t *testing.T) { + mockStore := mock.NewWithData(tc.existingData, inmem.OptReturnASTValuesOnRead(rm.readAst)) + txn := storage.NewTransactionOrDie(ctx, mockStore, storage.WriteParams) - err := writeData(ctx, mockStore, txn, tc.roots, tc.data) - if !tc.expectErr && err != nil { - t.Fatalf("unepected error: %s", err) - } else if tc.expectErr && err == nil { - t.Fatalf("expected error, got: %s", err) - } + err := writeData(ctx, mockStore, txn, tc.roots, tc.data) + if !tc.expectErr && err != nil { + t.Fatalf("unepected error: %s", err) + } else if tc.expectErr && err == nil { + t.Fatalf("expected error, got: %s", err) + } - err = mockStore.Commit(ctx, txn) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - mockStore.AssertValid(t) + err = mockStore.Commit(ctx, txn) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + mockStore.AssertValid(t) - txn = storage.NewTransactionOrDie(ctx, mockStore) - actual, err := mockStore.Read(ctx, txn, storage.MustParsePath("/")) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - expected := loadExpectedSortedResult(tc.expected) - if !reflect.DeepEqual(expected, actual) { - t.Errorf("expected %v, got %v", tc.expected, actual) + txn = storage.NewTransactionOrDie(ctx, mockStore) + actual, err := mockStore.Read(ctx, txn, storage.MustParsePath("/")) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + assertEqual(t, rm.readAst, tc.expected, actual) + }) } }) } @@ -4933,3 +5021,182 @@ func TestHasRootsOverlap(t *testing.T) { }) } } + +func TestBundleStoreHelpers(t *testing.T) { + storeReadModes := []struct { + note string + readAst bool + }{ + { + note: "read raw", + readAst: false, + }, + { + note: "read ast", + readAst: true, + }, + } + + ctx := context.Background() + + bundles := map[string]*Bundle{ + "bundle1": { + Manifest: Manifest{ + Roots: &[]string{}, + }, + }, + "bundle2": { + Manifest: Manifest{ + Roots: &[]string{"a"}, + Revision: "foo", + Metadata: map[string]interface{}{ + "a": "b", + }, + WasmResolvers: []WasmResolver{ + { + Entrypoint: "foo/bar", + Module: "m.wasm", + }, + }, + }, + Etag: "bar", + WasmModules: []WasmModuleFile{ + { + Path: "/m.wasm", + Raw: []byte("d2FzbS1tb2R1bGU="), + }, + }, + }, + } + + for _, srm := range storeReadModes { + t.Run(srm.note, func(t *testing.T) { + mockStore := mock.NewWithData(nil, inmem.OptReturnASTValuesOnRead(srm.readAst)) + txn := storage.NewTransactionOrDie(ctx, mockStore, storage.WriteParams) + c := ast.NewCompiler() + m := metrics.New() + + err := Activate(&ActivateOpts{ + Ctx: ctx, + Store: mockStore, + Txn: txn, + Compiler: c, + Metrics: m, + Bundles: bundles, + }) + + if err != nil { + t.Fatal(err) + } + + // Bundle names + + if names, err := ReadBundleNamesFromStore(ctx, mockStore, txn); err != nil { + t.Fatalf("unexpected error: %s", err) + } else if len(names) != len(bundles) { + t.Errorf("expected bundle names:\n\n%v\n\nin store, found\n\n%v", bundles, names) + } else { + for _, name := range names { + if _, ok := bundles[name]; !ok { + t.Errorf("expected bundle names:\n\n%v\n\nin store, found\n\n%v", bundles, names) + } + } + } + + // Etag + + if etag, err := ReadBundleEtagFromStore(ctx, mockStore, txn, "bundle1"); err != nil { + t.Fatalf("unexpected error: %s", err) + } else if etag != "" { + t.Errorf("expected empty etag but got %s", etag) + } + + if etag, err := ReadBundleEtagFromStore(ctx, mockStore, txn, "bundle2"); err != nil { + t.Fatalf("unexpected error: %s", err) + } else if exp := "bar"; etag != exp { + t.Errorf("expected etag %s but got %s", exp, etag) + } + + // Revision + + if rev, err := ReadBundleRevisionFromStore(ctx, mockStore, txn, "bundle1"); err != nil { + t.Fatalf("unexpected error: %s", err) + } else if rev != "" { + t.Errorf("expected empty revision but got %s", rev) + } + + if rev, err := ReadBundleRevisionFromStore(ctx, mockStore, txn, "bundle2"); err != nil { + t.Fatalf("unexpected error: %s", err) + } else if exp := "foo"; rev != exp { + t.Errorf("expected revision %s but got %s", exp, rev) + } + + // Roots + + if roots, err := ReadBundleRootsFromStore(ctx, mockStore, txn, "bundle1"); err != nil { + t.Fatalf("unexpected error: %s", err) + } else if len(roots) != 0 { + t.Errorf("expected empty roots but got %v", roots) + } + + if roots, err := ReadBundleRootsFromStore(ctx, mockStore, txn, "bundle2"); err != nil { + t.Fatalf("unexpected error: %s", err) + } else if exp := *bundles["bundle2"].Manifest.Roots; !reflect.DeepEqual(exp, roots) { + t.Errorf("expected roots %v but got %v", exp, roots) + } + + // Bundle metadata + + if meta, err := ReadBundleMetadataFromStore(ctx, mockStore, txn, "bundle1"); err != nil { + t.Fatalf("unexpected error: %s", err) + } else if len(meta) != 0 { + t.Errorf("expected empty metadata but got %v", meta) + } + + if meta, err := ReadBundleMetadataFromStore(ctx, mockStore, txn, "bundle2"); err != nil { + t.Fatalf("unexpected error: %s", err) + } else if exp := bundles["bundle2"].Manifest.Metadata; !reflect.DeepEqual(exp, meta) { + t.Errorf("expected metadata %v but got %v", exp, meta) + } + + // Wasm metadata + + if _, err := ReadWasmMetadataFromStore(ctx, mockStore, txn, "bundle1"); err == nil { + t.Fatalf("expected error but got nil") + } else if exp, act := "storage_not_found_error: /bundles/bundle1/manifest/wasm: document does not exist", err.Error(); !strings.Contains(act, exp) { + t.Fatalf("expected error:\n\n%s\n\nbut got:\n\n%v", exp, act) + } + + if resolvers, err := ReadWasmMetadataFromStore(ctx, mockStore, txn, "bundle2"); err != nil { + t.Fatalf("unexpected error: %s", err) + } else if exp := bundles["bundle2"].Manifest.WasmResolvers; !reflect.DeepEqual(exp, resolvers) { + t.Errorf("expected wasm metadata:\n\n%v\n\nbut got:\n\n%v", exp, resolvers) + } + + // Wasm modules + + if _, err := ReadWasmModulesFromStore(ctx, mockStore, txn, "bundle1"); err == nil { + t.Fatalf("expected error but got nil") + } else if exp, act := "storage_not_found_error: /bundles/bundle1/wasm: document does not exist", err.Error(); !strings.Contains(act, exp) { + t.Fatalf("expected error:\n\n%s\n\nbut got:\n\n%v", exp, act) + } + + if modules, err := ReadWasmModulesFromStore(ctx, mockStore, txn, "bundle2"); err != nil { + t.Fatalf("unexpected error: %s", err) + } else if exp := bundles["bundle2"].WasmModules; len(exp) != len(modules) { + t.Errorf("expected wasm modules:\n\n%v\n\nbut got:\n\n%v", exp, modules) + } else { + for _, exp := range bundles["bundle2"].WasmModules { + act := modules[exp.Path] + if act == nil { + t.Errorf("expected wasm module %s but got nil", exp.Path) + } + if !bytes.Equal(exp.Raw, act) { + t.Errorf("expected wasm module %s to have raw data:\n\n%v\n\nbut got:\n\n%v", exp.Path, exp.Raw, act) + } + } + } + + }) + } +} diff --git a/cmd/bench.go b/cmd/bench.go index 0a9315a451..74bc4073c5 100644 --- a/cmd/bench.go +++ b/cmd/bench.go @@ -126,6 +126,7 @@ The optional "gobench" output format conforms to the Go Benchmark Data Format. addTargetFlag(benchCommand.Flags(), params.target) addV0CompatibleFlag(benchCommand.Flags(), ¶ms.v0Compatible, false) addV1CompatibleFlag(benchCommand.Flags(), ¶ms.v1Compatible, false) + addReadAstValuesFromStoreFlag(benchCommand.Flags(), ¶ms.ReadAstValuesFromStore, false) // Shared benchmark flags addCountFlag(benchCommand.Flags(), ¶ms.count, "benchmark") diff --git a/cmd/eval.go b/cmd/eval.go index 49e809331b..89bfead2af 100644 --- a/cmd/eval.go +++ b/cmd/eval.go @@ -36,44 +36,45 @@ import ( ) type evalCommandParams struct { - capabilities *capabilitiesFlag - coverage bool - partial bool - unknowns []string - disableInlining []string - shallowInlining bool - disableIndexing bool - disableEarlyExit bool - strictBuiltinErrors bool - showBuiltinErrors bool - dataPaths repeatedStringFlag - inputPath string - imports repeatedStringFlag - pkg string - stdin bool - stdinInput bool - explain *util.EnumFlag - metrics bool - instrument bool - ignore []string - outputFormat *util.EnumFlag - profile bool - profileCriteria repeatedStringFlag - profileLimit intFlag - count int - prettyLimit intFlag - fail bool - failDefined bool - bundlePaths repeatedStringFlag - schema *schemaFlags - target *util.EnumFlag - timeout time.Duration - optimizationLevel int - entrypoints repeatedStringFlag - strict bool - v0Compatible bool - v1Compatible bool - traceVarValues bool + capabilities *capabilitiesFlag + coverage bool + partial bool + unknowns []string + disableInlining []string + shallowInlining bool + disableIndexing bool + disableEarlyExit bool + strictBuiltinErrors bool + showBuiltinErrors bool + dataPaths repeatedStringFlag + inputPath string + imports repeatedStringFlag + pkg string + stdin bool + stdinInput bool + explain *util.EnumFlag + metrics bool + instrument bool + ignore []string + outputFormat *util.EnumFlag + profile bool + profileCriteria repeatedStringFlag + profileLimit intFlag + count int + prettyLimit intFlag + fail bool + failDefined bool + bundlePaths repeatedStringFlag + schema *schemaFlags + target *util.EnumFlag + timeout time.Duration + optimizationLevel int + entrypoints repeatedStringFlag + strict bool + v0Compatible bool + v1Compatible bool + traceVarValues bool + ReadAstValuesFromStore bool } func (p *evalCommandParams) regoVersion() ast.RegoVersion { @@ -344,6 +345,7 @@ access. addStrictFlag(evalCommand.Flags(), ¶ms.strict, false) addV0CompatibleFlag(evalCommand.Flags(), ¶ms.v0Compatible, false) addV1CompatibleFlag(evalCommand.Flags(), ¶ms.v1Compatible, false) + addReadAstValuesFromStoreFlag(evalCommand.Flags(), ¶ms.ReadAstValuesFromStore, false) RootCommand.AddCommand(evalCommand) } @@ -550,6 +552,7 @@ func setupEval(args []string, params evalCommandParams) (*evalContext, error) { rego.Query(query), rego.Runtime(info), rego.SetRegoVersion(params.regoVersion()), + rego.StoreReadAST(params.ReadAstValuesFromStore), } evalArgs := []rego.EvalOption{ diff --git a/cmd/eval_test.go b/cmd/eval_test.go index d942f9c4cc..ffeb1348a2 100755 --- a/cmd/eval_test.go +++ b/cmd/eval_test.go @@ -1045,6 +1045,48 @@ func TestEvalWithBundleDuplicateFileNames(t *testing.T) { }) } +func TestEvalWithReadASTValuesFromStore(t *testing.T) { + // Note: This test is a bit of a hack. It's difficult to discern whether AST values were actually read from the store. + // This just ensures that we don't get any unexpected errors when enabling the flag. + + tests := []struct { + note string + readAst bool + }{ + { + note: "read raw data from store", + readAst: false, + }, + { + note: "read AST values from store", + readAst: true, + }, + } + + files := map[string]string{ + "test.rego": ` + package test + p = 1`, + } + + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + test.WithTempFS(files, func(path string) { + params := newEvalCommandParams() + params.dataPaths = newrepeatedStringFlag([]string{path}) + params.ReadAstValuesFromStore = tc.readAst + + var buf bytes.Buffer + + defined, err := eval([]string{"data.test.p"}, params, &buf) + if !defined || err != nil { + t.Fatalf("Unexpected undefined or error: %v", err) + } + }) + }) + } +} + func TestEvalWithStrictBuiltinErrors(t *testing.T) { params := newEvalCommandParams() params.strictBuiltinErrors = true diff --git a/cmd/flags.go b/cmd/flags.go index 51d8711d8c..c0e65b5219 100644 --- a/cmd/flags.go +++ b/cmd/flags.go @@ -165,6 +165,10 @@ func addV1CompatibleFlag(fs *pflag.FlagSet, v1Compatible *bool, value bool) { fs.BoolVar(v1Compatible, "v1-compatible", value, "opt-in to OPA features and behaviors that are enabled by default in OPA v1.0") } +func addReadAstValuesFromStoreFlag(fs *pflag.FlagSet, readAstValuesFromStore *bool, value bool) { + fs.BoolVar(readAstValuesFromStore, "optimize-store-for-read-speed", value, "optimize default in-memory store for read speed. Has possible negative impact on memory footprint and write speed. See https://www.openpolicyagent.org/docs/latest/policy-performance/#storage-optimization for more details.") +} + func addE2EFlag(fs *pflag.FlagSet, e2e *bool, value bool) { fs.BoolVar(e2e, "e2e", value, "run benchmarks against a running OPA server") } diff --git a/cmd/run.go b/cmd/run.go index 7078dfe071..9e8cb0eac3 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -242,6 +242,7 @@ See https://godoc.org/crypto/tls#pkg-constants for more information. addConfigOverrides(runCommand.Flags(), &cmdParams.rt.ConfigOverrides) addConfigOverrideFiles(runCommand.Flags(), &cmdParams.rt.ConfigOverrideFiles) addBundleModeFlag(runCommand.Flags(), &cmdParams.rt.BundleMode, false) + addReadAstValuesFromStoreFlag(runCommand.Flags(), &cmdParams.rt.ReadAstValuesFromStore, false) runCommand.Flags().BoolVar(&cmdParams.skipVersionCheck, "skip-version-check", false, "disables anonymous version reporting (see: https://www.openpolicyagent.org/docs/latest/privacy)") err := runCommand.Flags().MarkDeprecated("skip-version-check", "\"skip-version-check\" is deprecated. Use \"disable-telemetry\" instead") diff --git a/docs/content/cli.md b/docs/content/cli.md index 668f06db0f..ee040c9640 100755 --- a/docs/content/cli.md +++ b/docs/content/cli.md @@ -40,30 +40,31 @@ opa bench [flags] ### Options ``` - --benchmem report memory allocations with benchmark results (default true) - -b, --bundle string set bundle file(s) or directory path(s). This flag can be repeated. - -c, --config-file string set path of configuration file - --count int number of times to repeat each benchmark (default 1) - -d, --data string set policy or data file(s). This flag can be repeated. - --e2e run benchmarks against a running OPA server - --fail exits with non-zero exit code on undefined/empty result and errors (default true) - -f, --format {json,pretty,gobench} set output format (default pretty) - -h, --help help for bench - --ignore strings set file and directory names to ignore during loading (e.g., '.*' excludes hidden files) - --import string set query import(s). This flag can be repeated. - -i, --input string set input file path - --metrics report query performance metrics (default true) - --package string set query package - -p, --partial perform partial evaluation - -s, --schema string set schema file path or directory path - --shutdown-grace-period int set the time (in seconds) that the server will wait to gracefully shut down. This flag is valid in 'e2e' mode only. (default 10) - --shutdown-wait-period int set the time (in seconds) that the server will wait before initiating shutdown. This flag is valid in 'e2e' mode only. - --stdin read query from stdin - -I, --stdin-input read input document from stdin - -t, --target {rego,wasm} set the runtime to exercise (default rego) - -u, --unknowns stringArray set paths to treat as unknown during partial evaluation (default [input]) - --v0-compatible opt-in to OPA features and behaviors prior to the OPA v1.0 release. Takes precedence over --v1-compatible - --v1-compatible opt-in to OPA features and behaviors that are enabled by default in OPA v1.0 + --benchmem report memory allocations with benchmark results (default true) + -b, --bundle string set bundle file(s) or directory path(s). This flag can be repeated. + -c, --config-file string set path of configuration file + --count int number of times to repeat each benchmark (default 1) + -d, --data string set policy or data file(s). This flag can be repeated. + --e2e run benchmarks against a running OPA server + --fail exits with non-zero exit code on undefined/empty result and errors (default true) + -f, --format {json,pretty,gobench} set output format (default pretty) + -h, --help help for bench + --ignore strings set file and directory names to ignore during loading (e.g., '.*' excludes hidden files) + --import string set query import(s). This flag can be repeated. + -i, --input string set input file path + --metrics report query performance metrics (default true) + --optimize-store-for-read-speed optimize default in-memory store for read speed. Has possible negative impact on memory footprint and write speed. See https://www.openpolicyagent.org/docs/latest/policy-performance/#storage-optimization for more details. + --package string set query package + -p, --partial perform partial evaluation + -s, --schema string set schema file path or directory path + --shutdown-grace-period int set the time (in seconds) that the server will wait to gracefully shut down. This flag is valid in 'e2e' mode only. (default 10) + --shutdown-wait-period int set the time (in seconds) that the server will wait before initiating shutdown. This flag is valid in 'e2e' mode only. + --stdin read query from stdin + -I, --stdin-input read input document from stdin + -t, --target {rego,wasm} set the runtime to exercise (default rego) + -u, --unknowns stringArray set paths to treat as unknown during partial evaluation (default [input]) + --v0-compatible opt-in to OPA features and behaviors prior to the OPA v1.0 release. Takes precedence over --v1-compatible + --v1-compatible opt-in to OPA features and behaviors that are enabled by default in OPA v1.0 ``` ____ @@ -556,6 +557,7 @@ opa eval [flags] --instrument enable query instrumentation metrics (implies --metrics) --metrics report query performance metrics -O, --optimize int set optimization level + --optimize-store-for-read-speed optimize default in-memory store for read speed. Has possible negative impact on memory footprint and write speed. See https://www.openpolicyagent.org/docs/latest/policy-performance/#storage-optimization for more details. --package string set query package -p, --partial perform partial evaluation --pretty-limit int set limit after which pretty output gets truncated (default 80) @@ -913,6 +915,7 @@ opa run [flags] --log-timestamp-format string set log timestamp format (OPA_LOG_TIMESTAMP_FORMAT environment variable) -m, --max-errors int set the number of errors to allow before compilation fails early (default 10) --min-tls-version {1.0,1.1,1.2,1.3} set minimum TLS version to be used by OPA's server (default 1.2) + --optimize-store-for-read-speed optimize default in-memory store for read speed. Has possible negative impact on memory footprint and write speed. See https://www.openpolicyagent.org/docs/latest/policy-performance/#storage-optimization for more details. --pprof enables pprof endpoints --ready-timeout int wait (in seconds) for configured plugins before starting server (value <= 0 disables ready check) --scope string scope to use for bundle signature verification diff --git a/docs/content/policy-performance.md b/docs/content/policy-performance.md index 3152d1ee3b..4b170ef24f 100644 --- a/docs/content/policy-performance.md +++ b/docs/content/policy-performance.md @@ -975,6 +975,22 @@ into call sites. In addition, more aggressive inlining is applied within rules. [copy propagation](https://en.wikipedia.org/wiki/Copy_propagation) and inlining of certain negated statements that would otherwise generate support rules. +## Storage Optimization + +### In-Memory Store Read Optimization + +During normal operation, data values read from storage are converted to an AST representation that is used during policy evaluation. +This conversion can be expensive both in execution time and in memory usage, especially for large data values. +The default in-memory store can be configured to optimize for read speed by precomputing the AST representation of data values during storage write operations. +This removes the time spent converting raw data values to AST during policy evaluation, improving performance. + +The memory footprint of the store will increase, as processed AST values generally take up more space in memory than the corresponding raw data values, but overall memory usage of OPA might remain more stable over time, as pre-converted data is shared across evaluations and isn't recomputed for each evaluation, which can cause spikes in memory usage. +Storage write operations will be slower due to the additional processing required to precompute the AST representation of data values. This can impact startup time and bundle loading/updates, especially for large data values. + +This feature can be enabled for `opa run`, `opa eval`, and `opa bench` by setting the `--optimize-store-for-read-speed` flag. + +Users are recommended to do performance testing to determine the optimal configuration for their use case. + ## Key Takeaways For high-performance use cases: diff --git a/internal/storage/mock/mock.go b/internal/storage/mock/mock.go index 5ecdb61c58..33c69f669d 100644 --- a/internal/storage/mock/mock.go +++ b/internal/storage/mock/mock.go @@ -45,6 +45,7 @@ func (t *Transaction) safeToUse() bool { // Store is a mock storage.Store implementation for use in testing. type Store struct { inmem storage.Store + storeOpts []inmem.Opt baseData map[string]interface{} Transactions []*Transaction Reads []*ReadCall @@ -69,16 +70,19 @@ type WriteCall struct { } // New creates a new mock Store -func New() *Store { - s := &Store{} +func New(opt ...inmem.Opt) *Store { + s := &Store{ + storeOpts: opt, + } s.Reset() return s } // NewWithData creates a store with some initial data -func NewWithData(data map[string]interface{}) *Store { +func NewWithData(data map[string]interface{}, opt ...inmem.Opt) *Store { s := &Store{ - baseData: data, + baseData: data, + storeOpts: opt, } s.Reset() return s @@ -90,9 +94,9 @@ func (s *Store) Reset() { s.Reads = []*ReadCall{} s.Writes = []*WriteCall{} if s.baseData != nil { - s.inmem = inmem.NewFromObject(s.baseData) + s.inmem = inmem.NewFromObjectWithOpts(s.baseData, s.storeOpts...) } else { - s.inmem = inmem.New() + s.inmem = inmem.NewWithOpts(s.storeOpts...) } } diff --git a/plugins/bundle/plugin_test.go b/plugins/bundle/plugin_test.go index fbab3ec219..4650689f73 100644 --- a/plugins/bundle/plugin_test.go +++ b/plugins/bundle/plugin_test.go @@ -34,7 +34,8 @@ import ( "github.com/open-policy-agent/opa/plugins" "github.com/open-policy-agent/opa/storage" "github.com/open-policy-agent/opa/storage/disk" - inmem "github.com/open-policy-agent/opa/storage/inmem/test" + "github.com/open-policy-agent/opa/storage/inmem" + inmemtst "github.com/open-policy-agent/opa/storage/inmem/test" "github.com/open-policy-agent/opa/util" "github.com/open-policy-agent/opa/util/test" ) @@ -111,6 +112,50 @@ func TestPluginOneShot(t *testing.T) { } } +func TestPluginOneShotWithAstStore(t *testing.T) { + + ctx := context.Background() + store := inmem.NewWithOpts(inmem.OptRoundTripOnWrite(false), inmem.OptReturnASTValuesOnRead(true)) + manager := getTestManagerWithOpts(nil, store) + plugin := New(&Config{}, manager) + bundleName := "test-bundle" + plugin.status[bundleName] = &Status{Name: bundleName, Metrics: metrics.New()} + plugin.downloaders[bundleName] = download.New(download.Config{}, plugin.manager.Client(""), bundleName) + + ensurePluginState(t, plugin, plugins.StateNotReady) + + b := bundle.Bundle{ + Manifest: bundle.Manifest{Revision: "quickbrownfaux"}, + Data: util.MustUnmarshalJSON([]byte(`{"foo": {"bar": 1, "baz": "qux"}}`)).(map[string]interface{}), + Etag: "foo", + } + + b.Manifest.Init() + + plugin.oneShot(ctx, bundleName, download.Update{Bundle: &b, Metrics: metrics.New(), Size: snapshotBundleSize}) + + ensurePluginState(t, plugin, plugins.StateOK) + + if status, ok := plugin.status[bundleName]; !ok { + t.Fatalf("Expected to find status for %s, found nil", bundleName) + } else if status.Type != bundle.SnapshotBundleType { + t.Fatalf("expected snapshot bundle but got %v", status.Type) + } else if status.Size != snapshotBundleSize { + t.Fatalf("expected snapshot bundle size %d but got %d", snapshotBundleSize, status.Size) + } + + txn := storage.NewTransactionOrDie(ctx, manager.Store) + defer manager.Store.Abort(ctx, txn) + + data, err := manager.Store.Read(ctx, txn, storage.Path{}) + expData := ast.MustParseTerm(`{"foo": {"bar": 1, "baz": "qux"}, "system": {"bundles": {"test-bundle": {"etag": "foo", "manifest": {"revision": "quickbrownfaux", "roots": [""]}}}}}`) + if err != nil { + t.Fatal(err) + } else if ast.Compare(data, expData) != 0 { + t.Fatalf("Bad data content. Exp:\n%v\n\nGot:\n\n%v", expData, data) + } +} + func TestPluginOneShotV1Compatible(t *testing.T) { // Note: modules are parsed before passed to plugin, so any expected errors must be triggered by the compiler stage. tests := []struct { @@ -403,7 +448,7 @@ corge contains 1 if { t.Run(tc.note, func(t *testing.T) { ctx := context.Background() managerPopts := ast.ParserOptions{RegoVersion: tc.managerRegoVersion} - manager, err := plugins.New(nil, "test-instance-id", inmem.New(), + manager, err := plugins.New(nil, "test-instance-id", inmemtst.New(), plugins.WithParserOptions(managerPopts)) if err != nil { t.Fatal(err) @@ -740,49 +785,65 @@ func TestPluginOneShotWithAuthzSchemaVerificationNonDefaultAuthzPath(t *testing. } func TestPluginStartLazyLoadInMem(t *testing.T) { - ctx := context.Background() - - module := "package authz\n\ncorge=1" - - // setup fake http server with mock bundle - mockBundle1 := bundle.Bundle{ - Data: map[string]interface{}{"p": "x1"}, - Modules: []bundle.ModuleFile{ - { - URL: "/bar/policy.rego", - Path: "/bar/policy.rego", - Parsed: ast.MustParseModule(module), - Raw: []byte(module), - }, + readMode := []struct { + note string + readAst bool + }{ + { + note: "read raw", + readAst: false, }, - Manifest: bundle.Manifest{ - Roots: &[]string{"p", "authz"}, + { + note: "read ast", + readAst: true, }, } - s1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - err := bundle.NewWriter(w).Write(mockBundle1) - if err != nil { - t.Fatal(err) - } - })) + for _, rm := range readMode { + t.Run(rm.note, func(t *testing.T) { + ctx := context.Background() - mockBundle2 := bundle.Bundle{ - Data: map[string]interface{}{"q": "x2"}, - Modules: []bundle.ModuleFile{}, - Manifest: bundle.Manifest{ - Roots: &[]string{"q"}, - }, - } + module := "package authz\n\ncorge=1" - s2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - err := bundle.NewWriter(w).Write(mockBundle2) - if err != nil { - t.Fatal(err) - } - })) + // setup fake http server with mock bundle + mockBundle1 := bundle.Bundle{ + Data: map[string]interface{}{"p": "x1"}, + Modules: []bundle.ModuleFile{ + { + URL: "/bar/policy.rego", + Path: "/bar/policy.rego", + Parsed: ast.MustParseModule(module), + Raw: []byte(module), + }, + }, + Manifest: bundle.Manifest{ + Roots: &[]string{"p", "authz"}, + }, + } - config := []byte(fmt.Sprintf(`{ + s1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + err := bundle.NewWriter(w).Write(mockBundle1) + if err != nil { + t.Fatal(err) + } + })) + + mockBundle2 := bundle.Bundle{ + Data: map[string]interface{}{"q": "x2"}, + Modules: []bundle.ModuleFile{}, + Manifest: bundle.Manifest{ + Roots: &[]string{"q"}, + }, + } + + s2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + err := bundle.NewWriter(w).Write(mockBundle2) + if err != nil { + t.Fatal(err) + } + })) + + config := []byte(fmt.Sprintf(`{ "services": { "default": { "url": %q @@ -793,89 +854,115 @@ func TestPluginStartLazyLoadInMem(t *testing.T) { } }`, s1.URL, s2.URL)) - manager := getTestManagerWithOpts(config) - defer manager.Stop(ctx) + manager := getTestManagerWithOpts(config, inmem.NewWithOpts(inmem.OptReturnASTValuesOnRead(rm.readAst))) + defer manager.Stop(ctx) - var mode plugins.TriggerMode = "manual" + var mode plugins.TriggerMode = "manual" - plugin := New(&Config{ - Bundles: map[string]*Source{ - "test-1": { - Service: "default", - SizeLimitBytes: int64(bundle.DefaultSizeLimitBytes), - Config: download.Config{Trigger: &mode}, - }, - "test-2": { - Service: "acmecorp", - SizeLimitBytes: int64(bundle.DefaultSizeLimitBytes), - Config: download.Config{Trigger: &mode}, - }, - }, - }, manager) + plugin := New(&Config{ + Bundles: map[string]*Source{ + "test-1": { + Service: "default", + SizeLimitBytes: int64(bundle.DefaultSizeLimitBytes), + Config: download.Config{Trigger: &mode}, + }, + "test-2": { + Service: "acmecorp", + SizeLimitBytes: int64(bundle.DefaultSizeLimitBytes), + Config: download.Config{Trigger: &mode}, + }, + }, + }, manager) - statusCh := make(chan map[string]*Status) + statusCh := make(chan map[string]*Status) - // register for bundle updates to observe changes and start the plugin - plugin.RegisterBulkListener("test-case", func(st map[string]*Status) { - statusCh <- st - }) + // register for bundle updates to observe changes and start the plugin + plugin.RegisterBulkListener("test-case", func(st map[string]*Status) { + statusCh <- st + }) - err := plugin.Start(ctx) - if err != nil { - t.Fatal(err) - } + err := plugin.Start(ctx) + if err != nil { + t.Fatal(err) + } - // manually trigger bundle download on all configured bundles - go func() { - _ = plugin.Trigger(ctx) - }() + // manually trigger bundle download on all configured bundles + go func() { + _ = plugin.Trigger(ctx) + }() - // wait for bundle update and then assert on data content - <-statusCh - <-statusCh + // wait for bundle update and then assert on data content + <-statusCh + <-statusCh - result, err := storage.ReadOne(ctx, manager.Store, storage.Path{"p"}) - if err != nil { - t.Fatal(err) - } + result, err := storage.ReadOne(ctx, manager.Store, storage.Path{"p"}) + if err != nil { + t.Fatal(err) + } - if !reflect.DeepEqual(result, mockBundle1.Data["p"]) { - t.Fatalf("expected data to be %v but got %v", mockBundle1.Data, result) - } + if rm.readAst { + expected, _ := ast.InterfaceToValue(mockBundle1.Data["p"]) + if ast.Compare(result, expected) != 0 { + t.Fatalf("expected data to be %v but got %v", expected, result) + } + } else { + if !reflect.DeepEqual(result, mockBundle1.Data["p"]) { + t.Fatalf("expected data to be %v but got %v", mockBundle1.Data, result) + } + } - result, err = storage.ReadOne(ctx, manager.Store, storage.Path{"q"}) - if err != nil { - t.Fatal(err) - } + result, err = storage.ReadOne(ctx, manager.Store, storage.Path{"q"}) + if err != nil { + t.Fatal(err) + } - if !reflect.DeepEqual(result, mockBundle2.Data["q"]) { - t.Fatalf("expected data to be %v but got %v", mockBundle2.Data, result) - } + if rm.readAst { + expected, _ := ast.InterfaceToValue(mockBundle2.Data["q"]) + if ast.Compare(result, expected) != 0 { + t.Fatalf("expected data to be %v but got %v", expected, result) + } + } else { + if !reflect.DeepEqual(result, mockBundle2.Data["q"]) { + t.Fatalf("expected data to be %v but got %v", mockBundle2.Data, result) + } + } - txn := storage.NewTransactionOrDie(ctx, manager.Store) - defer manager.Store.Abort(ctx, txn) + txn := storage.NewTransactionOrDie(ctx, manager.Store) + defer manager.Store.Abort(ctx, txn) - ids, err := manager.Store.ListPolicies(ctx, txn) - if err != nil { - t.Fatal(err) - } else if len(ids) != 1 { - t.Fatal("Expected 1 policy") - } + ids, err := manager.Store.ListPolicies(ctx, txn) + if err != nil { + t.Fatal(err) + } else if len(ids) != 1 { + t.Fatal("Expected 1 policy") + } - bs, err := manager.Store.GetPolicy(ctx, txn, ids[0]) - exp := []byte("package authz\n\ncorge=1") - if err != nil { - t.Fatal(err) - } else if !bytes.Equal(bs, exp) { - t.Fatalf("Bad policy content. Exp:\n%v\n\nGot:\n\n%v", string(exp), string(bs)) - } + bs, err := manager.Store.GetPolicy(ctx, txn, ids[0]) + exp := []byte("package authz\n\ncorge=1") + if err != nil { + t.Fatal(err) + } else if !bytes.Equal(bs, exp) { + t.Fatalf("Bad policy content. Exp:\n%v\n\nGot:\n\n%v", string(exp), string(bs)) + } - data, err := manager.Store.Read(ctx, txn, storage.Path{}) - expData := util.MustUnmarshalJSON([]byte(`{"p": "x1", "q": "x2", "system": {"bundles": {"test-1": {"etag": "", "manifest": {"revision": "", "roots": ["p", "authz"]}}, "test-2": {"etag": "", "manifest": {"revision": "", "roots": ["q"]}}}}}`)) - if err != nil { - t.Fatal(err) - } else if !reflect.DeepEqual(data, expData) { - t.Fatalf("Bad data content. Exp:\n%v\n\nGot:\n\n%v", expData, data) + data, err := manager.Store.Read(ctx, txn, storage.Path{}) + if err != nil { + t.Fatal(err) + } + + expected := `{"p": "x1", "q": "x2", "system": {"bundles": {"test-1": {"etag": "", "manifest": {"revision": "", "roots": ["p", "authz"]}}, "test-2": {"etag": "", "manifest": {"revision": "", "roots": ["q"]}}}}}` + if rm.readAst { + expData := ast.MustParseTerm(expected) + if ast.Compare(data, expData) != 0 { + t.Fatalf("Bad data content. Exp:\n%v\n\nGot:\n\n%v", expData, data) + } + } else { + expData := util.MustUnmarshalJSON([]byte(expected)) + if !reflect.DeepEqual(data, expData) { + t.Fatalf("Bad data content. Exp:\n%v\n\nGot:\n\n%v", expData, data) + } + } + }) } } @@ -1081,6 +1168,104 @@ func TestPluginOneShotDeltaBundle(t *testing.T) { } } +func TestPluginOneShotDeltaBundleWithAstStore(t *testing.T) { + + ctx := context.Background() + store := inmem.NewWithOpts(inmem.OptRoundTripOnWrite(false), inmem.OptReturnASTValuesOnRead(true)) + manager := getTestManagerWithOpts(nil, store) + plugin := New(&Config{}, manager) + bundleName := "test-bundle" + plugin.status[bundleName] = &Status{Name: bundleName, Metrics: metrics.New()} + plugin.downloaders[bundleName] = download.New(download.Config{}, plugin.manager.Client(""), bundleName) + + ensurePluginState(t, plugin, plugins.StateNotReady) + + module := "package a\n\ncorge=1" + + b := bundle.Bundle{ + Manifest: bundle.Manifest{Revision: "quickbrownfaux", Roots: &[]string{"a"}}, + Data: map[string]interface{}{ + "a": map[string]interface{}{ + "baz": "qux", + }, + }, + Modules: []bundle.ModuleFile{ + { + Path: "a/policy.rego", + Parsed: ast.MustParseModule(module), + Raw: []byte(module), + }, + }, + } + + plugin.oneShot(ctx, bundleName, download.Update{Bundle: &b, Metrics: metrics.New()}) + + ensurePluginState(t, plugin, plugins.StateOK) + + // simulate a delta bundle download + + // replace a value + p1 := bundle.PatchOperation{ + Op: "replace", + Path: "a/baz", + Value: "bux", + } + + // add a new object member + p2 := bundle.PatchOperation{ + Op: "upsert", + Path: "/a/foo", + Value: []interface{}{"hello", "world"}, + } + + b2 := bundle.Bundle{ + Manifest: bundle.Manifest{Revision: "delta", Roots: &[]string{"a"}}, + Patch: bundle.Patch{Data: []bundle.PatchOperation{p1, p2}}, + Etag: "foo", + } + + plugin.process(ctx, bundleName, download.Update{Bundle: &b2, Metrics: metrics.New(), Size: deltaBundleSize}) + + ensurePluginState(t, plugin, plugins.StateOK) + + if status, ok := plugin.status[bundleName]; !ok { + t.Fatalf("Expected to find status for %s, found nil", bundleName) + } else if status.Type != bundle.DeltaBundleType { + t.Fatalf("expected delta bundle but got %v", status.Type) + } else if status.Size != deltaBundleSize { + t.Fatalf("expected delta bundle size %d but got %d", deltaBundleSize, status.Size) + } + + txn := storage.NewTransactionOrDie(ctx, manager.Store) + defer manager.Store.Abort(ctx, txn) + + ids, err := manager.Store.ListPolicies(ctx, txn) + if err != nil { + t.Fatal(err) + } + if len(ids) != 1 { + t.Fatalf("Expected 1 policy, got %d", len(ids)) + } + + bs, err := manager.Store.GetPolicy(ctx, txn, ids[0]) + if err != nil { + t.Fatal(err) + } + exp := []byte("package a\n\ncorge=1") + if !bytes.Equal(bs, exp) { + t.Fatalf("Bad policy content. Exp:\n%v\n\nGot:\n\n%v", string(exp), string(bs)) + } + + data, err := manager.Store.Read(ctx, txn, storage.Path{}) + if err != nil { + t.Fatal(err) + } + expData := ast.MustParseTerm(`{"a": {"baz": "bux", "foo": ["hello", "world"]}, "system": {"bundles": {"test-bundle": {"etag": "foo", "manifest": {"revision": "delta", "roots": ["a"]}}}}}`) + if ast.Compare(data, expData) != 0 { + t.Fatalf("Bad data content. Exp:\n%#v\n\nGot:\n\n%#v", expData, data) + } +} + func TestPluginStart(t *testing.T) { ctx := context.Background() @@ -1318,7 +1503,7 @@ corge contains 1 if { popts := ast.ParserOptions{RegoVersion: regoVersion} ctx := context.Background() - manager, err := plugins.New(nil, "test-instance-id", inmem.New(), plugins.WithParserOptions(popts)) + manager, err := plugins.New(nil, "test-instance-id", inmemtst.New(), plugins.WithParserOptions(popts)) if err != nil { t.Fatal("unexpected error:", err) } @@ -1594,7 +1779,7 @@ corge contains 1 if { t.Run(tc.note, func(t *testing.T) { ctx := context.Background() managerPopts := ast.ParserOptions{RegoVersion: tc.managerRegoVersion} - manager, err := plugins.New(nil, "test-instance-id", inmem.New(), + manager, err := plugins.New(nil, "test-instance-id", inmemtst.New(), plugins.WithParserOptions(managerPopts)) if err != nil { t.Fatal("unexpected error:", err) @@ -2127,7 +2312,7 @@ corge contains 2 if { popts := ast.ParserOptions{RegoVersion: regoVersion} ctx := context.Background() - manager, err := plugins.New(nil, "test-instance-id", inmem.New(), + manager, err := plugins.New(nil, "test-instance-id", inmemtst.New(), plugins.WithParserOptions(popts)) if err != nil { t.Fatal("unexpected error:", err) @@ -2334,7 +2519,7 @@ corge contains 1 if { t.Run(tc.note, func(t *testing.T) { ctx := context.Background() managerPopts := ast.ParserOptions{RegoVersion: tc.managerRegoVersion} - manager, err := plugins.New(nil, "test-instance-id", inmem.New(), + manager, err := plugins.New(nil, "test-instance-id", inmemtst.New(), plugins.WithParserOptions(managerPopts)) if err != nil { t.Fatal("unexpected error:", err) @@ -3399,135 +3584,163 @@ p contains x if { x = 1 }` } func TestPluginActivateScopedBundle(t *testing.T) { - - ctx := context.Background() - manager := getTestManager() - plugin := Plugin{ - manager: manager, - status: map[string]*Status{}, - etags: map[string]string{}, - downloaders: map[string]Loader{}, + readMode := []struct { + note string + readAst bool + }{ + { + note: "read raw", + readAst: false, + }, + { + note: "read ast", + readAst: true, + }, } - bundleName := "test-bundle" - plugin.status[bundleName] = &Status{Name: bundleName} - plugin.downloaders[bundleName] = download.New(download.Config{}, plugin.manager.Client(""), bundleName) - // Transact test data and policies that represent data coming from - // _outside_ the bundle. The test will verify that data _outside_ - // the bundle is both not erased and is overwritten appropriately. - // - // The test data claims a/{a1-6} where even paths are policy and - // odd paths are raw JSON. - if err := storage.Txn(ctx, manager.Store, storage.WriteParams, func(txn storage.Transaction) error { + for _, rm := range readMode { + t.Run(rm.note, func(t *testing.T) { + ctx := context.Background() + manager := getTestManagerWithOpts(nil, inmem.NewWithOpts(inmem.OptReturnASTValuesOnRead(rm.readAst))) + plugin := Plugin{ + manager: manager, + status: map[string]*Status{}, + etags: map[string]string{}, + downloaders: map[string]Loader{}, + } + bundleName := "test-bundle" + plugin.status[bundleName] = &Status{Name: bundleName} + plugin.downloaders[bundleName] = download.New(download.Config{}, plugin.manager.Client(""), bundleName) - externalData := map[string]interface{}{"a": map[string]interface{}{"a1": "x1", "a3": "x2", "a5": "x3"}} + // Transact test data and policies that represent data coming from + // _outside_ the bundle. The test will verify that data _outside_ + // the bundle is both not erased and is overwritten appropriately. + // + // The test data claims a/{a1-6} where even paths are policy and + // odd paths are raw JSON. + if err := storage.Txn(ctx, manager.Store, storage.WriteParams, func(txn storage.Transaction) error { - if err := manager.Store.Write(ctx, txn, storage.AddOp, storage.Path{}, externalData); err != nil { - return err - } - if err := manager.Store.UpsertPolicy(ctx, txn, "some/id1", []byte(`package a.a2`)); err != nil { - return err - } - if err := manager.Store.UpsertPolicy(ctx, txn, "some/id2", []byte(`package a.a4`)); err != nil { - return err - } - return manager.Store.UpsertPolicy(ctx, txn, "some/id3", []byte(`package a.a6`)) - }); err != nil { - t.Fatal(err) - } + externalData := map[string]interface{}{"a": map[string]interface{}{"a1": "x1", "a3": "x2", "a5": "x3"}} - // Activate a bundle that is scoped to a/a1 and a/a2. This will - // erase and overwrite the external data at these paths but leave - // a3-6 untouched. - module := "package a.a2\n\nbar=1" - - b := bundle.Bundle{ - Manifest: bundle.Manifest{Revision: "quickbrownfaux", Roots: &[]string{"a/a1", "a/a2"}}, - Data: map[string]interface{}{ - "a": map[string]interface{}{ - "a1": "foo", - }, - }, - Modules: []bundle.ModuleFile{ - { - Path: "bundle/id1", - Parsed: ast.MustParseModule(module), - Raw: []byte(module), - }, - }, - } + if err := manager.Store.Write(ctx, txn, storage.AddOp, storage.Path{}, externalData); err != nil { + return err + } + if err := manager.Store.UpsertPolicy(ctx, txn, "some/id1", []byte(`package a.a2`)); err != nil { + return err + } + if err := manager.Store.UpsertPolicy(ctx, txn, "some/id2", []byte(`package a.a4`)); err != nil { + return err + } + return manager.Store.UpsertPolicy(ctx, txn, "some/id3", []byte(`package a.a6`)) + }); err != nil { + t.Fatal(err) + } - b.Manifest.Init() + // Activate a bundle that is scoped to a/a1 and a/a2. This will + // erase and overwrite the external data at these paths but leave + // a3-6 untouched. + module := "package a.a2\n\nbar=1" - plugin.oneShot(ctx, bundleName, download.Update{Bundle: &b}) + b := bundle.Bundle{ + Manifest: bundle.Manifest{Revision: "quickbrownfaux", Roots: &[]string{"a/a1", "a/a2"}}, + Data: map[string]interface{}{ + "a": map[string]interface{}{ + "a1": "foo", + }, + }, + Modules: []bundle.ModuleFile{ + { + Path: "bundle/id1", + Parsed: ast.MustParseModule(module), + Raw: []byte(module), + }, + }, + } - // Ensure a/a3-6 are intact. a1-2 are overwritten by bundle, and - // that the manifest has been written to storage. - expData := util.MustUnmarshalJSON([]byte(`{"a1": "foo", "a3": "x2", "a5": "x3"}`)) - expIDs := []string{filepath.Join(bundleName, "bundle/id1"), "some/id2", "some/id3"} - validateStoreState(ctx, t, manager.Store, "/a", expData, expIDs, bundleName, "quickbrownfaux", nil) + b.Manifest.Init() - // Activate a bundle that is scoped to a/a3 ad a/a6. Include a function - // inside package a.a4 that we can depend on outside of the bundle scope to - // exercise the compile check with remaining modules. - module = "package a.a4\n\nbar=1\n\nfunc(x) = x" + plugin.oneShot(ctx, bundleName, download.Update{Bundle: &b}) - b = bundle.Bundle{ - Manifest: bundle.Manifest{Revision: "quickbrownfaux-2", Roots: &[]string{"a/a3", "a/a4"}, - Metadata: map[string]interface{}{ - "a": map[string]interface{}{ - "a1": "deadbeef", + // Ensure a/a3-6 are intact. a1-2 are overwritten by bundle, and + // that the manifest has been written to storage. + exp := `{"a1": "foo", "a3": "x2", "a5": "x3"}` + var expData interface{} + if rm.readAst { + expData = ast.MustParseTerm(exp).Value + } else { + expData = util.MustUnmarshalJSON([]byte(exp)) + } + expIDs := []string{filepath.Join(bundleName, "bundle/id1"), "some/id2", "some/id3"} + validateStoreState(ctx, t, manager.Store, "/a", expData, expIDs, bundleName, "quickbrownfaux", nil) + + // Activate a bundle that is scoped to a/a3 ad a/a6. Include a function + // inside package a.a4 that we can depend on outside of the bundle scope to + // exercise the compile check with remaining modules. + module = "package a.a4\n\nbar=1\n\nfunc(x) = x" + + b = bundle.Bundle{ + Manifest: bundle.Manifest{Revision: "quickbrownfaux-2", Roots: &[]string{"a/a3", "a/a4"}, + Metadata: map[string]interface{}{ + "a": map[string]interface{}{ + "a1": "deadbeef", + }, + }, }, - }, - }, - Data: map[string]interface{}{ - "a": map[string]interface{}{ - "a3": "foo", - }, - }, - Modules: []bundle.ModuleFile{ - { - Path: "bundle/id2", - Parsed: ast.MustParseModule(module), - Raw: []byte(module), - }, - }, - } - - b.Manifest.Init() - plugin.oneShot(ctx, bundleName, download.Update{Bundle: &b}) + Data: map[string]interface{}{ + "a": map[string]interface{}{ + "a3": "foo", + }, + }, + Modules: []bundle.ModuleFile{ + { + Path: "bundle/id2", + Parsed: ast.MustParseModule(module), + Raw: []byte(module), + }, + }, + } - // Ensure a/a5-a6 are intact. a3 and a4 are overwritten by bundle. - expData = util.MustUnmarshalJSON([]byte(`{"a3": "foo", "a5": "x3"}`)) - expIDs = []string{filepath.Join(bundleName, "bundle/id2"), "some/id3"} - validateStoreState(ctx, t, manager.Store, "/a", expData, expIDs, bundleName, "quickbrownfaux-2", - map[string]interface{}{ - "a": map[string]interface{}{"a1": "deadbeef"}, - }) + b.Manifest.Init() + plugin.oneShot(ctx, bundleName, download.Update{Bundle: &b}) - // Upsert policy outside of bundle scope that depends on bundle. - if err := storage.Txn(ctx, manager.Store, storage.WriteParams, func(txn storage.Transaction) error { - return manager.Store.UpsertPolicy(ctx, txn, "not_scoped", []byte("package not_scoped\np { data.a.a4.func(1) = 1 }")) - }); err != nil { - t.Fatal(err) - } + // Ensure a/a5-a6 are intact. a3 and a4 are overwritten by bundle. + exp = `{"a3": "foo", "a5": "x3"}` + if rm.readAst { + expData = ast.MustParseTerm(exp).Value + } else { + expData = util.MustUnmarshalJSON([]byte(exp)) + } + expIDs = []string{filepath.Join(bundleName, "bundle/id2"), "some/id3"} + validateStoreState(ctx, t, manager.Store, "/a", expData, expIDs, bundleName, "quickbrownfaux-2", + map[string]interface{}{ + "a": map[string]interface{}{"a1": "deadbeef"}, + }) - b = bundle.Bundle{ - Manifest: bundle.Manifest{Revision: "quickbrownfaux-3", Roots: &[]string{"a/a3", "a/a4"}}, - Data: map[string]interface{}{}, - Modules: []bundle.ModuleFile{}, - } + // Upsert policy outside of bundle scope that depends on bundle. + if err := storage.Txn(ctx, manager.Store, storage.WriteParams, func(txn storage.Transaction) error { + return manager.Store.UpsertPolicy(ctx, txn, "not_scoped", []byte("package not_scoped\np { data.a.a4.func(1) = 1 }")) + }); err != nil { + t.Fatal(err) + } - b.Manifest.Init() - plugin.oneShot(ctx, bundleName, download.Update{Bundle: &b}) + b = bundle.Bundle{ + Manifest: bundle.Manifest{Revision: "quickbrownfaux-3", Roots: &[]string{"a/a3", "a/a4"}}, + Data: map[string]interface{}{}, + Modules: []bundle.ModuleFile{}, + } - // Ensure bundle activation failed by checking that previous revision is - // still active. - expIDs = []string{filepath.Join(bundleName, "bundle/id2"), "not_scoped", "some/id3"} - validateStoreState(ctx, t, manager.Store, "/a", expData, expIDs, bundleName, "quickbrownfaux-2", - map[string]interface{}{ - "a": map[string]interface{}{"a1": "deadbeef"}, + b.Manifest.Init() + plugin.oneShot(ctx, bundleName, download.Update{Bundle: &b}) + + // Ensure bundle activation failed by checking that previous revision is + // still active. + expIDs = []string{filepath.Join(bundleName, "bundle/id2"), "not_scoped", "some/id3"} + validateStoreState(ctx, t, manager.Store, "/a", expData, expIDs, bundleName, "quickbrownfaux-2", + map[string]interface{}{ + "a": map[string]interface{}{"a1": "deadbeef"}, + }) }) + } } func TestPluginSetCompilerOnContext(t *testing.T) { @@ -3596,7 +3809,7 @@ func getTestManager() *plugins.Manager { } func getTestManagerWithOpts(config []byte, stores ...storage.Store) *plugins.Manager { - store := inmem.New() + store := inmemtst.New() if len(stores) == 1 { store = stores[0] } @@ -4089,6 +4302,111 @@ func TestUpgradeLegacyBundleToMultiBundleNewBundles(t *testing.T) { } } +func TestLegacyBundleDataRead(t *testing.T) { + readModes := []struct { + note string + readAst bool + }{ + { + note: "read raw", + readAst: false, + }, + { + note: "read ast", + readAst: true, + }, + } + + for _, rm := range readModes { + t.Run(rm.note, func(t *testing.T) { + ctx := context.Background() + manager := getTestManagerWithOpts(nil, inmem.NewWithOpts(inmem.OptReturnASTValuesOnRead(rm.readAst))) + + plugin := Plugin{ + manager: manager, + status: map[string]*Status{}, + etags: map[string]string{}, + downloaders: map[string]Loader{}, + } + + bundleName := "test-bundle" + plugin.status[bundleName] = &Status{Name: bundleName} + plugin.downloaders[bundleName] = download.New(download.Config{}, plugin.manager.Client(""), bundleName) + + tsURLBase := "/opa-test/" + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !strings.HasPrefix(r.URL.Path, tsURLBase) { + t.Fatalf("Invalid request URL path: %s, expected prefix %s", r.URL.Path, tsURLBase) + } + fmt.Fprintln(w, "") // Note: this is an invalid bundle and will fail the download + })) + defer ts.Close() + + serviceName := "test-svc" + err := manager.Reconfigure(&config.Config{ + Services: []byte(fmt.Sprintf("{\"%s\":{ \"url\": \"%s\"}}", serviceName, ts.URL+tsURLBase)), + }) + if err != nil { + t.Fatalf("Error configuring plugin manager: %s", err) + } + + var delay int64 = 10 + triggerPolling := plugins.TriggerPeriodic + downloadConf := download.Config{Polling: download.PollingConfig{MinDelaySeconds: &delay, MaxDelaySeconds: &delay}, Trigger: &triggerPolling} + + // Start with a "legacy" style config for a single bundle + plugin.config = Config{ + Bundles: map[string]*Source{ + bundleName: { + Config: downloadConf, + Service: serviceName, + }, + }, + Name: bundleName, + Service: serviceName, + Prefix: nil, + } + + module := "package a.a1\n\nbar=1" + + b := bundle.Bundle{ + Manifest: bundle.Manifest{Revision: "quickbrownfaux", Roots: &[]string{"a/a1", "a/a2"}}, + Data: map[string]interface{}{ + "a": map[string]interface{}{ + "a2": "foo", + }, + }, + Modules: []bundle.ModuleFile{ + { + Path: "bundle/id1", + Parsed: ast.MustParseModule(module), + Raw: []byte(module), + }, + }, + } + + b.Manifest.Init() + + if plugin.config.IsMultiBundle() { + t.Fatalf("Expected plugin to be in non-multi bundle config mode") + } + + plugin.oneShot(ctx, bundleName, download.Update{Bundle: &b}) + + exp := `{"a2": "foo"}` + var expData interface{} + if rm.readAst { + expData = ast.MustParseTerm(exp).Value + } else { + expData = util.MustUnmarshalJSON([]byte(exp)) + } + + expIDs := []string{"bundle/id1"} + validateStoreState(ctx, t, manager.Store, "/a", expData, expIDs, bundleName, "quickbrownfaux", nil) + }) + } +} + func TestSaveBundleToDiskNew(t *testing.T) { manager := getTestManager() @@ -4250,7 +4568,7 @@ func TestLoadBundleFromDisk(t *testing.T) { func TestLoadBundleFromDiskV1Compatible(t *testing.T) { popts := ast.ParserOptions{RegoVersion: ast.RegoV1} - manager, err := plugins.New(nil, "test-instance-id", inmem.New(), plugins.WithParserOptions(popts)) + manager, err := plugins.New(nil, "test-instance-id", inmemtst.New(), plugins.WithParserOptions(popts)) if err != nil { t.Fatal("unexpected error:", err) } @@ -4573,7 +4891,7 @@ p contains 7 if { f.Close() - manager, err := plugins.New(nil, "test-instance-id", inmem.New(), plugins.WithParserOptions(popts)) + manager, err := plugins.New(nil, "test-instance-id", inmemtst.New(), plugins.WithParserOptions(popts)) if err != nil { t.Fatal("unexpected error:", err) } @@ -4887,7 +5205,7 @@ p contains 7 if { f.Close() managerPopts := ast.ParserOptions{RegoVersion: tc.managerRegoVersion} - manager, err := plugins.New(nil, "test-instance-id", inmem.New(), + manager, err := plugins.New(nil, "test-instance-id", inmemtst.New(), plugins.WithParserOptions(managerPopts)) if err != nil { t.Fatal("unexpected error:", err) @@ -5089,7 +5407,7 @@ p contains 7 if { "test.rego": tc.module, }, func(dir string) { - manager, err := plugins.New(nil, "test-instance-id", inmem.New(), plugins.WithParserOptions(popts)) + manager, err := plugins.New(nil, "test-instance-id", inmemtst.New(), plugins.WithParserOptions(popts)) if err != nil { t.Fatal("unexpected error:", err) } @@ -5380,7 +5698,7 @@ p contains 7 if { }, func(dir string) { managerPopts := ast.ParserOptions{RegoVersion: tc.managerRegoVersion} - manager, err := plugins.New(nil, "test-instance-id", inmem.New(), + manager, err := plugins.New(nil, "test-instance-id", inmemtst.New(), plugins.WithParserOptions(managerPopts)) if err != nil { t.Fatal("unexpected error:", err) @@ -6339,8 +6657,14 @@ func validateStoreState(ctx context.Context, t *testing.T, store storage.Store, return err } - if !reflect.DeepEqual(value, expData) { - return fmt.Errorf("Expected %v but got %v", expData, value) + if expAst, ok := expData.(ast.Value); ok { + if ast.Compare(value, expAst) != 0 { + return fmt.Errorf("expected %v but got %v", expAst, value) + } + } else { + if !reflect.DeepEqual(value, expData) { + return fmt.Errorf("expected %v but got %v", expData, value) + } } ids, err := store.ListPolicies(ctx, txn) @@ -6352,24 +6676,24 @@ func validateStoreState(ctx context.Context, t *testing.T, store storage.Store, sort.Strings(expIDs) if !reflect.DeepEqual(ids, expIDs) { - return fmt.Errorf("Expected ids %v but got %v", expIDs, ids) + return fmt.Errorf("expected ids %v but got %v", expIDs, ids) } rev, err := bundle.ReadBundleRevisionFromStore(ctx, store, txn, expBundleName) if err != nil { - return fmt.Errorf("Unexpected error when reading bundle revision from store: %s", err) + return fmt.Errorf("unexpected error when reading bundle revision from store: %s", err) } if rev != expBundleRev { - return fmt.Errorf("Unexpected revision found on bundle: %s", rev) + return fmt.Errorf("unexpected revision found on bundle: %s", rev) } metadata, err := bundle.ReadBundleMetadataFromStore(ctx, store, txn, expBundleName) if err != nil { - return fmt.Errorf("Unexpected error when reading bundle metadata from store: %s", err) + return fmt.Errorf("unexpected error when reading bundle metadata from store: %s", err) } if !reflect.DeepEqual(expMetadata, metadata) { - return fmt.Errorf("Unexpected metadata found on bundle: %v", metadata) + return fmt.Errorf("unexpected metadata found on bundle: %v", metadata) } return nil diff --git a/rego/rego.go b/rego/rego.go index 91465b003d..64b4b9b93e 100644 --- a/rego/rego.go +++ b/rego/rego.go @@ -579,6 +579,7 @@ type Rego struct { compiler *ast.Compiler store storage.Store ownStore bool + ownStoreReadAst bool txn storage.Transaction metrics metrics.Metrics queryTracers []topdown.QueryTracer @@ -1007,6 +1008,15 @@ func Store(s storage.Store) func(r *Rego) { } } +// StoreReadAST returns an argument that sets whether the store should eagerly convert data to AST values. +// +// Only applicable when no store has been set on the Rego object through the Store option. +func StoreReadAST(enabled bool) func(r *Rego) { + return func(r *Rego) { + r.ownStoreReadAst = enabled + } +} + // Transaction returns an argument that sets the transaction to use for storage // layer operations. // @@ -1266,7 +1276,7 @@ func New(options ...func(r *Rego)) *Rego { } if r.store == nil { - r.store = inmem.New() + r.store = inmem.NewWithOpts(inmem.OptReturnASTValuesOnRead(r.ownStoreReadAst)) r.ownStore = true } else { r.ownStore = false diff --git a/runtime/runtime.go b/runtime/runtime.go index 9f218c87e6..d39c9efb5f 100644 --- a/runtime/runtime.go +++ b/runtime/runtime.go @@ -23,6 +23,7 @@ import ( "github.com/fsnotify/fsnotify" "github.com/gorilla/mux" + "github.com/open-policy-agent/opa/storage/inmem" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" "go.opentelemetry.io/otel/exporters/otlp/otlptrace" "go.opentelemetry.io/otel/propagation" @@ -53,7 +54,6 @@ import ( "github.com/open-policy-agent/opa/server" "github.com/open-policy-agent/opa/storage" "github.com/open-policy-agent/opa/storage/disk" - "github.com/open-policy-agent/opa/storage/inmem" "github.com/open-policy-agent/opa/tracing" "github.com/open-policy-agent/opa/util" "github.com/open-policy-agent/opa/version" @@ -240,6 +240,12 @@ type Params struct { // CipherSuites specifies the list of enabled TLS 1.0–1.2 cipher suites CipherSuites *[]uint16 + + // ReadAstValuesFromStore controls whether the storage layer should return AST values when reading from the store. + // This is an eager conversion, that comes with an upfront performance cost when updating the store (e.g. bundle updates). + // Evaluation performance is affected in that data doesn't need to be converted to AST during evaluation. + // Only applicable when using the default in-memory store, and not when used together with the DiskStorage option. + ReadAstValuesFromStore bool } func (p *Params) regoVersion() ast.RegoVersion { @@ -400,7 +406,8 @@ func NewRuntime(ctx context.Context, params Params) (*Runtime, error) { return nil, fmt.Errorf("initialize disk store: %w", err) } } else { - store = inmem.NewWithOpts(inmem.OptRoundTripOnWrite(false)) + store = inmem.NewWithOpts(inmem.OptRoundTripOnWrite(false), + inmem.OptReturnASTValuesOnRead(params.ReadAstValuesFromStore)) } traceExporter, tracerProvider, _, err := internal_tracing.Init(ctx, config, params.ID) diff --git a/runtime/runtime_test.go b/runtime/runtime_test.go index 44e6c034bb..e8bfcd0fb4 100644 --- a/runtime/runtime_test.go +++ b/runtime/runtime_test.go @@ -34,14 +34,37 @@ import ( ) func TestRuntimeProcessWatchEvents(t *testing.T) { - testRuntimeProcessWatchEvents(t, false) -} + tests := []struct { + note string + asBundle bool + readAst bool + }{ + { + note: "no bundle, read raw data", + }, + { + note: "no bundle, read ast", + readAst: true, + }, + { + note: "bundle, read raw data", + asBundle: true, + }, + { + note: "bundle, read ast", + asBundle: true, + readAst: true, + }, + } -func TestRuntimeProcessWatchEventsWithBundle(t *testing.T) { - testRuntimeProcessWatchEvents(t, true) + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + testRuntimeProcessWatchEvents(t, tc.asBundle, tc.readAst) + }) + } } -func testRuntimeProcessWatchEvents(t *testing.T, asBundle bool) { +func testRuntimeProcessWatchEvents(t *testing.T, asBundle bool, readAst bool) { t.Helper() ctx := context.Background() @@ -60,6 +83,7 @@ func testRuntimeProcessWatchEvents(t *testing.T, asBundle bool) { params := NewParams() params.Paths = []string{rootDir} params.BundleMode = asBundle + params.ReadAstValuesFromStore = readAst rt, err := NewRuntime(ctx, params) if err != nil { @@ -110,9 +134,18 @@ func testRuntimeProcessWatchEvents(t *testing.T, asBundle bool) { } rt.Store.Abort(ctx, txn) - if reflect.DeepEqual(val, expected) { - return // success + + if readAst { + exp, _ := ast.InterfaceToValue(expected) + if ast.Compare(val, exp) == 0 { + return // success + } + } else { + if reflect.DeepEqual(val, expected) { + return // success + } } + } t.Fatalf("Did not see expected change in %v, last value: %v, buf: %v", maxWaitTime, val, buf.String()) diff --git a/server/server_test.go b/server/server_test.go index a5819b681b..8e80f77117 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -2636,6 +2636,74 @@ func TestBundleNoRoots(t *testing.T) { } } +func TestDataUpdate(t *testing.T) { + tests := []struct { + note string + readAst bool + }{ + { + note: "read raw data", + }, + { + note: "read ast data", + readAst: true, + }, + } + + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + f := newFixtureWithStore(t, inmem.NewWithOpts(inmem.OptReturnASTValuesOnRead(tc.readAst))) + + // PUT data + + putData := `{"a":1,"b":2, "c": 3}` + err := f.v1(http.MethodPut, "/data/x", putData, 204, "") + if err != nil { + t.Fatal(err) + } + + req := newReqV1(http.MethodGet, "/data/x", "") + f.reset() + f.server.Handler.ServeHTTP(f.recorder, req) + + var result types.DataResponseV1 + + if err := util.NewJSONDecoder(f.recorder.Body).Decode(&result); err != nil { + t.Fatalf("Unexpected JSON decode error: %v", err) + } + + var expected interface{} + if err := util.UnmarshalJSON([]byte(putData), &expected); err != nil { + t.Fatalf("Unexpected JSON decode error: %v", err) + } + if result.Result == nil || !reflect.DeepEqual(*result.Result, expected) { + t.Fatalf("Expected %v but got: %v", expected, *result.Result) + } + + // DELETE data + + if err := f.v1(http.MethodDelete, "/data/x/b", "", 204, ""); err != nil { + t.Fatal("Unexpected error:", err) + } + + req = newReqV1(http.MethodGet, "/data/x", "") + f.reset() + f.server.Handler.ServeHTTP(f.recorder, req) + + if err := util.NewJSONDecoder(f.recorder.Body).Decode(&result); err != nil { + t.Fatalf("Unexpected JSON decode error: %v", err) + } + + if err := util.UnmarshalJSON([]byte(`{"a":1,"c": 3}`), &expected); err != nil { + t.Fatalf("Unexpected JSON decode error: %v", err) + } + if result.Result == nil || !reflect.DeepEqual(*result.Result, expected) { + t.Fatalf("Expected %v but got: %v", expected, *result.Result) + } + }) + } +} + func TestDataGetExplainFull(t *testing.T) { f := newFixture(t) diff --git a/storage/inmem/ast.go b/storage/inmem/ast.go new file mode 100644 index 0000000000..5a8a6743fa --- /dev/null +++ b/storage/inmem/ast.go @@ -0,0 +1,314 @@ +// Copyright 2024 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package inmem + +import ( + "fmt" + "strconv" + + "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/storage" + "github.com/open-policy-agent/opa/storage/internal/errors" + "github.com/open-policy-agent/opa/storage/internal/ptr" +) + +type updateAST struct { + path storage.Path // data path modified by update + remove bool // indicates whether update removes the value at path + value ast.Value // value to add/replace at path (ignored if remove is true) +} + +func (u *updateAST) Path() storage.Path { + return u.path +} + +func (u *updateAST) Remove() bool { + return u.remove +} + +func (u *updateAST) Set(v interface{}) { + if v, ok := v.(ast.Value); ok { + u.value = v + } else { + panic("illegal value type") // FIXME: do conversion? + } +} + +func (u *updateAST) Value() interface{} { + return u.value +} + +func (u *updateAST) Relative(path storage.Path) dataUpdate { + cpy := *u + cpy.path = cpy.path[len(path):] + return &cpy +} + +func (u *updateAST) Apply(v interface{}) interface{} { + if len(u.path) == 0 { + return u.value + } + + data, ok := v.(ast.Value) + if !ok { + panic(fmt.Errorf("illegal value type %T, expected ast.Value", v)) + } + + if u.remove { + newV, err := removeInAst(data, u.path) + if err != nil { + panic(err) + } + return newV + } + + // If we're not removing, we're replacing (adds are turned into replaces during updateAST creation). + newV, err := setInAst(data, u.path, u.value) + if err != nil { + panic(err) + } + return newV +} + +func newUpdateAST(data interface{}, op storage.PatchOp, path storage.Path, idx int, value ast.Value) (*updateAST, error) { + + switch data.(type) { + case ast.Null, ast.Boolean, ast.Number, ast.String: + return nil, errors.NewNotFoundError(path) + } + + switch data := data.(type) { + case ast.Object: + return newUpdateObjectAST(data, op, path, idx, value) + + case *ast.Array: + return newUpdateArrayAST(data, op, path, idx, value) + } + + return nil, &storage.Error{ + Code: storage.InternalErr, + Message: "invalid data value encountered", + } +} + +func newUpdateArrayAST(data *ast.Array, op storage.PatchOp, path storage.Path, idx int, value ast.Value) (*updateAST, error) { + + if idx == len(path)-1 { + if path[idx] == "-" || path[idx] == strconv.Itoa(data.Len()) { + if op != storage.AddOp { + return nil, invalidPatchError("%v: invalid patch path", path) + } + + cpy := data.Copy() + cpy = cpy.Append(ast.NewTerm(value)) + return &updateAST{path[:len(path)-1], false, cpy}, nil + } + + pos, err := ptr.ValidateASTArrayIndex(data, path[idx], path) + if err != nil { + return nil, err + } + + switch op { + case storage.AddOp: + var results []*ast.Term + for i := 0; i < data.Len(); i++ { + if i == pos { + results = append(results, ast.NewTerm(value)) + } + results = append(results, data.Elem(i)) + } + + return &updateAST{path[:len(path)-1], false, ast.NewArray(results...)}, nil + + case storage.RemoveOp: + var results []*ast.Term + for i := 0; i < data.Len(); i++ { + if i != pos { + results = append(results, data.Elem(i)) + } + } + return &updateAST{path[:len(path)-1], false, ast.NewArray(results...)}, nil + + default: + var results []*ast.Term + for i := 0; i < data.Len(); i++ { + if i == pos { + results = append(results, ast.NewTerm(value)) + } else { + results = append(results, data.Elem(i)) + } + } + + return &updateAST{path[:len(path)-1], false, ast.NewArray(results...)}, nil + } + } + + pos, err := ptr.ValidateASTArrayIndex(data, path[idx], path) + if err != nil { + return nil, err + } + + return newUpdateAST(data.Elem(pos).Value, op, path, idx+1, value) +} + +func newUpdateObjectAST(data ast.Object, op storage.PatchOp, path storage.Path, idx int, value ast.Value) (*updateAST, error) { + key := ast.StringTerm(path[idx]) + val := data.Get(key) + + if idx == len(path)-1 { + switch op { + case storage.ReplaceOp, storage.RemoveOp: + if val == nil { + return nil, errors.NewNotFoundError(path) + } + } + return &updateAST{path, op == storage.RemoveOp, value}, nil + } + + if val != nil { + return newUpdateAST(val.Value, op, path, idx+1, value) + } + + return nil, errors.NewNotFoundError(path) +} + +func interfaceToValue(v interface{}) (ast.Value, error) { + if v, ok := v.(ast.Value); ok { + return v, nil + } + return ast.InterfaceToValue(v) +} + +// setInAst updates the value in the AST at the given path with the given value. +// Values can only be replaced in arrays, not added. +// Values for new keys can be added to objects +func setInAst(data ast.Value, path storage.Path, value ast.Value) (ast.Value, error) { + if len(path) == 0 { + return data, nil + } + + switch data := data.(type) { + case ast.Object: + return setInAstObject(data, path, value) + case *ast.Array: + return setInAstArray(data, path, value) + default: + return nil, fmt.Errorf("illegal value type %T, expected ast.Object or ast.Array", data) + } +} + +func setInAstObject(obj ast.Object, path storage.Path, value ast.Value) (ast.Value, error) { + key := ast.StringTerm(path[0]) + + if len(path) == 1 { + obj.Insert(key, ast.NewTerm(value)) + return obj, nil + } + + child := obj.Get(key) + newChild, err := setInAst(child.Value, path[1:], value) + if err != nil { + return nil, err + } + obj.Insert(key, ast.NewTerm(newChild)) + return obj, nil +} + +func setInAstArray(arr *ast.Array, path storage.Path, value ast.Value) (ast.Value, error) { + idx, err := strconv.Atoi(path[0]) + if err != nil { + return nil, fmt.Errorf("illegal array index %v: %v", path[0], err) + } + + if idx < 0 || idx >= arr.Len() { + return arr, nil + } + + if len(path) == 1 { + arr.Set(idx, ast.NewTerm(value)) + return arr, nil + } + + child := arr.Elem(idx) + newChild, err := setInAst(child.Value, path[1:], value) + if err != nil { + return nil, err + } + arr.Set(idx, ast.NewTerm(newChild)) + return arr, nil +} + +func removeInAst(value ast.Value, path storage.Path) (ast.Value, error) { + if len(path) == 0 { + return value, nil + } + + switch value := value.(type) { + case ast.Object: + return removeInAstObject(value, path) + case *ast.Array: + return removeInAstArray(value, path) + default: + return nil, fmt.Errorf("illegal value type %T, expected ast.Object or ast.Array", value) + } +} + +func removeInAstObject(obj ast.Object, path storage.Path) (ast.Value, error) { + key := ast.StringTerm(path[0]) + + if len(path) == 1 { + var items [][2]*ast.Term + // Note: possibly expensive operation for large data. + obj.Foreach(func(k *ast.Term, v *ast.Term) { + if k.Equal(key) { + return + } + items = append(items, [2]*ast.Term{k, v}) + }) + return ast.NewObject(items...), nil + } + + if child := obj.Get(key); child != nil { + updatedChild, err := removeInAst(child.Value, path[1:]) + if err != nil { + return nil, err + } + obj.Insert(key, ast.NewTerm(updatedChild)) + } + + return obj, nil +} + +func removeInAstArray(arr *ast.Array, path storage.Path) (ast.Value, error) { + idx, err := strconv.Atoi(path[0]) + if err != nil { + // We expect the path to be valid at this point. + return arr, nil + } + + if idx < 0 || idx >= arr.Len() { + return arr, err + } + + if len(path) == 1 { + var elems []*ast.Term + // Note: possibly expensive operation for large data. + for i := 0; i < arr.Len(); i++ { + if i == idx { + continue + } + elems = append(elems, arr.Elem(i)) + } + return ast.NewArray(elems...), nil + } + + updatedChild, err := removeInAst(arr.Elem(idx).Value, path[1:]) + if err != nil { + return nil, err + } + arr.Set(idx, ast.NewTerm(updatedChild)) + return arr, nil +} diff --git a/storage/inmem/ast_test.go b/storage/inmem/ast_test.go new file mode 100644 index 0000000000..44c2a080fd --- /dev/null +++ b/storage/inmem/ast_test.go @@ -0,0 +1,200 @@ +// Copyright 2024 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package inmem + +import ( + "testing" + + "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/storage" +) + +func TestSetInAst(t *testing.T) { + tests := []struct { + note string + value string + path string + newValue string + expected string + }{ + { + note: "zero length path", + value: `{}`, + path: "/", + newValue: "42", + expected: "{}", + }, + { + note: "set object key", + value: `{"a": 1, "b": 2, "c": 3}`, + path: "/b", + newValue: "42", + expected: `{"a": 1, "b": 42, "c": 3}`, + }, + { + note: "set nested object key", + value: `{"a": {"b": 1, "c": 2, "d": 3}, "b": 4}`, + path: "/a/c", + newValue: "42", + expected: `{"a": {"b": 1, "c": 42, "d": 3}, "b": 4}`, + }, + // new keys can be added to objects + { + note: "add object key", + value: `{"a": 1, "b": 2, "c": 3}`, + path: "/d", + newValue: "42", + expected: `{"a": 1, "b": 2, "c": 3, "d": 42}`, + }, + { + note: "add nested object key", + value: `{"a": {"b": 1, "c": 2, "d": 3}, "b": 4}`, + path: "/a/e", + newValue: "42", + expected: `{"a": {"b": 1, "c": 2, "d": 3, "e": 42}, "b": 4}`, + }, + + { + note: "set array element", + value: `[1, 2, 3]`, + path: "/1", + newValue: "42", + expected: `[1, 42, 3]`, + }, + { + note: "set nested array element", + value: `[[1, 2], [3, 4], [5, 6]]`, + path: "/1/0", + newValue: "42", + expected: `[[1, 2], [42, 4], [5, 6]]`, + }, + } + + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + value := ast.MustParseTerm(tc.value).Value + path := storage.MustParsePath(tc.path) + newValue := ast.MustParseTerm(tc.newValue).Value + expected := ast.MustParseTerm(tc.expected).Value + + result, err := setInAst(value, path, newValue) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if expected.Compare(result) != 0 { + t.Fatalf("Expected:\n\n%v\n\nbut got:\n\n%v", expected, result) + } + + if result.Hash() != expected.Hash() { + t.Fatalf("Expected hash:\n\n%v\n\nbut got:\n\n%v", expected.Hash(), result.Hash()) + } + }) + } +} + +func TestRemoveInAst(t *testing.T) { + tests := []struct { + note string + value string + path string + expected string + }{ + { + note: "zero length path (no-op)", + value: `{"a": 1, "b": 2, "c": 3}`, + path: "/", + expected: `{"a": 1, "b": 2, "c": 3}`, + }, + { + note: "remove object key", + value: `{"a": 1, "b": 2, "c": 3}`, + path: "/b", + expected: `{"a": 1, "c": 3}`, + }, + { + note: "remove object key, no hit", + value: `{"a": 1, "b": 2, "c": 3}`, + path: "/d", + expected: `{"a": 1, "b": 2, "c": 3}`, + }, + { + note: "remove nested object key", + value: `{"a": {"b": 1, "c": 2, "d": 3}, "b": 4}`, + path: "/a/c", + expected: `{"a": {"b": 1, "d": 3}, "b": 4}`, + }, + { + note: "remove nested object key, no hit", + value: `{"a": {"b": 1, "c": 2, "d": 3}, "b": 4}`, + path: "/a/e", + expected: `{"a": {"b": 1, "c": 2, "d": 3}, "b": 4}`, + }, + + { + note: "remove array element", + value: `[1, 2, 3]`, + path: "/1", + expected: `[1, 3]`, + }, + { + note: "remove array element, no hit (over)", + value: `[1, 2, 3]`, + path: "/4", + expected: `[1, 2, 3]`, + }, + { + note: "remove array element, no hit (under)", + value: `[1, 2, 3]`, + path: "/-1", + expected: `[1, 2, 3]`, + }, + { + note: "remove nested array element", + value: `[[1, 2], [3, 4], [5, 6]]`, + path: "/1/0", + expected: `[[1, 2], [4], [5, 6]]`, + }, + { + note: "remove nested array element, no hit", + value: `[[1, 2], [3, 4], [5, 6]]`, + path: "/1/2", + expected: `[[1, 2], [3, 4], [5, 6]]`, + }, + { + note: "remove array element nested inside object", + value: `{"a": [1, 2, 3], "b": [4, 5, 6]}`, + path: "/a/1", + expected: `{"a": [1, 3], "b": [4, 5, 6]}`, + }, + { + note: "remove object key nested inside array", + value: `[{"a": 1, "b": 2}, {"a": 3, "b": 4}]`, + path: "/1/a", + expected: `[{"a": 1, "b": 2}, {"b": 4}]`, + }, + } + + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + value := ast.MustParseTerm(tc.value).Value + path := storage.MustParsePath(tc.path) + expected := ast.MustParseTerm(tc.expected).Value + + result, err := removeInAst(value, path) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if expected.Compare(result) != 0 { + t.Fatalf("Expected:\n\n%v\n\nbut got:\n\n%v", expected, result) + } + + if result.Hash() != expected.Hash() { + t.Fatalf("Expected hash:\n\n%v\n\nbut got:\n\n%v", expected.Hash(), result.Hash()) + } + }) + } +} diff --git a/storage/inmem/inmem.go b/storage/inmem/inmem.go index b6433795a3..9f5b8ba258 100644 --- a/storage/inmem/inmem.go +++ b/storage/inmem/inmem.go @@ -24,6 +24,7 @@ import ( "sync" "sync/atomic" + "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/internal/merge" "github.com/open-policy-agent/opa/storage" "github.com/open-policy-agent/opa/util" @@ -37,16 +38,22 @@ func New() storage.Store { // NewWithOpts returns an empty in-memory store, with extra options passed. func NewWithOpts(opts ...Opt) storage.Store { s := &store{ - data: map[string]interface{}{}, - triggers: map[*handle]storage.TriggerConfig{}, - policies: map[string][]byte{}, - roundTripOnWrite: true, + triggers: map[*handle]storage.TriggerConfig{}, + policies: map[string][]byte{}, + roundTripOnWrite: true, + returnASTValuesOnRead: false, } for _, opt := range opts { opt(s) } + if s.returnASTValuesOnRead { + s.data = ast.NewObject() + } else { + s.data = map[string]interface{}{} + } + return s } @@ -55,7 +62,7 @@ func NewFromObject(data map[string]interface{}) storage.Store { return NewFromObjectWithOpts(data) } -// NewFromObject returns a new in-memory store from the supplied data object, with the +// NewFromObjectWithOpts returns a new in-memory store from the supplied data object, with the // options passed. func NewFromObjectWithOpts(data map[string]interface{}, opts ...Opt) storage.Store { db := NewWithOpts(opts...) @@ -94,13 +101,18 @@ type store struct { rmu sync.RWMutex // reader-writer lock wmu sync.Mutex // writer lock xid uint64 // last generated transaction id - data map[string]interface{} // raw data + data interface{} // raw or AST data policies map[string][]byte // raw policies triggers map[*handle]storage.TriggerConfig // registered triggers // roundTripOnWrite, if true, means that every call to Write round trips the // data through JSON before adding the data to the store. Defaults to true. roundTripOnWrite bool + + // returnASTValuesOnRead, if true, means that the store will eagerly convert data to AST values, + // and return them on Read. + // FIXME: naming(?) + returnASTValuesOnRead bool } type handle struct { @@ -295,7 +307,13 @@ func (db *store) Read(_ context.Context, txn storage.Transaction, path storage.P if err != nil { return nil, err } - return underlying.Read(path) + + v, err := underlying.Read(path) + if err != nil { + return nil, err + } + + return v, nil } func (db *store) Write(_ context.Context, txn storage.Transaction, op storage.PatchOp, path storage.Path, value interface{}) error { @@ -327,11 +345,45 @@ func (h *handle) Unregister(_ context.Context, txn storage.Transaction) { } func (db *store) runOnCommitTriggers(ctx context.Context, txn storage.Transaction, event storage.TriggerEvent) { + if db.returnASTValuesOnRead && len(db.triggers) > 0 { + // FIXME: Not very performant for large data. + + dataEvents := make([]storage.DataEvent, 0, len(event.Data)) + + for _, dataEvent := range event.Data { + if astData, ok := dataEvent.Data.(ast.Value); ok { + jsn, err := ast.ValueToInterface(astData, illegalResolver{}) + if err != nil { + panic(err) + } + dataEvents = append(dataEvents, storage.DataEvent{ + Path: dataEvent.Path, + Data: jsn, + Removed: dataEvent.Removed, + }) + } else { + dataEvents = append(dataEvents, dataEvent) + } + } + + event = storage.TriggerEvent{ + Policy: event.Policy, + Data: dataEvents, + Context: event.Context, + } + } + for _, t := range db.triggers { t.OnCommit(ctx, txn, event) } } +type illegalResolver struct{} + +func (illegalResolver) Resolve(ref ast.Ref) (interface{}, error) { + return nil, fmt.Errorf("illegal value: %v", ref) +} + func (db *store) underlying(txn storage.Transaction) (*transaction, error) { underlying, ok := txn.(*transaction) if !ok { diff --git a/storage/inmem/inmem_test.go b/storage/inmem/inmem_test.go index 0d36aa9357..f0aecf32cd 100644 --- a/storage/inmem/inmem_test.go +++ b/storage/inmem/inmem_test.go @@ -11,6 +11,7 @@ import ( "reflect" "testing" + "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/bundle" "github.com/open-policy-agent/opa/internal/file/archive" storageerrors "github.com/open-policy-agent/opa/storage/internal/errors" @@ -69,133 +70,199 @@ func TestInMemoryRead(t *testing.T) { } -func TestInMemoryWrite(t *testing.T) { +func TestInMemoryReadAst(t *testing.T) { - tests := []struct { - note string - op string - path string - value string - expected error - getPath string - getExpected interface{} + data := loadSmallTestData() + + var tests = []struct { + path string + expected interface{} }{ - {"add root", "add", "/", `{"a": [1]}`, nil, "/", `{"a": [1]}`}, - {"add", "add", "/newroot", `{"a": [[1]]}`, nil, "/newroot", `{"a": [[1]]}`}, - {"add arr", "add", "/a/1", `"x"`, nil, "/a", `[1,"x",2,3,4]`}, - {"add arr/arr", "add", "/h/1/2", `"x"`, nil, "/h", `[[1,2,3], [2,3,"x",4]]`}, - {"add obj/arr", "add", "/d/e/1", `"x"`, nil, "/d", `{"e": ["bar", "x", "baz"]}`}, - {"add obj", "add", "/b/vNew", `"x"`, nil, "/b", `{"v1": "hello", "v2": "goodbye", "vNew": "x"}`}, - {"add obj (existing)", "add", "/b/v2", `"x"`, nil, "/b", `{"v1": "hello", "v2": "x"}`}, - - {"append arr", "add", "/a/-", `"x"`, nil, "/a", `[1,2,3,4,"x"]`}, - {"append arr-2", "add", "/a/4", `"x"`, nil, "/a", `[1,2,3,4,"x"]`}, - {"append obj/arr", "add", `/c/0/x/-`, `"x"`, nil, "/c/0/x", `[true,false,"foo","x"]`}, - {"append obj/arr-2", "add", `/c/0/x/3`, `"x"`, nil, "/c/0/x", `[true,false,"foo","x"]`}, - {"append arr/arr", "add", `/h/0/-`, `"x"`, nil, `/h/0/3`, `"x"`}, - {"append arr/arr-2", "add", `/h/0/3`, `"x"`, nil, `/h/0/3`, `"x"`}, - {"append err", "remove", "/c/0/x/-", "", invalidPatchError("/c/0/x/-: invalid patch path"), "", nil}, - {"append err-2", "replace", "/c/0/x/-", "", invalidPatchError("/c/0/x/-: invalid patch path"), "", nil}, - - {"remove", "remove", "/a", "", nil, "/a", storageerrors.NewNotFoundError(storage.MustParsePath("/a"))}, - {"remove arr", "remove", "/a/1", "", nil, "/a", "[1,3,4]"}, - {"remove obj/arr", "remove", "/c/0/x/1", "", nil, "/c/0/x", `[true,"foo"]`}, - {"remove arr/arr", "remove", "/h/0/1", "", nil, "/h/0", "[1,3]"}, - {"remove obj", "remove", "/b/v2", "", nil, "/b", `{"v1": "hello"}`}, - - {"replace root", "replace", "/", `{"a": [1]}`, nil, "/", `{"a": [1]}`}, - {"replace", "replace", "/a", "1", nil, "/a", "1"}, - {"replace obj", "replace", "/b/v1", "1", nil, "/b", `{"v1": 1, "v2": "goodbye"}`}, - {"replace array", "replace", "/a/1", "999", nil, "/a", "[1,999,3,4]"}, - - {"err: bad root type", "add", "/", "[1,2,3]", invalidPatchError(rootMustBeObjectMsg), "", nil}, - {"err: remove root", "remove", "/", "", invalidPatchError(rootCannotBeRemovedMsg), "", nil}, - {"err: add arr (non-integer)", "add", "/a/foo", "1", storageerrors.NewNotFoundErrorWithHint(storage.MustParsePath("/a/foo"), storageerrors.ArrayIndexTypeMsg), "", nil}, - {"err: add arr (non-integer)", "add", "/a/3.14", "1", storageerrors.NewNotFoundErrorWithHint(storage.MustParsePath("/a/3.14"), storageerrors.ArrayIndexTypeMsg), "", nil}, - {"err: add arr (out of range)", "add", "/a/5", "1", storageerrors.NewNotFoundErrorWithHint(storage.MustParsePath("/a/5"), storageerrors.OutOfRangeMsg), "", nil}, - {"err: add arr (out of range)", "add", "/a/-1", "1", storageerrors.NewNotFoundErrorWithHint(storage.MustParsePath("/a/-1"), storageerrors.OutOfRangeMsg), "", nil}, - {"err: add arr (missing root)", "add", "/dead/beef/0", "1", storageerrors.NewNotFoundError(storage.MustParsePath("/dead/beef/0")), "", nil}, - {"err: add non-coll", "add", "/a/1/2", "1", storageerrors.NewNotFoundError(storage.MustParsePath("/a/1/2")), "", nil}, - {"err: append (missing)", "add", `/dead/beef/-`, "1", storageerrors.NewNotFoundError(storage.MustParsePath("/dead/beef/-")), "", nil}, - {"err: append obj/arr", "add", `/c/0/deadbeef/-`, `"x"`, storageerrors.NewNotFoundError(storage.MustParsePath("/c/0/deadbeef/-")), "", nil}, - {"err: append arr/arr (out of range)", "add", `/h/9999/-`, `"x"`, storageerrors.NewNotFoundErrorWithHint(storage.MustParsePath("/h/9999/-"), storageerrors.OutOfRangeMsg), "", nil}, - {"err: append append+add", "add", `/a/-/b/-`, `"x"`, storageerrors.NewNotFoundErrorWithHint(storage.MustParsePath(`/a/-/b/-`), storageerrors.ArrayIndexTypeMsg), "", nil}, - {"err: append arr/arr (non-array)", "add", `/b/v1/-`, "1", storageerrors.NewNotFoundError(storage.MustParsePath("/b/v1/-")), "", nil}, - {"err: remove missing", "remove", "/dead/beef/0", "", storageerrors.NewNotFoundError(storage.MustParsePath("/dead/beef/0")), "", nil}, - {"err: remove obj (missing)", "remove", "/b/deadbeef", "", storageerrors.NewNotFoundError(storage.MustParsePath("/b/deadbeef")), "", nil}, - {"err: replace root (missing)", "replace", "/deadbeef", "1", storageerrors.NewNotFoundError(storage.MustParsePath("/deadbeef")), "", nil}, - {"err: replace missing", "replace", "/dead/beef/1", "1", storageerrors.NewNotFoundError(storage.MustParsePath("/dead/beef/1")), "", nil}, + {"/a/0", ast.Number("1")}, + {"/a/3", ast.Number("4")}, + {"/b/v1", ast.String("hello")}, + {"/b/v2", ast.String("goodbye")}, + {"/c/0/x/1", ast.Boolean(false)}, + {"/c/0/y/0", ast.Null{}}, + {"/c/0/y/1", ast.Number("3.14159")}, + {"/d/e/1", ast.String("baz")}, + {"/d/e", ast.NewArray(ast.StringTerm("bar"), ast.StringTerm("baz"))}, + {"/c/0/z", ast.NewObject(ast.Item(ast.StringTerm("p"), ast.BooleanTerm(true)), ast.Item(ast.StringTerm("q"), ast.BooleanTerm(false)))}, + {"/a/0/beef", storageerrors.NewNotFoundError(storage.MustParsePath("/a/0/beef"))}, + {"/d/100", storageerrors.NewNotFoundError(storage.MustParsePath("/d/100"))}, + {"/dead/beef", storageerrors.NewNotFoundError(storage.MustParsePath("/dead/beef"))}, + {"/a/str", storageerrors.NewNotFoundErrorWithHint(storage.MustParsePath("/a/str"), storageerrors.ArrayIndexTypeMsg)}, + {"/a/100", storageerrors.NewNotFoundErrorWithHint(storage.MustParsePath("/a/100"), storageerrors.OutOfRangeMsg)}, + {"/a/-1", storageerrors.NewNotFoundErrorWithHint(storage.MustParsePath("/a/-1"), storageerrors.OutOfRangeMsg)}, } + store := NewFromObjectWithOpts(data, OptReturnASTValuesOnRead(true)) ctx := context.Background() - for i, tc := range tests { - data := loadSmallTestData() - store := NewFromObject(data) - - // Perform patch and check result - value := loadExpectedSortedResult(tc.value) - - var op storage.PatchOp - switch tc.op { - case "add": - op = storage.AddOp - case "remove": - op = storage.RemoveOp - case "replace": - op = storage.ReplaceOp + for idx, tc := range tests { + result, err := storage.ReadOne(ctx, store, storage.MustParsePath(tc.path)) + switch e := tc.expected.(type) { + case error: + if err == nil { + t.Errorf("Test case %d: expected error for %v but got %v", idx+1, tc.path, result) + } else if !reflect.DeepEqual(err, tc.expected) { + t.Errorf("Test case %d: unexpected error for %v: %v, expected: %v", idx+1, tc.path, err, e) + } default: - panic(fmt.Sprintf("illegal value: %v", tc.op)) - } - - err := storage.WriteOne(ctx, store, op, storage.MustParsePath(tc.path), value) - if tc.expected == nil { if err != nil { - t.Errorf("Test case %d (%v): unexpected patch error: %v", i+1, tc.note, err) - continue - } - } else { - if err == nil { - t.Errorf("Test case %d (%v): expected patch error, but got nil instead", i+1, tc.note) - continue + t.Errorf("Test case %d: expected success for %v but got %v", idx+1, tc.path, err) } - if !reflect.DeepEqual(err, tc.expected) { - t.Errorf("Test case %d (%v): expected patch error %v but got: %v", i+1, tc.note, tc.expected, err) - continue + if ast.Compare(result, tc.expected) != 0 { + t.Errorf("Test case %d: expected %f but got %f", idx+1, tc.expected, result) } } + } +} - if tc.getPath == "" { - continue - } - - // Perform get and verify result - result, err := storage.ReadOne(ctx, store, storage.MustParsePath(tc.getPath)) - switch expected := tc.getExpected.(type) { - case error: - if err == nil { - t.Errorf("Test case %d (%v): expected get error but got: %v", i+1, tc.note, result) - continue - } - if !reflect.DeepEqual(err, expected) { - t.Errorf("Test case %d (%v): expected get error %v but got: %v", i+1, tc.note, expected, err) - continue - } - case string: - if err != nil { - t.Errorf("Test case %d (%v): unexpected get error: %v", i+1, tc.note, err) - continue +func TestInMemoryWrite(t *testing.T) { + readValueType := []struct { + note string + ast bool + }{ + {"raw", false}, + {"ast", true}, + } + + for _, rvt := range readValueType { + t.Run(rvt.note, func(t *testing.T) { + tests := []struct { + note string + op string + path string + value string + expected error + getPath string + getExpected interface{} + }{ + {"add root", "add", "/", `{"a": [1]}`, nil, "/", `{"a": [1]}`}, + {"add", "add", "/newroot", `{"a": [[1]]}`, nil, "/newroot", `{"a": [[1]]}`}, + {"add arr", "add", "/a/1", `"x"`, nil, "/a", `[1,"x",2,3,4]`}, + {"add arr/arr", "add", "/h/1/2", `"x"`, nil, "/h", `[[1,2,3], [2,3,"x",4]]`}, + {"add obj/arr", "add", "/d/e/1", `"x"`, nil, "/d", `{"e": ["bar", "x", "baz"]}`}, + {"add obj", "add", "/b/vNew", `"x"`, nil, "/b", `{"v1": "hello", "v2": "goodbye", "vNew": "x"}`}, + {"add obj (existing)", "add", "/b/v2", `"x"`, nil, "/b", `{"v1": "hello", "v2": "x"}`}, + + {"append arr", "add", "/a/-", `"x"`, nil, "/a", `[1,2,3,4,"x"]`}, + {"append arr-2", "add", "/a/4", `"x"`, nil, "/a", `[1,2,3,4,"x"]`}, + {"append obj/arr", "add", `/c/0/x/-`, `"x"`, nil, "/c/0/x", `[true,false,"foo","x"]`}, + {"append obj/arr-2", "add", `/c/0/x/3`, `"x"`, nil, "/c/0/x", `[true,false,"foo","x"]`}, + {"append arr/arr", "add", `/h/0/-`, `"x"`, nil, `/h/0/3`, `"x"`}, + {"append arr/arr-2", "add", `/h/0/3`, `"x"`, nil, `/h/0/3`, `"x"`}, + {"append err", "remove", "/c/0/x/-", "", invalidPatchError("/c/0/x/-: invalid patch path"), "", nil}, + {"append err-2", "replace", "/c/0/x/-", "", invalidPatchError("/c/0/x/-: invalid patch path"), "", nil}, + + {"remove", "remove", "/a", "", nil, "/a", storageerrors.NewNotFoundError(storage.MustParsePath("/a"))}, + {"remove arr", "remove", "/a/1", "", nil, "/a", "[1,3,4]"}, + {"remove obj/arr", "remove", "/c/0/x/1", "", nil, "/c/0/x", `[true,"foo"]`}, + {"remove arr/arr", "remove", "/h/0/1", "", nil, "/h/0", "[1,3]"}, + {"remove obj", "remove", "/b/v2", "", nil, "/b", `{"v1": "hello"}`}, + + {"replace root", "replace", "/", `{"a": [1]}`, nil, "/", `{"a": [1]}`}, + {"replace", "replace", "/a", "1", nil, "/a", "1"}, + {"replace obj", "replace", "/b/v1", "1", nil, "/b", `{"v1": 1, "v2": "goodbye"}`}, + {"replace array", "replace", "/a/1", "999", nil, "/a", "[1,999,3,4]"}, + + {"err: bad root type", "add", "/", "[1,2,3]", invalidPatchError(rootMustBeObjectMsg), "", nil}, + {"err: remove root", "remove", "/", "", invalidPatchError(rootCannotBeRemovedMsg), "", nil}, + {"err: add arr (non-integer)", "add", "/a/foo", "1", storageerrors.NewNotFoundErrorWithHint(storage.MustParsePath("/a/foo"), storageerrors.ArrayIndexTypeMsg), "", nil}, + {"err: add arr (non-integer)", "add", "/a/3.14", "1", storageerrors.NewNotFoundErrorWithHint(storage.MustParsePath("/a/3.14"), storageerrors.ArrayIndexTypeMsg), "", nil}, + {"err: add arr (out of range)", "add", "/a/5", "1", storageerrors.NewNotFoundErrorWithHint(storage.MustParsePath("/a/5"), storageerrors.OutOfRangeMsg), "", nil}, + {"err: add arr (out of range)", "add", "/a/-1", "1", storageerrors.NewNotFoundErrorWithHint(storage.MustParsePath("/a/-1"), storageerrors.OutOfRangeMsg), "", nil}, + {"err: add arr (missing root)", "add", "/dead/beef/0", "1", storageerrors.NewNotFoundError(storage.MustParsePath("/dead/beef/0")), "", nil}, + {"err: add non-coll", "add", "/a/1/2", "1", storageerrors.NewNotFoundError(storage.MustParsePath("/a/1/2")), "", nil}, + {"err: append (missing)", "add", `/dead/beef/-`, "1", storageerrors.NewNotFoundError(storage.MustParsePath("/dead/beef/-")), "", nil}, + {"err: append obj/arr", "add", `/c/0/deadbeef/-`, `"x"`, storageerrors.NewNotFoundError(storage.MustParsePath("/c/0/deadbeef/-")), "", nil}, + {"err: append arr/arr (out of range)", "add", `/h/9999/-`, `"x"`, storageerrors.NewNotFoundErrorWithHint(storage.MustParsePath("/h/9999/-"), storageerrors.OutOfRangeMsg), "", nil}, + {"err: append append+add", "add", `/a/-/b/-`, `"x"`, storageerrors.NewNotFoundErrorWithHint(storage.MustParsePath(`/a/-/b/-`), storageerrors.ArrayIndexTypeMsg), "", nil}, + {"err: append arr/arr (non-array)", "add", `/b/v1/-`, "1", storageerrors.NewNotFoundError(storage.MustParsePath("/b/v1/-")), "", nil}, + {"err: remove missing", "remove", "/dead/beef/0", "", storageerrors.NewNotFoundError(storage.MustParsePath("/dead/beef/0")), "", nil}, + {"err: remove obj (missing)", "remove", "/b/deadbeef", "", storageerrors.NewNotFoundError(storage.MustParsePath("/b/deadbeef")), "", nil}, + {"err: replace root (missing)", "replace", "/deadbeef", "1", storageerrors.NewNotFoundError(storage.MustParsePath("/deadbeef")), "", nil}, + {"err: replace missing", "replace", "/dead/beef/1", "1", storageerrors.NewNotFoundError(storage.MustParsePath("/dead/beef/1")), "", nil}, } - e := loadExpectedResult(expected) + ctx := context.Background() - if !reflect.DeepEqual(result, e) { - t.Errorf("Test case %d (%v): expected get result %v but got: %v", i+1, tc.note, e, result) + for i, tc := range tests { + data := loadSmallTestData() + store := NewFromObjectWithOpts(data, OptReturnASTValuesOnRead(rvt.ast)) + + // Perform patch and check result + value := loadExpectedSortedResult(tc.value) + + var op storage.PatchOp + switch tc.op { + case "add": + op = storage.AddOp + case "remove": + op = storage.RemoveOp + case "replace": + op = storage.ReplaceOp + default: + panic(fmt.Sprintf("illegal value: %v", tc.op)) + } + + err := storage.WriteOne(ctx, store, op, storage.MustParsePath(tc.path), value) + if tc.expected == nil { + if err != nil { + t.Errorf("Test case %d (%v): unexpected patch error: %v", i+1, tc.note, err) + continue + } + } else { + if err == nil { + t.Errorf("Test case %d (%v): expected patch error, but got nil instead", i+1, tc.note) + continue + } + if !reflect.DeepEqual(err, tc.expected) { + t.Errorf("Test case %d (%v): expected patch error %v but got: %v", i+1, tc.note, tc.expected, err) + continue + } + } + + if tc.getPath == "" { + continue + } + + // Perform get and verify result + result, err := storage.ReadOne(ctx, store, storage.MustParsePath(tc.getPath)) + switch expected := tc.getExpected.(type) { + case error: + if err == nil { + t.Errorf("Test case %d (%v): expected get error but got: %v", i+1, tc.note, result) + continue + } + if !reflect.DeepEqual(err, expected) { + t.Errorf("Test case %d (%v): expected get error %v but got: %v", i+1, tc.note, expected, err) + continue + } + case string: + if err != nil { + t.Errorf("Test case %d (%v): unexpected get error: %v", i+1, tc.note, err) + continue + } + + if rvt.ast { + e := ast.MustParseTerm(expected) + + if ast.Compare(result, e.Value) != 0 { + t.Errorf("Test case %d (%v): expected get result %v but got: %v", i+1, tc.note, e, result) + } + } else { + e := loadExpectedResult(expected) + + if !reflect.DeepEqual(result, e) { + t.Errorf("Test case %d (%v): expected get result %v but got: %v", i+1, tc.note, e, result) + } + } + } } - } - + }) } - } func TestInMemoryWriteOfStruct(t *testing.T) { @@ -243,6 +310,53 @@ func TestInMemoryWriteOfStruct(t *testing.T) { } } +func TestInMemoryWriteOfStructAst(t *testing.T) { + type B struct { + Bar int `json:"bar"` + } + + type A struct { + Foo *B `json:"foo"` + } + + cases := map[string]struct { + value interface{} + expected string + }{ + "nested struct": {A{&B{10}}, `{"foo": {"bar": 10 } }`}, + "pointer to nested struct": {&A{&B{10}}, `{"foo": {"bar": 10 } }`}, + "pointer to pointer to nested struct": { + func() interface{} { + a := &A{&B{10}} + return &a + }(), `{"foo": {"bar": 10 } }`}, + } + + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + store := NewWithOpts(OptReturnASTValuesOnRead(true)) + ctx := context.Background() + + // Written non-AST values are expected to be converted to AST values + err := storage.WriteOne(ctx, store, storage.AddOp, storage.MustParsePath("/x"), tc.value) + if err != nil { + t.Fatal(err) + } + + actual, err := storage.ReadOne(ctx, store, storage.MustParsePath("/x")) + if err != nil { + t.Fatal(err) + } + + // We expect the result to be an AST value + expected := ast.MustParseTerm(tc.expected) + if ast.Compare(expected.Value, actual) != 0 { + t.Errorf("expected %v, got %v", tc.expected, actual) + } + }) + } +} + func TestInMemoryTxnMultipleWrites(t *testing.T) { ctx := context.Background() @@ -325,34 +439,71 @@ func TestInMemoryTxnMultipleWrites(t *testing.T) { } } -func TestTruncateNoExistingPath(t *testing.T) { +func TestInMemoryTxnMultipleWritesAst(t *testing.T) { + ctx := context.Background() - store := NewFromObject(map[string]interface{}{}) + store := NewFromObjectWithOpts(loadSmallTestData(), OptReturnASTValuesOnRead(true)) txn := storage.NewTransactionOrDie(ctx, store, storage.WriteParams) - var archiveFiles = map[string]string{ - "/a/b/c/data.json": "[1,2,3]", + // Perform a sequence of writes and then verify the read results are the + // same for the writer during the transaction and the reader after the + // commit. + writes := []struct { + op storage.PatchOp + path string + value string + }{ + {storage.AddOp, "/a/-", "[]"}, + {storage.AddOp, "/a/4/-", "1"}, + {storage.AddOp, "/a/4/-", "2"}, + {storage.AddOp, "/a/4/2", "3"}, + {storage.AddOp, "/b/foo", "{}"}, + {storage.AddOp, "/b/foo/bar", "{}"}, + {storage.AddOp, "/b/foo/bar/baz", "1"}, + {storage.AddOp, "/arr", "[]"}, + {storage.AddOp, "/arr/-", "1"}, + {storage.AddOp, "/arr/0", "2"}, + {storage.AddOp, "/arr/2", "3"}, + {storage.AddOp, "/c/0/x/-", "0"}, + {storage.AddOp, "/_", "null"}, // introduce new txn.log head + {storage.AddOp, "/c/0", `"new c[0]"`}, + {storage.AddOp, "/c/1", `"new c[1]"`}, + {storage.AddOp, "/_head", "1"}, + {storage.AddOp, "/_head", "2"}, // invalidate the txn.log head + {storage.AddOp, "/d/f", `{"g": {"h": 0}}`}, + {storage.AddOp, "/d/f/g/i", `{"j": 1}`}, } - files := make([][2]string, 0, len(archiveFiles)) - for name, content := range archiveFiles { - files = append(files, [2]string{name, content}) + reads := []struct { + path string + expected string + }{ + {"/a", `[1,2,3,4,[1,2,3]]`}, + {"/b/foo", `{"bar": {"baz": 1}}`}, + {"/arr", `[2,1,3]`}, + {"/c/0", `"new c[0]"`}, + {"/c/1", `"new c[1]"`}, + {"/d/f", `{"g": {"h": 0, "i": {"j": 1}}}`}, + {"/d", `{"e": ["bar", "baz"], "f": {"g":{"h": 0, "i": {"j": 1}}}}`}, + {"/h/1/2", "4"}, } - buf := archive.MustWriteTarGz(files) - b, err := bundle.NewReader(buf).WithLazyLoadingMode(true).Read() - if err != nil { - t.Fatal(err) + for _, w := range writes { + var jsn interface{} + if w.value != "" { + jsn = util.MustUnmarshalJSON([]byte(w.value)) + } + if err := store.Write(ctx, txn, w.op, storage.MustParsePath(w.path), jsn); err != nil { + t.Fatalf("Unexpected write error on %v: %v", w, err) + } } - iterator := bundle.NewIterator(b.Raw) - - params := storage.WriteParams - params.BasePaths = []string{""} - - err = store.Truncate(ctx, txn, params, iterator) - if err != nil { - t.Fatalf("Unexpected truncate error: %v", err) + for _, r := range reads { + exp := ast.MustParseTerm(r.expected) + result, err := store.Read(ctx, txn, storage.MustParsePath(r.path)) + if err != nil || ast.Compare(exp.Value, result) != 0 { + t.Fatalf("Expected writer's read %v to be %v but got: %v (err: %v)", r.path, exp, result, err) + } } if err := store.Commit(ctx, txn); err != nil { @@ -361,12 +512,67 @@ func TestTruncateNoExistingPath(t *testing.T) { txn = storage.NewTransactionOrDie(ctx, store) - actual, err := store.Read(ctx, txn, storage.MustParsePath("/")) - if err != nil { - t.Fatal(err) + for _, r := range reads { + exp := ast.MustParseTerm(r.expected) + result, err := store.Read(ctx, txn, storage.MustParsePath(r.path)) + if err != nil || ast.Compare(exp.Value, result) != 0 { + t.Fatalf("Expected reader's read %v to be %v but got: %v (err: %v)", r.path, exp, result, err) + } } +} - expected := ` +func TestTruncateNoExistingPath(t *testing.T) { + cases := []struct { + note string + ast bool + }{ + {"raw", false}, + {"ast", true}, + } + + for _, tc := range cases { + t.Run(tc.note, func(t *testing.T) { + ctx := context.Background() + store := NewFromObjectWithOpts(map[string]interface{}{}, OptReturnASTValuesOnRead(tc.ast)) + txn := storage.NewTransactionOrDie(ctx, store, storage.WriteParams) + + var archiveFiles = map[string]string{ + "/a/b/c/data.json": "[1,2,3]", + } + + files := make([][2]string, 0, len(archiveFiles)) + for name, content := range archiveFiles { + files = append(files, [2]string{name, content}) + } + + buf := archive.MustWriteTarGz(files) + b, err := bundle.NewReader(buf).WithLazyLoadingMode(true).Read() + if err != nil { + t.Fatal(err) + } + + iterator := bundle.NewIterator(b.Raw) + + params := storage.WriteParams + params.BasePaths = []string{""} + + err = store.Truncate(ctx, txn, params, iterator) + if err != nil { + t.Fatalf("Unexpected truncate error: %v", err) + } + + if err := store.Commit(ctx, txn); err != nil { + t.Fatalf("Unexpected commit error: %v", err) + } + + txn = storage.NewTransactionOrDie(ctx, store) + + actual, err := store.Read(ctx, txn, storage.MustParsePath("/")) + if err != nil { + t.Fatal(err) + } + + expected := ` { "a": { "b": { @@ -375,10 +581,20 @@ func TestTruncateNoExistingPath(t *testing.T) { } } ` - jsn := util.MustUnmarshalJSON([]byte(expected)) - - if !reflect.DeepEqual(jsn, actual) { - t.Fatalf("Expected reader's read to be %v but got: %v", jsn, actual) + if tc.ast { + exp := ast.MustParseTerm(expected) + + if ast.Compare(exp.Value, actual) != 0 { + t.Fatalf("Expected reader's read to be %v but got: %v", exp, actual) + } + } else { + jsn := util.MustUnmarshalJSON([]byte(expected)) + + if !reflect.DeepEqual(jsn, actual) { + t.Fatalf("Expected reader's read to be %v but got: %v", jsn, actual) + } + } + }) } } @@ -480,14 +696,18 @@ func TestTruncate(t *testing.T) { } } -func TestTruncateDataMergeError(t *testing.T) { +func TestTruncateAst(t *testing.T) { ctx := context.Background() - store := NewFromObject(map[string]interface{}{}) + store := NewFromObjectWithOpts(map[string]interface{}{}, OptReturnASTValuesOnRead(true)) txn := storage.NewTransactionOrDie(ctx, store, storage.WriteParams) var archiveFiles = map[string]string{ - "/a/b/data.json": `{"c": "foo"}`, - "/data.json": `{"a": {"b": {"c": "bar"}}}`, + "/a/b/c/data.json": "[1,2,3]", + "/a/b/d/data.json": "true", + "/data.json": `{"x": {"y": true}, "a": {"b": {"z": true}}}`, + "/a/b/y/data.yaml": `foo: 1`, + "/policy.rego": "package foo\n p = 1", + "/roles/policy.rego": "package bar\n p = 1", } files := make([][2]string, 0, len(archiveFiles)) @@ -503,108 +723,247 @@ func TestTruncateDataMergeError(t *testing.T) { iterator := bundle.NewIterator(b.Raw) - err = store.Truncate(ctx, txn, storage.WriteParams, iterator) - if err == nil { - t.Fatal("Expected truncate error but got nil") + params := storage.WriteParams + params.BasePaths = []string{""} + + err = store.Truncate(ctx, txn, params, iterator) + if err != nil { + t.Fatalf("Unexpected truncate error: %v", err) } - expected := "failed to insert data file from path a/b" - if err.Error() != expected { - t.Fatalf("Expected error %v but got %v", expected, err.Error()) + if err := store.Commit(ctx, txn); err != nil { + t.Fatalf("Unexpected commit error: %v", err) } -} -func TestTruncateBadRootWrite(t *testing.T) { - ctx := context.Background() - store := NewFromObject(map[string]interface{}{}) - txn := storage.NewTransactionOrDie(ctx, store, storage.WriteParams) + txn = storage.NewTransactionOrDie(ctx, store) - var archiveFiles = map[string]string{ - "/a/b/d/data.json": "true", - "/data.json": "[1,2,3]", - "/roles/policy.rego": "package bar\n p = 1", + actual, err := store.Read(ctx, txn, storage.MustParsePath("/")) + if err != nil { + t.Fatal(err) } - files := make([][2]string, 0, len(archiveFiles)) - for name, content := range archiveFiles { - files = append(files, [2]string{name, content}) + expected := ` +{ + "a": { + "b": { + "c": [1,2,3], + "d": true, + "y": { + "foo": 1 + }, + "z": true + } + }, + "x": { + "y": true } +} +` + exp := ast.MustParseTerm(expected) - buf := archive.MustWriteTarGz(files) - b, err := bundle.NewReader(buf).WithLazyLoadingMode(true).Read() + if ast.Compare(exp.Value, actual) != 0 { + t.Fatalf("Expected reader's read to be %v but got: %v", exp, actual) + } + + store.Abort(ctx, txn) + + txn = storage.NewTransactionOrDie(ctx, store) + ids, err := store.ListPolicies(ctx, txn) if err != nil { t.Fatal(err) } - iterator := bundle.NewIterator(b.Raw) + expectedIDs := map[string]struct{}{"policy.rego": {}, "roles/policy.rego": {}} + + for _, id := range ids { + if _, ok := expectedIDs[id]; !ok { + t.Fatalf("Expected list policies to contain %v but got: %v", id, expectedIDs) + } + } - err = store.Truncate(ctx, txn, storage.WriteParams, iterator) - if err == nil { - t.Fatal("Expected truncate error but got nil") + bs, err := store.GetPolicy(ctx, txn, "policy.rego") + expectedBytes := []byte("package foo\n p = 1") + if err != nil || !reflect.DeepEqual(expectedBytes, bs) { + t.Fatalf("Expected get policy to return %v but got: %v (err: %v)", expectedBytes, bs, err) } - expected := "storage_invalid_patch_error: root must be object" - if err.Error() != expected { - t.Fatalf("Expected error %v but got %v", expected, err.Error()) + bs, err = store.GetPolicy(ctx, txn, "roles/policy.rego") + expectedBytes = []byte("package bar\n p = 1") + if err != nil || !reflect.DeepEqual(expectedBytes, bs) { + t.Fatalf("Expected get policy to return %v but got: %v (err: %v)", expectedBytes, bs, err) } } -func TestInMemoryTxnWriteFailures(t *testing.T) { - - ctx := context.Background() - store := NewFromObject(loadSmallTestData()) - txn := storage.NewTransactionOrDie(ctx, store, storage.WriteParams) - - writes := []struct { - op storage.PatchOp - path string - value string - errCode string +func TestTruncateDataMergeError(t *testing.T) { + cases := []struct { + note string + ast bool }{ - {storage.RemoveOp, "/c/0/y", "", ""}, - {storage.RemoveOp, "/c/0/y", "", storage.NotFoundErr}, - {storage.ReplaceOp, "/c/0/y/0", "", storage.NotFoundErr}, - {storage.AddOp, "/new", `{"foo": "bar"}`, ""}, - {storage.AddOp, "/a/0/beef", "", storage.NotFoundErr}, - {storage.AddOp, "/arr", `[1,2,3]`, ""}, - {storage.AddOp, "/arr/0/foo", "", storage.NotFoundErr}, - {storage.AddOp, "/arr/4", "", storage.NotFoundErr}, + {"raw", false}, + {"ast", true}, } - for _, w := range writes { - var jsn interface{} - if w.value != "" { - jsn = util.MustUnmarshalJSON([]byte(w.value)) - } - err := store.Write(ctx, txn, w.op, storage.MustParsePath(w.path), jsn) - if (w.errCode == "" && err != nil) || (err == nil && w.errCode != "") { - t.Fatalf("Expected errCode %q but got: %v", w.errCode, err) - } + for _, tc := range cases { + t.Run(tc.note, func(t *testing.T) { + ctx := context.Background() + store := NewFromObjectWithOpts(map[string]interface{}{}, OptReturnASTValuesOnRead(tc.ast)) + txn := storage.NewTransactionOrDie(ctx, store, storage.WriteParams) + + var archiveFiles = map[string]string{ + "/a/b/data.json": `{"c": "foo"}`, + "/data.json": `{"a": {"b": {"c": "bar"}}}`, + } + + files := make([][2]string, 0, len(archiveFiles)) + for name, content := range archiveFiles { + files = append(files, [2]string{name, content}) + } + + buf := archive.MustWriteTarGz(files) + b, err := bundle.NewReader(buf).WithLazyLoadingMode(true).Read() + if err != nil { + t.Fatal(err) + } + + iterator := bundle.NewIterator(b.Raw) + + err = store.Truncate(ctx, txn, storage.WriteParams, iterator) + if err == nil { + t.Fatal("Expected truncate error but got nil") + } + + expected := "failed to insert data file from path a/b" + if err.Error() != expected { + t.Fatalf("Expected error %v but got %v", expected, err.Error()) + } + }) } } -func TestInMemoryTxnReadFailures(t *testing.T) { +func TestTruncateBadRootWrite(t *testing.T) { + cases := []struct { + note string + ast bool + }{ + {"raw", false}, + {"ast", true}, + } - ctx := context.Background() - store := NewFromObject(loadSmallTestData()) - txn := storage.NewTransactionOrDie(ctx, store, storage.WriteParams) + for _, tc := range cases { + t.Run(tc.note, func(t *testing.T) { + ctx := context.Background() + store := NewFromObjectWithOpts(map[string]interface{}{}, OptReturnASTValuesOnRead(tc.ast)) + txn := storage.NewTransactionOrDie(ctx, store, storage.WriteParams) - if err := store.Write(ctx, txn, storage.RemoveOp, storage.MustParsePath("/c/0/y"), nil); err != nil { - t.Fatalf("Unexpected write error: %v", err) + var archiveFiles = map[string]string{ + "/a/b/d/data.json": "true", + "/data.json": "[1,2,3]", + "/roles/policy.rego": "package bar\n p = 1", + } + + files := make([][2]string, 0, len(archiveFiles)) + for name, content := range archiveFiles { + files = append(files, [2]string{name, content}) + } + + buf := archive.MustWriteTarGz(files) + b, err := bundle.NewReader(buf).WithLazyLoadingMode(true).Read() + if err != nil { + t.Fatal(err) + } + + iterator := bundle.NewIterator(b.Raw) + + err = store.Truncate(ctx, txn, storage.WriteParams, iterator) + if err == nil { + t.Fatal("Expected truncate error but got nil") + } + + expected := "storage_invalid_patch_error: root must be object" + if err.Error() != expected { + t.Fatalf("Expected error %v but got %v", expected, err.Error()) + } + }) } +} - if result, err := store.Read(ctx, txn, storage.MustParsePath("/c/0/y/0")); !storage.IsNotFound(err) { - t.Fatalf("Expected NotFoundErr for /c/0/y/0 but got: %v (err: %v)", result, err) +func TestInMemoryTxnWriteFailures(t *testing.T) { + cases := []struct { + note string + ast bool + }{ + {"raw", false}, + {"ast", true}, } - if result, err := store.Read(ctx, txn, storage.MustParsePath("/c/0/y")); !storage.IsNotFound(err) { - t.Fatalf("Expected NotFoundErr for /c/0/y but got: %v (err: %v)", result, err) + for _, tc := range cases { + t.Run(tc.note, func(t *testing.T) { + ctx := context.Background() + store := NewFromObjectWithOpts(loadSmallTestData(), OptReturnASTValuesOnRead(tc.ast)) + txn := storage.NewTransactionOrDie(ctx, store, storage.WriteParams) + + writes := []struct { + op storage.PatchOp + path string + value string + errCode string + }{ + {storage.RemoveOp, "/c/0/y", "", ""}, + {storage.RemoveOp, "/c/0/y", "", storage.NotFoundErr}, + {storage.ReplaceOp, "/c/0/y/0", "", storage.NotFoundErr}, + {storage.AddOp, "/new", `{"foo": "bar"}`, ""}, + {storage.AddOp, "/a/0/beef", "", storage.NotFoundErr}, + {storage.AddOp, "/arr", `[1,2,3]`, ""}, + {storage.AddOp, "/arr/0/foo", "", storage.NotFoundErr}, + {storage.AddOp, "/arr/4", "", storage.NotFoundErr}, + } + + for _, w := range writes { + var jsn interface{} + if w.value != "" { + jsn = util.MustUnmarshalJSON([]byte(w.value)) + } + err := store.Write(ctx, txn, w.op, storage.MustParsePath(w.path), jsn) + if (w.errCode == "" && err != nil) || (err == nil && w.errCode != "") { + t.Fatalf("Expected errCode %q but got: %v", w.errCode, err) + } + } + }) } +} - if result, err := store.Read(ctx, txn, storage.MustParsePath("/a/0/beef")); !storage.IsNotFound(err) { - t.Fatalf("Expected NotFoundErr for /c/0/y but got: %v (err: %v)", result, err) +func TestInMemoryTxnReadFailures(t *testing.T) { + cases := []struct { + note string + ast bool + }{ + {"raw", false}, + {"ast", true}, } + for _, tc := range cases { + t.Run(tc.note, func(t *testing.T) { + ctx := context.Background() + store := NewFromObjectWithOpts(loadSmallTestData(), OptReturnASTValuesOnRead(tc.ast)) + txn := storage.NewTransactionOrDie(ctx, store, storage.WriteParams) + + if err := store.Write(ctx, txn, storage.RemoveOp, storage.MustParsePath("/c/0/y"), nil); err != nil { + t.Fatalf("Unexpected write error: %v", err) + } + + if result, err := store.Read(ctx, txn, storage.MustParsePath("/c/0/y/0")); !storage.IsNotFound(err) { + t.Fatalf("Expected NotFoundErr for /c/0/y/0 but got: %v (err: %v)", result, err) + } + + if result, err := store.Read(ctx, txn, storage.MustParsePath("/c/0/y")); !storage.IsNotFound(err) { + t.Fatalf("Expected NotFoundErr for /c/0/y but got: %v (err: %v)", result, err) + } + + if result, err := store.Read(ctx, txn, storage.MustParsePath("/a/0/beef")); !storage.IsNotFound(err) { + t.Fatalf("Expected NotFoundErr for /c/0/y but got: %v (err: %v)", result, err) + } + }) + } } func TestInMemoryTxnBadWrite(t *testing.T) { @@ -612,7 +971,7 @@ func TestInMemoryTxnBadWrite(t *testing.T) { store := NewFromObject(loadSmallTestData()) txn := storage.NewTransactionOrDie(ctx, store) if err := store.Write(ctx, txn, storage.RemoveOp, storage.MustParsePath("/a"), nil); !storage.IsInvalidTransaction(err) { - t.Fatalf("Expected InvalidPatchErr but got: %v", err) + t.Fatalf("Expected InvalidTransactionErr but got: %v", err) } } @@ -711,65 +1070,83 @@ func TestInMemoryTxnPolicies(t *testing.T) { } func TestInMemoryTriggers(t *testing.T) { - - ctx := context.Background() - store := NewFromObject(loadSmallTestData()) - writeTxn := storage.NewTransactionOrDie(ctx, store, storage.WriteParams) - readTxn := storage.NewTransactionOrDie(ctx, store) - - _, err := store.Register(ctx, readTxn, storage.TriggerConfig{ - OnCommit: func(context.Context, storage.Transaction, storage.TriggerEvent) {}, - }) - - if err == nil || !storage.IsInvalidTransaction(err) { - t.Fatalf("Expected transaction error: %v", err) + cases := []struct { + note string + ast bool + }{ + {"raw", false}, + {"ast", true}, } - store.Abort(ctx, readTxn) + for _, tc := range cases { + t.Run(tc.note, func(t *testing.T) { + ctx := context.Background() + store := NewFromObjectWithOpts(loadSmallTestData(), OptReturnASTValuesOnRead(tc.ast)) + writeTxn := storage.NewTransactionOrDie(ctx, store, storage.WriteParams) + readTxn := storage.NewTransactionOrDie(ctx, store) - var event storage.TriggerEvent - modifiedPath := storage.MustParsePath("/a") - expectedValue := "hello" + _, err := store.Register(ctx, readTxn, storage.TriggerConfig{ + OnCommit: func(context.Context, storage.Transaction, storage.TriggerEvent) {}, + }) - _, err = store.Register(ctx, writeTxn, storage.TriggerConfig{ - OnCommit: func(ctx context.Context, txn storage.Transaction, evt storage.TriggerEvent) { - result, err := store.Read(ctx, txn, modifiedPath) - if err != nil || !reflect.DeepEqual(result, expectedValue) { - t.Fatalf("Expected result to be hello for trigger read but got: %v (err: %v)", result, err) + if err == nil || !storage.IsInvalidTransaction(err) { + t.Fatalf("Expected transaction error: %v", err) } - event = evt - }, - }) - if err != nil { - t.Fatalf("Failed to register callback: %v", err) - } - if err := store.Write(ctx, writeTxn, storage.ReplaceOp, modifiedPath, expectedValue); err != nil { - t.Fatalf("Unexpected write error: %v", err) - } + store.Abort(ctx, readTxn) + + var event storage.TriggerEvent + modifiedPath := storage.MustParsePath("/a") + expectedValue := "hello" + + _, err = store.Register(ctx, writeTxn, storage.TriggerConfig{ + OnCommit: func(ctx context.Context, txn storage.Transaction, evt storage.TriggerEvent) { + result, err := store.Read(ctx, txn, modifiedPath) + if tc.ast { + expAstValue := ast.String(expectedValue) + if err != nil || ast.Compare(expAstValue, result) != 0 { + t.Fatalf("Expected result to be %v for trigger read but got: %v (err: %v)", expectedValue, result, err) + } + } else { + if err != nil || !reflect.DeepEqual(result, expectedValue) { + t.Fatalf("Expected result to be %v for trigger read but got: %v (err: %v)", expectedValue, result, err) + } + } + event = evt + }, + }) + if err != nil { + t.Fatalf("Failed to register callback: %v", err) + } - id := "test" - data := []byte("package abc") - if err := store.UpsertPolicy(ctx, writeTxn, id, data); err != nil { - t.Fatalf("Unexpected upsert error: %v", err) - } + if err := store.Write(ctx, writeTxn, storage.ReplaceOp, modifiedPath, expectedValue); err != nil { + t.Fatalf("Unexpected write error: %v", err) + } - if err := store.Commit(ctx, writeTxn); err != nil { - t.Fatalf("Unexpected commit error: %v", err) - } + id := "test" + data := []byte("package abc") + if err := store.UpsertPolicy(ctx, writeTxn, id, data); err != nil { + t.Fatalf("Unexpected upsert error: %v", err) + } - if event.IsZero() || !event.PolicyChanged() || !event.DataChanged() { - t.Fatalf("Expected policy and data change but got: %v", event) - } + if err := store.Commit(ctx, writeTxn); err != nil { + t.Fatalf("Unexpected commit error: %v", err) + } - expData := storage.DataEvent{Path: modifiedPath, Data: expectedValue, Removed: false} - if d := event.Data[0]; !reflect.DeepEqual(expData, d) { - t.Fatalf("Expected data event %v, got %v", expData, d) - } + if event.IsZero() || !event.PolicyChanged() || !event.DataChanged() { + t.Fatalf("Expected policy and data change but got: %v", event) + } + + expData := storage.DataEvent{Path: modifiedPath, Data: expectedValue, Removed: false} + if d := event.Data[0]; !reflect.DeepEqual(expData, d) { + t.Fatalf("Expected data event %v, got %v", expData, d) + } - expPolicy := storage.PolicyEvent{ID: id, Data: data, Removed: false} - if p := event.Policy[0]; !reflect.DeepEqual(expPolicy, p) { - t.Fatalf("Expected policy event %v, got %v", expPolicy, p) + expPolicy := storage.PolicyEvent{ID: id, Data: data, Removed: false} + if p := event.Policy[0]; !reflect.DeepEqual(expPolicy, p) { + t.Fatalf("Expected policy event %v, got %v", expPolicy, p) + } + }) } } diff --git a/storage/inmem/opts.go b/storage/inmem/opts.go index fb8dc8e2bf..2239fc73a3 100644 --- a/storage/inmem/opts.go +++ b/storage/inmem/opts.go @@ -23,3 +23,15 @@ func OptRoundTripOnWrite(enabled bool) Opt { s.roundTripOnWrite = enabled } } + +// OptReturnASTValuesOnRead sets whether data values added to the store should be +// eagerly converted to AST values, which are then returned on read. +// +// When enabled, this feature does not sanity check data before converting it to AST values, +// which may result in panics if the data is not valid. Callers should ensure that passed data +// can be serialized to AST values; otherwise, it's recommended to also enable OptRoundTripOnWrite. +func OptReturnASTValuesOnRead(enabled bool) Opt { + return func(s *store) { + s.returnASTValuesOnRead = enabled + } +} diff --git a/storage/inmem/txn.go b/storage/inmem/txn.go index 3a61018291..d3252e8822 100644 --- a/storage/inmem/txn.go +++ b/storage/inmem/txn.go @@ -9,6 +9,7 @@ import ( "encoding/json" "strconv" + "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/internal/deepcopy" "github.com/open-policy-agent/opa/storage" "github.com/open-policy-agent/opa/storage/internal/errors" @@ -76,13 +77,13 @@ func (txn *transaction) Write(op storage.PatchOp, path storage.Path, value inter } for curr := txn.updates.Front(); curr != nil; { - update := curr.Value.(*update) + update := curr.Value.(dataUpdate) // Check if new update masks existing update exactly. In this case, the // existing update can be removed and no other updates have to be // visited (because no two updates overlap.) - if update.path.Equal(path) { - if update.remove { + if update.Path().Equal(path) { + if update.Remove() { if op != storage.AddOp { return errors.NewNotFoundError(path) } @@ -94,7 +95,7 @@ func (txn *transaction) Write(op storage.PatchOp, path storage.Path, value inter // Check if new update masks existing update. In this case, the // existing update has to be removed but other updates may overlap, so // we must continue. - if update.path.HasPrefix(path) { + if update.Path().HasPrefix(path) { remove := curr curr = curr.Next() txn.updates.Remove(remove) @@ -103,23 +104,23 @@ func (txn *transaction) Write(op storage.PatchOp, path storage.Path, value inter // Check if new update modifies existing update. In this case, the // existing update is mutated. - if path.HasPrefix(update.path) { - if update.remove { + if path.HasPrefix(update.Path()) { + if update.Remove() { return errors.NewNotFoundError(path) } - suffix := path[len(update.path):] - newUpdate, err := newUpdate(update.value, op, suffix, 0, value) + suffix := path[len(update.Path()):] + newUpdate, err := txn.db.newUpdate(update.Value(), op, suffix, 0, value) if err != nil { return err } - update.value = newUpdate.Apply(update.value) + update.Set(newUpdate.Apply(update.Value())) return nil } curr = curr.Next() } - update, err := newUpdate(txn.db.data, op, path, 0, value) + update, err := txn.db.newUpdate(txn.db.data, op, path, 0, value) if err != nil { return err } @@ -132,72 +133,115 @@ func (txn *transaction) updateRoot(op storage.PatchOp, value interface{}) error if op == storage.RemoveOp { return invalidPatchError(rootCannotBeRemovedMsg) } - if _, ok := value.(map[string]interface{}); !ok { - return invalidPatchError(rootMustBeObjectMsg) + + var update any + if txn.db.returnASTValuesOnRead { + valueAST, err := interfaceToValue(value) + if err != nil { + return err + } + if _, ok := valueAST.(ast.Object); !ok { + return invalidPatchError(rootMustBeObjectMsg) + } + + update = &updateAST{ + path: storage.Path{}, + remove: false, + value: valueAST, + } + } else { + if _, ok := value.(map[string]interface{}); !ok { + return invalidPatchError(rootMustBeObjectMsg) + } + + update = &updateRaw{ + path: storage.Path{}, + remove: false, + value: value, + } } + txn.updates.Init() - txn.updates.PushFront(&update{ - path: storage.Path{}, - remove: false, - value: value, - }) + txn.updates.PushFront(update) return nil } func (txn *transaction) Commit() (result storage.TriggerEvent) { result.Context = txn.context for curr := txn.updates.Front(); curr != nil; curr = curr.Next() { - action := curr.Value.(*update) - updated := action.Apply(txn.db.data) - txn.db.data = updated.(map[string]interface{}) + action := curr.Value.(dataUpdate) + txn.db.data = action.Apply(txn.db.data) result.Data = append(result.Data, storage.DataEvent{ - Path: action.path, - Data: action.value, - Removed: action.remove, + Path: action.Path(), + Data: action.Value(), + Removed: action.Remove(), }) } - for id, update := range txn.policies { - if update.remove { + for id, upd := range txn.policies { + if upd.remove { delete(txn.db.policies, id) } else { - txn.db.policies[id] = update.value + txn.db.policies[id] = upd.value } result.Policy = append(result.Policy, storage.PolicyEvent{ ID: id, - Data: update.value, - Removed: update.remove, + Data: upd.value, + Removed: upd.remove, }) } return result } +func pointer(v interface{}, path storage.Path) (interface{}, error) { + if v, ok := v.(ast.Value); ok { + return ptr.ValuePtr(v, path) + } + return ptr.Ptr(v, path) +} + +func deepcpy(v interface{}) interface{} { + if v, ok := v.(ast.Value); ok { + var cpy ast.Value + + switch data := v.(type) { + case ast.Object: + cpy = data.Copy() + case *ast.Array: + cpy = data.Copy() + } + + return cpy + } + return deepcopy.DeepCopy(v) +} + func (txn *transaction) Read(path storage.Path) (interface{}, error) { if !txn.write { - return ptr.Ptr(txn.db.data, path) + return pointer(txn.db.data, path) } - merge := []*update{} + var merge []dataUpdate for curr := txn.updates.Front(); curr != nil; curr = curr.Next() { - update := curr.Value.(*update) + upd := curr.Value.(dataUpdate) - if path.HasPrefix(update.path) { - if update.remove { + if path.HasPrefix(upd.Path()) { + if upd.Remove() { return nil, errors.NewNotFoundError(path) } - return ptr.Ptr(update.value, path[len(update.path):]) + return pointer(upd.Value(), path[len(upd.Path()):]) } - if update.path.HasPrefix(path) { - merge = append(merge, update) + if upd.Path().HasPrefix(path) { + merge = append(merge, upd) } } - data, err := ptr.Ptr(txn.db.data, path) + data, err := pointer(txn.db.data, path) if err != nil { return nil, err @@ -207,7 +251,7 @@ func (txn *transaction) Read(path storage.Path) (interface{}, error) { return data, nil } - cpy := deepcopy.DeepCopy(data) + cpy := deepcpy(data) for _, update := range merge { cpy = update.Relative(path).Apply(cpy) @@ -266,15 +310,44 @@ func (txn *transaction) DeletePolicy(id string) error { return nil } +type dataUpdate interface { + Path() storage.Path + Remove() bool + Apply(interface{}) interface{} + Relative(path storage.Path) dataUpdate + Set(interface{}) + Value() interface{} +} + // update contains state associated with an update to be applied to the // in-memory data store. -type update struct { +type updateRaw struct { path storage.Path // data path modified by update remove bool // indicates whether update removes the value at path value interface{} // value to add/replace at path (ignored if remove is true) } -func newUpdate(data interface{}, op storage.PatchOp, path storage.Path, idx int, value interface{}) (*update, error) { +func (db *store) newUpdate(data interface{}, op storage.PatchOp, path storage.Path, idx int, value interface{}) (dataUpdate, error) { + if db.returnASTValuesOnRead { + astData, err := interfaceToValue(data) + if err != nil { + return nil, err + } + astValue, err := interfaceToValue(value) + if err != nil { + return nil, err + } + return newUpdateAST(astData, op, path, idx, astValue) + } + return newUpdateRaw(data, op, path, idx, value) +} + +func newUpdateRaw(data interface{}, op storage.PatchOp, path storage.Path, idx int, value interface{}) (dataUpdate, error) { + + switch data.(type) { + case nil, bool, json.Number, string: + return nil, errors.NewNotFoundError(path) + } switch data := data.(type) { case map[string]interface{}: @@ -282,9 +355,6 @@ func newUpdate(data interface{}, op storage.PatchOp, path storage.Path, idx int, case []interface{}: return newUpdateArray(data, op, path, idx, value) - - case nil, bool, json.Number, string: - return nil, errors.NewNotFoundError(path) } return nil, &storage.Error{ @@ -293,7 +363,7 @@ func newUpdate(data interface{}, op storage.PatchOp, path storage.Path, idx int, } } -func newUpdateArray(data []interface{}, op storage.PatchOp, path storage.Path, idx int, value interface{}) (*update, error) { +func newUpdateArray(data []interface{}, op storage.PatchOp, path storage.Path, idx int, value interface{}) (dataUpdate, error) { if idx == len(path)-1 { if path[idx] == "-" || path[idx] == strconv.Itoa(len(data)) { @@ -303,7 +373,7 @@ func newUpdateArray(data []interface{}, op storage.PatchOp, path storage.Path, i cpy := make([]interface{}, len(data)+1) copy(cpy, data) cpy[len(data)] = value - return &update{path[:len(path)-1], false, cpy}, nil + return &updateRaw{path[:len(path)-1], false, cpy}, nil } pos, err := ptr.ValidateArrayIndex(data, path[idx], path) @@ -317,19 +387,19 @@ func newUpdateArray(data []interface{}, op storage.PatchOp, path storage.Path, i copy(cpy[:pos], data[:pos]) copy(cpy[pos+1:], data[pos:]) cpy[pos] = value - return &update{path[:len(path)-1], false, cpy}, nil + return &updateRaw{path[:len(path)-1], false, cpy}, nil case storage.RemoveOp: cpy := make([]interface{}, len(data)-1) copy(cpy[:pos], data[:pos]) copy(cpy[pos:], data[pos+1:]) - return &update{path[:len(path)-1], false, cpy}, nil + return &updateRaw{path[:len(path)-1], false, cpy}, nil default: cpy := make([]interface{}, len(data)) copy(cpy, data) cpy[pos] = value - return &update{path[:len(path)-1], false, cpy}, nil + return &updateRaw{path[:len(path)-1], false, cpy}, nil } } @@ -338,10 +408,10 @@ func newUpdateArray(data []interface{}, op storage.PatchOp, path storage.Path, i return nil, err } - return newUpdate(data[pos], op, path, idx+1, value) + return newUpdateRaw(data[pos], op, path, idx+1, value) } -func newUpdateObject(data map[string]interface{}, op storage.PatchOp, path storage.Path, idx int, value interface{}) (*update, error) { +func newUpdateObject(data map[string]interface{}, op storage.PatchOp, path storage.Path, idx int, value interface{}) (dataUpdate, error) { if idx == len(path)-1 { switch op { @@ -350,16 +420,25 @@ func newUpdateObject(data map[string]interface{}, op storage.PatchOp, path stora return nil, errors.NewNotFoundError(path) } } - return &update{path, op == storage.RemoveOp, value}, nil + return &updateRaw{path, op == storage.RemoveOp, value}, nil } if data, ok := data[path[idx]]; ok { - return newUpdate(data, op, path, idx+1, value) + return newUpdateRaw(data, op, path, idx+1, value) } return nil, errors.NewNotFoundError(path) } -func (u *update) Apply(data interface{}) interface{} { + +func (u *updateRaw) Remove() bool { + return u.remove +} + +func (u *updateRaw) Path() storage.Path { + return u.path +} + +func (u *updateRaw) Apply(data interface{}) interface{} { if len(u.path) == 0 { return u.value } @@ -389,7 +468,15 @@ func (u *update) Apply(data interface{}) interface{} { return data } -func (u *update) Relative(path storage.Path) *update { +func (u *updateRaw) Set(v interface{}) { + u.value = v +} + +func (u *updateRaw) Value() interface{} { + return u.value +} + +func (u *updateRaw) Relative(path storage.Path) dataUpdate { cpy := *u cpy.path = cpy.path[len(path):] return &cpy diff --git a/storage/internal/ptr/ptr.go b/storage/internal/ptr/ptr.go index 56772f7976..14adbd682e 100644 --- a/storage/internal/ptr/ptr.go +++ b/storage/internal/ptr/ptr.go @@ -8,6 +8,7 @@ package ptr import ( "strconv" + "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/storage" "github.com/open-policy-agent/opa/storage/internal/errors" ) @@ -36,6 +37,32 @@ func Ptr(data interface{}, path storage.Path) (interface{}, error) { return node, nil } +func ValuePtr(data ast.Value, path storage.Path) (ast.Value, error) { + node := data + for i := range path { + key := path[i] + switch curr := node.(type) { + case ast.Object: + keyTerm := ast.StringTerm(key) + val := curr.Get(keyTerm) + if val == nil { + return nil, errors.NewNotFoundError(path) + } + node = val.Value + case *ast.Array: + pos, err := ValidateASTArrayIndex(curr, key, path) + if err != nil { + return nil, err + } + node = curr.Elem(pos).Value + default: + return nil, errors.NewNotFoundError(path) + } + } + + return node, nil +} + func ValidateArrayIndex(arr []interface{}, s string, path storage.Path) (int, error) { idx, ok := isInt(s) if !ok { @@ -44,6 +71,14 @@ func ValidateArrayIndex(arr []interface{}, s string, path storage.Path) (int, er return inRange(idx, arr, path) } +func ValidateASTArrayIndex(arr *ast.Array, s string, path storage.Path) (int, error) { + idx, ok := isInt(s) + if !ok { + return 0, errors.NewNotFoundErrorWithHint(path, errors.ArrayIndexTypeMsg) + } + return inRange(idx, arr, path) +} + // ValidateArrayIndexForWrite also checks that `s` is a valid way to address an // array element like `ValidateArrayIndex`, but returns a `resource_conflict` error // if it is not. @@ -60,8 +95,18 @@ func isInt(s string) (int, bool) { return idx, err == nil } -func inRange(i int, arr []interface{}, path storage.Path) (int, error) { - if i < 0 || i >= len(arr) { +func inRange(i int, arr interface{}, path storage.Path) (int, error) { + + var arrLen int + + switch v := arr.(type) { + case []interface{}: + arrLen = len(v) + case *ast.Array: + arrLen = v.Len() + } + + if i < 0 || i >= arrLen { return 0, errors.NewNotFoundErrorWithHint(path, errors.OutOfRangeMsg) } return i, nil diff --git a/storage/storage.go b/storage/storage.go index 1e290c50bb..2f8a39c597 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -6,6 +6,8 @@ package storage import ( "context" + + "github.com/open-policy-agent/opa/ast" ) // NewTransactionOrDie is a helper function to create a new transaction. If the @@ -78,6 +80,11 @@ func MakeDir(ctx context.Context, store Store, txn Transaction, path Path) error if _, ok := node.(map[string]interface{}); ok { return nil } + + if _, ok := node.(ast.Object); ok { + return nil + } + return writeConflictError(path) } @@ -118,6 +125,9 @@ func NonEmpty(ctx context.Context, store Store, txn Transaction) func([]string) if _, ok := val.(map[string]interface{}); ok { return false, nil } + if _, ok := val.(ast.Object); ok { + return false, nil + } return true, nil } }