diff --git a/internal/migrate/stage_contract.go b/internal/migrate/stage_contract.go index e95c3df5d..3cc5312eb 100644 --- a/internal/migrate/stage_contract.go +++ b/internal/migrate/stage_contract.go @@ -128,7 +128,7 @@ func stageContract( } } else if err != nil { logger.Error(validator.prettyPrintError(err, common.StringLocation(contract.Location))) - return nil, fmt.Errorf("errors were found while validating the contract code, and your contract HAS NOT been staged, you can use the --skip-validation flag to bypass this check") + return nil, fmt.Errorf("errors were found while attempting to perform preliminary validation of the contract code, and your contract HAS NOT been staged, however you can use the --skip-validation flag to bypass this check & stage the contract anyway") } else { logger.Info("No issues found while validating contract code\n") logger.Info("DISCLAIMER: Pre-staging validation checks are not exhaustive and do not guarantee the contract will work as expected, please monitor the status of your contract using the `flow migrate is-validated` command\n") diff --git a/internal/migrate/staging_validator.go b/internal/migrate/staging_validator.go index c65e83921..393048885 100644 --- a/internal/migrate/staging_validator.go +++ b/internal/migrate/staging_validator.go @@ -153,6 +153,11 @@ func (v *stagingValidator) ValidateContractUpdate( interpreterProgram, v.elaborations, ) + chainId, ok := chainIdMap[v.flow.Network().Name] + if !ok { + return fmt.Errorf("unsupported network: %s", v.flow.Network().Name) + } + validator.WithUserDefinedTypeChangeChecker(newUserDefinedTypeChangeCheckerFunc(chainId)) err = validator.Validate() if err != nil { @@ -438,3 +443,30 @@ func (a *accountContractNamesProviderImpl) GetAccountContractNames( ) ([]string, error) { return a.resolverFunc(address) } + +// TEMPORARY: this is not exported by flow-go and should be removed once it is +// This is for a quick fix to get the validator working +func newUserDefinedTypeChangeCheckerFunc( + chainID flow.ChainID, +) func(oldTypeID common.TypeID, newTypeID common.TypeID) (checked, valid bool) { + + typeChangeRules := map[common.TypeID]common.TypeID{} + + compositeTypeRules := migrations.NewCompositeTypeConversionRules(chainID) + for typeID, newStaticType := range compositeTypeRules { + typeChangeRules[typeID] = newStaticType.ID() + } + + interfaceTypeRules := migrations.NewInterfaceTypeConversionRules(chainID) + for typeID, newStaticType := range interfaceTypeRules { + typeChangeRules[typeID] = newStaticType.ID() + } + + return func(oldTypeID common.TypeID, newTypeID common.TypeID) (checked, valid bool) { + expectedNewTypeID, found := typeChangeRules[oldTypeID] + if found { + return true, expectedNewTypeID == newTypeID + } + return false, false + } +}