Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sort in memory to Arrays library #4846

Merged
merged 30 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
d78ae17
Migrate 'arrays'
RenanSouza2 Nov 12, 2023
52b3bd8
fix findUpperBound and add findLowerBound
Amxx Nov 17, 2023
9a1411b
add memory variants
Amxx Nov 20, 2023
c4726c8
Merge branch 'master' into feature/array-bound-with-duplicates
Amxx Jan 18, 2024
a6ec616
fix merge
Amxx Jan 18, 2024
23e8db9
fix lint
Amxx Jan 18, 2024
c72591f
minimize change
Amxx Jan 18, 2024
9162e42
add changeset
Amxx Jan 18, 2024
ed1de5b
Apply suggestions from code review
Amxx Jan 18, 2024
4c1c7f4
add Arrays.sort
Amxx Jan 18, 2024
abd07b6
add sort test
Amxx Jan 18, 2024
1e89815
fix lint
Amxx Jan 18, 2024
b73989d
codespell
Amxx Jan 19, 2024
79bf367
add fuzzing tests for Arrays.sort
Amxx Jan 19, 2024
f2d49ef
add unsafeMemoryAccess tests
Amxx Jan 22, 2024
c043453
Merge branch 'master' into feature/quicksort
Amxx Jan 29, 2024
c75de32
fix lint
Amxx Jan 29, 2024
f823bee
lint
Amxx Jan 30, 2024
c90f12b
Update contracts/utils/Arrays.sol
Amxx Feb 2, 2024
180a969
Update contracts/utils/Arrays.sol
Amxx Feb 2, 2024
708972f
Apply suggestions from code review
Amxx Feb 2, 2024
533e6cd
Merge branch 'master' into feature/quicksort
Amxx Feb 2, 2024
3f1f0a5
Update contracts/utils/Arrays.sol
ernestognw Feb 2, 2024
5a0ad7f
Add comments to `_quickSort`
ernestognw Feb 3, 2024
a8e6f54
Lint
ernestognw Feb 3, 2024
7600291
Update contracts/utils/Arrays.sol
Amxx Feb 5, 2024
6f163d2
Update contracts/utils/Arrays.sol
Amxx Feb 5, 2024
8983066
cache the pivot and improve doc
Amxx Feb 5, 2024
8704763
Apply suggestions from code review
Amxx Feb 5, 2024
4336e2e
Merge branch 'master' into feature/quicksort
Amxx Feb 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/dirty-cobras-smile.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'openzeppelin-solidity': minor
---

