Skip to content

Commit

Permalink
feat: optimize sha2 implementation (#4441)
Browse files Browse the repository at this point in the history
# Description

## Problem\*

Resolves <!-- Link to GitHub Issue -->

## Summary\*

We're currently performing byte decompositions in the sha2 functions
through repeated division. This would be more efficient if we just did
the full byte decomposition at once and then iterate through the
results.

I've also removed some noop casts.

## Additional Context



## Documentation\*

Check one:
- [x] No documentation needed.
- [ ] Documentation included in this PR.
- [ ] **[Exceptional Case]** Documentation to be submitted in a separate
PR.

# PR Checklist\*

- [x] I have tested the changes locally.
- [x] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.
  • Loading branch information
TomAFrench authored Mar 13, 2024
1 parent 9ca1a60 commit 80373d6
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 20 deletions.
22 changes: 12 additions & 10 deletions noir_stdlib/src/sha256.nr
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ fn msg_u8_to_u32(msg: [u8; 64]) -> [u32; 16] {
let mut msg32: [u32; 16] = [0; 16];

for i in 0..16 {
let mut msg_field: Field = 0;
for j in 0..4 {
msg32[15 - i] = (msg32[15 - i] << 8) + msg[64 - 4*(i + 1) + j] as u32;
msg_field = msg_field * 256 + msg[64 - 4*(i + 1) + j] as Field;
}
msg32[15 - i] = msg_field as u32;
}

msg32
Expand All @@ -21,7 +23,7 @@ pub fn digest<N>(msg: [u8; N]) -> [u8; 32] {
let mut i: u64 = 0; // Message byte pointer
for k in 0..N {
// Populate msg_block
msg_block[i as Field] = msg[k];
msg_block[i] = msg[k];
i = i + 1;
if i == 64 {
// Enough to hash block
Expand All @@ -32,7 +34,7 @@ pub fn digest<N>(msg: [u8; N]) -> [u8; 32] {
}
// Pad the rest such that we have a [u32; 2] block at the end representing the length
// of the message, and a block of 1 0 ... 0 following the message (i.e. [1 << 7, 0, ..., 0]).
msg_block[i as Field] = 1 << 7;
msg_block[i] = 1 << 7;
i = i + 1;
// If i >= 57, there aren't enough bits in the current message block to accomplish this, so
// the 1 and 0s fill up the current block, which we then compress accordingly.
Expand All @@ -41,7 +43,7 @@ pub fn digest<N>(msg: [u8; N]) -> [u8; 32] {
if i < 64 {
for _i in 57..64 {
if i <= 63 {
msg_block[i as Field] = 0;
msg_block[i] = 0;
i += 1;
}
}
Expand All @@ -51,16 +53,16 @@ pub fn digest<N>(msg: [u8; N]) -> [u8; 32] {
i = 0;
}

let len = 8 * msg.len();
let len_bytes = (len as Field).to_le_bytes(8);
for _i in 0..64 {
// In any case, fill blocks up with zeros until the last 64 (i.e. until i = 56).
if i < 56 {
msg_block[i as Field] = 0;
msg_block[i] = 0;
i = i + 1;
} else if i < 64 {
let mut len = 8 * msg.len();
for j in 0..8 {
msg_block[63 - j] = len as u8;
len >>= 8;
msg_block[63 - j] = len_bytes[j];
}
i += 8;
}
Expand All @@ -70,9 +72,9 @@ pub fn digest<N>(msg: [u8; N]) -> [u8; 32] {

// Return final hash as byte array
for j in 0..8 {
let h_bytes = (h[7 - j] as Field).to_le_bytes(4);
for k in 0..4 {
out_h[31 - 4*j - k] = h[7 - j] as u8;
h[7-j] >>= 8;
out_h[31 - 4*j - k] = h_bytes[k];
}
}

Expand Down
22 changes: 12 additions & 10 deletions noir_stdlib/src/sha512.nr
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,11 @@ fn msg_u8_to_u64(msg: [u8; 128]) -> [u64; 16] {
let mut msg64: [u64; 16] = [0; 16];

for i in 0..16 {
let mut msg_field: Field = 0;
for j in 0..8 {
msg64[15 - i] = (msg64[15 - i] << 8) + msg[128 - 8*(i + 1) + j] as u64;
msg_field = msg_field * 256 + msg[128 - 8*(i + 1) + j] as Field;
}
msg64[15 - i] = msg_field as u64;
}

msg64
Expand All @@ -94,7 +96,7 @@ pub fn digest<N>(msg: [u8; N]) -> [u8; 64] {
let mut i: u64 = 0; // Message byte pointer
for k in 0..msg.len() {
// Populate msg_block
msg_block[i as Field] = msg[k];
msg_block[i] = msg[k];
i = i + 1;
if i == 128 {
// Enough to hash block
Expand All @@ -108,7 +110,7 @@ pub fn digest<N>(msg: [u8; N]) -> [u8; 64] {
}
// Pad the rest such that we have a [u64; 2] block at the end representing the length
// of the message, and a block of 1 0 ... 0 following the message (i.e. [1 << 7, 0, ..., 0]).
msg_block[i as Field] = 1 << 7;
msg_block[i] = 1 << 7;
i += 1;
// If i >= 113, there aren't enough bits in the current message block to accomplish this, so
// the 1 and 0s fill up the current block, which we then compress accordingly.
Expand All @@ -117,7 +119,7 @@ pub fn digest<N>(msg: [u8; N]) -> [u8; 64] {
if i < 128 {
for _i in 113..128 {
if i <= 127 {
msg_block[i as Field] = 0;
msg_block[i] = 0;
i += 1;
}
}
Expand All @@ -130,16 +132,16 @@ pub fn digest<N>(msg: [u8; N]) -> [u8; 64] {
i = 0;
}

let len = 8 * msg.len();
let len_bytes = (len as Field).to_le_bytes(16);
for _i in 0..128 {
// In any case, fill blocks up with zeros until the last 128 (i.e. until i = 112).
if i < 112 {
msg_block[i as Field] = 0;
msg_block[i] = 0;
i += 1;
} else if i < 128 {
let mut len = 8 * msg.len();
for j in 0..16 {
msg_block[127 - j] = len as u8;
len >>= 8;
msg_block[127 - j] = len_bytes[j];
}
i += 16; // Done.
}
Expand All @@ -151,9 +153,9 @@ pub fn digest<N>(msg: [u8; N]) -> [u8; 64] {
}
// Return final hash as byte array
for j in 0..8 {
let h_bytes = (h[7 - j] as Field).to_le_bytes(8);
for k in 0..8 {
out_h[63 - 8*j - k] = h[7 - j] as u8;
h[7-j] >>= 8;
out_h[63 - 8*j - k] = h_bytes[k];
}
}

Expand Down

0 comments on commit 80373d6

Please sign in to comment.