Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/timur/msm precompute go #432

Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,27 @@ jobs:
steps:
- name: Checkout
uses: actions/checkout@v4
with:
ssh-key: ${{ secrets.DEPLOY_KEY }}
- name: Setup Cache
id: cache
uses: actions/cache@v4
with:
path: |
~/.cargo/bin/
~/.cargo/registry/index/
~/.cargo/registry/cache/
~/.cargo/git/db/
key: ${{ runner.os }}-cargo-${{ hashFiles('~/.cargo/bin/cargo-workspaces') }}
- name: Install cargo-workspaces
if: steps.cache.outputs.cache-hit != 'true'
run: cargo install cargo-workspaces
- name: Bump rust crate versions, commit, and tag
working-directory: wrappers/rust
# https://github.com/pksunkara/cargo-workspaces?tab=readme-ov-file#version
run: |
cargo install cargo-workspaces
git config user.name release-bot
git config user.email release-bot@ingonyama.com
cargo workspaces version ${{ inputs.releaseType }} -y --no-individual-tags -m "Bump rust crates' version"
- name: Create draft release
env:
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ ENV PATH="/root/.cargo/bin:${PATH}"

# Install Golang
ENV GOLANG_VERSION 1.21.1
RUN curl -L https://golang.org/dl/go${GOLANG_VERSION}.linux-amd64.tar.gz | tar -xz -C /usr/local
RUN curl -L https://go.dev/dl/go${GOLANG_VERSION}.linux-amd64.tar.gz | tar -xz -C /usr/local
ENV PATH="/usr/local/go/bin:${PATH}"

# Set the working directory in the container
Expand Down
16 changes: 16 additions & 0 deletions wrappers/golang/core/msm.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,19 @@ func MsmCheck(scalars HostOrDeviceSlice, points HostOrDeviceSlice, cfg *MSMConfi
cfg.arePointsOnDevice = points.IsOnDevice()
cfg.areResultsOnDevice = results.IsOnDevice()
}

func PrecomputeBasesCheck(points HostOrDeviceSlice, precompute_factor int32, output_bases HostOrDeviceSlice) {
outputBasesLength, pointsLength := output_bases.Len(), points.Len()
if outputBasesLength != pointsLength*int(precompute_factor) {
errorString := fmt.Sprintf(
"Precompute factor is probably incorrect: expected %d but got %d",
outputBasesLength/pointsLength,
precompute_factor,
)
panic(errorString)
}

if !output_bases.IsOnDevice() {
panic("Output bases are not on device")
}
jeremyfelder marked this conversation as resolved.
Show resolved Hide resolved
}
30 changes: 30 additions & 0 deletions wrappers/golang/curves/bls12377/g2_msm.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,33 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
err := (cr.CudaError)(__ret)
return err
}

func G2PrecomputeBases(points core.HostOrDeviceSlice, precompute_factor int32, _c int32, ctx *cr.DeviceContext, output_bases core.HostOrDeviceSlice) cr.CudaError {
jeremyfelder marked this conversation as resolved.
Show resolved Hide resolved
jeremyfelder marked this conversation as resolved.
Show resolved Hide resolved
core.PrecomputeBasesCheck(points, precompute_factor, output_bases)

var pointsPointer unsafe.Pointer
if points.IsOnDevice() {
pointsPointer = points.(core.DeviceSlice).AsPointer()
} else {
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[G2Affine])[0])
}
cPoints := (*C.g2_affine_t)(pointsPointer)

cPointsLen := (C.int)(points.Len())
cPrecomputeFactor := (C.int)(precompute_factor)
c_C := (C.int)(_c)
cPointsIsOnDevice := (C._Bool)(points.IsOnDevice())
cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx))
jeremyfelder marked this conversation as resolved.
Show resolved Hide resolved

var outputBasesPointer unsafe.Pointer
if output_bases.IsOnDevice() {
outputBasesPointer = output_bases.(core.DeviceSlice).AsPointer()
} else {
outputBasesPointer = unsafe.Pointer(&output_bases.(core.HostSlice[G2Projective])[0])
}
cOutputBases := (*C.g2_affine_t)(outputBasesPointer)

__ret := C.bls12_377G2PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, c_C, cPointsIsOnDevice, cCtx, cOutputBases)
err := (cr.CudaError)(__ret)
return err
}
2 changes: 2 additions & 0 deletions wrappers/golang/curves/bls12377/include/msm.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <cuda_runtime.h>
#include "../../include/types.h"
#include <stdbool.h>

