diff --git a/runtime/deferral_test.go b/runtime/deferral_test.go index a8d1af3f42..c5cbccbfe9 100644 --- a/runtime/deferral_test.go +++ b/runtime/deferral_test.go @@ -67,6 +67,11 @@ const simpleDeferralContract = ` return <-r } + pub fun insert(_ id: String, _ r: @R): @R? { + let old <- self.rs.insert(key: id, <-r) + return <- old + } + destroy() { destroy self.rs } @@ -960,7 +965,7 @@ func TestRuntimeStorageDeferredResourceDictionaryValuesTransfer(t *testing.T) { func TestRuntimeStorageDeferredResourceDictionaryValuesRemoval(t *testing.T) { - // Test that `remove` function correctly loads the potentially deferred value + // Test that the `remove` function correctly loads the potentially deferred value runtime := NewInterpreterRuntime() @@ -1153,3 +1158,117 @@ func TestRuntimeStorageDeferredResourceDictionaryValuesDestruction(t *testing.T) loggedMessages, ) } + +func TestRuntimeStorageDeferredResourceDictionaryValuesInsertion(t *testing.T) { + + // Test that the `insert` function correctly loads the potentially deferred value + + runtime := NewInterpreterRuntime() + + contract := []byte(simpleDeferralContract) + + deployTx := []byte(fmt.Sprintf( + ` + transaction { + + prepare(signer: AuthAccount) { + signer.setCode(%s) + } + } + `, + ArrayValueFromBytes(contract).String(), + )) + + setupTx := []byte(` + import Test from 0x1 + + transaction { + + prepare(signer: AuthAccount) { + let c <- Test.createC() + c.rs["a"] <-! Test.createR(1) + c.rs["b"] <-! Test.createR(2) + signer.save(<-c, to: /storage/c) + } + } + `) + + borrowTx := []byte(` + import Test from 0x1 + + transaction { + + prepare(signer: AuthAccount) { + let c = signer.borrow<&Test.C>(from: /storage/c)! + + let e1 <- c.insert("c", <-Test.createR(3)) + assert(e1 == nil) + destroy e1 + + let e2 <- c.insert("a", <-Test.createR(1)) + assert(e2 != nil) + destroy e2 + } + } + `) + + loadTx := []byte(` + import Test from 0x1 + + transaction { + + prepare(signer: AuthAccount) { + let c <- signer.load<@Test.C>(from: /storage/c)! + let e1 <- c.insert("d", <-Test.createR(4)) + assert(e1 == nil) + destroy e1 + + let e2 <- c.insert("b", <-Test.createR(2)) + assert(e2 != nil) + destroy e2 + + destroy c + } + } + `) + + var accountCode []byte + var events []cadence.Event + var loggedMessages []string + + signer := common.BytesToAddress([]byte{0x1}) + + runtimeInterface := &testRuntimeInterface{ + resolveImport: func(_ Location) (bytes []byte, err error) { + return accountCode, nil + }, + storage: newTestStorage(nil, nil), + getSigningAccounts: func() []Address { + return []Address{signer} + }, + updateAccountCode: func(address Address, code []byte, checkPermission bool) (err error) { + accountCode = code + return nil + }, + emitEvent: func(event cadence.Event) { + events = append(events, event) + }, + log: func(message string) { + loggedMessages = append(loggedMessages, message) + }, + } + + nextTransactionLocation := newTransactionLocationGenerator() + + err := runtime.ExecuteTransaction(deployTx, nil, runtimeInterface, nextTransactionLocation()) + require.NoError(t, err) + + err = runtime.ExecuteTransaction(setupTx, nil, runtimeInterface, nextTransactionLocation()) + require.NoError(t, err) + + err = runtime.ExecuteTransaction(borrowTx, nil, runtimeInterface, nextTransactionLocation()) + require.NoError(t, err) + + err = runtime.ExecuteTransaction(loadTx, nil, runtimeInterface, nextTransactionLocation()) + require.NoError(t, err) +} diff --git a/runtime/interpreter/interpreter.go b/runtime/interpreter/interpreter.go index d23b60795d..cb107a2b9e 100644 --- a/runtime/interpreter/interpreter.go +++ b/runtime/interpreter/interpreter.go @@ -1893,7 +1893,8 @@ func (interpreter *Interpreter) VisitDictionaryExpression(expression *ast.Dictio // NOTE: important to convert in optional, as assignment to dictionary // is always considered as an optional - newDictionary.Insert(key, value) + locationRange := interpreter.locationRange(expression) + _ = newDictionary.Insert(interpreter, locationRange, key, value) } return Done{Result: newDictionary} diff --git a/runtime/interpreter/value.go b/runtime/interpreter/value.go index 50f6ccff93..9b3b26cdbc 100644 --- a/runtime/interpreter/value.go +++ b/runtime/interpreter/value.go @@ -5061,7 +5061,7 @@ func NewDictionaryValueUnownedNonCopying(keysAndValues ...Value) *DictionaryValu } for i := 0; i < keysAndValuesCount; i += 2 { - result.Insert(keysAndValues[i], keysAndValues[i+1]) + _ = result.Insert(nil, LocationRange{}, keysAndValues[i], keysAndValues[i+1]) } return result @@ -5225,10 +5225,10 @@ func (v *DictionaryValue) Set(inter *Interpreter, locationRange LocationRange, k switch typedValue := value.(type) { case *SomeValue: - v.Insert(keyValue, typedValue.Value) + _ = v.Insert(inter, locationRange, keyValue, typedValue.Value) case NilValue: - v.Remove(inter, locationRange, keyValue) + _ = v.Remove(inter, locationRange, keyValue) return default: @@ -5281,8 +5281,14 @@ func (v *DictionaryValue) GetMember(_ *Interpreter, _ LocationRange, name string return NewHostFunctionValue( func(invocation Invocation) trampoline.Trampoline { keyValue := invocation.Arguments[0] - result := v.Remove(invocation.Interpreter, invocation.LocationRange, keyValue) - return trampoline.Done{Result: result} + + existingValue := v.Remove( + invocation.Interpreter, + invocation.LocationRange, + keyValue, + ) + + return trampoline.Done{Result: existingValue} }, ) @@ -5292,18 +5298,14 @@ func (v *DictionaryValue) GetMember(_ *Interpreter, _ LocationRange, name string keyValue := invocation.Arguments[0] newValue := invocation.Arguments[1] - existingValue := v.Insert(keyValue, newValue) - - var returnValue Value - if existingValue == nil { - returnValue = NilValue{} - } else { - returnValue = NewSomeValueOwningNonCopying(existingValue) - } + existingValue := v.Insert( + invocation.Interpreter, + invocation.LocationRange, + keyValue, + newValue, + ) - return trampoline.Done{ - Result: returnValue, - } + return trampoline.Done{Result: existingValue} }, ) @@ -5364,25 +5366,29 @@ func (v *DictionaryValue) Remove(inter *Interpreter, locationRange LocationRange } } -func (v *DictionaryValue) Insert(keyValue Value, value Value) (existingValue Value) { +func (v *DictionaryValue) Insert(inter *Interpreter, locationRange LocationRange, keyValue, value Value) OptionalValue { v.modified = true - key := dictionaryKey(keyValue) - existingValue, existed := v.Entries[key] + // Don't use `Entries` here: the value might be deferred and needs to be loaded + existingValue := v.Get(inter, locationRange, keyValue) - if !existed { - v.Keys.Append(keyValue) - } + key := dictionaryKey(keyValue) value.SetOwner(v.Owner) v.Entries[key] = value - if !existed { - return nil - } + switch existingValue := existingValue.(type) { + case *SomeValue: + return existingValue + + case NilValue: + v.Keys.Append(keyValue) + return existingValue - return existingValue + default: + panic(errors.NewUnreachableError()) + } } func (v *DictionaryValue) SetModified(modified bool) { diff --git a/runtime/interpreter/value_test.go b/runtime/interpreter/value_test.go index a94feca86c..f6de22cd6e 100644 --- a/runtime/interpreter/value_test.go +++ b/runtime/interpreter/value_test.go @@ -257,7 +257,7 @@ func TestSetOwnerDictionaryInsert(t *testing.T) { assert.Equal(t, &newOwner, dictionary.GetOwner()) assert.Equal(t, &oldOwner, value.GetOwner()) - dictionary.Insert(keyValue, value) + dictionary.Insert(nil, LocationRange{}, keyValue, value) assert.Equal(t, &newOwner, dictionary.GetOwner()) assert.Equal(t, &newOwner, value.GetOwner()) diff --git a/runtime/tests/interpreter/interpreter_test.go b/runtime/tests/interpreter/interpreter_test.go index 7181f9c9b1..a63ec1daf0 100644 --- a/runtime/tests/interpreter/interpreter_test.go +++ b/runtime/tests/interpreter/interpreter_test.go @@ -5091,6 +5091,8 @@ func TestInterpretDictionaryInsert(t *testing.T) { interpreter.NewStringValue("def"), interpreter.NewIntValueFromInt64(2), ).Copy().(*interpreter.DictionaryValue) expectedDict.Insert( + nil, + interpreter.LocationRange{}, interpreter.NewStringValue("abc"), interpreter.NewIntValueFromInt64(3), )