Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use optimized rotation2d kernel in Jacobi eigh update first step. #48

Merged
merged 1 commit into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 10 additions & 25 deletions tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <poplar/Vertex.hpp>

#include "intrinsics_utils.hpp"
#include "tile_small_dot.hpp"

using namespace poplar;

Expand Down Expand Up @@ -78,7 +79,7 @@ class JacobiSymSchur2 : public Vertex {
}
};

template <typename T>
template <class IpuTag, typename T>
void jacob_update_first_step(const T* pcol, const T* qcol, T* pcol_updated,
T* qcol_updated, T* cs, unsigned p, unsigned q,
unsigned short wstart,
Expand All @@ -89,40 +90,23 @@ void jacob_update_first_step(const T* pcol, const T* qcol, T* pcol_updated,
const T Apq = pcol[q];
const T App = pcol[p];
const T Aqq = qcol[q];

// Schur2 decomposition.
const T2 cs_vec = sym_schur2(App, Aqq, Apq);
const T& c = cs_vec[0];
const T& s = cs_vec[1];
cs[0] = c;
cs[1] = s;

// Worker load: start + end vectorized indexes.
constexpr unsigned ptr_step = 1;
const IndexType wsize = wend - wstart;

// pcol, qcol and results pointers.
const float2* ptr_pcol = reinterpret_cast<const float2*>(pcol) + wstart;
const float2* ptr_qcol = reinterpret_cast<const float2*>(qcol) + wstart;
float2* ptr_pcol_updated = reinterpret_cast<float2*>(pcol_updated) + wstart;
float2* ptr_qcol_updated = reinterpret_cast<float2*>(qcol_updated) + wstart;

const T2 cvec = T2{c, c};
const T2 svec = T2{s, s};

// Easier to vectorized + parallelize if start with normal update first.
for (IndexType idx = 0; idx != wsize; ++idx) {
// TODO: investigate assembly?
const T2 pvec = ipu::load_postinc(&ptr_pcol, 1);
const T2 qvec = ipu::load_postinc(&ptr_qcol, 1);

const T2 pvec_updated = cvec * pvec - svec * qvec;
const T2 qvec_updated = svec * pvec + cvec * qvec;

ipu::store_postinc(&ptr_pcol_updated, pvec_updated, 1);
ipu::store_postinc(&ptr_qcol_updated, qvec_updated, 1);
}

// Apply Schur2 cs rotation to p/q columns (optimized kernel).
rotation2d_f32<IpuTag>(cs_vec, ptr_pcol, ptr_qcol, ptr_pcol_updated,
ptr_qcol_updated, wsize);
// Update main values App, Apq, Aqq
pcol_updated[p] = c * c * App - 2 * s * c * Apq + s * s * Aqq;
qcol_updated[q] = s * s * App + 2 * s * c * Apq + c * c * Aqq;
Expand All @@ -137,7 +121,9 @@ void jacob_update_first_step(const T* pcol, const T* qcol, T* pcol_updated,
* See: Gene H. Golub, Charles F. Van Loan, MATRIX COMPUTATIONS, 3rd edition,
* Johns Hopkins Chapter 8.
*/
class [[poplar::constraint("elem(*pcol) != elem(*qcol)")]] JacobiUpdateFirstStep
class [[poplar::constraint(
"elem(*pcol) != elem(*pcol_updated)",
"elem(*qcol) != elem(*qcol_updated)")]] JacobiUpdateFirstStep
: public MultiVertex {
public:
using T = float;
Expand Down Expand Up @@ -179,15 +165,15 @@ class [[poplar::constraint("elem(*pcol) != elem(*qcol)")]] JacobiUpdateFirstStep

if (p <= q) {
// Proper ordering of p and q already.
jacob_update_first_step(
jacob_update_first_step<IPU_TAG_TYPE>(
pcol.data() + INDEX_PREFIX, qcol.data() + INDEX_PREFIX,
pcol_updated.data() + INDEX_PREFIX,
qcol_updated.data() + INDEX_PREFIX, cs.data(), p, q, wstart, wend);
rotset_sorted[0] = p;
rotset_sorted[1] = q;
} else {
// Swap p and q columns as q < p
jacob_update_first_step(
jacob_update_first_step<IPU_TAG_TYPE>(
qcol.data() + INDEX_PREFIX, pcol.data() + INDEX_PREFIX,
qcol_updated.data() + INDEX_PREFIX,
pcol_updated.data() + INDEX_PREFIX, cs.data(), q, p, wstart, wend);
Expand Down Expand Up @@ -229,7 +215,6 @@ class JacobiUpdateSecondStep : public MultiVertex {
// Size of the index prefix in pcol and qcol.
constexpr int INDEX_PREFIX = 2;
// Worker load: start + end vectorized indexes.
constexpr unsigned ptr_step = 1;
const IndexType wstart = worker_offsets[wid];
const IndexType wend = worker_offsets[wid + 1];
const IndexType wsize = wend - wstart;
Expand Down
1 change: 1 addition & 0 deletions tessellate_ipu/core/vertex/tile_small_dot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class [[poplar::constraint("elem(*inrow0) != elem(*outrow0)",
T2* outrow0_ptr = reinterpret_cast<T2*>(outrow0.data()) + wstart;
T2* outrow1_ptr = reinterpret_cast<T2*>(outrow1.data()) + wstart;

// Passing IPU model/hardware tag type.
rotation2d_f32<IPU_TAG_TYPE>(cs_ptr[0], inrow0_ptr, inrow1_ptr, outrow0_ptr,
outrow1_ptr, wsize);
return true;
Expand Down
4 changes: 2 additions & 2 deletions tests/linalg/test_tile_linalg_jacobi.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def jacobi_sym_schur2_fn(pq, pcol, qcol):
# assert False

def test__jacobi_update_first_step_vertex__benchmark_performance(self):
N = 128
N = 256
tiles = (0,)
pq = np.array([3, N // 2], dtype=np.uint32)
pcol = np.random.randn(1, N).astype(np.float32)
Expand All @@ -123,7 +123,7 @@ def jacobi_update_first_step_fn(pq, pcol, qcol):

start, end = np.asarray(start)[0], np.asarray(end)[0]
qr_correction_cycle_count = end[0] - start[0]
assert qr_correction_cycle_count <= 1700
assert qr_correction_cycle_count <= 1550
# print("CYCLE count:", qr_correction_cycle_count)
# assert False

Expand Down