Skip to content

Commit

Permalink
Revert "Revert "[Vulkan] Add cooperative matrix support""
Browse files Browse the repository at this point in the history
This reverts commit 2397989.
  • Loading branch information
mei-ye committed May 10, 2023
1 parent 2397989 commit 2c6b33e
Show file tree
Hide file tree
Showing 31 changed files with 1,254 additions and 46 deletions.
10 changes: 7 additions & 3 deletions include/tvm/meta_schedule/task_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ class TaskSchedulerNode : public runtime::Object {
* \param measure_callbacks The callbacks to be called after each measurement
* \param database The database used in tuning
* \param cost_model The cost model used in tuning
* \param min_design_space The minimum design space used in tuning
*/
virtual void Tune(Array<TuneContext> tasks, //
Array<FloatImm> task_weights, //
Expand All @@ -185,7 +186,8 @@ class TaskSchedulerNode : public runtime::Object {
Runner runner, //
Array<MeasureCallback> measure_callbacks, //
Optional<Database> database, //
Optional<CostModel> cost_model);
Optional<CostModel> cost_model, //
int min_design_space);
/*!
* \brief Terminate a task
* \param task_id The id of the task to be terminated
Expand Down Expand Up @@ -228,7 +230,9 @@ class PyTaskSchedulerNode : public TaskSchedulerNode {
Runner runner, //
Array<MeasureCallback> measure_callbacks, //
Optional<Database> database, //
Optional<CostModel> cost_model)>;
Optional<CostModel> cost_model,

int min_design_space)>;

/*! \brief The packed function to the `NextTaskId` function. */
FNextTaskId f_next_task_id;
Expand All @@ -249,7 +253,7 @@ class PyTaskSchedulerNode : public TaskSchedulerNode {
void Tune(Array<TuneContext> tasks, Array<FloatImm> task_weights, int max_trials_global,
int max_trials_per_task, int num_trials_per_iter, Builder builder, Runner runner,
Array<MeasureCallback> measure_callbacks, Optional<Database> database,
Optional<CostModel> cost_model) final;
Optional<CostModel> cost_model, int min_design_space) final;

static constexpr const char* _type_key = "meta_schedule.PyTaskScheduler";
TVM_DECLARE_FINAL_OBJECT_INFO(PyTaskSchedulerNode, TaskSchedulerNode);
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/meta_schedule/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ def tune_relay(
num_tuning_cores: Union[Literal["physical", "logical"], int] = "physical",
disabled_pass: Optional[Union[List[str], Set[str], Tuple[str]]] = None,
instruments: Optional[Sequence[PassInstrument]] = None,
min_design_space: int = 1,
) -> Database:
"""Tune a Relay program.
Expand Down Expand Up @@ -328,6 +329,8 @@ def tune_relay(
The list of disabled passes during tasks extraction
instruments : Optional[Sequence[PassInstrument]]
The list of pass instrument implementations.
min_design_space : int
The minimum design space.
Returns
-------
Expand Down Expand Up @@ -363,6 +366,7 @@ def tune_relay(
measure_callbacks=measure_callbacks,
task_scheduler=task_scheduler,
module_equality=module_equality,
min_design_space=min_design_space,
)


Expand Down
6 changes: 6 additions & 0 deletions python/tvm/meta_schedule/task_scheduler/task_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def tune(
measure_callbacks: List[MeasureCallback],
database: Optional[Database],
cost_model: Optional[CostModel],
min_design_space: int = 1,
) -> None:
"""Auto-tuning.
Expand All @@ -127,6 +128,8 @@ def tune(
The database.
cost_model : Optional[CostModel]
The cost model.
min_design_space : int
THe minimum size of design space.
"""
task_weights = [float(w) for w in task_weights]
_ffi_api.TaskSchedulerTune( # type: ignore # pylint: disable=no-member
Expand All @@ -141,6 +144,7 @@ def tune(
measure_callbacks,
database,
cost_model,
min_design_space,
)

def terminate_task(self, task_id: int) -> None:
Expand Down Expand Up @@ -243,6 +247,7 @@ def tune(
measure_callbacks: List[MeasureCallback],
database: Optional[Database],
cost_model: Optional[CostModel],
min_design_space: int = 1,
) -> None:
"""Auto-tuning."""
# Using self._outer to replace the self pointer
Expand All @@ -257,6 +262,7 @@ def tune(
measure_callbacks,
database,
cost_model,
min_design_space,
)

def next_task_id(self) -> int:
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/meta_schedule/tir_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def tune_tir( # pylint: disable=too-many-locals
seed: Optional[int] = None,
module_equality: str = "structural",
special_space: Optional[Mapping[str, SpaceGenerator.SpaceGeneratorType]] = None,
min_design_space: int = 1,
) -> Database:
"""Tune a TIR function or an IRModule of TIR functions.
Expand Down Expand Up @@ -99,6 +100,8 @@ def tune_tir( # pylint: disable=too-many-locals
A string to specify the module equality testing and hashing method.
special_space : Optional[Mapping[str, SpaceGenerator.SpaceGeneratorType]]
A mapping from task name to a special space generator for that task.
min_design_space : int
The minimum design space.
Returns
-------
Expand Down Expand Up @@ -154,6 +157,7 @@ def tune_tir( # pylint: disable=too-many-locals
measure_callbacks=measure_callbacks,
task_scheduler=task_scheduler,
module_equality=module_equality,
min_design_space=min_design_space,
)


Expand Down
5 changes: 4 additions & 1 deletion python/tvm/meta_schedule/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def tune_tasks(
measure_callbacks: MeasureCallback.CallbackListType = "default",
task_scheduler: TaskScheduler.TaskSchedulerType = "gradient",
module_equality: str = "structural",
min_design_space: int = 1,
) -> Database:
"""Tune a list of tasks. Using a task scheduler.
Expand Down Expand Up @@ -73,14 +74,15 @@ def tune_tasks(
module_equality : Optional[str]
A string to specify the module equality testing and hashing method.
It must be one of the followings:
- "structural": Use StructuralEqual/Hash
- "ignore-ndarray": Same as "structural", but ignore ndarray raw data during equality
testing and hashing.
- "anchor-block": Apply equality testing and hashing on the anchor block extracted from
a given module. The "ignore-ndarray" varint is used for the extracted blocks or in
case no anchor block is found. For the definition of the anchor block, see
tir/analysis/analysis.py.
min_design_space : int
Minimum design space.
Returns
-------
Expand Down Expand Up @@ -126,5 +128,6 @@ def tune_tasks(
measure_callbacks=measure_callbacks,
database=database,
cost_model=cost_model,
min_design_space=min_design_space,
)
return database
12 changes: 12 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,13 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
kernel_layout = attrs.kernel_layout
if dilation_h < 1 or dilation_w < 1:
raise ValueError("dilation should be positive value")

if groups == 1:
if layout == "NCHW":
assert kernel_layout == "OIHW"
do_im2col = topi.nn.use_im2col(
data, kernel, stride_h, stride_w, dilation_h, dilation_w, padding
)
if (
(target.kind.name in ["cuda", "vulkan", "rocm"])
and data.dtype in ("int8", "uint8")
Expand All @@ -159,6 +163,14 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw_int8),
name="conv2d_nchw_int8.cuda",
)
elif do_im2col:
assert data.dtype == kernel.dtype
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nchw_mma),
naive_schedule,
name="conv2d_nchw_mma.cuda",
plevel=15,
)
else:
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nchw),
Expand Down
21 changes: 21 additions & 0 deletions python/tvm/target/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,25 @@ def max_num_threads(self):
"""Returns the max_num_threads from the target if it exists."""
return int(self.attrs["max_num_threads"])

@property
def max_block_size_x(self):
"""Returns the max block size in x-dimension from the target if it exists."""
return int(self.attrs["max_block_size_x"])

@property
def max_block_size_y(self):
"""Returns the max block size in y-dimension from the target if it exists."""
return int(self.attrs["max_block_size_y"])

@property
def thread_warp_size(self):
"""Returns the thread_warp_size from the target if it exists."""
return int(self.attrs["thread_warp_size"])

@property
def max_shared_memory_per_block(self):
return int(self.attrs["max_shared_memory_per_block"])

@property
def max_function_args(self):
return int(self.attrs.get("max_function_args", -1))
Expand Down Expand Up @@ -219,6 +233,13 @@ def supports_integer_dot_product(self):
def libs(self):
return list(self.attrs.get("libs", []))

@property
def supports_cooperative_matrix(self):
if self.attrs.get("supports_cooperative_matrix", []):
return bool(self.attrs["supports_cooperative_matrix"])
else:
return False

@property
def features(self):
return TargetFeatures(self)
Expand Down
Loading

0 comments on commit 2c6b33e

Please sign in to comment.