Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap committed Oct 12, 2023
1 parent 6455dd7 commit f60fcac
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 38 deletions.
10 changes: 8 additions & 2 deletions tessellate_ipu/core/vertex/intrinsics_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,21 @@
*/
namespace ipu {
/** IPU hardware tag. */
struct HardwareTag {};
struct HardwareTag {
static constexpr bool hardware = true;
};
/** IPU model tag. */
struct ModelTag {};
struct ModelTag {
static constexpr bool model = true;
};
} // namespace ipu

// IPU dispatch tag preprocessor.
#ifdef __IPU__
#define IPU_TAG_TYPE ipu::HardwareTag
#define IPU_DISPATCH_TAG (ipu::HardwareTag{})
#else
#define IPU_TAG_TYPE ipu::ModelTag
#define IPU_DISPATCH_TAG (ipu::ModelTag{})
#endif

Expand Down
4 changes: 2 additions & 2 deletions tessellate_ipu/core/vertex/tile_small_dot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ class [[poplar::constraint("elem(*inrow0) != elem(*outrow0)",
T2* outrow0_ptr = reinterpret_cast<T2*>(outrow0.data()) + wstart;
T2* outrow1_ptr = reinterpret_cast<T2*>(outrow1.data()) + wstart;

rotation2d_f32(cs_ptr[0], inrow0_ptr, inrow1_ptr, outrow0_ptr, outrow1_ptr,
wsize, IPU_DISPATCH_TAG);
rotation2d_f32<IPU_TAG_TYPE>(cs_ptr[0], inrow0_ptr, inrow1_ptr, outrow0_ptr, outrow1_ptr,
wsize);
return true;
}
};
80 changes: 47 additions & 33 deletions tessellate_ipu/core/vertex/tile_small_dot.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ inline void axplusby_f32_v0(float a, float b, const float2 *x, const float2 *y,
/**
* @brief z = a*x + b*y float32 implementation using `rpt` loop and `f32v2axpy`
*/
inline void axplusby_f32_v1(float a, float b, const float2 *x, const float2 *y,
template <class IpuTag, std::enable_if_t<IpuTag::model, bool> = true>
inline void axplusby_f32_v2(float a, float b, const float2 *x, const float2 *y,
float2 *z, rptsize_t nblocks) {
// Necessary if using unsigned `nblocks`.
// __builtin_assume(nblocks < 4096);
Expand Down Expand Up @@ -63,56 +64,65 @@ inline void axplusby_f32_v1(float a, float b, const float2 *x, const float2 *y,
* @brief z = a*x + b*y float32 implementation fully optimized in inline
* assembly.
*/
template <class IpuTag, std::enable_if_t<IpuTag::hardware, bool> = true>
inline void axplusby_f32_v2(float a, float b, const float2 *x, const float2 *y,
float2 *z, rptsize_t nblocks) {
// Necessary if using unsigned `nblocks`.
// __builtin_assume(nblocks < 4096);
using T2 = float2;
const T2 av = {a, a};
T2 av = {a, a};
// Using TAS register for the scalar `b`.
__ipu_and_ipumodel_tas tas;
tas.put(b);

T2 xv, yv, zv;
uint2 tapaddr;

// Inline assembly loop in order to use `ldst64pace` instruction.
// Note: requires "unrolling" the beginning of the `f32v2axpy` pipeline.
asm volatile(
R"(
ld64step $a0:1, $m15, %[xptr]+=, 1
ld64step $a4:5, $m15, %[yptr]+=, 1
ld64step %[xv], $m15, %[xptr]+=, 1
ld64step %[yv], $m15, %[yptr]+=, 1
{
ld64step $a0:1, $m15, %[xptr]+=, 1
f32v2mul $a2:3, $a0:1, %[avec]
ld64step %[xv], $m15, %[xptr]+=, 1
f32v2mul %[zv], %[xv], %[av]
}
{
ld64step $a4:5, $m15, %[yptr]+=, 1
f32v2axpy $a2:3, $a4:5, $a2:3
ld64step %[yv], $m15, %[yptr]+=, 1
f32v2axpy %[zv], %[yv], %[zv]
}
{
ld64step $a0:1, $m15, %[xptr]+=, 1
f32v2mul $a2:3, $a0:1, %[avec]
ld64step %[xv], $m15, %[xptr]+=, 1
f32v2mul %[zv], %[xv], %[av]
}
{
ld64step $a4:5, $m15, %[yptr]+=, 1
f32v2axpy $a2:3, $a4:5, $a2:3
ld64step %[yv], $m15, %[yptr]+=, 1
f32v2axpy %[zv], %[yv], %[zv]
}
tapack $m4:5, %[xptr], $mzero, %[zptr]
tapack %[tapaddr], %[xptr], $mzero, %[zptr]
.align 8
{
rpt %[nb], 1
fnop
}
{
ldst64pace $a0:1, $a2:3, $m4:5+=, $mzero, 0
f32v2mul $a2:3, $a0:1, %[avec]
ldst64pace %[xv], %[zv], %[tapaddr]+=, $mzero, 0
f32v2mul %[zv], %[xv], %[av]
}
{
ld64step $a4:5, $m15, %[yptr]+=, 1
f32v2axpy $a2:3, $a4:5, $a2:3
ld64step %[yv], $m15, %[yptr]+=, 1
f32v2axpy %[zv], %[yv], %[zv]
}
)"
: [ xptr ] "+r"(x), [ yptr ] "+r"(y)
: [ nb ] "r"(nblocks), [ zptr ] "r"(z), [ avec ] "r"(av)
: "$a0:1", "$a2:3", "$a4:5", "$m4", "$m5");
: [ xptr ] "+r"(x), [ yptr ] "+r"(y), [ av ] "+r"(av), [ xv ] "=r"(xv),
[ yv ] "=r"(yv), [ zv ] "=r"(zv), [ tapaddr ] "+r"(tapaddr),
[ nb ] "+r"(nblocks)
: [ zptr ] "r"(z)
// : "r0");
// : "$a0:1", "$a2:3", "$a4:5", "$m4", "$m5");
// : "$m4", "$m5");
:);
}

/**
Expand All @@ -121,19 +131,23 @@ inline void axplusby_f32_v2(float a, float b, const float2 *x, const float2 *y,
* Note: input rows are separated, allowing more flexibility
* for functions/vertices using this base compute method.
*/
inline void rotation2d_f32(float2 cs, const float2 *inrow0,
const float2 *inrow1, float2 *outrow0,
float2 *outrow1, rptsize_t nblocks, ipu::ModelTag) {
template <class IpuTag>
void rotation2d_f32(float2 cs, const float2 *inrow0, const float2 *inrow1,
float2 *outrow0, float2 *outrow1, rptsize_t nblocks) {
// TODO: investigate using IPU AMP unit?
axplusby_f32_v2(cs[0], -cs[1], inrow0, inrow1, outrow0, nblocks);
axplusby_f32_v2(cs[0], cs[1], inrow1, inrow0, outrow1, nblocks);
axplusby_f32_v2<IpuTag>(cs[0], -cs[1], inrow0, inrow1, outrow0, nblocks);
axplusby_f32_v2<IpuTag>(cs[0], cs[1], inrow1, inrow0, outrow1, nblocks);
}

inline void rotation2d_f32(float2 cs, const float2 *inrow0,
const float2 *inrow1, float2 *outrow0,
float2 *outrow1, rptsize_t nblocks,
ipu::HardwareTag) {
// Using same implementation as IPU model for now.
rotation2d_f32(cs, inrow0, inrow1, outrow0, outrow1, nblocks,
ipu::ModelTag{});
}
// void rotation2d_f32(float2 cs, const float2 *inrow0,
// const float2 *inrow1, float2 *outrow0,
// float2 *outrow1, rptsize_t nblocks,
// ipu::HardwareTag) {
// // Using same implementation as IPU model for now.
// axplusby_f32_v2<ipu::HardwareTag>(cs[0], -cs[1], inrow0, inrow1, outrow0,
// nblocks); axplusby_f32_v2<ipu::HardwareTag>(cs[0], cs[1], inrow1, inrow0,
// outrow1, nblocks);

// // rotation2d_f32(cs, inrow0, inrow1, outrow0, outrow1, nblocks,
// // ipu::ModelTag{});
// }
2 changes: 1 addition & 1 deletion tests/lax/test_tile_lax_small_dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,4 @@ def compute_fn(cs, row0, row1):
start, end = np.asarray(start)[0], np.asarray(end)[0]
hw_cycle_count = end[0] - start[0]
# Observe on IPU Mk2 hw ~1418 cycles.
assert hw_cycle_count <= 1500
assert hw_cycle_count <= 150

0 comments on commit f60fcac

Please sign in to comment.