Skip to content

Commit

Permalink
Merge pull request #19 from tenetxyz/support-composite-keys-keyswithv…
Browse files Browse the repository at this point in the history
…alue

feat(world): add support for composite keys in KeysWithValue module
  • Loading branch information
dhvanipa authored Jul 23, 2023
2 parents 9e0f64a + 1a167df commit 29a9dba
Show file tree
Hide file tree
Showing 11 changed files with 986 additions and 105 deletions.
4 changes: 2 additions & 2 deletions docs/pages/world/modules.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ Internally, it works by installing a [hook](https://v2.mud.dev/store/advanced-fe

#### **`KeysWithValueModule`**

The [`KeysWithValueModule`](https://github.com/latticexyz/mud/blob/main/packages/world/src/modules/keyswithvalue/KeysWithValueModule.sol) allows for querying for all keys that have a given value in a table. It can only be used on tables that have one key (eg: ECS components).
The [`KeysWithValueModule`](https://github.com/latticexyz/mud/blob/main/packages/world/src/modules/keyswithvalue/KeysWithValueModule.sol) allows for querying for all keys that have a given value in a table. It can only be used on tables that have up to five keys.

As an example, this can be used to ask for all NFTs owned by a specific user (if the table models `NFT ID → Owner`), or all units on a specific position (if the ECS component is modeled as `entity → Position`)

Expand All @@ -106,7 +106,7 @@ Using `getKeysWithValue` to retrieve all NFTs owned by a specific address:
import { getKeysWithValue } from "@latticexyz/world/src/modules/keyswithvalue/getKeysWithValue.sol";
import { Owners, OwnersId } from "../codegen/tables/Owners.sol";
// get all nfts (as bytes, need to convert to uint256) owned by address 0x42
bytes32[] memory keysWithValue = getKeysWithValue(world, OwnersId, Owners.encode(address(42)));
bytes32[][] memory keysWithValue = getKeysWithValue(world, OwnersId, Owners.encode(address(42)));
```

Internally, it works by installing a [hook](/store/advanced-features#storage-hooks) that maintains an array of all keys in the table.
Expand Down
2 changes: 1 addition & 1 deletion examples/minimal/packages/contracts/test/CounterTest.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ contract CounterTest is MudTest {
function testKeysWithValue() public {
bytes32 key = SingletonKey;
uint32 counter = CounterTable.get(key);
bytes32[] memory keysWithValue = getKeysWithValue(CounterTableTableId, CounterTable.encode(counter));
bytes32[][] memory keysWithValue = getKeysWithValue(CounterTableTableId, CounterTable.encode(counter));
assertEq(keysWithValue.length, 1);
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,20 @@
[
{
"inputs": [
{
"internalType": "uint256",
"name": "length",
"type": "uint256"
}
],
"name": "SchemaLib_InvalidLength",
"type": "error"
},
{
"inputs": [],
"name": "SchemaLib_StaticTypeAfterDynamicType",
"type": "error"
},
{
"inputs": [
{
Expand Down
6 changes: 5 additions & 1 deletion packages/world/mud.config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,11 @@ export default mudConfig({
valueHash: "bytes32",
},
schema: {
keysWithValue: "bytes32[]", // For now only supports 1 key per value
keys0: "bytes32[]",
keys1: "bytes32[]",
keys2: "bytes32[]",
keys3: "bytes32[]",
keys4: "bytes32[]",
},
tableIdArgument: true,
},
Expand Down
2 changes: 1 addition & 1 deletion packages/world/src/Tables.sol
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import { SystemRegistry, SystemRegistryTableId } from "./modules/core/tables/Sys
import { SystemHooks, SystemHooksTableId } from "./modules/core/tables/SystemHooks.sol";
import { ResourceType, ResourceTypeTableId } from "./modules/core/tables/ResourceType.sol";
import { FunctionSelectors, FunctionSelectorsTableId } from "./modules/core/tables/FunctionSelectors.sol";
import { KeysWithValue } from "./modules/keyswithvalue/tables/KeysWithValue.sol";
import { KeysWithValue, KeysWithValueData } from "./modules/keyswithvalue/tables/KeysWithValue.sol";
import { KeysInTable, KeysInTableData, KeysInTableTableId } from "./modules/keysintable/tables/KeysInTable.sol";
import { UsedKeysIndex, UsedKeysIndexTableId } from "./modules/keysintable/tables/UsedKeysIndex.sol";
import { UniqueEntity } from "./modules/uniqueentity/tables/UniqueEntity.sol";
Expand Down
20 changes: 6 additions & 14 deletions packages/world/src/modules/keysintable/query.sol
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,6 @@ struct QueryFragment {
bytes value;
}

function valuesToTuples(bytes32[] memory flatArr) pure returns (bytes32[][] memory arr) {
arr = new bytes32[][](flatArr.length);
for (uint256 i; i < flatArr.length; i++) {
arr[i] = new bytes32[](1);
arr[i][0] = flatArr[i];
}
}

/**
* Helper function to check whether a given key passes a given query fragment.
*
Expand All @@ -45,7 +37,7 @@ function passesQueryFragment(bytes32[] memory keyTuple, QueryFragment memory fra

if (fragment.queryType == QueryType.HasValue) {
// Key must have the given value
return ArrayLib.includes(valuesToTuples(getKeysWithValue(fragment.tableId, fragment.value)), keyTuple);
return ArrayLib.includes(getKeysWithValue(fragment.tableId, fragment.value), keyTuple);
}

if (fragment.queryType == QueryType.Not) {
Expand All @@ -55,7 +47,7 @@ function passesQueryFragment(bytes32[] memory keyTuple, QueryFragment memory fra

if (fragment.queryType == QueryType.NotValue) {
// Key must not have the given value
return !ArrayLib.includes(valuesToTuples(getKeysWithValue(fragment.tableId, fragment.value)), keyTuple);
return !ArrayLib.includes(getKeysWithValue(fragment.tableId, fragment.value), keyTuple);
}

return false;
Expand All @@ -78,7 +70,7 @@ function passesQueryFragment(

if (fragment.queryType == QueryType.HasValue) {
// Key must be have the given value
return ArrayLib.includes(valuesToTuples(getKeysWithValue(store, fragment.tableId, fragment.value)), keyTuple);
return ArrayLib.includes(getKeysWithValue(store, fragment.tableId, fragment.value), keyTuple);
}

if (fragment.queryType == QueryType.Not) {
Expand All @@ -88,7 +80,7 @@ function passesQueryFragment(

if (fragment.queryType == QueryType.NotValue) {
// Key must not have the given value
return !ArrayLib.includes(valuesToTuples(getKeysWithValue(store, fragment.tableId, fragment.value)), keyTuple);
return !ArrayLib.includes(getKeysWithValue(store, fragment.tableId, fragment.value), keyTuple);
}

return false;
Expand All @@ -108,7 +100,7 @@ function query(QueryFragment[] memory fragments) view returns (bytes32[][] memor
// Create the first interim result
keyTuples = fragments[0].queryType == QueryType.Has
? getKeysInTable(fragments[0].tableId)
: valuesToTuples(getKeysWithValue(fragments[0].tableId, fragments[0].value));
: getKeysWithValue(fragments[0].tableId, fragments[0].value);

for (uint256 i = 1; i < fragments.length; i++) {
bytes32[][] memory result = new bytes32[][](0);
Expand Down Expand Up @@ -143,7 +135,7 @@ function query(IStore store, QueryFragment[] memory fragments) view returns (byt
// Create the first interim result
keyTuples = fragments[0].queryType == QueryType.Has
? getKeysInTable(store, fragments[0].tableId)
: valuesToTuples(getKeysWithValue(store, fragments[0].tableId, fragments[0].value));
: getKeysWithValue(store, fragments[0].tableId, fragments[0].value);

for (uint256 i = 1; i < fragments.length; i++) {
bytes32[][] memory result = new bytes32[][](0);
Expand Down
55 changes: 44 additions & 11 deletions packages/world/src/modules/keyswithvalue/KeysWithValueHook.sol
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import { IBaseWorld } from "../../interfaces/IBaseWorld.sol";
import { ResourceSelector } from "../../ResourceSelector.sol";

import { MODULE_NAMESPACE } from "./constants.sol";
import { KeysWithValue } from "./tables/KeysWithValue.sol";
import { KeysWithValue, KeysWithValueData } from "./tables/KeysWithValue.sol";
import { ArrayLib } from "../utils/ArrayLib.sol";
import { getTargetTableSelector } from "../utils/getTargetTableSelector.sol";

Expand All @@ -19,7 +19,7 @@ import { getTargetTableSelector } from "../utils/getTargetTableSelector.sol";
* and then replicate logic from solecs's Set.sol.
* (See https://github.com/latticexyz/mud/issues/444)
*
* Note: if a table with composite keys is used, only the first key is indexed
* Note: this module only supports up to 5 composite keys.
*/
contract KeysWithValueHook is IStoreHook {
using ArrayLib for bytes32[];
Expand All @@ -29,45 +29,78 @@ contract KeysWithValueHook is IStoreHook {
return IBaseWorld(StoreSwitch.getStoreAddress());
}

function handleSet(bytes32 tableId, bytes32 valueHash, bytes32[] memory key) internal {
if (key.length > 0) {
KeysWithValue.pushKeys0(tableId, valueHash, key[0]);
if (key.length > 1) {
KeysWithValue.pushKeys1(tableId, valueHash, key[1]);
if (key.length > 2) {
KeysWithValue.pushKeys2(tableId, valueHash, key[2]);
if (key.length > 3) {
KeysWithValue.pushKeys3(tableId, valueHash, key[3]);
if (key.length > 4) {
KeysWithValue.pushKeys4(tableId, valueHash, key[4]);
}
}
}
}
}
}

function onSetRecord(bytes32 sourceTableId, bytes32[] memory key, bytes memory data) public {
bytes32 targetTableId = getTargetTableSelector(MODULE_NAMESPACE, sourceTableId);

// Get the previous value
bytes32 previousValue = keccak256(_world().getRecord(sourceTableId, key));

// Remove the key from the list of keys with the previous value
_removeKeyFromList(targetTableId, key[0], previousValue);
_removeKeyFromList(targetTableId, key, previousValue);

// Push the key to the list of keys with the new value
KeysWithValue.push(targetTableId, keccak256(data), key[0]);
handleSet(targetTableId, keccak256(data), key);
}

function onBeforeSetField(bytes32 sourceTableId, bytes32[] memory key, uint8, bytes memory) public {
// Remove the key from the list of keys with the previous value
bytes32 previousValue = keccak256(_world().getRecord(sourceTableId, key));
bytes32 targetTableId = getTargetTableSelector(MODULE_NAMESPACE, sourceTableId);
_removeKeyFromList(targetTableId, key[0], previousValue);
_removeKeyFromList(targetTableId, key, previousValue);
}

function onAfterSetField(bytes32 sourceTableId, bytes32[] memory key, uint8, bytes memory) public {
// Add the key to the list of keys with the new value
bytes32 newValue = keccak256(_world().getRecord(sourceTableId, key));
bytes32 targetTableId = getTargetTableSelector(MODULE_NAMESPACE, sourceTableId);
KeysWithValue.push(targetTableId, newValue, key[0]);
handleSet(targetTableId, newValue, key);
}

function onDeleteRecord(bytes32 sourceTableId, bytes32[] memory key) public {
// Remove the key from the list of keys with the previous value
bytes32 previousValue = keccak256(_world().getRecord(sourceTableId, key));
bytes32 targetTableId = getTargetTableSelector(MODULE_NAMESPACE, sourceTableId);
_removeKeyFromList(targetTableId, key[0], previousValue);
_removeKeyFromList(targetTableId, key, previousValue);
}

function _removeKeyFromList(bytes32 targetTableId, bytes32 key, bytes32 valueHash) internal {
// Get the keys with the previous value excluding the current key
bytes32[] memory keysWithPreviousValue = KeysWithValue.get(targetTableId, valueHash).filter(key);
function _removeKeyFromList(bytes32 targetTableId, bytes32[] memory key, bytes32 valueHash) internal {
// Get the keys with the previous value
KeysWithValueData memory keysWithPreviousValue = KeysWithValue.get(targetTableId, valueHash);
if (key.length > 0) {
keysWithPreviousValue.keys0 = keysWithPreviousValue.keys0.filter(key[0]);
if (key.length > 1) {
keysWithPreviousValue.keys1 = keysWithPreviousValue.keys1.filter(key[1]);
if (key.length > 2) {
keysWithPreviousValue.keys2 = keysWithPreviousValue.keys2.filter(key[2]);
if (key.length > 3) {
keysWithPreviousValue.keys3 = keysWithPreviousValue.keys3.filter(key[3]);
if (key.length > 4) {
keysWithPreviousValue.keys4 = keysWithPreviousValue.keys4.filter(key[4]);
}
}
}
}
}

if (keysWithPreviousValue.length == 0) {
if (keysWithPreviousValue.keys0.length == 0) {
// Delete the list of keys in this table
KeysWithValue.deleteRecord(targetTableId, valueHash);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import { getTargetTableSelector } from "../utils/getTargetTableSelector.sol";
* from value to list of keys with this value. This mapping is stored in a table registered
* by the module at the `targetTableId` provided in the install methods arguments.
*
* Note: if a table with composite keys is used, only the first key is indexed
* Note: this module only supports up to 5 composite keys.
*
* Note: this module currently expects to be `delegatecalled` via World.installRootModule.
* Support for installing it via `World.installModule` depends on `World.callFrom` being implemented.
Expand Down
60 changes: 54 additions & 6 deletions packages/world/src/modules/keyswithvalue/getKeysWithValue.sol
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
pragma solidity >=0.8.0;

import { IStore } from "@latticexyz/store/src/IStore.sol";
import { Schema } from "@latticexyz/store/src/Schema.sol";
import { StoreSwitch } from "@latticexyz/store/src/StoreSwitch.sol";

import { MODULE_NAMESPACE } from "./constants.sol";
import { KeysWithValue } from "./tables/KeysWithValue.sol";
Expand All @@ -13,12 +15,35 @@ import { getTargetTableSelector } from "../utils/getTargetTableSelector.sol";
* Note: this util can only be called within the context of a Store (e.g. from a System or Module).
* For usage outside of a Store, use the overload that takes an explicit store argument.
*/
function getKeysWithValue(bytes32 tableId, bytes memory value) view returns (bytes32[] memory keysWithValue) {
function getKeysWithValue(bytes32 tableId, bytes memory value) view returns (bytes32[][] memory keyTuples) {
// Get the corresponding reverse mapping table
bytes32 keysWithValueTableId = getTargetTableSelector(MODULE_NAMESPACE, tableId);
bytes32 valueHash = keccak256(value);

// Get the keys with the given value
keysWithValue = KeysWithValue.get(keysWithValueTableId, keccak256(value));
Schema schema = StoreSwitch.getKeySchema(tableId);
uint256 numFields = schema.numFields();
uint256 length = KeysWithValue.lengthKeys0(keysWithValueTableId, valueHash);
keyTuples = new bytes32[][](length);

for (uint256 i; i < length; i++) {
keyTuples[i] = new bytes32[](numFields); // the length of the key tuple depends on the schema

if (numFields > 0) {
keyTuples[i][0] = KeysWithValue.getItemKeys0(keysWithValueTableId, valueHash, i);
if (numFields > 1) {
keyTuples[i][1] = KeysWithValue.getItemKeys1(keysWithValueTableId, valueHash, i);
if (numFields > 2) {
keyTuples[i][2] = KeysWithValue.getItemKeys2(keysWithValueTableId, valueHash, i);
if (numFields > 3) {
keyTuples[i][3] = KeysWithValue.getItemKeys3(keysWithValueTableId, valueHash, i);
if (numFields > 4) {
keyTuples[i][4] = KeysWithValue.getItemKeys4(keysWithValueTableId, valueHash, i);
}
}
}
}
}
}
}

/**
Expand All @@ -28,10 +53,33 @@ function getKeysWithValue(
IStore store,
bytes32 tableId,
bytes memory value
) view returns (bytes32[] memory keysWithValue) {
) view returns (bytes32[][] memory keyTuples) {
// Get the corresponding reverse mapping table
bytes32 keysWithValueTableId = getTargetTableSelector(MODULE_NAMESPACE, tableId);
bytes32 valueHash = keccak256(value);

Schema schema = store.getKeySchema(tableId);
uint256 numFields = schema.numFields();
uint256 length = KeysWithValue.lengthKeys0(store, keysWithValueTableId, valueHash);
keyTuples = new bytes32[][](length);

for (uint256 i; i < length; i++) {
keyTuples[i] = new bytes32[](numFields); // the length of the key tuple depends on the schema

// Get the keys with the given value
keysWithValue = KeysWithValue.get(store, keysWithValueTableId, keccak256(value));
if (numFields > 0) {
keyTuples[i][0] = KeysWithValue.getItemKeys0(store, keysWithValueTableId, valueHash, i);
if (numFields > 1) {
keyTuples[i][1] = KeysWithValue.getItemKeys1(store, keysWithValueTableId, valueHash, i);
if (numFields > 2) {
keyTuples[i][2] = KeysWithValue.getItemKeys2(store, keysWithValueTableId, valueHash, i);
if (numFields > 3) {
keyTuples[i][3] = KeysWithValue.getItemKeys3(store, keysWithValueTableId, valueHash, i);
if (numFields > 4) {
keyTuples[i][4] = KeysWithValue.getItemKeys4(store, keysWithValueTableId, valueHash, i);
}
}
}
}
}
}
}
Loading

0 comments on commit 29a9dba

Please sign in to comment.