From fa219d9c95a2233b8b01878ea861155fa806991a Mon Sep 17 00:00:00 2001 From: Jeremy Felder Date: Sun, 10 Mar 2024 08:57:35 +0200 Subject: [PATCH 1/8] Fix release flow with deploy key and caching (#425) ## Describe the changes This PR fixes the release flow action --- .github/workflows/release.yml | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 1f8ebd69f..af8f024b4 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -20,11 +20,27 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 + with: + ssh-key: ${{ secrets.DEPLOY_KEY }} + - name: Setup Cache + id: cache + uses: actions/cache@v4 + with: + path: | + ~/.cargo/bin/ + ~/.cargo/registry/index/ + ~/.cargo/registry/cache/ + ~/.cargo/git/db/ + key: ${{ runner.os }}-cargo-${{ hashFiles('~/.cargo/bin/cargo-workspaces') }} + - name: Install cargo-workspaces + if: steps.cache.outputs.cache-hit != 'true' + run: cargo install cargo-workspaces - name: Bump rust crate versions, commit, and tag working-directory: wrappers/rust # https://github.com/pksunkara/cargo-workspaces?tab=readme-ov-file#version run: | - cargo install cargo-workspaces + git config user.name release-bot + git config user.email release-bot@ingonyama.com cargo workspaces version ${{ inputs.releaseType }} -y --no-individual-tags -m "Bump rust crates' version" - name: Create draft release env: From 08ec0b1ff6483d61aaad184a25438c74a9e8f19d Mon Sep 17 00:00:00 2001 From: hhh_QC <52317293+cyl19970726@users.noreply.github.com> Date: Sun, 10 Mar 2024 16:47:08 +0800 Subject: [PATCH 2/8] update go install source in Dockerfile (#428) --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 38ce58675..a97b9f59b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -15,7 +15,7 @@ ENV PATH="/root/.cargo/bin:${PATH}" # Install Golang ENV GOLANG_VERSION 1.21.1 -RUN curl -L https://golang.org/dl/go${GOLANG_VERSION}.linux-amd64.tar.gz | tar -xz -C /usr/local +RUN curl -L https://go.dev/dl/go${GOLANG_VERSION}.linux-amd64.tar.gz | tar -xz -C /usr/local ENV PATH="/usr/local/go/bin:${PATH}" # Set the working directory in the container From b83def1a3e02f1f51e45e2a87cf3c945a66d2cb7 Mon Sep 17 00:00:00 2001 From: nonam3e Date: Mon, 11 Mar 2024 17:42:15 +0000 Subject: [PATCH 3/8] precompute bases template init --- wrappers/golang/core/msm.go | 16 ++++++++++ .../internal/generator/templates/msm.go.tmpl | 30 +++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/wrappers/golang/core/msm.go b/wrappers/golang/core/msm.go index 9606c856c..b85e70942 100644 --- a/wrappers/golang/core/msm.go +++ b/wrappers/golang/core/msm.go @@ -99,3 +99,19 @@ func MsmCheck(scalars HostOrDeviceSlice, points HostOrDeviceSlice, cfg *MSMConfi cfg.arePointsOnDevice = points.IsOnDevice() cfg.areResultsOnDevice = results.IsOnDevice() } + +func PrecomputeBasesCheck(points HostOrDeviceSlice, precompute_factor int32, output_bases HostOrDeviceSlice) { + outputBasesLength, pointsLength := output_bases.Len(), points.Len() + if outputBasesLength != pointsLength*int(precompute_factor) { + errorString := fmt.Sprintf( + "Precompute factor is probably incorrect: expected %d but got %d", + outputBasesLength/pointsLength, + precompute_factor, + ) + panic(errorString) + } + + if !output_bases.IsOnDevice() { + panic("Output bases are not on device") + } +} diff --git a/wrappers/golang/internal/generator/templates/msm.go.tmpl b/wrappers/golang/internal/generator/templates/msm.go.tmpl index f86f4b02b..ae49dbcd3 100644 --- a/wrappers/golang/internal/generator/templates/msm.go.tmpl +++ b/wrappers/golang/internal/generator/templates/msm.go.tmpl @@ -51,3 +51,33 @@ func {{if .IsG2}}G2{{end}}Msm(scalars core.HostOrDeviceSlice, points core.HostOr err := (cr.CudaError)(__ret) return err } + +func {{if .IsG2}}G2{{end}}PrecomputeBases(points core.HostOrDeviceSlice, precompute_factor int32, _c int32, ctx *cr.DeviceContext, output_bases core.HostOrDeviceSlice) cr.CudaError { + core.PrecomputeBasesCheck(points, precompute_factor, output_bases) + + var pointsPointer unsafe.Pointer + if points.IsOnDevice() { + pointsPointer = points.(core.DeviceSlice).AsPointer() + } else { + pointsPointer = unsafe.Pointer(&points.(core.HostSlice[{{if .IsG2}}G2{{end}}Affine])[0]) + } + cPoints := (*C.{{if .IsG2}}g2_{{end}}affine_t)(pointsPointer) + + cPointsLen := (C.int)(points.Len()) + cPrecomputeFactor := (C.int)(precompute_factor) + c_C := (C.int)(_c) + cPointsIsOnDevice = (C._Bool)(points.IsOnDevice()) + cCtx := (*C.MSMConfig)(unsafe.Pointer(ctx)) + + var outputBasesPointer unsafe.Pointer + if output_bases.IsOnDevice() { + outputBasesPointer = output_bases.(core.DeviceSlice).AsPointer() + } else { + outputBasesPointer = unsafe.Pointer(&output_bases.(core.HostSlice[{{if .IsG2}}G2{{end}}Projective])[0]) + } + cOutputBases = (*C.{{if .IsG2}}g2_{{end}}projective_t)(outputBasesPointer) + + __ret := C.{{.Curve}}{{if .IsG2}}G2{{end}}PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, c_C, cPointsIsOnDevice, cCtx, cOutputBases) + err := (cr.CudaError)(__ret) + return err +} From e89dabb91765bbf4e2fa16723336787cde9c2fc9 Mon Sep 17 00:00:00 2001 From: nonam3e Date: Tue, 12 Mar 2024 08:59:50 +0000 Subject: [PATCH 4/8] precompute bases template --- wrappers/golang/internal/generator/templates/include/msm.h.tmpl | 1 + wrappers/golang/internal/generator/templates/msm.go.tmpl | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/wrappers/golang/internal/generator/templates/include/msm.h.tmpl b/wrappers/golang/internal/generator/templates/include/msm.h.tmpl index 82acaf8fa..699a4d6d2 100644 --- a/wrappers/golang/internal/generator/templates/include/msm.h.tmpl +++ b/wrappers/golang/internal/generator/templates/include/msm.h.tmpl @@ -9,6 +9,7 @@ extern "C" { #endif cudaError_t {{.Curve}}MSMCuda(scalar_t* scalars, affine_t* points, int count, MSMConfig* config, projective_t* out); +cudaError_t {{.Curve}}PrecomputeMSMBases(affine_t* points, int bases_size, int precompute_factor, int _c, bool are_bases_on_device, DeviceContext* ctx, affine_t* output_bases); #ifdef __cplusplus } diff --git a/wrappers/golang/internal/generator/templates/msm.go.tmpl b/wrappers/golang/internal/generator/templates/msm.go.tmpl index ae49dbcd3..2f03d23f6 100644 --- a/wrappers/golang/internal/generator/templates/msm.go.tmpl +++ b/wrappers/golang/internal/generator/templates/msm.go.tmpl @@ -67,7 +67,7 @@ func {{if .IsG2}}G2{{end}}PrecomputeBases(points core.HostOrDeviceSlice, precomp cPrecomputeFactor := (C.int)(precompute_factor) c_C := (C.int)(_c) cPointsIsOnDevice = (C._Bool)(points.IsOnDevice()) - cCtx := (*C.MSMConfig)(unsafe.Pointer(ctx)) + cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx)) var outputBasesPointer unsafe.Pointer if output_bases.IsOnDevice() { From 739c1fcaf9cb01edbab75f45c71ac2691afda092 Mon Sep 17 00:00:00 2001 From: nonam3e Date: Tue, 12 Mar 2024 09:27:25 +0000 Subject: [PATCH 5/8] precompute bases template --- wrappers/golang/internal/generator/templates/include/msm.h.tmpl | 1 + wrappers/golang/internal/generator/templates/msm.go.tmpl | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/wrappers/golang/internal/generator/templates/include/msm.h.tmpl b/wrappers/golang/internal/generator/templates/include/msm.h.tmpl index 699a4d6d2..9154512fc 100644 --- a/wrappers/golang/internal/generator/templates/include/msm.h.tmpl +++ b/wrappers/golang/internal/generator/templates/include/msm.h.tmpl @@ -1,5 +1,6 @@ #include #include "../../include/types.h" +#include #ifndef _{{toUpper .Curve}}_MSM_H #define _{{toUpper .Curve}}_MSM_H diff --git a/wrappers/golang/internal/generator/templates/msm.go.tmpl b/wrappers/golang/internal/generator/templates/msm.go.tmpl index 2f03d23f6..bf6bea35f 100644 --- a/wrappers/golang/internal/generator/templates/msm.go.tmpl +++ b/wrappers/golang/internal/generator/templates/msm.go.tmpl @@ -67,7 +67,7 @@ func {{if .IsG2}}G2{{end}}PrecomputeBases(points core.HostOrDeviceSlice, precomp cPrecomputeFactor := (C.int)(precompute_factor) c_C := (C.int)(_c) cPointsIsOnDevice = (C._Bool)(points.IsOnDevice()) - cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx)) + cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx)) var outputBasesPointer unsafe.Pointer if output_bases.IsOnDevice() { From a3a5d36424f68e07a6c07e5f5de3623c4f8ea678 Mon Sep 17 00:00:00 2001 From: nonam3e Date: Tue, 12 Mar 2024 10:02:00 +0000 Subject: [PATCH 6/8] precompute bases curve generated --- wrappers/golang/curves/bls12377/g2_msm.go | 30 +++++++++++++++++++ wrappers/golang/curves/bls12377/include/msm.h | 2 ++ wrappers/golang/curves/bls12377/msm.go | 30 +++++++++++++++++++ wrappers/golang/curves/bls12381/g2_msm.go | 30 +++++++++++++++++++ wrappers/golang/curves/bls12381/include/msm.h | 2 ++ wrappers/golang/curves/bls12381/msm.go | 30 +++++++++++++++++++ wrappers/golang/curves/bn254/g2_msm.go | 30 +++++++++++++++++++ wrappers/golang/curves/bn254/include/msm.h | 2 ++ wrappers/golang/curves/bn254/msm.go | 30 +++++++++++++++++++ wrappers/golang/curves/bw6761/g2_msm.go | 30 +++++++++++++++++++ wrappers/golang/curves/bw6761/include/msm.h | 2 ++ wrappers/golang/curves/bw6761/msm.go | 30 +++++++++++++++++++ .../internal/generator/templates/msm.go.tmpl | 4 +-- 13 files changed, 250 insertions(+), 2 deletions(-) diff --git a/wrappers/golang/curves/bls12377/g2_msm.go b/wrappers/golang/curves/bls12377/g2_msm.go index 28fb546b2..f3b32c93f 100644 --- a/wrappers/golang/curves/bls12377/g2_msm.go +++ b/wrappers/golang/curves/bls12377/g2_msm.go @@ -49,3 +49,33 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c err := (cr.CudaError)(__ret) return err } + +func G2PrecomputeBases(points core.HostOrDeviceSlice, precompute_factor int32, _c int32, ctx *cr.DeviceContext, output_bases core.HostOrDeviceSlice) cr.CudaError { + core.PrecomputeBasesCheck(points, precompute_factor, output_bases) + + var pointsPointer unsafe.Pointer + if points.IsOnDevice() { + pointsPointer = points.(core.DeviceSlice).AsPointer() + } else { + pointsPointer = unsafe.Pointer(&points.(core.HostSlice[G2Affine])[0]) + } + cPoints := (*C.g2_affine_t)(pointsPointer) + + cPointsLen := (C.int)(points.Len()) + cPrecomputeFactor := (C.int)(precompute_factor) + c_C := (C.int)(_c) + cPointsIsOnDevice := (C._Bool)(points.IsOnDevice()) + cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx)) + + var outputBasesPointer unsafe.Pointer + if output_bases.IsOnDevice() { + outputBasesPointer = output_bases.(core.DeviceSlice).AsPointer() + } else { + outputBasesPointer = unsafe.Pointer(&output_bases.(core.HostSlice[G2Projective])[0]) + } + cOutputBases := (*C.g2_affine_t)(outputBasesPointer) + + __ret := C.bls12_377G2PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, c_C, cPointsIsOnDevice, cCtx, cOutputBases) + err := (cr.CudaError)(__ret) + return err +} diff --git a/wrappers/golang/curves/bls12377/include/msm.h b/wrappers/golang/curves/bls12377/include/msm.h index cf84172e8..e1cfec67c 100644 --- a/wrappers/golang/curves/bls12377/include/msm.h +++ b/wrappers/golang/curves/bls12377/include/msm.h @@ -1,5 +1,6 @@ #include #include "../../include/types.h" +#include #ifndef _BLS12_377_MSM_H #define _BLS12_377_MSM_H @@ -9,6 +10,7 @@ extern "C" { #endif cudaError_t bls12_377MSMCuda(scalar_t* scalars, affine_t* points, int count, MSMConfig* config, projective_t* out); +cudaError_t bls12_377PrecomputeMSMBases(affine_t* points, int bases_size, int precompute_factor, int _c, bool are_bases_on_device, DeviceContext* ctx, affine_t* output_bases); #ifdef __cplusplus } diff --git a/wrappers/golang/curves/bls12377/msm.go b/wrappers/golang/curves/bls12377/msm.go index bad05683e..286b9fb51 100644 --- a/wrappers/golang/curves/bls12377/msm.go +++ b/wrappers/golang/curves/bls12377/msm.go @@ -47,3 +47,33 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor err := (cr.CudaError)(__ret) return err } + +func PrecomputeBases(points core.HostOrDeviceSlice, precompute_factor int32, _c int32, ctx *cr.DeviceContext, output_bases core.HostOrDeviceSlice) cr.CudaError { + core.PrecomputeBasesCheck(points, precompute_factor, output_bases) + + var pointsPointer unsafe.Pointer + if points.IsOnDevice() { + pointsPointer = points.(core.DeviceSlice).AsPointer() + } else { + pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Affine])[0]) + } + cPoints := (*C.affine_t)(pointsPointer) + + cPointsLen := (C.int)(points.Len()) + cPrecomputeFactor := (C.int)(precompute_factor) + c_C := (C.int)(_c) + cPointsIsOnDevice := (C._Bool)(points.IsOnDevice()) + cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx)) + + var outputBasesPointer unsafe.Pointer + if output_bases.IsOnDevice() { + outputBasesPointer = output_bases.(core.DeviceSlice).AsPointer() + } else { + outputBasesPointer = unsafe.Pointer(&output_bases.(core.HostSlice[Projective])[0]) + } + cOutputBases := (*C.affine_t)(outputBasesPointer) + + __ret := C.bls12_377PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, c_C, cPointsIsOnDevice, cCtx, cOutputBases) + err := (cr.CudaError)(__ret) + return err +} diff --git a/wrappers/golang/curves/bls12381/g2_msm.go b/wrappers/golang/curves/bls12381/g2_msm.go index 4ecae93c3..b3eb01d65 100644 --- a/wrappers/golang/curves/bls12381/g2_msm.go +++ b/wrappers/golang/curves/bls12381/g2_msm.go @@ -49,3 +49,33 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c err := (cr.CudaError)(__ret) return err } + +func G2PrecomputeBases(points core.HostOrDeviceSlice, precompute_factor int32, _c int32, ctx *cr.DeviceContext, output_bases core.HostOrDeviceSlice) cr.CudaError { + core.PrecomputeBasesCheck(points, precompute_factor, output_bases) + + var pointsPointer unsafe.Pointer + if points.IsOnDevice() { + pointsPointer = points.(core.DeviceSlice).AsPointer() + } else { + pointsPointer = unsafe.Pointer(&points.(core.HostSlice[G2Affine])[0]) + } + cPoints := (*C.g2_affine_t)(pointsPointer) + + cPointsLen := (C.int)(points.Len()) + cPrecomputeFactor := (C.int)(precompute_factor) + c_C := (C.int)(_c) + cPointsIsOnDevice := (C._Bool)(points.IsOnDevice()) + cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx)) + + var outputBasesPointer unsafe.Pointer + if output_bases.IsOnDevice() { + outputBasesPointer = output_bases.(core.DeviceSlice).AsPointer() + } else { + outputBasesPointer = unsafe.Pointer(&output_bases.(core.HostSlice[G2Projective])[0]) + } + cOutputBases := (*C.g2_affine_t)(outputBasesPointer) + + __ret := C.bls12_381G2PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, c_C, cPointsIsOnDevice, cCtx, cOutputBases) + err := (cr.CudaError)(__ret) + return err +} diff --git a/wrappers/golang/curves/bls12381/include/msm.h b/wrappers/golang/curves/bls12381/include/msm.h index dc21a96a7..6f950eddb 100644 --- a/wrappers/golang/curves/bls12381/include/msm.h +++ b/wrappers/golang/curves/bls12381/include/msm.h @@ -1,5 +1,6 @@ #include #include "../../include/types.h" +#include #ifndef _BLS12_381_MSM_H #define _BLS12_381_MSM_H @@ -9,6 +10,7 @@ extern "C" { #endif cudaError_t bls12_381MSMCuda(scalar_t* scalars, affine_t* points, int count, MSMConfig* config, projective_t* out); +cudaError_t bls12_381PrecomputeMSMBases(affine_t* points, int bases_size, int precompute_factor, int _c, bool are_bases_on_device, DeviceContext* ctx, affine_t* output_bases); #ifdef __cplusplus } diff --git a/wrappers/golang/curves/bls12381/msm.go b/wrappers/golang/curves/bls12381/msm.go index 10d9f2860..8ab675190 100644 --- a/wrappers/golang/curves/bls12381/msm.go +++ b/wrappers/golang/curves/bls12381/msm.go @@ -47,3 +47,33 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor err := (cr.CudaError)(__ret) return err } + +func PrecomputeBases(points core.HostOrDeviceSlice, precompute_factor int32, _c int32, ctx *cr.DeviceContext, output_bases core.HostOrDeviceSlice) cr.CudaError { + core.PrecomputeBasesCheck(points, precompute_factor, output_bases) + + var pointsPointer unsafe.Pointer + if points.IsOnDevice() { + pointsPointer = points.(core.DeviceSlice).AsPointer() + } else { + pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Affine])[0]) + } + cPoints := (*C.affine_t)(pointsPointer) + + cPointsLen := (C.int)(points.Len()) + cPrecomputeFactor := (C.int)(precompute_factor) + c_C := (C.int)(_c) + cPointsIsOnDevice := (C._Bool)(points.IsOnDevice()) + cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx)) + + var outputBasesPointer unsafe.Pointer + if output_bases.IsOnDevice() { + outputBasesPointer = output_bases.(core.DeviceSlice).AsPointer() + } else { + outputBasesPointer = unsafe.Pointer(&output_bases.(core.HostSlice[Projective])[0]) + } + cOutputBases := (*C.affine_t)(outputBasesPointer) + + __ret := C.bls12_381PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, c_C, cPointsIsOnDevice, cCtx, cOutputBases) + err := (cr.CudaError)(__ret) + return err +} diff --git a/wrappers/golang/curves/bn254/g2_msm.go b/wrappers/golang/curves/bn254/g2_msm.go index b8a9a2b78..68602d5da 100644 --- a/wrappers/golang/curves/bn254/g2_msm.go +++ b/wrappers/golang/curves/bn254/g2_msm.go @@ -49,3 +49,33 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c err := (cr.CudaError)(__ret) return err } + +func G2PrecomputeBases(points core.HostOrDeviceSlice, precompute_factor int32, _c int32, ctx *cr.DeviceContext, output_bases core.HostOrDeviceSlice) cr.CudaError { + core.PrecomputeBasesCheck(points, precompute_factor, output_bases) + + var pointsPointer unsafe.Pointer + if points.IsOnDevice() { + pointsPointer = points.(core.DeviceSlice).AsPointer() + } else { + pointsPointer = unsafe.Pointer(&points.(core.HostSlice[G2Affine])[0]) + } + cPoints := (*C.g2_affine_t)(pointsPointer) + + cPointsLen := (C.int)(points.Len()) + cPrecomputeFactor := (C.int)(precompute_factor) + c_C := (C.int)(_c) + cPointsIsOnDevice := (C._Bool)(points.IsOnDevice()) + cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx)) + + var outputBasesPointer unsafe.Pointer + if output_bases.IsOnDevice() { + outputBasesPointer = output_bases.(core.DeviceSlice).AsPointer() + } else { + outputBasesPointer = unsafe.Pointer(&output_bases.(core.HostSlice[G2Projective])[0]) + } + cOutputBases := (*C.g2_affine_t)(outputBasesPointer) + + __ret := C.bn254G2PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, c_C, cPointsIsOnDevice, cCtx, cOutputBases) + err := (cr.CudaError)(__ret) + return err +} diff --git a/wrappers/golang/curves/bn254/include/msm.h b/wrappers/golang/curves/bn254/include/msm.h index 664032aa0..67a2c8c5e 100644 --- a/wrappers/golang/curves/bn254/include/msm.h +++ b/wrappers/golang/curves/bn254/include/msm.h @@ -1,5 +1,6 @@ #include #include "../../include/types.h" +#include #ifndef _BN254_MSM_H #define _BN254_MSM_H @@ -9,6 +10,7 @@ extern "C" { #endif cudaError_t bn254MSMCuda(scalar_t* scalars, affine_t* points, int count, MSMConfig* config, projective_t* out); +cudaError_t bn254PrecomputeMSMBases(affine_t* points, int bases_size, int precompute_factor, int _c, bool are_bases_on_device, DeviceContext* ctx, affine_t* output_bases); #ifdef __cplusplus } diff --git a/wrappers/golang/curves/bn254/msm.go b/wrappers/golang/curves/bn254/msm.go index 5ffea6126..8d9978661 100644 --- a/wrappers/golang/curves/bn254/msm.go +++ b/wrappers/golang/curves/bn254/msm.go @@ -47,3 +47,33 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor err := (cr.CudaError)(__ret) return err } + +func PrecomputeBases(points core.HostOrDeviceSlice, precompute_factor int32, _c int32, ctx *cr.DeviceContext, output_bases core.HostOrDeviceSlice) cr.CudaError { + core.PrecomputeBasesCheck(points, precompute_factor, output_bases) + + var pointsPointer unsafe.Pointer + if points.IsOnDevice() { + pointsPointer = points.(core.DeviceSlice).AsPointer() + } else { + pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Affine])[0]) + } + cPoints := (*C.affine_t)(pointsPointer) + + cPointsLen := (C.int)(points.Len()) + cPrecomputeFactor := (C.int)(precompute_factor) + c_C := (C.int)(_c) + cPointsIsOnDevice := (C._Bool)(points.IsOnDevice()) + cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx)) + + var outputBasesPointer unsafe.Pointer + if output_bases.IsOnDevice() { + outputBasesPointer = output_bases.(core.DeviceSlice).AsPointer() + } else { + outputBasesPointer = unsafe.Pointer(&output_bases.(core.HostSlice[Projective])[0]) + } + cOutputBases := (*C.affine_t)(outputBasesPointer) + + __ret := C.bn254PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, c_C, cPointsIsOnDevice, cCtx, cOutputBases) + err := (cr.CudaError)(__ret) + return err +} diff --git a/wrappers/golang/curves/bw6761/g2_msm.go b/wrappers/golang/curves/bw6761/g2_msm.go index 8d9a320ad..cc63e825b 100644 --- a/wrappers/golang/curves/bw6761/g2_msm.go +++ b/wrappers/golang/curves/bw6761/g2_msm.go @@ -49,3 +49,33 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c err := (cr.CudaError)(__ret) return err } + +func G2PrecomputeBases(points core.HostOrDeviceSlice, precompute_factor int32, _c int32, ctx *cr.DeviceContext, output_bases core.HostOrDeviceSlice) cr.CudaError { + core.PrecomputeBasesCheck(points, precompute_factor, output_bases) + + var pointsPointer unsafe.Pointer + if points.IsOnDevice() { + pointsPointer = points.(core.DeviceSlice).AsPointer() + } else { + pointsPointer = unsafe.Pointer(&points.(core.HostSlice[G2Affine])[0]) + } + cPoints := (*C.g2_affine_t)(pointsPointer) + + cPointsLen := (C.int)(points.Len()) + cPrecomputeFactor := (C.int)(precompute_factor) + c_C := (C.int)(_c) + cPointsIsOnDevice := (C._Bool)(points.IsOnDevice()) + cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx)) + + var outputBasesPointer unsafe.Pointer + if output_bases.IsOnDevice() { + outputBasesPointer = output_bases.(core.DeviceSlice).AsPointer() + } else { + outputBasesPointer = unsafe.Pointer(&output_bases.(core.HostSlice[G2Projective])[0]) + } + cOutputBases := (*C.g2_affine_t)(outputBasesPointer) + + __ret := C.bw6_761G2PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, c_C, cPointsIsOnDevice, cCtx, cOutputBases) + err := (cr.CudaError)(__ret) + return err +} diff --git a/wrappers/golang/curves/bw6761/include/msm.h b/wrappers/golang/curves/bw6761/include/msm.h index 8101cb5c0..166f0ee8b 100644 --- a/wrappers/golang/curves/bw6761/include/msm.h +++ b/wrappers/golang/curves/bw6761/include/msm.h @@ -1,5 +1,6 @@ #include #include "../../include/types.h" +#include #ifndef _BW6_761_MSM_H #define _BW6_761_MSM_H @@ -9,6 +10,7 @@ extern "C" { #endif cudaError_t bw6_761MSMCuda(scalar_t* scalars, affine_t* points, int count, MSMConfig* config, projective_t* out); +cudaError_t bw6_761PrecomputeMSMBases(affine_t* points, int bases_size, int precompute_factor, int _c, bool are_bases_on_device, DeviceContext* ctx, affine_t* output_bases); #ifdef __cplusplus } diff --git a/wrappers/golang/curves/bw6761/msm.go b/wrappers/golang/curves/bw6761/msm.go index 9bd2a8ce8..078ad8f66 100644 --- a/wrappers/golang/curves/bw6761/msm.go +++ b/wrappers/golang/curves/bw6761/msm.go @@ -47,3 +47,33 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor err := (cr.CudaError)(__ret) return err } + +func PrecomputeBases(points core.HostOrDeviceSlice, precompute_factor int32, _c int32, ctx *cr.DeviceContext, output_bases core.HostOrDeviceSlice) cr.CudaError { + core.PrecomputeBasesCheck(points, precompute_factor, output_bases) + + var pointsPointer unsafe.Pointer + if points.IsOnDevice() { + pointsPointer = points.(core.DeviceSlice).AsPointer() + } else { + pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Affine])[0]) + } + cPoints := (*C.affine_t)(pointsPointer) + + cPointsLen := (C.int)(points.Len()) + cPrecomputeFactor := (C.int)(precompute_factor) + c_C := (C.int)(_c) + cPointsIsOnDevice := (C._Bool)(points.IsOnDevice()) + cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx)) + + var outputBasesPointer unsafe.Pointer + if output_bases.IsOnDevice() { + outputBasesPointer = output_bases.(core.DeviceSlice).AsPointer() + } else { + outputBasesPointer = unsafe.Pointer(&output_bases.(core.HostSlice[Projective])[0]) + } + cOutputBases := (*C.affine_t)(outputBasesPointer) + + __ret := C.bw6_761PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, c_C, cPointsIsOnDevice, cCtx, cOutputBases) + err := (cr.CudaError)(__ret) + return err +} diff --git a/wrappers/golang/internal/generator/templates/msm.go.tmpl b/wrappers/golang/internal/generator/templates/msm.go.tmpl index bf6bea35f..4465340a7 100644 --- a/wrappers/golang/internal/generator/templates/msm.go.tmpl +++ b/wrappers/golang/internal/generator/templates/msm.go.tmpl @@ -66,7 +66,7 @@ func {{if .IsG2}}G2{{end}}PrecomputeBases(points core.HostOrDeviceSlice, precomp cPointsLen := (C.int)(points.Len()) cPrecomputeFactor := (C.int)(precompute_factor) c_C := (C.int)(_c) - cPointsIsOnDevice = (C._Bool)(points.IsOnDevice()) + cPointsIsOnDevice := (C._Bool)(points.IsOnDevice()) cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx)) var outputBasesPointer unsafe.Pointer @@ -75,7 +75,7 @@ func {{if .IsG2}}G2{{end}}PrecomputeBases(points core.HostOrDeviceSlice, precomp } else { outputBasesPointer = unsafe.Pointer(&output_bases.(core.HostSlice[{{if .IsG2}}G2{{end}}Projective])[0]) } - cOutputBases = (*C.{{if .IsG2}}g2_{{end}}projective_t)(outputBasesPointer) + cOutputBases := (*C.{{if .IsG2}}g2_{{end}}affine_t)(outputBasesPointer) __ret := C.{{.Curve}}{{if .IsG2}}G2{{end}}PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, c_C, cPointsIsOnDevice, cCtx, cOutputBases) err := (cr.CudaError)(__ret) From 2003b30de6e23571067313ce9b1cadc5d64b645c Mon Sep 17 00:00:00 2001 From: nonam3e Date: Tue, 12 Mar 2024 13:07:04 +0000 Subject: [PATCH 7/8] review changes --- wrappers/golang/core/msm.go | 12 ++++-------- wrappers/golang/curves/bls12377/g2_msm.go | 19 ++++++++----------- wrappers/golang/curves/bls12377/msm.go | 19 ++++++++----------- wrappers/golang/curves/bls12381/g2_msm.go | 19 ++++++++----------- wrappers/golang/curves/bls12381/msm.go | 19 ++++++++----------- wrappers/golang/curves/bn254/g2_msm.go | 19 ++++++++----------- wrappers/golang/curves/bn254/msm.go | 19 ++++++++----------- wrappers/golang/curves/bw6761/g2_msm.go | 19 ++++++++----------- wrappers/golang/curves/bw6761/msm.go | 19 ++++++++----------- .../internal/generator/templates/msm.go.tmpl | 19 ++++++++----------- 10 files changed, 76 insertions(+), 107 deletions(-) diff --git a/wrappers/golang/core/msm.go b/wrappers/golang/core/msm.go index b85e70942..063a46659 100644 --- a/wrappers/golang/core/msm.go +++ b/wrappers/golang/core/msm.go @@ -100,18 +100,14 @@ func MsmCheck(scalars HostOrDeviceSlice, points HostOrDeviceSlice, cfg *MSMConfi cfg.areResultsOnDevice = results.IsOnDevice() } -func PrecomputeBasesCheck(points HostOrDeviceSlice, precompute_factor int32, output_bases HostOrDeviceSlice) { - outputBasesLength, pointsLength := output_bases.Len(), points.Len() - if outputBasesLength != pointsLength*int(precompute_factor) { +func PrecomputeBasesCheck(points HostOrDeviceSlice, precomputeFactor int32, outputBases DeviceSlice) { + outputBasesLength, pointsLength := outputBases.Len(), points.Len() + if outputBasesLength != pointsLength*int(precomputeFactor) { errorString := fmt.Sprintf( "Precompute factor is probably incorrect: expected %d but got %d", outputBasesLength/pointsLength, - precompute_factor, + precomputeFactor, ) panic(errorString) } - - if !output_bases.IsOnDevice() { - panic("Output bases are not on device") - } } diff --git a/wrappers/golang/curves/bls12377/g2_msm.go b/wrappers/golang/curves/bls12377/g2_msm.go index f3b32c93f..4bafef02c 100644 --- a/wrappers/golang/curves/bls12377/g2_msm.go +++ b/wrappers/golang/curves/bls12377/g2_msm.go @@ -50,8 +50,8 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c return err } -func G2PrecomputeBases(points core.HostOrDeviceSlice, precompute_factor int32, _c int32, ctx *cr.DeviceContext, output_bases core.HostOrDeviceSlice) cr.CudaError { - core.PrecomputeBasesCheck(points, precompute_factor, output_bases) +func G2PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c int32, ctx *cr.DeviceContext, outputBases core.DeviceSlice) cr.CudaError { + core.PrecomputeBasesCheck(points, precomputeFactor, outputBases) var pointsPointer unsafe.Pointer if points.IsOnDevice() { @@ -62,20 +62,17 @@ func G2PrecomputeBases(points core.HostOrDeviceSlice, precompute_factor int32, _ cPoints := (*C.g2_affine_t)(pointsPointer) cPointsLen := (C.int)(points.Len()) - cPrecomputeFactor := (C.int)(precompute_factor) - c_C := (C.int)(_c) + cPrecomputeFactor := (C.int)(precomputeFactor) + cC := (C.int)(c) cPointsIsOnDevice := (C._Bool)(points.IsOnDevice()) - cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx)) + cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx)) var outputBasesPointer unsafe.Pointer - if output_bases.IsOnDevice() { - outputBasesPointer = output_bases.(core.DeviceSlice).AsPointer() - } else { - outputBasesPointer = unsafe.Pointer(&output_bases.(core.HostSlice[G2Projective])[0]) - } + + outputBasesPointer = outputBases.AsPointer() cOutputBases := (*C.g2_affine_t)(outputBasesPointer) - __ret := C.bls12_377G2PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, c_C, cPointsIsOnDevice, cCtx, cOutputBases) + __ret := C.bls12_377G2PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases) err := (cr.CudaError)(__ret) return err } diff --git a/wrappers/golang/curves/bls12377/msm.go b/wrappers/golang/curves/bls12377/msm.go index 286b9fb51..f14d861fc 100644 --- a/wrappers/golang/curves/bls12377/msm.go +++ b/wrappers/golang/curves/bls12377/msm.go @@ -48,8 +48,8 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor return err } -func PrecomputeBases(points core.HostOrDeviceSlice, precompute_factor int32, _c int32, ctx *cr.DeviceContext, output_bases core.HostOrDeviceSlice) cr.CudaError { - core.PrecomputeBasesCheck(points, precompute_factor, output_bases) +func PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c int32, ctx *cr.DeviceContext, outputBases core.DeviceSlice) cr.CudaError { + core.PrecomputeBasesCheck(points, precomputeFactor, outputBases) var pointsPointer unsafe.Pointer if points.IsOnDevice() { @@ -60,20 +60,17 @@ func PrecomputeBases(points core.HostOrDeviceSlice, precompute_factor int32, _c cPoints := (*C.affine_t)(pointsPointer) cPointsLen := (C.int)(points.Len()) - cPrecomputeFactor := (C.int)(precompute_factor) - c_C := (C.int)(_c) + cPrecomputeFactor := (C.int)(precomputeFactor) + cC := (C.int)(c) cPointsIsOnDevice := (C._Bool)(points.IsOnDevice()) - cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx)) + cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx)) var outputBasesPointer unsafe.Pointer - if output_bases.IsOnDevice() { - outputBasesPointer = output_bases.(core.DeviceSlice).AsPointer() - } else { - outputBasesPointer = unsafe.Pointer(&output_bases.(core.HostSlice[Projective])[0]) - } + + outputBasesPointer = outputBases.AsPointer() cOutputBases := (*C.affine_t)(outputBasesPointer) - __ret := C.bls12_377PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, c_C, cPointsIsOnDevice, cCtx, cOutputBases) + __ret := C.bls12_377PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases) err := (cr.CudaError)(__ret) return err } diff --git a/wrappers/golang/curves/bls12381/g2_msm.go b/wrappers/golang/curves/bls12381/g2_msm.go index b3eb01d65..61d507342 100644 --- a/wrappers/golang/curves/bls12381/g2_msm.go +++ b/wrappers/golang/curves/bls12381/g2_msm.go @@ -50,8 +50,8 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c return err } -func G2PrecomputeBases(points core.HostOrDeviceSlice, precompute_factor int32, _c int32, ctx *cr.DeviceContext, output_bases core.HostOrDeviceSlice) cr.CudaError { - core.PrecomputeBasesCheck(points, precompute_factor, output_bases) +func G2PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c int32, ctx *cr.DeviceContext, outputBases core.DeviceSlice) cr.CudaError { + core.PrecomputeBasesCheck(points, precomputeFactor, outputBases) var pointsPointer unsafe.Pointer if points.IsOnDevice() { @@ -62,20 +62,17 @@ func G2PrecomputeBases(points core.HostOrDeviceSlice, precompute_factor int32, _ cPoints := (*C.g2_affine_t)(pointsPointer) cPointsLen := (C.int)(points.Len()) - cPrecomputeFactor := (C.int)(precompute_factor) - c_C := (C.int)(_c) + cPrecomputeFactor := (C.int)(precomputeFactor) + cC := (C.int)(c) cPointsIsOnDevice := (C._Bool)(points.IsOnDevice()) - cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx)) + cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx)) var outputBasesPointer unsafe.Pointer - if output_bases.IsOnDevice() { - outputBasesPointer = output_bases.(core.DeviceSlice).AsPointer() - } else { - outputBasesPointer = unsafe.Pointer(&output_bases.(core.HostSlice[G2Projective])[0]) - } + + outputBasesPointer = outputBases.AsPointer() cOutputBases := (*C.g2_affine_t)(outputBasesPointer) - __ret := C.bls12_381G2PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, c_C, cPointsIsOnDevice, cCtx, cOutputBases) + __ret := C.bls12_381G2PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases) err := (cr.CudaError)(__ret) return err } diff --git a/wrappers/golang/curves/bls12381/msm.go b/wrappers/golang/curves/bls12381/msm.go index 8ab675190..bb6331d83 100644 --- a/wrappers/golang/curves/bls12381/msm.go +++ b/wrappers/golang/curves/bls12381/msm.go @@ -48,8 +48,8 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor return err } -func PrecomputeBases(points core.HostOrDeviceSlice, precompute_factor int32, _c int32, ctx *cr.DeviceContext, output_bases core.HostOrDeviceSlice) cr.CudaError { - core.PrecomputeBasesCheck(points, precompute_factor, output_bases) +func PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c int32, ctx *cr.DeviceContext, outputBases core.DeviceSlice) cr.CudaError { + core.PrecomputeBasesCheck(points, precomputeFactor, outputBases) var pointsPointer unsafe.Pointer if points.IsOnDevice() { @@ -60,20 +60,17 @@ func PrecomputeBases(points core.HostOrDeviceSlice, precompute_factor int32, _c cPoints := (*C.affine_t)(pointsPointer) cPointsLen := (C.int)(points.Len()) - cPrecomputeFactor := (C.int)(precompute_factor) - c_C := (C.int)(_c) + cPrecomputeFactor := (C.int)(precomputeFactor) + cC := (C.int)(c) cPointsIsOnDevice := (C._Bool)(points.IsOnDevice()) - cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx)) + cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx)) var outputBasesPointer unsafe.Pointer - if output_bases.IsOnDevice() { - outputBasesPointer = output_bases.(core.DeviceSlice).AsPointer() - } else { - outputBasesPointer = unsafe.Pointer(&output_bases.(core.HostSlice[Projective])[0]) - } + + outputBasesPointer = outputBases.AsPointer() cOutputBases := (*C.affine_t)(outputBasesPointer) - __ret := C.bls12_381PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, c_C, cPointsIsOnDevice, cCtx, cOutputBases) + __ret := C.bls12_381PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases) err := (cr.CudaError)(__ret) return err } diff --git a/wrappers/golang/curves/bn254/g2_msm.go b/wrappers/golang/curves/bn254/g2_msm.go index 68602d5da..4e47bb9d1 100644 --- a/wrappers/golang/curves/bn254/g2_msm.go +++ b/wrappers/golang/curves/bn254/g2_msm.go @@ -50,8 +50,8 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c return err } -func G2PrecomputeBases(points core.HostOrDeviceSlice, precompute_factor int32, _c int32, ctx *cr.DeviceContext, output_bases core.HostOrDeviceSlice) cr.CudaError { - core.PrecomputeBasesCheck(points, precompute_factor, output_bases) +func G2PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c int32, ctx *cr.DeviceContext, outputBases core.DeviceSlice) cr.CudaError { + core.PrecomputeBasesCheck(points, precomputeFactor, outputBases) var pointsPointer unsafe.Pointer if points.IsOnDevice() { @@ -62,20 +62,17 @@ func G2PrecomputeBases(points core.HostOrDeviceSlice, precompute_factor int32, _ cPoints := (*C.g2_affine_t)(pointsPointer) cPointsLen := (C.int)(points.Len()) - cPrecomputeFactor := (C.int)(precompute_factor) - c_C := (C.int)(_c) + cPrecomputeFactor := (C.int)(precomputeFactor) + cC := (C.int)(c) cPointsIsOnDevice := (C._Bool)(points.IsOnDevice()) - cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx)) + cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx)) var outputBasesPointer unsafe.Pointer - if output_bases.IsOnDevice() { - outputBasesPointer = output_bases.(core.DeviceSlice).AsPointer() - } else { - outputBasesPointer = unsafe.Pointer(&output_bases.(core.HostSlice[G2Projective])[0]) - } + + outputBasesPointer = outputBases.AsPointer() cOutputBases := (*C.g2_affine_t)(outputBasesPointer) - __ret := C.bn254G2PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, c_C, cPointsIsOnDevice, cCtx, cOutputBases) + __ret := C.bn254G2PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases) err := (cr.CudaError)(__ret) return err } diff --git a/wrappers/golang/curves/bn254/msm.go b/wrappers/golang/curves/bn254/msm.go index 8d9978661..6f061bd78 100644 --- a/wrappers/golang/curves/bn254/msm.go +++ b/wrappers/golang/curves/bn254/msm.go @@ -48,8 +48,8 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor return err } -func PrecomputeBases(points core.HostOrDeviceSlice, precompute_factor int32, _c int32, ctx *cr.DeviceContext, output_bases core.HostOrDeviceSlice) cr.CudaError { - core.PrecomputeBasesCheck(points, precompute_factor, output_bases) +func PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c int32, ctx *cr.DeviceContext, outputBases core.DeviceSlice) cr.CudaError { + core.PrecomputeBasesCheck(points, precomputeFactor, outputBases) var pointsPointer unsafe.Pointer if points.IsOnDevice() { @@ -60,20 +60,17 @@ func PrecomputeBases(points core.HostOrDeviceSlice, precompute_factor int32, _c cPoints := (*C.affine_t)(pointsPointer) cPointsLen := (C.int)(points.Len()) - cPrecomputeFactor := (C.int)(precompute_factor) - c_C := (C.int)(_c) + cPrecomputeFactor := (C.int)(precomputeFactor) + cC := (C.int)(c) cPointsIsOnDevice := (C._Bool)(points.IsOnDevice()) - cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx)) + cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx)) var outputBasesPointer unsafe.Pointer - if output_bases.IsOnDevice() { - outputBasesPointer = output_bases.(core.DeviceSlice).AsPointer() - } else { - outputBasesPointer = unsafe.Pointer(&output_bases.(core.HostSlice[Projective])[0]) - } + + outputBasesPointer = outputBases.AsPointer() cOutputBases := (*C.affine_t)(outputBasesPointer) - __ret := C.bn254PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, c_C, cPointsIsOnDevice, cCtx, cOutputBases) + __ret := C.bn254PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases) err := (cr.CudaError)(__ret) return err } diff --git a/wrappers/golang/curves/bw6761/g2_msm.go b/wrappers/golang/curves/bw6761/g2_msm.go index cc63e825b..1782f400e 100644 --- a/wrappers/golang/curves/bw6761/g2_msm.go +++ b/wrappers/golang/curves/bw6761/g2_msm.go @@ -50,8 +50,8 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c return err } -func G2PrecomputeBases(points core.HostOrDeviceSlice, precompute_factor int32, _c int32, ctx *cr.DeviceContext, output_bases core.HostOrDeviceSlice) cr.CudaError { - core.PrecomputeBasesCheck(points, precompute_factor, output_bases) +func G2PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c int32, ctx *cr.DeviceContext, outputBases core.DeviceSlice) cr.CudaError { + core.PrecomputeBasesCheck(points, precomputeFactor, outputBases) var pointsPointer unsafe.Pointer if points.IsOnDevice() { @@ -62,20 +62,17 @@ func G2PrecomputeBases(points core.HostOrDeviceSlice, precompute_factor int32, _ cPoints := (*C.g2_affine_t)(pointsPointer) cPointsLen := (C.int)(points.Len()) - cPrecomputeFactor := (C.int)(precompute_factor) - c_C := (C.int)(_c) + cPrecomputeFactor := (C.int)(precomputeFactor) + cC := (C.int)(c) cPointsIsOnDevice := (C._Bool)(points.IsOnDevice()) - cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx)) + cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx)) var outputBasesPointer unsafe.Pointer - if output_bases.IsOnDevice() { - outputBasesPointer = output_bases.(core.DeviceSlice).AsPointer() - } else { - outputBasesPointer = unsafe.Pointer(&output_bases.(core.HostSlice[G2Projective])[0]) - } + + outputBasesPointer = outputBases.AsPointer() cOutputBases := (*C.g2_affine_t)(outputBasesPointer) - __ret := C.bw6_761G2PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, c_C, cPointsIsOnDevice, cCtx, cOutputBases) + __ret := C.bw6_761G2PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases) err := (cr.CudaError)(__ret) return err } diff --git a/wrappers/golang/curves/bw6761/msm.go b/wrappers/golang/curves/bw6761/msm.go index 078ad8f66..771e4d5f2 100644 --- a/wrappers/golang/curves/bw6761/msm.go +++ b/wrappers/golang/curves/bw6761/msm.go @@ -48,8 +48,8 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor return err } -func PrecomputeBases(points core.HostOrDeviceSlice, precompute_factor int32, _c int32, ctx *cr.DeviceContext, output_bases core.HostOrDeviceSlice) cr.CudaError { - core.PrecomputeBasesCheck(points, precompute_factor, output_bases) +func PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c int32, ctx *cr.DeviceContext, outputBases core.DeviceSlice) cr.CudaError { + core.PrecomputeBasesCheck(points, precomputeFactor, outputBases) var pointsPointer unsafe.Pointer if points.IsOnDevice() { @@ -60,20 +60,17 @@ func PrecomputeBases(points core.HostOrDeviceSlice, precompute_factor int32, _c cPoints := (*C.affine_t)(pointsPointer) cPointsLen := (C.int)(points.Len()) - cPrecomputeFactor := (C.int)(precompute_factor) - c_C := (C.int)(_c) + cPrecomputeFactor := (C.int)(precomputeFactor) + cC := (C.int)(c) cPointsIsOnDevice := (C._Bool)(points.IsOnDevice()) - cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx)) + cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx)) var outputBasesPointer unsafe.Pointer - if output_bases.IsOnDevice() { - outputBasesPointer = output_bases.(core.DeviceSlice).AsPointer() - } else { - outputBasesPointer = unsafe.Pointer(&output_bases.(core.HostSlice[Projective])[0]) - } + + outputBasesPointer = outputBases.AsPointer() cOutputBases := (*C.affine_t)(outputBasesPointer) - __ret := C.bw6_761PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, c_C, cPointsIsOnDevice, cCtx, cOutputBases) + __ret := C.bw6_761PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases) err := (cr.CudaError)(__ret) return err } diff --git a/wrappers/golang/internal/generator/templates/msm.go.tmpl b/wrappers/golang/internal/generator/templates/msm.go.tmpl index 4465340a7..4b69aa9c9 100644 --- a/wrappers/golang/internal/generator/templates/msm.go.tmpl +++ b/wrappers/golang/internal/generator/templates/msm.go.tmpl @@ -52,8 +52,8 @@ func {{if .IsG2}}G2{{end}}Msm(scalars core.HostOrDeviceSlice, points core.HostOr return err } -func {{if .IsG2}}G2{{end}}PrecomputeBases(points core.HostOrDeviceSlice, precompute_factor int32, _c int32, ctx *cr.DeviceContext, output_bases core.HostOrDeviceSlice) cr.CudaError { - core.PrecomputeBasesCheck(points, precompute_factor, output_bases) +func {{if .IsG2}}G2{{end}}PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c int32, ctx *cr.DeviceContext, outputBases core.DeviceSlice) cr.CudaError { + core.PrecomputeBasesCheck(points, precomputeFactor, outputBases) var pointsPointer unsafe.Pointer if points.IsOnDevice() { @@ -64,20 +64,17 @@ func {{if .IsG2}}G2{{end}}PrecomputeBases(points core.HostOrDeviceSlice, precomp cPoints := (*C.{{if .IsG2}}g2_{{end}}affine_t)(pointsPointer) cPointsLen := (C.int)(points.Len()) - cPrecomputeFactor := (C.int)(precompute_factor) - c_C := (C.int)(_c) + cPrecomputeFactor := (C.int)(precomputeFactor) + cC := (C.int)(c) cPointsIsOnDevice := (C._Bool)(points.IsOnDevice()) - cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx)) + cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx)) var outputBasesPointer unsafe.Pointer - if output_bases.IsOnDevice() { - outputBasesPointer = output_bases.(core.DeviceSlice).AsPointer() - } else { - outputBasesPointer = unsafe.Pointer(&output_bases.(core.HostSlice[{{if .IsG2}}G2{{end}}Projective])[0]) - } + + outputBasesPointer = outputBases.AsPointer() cOutputBases := (*C.{{if .IsG2}}g2_{{end}}affine_t)(outputBasesPointer) - __ret := C.{{.Curve}}{{if .IsG2}}G2{{end}}PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, c_C, cPointsIsOnDevice, cCtx, cOutputBases) + __ret := C.{{.Curve}}{{if .IsG2}}G2{{end}}PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases) err := (cr.CudaError)(__ret) return err } From 2e07c4a00df1929854d4eba2577a1c9e355df49d Mon Sep 17 00:00:00 2001 From: nonam3e Date: Wed, 13 Mar 2024 16:01:11 +0000 Subject: [PATCH 8/8] add precopute tests --- wrappers/golang/core/msm.go | 2 +- wrappers/golang/curves/bls12377/g2_msm.go | 4 +- .../golang/curves/bls12377/g2_msm_test.go | 43 +++++++++++++++++++ wrappers/golang/curves/bls12377/msm.go | 4 +- wrappers/golang/curves/bls12377/msm_test.go | 43 +++++++++++++++++++ wrappers/golang/curves/bls12381/g2_msm.go | 4 +- .../golang/curves/bls12381/g2_msm_test.go | 43 +++++++++++++++++++ wrappers/golang/curves/bls12381/msm.go | 4 +- wrappers/golang/curves/bls12381/msm_test.go | 43 +++++++++++++++++++ wrappers/golang/curves/bn254/g2_msm.go | 4 +- wrappers/golang/curves/bn254/g2_msm_test.go | 43 +++++++++++++++++++ wrappers/golang/curves/bn254/msm.go | 4 +- wrappers/golang/curves/bn254/msm_test.go | 43 +++++++++++++++++++ wrappers/golang/curves/bw6761/g2_msm.go | 4 +- wrappers/golang/curves/bw6761/g2_msm_test.go | 43 +++++++++++++++++++ wrappers/golang/curves/bw6761/msm.go | 4 +- wrappers/golang/curves/bw6761/msm_test.go | 43 +++++++++++++++++++ .../internal/generator/templates/msm.go.tmpl | 4 +- .../generator/templates/msm_test.go.tmpl | 43 +++++++++++++++++++ 19 files changed, 397 insertions(+), 28 deletions(-) diff --git a/wrappers/golang/core/msm.go b/wrappers/golang/core/msm.go index 063a46659..6229cf805 100644 --- a/wrappers/golang/core/msm.go +++ b/wrappers/golang/core/msm.go @@ -76,7 +76,7 @@ func GetDefaultMSMConfig() MSMConfig { } func MsmCheck(scalars HostOrDeviceSlice, points HostOrDeviceSlice, cfg *MSMConfig, results HostOrDeviceSlice) { - scalarsLength, pointsLength, resultsLength := scalars.Len(), points.Len(), results.Len() + scalarsLength, pointsLength, resultsLength := scalars.Len(), points.Len()/int(cfg.PrecomputeFactor), results.Len() if scalarsLength%pointsLength != 0 { errorString := fmt.Sprintf( "Number of points %d does not divide the number of scalars %d", diff --git a/wrappers/golang/curves/bls12377/g2_msm.go b/wrappers/golang/curves/bls12377/g2_msm.go index 4bafef02c..1380ab6fe 100644 --- a/wrappers/golang/curves/bls12377/g2_msm.go +++ b/wrappers/golang/curves/bls12377/g2_msm.go @@ -66,10 +66,8 @@ func G2PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c cC := (C.int)(c) cPointsIsOnDevice := (C._Bool)(points.IsOnDevice()) cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx)) - - var outputBasesPointer unsafe.Pointer - outputBasesPointer = outputBases.AsPointer() + outputBasesPointer := outputBases.AsPointer() cOutputBases := (*C.g2_affine_t)(outputBasesPointer) __ret := C.bls12_377G2PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases) diff --git a/wrappers/golang/curves/bls12377/g2_msm_test.go b/wrappers/golang/curves/bls12377/g2_msm_test.go index eb8432d64..a02afc57d 100644 --- a/wrappers/golang/curves/bls12377/g2_msm_test.go +++ b/wrappers/golang/curves/bls12377/g2_msm_test.go @@ -136,6 +136,49 @@ func TestMSMG2Batch(t *testing.T) { } } +func TestPrecomputeBaseG2(t *testing.T) { + cfg := GetDefaultMSMConfig() + const precomputeFactor = 8 + for _, power := range []int{10, 16} { + for _, batchSize := range []int{1, 3, 16} { + size := 1 << power + totalSize := size * batchSize + scalars := GenerateScalars(totalSize) + points := G2GenerateAffinePoints(totalSize) + + var precomputeOut core.DeviceSlice + _, e := precomputeOut.Malloc(points[0].Size()*points.Len()*int(precomputeFactor), points[0].Size()) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for PrecomputeBases results failed") + + e = G2PrecomputeBases(points, precomputeFactor, 0, &cfg.Ctx, precomputeOut) + assert.Equal(t, e, cr.CudaSuccess, "PrecomputeBases failed") + + var p G2Projective + var out core.DeviceSlice + _, e = out.Malloc(batchSize*p.Size(), p.Size()) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed") + + cfg.PrecomputeFactor = precomputeFactor + + e = G2Msm(scalars, precomputeOut, &cfg, out) + assert.Equal(t, e, cr.CudaSuccess, "Msm failed") + outHost := make(core.HostSlice[G2Projective], batchSize) + outHost.CopyFromDevice(&out) + out.Free() + precomputeOut.Free() + + // Check with gnark-crypto + for i := 0; i < batchSize; i++ { + scalarsSlice := scalars[i*size : (i+1)*size] + pointsSlice := points[i*size : (i+1)*size] + out := outHost[i] + assert.True(t, testAgainstGnarkCryptoMsmG2(scalarsSlice, pointsSlice, out)) + } + } + } +} + + func TestMSMG2SkewedDistribution(t *testing.T) { cfg := GetDefaultMSMConfig() for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} { diff --git a/wrappers/golang/curves/bls12377/msm.go b/wrappers/golang/curves/bls12377/msm.go index f14d861fc..4af06a31d 100644 --- a/wrappers/golang/curves/bls12377/msm.go +++ b/wrappers/golang/curves/bls12377/msm.go @@ -64,10 +64,8 @@ func PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c in cC := (C.int)(c) cPointsIsOnDevice := (C._Bool)(points.IsOnDevice()) cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx)) - - var outputBasesPointer unsafe.Pointer - outputBasesPointer = outputBases.AsPointer() + outputBasesPointer := outputBases.AsPointer() cOutputBases := (*C.affine_t)(outputBasesPointer) __ret := C.bls12_377PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases) diff --git a/wrappers/golang/curves/bls12377/msm_test.go b/wrappers/golang/curves/bls12377/msm_test.go index 5c41051fa..88fa03619 100644 --- a/wrappers/golang/curves/bls12377/msm_test.go +++ b/wrappers/golang/curves/bls12377/msm_test.go @@ -107,6 +107,49 @@ func TestMSMBatch(t *testing.T) { } } +func TestPrecomputeBase(t *testing.T) { + cfg := GetDefaultMSMConfig() + const precomputeFactor = 8 + for _, power := range []int{10, 16} { + for _, batchSize := range []int{1, 3, 16} { + size := 1 << power + totalSize := size * batchSize + scalars := GenerateScalars(totalSize) + points := GenerateAffinePoints(totalSize) + + var precomputeOut core.DeviceSlice + _, e := precomputeOut.Malloc(points[0].Size()*points.Len()*int(precomputeFactor), points[0].Size()) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for PrecomputeBases results failed") + + e = PrecomputeBases(points, precomputeFactor, 0, &cfg.Ctx, precomputeOut) + assert.Equal(t, e, cr.CudaSuccess, "PrecomputeBases failed") + + var p Projective + var out core.DeviceSlice + _, e = out.Malloc(batchSize*p.Size(), p.Size()) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed") + + cfg.PrecomputeFactor = precomputeFactor + + e = Msm(scalars, precomputeOut, &cfg, out) + assert.Equal(t, e, cr.CudaSuccess, "Msm failed") + outHost := make(core.HostSlice[Projective], batchSize) + outHost.CopyFromDevice(&out) + out.Free() + precomputeOut.Free() + + // Check with gnark-crypto + for i := 0; i < batchSize; i++ { + scalarsSlice := scalars[i*size : (i+1)*size] + pointsSlice := points[i*size : (i+1)*size] + out := outHost[i] + assert.True(t, testAgainstGnarkCryptoMsm(scalarsSlice, pointsSlice, out)) + } + } + } +} + + func TestMSMSkewedDistribution(t *testing.T) { cfg := GetDefaultMSMConfig() for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} { diff --git a/wrappers/golang/curves/bls12381/g2_msm.go b/wrappers/golang/curves/bls12381/g2_msm.go index 61d507342..1a82c0e4d 100644 --- a/wrappers/golang/curves/bls12381/g2_msm.go +++ b/wrappers/golang/curves/bls12381/g2_msm.go @@ -66,10 +66,8 @@ func G2PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c cC := (C.int)(c) cPointsIsOnDevice := (C._Bool)(points.IsOnDevice()) cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx)) - - var outputBasesPointer unsafe.Pointer - outputBasesPointer = outputBases.AsPointer() + outputBasesPointer := outputBases.AsPointer() cOutputBases := (*C.g2_affine_t)(outputBasesPointer) __ret := C.bls12_381G2PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases) diff --git a/wrappers/golang/curves/bls12381/g2_msm_test.go b/wrappers/golang/curves/bls12381/g2_msm_test.go index 9ce775127..75abb4507 100644 --- a/wrappers/golang/curves/bls12381/g2_msm_test.go +++ b/wrappers/golang/curves/bls12381/g2_msm_test.go @@ -136,6 +136,49 @@ func TestMSMG2Batch(t *testing.T) { } } +func TestPrecomputeBaseG2(t *testing.T) { + cfg := GetDefaultMSMConfig() + const precomputeFactor = 8 + for _, power := range []int{10, 16} { + for _, batchSize := range []int{1, 3, 16} { + size := 1 << power + totalSize := size * batchSize + scalars := GenerateScalars(totalSize) + points := G2GenerateAffinePoints(totalSize) + + var precomputeOut core.DeviceSlice + _, e := precomputeOut.Malloc(points[0].Size()*points.Len()*int(precomputeFactor), points[0].Size()) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for PrecomputeBases results failed") + + e = G2PrecomputeBases(points, precomputeFactor, 0, &cfg.Ctx, precomputeOut) + assert.Equal(t, e, cr.CudaSuccess, "PrecomputeBases failed") + + var p G2Projective + var out core.DeviceSlice + _, e = out.Malloc(batchSize*p.Size(), p.Size()) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed") + + cfg.PrecomputeFactor = precomputeFactor + + e = G2Msm(scalars, precomputeOut, &cfg, out) + assert.Equal(t, e, cr.CudaSuccess, "Msm failed") + outHost := make(core.HostSlice[G2Projective], batchSize) + outHost.CopyFromDevice(&out) + out.Free() + precomputeOut.Free() + + // Check with gnark-crypto + for i := 0; i < batchSize; i++ { + scalarsSlice := scalars[i*size : (i+1)*size] + pointsSlice := points[i*size : (i+1)*size] + out := outHost[i] + assert.True(t, testAgainstGnarkCryptoMsmG2(scalarsSlice, pointsSlice, out)) + } + } + } +} + + func TestMSMG2SkewedDistribution(t *testing.T) { cfg := GetDefaultMSMConfig() for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} { diff --git a/wrappers/golang/curves/bls12381/msm.go b/wrappers/golang/curves/bls12381/msm.go index bb6331d83..bd51db0cd 100644 --- a/wrappers/golang/curves/bls12381/msm.go +++ b/wrappers/golang/curves/bls12381/msm.go @@ -64,10 +64,8 @@ func PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c in cC := (C.int)(c) cPointsIsOnDevice := (C._Bool)(points.IsOnDevice()) cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx)) - - var outputBasesPointer unsafe.Pointer - outputBasesPointer = outputBases.AsPointer() + outputBasesPointer := outputBases.AsPointer() cOutputBases := (*C.affine_t)(outputBasesPointer) __ret := C.bls12_381PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases) diff --git a/wrappers/golang/curves/bls12381/msm_test.go b/wrappers/golang/curves/bls12381/msm_test.go index 7f2d49358..7c05038b5 100644 --- a/wrappers/golang/curves/bls12381/msm_test.go +++ b/wrappers/golang/curves/bls12381/msm_test.go @@ -107,6 +107,49 @@ func TestMSMBatch(t *testing.T) { } } +func TestPrecomputeBase(t *testing.T) { + cfg := GetDefaultMSMConfig() + const precomputeFactor = 8 + for _, power := range []int{10, 16} { + for _, batchSize := range []int{1, 3, 16} { + size := 1 << power + totalSize := size * batchSize + scalars := GenerateScalars(totalSize) + points := GenerateAffinePoints(totalSize) + + var precomputeOut core.DeviceSlice + _, e := precomputeOut.Malloc(points[0].Size()*points.Len()*int(precomputeFactor), points[0].Size()) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for PrecomputeBases results failed") + + e = PrecomputeBases(points, precomputeFactor, 0, &cfg.Ctx, precomputeOut) + assert.Equal(t, e, cr.CudaSuccess, "PrecomputeBases failed") + + var p Projective + var out core.DeviceSlice + _, e = out.Malloc(batchSize*p.Size(), p.Size()) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed") + + cfg.PrecomputeFactor = precomputeFactor + + e = Msm(scalars, precomputeOut, &cfg, out) + assert.Equal(t, e, cr.CudaSuccess, "Msm failed") + outHost := make(core.HostSlice[Projective], batchSize) + outHost.CopyFromDevice(&out) + out.Free() + precomputeOut.Free() + + // Check with gnark-crypto + for i := 0; i < batchSize; i++ { + scalarsSlice := scalars[i*size : (i+1)*size] + pointsSlice := points[i*size : (i+1)*size] + out := outHost[i] + assert.True(t, testAgainstGnarkCryptoMsm(scalarsSlice, pointsSlice, out)) + } + } + } +} + + func TestMSMSkewedDistribution(t *testing.T) { cfg := GetDefaultMSMConfig() for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} { diff --git a/wrappers/golang/curves/bn254/g2_msm.go b/wrappers/golang/curves/bn254/g2_msm.go index 4e47bb9d1..1dc546525 100644 --- a/wrappers/golang/curves/bn254/g2_msm.go +++ b/wrappers/golang/curves/bn254/g2_msm.go @@ -66,10 +66,8 @@ func G2PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c cC := (C.int)(c) cPointsIsOnDevice := (C._Bool)(points.IsOnDevice()) cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx)) - - var outputBasesPointer unsafe.Pointer - outputBasesPointer = outputBases.AsPointer() + outputBasesPointer := outputBases.AsPointer() cOutputBases := (*C.g2_affine_t)(outputBasesPointer) __ret := C.bn254G2PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases) diff --git a/wrappers/golang/curves/bn254/g2_msm_test.go b/wrappers/golang/curves/bn254/g2_msm_test.go index c01b3e745..d1f49040b 100644 --- a/wrappers/golang/curves/bn254/g2_msm_test.go +++ b/wrappers/golang/curves/bn254/g2_msm_test.go @@ -136,6 +136,49 @@ func TestMSMG2Batch(t *testing.T) { } } +func TestPrecomputeBaseG2(t *testing.T) { + cfg := GetDefaultMSMConfig() + const precomputeFactor = 8 + for _, power := range []int{10, 16} { + for _, batchSize := range []int{1, 3, 16} { + size := 1 << power + totalSize := size * batchSize + scalars := GenerateScalars(totalSize) + points := G2GenerateAffinePoints(totalSize) + + var precomputeOut core.DeviceSlice + _, e := precomputeOut.Malloc(points[0].Size()*points.Len()*int(precomputeFactor), points[0].Size()) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for PrecomputeBases results failed") + + e = G2PrecomputeBases(points, precomputeFactor, 0, &cfg.Ctx, precomputeOut) + assert.Equal(t, e, cr.CudaSuccess, "PrecomputeBases failed") + + var p G2Projective + var out core.DeviceSlice + _, e = out.Malloc(batchSize*p.Size(), p.Size()) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed") + + cfg.PrecomputeFactor = precomputeFactor + + e = G2Msm(scalars, precomputeOut, &cfg, out) + assert.Equal(t, e, cr.CudaSuccess, "Msm failed") + outHost := make(core.HostSlice[G2Projective], batchSize) + outHost.CopyFromDevice(&out) + out.Free() + precomputeOut.Free() + + // Check with gnark-crypto + for i := 0; i < batchSize; i++ { + scalarsSlice := scalars[i*size : (i+1)*size] + pointsSlice := points[i*size : (i+1)*size] + out := outHost[i] + assert.True(t, testAgainstGnarkCryptoMsmG2(scalarsSlice, pointsSlice, out)) + } + } + } +} + + func TestMSMG2SkewedDistribution(t *testing.T) { cfg := GetDefaultMSMConfig() for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} { diff --git a/wrappers/golang/curves/bn254/msm.go b/wrappers/golang/curves/bn254/msm.go index 6f061bd78..1d773d643 100644 --- a/wrappers/golang/curves/bn254/msm.go +++ b/wrappers/golang/curves/bn254/msm.go @@ -64,10 +64,8 @@ func PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c in cC := (C.int)(c) cPointsIsOnDevice := (C._Bool)(points.IsOnDevice()) cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx)) - - var outputBasesPointer unsafe.Pointer - outputBasesPointer = outputBases.AsPointer() + outputBasesPointer := outputBases.AsPointer() cOutputBases := (*C.affine_t)(outputBasesPointer) __ret := C.bn254PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases) diff --git a/wrappers/golang/curves/bn254/msm_test.go b/wrappers/golang/curves/bn254/msm_test.go index 39834dfbd..37b5799ec 100644 --- a/wrappers/golang/curves/bn254/msm_test.go +++ b/wrappers/golang/curves/bn254/msm_test.go @@ -107,6 +107,49 @@ func TestMSMBatch(t *testing.T) { } } +func TestPrecomputeBase(t *testing.T) { + cfg := GetDefaultMSMConfig() + const precomputeFactor = 8 + for _, power := range []int{10, 16} { + for _, batchSize := range []int{1, 3, 16} { + size := 1 << power + totalSize := size * batchSize + scalars := GenerateScalars(totalSize) + points := GenerateAffinePoints(totalSize) + + var precomputeOut core.DeviceSlice + _, e := precomputeOut.Malloc(points[0].Size()*points.Len()*int(precomputeFactor), points[0].Size()) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for PrecomputeBases results failed") + + e = PrecomputeBases(points, precomputeFactor, 0, &cfg.Ctx, precomputeOut) + assert.Equal(t, e, cr.CudaSuccess, "PrecomputeBases failed") + + var p Projective + var out core.DeviceSlice + _, e = out.Malloc(batchSize*p.Size(), p.Size()) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed") + + cfg.PrecomputeFactor = precomputeFactor + + e = Msm(scalars, precomputeOut, &cfg, out) + assert.Equal(t, e, cr.CudaSuccess, "Msm failed") + outHost := make(core.HostSlice[Projective], batchSize) + outHost.CopyFromDevice(&out) + out.Free() + precomputeOut.Free() + + // Check with gnark-crypto + for i := 0; i < batchSize; i++ { + scalarsSlice := scalars[i*size : (i+1)*size] + pointsSlice := points[i*size : (i+1)*size] + out := outHost[i] + assert.True(t, testAgainstGnarkCryptoMsm(scalarsSlice, pointsSlice, out)) + } + } + } +} + + func TestMSMSkewedDistribution(t *testing.T) { cfg := GetDefaultMSMConfig() for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} { diff --git a/wrappers/golang/curves/bw6761/g2_msm.go b/wrappers/golang/curves/bw6761/g2_msm.go index 1782f400e..7e22e32ac 100644 --- a/wrappers/golang/curves/bw6761/g2_msm.go +++ b/wrappers/golang/curves/bw6761/g2_msm.go @@ -66,10 +66,8 @@ func G2PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c cC := (C.int)(c) cPointsIsOnDevice := (C._Bool)(points.IsOnDevice()) cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx)) - - var outputBasesPointer unsafe.Pointer - outputBasesPointer = outputBases.AsPointer() + outputBasesPointer := outputBases.AsPointer() cOutputBases := (*C.g2_affine_t)(outputBasesPointer) __ret := C.bw6_761G2PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases) diff --git a/wrappers/golang/curves/bw6761/g2_msm_test.go b/wrappers/golang/curves/bw6761/g2_msm_test.go index 388e7dbc7..bca18c7ff 100644 --- a/wrappers/golang/curves/bw6761/g2_msm_test.go +++ b/wrappers/golang/curves/bw6761/g2_msm_test.go @@ -109,6 +109,49 @@ func TestMSMG2Batch(t *testing.T) { } } +func TestPrecomputeBaseG2(t *testing.T) { + cfg := GetDefaultMSMConfig() + const precomputeFactor = 8 + for _, power := range []int{10, 16} { + for _, batchSize := range []int{1, 3, 16} { + size := 1 << power + totalSize := size * batchSize + scalars := GenerateScalars(totalSize) + points := G2GenerateAffinePoints(totalSize) + + var precomputeOut core.DeviceSlice + _, e := precomputeOut.Malloc(points[0].Size()*points.Len()*int(precomputeFactor), points[0].Size()) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for PrecomputeBases results failed") + + e = G2PrecomputeBases(points, precomputeFactor, 0, &cfg.Ctx, precomputeOut) + assert.Equal(t, e, cr.CudaSuccess, "PrecomputeBases failed") + + var p G2Projective + var out core.DeviceSlice + _, e = out.Malloc(batchSize*p.Size(), p.Size()) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed") + + cfg.PrecomputeFactor = precomputeFactor + + e = G2Msm(scalars, precomputeOut, &cfg, out) + assert.Equal(t, e, cr.CudaSuccess, "Msm failed") + outHost := make(core.HostSlice[G2Projective], batchSize) + outHost.CopyFromDevice(&out) + out.Free() + precomputeOut.Free() + + // Check with gnark-crypto + for i := 0; i < batchSize; i++ { + scalarsSlice := scalars[i*size : (i+1)*size] + pointsSlice := points[i*size : (i+1)*size] + out := outHost[i] + assert.True(t, testAgainstGnarkCryptoMsmG2(scalarsSlice, pointsSlice, out)) + } + } + } +} + + func TestMSMG2SkewedDistribution(t *testing.T) { cfg := GetDefaultMSMConfig() for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} { diff --git a/wrappers/golang/curves/bw6761/msm.go b/wrappers/golang/curves/bw6761/msm.go index 771e4d5f2..a5146755d 100644 --- a/wrappers/golang/curves/bw6761/msm.go +++ b/wrappers/golang/curves/bw6761/msm.go @@ -64,10 +64,8 @@ func PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c in cC := (C.int)(c) cPointsIsOnDevice := (C._Bool)(points.IsOnDevice()) cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx)) - - var outputBasesPointer unsafe.Pointer - outputBasesPointer = outputBases.AsPointer() + outputBasesPointer := outputBases.AsPointer() cOutputBases := (*C.affine_t)(outputBasesPointer) __ret := C.bw6_761PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases) diff --git a/wrappers/golang/curves/bw6761/msm_test.go b/wrappers/golang/curves/bw6761/msm_test.go index 415f88cbe..302b93ccc 100644 --- a/wrappers/golang/curves/bw6761/msm_test.go +++ b/wrappers/golang/curves/bw6761/msm_test.go @@ -107,6 +107,49 @@ func TestMSMBatch(t *testing.T) { } } +func TestPrecomputeBase(t *testing.T) { + cfg := GetDefaultMSMConfig() + const precomputeFactor = 8 + for _, power := range []int{10, 16} { + for _, batchSize := range []int{1, 3, 16} { + size := 1 << power + totalSize := size * batchSize + scalars := GenerateScalars(totalSize) + points := GenerateAffinePoints(totalSize) + + var precomputeOut core.DeviceSlice + _, e := precomputeOut.Malloc(points[0].Size()*points.Len()*int(precomputeFactor), points[0].Size()) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for PrecomputeBases results failed") + + e = PrecomputeBases(points, precomputeFactor, 0, &cfg.Ctx, precomputeOut) + assert.Equal(t, e, cr.CudaSuccess, "PrecomputeBases failed") + + var p Projective + var out core.DeviceSlice + _, e = out.Malloc(batchSize*p.Size(), p.Size()) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed") + + cfg.PrecomputeFactor = precomputeFactor + + e = Msm(scalars, precomputeOut, &cfg, out) + assert.Equal(t, e, cr.CudaSuccess, "Msm failed") + outHost := make(core.HostSlice[Projective], batchSize) + outHost.CopyFromDevice(&out) + out.Free() + precomputeOut.Free() + + // Check with gnark-crypto + for i := 0; i < batchSize; i++ { + scalarsSlice := scalars[i*size : (i+1)*size] + pointsSlice := points[i*size : (i+1)*size] + out := outHost[i] + assert.True(t, testAgainstGnarkCryptoMsm(scalarsSlice, pointsSlice, out)) + } + } + } +} + + func TestMSMSkewedDistribution(t *testing.T) { cfg := GetDefaultMSMConfig() for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} { diff --git a/wrappers/golang/internal/generator/templates/msm.go.tmpl b/wrappers/golang/internal/generator/templates/msm.go.tmpl index 4b69aa9c9..80e79f7e4 100644 --- a/wrappers/golang/internal/generator/templates/msm.go.tmpl +++ b/wrappers/golang/internal/generator/templates/msm.go.tmpl @@ -68,10 +68,8 @@ func {{if .IsG2}}G2{{end}}PrecomputeBases(points core.HostOrDeviceSlice, precomp cC := (C.int)(c) cPointsIsOnDevice := (C._Bool)(points.IsOnDevice()) cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx)) - - var outputBasesPointer unsafe.Pointer - outputBasesPointer = outputBases.AsPointer() + outputBasesPointer := outputBases.AsPointer() cOutputBases := (*C.{{if .IsG2}}g2_{{end}}affine_t)(outputBasesPointer) __ret := C.{{.Curve}}{{if .IsG2}}G2{{end}}PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases) diff --git a/wrappers/golang/internal/generator/templates/msm_test.go.tmpl b/wrappers/golang/internal/generator/templates/msm_test.go.tmpl index 9c3bbd200..e7c127948 100644 --- a/wrappers/golang/internal/generator/templates/msm_test.go.tmpl +++ b/wrappers/golang/internal/generator/templates/msm_test.go.tmpl @@ -156,6 +156,49 @@ func TestMSM{{if .IsG2}}G2{{end}}Batch(t *testing.T) { } } +func TestPrecomputeBase{{if .IsG2}}G2{{end}}(t *testing.T) { + cfg := GetDefaultMSMConfig() + const precomputeFactor = 8 + for _, power := range []int{10, 16} { + for _, batchSize := range []int{1, 3, 16} { + size := 1 << power + totalSize := size * batchSize + scalars := GenerateScalars(totalSize) + points := {{if .IsG2}}G2{{end}}GenerateAffinePoints(totalSize) + + var precomputeOut core.DeviceSlice + _, e := precomputeOut.Malloc(points[0].Size()*points.Len()*int(precomputeFactor), points[0].Size()) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for PrecomputeBases results failed") + + e = {{if .IsG2}}G2{{end}}PrecomputeBases(points, precomputeFactor, 0, &cfg.Ctx, precomputeOut) + assert.Equal(t, e, cr.CudaSuccess, "PrecomputeBases failed") + + var p {{if .IsG2}}G2{{end}}Projective + var out core.DeviceSlice + _, e = out.Malloc(batchSize*p.Size(), p.Size()) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed") + + cfg.PrecomputeFactor = precomputeFactor + + e = {{if .IsG2}}G2{{end}}Msm(scalars, precomputeOut, &cfg, out) + assert.Equal(t, e, cr.CudaSuccess, "Msm failed") + outHost := make(core.HostSlice[{{if .IsG2}}G2{{end}}Projective], batchSize) + outHost.CopyFromDevice(&out) + out.Free() + precomputeOut.Free() + + // Check with gnark-crypto + for i := 0; i < batchSize; i++ { + scalarsSlice := scalars[i*size : (i+1)*size] + pointsSlice := points[i*size : (i+1)*size] + out := outHost[i] + assert.True(t, testAgainstGnarkCryptoMsm{{if .IsG2}}G2{{end}}(scalarsSlice, pointsSlice, out)) + } + } + } +} + + func TestMSM{{if .IsG2}}G2{{end}}SkewedDistribution(t *testing.T) { cfg := GetDefaultMSMConfig() for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {