diff --git a/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp b/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp index 0cb74c0..e30de5f 100644 --- a/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp +++ b/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp @@ -3,6 +3,7 @@ #include #include "intrinsics_utils.hpp" +#include "tile_small_dot.hpp" using namespace poplar; @@ -78,7 +79,7 @@ class JacobiSymSchur2 : public Vertex { } }; -template +template void jacob_update_first_step(const T* pcol, const T* qcol, T* pcol_updated, T* qcol_updated, T* cs, unsigned p, unsigned q, unsigned short wstart, @@ -89,16 +90,13 @@ void jacob_update_first_step(const T* pcol, const T* qcol, T* pcol_updated, const T Apq = pcol[q]; const T App = pcol[p]; const T Aqq = qcol[q]; - // Schur2 decomposition. const T2 cs_vec = sym_schur2(App, Aqq, Apq); const T& c = cs_vec[0]; const T& s = cs_vec[1]; cs[0] = c; cs[1] = s; - // Worker load: start + end vectorized indexes. - constexpr unsigned ptr_step = 1; const IndexType wsize = wend - wstart; // pcol, qcol and results pointers. @@ -106,23 +104,9 @@ void jacob_update_first_step(const T* pcol, const T* qcol, T* pcol_updated, const float2* ptr_qcol = reinterpret_cast(qcol) + wstart; float2* ptr_pcol_updated = reinterpret_cast(pcol_updated) + wstart; float2* ptr_qcol_updated = reinterpret_cast(qcol_updated) + wstart; - - const T2 cvec = T2{c, c}; - const T2 svec = T2{s, s}; - - // Easier to vectorized + parallelize if start with normal update first. - for (IndexType idx = 0; idx != wsize; ++idx) { - // TODO: investigate assembly? - const T2 pvec = ipu::load_postinc(&ptr_pcol, 1); - const T2 qvec = ipu::load_postinc(&ptr_qcol, 1); - - const T2 pvec_updated = cvec * pvec - svec * qvec; - const T2 qvec_updated = svec * pvec + cvec * qvec; - - ipu::store_postinc(&ptr_pcol_updated, pvec_updated, 1); - ipu::store_postinc(&ptr_qcol_updated, qvec_updated, 1); - } - + // Apply Schur2 cs rotation to p/q columns (optimized kernel). + rotation2d_f32(cs_vec, ptr_pcol, ptr_qcol, ptr_pcol_updated, + ptr_qcol_updated, wsize); // Update main values App, Apq, Aqq pcol_updated[p] = c * c * App - 2 * s * c * Apq + s * s * Aqq; qcol_updated[q] = s * s * App + 2 * s * c * Apq + c * c * Aqq; @@ -137,7 +121,9 @@ void jacob_update_first_step(const T* pcol, const T* qcol, T* pcol_updated, * See: Gene H. Golub, Charles F. Van Loan, MATRIX COMPUTATIONS, 3rd edition, * Johns Hopkins Chapter 8. */ -class [[poplar::constraint("elem(*pcol) != elem(*qcol)")]] JacobiUpdateFirstStep +class [[poplar::constraint( + "elem(*pcol) != elem(*pcol_updated)", + "elem(*qcol) != elem(*qcol_updated)")]] JacobiUpdateFirstStep : public MultiVertex { public: using T = float; @@ -179,7 +165,7 @@ class [[poplar::constraint("elem(*pcol) != elem(*qcol)")]] JacobiUpdateFirstStep if (p <= q) { // Proper ordering of p and q already. - jacob_update_first_step( + jacob_update_first_step( pcol.data() + INDEX_PREFIX, qcol.data() + INDEX_PREFIX, pcol_updated.data() + INDEX_PREFIX, qcol_updated.data() + INDEX_PREFIX, cs.data(), p, q, wstart, wend); @@ -187,7 +173,7 @@ class [[poplar::constraint("elem(*pcol) != elem(*qcol)")]] JacobiUpdateFirstStep rotset_sorted[1] = q; } else { // Swap p and q columns as q < p - jacob_update_first_step( + jacob_update_first_step( qcol.data() + INDEX_PREFIX, pcol.data() + INDEX_PREFIX, qcol_updated.data() + INDEX_PREFIX, pcol_updated.data() + INDEX_PREFIX, cs.data(), q, p, wstart, wend); @@ -229,7 +215,6 @@ class JacobiUpdateSecondStep : public MultiVertex { // Size of the index prefix in pcol and qcol. constexpr int INDEX_PREFIX = 2; // Worker load: start + end vectorized indexes. - constexpr unsigned ptr_step = 1; const IndexType wstart = worker_offsets[wid]; const IndexType wend = worker_offsets[wid + 1]; const IndexType wsize = wend - wstart; diff --git a/tessellate_ipu/core/vertex/tile_small_dot.cpp b/tessellate_ipu/core/vertex/tile_small_dot.cpp index 2391ca1..531bbe7 100644 --- a/tessellate_ipu/core/vertex/tile_small_dot.cpp +++ b/tessellate_ipu/core/vertex/tile_small_dot.cpp @@ -48,6 +48,7 @@ class [[poplar::constraint("elem(*inrow0) != elem(*outrow0)", T2* outrow0_ptr = reinterpret_cast(outrow0.data()) + wstart; T2* outrow1_ptr = reinterpret_cast(outrow1.data()) + wstart; + // Passing IPU model/hardware tag type. rotation2d_f32(cs_ptr[0], inrow0_ptr, inrow1_ptr, outrow0_ptr, outrow1_ptr, wsize); return true; diff --git a/tests/linalg/test_tile_linalg_jacobi.py b/tests/linalg/test_tile_linalg_jacobi.py index 01f43b1..e54dd6c 100644 --- a/tests/linalg/test_tile_linalg_jacobi.py +++ b/tests/linalg/test_tile_linalg_jacobi.py @@ -100,7 +100,7 @@ def jacobi_sym_schur2_fn(pq, pcol, qcol): # assert False def test__jacobi_update_first_step_vertex__benchmark_performance(self): - N = 128 + N = 256 tiles = (0,) pq = np.array([3, N // 2], dtype=np.uint32) pcol = np.random.randn(1, N).astype(np.float32) @@ -123,7 +123,7 @@ def jacobi_update_first_step_fn(pq, pcol, qcol): start, end = np.asarray(start)[0], np.asarray(end)[0] qr_correction_cycle_count = end[0] - start[0] - assert qr_correction_cycle_count <= 1700 + assert qr_correction_cycle_count <= 1550 # print("CYCLE count:", qr_correction_cycle_count) # assert False