From c52dc1c77aedf5a876a858cc5a942c29e868e9e6 Mon Sep 17 00:00:00 2001 From: Maxim Vezenov Date: Tue, 27 Aug 2024 09:09:49 -0400 Subject: [PATCH] fix(sha256): Perform compression per block and utilize ROM instead of RAM when setting up the message block (#5760) # Description ## Problem\* Resolves #5761 Resolution to performance blow-up found with sha256_var. ## Summary\* ### Issue The crux of the blow-up was the result of calling `sha256_compression` inside of the same loop where we build the message block. In the current `sha256_var` algorithm we are looping over the entire message and conditionally checking a msg byte pointer (the pointer into the msg block) to determine whether we have filled up a msg block and should run the sha compression. However, in a circuit this leads to us calling the compression opcode `N` times where `N` is the size of the message. We also were utilize RAM to build our message block when we do not have to do so. We can instead construct our block outside of the circuit and verify that the block has been constructed as we expect with assertion that just require ROM. ### Improvements This PR produces a ~16x improvement in ACIR opcodes a >13x improvement in backend constraints for the following circuit: ```rust fn main(foo: [u8; 95], toggle: bool) { let size: Field = 93 + toggle as Field * 2; let hash = std::sha256::sha256_var(foo, size as u64); println(f"{hash}"); } ``` #### master nargo info: ``` +---------+----------------------------+----------------------+--------------+-----------------+ | Package | Function | Expression Width | ACIR Opcodes | Brillig Opcodes | +---------+----------------------------+----------------------+--------------+-----------------+ | sha256 | main | Bounded { width: 4 } | 125852 | 243 | +---------+----------------------------+----------------------+--------------+-----------------+ | sha256 | print_unconstrained | N/A | N/A | 230 | +---------+----------------------------+----------------------+--------------+-----------------+ | sha256 | directive_integer_quotient | N/A | N/A | 6 | +---------+----------------------------+----------------------+--------------+-----------------+ | sha256 | directive_invert | N/A | N/A | 7 | +---------+----------------------------+----------------------+--------------+-----------------+ ``` bb gates: ``` {"functions": [ { "acir_opcodes": 125852, "circuit_size": 597646, ``` #### This PR Output of nargo info: ``` +----------------------------+----------------------------+----------------------+--------------+-----------------+ | Package | Function | Expression Width | ACIR Opcodes | Brillig Opcodes | +----------------------------+----------------------------+----------------------+--------------+-----------------+ | sha256_var_size_regression | main | Bounded { width: 4 } | 7768 | 1041 | +----------------------------+----------------------------+----------------------+--------------+-----------------+ | sha256_var_size_regression | build_msg_block_iter | N/A | N/A | 299 | +----------------------------+----------------------------+----------------------+--------------+-----------------+ | sha256_var_size_regression | pad_msg_block | N/A | N/A | 201 | +----------------------------+----------------------------+----------------------+--------------+-----------------+ | sha256_var_size_regression | attach_len_to_msg_block | N/A | N/A | 298 | +----------------------------+----------------------------+----------------------+--------------+-----------------+ | sha256_var_size_regression | print_unconstrained | N/A | N/A | 230 | +----------------------------+----------------------------+----------------------+--------------+-----------------+ | sha256_var_size_regression | directive_integer_quotient | N/A | N/A | 6 | +----------------------------+----------------------------+----------------------+--------------+-----------------+ | sha256_var_size_regression | directive_invert | N/A | N/A | 7 | +----------------------------+----------------------------+----------------------+--------------+-----------------+ ``` bb gates output: ``` {"functions": [ { "acir_opcodes": 7768, "circuit_size": 44663, ``` ## Additional Context ## Documentation\* Check one: - [ ] No documentation needed. - [ ] Documentation included in this PR. - [ ] **[For Experimental Features]** Documentation to be submitted in a separate PR. # PR Checklist\* - [ ] I have tested the changes locally. - [ ] I have formatted the changes with [Prettier](https://prettier.io/) and/or `cargo fmt` on default settings. --- noir_stdlib/src/hash/sha256.nr | 237 ++++++++++++++---- .../sha256_var_size_regression/Nargo.toml | 7 + .../sha256_var_size_regression/Prover.toml | 3 + .../sha256_var_size_regression/src/main.nr | 17 ++ .../Nargo.toml | 7 + .../Prover.toml | 2 + .../src/main.nr | 9 + 7 files changed, 235 insertions(+), 47 deletions(-) create mode 100644 test_programs/execution_success/sha256_var_size_regression/Nargo.toml create mode 100644 test_programs/execution_success/sha256_var_size_regression/Prover.toml create mode 100644 test_programs/execution_success/sha256_var_size_regression/src/main.nr create mode 100644 test_programs/execution_success/sha256_var_witness_const_regression/Nargo.toml create mode 100644 test_programs/execution_success/sha256_var_witness_const_regression/Prover.toml create mode 100644 test_programs/execution_success/sha256_var_witness_const_regression/src/main.nr diff --git a/noir_stdlib/src/hash/sha256.nr b/noir_stdlib/src/hash/sha256.nr index 5035be4b73e..55cdd984003 100644 --- a/noir_stdlib/src/hash/sha256.nr +++ b/noir_stdlib/src/hash/sha256.nr @@ -17,82 +17,224 @@ pub fn digest(msg: [u8; N]) -> [u8; 32] { sha256_var(msg, N as u64) } +// Convert 64-byte array to array of 16 u32s +fn msg_u8_to_u32(msg: [u8; 64]) -> [u32; 16] { + let mut msg32: [u32; 16] = [0; 16]; + + for i in 0..16 { + let mut msg_field: Field = 0; + for j in 0..4 { + msg_field = msg_field * 256 + msg[64 - 4*(i + 1) + j] as Field; + } + msg32[15 - i] = msg_field as u32; + } + + msg32 +} + +unconstrained fn build_msg_block_iter( + msg: [u8; N], + message_size: u64, + mut msg_block: [u8; 64], + msg_start: u32 +) -> ([u8; 64], u64) { + let mut msg_byte_ptr: u64 = 0; // Message byte pointer + for k in msg_start..N { + if k as u64 < message_size { + msg_block[msg_byte_ptr] = msg[k]; + msg_byte_ptr = msg_byte_ptr + 1; + + if msg_byte_ptr == 64 { + msg_byte_ptr = 0; + } + } + } + (msg_block, msg_byte_ptr) +} + +// Verify the block we are compressing was appropriately constructed +fn verify_msg_block( + msg: [u8; N], + message_size: u64, + msg_block: [u8; 64], + msg_start: u32 +) -> u64 { + let mut msg_byte_ptr: u64 = 0; // Message byte pointer + for k in msg_start..N { + if k as u64 < message_size { + assert_eq(msg_block[msg_byte_ptr], msg[k]); + msg_byte_ptr = msg_byte_ptr + 1; + if msg_byte_ptr == 64 { + // Enough to hash block + msg_byte_ptr = 0; + } + } else { + // Need to assert over the msg block in the else case as well + if N < 64 { + assert_eq(msg_block[msg_byte_ptr], 0); + } else { + assert_eq(msg_block[msg_byte_ptr], msg[k]); + } + } + } + msg_byte_ptr +} + +global BLOCK_SIZE = 64; + // Variable size SHA-256 hash pub fn sha256_var(msg: [u8; N], message_size: u64) -> [u8; 32] { - let mut msg_block: [u8; 64] = [0; 64]; + let num_blocks = N / BLOCK_SIZE; + let mut msg_block: [u8; BLOCK_SIZE] = [0; BLOCK_SIZE]; let mut h: [u32; 8] = [1779033703, 3144134277, 1013904242, 2773480762, 1359893119, 2600822924, 528734635, 1541459225]; // Intermediate hash, starting with the canonical initial value - let mut i: u64 = 0; // Message byte pointer - for k in 0..N { - if k as u64 < message_size { - // Populate msg_block - msg_block[i] = msg[k]; - i = i + 1; - if i == 64 { - // Enough to hash block - h = crate::hash::sha256_compression(msg_u8_to_u32(msg_block), h); + let mut msg_byte_ptr = 0; // Pointer into msg_block - i = 0; - } + if num_blocks == 0 { + unsafe { + let (new_msg_block, new_msg_byte_ptr) = build_msg_block_iter(msg, message_size, msg_block, 0); + msg_block = new_msg_block; + msg_byte_ptr = new_msg_byte_ptr; + } + + if !crate::runtime::is_unconstrained() { + msg_byte_ptr = verify_msg_block(msg, message_size, msg_block, 0); } } + + for i in 0..num_blocks { + unsafe { + let (new_msg_block, new_msg_byte_ptr) = build_msg_block_iter(msg, message_size, msg_block, BLOCK_SIZE * i); + msg_block = new_msg_block; + msg_byte_ptr = new_msg_byte_ptr; + } + if !crate::runtime::is_unconstrained() { + // Verify the block we are compressing was appropriately constructed + msg_byte_ptr = verify_msg_block(msg, message_size, msg_block, BLOCK_SIZE * i); + } + + // Hash the block + h = sha256_compression(msg_u8_to_u32(msg_block), h); + } + + let last_block = msg_block; // Pad the rest such that we have a [u32; 2] block at the end representing the length - // of the message, and a block of 1 0 ... 0 following the message (i.e. [1 << 7, 0, ..., 0]). - msg_block[i] = 1 << 7; - i = i + 1; + // of the message, and a block of 1 0 ... 0 following the message (i.e. [1 << 7, 0, ..., 0]). + msg_block[msg_byte_ptr] = 1 << 7; + msg_byte_ptr = msg_byte_ptr + 1; + unsafe { + let (new_msg_block, new_msg_byte_ptr)= pad_msg_block(msg_block, msg_byte_ptr); + msg_block = new_msg_block; + if crate::runtime::is_unconstrained() { + msg_byte_ptr = new_msg_byte_ptr; + } + } + + if !crate::runtime::is_unconstrained() { + for i in 0..64 { + if i as u64 < msg_byte_ptr - 1 { + assert_eq(msg_block[i], last_block[i]); + } + } + assert_eq(msg_block[msg_byte_ptr - 1], 1 << 7); + + // If i >= 57, there aren't enough bits in the current message block to accomplish this, so + // the 1 and 0s fill up the current block, which we then compress accordingly. + // Not enough bits (64) to store length. Fill up with zeros. + for _i in 57..64 { + if msg_byte_ptr <= 63 & msg_byte_ptr >= 57 { + assert_eq(msg_block[msg_byte_ptr], 0); + msg_byte_ptr += 1; + } + } + } + + if msg_byte_ptr >= 57 { + h = sha256_compression(msg_u8_to_u32(msg_block), h); + + msg_byte_ptr = 0; + } + + unsafe { + msg_block = attach_len_to_msg_block(msg_block, msg_byte_ptr, message_size); + } + + if !crate::runtime::is_unconstrained() { + if msg_byte_ptr != 0 { + for i in 0..64 { + if i as u64 < msg_byte_ptr - 1 { + assert_eq(msg_block[i], last_block[i]); + } + } + assert_eq(msg_block[msg_byte_ptr - 1], 1 << 7); + } + + let len = 8 * message_size; + let len_bytes = (len as Field).to_le_bytes(8); + // In any case, fill blocks up with zeros until the last 64 (i.e. until msg_byte_ptr = 56). + for _ in 0..64 { + if msg_byte_ptr < 56 { + assert_eq(msg_block[msg_byte_ptr], 0); + msg_byte_ptr = msg_byte_ptr + 1; + } + } + + let mut block_idx = 0; + for i in 56..64 { + assert_eq(msg_block[63 - block_idx], len_bytes[i - 56]); + block_idx = block_idx + 1; + } + } + + hash_final_block(msg_block, h) +} + +unconstrained fn pad_msg_block( + mut msg_block: [u8; 64], + mut msg_byte_ptr: u64 +) -> ([u8; 64], u64) { // If i >= 57, there aren't enough bits in the current message block to accomplish this, so // the 1 and 0s fill up the current block, which we then compress accordingly. - if i >= 57 { + if msg_byte_ptr >= 57 { // Not enough bits (64) to store length. Fill up with zeros. - if i < 64 { - for _i in 57..64 { - if i <= 63 { - msg_block[i] = 0; - i += 1; + if msg_byte_ptr < 64 { + for _ in 57..64 { + if msg_byte_ptr <= 63 { + msg_block[msg_byte_ptr] = 0; + msg_byte_ptr += 1; } } } - h = crate::hash::sha256_compression(msg_u8_to_u32(msg_block), h); - - i = 0; } + (msg_block, msg_byte_ptr) +} +unconstrained fn attach_len_to_msg_block( + mut msg_block: [u8; 64], + mut msg_byte_ptr: u64, + message_size: u64 +) -> [u8; 64] { let len = 8 * message_size; let len_bytes = (len as Field).to_le_bytes(8); for _i in 0..64 { - // In any case, fill blocks up with zeros until the last 64 (i.e. until i = 56). - if i < 56 { - msg_block[i] = 0; - i = i + 1; - } else if i < 64 { + // In any case, fill blocks up with zeros until the last 64 (i.e. until msg_byte_ptr = 56). + if msg_byte_ptr < 56 { + msg_block[msg_byte_ptr] = 0; + msg_byte_ptr = msg_byte_ptr + 1; + } else if msg_byte_ptr < 64 { for j in 0..8 { msg_block[63 - j] = len_bytes[j]; } - i += 8; - } - } - hash_final_block(msg_block, h) -} - -// Convert 64-byte array to array of 16 u32s -fn msg_u8_to_u32(msg: [u8; 64]) -> [u32; 16] { - let mut msg32: [u32; 16] = [0; 16]; - - for i in 0..16 { - let mut msg_field: Field = 0; - for j in 0..4 { - msg_field = msg_field * 256 + msg[64 - 4*(i + 1) + j] as Field; + msg_byte_ptr += 8; } - msg32[15 - i] = msg_field as u32; } - - msg32 + msg_block } fn hash_final_block(msg_block: [u8; 64], mut state: [u32; 8]) -> [u8; 32] { let mut out_h: [u8; 32] = [0; 32]; // Digest as sequence of bytes // Hash final padded block - state = crate::hash::sha256_compression(msg_u8_to_u32(msg_block), state); + state = sha256_compression(msg_u8_to_u32(msg_block), state); // Return final hash as byte array for j in 0..8 { @@ -104,3 +246,4 @@ fn hash_final_block(msg_block: [u8; 64], mut state: [u32; 8]) -> [u8; 32] { out_h } + diff --git a/test_programs/execution_success/sha256_var_size_regression/Nargo.toml b/test_programs/execution_success/sha256_var_size_regression/Nargo.toml new file mode 100644 index 00000000000..3e141ee5d5f --- /dev/null +++ b/test_programs/execution_success/sha256_var_size_regression/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "sha256_var_size_regression" +type = "bin" +authors = [""] +compiler_version = ">=0.33.0" + +[dependencies] \ No newline at end of file diff --git a/test_programs/execution_success/sha256_var_size_regression/Prover.toml b/test_programs/execution_success/sha256_var_size_regression/Prover.toml new file mode 100644 index 00000000000..df632a42858 --- /dev/null +++ b/test_programs/execution_success/sha256_var_size_regression/Prover.toml @@ -0,0 +1,3 @@ +enable = [true, false] +foo = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] +toggle = false diff --git a/test_programs/execution_success/sha256_var_size_regression/src/main.nr b/test_programs/execution_success/sha256_var_size_regression/src/main.nr new file mode 100644 index 00000000000..de1c2b23c5f --- /dev/null +++ b/test_programs/execution_success/sha256_var_size_regression/src/main.nr @@ -0,0 +1,17 @@ +global NUM_HASHES = 2; + +fn main(foo: [u8; 95], toggle: bool, enable: [bool; NUM_HASHES]) { + let mut result = [[0; 32]; NUM_HASHES]; + let mut const_result = [[0; 32]; NUM_HASHES]; + let size: Field = 93 + toggle as Field * 2; + for i in 0..NUM_HASHES { + if enable[i] { + result[i] = std::sha256::sha256_var(foo, size as u64); + const_result[i] = std::sha256::sha256_var(foo, 93); + } + } + + for i in 0..NUM_HASHES { + assert_eq(result[i], const_result[i]); + } +} diff --git a/test_programs/execution_success/sha256_var_witness_const_regression/Nargo.toml b/test_programs/execution_success/sha256_var_witness_const_regression/Nargo.toml new file mode 100644 index 00000000000..e8f3e6bbe64 --- /dev/null +++ b/test_programs/execution_success/sha256_var_witness_const_regression/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "sha256_var_witness_const_regression" +type = "bin" +authors = [""] +compiler_version = ">=0.33.0" + +[dependencies] \ No newline at end of file diff --git a/test_programs/execution_success/sha256_var_witness_const_regression/Prover.toml b/test_programs/execution_success/sha256_var_witness_const_regression/Prover.toml new file mode 100644 index 00000000000..7b91051c1a0 --- /dev/null +++ b/test_programs/execution_success/sha256_var_witness_const_regression/Prover.toml @@ -0,0 +1,2 @@ +input = [0, 0] +toggle = false \ No newline at end of file diff --git a/test_programs/execution_success/sha256_var_witness_const_regression/src/main.nr b/test_programs/execution_success/sha256_var_witness_const_regression/src/main.nr new file mode 100644 index 00000000000..97c4435d41d --- /dev/null +++ b/test_programs/execution_success/sha256_var_witness_const_regression/src/main.nr @@ -0,0 +1,9 @@ +fn main(input: [u8; 2], toggle: bool) { + let size: Field = 1 + toggle as Field; + assert(!toggle); + + let variable_sha = std::sha256::sha256_var(input, size as u64); + let constant_sha = std::sha256::sha256_var(input, 1); + + assert_eq(variable_sha, constant_sha); +}