diff --git a/store/v2/commitment/iavl/tree.go b/store/v2/commitment/iavl/tree.go index 5503218e66a1..5047e8ef6ed4 100644 --- a/store/v2/commitment/iavl/tree.go +++ b/store/v2/commitment/iavl/tree.go @@ -65,6 +65,13 @@ func (t *IavlTree) WorkingHash() []byte { // LoadVersion loads the state at the given version. func (t *IavlTree) LoadVersion(version uint64) error { + _, err := t.tree.LoadVersion(int64(version)) + return err +} + +// LoadVersionForOverwriting loads the state at the given version. +// Any versions greater than targetVersion will be deleted. +func (t *IavlTree) LoadVersionForOverwriting(version uint64) error { return t.tree.LoadVersionForOverwriting(int64(version)) } diff --git a/store/v2/commitment/iavlv2/tree.go b/store/v2/commitment/iavlv2/tree.go index 997a0a60cc1a..14b7967a6c78 100644 --- a/store/v2/commitment/iavlv2/tree.go +++ b/store/v2/commitment/iavlv2/tree.go @@ -64,6 +64,10 @@ func (t *Tree) LoadVersion(version uint64) error { return t.tree.LoadVersion(int64(version)) } +func (t *Tree) LoadVersionForOverwriting(version uint64) error { + return t.LoadVersion(version) // TODO: implement overwriting +} + func (t *Tree) Commit() ([]byte, uint64, error) { h, v, err := t.tree.SaveVersion() return h, uint64(v), err diff --git a/store/v2/commitment/mem/tree.go b/store/v2/commitment/mem/tree.go index cbc28ce7d9ae..bf0e95bfa9c1 100644 --- a/store/v2/commitment/mem/tree.go +++ b/store/v2/commitment/mem/tree.go @@ -34,6 +34,10 @@ func (t *Tree) LoadVersion(version uint64) error { return nil } +func (t *Tree) LoadVersionForOverwriting(version uint64) error { + return nil +} + func (t *Tree) Commit() ([]byte, uint64, error) { return nil, 0, nil } diff --git a/store/v2/commitment/store.go b/store/v2/commitment/store.go index e9f2ee8379c7..5219255f95ca 100644 --- a/store/v2/commitment/store.go +++ b/store/v2/commitment/store.go @@ -87,7 +87,16 @@ func (c *CommitStore) LoadVersion(targetVersion uint64) error { for storeKey := range c.multiTrees { storeKeys = append(storeKeys, storeKey) } - return c.loadVersion(targetVersion, storeKeys) + return c.loadVersion(targetVersion, storeKeys, false) +} + +func (c *CommitStore) LoadVersionForOverwriting(targetVersion uint64) error { + storeKeys := make([]string, 0, len(c.multiTrees)) + for storeKey := range c.multiTrees { + storeKeys = append(storeKeys, storeKey) + } + + return c.loadVersion(targetVersion, storeKeys, true) } // LoadVersionAndUpgrade implements store.UpgradeableStore. @@ -133,10 +142,10 @@ func (c *CommitStore) LoadVersionAndUpgrade(targetVersion uint64, upgrades *core return err } - return c.loadVersion(targetVersion, newStoreKeys) + return c.loadVersion(targetVersion, newStoreKeys, true) } -func (c *CommitStore) loadVersion(targetVersion uint64, storeKeys []string) error { +func (c *CommitStore) loadVersion(targetVersion uint64, storeKeys []string, overrideAfter bool) error { // Rollback the metadata to the target version. latestVersion, err := c.GetLatestVersion() if err != nil { @@ -154,8 +163,14 @@ func (c *CommitStore) loadVersion(targetVersion uint64, storeKeys []string) erro } for _, storeKey := range storeKeys { - if err := c.multiTrees[storeKey].LoadVersion(targetVersion); err != nil { - return err + if overrideAfter { + if err := c.multiTrees[storeKey].LoadVersionForOverwriting(targetVersion); err != nil { + return err + } + } else { + if err := c.multiTrees[storeKey].LoadVersion(targetVersion); err != nil { + return err + } } } diff --git a/store/v2/commitment/tree.go b/store/v2/commitment/tree.go index f57eabd20724..58a8b20beff2 100644 --- a/store/v2/commitment/tree.go +++ b/store/v2/commitment/tree.go @@ -25,6 +25,7 @@ type Tree interface { Version() uint64 LoadVersion(version uint64) error + LoadVersionForOverwriting(version uint64) error Commit() ([]byte, uint64, error) SetInitialVersion(version uint64) error GetProof(version uint64, key []byte) (*ics23.CommitmentProof, error) diff --git a/store/v2/database.go b/store/v2/database.go index 27d0973ec18e..e3361d731024 100644 --- a/store/v2/database.go +++ b/store/v2/database.go @@ -50,6 +50,10 @@ type Committer interface { // LoadVersion loads the tree at the given version. LoadVersion(targetVersion uint64) error + // LoadVersionForOverwriting loads the tree at the given version. + // Any versions greater than targetVersion will be deleted. + LoadVersionForOverwriting(targetVersion uint64) error + // Commit commits the working tree to the database. Commit(version uint64) (*proof.CommitInfo, error) diff --git a/store/v2/mock/db_mock.go b/store/v2/mock/db_mock.go index 9b962affb102..ba65f2baf243 100644 --- a/store/v2/mock/db_mock.go +++ b/store/v2/mock/db_mock.go @@ -158,6 +158,20 @@ func (mr *MockStateCommitterMockRecorder) LoadVersionAndUpgrade(version, upgrade return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadVersionAndUpgrade", reflect.TypeOf((*MockStateCommitter)(nil).LoadVersionAndUpgrade), version, upgrades) } +// LoadVersionForOverwriting mocks base method. +func (m *MockStateCommitter) LoadVersionForOverwriting(targetVersion uint64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LoadVersionForOverwriting", targetVersion) + ret0, _ := ret[0].(error) + return ret0 +} + +// LoadVersionForOverwriting indicates an expected call of LoadVersionForOverwriting. +func (mr *MockStateCommitterMockRecorder) LoadVersionForOverwriting(targetVersion any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadVersionForOverwriting", reflect.TypeOf((*MockStateCommitter)(nil).LoadVersionForOverwriting), targetVersion) +} + // PausePruning mocks base method. func (m *MockStateCommitter) PausePruning(pause bool) { m.ctrl.T.Helper() diff --git a/store/v2/root/store.go b/store/v2/root/store.go index 59363e2fb35b..b40baef6424e 100644 --- a/store/v2/root/store.go +++ b/store/v2/root/store.go @@ -250,7 +250,7 @@ func (s *Store) LoadLatestVersion() error { return err } - return s.loadVersion(lv, nil) + return s.loadVersion(lv, nil, false) } func (s *Store) LoadVersion(version uint64) error { @@ -259,7 +259,16 @@ func (s *Store) LoadVersion(version uint64) error { defer s.telemetry.MeasureSince(now, "root_store", "load_version") } - return s.loadVersion(version, nil) + return s.loadVersion(version, nil, false) +} + +func (s *Store) LoadVersionForOverwriting(version uint64) error { + if s.telemetry != nil { + now := time.Now() + defer s.telemetry.MeasureSince(now, "root_store", "load_version_for_overwriting") + } + + return s.loadVersion(version, nil, true) } // LoadVersionAndUpgrade implements the UpgradeableStore interface. @@ -278,7 +287,7 @@ func (s *Store) LoadVersionAndUpgrade(version uint64, upgrades *corestore.StoreU return errors.New("cannot upgrade while migrating") } - if err := s.loadVersion(version, upgrades); err != nil { + if err := s.loadVersion(version, upgrades, true); err != nil { return err } @@ -294,12 +303,18 @@ func (s *Store) LoadVersionAndUpgrade(version uint64, upgrades *corestore.StoreU return nil } -func (s *Store) loadVersion(v uint64, upgrades *corestore.StoreUpgrades) error { +func (s *Store) loadVersion(v uint64, upgrades *corestore.StoreUpgrades, overrideAfter bool) error { s.logger.Debug("loading version", "version", v) if upgrades == nil { - if err := s.stateCommitment.LoadVersion(v); err != nil { - return fmt.Errorf("failed to load SC version %d: %w", v, err) + if !overrideAfter { + if err := s.stateCommitment.LoadVersion(v); err != nil { + return fmt.Errorf("failed to load SC version %d: %w", v, err) + } + } else { + if err := s.stateCommitment.LoadVersionForOverwriting(v); err != nil { + return fmt.Errorf("failed to load SC version %d: %w", v, err) + } } } else { // if upgrades are provided, we need to load the version and apply the upgrades diff --git a/store/v2/root/store_test.go b/store/v2/root/store_test.go index 59a490b11b00..8bb6b5604e2d 100644 --- a/store/v2/root/store_test.go +++ b/store/v2/root/store_test.go @@ -256,6 +256,74 @@ func (s *RootStoreTestSuite) TestLoadVersion() { s.Require().NoError(err) s.Require().Equal([]byte("val003"), val) + // attempt to write and commit a few changesets + for v := 4; v <= 5; v++ { + val := fmt.Sprintf("overwritten_val%03d", v) // overwritten_val004, overwritten_val005 + + cs := corestore.NewChangeset(uint64(v)) + cs.Add(testStoreKeyBytes, []byte("key"), []byte(val), false) + + _, err := s.rootStore.Commit(cs) + s.Require().Error(err) + } + + // ensure the latest version is correct + latest, err = s.rootStore.GetLatestVersion() + s.Require().NoError(err) + s.Require().Equal(uint64(3), latest) // should have stayed at 3 after failed commits + + // query state and ensure values returned are based on the loaded version + _, ro, err = s.rootStore.StateLatest() + s.Require().NoError(err) + + reader, err = ro.GetReader(testStoreKeyBytes) + s.Require().NoError(err) + val, err = reader.Get([]byte("key")) + s.Require().NoError(err) + s.Require().Equal([]byte("val003"), val) +} + +func (s *RootStoreTestSuite) TestLoadVersionForOverwriting() { + // write and commit a few changesets + for v := uint64(1); v <= 5; v++ { + val := fmt.Sprintf("val%03d", v) // val001, val002, ..., val005 + + cs := corestore.NewChangeset(v) + cs.Add(testStoreKeyBytes, []byte("key"), []byte(val), false) + + commitHash, err := s.rootStore.Commit(cs) + s.Require().NoError(err) + s.Require().NotNil(commitHash) + } + + // ensure the latest version is correct + latest, err := s.rootStore.GetLatestVersion() + s.Require().NoError(err) + s.Require().Equal(uint64(5), latest) + + // attempt to load a non-existent version + err = s.rootStore.LoadVersionForOverwriting(6) + s.Require().Error(err) + + // attempt to load a previously committed version + err = s.rootStore.LoadVersionForOverwriting(3) + s.Require().NoError(err) + + // ensure the latest version is correct + latest, err = s.rootStore.GetLatestVersion() + s.Require().NoError(err) + s.Require().Equal(uint64(3), latest) + + // query state and ensure values returned are based on the loaded version + _, ro, err := s.rootStore.StateLatest() + s.Require().NoError(err) + + reader, err := ro.GetReader(testStoreKeyBytes) + s.Require().NoError(err) + val, err := reader.Get([]byte("key")) + s.Require().NoError(err) + s.Require().Equal([]byte("val003"), val) + // attempt to write and commit a few changesets for v := 4; v <= 5; v++ { val := fmt.Sprintf("overwritten_val%03d", v) // overwritten_val004, overwritten_val005 diff --git a/store/v2/store.go b/store/v2/store.go index 124d7de579a1..bf967d0f78a6 100644 --- a/store/v2/store.go +++ b/store/v2/store.go @@ -30,6 +30,10 @@ type RootStore interface { // LoadVersion loads the RootStore to the given version. LoadVersion(version uint64) error + // LoadVersionForOverwriting loads the state at the given version. + // Any versions greater than targetVersion will be deleted. + LoadVersionForOverwriting(version uint64) error + // LoadLatestVersion behaves identically to LoadVersion except it loads the // latest version implicitly. LoadLatestVersion() error