-
-
Notifications
You must be signed in to change notification settings - Fork 43
/
lib.rs
96 lines (81 loc) · 3.12 KB
/
lib.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
//! Constantine
//! Copyright (c) 2018-2019 Status Research & Development GmbH
//! Copyright (c) 2020-Present Mamy André-Ratsimbazafy
//! Licensed and distributed under either of
//! * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
//! * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
//! at your option. This file may not be copied, modified, or distributed except according to those terms.
//! Implementation of the ZK Accel Layer using Constantine as a backend
//! See https://github.com/privacy-scaling-explorations/halo2/issues/216
use constantine_core::Threadpool;
use constantine_sys::*;
use ::core::mem::MaybeUninit;
use std::mem;
use halo2_middleware::halo2curves::bn256;
use halo2_middleware::zal::traits::MsmAccel;
use halo2_middleware::halo2curves::CurveAffine;
#[derive(Debug)]
pub struct CttEngine(Threadpool);
impl CttEngine {
#[inline(always)]
pub fn new(num_threads: usize) -> Self {
Self(Threadpool::new(num_threads))
}
}
#[derive(Debug, Clone, Copy)]
pub struct CttMsmCoeffsDesc<'c, C: CurveAffine> {
raw: &'c [C::Scalar],
}
#[derive(Debug, Clone, Copy)]
pub struct CttMsmBaseDesc<'b, C: CurveAffine> {
raw: &'b [C],
}
impl MsmAccel<bn256::G1Affine> for CttEngine {
fn msm(&self, coeffs: &[bn256::Fr], bases: &[bn256::G1Affine]) -> bn256::G1 {
assert_eq!(coeffs.len(), bases.len());
let mut result = MaybeUninit::<bn254_snarks_g1_prj>::uninit();
unsafe {
ctt_bn254_snarks_g1_prj_multi_scalar_mul_fr_coefs_vartime_parallel(
self.0.get_private_context(),
result.as_mut_ptr(),
coeffs.as_ptr() as *const bn254_snarks_fr,
bases.as_ptr() as *const bn254_snarks_g1_aff,
bases.len(),
);
mem::transmute::<MaybeUninit<bn254_snarks_g1_prj>, bn256::G1>(result)
}
}
// Caching API
// -------------------------------------------------
type CoeffsDescriptor<'c> = CttMsmCoeffsDesc<'c, bn256::G1Affine>;
type BaseDescriptor<'b> = CttMsmBaseDesc<'b, bn256::G1Affine>;
fn get_coeffs_descriptor<'c>(&self, coeffs: &'c [bn256::Fr]) -> Self::CoeffsDescriptor<'c> {
// Do expensive device/library specific preprocessing here
Self::CoeffsDescriptor { raw: coeffs }
}
fn get_base_descriptor<'b>(&self, base: &'b [bn256::G1Affine]) -> Self::BaseDescriptor<'b> {
// Do expensive device/library specific preprocessing here
Self::BaseDescriptor { raw: base }
}
fn msm_with_cached_scalars(
&self,
coeffs: &Self::CoeffsDescriptor<'_>,
base: &[bn256::G1Affine],
) -> bn256::G1 {
self.msm(coeffs.raw, base)
}
fn msm_with_cached_base(
&self,
coeffs: &[bn256::Fr],
base: &Self::BaseDescriptor<'_>,
) -> bn256::G1 {
self.msm(coeffs, base.raw)
}
fn msm_with_cached_inputs(
&self,
coeffs: &Self::CoeffsDescriptor<'_>,
base: &Self::BaseDescriptor<'_>,
) -> bn256::G1 {
self.msm(coeffs.raw, base.raw)
}
}