diff --git a/circuits/aes-gcm/gctr.circom b/circuits/aes-gcm/gctr.circom index 395751a..72ff609 100644 --- a/circuits/aes-gcm/gctr.circom +++ b/circuits/aes-gcm/gctr.circom @@ -38,70 +38,85 @@ template GCTR(INPUT_LEN, nk) { signal input plainText[INPUT_LEN]; signal output cipherText[INPUT_LEN]; + // number of 128 bit blocks in the plaintext var nBlocks = INPUT_LEN / 128; + // size of the last block var lastBlockSize = INPUT_LEN % 128; + // total number of bits in the plaintext blocks + var bitblocks = 128 * nBlocks; - component toBlocks = ToBlocks(INPUT_LEN); - for (var i = 0; i < nBlocks * 128; i++) { - toBlocks.stream[i] <== plainText[i]; - } - + // last block of plaintext signal tempLastBlock[lastBlockSize]; - for (var i = 0; i < lastBlockSize; i++) { + for (var i = 0; i < lastBlockSize -1; i++) { tempLastBlock[i] <== plainText[nBlocks * 128 + i]; } - // intermediate signal - signal cipherBlocks[nBlocks][4][4]; - component AddCipher[nBlocks]; + // generate plaintext blocks + // note to not use the last block of plaintext + component plainTextBlocks = ToBlocks(INPUT_LEN); + plainTextBlocks.stream <== plainText; + // Step 1: Generate counter blocks - signal counterBlocks[nBlocks][128]; + // signal incCounterBlocks[nBlocks][128]; + component counterBlocks[nBlocks]; + counterBlocks[1] <== ToBlocks(128); + counterBlocks[1].stream <== initialCounterBlock; + component inc32[nBlocks]; - counterBlocks[1] <== initialCounterBlock; // For i = 2 to nBlocks, let CBi = inc32(CBi-1). for (var i = 2; i < nBlocks; i++) { inc32[i] = Increment32(); - inc32[i].in <== counterBlocks[i - 1]; - counterBlocks[i] <== inc32[i].out; + inc32[i].in <== incCounterBlocks[i - 1]; + incCounterBlocks[i] <== inc32[i].out; } + // Convert blocks to stream + component toStream = ToStream(nBlocks, bitblocks); // Step 2: Encrypt each counter block with the key component aes[nBlocks]; + component AddCipher[nBlocks]; for (var i = 1; i < nBlocks -1; i++) { + // convert counter block to blocks type + counterBlocks[i] = ToBlocks(128); + counterBlocks[i].stream <== incCounterBlocks[i]; + // encrypt counter block aes[i] = Cipher(nk); aes[i].key <== key; - aes[i].block <== counterBlocks[i]; // TODO(WJ 2024-09-10): need to turn these into blocks + aes[i].block <== counterBlocks[i].blocks[0]; // XOR cipher text with input block AddCipher[i] = AddCipher(); - AddCipher[i].state <== toBlocks.blocks[i]; + AddCipher[i].state <== plainTextBlocks.blocks[i]; AddCipher[i].cipher <== aes[i].cipher; // set output block - cipherBlocks[i] <== AddCipher[i].newState; + toStream.blocks[i] <== AddCipher[i].newState; } // Step 3: Handle the last block separately // Y* = X* ⊕ MSBlen(X*) (CIPH_K (CB_n*)) + // convert last counter block to blocks + counterBlocks[nBlocks] = ToBlocks(128); + counterBlocks[nBlocks].stream <== incCounterBlocks[nBlocks]; // encrypt the last counter block aes[nBlocks] = Cipher(nk); aes[nBlocks].key <== key; - aes[nBlocks].block <== counterBlocks[nBlocks]; + aes[nBlocks].block <== counterBlocks[nBlocks].blocks[0]; // XOR the cipher with the last chunk of un padded plaintext + component aesCipherToStream = ToStream(1, 128); component addLastCipher = XorMultiple(2, lastBlockSize); for (var i = 0; i < lastBlockSize; i++) { - addLastCipher.inputs[0][i] <== aes[nBlocks].cipher[i]; + // convert cipher to stream + aesCipherToStream.blocks[0] <== aes[nBlocks].cipher; + addLastCipher.inputs[0][i] <== aesCipherToStream.stream[i]; addLastCipher.inputs[1][i] <== tempLastBlock[i]; } - var bitblocks = 128 * nBlocks; - // Convert blocks to stream - component toStream = ToStream(nBlocks, bitblocks); - toStream.blocks <== cipherBlocks; + for (var i = 0; i < bitblocks; i++) { cipherText[i] <== toStream.stream[i]; cipherText[bitblocks + i] <== addLastCipher.out[i]; diff --git a/circuits/aes-gcm/helper_functions.circom b/circuits/aes-gcm/helper_functions.circom index 24d50f3..e0a7cc1 100644 --- a/circuits/aes-gcm/helper_functions.circom +++ b/circuits/aes-gcm/helper_functions.circom @@ -300,8 +300,8 @@ template IndexSelector(total) { out <== calcTotal.sum; } -// reverse the bit order in an n-bit array -template ReverseBitsArray(n) { +// reverse the order in an n-bit array +template ReverseArray(n) { signal input in[n]; signal output out[n]; @@ -355,10 +355,87 @@ template Increment32() { } // TODO(WJ 2024-09-09): Check if this bit-reversal is needed. - component reverseBits = ReverseBitsArray(32); + component reverseBits = ReverseArray(32); reverseBits.in <== incrementedBits; // Copy the incremented bits to the output for (var i = 0; i < 32; i++) { out[96 + i] <== reverseBits.out[i]; } -} \ No newline at end of file +} + +/// IncrementingFunction increments the integer represented by the 32 least significant bits of the input 16-byte block +/// and returns the result. +template Increment32Block() { + signal input in[4][4]; + signal output out[4][4]; + + log("input:"); + log(in[0][0], in[0][1], in[0][2], in[0][3]); + log(in[1][0], in[1][1], in[1][2], in[1][3]); + log(in[2][0], in[2][1], in[2][2], in[2][3]); + log(in[3][0], in[3][1], in[3][2], in[3][3]); + // Copy the left-most 12 bytes unchanged + for (var i = 0; i < 3; i++) { + for (var j = 0; j < 4; j++) { + out[i][j] <== in[i][j]; + } + } + + // Convert the last 4 bytes to an 32 bit number + // signal bits[32]; + component bits2num = Bits2Num(32); + component byte2bits[4]; + for (var i = 0; i < 4; i++) { + byte2bits[i] = Num2Bits(8); + byte2bits[i].in <== in[3][i]; + for (var j = 0; j < 8; j++) { + bits2num.in[i * 8 + j] <== byte2bits[i].out[j]; + } + } + // TODO: handle overflow + signal incremented <== bits2num.out + 1; + + // Convert the incremented integer back to binary + component num2bits = Num2Bits(32); + num2bits.in <== incremented; + signal incrementedBits[32]; + for (var i = 0; i < 32; i++) { + incrementedBits[i] <== num2bits.out[i]; + } + + + // Convert the incremented bits back to four bytes and assign to out + component bits2byte[4]; + signal outBytes[4]; + for (var i = 0; i < 4; i++) { + bits2byte[i] = Bits2Num(8); + for (var j = 0; j < 8; j++) { + bits2byte[i].in[j] <== incrementedBits[i * 8 + j]; + } + outBytes[i] <== bits2byte[i].out; + } + log("outBytes:"); + log(outBytes[0], outBytes[1], outBytes[2], outBytes[3]); + out[3][0] <== outBytes[3]; + out[3][1] <== outBytes[2]; + out[3][2] <== outBytes[1]; + out[3][3] <== outBytes[0]; +} + +// Idea: try to increment the word by 1 +// template IncrementWord() { +// signal input in[4]; +// signal output out[4]; + +// // Convert the 4 bytes to a 32-bit number +// signal num <== in[0] * 0x1000000 + in[1] * 0x10000 + in[2] * 0x100 + in[3]; + +// // Increment the number +// signal incremented <== num + 1; + +// // Convert the incremented number back to 4 bytes +// out[0] <== (incremented >> 24) & 0xFF; +// out[1] <== (incremented >> 16) & 0xFF; +// out[2] <== (incremented >> 8) & 0xFF; +// out[3] <== incremented & 0xFF; +// } diff --git a/circuits/test/increment.test.ts b/circuits/test/increment.test.ts index 13d233a..f23e908 100644 --- a/circuits/test/increment.test.ts +++ b/circuits/test/increment.test.ts @@ -17,4 +17,78 @@ describe("Increment", () => { } ); }); -}); \ No newline at end of file +}); + +describe("Increment32Block", () => { + let circuit: WitnessTester<["in"], ["out"]>; + it("should increment the block input 0", async () => { + circuit = await circomkit.WitnessTester(`Increment32Block`, { + file: "aes-gcm/helper_functions", + template: "Increment32Block", + }); + await circuit.expectPass( + { + in: [ + [0x00, 0x00, 0x00, 0x00], + [0x00, 0x00, 0x00, 0x00], + [0x00, 0x00, 0x00, 0x00], + [0x00, 0x00, 0x00, 0x00] + ], + }, + { + out: [ + [0x00, 0x00, 0x00, 0x00], + [0x00, 0x00, 0x00, 0x00], + [0x00, 0x00, 0x00, 0x00], + [0x00, 0x00, 0x00, 0x01] + ], + } + ); + }); + it("should increment the block input 1", async () => { + circuit = await circomkit.WitnessTester(`Increment32Block`, { + file: "aes-gcm/helper_functions", + template: "Increment32Block", + }); + await circuit.expectPass( + { + in: [ + [0x00, 0x00, 0x00, 0x00], + [0x00, 0x00, 0x00, 0x00], + [0x00, 0x00, 0x00, 0x00], + [0x00, 0x00, 0x00, 0x01] + ], + }, + { + out: [ + [0x00, 0x00, 0x00, 0x00], + [0x00, 0x00, 0x00, 0x00], + [0x00, 0x00, 0x00, 0x00], + [0x00, 0x00, 0x00, 0x02] + ], + } + ); + }); +}); + +// describe("IncrementWord", () => { +// let circuit: WitnessTester<["in"], ["out"]>; +// it("should increment the word input", async () => { +// circuit = await circomkit.WitnessTester(`Increment32Block`, { +// file: "aes-gcm/helper_functions", +// template: "IncrementWord", +// }); +// await circuit.expectPass( +// { +// in: [ +// [0x00, 0x00, 0x00, 0x00], +// ], +// }, +// { +// out: [ +// [0x00, 0x00, 0x00, 0x01] +// ], +// } +// ); +// }); +// }); \ No newline at end of file