Skip to content

Commit

Permalink
Remove TAS register + axpy instrinsic wrapper class (#53)
Browse files Browse the repository at this point in the history
Replacing the use with `ipu::AMP<T>`, which is much more general
and should be able to properly model any IPU AMP unit instruction.
  • Loading branch information
balancap committed Oct 20, 2023
1 parent acabc7d commit 5f247c4
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 78 deletions.
53 changes: 3 additions & 50 deletions tessellate_ipu/core/vertex/intrinsics_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,6 @@ ALWAYS_INLINE float ld32(const T* address, unsigned offset) {
return result;
}

struct __ipu_and_ipumodel_tas {
void put(float v) { __builtin_ipu_put_tas(v); }
float2 f32v2axpy(float2 const& x, float2 const& y) {
return __builtin_ipu_f32v2axpy(x, y);
}
};

#else

#include <limits>
Expand Down Expand Up @@ -152,60 +145,20 @@ IpuVector<T, N> fma(IpuVector<T, N> const& x, IpuVector<T, N> const& y,

} // namespace ipu

// Reflect IPU's AXPY semantics in a way that is IPUModel compatible
// IPU-only usage:
// __builtin_ipu_put_tas(v);
// z_prev = __builtin_ipu_f32v2axpy(x, y)
//
// IPUModel-compatible usage:
// __ipu_and_ipumodel_tas tas;
// tas.put(v);
// z_prev = tas.f32v2axpy(x, y)
//
// https://docs.graphcore.ai/projects/poplar-api/en/latest/ipu_intrinsics/ipu_builtins.html#_CPPv423__builtin_ipu_f32v2axpy6float26float2
struct __ipu_and_ipumodel_tas {
float tas;
float2 prev;

__ipu_and_ipumodel_tas() : tas{0}, prev{0, 0} {}

void put(float v) { tas = v; }

float2 f32v2axpy(float2 const& x, float2 const& y) {
const auto res = prev;
prev = float2{
// TODO: understand ordering!?
// tas * x[0] + y[0],
// tas * x[1] + y[1],
tas * y[0] + x[0],
tas * y[1] + x[1],
};
return res;
}
};

// And give useful error messages when people port from IPU to IPUModel, e.g.
/* clang-format off */ // need these error messages on one line
/*
/workspaces/tessellate-ipu/tessellate/tile/vertex/intrinsics_utils.hpp:166:3: error: static_assert failed due to requirement '__ipu_false<IpuVector<float, 2>>()': *** Replace __builtin_ipu_f32v2axpy with __ipu_and_ipumodel_tas for TAS handling on IPUModel.
static_assert(__ipu_false<T>(), "*** Replace __builtin_ipu_f32v2axpy with __ipu_and_ipumodel_tas for TAS handling on IPUModel.");
^ ~~~~~~~~~~~~~~~~
/workspaces/tessellate-ipu/tessellate/tile/vertex/tile_qr_vertex.cpp:231:12: note: in instantiation of function template specialization '__builtin_ipu_f32v2axpy<IpuVector<float, 2>>' requested here
rout = __builtin_ipu_f32v2axpy(rtmp, rtmp);
*/
// And give useful error messages when people port from IPU to IPUModel.
template <typename T>
constexpr bool __ipu_false() {
return !std::is_same<T, T>::value;
}

template <typename T>
void __builtin_ipu_put_tas(T v) {
static_assert(__ipu_false<T>(), "*** Replace __builtin_ipu_put_tas with __ipu_and_ipumodel_tas for TAS handling on IPUModel.");
static_assert(__ipu_false<T>(), "*** Please use `ipu::AMP` class for TAS handling on IPUModel.");
}

template <typename T>
T __builtin_ipu_f32v2axpy(T const& x, T const& y) {
static_assert(__ipu_false<T>(), "*** Replace __builtin_ipu_f32v2axpy with __ipu_and_ipumodel_tas for TAS handling on IPUModel.");
static_assert(__ipu_false<T>(), "*** Please use `ipu::AMP::axpy` for `f32v2axpy` intrinsic on IPUModel.");
return T{};
}
// clang-format on
Expand Down
20 changes: 9 additions & 11 deletions tessellate_ipu/core/vertex/tile_hessenberg_vertex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <poplar/Vertex.hpp>

#include "intrinsics_utils.hpp"
#include "ipu_amp.hpp"

using namespace poplar;

Expand Down Expand Up @@ -162,9 +163,10 @@ class [[poplar::constraint(

// Set the $TAS register with the proper scale.
const T s = -scale1[0] * scale2[0];
// __builtin_ipu_put_tas(s);
__ipu_and_ipumodel_tas tas;
tas.put(s);
// Basic AMP usage with TAS + axpy instruction.
// AMP code using this abstraction is compatible with IPU hw & model.
ipu::AMP<T> amp;
amp.tas(s);

// Nothing to do in this worker thread.
if (wstart == wend) {
Expand All @@ -183,20 +185,16 @@ class [[poplar::constraint(
vin = ipu::load_postinc(&ptr_vdata_f2, ptr_step);
// TODO: use ld2x64pace + tapack instructions.
for (IndexType idx = 1; idx != wsize; ++idx) {
rtmp = tas.f32v2axpy(xin, vin);
// rtmp = __builtin_ipu_f32v2axpy(xin, vin);
rtmp = amp.axpy(vin, xin);
// Grouping here seems to help the compiler optimising loads?
xin = ipu::load_postinc(&ptr_inxdata_f2, ptr_step);
vin = ipu::load_postinc(&ptr_vdata_f2, ptr_step);
rout = tas.f32v2axpy(rtmp, rtmp);
// rout = __builtin_ipu_f32v2axpy(rtmp, rtmp);
rout = amp.axpy(rtmp, rtmp);
ipu::store_postinc(&ptr_outxdata_f2, rout, ptr_step);
}
// Finish the loop, getting the last computation.
// rtmp = __builtin_ipu_f32v2axpy(xin, vin);
// rout = __builtin_ipu_f32v2axpy(rtmp, rtmp);
rtmp = tas.f32v2axpy(xin, vin);
rout = tas.f32v2axpy(rtmp, rtmp);
rtmp = amp.axpy(vin, xin);
rout = amp.axpy(rtmp, rtmp);
ipu::store_postinc(&ptr_outxdata_f2, rout, ptr_step);

return true;
Expand Down
2 changes: 1 addition & 1 deletion tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ void jacob_update_eigenvectors(const T* vpcol, const T* vqcol, T* vpcol_updated,
* See: Gene H. Golub, Charles F. Van Loan, MATRIX COMPUTATIONS, 3rd edition,
* Johns Hopkins Chapter 8.
*/
class [[poplar::constraint(
class [[poplar::constraint(
"elem(*vpcol) != elem(*vpcol_out)",
"elem(*vqcol) != elem(*vqcol_out)")]] JacobiUpdateEigenvectors
: public MultiVertex {
Expand Down
24 changes: 11 additions & 13 deletions tessellate_ipu/core/vertex/tile_qr_vertex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <poplar/Vertex.hpp>

#include "intrinsics_utils.hpp"
#include "ipu_amp.hpp"

using namespace poplar;

Expand Down Expand Up @@ -165,8 +166,8 @@ float QRCorrectionVectorVertex::shared_partial_sqnorms[6] = {-1};
* NOTE: poplar::constraint here to make sure x and v are not part of the same
* memory bank, allowing simultaneous loads (see `ld2x64pace` instruction).
*/
class [[poplar::constraint(
"elem(*x) != elem(*v)")]] QRHouseholderRowUpdateVertex
class [
[poplar::constraint("elem(*x) != elem(*v)")]] QRHouseholderRowUpdateVertex
: public MultiVertex {
public:
using T = float;
Expand Down Expand Up @@ -199,9 +200,10 @@ class [[poplar::constraint(

// Set the $TAS register with the proper scale.
const T s = -scale1[0] * scale2[0];
// __builtin_ipu_put_tas(s);
__ipu_and_ipumodel_tas tas;
tas.put(s);
// Basic AMP usage with TAS + axpy instruction.
// AMP code using this abstraction is compatible with IPU hw & model.
ipu::AMP<T> amp;
amp.tas(s);

// Nothing to do in this worker thread.
if (wstart == wend) {
Expand All @@ -220,20 +222,16 @@ class [[poplar::constraint(
vin = ipu::load_postinc(&ptr_vdata_f2, ptr_step);
// TODO: use ld2x64pace + tapack instructions.
for (IndexType idx = 1; idx != wsize; ++idx) {
rtmp = tas.f32v2axpy(xin, vin);
// rtmp = __builtin_ipu_f32v2axpy(xin, vin);
rtmp = amp.axpy(vin, xin);
// Grouping here seems to help the compiler optimising loads?
xin = ipu::load_postinc(&ptr_inxdata_f2, ptr_step);
vin = ipu::load_postinc(&ptr_vdata_f2, ptr_step);
rout = tas.f32v2axpy(rtmp, rtmp);
// rout = __builtin_ipu_f32v2axpy(rtmp, rtmp);
rout = amp.axpy(rtmp, rtmp);
ipu::store_postinc(&ptr_outxdata_f2, rout, ptr_step);
}
// Finish the loop, getting the last computation.
// rtmp = __builtin_ipu_f32v2axpy(xin, vin);
// rout = __builtin_ipu_f32v2axpy(rtmp, rtmp);
rtmp = tas.f32v2axpy(xin, vin);
rout = tas.f32v2axpy(rtmp, rtmp);
rtmp = amp.axpy(vin, xin);
rout = amp.axpy(rtmp, rtmp);
ipu::store_postinc(&ptr_outxdata_f2, rout, ptr_step);

return true;
Expand Down
7 changes: 4 additions & 3 deletions tessellate_ipu/core/vertex/tile_small_dot.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,10 @@ inline void axplusby_f32(float a, float b, const float2 *x, const float2 *y,
// Necessary if using unsigned `nblocks`.
// __builtin_assume(nblocks < 4096);
using T2 = float2;
// Using TAS register for the scalar `b`.
__ipu_and_ipumodel_tas tas;
tas.put(b);
// Basic AMP usage with TAS + axpy instruction.
ipu::AMP<float> amp;
amp.tas(b);


T2 av = {a, a};
// Explicit variables passed to inline assembly.
Expand Down

0 comments on commit 5f247c4

Please sign in to comment.