From 2397989023b4b17e75e7208ed306d25534f5962d Mon Sep 17 00:00:00 2001 From: Mei Ye Date: Wed, 10 May 2023 06:01:32 +0000 Subject: [PATCH] Revert "[Vulkan] Add cooperative matrix support" This reverts commit 34ceb8cc210a0dbcbac347b2b1508164ac23a66b. --- include/tvm/meta_schedule/task_scheduler.h | 10 +- python/tvm/meta_schedule/relay_integration.py | 4 - .../task_scheduler/task_scheduler.py | 6 - python/tvm/meta_schedule/tir_integration.py | 4 - python/tvm/meta_schedule/tune.py | 5 +- python/tvm/relay/op/strategy/cuda.py | 12 - python/tvm/target/target.py | 21 - python/tvm/topi/cuda/conv2d.py | 365 +----------------- python/tvm/topi/cuda/conv2d_alter_op.py | 50 +-- python/tvm/topi/nn/__init__.py | 1 - python/tvm/topi/nn/utils.py | 128 ------ src/arith/iter_affine_map.cc | 10 - src/meta_schedule/postproc/verify_gpu_code.cc | 2 +- .../search_strategy/replay_func.cc | 10 - .../task_scheduler/gradient_based.cc | 4 +- .../task_scheduler/task_scheduler.cc | 29 +- src/relay/backend/te_compiler.cc | 1 - src/relay/backend/utils.cc | 4 +- src/relay/backend/utils.h | 10 - src/runtime/vulkan/vulkan_device.cc | 5 +- src/runtime/vulkan/vulkan_device.h | 1 - src/runtime/vulkan/vulkan_device_api.cc | 4 - src/target/spirv/codegen_spirv.cc | 235 +---------- src/target/spirv/codegen_spirv.h | 16 - src/target/spirv/ir_builder.cc | 68 +--- src/target/spirv/ir_builder.h | 25 +- src/target/spirv/spirv_support.cc | 4 - src/target/spirv/spirv_support.h | 14 - src/target/target_kind.cc | 1 - src/tir/transforms/storage_rewrite.cc | 3 - tests/python/unittest/test_wmma.py | 248 ------------ 31 files changed, 46 insertions(+), 1254 deletions(-) delete mode 100644 tests/python/unittest/test_wmma.py diff --git a/include/tvm/meta_schedule/task_scheduler.h b/include/tvm/meta_schedule/task_scheduler.h index a4cfeacec02db..f4fc491286dda 100644 --- a/include/tvm/meta_schedule/task_scheduler.h +++ b/include/tvm/meta_schedule/task_scheduler.h @@ -175,7 +175,6 @@ class TaskSchedulerNode : public runtime::Object { * \param measure_callbacks The callbacks to be called after each measurement * \param database The database used in tuning * \param cost_model The cost model used in tuning - * \param min_design_space The minimum design space used in tuning */ virtual void Tune(Array tasks, // Array task_weights, // @@ -186,8 +185,7 @@ class TaskSchedulerNode : public runtime::Object { Runner runner, // Array measure_callbacks, // Optional database, // - Optional cost_model, // - int min_design_space); + Optional cost_model); /*! * \brief Terminate a task * \param task_id The id of the task to be terminated @@ -230,9 +228,7 @@ class PyTaskSchedulerNode : public TaskSchedulerNode { Runner runner, // Array measure_callbacks, // Optional database, // - Optional cost_model, - - int min_design_space)>; + Optional cost_model)>; /*! \brief The packed function to the `NextTaskId` function. */ FNextTaskId f_next_task_id; @@ -253,7 +249,7 @@ class PyTaskSchedulerNode : public TaskSchedulerNode { void Tune(Array tasks, Array task_weights, int max_trials_global, int max_trials_per_task, int num_trials_per_iter, Builder builder, Runner runner, Array measure_callbacks, Optional database, - Optional cost_model, int min_design_space) final; + Optional cost_model) final; static constexpr const char* _type_key = "meta_schedule.PyTaskScheduler"; TVM_DECLARE_FINAL_OBJECT_INFO(PyTaskSchedulerNode, TaskSchedulerNode); diff --git a/python/tvm/meta_schedule/relay_integration.py b/python/tvm/meta_schedule/relay_integration.py index 5367c430f7c84..41d3f9d12ebc1 100644 --- a/python/tvm/meta_schedule/relay_integration.py +++ b/python/tvm/meta_schedule/relay_integration.py @@ -275,7 +275,6 @@ def tune_relay( num_tuning_cores: Union[Literal["physical", "logical"], int] = "physical", disabled_pass: Optional[Union[List[str], Set[str], Tuple[str]]] = None, instruments: Optional[Sequence[PassInstrument]] = None, - min_design_space: int = 1, ) -> Database: """Tune a Relay program. @@ -329,8 +328,6 @@ def tune_relay( The list of disabled passes during tasks extraction instruments : Optional[Sequence[PassInstrument]] The list of pass instrument implementations. - min_design_space : int - The minimum design space. Returns ------- @@ -366,7 +363,6 @@ def tune_relay( measure_callbacks=measure_callbacks, task_scheduler=task_scheduler, module_equality=module_equality, - min_design_space=min_design_space, ) diff --git a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py index af1abb6f13011..d56d944474e9a 100644 --- a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py +++ b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py @@ -102,7 +102,6 @@ def tune( measure_callbacks: List[MeasureCallback], database: Optional[Database], cost_model: Optional[CostModel], - min_design_space: int = 1, ) -> None: """Auto-tuning. @@ -128,8 +127,6 @@ def tune( The database. cost_model : Optional[CostModel] The cost model. - min_design_space : int - THe minimum size of design space. """ task_weights = [float(w) for w in task_weights] _ffi_api.TaskSchedulerTune( # type: ignore # pylint: disable=no-member @@ -144,7 +141,6 @@ def tune( measure_callbacks, database, cost_model, - min_design_space, ) def terminate_task(self, task_id: int) -> None: @@ -247,7 +243,6 @@ def tune( measure_callbacks: List[MeasureCallback], database: Optional[Database], cost_model: Optional[CostModel], - min_design_space: int = 1, ) -> None: """Auto-tuning.""" # Using self._outer to replace the self pointer @@ -262,7 +257,6 @@ def tune( measure_callbacks, database, cost_model, - min_design_space, ) def next_task_id(self) -> int: diff --git a/python/tvm/meta_schedule/tir_integration.py b/python/tvm/meta_schedule/tir_integration.py index 367f06a192fab..5f6f82bf148bc 100644 --- a/python/tvm/meta_schedule/tir_integration.py +++ b/python/tvm/meta_schedule/tir_integration.py @@ -58,7 +58,6 @@ def tune_tir( # pylint: disable=too-many-locals seed: Optional[int] = None, module_equality: str = "structural", special_space: Optional[Mapping[str, SpaceGenerator.SpaceGeneratorType]] = None, - min_design_space: int = 1, ) -> Database: """Tune a TIR function or an IRModule of TIR functions. @@ -100,8 +99,6 @@ def tune_tir( # pylint: disable=too-many-locals A string to specify the module equality testing and hashing method. special_space : Optional[Mapping[str, SpaceGenerator.SpaceGeneratorType]] A mapping from task name to a special space generator for that task. - min_design_space : int - The minimum design space. Returns ------- @@ -157,7 +154,6 @@ def tune_tir( # pylint: disable=too-many-locals measure_callbacks=measure_callbacks, task_scheduler=task_scheduler, module_equality=module_equality, - min_design_space=min_design_space, ) diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py index d535d6f187bcc..132f446a52525 100644 --- a/python/tvm/meta_schedule/tune.py +++ b/python/tvm/meta_schedule/tune.py @@ -41,7 +41,6 @@ def tune_tasks( measure_callbacks: MeasureCallback.CallbackListType = "default", task_scheduler: TaskScheduler.TaskSchedulerType = "gradient", module_equality: str = "structural", - min_design_space: int = 1, ) -> Database: """Tune a list of tasks. Using a task scheduler. @@ -74,6 +73,7 @@ def tune_tasks( module_equality : Optional[str] A string to specify the module equality testing and hashing method. It must be one of the followings: + - "structural": Use StructuralEqual/Hash - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during equality testing and hashing. @@ -81,8 +81,6 @@ def tune_tasks( a given module. The "ignore-ndarray" varint is used for the extracted blocks or in case no anchor block is found. For the definition of the anchor block, see tir/analysis/analysis.py. - min_design_space : int - Minimum design space. Returns ------- @@ -128,6 +126,5 @@ def tune_tasks( measure_callbacks=measure_callbacks, database=database, cost_model=cost_model, - min_design_space=min_design_space, ) return database diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 5c86cfa948bc0..65573321f76cd 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -145,13 +145,9 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): kernel_layout = attrs.kernel_layout if dilation_h < 1 or dilation_w < 1: raise ValueError("dilation should be positive value") - if groups == 1: if layout == "NCHW": assert kernel_layout == "OIHW" - do_im2col = topi.nn.use_im2col( - data, kernel, stride_h, stride_w, dilation_h, dilation_w, padding - ) if ( (target.kind.name in ["cuda", "vulkan", "rocm"]) and data.dtype in ("int8", "uint8") @@ -163,14 +159,6 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw_int8), name="conv2d_nchw_int8.cuda", ) - elif do_im2col: - assert data.dtype == kernel.dtype - strategy.add_implementation( - wrap_compute_conv2d(topi.cuda.conv2d_nchw_mma), - naive_schedule, - name="conv2d_nchw_mma.cuda", - plevel=15, - ) else: strategy.add_implementation( wrap_compute_conv2d(topi.cuda.conv2d_nchw), diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index fce9f3e6becca..06e1776965c23 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -173,25 +173,11 @@ def max_num_threads(self): """Returns the max_num_threads from the target if it exists.""" return int(self.attrs["max_num_threads"]) - @property - def max_block_size_x(self): - """Returns the max block size in x-dimension from the target if it exists.""" - return int(self.attrs["max_block_size_x"]) - - @property - def max_block_size_y(self): - """Returns the max block size in y-dimension from the target if it exists.""" - return int(self.attrs["max_block_size_y"]) - @property def thread_warp_size(self): """Returns the thread_warp_size from the target if it exists.""" return int(self.attrs["thread_warp_size"]) - @property - def max_shared_memory_per_block(self): - return int(self.attrs["max_shared_memory_per_block"]) - @property def max_function_args(self): return int(self.attrs.get("max_function_args", -1)) @@ -233,13 +219,6 @@ def supports_integer_dot_product(self): def libs(self): return list(self.attrs.get("libs", [])) - @property - def supports_cooperative_matrix(self): - if self.attrs.get("supports_cooperative_matrix", []): - return bool(self.attrs["supports_cooperative_matrix"]) - else: - return False - @property def features(self): return TargetFeatures(self) diff --git a/python/tvm/topi/cuda/conv2d.py b/python/tvm/topi/cuda/conv2d.py index f2eb1468e28bb..bce032040dcd9 100644 --- a/python/tvm/topi/cuda/conv2d.py +++ b/python/tvm/topi/cuda/conv2d.py @@ -16,31 +16,15 @@ # under the License. # pylint: disable=invalid-name, unused-argument """Compute definition for conv2d with cuda backend""" -import tvm -from tvm.tir.schedule import BlockRV, Schedule -from tvm._ffi import register_func from tvm import te from tvm import autotvm from tvm.autotvm.task.space import OtherOptionEntity from tvm.contrib import cudnn -from tvm.tir.tensor_intrin.cuda import ( - WMMA_LOAD_16x16x16_F16_A_INTRIN, - WMMA_LOAD_16x16x16_F16_B_INTRIN, - WMMA_SYNC_16x16x16_f16f16f32_INTRIN, - WMMA_FILL_16x16x16_F32_INTRIN, - WMMA_STORE_16x16x16_F32_SHARED_INTRIN, - WMMA_SYNC_16x16x16_f16f16f16_INTRIN, - WMMA_FILL_16x16x16_F16_INTRIN, - WMMA_STORE_16x16x16_F16_SHARED_INTRIN, -) - from .. import nn, generic -from ..nn.utils import get_pad_tuple, get_output_shape -from ..nn.pad import pad +from ..nn.utils import get_pad_tuple from ..utils import get_const_tuple, traverse_inline from .conv2d_direct import schedule_direct_cuda -from ..transform import reshape @autotvm.register_topi_compute("conv2d_nchw.cuda") @@ -164,350 +148,3 @@ def conv2d_backward_weight_cudnn( conv_dtype=conv_dtype, groups=groups, ) - - -@autotvm.register_topi_compute("conv2d_nchw_mma.cuda") -def conv2d_nchw_mma(cfg, data, kernel, strides, padding, dilation, out_dtype="float32"): - """Compute conv2d nchw using im2col""" - assert data.dtype == "float16" - out_channels, in_channels, kernel_h, kernel_w = get_const_tuple(kernel.shape) - - if isinstance(dilation, int): - dilation_h = dilation_w = dilation - else: - dilation_h, dilation_w = dilation - assert dilation_h == 1 and dilation_w == 1 - - if isinstance(strides, int): - stride_h = stride_w = strides - else: - stride_h, stride_w = strides - - batch_size, _, P, Q = get_output_shape( - data, kernel, stride_h, stride_w, dilation_h, dilation_w, padding - ) - assert batch_size == 1 - - pad_top, pad_left, pad_down, pad_right = get_pad_tuple(padding, (kernel_h, kernel_w)) - pad_before = [0, 0, pad_top, pad_left] - pad_after = [0, 0, pad_down, pad_right] - - if all([v == 0 for v in pad_before]) and all([v == 0 for v in pad_after]): - pad_data = data - else: - pad_data = pad(data, pad_before, pad_after, name="pad_data") - - M = out_channels - K = in_channels * kernel_h * kernel_w - N = batch_size * P * Q - - if kernel_h * kernel_w == 1: - ck = te.reduce_axis((0, K), name="k") - A = reshape(kernel, (M, K)) - B = reshape(pad_data, (K, N)) - C = te.compute( - (batch_size, out_channels, P, Q), - lambda b, o, h, w: te.sum( - A[o, ck].astype(out_dtype) * B[ck, h * Q + w].astype(out_dtype), - axis=[ck], - ), - name="conv2d_nchw_mma", - attrs={ - "schedule_rule": "conv2d_nchw_mma", - }, - ) - else: - # Convert the kernel of (O,I,H,W) to (N,K) format i.e (OC,IC*KH*KW) - A = te.compute( - (M, K), - lambda x, y: kernel[ - x, (y // (kernel_h * kernel_w)), (y // kernel_w) % kernel_h, y % kernel_w - ], - name="T_reshape", - ) - - # Convert the data of (N,C,H,W) to (K,M) format i.e (IC*KH*KW,OH*OW) - B = te.compute( - (K, N), - lambda y, x: pad_data[ - 0, - y // (kernel_h * kernel_w), - stride_h * (x // Q) + ((y // kernel_w) % kernel_h), - stride_w * (x % Q) + y % kernel_w, - ], - name="T_reshape_1", - ) - - # Apply GEMM operation. The result will be of (N,O,H,W) format - ck = te.reduce_axis((0, K), name="k") - C = te.compute( - (batch_size, out_channels, P, Q), - lambda b, o, h, w: te.sum( - A[o, ck].astype(out_dtype) - * B[ck, (b * in_channels * P * Q) + h * Q + w].astype(out_dtype), - axis=[ck], - ), - name="conv2d_nchw_mma", - attrs={ - "schedule_rule": "conv2d_nchw_mma", - }, - ) - return C - - -def schedule_rule_conv2d_nchw_mma(sch: Schedule, block: BlockRV): - """Create the schedule for conv2d nchw im2col""" - k_inner = 16 - - target = tvm.target.Target.current(allow_none=False) - - i_factors = [] - j_factors = [] - k_factors = [] - - # comment out the following line enables sampling - # i_factors, j_factors, k_factors = [16, 8, 2, 1, 1], [1, 32, 2, 4, 1], [64, 4, 1] - do_sample = False - if len(i_factors) == 0 and len(j_factors) == 0 and len(k_factors) == 0: - do_sample = True - - shared_scope = "shared" - warp_size = target.thread_warp_size - vector_size = 4 - b_transposed = False - block = sch.get_block("conv2d_nchw_mma") - write_buf = sch.get_sref(block).stmt.writes - output_type = write_buf[0].buffer.dtype - - if output_type == "float32": - out_bytes_per_ele = 4 - else: - out_bytes_per_ele = 2 - - read_buf = sch.get_sref(sch.get_block("T_reshape_1")).stmt.reads - in_channels = read_buf[0].buffer.shape[-3] - - data_type = read_buf[0].buffer.dtype - bytes_per_ele = 4 - - if data_type == "float16": - bytes_per_ele = 2 - else: - raise ValueError("Unsupported data type" % data_type) - - read_buf = sch.get_sref(sch.get_block("T_reshape")).stmt.reads - kernel_height = read_buf[0].buffer.shape[-2] - kernel_width = read_buf[0].buffer.shape[-1] - k_dim = in_channels * kernel_height * kernel_width - - loops = sch.get_loops(block) - i3_extent = sch.get_sref(loops[-2]).stmt.extent - - sch.transform_block_layout( - block, lambda i0, i1, i2, i3, i4: ((i0 * in_channels + i1), (i2 * i3_extent + i3), i4) - ) - block1 = sch.reindex(block=block, buffer=("write", 0)) - - i, j, k = sch.get_loops(block) - i, i_tc = sch.split(i, factors=[None, 16]) - j, j_tc = sch.split(j, factors=[None, 16]) - k, k_tc = sch.split(k, factors=[None, k_inner]) - - sch.reorder(i, j, k, i_tc, j_tc, k_tc) - block_inner = sch.blockize(i_tc) - block_outer, block_inner = block_inner, block - - max_local = 65536 - tile_row = 16 - tile_col = 16 - max_num_threads = target.max_num_threads - max_shared_mem = target.max_shared_memory_per_block - max_block_size_x = target.max_block_size_x - max_block_size_y = target.max_block_size_y - - if do_sample: - while True: - # sample i-axis - factors = sch.sample_perfect_tile(i, n=5) - i_factors = [sch.get(e) for e in factors] - - # Local memory (register) constraint - j3_j4_max = ( - (max_local - i_factors[3] * i_factors[4] * tile_row * tile_col) - // (tile_row * tile_col) - // (1 + i_factors[3] * i_factors[4]) - ) - - # max_num_threads constraint - j2_max = max_num_threads // warp_size // i_factors[2] - - # max_block_size constraint - j0_max = max_block_size_x // i_factors[0] - j1_max = max_block_size_y // i_factors[1] - - factors = sch.sample_perfect_tile(j, n=5) - j_factors = [sch.get(e) for e in factors] - - if ( - j_factors[0] > j0_max - or j_factors[1] > j1_max - or j_factors[2] > j2_max - or j_factors[3] * j_factors[4] > j3_j4_max - ): - continue - - # Calculate the shared mem required for the staging buffer. - # In the compact_buffer pass, the size of the final buffer - # will be determined based on below calculation - # Buffer[(i2 - 1) * (tile_row * i3 * i4) + tile_row, (j2 - 1) - # * (tile_col * j3 * j4) + tile_col] - x_dim = (i_factors[2] - 1) * (tile_row * i_factors[3] * i_factors[4]) + tile_row - y_dim = (j_factors[2] - 1) * (tile_col * j_factors[3] * j_factors[4]) + tile_col - max_out_shared_mem = x_dim * y_dim * out_bytes_per_ele - - # calculate the remaining shared memory after allocating staging buffer - rem_shared_mem = max_shared_mem - max_out_shared_mem - if rem_shared_mem <= 0: - continue - - # max_shared_memory_per_block constraint - matrix_A_and_B_shared_mem = ( - i_factors[2] * i_factors[3] * i_factors[4] * tile_row - + j_factors[2] * j_factors[3] * j_factors[4] * tile_col - ) * bytes_per_ele - - k0_min = (k_dim * matrix_A_and_B_shared_mem + rem_shared_mem - 1) // rem_shared_mem - k1_k2_max = tile_row * tile_col // k0_min - - factors = sch.sample_perfect_tile(k, n=3) - k_factors = [sch.get(e) for e in factors] - - if k_factors[0] >= k0_min and k_factors[1] * k_factors[2] <= k1_k2_max: - break - - num_ty = i_factors[2] * j_factors[2] - i0, i1, i2, i3, i4 = sch.split(i, factors=i_factors) - j0, j1, j2, j3, j4 = sch.split(j, factors=j_factors) - k0, k1, k2 = sch.split(k, k_factors) - - sch.reorder(i0, j0, i1, j1, i2, j2, k0, k1, i3, j3, k2, i4, j4) - - block_idx = sch.fuse(i0, j0) - block_idy = sch.fuse(i1, j1) - thread_idy = sch.fuse(i2, j2) - sch.bind(block_idx, "blockIdx.x") - sch.bind(block_idy, "blockIdx.y") - sch.bind(thread_idy, "threadIdx.y") - - def fetch_to_shared(block, idx, ndim): - block_read = sch.cache_read(block, idx, shared_scope) - sch.compute_at(block_read, k0) - fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) - if idx == 0: - # Calculate the size of data that will be copied into the shared buffer for buffer 0 - copy_size = i_factors[3] * i_factors[4] * 16 * k_factors[1] * k_factors[2] * k_inner - else: - # Calculate the size of data that will be copied into the shared buffer for buffer 1 - copy_size = j_factors[3] * j_factors[4] * 16 * k_factors[1] * k_factors[2] * k_inner - - # Instead of interleaving, copy the data buffer into the shared buffer in contiguous blocks. - # By doing this, the compiler will generate ds_load_b128 instructions instead of the default - # ds_load_u16/ds_load_u64, improving performance. - factors = num_ty * warp_size * vector_size - loop_size = copy_size // factors - - if loop_size > 1: - _, f_1, f_2, f_3, f_4 = sch.split( - fused, factors=[None, num_ty, warp_size, loop_size, vector_size] - ) - sch.vectorize(f_4) - else: - # Once the copying has been divided into threads in x - # and y dimensions, see if the remaining buffer can be vectorized. - v_size = copy_size // (num_ty * warp_size) - if v_size >= vector_size: - _, f_1, f_2, f_3 = sch.split(fused, factors=[None, num_ty, warp_size, vector_size]) - sch.vectorize(f_3) - else: - # Do not vectorize because the amount of space left over - # after threading the copy is less than vector_size. - _, f_1, f_2 = sch.split(fused, factors=[None, num_ty, warp_size]) - sch.bind(f_2, "threadIdx.x") - sch.bind(f_1, "threadIdx.y") - - return block_read - - fetch_to_shared(block_outer, 0, 2) - fetch_to_shared(block_outer, 1, 2) - - c_warp_scope = "wmma.accumulator" - a_warp_scope = "wmma.matrix_a" - b_warp_scope = "wmma.matrix_b" - - A_warp = sch.cache_read(block_outer, 0, a_warp_scope) - B_warp = sch.cache_read(block_outer, 1, b_warp_scope) - - sch.compute_at(A_warp, k1) - sch.compute_at(B_warp, k1) - - C_warp = sch.cache_write(block_outer, 0, c_warp_scope) - sch.reverse_compute_at(C_warp, thread_idy) - - ii, jj = sch.get_loops(C_warp)[-2:] - io, ii = sch.split(ii, factors=[None, 16]) - jo, ji = sch.split(jj, factors=[None, 16]) - sch.reorder(io, jo, ii, ji) - - sch.decompose_reduction(block_outer, sch.get_loops(block_outer)[3]) - block_init_c = sch.get_block("conv2d_nchw_mma_init") - - def tile_wmma_fragment(block_read, height, width): - i, j = sch.get_loops(block_read)[-2:] - i0, i1 = sch.split(i, factors=[None, height]) - j0, j1 = sch.split(j, factors=[None, width]) - sch.reorder(i0, j0, i1, j1) - return i1 - - loop_a = tile_wmma_fragment(A_warp, 16, k_inner) - - if b_transposed: - loop_b = tile_wmma_fragment(B_warp, 16, k_inner) - else: - loop_b = tile_wmma_fragment(B_warp, k_inner, 16) - - sch.reverse_compute_at(block1, jo) - fused = sch.fuse(*sch.get_loops(block1)[-2:]) - _, f_2, f_3 = sch.split(fused, factors=[None, warp_size, vector_size]) - sch.bind(f_2, "threadIdx.x") - sch.vectorize(f_3) - - sch.set_scope( - sch.get_block("conv2d_nchw_mma_reindex_wmma.accumulator"), - buffer_index=0, - storage_scope="shared", - ) - - sch.tensorize(loop_a, WMMA_LOAD_16x16x16_F16_A_INTRIN) - sch.tensorize(loop_b, WMMA_LOAD_16x16x16_F16_B_INTRIN) - - intrin = WMMA_SYNC_16x16x16_f16f16f32_INTRIN - if output_type == "float16": - intrin = WMMA_SYNC_16x16x16_f16f16f16_INTRIN - sch.tensorize(sch.get_loops(block_inner)[-3], intrin) - - intrin = WMMA_FILL_16x16x16_F32_INTRIN - if output_type == "float16": - intrin = WMMA_FILL_16x16x16_F16_INTRIN - sch.tensorize(sch.get_loops(block_init_c)[-2], intrin) - - intrin = WMMA_STORE_16x16x16_F32_SHARED_INTRIN - if output_type == "float16": - intrin = WMMA_STORE_16x16x16_F16_SHARED_INTRIN - sch.tensorize(sch.get_loops(C_warp)[-2], intrin) - - # print(sch.mod.script()) - - return [sch] - - -register_func("meta_schedule.vulkan.conv2d_nchw_mma", schedule_rule_conv2d_nchw_mma) diff --git a/python/tvm/topi/cuda/conv2d_alter_op.py b/python/tvm/topi/cuda/conv2d_alter_op.py index c3160274398ea..93512ca07d9ea 100644 --- a/python/tvm/topi/cuda/conv2d_alter_op.py +++ b/python/tvm/topi/cuda/conv2d_alter_op.py @@ -368,8 +368,6 @@ def _conv2d_legalize(attrs, inputs, arg_types): dilation = attrs.get_int_tuple("dilation") if not (dilation[0] == 1 and dilation[1] == 1): return None - padding = attrs.get_int_tuple("padding") - stride = attrs.get_int_tuple("strides") # No legalization for depthwise convolutions yet. groups = attrs.get_int("groups") @@ -472,11 +470,11 @@ def _conv2d_legalize(attrs, inputs, arg_types): return _pad_conv2d_HWNC(db, di, do, data, kernel, out_channel, new_attrs, output_tensor) elif data_dtype in ["float16"]: - if isinstance(data_tensor.shape[0], tvm.tir.expr.Any): - # Skip legalize when the batch size is dynamic - return None - if data_layout == "NHWC" and kernel_layout == "HWIO": + if isinstance(data_tensor.shape[0], tvm.tir.expr.Any): + # Skip legalize when the batch size is dynamic + return None + batch = data_tensor.shape[0].value in_channel = data_tensor.shape[3].value out_channel = kernel_tensor.shape[3].value @@ -501,46 +499,6 @@ def _conv2d_legalize(attrs, inputs, arg_types): logger.info("conv2d pad_to_tensorcore, extra_flops %s", extra_flops) return _pad_conv2d_NHWC(db, di, do, data, kernel, out_channel, new_attrs, output_tensor) - elif data_layout == "NCHW" and kernel_layout == "OIHW": - if not nn.use_im2col( - data_tensor, - kernel_tensor, - stride[0], - stride[1], - dilation[0], - dilation[1], - padding, - True, - ): - return None - oc_modified = False - in_channel = data_tensor.shape[1].value - out_channel = kernel_tensor.shape[0].value - - # Pad input channel - if in_channel % 16 != 0: - new_in_channel = ((in_channel + 16) // 16) * 16 - diff = new_in_channel - in_channel - pad_width = ((0, 0), (0, diff), (0, 0), (0, 0)) - data = relay.nn.pad(data, pad_width=pad_width) - kernel = relay.nn.pad(kernel, pad_width=pad_width) - - # Pad output channel - new_out_channel = out_channel - if out_channel % 16 != 0: - new_out_channel = ((out_channel + 16) // 16) * 16 - diff = new_out_channel - out_channel - kernel = relay.nn.pad(kernel, pad_width=((0, diff), (0, 0), (0, 0), (0, 0))) - oc_modified = True - - if oc_modified: - new_attrs["channels"] = new_out_channel - out = tvm.relay.nn.conv2d(data, kernel, **new_attrs) - original_out_shape = [x.value for x in output_tensor.shape] - out = relay.strided_slice(out, begin=[0, 0, 0, 0], end=original_out_shape) - else: - out = relay.nn.conv2d(data, kernel, **new_attrs) - return out elif data_dtype in ["int4", "uint4"]: if data_layout == "NHWC" and kernel_layout == "HWIO": diff --git a/python/tvm/topi/nn/__init__.py b/python/tvm/topi/nn/__init__.py index 395a8a2008d3f..d65c5c45c7e08 100644 --- a/python/tvm/topi/nn/__init__.py +++ b/python/tvm/topi/nn/__init__.py @@ -55,4 +55,3 @@ from .batch_to_space_nd import * from .loss import * from .lstm import * -from .utils import * diff --git a/python/tvm/topi/nn/utils.py b/python/tvm/topi/nn/utils.py index dda19828d0962..01e1c1ab5444b 100644 --- a/python/tvm/topi/nn/utils.py +++ b/python/tvm/topi/nn/utils.py @@ -19,7 +19,6 @@ from __future__ import absolute_import import tvm -from tvm.meta_schedule import is_meta_schedule_enabled from ..utils import get_const_int @@ -308,130 +307,3 @@ def get_pad_tuple1d(padding, kernel): raise ValueError("Unknown padding option %s" % padding) pad_left = (pad_w + 1) // 2 return pad_left, pad_w - pad_left - - -def use_im2col( - data, - kernel, - stride_h, - stride_w, - dilation_h, - dilation_w, - padding, - patch_channel=False, - M=16, - K=16, - N=16, -): - """Whether to use im2col wmma implementation for conv2d with NCHW layout - - Parameters - ---------- - data : Tensor - data - - kernel: Tensor - kernel - - stride_h : int - stride in height dimension - - stride_w : int - stride in width dimension - - dilation_h : int - dilation in height dimension - - dilation_w : int - dilation in width dimension - - padding : int or str - padding size, or ['VALID', 'SAME'] - - patch_channel : bool - whether to patch channel dimension - - M : int - row dimension of matrix A fragment - - K : int - reduction dimension of matrix A or B fragment - - N : int - column dimension of matrix B fragment - - Returns - ------- - True or False - - """ - if not is_meta_schedule_enabled() or data.dtype != "float16": - return False - - target = tvm.target.Target.current(allow_none=False) - if not target.kind.name in ["vulkan"] or not target.supports_cooperative_matrix: - return False - - out_channels, in_channels, kernel_h, kernel_w = kernel.shape - batch_size, output_channels, output_height, output_weight = get_output_shape( - data, kernel, stride_h, stride_w, dilation_h, dilation_w, padding - ) - - if isinstance(batch_size, tvm.tir.expr.Any): - return False - - if (output_channels % M != 0) and not patch_channel: - return False - if (in_channels * kernel_h * kernel_w) % K != 0 and not patch_channel: - return False - - if ( - (batch_size == 1) - and (dilation_h * dilation_w == 1) - and ((batch_size * output_height * output_weight) % N == 0) - ): - return True - - return False - - -def get_output_shape(data, kernel, stride_h, stride_w, dilation_h, dilation_w, padding): - """Get output shape of conv2d with NCHW layout - - Parameters - ---------- - data : Tensor - data - - kernel : Tensor - kernel - - stride_h : int - stride in height dimension - - stride_w : int - stride in width dimension - - dilation_h : int - dilation in height dimension - - dilation_w : int - dilation in width dimension - - padding : int or str - padding size, or ['VALID', 'SAME'] - - Returns - ------- - Shape of output - - """ - batch_size, in_channels, height, width = data.shape - out_channels, in_channels, kernel_h, kernel_w = kernel.shape - pad_top, pad_left, pad_down, pad_right = get_pad_tuple(padding, (kernel_h, kernel_w)) - pad_top, pad_left, pad_down, pad_right = get_pad_tuple(padding, (kernel_h, kernel_w)) - dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 - dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 - output_height = (height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1 - output_width = (width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1 - return [batch_size, out_channels, output_height, output_width] diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 81f83432d1619..ed2c40da72a1d 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -325,7 +325,6 @@ class IterMapRewriter : public ExprMutator { PrimExpr VisitExpr_(const MulNode* op) final; PrimExpr VisitExpr_(const FloorDivNode* op) final; PrimExpr VisitExpr_(const FloorModNode* op) final; - PrimExpr VisitExpr_(const CastNode* op) final; private: /* \brief Preprocessing common to both FloorDiv and FloorMod @@ -1544,15 +1543,6 @@ PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) { } } -PrimExpr IterMapRewriter::VisitExpr_(const CastNode* op) { - PrimExpr value = this->DirectMutate(op->value); - ICHECK(value->IsInstance()); - const auto* node = value.as(); - IterSplitExpr new_expr = GetRef(node); - new_expr.CopyOnWrite()->dtype = op->dtype; - return new_expr; -} - IterSumExpr IterMapRewriter::PreprocessDividend(IterMapExpr dividend, PrimExpr original_dividend) { if (dividend->IsInstance()) { auto split = Downcast(dividend); diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc index 5e396a38d8c62..a19e6ea3fe232 100644 --- a/src/meta_schedule/postproc/verify_gpu_code.cc +++ b/src/meta_schedule/postproc/verify_gpu_code.cc @@ -122,7 +122,7 @@ class VerifyGPUCodeNode : public PostprocNode { this->target_ = context->target.value(); this->target_constraints_ = Map{ {"max_shared_memory_per_block", Extract(this->target_, "max_shared_memory_per_block")}, - {"max_threads_per_block", Extract(this->target_, "max_num_threads")}, + {"max_threads_per_block", Extract(this->target_, "max_threads_per_block")}, {"max_vthread", Integer(8)}, {"max_vector_bytes", Integer(16)}, }; diff --git a/src/meta_schedule/search_strategy/replay_func.cc b/src/meta_schedule/search_strategy/replay_func.cc index 7dbfef5c8a138..7bb4a02ab299b 100644 --- a/src/meta_schedule/search_strategy/replay_func.cc +++ b/src/meta_schedule/search_strategy/replay_func.cc @@ -59,8 +59,6 @@ class ReplayFuncNode : public SearchStrategyNode { Optional space_generator_ = NullOpt; /*! \brief The state of the search strategy. */ std::unique_ptr state_ = nullptr; - /*! \brief The tuning context of the search strategy. */ - const TuneContextNode* context_{nullptr}; void VisitAttrs(tvm::AttrVisitor* v) {} @@ -81,7 +79,6 @@ class ReplayFuncNode : public SearchStrategyNode { this->mod_ = ctx->mod; this->space_generator_ = ctx->space_generator; this->state_.reset(); - this->context_ = ctx.get(); } void PreTuning(int max_trials, int num_trials_per_iter, const Array& design_spaces, @@ -126,11 +123,6 @@ inline Optional> ReplayFuncNode::State::GenerateMeasureC Array result; IRModule mod = self->mod_.value(); Array postprocs = self->space_generator_.value()->postprocs.value_or({}); - using tvm::runtime::Registry; - const TuneContextNode* ctx = self->context_; - const auto* f_enter = Registry::Get("target.TargetEnterScope"); - (*f_enter)(ctx->target); - for (int i = st; i < ed; i++) { for (;;) { Array schs = self->space_generator_.value()->GenerateDesignSpace(mod); @@ -151,8 +143,6 @@ inline Optional> ReplayFuncNode::State::GenerateMeasureC } } } - const auto* f_exit = Registry::Get("target.TargetExitScope"); - (*f_exit)(ctx->target); return result; } diff --git a/src/meta_schedule/task_scheduler/gradient_based.cc b/src/meta_schedule/task_scheduler/gradient_based.cc index 14314fc3c3eb0..5b261eec32a4e 100644 --- a/src/meta_schedule/task_scheduler/gradient_based.cc +++ b/src/meta_schedule/task_scheduler/gradient_based.cc @@ -47,13 +47,13 @@ class GradientBasedNode final : public TaskSchedulerNode { void Tune(Array tasks, Array task_weights, int max_trials_global, int max_trials_per_task, int num_trials_per_iter, Builder builder, Runner runner, Array measure_callbacks, Optional database, - Optional cost_model, int min_design_space) final { + Optional cost_model) final { int n_tasks = tasks.size(); round_robin_rounds_ = 0; best_latency_history_.resize(n_tasks, std::vector()); TaskSchedulerNode::Tune(tasks, task_weights, max_trials_global, max_trials_per_task, num_trials_per_iter, builder, runner, measure_callbacks, database, - cost_model, min_design_space); + cost_model); } int NextTaskId() final { diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc b/src/meta_schedule/task_scheduler/task_scheduler.cc index 5e677e5c0f15a..404ee01983c5a 100644 --- a/src/meta_schedule/task_scheduler/task_scheduler.cc +++ b/src/meta_schedule/task_scheduler/task_scheduler.cc @@ -34,10 +34,6 @@ TaskRecord::TaskRecord(TuneContext ctx, double task_weight) { << "ValueError: Require `context.search_strategy`, but it is not defined"; TVM_PY_LOG(INFO, ctx->logger) << "\n" << ctx->mod; ctx->Initialize(); - using tvm::runtime::Registry; - const auto* f_enter = Registry::Get("target.TargetEnterScope"); - (*f_enter)(ctx->target); - n->flop = std::max(1.0, tir::EstimateTIRFlops(ctx->mod.value())); this->data_ = std::move(n); } @@ -148,7 +144,7 @@ void TaskSchedulerNode::Tune(Array ctxs, Array task_weigh int max_trials_global, int max_trials_per_task, int num_trials_per_iter, Builder builder, Runner runner, Array measure_callbacks, Optional database, - Optional cost_model, int min_design_space) { + Optional cost_model) { CHECK_EQ(ctxs.size(), task_weights.size()) << "ValueError: `task_weights` must have the same " "length as `ctxs`"; int n_tasks = this->remaining_tasks_ = ctxs.size(); @@ -163,20 +159,8 @@ void TaskSchedulerNode::Tune(Array ctxs, Array task_weigh TVM_PY_LOG(INFO, this->logger) << "Initializing Task #" << i << ": " << ctx->task_name; TVM_PY_LOG(INFO, ctx->logger) << "Initializing Task #" << i << ": " << ctx->task_name; this->tasks_.push_back(TaskRecord(ctx, weight)); - Array design_spaces; - int sample = 0; - while (sample < min_design_space) { - Array cur_design_spaces = - ctx->space_generator.value()->GenerateDesignSpace(ctx->mod.value()); - unsigned int size = cur_design_spaces.size(); - CHECK(size > 0) << "ValueError: Empty design spaces"; - for (unsigned int i = 0; i < size; ++i) design_spaces.push_back(cur_design_spaces[i]); - sample += size; - } - - using tvm::runtime::Registry; - const auto* f_exit = Registry::Get("target.TargetExitScope"); - (*f_exit)(ctx->target); + Array design_spaces = + ctx->space_generator.value()->GenerateDesignSpace(ctx->mod.value()); TVM_PY_LOG(INFO, ctx->logger) << "Total " << design_spaces.size() << " design space(s) generated"; for (int i = 0, n = design_spaces.size(); i < n; ++i) { @@ -366,15 +350,14 @@ void PyTaskSchedulerNode::Tune(Array tasks, Array task_we int max_trials_global, int max_trials_per_task, int num_trials_per_iter, Builder builder, Runner runner, Array measure_callbacks, - Optional database, Optional cost_model, - int min_design_space) { + Optional database, Optional cost_model) { if (f_tune == nullptr) { TaskSchedulerNode::Tune(tasks, task_weights, max_trials_global, max_trials_per_task, num_trials_per_iter, builder, runner, measure_callbacks, database, - cost_model, min_design_space); + cost_model); } else { f_tune(tasks, task_weights, max_trials_global, max_trials_per_task, num_trials_per_iter, - builder, runner, measure_callbacks, database, cost_model, min_design_space); + builder, runner, measure_callbacks, database, cost_model); } } diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index df02430b16ea8..8165954749095 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -552,7 +552,6 @@ TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_auto_scheduler", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_meta_schedule", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_meta_schedule_dispatch", Integer); TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.tir_converter", String); -TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_int32_const", Bool); TVM_REGISTER_GLOBAL("relay.backend._TECompilerGlobal").set_body_typed([]() { return TECompiler::Global(); diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index 1d81767cd9f21..f009bda9cd98e 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -417,9 +417,7 @@ Optional DefaultTIRConverterImpl(const Array& args, return NullOpt; } } - - PrimFunc func = te::CreatePrimFuncWithConstants( - args, constants, UseInt32Const() ? DataType::Int(32) : DataType::Int(64)); + PrimFunc func = te::CreatePrimFuncWithConstants(args, constants, DataType::Int(64)); bool dynamic_loop_extent = false; tir::PostOrderVisit(func->body, [&dynamic_loop_extent](const ObjectRef& obj) -> void { if (const auto* loop = obj.as()) { diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index f3dc4f3eabf25..acaea425d1789 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -641,16 +641,6 @@ inline int UseMetaScheduleDispatch() { .value() ->value; } - -/* - * \brief Return whether int32 type is used for constants. - */ -inline bool UseInt32Const() { - return transform::PassContext::Current() - ->GetConfig("relay.backend.use_int32_const", Bool(false)) - .value(); -} - /*! * \brief Method in TECompiler to convert TE compute to scheduleable TIR * \param args The arguments of the TE compute diff --git a/src/runtime/vulkan/vulkan_device.cc b/src/runtime/vulkan/vulkan_device.cc index dfc8034c85c3c..b3e017d03418e 100644 --- a/src/runtime/vulkan/vulkan_device.cc +++ b/src/runtime/vulkan/vulkan_device.cc @@ -134,8 +134,6 @@ VulkanDeviceProperties::VulkanDeviceProperties(const VulkanInstance& instance, supports_integer_dot_product = device.HasExtension("VK_KHR_shader_integer_dot_product"); - supports_cooperative_matrix = device.HasExtension("VK_NV_cooperative_matrix"); - // The check of VK_SHADER_STAGE_COMPUTE_BIT isn't technically // needed, since it will be set so long at least one queue has // VK_QUEUE_COMPUTE_BIT. Including it to avoid potential future @@ -437,8 +435,7 @@ std::vector VulkanDevice::SelectEnabledExtensions() const { "VK_KHR_get_memory_requirements2", "VK_KHR_dedicated_allocation", "VK_KHR_spirv_1_4", - "VK_KHR_shader_integer_dot_product", - "VK_NV_cooperative_matrix"}; + "VK_KHR_shader_integer_dot_product"}; uint32_t device_extension_prop_count; VULKAN_CALL(vkEnumerateDeviceExtensionProperties(physical_device_, nullptr, diff --git a/src/runtime/vulkan/vulkan_device.h b/src/runtime/vulkan/vulkan_device.h index 296483a6b1042..59ebf430e6e64 100644 --- a/src/runtime/vulkan/vulkan_device.h +++ b/src/runtime/vulkan/vulkan_device.h @@ -88,7 +88,6 @@ struct VulkanDeviceProperties { bool supports_push_descriptor{false}; bool supports_dedicated_allocation{false}; bool supports_integer_dot_product{false}; - bool supports_cooperative_matrix{false}; uint32_t supported_subgroup_operations{0}; uint32_t max_num_threads{1}; uint32_t thread_warp_size{1}; diff --git a/src/runtime/vulkan/vulkan_device_api.cc b/src/runtime/vulkan/vulkan_device_api.cc index 1087415256025..93f017a5aa667 100644 --- a/src/runtime/vulkan/vulkan_device_api.cc +++ b/src/runtime/vulkan/vulkan_device_api.cc @@ -241,10 +241,6 @@ void VulkanDeviceAPI::GetTargetProperty(Device dev, const std::string& property, *rv = prop.supports_integer_dot_product; } - if (property == "supports_cooperative_matrix") { - *rv = prop.supports_cooperative_matrix; - } - if (property == "device_name") { *rv = prop.device_name; } diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 703e051670f78..e3ef5acb83314 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -130,7 +130,6 @@ void CodeGenSPIRV::InitFuncState() { builder_.reset(new spirv::IRBuilder(spirv_support_)); builder_->InitHeader(); shared_memory_bytes_used_ = 0; - fragment_info_.clear(); } spirv::Value CodeGenSPIRV::GetThreadIndex(const IterVar& iv, const PrimExpr& extent) { @@ -396,135 +395,6 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { LOG(FATAL) << "SPIR-V shader cannot make extern calls. Graph contains extern \"" << Downcast(op->args[0]) << "\""; return spirv::Value(); - } else if (op->op.same_as(builtin::tvm_fill_fragment())) { - ICHECK_EQ(op->args.size(), 6U); - const VarNode* buffer_node = op->args[0].as(); - ICHECK(buffer_node && fragment_info_.count(buffer_node)); - DataType ele_dtype = GetElementDataType(buffer_node); - ICHECK(ele_dtype.is_float()) << "Only floating point fragment accumulator is supported"; - spirv::SType ele_stype = builder_->GetSType(ele_dtype); - spirv::SType& fragment_type = fragment_info_[buffer_node].stype; - double init = static_cast(Downcast(op->args[5])->value); - PrimExpr prim_index = op->args[4]; - spirv::Value init_val = builder_->GetCompositeConst(ele_stype, fragment_type, init); - spirv::SType ptr_type = - builder_->GetPointerType(fragment_type, fragment_info_[buffer_node].sclass); - spirv::Value index = MakeValue(prim_index); - ICHECK(var_map_.count(buffer_node)); - spirv::Value ptr = builder_->StructArrayAccess(ptr_type, var_map_[buffer_node], index); - builder_->MakeInst(spv::OpStore, ptr, init_val, spv::MemoryAccessMaskNone); - return spirv::Value(); - - } else if (op->op.same_as(builtin::tvm_load_matrix_sync())) { - ICHECK_EQ(op->args.size(), 8U); - const VarNode* buffer_node = op->args[0].as(); - ICHECK(buffer_node && fragment_info_.count(buffer_node)); - spirv::SType& fragment_type = fragment_info_[buffer_node].stype; - PrimExpr dst_index = op->args[4]; - PrimExpr src_ptr_expr = op->args[5]; - int stride = static_cast(Downcast(op->args[6])->value); - auto type_int = builder_->GetSType(DataType::Int(32)); - spirv::Value stride_val = builder_->IntImm(type_int, stride); - std::string layout = (op->args[7].as())->value; - spirv::SType dst_ptr_type = - builder_->GetPointerType(fragment_type, fragment_info_[buffer_node].sclass); - spirv::Value dst_ptr = - builder_->StructArrayAccess(dst_ptr_type, var_map_[buffer_node], MakeValue(dst_index)); - const CallNode* call_node = src_ptr_expr.as(); - ICHECK(call_node && call_node->op.same_as(builtin::address_of())); - const BufferLoadNode* load = call_node->args[0].as(); - Var src_buffer_var = load->buffer->data; - const VarNode* src_buffer_node = src_buffer_var.get(); - PrimExpr src_index = load->indices[0]; - DataType src_ele_dtype = GetElementDataType(src_buffer_node); - spirv::SType src_ele_stype = builder_->GetSType(src_ele_dtype); - spirv::Value src_buffer_val = MakeValue(src_buffer_var); - spirv::SType src_ptr_type = - builder_->GetPointerType(src_ele_stype, src_buffer_val.stype.storage_class); - ICHECK(var_map_.count(src_buffer_node)); - spirv::Value src_ptr = - builder_->StructArrayAccess(src_ptr_type, var_map_[src_buffer_node], MakeValue(src_index)); - spirv::SType type_bool = builder_->GetSType(DataType::UInt(1)); - spirv::Value t_val = builder_->UIntImm(type_bool, 1); - spirv::Value f_val = builder_->UIntImm(type_bool, 0); - spirv::Value loaded = - builder_->MakeValue(spv::OpCooperativeMatrixLoadNV, fragment_type, src_ptr, stride_val, - (layout != "row_major") ? t_val : f_val); - builder_->MakeInst(spv::OpStore, dst_ptr, loaded, spv::MemoryAccessMaskNone); - return spirv::Value(); - } else if (op->op.same_as(builtin::tvm_mma_sync())) { - const VarNode* buffer_d = op->args[0].as(); - const VarNode* buffer_a = op->args[2].as(); - const VarNode* buffer_b = op->args[4].as(); - const VarNode* buffer_c = op->args[6].as(); - PrimExpr index_d = op->args[1]; - PrimExpr index_a = op->args[3]; - PrimExpr index_b = op->args[5]; - tvm::tir::ExprDeepEqual expr_equal; - PrimExpr index_c = op->args[7]; - bool is_equal = ((buffer_d == buffer_c) && expr_equal(index_d, index_c)); - spirv::SType& fragment_type_d = fragment_info_[buffer_d].stype; - spirv::SType& fragment_type_a = fragment_info_[buffer_a].stype; - spirv::SType& fragment_type_b = fragment_info_[buffer_b].stype; - spirv::SType& fragment_type_c = fragment_info_[buffer_c].stype; - spv::StorageClass storage = fragment_info_[buffer_d].sclass; - spirv::SType ptr_type_d = builder_->GetPointerType(fragment_type_d, storage); - spirv::SType ptr_type_a = builder_->GetPointerType(fragment_type_a, storage); - spirv::SType ptr_type_b = builder_->GetPointerType(fragment_type_b, storage); - spirv::SType ptr_type_c = builder_->GetPointerType(fragment_type_c, storage); - spirv::Value ptr_d = - builder_->StructArrayAccess(ptr_type_d, var_map_[buffer_d], MakeValue(index_d)); - spirv::Value ptr_a = - builder_->StructArrayAccess(ptr_type_a, var_map_[buffer_a], MakeValue(index_a)); - spirv::Value ptr_b = - builder_->StructArrayAccess(ptr_type_b, var_map_[buffer_b], MakeValue(index_b)); - spirv::Value ptr_c = - is_equal ? ptr_d - : builder_->StructArrayAccess(ptr_type_c, var_map_[buffer_c], MakeValue(index_c)); - uint32_t mask = spv::MemoryAccessMaskNone; - spirv::Value loaded_a = builder_->MakeValue(spv::OpLoad, fragment_type_a, ptr_a, mask); - spirv::Value loaded_b = builder_->MakeValue(spv::OpLoad, fragment_type_b, ptr_b, mask); - spirv::Value loaded_c = builder_->MakeValue(spv::OpLoad, fragment_type_c, ptr_c, mask); - spirv::Value result = builder_->MakeValue(spv::OpCooperativeMatrixMulAddNV, fragment_type_d, - loaded_a, loaded_b, loaded_c); - builder_->MakeInst(spv::OpStore, ptr_d, result, spv::MemoryAccessMaskNone); - return spirv::Value(); - } else if (op->op.same_as(builtin::tvm_store_matrix_sync())) { - ICHECK_EQ(op->args.size(), 8U); - const VarNode* buffer_node = op->args[0].as(); - PrimExpr index = op->args[4]; - PrimExpr buffer_ptr = op->args[5]; - int stride = static_cast(Downcast(op->args[6])->value); - auto type_int = builder_->GetSType(DataType::Int(32)); - spirv::Value stride_val = builder_->IntImm(type_int, stride); - std::string layout = (op->args[7].as())->value; - const CallNode* call_node = buffer_ptr.as(); - ICHECK(call_node && call_node->op.same_as(builtin::address_of())); - const BufferLoadNode* load = call_node->args[0].as(); - Var dst_buffer_var = load->buffer->data; - const VarNode* dst_buffer_node = dst_buffer_var.get(); - PrimExpr dst_index = load->indices[0]; - DataType dst_ele_dtype = GetElementDataType(dst_buffer_node); - spirv::SType dst_ele_stype = builder_->GetSType(dst_ele_dtype); - spirv::Value dst_buffer_val = MakeValue(dst_buffer_var); - spirv::SType dst_ptr_type = - builder_->GetPointerType(dst_ele_stype, dst_buffer_val.stype.storage_class); - ICHECK(var_map_.count(dst_buffer_node)); - spirv::Value dst_ptr = - builder_->StructArrayAccess(dst_ptr_type, var_map_[dst_buffer_node], MakeValue(dst_index)); - spirv::SType& fragment_type = fragment_info_[buffer_node].stype; - spv::StorageClass storage = fragment_info_[buffer_node].sclass; - spirv::SType ptr_type = builder_->GetPointerType(fragment_type, storage); - spirv::Value ptr = - builder_->StructArrayAccess(ptr_type, var_map_[buffer_node], MakeValue(index)); - uint32_t mask = spv::MemoryAccessMaskNone; - spirv::Value loaded = builder_->MakeValue(spv::OpLoad, fragment_type, ptr, mask); - spirv::SType type_bool = builder_->GetSType(DataType::UInt(1)); - spirv::Value t_val = builder_->UIntImm(type_bool, 1); - spirv::Value f_val = builder_->UIntImm(type_bool, 0); - builder_->MakeInst(spv::OpCooperativeMatrixStoreNV, dst_ptr, loaded, stride_val, - (layout != "row_major") ? t_val : f_val); - return spirv::Value(); } else { LOG(FATAL) << "Unresolved call " << op->op; } @@ -788,44 +658,22 @@ void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) { ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; spirv::Value buf; - const std::string scope = GetPtrStorageScope(op->buffer_var); - auto storage_scope = runtime::StorageScope::Create(scope); + auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); spirv::SType etype = builder_->GetSType(op->dtype); - runtime::StorageRank rank = storage_scope.rank; - spv::StorageClass storage_class; - const VarNode* var_node = (op->buffer_var).get(); - - switch (rank) { - case runtime::StorageRank::kWMMAMatrixA: - case runtime::StorageRank::kWMMAMatrixB: - case runtime::StorageRank::kWMMAAccumulator: { - ICHECK(fragment_info_.count(var_node)); - fragment_info_[var_node].scope = scope; - etype = GetFragmentSType(var_node, op->dtype); - storage_class = spv::StorageClassFunction; - fragment_info_[var_node].sclass = storage_class; - ICHECK(fragment_info_.count(var_node)); - std::pair dim = GetWmmaFragmentSize(var_node); - int64_t size = dim.first * dim.second; - buf = builder_->Allocate(etype, static_cast(constant_size) / size, storage_class); - } break; - case runtime::StorageRank::kLocal: { - storage_class = spv::StorageClassFunction; - buf = builder_->Allocate(etype, static_cast(constant_size), storage_class); - } break; - case runtime::StorageRank::kShared: { - storage_class = spv::StorageClassWorkgroup; - // Shared memory - // Aligned on 4-byte boundary - int32_t aligned_constant_size = ((constant_size + 3) & ~0x3); - buf = builder_->Allocate(etype, static_cast(aligned_constant_size), storage_class); - - size_t num_bytes = - op->dtype.bytes() * op->dtype.lanes() * static_cast(aligned_constant_size); - shared_memory_bytes_used_ += num_bytes; - } break; - default: - LOG(FATAL) << "Can only allocate shared or local memory inside kernel"; + if (storage_scope.rank == runtime::StorageRank::kLocal) { + buf = + builder_->Allocate(etype, static_cast(constant_size), spv::StorageClassFunction); + } else if (storage_scope.rank == runtime::StorageRank::kShared) { + // Shared memory + // Aligned on 4-byte boundary + int32_t aligned_constant_size = ((constant_size + 3) & ~0x3); + buf = builder_->Allocate(etype, static_cast(aligned_constant_size), + spv::StorageClassWorkgroup); + + size_t num_bytes = op->dtype.bytes() * op->dtype.lanes() * static_cast(constant_size); + shared_memory_bytes_used_ += num_bytes; + } else { + LOG(FATAL) << "Can only allocate shared or local memory inside kernel"; } builder_->SetName(buf, op->buffer_var->name_hint); @@ -853,13 +701,6 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmtNode* op) { const VarNode* v = op->node.as(); ICHECK(v); storage_info_[v].is_volatile = true; - } else if (op->attr_key == tir::attr::buffer_bind_scope) { - const VarNode* v = op->node.as(); - ICHECK(v); - } else if (op->attr_key == tir::attr::fragment_shape) { - const VarNode* buffer = op->node.as(); - const StringImmNode* shape_str = op->value.as(); - fragment_info_[buffer] = {shape_str->value}; } this->VisitStmt(op->body); } @@ -885,51 +726,5 @@ void CodeGenSPIRV::VisitStmt_(const SeqStmtNode* op) { void CodeGenSPIRV::VisitStmt_(const EvaluateNode* op) { MakeValue(op->value); } -int CodeGenSPIRV::Stoi(const std::string& str) { - try { - return std::stoi(str); - } catch (std::invalid_argument& e) { - LOG(FATAL) << "Cannot convert \"" << str << "\" to int"; - throw; - } -} - -std::pair CodeGenSPIRV::GetWmmaFragmentSize(const VarNode* variable) { - const std::string& scope = fragment_info_[variable].scope; - std::string& shape_str = fragment_info_.at(variable).shape; - size_t m, n, k; - size_t last_pos = 0, pos = 0; - pos = shape_str.find(", ", last_pos); - m = Stoi(shape_str.substr(last_pos, pos - last_pos)); - last_pos = pos + 2; - pos = shape_str.find(", ", last_pos); - n = Stoi(shape_str.substr(last_pos, pos - last_pos)); - last_pos = pos + 2; - k = Stoi(shape_str.substr(last_pos, shape_str.length() - last_pos)); - if (scope == "wmma.matrix_a") { - return std::pair(m, k); - } else if (scope == "wmma.matrix_b") { - return std::pair(k, n); - } else if (scope == "wmma.accumulator") { - return std::pair(m, n); - } - return std::pair(0, 0); -} - -spirv::SType CodeGenSPIRV::GetFragmentSType(const VarNode* buffer, const DataType& dtype) { - ICHECK(fragment_info_.count(buffer)); - std::pair dim = GetWmmaFragmentSize(buffer); - int64_t size = dim.first * dim.second; - spirv::SType stype = builder_->GetSType(dtype.with_lanes(size), dim.first, dim.second); - fragment_info_[buffer].stype = stype; - return stype; -} - -DataType CodeGenSPIRV::GetElementDataType(const VarNode* buffer) { - auto it = storage_info_.find(buffer); - ICHECK(it != storage_info_.end()); - return it->second.element_type; -} - } // namespace codegen } // namespace tvm diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h index df44b236177a4..08b9db0ee5398 100644 --- a/src/target/spirv/codegen_spirv.h +++ b/src/target/spirv/codegen_spirv.h @@ -34,7 +34,6 @@ #include #include #include -#include #include #include "../../runtime/thread_storage_scope.h" @@ -172,14 +171,6 @@ class CodeGenSPIRV : public ExprFunctor, element_type_known = true; } }; - - struct FragmentInfo { - std::string shape; - std::string scope; - spirv::SType stype; - spv::StorageClass sclass; - }; - // Reset the state so it works for a new function. void InitFuncState(); // Get the thread index @@ -188,11 +179,6 @@ class CodeGenSPIRV : public ExprFunctor, spirv::Value CreateStorageSync(const CallNode* op); void Scalarize(const PrimExpr& e, std::function f); - int Stoi(const std::string& str); - std::pair GetWmmaFragmentSize(const VarNode* variable); - spirv::SType GetFragmentSType(const VarNode* buffer, const DataType& dtype); - DataType GetElementDataType(const VarNode* buffer); - // SPIRV-related capabilities of the target SPIRVSupport spirv_support_; @@ -232,8 +218,6 @@ class CodeGenSPIRV : public ExprFunctor, // Running total of the number of bytes of shared memory used. // Checked against the max_shared_memory_per_group size_t shared_memory_bytes_used_{0}; - - std::unordered_map fragment_info_; }; } // namespace codegen diff --git a/src/target/spirv/ir_builder.cc b/src/target/spirv/ir_builder.cc index 545e677af9f28..46c9c5869c79d 100644 --- a/src/target/spirv/ir_builder.cc +++ b/src/target/spirv/ir_builder.cc @@ -60,11 +60,6 @@ void IRBuilder::InitHeader() { } #endif - if (spirv_support_.supports_cooperative_matrix) { - capabilities_used_.insert(spv::CapabilityCooperativeMatrixNV); - extensions_used_.insert("SPV_NV_cooperative_matrix"); - } - // memory model ib_.Begin(spv::OpMemoryModel) .AddSeq(spv::AddressingModelLogical, spv::MemoryModelGLSL450) @@ -79,7 +74,6 @@ void IRBuilder::InitPreDefs() { t_bool_ = DeclareType(DataType::UInt(1)); t_fp32_ = DeclareType(DataType::Float(32)); const_i32_zero_ = IntImm(t_int32_, 0); - // declare void, and void functions t_void_.id = id_counter_++; ib_.Begin(spv::OpTypeVoid).Add(t_void_).Commit(&global_); @@ -112,7 +106,7 @@ std::vector IRBuilder::Finalize() { return data; } -SType IRBuilder::GetSType(const DataType& dtype, uint32_t row, uint32_t col) { +SType IRBuilder::GetSType(const DataType& dtype) { if (dtype == DataType::Int(32)) { return t_int32_; } else if (dtype == DataType::UInt(1)) { @@ -122,22 +116,15 @@ SType IRBuilder::GetSType(const DataType& dtype, uint32_t row, uint32_t col) { } else if (dtype == DataType::UInt(32)) { return t_uint32_; } - uint64_t type_key; + uint32_t type_key; type_key = static_cast(dtype.code()); type_key |= static_cast(dtype.bits()) << 8U; - if (row * col == 0) { - ICHECK((row == 0) && (col == 0)); - type_key |= static_cast(dtype.lanes()) << 16U; - } else { - type_key |= static_cast(row) << 32U; - type_key |= static_cast(col) << 40U; - } - + type_key |= static_cast(dtype.lanes()) << 16U; auto it = pod_type_tbl_.find(type_key); if (it != pod_type_tbl_.end()) { return it->second; } - SType t = DeclareType(dtype, row, col); + SType t = DeclareType(dtype); pod_type_tbl_[type_key] = t; return t; } @@ -234,13 +221,7 @@ Value IRBuilder::FloatImm(const SType& dtype, double value) { return GetConst_(dtype, &data); } else { ICHECK_EQ(dtype.type.bits(), 16); - float fvalue = static_cast(value); - uint32_t* ptr = reinterpret_cast(&fvalue); - uint64_t data = ptr[0]; - if (data == 0) - return GetConst_(dtype, &data); - else - return Cast(dtype, FloatImm(GetSType(DataType::Float(32)), value)); + return Cast(dtype, FloatImm(GetSType(DataType::Float(32)), value)); } } @@ -494,7 +475,7 @@ Value IRBuilder::GetConst_(const SType& dtype, const uint64_t* pvalue) { return ret; } -SType IRBuilder::DeclareType(const DataType& dtype, uint32_t row, uint32_t col) { +SType IRBuilder::DeclareType(const DataType& dtype) { AddCapabilityFor(dtype); if (dtype.lanes() == 1) { @@ -519,18 +500,7 @@ SType IRBuilder::DeclareType(const DataType& dtype, uint32_t row, uint32_t col) t.id = id_counter_++; t.type = dtype; SType base_type = GetSType(dtype.element_of()); - - if (row * col == 0) { - ICHECK((row == 0) && (col == 0)); - ib_.Begin(spv::OpTypeVector).AddSeq(t, base_type, dtype.lanes()).Commit(&global_); - } else { - Value v_row = GetSpecConst(GetSType(DataType::UInt(32)), row); - Value v_col = GetSpecConst(GetSType(DataType::UInt(32)), col); - Value scope = UIntImm(GetSType(DataType::UInt(32)), spv::ScopeSubgroup); - ib_.Begin(spv::OpTypeCooperativeMatrixNV) - .AddSeq(t, base_type, scope, v_row, v_col) - .Commit(&global_); - } + ib_.Begin(spv::OpTypeVector).AddSeq(t, base_type, dtype.lanes()).Commit(&global_); return t; } } @@ -757,30 +727,6 @@ Value IRBuilder::Cast(const SType& dst_type, spirv::Value value) { } } -Value IRBuilder::GetCompositeConst(const SType& ele_stype, const SType& composite_stype, - const double dval) { - auto key = std::make_pair(composite_stype.id, dval); - auto it = composite_const_tbl_.find(key); - if (it != composite_const_tbl_.end()) { - return it->second; - } - spirv::Value const_val = FloatImm(ele_stype, dval); - Value new_val = NewValue(composite_stype, kNormal); - ib_.Begin(spv::OpConstantComposite).AddSeq(composite_stype, new_val, const_val); - ib_.Commit(&global_); - composite_const_tbl_[key] = new_val; - return new_val; -} - -Value IRBuilder::GetSpecConst(const SType& dtype, uint64_t value) { - ICHECK_LE(dtype.type.bits(), 32); - Value ret = NewValue(dtype, kSpecConst); - ib_.Begin(spv::OpSpecConstant).AddSeq(dtype, ret); - ib_.Add(static_cast(value)); - ib_.Commit(&global_); - return ret; -} - #define DEFINE_BUILDER_BINARY_USIGN_OP(_OpName, _Op) \ Value IRBuilder::_OpName(Value a, Value b) { \ ICHECK_EQ(a.stype.id, b.stype.id); \ diff --git a/src/target/spirv/ir_builder.h b/src/target/spirv/ir_builder.h index e92e8364ee1bf..d642484532f99 100644 --- a/src/target/spirv/ir_builder.h +++ b/src/target/spirv/ir_builder.h @@ -65,8 +65,7 @@ enum ValueKind { kPushConstantPtr, kFunction, kExtInst, - kUniformPtr, - kSpecConst, + kUniformPtr }; /*! \brief Represent the SPIRV Value */ @@ -444,7 +443,7 @@ class IRBuilder { * \param dtype The data type. * \return The corresponding spirv type. */ - SType GetSType(const tvm::DataType& dtype, uint32_t row = 0, uint32_t col = 0); + SType GetSType(const tvm::DataType& dtype); /*! * \brief Get the pointer type that points to value_type * \param value_type. @@ -593,19 +592,6 @@ class IRBuilder { Value GT(Value a, Value b); Value GE(Value a, Value b); Value Select(Value cond, Value a, Value b); - /* - * \brief Get composite constant - * \param ele_stype The value type of elements in the composite. - * \param composite_type The value type of the composite. - * \param dval The initial value for all elements in the composite. - */ - Value GetCompositeConst(const SType& ele_stype, const SType& composite_stype, double dval); - /* - * Get specialization constant - * \param dtype The content value type - * \param value The default value - */ - Value GetSpecConst(const SType& dtype, uint64_t value); private: /*! @@ -654,9 +640,8 @@ class IRBuilder { // get constant given value encoded in uint64_t Value GetConst_(const SType& dtype, const uint64_t* pvalue); - // declare type - SType DeclareType(const DataType& dtype, uint32_t row = 0, uint32_t col = 0); + SType DeclareType(const DataType& dtype); // Declare the appropriate SPIR-V capabilities and extensions to use // this data type. @@ -711,15 +696,13 @@ class IRBuilder { /*! \brief whether push constant is defined */ Value push_const_; /*! \brief map from type code to the type */ - std::unordered_map pod_type_tbl_; + std::unordered_map pod_type_tbl_; /*! \brief map from value to array type */ std::map, SType> struct_array_type_tbl_; /*! \brief map from value to its pointer type */ std::map, SType> pointer_type_tbl_; /*! \brief map from constant int to its value */ std::map, Value> const_tbl_; - /*! \brief map from floating point composite constant to its value */ - std::map, Value> composite_const_tbl_; /*! \brief map from name of a ExtInstImport to its value */ std::map ext_inst_tbl_; diff --git a/src/target/spirv/spirv_support.cc b/src/target/spirv/spirv_support.cc index 4833d43e3e851..81b5cd8b8a6aa 100644 --- a/src/target/spirv/spirv_support.cc +++ b/src/target/spirv/spirv_support.cc @@ -101,10 +101,6 @@ SPIRVSupport::SPIRVSupport(tvm::Target target) { } } } - // Check whether cooperative matrix is enabled in the target string. - if (target->GetAttr("supports_cooperative_matrix")) { - supports_cooperative_matrix = target->GetAttr("supports_cooperative_matrix").value(); - } } } // namespace codegen diff --git a/src/target/spirv/spirv_support.h b/src/target/spirv/spirv_support.h index 83f92595112e0..6365e576b8cf1 100644 --- a/src/target/spirv/spirv_support.h +++ b/src/target/spirv/spirv_support.h @@ -276,20 +276,6 @@ struct SPIRVSupport { * attempting to perform integer dot product. */ bool supports_integer_dot_product{false}; - - /*! - * \brief Whether the driver supports operations involving cooperative matrix. - * - * Vulkan extension: VK_NV_cooperative_matrix - * SPV Extension name: SPV_NV_cooperative_matrix - * SPV Capability: spv::CapabilityCooperativeMatrixNV - * - * If support is present, can perform cooperative matrix operations. If - * support is not present, codegen will throw exception on - * attempting to perform cooperative matrix. - */ - - bool supports_cooperative_matrix{false}; }; } // namespace codegen diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 3c4e885ef9b52..3a555e304cb03 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -378,7 +378,6 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan) .add_attr_option("supports_push_descriptor") .add_attr_option("supports_dedicated_allocation") .add_attr_option("supports_integer_dot_product") - .add_attr_option("supports_cooperative_matrix") .add_attr_option("supported_subgroup_operations") // Physical device limits .add_attr_option("max_num_threads", Integer(256)) diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index ff838cf04e089..240b16aa5b1fd 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -926,9 +926,6 @@ class StoragePlanRewriter : public StmtExprMutator { StorageEntry* e = it->second; if (e->attach_scope_ != attach_scope) continue; if (e->scope != scope) continue; - // TODO(Anurag Kumar Vulisha): Fix the reuse of buffers with different dtype ex., float16 vs - // float32 - if (e->elem_type != op->dtype.element_of()) continue; // when not divided, no reuse, eg, float4 vs float3 if (e->bits_offset % op_elem_bits != 0) continue; e->const_nbits = std::max(const_nbits, e->const_nbits); diff --git a/tests/python/unittest/test_wmma.py b/tests/python/unittest/test_wmma.py deleted file mode 100644 index 7052be0e24b31..0000000000000 --- a/tests/python/unittest/test_wmma.py +++ /dev/null @@ -1,248 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# pylint: disable=missing-docstring -import tvm -from tvm import te -import tvm.testing -import numpy as np -from tvm import relay -from tvm.contrib import graph_executor -import os -from os import path as osp -import argparse -import sys -from types import MappingProxyType - -from tvm.tir.tensor_intrin.cuda import ( - LDMATRIX_16x16_A_INTRIN, - LDMATRIX_16x16_B_INTRIN, - LDMATRIX_16x16_B_TRANS_INTRIN, - LDMATRIX_16x32_A_INTRIN, - LDMATRIX_32x16_B_INTRIN, - LDMATRIX_16x32_B_TRANS_INTRIN, - MMA_f16f16f32_INTRIN, - MMA_f16f16f32_TRANS_INTRIN, - MMA_f16f16f16_INTRIN, - MMA_f16f16f16_TRANS_INTRIN, - MMA_i8i8i32_INTRIN, - MMA_i8i8i32_TRANS_INTRIN, - MMA_fill_16x16_f32_INTRIN, - MMA_fill_16x16_f16_INTRIN, - MMA_fill_16x16_i32_INTRIN, - MMA_store_16x16_f32_global_INTRIN, - MMA_store_16x16_f16_global_INTRIN, - MMA_store_16x16_i32_global_INTRIN, - shared_16x16_to_ldmatrix_32x8_layout, - shared_32x16_to_ldmatrix_32x16_layout, - shared_16x32_to_ldmatrix_32x16_layout, - WMMA_LOAD_16x16x16_F16_A_INTRIN, - WMMA_LOAD_16x16x16_F16_B_INTRIN, - WMMA_SYNC_16x16x16_f16f16f32_INTRIN, - WMMA_FILL_16x16x16_F32_INTRIN, - WMMA_STORE_16x16x16_F32_GLOBAL_INTRIN, -) - - -@tvm.testing.requires_vulkan -def test_wmma( - apply_log, - work_dir, - report_time, - add_cooperative, - batch_size=1, - in_channels=4096, - height=64, - width=64, - out_channels=4096, - kernel_h=1, - kernel_w=1, - in_dtype="float16", - out_dtype="float32", - padding=[0, 0, 0, 0], -): - # compute as conv2d - if sys.platform == "win32": - target_host = "llvm" - else: - target_host = "llvm -mtriple=x86_64-linux-gnu" - - old_target_str = "vulkan -from_device=0" - target_str = old_target_str - if add_cooperative: - target_str += " -supports_cooperative_matrix=1" - old_target = tvm.target.Target(old_target_str, host=target_host) - target = tvm.target.Target(target_str, host=target_host) - - if not target.supports_cooperative_matrix: - return - - data_shape = (batch_size, in_channels, height, width) - kernel_shape = (out_channels, in_channels, kernel_h, kernel_w) - - data = relay.var("data", shape=data_shape, dtype=in_dtype) - kernel = relay.var("kernel", shape=kernel_shape, dtype=in_dtype) - conv = relay.nn.conv2d( - data, - kernel, - data_layout="NCHW", - kernel_layout="OIHW", - padding=padding, - strides=[1, 1], - out_dtype=out_dtype, - channels=out_channels, - kernel_size=(kernel_h, kernel_w), - ) - - func = relay.Function([data, kernel], conv) - mod = tvm.IRModule.from_expr(func) - mod = tvm.relay.transform.InferType()(mod) - kernel_np = np.random.uniform(size=kernel_shape).astype(in_dtype) - - mod_params = {"kernel": kernel_np} - - from tvm.meta_schedule.relay_integration import extract_tasks - - if not apply_log: - tasks = extract_tasks( - mod, - target_str, - mod_params, - pass_config=MappingProxyType( - { - "relay.backend.use_meta_schedule": True, - "relay.backend.use_meta_schedule_dispatch": True, - "relay.backend.tir_converter": "default", - "relay.backend.use_int32_const": True, - } - ), - ) - - from tvm import meta_schedule as ms - - workload = osp.join(work_dir, "database_workload.json") - record = osp.join(work_dir, "database_tuning_record.json") - # dev = tvm.device(old_target_str, 0) - dev = tvm.vulkan(0) - data_np = np.random.uniform(size=data_shape).astype(in_dtype) - database = ms.database.JSONDatabase(workload, record) - - if apply_log: - assert osp.exists(workload) and osp.exists(record) - - else: - link_params = True - executor = relay.backend.Executor("graph", {"link-params": link_params}) - runner = ms.runner.LocalRunner( - evaluator_config=ms.runner.EvaluatorConfig( - number=3, - repeat=1, - min_repeat_ms=300, - enable_cpu_cache_flush="llvm" in str(target_str), - ) - ) - tasks, task_weights = ms.relay_integration.extracted_tasks_to_tune_contexts( - extracted_tasks=tasks, - work_dir=work_dir, - strategy="evolutionary", - ) - database = ms.tune.tune_tasks( - tasks=tasks, - task_weights=task_weights, - work_dir=work_dir, - max_trials_global=128, - max_trials_per_task=128, - num_trials_per_iter=128, - runner=runner, - database=database, - min_design_space=50, - ) - - lib = ms.relay_integration.compile_relay( - database=database, - mod=mod, - target=target_str, - params=mod_params, - backend="graph", - pass_config=MappingProxyType( - { - "relay.backend.use_meta_schedule": True, - "relay.backend.use_meta_schedule_dispatch": True, - "relay.backend.tir_converter": "default", - "relay.backend.use_int32_const": True, - } - ), - ) - - m = graph_executor.GraphModule(lib["default"](dev)) - m.set_input("data", data_np) - m.run() - output_0 = m.get_output(0).asnumpy() - if report_time: - ftimer = m.module.time_evaluator("run", dev, number=10, repeat=2) - prof_res = np.array(ftimer().results) * 1000 # convert to millisecond - print( - "With cooperative matrix, TVM graph runtime mean inference time (std dev): %.2f ms (%.2f ms)" - % (np.mean(prof_res), np.std(prof_res)) - ) - - with tvm.transform.PassContext(opt_level=3): - lib = relay.build(mod, old_target, params=mod_params) - m = graph_executor.GraphModule(lib["default"](dev)) - m.set_input("data", data_np) - m.run() - output_1 = m.get_output(0).asnumpy() - if report_time: - ftimer = m.module.time_evaluator("run", dev, number=10, repeat=2) - prof_res = np.array(ftimer().results) * 1000 # convert to millisecond - print( - "Without cooperative matrix, TVM graph runtime mean inference time (std dev): %.2f ms (%.2f ms)" - % (np.mean(prof_res), np.std(prof_res)) - ) - - tvm.testing.assert_allclose(output_0, output_1, rtol=1e-3) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--apply-log", action="store_true", help="If set, apply tunig log") - parser.add_argument("--log-dir", type=str, default=".", help="Dir of logs") - parser.add_argument("--report-time", action="store_true", help="If set, print execution time") - parser.add_argument( - "--add-cooperative", - action="store_true", - help="If set, add cooperative matrix target string", - ) - - args = parser.parse_args() - # test_wmma(args.apply_log, args.log_dir, args.report_time, args.add_cooperative, batch_size=1, in_channels=8, height=2, width=2, out_channels=8, kernel_h=1, kernel_w=1, padding=[0, 0, 0, 0],) - # test_wmma(args.apply_log, args.log_dir, args.report_time, args.add_cooperative, batch_size=1, in_channels=8, height=16, width=16, out_channels=8, kernel_h=2, kernel_w=2, padding=[1, 1, 1, 1],) - test_wmma( - args.apply_log, - args.log_dir, - args.report_time, - args.add_cooperative, - batch_size=1, - in_channels=16, - height=16, - width=16, - out_channels=16, - kernel_h=1, - kernel_w=1, - padding=[0, 0, 0, 0], - ) - # test_wmma(args.apply_log, args.log_dir, args.report_time, args.add_cooperative, batch_size=1, in_channels=16, height=16, width=16, out_channels=16, kernel_h=2, kernel_w=2, padding=[1, 1, 1, 1],)