`Arrays`: add a `sort` function.
75 changes: 73 additions & 2 deletions contracts/utils/Arrays.sol
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,66 @@ import {Math} from "./math/Math.sol";
library Arrays {
using StorageSlot for bytes32;

/**
* @dev Sort an array (in memory) in increasing order.
*
* This function does the sorting "in place", meaning that it overrides the input. The object is returned for
* convenience, but that returned value can be discarded safely if the caller has a memory pointer to the array.
*
* NOTE: this function's cost is `O(n · log(n))` in average and `O(n²)` in the worst case, with n the length of the
* array. Using it in view functions that are executed through `eth_call` is safe, but one should be very careful
* when executing this as part of a transaction. If the array being sorted is too large, the sort operation may
* consume more gas than is available in a block, leading to potential DoS.
*/
function sort(uint256[] memory array) internal pure returns (uint256[] memory) {
_quickSort(array, 0, array.length);
return array;
}

/**
* @dev Performs a quick sort on an array in memory. The array is sorted in increasing order.
* This private implementation assumes that `i <= j` and that `j < array.length`.
Amxx marked this conversation as resolved.
Show resolved Hide resolved
*/
function _quickSort(uint256[] memory array, uint256 i, uint256 j) private pure {
unchecked {
// Can't overflow given `i <= j`
if (j - i < 2) return;

// Use first element as pivot
// i = pivot index
uint256 index = i;

for (uint256 k = i + 1; k < j; ++k) {
// pivot > array[k]
// Unsafe access is safe given `j < array.length` and `k < j`.
if (unsafeMemoryAccess(array, i) > unsafeMemoryAccess(array, k)) {
_swap(array, ++index, k);
}
}

// Swap pivot into place
_swap(array, i, index);

_quickSort(array, i, index); // Sort the left side of the pivot
_quickSort(array, index + 1, j); // Sort the right side of the pivot
}
}

/**
* @dev Swaps the elements at positions `i` and `j` in the `arr` array.
*/
function _swap(uint256[] memory arr, uint256 i, uint256 j) private pure {
assembly {
let start := add(arr, 0x20) // Pointer to the first element of the array
let pos_i := add(start, mul(i, 0x20))
let pos_j := add(start, mul(j, 0x20))
let val_i := mload(pos_i)
let val_j := mload(pos_j)
mstore(pos_i, val_j)
mstore(pos_j, val_i)
}
}

/**
* @dev Searches a sorted `array` and returns the first index that contains
* a value greater or equal to `element`. If no such index exists (i.e. all
Expand Down Expand Up @@ -238,7 +298,7 @@ library Arrays {
*
* WARNING: Only use if you are certain `pos` is lower than the array length.
*/
function unsafeMemoryAccess(uint256[] memory arr, uint256 pos) internal pure returns (uint256 res) {
function unsafeMemoryAccess(address[] memory arr, uint256 pos) internal pure returns (address res) {
assembly {
res := mload(add(add(arr, 0x20), mul(pos, 0x20)))
}
Expand All @@ -249,7 +309,18 @@ library Arrays {
*
* WARNING: Only use if you are certain `pos` is lower than the array length.
*/
function unsafeMemoryAccess(address[] memory arr, uint256 pos) internal pure returns (address res) {
function unsafeMemoryAccess(bytes32[] memory arr, uint256 pos) internal pure returns (bytes32 res) {
assembly {
res := mload(add(add(arr, 0x20), mul(pos, 0x20)))
}
}

/**
* @dev Access an array in an "unsafe" way. Skips solidity "index-out-of-range" check.
*
* WARNING: Only use if you are certain `pos` is lower than the array length.
*/
function unsafeMemoryAccess(uint256[] memory arr, uint256 pos) internal pure returns (uint256 res) {
assembly {
res := mload(add(add(arr, 0x20), mul(pos, 0x20)))
}
Expand Down
4 changes: 2 additions & 2 deletions scripts/generate/templates/Checkpoints.t.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ const header = `\
pragma solidity ^0.8.20;

import {Test} from "forge-std/Test.sol";
import {SafeCast} from "../../../contracts/utils/math/SafeCast.sol";
import {Checkpoints} from "../../../contracts/utils/structs/Checkpoints.sol";
import {SafeCast} from "@openzeppelin/contracts/utils/math/SafeCast.sol";
import {Checkpoints} from "@openzeppelin/contracts/utils/structs/Checkpoints.sol";
`;

/* eslint-disable max-len */
Expand Down
15 changes: 15 additions & 0 deletions test/utils/Arrays.t.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// SPDX-License-Identifier: MIT

pragma solidity ^0.8.20;

import {Test} from "forge-std/Test.sol";
import {Arrays} from "@openzeppelin/contracts/utils/Arrays.sol";

contract ArraysTest is Test {
function testSort(uint256[] memory values) public {
Arrays.sort(values);
for (uint256 i = 1; i < values.length; ++i) {
assertLe(values[i - 1], values[i]);
}
}
}
105 changes: 84 additions & 21 deletions test/utils/Arrays.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,56 @@ const upperBound = (array, value) => {
return i == -1 ? array.length : i;
};

// By default, js "sort" cast to string and then sort in alphabetical order. Use this to sort numbers.
const compareNumbers = (a, b) => (a > b ? 1 : a < b ? -1 : 0);

const hasDuplicates = array => array.some((v, i) => array.indexOf(v) != i);

describe('Arrays', function () {
const fixture = async () => {
return { mock: await ethers.deployContract('$Arrays') };
};

beforeEach(async function () {
Object.assign(this, await loadFixture(fixture));
});

describe('sort', function () {
for (const length of [0, 1, 2, 8, 32, 128]) {
it(`sort array of length ${length}`, async function () {
this.elements = randomArray(generators.uint256, length);
this.expected = Array.from(this.elements).sort(compareNumbers);
});

if (length > 1) {
it(`sort array of length ${length} (identical elements)`, async function () {
this.elements = Array(length).fill(generators.uint256());
this.expected = this.elements;
});

it(`sort array of length ${length} (already sorted)`, async function () {
this.elements = randomArray(generators.uint256, length).sort(compareNumbers);
this.expected = this.elements;
});

it(`sort array of length ${length} (sorted in reverse order)`, async function () {
this.elements = randomArray(generators.uint256, length).sort(compareNumbers).reverse();
this.expected = Array.from(this.elements).reverse();
});

it(`sort array of length ${length} (almost sorted)`, async function () {
this.elements = randomArray(generators.uint256, length).sort(compareNumbers);
this.expected = Array.from(this.elements);
// rotate (move the last element to the front) for an almost sorted effect
this.elements.unshift(this.elements.pop());
});
}
}
afterEach(async function () {
expect(await this.mock.$sort(this.elements)).to.deep.equal(this.expected);
});
});

describe('search', function () {
for (const [title, { array, tests }] of Object.entries({
'Even number of elements': {
Expand Down Expand Up @@ -74,7 +121,7 @@ describe('Arrays', function () {
})) {
describe(title, function () {
const fixture = async () => {
return { mock: await ethers.deployContract('Uint256ArraysMock', [array]) };
return { instance: await ethers.deployContract('Uint256ArraysMock', [array]) };
};

beforeEach(async function () {
Expand All @@ -86,20 +133,20 @@ describe('Arrays', function () {
it('[deprecated] findUpperBound', async function () {
// findUpperBound does not support duplicated
if (hasDuplicates(array)) {
expect(await this.mock.findUpperBound(input)).to.be.equal(upperBound(array, input) - 1);
expect(await this.instance.findUpperBound(input)).to.equal(upperBound(array, input) - 1);
} else {
expect(await this.mock.findUpperBound(input)).to.be.equal(lowerBound(array, input));
expect(await this.instance.findUpperBound(input)).to.equal(lowerBound(array, input));
}
});

it('lowerBound', async function () {
expect(await this.mock.lowerBound(input)).to.be.equal(lowerBound(array, input));
expect(await this.mock.lowerBoundMemory(array, input)).to.be.equal(lowerBound(array, input));
expect(await this.instance.lowerBound(input)).to.equal(lowerBound(array, input));
expect(await this.instance.lowerBoundMemory(array, input)).to.equal(lowerBound(array, input));
});

it('upperBound', async function () {
expect(await this.mock.upperBound(input)).to.be.equal(upperBound(array, input));
expect(await this.mock.upperBoundMemory(array, input)).to.be.equal(upperBound(array, input));
expect(await this.instance.upperBound(input)).to.equal(upperBound(array, input));
expect(await this.instance.upperBoundMemory(array, input)).to.equal(upperBound(array, input));
});
});
}
Expand All @@ -108,28 +155,44 @@ describe('Arrays', function () {
});

describe('unsafeAccess', function () {
for (const [title, { artifact, elements }] of Object.entries({
for (const [type, { artifact, elements }] of Object.entries({
address: { artifact: 'AddressArraysMock', elements: randomArray(generators.address, 10) },
bytes32: { artifact: 'Bytes32ArraysMock', elements: randomArray(generators.bytes32, 10) },
uint256: { artifact: 'Uint256ArraysMock', elements: randomArray(generators.uint256, 10) },
})) {
describe(title, function () {
const fixture = async () => {
return { mock: await ethers.deployContract(artifact, [elements]) };
};
describe(type, function () {
describe('storage', function () {
const fixture = async () => {
return { instance: await ethers.deployContract(artifact, [elements]) };
};

beforeEach(async function () {
Object.assign(this, await loadFixture(fixture));
});
beforeEach(async function () {
Object.assign(this, await loadFixture(fixture));
});

for (const i in elements) {
it(`unsafeAccess within bounds #${i}`, async function () {
expect(await this.mock.unsafeAccess(i)).to.equal(elements[i]);
for (const i in elements) {
it(`unsafeAccess within bounds #${i}`, async function () {
expect(await this.instance.unsafeAccess(i)).to.equal(elements[i]);
});
}

it('unsafeAccess outside bounds', async function () {
await expect(this.instance.unsafeAccess(elements.length)).to.not.be.rejected;
});
}
});

describe('memory', function () {
const fragment = `$unsafeMemoryAccess(${type}[] arr, uint256 pos)`;

it('unsafeAccess outside bounds', async function () {
await expect(this.mock.unsafeAccess(elements.length)).to.not.be.rejected;
for (const i in elements) {
it(`unsafeMemoryAccess within bounds #${i}`, async function () {
expect(await this.mock[fragment](elements, i)).to.equal(elements[i]);
});
}

it('unsafeMemoryAccess outside bounds', async function () {
await expect(this.mock[fragment](elements, elements.length)).to.not.be.rejected;
});
});
});
}
Expand Down
1 change: 0 additions & 1 deletion test/utils/Base64.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
pragma solidity ^0.8.20;

import {Test} from "forge-std/Test.sol";

import {Base64} from "@openzeppelin/contracts/utils/Base64.sol";

contract Base64Test is Test {
Expand Down
4 changes: 2 additions & 2 deletions test/utils/structs/Checkpoints.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
pragma solidity ^0.8.20;

import {Test} from "forge-std/Test.sol";
import {SafeCast} from "../../../contracts/utils/math/SafeCast.sol";
import {Checkpoints} from "../../../contracts/utils/structs/Checkpoints.sol";
import {SafeCast} from "@openzeppelin/contracts/utils/math/SafeCast.sol";
import {Checkpoints} from "@openzeppelin/contracts/utils/structs/Checkpoints.sol";

contract CheckpointsTrace224Test is Test {
using Checkpoints for Checkpoints.Trace224;
Expand Down
Loading