diff --git a/README.md b/README.md
index 9a72dae882..3ea89ebdee 100644
--- a/README.md
+++ b/README.md
@@ -90,6 +90,7 @@ Within the following table, we summarized the current NNI capabilities, we are g
Cifar10-pytorch
Scikit-learn
EfficientNet
+ Kernel Tunning
More...
@@ -170,7 +171,7 @@ Within the following table, we summarized the current NNI capabilities, we are g
Kubeflow
FrameworkController on K8S (AKS etc.)
- - DLWorkspace (aka. DLTS)
+
diff --git a/docs/en_US/TrialExample/OpEvoExamples.md b/docs/en_US/TrialExample/OpEvoExamples.md
new file mode 100644
index 0000000000..5db331f1bf
--- /dev/null
+++ b/docs/en_US/TrialExample/OpEvoExamples.md
@@ -0,0 +1,85 @@
+# Tuning Tensor Operators on NNI
+
+## Overview
+
+Abundant applications raise the demands of training and inference deep neural networks (DNNs) efficiently on diverse hardware platforms ranging from cloud servers to embedded devices. Moreover, computational graph-level optimization of deep neural network, like tensor operator fusion, may introduce new tensor operators. Thus, manually optimized tensor operators provided by hardware-specific libraries have limitations in terms of supporting new hardware platforms or supporting new operators, so automatically optimizing tensor operators on diverse hardware platforms is essential for large-scale deployment and application of deep learning technologies in the real-world problems.
+
+Tensor operator optimization is substantially a combinatorial optimization problem. The objective function is the performance of a tensor operator on specific hardware platform, which should be maximized with respect to the hyper-parameters of corresponding device code, such as how to tile a matrix or whether to unroll a loop. This example illustrates how to automatically tune tensor operators with NNI. Three tuning algorithms, OpEvo, G-BFS and N-A2C are provided. Please refer to [OpEvo: An Evolutionary Method for Tensor Operator Optimization](https://arxiv.org/abs/2006.05664) for detailed explanation about these algorithms.
+
+
+## Environment Setup
+
+We prepared a dockerfile for setting up experiment environments. Before starting, please make sure the Docker daemon is running and the driver of your GPU accelerator is properly installed. Enter into the example folder `examples/trials/systems/opevo` and run below command to build and instantiate a Docker image from the dockerfile.
+```bash
+# if you are using Nvidia GPU
+make cuda-env
+# if you are using AMD GPU
+make rocm-env
+```
+
+## Run Experiments:
+
+Three representative kinds of tensor operators, **matrix multiplication**, **batched matrix multiplication** and **2D convolution**, are chosen from BERT and AlexNet, and tuned with NNI. The `Trial` code for all tensor operators is `/root/compiler_auto_tune_stable.py`, and `Search Space` files and `config` files for each tuning algorithm locate in `/root/experiments/`, which are categorized by tensor operators. Here `/root` refers to the root of the container.
+
+For tuning the operators of matrix multiplication, please run below commands from `/root`:
+```bash
+# (N, K) x (K, M) represents a matrix of shape (N, K) multiplies a matrix of shape (K, M)
+
+# (512, 1024) x (1024, 1024)
+# tuning with opevo
+nnictl create --config experiments/mm/N512K1024M1024/config_opevo.yml
+# tuning with g-bfs
+nnictl create --config experiments/mm/N512K1024M1024/config_gbfs.yml
+# tuning with n-a2c
+nnictl create --config experiments/mm/N512K1024M1024/config_na2c.yml
+
+# (512, 1024) x (1024, 4096)
+# tuning with opevo
+nnictl create --config experiments/mm/N512K1024M4096/config_opevo.yml
+# tuning with g-bfs
+nnictl create --config experiments/mm/N512K1024M4096/config_gbfs.yml
+# tuning with n-a2c
+nnictl create --config experiments/mm/N512K1024M4096/config_na2c.yml
+
+# (512, 4096) x (4096, 1024)
+# tuning with opevo
+nnictl create --config experiments/mm/N512K1024M4096/config_opevo.yml
+# tuning with g-bfs
+nnictl create --config experiments/mm/N512K1024M4096/config_gbfs.yml
+# tuning with n-a2c
+nnictl create --config experiments/mm/N512K1024M4096/config_na2c.yml
+```
+
+For tuning the operators of batched matrix multiplication, please run below commands from `/root`:
+```bash
+# batched matrix with batch size 960 and shape of matrix (128, 128) multiplies batched matrix with batch size 960 and shape of matrix (128, 64)
+nnictl create --config experiments/bmm/B960N128K128M64PNN/config_opevo.yml
+# batched matrix with batch size 960 and shape of matrix (128, 128) is transposed first and then multiplies batched matrix with batch size 960 and shape of matrix (128, 64)
+nnictl create --config experiments/bmm/B960N128K128M64PTN/config_opevo.yml
+# batched matrix with batch size 960 and shape of matrix (128, 64) is transposed first and then right multiplies batched matrix with batch size 960 and shape of matrix (128, 64).
+nnictl create --config experiments/bmm/B960N128K64M128PNT/config_opevo.yml
+```
+
+For tuning the operators of 2D convolution, please run below commands from `/root`:
+```bash
+# image tensor of shape $(512, 3, 227, 227)$ convolves with kernel tensor of shape $(64, 3, 11, 11)$ with stride 4 and padding 0
+nnictl create --config experiments/conv/N512C3HW227F64K11ST4PD0/config_opevo.yml
+# image tensor of shape $(512, 64, 27, 27)$ convolves with kernel tensor of shape $(192, 64, 5, 5)$ with stride 1 and padding 2
+nnictl create --config experiments/conv/N512C64HW27F192K5ST1PD2/config_opevo.yml
+```
+
+Please note that G-BFS and N-A2C are not eligible to tune the operators of batched matrix multiplication and 2D convolution, since there are unsupportable parameters in the search spaces of these operators.
+
+## Citing OpEvo
+
+If you use OpEvo in your research, please consider citing the paper as follows:
+```
+@misc{gao2020opevo,
+ title={OpEvo: An Evolutionary Method for Tensor Operator Optimization},
+ author={Xiaotian Gao and Cui Wei and Lintao Zhang and Mao Yang},
+ year={2020},
+ eprint={2006.05664},
+ archivePrefix={arXiv},
+ primaryClass={cs.LG}
+}
+```
diff --git a/docs/en_US/examples.rst b/docs/en_US/examples.rst
index 77ea9733b4..c44f081997 100644
--- a/docs/en_US/examples.rst
+++ b/docs/en_US/examples.rst
@@ -11,5 +11,6 @@ Examples
EvolutionSQuAD<./TrialExample/SquadEvolutionExamples>
GBDT<./TrialExample/GbdtExample>
RocksDB <./TrialExample/RocksdbExamples>
+ OpEvo <./TrialExample/OpEvoExamples>
KDExample <./TrialExample/KDExample>
EfficientNet <./TrialExample/EfficientNet>
diff --git a/examples/trials/systems/opevo/Dockerfile b/examples/trials/systems/opevo/Dockerfile
new file mode 100644
index 0000000000..a1c589c368
--- /dev/null
+++ b/examples/trials/systems/opevo/Dockerfile
@@ -0,0 +1,42 @@
+FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu18.04
+
+ENV PYTHONDONTWRITEBYTECODE 1
+ENV HIP_PLATFORM hcc
+ENV PATH $PATH:/opt/rocm/bin:/usr/local/nvidia/lib64/bin
+ENV TVM_HOME=/opt/tvm
+ENV PYTHONPATH=/usr/local/rocm/src:$TVM_HOME/python:$TVM_HOME/topi/python:$TVM_HOME/nnvm/python
+ENV HSA_USERPTR_FOR_PAGED_MEM=0
+
+RUN env > /etc/environment
+
+RUN apt-get update && apt install -y --no-install-recommends git ca-certificates \
+ python3-pip python3-wheel python3-setuptools python3-dev python3-pytest \
+ vim less netcat-openbsd inetutils-ping curl patch iproute2 \
+ g++ libpci3 libnuma-dev make cmake file openssh-server kmod gdb libopenmpi-dev openmpi-bin \
+ autoconf automake autotools-dev libtool multiarch-support \
+ && rm -rf /var/lib/apt/lists/*
+
+RUN curl -sL http://repo.radeon.com/rocm/apt/debian/rocm.gpg.key | apt-key add - && \
+ printf "deb [arch=amd64] http://repo.radeon.com/rocm/apt/3.3/ xenial main" | tee /etc/apt/sources.list.d/rocm_hip.list && \
+ apt update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
+ rocm-dev zlib1g-dev unzip librdmacm-dev rocblas hipsparse rccl rocfft rocrand miopen-hip && apt-get clean && rm -rf /var/lib/apt/lists/*
+RUN ln -sf libcudart.so /usr/local/cuda-10.0/targets/x86_64-linux/lib/libcudart_static.a
+
+RUN pip3 install tornado psutil xgboost==0.80 numpy decorator attrs && rm -rf ~/.cache
+RUN git clone https://github.com/dmlc/tvm $TVM_HOME
+
+RUN cd $TVM_HOME && git checkout v0.6 && git submodule init && git submodule update && \
+ mkdir -p build && cd build && cp ../cmake/config.cmake . && \
+ sed -i 's/LLVM ON/LLVM OFF/g' config.cmake && sed -i 's/CUDA OFF/CUDA ON/g' config.cmake && \
+ cmake .. && make -j16
+
+RUN pip3 install nni==1.5 && rm -rf ~/.cache
+RUN pip3 install torch==1.5.0+cpu torchvision==0.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html && rm -rf ~/.cache
+
+ADD tvm_patches/tvm_v0.6.patch $TVM_HOME/tvm_v0.6.patch
+ADD tvm_patches/libcuda.so.1 $TVM_HOME/build
+RUN ln -sf libcuda.so.1 $TVM_HOME/build/libcudart.so.10.0
+RUN cd $TVM_HOME && git apply tvm_v0.6.patch && cd build && make -j16
+
+ADD src /root/
+
diff --git a/examples/trials/systems/opevo/Makefile b/examples/trials/systems/opevo/Makefile
new file mode 100644
index 0000000000..ba7ea78c18
--- /dev/null
+++ b/examples/trials/systems/opevo/Makefile
@@ -0,0 +1,14 @@
+rocm-env: build
+ docker run -it --rm --privileged -v /:/host -w /root \
+ -e BACKEND=c-rocm -p 8080:8080 \
+ tvm4nni bash || true
+
+cuda-env: build
+ docker run -it --rm --privileged -v /:/host -w /root \
+ -e BACKEND=c-cuda -p 8080:8080 \
+ -v /usr/lib/x86_64-linux-gnu/libcuda.so.1:/usr/lib/x86_64-linux-gnu/libcuda.so.1 \
+ -v $(shell dirname `ldd /usr/lib/x86_64-linux-gnu/libcuda.so.1 | grep nvidia-fatbinaryloader | awk '{print $$3}'`):/usr/local/nvidia/lib64 \
+ tvm4nni bash || true
+
+build:
+ docker build -t tvm4nni --network=host .
diff --git a/examples/trials/systems/opevo/src/algorithms/gbfs.py b/examples/trials/systems/opevo/src/algorithms/gbfs.py
new file mode 100644
index 0000000000..5a4a91e21b
--- /dev/null
+++ b/examples/trials/systems/opevo/src/algorithms/gbfs.py
@@ -0,0 +1,278 @@
+import math
+import random
+import logging
+import copy
+
+import nni
+from nni.tuner import Tuner
+
+
+class Factor(object):
+ """factor type parameter
+ """
+ def __init__(self, value):
+ self.product, self.num = value
+ self.partition = [1] * self.num
+ self.partition[0] = self.product
+
+ def pick_out(self):
+ return self.partition
+
+ def step(self, action):
+ tmp = copy.deepcopy(self)
+ tmp.partition[action[0]] = int(tmp.partition[action[0]] / action[2])
+ tmp.partition[action[1]] = int(tmp.partition[action[1]] * action[2])
+
+ return tmp
+
+ def get_actions(self):
+ actions = []
+ prime_factors = self._get_prime_factors(self.product, False)
+ for i in range(self.num):
+ for j in range(self.num):
+ if i != j:
+ for k in range(len(prime_factors)):
+ action = [i]
+ action.append(j)
+ action.append(prime_factors[k])
+ if self.partition[action[0]] % action[2] == 0:
+ actions.append(action)
+ return actions
+
+ def __repr__(self):
+ string = "["
+ for factor in self.partition:
+ string += factor.__repr__() + " "
+ string = string[:-1] + "]"
+
+ return string
+
+ def _get_prime_factors(self, n, repeat=True):
+ prime_factors = []
+
+ while n % 2 == 0:
+ if 2 not in prime_factors:
+ prime_factors.append(2)
+ elif repeat:
+ prime_factors.append(2)
+ n = n / 2
+
+ for i in range(3, int(math.sqrt(n)) + 1, 2):
+ while n % i == 0:
+ if i not in prime_factors:
+ prime_factors.append(i)
+ elif repeat:
+ prime_factors.append(i)
+ n = n / i
+
+ if n > 2:
+ prime_factors.append(int(n))
+
+ return prime_factors
+
+
+class Configuration(object):
+ """Configuration class
+ """
+ def __init__(self, search_space):
+ self.params = {}
+ for key in search_space.keys():
+ if search_space[key]['_type'] == 'factor':
+ self.params[key] = \
+ Factor(search_space[key]['_value'])
+ else:
+ raise RuntimeError(
+ "G_BFS Tuner doesn't support this kind of parameter: "
+ + str(search_space[key]['_type'])
+ )
+
+ def __eq__(self, other):
+ if isinstance(other, self.__class__):
+ return self.__dict__ == other.__dict__
+ else:
+ return False
+
+ def __repr__(self):
+ string = ""
+ for param in self.params:
+ string += param.__repr__() + '\n'
+
+ return string
+
+ def pick_out(self):
+ output = {}
+ for key in self.params.keys():
+ output[key] = self.params[key].pick_out()
+
+ return output
+
+ def step(self, action):
+ config = copy.deepcopy(self)
+ config.params[action[0]] = config.params[action[0]].step(action[1])
+
+ return config
+
+ def get_actions(self):
+ actions = []
+ for key, value in self.params.items():
+ subactions = value.get_actions()
+ for subaction in subactions:
+ action = [key]
+ action.append(subaction)
+ actions.append(action)
+
+ return actions
+
+
+class Population(object):
+ """Population class
+ """
+
+ def __init__(self, opt_mode, search_space, num_samples):
+ self.opt_mode = opt_mode
+ self.search_space = search_space
+ self.num_samples = num_samples
+
+ self.queue = []
+ self.population = []
+ self.fitness = []
+
+ def append(self, individual, fitness):
+ if self.opt_mode == "minimize":
+ fitness = -1 * fitness
+
+ self.population.append(individual)
+ self.queue.insert(0, individual)
+ self.fitness.insert(0, fitness)
+
+ i = 0
+ while (i < len(self.fitness) - 1
+ and self.fitness[i] < self.fitness[i + 1]):
+ self.fitness[i], self.fitness[i + 1] = \
+ self.fitness[i + 1], self.fitness[i]
+ self.queue[i], self.queue[i + 1] = \
+ self.queue[i + 1], self.queue[i]
+ i += 1
+
+ def generate(self):
+ if not self.fitness and not self.population:
+ return [Configuration(self.search_space)]
+ elif not self.fitness and self.population:
+ return []
+ else:
+ self.fitness.pop(0)
+ config = self.queue.pop(0)
+
+ action_space = config.get_actions()
+ num = len(action_space)
+ if num > self.num_samples:
+ indices = random.sample(range(num), self.num_samples)
+ else:
+ indices = range(num)
+
+ res = []
+ for idx in indices:
+ tmp = config.step(action_space[idx])
+ if tmp not in self.population:
+ res.append(tmp)
+
+ return res
+
+
+class G_BFS(Tuner):
+ """G-BFS Tuner
+ Based on paper Compiler-Level Matrix Multiplication Optimization for Deep Learning
+
+ Parameters
+ ----------
+ optimize_mode: str, 'maximize' or 'minimize'
+ num_samples: int,
+ The random selection parameter rho
+ """
+ def __init__(self, optimize_mode="maximize", num_samples=5):
+ self.logger = logging.getLogger(
+ self.__module__ + "." + self.__class__.__name__)
+ self.logger.setLevel('DEBUG')
+
+ self.opt_mode = optimize_mode
+ self.num_samples = num_samples
+
+ self.request_list = []
+ self.serve_list = []
+ self.wait_dict = {}
+
+ def update_search_space(self, search_space):
+ """Update the self.bounds and self.types by the search_space.json file.
+
+ Override of the abstract method in :class:`~nni.tuner.Tuner`.
+ """
+ if not isinstance(search_space, dict):
+ self.logger.info("The format of search space is not a dict.")
+ raise RuntimeError("The format of search space is not a dict.")
+
+ self.population = \
+ Population(self.opt_mode, search_space, self.num_samples)
+
+ if not self.serve_list:
+ self.serve_list = self.population.generate()
+
+ def generate_multiple_parameters(self, parameter_id_list, **kwargs):
+ """Returns multiple sets of trial (hyper-)parameters,
+ as iterable of serializable objects.
+ """
+ result = []
+ self.send_trial_callback = kwargs['st_callback']
+ for parameter_id in parameter_id_list:
+ had_exception = False
+ try:
+ self.logger.debug("generating param for %s", parameter_id)
+ res = self.generate_parameters(parameter_id, **kwargs)
+ except nni.NoMoreTrialError:
+ had_exception = True
+ if not had_exception:
+ result.append(res)
+ return result
+
+ def generate_parameters(self, parameter_id, **kwargs):
+ """Method which provides one set of hyper-parameters.
+
+ Override of the abstract method in :class:`~nni.tuner.Tuner`.
+ """
+ if self.serve_list:
+ self.wait_dict[parameter_id] = self.serve_list.pop()
+ return self.wait_dict[parameter_id].pick_out()
+ else:
+ self.request_list.append(parameter_id)
+ raise nni.NoMoreTrialError('no more parameters now.')
+
+ def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
+ """Method invoked when a trial reports its final result.
+
+ Override of the abstract method in :class:`~nni.tuner.Tuner`.
+ """
+ if isinstance(value, dict):
+ value = value['default']
+
+ self.population.append(self.wait_dict[parameter_id], value)
+ del self.wait_dict[parameter_id]
+
+ if not self.serve_list and not self.wait_dict:
+ self.serve_list = self.population.generate()
+ if not self.serve_list:
+ raise RuntimeError("Tuner stopped since no candidates")
+
+ while self.request_list and self.serve_list:
+ param_id = self.request_list[0]
+ self.wait_dict[param_id] = self.serve_list.pop()
+ self.send_trial_callback(
+ param_id, self.wait_dict[param_id].pick_out())
+ self.request_list.pop(0)
+
+ def trial_end(self, parameter_id, success, **kwargs):
+ """Method invoked when a trial is completed or terminated.
+
+ Override of the abstract method in :class:`~nni.tuner.Tuner`.
+ """
+ if not success:
+ self.population.append(self.wait_dict[parameter_id], 0.0)
+ del self.wait_dict[parameter_id]
diff --git a/examples/trials/systems/opevo/src/algorithms/na2c.py b/examples/trials/systems/opevo/src/algorithms/na2c.py
new file mode 100644
index 0000000000..7c51a36e40
--- /dev/null
+++ b/examples/trials/systems/opevo/src/algorithms/na2c.py
@@ -0,0 +1,401 @@
+import math
+import random
+import logging
+import copy
+
+import torch
+from torch import optim
+from torch import nn
+import torch.nn.functional as F
+import numpy as np
+
+import nni
+from nni.tuner import Tuner
+
+
+class Factor(object):
+ """factor type parameter
+ """
+ def __init__(self, value):
+ self.product, self.num = value
+ self.partition = [1] * self.num
+ self.partition[0] = self.product
+
+ def pick_out(self):
+ return self.partition
+
+ def step(self, action):
+ if self.partition[action[0]] % action[2] == 0:
+ self.partition[action[0]] /= action[2]
+ self.partition[action[1]] *= action[2]
+ status = True
+ else:
+ status = False
+
+ return status
+
+ def get_actions(self):
+ actions = []
+ prime_factors = self._get_prime_factors(self.product, False)
+ for i in range(self.num):
+ for j in range(self.num):
+ if i != j:
+ for k in range(len(prime_factors)):
+ action = [i]
+ action.append(j)
+ action.append(prime_factors[k])
+ actions.append(action)
+
+ return actions
+
+ def __repr__(self):
+ return self.partition.__repr__()
+
+ def __eq__(self, other):
+ if isinstance(other, self.__class__):
+ return self.__dict__ == other.__dict__
+ else:
+ return False
+
+ def _get_prime_factors(self, n, repeat=True):
+ prime_factors = []
+
+ while n % 2 == 0:
+ if 2 not in prime_factors:
+ prime_factors.append(2)
+ elif repeat:
+ prime_factors.append(2)
+ n = n / 2
+
+ for i in range(3, int(math.sqrt(n)) + 1, 2):
+ while n % i == 0:
+ if i not in prime_factors:
+ prime_factors.append(i)
+ elif repeat:
+ prime_factors.append(i)
+ n = n / i
+
+ if n > 2:
+ prime_factors.append(int(n))
+
+ return prime_factors
+
+
+class Configuration(object):
+ """Configuration class
+ """
+ def __init__(self, search_space):
+ self.params = {}
+ self.key_order = []
+ for key in search_space.keys():
+ if search_space[key]['_type'] == 'factor':
+ self.key_order.append(key)
+ self.params[key] = \
+ Factor(search_space[key]['_value'])
+ else:
+ raise RuntimeError(
+ "N_A2C Tuner doesn't support this kind of parameter: "
+ + str(search_space[key]['_type'])
+ )
+
+ def __eq__(self, other):
+ if isinstance(other, self.__class__):
+ return self.__dict__ == other.__dict__
+ else:
+ return False
+
+ def __repr__(self):
+ string = ""
+ for key, value in self.params.items():
+ string += key + ': ' + value.__repr__() + ' '
+
+ return string
+
+ def pick_out(self):
+ output = {}
+ for key in self.params.keys():
+ output[key] = self.params[key].pick_out()
+
+ return output
+
+ def step(self, action):
+ config = copy.deepcopy(self)
+ status = config.params[action[0]].step(action[1])
+
+ return status, config
+
+ def get_actions(self):
+ actions = []
+ for key, value in self.params.items():
+ subactions = value.get_actions()
+ for subaction in subactions:
+ action = [key]
+ action.append(subaction)
+ actions.append(action)
+
+ return actions
+
+ def to_torch(self):
+ states = []
+ for key in self.key_order:
+ state = torch.tensor(self.params[key].partition).float() / \
+ self.params[key].product - 0.5
+ states.append(state)
+
+ return torch.cat(states).float()
+
+
+class ActorCritic(nn.Module):
+ def __init__(self, num_states, num_actions, hidden_size):
+ super(ActorCritic, self).__init__()
+
+ self.num_actions = num_actions
+ self.fc = nn.Linear(num_states, hidden_size)
+ self.critic_linear2 = nn.Linear(hidden_size, 1)
+ self.actor_linear2 = nn.Linear(hidden_size, num_actions)
+
+ def forward(self, state):
+ x = F.relu(self.fc(state))
+ value = self.critic_linear2(x)
+ policy_dist = F.softmax(self.actor_linear2(x))
+
+ return value, policy_dist
+
+
+class Population(object):
+ """Population class
+ """
+ def __init__(self, search_space, opt_mode, n_states, n_steps,
+ hidden_size, lr):
+ self.search_space = search_space
+ self.opt_mode = opt_mode
+ self.n_states = n_states
+ self.n_steps = n_steps
+ self.hidden_size = hidden_size
+ self.lr = lr
+
+ self.config = Configuration(search_space)
+ self.max_reward = 0.0
+
+ self.action_space = self.config.get_actions()
+ self.dim_actions = len(self.action_space)
+ self.dim_states = len(self.config.to_torch())
+ self.log_probs = []
+ self.values = []
+ self.rewards = []
+
+ self.population = []
+
+ self.actor_critic = ActorCritic(
+ self.dim_states, self.dim_actions, self.hidden_size
+ )
+ self.ac_optimizer = optim.Adam(
+ self.actor_critic.parameters(), lr=self.lr
+ )
+
+ def append(self, individual, fitness):
+ if self.opt_mode == "minimize":
+ fitness = -1 * fitness
+
+ self.population.append(individual)
+
+ if self.max_reward < fitness:
+ self.max_reward = fitness
+ self.config = individual
+
+ if self.collect:
+ idx = self.collect.index(individual)
+ self.waiting_rewards[idx] = fitness
+ del self.collect[idx]
+ else:
+ raise RuntimeError("Received unexpected trials.")
+
+ if not self.collect:
+ self.rewards.extend(self.waiting_rewards)
+
+ self.ac_optimizer.zero_grad()
+ gradient_loss = 0
+ value_loss = 0
+ for i in range(len(self.values)):
+ advantage = self.rewards[i] - self.values[i]
+ gradient_loss += self.log_probs[i] * advantage
+ value_loss += torch.pow(advantage, 2)
+ loss = gradient_loss + value_loss
+ loss.backward()
+ self.ac_optimizer.step()
+
+ self.rewards = []
+ self.values = []
+ self.log_probs = []
+ self.collect = []
+
+ def generate(self):
+ self.collect = []
+ while len(self.collect) < self.n_states:
+ config = self.config
+ for i in range(self.n_steps):
+ value, policy_dist = self.actor_critic(config.to_torch())
+ dist = policy_dist.detach().numpy()
+
+ if random.uniform(0, 1) < 0.1:
+ action = random.choice(range(self.dim_actions))
+ else:
+ action = np.random.choice(
+ self.dim_actions, p=np.squeeze(dist))
+
+ log_prob = torch.log(policy_dist.squeeze(0)[action])
+ # entropy = -np.sum(np.mean(dist) * np.log(dist))
+ flag, new_config = config.step(self.action_space[action])
+
+ if (flag and new_config not in self.population
+ and new_config not in self.collect):
+ self.collect.append(new_config)
+ self.log_probs.append(log_prob)
+ self.values.append(value)
+
+ config = new_config
+ # print([math.exp(float(i)) for i in self.log_probs])
+
+ self.waiting_rewards = [0.0] * len(self.collect)
+ return copy.deepcopy(self.collect)
+
+
+class N_A2C(Tuner):
+ """N-A2C Tuner
+ Based on paper Compiler-Level Matrix Multiplication Optimization for Deep Learning
+
+ Parameters
+ ----------
+ optimize_mode: str, 'maximize' or 'minimize'
+ n_states: int,
+ The maximum search steps Tau
+ n_steps: int
+ number of steps to train the policy and critic networks each iteration
+ hidden_size: int,
+ number of hidden size of the policy and critic networks
+ lr: float,
+ learning rate of the policy and critic networks
+ """
+
+ def __init__(self,
+ optimize_mode="maximize",
+ n_states=6,
+ n_steps=3,
+ hidden_size=128,
+ lr=1e-3):
+ self.logger = logging.getLogger(
+ self.__module__ + "." + self.__class__.__name__)
+ self.logger.setLevel('DEBUG')
+
+ self.opt_mode = optimize_mode
+ self.n_states = n_states
+ self.n_steps = n_steps
+ self.hidden_size = 128
+ self.lr = lr
+
+ self.request_list = []
+ self.serve_list = []
+ self.wait_dict = {}
+
+ def update_search_space(self, search_space):
+ """Update the self.bounds and self.types by the search_space.json file.
+
+ Override of the abstract method in :class:`~nni.tuner.Tuner`.
+ """
+ if not isinstance(search_space, dict):
+ self.logger.info("The format of search space is not a dict.")
+ raise RuntimeError("The format of search space is not a dict.")
+
+ self.population = \
+ Population(
+ search_space,
+ self.opt_mode,
+ self.n_states,
+ self.n_steps,
+ self.hidden_size,
+ self.lr
+ )
+
+ if not self.serve_list:
+ self.serve_list = self.population.generate()
+
+ def generate_multiple_parameters(self, parameter_id_list, **kwargs):
+ """Returns multiple sets of trial (hyper-)parameters,
+ as iterable of serializable objects.
+ """
+ result = []
+ self.send_trial_callback = kwargs['st_callback']
+ for parameter_id in parameter_id_list:
+ had_exception = False
+ try:
+ self.logger.debug("generating param for %s", parameter_id)
+ res = self.generate_parameters(parameter_id, **kwargs)
+ except nni.NoMoreTrialError:
+ had_exception = True
+ if not had_exception:
+ result.append(res)
+ return result
+
+ def generate_parameters(self, parameter_id, **kwargs):
+ """Method which provides one set of hyper-parameters.
+
+ Override of the abstract method in :class:`~nni.tuner.Tuner`.
+ """
+ if self.serve_list:
+ self.wait_dict[parameter_id] = self.serve_list.pop()
+ return self.wait_dict[parameter_id].pick_out()
+ else:
+ self.request_list.append(parameter_id)
+ raise nni.NoMoreTrialError('no more parameters now.')
+
+ def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
+ """Method invoked when a trial reports its final result.
+
+ Override of the abstract method in :class:`~nni.tuner.Tuner`.
+ """
+ if isinstance(value, dict):
+ value = value['default']
+
+ self.population.append(self.wait_dict[parameter_id], value)
+ del self.wait_dict[parameter_id]
+
+ if not self.serve_list and not self.wait_dict:
+ self.serve_list = self.population.generate()
+ if not self.serve_list:
+ raise RuntimeError("Tuner stopped since no candidates")
+
+ while self.request_list and self.serve_list:
+ param_id = self.request_list[0]
+ self.wait_dict[param_id] = self.serve_list.pop()
+ self.send_trial_callback(
+ param_id, self.wait_dict[param_id].pick_out())
+ self.request_list.pop(0)
+
+ # print('request_list: ' + str(len(self.request_list)))
+ # print('serve_list: ' + str(len(self.serve_list)))
+ # print('wait_dict: ' + str(len(self.wait_dict.keys())))
+
+ def trial_end(self, parameter_id, success, **kwargs):
+ """Method invoked when a trial is completed or terminated.
+
+ Override of the abstract method in :class:`~nni.tuner.Tuner`.
+ """
+ if not success:
+ self.population.append(self.wait_dict[parameter_id], 0.0)
+ del self.wait_dict[parameter_id]
+
+ if not self.serve_list and not self.wait_dict:
+ self.serve_list = self.population.generate()
+ if not self.serve_list:
+ raise RuntimeError("Tuner stopped since no candidates")
+
+ while self.request_list and self.serve_list:
+ param_id = self.request_list[0]
+ self.wait_dict[param_id] = self.serve_list.pop()
+ self.send_trial_callback(
+ param_id, self.wait_dict[param_id].pick_out())
+ self.request_list.pop(0)
+
+ # print('trial_end request_list: ' + str(len(self.request_list)))
+ # print('trial_end serve_list: ' + str(len(self.serve_list)))
+ # print('trial_end wait_dict: ' + str(len(self.wait_dict.keys())))
diff --git a/examples/trials/systems/opevo/src/algorithms/opevo.py b/examples/trials/systems/opevo/src/algorithms/opevo.py
new file mode 100644
index 0000000000..31346ab873
--- /dev/null
+++ b/examples/trials/systems/opevo/src/algorithms/opevo.py
@@ -0,0 +1,455 @@
+import math
+import logging
+import copy
+import random
+import numpy as np
+from itertools import permutations, combinations
+
+import nni
+from nni.tuner import Tuner
+
+
+class Parameter(object):
+ """Base class for all types of parameters
+ """
+ def mutate(self):
+ raise NotImplementedError
+
+ def reset(self):
+ raise NotImplementedError
+
+ def pick_out(self):
+ raise NotImplementedError
+
+ def get_cardinality(self):
+ raise NotImplementedError
+
+ def __eq__(self, other):
+ if isinstance(other, self.__class__):
+ return self.__dict__ == other.__dict__
+ else:
+ return False
+
+
+class Choice(Parameter):
+ """choice type parameter
+ """
+ def __init__(self, choices, mutate_rate):
+ self.choices = choices
+ self.value = random.choice(self.choices)
+ self.mutate_rate = mutate_rate
+
+ def get_cardinality(self):
+ return len(self.choices)
+
+ def reset(self):
+ self.value = random.choice(self.choices)
+
+ def mutate(self):
+ child = copy.deepcopy(self)
+ while random.uniform(0, 1) < child.mutate_rate:
+ choices = copy.deepcopy(child.choices)
+ choices.remove(child.value)
+ if choices:
+ child.value = random.choice(choices)
+ else:
+ break
+
+ return child
+
+ def pick_out(self):
+ return self.value
+
+
+class Discrete(Parameter):
+ """choice type parameter
+ """
+ def __init__(self, numbers, mutate_rate):
+ numbers.sort()
+ self.numbers = numbers
+ self.value = random.choice(self.numbers)
+ self.mutate_rate = mutate_rate
+
+ def get_cardinality(self):
+ return len(self.numbers)
+
+ def reset(self):
+ self.value = random.choice(self.numbers)
+
+ def mutate(self):
+ child = copy.deepcopy(self)
+ while random.uniform(0, 1) < child.mutate_rate:
+ idx = child.numbers.index(child.value)
+ if idx == 0 and idx + 1 < len(child.numbers):
+ child.value = child.numbers[idx + 1]
+ elif idx + 1 == len(child.numbers) and idx - 1 >= 0:
+ child.value = child.numbers[idx - 1]
+ elif idx == 0 and idx + 1 == len(child.numbers):
+ break
+ else:
+ shift = random.choice([-1, 1])
+ child.value = child.numbers[idx + shift]
+
+ return child
+
+ def pick_out(self):
+ return self.value
+
+
+class Factor(Parameter):
+ """factor type parameter
+ """
+ def __init__(self, value, mutate_rate):
+ self.product, self.num = value
+ self.mutate_rate = mutate_rate
+ self.all_partitions = self._get_all_partitions(self.product, self.num)
+ self.partition = random.choice(self.all_partitions)
+
+ def reset(self):
+ self.partition = random.choice(self.all_partitions)
+
+ def get_cardinality(self):
+ return len(self.all_partitions)
+
+ def mutate(self):
+ child = copy.deepcopy(self)
+ while random.uniform(0, 1) < self.mutate_rate:
+ action = random.choice(child._get_actions())
+ child._step(action)
+
+ return child
+
+ def pick_out(self):
+ return self.partition
+
+ def _step(self, action):
+ self.partition[action[0]] = int(self.partition[action[0]] / action[2])
+ self.partition[action[1]] = int(self.partition[action[1]] * action[2])
+
+ def _get_actions(self):
+ actions = []
+ prime_factors = self._get_prime_factors(self.product, False)
+ for i in range(self.num):
+ for j in range(self.num):
+ if i != j:
+ for k in range(len(prime_factors)):
+ action = [i]
+ action.append(j)
+ action.append(prime_factors[k])
+ if self.partition[action[0]] % action[2] == 0:
+ actions.append(action)
+ return actions
+
+ def __repr__(self):
+ string = "["
+ for factor in self.partition:
+ string += factor.__repr__() + " "
+ string = string[:-1] + "]"
+
+ return string
+
+ def _get_all_partitions(self, product, num):
+ # get all prime factors with repetition
+ prime_factors = self._get_prime_factors(product)
+
+ # group all prime factors
+ groups = {}
+ for prime_factor in prime_factors:
+ if prime_factor in groups.keys():
+ groups[prime_factor] += 1
+ else:
+ groups[prime_factor] = 1
+
+ # partition each group
+ for key, value in groups.items():
+ partitions = []
+ for comb in combinations(range(value + num - 1), num - 1):
+ # print(comb)
+ partition = []
+ start_idx = -1
+ for idx in comb:
+ partition.append(key**(idx - start_idx - 1))
+ start_idx = idx
+ partition.append(key**(value + num - 2 - start_idx))
+ partitions.append(partition)
+ groups[key] = partitions
+
+ # generate partitions
+ partitions = []
+
+ def part(groups, mul=[]):
+ if not groups:
+ partition = [1] * num
+ for i in range(num):
+ for m in mul:
+ partition[i] *= m[i]
+ partitions.append(partition)
+
+ for key, group in groups.items():
+ for partition in group:
+ mul.append(partition)
+ tmp = copy.deepcopy(groups)
+ del tmp[key]
+ part(tmp, mul)
+ mul.pop()
+ break
+
+ part(groups)
+ return partitions
+
+ def _get_prime_factors(self, n, repeat=True):
+ prime_factors = []
+
+ while n % 2 == 0:
+ if 2 not in prime_factors:
+ prime_factors.append(2)
+ elif repeat:
+ prime_factors.append(2)
+ n = n / 2
+
+ for i in range(3, int(math.sqrt(n)) + 1, 2):
+ while n % i == 0:
+ if i not in prime_factors:
+ prime_factors.append(i)
+ elif repeat:
+ prime_factors.append(i)
+ n = n / i
+
+ if n > 2:
+ prime_factors.append(int(n))
+
+ return prime_factors
+
+
+class Individual(object):
+ """Individual class
+ """
+ def __init__(self, search_space, mutate_rate):
+ self.params = {}
+ for key in search_space.keys():
+ if search_space[key]['_type'] == 'choice':
+ self.params[key] = \
+ Choice(search_space[key]['_value'], mutate_rate)
+ elif search_space[key]['_type'] == 'discrete':
+ self.params[key] = \
+ Discrete(search_space[key]['_value'], mutate_rate)
+ elif search_space[key]['_type'] == 'factor':
+ self.params[key] = \
+ Factor(search_space[key]['_value'], mutate_rate)
+ else:
+ raise RuntimeError(
+ "OpEvo Tuner doesn't support this kind of parameter: "
+ + str(search_space[key]['_type'])
+ )
+
+ def __eq__(self, other):
+ if isinstance(other, self.__class__):
+ return self.__dict__ == other.__dict__
+ else:
+ return False
+
+ def __repr__(self):
+ string = ""
+ for param in self.params:
+ string += param.__repr__() + '\n'
+
+ return string
+
+ def mutate(self):
+ child = copy.deepcopy(self)
+ for key in child.params.keys():
+ child.params[key] = child.params[key].mutate()
+
+ return child
+
+ def reset(self):
+ for key in self.params.keys():
+ self.params[key].reset()
+
+ return self
+
+ def pick_out(self):
+ output = {}
+ for key in self.params.keys():
+ output[key] = self.params[key].pick_out()
+
+ return output
+
+
+class Population(object):
+ """Population class
+ """
+
+ def __init__(self, search_space, mutate_rate, opt_mode='maximize'):
+ self.search_space = search_space
+ self.mutate_rate = mutate_rate
+ self.opt_mode = opt_mode
+ self.population = []
+ self.fitness = []
+
+ self.individual = Individual(self.search_space, self.mutate_rate)
+ self.volume = 1
+ for key, value in self.individual.params.items():
+ self.volume *= self.individual.params[key].get_cardinality()
+
+ def append(self, individual, fitness):
+ if self.opt_mode == "minimize":
+ fitness = -1 * fitness
+
+ self.population.insert(0, individual)
+ self.fitness.insert(0, fitness)
+
+ i = 0
+ while (i < len(self.fitness) - 1
+ and self.fitness[i] < self.fitness[i + 1]):
+ self.fitness[i], self.fitness[i + 1] = \
+ self.fitness[i + 1], self.fitness[i]
+ self.population[i], self.population[i + 1] = \
+ self.population[i + 1], self.population[i]
+ i += 1
+
+ def get_offspring(self, parents_size, offspring_size):
+ children = []
+ if len(self.fitness) < parents_size:
+ for _ in range(offspring_size):
+ child = copy.deepcopy(self.individual.reset())
+ while child in self.population or child in children:
+ child = child.mutate()
+ children.append(child)
+ elif self.fitness[0] < 1e-3:
+ for _ in range(offspring_size):
+ child = copy.deepcopy(self.individual.reset())
+ while child in self.population or child in children:
+ child = child.mutate()
+ children.append(child)
+ else:
+ prob = np.array(self.fitness[:parents_size]) / \
+ np.sum(self.fitness[:parents_size])
+
+ for _ in range(offspring_size):
+ child = copy.deepcopy(self.population[0])
+ for key in child.params.keys():
+ idx = np.random.choice(range(parents_size), p=prob)
+ child.params[key] = self.population[idx].params[key]
+ child = child.mutate()
+ while child in self.population or child in children:
+ child = child.mutate()
+ children.append(child)
+
+ return children
+
+
+class OpEvo(Tuner):
+ """OpEvo Tuner
+
+ Parameters
+ ----------
+ optimize_mode: str, 'maximize' or 'minimize'
+ parents_size: int
+ offspring_size: int
+ parents_size and offspring_size govern the diversity in evolutionary
+ process. OpEvo with large parents_size and offspring_size tends to get
+ rid of suboptimum but sacrifice data efficiency, while one with smaller
+ parants_size and offspring_size is easier to converge but suffers suboptimum.
+ mutate_rate: float, (0, 1)
+ Mutation rate ranging from 0 to 1. It trade-offs the exploration and
+ exploitation. OpEvo tends to exploration as q approaches 0, while tends
+ to exploitation as q approaches 1.
+ """
+
+ def __init__(self,
+ optimize_mode="maximize",
+ parents_size=20,
+ offspring_size=20,
+ mutate_rate=0.5):
+ self.logger = logging.getLogger(
+ self.__module__ + "." + self.__class__.__name__)
+ self.logger.setLevel('DEBUG')
+
+ self.optimize_mode = optimize_mode
+ self.parents_size = parents_size
+ self.offspring_size = offspring_size
+ self.mutate_rate = mutate_rate
+
+ self.request_list = []
+ self.serve_list = []
+ self.wait_dict = {}
+
+ def update_search_space(self, search_space):
+ """Update the self.bounds and self.types by the search_space.json file.
+
+ Override of the abstract method in :class:`~nni.tuner.Tuner`.
+ """
+ if not isinstance(search_space, dict):
+ self.logger.info("The format of search space is not a dict.")
+ raise RuntimeError("The format of search space is not a dict.")
+
+ self.population = Population(search_space,
+ self.mutate_rate,
+ self.optimize_mode)
+ self.logger.debug('Total search space volume: '
+ + str(self.population.volume))
+
+ if not self.serve_list:
+ self.serve_list = self.population.get_offspring(
+ self.parents_size, self.offspring_size)
+
+ def generate_multiple_parameters(self, parameter_id_list, **kwargs):
+ """Returns multiple sets of trial (hyper-)parameters,
+ as iterable of serializable objects.
+ """
+ result = []
+ self.send_trial_callback = kwargs['st_callback']
+ for parameter_id in parameter_id_list:
+ had_exception = False
+ try:
+ self.logger.debug("generating param for %s", parameter_id)
+ res = self.generate_parameters(parameter_id, **kwargs)
+ except nni.NoMoreTrialError:
+ had_exception = True
+ if not had_exception:
+ result.append(res)
+ return result
+
+ def generate_parameters(self, parameter_id, **kwargs):
+ """Method which provides one set of hyper-parameters.
+
+ Override of the abstract method in :class:`~nni.tuner.Tuner`.
+ """
+ if self.serve_list:
+ self.wait_dict[parameter_id] = self.serve_list.pop()
+ return self.wait_dict[parameter_id].pick_out()
+ else:
+ self.request_list.append(parameter_id)
+ raise nni.NoMoreTrialError('no more parameters now.')
+
+ def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
+ """Method invoked when a trial reports its final result.
+
+ Override of the abstract method in :class:`~nni.tuner.Tuner`.
+ """
+ if isinstance(value, dict):
+ value = value['default']
+
+ self.population.append(self.wait_dict[parameter_id], value)
+ del self.wait_dict[parameter_id]
+
+ if not self.serve_list:
+ self.serve_list = self.population.get_offspring(
+ self.parents_size, self.offspring_size)
+
+ while self.request_list and self.serve_list:
+ param_id = self.request_list[0]
+ self.wait_dict[param_id] = self.serve_list.pop()
+ self.send_trial_callback(
+ param_id, self.wait_dict[param_id].pick_out())
+ self.request_list.pop(0)
+
+ def trial_end(self, parameter_id, success, **kwargs):
+ """Method invoked when a trial is completed or terminated.
+
+ Override of the abstract method in :class:`~nni.tuner.Tuner`.
+ """
+ if not success:
+ self.population.append(self.wait_dict[parameter_id], 0.0)
+ del self.wait_dict[parameter_id]
diff --git a/examples/trials/systems/opevo/src/compiler_auto_tune_stable.py b/examples/trials/systems/opevo/src/compiler_auto_tune_stable.py
new file mode 100644
index 0000000000..68a235c60c
--- /dev/null
+++ b/examples/trials/systems/opevo/src/compiler_auto_tune_stable.py
@@ -0,0 +1,399 @@
+#!/usr/bin/env python3
+
+## TODO: optimize c-mcpu metric; early-stop handler; fp16/int8; Kill pyRPC;
+
+import numpy as np
+import tvm
+import logging
+import math
+import re
+import sys, time, subprocess, os, random, hashlib
+from tvm import autotvm
+import topi
+import json
+from topi.util import get_const_tuple
+import importlib
+from tvm.autotvm.task.dispatcher import ApplyConfig
+from tvm.autotvm.task import ConfigEntity
+from threading import Timer
+
+backend = os.environ['BACKEND'] if 'BACKEND' in os.environ else 'c-cuda'
+
+def system_lock(key_ids):
+ import socket, time
+ occupied_sock = None
+ while not occupied_sock:
+ for key_id in key_ids:
+ try:
+ sock = socket.socket(socket.AF_INET,socket.SOCK_STREAM)
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ sock.bind(('127.0.0.1', 9050 + key_id))
+ sock.listen(1)
+ occupied_sock = (sock, key_id)
+ break
+ except:
+ try:
+ sock.shutdown(socket.SHUT_RDWR)
+ sock.close()
+ except:
+ sock.close()
+ if occupied_sock:
+ break
+ # print('still waiting ..')
+ time.sleep(0.2)
+
+ # print('Using key_id = %d' % occupied_sock[1])
+ sock = occupied_sock[0]
+
+ def unlock_fd():
+ try:
+ sock.shutdown(socket.SHUT_RDWR)
+ sock.close()
+ except:
+ sock.close()
+ return unlock_fd, occupied_sock[1]
+
+def show_search_space(config_space, printable):
+ search_space = {}
+ for _, name in enumerate(config_space.space_map):
+ curr = config_space.space_map[name]
+ if (curr.__class__ == tvm.autotvm.task.space.SplitSpace):
+ search_space[name] = {"_type": "factor", "_value": [curr.product, curr.num_output]}
+ elif (curr.__class__ == tvm.autotvm.task.space.OtherOptionSpace):
+ search_space[name] = {"_type": "choice", "_value": [x.val for x in curr.entities]}
+ else:
+ raise Exception("Cannot recognize search space type: %s" % (config_space.space_map[name].__class__))
+ json_space = json.dumps(search_space)
+ print("\n>> Search Space = %s" % json_space)
+ if printable:
+ print("\n>> Writing Search Space to './search_space.json'..")
+ with open("search_space.json", "w") as fp:
+ fp.write(json_space)
+ print("\n>> Done")
+ sys.exit(0)
+
+def get_tuning_parallism():
+ if 'DEV_NUM' in os.environ:
+ dev_num = int(os.environ['DEV_NUM'])
+ else:
+ if backend in ['c-rocm', '#rocm']:
+ devices = subprocess.getoutput('/opt/rocm/bin/rocm_agent_enumerator | grep -v gfx000').split()
+ if not devices:
+ raise Exception("Not valid rocm device found.")
+ dev_num = len(devices)
+ elif backend in ['c-cuda', '#cuda']:
+ devices = subprocess.getoutput('ls /dev/nvidia[0-9]* 2>/dev/null').split()
+ if not devices:
+ raise Exception("Not valid rocm device found.")
+ dev_num = len(devices)
+ else:
+ raise Exception("Unrecognized backend: %s" % backend)
+ print(' >> Tuning parallism = %d' % dev_num)
+ return dev_num
+
+def local_get_dir_file(rel_file, dir_sid=None):
+ if not dir_sid:
+ dir_sid = os.environ['DIR_SID'] if 'DIR_SID' in os.environ else '_'
+ dir_space = '/tmp/tvm_autotvm_engine'
+ os.system('mkdir -p "%s/%s"' % (dir_space, dir_sid))
+ return "%s/%s/%s" % (dir_space, dir_sid, rel_file)
+
+def run_process_with_timeout(args, timeout=None, envs=None):
+ try:
+ proc = subprocess.Popen(args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, env=envs)
+ retcode = proc.wait(timeout=timeout)
+ return retcode == 0
+ except subprocess.TimeoutExpired:
+ print('Timed out - killing', proc.pid)
+ proc.kill()
+ return False
+
+def parse_launch_bounds(code):
+ func_arr = code.split('extern "C" __global__ ')
+ for i in range(1, len(func_arr)):
+ axis_map = dict()
+ lines = func_arr[i].split('\n')
+ for it in lines:
+ if it.startswith(' // [thread_extent] '):
+ words = it.split(' ')
+ nthread = int(words[-1])
+ axis = words[-3]
+ if axis in axis_map:
+ if axis_map[axis] != nthread:
+ assert(False)
+ else:
+ axis_map[axis] = nthread
+ block_bound = axis_map.get('threadIdx.x', 1) * axis_map.get('threadIdx.y', 1) * axis_map.get('threadIdx.z', 1)
+ func_arr[i] = 'extern "C" __global__ __launch_bounds__(%d) %s' % (block_bound, func_arr[i])
+
+ code = ''.join(func_arr)
+ return code
+
+def translate_code(code):
+ if backend == 'c-rocm':
+ code = parse_launch_bounds(code)
+ code = '#include \n#include \n\n'+ code.replace('(__shared__ float4*)', '(float4*)').replace('#include ', '').replace('typedef unsigned long long uint64_t;', '')
+ elif backend in ['#cuda', 'c-cuda']:
+ code = parse_launch_bounds(code)
+ code = '#include \n#include \n\n' + code
+ else:
+ raise Exception("Unrecognized backend: %s" % backend)
+ return code
+
+@tvm.register_func
+def tvm_callback_backend_proc(code):
+ native_code = translate_code(code)
+ # Compile code
+ module_data = None
+ if backend == 'c-rocm':
+ gcn_arch = subprocess.getoutput('/opt/rocm/bin/rocm_agent_enumerator | sort | uniq | grep -v gfx000 | tail -n 1').strip()
+ if not gcn_arch:
+ raise RuntimeError("Compilation error: no valid gcn_arch gpu detected!")
+ temp_code = local_get_dir_file("my_kernel.cc")
+ temp_cobj = local_get_dir_file("my_kernel.hsaco")
+ args = ['/opt/rocm/bin/lpl', temp_code, '-t=' + gcn_arch, '-f="-Wno-ignored-attributes -D__HIP_PLATFORM_HCC__=1"', '-o', temp_cobj]
+ elif backend in ['#cuda', 'c-cuda']:
+ temp_code = local_get_dir_file("my_kernel.cu")
+ temp_cobj = local_get_dir_file("my_kernel.ptx")
+ args = ['/usr/local/cuda/bin/nvcc', temp_code, '--ptx', '-O3', '-o', temp_cobj]
+ else:
+ raise Exception("Unrecognized backend: %s" % backend)
+ with open(temp_code, 'w') as fp:
+ fp.write(native_code)
+ print('[Build @%x]' % os.getpid(), ' '.join(args))
+ if not run_process_with_timeout(args, 10):
+ raise Exception("Compilation failed or time limit exceeded")
+ if module_data is None:
+ module_data = bytearray(open(temp_cobj, "rb").read())
+ return module_data
+
+def run_config_entity(params_given, dir_sid, expected_timecost='inf', tune_slot_id=0):
+ dir_sid = str(dir_sid)
+ result_file = local_get_dir_file('result.txt', dir_sid)
+ try:
+ os.remove(result_file)
+ except:
+ pass
+ config_str = json.dumps(params_given)
+ envs = os.environ.copy()
+ envs['CONFIG'] = config_str
+ envs['DIR_SID'] = dir_sid
+ envs['CUDA_VISIBLE_DEVICES'] = str(tune_slot_id)
+ print(" >> Try param_entity on sid = %s: config = %s, slot_id = %d" % (dir_sid, config_str, tune_slot_id))
+ try:
+ assert(True == run_process_with_timeout(["python%d" % sys.version_info.major] + sys.argv, envs=envs))
+ result = float(open(result_file, 'r').read().strip())
+ except:
+ result = float('inf')
+ print(" >> Try param_entity on sid = %s: result = `%.6f`" % (dir_sid, result))
+ return result
+
+def compute_gflops(flop, t):
+ return flop / (t * 1e3) / 1e6
+
+def search_op_config(code_only=False):
+ tvm_target = 'cuda'
+ logging.getLogger('autotvm').setLevel(logging.DEBUG)
+ logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout))
+
+ default_tune_op = importlib.import_module('templates.' + (os.environ['OP']))
+ print(' >> Backend = %s, Python PID = %s, Task = %s;' % (backend, os.getpid(), default_tune_op.__name__))
+
+ task = autotvm.task.create(default_tune_op.get_template_op, args=(), target=tvm_target)
+ op_attributes = default_tune_op.op_attributes
+ op_summary = '_'.join([k + str(op_attributes[k]) for k in op_attributes])
+
+ def json_to_config(json_dict):
+ config = ConfigEntity.from_json_dict({"i": -1, "t": "", "c": None, "e": json_dict})
+ return config
+
+ def config_to_json(config):
+ jobj = config.to_json_dict()['e']
+ json_dict = dict()
+ for i in range(len(jobj)):
+ assert(jobj[i][1] in ['sp', 'ot'])
+ json_dict[jobj[i][0]] = jobj[i][2]
+ return json_dict
+
+ num_trials = int(os.environ['STEP']) if 'STEP' in os.environ else 0
+
+ if 'CONFIG' in os.environ:
+ params_given = json.loads(os.environ['CONFIG'])
+ print("====>> [Current Config Option]", os.environ['CONFIG'])
+
+ trial_config = []
+ for key in params_given:
+ trial_config.append([key, "sp" if type(params_given[key]) is list else "ot", params_given[key]])
+ best_config = json_to_config(trial_config)
+
+ elif 'NNI_TRIAL_JOB_ID' in os.environ:
+ show_search_space(task.config_space, os.environ['NNI_TRIAL_JOB_ID'] == '@')
+ import nni
+ params_given = nni.get_next_parameter()
+ if params_given is None:
+ raise
+ local_dir_id = os.environ['NNI_TRIAL_JOB_ID']
+ t = run_config_entity(params_given, local_dir_id)
+ gflops = compute_gflops(task.flop, t)
+ print('[TVM-engine] Final entity result is: %g' % gflops)
+ try:
+ nni.report_final_result(gflops)
+ except:
+ print('[TVM-engine] (not reporting final result to NNI.)')
+ exit(0)
+
+ elif num_trials > 0:
+ n_parallel = 16 if 'BATCH' not in os.environ else int(os.environ['BATCH'])
+ measure_option = autotvm.measure_option(
+ builder=autotvm.LocalBuilder(n_parallel=n_parallel),
+ runner=autotvm.LocalRunner(repeat=3, min_repeat_ms=100, timeout=4)
+ )
+ # if DO_TUNING:
+ tuner = autotvm.tuner.XGBTuner(task, num_threads=8)
+
+ from concurrent.futures import ThreadPoolExecutor
+ thread_pool = ThreadPoolExecutor(max_workers=n_parallel)
+
+ dev_num = get_tuning_parallism()
+
+ def parse_configs(task, configs):
+ results = []
+ futures = []
+ expected_timecost = 'inf'
+ for i in range(len(configs)):
+ futures.append(thread_pool.submit(run_config_entity, config_to_json(configs[i]), i, expected_timecost, i % dev_num))
+ for i in range(len(configs)):
+ t = futures[i].result()
+ if t < tuner.task.best_config[0]:
+ tuner.task.best_config = (t, configs[i])
+ results.append(autotvm.measure.MeasureResult(costs=(t,), error_no=0, all_cost=i, timestamp=time.time()))
+ return results
+
+ tuner.task.best_config = (float('inf'), None)
+ tuner.parse_configs = parse_configs
+
+ tuner.tune(n_trial=num_trials, measure_option=measure_option, callbacks=[])
+ assert(not math.isinf(tuner.task.best_config[0]))
+ best_config = tuner.task.best_config[1]
+ print('\n[Best Config]', json.dumps(config_to_json(best_config)))
+ else:
+ best_config = task.config_space
+
+ with ApplyConfig(best_config):
+ with tvm.target.create(tvm_target):
+ s, arg_bufs = default_tune_op.get_template_op()
+ lower_source = str(tvm.lower(s, arg_bufs, simple_mode=True))
+
+ # Verify Source Code
+ assert(len(('\n' + lower_source).split('\nproduce ')) == 2)
+ lower_file = local_get_dir_file('my_kernel.lower')
+ with open(lower_file, 'w') as fp:
+ fp.write(lower_source)
+
+ max_threads_per_block = tvm.ndarray.gpu(0).max_threads_per_block
+ max_shared_memory_per_block = tvm.ndarray.gpu(0).max_shared_memory_per_block
+
+ thread_extents = subprocess.getoutput("cat '%s' | grep '^ *// attr.*iter_var.*thread_extent'" % (lower_file)).split('\n')
+ reserved_axes = dict({'threadIdx.x': None, 'threadIdx.y': None, 'threadIdx.z': None, 'blockIdx.x': None, 'blockIdx.y': None, 'blockIdx.z': None})
+ for line in thread_extents:
+ thread_name = line.split('[iter_var(')[-1].split(',')[0]
+ if thread_name in reserved_axes:
+ thread_val = int(line.split('thread_extent = ')[-1])
+ if reserved_axes[thread_name] is not None:
+ if reserved_axes[thread_name] != thread_val:
+ assert(False)
+ else:
+ reserved_axes[thread_name] = thread_val
+ else:
+ raise Exception("Invalid thread_axis name: %s" % thread_name)
+
+ num_threads = 1
+ for thread_name in ['threadIdx.x', 'threadIdx.y', 'threadIdx.z']:
+ if reserved_axes[thread_name] is not None:
+ num_threads *= reserved_axes[thread_name]
+ if num_threads > max_threads_per_block:
+ raise Exception("Invalid kernel code: using num_threads %d > max_threads_per_block %d" % (num_threads, max_threads_per_block))
+
+ allocate_shared = subprocess.getoutput("cat '%s' | grep 'allocate .*shared\[.*\]'" % (lower_file)).split('\n')
+ shared_memory_in_bytes = 0
+ for line in allocate_shared:
+ if not line:
+ continue
+ parts = line.split('[')
+ assert(len(parts) == 2)
+ parts = parts[1].split(' * ')
+ assert(len(parts) == 2)
+ assert(parts[1][-1] == ']')
+ allocate_type = parts[0]
+ allocate_val = int(parts[1][:-1])
+ if allocate_type in ['float32']:
+ shared_memory_in_bytes += allocate_val * 4
+ else:
+ raise Exception("Unrecognized shared memory data type: %s" % allocate_type)
+ if shared_memory_in_bytes > max_shared_memory_per_block:
+ raise Exception("Invalid kernel code: using shared_memory_in_bytes %d > max_shared_memory_per_block %d" % (shared_memory_in_bytes, max_shared_memory_per_block))
+
+ func = tvm.build(s, arg_bufs, tvm_target, name='template_op')
+
+ assert(len(func.imported_modules) == 1)
+ device_source = translate_code(func.imported_modules[0].get_source())
+
+ if code_only:
+ return device_source
+
+ if lower_source and device_source:
+ tune_slot_id = 0 if 'CUDA_VISIBLE_DEVICES' not in os.environ else int(os.environ['CUDA_VISIBLE_DEVICES'])
+ exec_fd, _ = system_lock([tune_slot_id])
+ gpu_id = 0
+ ctx = tvm.context(tvm_target, gpu_id)
+ tensors, outs = [], []
+ for arg in arg_bufs:
+ shape = [int(x) for x in arg.shape]
+ is_output = arg.op.__class__ != tvm.tensor.PlaceholderOp
+ from tvm._ffi.ndarray import empty
+ td = empty(shape, arg.dtype, ctx)
+ if is_output:
+ outs.append(td)
+ tensors.append(td)
+
+ def timeout_handler():
+ print("Error: Timeout during Kernel warmup")
+ os._exit(1)
+
+ my_timer = Timer(10, timeout_handler, [])
+ my_timer.start()
+ # Warmup
+ func(*tensors)
+ tvm.ndarray.gpu(gpu_id).sync()
+ # Estimate
+ t_start = time.time()
+ func(*tensors)
+ tvm.ndarray.gpu(gpu_id).sync()
+ t_diff = time.time() - t_start
+ my_timer.cancel()
+ del my_timer
+
+ num_runs = max(3, min(100, math.floor(1.0 / t_diff)))
+ timeout_seconds = math.ceil((num_runs + 5) * t_diff)
+ my_timer = Timer(timeout_seconds, timeout_handler, [])
+ my_timer.start()
+ timer_f = func.time_evaluator(func.entry_name, ctx, number=num_runs)
+ t = timer_f(*tensors).mean
+ my_timer.cancel()
+ exec_fd()
+
+ gflops = compute_gflops(task.flop, t)
+ print("[TVM-engine] Average time cost of %d runs = %g ms, %g gflops." % (num_runs, t * 1e3, gflops))
+
+ with open(local_get_dir_file('result.txt'), 'w') as fp:
+ fp.write(str(t))
+
+
+if __name__ == '__main__':
+ try:
+ search_op_config()
+ except SystemExit:
+ sys.exit(0)
+ except:
+ import traceback
+ traceback.print_exc()
diff --git a/examples/trials/systems/opevo/src/experiments/bmm/B960N128K128M64PNN/config_opevo.yml b/examples/trials/systems/opevo/src/experiments/bmm/B960N128K128M64PNN/config_opevo.yml
new file mode 100644
index 0000000000..5e2854e6d2
--- /dev/null
+++ b/examples/trials/systems/opevo/src/experiments/bmm/B960N128K128M64PNN/config_opevo.yml
@@ -0,0 +1,25 @@
+authorName: default
+experimentName: BatchMatMul_B960N128K128M64PNN_OPEVO
+trialConcurrency: 8
+maxExecDuration: 24h
+maxTrialNum: 512
+#choice: local, remote, pai
+trainingServicePlatform: local
+searchSpacePath: search_space.json
+#choice: true, false
+useAnnotation: false
+tuner:
+ codeDir: /root/algorithms/
+ classFileName: opevo.py
+ className: OpEvo
+ # Any parameter need to pass to your tuner class __init__ constructor
+ # can be specified in this optional classArgs field, for example
+ classArgs:
+ optimize_mode: maximize
+ parents_size: 8
+ offspring_size: 8
+ mutate_rate: 0.5
+trial:
+ command: OP=batch_matmul B=960 N=128 K=128 M=64 P=NN ./run.sh
+ codeDir: /root
+ # gpuNum: 0
diff --git a/examples/trials/systems/opevo/src/experiments/bmm/B960N128K128M64PNN/search_space.json b/examples/trials/systems/opevo/src/experiments/bmm/B960N128K128M64PNN/search_space.json
new file mode 100644
index 0000000000..ef533e6ece
--- /dev/null
+++ b/examples/trials/systems/opevo/src/experiments/bmm/B960N128K128M64PNN/search_space.json
@@ -0,0 +1 @@
+{"B": {"_type": "factor", "_value": [960, 2]}, "K": {"_type": "factor", "_value": [128, 3]}, "X": {"_type": "factor", "_value": [128, 4]}, "Y": {"_type": "factor", "_value": [64, 4]}}
\ No newline at end of file
diff --git a/examples/trials/systems/opevo/src/experiments/bmm/B960N128K128M64PTN/config_opevo.yml b/examples/trials/systems/opevo/src/experiments/bmm/B960N128K128M64PTN/config_opevo.yml
new file mode 100644
index 0000000000..a84b80a70d
--- /dev/null
+++ b/examples/trials/systems/opevo/src/experiments/bmm/B960N128K128M64PTN/config_opevo.yml
@@ -0,0 +1,25 @@
+authorName: default
+experimentName: BatchMatMul_B960N128K128M64PTN_OPEVO
+trialConcurrency: 8
+maxExecDuration: 24h
+maxTrialNum: 512
+#choice: local, remote, pai
+trainingServicePlatform: local
+searchSpacePath: search_space.json
+#choice: true, false
+useAnnotation: false
+tuner:
+ codeDir: /root/algorithms/
+ classFileName: opevo.py
+ className: OpEvo
+ # Any parameter need to pass to your tuner class __init__ constructor
+ # can be specified in this optional classArgs field, for example
+ classArgs:
+ optimize_mode: maximize
+ parents_size: 8
+ offspring_size: 8
+ mutate_rate: 0.5
+trial:
+ command: OP=batch_matmul B=960 N=128 K=128 M=64 P=TN ./run.sh
+ codeDir: /root
+ # gpuNum: 0
diff --git a/examples/trials/systems/opevo/src/experiments/bmm/B960N128K128M64PTN/search_space.json b/examples/trials/systems/opevo/src/experiments/bmm/B960N128K128M64PTN/search_space.json
new file mode 100644
index 0000000000..ef533e6ece
--- /dev/null
+++ b/examples/trials/systems/opevo/src/experiments/bmm/B960N128K128M64PTN/search_space.json
@@ -0,0 +1 @@
+{"B": {"_type": "factor", "_value": [960, 2]}, "K": {"_type": "factor", "_value": [128, 3]}, "X": {"_type": "factor", "_value": [128, 4]}, "Y": {"_type": "factor", "_value": [64, 4]}}
\ No newline at end of file
diff --git a/examples/trials/systems/opevo/src/experiments/bmm/B960N128K64M128PNT/config_opevo.yml b/examples/trials/systems/opevo/src/experiments/bmm/B960N128K64M128PNT/config_opevo.yml
new file mode 100644
index 0000000000..36e04b4da3
--- /dev/null
+++ b/examples/trials/systems/opevo/src/experiments/bmm/B960N128K64M128PNT/config_opevo.yml
@@ -0,0 +1,25 @@
+authorName: default
+experimentName: BatchMatMul_B960N128K64M128PNT_OPEVO
+trialConcurrency: 8
+maxExecDuration: 24h
+maxTrialNum: 512
+#choice: local, remote, pai
+trainingServicePlatform: local
+searchSpacePath: search_space.json
+#choice: true, false
+useAnnotation: false
+tuner:
+ codeDir: /root/algorithms/
+ classFileName: opevo.py
+ className: OpEvo
+ # Any parameter need to pass to your tuner class __init__ constructor
+ # can be specified in this optional classArgs field, for example
+ classArgs:
+ optimize_mode: maximize
+ parents_size: 8
+ offspring_size: 8
+ mutate_rate: 0.5
+trial:
+ command: OP=batch_matmul B=960 N=128 K=64 M=128 P=NT ./run.sh
+ codeDir: /root
+ # gpuNum: 0
diff --git a/examples/trials/systems/opevo/src/experiments/bmm/B960N128K64M128PNT/search_space.json b/examples/trials/systems/opevo/src/experiments/bmm/B960N128K64M128PNT/search_space.json
new file mode 100644
index 0000000000..a34020c349
--- /dev/null
+++ b/examples/trials/systems/opevo/src/experiments/bmm/B960N128K64M128PNT/search_space.json
@@ -0,0 +1 @@
+{"B": {"_type": "factor", "_value": [960, 2]}, "K": {"_type": "factor", "_value": [64, 3]}, "X": {"_type": "factor", "_value": [128, 4]}, "Y": {"_type": "factor", "_value": [128, 4]}}
\ No newline at end of file
diff --git a/examples/trials/systems/opevo/src/experiments/conv/N512C3HW227F64K11ST4PD0/config_opevo.yml b/examples/trials/systems/opevo/src/experiments/conv/N512C3HW227F64K11ST4PD0/config_opevo.yml
new file mode 100644
index 0000000000..d1b88e883c
--- /dev/null
+++ b/examples/trials/systems/opevo/src/experiments/conv/N512C3HW227F64K11ST4PD0/config_opevo.yml
@@ -0,0 +1,25 @@
+authorName: default
+experimentName: Conv_N512C3HW227F64K11ST4PD0_OPEVO
+trialConcurrency: 8
+maxExecDuration: 24h
+maxTrialNum: 512
+#choice: local, remote, pai
+trainingServicePlatform: local
+searchSpacePath: search_space.json
+#choice: true, false
+useAnnotation: false
+tuner:
+ codeDir: /root/algorithms/
+ classFileName: opevo.py
+ className: OpEvo
+ # Any parameter need to pass to your tuner class __init__ constructor
+ # can be specified in this optional classArgs field, for example
+ classArgs:
+ optimize_mode: maximize
+ parents_size: 8
+ offspring_size: 8
+ mutate_rate: 0.5
+trial:
+ command: OP=convfwd_direct N=512 C=3 H=227 W=227 F=64 K=11 ST=4 PD=0 ./run.sh
+ codeDir: /root
+ # gpuNum: 0
diff --git a/examples/trials/systems/opevo/src/experiments/conv/N512C3HW227F64K11ST4PD0/search_space.json b/examples/trials/systems/opevo/src/experiments/conv/N512C3HW227F64K11ST4PD0/search_space.json
new file mode 100644
index 0000000000..7f818119b9
--- /dev/null
+++ b/examples/trials/systems/opevo/src/experiments/conv/N512C3HW227F64K11ST4PD0/search_space.json
@@ -0,0 +1 @@
+{"tile_f": {"_type": "factor", "_value": [64, 4]}, "tile_y": {"_type": "factor", "_value": [55, 4]}, "tile_x": {"_type": "factor", "_value": [55, 4]}, "tile_rc": {"_type": "factor", "_value": [3, 2]}, "tile_ry": {"_type": "factor", "_value": [11, 2]}, "tile_rx": {"_type": "factor", "_value": [11, 2]}, "auto_unroll_max_step": {"_type": "discrete", "_value": [0, 125, 256]}, "unroll_explicit": {"_type": "choice", "_value": [0, 1]}}
diff --git a/examples/trials/systems/opevo/src/experiments/conv/N512C64HW27F192K5ST1PD2/config_opevo.yml b/examples/trials/systems/opevo/src/experiments/conv/N512C64HW27F192K5ST1PD2/config_opevo.yml
new file mode 100644
index 0000000000..1f84af8d67
--- /dev/null
+++ b/examples/trials/systems/opevo/src/experiments/conv/N512C64HW27F192K5ST1PD2/config_opevo.yml
@@ -0,0 +1,25 @@
+authorName: default
+experimentName: Conv_N512C64HW27F192K5ST1PD2_OPEVO
+trialConcurrency: 8
+maxExecDuration: 24h
+maxTrialNum: 512
+#choice: local, remote, pai
+trainingServicePlatform: local
+searchSpacePath: search_space.json
+#choice: true, false
+useAnnotation: false
+tuner:
+ codeDir: /root/algorithms/
+ classFileName: opevo.py
+ className: OpEvo
+ # Any parameter need to pass to your tuner class __init__ constructor
+ # can be specified in this optional classArgs field, for example
+ classArgs:
+ optimize_mode: maximize
+ parents_size: 8
+ offspring_size: 8
+ mutate_rate: 0.5
+trial:
+ command: OP=convfwd_direct N=512 C=64 H=27 W=27 F=192 K=5 ST=1 PD=2 ./run.sh
+ codeDir: /root
+ # gpuNum: 0
diff --git a/examples/trials/systems/opevo/src/experiments/conv/N512C64HW27F192K5ST1PD2/search_space.json b/examples/trials/systems/opevo/src/experiments/conv/N512C64HW27F192K5ST1PD2/search_space.json
new file mode 100644
index 0000000000..ee2e1fc2ba
--- /dev/null
+++ b/examples/trials/systems/opevo/src/experiments/conv/N512C64HW27F192K5ST1PD2/search_space.json
@@ -0,0 +1 @@
+{"tile_f": {"_type": "factor", "_value": [192, 4]}, "tile_y": {"_type": "factor", "_value": [27, 4]}, "tile_x": {"_type": "factor", "_value": [27, 4]}, "tile_rc": {"_type": "factor", "_value": [64, 2]}, "tile_ry": {"_type": "factor", "_value": [5, 2]}, "tile_rx": {"_type": "factor", "_value": [5, 2]}, "auto_unroll_max_step": {"_type": "discrete", "_value": [0, 125, 256]}, "unroll_explicit": {"_type": "choice", "_value": [0, 1]}}
diff --git a/examples/trials/systems/opevo/src/experiments/mm/N512K1024M1024/config_gbfs.yml b/examples/trials/systems/opevo/src/experiments/mm/N512K1024M1024/config_gbfs.yml
new file mode 100644
index 0000000000..b24388fc59
--- /dev/null
+++ b/examples/trials/systems/opevo/src/experiments/mm/N512K1024M1024/config_gbfs.yml
@@ -0,0 +1,23 @@
+authorName: default
+experimentName: MatMul_N512K1024M1024_GBFS
+trialConcurrency: 5
+maxExecDuration: 24h
+maxTrialNum: 512
+#choice: local, remote, pai
+trainingServicePlatform: local
+searchSpacePath: search_space.json
+#choice: true, false
+useAnnotation: false
+tuner:
+ codeDir: /root/algorithms/
+ classFileName: gbfs.py
+ className: G_BFS
+ # Any parameter need to pass to your tuner class __init__ constructor
+ # can be specified in this optional classArgs field, for example
+ classArgs:
+ optimize_mode: maximize
+ num_samples: 5
+trial:
+ command: OP=matmul N=512 K=1024 M=1024 P=NN ./run.sh
+ codeDir: /root
+ # gpuNum: 0
diff --git a/examples/trials/systems/opevo/src/experiments/mm/N512K1024M1024/config_na2c.yml b/examples/trials/systems/opevo/src/experiments/mm/N512K1024M1024/config_na2c.yml
new file mode 100644
index 0000000000..e7eb2df09e
--- /dev/null
+++ b/examples/trials/systems/opevo/src/experiments/mm/N512K1024M1024/config_na2c.yml
@@ -0,0 +1,22 @@
+authorName: default
+experimentName: MatMul_N512K1024M1024_NA2C
+trialConcurrency: 6
+maxExecDuration: 24h
+maxTrialNum: 512
+#choice: local, remote, pai
+trainingServicePlatform: local
+searchSpacePath: search_space.json
+#choice: true, false
+useAnnotation: false
+tuner:
+ codeDir: /root/algorithms/
+ classFileName: na2c.py
+ className: N_A2C
+ # Any parameter need to pass to your tuner class __init__ constructor
+ # can be specified in this optional classArgs field, for example
+ classArgs:
+ optimize_mode: maximize
+trial:
+ command: OP=matmul N=512 K=1024 M=1024 P=NN ./run.sh
+ codeDir: /root
+ # gpuNum: 0
diff --git a/examples/trials/systems/opevo/src/experiments/mm/N512K1024M1024/config_opevo.yml b/examples/trials/systems/opevo/src/experiments/mm/N512K1024M1024/config_opevo.yml
new file mode 100644
index 0000000000..bf04dbbf6b
--- /dev/null
+++ b/examples/trials/systems/opevo/src/experiments/mm/N512K1024M1024/config_opevo.yml
@@ -0,0 +1,25 @@
+authorName: default
+experimentName: MatMul_N512K1024M1024_OPEVO
+trialConcurrency: 8
+maxExecDuration: 24h
+maxTrialNum: 512
+#choice: local, remote, pai
+trainingServicePlatform: local
+searchSpacePath: search_space.json
+#choice: true, false
+useAnnotation: false
+tuner:
+ codeDir: /root/algorithms/
+ classFileName: opevo.py
+ className: OpEvo
+ # Any parameter need to pass to your tuner class __init__ constructor
+ # can be specified in this optional classArgs field, for example
+ classArgs:
+ optimize_mode: maximize
+ parents_size: 8
+ offspring_size: 8
+ mutate_rate: 0.5
+trial:
+ command: OP=matmul N=512 K=1024 M=1024 P=NN ./run.sh
+ codeDir: /root
+ # gpuNum: 0
diff --git a/examples/trials/systems/opevo/src/experiments/mm/N512K1024M1024/search_space.json b/examples/trials/systems/opevo/src/experiments/mm/N512K1024M1024/search_space.json
new file mode 100644
index 0000000000..af331ca6cb
--- /dev/null
+++ b/examples/trials/systems/opevo/src/experiments/mm/N512K1024M1024/search_space.json
@@ -0,0 +1 @@
+{"K": {"_type": "factor", "_value": [1024, 3]}, "X": {"_type": "factor", "_value": [512, 4]}, "Y": {"_type": "factor", "_value": [1024, 4]}}
diff --git a/examples/trials/systems/opevo/src/experiments/mm/N512K1024M4096/config_gbfs.yml b/examples/trials/systems/opevo/src/experiments/mm/N512K1024M4096/config_gbfs.yml
new file mode 100644
index 0000000000..1a26abf2ec
--- /dev/null
+++ b/examples/trials/systems/opevo/src/experiments/mm/N512K1024M4096/config_gbfs.yml
@@ -0,0 +1,23 @@
+authorName: default
+experimentName: MatMul_N512K1024M4096_GBFS
+trialConcurrency: 5
+maxExecDuration: 24h
+maxTrialNum: 512
+#choice: local, remote, pai
+trainingServicePlatform: local
+searchSpacePath: search_space.json
+#choice: true, false
+useAnnotation: false
+tuner:
+ codeDir: /root/algorithms/
+ classFileName: gbfs.py
+ className: G_BFS
+ # Any parameter need to pass to your tuner class __init__ constructor
+ # can be specified in this optional classArgs field, for example
+ classArgs:
+ optimize_mode: maximize
+ num_samples: 5
+trial:
+ command: OP=matmul N=512 K=1024 M=4096 P=NN ./run.sh
+ codeDir: /root
+ # gpuNum: 0
diff --git a/examples/trials/systems/opevo/src/experiments/mm/N512K1024M4096/config_na2c.yml b/examples/trials/systems/opevo/src/experiments/mm/N512K1024M4096/config_na2c.yml
new file mode 100644
index 0000000000..6defbf2072
--- /dev/null
+++ b/examples/trials/systems/opevo/src/experiments/mm/N512K1024M4096/config_na2c.yml
@@ -0,0 +1,22 @@
+authorName: default
+experimentName: MatMul_N512K1024M4096_NA2C
+trialConcurrency: 6
+maxExecDuration: 24h
+maxTrialNum: 512
+#choice: local, remote, pai
+trainingServicePlatform: local
+searchSpacePath: search_space.json
+#choice: true, false
+useAnnotation: false
+tuner:
+ codeDir: /root/algorithms/
+ classFileName: na2c.py
+ className: N_A2C
+ # Any parameter need to pass to your tuner class __init__ constructor
+ # can be specified in this optional classArgs field, for example
+ classArgs:
+ optimize_mode: maximize
+trial:
+ command: OP=matmul N=512 K=1024 M=4096 P=NN ./run.sh
+ codeDir: /root
+ # gpuNum: 0
diff --git a/examples/trials/systems/opevo/src/experiments/mm/N512K1024M4096/config_opevo.yml b/examples/trials/systems/opevo/src/experiments/mm/N512K1024M4096/config_opevo.yml
new file mode 100644
index 0000000000..6bb7b83065
--- /dev/null
+++ b/examples/trials/systems/opevo/src/experiments/mm/N512K1024M4096/config_opevo.yml
@@ -0,0 +1,25 @@
+authorName: default
+experimentName: MatMul_N512K1024M4096_OPEVO
+trialConcurrency: 8
+maxExecDuration: 24h
+maxTrialNum: 512
+#choice: local, remote, pai
+trainingServicePlatform: local
+searchSpacePath: search_space.json
+#choice: true, false
+useAnnotation: false
+tuner:
+ codeDir: /root/algorithms/
+ classFileName: opevo.py
+ className: OpEvo
+ # Any parameter need to pass to your tuner class __init__ constructor
+ # can be specified in this optional classArgs field, for example
+ classArgs:
+ optimize_mode: maximize
+ parents_size: 8
+ offspring_size: 8
+ mutate_rate: 0.5
+trial:
+ command: OP=matmul N=512 K=1024 M=4096 P=NN ./run.sh
+ codeDir: /root
+ # gpuNum: 0
diff --git a/examples/trials/systems/opevo/src/experiments/mm/N512K1024M4096/search_space.json b/examples/trials/systems/opevo/src/experiments/mm/N512K1024M4096/search_space.json
new file mode 100644
index 0000000000..fd7d449107
--- /dev/null
+++ b/examples/trials/systems/opevo/src/experiments/mm/N512K1024M4096/search_space.json
@@ -0,0 +1 @@
+{"K": {"_type": "factor", "_value": [1024, 3]}, "X": {"_type": "factor", "_value": [512, 4]}, "Y": {"_type": "factor", "_value": [4096, 4]}}
\ No newline at end of file
diff --git a/examples/trials/systems/opevo/src/experiments/mm/N512K4096M1024/config_gbfs.yml b/examples/trials/systems/opevo/src/experiments/mm/N512K4096M1024/config_gbfs.yml
new file mode 100644
index 0000000000..1de78a51f9
--- /dev/null
+++ b/examples/trials/systems/opevo/src/experiments/mm/N512K4096M1024/config_gbfs.yml
@@ -0,0 +1,23 @@
+authorName: default
+experimentName: MatMul_N512K4096M1024_GBFS
+trialConcurrency: 5
+maxExecDuration: 24h
+maxTrialNum: 512
+#choice: local, remote, pai
+trainingServicePlatform: local
+searchSpacePath: search_space.json
+#choice: true, false
+useAnnotation: false
+tuner:
+ codeDir: /root/algorithms/
+ classFileName: gbfs.py
+ className: G_BFS
+ # Any parameter need to pass to your tuner class __init__ constructor
+ # can be specified in this optional classArgs field, for example
+ classArgs:
+ optimize_mode: maximize
+ num_samples: 5
+trial:
+ command: OP=matmul N=512 K=4096 M=1024 P=NN ./run.sh
+ codeDir: /root
+ # gpuNum: 0
diff --git a/examples/trials/systems/opevo/src/experiments/mm/N512K4096M1024/config_na2c.yml b/examples/trials/systems/opevo/src/experiments/mm/N512K4096M1024/config_na2c.yml
new file mode 100644
index 0000000000..2ba096ce60
--- /dev/null
+++ b/examples/trials/systems/opevo/src/experiments/mm/N512K4096M1024/config_na2c.yml
@@ -0,0 +1,22 @@
+authorName: default
+experimentName: MatMul_N512K4096M1024_NA2C
+trialConcurrency: 6
+maxExecDuration: 24h
+maxTrialNum: 512
+#choice: local, remote, pai
+trainingServicePlatform: local
+searchSpacePath: search_space.json
+#choice: true, false
+useAnnotation: false
+tuner:
+ codeDir: /root/algorithms/
+ classFileName: na2c.py
+ className: N_A2C
+ # Any parameter need to pass to your tuner class __init__ constructor
+ # can be specified in this optional classArgs field, for example
+ classArgs:
+ optimize_mode: maximize
+trial:
+ command: OP=matmul N=512 K=4096 M=1024 P=NN ./run.sh
+ codeDir: /root
+ # gpuNum: 0
diff --git a/examples/trials/systems/opevo/src/experiments/mm/N512K4096M1024/config_opevo.yml b/examples/trials/systems/opevo/src/experiments/mm/N512K4096M1024/config_opevo.yml
new file mode 100644
index 0000000000..eb07e78b48
--- /dev/null
+++ b/examples/trials/systems/opevo/src/experiments/mm/N512K4096M1024/config_opevo.yml
@@ -0,0 +1,25 @@
+authorName: default
+experimentName: MatMul_N512K4096M1024_OPEVO
+trialConcurrency: 8
+maxExecDuration: 24h
+maxTrialNum: 512
+#choice: local, remote, pai
+trainingServicePlatform: local
+searchSpacePath: search_space.json
+#choice: true, false
+useAnnotation: false
+tuner:
+ codeDir: /root/algorithms/
+ classFileName: opevo.py
+ className: OpEvo
+ # Any parameter need to pass to your tuner class __init__ constructor
+ # can be specified in this optional classArgs field, for example
+ classArgs:
+ optimize_mode: maximize
+ parents_size: 8
+ offspring_size: 8
+ mutate_rate: 0.5
+trial:
+ command: OP=matmul N=512 K=4096 M=1024 P=NN ./run.sh
+ codeDir: /root
+ # gpuNum: 0
diff --git a/examples/trials/systems/opevo/src/experiments/mm/N512K4096M1024/search_space.json b/examples/trials/systems/opevo/src/experiments/mm/N512K4096M1024/search_space.json
new file mode 100644
index 0000000000..3dce726458
--- /dev/null
+++ b/examples/trials/systems/opevo/src/experiments/mm/N512K4096M1024/search_space.json
@@ -0,0 +1 @@
+{"K": {"_type": "factor", "_value": [4096, 3]}, "X": {"_type": "factor", "_value": [512, 4]}, "Y": {"_type": "factor", "_value": [1024, 4]}}
\ No newline at end of file
diff --git a/examples/trials/systems/opevo/src/run.sh b/examples/trials/systems/opevo/src/run.sh
new file mode 100644
index 0000000000..3b4c484ed0
--- /dev/null
+++ b/examples/trials/systems/opevo/src/run.sh
@@ -0,0 +1,25 @@
+#!/bin/bash -e
+
+cd $(dirname $0)
+
+export BACKEND=${BACKEND:-c-cuda}
+
+if [[ "${BACKEND}" == "c-cuda" ]]; then
+ export BACKEND="#cuda"
+fi
+
+if [[ "${BACKEND}" != "#cuda" ]]; then
+ export LD_LIBRARY_PATH=/opt/tvm/build
+else
+ export LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64
+fi
+
+export HIP_PLATFORM=hcc
+export HSA_USERPTR_FOR_PAGED_MEM=0
+export PYTHONDONTWRITEBYTECODE=1
+export PYTHONPATH=/opt/tvm/python:/opt/tvm/topi/python:/opt/tvm/nnvm/python:/usr/local/rocm/src
+
+ldconfig
+
+time OP=${OP:-matmul} S=${S:-0} python3 ./compiler_auto_tune_stable.py "$@"
+
diff --git a/examples/trials/systems/opevo/src/templates/batch_matmul.py b/examples/trials/systems/opevo/src/templates/batch_matmul.py
new file mode 100644
index 0000000000..4c84fbf56a
--- /dev/null
+++ b/examples/trials/systems/opevo/src/templates/batch_matmul.py
@@ -0,0 +1,119 @@
+import numpy as np
+import tvm
+import logging
+import sys, time, subprocess
+from tvm import autotvm
+import topi
+import json
+from topi.util import get_const_tuple
+import os
+
+
+op_attributes = {
+ "B": int(os.environ['B']) if 'B' in os.environ else 6,
+ "N": int(os.environ['N']) if 'N' in os.environ else 1024,
+ "K": int(os.environ['K']) if 'K' in os.environ else 64,
+ "M": int(os.environ['M']) if 'M' in os.environ else 4096,
+ "P": os.environ['P'] if 'P' in os.environ else "NN",
+}
+
+@autotvm.template
+def get_template_op(**kargs):
+ batch = op_attributes["B"]
+ M = op_attributes["N"]
+ K = op_attributes["K"]
+ N = op_attributes["M"]
+ pose = op_attributes["P"]
+
+ if pose == 'NN':
+ A = tvm.placeholder((batch, M, K), name='A', dtype="float32")
+ B = tvm.placeholder((batch, K, N), name='B', dtype="float32")
+ k = tvm.reduce_axis((0, K), name='k')
+ C = tvm.compute((batch, M, N), lambda b, i, j: tvm.sum(
+ A[b, i, k] * B[b, k, j], axis=k), name='C')
+ elif pose == 'NT':
+ A = tvm.placeholder((batch, M, K), name='A', dtype="float32")
+ B = tvm.placeholder((batch, N, K), name='B', dtype="float32")
+ k = tvm.reduce_axis((0, K), name='k')
+ C = tvm.compute((batch, M, N), lambda b, i, j: tvm.sum(
+ A[b, i, k] * B[b, j, k], axis=k), name='C')
+ elif pose == 'TN':
+ A = tvm.placeholder((batch, K, M), name='A', dtype="float32")
+ B = tvm.placeholder((batch, K, N), name='B', dtype="float32")
+ k = tvm.reduce_axis((0, K), name='k')
+ C = tvm.compute((batch, M, N), lambda b, i, j: tvm.sum(
+ A[b, k, i] * B[b, k, j], axis=k), name='C')
+ elif pose == 'TT':
+ A = tvm.placeholder((batch, K, M), name='A', dtype="float32")
+ B = tvm.placeholder((batch, N, K), name='B', dtype="float32")
+ k = tvm.reduce_axis((0, K), name='k')
+ C = tvm.compute((batch, M, N), lambda b, i, j: tvm.sum(
+ A[b, k, i] * B[b, j, k], axis=k), name='C')
+ else:
+ raise
+
+ cfg = autotvm.get_config()
+ s = tvm.create_schedule(C.op)
+ AA = s.cache_read(A, "shared", [C])
+ AL = s.cache_read(AA, "local", [C])
+ BB = s.cache_read(B, "shared", [C])
+ BL = s.cache_read(BB, "local", [C])
+ CC = s.cache_write(C, "local")
+
+ b, y, x = C.op.axis
+ k = CC.op.reduce_axis[0]
+
+ cfg.define_split('B', cfg.axis(b), num_outputs=2)
+ bo, bi = cfg['B'].apply(s, C, b)
+
+ cfg.define_split('K', cfg.axis(k), num_outputs=3)
+ ko, kt, ki = cfg['K'].apply(s, CC, k)
+
+ block_x = tvm.thread_axis('blockIdx.x')
+ block_y = tvm.thread_axis('blockIdx.y')
+ block_z = tvm.thread_axis('blockIdx.z')
+ thread_x = tvm.thread_axis('threadIdx.x')
+ thread_y = tvm.thread_axis('threadIdx.y')
+ thread_z = tvm.thread_axis('threadIdx.z')
+
+ cfg.define_split('X', cfg.axis(y), num_outputs=4)
+ cfg.define_split('Y', cfg.axis(x), num_outputs=4)
+
+ by, tyz, ty, yi = cfg['X'].apply(s, C, y)
+ bx, txz, tx, xi = cfg['Y'].apply(s, C, x)
+
+ s[C].bind(bo, block_z)
+ s[C].bind(by, block_y)
+ s[C].bind(bx, block_x)
+ s[C].bind(tyz, tvm.thread_axis('vthread'))
+ s[C].bind(txz, tvm.thread_axis('vthread'))
+ s[C].bind(bi, thread_z)
+ s[C].bind(ty, thread_y)
+ s[C].bind(tx, thread_x)
+ s[C].reorder(by, bx, tyz, txz, ty, tx, yi, xi)
+
+ s[CC].compute_at(s[C], tx)
+
+ bo, yo, xo = CC.op.axis
+ s[CC].reorder(ko, kt, yo, xo, ki)
+ s[CC].unroll(kt)
+
+ for stage in [AL, BL]:
+ s[stage].compute_at(s[CC], kt)
+ s[stage].double_buffer()
+
+ for stage in [AA, BB]:
+ s[stage].compute_at(s[CC], ko)
+
+ fused = s[stage].fuse(*s[stage].op.axis)
+ ty, tx = s[stage].split(fused, nparts=cfg['X'].size[2])
+ tx, xi = s[stage].split(tx, nparts=cfg['Y'].size[2])
+ _, xi = s[stage].split(xi, factor=4)
+
+ s[stage].bind(ty, thread_y)
+ s[stage].bind(tx, thread_x)
+ s[stage].vectorize(xi)
+ s[stage].double_buffer()
+
+ cfg.add_flop(batch * M * K * N * 2.0)
+ return s, [A, B, C]
diff --git a/examples/trials/systems/opevo/src/templates/convfwd_direct.py b/examples/trials/systems/opevo/src/templates/convfwd_direct.py
new file mode 100644
index 0000000000..5f40abff80
--- /dev/null
+++ b/examples/trials/systems/opevo/src/templates/convfwd_direct.py
@@ -0,0 +1,130 @@
+import numpy as np
+import tvm
+import logging
+import sys, time, subprocess
+from tvm import autotvm
+import topi
+import json
+from topi.util import get_const_tuple
+import os
+
+
+op_attributes = {
+ "N": int(os.environ['N']) if 'N' in os.environ else 64,
+ "C": int(os.environ['C']) if 'C' in os.environ else 3,
+ "H": int(os.environ['H']) if 'H' in os.environ else 229,
+ "W": int(os.environ['W']) if 'W' in os.environ else 229,
+ "F": int(os.environ['F']) if 'F' in os.environ else 32,
+ "K": int(os.environ['K']) if 'K' in os.environ else 5,
+ "ST": int(os.environ['ST']) if 'ST' in os.environ else 1,
+ "PD": int(os.environ['PD']) if 'PD' in os.environ else 2,
+}
+
+
+@autotvm.template
+def get_template_op(**kargs):
+ N = op_attributes["N"]
+ CI = op_attributes["C"]
+ H = op_attributes["H"]
+ W = op_attributes["W"]
+ H = op_attributes["H"]
+ CO = op_attributes["F"]
+ KH = KW = op_attributes["K"]
+ stride = op_attributes["ST"]
+ padding = op_attributes["PD"]
+ dilation = 1
+
+ data = tvm.placeholder((N, CI, H, W), name='data')
+ kernel = tvm.placeholder((CO, CI, KH, KW), name='kernel')
+ conv = topi.nn.conv2d_nchw(
+ data, kernel, (stride, stride), (padding, padding), dilation=1, out_dtype='float32')
+ s = tvm.create_schedule([conv.op])
+ cfg = autotvm.get_config()
+
+ ##### space definition begin #####
+ n, f, y, x = s[conv].op.axis
+ rc, ry, rx = s[conv].op.reduce_axis
+ cfg.define_split("tile_f", f, num_outputs=4)
+ cfg.define_split("tile_y", y, num_outputs=4)
+ cfg.define_split("tile_x", x, num_outputs=4)
+ cfg.define_split("tile_rc", rc, num_outputs=2)
+ cfg.define_split("tile_ry", ry, num_outputs=2)
+ cfg.define_split("tile_rx", rx, num_outputs=2)
+ cfg.define_knob("auto_unroll_max_step", [0, 125, 256])
+
+ target = tvm.target.current_target()
+ if target.target_name in ['nvptx', 'rocm']:
+ cfg.define_knob("unroll_explicit", [1])
+ else:
+ cfg.define_knob("unroll_explicit", [0, 1])
+
+ pad_data, kernel = s[conv].op.input_tensors
+
+ s[pad_data].compute_inline()
+ if isinstance(kernel.op, tvm.tensor.ComputeOp) and 'dilate' in kernel.op.tag:
+ s[kernel].compute_inline()
+
+ if conv.op in s.outputs:
+ output = conv
+ OL = s.cache_write(conv, 'local')
+ else:
+ output = s.outputs[0].output(0)
+ s[conv].set_scope('local')
+ OL = conv
+
+ # create cache stage
+ AA = s.cache_read(pad_data, 'shared', [OL])
+ WW = s.cache_read(kernel, 'shared', [OL])
+
+ # tile and bind spatial axes
+ n, f, y, x = s[output].op.axis
+ kernel_scope, n = s[output].split(n, nparts=1)
+
+ bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
+ by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
+ bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)
+
+ bf = s[output].fuse(n, bf)
+ s[output].bind(bf, tvm.thread_axis("blockIdx.z"))
+ s[output].bind(by, tvm.thread_axis("blockIdx.y"))
+ s[output].bind(bx, tvm.thread_axis("blockIdx.x"))
+ s[output].bind(vf, tvm.thread_axis("vthread"))
+ s[output].bind(vy, tvm.thread_axis("vthread"))
+ s[output].bind(vx, tvm.thread_axis("vthread"))
+ s[output].bind(tf, tvm.thread_axis("threadIdx.z"))
+ s[output].bind(ty, tvm.thread_axis("threadIdx.y"))
+ s[output].bind(tx, tvm.thread_axis("threadIdx.x"))
+ s[output].reorder(bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi)
+ s[OL].compute_at(s[output], tx)
+
+ # tile reduction axes
+ n, f, y, x = s[OL].op.axis
+ rc, ry, rx = s[OL].op.reduce_axis
+ rco, rci = cfg['tile_rc'].apply(s, OL, rc)
+ ryo, ryi = cfg['tile_rx'].apply(s, OL, ry)
+ rxo, rxi = cfg['tile_ry'].apply(s, OL, rx)
+ s[OL].reorder(rco, ryo, rxo, rci, ryi, rxi, n, f, y, x)
+
+ s[AA].compute_at(s[OL], rxo)
+ s[WW].compute_at(s[OL], rxo)
+
+ # cooperative fetching
+ for load in [AA, WW]:
+ n, f, y, x = s[load].op.axis
+ fused = s[load].fuse(n, f, y, x)
+ tz, fused = s[load].split(fused, nparts=cfg["tile_f"].size[2])
+ ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2])
+ tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2])
+ s[load].bind(tz, tvm.thread_axis("threadIdx.z"))
+ s[load].bind(ty, tvm.thread_axis("threadIdx.y"))
+ s[load].bind(tx, tvm.thread_axis("threadIdx.x"))
+
+ # unroll
+ s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
+ s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val)
+
+ N, CO, OH, OW = get_const_tuple(output.shape)
+ _, KH, KW, CI = get_const_tuple(kernel.shape)
+
+ cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW)
+ return s, [data, kernel, conv]
diff --git a/examples/trials/systems/opevo/src/templates/matmul.py b/examples/trials/systems/opevo/src/templates/matmul.py
new file mode 100644
index 0000000000..7cbd75c458
--- /dev/null
+++ b/examples/trials/systems/opevo/src/templates/matmul.py
@@ -0,0 +1,111 @@
+import numpy as np
+import tvm
+import logging
+import sys, time, subprocess
+from tvm import autotvm
+import topi
+import json
+from topi.util import get_const_tuple
+import os
+
+
+op_attributes = {
+ "N": int(os.environ['N']) if 'N' in os.environ else 1024,
+ "K": int(os.environ['K']) if 'K' in os.environ else 64,
+ "M": int(os.environ['M']) if 'M' in os.environ else 4096,
+ "P": os.environ['P'] if 'P' in os.environ else "NN",
+}
+
+@autotvm.template
+def get_template_op(**kargs):
+ batch = op_attributes["N"]
+ in_dim = op_attributes["K"]
+ out_dim = op_attributes["M"]
+ pose = op_attributes["P"]
+
+ if pose == 'NN':
+ A = tvm.placeholder((batch, in_dim), name='A', dtype="float32")
+ B = tvm.placeholder((in_dim, out_dim), name='B', dtype="float32")
+ k = tvm.reduce_axis((0, in_dim), name='k')
+ C = tvm.compute((batch, out_dim), lambda i, j: tvm.sum(
+ A[i, k] * B[k, j], axis=k), name='C')
+ elif pose == 'NT':
+ A = tvm.placeholder((batch, in_dim), name='A', dtype="float32")
+ B = tvm.placeholder((out_dim, in_dim), name='B', dtype="float32")
+ k = tvm.reduce_axis((0, in_dim), name='k')
+ C = tvm.compute((batch, out_dim), lambda i, j: tvm.sum(
+ A[i, k] * B[j, k], axis=k), name='C')
+ elif pose == 'TN':
+ A = tvm.placeholder((in_dim, batch), name='A', dtype="float32")
+ B = tvm.placeholder((in_dim, out_dim), name='B', dtype="float32")
+ k = tvm.reduce_axis((0, in_dim), name='k')
+ C = tvm.compute((batch, out_dim), lambda i, j: tvm.sum(
+ A[k, i] * B[k, j], axis=k), name='C')
+ elif pose == 'TT':
+ A = tvm.placeholder((in_dim, batch), name='A', dtype="float32")
+ B = tvm.placeholder((out_dim, in_dim), name='B', dtype="float32")
+ k = tvm.reduce_axis((0, in_dim), name='k')
+ C = tvm.compute((batch, out_dim), lambda i, j: tvm.sum(
+ A[k, i] * B[j, k], axis=k), name='C')
+ else:
+ raise
+
+ cfg = autotvm.get_config()
+ s = tvm.create_schedule(C.op)
+
+ cfg.add_flop(batch * in_dim * out_dim * 2.0)
+
+ AA = s.cache_read(A, "shared", [C])
+ AL = s.cache_read(AA, "local", [C])
+ BB = s.cache_read(B, "shared", [C])
+ BL = s.cache_read(BB, "local", [C])
+ CC = s.cache_write(C, "local")
+
+ y, x = C.op.axis
+ k = CC.op.reduce_axis[0]
+
+ cfg.define_split('K', cfg.axis(k), num_outputs=3)
+ cfg.define_split('X', cfg.axis(y), num_outputs=4)
+ cfg.define_split('Y', cfg.axis(x), num_outputs=4)
+
+ ko, kt, ki = cfg['K'].apply(s, CC, k)
+
+ block_x = tvm.thread_axis('blockIdx.x')
+ block_y = tvm.thread_axis('blockIdx.y')
+ thread_x = tvm.thread_axis('threadIdx.x')
+ thread_y = tvm.thread_axis('threadIdx.y')
+
+ by, tyz, ty, yi = cfg['X'].apply(s, C, y)
+ bx, txz, tx, xi = cfg['Y'].apply(s, C, x)
+
+ s[C].bind(by, block_y)
+ s[C].bind(bx, block_x)
+ s[C].bind(tyz, tvm.thread_axis('vthread'))
+ s[C].bind(txz, tvm.thread_axis('vthread'))
+ s[C].bind(ty, thread_y)
+ s[C].bind(tx, thread_x)
+ s[C].reorder(by, bx, tyz, txz, ty, tx, yi, xi)
+
+ s[CC].compute_at(s[C], tx)
+
+ yo, xo = CC.op.axis
+ s[CC].reorder(ko, kt, yo, xo, ki)
+ s[CC].unroll(kt)
+
+ for stage in [AL, BL]:
+ s[stage].compute_at(s[CC], kt)
+
+ for stage in [AA, BB]:
+ s[stage].compute_at(s[CC], ko)
+
+ fused = s[stage].fuse(*s[stage].op.axis)
+ ty, tx = s[stage].split(fused, nparts=cfg['X'].size[2])
+ tx, xi = s[stage].split(tx, nparts=cfg['Y'].size[2])
+ _, xi = s[stage].split(xi, factor=4)
+
+ s[stage].bind(ty, thread_y)
+ s[stage].bind(tx, thread_x)
+ s[stage].vectorize(xi)
+ s[stage].double_buffer()
+
+ return s, [A, B, C]
diff --git a/examples/trials/systems/opevo/tvm_patches/libcuda.so.1 b/examples/trials/systems/opevo/tvm_patches/libcuda.so.1
new file mode 100644
index 0000000000..8417946346
Binary files /dev/null and b/examples/trials/systems/opevo/tvm_patches/libcuda.so.1 differ
diff --git a/examples/trials/systems/opevo/tvm_patches/tvm_v0.6.patch b/examples/trials/systems/opevo/tvm_patches/tvm_v0.6.patch
new file mode 100644
index 0000000000..e290b230e1
--- /dev/null
+++ b/examples/trials/systems/opevo/tvm_patches/tvm_v0.6.patch
@@ -0,0 +1,84 @@
+diff --git a/python/tvm/autotvm/tuner/tuner.py b/python/tvm/autotvm/tuner/tuner.py
+index 76d088f4c..7ed4ff02a 100644
+--- a/python/tvm/autotvm/tuner/tuner.py
++++ b/python/tvm/autotvm/tuner/tuner.py
+@@ -122,7 +122,7 @@ class Tuner(object):
+ configs = self.next_batch(min(n_parallel, n_trial - i))
+
+ inputs = [MeasureInput(self.task.target, self.task, config) for config in configs]
+- results = measure_batch(inputs)
++ results = self.parse_configs(self.task, configs) if hasattr(self, 'parse_configs') else measure_batch(inputs)
+
+ # keep best config
+ for k, (inp, res) in enumerate(zip(inputs, results)):
+diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc
+index eab542dd3..2f1a11303 100644
+--- a/src/codegen/codegen_c.cc
++++ b/src/codegen/codegen_c.cc
+@@ -808,6 +808,7 @@ void CodeGenC::VisitStmt_(const AttrStmt* op) {
+ IterVar iv = Downcast(op->node);
+ if (iv->thread_tag.length() != 0) {
+ if (!var_idmap_.count(iv->var.get())) {
++ this->currentOp = op;
+ BindThreadIndex(iv);
+ }
+ }
+diff --git a/src/codegen/codegen_c.h b/src/codegen/codegen_c.h
+index 8701cda1e..7d3d56ddc 100644
+--- a/src/codegen/codegen_c.h
++++ b/src/codegen/codegen_c.h
+@@ -174,6 +174,8 @@ class CodeGenC :
+ // Get a cast type from to
+ virtual std::string CastFromTo(std::string value, Type from, Type target);
+
++ const AttrStmt* currentOp;
++
+ protected:
+ // Print reference to struct location
+ std::string GetStructRef(
+diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc
+index 6656fa077..a4f0f962d 100644
+--- a/src/codegen/codegen_cuda.cc
++++ b/src/codegen/codegen_cuda.cc
+@@ -106,6 +106,9 @@ void CodeGenCUDA::BindThreadIndex(const IterVar& iv) {
+ CHECK(!var_idmap_.count(iv->var.get()));
+ var_idmap_[iv->var.get()] =
+ CastFromTo(iv->thread_tag, UInt(32), iv->var.type());
++ int nthread = static_cast(this->currentOp->value.as()->value);
++ if (iv->thread_tag.find("threadIdx.") == 0 || iv->thread_tag.find("blockIdx.") == 0)
++ this->stream << " // [thread_extent] " << iv->thread_tag << " = " << nthread << "\n";
+ }
+
+ void CodeGenCUDA::PrintType(Type t, std::ostream& os) { // NOLINT(*)
+diff --git a/src/codegen/opt/build_cuda_on.cc b/src/codegen/opt/build_cuda_on.cc
+index 1992ac5d9..9b0ff4cd9 100644
+--- a/src/codegen/opt/build_cuda_on.cc
++++ b/src/codegen/opt/build_cuda_on.cc
+@@ -137,6 +137,9 @@ runtime::Module BuildCUDA(Array funcs) {
+ cg.AddFunction(f);
+ }
+ std::string code = cg.Finish();
++ const auto* backendproc = Registry::Get("tvm_callback_backend_proc");
++ if (backendproc)
++ return CUDAModuleCreate((*backendproc)(code).operator std::string(), "cubin", ExtractFuncInfo(funcs), code);
+
+ if (const auto* f = Registry::Get("tvm_callback_cuda_postproc")) {
+ code = (*f)(code).operator std::string();
+diff --git a/src/lang/expr_operator.cc b/src/lang/expr_operator.cc
+index 220d4378c..cc435d138 100644
+--- a/src/lang/expr_operator.cc
++++ b/src/lang/expr_operator.cc
+@@ -208,11 +208,11 @@ Expr operator%(Expr a, Expr b) {
+
+ // TODO(tqchen): switch to floordiv
+ Expr indexdiv(Expr a, Expr b) {
+- return floordiv(a, b);
++ return truncdiv(a, b);
+ }
+
+ Expr indexmod(Expr a, Expr b) {
+- return floormod(a, b);
++ return truncmod(a, b);
+ }
+
+ Expr floordiv(Expr a, Expr b) {