Skip to content

Commit

Permalink
Additional optimization to axplusby_f32 using inline assembly. (#47)
Browse files Browse the repository at this point in the history
The function `axplusby_f32_v2` implements a fully optimized `rpt`
loop in inline assembly, taking advantage of the instruction `ldst64pace`.
  • Loading branch information
balancap authored Oct 13, 2023
1 parent 5d552e0 commit 9cd48f4
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 22 deletions.
10 changes: 8 additions & 2 deletions tessellate_ipu/core/vertex/intrinsics_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,21 @@
*/
namespace ipu {
/** IPU hardware tag. */
struct HardwareTag {};
struct HardwareTag {
static constexpr bool hardware = true;
};
/** IPU model tag. */
struct ModelTag {};
struct ModelTag {
static constexpr bool model = true;
};
} // namespace ipu

// IPU dispatch tag preprocessor.
#ifdef __IPU__
#define IPU_TAG_TYPE ipu::HardwareTag
#define IPU_DISPATCH_TAG (ipu::HardwareTag{})
#else
#define IPU_TAG_TYPE ipu::ModelTag
#define IPU_DISPATCH_TAG (ipu::ModelTag{})
#endif

Expand Down
8 changes: 5 additions & 3 deletions tessellate_ipu/core/vertex/tile_small_dot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ using namespace poplar;
/**
* @brief 2d rotation vertex.
*/
class Rotation2dVertex : public MultiVertex {
class [[poplar::constraint("elem(*inrow0) != elem(*outrow0)",
"elem(*inrow1) != elem(*outrow1)")]] Rotation2dVertex
: public MultiVertex {
public:
using T = float;
using T2 = float2;
Expand Down Expand Up @@ -46,8 +48,8 @@ class Rotation2dVertex : public MultiVertex {
T2* outrow0_ptr = reinterpret_cast<T2*>(outrow0.data()) + wstart;
T2* outrow1_ptr = reinterpret_cast<T2*>(outrow1.data()) + wstart;

rotation2d_f32(cs_ptr[0], inrow0_ptr, inrow1_ptr, outrow0_ptr, outrow1_ptr,
wsize, IPU_DISPATCH_TAG);
rotation2d_f32<IPU_TAG_TYPE>(cs_ptr[0], inrow0_ptr, inrow1_ptr, outrow0_ptr,
outrow1_ptr, wsize);
return true;
}
};
94 changes: 79 additions & 15 deletions tessellate_ipu/core/vertex/tile_small_dot.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,15 @@ inline void axplusby_f32_v0(float a, float b, const float2 *x, const float2 *y,
ipu::store_postinc(&z, zv, 1);
}
}

inline void axplusby_f32_v1(float a, float b, const float2 *x, const float2 *y,
float2 *z, rptsize_t nblocks) {
/**
* @brief z = a*x + b*y float32 implementation using `rpt` loop and `f32v2axpy`
*
* Compatible with IPU hardware and IPU model.
* 30% slower than inline assembly implementation.
*/
template <class IpuTag, std::enable_if_t<IpuTag::model, bool> = true>
inline void axplusby_f32(float a, float b, const float2 *x, const float2 *y,
float2 *z, rptsize_t nblocks) {
// Necessary if using unsigned `nblocks`.
// __builtin_assume(nblocks < 4096);
using T2 = float2;
Expand Down Expand Up @@ -57,25 +63,83 @@ inline void axplusby_f32_v1(float a, float b, const float2 *x, const float2 *y,
}
}
}
/**
* @brief z = a*x + b*y float32 implementation fully optimized in inline
* assembly.
*/
template <class IpuTag, std::enable_if_t<IpuTag::hardware, bool> = true>
inline void axplusby_f32(float a, float b, const float2 *x, const float2 *y,
float2 *z, rptsize_t nblocks) {
// Necessary if using unsigned `nblocks`.
// __builtin_assume(nblocks < 4096);
using T2 = float2;
// Using TAS register for the scalar `b`.
__ipu_and_ipumodel_tas tas;
tas.put(b);

T2 av = {a, a};
// Explicit variables passed to inline assembly.
// Easier to read + compiling on IPU model.
T2 xv, yv, zv;
uint2 tapaddr;
// Inline assembly loop in order to use `ldst64pace` instruction.
// Note: requires "unrolling" the beginning of the `f32v2axpy` pipeline.
// TODO: investigate issue with inputs register re-use.
asm volatile(
R"(
ld64step %[xv], $m15, %[xptr]+=, 1
ld64step %[yv], $m15, %[yptr]+=, 1
{
ld64step %[xv], $m15, %[xptr]+=, 1
f32v2mul %[zv], %[xv], %[av]
}
{
ld64step %[yv], $m15, %[yptr]+=, 1
f32v2axpy %[zv], %[yv], %[zv]
}
{
ld64step %[xv], $m15, %[xptr]+=, 1
f32v2mul %[zv], %[xv], %[av]
}
{
ld64step %[yv], $m15, %[yptr]+=, 1
f32v2axpy %[zv], %[yv], %[zv]
}
tapack %[tapaddr], %[xptr], $mzero, %[zptr]
.align 8
{
rpt %[nb], 1
fnop
}
{
ldst64pace %[xv], %[zv], %[tapaddr]+=, $mzero, 0
f32v2mul %[zv], %[xv], %[av]
}
{
ld64step %[yv], $m15, %[yptr]+=, 1
f32v2axpy %[zv], %[yv], %[zv]
}
)"
: [ xptr ] "+r"(x), [ yptr ] "+r"(y), [ av ] "+r"(av), [ xv ] "=r"(xv),
[ yv ] "=r"(yv), [ zv ] "=r"(zv), [ tapaddr ] "+r"(tapaddr),
[ nb ] "+r"(nblocks)
: [ zptr ] "r"(z)
:);
// Note: explicit list of used registers not compiling on IPU model.
// : "$a0:1", "$a2:3", "$a4:5", "$m4", "$m5"
}

