Skip to content

Commit

Permalink
xe: conv: fix small GRF mode heuristic handling
Browse files Browse the repository at this point in the history
  • Loading branch information
echeresh committed Nov 5, 2024
1 parent 7740c75 commit e595e59
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 43 deletions.
3 changes: 2 additions & 1 deletion src/gpu/intel/jit/conv/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1760,13 +1760,14 @@ status_t init_cfg(conv_config_t &cfg, const primitive_t *prim) {
static std::mutex tune_mutex;
std::unique_lock<std::mutex> lock(tune_mutex, std::defer_lock_t());
if (cfg.tiler().is_tuning_mode()) lock.lock();
while (cfg.tiler().can_move_next()) {
while (cfg.tiler().is_valid()) {
auto try_cfg = cfg;
auto status = try_init_cfg(try_cfg);
if (status == status::success) {
cfg = std::move(try_cfg);
return status::success;
}
cfg.tiler().move_next(cfg);
}
return status::runtime_error;
}
Expand Down
5 changes: 3 additions & 2 deletions src/gpu/intel/jit/conv/gen_convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,11 @@ class gen_convolution_t {
if (data.zp_pd) zp_dst = layout_t(data.zp_pd->impl()->dst_md(), false);

if (primitive->cache_blob()) {
tiler->set_cur_index(primitive->version() - 1);
tiler->set_cur_version(primitive->version());
}

for (int try_iter = 0; try_iter < max_tries; try_iter++) {
if (try_iter != 0) tiler->move_next(cfg);
try {
cfg = data.pd_cfg;
cfg.set_pd(
Expand Down Expand Up @@ -259,7 +260,7 @@ class gen_convolution_t {
if (!tmp_kernels[i]) return status::runtime_error;
}
ok = true;
primitive->set_version(tiler->cur_index());
primitive->set_version(tiler->cur_version());
kernels_ = std::move(tmp_kernels);
break;
} catch (ngen::out_of_registers_exception &err) {
Expand Down
91 changes: 60 additions & 31 deletions src/gpu/intel/jit/conv/tiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1255,12 +1255,9 @@ class conv_tuner_t {
maybe_rescore();
}

bool can_move_next() const { return params_gen_.can_move_next(); }
bool is_valid() const { return params_gen_.is_valid(); }

void move_next() {
ir_assert(can_move_next());
params_gen_.move_next();
}
void move_next() { params_gen_.move_next(); }

int cur_index() const { return params_gen_.cur_index(); }

Expand Down Expand Up @@ -1423,6 +1420,13 @@ std::unordered_map<const impl::primitive_t *, conv_tuner_t::primitive_info_t>

std::mutex conv_tuner_t::mutex_;

enum class grf_mode_policy_t {
// Try 128 GRF mode based on heuristics.
try_small_grf = 0,
// Use default_regs().
_default = 1
};

class conv_tiler_impl_t {
public:
conv_tiler_impl_t() = default;
Expand All @@ -1445,30 +1449,46 @@ class conv_tiler_impl_t {

bool is_tuning_mode() const { return tuner_; }

bool can_move_next() const {
if (is_tuning_mode()) return tuner_->can_move_next();
return params_gen_.can_move_next();
bool is_valid() const {
if (is_tuning_mode()) return tuner_->is_valid();
return params_gen_.is_valid();
}

int cur_index() const {
if (is_tuning_mode()) return tuner_->cur_index();
return params_gen_.cur_index();
void move_next(const conv_config_t &cfg) {
if (is_tuning_mode()) {
tuner_->move_next();
return;
}
if (grf_mode_policy_ == grf_mode_policy_t::try_small_grf
&& cfg.regs() != default_regs(cfg)) {
grf_mode_policy_ = grf_mode_policy_t::_default;
return;
}
grf_mode_policy_ = grf_mode_policy_t::try_small_grf;
params_gen_.move_next();
}

void set_cur_index(int idx) {
int32_t cur_version() const {
return pack_version(is_tuning_mode() ? tuner_->cur_index()
: params_gen_.cur_index(),
grf_mode_policy_);
}

void set_cur_version(int32_t version) {
ir_assert(!is_tuning_mode());
return params_gen_.set_cur_index(idx);
int idx;
unpack_version(version, idx, grf_mode_policy_);
params_gen_.set_cur_index(idx);
}

void set_params(conv_config_t &cfg) {
init_regs(cfg);
if (is_tuning_mode()) {
tuner_->move_next();
tuner_->set_params(cfg);
} else {
if (!try_small_grf_) params_gen_.move_next();
params_gen_.set_params(cfg);
maybe_try_small_grf(cfg);
if (grf_mode_policy_ == grf_mode_policy_t::try_small_grf)
maybe_try_small_grf(cfg);
}
}

Expand All @@ -1492,6 +1512,16 @@ class conv_tiler_impl_t {
}

private:
static int32_t pack_version(int idx, grf_mode_policy_t policy) {
return idx * 2 + static_cast<int>(policy);
}

static void unpack_version(
int32_t version, int &idx, grf_mode_policy_t &policy) {
idx = version / 2;
policy = static_cast<grf_mode_policy_t>(version % 2);
}

void init(const conv_config_t &cfg) {
if (cfg.loop_dims().is_overridden()
|| cfg.thread_group_dims().is_overridden()
Expand Down Expand Up @@ -1557,6 +1587,8 @@ class conv_tiler_impl_t {
}

void maybe_try_small_grf(conv_config_t &cfg) {
if (cfg.regs() == 128 || cfg.exec_cfg_param().is_overridden("regs"))
return;
auto try_cfg = cfg;
init_walk_order(try_cfg);
init_kernel_grid(try_cfg);
Expand All @@ -1568,14 +1600,7 @@ class conv_tiler_impl_t {
try_cfg.exec_cfg(), kg_elems, tg_elems);
int wave_util = conv_config_t::get_wave_utilization(
cfg.exec_cfg(), kg_elems, tg_elems);
if (wave_util > 90 && new_wave_util >= wave_util && !try_small_grf_
&& cfg.regs() > 128
&& !cfg.exec_cfg_param().is_overridden("regs")) {
cfg.set_regs(128);
try_small_grf_ = true;
} else {
try_small_grf_ = false;
}
if (wave_util > 90 && new_wave_util >= wave_util) cfg.set_regs(128);
}

void print_info(double init_time_ms) {
Expand All @@ -1589,7 +1614,7 @@ class conv_tiler_impl_t {
params_generator_t params_gen_;
conv_tuner_t *tuner_ = nullptr;
int grf_usage_limit_ = 0;
bool try_small_grf_ = false;
grf_mode_policy_t grf_mode_policy_ = grf_mode_policy_t::try_small_grf;
};

conv_tiler_t::conv_tiler_t(const conv_config_t &cfg)
Expand All @@ -1603,16 +1628,20 @@ bool conv_tiler_t::is_tuning_mode() const {
return impl_->is_tuning_mode();
}

bool conv_tiler_t::can_move_next() const {
return impl_->can_move_next();
bool conv_tiler_t::is_valid() const {
return impl_->is_valid();
}

void conv_tiler_t::move_next(const conv_config_t &cfg) {
impl_->move_next(cfg);
}

int conv_tiler_t::cur_index() const {
return impl_->cur_index();
int32_t conv_tiler_t::cur_version() const {
return impl_->cur_version();
}

void conv_tiler_t::set_cur_index(int idx) {
impl_->set_cur_index(idx);
void conv_tiler_t::set_cur_version(int32_t version) {
impl_->set_cur_version(version);
}

void conv_tiler_t::set_params(conv_config_t &cfg) {
Expand Down
7 changes: 4 additions & 3 deletions src/gpu/intel/jit/conv/tiler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ class conv_tiler_t {
void set_tuner(conv_tuner_t *tuner);
int configs() const;
bool is_tuning_mode() const;
bool can_move_next() const;
int cur_index() const;
void set_cur_index(int idx);
bool is_valid() const;
void move_next(const conv_config_t &cfg);
int32_t cur_version() const;
void set_cur_version(int32_t idx);
void set_params(conv_config_t &cfg);
void notify_out_of_registers(const conv_config_t &cfg);
bool is_grf_limit_ok(const conv_config_t &cfg) const;
Expand Down
9 changes: 3 additions & 6 deletions src/gpu/intel/jit/ir/blocking.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -548,12 +548,9 @@ class params_generator_t {

bool is_empty() const { return params_vec_.empty(); }

bool can_move_next() const { return cur_idx_ + 1 < configs(); }
bool is_valid() const { return cur_idx_ < configs(); }

void move_next() {
ir_assert(can_move_next());
cur_idx_++;
}
void move_next() { cur_idx_++; }

int cur_index() const { return cur_idx_; }

Expand Down Expand Up @@ -614,7 +611,7 @@ class params_generator_t {
blocking_checker_t &chk, int tune_level, int simd_size);

std::vector<blocking_params_t> params_vec_;
int cur_idx_ = -1;
int cur_idx_ = 0;
};

enum class tiler_mode_t {
Expand Down

0 comments on commit e595e59

Please sign in to comment.