-
Notifications
You must be signed in to change notification settings - Fork 0
/
ModMultiply.m
55 lines (51 loc) · 1.85 KB
/
ModMultiply.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
% calculate (a * b) % m without overflowing uint64
% private function which does minimal input validation
% This implementation takes inspiration from
% https://en.wikipedia.org/wiki/Modular_arithmetic#Example_implementations
% however it also includes several vectorized optimizations and fast-paths to boost performance
function result = ModMultiply(a, b, m)
% faster calculation for non-overflowing multiplication
c = a * b;
if c ~= uint64(18446744073709551615) % intmax("uint64")
% does not overflow, use faster calculation
result = mod(c, m);
return;
elseif a == b
% Faster handling for special case of a*a (mod m)
% a^2 (mod m) equivalent to (m - a)^2 mod m, so choose smaller
assert(m > a);
aSmall = min(m - a, a);
if aSmall <= 4294967295
% highest number squared which does not oveflow uint64
result = mod(aSmall * aSmall, m);
return;
end
elseif a == 2
% faster calculation when a=2 (common in LucasPrime)
result = b - (m - b);
return;
end
assert(b ~= 2, 'a & b should be flipped');
shift = bitshift(b, 0:-1:-63);
oddIdx = find(mod(shift, 2) == 1);
stopIdx = oddIdx(end);
seqA = zeros(1, stopIdx, 'uint64');
seqA(1) = a;
for idx = 2 : stopIdx
if seqA(idx-1) >= uint64(9223372036854775808) % intmax("uint64") / 2
seqA(idx) = seqA(idx-1) - (m - seqA(idx-1)); % avoid overflow of a*2
else
seqA(idx) = mod(bitshift(seqA(idx-1), 1), m);
end
end
result = sum(seqA(oddIdx), 'native');
if result == uint64(18446744073709551615)
% mod add when sum overflows
result = 0;
for idx = oddIdx % odd numbers
result = ModAdd(result, seqA(idx), m);
end
else
result = mod(result, m);
end
end