Skip to content

Commit

Permalink
crs: first version of optimized pedersen hash
Browse files Browse the repository at this point in the history
Signed-off-by: Ignacio Hagopian <jsign.uy@gmail.com>
  • Loading branch information
jsign committed Sep 21, 2023
1 parent d4ee759 commit a828a95
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 5 deletions.
14 changes: 9 additions & 5 deletions src/crs/crs.zig
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ pub const CRS = struct {
};

fn deserialize_vkt_points() [DomainSize]Element {
var points: [crs_points.len]Element = undefined;
for (crs_points, 0..) |serialized_point, i| {
var points: [vkt_crs_points.len]Element = undefined;
for (vkt_crs_points, 0..) |serialized_point, i| {
var g_be_bytes: [32]u8 = undefined;
_ = std.fmt.hexToBytes(&g_be_bytes, serialized_point) catch unreachable;
points[i] = Element.fromBytes(g_be_bytes) catch unreachable;
Expand All @@ -47,12 +47,12 @@ fn deserialize_vkt_points() [DomainSize]Element {

test "crs is consistent" {
const crs = CRS.init();
try std.testing.expect(crs.Gs.len == crs_points.len);
try std.testing.expect(crs.Gs.len == vkt_crs_points.len);

// Reserialize Gs points and check they match with the original representation.
for (crs.Gs, 0..) |g, i| {
const got_point = std.fmt.bytesToHex(g.toBytes(), std.fmt.Case.lower);
const expected_point = crs_points[i];
const expected_point = vkt_crs_points[i];
try std.testing.expect(std.mem.eql(u8, &got_point, expected_point));
}

Expand All @@ -74,7 +74,7 @@ test "Gs cannot contain the generator" {
}
}

const crs_points = [_][]const u8{
const vkt_crs_points = [_][]const u8{
"01587ad1336675eb912550ec2a28eb8923b824b490dd2ba82e48f14590a298a0",
"6c6e607df0723edfff382fa914bfc38136f3300ab2e06fb97007b559fd323b82",
"326be3bebfd97ed9d0d4ca1b8bc47e036a24b129f1488110b71c2cae1463db8f",
Expand Down Expand Up @@ -332,3 +332,7 @@ const crs_points = [_][]const u8{
"3102a5884d3dce8d94a8cf6d5ab2d3a4c76ec8b00f4554caa68c028aedf5970f",
"3de2be346b539395b0c0de56a5ccca54a317f1b5c80107b0802af9a62276a4d8",
};

test "msm" {
_ = @import("msm.zig");
}
115 changes: 115 additions & 0 deletions src/crs/msm.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
const std = @import("std");
const Allocator = std.mem.Allocator;
const banderwagon = @import("../banderwagon/banderwagon.zig");
const Element = banderwagon.Element;
const Fr = banderwagon.Fr;

const PrecompMSM = struct {
allocator: Allocator,
b: usize,
basis_len: usize,
table: []const Element,

pub fn init(allocator: Allocator, basis: []const Element, b: usize) !PrecompMSM {
std.debug.assert(basis.len % b == 0);

const window_size = std.math.shl(usize, 1, b);
const num_windows = basis.len / b;
const table_num_elements = window_size * num_windows;
var table = try allocator.alloc(Element, table_num_elements);

for (0..num_windows) |w| {
const window_basis = basis[w * b .. (w + 1) * b];
fill_window(window_basis, table[w * window_size .. (w + 1) * window_size]);
}
return PrecompMSM{
.allocator = allocator,
.b = b,
.basis_len = basis.len,
.table = table,
};
}

pub fn deinit(self: PrecompMSM) void {
self.allocator.free(self.table);
}

pub fn msm(self: PrecompMSM, mont_scalars: []const Fr) !Element {
std.debug.assert(mont_scalars.len <= self.basis_len);

var scalars = try self.allocator.alloc(u256, mont_scalars.len);
defer self.allocator.free(scalars);
for (0..mont_scalars.len) |i| {
scalars[i] = mont_scalars[i].toInteger();
}

const window_size = std.math.shl(usize, 1, self.b);
const num_windows = self.basis_len / self.b;
var accum = Element.identity();
for (0..253) |k| {
accum.double(accum);
for (0..num_windows) |w| {
if (w * self.b < scalars.len) {
const window_scalars = scalars[w * self.b ..];
var table_idx: usize = 0;
for (0..self.b) |i| {
table_idx <<= 1;
if (i < window_scalars.len) {
table_idx |= @as(u1, @truncate((window_scalars[i] >> @as(u8, @intCast(252 - k)))));
}
}
const window_table = self.table[w * window_size .. (w + 1) * window_size];
accum.add(accum, window_table[table_idx]);
}
}
}

return accum;
}

fn fill_window(basis: []const Element, table: []Element) void {
if (basis.len == 0) {
table[0] = Element.identity();
return;
}
fill_window(basis[1..], table[0 .. table.len / 2]);
for (0..table.len / 2) |i| {
table[table.len / 2 + i].add(table[i], basis[0]);
}
}
};

test "correctness" {
const crs = @import("crs.zig");
const CRS = crs.CRS.init();
var test_allocator = std.testing.allocator;

const precomp = try PrecompMSM.init(test_allocator, &CRS.Gs, 8);
defer precomp.deinit();

var scalars: [crs.DomainSize]Fr = undefined;
for (0..scalars.len) |i| {
scalars[i] = Fr.fromInteger(i + 0x424242);
}

for (1..crs.DomainSize) |msm_length| {
const msm_scalars = scalars[0..msm_length];

var full_scalars: [crs.DomainSize]Fr = undefined;
for (0..full_scalars.len) |i| {
if (i < msm_length) {
full_scalars[i] = msm_scalars[i];
continue;
}
full_scalars[i] = Fr.zero();
}
std.debug.print("For {} ...", .{msm_length});
const exp = CRS.commit(full_scalars);
std.debug.print("ok", .{});

const got = try precomp.msm(msm_scalars);
std.debug.print(" ok\n", .{});

try std.testing.expect(Element.equal(exp, got));
}
}

0 comments on commit a828a95

Please sign in to comment.