diff --git a/cmd/util/ledger/migrations/cadence.go b/cmd/util/ledger/migrations/cadence.go index b78d7e04d1f..bc969201a1d 100644 --- a/cmd/util/ledger/migrations/cadence.go +++ b/cmd/util/ledger/migrations/cadence.go @@ -193,6 +193,9 @@ func NewCadence1ContractsMigrations( stagedContracts []StagedContract, ) []ledger.Migration { + stagedContractsMigration := NewStagedContractsMigration(chainID) + stagedContractsMigration.RegisterContractUpdates(stagedContracts) + return []ledger.Migration{ NewAccountBasedMigration( log, @@ -211,7 +214,7 @@ func NewCadence1ContractsMigrations( log, nWorker, []AccountBasedMigration{ - NewStagedContractsMigration(stagedContracts), + stagedContractsMigration, }, ), } diff --git a/cmd/util/ledger/migrations/cadence_values_migration_test.go b/cmd/util/ledger/migrations/cadence_values_migration_test.go index d620b9bc8c9..0c139d7dc2c 100644 --- a/cmd/util/ledger/migrations/cadence_values_migration_test.go +++ b/cmd/util/ledger/migrations/cadence_values_migration_test.go @@ -7,8 +7,6 @@ import ( "sync" "testing" - "github.com/onflow/flow-go/fvm/environment" - "github.com/rs/zerolog" _ "github.com/glebarez/go-sqlite" @@ -22,6 +20,7 @@ import ( "github.com/onflow/flow-go/cmd/util/ledger/reporters" "github.com/onflow/flow-go/cmd/util/ledger/util" + "github.com/onflow/flow-go/fvm/environment" "github.com/onflow/flow-go/ledger" "github.com/onflow/flow-go/model/flow" ) @@ -99,7 +98,7 @@ func TestCadenceValuesMigration(t *testing.T) { checkReporters(t, rwf, address) // Check error logs. - require.Nil(t, logWriter.logs) + require.Empty(t, logWriter.logs) } // TODO: diff --git a/cmd/util/ledger/migrations/change_contract_code_migration.go b/cmd/util/ledger/migrations/change_contract_code_migration.go index 0c7021a3c57..863113a29d5 100644 --- a/cmd/util/ledger/migrations/change_contract_code_migration.go +++ b/cmd/util/ledger/migrations/change_contract_code_migration.go @@ -1,15 +1,11 @@ package migrations import ( - "context" "fmt" - "strings" - "sync" coreContracts "github.com/onflow/flow-core-contracts/lib/go/contracts" ftContracts "github.com/onflow/flow-ft/lib/go/contracts" nftContracts "github.com/onflow/flow-nft/lib/go/contracts" - "github.com/rs/zerolog" sdk "github.com/onflow/flow-go-sdk" @@ -17,159 +13,34 @@ import ( evm "github.com/onflow/flow-go/fvm/evm/stdlib" "github.com/onflow/flow-go/fvm/systemcontracts" - "github.com/onflow/flow-go/ledger" - "github.com/onflow/flow-go/ledger/common/convert" "github.com/onflow/flow-go/model/flow" ) type ChangeContractCodeMigration struct { - log zerolog.Logger - mutex sync.RWMutex - contracts map[common.Address]map[flow.RegisterID]string + *StagedContractsMigration } var _ AccountBasedMigration = (*ChangeContractCodeMigration)(nil) -func (d *ChangeContractCodeMigration) Close() error { - d.mutex.RLock() - defer d.mutex.RUnlock() - - if len(d.contracts) > 0 { - var sb strings.Builder - sb.WriteString("failed to find all contract registers that need to be changed:\n") - for address, contracts := range d.contracts { - _, _ = fmt.Fprintf(&sb, "- address: %s\n", address) - for registerID := range contracts { - _, _ = fmt.Fprintf(&sb, " - %s\n", flow.RegisterIDContractName(registerID)) - } - } - return fmt.Errorf(sb.String()) - } - - return nil -} - -func (d *ChangeContractCodeMigration) InitMigration( - log zerolog.Logger, - _ []*ledger.Payload, - _ int, -) error { - d.log = log. - With(). - Str("migration", "ChangeContractCodeMigration"). - Logger() - - return nil -} - -func (d *ChangeContractCodeMigration) MigrateAccount( - _ context.Context, - address common.Address, - payloads []*ledger.Payload, -) ([]*ledger.Payload, error) { - - contracts, ok := (func() (map[flow.RegisterID]string, bool) { - d.mutex.Lock() - defer d.mutex.Unlock() - - contracts, ok := d.contracts[address] - - // remove address from set of addresses - // to keep track of which addresses are left to change - delete(d.contracts, address) - - return contracts, ok - })() - - if !ok { - // no contracts to change on this address - return payloads, nil - } - - for payloadIndex, payload := range payloads { - key, err := payload.Key() - if err != nil { - return nil, err - } - - registerID, err := convert.LedgerKeyToRegisterID(key) - if err != nil { - return nil, err - } - - newContract, ok := contracts[registerID] - if !ok { - // not a contract register, or - // not interested in this contract - continue - } - - // change contract code - payloads[payloadIndex] = ledger.NewPayload( - key, - []byte(newContract), - ) - - // TODO: maybe log diff between old and new - - // remove contract from list of contracts to change - // to keep track of which contracts are left to change - delete(contracts, registerID) - } - - if len(contracts) > 0 { - var sb strings.Builder - _, _ = fmt.Fprintf(&sb, "failed to find all contract registers that need to be changed for address %s:\n", address) - for registerID := range contracts { - _, _ = fmt.Fprintf(&sb, "- %s\n", flow.RegisterIDContractName(registerID)) - } - return nil, fmt.Errorf(sb.String()) - } - - return payloads, nil -} - -func (d *ChangeContractCodeMigration) RegisterContractChange( - address common.Address, - contractName string, - newContractCode string, -) ( - previousNewContractCode string, -) { - d.mutex.Lock() - defer d.mutex.Unlock() - - if d.contracts == nil { - d.contracts = map[common.Address]map[flow.RegisterID]string{} - } - - if _, ok := d.contracts[address]; !ok { - d.contracts[address] = map[flow.RegisterID]string{} +func NewChangeContractCodeMigration(chainID flow.ChainID) *ChangeContractCodeMigration { + return &ChangeContractCodeMigration{ + StagedContractsMigration: NewStagedContractsMigration(chainID). + // TODO: + //WithContractUpdateValidation(). + WithName("ChangeContractCodeMigration"), } - - registerID := flow.ContractRegisterID(flow.ConvertAddress(address), contractName) - - previousNewContractCode = d.contracts[address][registerID] - - d.contracts[address][registerID] = newContractCode - - return -} - -type SystemContractChange struct { - Address common.Address - ContractName string - NewContractCode string } func NewSystemContractChange( systemContract systemcontracts.SystemContract, newContractCode []byte, -) SystemContractChange { - return SystemContractChange{ - Address: common.Address(systemContract.Address), - ContractName: systemContract.Name, - NewContractCode: string(newContractCode), +) StagedContract { + return StagedContract{ + Address: common.Address(systemContract.Address), + Contract: Contract{ + Name: systemContract.Name, + Code: newContractCode, + }, } } @@ -185,7 +56,7 @@ type SystemContractChangesOptions struct { EVM EVMContractChange } -func SystemContractChanges(chainID flow.ChainID, options SystemContractChangesOptions) []SystemContractChange { +func SystemContractChanges(chainID flow.ChainID, options SystemContractChangesOptions) []StagedContract { systemContracts := systemcontracts.SystemContractsForChain(chainID) var stakingCollectionAddress, stakingProxyAddress common.Address @@ -211,7 +82,7 @@ func SystemContractChanges(chainID flow.ChainID, options SystemContractChangesOp fungibleTokenMetadataViewsAddress := common.Address(systemContracts.FungibleToken.Address) fungibleTokenSwitchboardAddress := common.Address(systemContracts.FungibleToken.Address) - contractChanges := []SystemContractChange{ + contractChanges := []StagedContract{ // epoch related contracts NewSystemContractChange( systemContracts.Epoch, @@ -269,35 +140,41 @@ func SystemContractChanges(chainID flow.ChainID, options SystemContractChangesOp ), ), { - Address: stakingCollectionAddress, - ContractName: "FlowStakingCollection", - NewContractCode: string(coreContracts.FlowStakingCollection( - systemContracts.FungibleToken.Address.HexWithPrefix(), - systemContracts.FlowToken.Address.HexWithPrefix(), - systemContracts.IDTableStaking.Address.HexWithPrefix(), - stakingProxyAddress.HexWithPrefix(), - lockedTokensAddress.HexWithPrefix(), - systemContracts.FlowStorageFees.Address.HexWithPrefix(), - systemContracts.ClusterQC.Address.HexWithPrefix(), - systemContracts.DKG.Address.HexWithPrefix(), - systemContracts.Epoch.Address.HexWithPrefix(), - )), + Address: stakingCollectionAddress, + Contract: Contract{ + Name: "FlowStakingCollection", + Code: coreContracts.FlowStakingCollection( + systemContracts.FungibleToken.Address.HexWithPrefix(), + systemContracts.FlowToken.Address.HexWithPrefix(), + systemContracts.IDTableStaking.Address.HexWithPrefix(), + stakingProxyAddress.HexWithPrefix(), + lockedTokensAddress.HexWithPrefix(), + systemContracts.FlowStorageFees.Address.HexWithPrefix(), + systemContracts.ClusterQC.Address.HexWithPrefix(), + systemContracts.DKG.Address.HexWithPrefix(), + systemContracts.Epoch.Address.HexWithPrefix(), + ), + }, }, { - Address: stakingProxyAddress, - ContractName: "StakingProxy", - NewContractCode: string(coreContracts.FlowStakingProxy()), + Address: stakingProxyAddress, + Contract: Contract{ + Name: "StakingProxy", + Code: coreContracts.FlowStakingProxy(), + }, }, { - Address: lockedTokensAddress, - ContractName: "LockedTokens", - NewContractCode: string(coreContracts.FlowLockedTokens( - systemContracts.FungibleToken.Address.HexWithPrefix(), - systemContracts.FlowToken.Address.HexWithPrefix(), - systemContracts.IDTableStaking.Address.HexWithPrefix(), - stakingProxyAddress.HexWithPrefix(), - systemContracts.FlowStorageFees.Address.HexWithPrefix(), - )), + Address: lockedTokensAddress, + Contract: Contract{ + Name: "LockedTokens", + Code: coreContracts.FlowLockedTokens( + systemContracts.FungibleToken.Address.HexWithPrefix(), + systemContracts.FlowToken.Address.HexWithPrefix(), + systemContracts.IDTableStaking.Address.HexWithPrefix(), + stakingProxyAddress.HexWithPrefix(), + systemContracts.FlowStorageFees.Address.HexWithPrefix(), + ), + }, }, // token related contracts @@ -327,14 +204,16 @@ func SystemContractChanges(chainID flow.ChainID, options SystemContractChangesOp ), ), { - Address: fungibleTokenMetadataViewsAddress, - ContractName: "FungibleTokenMetadataViews", - NewContractCode: string(ftContracts.FungibleTokenMetadataViews( - // Use `Hex()`, since this method adds the prefix. - systemContracts.FungibleToken.Address.Hex(), - systemContracts.MetadataViews.Address.Hex(), - systemContracts.ViewResolver.Address.Hex(), - )), + Address: fungibleTokenMetadataViewsAddress, + Contract: Contract{ + Name: "FungibleTokenMetadataViews", + Code: ftContracts.FungibleTokenMetadataViews( + // Use `Hex()`, since this method adds the prefix. + systemContracts.FungibleToken.Address.Hex(), + systemContracts.MetadataViews.Address.Hex(), + systemContracts.ViewResolver.Address.Hex(), + ), + }, }, // NFT related contracts @@ -361,12 +240,14 @@ func SystemContractChanges(chainID flow.ChainID, options SystemContractChangesOp if chainID != flow.Emulator { contractChanges = append( contractChanges, - SystemContractChange{ - Address: fungibleTokenSwitchboardAddress, - ContractName: "FungibleTokenSwitchboard", - NewContractCode: string(ftContracts.FungibleTokenSwitchboard( - systemContracts.FungibleToken.Address.HexWithPrefix(), - )), + StagedContract{ + Address: fungibleTokenSwitchboardAddress, + Contract: Contract{ + Name: "FungibleTokenSwitchboard", + Code: ftContracts.FungibleTokenSwitchboard( + systemContracts.FungibleToken.Address.HexWithPrefix(), + ), + }, }, ) } @@ -406,13 +287,9 @@ func NewSystemContactsMigration( chainID flow.ChainID, options SystemContractChangesOptions, ) *ChangeContractCodeMigration { - migration := &ChangeContractCodeMigration{} + migration := NewChangeContractCodeMigration(chainID) for _, change := range SystemContractChanges(chainID, options) { - migration.RegisterContractChange( - change.Address, - change.ContractName, - change.NewContractCode, - ) + migration.RegisterContractChange(change) } return migration } diff --git a/cmd/util/ledger/migrations/change_contract_code_migration_test.go b/cmd/util/ledger/migrations/change_contract_code_migration_test.go index 1f499115dda..d23aefbc802 100644 --- a/cmd/util/ledger/migrations/change_contract_code_migration_test.go +++ b/cmd/util/ledger/migrations/change_contract_code_migration_test.go @@ -23,6 +23,26 @@ func newContractPayload(address common.Address, contractName string, contract [] ) } +const contractA = ` +access(all) contract A { + access(all) fun foo() {} +}` + +const updatedContractA = ` +access(all) contract A { + access(all) fun bar() {} +}` + +const contractB = ` +access(all) contract B { + access(all) fun foo() {} +}` + +const updatedContractB = ` +access(all) contract B { + access(all) fun bar() {} +}` + func TestChangeContractCodeMigration(t *testing.T) { t.Parallel() @@ -37,7 +57,7 @@ func TestChangeContractCodeMigration(t *testing.T) { t.Run("no contracts", func(t *testing.T) { t.Parallel() - migration := migrations.ChangeContractCodeMigration{} + migration := migrations.NewChangeContractCodeMigration(flow.Emulator) log := zerolog.New(zerolog.NewTestWriter(t)) err := migration.InitMigration(log, nil, 0) require.NoError(t, err) @@ -55,20 +75,20 @@ func TestChangeContractCodeMigration(t *testing.T) { t.Run("1 contract - dont migrate", func(t *testing.T) { t.Parallel() - migration := migrations.ChangeContractCodeMigration{} + migration := migrations.NewChangeContractCodeMigration(flow.Emulator) log := zerolog.New(zerolog.NewTestWriter(t)) err := migration.InitMigration(log, nil, 0) require.NoError(t, err) payloads, err := migration.MigrateAccount(ctx, address1, []*ledger.Payload{ - newContractPayload(address1, "A", []byte("A")), + newContractPayload(address1, "A", []byte(contractA)), }, ) require.NoError(t, err) require.Len(t, payloads, 1) - require.Equal(t, []byte("A"), []byte(payloads[0].Value())) + require.Equal(t, []byte(contractA), []byte(payloads[0].Value())) err = migration.Close() require.NoError(t, err) @@ -77,22 +97,30 @@ func TestChangeContractCodeMigration(t *testing.T) { t.Run("1 contract - migrate", func(t *testing.T) { t.Parallel() - migration := migrations.ChangeContractCodeMigration{} + migration := migrations.NewChangeContractCodeMigration(flow.Emulator) log := zerolog.New(zerolog.NewTestWriter(t)) err := migration.InitMigration(log, nil, 0) require.NoError(t, err) - migration.RegisterContractChange(address1, "A", "B") + migration.RegisterContractChange( + migrations.StagedContract{ + Address: address1, + Contract: migrations.Contract{ + Name: "A", + Code: []byte(updatedContractA), + }, + }, + ) payloads, err := migration.MigrateAccount(ctx, address1, []*ledger.Payload{ - newContractPayload(address1, "A", []byte("A")), + newContractPayload(address1, "A", []byte(contractA)), }, ) require.NoError(t, err) require.Len(t, payloads, 1) - require.Equal(t, []byte("B"), []byte(payloads[0].Value())) + require.Equal(t, updatedContractA, string(payloads[0].Value())) err = migration.Close() require.NoError(t, err) @@ -101,24 +129,32 @@ func TestChangeContractCodeMigration(t *testing.T) { t.Run("2 contracts - migrate 1", func(t *testing.T) { t.Parallel() - migration := migrations.ChangeContractCodeMigration{} + migration := migrations.NewChangeContractCodeMigration(flow.Emulator) log := zerolog.New(zerolog.NewTestWriter(t)) err := migration.InitMigration(log, nil, 0) require.NoError(t, err) - migration.RegisterContractChange(address1, "A", "B") + migration.RegisterContractChange( + migrations.StagedContract{ + Address: address1, + Contract: migrations.Contract{ + Name: "A", + Code: []byte(updatedContractA), + }, + }, + ) payloads, err := migration.MigrateAccount(ctx, address1, []*ledger.Payload{ - newContractPayload(address1, "A", []byte("A")), - newContractPayload(address1, "B", []byte("A")), + newContractPayload(address1, "A", []byte(contractA)), + newContractPayload(address1, "B", []byte(contractB)), }, ) require.NoError(t, err) require.Len(t, payloads, 2) - require.Equal(t, []byte("B"), []byte(payloads[0].Value())) - require.Equal(t, []byte("A"), []byte(payloads[1].Value())) + require.Equal(t, []byte(updatedContractA), []byte(payloads[0].Value())) + require.Equal(t, []byte(contractB), []byte(payloads[1].Value())) err = migration.Close() require.NoError(t, err) @@ -127,25 +163,41 @@ func TestChangeContractCodeMigration(t *testing.T) { t.Run("2 contracts - migrate 2", func(t *testing.T) { t.Parallel() - migration := migrations.ChangeContractCodeMigration{} + migration := migrations.NewChangeContractCodeMigration(flow.Emulator) log := zerolog.New(zerolog.NewTestWriter(t)) err := migration.InitMigration(log, nil, 0) require.NoError(t, err) - migration.RegisterContractChange(address1, "A", "B") - migration.RegisterContractChange(address1, "B", "B") + migration.RegisterContractChange( + migrations.StagedContract{ + Address: address1, + Contract: migrations.Contract{ + Name: "A", + Code: []byte(updatedContractA), + }, + }, + ) + migration.RegisterContractChange( + migrations.StagedContract{ + Address: address1, + Contract: migrations.Contract{ + Name: "B", + Code: []byte(updatedContractB), + }, + }, + ) payloads, err := migration.MigrateAccount(ctx, address1, []*ledger.Payload{ - newContractPayload(address1, "A", []byte("A")), - newContractPayload(address1, "B", []byte("A")), + newContractPayload(address1, "A", []byte(contractA)), + newContractPayload(address1, "B", []byte(contractB)), }, ) require.NoError(t, err) require.Len(t, payloads, 2) - require.Equal(t, []byte("B"), []byte(payloads[0].Value())) - require.Equal(t, []byte("B"), []byte(payloads[1].Value())) + require.Equal(t, []byte(updatedContractA), []byte(payloads[0].Value())) + require.Equal(t, []byte(updatedContractB), []byte(payloads[1].Value())) err = migration.Close() require.NoError(t, err) @@ -154,24 +206,32 @@ func TestChangeContractCodeMigration(t *testing.T) { t.Run("2 contracts on different accounts - migrate 1", func(t *testing.T) { t.Parallel() - migration := migrations.ChangeContractCodeMigration{} + migration := migrations.NewChangeContractCodeMigration(flow.Emulator) log := zerolog.New(zerolog.NewTestWriter(t)) err := migration.InitMigration(log, nil, 0) require.NoError(t, err) - migration.RegisterContractChange(address1, "A", "B") + migration.RegisterContractChange( + migrations.StagedContract{ + Address: address1, + Contract: migrations.Contract{ + Name: "A", + Code: []byte(updatedContractA), + }, + }, + ) payloads, err := migration.MigrateAccount(ctx, address1, []*ledger.Payload{ - newContractPayload(address1, "A", []byte("A")), - newContractPayload(address2, "A", []byte("A")), + newContractPayload(address1, "A", []byte(contractA)), + newContractPayload(address2, "A", []byte(contractA)), }, ) require.NoError(t, err) require.Len(t, payloads, 2) - require.Equal(t, []byte("B"), []byte(payloads[0].Value())) - require.Equal(t, []byte("A"), []byte(payloads[1].Value())) + require.Equal(t, []byte(updatedContractA), []byte(payloads[0].Value())) + require.Equal(t, []byte(contractA), []byte(payloads[1].Value())) err = migration.Close() require.NoError(t, err) @@ -180,17 +240,33 @@ func TestChangeContractCodeMigration(t *testing.T) { t.Run("not all contracts on one account migrated", func(t *testing.T) { t.Parallel() - migration := migrations.ChangeContractCodeMigration{} + migration := migrations.NewChangeContractCodeMigration(flow.Emulator) log := zerolog.New(zerolog.NewTestWriter(t)) err := migration.InitMigration(log, nil, 0) require.NoError(t, err) - migration.RegisterContractChange(address1, "A", "B") - migration.RegisterContractChange(address1, "B", "B") + migration.RegisterContractChange( + migrations.StagedContract{ + Address: address1, + Contract: migrations.Contract{ + Name: "A", + Code: []byte(updatedContractA), + }, + }, + ) + migration.RegisterContractChange( + migrations.StagedContract{ + Address: address1, + Contract: migrations.Contract{ + Name: "B", + Code: []byte(updatedContractB), + }, + }, + ) _, err = migration.MigrateAccount(ctx, address1, []*ledger.Payload{ - newContractPayload(address1, "A", []byte("A")), + newContractPayload(address1, "A", []byte(contractA)), }, ) @@ -200,16 +276,24 @@ func TestChangeContractCodeMigration(t *testing.T) { t.Run("not all accounts migrated", func(t *testing.T) { t.Parallel() - migration := migrations.ChangeContractCodeMigration{} + migration := migrations.NewChangeContractCodeMigration(flow.Emulator) log := zerolog.New(zerolog.NewTestWriter(t)) err := migration.InitMigration(log, nil, 0) require.NoError(t, err) - migration.RegisterContractChange(address2, "A", "B") + migration.RegisterContractChange( + migrations.StagedContract{ + Address: address2, + Contract: migrations.Contract{ + Name: "A", + Code: []byte(updatedContractA), + }, + }, + ) _, err = migration.MigrateAccount(ctx, address1, []*ledger.Payload{ - newContractPayload(address1, "A", []byte("A")), + newContractPayload(address1, "A", []byte(contractA)), }, ) diff --git a/cmd/util/ledger/migrations/migrator_runtime.go b/cmd/util/ledger/migrations/migrator_runtime.go index 80660019628..4e747694b63 100644 --- a/cmd/util/ledger/migrations/migrator_runtime.go +++ b/cmd/util/ledger/migrations/migrator_runtime.go @@ -118,6 +118,12 @@ type migratorRuntime struct { ContractNamesProvider stdlib.AccountContractNamesProvider } +var _ stdlib.AccountContractNamesProvider = &migratorRuntime{} + func (mr *migratorRuntime) GetReadOnlyStorage() *runtime.Storage { return runtime.NewStorage(util.NewPayloadsReadonlyLedger(mr.Snapshot), util.NopMemoryGauge{}) } + +func (mr *migratorRuntime) GetAccountContractNames(address common.Address) ([]string, error) { + return mr.Accounts.GetContractNames(flow.Address(address)) +} diff --git a/cmd/util/ledger/migrations/staged_contracts_migration.go b/cmd/util/ledger/migrations/staged_contracts_migration.go index dac3f9bc6a1..8fd7dcb8dd5 100644 --- a/cmd/util/ledger/migrations/staged_contracts_migration.go +++ b/cmd/util/ledger/migrations/staged_contracts_migration.go @@ -14,6 +14,7 @@ import ( "github.com/onflow/cadence/runtime" "github.com/onflow/cadence/runtime/common" + "github.com/onflow/cadence/runtime/interpreter" "github.com/onflow/cadence/runtime/old_parser" "github.com/onflow/cadence/runtime/stdlib" @@ -21,14 +22,18 @@ import ( "github.com/onflow/flow-go/ledger" "github.com/onflow/flow-go/ledger/common/convert" "github.com/onflow/flow-go/model/flow" + + coreContracts "github.com/onflow/flow-core-contracts/lib/go/contracts" ) type StagedContractsMigration struct { - log zerolog.Logger - mutex sync.RWMutex - contracts map[common.Address]map[flow.RegisterID]Contract - contractsByLocation map[common.Location][]byte - stagedContracts []StagedContract + name string + chainID flow.ChainID + log zerolog.Logger + mutex sync.RWMutex + stagedContracts map[common.Address]map[flow.RegisterID]Contract + contractsByLocation map[common.Location][]byte + enableUpdateValidation bool } type StagedContract struct { @@ -43,22 +48,33 @@ type Contract struct { var _ AccountBasedMigration = &StagedContractsMigration{} -func NewStagedContractsMigration(stagedContracts []StagedContract) *StagedContractsMigration { +func NewStagedContractsMigration(chainID flow.ChainID) *StagedContractsMigration { return &StagedContractsMigration{ - stagedContracts: stagedContracts, - contracts: map[common.Address]map[flow.RegisterID]Contract{}, + name: "StagedContractsMigration", + chainID: chainID, + stagedContracts: map[common.Address]map[flow.RegisterID]Contract{}, contractsByLocation: map[common.Location][]byte{}, } } +func (m *StagedContractsMigration) WithContractUpdateValidation() *StagedContractsMigration { + m.enableUpdateValidation = true + return m +} + +func (m *StagedContractsMigration) WithName(name string) *StagedContractsMigration { + m.name = name + return m +} + func (m *StagedContractsMigration) Close() error { m.mutex.RLock() defer m.mutex.RUnlock() - if len(m.contracts) > 0 { + if len(m.stagedContracts) > 0 { var sb strings.Builder sb.WriteString("failed to find all contract registers that need to be changed:\n") - for address, contracts := range m.contracts { + for address, contracts := range m.stagedContracts { _, _ = fmt.Fprintf(&sb, "- address: %s\n", address) for registerID := range contracts { _, _ = fmt.Fprintf(&sb, " - %s\n", flow.RegisterIDContractName(registerID)) @@ -78,47 +94,57 @@ func (m *StagedContractsMigration) InitMigration( ) error { m.log = log. With(). - Str("migration", "StagedContractsMigration"). + Str("migration", m.name). Logger() - m.registerContractUpdates() + // Manually register burner contract + burnerLocation := common.AddressLocation{ + Name: "Burner", + Address: common.Address(m.chainID.Chain().ServiceAddress()), + } + m.contractsByLocation[burnerLocation] = coreContracts.Burner() return nil } -// registerContractUpdates prepares the contract updates as a map for easy lookup. -func (m *StagedContractsMigration) registerContractUpdates() { - for _, contractChange := range m.stagedContracts { - m.registerContractChange(contractChange) +// RegisterContractUpdates prepares the contract updates as a map for easy lookup. +func (m *StagedContractsMigration) RegisterContractUpdates(stagedContracts []StagedContract) { + for _, contractChange := range stagedContracts { + m.RegisterContractChange(contractChange) } } -func (m *StagedContractsMigration) registerContractChange(change StagedContract) { +func (m *StagedContractsMigration) RegisterContractChange(change StagedContract) { + m.mutex.Lock() + defer m.mutex.Unlock() + address := change.Address - if _, ok := m.contracts[address]; !ok { - m.contracts[address] = map[flow.RegisterID]Contract{} + if _, ok := m.stagedContracts[address]; !ok { + m.stagedContracts[address] = map[flow.RegisterID]Contract{} } - registerID := flow.ContractRegisterID(flow.ConvertAddress(address), change.Name) + name := change.Name - _, exist := m.contracts[address][registerID] + registerID := flow.ContractRegisterID(flow.ConvertAddress(address), name) + + _, exist := m.stagedContracts[address][registerID] if exist { // Staged multiple updates for the same contract. // Overwrite the previous update. m.log.Warn().Msgf( "existing staged update found for contract %s.%s. Previous update will be overwritten.", address.HexWithPrefix(), - change.Name, + name, ) } - m.contracts[address][registerID] = change.Contract + m.stagedContracts[address][registerID] = change.Contract location := common.AddressLocation{ - Name: change.Name, + Name: name, Address: address, } - m.contractsByLocation[location] = change.Contract.Code + m.contractsByLocation[location] = change.Code } func (m *StagedContractsMigration) contractUpdatesForAccount( @@ -127,11 +153,11 @@ func (m *StagedContractsMigration) contractUpdatesForAccount( m.mutex.Lock() defer m.mutex.Unlock() - contracts, ok := m.contracts[address] + contracts, ok := m.stagedContracts[address] // remove address from set of addresses // to keep track of which addresses are left to change - delete(m.contracts, address) + delete(m.stagedContracts, address) return contracts, ok } @@ -148,10 +174,17 @@ func (m *StagedContractsMigration) MigrateAccount( return payloads, nil } + elaborations := map[common.Location]*sema.Elaboration{} + config := util.RuntimeInterfaceConfig{ GetContractCodeFunc: func(location runtime.Location) ([]byte, error) { return m.contractsByLocation[location], nil }, + GetOrLoadProgramListener: func(location runtime.Location, program *interpreter.Program, err error) { + if err == nil { + elaborations[location] = program.Elaboration + } + }, } mr, err := newMigratorRuntime(address, payloads, config) @@ -181,7 +214,17 @@ func (m *StagedContractsMigration) MigrateAccount( newCode := updatedContract.Code oldCode := payload.Value() - err = m.checkUpdateValidity(mr, address, name, newCode, oldCode) + if m.enableUpdateValidation { + err = CheckContractUpdateValidity( + mr, + address, + name, + newCode, + oldCode, + elaborations, + ) + } + if err != nil { m.log.Error().Err(err). Msgf( @@ -214,12 +257,13 @@ func (m *StagedContractsMigration) MigrateAccount( return payloads, nil } -func (m *StagedContractsMigration) checkUpdateValidity( +func CheckContractUpdateValidity( mr *migratorRuntime, address common.Address, contractName string, newCode []byte, oldCode ledger.Value, + elaborations map[common.Location]*sema.Elaboration, ) error { location := common.AddressLocation{ Name: contractName, @@ -248,8 +292,7 @@ func (m *StagedContractsMigration) checkUpdateValidity( mr.ContractNamesProvider, oldProgram, newProgram.Program, - // TODO: - map[common.Location]*sema.Elaboration{}, + elaborations, ) return validator.Validate() diff --git a/cmd/util/ledger/migrations/staged_contracts_migration_test.go b/cmd/util/ledger/migrations/staged_contracts_migration_test.go index c8ec47f3e93..fff537830e7 100644 --- a/cmd/util/ledger/migrations/staged_contracts_migration_test.go +++ b/cmd/util/ledger/migrations/staged_contracts_migration_test.go @@ -65,7 +65,8 @@ func TestStagedContractsMigration(t *testing.T) { }, } - migration := NewStagedContractsMigration(stagedContracts) + migration := NewStagedContractsMigration(flow.Emulator) + migration.RegisterContractUpdates(stagedContracts) logWriter := &logWriter{} log := zerolog.New(logWriter) @@ -104,7 +105,9 @@ func TestStagedContractsMigration(t *testing.T) { }, } - migration := NewStagedContractsMigration(stagedContracts) + migration := NewStagedContractsMigration(flow.Emulator). + WithContractUpdateValidation() + migration.RegisterContractUpdates(stagedContracts) logWriter := &logWriter{} log := zerolog.New(logWriter) @@ -145,7 +148,9 @@ func TestStagedContractsMigration(t *testing.T) { }, } - migration := NewStagedContractsMigration(stagedContracts) + migration := NewStagedContractsMigration(flow.Emulator). + WithContractUpdateValidation() + migration.RegisterContractUpdates(stagedContracts) logWriter := &logWriter{} log := zerolog.New(logWriter) @@ -196,7 +201,9 @@ func TestStagedContractsMigration(t *testing.T) { }, } - migration := NewStagedContractsMigration(stagedContracts) + migration := NewStagedContractsMigration(flow.Emulator). + WithContractUpdateValidation() + migration.RegisterContractUpdates(stagedContracts) logWriter := &logWriter{} log := zerolog.New(logWriter) @@ -240,7 +247,8 @@ func TestStagedContractsMigration(t *testing.T) { }, } - migration := NewStagedContractsMigration(stagedContracts) + migration := NewStagedContractsMigration(flow.Emulator) + migration.RegisterContractUpdates(stagedContracts) logWriter := &logWriter{} log := zerolog.New(logWriter) @@ -301,13 +309,15 @@ func TestStagedContractsMigration(t *testing.T) { }, } - migration := NewStagedContractsMigration(stagedContracts) + migration := NewStagedContractsMigration(flow.Emulator) logWriter := &logWriter{} log := zerolog.New(logWriter) err := migration.InitMigration(log, nil, 0) require.NoError(t, err) + migration.RegisterContractUpdates(stagedContracts) + payloads, err := migration.MigrateAccount(ctx, address1, []*ledger.Payload{ newContractPayload(address1, "A", []byte(oldCode)), @@ -385,7 +395,8 @@ func TestStagedContractsWithImports(t *testing.T) { }, } - migration := NewStagedContractsMigration(stagedContracts) + migration := NewStagedContractsMigration(flow.Emulator) + migration.RegisterContractUpdates(stagedContracts) logWriter := &logWriter{} log := zerolog.New(logWriter) @@ -444,7 +455,9 @@ func TestStagedContractsWithImports(t *testing.T) { }, } - migration := NewStagedContractsMigration(stagedContracts) + migration := NewStagedContractsMigration(flow.Emulator). + WithContractUpdateValidation() + migration.RegisterContractUpdates(stagedContracts) logWriter := &logWriter{} log := zerolog.New(logWriter) @@ -520,7 +533,9 @@ func TestStagedContractsWithImports(t *testing.T) { }, } - migration := NewStagedContractsMigration(stagedContracts) + migration := NewStagedContractsMigration(flow.Emulator). + WithContractUpdateValidation() + migration.RegisterContractUpdates(stagedContracts) logWriter := &logWriter{} log := zerolog.New(logWriter) diff --git a/cmd/util/ledger/util/migration_runtime_interface.go b/cmd/util/ledger/util/migration_runtime_interface.go index 67291c2dba1..529b1a351d8 100644 --- a/cmd/util/ledger/util/migration_runtime_interface.go +++ b/cmd/util/ledger/util/migration_runtime_interface.go @@ -117,7 +117,14 @@ func (m MigrationRuntimeInterface) GetAccountContractCode( func (m MigrationRuntimeInterface) GetOrLoadProgram( location runtime.Location, load func() (*interpreter.Program, error), -) (*interpreter.Program, error) { +) (program *interpreter.Program, err error) { + + defer func() { + if m.GetOrLoadProgramListener != nil { + m.GetOrLoadProgramListener(location, program, err) + } + }() + if m.GetOrLoadProgramFunc != nil { return m.GetOrLoadProgramFunc(location, load) } @@ -309,6 +316,8 @@ type RuntimeInterfaceConfig struct { // GetOrLoadProgramFunc allows for injecting extra logic GetOrLoadProgramFunc func(location runtime.Location, load func() (*interpreter.Program, error)) (*interpreter.Program, error) + GetOrLoadProgramListener func(runtime.Location, *interpreter.Program, error) + // GetContractCodeFunc allows for injecting extra logic for code lookup GetContractCodeFunc func(location runtime.Location) ([]byte, error) }