Skip to content

Commit

Permalink
Fix underconstrained signals (#89)
Browse files Browse the repository at this point in the history
* fix: aes-gcm passing

* chore: delete yarn

* fix: aes underconstrained circuits

* feat: move to finite field ops

* fix underconstrained vars at more places

* add circom test CI

* fix tests
  • Loading branch information
lonerapier authored Sep 28, 2024
1 parent 0384531 commit f0b1dcc
Show file tree
Hide file tree
Showing 15 changed files with 332 additions and 1,359 deletions.
35 changes: 35 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
name: circom

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]

jobs:
circom:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4

- name: Use Node.js
uses: actions/setup-node@v4
with:
node-version: '20'

- name: Install dependencies
run: |
npm install
npm install -g snarkjs
- name: Download and install Circom
run: |
CIRCOM_VERSION=2.1.9
curl -L https://github.com/iden3/circom/releases/download/v$CIRCOM_VERSION/circom-linux-amd64 -o circom
chmod +x circom
sudo mv circom /usr/local/bin/
circom --version
- name: Run tests
run: npm run test
36 changes: 20 additions & 16 deletions circuits/aes-gcm/aes-gcm.circom
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ include "gctr.circom";


/// AES-GCM with 128 bit key authenticated encryption according to: https://nvlpubs.nist.gov/nistpubs/legacy/sp/nistspecialpublication800-38d.pdf
///
///
/// Parameters:
/// l: length of the plaintext
///
Expand All @@ -21,7 +21,7 @@ include "gctr.circom";
/// Outputs:
/// cipherText: encrypted ciphertext
/// authTag: authentication tag
///
///
template AESGCM(l) {
// Inputs
signal input key[16]; // 128-bit key
Expand Down Expand Up @@ -54,7 +54,7 @@ template AESGCM(l) {
}
component J0WordIncrementer = IncrementWord();
J0WordIncrementer.in <== J0builder.blocks[0][3];

component J0WordIncrementer2 = IncrementWord();
J0WordIncrementer2.in <== J0WordIncrementer.out;

Expand All @@ -81,14 +81,14 @@ template AESGCM(l) {
}
var ghashblocks = 1 + blockCount + 1; // blocksize is 16 bytes

//
//
// A => 1 => length of AAD (always at most 128 bits)
// 0^v => padding bytes, none for v
// C => l\16+1 => number of ciphertext blocks
// 0^u => padding bytes, u value
// len(A) => u64
// len(b) => u64 (together, 1 block)
//
//
signal ghashMessage[ghashblocks][4][4];

// set aad as first block
Expand All @@ -104,7 +104,7 @@ template AESGCM(l) {
for (var i=0; i<blockCount; i++) {
ghashMessage[i+1] <== ciphertextBlocks.blocks[i];
}

// length of aad = 128 = 0x80 as 64 bit number
ghashMessage[ghashblocks-1][0] <== [0x00, 0x00, 0x00, 0x00];
ghashMessage[ghashblocks-1][1] <== [0x00, 0x00, 0x00, 0x80];
Expand All @@ -126,30 +126,34 @@ template AESGCM(l) {
hashKeyToStream.blocks[0] <== cipherH.cipher;
ghash.HashKey <== hashKeyToStream.stream;
// S = GHASHH (A || 0^v || C || 0^u || [len(A)] || [len(C)]).
component msgToStream = ToStream(ghashblocks, 16);
msgToStream.blocks <== ghashMessage;
ghash.msg <== msgToStream.stream;
component selectedBlocksToStream[ghashblocks];
for (var i = 0 ; i<ghashblocks ; i++) {
ghash.msg[i] <== ToStream(1, 16)([ghashMessage[i]]);
}
// ghash.msg <== msgToStream.stream;
// In Steps 4 and 5, the AAD and the ciphertext are each appended with the minimum number of
// ‘0’ bits, possibly none, so that the bit lengths of the resulting strings are multiples of the block
// size. The concatenation of these strings is appended with the 64-bit representations of the
// lengths of the AAD and the ciphertext, and the GHASH function is applied to the result to
// produce a single output block.

// TODO: Check the endianness
log("ghash bytes"); // BUG: Currently 0.
var bytes[16];
// TODO: this is underconstrained too
// log("ghash bytes"); // BUG: Currently 0.
signal bytes[16];
signal tagBytes[16 * 8] <== BytesToBits(16)(ghash.tag);
for(var i = 0; i < 16; i++) {
var byteValue = 0;
var sum=1;
for(var j = 0; j<8; j++) {
var bitIndex = i*8+j;
byteValue += ghash.tag[bitIndex]*sum;
var bitIndex = i*8+j;
byteValue += tagBytes[bitIndex]*sum;
sum = sum*sum;
}
log(byteValue);
bytes[i] = byteValue;
// log(byteValue);
bytes[i] <== byteValue;
}
log("end ghash bytes");
// log("end ghash bytes");

// Step 6: Let T = MSBt(GCTRK(J0, S))
component gctrT = GCTR(16, 4);
Expand Down
170 changes: 170 additions & 0 deletions circuits/aes-gcm/aes/ff.circom
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
// from: https://github.com/crema-labs/aes-circom/tree/main/circuits
pragma circom 2.1.9;

include "circomlib/circuits/bitify.circom";

// Finite field addition, the signal variable plus a compile-time constant
template FieldAddConst(c) {
signal input in[8];
// control bit, if 0, then do not perform addition
signal input control;
signal output out[8];

for (var i=0; i<8; i++) {
if(c & (1<<i) != 0) {
// XOR operation
out[i] <== in[i] + control - 2 * in[i] * control;
} else {
out[i] <== in[i];
}
}
}

// Finite field multiplication by 2 operation for AES. This involves left-shifting 'input' by 1 (input << 1),
// and then XORing with 0x1B if the most significate bit is 1. This is because the irreducible polynomial
// for AES's finite field (GF(2^8)) is x^8 + x^4 + x^3 + x + 1.
template FieldMul2() {
signal input in;
signal output out;

signal inBits[8];
inBits <== Num2Bits(8)(in);

component reduce = FieldAddConst(0x1b);
reduce.in[0] <== 0;
for (var i = 1; i < 8; i++) {
reduce.in[i] <== inBits[i-1];
}
reduce.control <== inBits[7];
out <== Bits2Num(8)(reduce.out);
}

// Finite field multiplication by 3 operation for AES. This involves (input << 1) ⊕ input and then XORing
// with 0x1B if the most significate bit is 1.
template FieldMul3() {
signal input in;
signal output out;

signal inBits[8] <== Num2Bits(8)(in);

component reduce = FieldAddConst(0x1b);
reduce.in[0] <== inBits[0];
for (var i = 1; i < 8; i++) {
reduce.in[i] <== inBits[i-1] + inBits[i] - 2 * inBits[i-1] * inBits[i];
}
reduce.control <== inBits[7];
out <== Bits2Num(8)(reduce.out);
}

// Determine the parity (odd or even) of an integer that can be accommodated within 'nBits' bits.
template IsOdd(nBits) {
signal input in;
signal output out;
if (nBits == 1) {
out <== in;
} else {
signal bits[nBits] <== Num2Bits(nBits)(in);
out <== bits[0];
}
}

// Finite field multiplication.
template FieldMul() {
signal input a;
signal input b;
signal inBits[2][8];
signal output out;

inBits[0] <== Num2Bits(8)(a);
inBits[1] <== Num2Bits(8)(b);

// List of finite field elements obtained by successively doubling, starting from 1.
var power[15] = [0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36, 0x6c, 0xd8, 0xab, 0x4d, 0x9a];

signal mulMatrix[8][8];
var outLinesLc[8];
for (var i = 0; i < 8; i++) {
outLinesLc[i] = 0;
}
// Apply elementary multiplication
for (var i = 0; i < 8; i++) {
for (var j = 0; j < 8; j++) {
mulMatrix[i][j] <== inBits[0][i] * inBits[1][j];
for (var t = 0; t < 8; t++) {
if (power[i+j] & (1 << t) != 0) {
outLinesLc[t] += mulMatrix[i][j];
}
}
}
}
signal outBitsUnreduced[8];
signal outBits[8];
for (var i = 0; i < 8; i++) {
outBitsUnreduced[i] <== outLinesLc[i];
// Each element in 'outLinesLc' is incremented by a known constant number of
// elements from 'mulMatrix', less than 31.
outBits[i] <== IsOdd(6)(outBitsUnreduced[i]);
}

out <== Bits2Num(8)(outBits);
}

// Finite Field Inversion. Specially, if the input is 0, the output is also 0.
template FieldInv() {
signal input in;
signal output out;

var inv[256] = [0x00, 0x01, 0x8d, 0xf6, 0xcb, 0x52, 0x7b, 0xd1, 0xe8, 0x4f, 0x29, 0xc0, 0xb0, 0xe1, 0xe5, 0xc7,
0x74, 0xb4, 0xaa, 0x4b, 0x99, 0x2b, 0x60, 0x5f, 0x58, 0x3f, 0xfd, 0xcc, 0xff, 0x40, 0xee, 0xb2,
0x3a, 0x6e, 0x5a, 0xf1, 0x55, 0x4d, 0xa8, 0xc9, 0xc1, 0x0a, 0x98, 0x15, 0x30, 0x44, 0xa2, 0xc2,
0x2c, 0x45, 0x92, 0x6c, 0xf3, 0x39, 0x66, 0x42, 0xf2, 0x35, 0x20, 0x6f, 0x77, 0xbb, 0x59, 0x19,
0x1d, 0xfe, 0x37, 0x67, 0x2d, 0x31, 0xf5, 0x69, 0xa7, 0x64, 0xab, 0x13, 0x54, 0x25, 0xe9, 0x09,
0xed, 0x5c, 0x05, 0xca, 0x4c, 0x24, 0x87, 0xbf, 0x18, 0x3e, 0x22, 0xf0, 0x51, 0xec, 0x61, 0x17,
0x16, 0x5e, 0xaf, 0xd3, 0x49, 0xa6, 0x36, 0x43, 0xf4, 0x47, 0x91, 0xdf, 0x33, 0x93, 0x21, 0x3b,
0x79, 0xb7, 0x97, 0x85, 0x10, 0xb5, 0xba, 0x3c, 0xb6, 0x70, 0xd0, 0x06, 0xa1, 0xfa, 0x81, 0x82,
0x83, 0x7e, 0x7f, 0x80, 0x96, 0x73, 0xbe, 0x56, 0x9b, 0x9e, 0x95, 0xd9, 0xf7, 0x02, 0xb9, 0xa4,
0xde, 0x6a, 0x32, 0x6d, 0xd8, 0x8a, 0x84, 0x72, 0x2a, 0x14, 0x9f, 0x88, 0xf9, 0xdc, 0x89, 0x9a,
0xfb, 0x7c, 0x2e, 0xc3, 0x8f, 0xb8, 0x65, 0x48, 0x26, 0xc8, 0x12, 0x4a, 0xce, 0xe7, 0xd2, 0x62,
0x0c, 0xe0, 0x1f, 0xef, 0x11, 0x75, 0x78, 0x71, 0xa5, 0x8e, 0x76, 0x3d, 0xbd, 0xbc, 0x86, 0x57,
0x0b, 0x28, 0x2f, 0xa3, 0xda, 0xd4, 0xe4, 0x0f, 0xa9, 0x27, 0x53, 0x04, 0x1b, 0xfc, 0xac, 0xe6,
0x7a, 0x07, 0xae, 0x63, 0xc5, 0xdb, 0xe2, 0xea, 0x94, 0x8b, 0xc4, 0xd5, 0x9d, 0xf8, 0x90, 0x6b,
0xb1, 0x0d, 0xd6, 0xeb, 0xc6, 0x0e, 0xcf, 0xad, 0x08, 0x4e, 0xd7, 0xe3, 0x5d, 0x50, 0x1e, 0xb3,
0x5b, 0x23, 0x38, 0x34, 0x68, 0x46, 0x03, 0x8c, 0xdd, 0x9c, 0x7d, 0xa0, 0xcd, 0x1a, 0x41, 0x1c];

// Obtain an unchecked result from a lookup table
out <-- inv[in];
// Compute the product of the input and output, expected to be 1
signal checkRes <== FieldMul()(in, out);
// For the special case when the input is 0, both input and output should be 0
signal isZeroIn <== IsZero()(in);
signal isZeroOut <== IsZero()(out);
signal checkZero <== isZeroIn * isZeroOut;
// Ensure that either the product is 1 or both input and output are 0, satisfying at least one condition
(1 - checkRes) * (1 - checkZero) === 0;
}

// AffineTransform required by the S-box computation.
template AffineTransform() {
signal input inBits[8];
signal output outBits[8];

var matrix[8][8] = [[1, 0, 0, 0, 1, 1, 1, 1],
[1, 1, 0, 0, 0, 1, 1, 1],
[1, 1, 1, 0, 0, 0, 1, 1],
[1, 1, 1, 1, 0, 0, 0, 1],
[1, 1, 1, 1, 1, 0, 0, 0],
[0, 1, 1, 1, 1, 1, 0, 0],
[0, 0, 1, 1, 1, 1, 1, 0],
[0, 0, 0, 1, 1, 1, 1, 1]];
var offset[8] = [1, 1, 0, 0, 0, 1, 1, 0];
for (var i = 0; i < 8; i++) {
var lc = 0;
for (var j = 0; j < 8; j++) {
if (matrix[i][j] == 1) {
lc += inBits[j];
}
}
lc += offset[i];
outBits[i] <== IsOdd(3)(lc);
}
}
26 changes: 11 additions & 15 deletions circuits/aes-gcm/aes/key_expansion.circom
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ include "utils.circom";
// @inputs key: array of nk*4 bytes representing the key
// @outputs keyExpanded: array of (nr+1)*4 words i.e for AES 128, 192, 256 it will be 44, 52, 60 words
template KeyExpansion(nk,nr) {
assert(nk == 4 || nk == 6 || nk == 8 );
assert(nk == 4 || nk == 6 || nk == 8 );
signal input key[nk * 4];

var totalWords = (4 * (nr + 1));
var effectiveRounds = nk == 4 ? 10 : totalWords\nk;

Expand All @@ -55,21 +55,19 @@ template KeyExpansion(nk,nr) {
keyExpanded[i][j] <== key[(4 * i) + j];
}
}

component nextRound[effectiveRounds];

for (var round = 1; round <= effectiveRounds; round++) {
var outputWordLen = round == effectiveRounds ? 4 : nk;
nextRound[round - 1] = NextRound(nk, outputWordLen);
var outputWordLen = round == effectiveRounds ? 4 : nk;
nextRound[round - 1] = NextRound(nk, outputWordLen, round);

for (var i = 0; i < nk; i++) {
for (var j = 0; j < 4; j++) {
nextRound[round - 1].key[i][j] <== keyExpanded[(round * nk) + i - nk][j];
}
}

nextRound[round - 1].round <== round;

for (var i = 0; i < outputWordLen; i++) {
for (var j = 0; j < 4; j++) {
keyExpanded[(round * nk) + i][j] <== nextRound[round - 1].nextKey[i][j];
Expand All @@ -80,22 +78,20 @@ template KeyExpansion(nk,nr) {

// @param nk: number of keys which can be 4, 6, 8
// @param o: number of output words which can be 4 or nk
template NextRound(nk, o){
signal input key[nk][4];
signal input round;
template NextRound(nk, o, round){
signal input key[nk][4];
signal output nextKey[o][4];

component rotateWord = Rotate(1, 4);
for (var i = 0; i < 4; i++) {
rotateWord.bytes[i] <== key[nk - 1][i];
}

component substituteWord[2];
substituteWord[0] = SubstituteWord();
substituteWord[0].bytes <== rotateWord.rotated;

component rcon = RCon();
rcon.round <== round;
component rcon = RCon(round);

component xorWord[o + 1];
xorWord[0] = XorWord();
Expand All @@ -114,7 +110,7 @@ template NextRound(nk, o){
xorWord[i+1].bytes1 <== nextKey[i-1];
}
xorWord[i+1].bytes2 <== key[i];

for (var j = 0; j < 4; j++) {
nextKey[i][j] <== xorWord[i+1].out[j];
}
Expand Down
Loading

0 comments on commit f0b1dcc

Please sign in to comment.