-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implementing optimized 2d rotation transform vertex. (#46)
This PR is starting a collection of basic linalg C++ functions, well optimized, which can be re-used in more complex C++ vertices. This PR is starting with `axplusby` and `rotation2d` functions. Note: investigate additional optimizations to reduce the `rpt` loop in `axplusby` from 3 cycles to 2 cycles.
- Loading branch information
Showing
7 changed files
with
265 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
// Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
#include "tile_small_dot.hpp" | ||
|
||
#include <poplar/HalfFloat.hpp> | ||
#include <poplar/Vertex.hpp> | ||
|
||
using namespace poplar; | ||
|
||
/** | ||
* @brief 2d rotation vertex. | ||
*/ | ||
class Rotation2dVertex : public MultiVertex { | ||
public: | ||
using T = float; | ||
using T2 = float2; | ||
// Using `uint16` seems to be generating more efficient loops? | ||
using IndexType = unsigned short; | ||
|
||
static constexpr size_t MIN_ALIGN = 8; | ||
|
||
Input<Vector<T, poplar::VectorLayout::ONE_PTR, MIN_ALIGN>> | ||
cs; // (2,) rotation cosinus/sinus values | ||
Input<Vector<T, poplar::VectorLayout::ONE_PTR, MIN_ALIGN>> | ||
inrow0; // (N,) first input row vector | ||
Input<Vector<T, poplar::VectorLayout::ONE_PTR, MIN_ALIGN>> | ||
inrow1; // (N,) second input row vector | ||
|
||
Input<Vector<IndexType, poplar::VectorLayout::ONE_PTR>> | ||
worker_offsets; // (7,) number threads + 1. | ||
|
||
Output<Vector<T, poplar::VectorLayout::ONE_PTR>> | ||
outrow0; // (N,) first input row vector | ||
Output<Vector<T, poplar::VectorLayout::ONE_PTR>> | ||
outrow1; // (N,) first input row vector | ||
|
||
bool compute(unsigned wid) { | ||
// vectorized offsets. | ||
const IndexType wstart = worker_offsets[wid]; | ||
const IndexType wend = worker_offsets[wid + 1]; | ||
const IndexType wsize = wend - wstart; | ||
|
||
// Vertex inputs/outputs assuring proper alignment. | ||
const T2* inrow0_ptr = reinterpret_cast<const T2*>(inrow0.data()) + wstart; | ||
const T2* inrow1_ptr = reinterpret_cast<const T2*>(inrow1.data()) + wstart; | ||
const T2* cs_ptr = reinterpret_cast<const T2*>(cs.data()); | ||
T2* outrow0_ptr = reinterpret_cast<T2*>(outrow0.data()) + wstart; | ||
T2* outrow1_ptr = reinterpret_cast<T2*>(outrow1.data()) + wstart; | ||
|
||
rotation2d_f32(cs_ptr[0], inrow0_ptr, inrow1_ptr, outrow0_ptr, outrow1_ptr, | ||
wsize, IPU_DISPATCH_TAG); | ||
return true; | ||
} | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
// Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
#include "intrinsics_utils.hpp" | ||
|
||
/** | ||
* @brief z = a*x + b*y float32 implementation. | ||
* | ||
* where x, y, z are 1D arrays and a, b are scalars. | ||
* Implementation compatible with IPU model and hardware. | ||
* | ||
* Requires input arrays with size % 2 == 0 | ||
*/ | ||
inline void axplusby_f32_v0(float a, float b, const float2 *x, const float2 *y, | ||
float2 *z, rptsize_t nblocks) { | ||
using T2 = float2; | ||
const T2 av = {a, a}; | ||
const T2 bv = {b, b}; | ||
// Sub-optimal vectorized implementation. | ||
for (unsigned idx = 0; idx < nblocks; ++idx) { | ||
const T2 xv = ipu::load_postinc(&x, 1); | ||
const T2 yv = ipu::load_postinc(&y, 1); | ||
const T2 zv = av * xv + bv * yv; | ||
ipu::store_postinc(&z, zv, 1); | ||
} | ||
} | ||
|
||
inline void axplusby_f32_v1(float a, float b, const float2 *x, const float2 *y, | ||
float2 *z, rptsize_t nblocks) { | ||
// Necessary if using unsigned `nblocks`. | ||
// __builtin_assume(nblocks < 4096); | ||
using T2 = float2; | ||
const T2 av = {a, a}; | ||
// Using TAS register for one of the scalar. | ||
__ipu_and_ipumodel_tas tas; | ||
tas.put(b); | ||
|
||
T2 res, xv, yv, zv, tmp; | ||
|
||
xv = ipu::load_postinc(&x, 1); | ||
yv = ipu::load_postinc(&y, 1); | ||
res = xv * av; | ||
for (unsigned idx = 0; idx != nblocks; ++idx) { | ||
// Pseudo dual-issuing of instructions. | ||
// popc should be able to generate an optimal rpt loop. | ||
{ | ||
xv = ipu::load_postinc(&x, 1); | ||
// TODO: fix ordering of arguments in `f32v2axpy`. | ||
tmp = tas.f32v2axpy(res, yv); | ||
} | ||
{ | ||
yv = ipu::load_postinc(&y, 1); | ||
// TODO: fix ordering of arguments in `f32v2axpy`. | ||
zv = tas.f32v2axpy(tmp, tmp); | ||
} | ||
{ | ||
ipu::store_postinc(&z, zv, 1); | ||
res = xv * av; | ||
} | ||
} | ||
} | ||
|
||
/** | ||
* @brief Apply 2d rotation transform (float). | ||
* | ||
* Note: input rows are separated, allowing more flexibility | ||
* for functions/vertices using this base compute method. | ||
*/ | ||
inline void rotation2d_f32(float2 cs, const float2 *inrow0, | ||
const float2 *inrow1, float2 *outrow0, | ||
float2 *outrow1, rptsize_t nblocks, ipu::ModelTag) { | ||
axplusby_f32_v1(cs[0], -cs[1], inrow0, inrow1, outrow0, nblocks); | ||
axplusby_f32_v1(cs[1], cs[0], inrow0, inrow1, outrow1, nblocks); | ||
} | ||
|
||
inline void rotation2d_f32(float2 cs, const float2 *inrow0, | ||
const float2 *inrow1, float2 *outrow0, | ||
float2 *outrow1, rptsize_t nblocks, | ||
ipu::HardwareTag) { | ||
// Using same implementation as IPU model for now. | ||
rotation2d_f32(cs, inrow0, inrow1, outrow0, outrow1, nblocks, | ||
ipu::ModelTag{}); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
import os | ||
from typing import Any, Dict | ||
|
||
import numpy as np | ||
from jax.core import ShapedArray | ||
|
||
from tessellate_ipu.core import declare_ipu_tile_primitive | ||
from tessellate_ipu.core.tile_interpreter_vertex_utils import make_ipu_vector1d_worker_offsets | ||
|
||
|
||
def get_small_dot_vertex_gp_filename() -> str: | ||
return os.path.join(os.path.dirname(__file__), "../core", "vertex", "tile_small_dot.cpp") | ||
|
||
|
||
@declare_ipu_tile_primitive("Rotation2dVertex", gp_filename=get_small_dot_vertex_gp_filename()) | ||
def rotation2d_p(cs: ShapedArray, inrow0: ShapedArray, inrow1: ShapedArray): | ||
"""2d rotation apply primitive. | ||
Specific optimization on IPU backend compared to `dot_general_p` primitive. | ||
In particular, allows passing the 2 rows of the (2, N) input as separate arrays (in some | ||
applications, contiguous storage may not be possible). | ||
Args: | ||
cs: Cos/sin 2d rotation entries. | ||
inrow0: First row (N,) | ||
inrow1: Second row (N,) | ||
Returns: | ||
outrow0: First output row (N,) | ||
outrow1: Second output row (N,) | ||
""" | ||
N = inrow0.size | ||
assert N % 2 == 0 | ||
assert inrow0 == inrow1 | ||
assert cs.dtype == inrow0.dtype | ||
assert cs.dtype == inrow1.dtype | ||
assert inrow0.dtype == np.float32 | ||
|
||
outputs = { | ||
"outrow0": inrow0, | ||
"outrow1": inrow1, | ||
} | ||
constants = {"worker_offsets": make_ipu_vector1d_worker_offsets(N, vector_size=2, wdtype=np.uint16)} | ||
temps: Dict[str, Any] = {} | ||
perf_estimate = 100 | ||
return outputs, constants, temps, perf_estimate |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
from functools import partial | ||
|
||
import chex | ||
import jax | ||
import numpy as np | ||
import numpy.testing as npt | ||
import pytest | ||
|
||
from tessellate_ipu import ipu_cycle_count, tile_map, tile_put_replicated | ||
from tessellate_ipu.lax.tile_lax_small_dot import rotation2d_p | ||
|
||
|
||
@pytest.mark.ipu_hardware | ||
class IpuTileRotation2dHwTests(chex.TestCase): | ||
def setUp(self): | ||
super().setUp() | ||
np.random.seed(42) | ||
|
||
def test__tile_map__rotation2d_primitive__proper_result_and_cycle_count(self): | ||
N = 512 | ||
tiles = (0,) | ||
indata = np.random.randn(2, N).astype(np.float32) | ||
cs = np.random.randn(2).astype(np.float32) | ||
rot2d = np.array([[cs[0], -cs[1]], [cs[1], cs[0]]]).astype(np.float32) | ||
|
||
def compute_fn(cs, row0, row1): | ||
cs = tile_put_replicated(cs, tiles) | ||
row0 = tile_put_replicated(row0, tiles) | ||
row1 = tile_put_replicated(row1, tiles) | ||
# Benchmark the raw 2d rotation vertex. | ||
cs, row0, row1, start = ipu_cycle_count(cs, row0, row1) | ||
outrow0, outrow1 = tile_map(rotation2d_p, cs, row0, row1) # type:ignore | ||
outrow0, outrow1, end = ipu_cycle_count(outrow0, outrow1) | ||
|
||
return outrow0, outrow1, start, end | ||
|
||
compute_fn_ipu = partial(jax.jit, backend="ipu")(compute_fn) | ||
outrow0, outrow1, start, end = compute_fn_ipu(cs, indata[0], indata[1]) | ||
|
||
# Checking getting the proper result! | ||
expected_out = rot2d @ indata | ||
npt.assert_array_almost_equal(np.ravel(outrow0), expected_out[0], decimal=6) | ||
npt.assert_array_almost_equal(np.ravel(outrow1), expected_out[1], decimal=6) | ||
# Hardware cycle count bound. | ||
start, end = np.asarray(start)[0], np.asarray(end)[0] | ||
hw_cycle_count = end[0] - start[0] | ||
# Observe on IPU Mk2 hw ~1916 cycles. | ||
assert hw_cycle_count <= 2000 |