/**
* @brief Apply 2d rotation transform (float).
*
* Note: input rows are separated, allowing more flexibility
* for functions/vertices using this base compute method.
*/
template <class IpuTag>
inline void rotation2d_f32(float2 cs, const float2 *inrow0,
const float2 *inrow1, float2 *outrow0,
float2 *outrow1, rptsize_t nblocks, ipu::ModelTag) {
axplusby_f32_v1(cs[0], -cs[1], inrow0, inrow1, outrow0, nblocks);
axplusby_f32_v1(cs[1], cs[0], inrow0, inrow1, outrow1, nblocks);
}

inline void rotation2d_f32(float2 cs, const float2 *inrow0,
const float2 *inrow1, float2 *outrow0,
float2 *outrow1, rptsize_t nblocks,
ipu::HardwareTag) {
// Using same implementation as IPU model for now.
rotation2d_f32(cs, inrow0, inrow1, outrow0, outrow1, nblocks,
ipu::ModelTag{});
float2 *outrow1, rptsize_t nblocks) {
// TODO: investigate using IPU AMP unit?
axplusby_f32<IpuTag>(cs[0], -cs[1], inrow0, inrow1, outrow0, nblocks);
axplusby_f32<IpuTag>(cs[0], cs[1], inrow1, inrow0, outrow1, nblocks);
}
4 changes: 2 additions & 2 deletions tests/lax/test_tile_lax_small_dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,5 @@ def compute_fn(cs, row0, row1):
# Hardware cycle count bound.
start, end = np.asarray(start)[0], np.asarray(end)[0]
hw_cycle_count = end[0] - start[0]
# Observe on IPU Mk2 hw ~1916 cycles.
assert hw_cycle_count <= 2000
# Observe on IPU Mk2 hw ~1436 cycles.
assert hw_cycle_count <= 1500

0 comments on commit 9cd48f4

Please sign in to comment.