#ifndef _BLS12_377_MSM_H
#define _BLS12_377_MSM_H
Expand All @@ -9,6 +10,7 @@ extern "C" {
#endif

cudaError_t bls12_377MSMCuda(scalar_t* scalars, affine_t* points, int count, MSMConfig* config, projective_t* out);
cudaError_t bls12_377PrecomputeMSMBases(affine_t* points, int bases_size, int precompute_factor, int _c, bool are_bases_on_device, DeviceContext* ctx, affine_t* output_bases);

#ifdef __cplusplus
}
Expand Down
30 changes: 30 additions & 0 deletions wrappers/golang/curves/bls12377/msm.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,33 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
err := (cr.CudaError)(__ret)
return err
}

func PrecomputeBases(points core.HostOrDeviceSlice, precompute_factor int32, _c int32, ctx *cr.DeviceContext, output_bases core.HostOrDeviceSlice) cr.CudaError {
core.PrecomputeBasesCheck(points, precompute_factor, output_bases)

var pointsPointer unsafe.Pointer
if points.IsOnDevice() {
pointsPointer = points.(core.DeviceSlice).AsPointer()
} else {
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Affine])[0])
}
cPoints := (*C.affine_t)(pointsPointer)

cPointsLen := (C.int)(points.Len())
cPrecomputeFactor := (C.int)(precompute_factor)
c_C := (C.int)(_c)
cPointsIsOnDevice := (C._Bool)(points.IsOnDevice())
cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx))

var outputBasesPointer unsafe.Pointer
if output_bases.IsOnDevice() {
outputBasesPointer = output_bases.(core.DeviceSlice).AsPointer()
} else {
outputBasesPointer = unsafe.Pointer(&output_bases.(core.HostSlice[Projective])[0])
}
cOutputBases := (*C.affine_t)(outputBasesPointer)

__ret := C.bls12_377PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, c_C, cPointsIsOnDevice, cCtx, cOutputBases)
err := (cr.CudaError)(__ret)
return err
}
30 changes: 30 additions & 0 deletions wrappers/golang/curves/bls12381/g2_msm.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,33 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
err := (cr.CudaError)(__ret)
return err
}

func G2PrecomputeBases(points core.HostOrDeviceSlice, precompute_factor int32, _c int32, ctx *cr.DeviceContext, output_bases core.HostOrDeviceSlice) cr.CudaError {
core.PrecomputeBasesCheck(points, precompute_factor, output_bases)

var pointsPointer unsafe.Pointer
if points.IsOnDevice() {
pointsPointer = points.(core.DeviceSlice).AsPointer()
} else {
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[G2Affine])[0])
}
cPoints := (*C.g2_affine_t)(pointsPointer)

cPointsLen := (C.int)(points.Len())
cPrecomputeFactor := (C.int)(precompute_factor)
c_C := (C.int)(_c)
cPointsIsOnDevice := (C._Bool)(points.IsOnDevice())
cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx))

var outputBasesPointer unsafe.Pointer
if output_bases.IsOnDevice() {
outputBasesPointer = output_bases.(core.DeviceSlice).AsPointer()
} else {
outputBasesPointer = unsafe.Pointer(&output_bases.(core.HostSlice[G2Projective])[0])
}
cOutputBases := (*C.g2_affine_t)(outputBasesPointer)

__ret := C.bls12_381G2PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, c_C, cPointsIsOnDevice, cCtx, cOutputBases)
err := (cr.CudaError)(__ret)
return err
}
2 changes: 2 additions & 0 deletions wrappers/golang/curves/bls12381/include/msm.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <cuda_runtime.h>
#include "../../include/types.h"
#include <stdbool.h>

#ifndef _BLS12_381_MSM_H
#define _BLS12_381_MSM_H
Expand All @@ -9,6 +10,7 @@ extern "C" {
#endif

cudaError_t bls12_381MSMCuda(scalar_t* scalars, affine_t* points, int count, MSMConfig* config, projective_t* out);
cudaError_t bls12_381PrecomputeMSMBases(affine_t* points, int bases_size, int precompute_factor, int _c, bool are_bases_on_device, DeviceContext* ctx, affine_t* output_bases);

#ifdef __cplusplus
}
Expand Down
30 changes: 30 additions & 0 deletions wrappers/golang/curves/bls12381/msm.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,33 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
err := (cr.CudaError)(__ret)
return err
}

