Skip to content

Commit

Permalink
fix(sha256): Perform compression per block and utilize ROM instead of…
Browse files Browse the repository at this point in the history
… RAM when setting up the message block (#5760)

# Description

## Problem\*

Resolves #5761 

Resolution to performance blow-up found with sha256_var.

## Summary\*

### Issue

The crux of the blow-up was the result of calling `sha256_compression`
inside of the same loop where we build the message block. In the current
`sha256_var` algorithm we are looping over the entire message and
conditionally checking a msg byte pointer (the pointer into the msg
block) to determine whether we have filled up a msg block and should run
the sha compression. However, in a circuit this leads to us calling the
compression opcode `N` times where `N` is the size of the message.

We also were utilize RAM to build our message block when we do not have
to do so. We can instead construct our block outside of the circuit and
verify that the block has been constructed as we expect with assertion
that just require ROM.

### Improvements

This PR produces a ~16x improvement in ACIR opcodes a >13x improvement
in backend constraints for the following circuit:
```rust
fn main(foo: [u8; 95], toggle: bool) {
    let size: Field = 93 + toggle as Field * 2;
    let hash = std::sha256::sha256_var(foo, size as u64);
    println(f"{hash}");
}  
```
#### master
nargo info:
```
+---------+----------------------------+----------------------+--------------+-----------------+
| Package | Function                   | Expression Width     | ACIR Opcodes | Brillig Opcodes |
+---------+----------------------------+----------------------+--------------+-----------------+
| sha256  | main                       | Bounded { width: 4 } | 125852       | 243             |
+---------+----------------------------+----------------------+--------------+-----------------+
| sha256  | print_unconstrained        | N/A                  | N/A          | 230             |
+---------+----------------------------+----------------------+--------------+-----------------+
| sha256  | directive_integer_quotient | N/A                  | N/A          | 6               |
+---------+----------------------------+----------------------+--------------+-----------------+
| sha256  | directive_invert           | N/A                  | N/A          | 7               |
+---------+----------------------------+----------------------+--------------+-----------------+
```
bb gates:
```
{"functions": [
  {
        "acir_opcodes": 125852,
        "circuit_size": 597646,
```

#### This PR

Output of nargo info:
```
+----------------------------+----------------------------+----------------------+--------------+-----------------+
| Package                    | Function                   | Expression Width     | ACIR Opcodes | Brillig Opcodes |
+----------------------------+----------------------------+----------------------+--------------+-----------------+
| sha256_var_size_regression | main                       | Bounded { width: 4 } | 7768         | 1041            |
+----------------------------+----------------------------+----------------------+--------------+-----------------+
| sha256_var_size_regression | build_msg_block_iter       | N/A                  | N/A          | 299             |
+----------------------------+----------------------------+----------------------+--------------+-----------------+
| sha256_var_size_regression | pad_msg_block              | N/A                  | N/A          | 201             |
+----------------------------+----------------------------+----------------------+--------------+-----------------+
| sha256_var_size_regression | attach_len_to_msg_block    | N/A                  | N/A          | 298             |
+----------------------------+----------------------------+----------------------+--------------+-----------------+
| sha256_var_size_regression | print_unconstrained        | N/A                  | N/A          | 230             |
+----------------------------+----------------------------+----------------------+--------------+-----------------+
| sha256_var_size_regression | directive_integer_quotient | N/A                  | N/A          | 6               |
+----------------------------+----------------------------+----------------------+--------------+-----------------+
| sha256_var_size_regression | directive_invert           | N/A                  | N/A          | 7               |
+----------------------------+----------------------------+----------------------+--------------+-----------------+
```
bb gates output:
```
{"functions": [
  {
        "acir_opcodes": 7768,
        "circuit_size": 44663,
```

## Additional Context



## Documentation\*

Check one:
- [ ] No documentation needed.
- [ ] Documentation included in this PR.
- [ ] **[For Experimental Features]** Documentation to be submitted in a
separate PR.

# PR Checklist\*

- [ ] I have tested the changes locally.
- [ ] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.
  • Loading branch information
vezenovm authored Aug 27, 2024
1 parent 6145877 commit c52dc1c
Show file tree
Hide file tree
Showing 7 changed files with 235 additions and 47 deletions.
237 changes: 190 additions & 47 deletions noir_stdlib/src/hash/sha256.nr
Original file line number Diff line number Diff line change
Expand Up @@ -17,82 +17,224 @@ pub fn digest<let N: u32>(msg: [u8; N]) -> [u8; 32] {
sha256_var(msg, N as u64)
}

// Convert 64-byte array to array of 16 u32s
fn msg_u8_to_u32(msg: [u8; 64]) -> [u32; 16] {
let mut msg32: [u32; 16] = [0; 16];

for i in 0..16 {
let mut msg_field: Field = 0;
for j in 0..4 {
msg_field = msg_field * 256 + msg[64 - 4*(i + 1) + j] as Field;
}
msg32[15 - i] = msg_field as u32;
}

msg32
}

unconstrained fn build_msg_block_iter<let N: u32>(
msg: [u8; N],
message_size: u64,
mut msg_block: [u8; 64],
msg_start: u32
) -> ([u8; 64], u64) {
let mut msg_byte_ptr: u64 = 0; // Message byte pointer
for k in msg_start..N {
if k as u64 < message_size {
msg_block[msg_byte_ptr] = msg[k];
msg_byte_ptr = msg_byte_ptr + 1;

if msg_byte_ptr == 64 {
msg_byte_ptr = 0;
}
}
}
(msg_block, msg_byte_ptr)
}

// Verify the block we are compressing was appropriately constructed
fn verify_msg_block<let N: u32>(
msg: [u8; N],
message_size: u64,
msg_block: [u8; 64],
msg_start: u32
) -> u64 {
let mut msg_byte_ptr: u64 = 0; // Message byte pointer
for k in msg_start..N {
if k as u64 < message_size {
assert_eq(msg_block[msg_byte_ptr], msg[k]);
msg_byte_ptr = msg_byte_ptr + 1;
if msg_byte_ptr == 64 {
// Enough to hash block
msg_byte_ptr = 0;
}
} else {
// Need to assert over the msg block in the else case as well
if N < 64 {
assert_eq(msg_block[msg_byte_ptr], 0);
} else {
assert_eq(msg_block[msg_byte_ptr], msg[k]);
}
}
}
msg_byte_ptr
}

global BLOCK_SIZE = 64;

// Variable size SHA-256 hash
pub fn sha256_var<let N: u32>(msg: [u8; N], message_size: u64) -> [u8; 32] {
let mut msg_block: [u8; 64] = [0; 64];
let num_blocks = N / BLOCK_SIZE;
let mut msg_block: [u8; BLOCK_SIZE] = [0; BLOCK_SIZE];
let mut h: [u32; 8] = [1779033703, 3144134277, 1013904242, 2773480762, 1359893119, 2600822924, 528734635, 1541459225]; // Intermediate hash, starting with the canonical initial value
let mut i: u64 = 0; // Message byte pointer
for k in 0..N {
if k as u64 < message_size {
// Populate msg_block
msg_block[i] = msg[k];
i = i + 1;
if i == 64 {
// Enough to hash block
h = crate::hash::sha256_compression(msg_u8_to_u32(msg_block), h);
let mut msg_byte_ptr = 0; // Pointer into msg_block

i = 0;
}
if num_blocks == 0 {
unsafe {
let (new_msg_block, new_msg_byte_ptr) = build_msg_block_iter(msg, message_size, msg_block, 0);
msg_block = new_msg_block;
msg_byte_ptr = new_msg_byte_ptr;
}

if !crate::runtime::is_unconstrained() {
msg_byte_ptr = verify_msg_block(msg, message_size, msg_block, 0);
}
}

for i in 0..num_blocks {
unsafe {
let (new_msg_block, new_msg_byte_ptr) = build_msg_block_iter(msg, message_size, msg_block, BLOCK_SIZE * i);
msg_block = new_msg_block;
msg_byte_ptr = new_msg_byte_ptr;
}
if !crate::runtime::is_unconstrained() {
// Verify the block we are compressing was appropriately constructed
msg_byte_ptr = verify_msg_block(msg, message_size, msg_block, BLOCK_SIZE * i);
}

// Hash the block
h = sha256_compression(msg_u8_to_u32(msg_block), h);
}

let last_block = msg_block;
// Pad the rest such that we have a [u32; 2] block at the end representing the length
// of the message, and a block of 1 0 ... 0 following the message (i.e. [1 << 7, 0, ..., 0]).
msg_block[i] = 1 << 7;
i = i + 1;
// of the message, and a block of 1 0 ... 0 following the message (i.e. [1 << 7, 0, ..., 0]).
msg_block[msg_byte_ptr] = 1 << 7;
msg_byte_ptr = msg_byte_ptr + 1;
unsafe {
let (new_msg_block, new_msg_byte_ptr)= pad_msg_block(msg_block, msg_byte_ptr);
msg_block = new_msg_block;
if crate::runtime::is_unconstrained() {
msg_byte_ptr = new_msg_byte_ptr;
}
}

if !crate::runtime::is_unconstrained() {
for i in 0..64 {
if i as u64 < msg_byte_ptr - 1 {
assert_eq(msg_block[i], last_block[i]);
}
}
assert_eq(msg_block[msg_byte_ptr - 1], 1 << 7);

// If i >= 57, there aren't enough bits in the current message block to accomplish this, so
// the 1 and 0s fill up the current block, which we then compress accordingly.
// Not enough bits (64) to store length. Fill up with zeros.
for _i in 57..64 {
if msg_byte_ptr <= 63 & msg_byte_ptr >= 57 {
assert_eq(msg_block[msg_byte_ptr], 0);
msg_byte_ptr += 1;
}
}
}

if msg_byte_ptr >= 57 {
h = sha256_compression(msg_u8_to_u32(msg_block), h);

msg_byte_ptr = 0;
}

unsafe {
msg_block = attach_len_to_msg_block(msg_block, msg_byte_ptr, message_size);
}

if !crate::runtime::is_unconstrained() {
if msg_byte_ptr != 0 {
for i in 0..64 {
if i as u64 < msg_byte_ptr - 1 {
assert_eq(msg_block[i], last_block[i]);
}
}
assert_eq(msg_block[msg_byte_ptr - 1], 1 << 7);
}

let len = 8 * message_size;
let len_bytes = (len as Field).to_le_bytes(8);
// In any case, fill blocks up with zeros until the last 64 (i.e. until msg_byte_ptr = 56).
for _ in 0..64 {
if msg_byte_ptr < 56 {
assert_eq(msg_block[msg_byte_ptr], 0);
msg_byte_ptr = msg_byte_ptr + 1;
}
}

let mut block_idx = 0;
for i in 56..64 {
assert_eq(msg_block[63 - block_idx], len_bytes[i - 56]);
block_idx = block_idx + 1;
}
}

hash_final_block(msg_block, h)
}

unconstrained fn pad_msg_block<let N: u32>(
mut msg_block: [u8; 64],
mut msg_byte_ptr: u64
) -> ([u8; 64], u64) {
// If i >= 57, there aren't enough bits in the current message block to accomplish this, so
// the 1 and 0s fill up the current block, which we then compress accordingly.
if i >= 57 {
if msg_byte_ptr >= 57 {
// Not enough bits (64) to store length. Fill up with zeros.
if i < 64 {
for _i in 57..64 {
if i <= 63 {
msg_block[i] = 0;
i += 1;
if msg_byte_ptr < 64 {
for _ in 57..64 {
if msg_byte_ptr <= 63 {
msg_block[msg_byte_ptr] = 0;
msg_byte_ptr += 1;
}
}
}
h = crate::hash::sha256_compression(msg_u8_to_u32(msg_block), h);

i = 0;
}
(msg_block, msg_byte_ptr)
}

unconstrained fn attach_len_to_msg_block<let N: u32>(
mut msg_block: [u8; 64],
mut msg_byte_ptr: u64,
message_size: u64
) -> [u8; 64] {
let len = 8 * message_size;
let len_bytes = (len as Field).to_le_bytes(8);
for _i in 0..64 {
// In any case, fill blocks up with zeros until the last 64 (i.e. until i = 56).
if i < 56 {
msg_block[i] = 0;
i = i + 1;
} else if i < 64 {
// In any case, fill blocks up with zeros until the last 64 (i.e. until msg_byte_ptr = 56).
if msg_byte_ptr < 56 {
msg_block[msg_byte_ptr] = 0;
msg_byte_ptr = msg_byte_ptr + 1;
} else if msg_byte_ptr < 64 {
for j in 0..8 {
msg_block[63 - j] = len_bytes[j];
}
i += 8;
}
}
hash_final_block(msg_block, h)
}

// Convert 64-byte array to array of 16 u32s
fn msg_u8_to_u32(msg: [u8; 64]) -> [u32; 16] {
let mut msg32: [u32; 16] = [0; 16];

for i in 0..16 {
let mut msg_field: Field = 0;
for j in 0..4 {
msg_field = msg_field * 256 + msg[64 - 4*(i + 1) + j] as Field;
msg_byte_ptr += 8;
}
msg32[15 - i] = msg_field as u32;
}

msg32
msg_block
}

fn hash_final_block(msg_block: [u8; 64], mut state: [u32; 8]) -> [u8; 32] {
let mut out_h: [u8; 32] = [0; 32]; // Digest as sequence of bytes

// Hash final padded block
state = crate::hash::sha256_compression(msg_u8_to_u32(msg_block), state);
state = sha256_compression(msg_u8_to_u32(msg_block), state);

// Return final hash as byte array
for j in 0..8 {
Expand All @@ -104,3 +246,4 @@ fn hash_final_block(msg_block: [u8; 64], mut state: [u32; 8]) -> [u8; 32] {

out_h
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[package]
name = "sha256_var_size_regression"
type = "bin"
authors = [""]
compiler_version = ">=0.33.0"

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
enable = [true, false]
foo = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
toggle = false
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
global NUM_HASHES = 2;

fn main(foo: [u8; 95], toggle: bool, enable: [bool; NUM_HASHES]) {
let mut result = [[0; 32]; NUM_HASHES];
let mut const_result = [[0; 32]; NUM_HASHES];
let size: Field = 93 + toggle as Field * 2;
for i in 0..NUM_HASHES {
if enable[i] {
result[i] = std::sha256::sha256_var(foo, size as u64);
const_result[i] = std::sha256::sha256_var(foo, 93);
}
}

for i in 0..NUM_HASHES {
assert_eq(result[i], const_result[i]);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[package]
name = "sha256_var_witness_const_regression"
type = "bin"
authors = [""]
compiler_version = ">=0.33.0"

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
input = [0, 0]
toggle = false
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
fn main(input: [u8; 2], toggle: bool) {
let size: Field = 1 + toggle as Field;
assert(!toggle);

let variable_sha = std::sha256::sha256_var(input, size as u64);
let constant_sha = std::sha256::sha256_var(input, 1);

assert_eq(variable_sha, constant_sha);
}

0 comments on commit c52dc1c

Please sign in to comment.