Skip to content

Commit

Permalink
Use optimized rotation2d kernel in Jacobi eigh update first step. (#48
Browse files Browse the repository at this point in the history
)

Switching to optimized kernel giving a 2x speedup on the vertex.
  • Loading branch information
balancap authored Oct 13, 2023
1 parent 9cd48f4 commit 5063716
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 27 deletions.
35 changes: 10 additions & 25 deletions tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <poplar/Vertex.hpp>

#include "intrinsics_utils.hpp"
#include "tile_small_dot.hpp"

using namespace poplar;

Expand Down Expand Up @@ -78,7 +79,7 @@ class JacobiSymSchur2 : public Vertex {
}
};

template <typename T>
template <class IpuTag, typename T>
void jacob_update_first_step(const T* pcol, const T* qcol, T* pcol_updated,
T* qcol_updated, T* cs, unsigned p, unsigned q,
unsigned short wstart,
Expand All @@ -89,40 +90,23 @@ void jacob_update_first_step(const T* pcol, const T* qcol, T* pcol_updated,
const T Apq = pcol[q];
const T App = pcol[p];
const T Aqq = qcol[q];

// Schur2 decomposition.
const T2 cs_vec = sym_schur2(App, Aqq, Apq);
const T& c = cs_vec[0];
const T& s = cs_vec[1];
cs[0] = c;
cs[1] = s;

// Worker load: start + end vectorized indexes.
constexpr unsigned ptr_step = 1;
const IndexType wsize = wend - wstart;

// pcol, qcol and results pointers.
const float2* ptr_pcol = reinterpret_cast<const float2*>(pcol) + wstart;
const float2* ptr_qcol = reinterpret_cast<const float2*>(qcol) + wstart;
float2* ptr_pcol_updated = reinterpret_cast<float2*>(pcol_updated) + wstart;
float2* ptr_qcol_updated = reinterpret_cast<float2*>(qcol_updated) + wstart;

const T2 cvec = T2{c, c};
const T2 svec = T2{s, s};

// Easier to vectorized + parallelize if start with normal update first.
for (IndexType idx = 0; idx != wsize; ++idx) {
// TODO: investigate assembly?
const T2 pvec = ipu::load_postinc(&ptr_pcol, 1);
const T2 qvec = ipu::load_postinc(&ptr_qcol, 1);

const T2 pvec_updated = cvec * pvec - svec * qvec;
const T2 qvec_updated = svec * pvec + cvec * qvec;

ipu::store_postinc(&ptr_pcol_updated, pvec_updated, 1);
ipu::store_postinc(&ptr_qcol_updated, qvec_updated, 1);
}

// Apply Schur2 cs rotation to p/q columns (optimized kernel).
rotation2d_f32<IpuTag>(cs_vec, ptr_pcol, ptr_qcol, ptr_pcol_updated,
ptr_qcol_updated, wsize);
// Update main values App, Apq, Aqq
pcol_updated[p] = c * c * App - 2 * s * c * Apq + s * s * Aqq;
qcol_updated[q] = s * s * App + 2 * s * c * Apq + c * c * Aqq;
Expand All @@ -137,7 +121,9 @@ void jacob_update_first_step(const T* pcol, const T* qcol, T* pcol_updated,
* See: Gene H. Golub, Charles F. Van Loan, MATRIX COMPUTATIONS, 3rd edition,
* Johns Hopkins Chapter 8.
*/
class [[poplar::constraint("elem(*pcol) != elem(*qcol)")]] JacobiUpdateFirstStep
class [[poplar::constraint(
"elem(*pcol) != elem(*pcol_updated)",
"elem(*qcol) != elem(*qcol_updated)")]] JacobiUpdateFirstStep
: public MultiVertex {
public:
using T = float;
Expand Down Expand Up @@ -179,15 +165,15 @@ class [[poplar::constraint("elem(*pcol) != elem(*qcol)")]] JacobiUpdateFirstStep

if (p <= q) {
// Proper ordering of p and q already.
jacob_update_first_step(
jacob_update_first_step<IPU_TAG_TYPE>(
pcol.data() + INDEX_PREFIX, qcol.data() + INDEX_PREFIX,
pcol_updated.data() + INDEX_PREFIX,
qcol_updated.data() + INDEX_PREFIX, cs.data(), p, q, wstart, wend);
rotset_sorted[0] = p;
rotset_sorted[1] = q;
} else {
// Swap p and q columns as q < p
jacob_update_first_step(
jacob_update_first_step<IPU_TAG_TYPE>(
qcol.data() + INDEX_PREFIX, pcol.data() + INDEX_PREFIX,
qcol_updated.data() + INDEX_PREFIX,
pcol_updated.data() + INDEX_PREFIX, cs.data(), q, p, wstart, wend);
Expand Down Expand Up @@ -229,7 +215,6 @@ class JacobiUpdateSecondStep : public MultiVertex {
// Size of the index prefix in pcol and qcol.
constexpr int INDEX_PREFIX = 2;
// Worker load: start + end vectorized indexes.
constexpr unsigned ptr_step = 1;
const IndexType wstart = worker_offsets[wid];
const IndexType wend = worker_offsets[wid + 1];
const IndexType wsize = wend - wstart;
Expand Down
1 change: 1 addition & 0 deletions tessellate_ipu/core/vertex/tile_small_dot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class [[poplar::constraint("elem(*inrow0) != elem(*outrow0)",
T2* outrow0_ptr = reinterpret_cast<T2*>(outrow0.data()) + wstart;
T2* outrow1_ptr = reinterpret_cast<T2*>(outrow1.data()) + wstart;

// Passing IPU model/hardware tag type.
rotation2d_f32<IPU_TAG_TYPE>(cs_ptr[0], inrow0_ptr, inrow1_ptr, outrow0_ptr,
outrow1_ptr, wsize);
return true;
Expand Down
4 changes: 2 additions & 2 deletions tests/linalg/test_tile_linalg_jacobi.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def jacobi_sym_schur2_fn(pq, pcol, qcol):
# assert False

def test__jacobi_update_first_step_vertex__benchmark_performance(self):
N = 128
N = 256
tiles = (0,)
pq = np.array([3, N // 2], dtype=np.uint32)
pcol = np.random.randn(1, N).astype(np.float32)
Expand All @@ -123,7 +123,7 @@ def jacobi_update_first_step_fn(pq, pcol, qcol):

start, end = np.asarray(start)[0], np.asarray(end)[0]
qr_correction_cycle_count = end[0] - start[0]
assert qr_correction_cycle_count <= 1700
assert qr_correction_cycle_count <= 1550
# print("CYCLE count:", qr_correction_cycle_count)
# assert False

Expand Down

0 comments on commit 5063716

Please sign in to comment.