func PrecomputeBases(points core.HostOrDeviceSlice, precompute_factor int32, _c int32, ctx *cr.DeviceContext, output_bases core.HostOrDeviceSlice) cr.CudaError {
core.PrecomputeBasesCheck(points, precompute_factor, output_bases)

var pointsPointer unsafe.Pointer
if points.IsOnDevice() {
pointsPointer = points.(core.DeviceSlice).AsPointer()
} else {
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Affine])[0])
}
cPoints := (*C.affine_t)(pointsPointer)

cPointsLen := (C.int)(points.Len())
cPrecomputeFactor := (C.int)(precompute_factor)
c_C := (C.int)(_c)
cPointsIsOnDevice := (C._Bool)(points.IsOnDevice())
cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx))

var outputBasesPointer unsafe.Pointer
if output_bases.IsOnDevice() {
outputBasesPointer = output_bases.(core.DeviceSlice).AsPointer()
} else {
outputBasesPointer = unsafe.Pointer(&output_bases.(core.HostSlice[Projective])[0])
}
cOutputBases := (*C.affine_t)(outputBasesPointer)

__ret := C.bls12_381PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, c_C, cPointsIsOnDevice, cCtx, cOutputBases)
err := (cr.CudaError)(__ret)
return err
}
30 changes: 30 additions & 0 deletions wrappers/golang/curves/bn254/g2_msm.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,33 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
err := (cr.CudaError)(__ret)
return err
}

func G2PrecomputeBases(points core.HostOrDeviceSlice, precompute_factor int32, _c int32, ctx *cr.DeviceContext, output_bases core.HostOrDeviceSlice) cr.CudaError {
core.PrecomputeBasesCheck(points, precompute_factor, output_bases)

var pointsPointer unsafe.Pointer
if points.IsOnDevice() {
pointsPointer = points.(core.DeviceSlice).AsPointer()
} else {
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[G2Affine])[0])
}
cPoints := (*C.g2_affine_t)(pointsPointer)

cPointsLen := (C.int)(points.Len())
cPrecomputeFactor := (C.int)(precompute_factor)
c_C := (C.int)(_c)
cPointsIsOnDevice := (C._Bool)(points.IsOnDevice())
cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx))

var outputBasesPointer unsafe.Pointer
if output_bases.IsOnDevice() {
outputBasesPointer = output_bases.(core.DeviceSlice).AsPointer()
} else {
outputBasesPointer = unsafe.Pointer(&output_bases.(core.HostSlice[G2Projective])[0])
}
cOutputBases := (*C.g2_affine_t)(outputBasesPointer)

__ret := C.bn254G2PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, c_C, cPointsIsOnDevice, cCtx, cOutputBases)
err := (cr.CudaError)(__ret)
return err
}
2 changes: 2 additions & 0 deletions wrappers/golang/curves/bn254/include/msm.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <cuda_runtime.h>
#include "../../include/types.h"
#include <stdbool.h>

#ifndef _BN254_MSM_H
#define _BN254_MSM_H
Expand All @@ -9,6 +10,7 @@ extern "C" {
#endif

cudaError_t bn254MSMCuda(scalar_t* scalars, affine_t* points, int count, MSMConfig* config, projective_t* out);
cudaError_t bn254PrecomputeMSMBases(affine_t* points, int bases_size, int precompute_factor, int _c, bool are_bases_on_device, DeviceContext* ctx, affine_t* output_bases);

#ifdef __cplusplus
}
Expand Down
30 changes: 30 additions & 0 deletions wrappers/golang/curves/bn254/msm.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,33 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
err := (cr.CudaError)(__ret)
return err
}

func PrecomputeBases(points core.HostOrDeviceSlice, precompute_factor int32, _c int32, ctx *cr.DeviceContext, output_bases core.HostOrDeviceSlice) cr.CudaError {
core.PrecomputeBasesCheck(points, precompute_factor, output_bases)

var pointsPointer unsafe.Pointer
if points.IsOnDevice() {
pointsPointer = points.(core.DeviceSlice).AsPointer()
} else {
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Affine])[0])
}
cPoints := (*C.affine_t)(pointsPointer)

