-
Notifications
You must be signed in to change notification settings - Fork 123
/
Karatsuba.cry
94 lines (70 loc) · 3.23 KB
/
Karatsuba.cry
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
/*
* Copyright (c) 2017 Galois, Inc.
* Distributed under the terms of the BSD3 license (see LICENSE file)
*
* This module defines a Karatsuba multiplier, which is polymorphic over the
* bitwidth of the constituent limbs and the size of the inputs.
*
* Note the trick of using `make_atleast` in the recursive case to teach the
* typecheker that the current bitwidth is long enough for the split multiplier
* to be sensible. The minimum width for which the splitting recursive case
* works out happens to be 6. This is the smallest width for which the
* divide-in-half strategy leaves enough bits left over for the additions
* steps not to overflow.
*
* Note that if one wishes to simulate using a standard CPU to perform Karatsuba
* multiplies, the limb width should be set to ONE LESS THAN the standard
* machine word size, e.g. 63 bits for a 64 bit machine. This ensures that
* the intermediate additions will fit into a machine word. Note that the the
* final `2*(limb+1)`-width additions will also have to be assembled manually
* using "add with carry" operations or similar.
*/
module Karatsuba where
// The divide-and-conquer Karatsuba multiplier. If the argument width
// is greater than the limb width, then the recursive `splitmult` multiplier
// is called; otherwise, a standard multiply is used.
kmult : {limb,n} (fin n, fin limb, limb >= 6, n >= 1) => [n] -> [n] -> [2*n]
kmult x y =
if `n >= (`limb : [max (width limb) (width n)]) then
take (splitmult`{limb} (make_atleast`{limb} x) (make_atleast`{limb} y))
else
(zero#x)*(zero#y)
// Force a bitvector to be at least `n` bits long. As used above, this operation
// should always be a no-op, but teaches the typechecker a necessary inequality
// to call the `splitmult` operation.
make_atleast : {n, m, a} (fin n, fin m, Zero a) => [m]a -> [max n m]a
make_atleast x = zero#x
// Execute the recursive step of Karatsuba's multiplication. Split the
// input words into high and low bit portions; then perform three
// recursive multiplies using these intermediate, shorter bitsequences.
splitmult : {limb,n} (fin n, fin limb, limb >= 6, n >= 6)
=> [n] -> [n] -> [2*n]
splitmult x y = (ac # bd) + (zero # ad_bc # (zero:[low]))
where
type hi = n/2
type low = n - hi
(a,b) = splitAt`{hi} x
(c,d) = splitAt`{hi} y
ac : [2*hi]
ac = kmult`{limb,hi} a c
bd : [2*low]
bd = kmult`{limb,low} b d
a_b = (zext a) + (zext b)
c_d = (zext c) + (zext d)
ad_bc : [2*(low+1)]
ad_bc = (kmult`{limb,low+1} a_b c_d) - (zext ac) - (zext bd)
// Verify Karatsuba's algorithm computes the correct answer
// for some fixed settings of the parameters.
//
// SMT solvers have great difficutly with these proofs, and I have
// only gotten small bit sizes to return with successful proofs.
property splitmult_correct_tiny (x:[9]) (y:[9]) =
zext x * zext y == splitmult`{limb=7} x y
property splitmult_correct_small (x:[11]) (y:[11]) =
zext x * zext y == splitmult`{limb=7} x y
property splitmult_correct_medium(x:[17]) (y:[17]) =
zext x * zext y == splitmult`{limb=7} x y
property splitmult_correct_large (x:[59]) (y:[59]) =
zext x * zext y == splitmult`{limb=7} x y
property splitmult_correct_huge (x:[511]) (y:[511]) =
zext x * zext y == splitmult`{limb=63} x y