diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 1f8ebd69f..af8f024b4 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -20,11 +20,27 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 + with: + ssh-key: ${{ secrets.DEPLOY_KEY }} + - name: Setup Cache + id: cache + uses: actions/cache@v4 + with: + path: | + ~/.cargo/bin/ + ~/.cargo/registry/index/ + ~/.cargo/registry/cache/ + ~/.cargo/git/db/ + key: ${{ runner.os }}-cargo-${{ hashFiles('~/.cargo/bin/cargo-workspaces') }} + - name: Install cargo-workspaces + if: steps.cache.outputs.cache-hit != 'true' + run: cargo install cargo-workspaces - name: Bump rust crate versions, commit, and tag working-directory: wrappers/rust # https://github.com/pksunkara/cargo-workspaces?tab=readme-ov-file#version run: | - cargo install cargo-workspaces + git config user.name release-bot + git config user.email release-bot@ingonyama.com cargo workspaces version ${{ inputs.releaseType }} -y --no-individual-tags -m "Bump rust crates' version" - name: Create draft release env: diff --git a/Dockerfile b/Dockerfile index 38ce58675..a97b9f59b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -15,7 +15,7 @@ ENV PATH="/root/.cargo/bin:${PATH}" # Install Golang ENV GOLANG_VERSION 1.21.1 -RUN curl -L https://golang.org/dl/go${GOLANG_VERSION}.linux-amd64.tar.gz | tar -xz -C /usr/local +RUN curl -L https://go.dev/dl/go${GOLANG_VERSION}.linux-amd64.tar.gz | tar -xz -C /usr/local ENV PATH="/usr/local/go/bin:${PATH}" # Set the working directory in the container diff --git a/wrappers/golang/core/msm.go b/wrappers/golang/core/msm.go index 9606c856c..6229cf805 100644 --- a/wrappers/golang/core/msm.go +++ b/wrappers/golang/core/msm.go @@ -76,7 +76,7 @@ func GetDefaultMSMConfig() MSMConfig { } func MsmCheck(scalars HostOrDeviceSlice, points HostOrDeviceSlice, cfg *MSMConfig, results HostOrDeviceSlice) { - scalarsLength, pointsLength, resultsLength := scalars.Len(), points.Len(), results.Len() + scalarsLength, pointsLength, resultsLength := scalars.Len(), points.Len()/int(cfg.PrecomputeFactor), results.Len() if scalarsLength%pointsLength != 0 { errorString := fmt.Sprintf( "Number of points %d does not divide the number of scalars %d", @@ -99,3 +99,15 @@ func MsmCheck(scalars HostOrDeviceSlice, points HostOrDeviceSlice, cfg *MSMConfi cfg.arePointsOnDevice = points.IsOnDevice() cfg.areResultsOnDevice = results.IsOnDevice() } + +func PrecomputeBasesCheck(points HostOrDeviceSlice, precomputeFactor int32, outputBases DeviceSlice) { + outputBasesLength, pointsLength := outputBases.Len(), points.Len() + if outputBasesLength != pointsLength*int(precomputeFactor) { + errorString := fmt.Sprintf( + "Precompute factor is probably incorrect: expected %d but got %d", + outputBasesLength/pointsLength, + precomputeFactor, + ) + panic(errorString) + } +} diff --git a/wrappers/golang/curves/bls12377/g2_msm.go b/wrappers/golang/curves/bls12377/g2_msm.go index 28fb546b2..1380ab6fe 100644 --- a/wrappers/golang/curves/bls12377/g2_msm.go +++ b/wrappers/golang/curves/bls12377/g2_msm.go @@ -49,3 +49,28 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c err := (cr.CudaError)(__ret) return err } + +func G2PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c int32, ctx *cr.DeviceContext, outputBases core.DeviceSlice) cr.CudaError { + core.PrecomputeBasesCheck(points, precomputeFactor, outputBases) + + var pointsPointer unsafe.Pointer + if points.IsOnDevice() { + pointsPointer = points.(core.DeviceSlice).AsPointer() + } else { + pointsPointer = unsafe.Pointer(&points.(core.HostSlice[G2Affine])[0]) + } + cPoints := (*C.g2_affine_t)(pointsPointer) + + cPointsLen := (C.int)(points.Len()) + cPrecomputeFactor := (C.int)(precomputeFactor) + cC := (C.int)(c) + cPointsIsOnDevice := (C._Bool)(points.IsOnDevice()) + cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx)) + + outputBasesPointer := outputBases.AsPointer() + cOutputBases := (*C.g2_affine_t)(outputBasesPointer) + + __ret := C.bls12_377G2PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases) + err := (cr.CudaError)(__ret) + return err +} diff --git a/wrappers/golang/curves/bls12377/g2_msm_test.go b/wrappers/golang/curves/bls12377/g2_msm_test.go index eb8432d64..a02afc57d 100644 --- a/wrappers/golang/curves/bls12377/g2_msm_test.go +++ b/wrappers/golang/curves/bls12377/g2_msm_test.go @@ -136,6 +136,49 @@ func TestMSMG2Batch(t *testing.T) { } } +func TestPrecomputeBaseG2(t *testing.T) { + cfg := GetDefaultMSMConfig() + const precomputeFactor = 8 + for _, power := range []int{10, 16} { + for _, batchSize := range []int{1, 3, 16} { + size := 1 << power + totalSize := size * batchSize + scalars := GenerateScalars(totalSize) + points := G2GenerateAffinePoints(totalSize) + + var precomputeOut core.DeviceSlice + _, e := precomputeOut.Malloc(points[0].Size()*points.Len()*int(precomputeFactor), points[0].Size()) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for PrecomputeBases results failed") + + e = G2PrecomputeBases(points, precomputeFactor, 0, &cfg.Ctx, precomputeOut) + assert.Equal(t, e, cr.CudaSuccess, "PrecomputeBases failed") + + var p G2Projective + var out core.DeviceSlice + _, e = out.Malloc(batchSize*p.Size(), p.Size()) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed") + + cfg.PrecomputeFactor = precomputeFactor + + e = G2Msm(scalars, precomputeOut, &cfg, out) + assert.Equal(t, e, cr.CudaSuccess, "Msm failed") + outHost := make(core.HostSlice[G2Projective], batchSize) + outHost.CopyFromDevice(&out) + out.Free() + precomputeOut.Free() + + // Check with gnark-crypto + for i := 0; i < batchSize; i++ { + scalarsSlice := scalars[i*size : (i+1)*size] + pointsSlice := points[i*size : (i+1)*size] + out := outHost[i] + assert.True(t, testAgainstGnarkCryptoMsmG2(scalarsSlice, pointsSlice, out)) + } + } + } +} + + func TestMSMG2SkewedDistribution(t *testing.T) { cfg := GetDefaultMSMConfig() for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} { diff --git a/wrappers/golang/curves/bls12377/include/msm.h b/wrappers/golang/curves/bls12377/include/msm.h index cf84172e8..e1cfec67c 100644 --- a/wrappers/golang/curves/bls12377/include/msm.h +++ b/wrappers/golang/curves/bls12377/include/msm.h @@ -1,5 +1,6 @@ #include #include "../../include/types.h" +#include #ifndef _BLS12_377_MSM_H #define _BLS12_377_MSM_H @@ -9,6 +10,7 @@ extern "C" { #endif cudaError_t bls12_377MSMCuda(scalar_t* scalars, affine_t* points, int count, MSMConfig* config, projective_t* out); +cudaError_t bls12_377PrecomputeMSMBases(affine_t* points, int bases_size, int precompute_factor, int _c, bool are_bases_on_device, DeviceContext* ctx, affine_t* output_bases); #ifdef __cplusplus } diff --git a/wrappers/golang/curves/bls12377/msm.go b/wrappers/golang/curves/bls12377/msm.go index bad05683e..4af06a31d 100644 --- a/wrappers/golang/curves/bls12377/msm.go +++ b/wrappers/golang/curves/bls12377/msm.go @@ -47,3 +47,28 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor err := (cr.CudaError)(__ret) return err } + +func PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c int32, ctx *cr.DeviceContext, outputBases core.DeviceSlice) cr.CudaError { + core.PrecomputeBasesCheck(points, precomputeFactor, outputBases) + + var pointsPointer unsafe.Pointer + if points.IsOnDevice() { + pointsPointer = points.(core.DeviceSlice).AsPointer() + } else { + pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Affine])[0]) + } + cPoints := (*C.affine_t)(pointsPointer) + + cPointsLen := (C.int)(points.Len()) + cPrecomputeFactor := (C.int)(precomputeFactor) + cC := (C.int)(c) + cPointsIsOnDevice := (C._Bool)(points.IsOnDevice()) + cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx)) + + outputBasesPointer := outputBases.AsPointer() + cOutputBases := (*C.affine_t)(outputBasesPointer) + + __ret := C.bls12_377PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases) + err := (cr.CudaError)(__ret) + return err +} diff --git a/wrappers/golang/curves/bls12377/msm_test.go b/wrappers/golang/curves/bls12377/msm_test.go index 5c41051fa..88fa03619 100644 --- a/wrappers/golang/curves/bls12377/msm_test.go +++ b/wrappers/golang/curves/bls12377/msm_test.go @@ -107,6 +107,49 @@ func TestMSMBatch(t *testing.T) { } } +func TestPrecomputeBase(t *testing.T) { + cfg := GetDefaultMSMConfig() + const precomputeFactor = 8 + for _, power := range []int{10, 16} { + for _, batchSize := range []int{1, 3, 16} { + size := 1 << power + totalSize := size * batchSize + scalars := GenerateScalars(totalSize) + points := GenerateAffinePoints(totalSize) + + var precomputeOut core.DeviceSlice + _, e := precomputeOut.Malloc(points[0].Size()*points.Len()*int(precomputeFactor), points[0].Size()) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for PrecomputeBases results failed") + + e = PrecomputeBases(points, precomputeFactor, 0, &cfg.Ctx, precomputeOut) + assert.Equal(t, e, cr.CudaSuccess, "PrecomputeBases failed") + + var p Projective + var out core.DeviceSlice + _, e = out.Malloc(batchSize*p.Size(), p.Size()) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed") + + cfg.PrecomputeFactor = precomputeFactor + + e = Msm(scalars, precomputeOut, &cfg, out) + assert.Equal(t, e, cr.CudaSuccess, "Msm failed") + outHost := make(core.HostSlice[Projective], batchSize) + outHost.CopyFromDevice(&out) + out.Free() + precomputeOut.Free() + + // Check with gnark-crypto + for i := 0; i < batchSize; i++ { + scalarsSlice := scalars[i*size : (i+1)*size] + pointsSlice := points[i*size : (i+1)*size] + out := outHost[i] + assert.True(t, testAgainstGnarkCryptoMsm(scalarsSlice, pointsSlice, out)) + } + } + } +} + + func TestMSMSkewedDistribution(t *testing.T) { cfg := GetDefaultMSMConfig() for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} { diff --git a/wrappers/golang/curves/bls12381/g2_msm.go b/wrappers/golang/curves/bls12381/g2_msm.go index 4ecae93c3..1a82c0e4d 100644 --- a/wrappers/golang/curves/bls12381/g2_msm.go +++ b/wrappers/golang/curves/bls12381/g2_msm.go @@ -49,3 +49,28 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c err := (cr.CudaError)(__ret) return err } + +func G2PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c int32, ctx *cr.DeviceContext, outputBases core.DeviceSlice) cr.CudaError { + core.PrecomputeBasesCheck(points, precomputeFactor, outputBases) + + var pointsPointer unsafe.Pointer + if points.IsOnDevice() { + pointsPointer = points.(core.DeviceSlice).AsPointer() + } else { + pointsPointer = unsafe.Pointer(&points.(core.HostSlice[G2Affine])[0]) + } + cPoints := (*C.g2_affine_t)(pointsPointer) + + cPointsLen := (C.int)(points.Len()) + cPrecomputeFactor := (C.int)(precomputeFactor) + cC := (C.int)(c) + cPointsIsOnDevice := (C._Bool)(points.IsOnDevice()) + cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx)) + + outputBasesPointer := outputBases.AsPointer() + cOutputBases := (*C.g2_affine_t)(outputBasesPointer) + + __ret := C.bls12_381G2PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases) + err := (cr.CudaError)(__ret) + return err +} diff --git a/wrappers/golang/curves/bls12381/g2_msm_test.go b/wrappers/golang/curves/bls12381/g2_msm_test.go index 9ce775127..75abb4507 100644 --- a/wrappers/golang/curves/bls12381/g2_msm_test.go +++ b/wrappers/golang/curves/bls12381/g2_msm_test.go @@ -136,6 +136,49 @@ func TestMSMG2Batch(t *testing.T) { } } +func TestPrecomputeBaseG2(t *testing.T) { + cfg := GetDefaultMSMConfig() + const precomputeFactor = 8 + for _, power := range []int{10, 16} { + for _, batchSize := range []int{1, 3, 16} { + size := 1 << power + totalSize := size * batchSize + scalars := GenerateScalars(totalSize) + points := G2GenerateAffinePoints(totalSize) + + var precomputeOut core.DeviceSlice + _, e := precomputeOut.Malloc(points[0].Size()*points.Len()*int(precomputeFactor), points[0].Size()) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for PrecomputeBases results failed") + + e = G2PrecomputeBases(points, precomputeFactor, 0, &cfg.Ctx, precomputeOut) + assert.Equal(t, e, cr.CudaSuccess, "PrecomputeBases failed") + + var p G2Projective + var out core.DeviceSlice + _, e = out.Malloc(batchSize*p.Size(), p.Size()) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed") + + cfg.PrecomputeFactor = precomputeFactor + + e = G2Msm(scalars, precomputeOut, &cfg, out) + assert.Equal(t, e, cr.CudaSuccess, "Msm failed") + outHost := make(core.HostSlice[G2Projective], batchSize) + outHost.CopyFromDevice(&out) + out.Free() + precomputeOut.Free() + + // Check with gnark-crypto + for i := 0; i < batchSize; i++ { + scalarsSlice := scalars[i*size : (i+1)*size] + pointsSlice := points[i*size : (i+1)*size] + out := outHost[i] + assert.True(t, testAgainstGnarkCryptoMsmG2(scalarsSlice, pointsSlice, out)) + } + } + } +} + + func TestMSMG2SkewedDistribution(t *testing.T) { cfg := GetDefaultMSMConfig() for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} { diff --git a/wrappers/golang/curves/bls12381/include/msm.h b/wrappers/golang/curves/bls12381/include/msm.h index dc21a96a7..6f950eddb 100644 --- a/wrappers/golang/curves/bls12381/include/msm.h +++ b/wrappers/golang/curves/bls12381/include/msm.h @@ -1,5 +1,6 @@ #include #include "../../include/types.h" +#include #ifndef _BLS12_381_MSM_H #define _BLS12_381_MSM_H @@ -9,6 +10,7 @@ extern "C" { #endif cudaError_t bls12_381MSMCuda(scalar_t* scalars, affine_t* points, int count, MSMConfig* config, projective_t* out); +cudaError_t bls12_381PrecomputeMSMBases(affine_t* points, int bases_size, int precompute_factor, int _c, bool are_bases_on_device, DeviceContext* ctx, affine_t* output_bases); #ifdef __cplusplus } diff --git a/wrappers/golang/curves/bls12381/msm.go b/wrappers/golang/curves/bls12381/msm.go index 10d9f2860..bd51db0cd 100644 --- a/wrappers/golang/curves/bls12381/msm.go +++ b/wrappers/golang/curves/bls12381/msm.go @@ -47,3 +47,28 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor err := (cr.CudaError)(__ret) return err } + +func PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c int32, ctx *cr.DeviceContext, outputBases core.DeviceSlice) cr.CudaError { + core.PrecomputeBasesCheck(points, precomputeFactor, outputBases) + + var pointsPointer unsafe.Pointer + if points.IsOnDevice() { + pointsPointer = points.(core.DeviceSlice).AsPointer() + } else { + pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Affine])[0]) + } + cPoints := (*C.affine_t)(pointsPointer) + + cPointsLen := (C.int)(points.Len()) + cPrecomputeFactor := (C.int)(precomputeFactor) + cC := (C.int)(c) + cPointsIsOnDevice := (C._Bool)(points.IsOnDevice()) + cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx)) + + outputBasesPointer := outputBases.AsPointer() + cOutputBases := (*C.affine_t)(outputBasesPointer) + + __ret := C.bls12_381PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases) + err := (cr.CudaError)(__ret) + return err +} diff --git a/wrappers/golang/curves/bls12381/msm_test.go b/wrappers/golang/curves/bls12381/msm_test.go index 7f2d49358..7c05038b5 100644 --- a/wrappers/golang/curves/bls12381/msm_test.go +++ b/wrappers/golang/curves/bls12381/msm_test.go @@ -107,6 +107,49 @@ func TestMSMBatch(t *testing.T) { } } +func TestPrecomputeBase(t *testing.T) { + cfg := GetDefaultMSMConfig() + const precomputeFactor = 8 + for _, power := range []int{10, 16} { + for _, batchSize := range []int{1, 3, 16} { + size := 1 << power + totalSize := size * batchSize + scalars := GenerateScalars(totalSize) + points := GenerateAffinePoints(totalSize) + + var precomputeOut core.DeviceSlice + _, e := precomputeOut.Malloc(points[0].Size()*points.Len()*int(precomputeFactor), points[0].Size()) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for PrecomputeBases results failed") + + e = PrecomputeBases(points, precomputeFactor, 0, &cfg.Ctx, precomputeOut) + assert.Equal(t, e, cr.CudaSuccess, "PrecomputeBases failed") + + var p Projective + var out core.DeviceSlice + _, e = out.Malloc(batchSize*p.Size(), p.Size()) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed") + + cfg.PrecomputeFactor = precomputeFactor + + e = Msm(scalars, precomputeOut, &cfg, out) + assert.Equal(t, e, cr.CudaSuccess, "Msm failed") + outHost := make(core.HostSlice[Projective], batchSize) + outHost.CopyFromDevice(&out) + out.Free() + precomputeOut.Free() + + // Check with gnark-crypto + for i := 0; i < batchSize; i++ { + scalarsSlice := scalars[i*size : (i+1)*size] + pointsSlice := points[i*size : (i+1)*size] + out := outHost[i] + assert.True(t, testAgainstGnarkCryptoMsm(scalarsSlice, pointsSlice, out)) + } + } + } +} + + func TestMSMSkewedDistribution(t *testing.T) { cfg := GetDefaultMSMConfig() for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} { diff --git a/wrappers/golang/curves/bn254/g2_msm.go b/wrappers/golang/curves/bn254/g2_msm.go index b8a9a2b78..1dc546525 100644 --- a/wrappers/golang/curves/bn254/g2_msm.go +++ b/wrappers/golang/curves/bn254/g2_msm.go @@ -49,3 +49,28 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c err := (cr.CudaError)(__ret) return err } + +func G2PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c int32, ctx *cr.DeviceContext, outputBases core.DeviceSlice) cr.CudaError { + core.PrecomputeBasesCheck(points, precomputeFactor, outputBases) + + var pointsPointer unsafe.Pointer + if points.IsOnDevice() { + pointsPointer = points.(core.DeviceSlice).AsPointer() + } else { + pointsPointer = unsafe.Pointer(&points.(core.HostSlice[G2Affine])[0]) + } + cPoints := (*C.g2_affine_t)(pointsPointer) + + cPointsLen := (C.int)(points.Len()) + cPrecomputeFactor := (C.int)(precomputeFactor) + cC := (C.int)(c) + cPointsIsOnDevice := (C._Bool)(points.IsOnDevice()) + cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx)) + + outputBasesPointer := outputBases.AsPointer() + cOutputBases := (*C.g2_affine_t)(outputBasesPointer) + + __ret := C.bn254G2PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases) + err := (cr.CudaError)(__ret) + return err +} diff --git a/wrappers/golang/curves/bn254/g2_msm_test.go b/wrappers/golang/curves/bn254/g2_msm_test.go index c01b3e745..d1f49040b 100644 --- a/wrappers/golang/curves/bn254/g2_msm_test.go +++ b/wrappers/golang/curves/bn254/g2_msm_test.go @@ -136,6 +136,49 @@ func TestMSMG2Batch(t *testing.T) { } } +func TestPrecomputeBaseG2(t *testing.T) { + cfg := GetDefaultMSMConfig() + const precomputeFactor = 8 + for _, power := range []int{10, 16} { + for _, batchSize := range []int{1, 3, 16} { + size := 1 << power + totalSize := size * batchSize + scalars := GenerateScalars(totalSize) + points := G2GenerateAffinePoints(totalSize) + + var precomputeOut core.DeviceSlice + _, e := precomputeOut.Malloc(points[0].Size()*points.Len()*int(precomputeFactor), points[0].Size()) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for PrecomputeBases results failed") + + e = G2PrecomputeBases(points, precomputeFactor, 0, &cfg.Ctx, precomputeOut) + assert.Equal(t, e, cr.CudaSuccess, "PrecomputeBases failed") + + var p G2Projective + var out core.DeviceSlice + _, e = out.Malloc(batchSize*p.Size(), p.Size()) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed") + + cfg.PrecomputeFactor = precomputeFactor + + e = G2Msm(scalars, precomputeOut, &cfg, out) + assert.Equal(t, e, cr.CudaSuccess, "Msm failed") + outHost := make(core.HostSlice[G2Projective], batchSize) + outHost.CopyFromDevice(&out) + out.Free() + precomputeOut.Free() + + // Check with gnark-crypto + for i := 0; i < batchSize; i++ { + scalarsSlice := scalars[i*size : (i+1)*size] + pointsSlice := points[i*size : (i+1)*size] + out := outHost[i] + assert.True(t, testAgainstGnarkCryptoMsmG2(scalarsSlice, pointsSlice, out)) + } + } + } +} + + func TestMSMG2SkewedDistribution(t *testing.T) { cfg := GetDefaultMSMConfig() for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} { diff --git a/wrappers/golang/curves/bn254/include/msm.h b/wrappers/golang/curves/bn254/include/msm.h index 664032aa0..67a2c8c5e 100644 --- a/wrappers/golang/curves/bn254/include/msm.h +++ b/wrappers/golang/curves/bn254/include/msm.h @@ -1,5 +1,6 @@ #include #include "../../include/types.h" +#include #ifndef _BN254_MSM_H #define _BN254_MSM_H @@ -9,6 +10,7 @@ extern "C" { #endif cudaError_t bn254MSMCuda(scalar_t* scalars, affine_t* points, int count, MSMConfig* config, projective_t* out); +cudaError_t bn254PrecomputeMSMBases(affine_t* points, int bases_size, int precompute_factor, int _c, bool are_bases_on_device, DeviceContext* ctx, affine_t* output_bases); #ifdef __cplusplus } diff --git a/wrappers/golang/curves/bn254/msm.go b/wrappers/golang/curves/bn254/msm.go index 5ffea6126..1d773d643 100644 --- a/wrappers/golang/curves/bn254/msm.go +++ b/wrappers/golang/curves/bn254/msm.go @@ -47,3 +47,28 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor err := (cr.CudaError)(__ret) return err } + +func PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c int32, ctx *cr.DeviceContext, outputBases core.DeviceSlice) cr.CudaError { + core.PrecomputeBasesCheck(points, precomputeFactor, outputBases) + + var pointsPointer unsafe.Pointer + if points.IsOnDevice() { + pointsPointer = points.(core.DeviceSlice).AsPointer() + } else { + pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Affine])[0]) + } + cPoints := (*C.affine_t)(pointsPointer) + + cPointsLen := (C.int)(points.Len()) + cPrecomputeFactor := (C.int)(precomputeFactor) + cC := (C.int)(c) + cPointsIsOnDevice := (C._Bool)(points.IsOnDevice()) + cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx)) + + outputBasesPointer := outputBases.AsPointer() + cOutputBases := (*C.affine_t)(outputBasesPointer) + + __ret := C.bn254PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases) + err := (cr.CudaError)(__ret) + return err +} diff --git a/wrappers/golang/curves/bn254/msm_test.go b/wrappers/golang/curves/bn254/msm_test.go index 39834dfbd..37b5799ec 100644 --- a/wrappers/golang/curves/bn254/msm_test.go +++ b/wrappers/golang/curves/bn254/msm_test.go @@ -107,6 +107,49 @@ func TestMSMBatch(t *testing.T) { } } +func TestPrecomputeBase(t *testing.T) { + cfg := GetDefaultMSMConfig() + const precomputeFactor = 8 + for _, power := range []int{10, 16} { + for _, batchSize := range []int{1, 3, 16} { + size := 1 << power + totalSize := size * batchSize + scalars := GenerateScalars(totalSize) + points := GenerateAffinePoints(totalSize) + + var precomputeOut core.DeviceSlice + _, e := precomputeOut.Malloc(points[0].Size()*points.Len()*int(precomputeFactor), points[0].Size()) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for PrecomputeBases results failed") + + e = PrecomputeBases(points, precomputeFactor, 0, &cfg.Ctx, precomputeOut) + assert.Equal(t, e, cr.CudaSuccess, "PrecomputeBases failed") + + var p Projective + var out core.DeviceSlice + _, e = out.Malloc(batchSize*p.Size(), p.Size()) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed") + + cfg.PrecomputeFactor = precomputeFactor + + e = Msm(scalars, precomputeOut, &cfg, out) + assert.Equal(t, e, cr.CudaSuccess, "Msm failed") + outHost := make(core.HostSlice[Projective], batchSize) + outHost.CopyFromDevice(&out) + out.Free() + precomputeOut.Free() + + // Check with gnark-crypto + for i := 0; i < batchSize; i++ { + scalarsSlice := scalars[i*size : (i+1)*size] + pointsSlice := points[i*size : (i+1)*size] + out := outHost[i] + assert.True(t, testAgainstGnarkCryptoMsm(scalarsSlice, pointsSlice, out)) + } + } + } +} + + func TestMSMSkewedDistribution(t *testing.T) { cfg := GetDefaultMSMConfig() for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} { diff --git a/wrappers/golang/curves/bw6761/g2_msm.go b/wrappers/golang/curves/bw6761/g2_msm.go index 8d9a320ad..7e22e32ac 100644 --- a/wrappers/golang/curves/bw6761/g2_msm.go +++ b/wrappers/golang/curves/bw6761/g2_msm.go @@ -49,3 +49,28 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c err := (cr.CudaError)(__ret) return err } + +func G2PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c int32, ctx *cr.DeviceContext, outputBases core.DeviceSlice) cr.CudaError { + core.PrecomputeBasesCheck(points, precomputeFactor, outputBases) + + var pointsPointer unsafe.Pointer + if points.IsOnDevice() { + pointsPointer = points.(core.DeviceSlice).AsPointer() + } else { + pointsPointer = unsafe.Pointer(&points.(core.HostSlice[G2Affine])[0]) + } + cPoints := (*C.g2_affine_t)(pointsPointer) + + cPointsLen := (C.int)(points.Len()) + cPrecomputeFactor := (C.int)(precomputeFactor) + cC := (C.int)(c) + cPointsIsOnDevice := (C._Bool)(points.IsOnDevice()) + cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx)) + + outputBasesPointer := outputBases.AsPointer() + cOutputBases := (*C.g2_affine_t)(outputBasesPointer) + + __ret := C.bw6_761G2PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases) + err := (cr.CudaError)(__ret) + return err +} diff --git a/wrappers/golang/curves/bw6761/g2_msm_test.go b/wrappers/golang/curves/bw6761/g2_msm_test.go index 388e7dbc7..bca18c7ff 100644 --- a/wrappers/golang/curves/bw6761/g2_msm_test.go +++ b/wrappers/golang/curves/bw6761/g2_msm_test.go @@ -109,6 +109,49 @@ func TestMSMG2Batch(t *testing.T) { } } +func TestPrecomputeBaseG2(t *testing.T) { + cfg := GetDefaultMSMConfig() + const precomputeFactor = 8 + for _, power := range []int{10, 16} { + for _, batchSize := range []int{1, 3, 16} { + size := 1 << power + totalSize := size * batchSize + scalars := GenerateScalars(totalSize) + points := G2GenerateAffinePoints(totalSize) + + var precomputeOut core.DeviceSlice + _, e := precomputeOut.Malloc(points[0].Size()*points.Len()*int(precomputeFactor), points[0].Size()) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for PrecomputeBases results failed") + + e = G2PrecomputeBases(points, precomputeFactor, 0, &cfg.Ctx, precomputeOut) + assert.Equal(t, e, cr.CudaSuccess, "PrecomputeBases failed") + + var p G2Projective + var out core.DeviceSlice + _, e = out.Malloc(batchSize*p.Size(), p.Size()) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed") + + cfg.PrecomputeFactor = precomputeFactor + + e = G2Msm(scalars, precomputeOut, &cfg, out) + assert.Equal(t, e, cr.CudaSuccess, "Msm failed") + outHost := make(core.HostSlice[G2Projective], batchSize) + outHost.CopyFromDevice(&out) + out.Free() + precomputeOut.Free() + + // Check with gnark-crypto + for i := 0; i < batchSize; i++ { + scalarsSlice := scalars[i*size : (i+1)*size] + pointsSlice := points[i*size : (i+1)*size] + out := outHost[i] + assert.True(t, testAgainstGnarkCryptoMsmG2(scalarsSlice, pointsSlice, out)) + } + } + } +} + + func TestMSMG2SkewedDistribution(t *testing.T) { cfg := GetDefaultMSMConfig() for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} { diff --git a/wrappers/golang/curves/bw6761/include/msm.h b/wrappers/golang/curves/bw6761/include/msm.h index 8101cb5c0..166f0ee8b 100644 --- a/wrappers/golang/curves/bw6761/include/msm.h +++ b/wrappers/golang/curves/bw6761/include/msm.h @@ -1,5 +1,6 @@ #include #include "../../include/types.h" +#include #ifndef _BW6_761_MSM_H #define _BW6_761_MSM_H @@ -9,6 +10,7 @@ extern "C" { #endif cudaError_t bw6_761MSMCuda(scalar_t* scalars, affine_t* points, int count, MSMConfig* config, projective_t* out); +cudaError_t bw6_761PrecomputeMSMBases(affine_t* points, int bases_size, int precompute_factor, int _c, bool are_bases_on_device, DeviceContext* ctx, affine_t* output_bases); #ifdef __cplusplus } diff --git a/wrappers/golang/curves/bw6761/msm.go b/wrappers/golang/curves/bw6761/msm.go index 9bd2a8ce8..a5146755d 100644 --- a/wrappers/golang/curves/bw6761/msm.go +++ b/wrappers/golang/curves/bw6761/msm.go @@ -47,3 +47,28 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor err := (cr.CudaError)(__ret) return err } + +func PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c int32, ctx *cr.DeviceContext, outputBases core.DeviceSlice) cr.CudaError { + core.PrecomputeBasesCheck(points, precomputeFactor, outputBases) + + var pointsPointer unsafe.Pointer + if points.IsOnDevice() { + pointsPointer = points.(core.DeviceSlice).AsPointer() + } else { + pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Affine])[0]) + } + cPoints := (*C.affine_t)(pointsPointer) + + cPointsLen := (C.int)(points.Len()) + cPrecomputeFactor := (C.int)(precomputeFactor) + cC := (C.int)(c) + cPointsIsOnDevice := (C._Bool)(points.IsOnDevice()) + cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx)) + + outputBasesPointer := outputBases.AsPointer() + cOutputBases := (*C.affine_t)(outputBasesPointer) + + __ret := C.bw6_761PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases) + err := (cr.CudaError)(__ret) + return err +} diff --git a/wrappers/golang/curves/bw6761/msm_test.go b/wrappers/golang/curves/bw6761/msm_test.go index 415f88cbe..302b93ccc 100644 --- a/wrappers/golang/curves/bw6761/msm_test.go +++ b/wrappers/golang/curves/bw6761/msm_test.go @@ -107,6 +107,49 @@ func TestMSMBatch(t *testing.T) { } } +func TestPrecomputeBase(t *testing.T) { + cfg := GetDefaultMSMConfig() + const precomputeFactor = 8 + for _, power := range []int{10, 16} { + for _, batchSize := range []int{1, 3, 16} { + size := 1 << power + totalSize := size * batchSize + scalars := GenerateScalars(totalSize) + points := GenerateAffinePoints(totalSize) + + var precomputeOut core.DeviceSlice + _, e := precomputeOut.Malloc(points[0].Size()*points.Len()*int(precomputeFactor), points[0].Size()) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for PrecomputeBases results failed") + + e = PrecomputeBases(points, precomputeFactor, 0, &cfg.Ctx, precomputeOut) + assert.Equal(t, e, cr.CudaSuccess, "PrecomputeBases failed") + + var p Projective + var out core.DeviceSlice + _, e = out.Malloc(batchSize*p.Size(), p.Size()) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed") + + cfg.PrecomputeFactor = precomputeFactor + + e = Msm(scalars, precomputeOut, &cfg, out) + assert.Equal(t, e, cr.CudaSuccess, "Msm failed") + outHost := make(core.HostSlice[Projective], batchSize) + outHost.CopyFromDevice(&out) + out.Free() + precomputeOut.Free() + + // Check with gnark-crypto + for i := 0; i < batchSize; i++ { + scalarsSlice := scalars[i*size : (i+1)*size] + pointsSlice := points[i*size : (i+1)*size] + out := outHost[i] + assert.True(t, testAgainstGnarkCryptoMsm(scalarsSlice, pointsSlice, out)) + } + } + } +} + + func TestMSMSkewedDistribution(t *testing.T) { cfg := GetDefaultMSMConfig() for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} { diff --git a/wrappers/golang/internal/generator/templates/include/msm.h.tmpl b/wrappers/golang/internal/generator/templates/include/msm.h.tmpl index 82acaf8fa..9154512fc 100644 --- a/wrappers/golang/internal/generator/templates/include/msm.h.tmpl +++ b/wrappers/golang/internal/generator/templates/include/msm.h.tmpl @@ -1,5 +1,6 @@ #include #include "../../include/types.h" +#include #ifndef _{{toUpper .Curve}}_MSM_H #define _{{toUpper .Curve}}_MSM_H @@ -9,6 +10,7 @@ extern "C" { #endif cudaError_t {{.Curve}}MSMCuda(scalar_t* scalars, affine_t* points, int count, MSMConfig* config, projective_t* out); +cudaError_t {{.Curve}}PrecomputeMSMBases(affine_t* points, int bases_size, int precompute_factor, int _c, bool are_bases_on_device, DeviceContext* ctx, affine_t* output_bases); #ifdef __cplusplus } diff --git a/wrappers/golang/internal/generator/templates/msm.go.tmpl b/wrappers/golang/internal/generator/templates/msm.go.tmpl index f86f4b02b..80e79f7e4 100644 --- a/wrappers/golang/internal/generator/templates/msm.go.tmpl +++ b/wrappers/golang/internal/generator/templates/msm.go.tmpl @@ -51,3 +51,28 @@ func {{if .IsG2}}G2{{end}}Msm(scalars core.HostOrDeviceSlice, points core.HostOr err := (cr.CudaError)(__ret) return err } + +func {{if .IsG2}}G2{{end}}PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c int32, ctx *cr.DeviceContext, outputBases core.DeviceSlice) cr.CudaError { + core.PrecomputeBasesCheck(points, precomputeFactor, outputBases) + + var pointsPointer unsafe.Pointer + if points.IsOnDevice() { + pointsPointer = points.(core.DeviceSlice).AsPointer() + } else { + pointsPointer = unsafe.Pointer(&points.(core.HostSlice[{{if .IsG2}}G2{{end}}Affine])[0]) + } + cPoints := (*C.{{if .IsG2}}g2_{{end}}affine_t)(pointsPointer) + + cPointsLen := (C.int)(points.Len()) + cPrecomputeFactor := (C.int)(precomputeFactor) + cC := (C.int)(c) + cPointsIsOnDevice := (C._Bool)(points.IsOnDevice()) + cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx)) + + outputBasesPointer := outputBases.AsPointer() + cOutputBases := (*C.{{if .IsG2}}g2_{{end}}affine_t)(outputBasesPointer) + + __ret := C.{{.Curve}}{{if .IsG2}}G2{{end}}PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases) + err := (cr.CudaError)(__ret) + return err +} diff --git a/wrappers/golang/internal/generator/templates/msm_test.go.tmpl b/wrappers/golang/internal/generator/templates/msm_test.go.tmpl index 9c3bbd200..e7c127948 100644 --- a/wrappers/golang/internal/generator/templates/msm_test.go.tmpl +++ b/wrappers/golang/internal/generator/templates/msm_test.go.tmpl @@ -156,6 +156,49 @@ func TestMSM{{if .IsG2}}G2{{end}}Batch(t *testing.T) { } } +func TestPrecomputeBase{{if .IsG2}}G2{{end}}(t *testing.T) { + cfg := GetDefaultMSMConfig() + const precomputeFactor = 8 + for _, power := range []int{10, 16} { + for _, batchSize := range []int{1, 3, 16} { + size := 1 << power + totalSize := size * batchSize + scalars := GenerateScalars(totalSize) + points := {{if .IsG2}}G2{{end}}GenerateAffinePoints(totalSize) + + var precomputeOut core.DeviceSlice + _, e := precomputeOut.Malloc(points[0].Size()*points.Len()*int(precomputeFactor), points[0].Size()) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for PrecomputeBases results failed") + + e = {{if .IsG2}}G2{{end}}PrecomputeBases(points, precomputeFactor, 0, &cfg.Ctx, precomputeOut) + assert.Equal(t, e, cr.CudaSuccess, "PrecomputeBases failed") + + var p {{if .IsG2}}G2{{end}}Projective + var out core.DeviceSlice + _, e = out.Malloc(batchSize*p.Size(), p.Size()) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed") + + cfg.PrecomputeFactor = precomputeFactor + + e = {{if .IsG2}}G2{{end}}Msm(scalars, precomputeOut, &cfg, out) + assert.Equal(t, e, cr.CudaSuccess, "Msm failed") + outHost := make(core.HostSlice[{{if .IsG2}}G2{{end}}Projective], batchSize) + outHost.CopyFromDevice(&out) + out.Free() + precomputeOut.Free() + + // Check with gnark-crypto + for i := 0; i < batchSize; i++ { + scalarsSlice := scalars[i*size : (i+1)*size] + pointsSlice := points[i*size : (i+1)*size] + out := outHost[i] + assert.True(t, testAgainstGnarkCryptoMsm{{if .IsG2}}G2{{end}}(scalarsSlice, pointsSlice, out)) + } + } + } +} + + func TestMSM{{if .IsG2}}G2{{end}}SkewedDistribution(t *testing.T) { cfg := GetDefaultMSMConfig() for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {