Skip to content

Commit

Permalink
Informal index-based proof of the modular permutation transpose kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
mlochbaum committed Nov 8, 2024
1 parent 964cf02 commit 470bc42
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 25 deletions.
50 changes: 37 additions & 13 deletions docs/implementation/primitive/transpose.html
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,12 @@ <h3 id="short-kernels"><a class="header" href="#short-kernels">Short kernels</a>
<line x1='613.7' x2='613.7' y1='10.2' y2='57.8'/>
<line x1='618.8' x2='618.8' y1='10.2' y2='74.8'/>
<line x1='623.9' x2='623.9' y1='10.2' y2='91.8'/>
<line x1='890.8' x2='860.2' y1='144.5' y2='144.5'/>
<line x1='890.8' x2='809.2' y1='139.4' y2='139.4'/>
<line x1='890.8' x2='860.2' y1='124.1' y2='124.1'/>
<line x1='890.8' x2='809.2' y1='129.2' y2='129.2'/>
<line x1='890.8' x2='758.2' y1='134.3' y2='134.3'/>
<line x1='890.8' x2='707.2' y1='129.2' y2='129.2'/>
<line x1='890.8' x2='656.2' y1='124.1' y2='124.1'/>
<path d='M93.5 153l8.5 17v-51l8.5 17'/>
<line x1='890.8' x2='707.2' y1='139.4' y2='139.4'/>
<line x1='890.8' x2='656.2' y1='144.5' y2='144.5'/>
<path d='M91.8 154.7l10.2 15.3v-51l10.2 15.3'/>
</g>
<g fill='currentColor'>
<text dy='0.32em' x='51' y='187'>A0</text>
Expand Down Expand Up @@ -887,24 +887,48 @@ <h3 id="short-kernels"><a class="header" href="#short-kernels">Short kernels</a>
<text dy='0.32em' x='884' y='442'>O6</text>
<text dy='0.32em' x='901' y='442'>P6</text>
</g>
<g stroke='currentColor' fill='none' opacity='0.4' stroke-width='12' stroke-linecap='round'>
<line x1='323' x2='459' y1='204' y2='340'/>
<circle cx='391' cy='272' r='61.2'/>
<path d='M248.2 438.6l-61.2 -30.6l61.2 -30.6m-61.2 30.6L612 408m-61.2 -30.6l61.2 30.6l-61.2 30.6'/>
<g stroke='currentColor' fill='none' opacity='0.4' stroke-linecap='round'>
<g stroke-width='14'>
<line x1='323' x2='459' y1='204' y2='340'/>
<circle cx='391' cy='272' r='61.2'/>
</g>
<g stroke-width='10'><path d='M248.2 438.6l-61.2 -30.6l61.2 -30.6m-61.2 30.6L612 408m-61.2 -30.6l61.2 30.6l-61.2 30.6'/></g>
</g>
</g>
</svg>

<p>For an odd width <code><span class='Value'>w</span></code>, the modular permutation works by moving through a representation where elements are stored along a wrapping diagonal: element <code><span class='Value'>i</span></code> gets position (vector index, index within vector) <code><span class='Value'>w</span><span class='Ligature'></span><span class='Value'>v</span><span class='Function'>|</span><span class='Value'>i</span></code> where <code><span class='Value'>w</span></code> is the number of vectors and <code><span class='Value'>v</span></code> is the length of each. All <code><span class='Value'>w</span><span class='Function'>×</span><span class='Value'>v</span></code> positions are unique by the Chinese remainder theorem. The steps are symmetric around this representation, with a permutation and a shearing step on each side. Here are the steps when starting with a short width:</p>
<p>For an odd width <code><span class='Value'>w</span></code>, elements are stored in <code><span class='Value'>w</span></code> vectors of length <code><span class='Value'>v</span></code>. The modular permutation kernel works by moving through an ordering with elements are stored along a wrapping diagonal: element <code><span class='Value'>i</span></code> gets position <code><span class='Value'>w</span><span class='Ligature'></span><span class='Value'>v</span><span class='Function'>|</span><span class='Value'>i</span></code> (vector index, index within vector). All <code><span class='Value'>w</span><span class='Function'>×</span><span class='Value'>v</span></code> positions are unique by the Chinese remainder theorem. Essentially, we move the single-step direction from horizontal (second axis) to diagonal (both axes) to vertical (first axis). So the steps are symmetric around this ordering, with a permutation and a shearing step on each side. Here's what to do when starting with a short width:</p>
<ul>
<li>Load contiguous rows into packed vectors</li>
<li>Permute each column by virtually reordering the registers (free)</li>
<li>Rotate each column by its index modulo <code><span class='Value'>w</span></code></li>
<li>Rotate each row by its index</li>
<li>Rotate each column forward by its index</li>
<li>Rotate each row backwards by its index</li>
<li>Permute each row with a shuffle (can be combined with previous)</li>
<li>Store each vector as part of a result row</li>
</ul>
<p>The shearing step is where most of the work happens because it's the only step that transfers elements between registers. It can be performed with <code><span class='Function'></span><span class='Number'>2</span><span class='Function'></span><span class='Modifier'></span><span class='Value'>w</span></code> steps, each one handling a fixed power of two smaller than <code><span class='Value'>w</span></code>. The step for <code><span class='Number'>2</span><span class='Function'></span><span class='Value'>i</span></code> rotates each column whose index has that bit set, by blending a given register with another whose index differs by <code><span class='Number'>2</span><span class='Function'></span><span class='Value'>i</span></code>.</p>
<p>The shearing step that rotates columns is where most of the work happens because it's the only step that transfers elements between registers. It can be performed with <code><span class='Function'></span><span class='Number'>2</span><span class='Function'></span><span class='Modifier'></span><span class='Value'>w</span></code> steps, each one handling a fixed power of two smaller than <code><span class='Value'>w</span></code>. The step for <code><span class='Number'>2</span><span class='Function'></span><span class='Value'>i</span></code> rotates each column whose index has that bit set, by blending a given register with another whose index differs by <code><span class='Number'>2</span><span class='Function'></span><span class='Value'>i</span></code>.</p>
<p>To explain why this works, and nail down the exact permutation operations used, we'll initially number the elements in index order. So the index of element <code><span class='Value'>i</span></code> in the initial kernel is <code><span class='Bracket'></span><span class='Function'></span><span class='Value'>i</span><span class='Function'>÷</span><span class='Value'>w</span><span class='Separator'>,</span> <span class='Value'>w</span><span class='Function'>|</span><span class='Value'>i</span><span class='Bracket'></span></code>, it should end up at <code><span class='Bracket'></span><span class='Value'>w</span><span class='Function'>|</span><span class='Value'>i</span><span class='Separator'>,</span> <span class='Function'></span><span class='Value'>i</span><span class='Function'>÷</span><span class='Value'>w</span><span class='Bracket'></span></code> (note that computing an element's position from its value is the opposite of the normal array selection). The initial reshaping into vector registers puts element <code><span class='Value'>i</span></code> at <code><span class='Bracket'></span><span class='Function'></span><span class='Value'>i</span><span class='Function'>÷</span><span class='Value'>v</span><span class='Separator'>,</span> <span class='Value'>v</span><span class='Function'>|</span><span class='Value'>i</span><span class='Bracket'></span></code>. Our register permutation will move row <code><span class='Value'>j</span></code> to <code><span class='Value'>w</span><span class='Function'>|</span><span class='Value'>v</span><span class='Function'>×</span><span class='Value'>j</span></code>, thus sending element <code><span class='Value'>i</span></code> from row <code><span class='Function'></span><span class='Value'>i</span><span class='Function'>÷</span><span class='Value'>v</span></code> to <code><span class='Value'>w</span><span class='Function'>|</span><span class='Value'>v</span><span class='Function'>×⌊</span><span class='Value'>i</span><span class='Function'>÷</span><span class='Value'>v</span></code>, but by the definition of <code><span class='Function'>|</span></code>, this is equal to <code><span class='Value'>w</span><span class='Function'>|</span><span class='Value'>i</span><span class='Function'>-</span><span class='Value'>v</span><span class='Function'>|</span><span class='Value'>i</span></code>! Now we rotate, we add the horizontal index to the vertical, modulo <code><span class='Value'>w</span></code>, moving from <code><span class='Bracket'></span><span class='Value'>w</span><span class='Function'>|</span><span class='Value'>i</span><span class='Function'>-</span><span class='Value'>v</span><span class='Function'>|</span><span class='Value'>i</span><span class='Separator'>,</span> <span class='Value'>v</span><span class='Function'>|</span><span class='Value'>i</span><span class='Bracket'></span></code> to <code><span class='Bracket'></span><span class='Value'>w</span><span class='Function'>|</span><span class='Value'>i</span><span class='Separator'>,</span> <span class='Value'>v</span><span class='Function'>|</span><span class='Value'>i</span><span class='Bracket'></span></code>. The remaining steps are entirely symmetrical: performed in the same order, they would move from <code><span class='Bracket'></span><span class='Value'>w</span><span class='Function'>|</span><span class='Value'>i</span><span class='Separator'>,</span> <span class='Function'></span><span class='Value'>i</span><span class='Function'>÷</span><span class='Value'>w</span><span class='Bracket'></span></code> to <code><span class='Bracket'></span><span class='Value'>w</span><span class='Function'>|</span><span class='Value'>i</span><span class='Separator'>,</span> <span class='Value'>v</span><span class='Function'>|</span><span class='Value'>i</span><span class='Bracket'></span></code>, affecting the second index instead of the first. Doing them backwards connects to the previous result. In sum, element <code><span class='Value'>i</span></code> moves through the following positions:</p>
<pre><span class='Bracket'></span><span class='Function'></span><span class='Value'>i</span><span class='Function'>÷</span><span class='Value'>w</span><span class='Separator'>,</span> <span class='Value'>w</span><span class='Function'>|</span><span class='Value'>i</span><span class='Bracket'></span>
<span class='Bracket'></span><span class='Function'></span><span class='Value'>i</span><span class='Function'>÷</span><span class='Value'>v</span><span class='Separator'>,</span> <span class='Value'>v</span><span class='Function'>|</span><span class='Value'>i</span><span class='Bracket'></span> <span class='Comment'># w‿v⥊
</span><span class='Bracket'></span><span class='Value'>w</span><span class='Function'>|</span><span class='Value'>i</span><span class='Function'>-</span><span class='Value'>v</span><span class='Function'>|</span><span class='Value'>i</span><span class='Separator'>,</span> <span class='Value'>v</span><span class='Function'>|</span><span class='Value'>i</span><span class='Bracket'></span> <span class='Comment'># (⍋w|v×↕w)⊏
</span><span class='Bracket'></span><span class='Value'>w</span><span class='Function'>|</span><span class='Value'>i</span><span class='Separator'>,</span> <span class='Value'>v</span><span class='Function'>|</span><span class='Value'>i</span><span class='Bracket'></span> <span class='Comment'># (-↕v)⌽˘⌾⍉
</span><span class='Bracket'></span><span class='Value'>w</span><span class='Function'>|</span><span class='Value'>i</span><span class='Separator'>,</span> <span class='Value'>v</span><span class='Function'>|</span><span class='Value'>i</span><span class='Function'>-</span><span class='Value'>w</span><span class='Function'>|</span><span class='Value'>i</span><span class='Bracket'></span> <span class='Comment'># (↕w)⌽˘
</span><span class='Bracket'></span><span class='Value'>w</span><span class='Function'>|</span><span class='Value'>i</span><span class='Separator'>,</span> <span class='Function'></span><span class='Value'>i</span><span class='Function'>÷</span><span class='Value'>w</span><span class='Bracket'></span> <span class='Comment'># (v|w×↕v)⊸⊏˘
</span></pre>
<p>The BQN code for each step is in the comments. Below, it's all run on the indices <code><span class='Value'>i</span></code>:</p>
<a class="replLink" title="Open in the REPL" target="_blank" href="https://mlochbaum.github.io/BQN/try.html#code=d+KAv3bihpA34oC/MTYKKHZ8d8OX4oaVdiniirjiio/LmCAo4oaVdynijL3LmCAoLeKGlXYp4oy9y5jijL7ijYkgKOKNi3d8dsOX4oaVdyniio8gd+KAv3bipYog4oaVd8OXdg==">↗️</a><pre> <span class='Value'>w</span><span class='Ligature'></span><span class='Value'>v</span><span class='Gets'></span><span class='Number'>7</span><span class='Ligature'></span><span class='Number'>16</span>
⟨ 7 16 ⟩
<span class='Paren'>(</span><span class='Value'>v</span><span class='Function'>|</span><span class='Value'>w</span><span class='Function'>×↕</span><span class='Value'>v</span><span class='Paren'>)</span><span class='Modifier2'></span><span class='Function'></span><span class='Modifier'>˘</span> <span class='Paren'>(</span><span class='Function'></span><span class='Value'>w</span><span class='Paren'>)</span><span class='Function'></span><span class='Modifier'>˘</span> <span class='Paren'>(</span><span class='Function'>-↕</span><span class='Value'>v</span><span class='Paren'>)</span><span class='Function'></span><span class='Modifier'>˘</span><span class='Modifier2'></span><span class='Function'></span> <span class='Paren'>(</span><span class='Function'></span><span class='Value'>w</span><span class='Function'>|</span><span class='Value'>v</span><span class='Function'>×↕</span><span class='Value'>w</span><span class='Paren'>)</span><span class='Function'></span> <span class='Value'>w</span><span class='Ligature'></span><span class='Value'>v</span><span class='Function'></span> <span class='Function'></span><span class='Value'>w</span><span class='Function'>×</span><span class='Value'>v</span>
┌─
╵ 0 7 14 21 28 35 42 49 56 63 70 77 84 91 98 105
1 8 15 22 29 36 43 50 57 64 71 78 85 92 99 106
2 9 16 23 30 37 44 51 58 65 72 79 86 93 100 107
3 10 17 24 31 38 45 52 59 66 73 80 87 94 101 108
4 11 18 25 32 39 46 53 60 67 74 81 88 95 102 109
5 12 19 26 33 40 47 54 61 68 75 82 89 96 103 110
6 13 20 27 34 41 48 55 62 69 76 83 90 97 104 111
</pre>
<h3 id="cache-efficient-orderings"><a class="header" href="#cache-efficient-orderings">Cache-efficient orderings</a></h3>
<p>There's some amount of literature on addressing cache issues in transpose. The theory is that by writing in index order, for instance, writes are perfectly cached but reads are very poorly cached. However, in CBQN testing with SIMD kernels, the orderings discussed here were not useful: it was best to simply loop in source order. I think the main reason for this is that an AVX2 register is an entire half the size of a cache line, so revisiting the same line is not that useful, and it confuses whatever predictor is trying to make sure the right lines are available.</p>
<p>The basic cache-friendliness tool is blocking: for example, split the array into blocks of a few kilobytes, transposing each before moving to the next. Multi-layer caches could in theory (but not in practice, it seems) demand multiple layers of blocking, but a cache-oblivious layout—named for its ability to perform well regardless of what layers of cache exist—skips over this mess and adds layers fractally at all scales. Here's <a href="https://en.algorithmica.org/hpc/external-memory/oblivious/#matrix-transposition">one presentation</a> of this idea. A simpler recursive implementation is to just halve the longer side of the matrix at each step. At the other complexity extreme, Hilbert curves offer better locality and can be traced without recursion. A recent <a href="https://dl.acm.org/doi/10.1145/3555353">paper</a> with <a href="https://github.com/JoaoAlves95/HPC-Cache-Oblivious-Transposition">source code</a> offers a SIMD scheme to generate the curve that even the authors say is pointless for transpose because the overhead was low enough already. But it also explains the non-SIMD aspects well, if you have the time.</p>
Expand Down
Loading

0 comments on commit 470bc42

Please sign in to comment.