-
Notifications
You must be signed in to change notification settings - Fork 0
/
mpint.c
2395 lines (2110 loc) · 81.2 KB
/
mpint.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#include <assert.h>
#include <limits.h>
#include <stdio.h>
#include "defs.h"
#include "misc.h"
#include "puttymem.h"
#include "mpint.h"
#include "mpint_i.h"
#define SIZE_T_BITS (CHAR_BIT * sizeof(size_t))
/*
* Inline helpers to take min and max of size_t values, used
* throughout this code.
*/
static inline size_t size_t_min(size_t a, size_t b)
{
return a < b ? a : b;
}
static inline size_t size_t_max(size_t a, size_t b)
{
return a > b ? a : b;
}
/*
* Helper to fetch a word of data from x with array overflow checking.
* If x is too short to have that word, 0 is returned.
*/
static inline BignumInt mp_word(mp_int *x, size_t i)
{
return i < x->nw ? x->w[i] : 0;
}
static mp_int *mp_make_sized(size_t nw)
{
mp_int *x = snew_plus(mp_int, nw * sizeof(BignumInt));
assert(nw); /* we outlaw the zero-word mp_int */
x->nw = nw;
x->w = snew_plus_get_aux(x);
mp_clear(x);
return x;
}
mp_int *mp_new(size_t maxbits)
{
size_t words = (maxbits + BIGNUM_INT_BITS - 1) / BIGNUM_INT_BITS;
return mp_make_sized(words);
}
mp_int *mp_from_integer(uintmax_t n)
{
mp_int *x = mp_make_sized(
(sizeof(n) + BIGNUM_INT_BYTES - 1) / BIGNUM_INT_BYTES);
for (size_t i = 0; i < x->nw; i++)
x->w[i] = n >> (i * BIGNUM_INT_BITS);
return x;
}
size_t mp_max_bytes(mp_int *x)
{
return x->nw * BIGNUM_INT_BYTES;
}
size_t mp_max_bits(mp_int *x)
{
return x->nw * BIGNUM_INT_BITS;
}
void mp_free(mp_int *x)
{
mp_clear(x);
smemclr(x, sizeof(*x));
sfree(x);
}
void mp_dump(FILE *fp, const char *prefix, mp_int *x, const char *suffix)
{
fprintf(fp, "%s0x", prefix);
for (size_t i = mp_max_bytes(x); i-- > 0 ;)
fprintf(fp, "%02X", mp_get_byte(x, i));
fputs(suffix, fp);
}
void mp_copy_into(mp_int *dest, mp_int *src)
{
size_t copy_nw = size_t_min(dest->nw, src->nw);
memmove(dest->w, src->w, copy_nw * sizeof(BignumInt));
smemclr(dest->w + copy_nw, (dest->nw - copy_nw) * sizeof(BignumInt));
}
/*
* Conditional selection is done by negating 'which', to give a mask
* word which is all 1s if which==1 and all 0s if which==0. Then you
* can select between two inputs a,b without data-dependent control
* flow by XORing them to get their difference; ANDing with the mask
* word to replace that difference with 0 if which==0; and XORing that
* into a, which will either turn it into b or leave it alone.
*
* This trick will be used throughout this code and taken as read the
* rest of the time (or else I'd be here all week typing comments),
* but I felt I ought to explain it in words _once_.
*/
void mp_select_into(mp_int *dest, mp_int *src0, mp_int *src1,
unsigned which)
{
BignumInt mask = -(BignumInt)(1 & which);
for (size_t i = 0; i < dest->nw; i++) {
BignumInt srcword0 = mp_word(src0, i), srcword1 = mp_word(src1, i);
dest->w[i] = srcword0 ^ ((srcword1 ^ srcword0) & mask);
}
}
void mp_cond_swap(mp_int *x0, mp_int *x1, unsigned swap)
{
assert(x0->nw == x1->nw);
volatile BignumInt mask = -(BignumInt)(1 & swap);
for (size_t i = 0; i < x0->nw; i++) {
BignumInt diff = (x0->w[i] ^ x1->w[i]) & mask;
x0->w[i] ^= diff;
x1->w[i] ^= diff;
}
}
void mp_clear(mp_int *x)
{
smemclr(x->w, x->nw * sizeof(BignumInt));
}
void mp_cond_clear(mp_int *x, unsigned clear)
{
BignumInt mask = ~-(BignumInt)(1 & clear);
for (size_t i = 0; i < x->nw; i++)
x->w[i] &= mask;
}
/*
* Common code between mp_from_bytes_{le,be} which reads bytes in an
* arbitrary arithmetic progression.
*/
static mp_int *mp_from_bytes_int(ptrlen bytes, size_t m, size_t c)
{
size_t nw = (bytes.len + BIGNUM_INT_BYTES - 1) / BIGNUM_INT_BYTES;
nw = size_t_max(nw, 1);
mp_int *n = mp_make_sized(nw);
for (size_t i = 0; i < bytes.len; i++)
n->w[i / BIGNUM_INT_BYTES] |=
(BignumInt)(((const unsigned char *)bytes.ptr)[m*i+c]) <<
(8 * (i % BIGNUM_INT_BYTES));
return n;
}
mp_int *mp_from_bytes_le(ptrlen bytes)
{
return mp_from_bytes_int(bytes, 1, 0);
}
mp_int *mp_from_bytes_be(ptrlen bytes)
{
return mp_from_bytes_int(bytes, -1, bytes.len - 1);
}
static mp_int *mp_from_words(size_t nw, const BignumInt *w)
{
mp_int *x = mp_make_sized(nw);
memcpy(x->w, w, x->nw * sizeof(BignumInt));
return x;
}
/*
* Decimal-to-binary conversion: just go through the input string
* adding on the decimal value of each digit, and then multiplying the
* number so far by 10.
*/
mp_int *mp_from_decimal_pl(ptrlen decimal)
{
/* 196/59 is an upper bound (and also a continued-fraction
* convergent) for log2(10), so this conservatively estimates the
* number of bits that will be needed to store any number that can
* be written in this many decimal digits. */
assert(decimal.len < (~(size_t)0) / 196);
size_t bits = 196 * decimal.len / 59;
/* Now round that up to words. */
size_t words = bits / BIGNUM_INT_BITS + 1;
mp_int *x = mp_make_sized(words);
for (size_t i = 0; i < decimal.len; i++) {
mp_add_integer_into(x, x, ((char *)decimal.ptr)[i] - '0');
if (i+1 == decimal.len)
break;
mp_mul_integer_into(x, x, 10);
}
return x;
}
mp_int *mp_from_decimal(const char *decimal)
{
return mp_from_decimal_pl(ptrlen_from_asciz(decimal));
}
/*
* Hex-to-binary conversion: _algorithmically_ simpler than decimal
* (none of those multiplications by 10), but there's some fiddly
* bit-twiddling needed to process each hex digit without diverging
* control flow depending on whether it's a letter or a number.
*/
mp_int *mp_from_hex_pl(ptrlen hex)
{
assert(hex.len <= (~(size_t)0) / 4);
size_t bits = hex.len * 4;
size_t words = (bits + BIGNUM_INT_BITS - 1) / BIGNUM_INT_BITS;
words = size_t_max(words, 1);
mp_int *x = mp_make_sized(words);
for (size_t nibble = 0; nibble < hex.len; nibble++) {
BignumInt digit = ((char *)hex.ptr)[hex.len-1 - nibble];
BignumInt lmask = ~-((BignumInt)((digit-'a')|('f'-digit))
>> (BIGNUM_INT_BITS-1));
BignumInt umask = ~-((BignumInt)((digit-'A')|('F'-digit))
>> (BIGNUM_INT_BITS-1));
BignumInt digitval = digit - '0';
digitval ^= (digitval ^ (digit - 'a' + 10)) & lmask;
digitval ^= (digitval ^ (digit - 'A' + 10)) & umask;
digitval &= 0xF; /* at least be _slightly_ nice about weird input */
size_t word_idx = nibble / (BIGNUM_INT_BYTES*2);
size_t nibble_within_word = nibble % (BIGNUM_INT_BYTES*2);
x->w[word_idx] |= digitval << (nibble_within_word * 4);
}
return x;
}
mp_int *mp_from_hex(const char *hex)
{
return mp_from_hex_pl(ptrlen_from_asciz(hex));
}
mp_int *mp_copy(mp_int *x)
{
return mp_from_words(x->nw, x->w);
}
uint8_t mp_get_byte(mp_int *x, size_t byte)
{
return 0xFF & (mp_word(x, byte / BIGNUM_INT_BYTES) >>
(8 * (byte % BIGNUM_INT_BYTES)));
}
unsigned mp_get_bit(mp_int *x, size_t bit)
{
return 1 & (mp_word(x, bit / BIGNUM_INT_BITS) >>
(bit % BIGNUM_INT_BITS));
}
uintmax_t mp_get_integer(mp_int *x)
{
uintmax_t toret = 0;
for (size_t i = x->nw; i-- > 0 ;) {
/* Shift in two stages to avoid undefined behaviour if the
* shift count equals the integer width */
toret = (toret << (BIGNUM_INT_BITS/2)) << (BIGNUM_INT_BITS/2);
toret |= x->w[i];
}
return toret;
}
void mp_set_bit(mp_int *x, size_t bit, unsigned val)
{
size_t word = bit / BIGNUM_INT_BITS;
assert(word < x->nw);
unsigned shift = (bit % BIGNUM_INT_BITS);
x->w[word] &= ~((BignumInt)1 << shift);
x->w[word] |= (BignumInt)(val & 1) << shift;
}
/*
* Helper function used here and there to normalise any nonzero input
* value to 1.
*/
static inline unsigned normalise_to_1(BignumInt n)
{
n = (n >> 1) | (n & 1); /* ensure top bit is clear */
n = (BignumInt)(-n) >> (BIGNUM_INT_BITS - 1); /* normalise to 0 or 1 */
return n;
}
static inline unsigned normalise_to_1_u64(uint64_t n)
{
n = (n >> 1) | (n & 1); /* ensure top bit is clear */
n = (-n) >> 63; /* normalise to 0 or 1 */
return n;
}
/*
* Find the highest nonzero word in a number. Returns the index of the
* word in x->w, and also a pair of output uint64_t in which that word
* appears in the high one shifted left by 'shift_wanted' bits, the
* words immediately below it occupy the space to the right, and the
* words below _that_ fill up the low one.
*
* If there is no nonzero word at all, the passed-by-reference output
* variables retain their original values.
*/
static inline void mp_find_highest_nonzero_word_pair(
mp_int *x, size_t shift_wanted, size_t *index,
uint64_t *hi, uint64_t *lo)
{
uint64_t curr_hi = 0, curr_lo = 0;
for (size_t curr_index = 0; curr_index < x->nw; curr_index++) {
BignumInt curr_word = x->w[curr_index];
unsigned indicator = normalise_to_1(curr_word);
curr_lo = (BIGNUM_INT_BITS < 64 ? (curr_lo >> BIGNUM_INT_BITS) : 0) |
(curr_hi << (64 - BIGNUM_INT_BITS));
curr_hi = (BIGNUM_INT_BITS < 64 ? (curr_hi >> BIGNUM_INT_BITS) : 0) |
((uint64_t)curr_word << shift_wanted);
if (hi) *hi ^= (curr_hi ^ *hi ) & -(uint64_t)indicator;
if (lo) *lo ^= (curr_lo ^ *lo ) & -(uint64_t)indicator;
if (index) *index ^= (curr_index ^ *index) & -(size_t) indicator;
}
}
size_t mp_get_nbits(mp_int *x)
{
/* Sentinel values in case there are no bits set at all: we
* imagine that there's a word at position -1 (i.e. the topmost
* fraction word) which is all 1s, because that way, we handle a
* zero input by considering its highest set bit to be the top one
* of that word, i.e. just below the units digit, i.e. at bit
* index -1, i.e. so we'll return 0 on output. */
size_t hiword_index = -(size_t)1;
uint64_t hiword64 = ~(BignumInt)0;
/*
* Find the highest nonzero word and its index.
*/
mp_find_highest_nonzero_word_pair(x, 0, &hiword_index, &hiword64, NULL);
BignumInt hiword = hiword64; /* in case BignumInt is a narrower type */
/*
* Find the index of the highest set bit within hiword.
*/
BignumInt hibit_index = 0;
for (size_t i = (1 << (BIGNUM_INT_BITS_BITS-1)); i != 0; i >>= 1) {
BignumInt shifted_word = hiword >> i;
BignumInt indicator =
(BignumInt)(-shifted_word) >> (BIGNUM_INT_BITS-1);
hiword ^= (shifted_word ^ hiword ) & -indicator;
hibit_index += i & -(size_t)indicator;
}
/*
* Put together the result.
*/
return (hiword_index << BIGNUM_INT_BITS_BITS) + hibit_index + 1;
}
/*
* Shared code between the hex and decimal output functions to get rid
* of leading zeroes on the output string. The idea is that we wrote
* out a fixed number of digits and a trailing \0 byte into 'buf', and
* now we want to shift it all left so that the first nonzero digit
* moves to buf[0] (or, if there are no nonzero digits at all, we move
* up by 'maxtrim', so that we return 0 as "0" instead of "").
*/
static void trim_leading_zeroes(char *buf, size_t bufsize, size_t maxtrim)
{
size_t trim = maxtrim;
/*
* Look for the first character not equal to '0', to find the
* shift count.
*/
if (trim > 0) {
for (size_t pos = trim; pos-- > 0 ;) {
uint8_t diff = buf[pos] ^ '0';
size_t mask = -((((size_t)diff) - 1) >> (SIZE_T_BITS - 1));
trim ^= (trim ^ pos) & ~mask;
}
}
/*
* Now do the shift, in log n passes each of which does a
* conditional shift by 2^i bytes if bit i is set in the shift
* count.
*/
uint8_t *ubuf = (uint8_t *)buf;
for (size_t logd = 0; bufsize >> logd; logd++) {
uint8_t mask = -(uint8_t)((trim >> logd) & 1);
size_t d = (size_t)1 << logd;
for (size_t i = 0; i+d < bufsize; i++) {
uint8_t diff = mask & (ubuf[i] ^ ubuf[i+d]);
ubuf[i] ^= diff;
ubuf[i+d] ^= diff;
}
}
}
/*
* Binary to decimal conversion. Our strategy here is to extract each
* decimal digit by finding the input number's residue mod 10, then
* subtract that off to give an exact multiple of 10, which then means
* you can safely divide by 10 by means of shifting right one bit and
* then multiplying by the inverse of 5 mod 2^n.
*/
char *mp_get_decimal(mp_int *x_orig)
{
mp_int *x = mp_copy(x_orig), *y = mp_make_sized(x->nw);
/*
* The inverse of 5 mod 2^lots is 0xccccccccccccccccccccd, for an
* appropriate number of 'c's. Manually construct an integer the
* right size.
*/
mp_int *inv5 = mp_make_sized(x->nw);
assert(BIGNUM_INT_BITS % 8 == 0);
for (size_t i = 0; i < inv5->nw; i++)
inv5->w[i] = BIGNUM_INT_MASK / 5 * 4;
inv5->w[0]++;
/*
* 146/485 is an upper bound (and also a continued-fraction
* convergent) of log10(2), so this is a conservative estimate of
* the number of decimal digits needed to store a value that fits
* in this many binary bits.
*/
assert(x->nw < (~(size_t)1) / (146 * BIGNUM_INT_BITS));
size_t bufsize = size_t_max(x->nw * (146 * BIGNUM_INT_BITS) / 485, 1) + 2;
char *outbuf = snewn(bufsize, char);
outbuf[bufsize - 1] = '\0';
/*
* Loop over the number generating digits from the least
* significant upwards, so that we write to outbuf in reverse
* order.
*/
for (size_t pos = bufsize - 1; pos-- > 0 ;) {
/*
* Find the current residue mod 10. We do this by first
* summing the bytes of the number, with all but the lowest
* one multiplied by 6 (because 256^i == 6 mod 10 for all
* i>0). That gives us a single word congruent mod 10 to the
* input number, and then we reduce it further by manual
* multiplication and shifting, just in case the compiler
* target implements the C division operator in a way that has
* input-dependent timing.
*/
uint32_t low_digit = 0, maxval = 0, mult = 1;
for (size_t i = 0; i < x->nw; i++) {
for (unsigned j = 0; j < BIGNUM_INT_BYTES; j++) {
low_digit += mult * (0xFF & (x->w[i] >> (8*j)));
maxval += mult * 0xFF;
mult = 6;
}
/*
* For _really_ big numbers, prevent overflow of t by
* periodically folding the top half of the accumulator
* into the bottom half, using the same rule 'multiply by
* 6 when shifting down by one or more whole bytes'.
*/
if (maxval > UINT32_MAX - (6 * 0xFF * BIGNUM_INT_BYTES)) {
low_digit = (low_digit & 0xFFFF) + 6 * (low_digit >> 16);
maxval = (maxval & 0xFFFF) + 6 * (maxval >> 16);
}
}
/*
* Final reduction of low_digit. We multiply by 2^32 / 10
* (that's the constant 0x19999999) to get a 64-bit value
* whose top 32 bits are the approximate quotient
* low_digit/10; then we subtract off 10 times that; and
* finally we do one last trial subtraction of 10 by adding 6
* (which sets bit 4 if the number was just over 10) and then
* testing bit 4.
*/
low_digit -= 10 * ((0x19999999ULL * low_digit) >> 32);
low_digit -= 10 * ((low_digit + 6) >> 4);
assert(low_digit < 10); /* make sure we did reduce fully */
outbuf[pos] = '0' + low_digit;
/*
* Now subtract off that digit, divide by 2 (using a right
* shift) and by 5 (using the modular inverse), to get the
* next output digit into the units position.
*/
mp_sub_integer_into(x, x, low_digit);
mp_rshift_fixed_into(y, x, 1);
mp_mul_into(x, y, inv5);
}
mp_free(x);
mp_free(y);
mp_free(inv5);
trim_leading_zeroes(outbuf, bufsize, bufsize - 2);
return outbuf;
}
/*
* Binary to hex conversion. Reasonably simple (only a spot of bit
* twiddling to choose whether to output a digit or a letter for each
* nibble).
*/
static char *mp_get_hex_internal(mp_int *x, uint8_t letter_offset)
{
size_t nibbles = x->nw * BIGNUM_INT_BYTES * 2;
size_t bufsize = nibbles + 1;
char *outbuf = snewn(bufsize, char);
outbuf[nibbles] = '\0';
for (size_t nibble = 0; nibble < nibbles; nibble++) {
size_t word_idx = nibble / (BIGNUM_INT_BYTES*2);
size_t nibble_within_word = nibble % (BIGNUM_INT_BYTES*2);
uint8_t digitval = 0xF & (x->w[word_idx] >> (nibble_within_word * 4));
uint8_t mask = -((digitval + 6) >> 4);
char digit = digitval + '0' + (letter_offset & mask);
outbuf[nibbles-1 - nibble] = digit;
}
trim_leading_zeroes(outbuf, bufsize, nibbles - 1);
return outbuf;
}
char *mp_get_hex(mp_int *x)
{
return mp_get_hex_internal(x, 'a' - ('0'+10));
}
char *mp_get_hex_uppercase(mp_int *x)
{
return mp_get_hex_internal(x, 'A' - ('0'+10));
}
/*
* Routines for reading and writing the SSH-1 and SSH-2 wire formats
* for multiprecision integers, declared in marshal.h.
*
* These can't avoid having control flow dependent on the true bit
* size of the number, because the wire format requires the number of
* output bytes to depend on that.
*/
void BinarySink_put_mp_ssh1(BinarySink *bs, mp_int *x)
{
size_t bits = mp_get_nbits(x);
size_t bytes = (bits + 7) / 8;
assert(bits < 0x10000);
put_uint16(bs, bits);
for (size_t i = bytes; i-- > 0 ;)
put_byte(bs, mp_get_byte(x, i));
}
void BinarySink_put_mp_ssh2(BinarySink *bs, mp_int *x)
{
size_t bytes = (mp_get_nbits(x) + 8) / 8;
put_uint32(bs, bytes);
for (size_t i = bytes; i-- > 0 ;)
put_byte(bs, mp_get_byte(x, i));
}
mp_int *BinarySource_get_mp_ssh1(BinarySource *src)
{
unsigned bitc = get_uint16(src);
ptrlen bytes = get_data(src, (bitc + 7) / 8);
if (get_err(src)) {
return mp_from_integer(0);
} else {
mp_int *toret = mp_from_bytes_be(bytes);
/* SSH-1.5 spec says that it's OK for the prefix uint16 to be
* _greater_ than the actual number of bits */
if (mp_get_nbits(toret) > bitc) {
src->err = BSE_INVALID;
mp_free(toret);
toret = mp_from_integer(0);
}
return toret;
}
}
mp_int *BinarySource_get_mp_ssh2(BinarySource *src)
{
ptrlen bytes = get_string(src);
if (get_err(src)) {
return mp_from_integer(0);
} else {
const unsigned char *p = bytes.ptr;
if ((bytes.len > 0 &&
((p[0] & 0x80) ||
(p[0] == 0 && (bytes.len <= 1 || !(p[1] & 0x80)))))) {
src->err = BSE_INVALID;
return mp_from_integer(0);
}
return mp_from_bytes_be(bytes);
}
}
/*
* Make an mp_int structure whose words array aliases a subinterval of
* some other mp_int. This makes it easy to read or write just the low
* or high words of a number, e.g. to add a number starting from a
* high bit position, or to reduce mod 2^{n*BIGNUM_INT_BITS}.
*
* The convention throughout this code is that when we store an mp_int
* directly by value, we always expect it to be an alias of some kind,
* so its words array won't ever need freeing. Whereas an 'mp_int *'
* has an owner, who knows whether it needs freeing or whether it was
* created by address-taking an alias.
*/
static mp_int mp_make_alias(mp_int *in, size_t offset, size_t len)
{
/*
* Bounds-check the offset and length so that we always return
* something valid, even if it's not necessarily the length the
* caller asked for.
*/
if (offset > in->nw)
offset = in->nw;
if (len > in->nw - offset)
len = in->nw - offset;
mp_int toret;
toret.nw = len;
toret.w = in->w + offset;
return toret;
}
/*
* A special case of mp_make_alias: in some cases we preallocate a
* large mp_int to use as scratch space (to avoid pointless
* malloc/free churn in recursive or iterative work).
*
* mp_alloc_from_scratch creates an alias of size 'len' to part of
* 'pool', and adjusts 'pool' itself so that further allocations won't
* overwrite that space.
*
* There's no free function to go with this. Typically you just copy
* the pool mp_int by value, allocate from the copy, and when you're
* done with those allocations, throw the copy away and go back to the
* original value of pool. (A mark/release system.)
*/
static mp_int mp_alloc_from_scratch(mp_int *pool, size_t len)
{
assert(len <= pool->nw);
mp_int toret = mp_make_alias(pool, 0, len);
*pool = mp_make_alias(pool, len, pool->nw);
return toret;
}
/*
* Internal component common to lots of assorted add/subtract code.
* Reads words from a,b; writes into w_out (which might be NULL if the
* output isn't even needed). Takes an input carry flag in 'carry',
* and returns the output carry. Each word read from b is ANDed with
* b_and and then XORed with b_xor.
*
* So you can implement addition by setting b_and to all 1s and b_xor
* to 0; you can subtract by making b_xor all 1s too (effectively
* bit-flipping b) and also passing 1 as the input carry (to turn
* one's complement into two's complement). And you can do conditional
* add/subtract by choosing b_and to be all 1s or all 0s based on a
* condition, because the value of b will be totally ignored if b_and
* == 0.
*/
static BignumCarry mp_add_masked_into(
BignumInt *w_out, size_t rw, mp_int *a, mp_int *b,
BignumInt b_and, BignumInt b_xor, BignumCarry carry)
{
for (size_t i = 0; i < rw; i++) {
BignumInt aword = mp_word(a, i), bword = mp_word(b, i), out;
bword = (bword & b_and) ^ b_xor;
BignumADC(out, carry, aword, bword, carry);
if (w_out)
w_out[i] = out;
}
return carry;
}
/*
* Like the public mp_add_into except that it returns the output carry.
*/
static inline BignumCarry mp_add_into_internal(mp_int *r, mp_int *a, mp_int *b)
{
return mp_add_masked_into(r->w, r->nw, a, b, ~(BignumInt)0, 0, 0);
}
void mp_add_into(mp_int *r, mp_int *a, mp_int *b)
{
mp_add_into_internal(r, a, b);
}
void mp_sub_into(mp_int *r, mp_int *a, mp_int *b)
{
mp_add_masked_into(r->w, r->nw, a, b, ~(BignumInt)0, ~(BignumInt)0, 1);
}
void mp_and_into(mp_int *r, mp_int *a, mp_int *b)
{
for (size_t i = 0; i < r->nw; i++) {
BignumInt aword = mp_word(a, i), bword = mp_word(b, i);
r->w[i] = aword & bword;
}
}
void mp_or_into(mp_int *r, mp_int *a, mp_int *b)
{
for (size_t i = 0; i < r->nw; i++) {
BignumInt aword = mp_word(a, i), bword = mp_word(b, i);
r->w[i] = aword | bword;
}
}
void mp_xor_into(mp_int *r, mp_int *a, mp_int *b)
{
for (size_t i = 0; i < r->nw; i++) {
BignumInt aword = mp_word(a, i), bword = mp_word(b, i);
r->w[i] = aword ^ bword;
}
}
void mp_bic_into(mp_int *r, mp_int *a, mp_int *b)
{
for (size_t i = 0; i < r->nw; i++) {
BignumInt aword = mp_word(a, i), bword = mp_word(b, i);
r->w[i] = aword & ~bword;
}
}
static void mp_cond_negate(mp_int *r, mp_int *x, unsigned yes)
{
BignumCarry carry = yes;
BignumInt flip = -(BignumInt)yes;
for (size_t i = 0; i < r->nw; i++) {
BignumInt xword = mp_word(x, i);
xword ^= flip;
BignumADC(r->w[i], carry, 0, xword, carry);
}
}
/*
* Similar to mp_add_masked_into, but takes a C integer instead of an
* mp_int as the masked operand.
*/
static BignumCarry mp_add_masked_integer_into(
BignumInt *w_out, size_t rw, mp_int *a, uintmax_t b,
BignumInt b_and, BignumInt b_xor, BignumCarry carry)
{
for (size_t i = 0; i < rw; i++) {
BignumInt aword = mp_word(a, i);
size_t shift = i * BIGNUM_INT_BITS;
BignumInt bword = shift < BIGNUM_INT_BYTES ? b >> shift : 0;
BignumInt out;
bword = (bword ^ b_xor) & b_and;
BignumADC(out, carry, aword, bword, carry);
if (w_out)
w_out[i] = out;
}
return carry;
}
void mp_add_integer_into(mp_int *r, mp_int *a, uintmax_t n)
{
mp_add_masked_integer_into(r->w, r->nw, a, n, ~(BignumInt)0, 0, 0);
}
void mp_sub_integer_into(mp_int *r, mp_int *a, uintmax_t n)
{
mp_add_masked_integer_into(r->w, r->nw, a, n,
~(BignumInt)0, ~(BignumInt)0, 1);
}
/*
* Sets r to a + n << (word_index * BIGNUM_INT_BITS), treating
* word_index as secret data.
*/
static void mp_add_integer_into_shifted_by_words(
mp_int *r, mp_int *a, uintmax_t n, size_t word_index)
{
unsigned indicator = 0;
BignumCarry carry = 0;
for (size_t i = 0; i < r->nw; i++) {
/* indicator becomes 1 when we reach the index that the least
* significant bits of n want to be placed at, and it stays 1
* thereafter. */
indicator |= 1 ^ normalise_to_1(i ^ word_index);
/* If indicator is 1, we add the low bits of n into r, and
* shift n down. If it's 0, we add zero bits into r, and
* leave n alone. */
BignumInt bword = n & -(BignumInt)indicator;
uintmax_t new_n = (BIGNUM_INT_BITS < 64 ? n >> BIGNUM_INT_BITS : 0);
n ^= (n ^ new_n) & -(uintmax_t)indicator;
BignumInt aword = mp_word(a, i);
BignumInt out;
BignumADC(out, carry, aword, bword, carry);
r->w[i] = out;
}
}
void mp_mul_integer_into(mp_int *r, mp_int *a, uint16_t n)
{
BignumInt carry = 0, mult = n;
for (size_t i = 0; i < r->nw; i++) {
BignumInt aword = mp_word(a, i);
BignumMULADD(carry, r->w[i], aword, mult, carry);
}
assert(!carry);
}
void mp_cond_add_into(mp_int *r, mp_int *a, mp_int *b, unsigned yes)
{
BignumInt mask = -(BignumInt)(yes & 1);
mp_add_masked_into(r->w, r->nw, a, b, mask, 0, 0);
}
void mp_cond_sub_into(mp_int *r, mp_int *a, mp_int *b, unsigned yes)
{
BignumInt mask = -(BignumInt)(yes & 1);
mp_add_masked_into(r->w, r->nw, a, b, mask, mask, 1 & mask);
}
/*
* Ordered comparison between unsigned numbers is done by subtracting
* one from the other and looking at the output carry.
*/
unsigned mp_cmp_hs(mp_int *a, mp_int *b)
{
size_t rw = size_t_max(a->nw, b->nw);
return mp_add_masked_into(NULL, rw, a, b, ~(BignumInt)0, ~(BignumInt)0, 1);
}
unsigned mp_hs_integer(mp_int *x, uintmax_t n)
{
BignumInt carry = 1;
for (size_t i = 0; i < x->nw; i++) {
size_t shift = i * BIGNUM_INT_BITS;
BignumInt nword = shift < CHAR_BIT*sizeof(n) ? n >> shift : 0;
BignumInt dummy_out;
BignumADC(dummy_out, carry, x->w[i], ~nword, carry);
(void)dummy_out;
}
return carry;
}
/*
* Equality comparison is done by bitwise XOR of the input numbers,
* ORing together all the output words, and normalising the result
* using our careful normalise_to_1 helper function.
*/
unsigned mp_cmp_eq(mp_int *a, mp_int *b)
{
BignumInt diff = 0;
for (size_t i = 0, limit = size_t_max(a->nw, b->nw); i < limit; i++)
diff |= mp_word(a, i) ^ mp_word(b, i);
return 1 ^ normalise_to_1(diff); /* return 1 if diff _is_ zero */
}
unsigned mp_eq_integer(mp_int *x, uintmax_t n)
{
BignumInt diff = 0;
for (size_t i = 0; i < x->nw; i++) {
size_t shift = i * BIGNUM_INT_BITS;
BignumInt nword = shift < CHAR_BIT*sizeof(n) ? n >> shift : 0;
diff |= x->w[i] ^ nword;
}
return 1 ^ normalise_to_1(diff); /* return 1 if diff _is_ zero */
}
void mp_neg_into(mp_int *r, mp_int *a)
{
mp_int zero;
zero.nw = 0;
mp_sub_into(r, &zero, a);
}
mp_int *mp_add(mp_int *x, mp_int *y)
{
mp_int *r = mp_make_sized(size_t_max(x->nw, y->nw) + 1);
mp_add_into(r, x, y);
return r;
}
mp_int *mp_sub(mp_int *x, mp_int *y)
{
mp_int *r = mp_make_sized(size_t_max(x->nw, y->nw));
mp_sub_into(r, x, y);
return r;
}
mp_int *mp_neg(mp_int *a)
{
mp_int *r = mp_make_sized(a->nw);
mp_neg_into(r, a);
return r;
}
/*
* Internal routine: multiply and accumulate in the trivial O(N^2)
* way. Sets r <- r + a*b.
*/
static void mp_mul_add_simple(mp_int *r, mp_int *a, mp_int *b)
{
BignumInt *aend = a->w + a->nw, *bend = b->w + b->nw, *rend = r->w + r->nw;
for (BignumInt *ap = a->w, *rp = r->w;
ap < aend && rp < rend; ap++, rp++) {
BignumInt adata = *ap, carry = 0, *rq = rp;
for (BignumInt *bp = b->w; bp < bend && rq < rend; bp++, rq++) {
BignumInt bdata = bp < bend ? *bp : 0;
BignumMULADD2(carry, *rq, adata, bdata, *rq, carry);
}
for (; rq < rend; rq++)
BignumADC(*rq, carry, carry, *rq, 0);
}
}
#ifndef KARATSUBA_THRESHOLD /* allow redefinition via -D for testing */
#define KARATSUBA_THRESHOLD 24
#endif
static inline size_t mp_mul_scratchspace_unary(size_t n)
{
/*
* Simplistic and overcautious bound on the amount of scratch
* space that the recursive multiply function will need.
*
* The rationale is: on the main Karatsuba branch of
* mp_mul_internal, which is the most space-intensive one, we
* allocate space for (a0+a1) and (b0+b1) (each just over half the
* input length n) and their product (the sum of those sizes, i.e.
* just over n itself). Then in order to actually compute the
* product, we do a recursive multiplication of size just over n.
*
* If all those 'just over' weren't there, and everything was
* _exactly_ half the length, you'd get the amount of space for a
* size-n multiply defined by the recurrence M(n) = 2n + M(n/2),
* which is satisfied by M(n) = 4n. But instead it's (2n plus a
* word or two) and M(n/2 plus a word or two). On the assumption
* that there's still some constant k such that M(n) <= kn, this
* gives us kn = 2n + w + k(n/2 + w), where w is a small constant
* (one or two words). That simplifies to kn/2 = 2n + (k+1)w, and
* since we don't even _start_ needing scratch space until n is at
* least 50, we can bound 2n + (k+1)w above by 3n, giving k=6.
*
* So I claim that 6n words of scratch space will suffice, and I
* check that by assertion at every stage of the recursion.
*/
return n * 6;
}
static size_t mp_mul_scratchspace(size_t rw, size_t aw, size_t bw)
{
size_t inlen = size_t_min(rw, size_t_max(aw, bw));
return mp_mul_scratchspace_unary(inlen);
}
static void mp_mul_internal(mp_int *r, mp_int *a, mp_int *b, mp_int scratch)
{
size_t inlen = size_t_min(r->nw, size_t_max(a->nw, b->nw));
assert(scratch.nw >= mp_mul_scratchspace_unary(inlen));
mp_clear(r);
if (inlen < KARATSUBA_THRESHOLD || a->nw == 0 || b->nw == 0) {
/*
* The input numbers are too small to bother optimising. Go
* straight to the simple primitive approach.
*/
mp_mul_add_simple(r, a, b);
return;
}
/*
* Karatsuba divide-and-conquer algorithm. We cut each input in
* half, so that it's expressed as two big 'digits' in a giant
* base D:
*
* a = a_1 D + a_0
* b = b_1 D + b_0
*
* Then the product is of course
*
* ab = a_1 b_1 D^2 + (a_1 b_0 + a_0 b_1) D + a_0 b_0
*
* and we compute the three coefficients by recursively calling