cPointsLen := (C.int)(points.Len())
cPrecomputeFactor := (C.int)(precompute_factor)
c_C := (C.int)(_c)
cPointsIsOnDevice := (C._Bool)(points.IsOnDevice())
cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx))

var outputBasesPointer unsafe.Pointer
if output_bases.IsOnDevice() {
outputBasesPointer = output_bases.(core.DeviceSlice).AsPointer()
} else {
outputBasesPointer = unsafe.Pointer(&output_bases.(core.HostSlice[Projective])[0])
}
cOutputBases := (*C.affine_t)(outputBasesPointer)

__ret := C.bn254PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, c_C, cPointsIsOnDevice, cCtx, cOutputBases)
err := (cr.CudaError)(__ret)
return err
}
30 changes: 30 additions & 0 deletions wrappers/golang/curves/bw6761/g2_msm.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,33 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
err := (cr.CudaError)(__ret)
return err
}

func G2PrecomputeBases(points core.HostOrDeviceSlice, precompute_factor int32, _c int32, ctx *cr.DeviceContext, output_bases core.HostOrDeviceSlice) cr.CudaError {
core.PrecomputeBasesCheck(points, precompute_factor, output_bases)

var pointsPointer unsafe.Pointer
if points.IsOnDevice() {
pointsPointer = points.(core.DeviceSlice).AsPointer()
} else {
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[G2Affine])[0])
}
cPoints := (*C.g2_affine_t)(pointsPointer)

cPointsLen := (C.int)(points.Len())
cPrecomputeFactor := (C.int)(precompute_factor)
c_C := (C.int)(_c)
cPointsIsOnDevice := (C._Bool)(points.IsOnDevice())
cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx))

var outputBasesPointer unsafe.Pointer
if output_bases.IsOnDevice() {
outputBasesPointer = output_bases.(core.DeviceSlice).AsPointer()
} else {
outputBasesPointer = unsafe.Pointer(&output_bases.(core.HostSlice[G2Projective])[0])
}
cOutputBases := (*C.g2_affine_t)(outputBasesPointer)

__ret := C.bw6_761G2PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, c_C, cPointsIsOnDevice, cCtx, cOutputBases)
err := (cr.CudaError)(__ret)
return err
}
2 changes: 2 additions & 0 deletions wrappers/golang/curves/bw6761/include/msm.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <cuda_runtime.h>
#include "../../include/types.h"
#include <stdbool.h>

#ifndef _BW6_761_MSM_H
#define _BW6_761_MSM_H
Expand All @@ -9,6 +10,7 @@ extern "C" {
#endif

cudaError_t bw6_761MSMCuda(scalar_t* scalars, affine_t* points, int count, MSMConfig* config, projective_t* out);
cudaError_t bw6_761PrecomputeMSMBases(affine_t* points, int bases_size, int precompute_factor, int _c, bool are_bases_on_device, DeviceContext* ctx, affine_t* output_bases);

#ifdef __cplusplus
}
Expand Down
30 changes: 30 additions & 0 deletions wrappers/golang/curves/bw6761/msm.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,33 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
err := (cr.CudaError)(__ret)
return err
}

func PrecomputeBases(points core.HostOrDeviceSlice, precompute_factor int32, _c int32, ctx *cr.DeviceContext, output_bases core.HostOrDeviceSlice) cr.CudaError {
core.PrecomputeBasesCheck(points, precompute_factor, output_bases)

var pointsPointer unsafe.Pointer
if points.IsOnDevice() {
pointsPointer = points.(core.DeviceSlice).AsPointer()
} else {
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Affine])[0])
}
cPoints := (*C.affine_t)(pointsPointer)

cPointsLen := (C.int)(points.Len())
cPrecomputeFactor := (C.int)(precompute_factor)
c_C := (C.int)(_c)
cPointsIsOnDevice := (C._Bool)(points.IsOnDevice())
cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx))

var outputBasesPointer unsafe.Pointer
if output_bases.IsOnDevice() {
outputBasesPointer = output_bases.(core.DeviceSlice).AsPointer()
} else {
outputBasesPointer = unsafe.Pointer(&output_bases.(core.HostSlice[Projective])[0])
}
cOutputBases := (*C.affine_t)(outputBasesPointer)

__ret := C.bw6_761PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, c_C, cPointsIsOnDevice, cCtx, cOutputBases)
err := (cr.CudaError)(__ret)
return err
}
Loading