Please note that this is a work in progress.
For improved performance, you'll need to build PyTorch on top of this PR: pytorch/pytorch#20284
cd pytorch
git fetch origin pull/20284/head:tvm_dev
git checkout tvm_dev
python setup.py install
Otherwise, install the latest Nightly build of PyTorch.
Then, build this repo
# Make sure the right llvm-config is in your PATH
python setup.py install
python setup.py test
This package transparently hooks into PyTorch's JIT, so the same tooling is applicable (see @torch.jit.script
, torch.jit.trace
and graph_for
). See below for an example.
import torch_tvm
torch_tvm.enable()
# The following function will be compiled with TVM
@torch.jit.script
def my_func(a, b, c):
return a * b + c
To disable the JIT hooks, use torch_tvm.disable()
.
register.cpp
: Sets up pybind bindings and invokes the registration of a TVM backend.compiler.{h,cpp}
: Main logic to compile a PyTorch JIT graph with TVM.operators.{h,cpp}
: Location of mapping from JIT IR to TVM operators.
- Add ability to register translations from opaque op names to TVM (as is done in
operator.cpp
) from Python. - Zero copy
set_input
- Bail-out mechanism (invoke PyTorch JIT fallback)
- Threadpool integration
- Allocator integration
- Operator translation
- Add
- Multiply
- Convolution
- BatchNorm
- Relu
- AveragePool
- MaxPool
- Linear
- Tensor manipulation
- Reshape
- Views
- Tooling
- Model coverage checks
- Benchmarks for master