Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MSM - supporting all window sizes #534

Merged
merged 16 commits into from
Jun 17, 2024
Merged

Conversation

HadarIngonyama
Copy link
Contributor

This PR enables using MSM with any value of c.

Note: default c isn't necessarily optimal, the user is expected to choose c and the precomputation factor that give the best results for the relevant case.

@HadarIngonyama HadarIngonyama changed the title Msm/phase 2 generalization MSM - supporting all window sizes Jun 5, 2024
@HadarIngonyama
Copy link
Contributor Author

Fixed precomputation tests in rust and go. Precomputaion function interface has changed as a result.

Comment on lines 37 to 47
AreScalarsOnDevice bool

/// True if scalars are in Montgomery form and false otherwise. Default value: true.
AreScalarsMontgomeryForm bool

arePointsOnDevice bool
ArePointsOnDevice bool

/// True if coordinates of points are in Montgomery form and false otherwise. Default value: true.
ArePointsMontgomeryForm bool

areResultsOnDevice bool
AreResultsOnDevice bool
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These should remain private as the XXXCheck functions should manipulate them internally based on the slice arguments. Its less error prone for the user

e = msm.PrecomputeBases(points, precomputeFactor, 0, &cfg.Ctx, precomputeOut)
cfg.PrecomputeFactor = precomputeFactor
cfg.PointsSize = int32(points.Len())
cfg.ArePointsOnDevice = points.IsOnDevice()
Copy link
Collaborator

@jeremyfelder jeremyfelder Jun 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the previous comment, this should be checked and updated in core.PrecomputeBasesCheck which is called in msm.PrecomputeBases

Same for cfg.PointsSize, it should be private and updated in core.PrecomputeBasesCheck

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, will fix, what about the rust test? is it important there too?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there as well. PrecomputeFactor is fine to update manually as it is public but points_size is not public so it shouldn't be updated manually

{
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid >= num_of_threads) { return; }
if (tid >= nof_threads) { return; }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be consistent on formatting; see line 99

Suggested change
if (tid >= nof_threads) { return; }
if (tid >= nof_threads) return;

0, // bitsize
10, // large_bucket_factor
batch_size, // batch_size
false, // are_scalars_on_device
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why don't you use scalar_d if you already copied them to device? intentional? if yes, no need to allocate scalars_d and copy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just a test file, I copy to have all options available, I don't use scalar_d because in zprize you need to include scalar transfer time

@@ -726,62 +745,90 @@ namespace msm {
NUM_BLOCKS = (nof_bms_in_batch + NUM_THREADS - 1) / NUM_THREADS;
big_triangle_sum_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(buckets, final_results, nof_bms_in_batch, c);
} else {
// the recursive reduction algorithm works with 2 types of reduction that can run on parallel streams
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any chance you move all this reduction logic to a 'reduction_phase()' function? thanks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in that case we should also have scalar_splitting_phase() sorting_phase() accumulation_phase() and so on, if we want this refactoring let's do it in a different PR

bool are_bases_on_device,
device_context::DeviceContext& ctx,
A* output_bases)
cudaError_t precompute_msm_bases(A* bases, int msm_size, MSMConfig& config, A* output_bases)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After discussing with the team, in order to prevent a breaking change right now, we think its best to:

  1. create a second function that will call the old function after computing c
  2. add a deprecation comment on the old function

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, why do we need to call the old function? I will just give my new precompute function a different name and also keep the old one as deprecated.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just so you don't implement twice.

@LeonHibnik LeonHibnik merged commit 8936d9c into main Jun 17, 2024
25 checks passed
@LeonHibnik LeonHibnik deleted the msm/phase_2_generalization branch June 17, 2024 12:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants