Skip to content

Commit

Permalink
gpu: jit: conv: work around MSVC bug
Browse files Browse the repository at this point in the history
  • Loading branch information
echeresh committed Jan 5, 2023
1 parent 23576f9 commit 4024775
Showing 1 changed file with 24 additions and 9 deletions.
33 changes: 24 additions & 9 deletions src/gpu/jit/conv/config.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2021-2022 Intel Corporation
* Copyright 2021-2023 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -1309,14 +1309,29 @@ class bmnk_dim_helper_t {
static const char *bwd_w_n_dims[] = {"oc", nullptr};
static const char *bwd_w_k_dims[] = {"mb", "od", "oh", "ow", nullptr};

const char **b_dims = prb_.pick_by_dir<const char **>(
fwd_b_dims, bwd_d_b_dims, bwd_w_b_dims);
const char **m_dims = prb_.pick_by_dir<const char **>(
fwd_m_dims, bwd_d_m_dims, bwd_w_m_dims);
const char **n_dims = prb_.pick_by_dir<const char **>(
fwd_n_dims, bwd_d_n_dims, bwd_w_n_dims);
const char **k_dims = prb_.pick_by_dir<const char **>(
fwd_k_dims, bwd_d_k_dims, bwd_w_k_dims);
// XXX: Do not use pick_by_dir() to work around MSVC compiler bug.
const char **b_dims = nullptr;
const char **m_dims = nullptr;
const char **n_dims = nullptr;
const char **k_dims = nullptr;
if (prb_.is_fwd) {
b_dims = fwd_b_dims;
m_dims = fwd_m_dims;
n_dims = fwd_n_dims;
k_dims = fwd_k_dims;
} else if (prb_.is_bwd_d) {
b_dims = bwd_d_b_dims;
m_dims = bwd_d_m_dims;
n_dims = bwd_d_n_dims;
k_dims = bwd_d_k_dims;
} else if (prb_.is_bwd_w) {
b_dims = bwd_w_b_dims;
m_dims = bwd_w_m_dims;
n_dims = bwd_w_n_dims;
k_dims = bwd_w_k_dims;
} else {
ir_error_not_expected();
}

if (contains(b_dims, dim_name)) return 'b';
if (contains(m_dims, dim_name)) return 'm';
Expand Down

0 comments on commit 4024775

Please sign in to comment.