Skip to content

Commit

Permalink
[PyTorch]Add PyTorchTVM: compile torchscript to tvm and export as pyt…
Browse files Browse the repository at this point in the history
…orch_op (#8777)

* add pt_op

* add compile api

* perf: support set_output_zero_copy

* fix: cpu device_id mismatch

* fix: pt_class test script

* refactor: unify namespace to tvm.contrib.torch

* add ASF header

* build: set pt tvmdsoop default off

* build: remove unset_log_macros.h

* refactor: change header order

* refactor: fix python code format

* style: resolve pylint issues

* style: add blank line

* style: fix pylint invalid_name

* trigger CI

* test: add more test scripts

* style: add empty lines

* test: update test for trace tvm module

* style: fix linting issues

* style: remove single quote

* style: disable pylint invalid-name

* trigger CI

* trigger CI

Co-authored-by: kongroo <imjcqt@gmail.com>
  • Loading branch information
Meteorix and kongroo authored Nov 5, 2021
1 parent 048994b commit e7024fb
Show file tree
Hide file tree
Showing 17 changed files with 2,072 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ tvm_option(USE_MICRO "Build with Micro TVM support" OFF)
tvm_option(INSTALL_DEV "Install compiler infrastructure" OFF)
tvm_option(HIDE_PRIVATE_SYMBOLS "Compile with -fvisibility=hidden." OFF)
tvm_option(USE_TF_TVMDSOOP "Build with TensorFlow TVMDSOOp" OFF)
tvm_option(USE_PT_TVMDSOOP "Build with PyTorch TVMDSOOp" OFF)
tvm_option(USE_FALLBACK_STL_MAP "Use TVM's POD compatible Map" OFF)
tvm_option(USE_ETHOSN "Build with Arm Ethos-N" OFF)
tvm_option(USE_CMSISNN "Build with Arm CMSIS-NN" OFF)
Expand Down Expand Up @@ -441,6 +442,7 @@ include(cmake/modules/contrib/NNPack.cmake)
include(cmake/modules/contrib/HybridDump.cmake)
include(cmake/modules/contrib/TFLite.cmake)
include(cmake/modules/contrib/TF_TVMDSOOP.cmake)
include(cmake/modules/contrib/PT_TVMDSOOP.cmake)
include(cmake/modules/contrib/CoreML.cmake)
include(cmake/modules/contrib/BNNS.cmake)
include(cmake/modules/contrib/ONNX.cmake)
Expand Down
34 changes: 34 additions & 0 deletions apps/pt_tvmdsoop/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
cmake_minimum_required(VERSION 3.2)
project(pt_tvmdsoop C CXX)

set(BUILD_PT_TVMDSOOP_ONLY ON)
set(CMAKE_CURRENT_SOURCE_DIR ${TVM_ROOT})
set(CMAKE_CURRENT_BINARY_DIR ${TVM_ROOT}/build)

include_directories(SYSTEM ${TVM_ROOT}/3rdparty/dlpack/include/)
include_directories(SYSTEM ${TVM_ROOT}/3rdparty/dmlc-core/include/)
include_directories(${TVM_ROOT}/include)

link_directories(${TVM_ROOT}/build)

include(${TVM_ROOT}/cmake/utils/Utils.cmake)
include(${TVM_ROOT}/cmake/utils/FindCUDA.cmake)
include(${TVM_ROOT}/cmake/modules/CUDA.cmake)

include(${TVM_ROOT}/cmake/modules/contrib/PT_TVMDSOOP.cmake)
46 changes: 46 additions & 0 deletions apps/pt_tvmdsoop/prepare_and_test_pt_tvm_class.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#!/bin/bash
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

TVM_ROOT=$(cd $(dirname $0)/../..; pwd)
echo "TVM_ROOT=${TVM_ROOT}"

export PYTHONPATH=${TVM_ROOT}/python

if [ ! -f $TVM_ROOT/build/libtvm.so ]; then
echo "$TVM_ROOT/build/libtvm.so missing"
exit 1
fi

if [ ! -f $TVM_ROOT/build/libtvm_runtime.so ]; then
echo "$TVM_ROOT/build/libtvm_runtime.so missing"
exit 1
fi

python3 -c "import tvm; print(tvm.runtime.enabled('gpu'))" | grep -e 1

if [ "$?" -eq 0 ]; then
echo "Build PT_TVMDSOOP with gpu support and execute tests"
CMAKE_OPTIONS="-DUSE_CUDA=ON -DUSE_CUDNN=ON -DPython3_EXECUTABLE=python3 -DTVM_ROOT=${TVM_ROOT}"
mkdir -p build
cd build; cmake .. ${CMAKE_OPTIONS} && make
cp *.so $TVM_ROOT/build/
cd ..

LD_LIBRARY_PATH=${TVM_ROOT}/build:./build:$LD_LIBRARY_PATH python3 -m pytest -v ./tests
fi

68 changes: 68 additions & 0 deletions apps/pt_tvmdsoop/tests/test_torch_compile_cpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#!/usr/bin/env python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test script for torch module"""
import torch
import time
import tvm
from tvm.contrib.torch import compile


class Model(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x: torch.Tensor):
return x * x


model = Model()
x = torch.rand([1, 3, 224, 224])
model_jit = torch.jit.trace(model, x)
print(model_jit.graph)

print("run torchscript...")
for i in range(20):
t = time.time()
model_jit(x)
print(time.time() - t)


option = {
"input_infos": [
("x", (1, 3, 224, 224)),
],
"default_dtype": "float16",
"export_dir": "pytorch_compiled",
"num_outputs": 1,
"tuning_n_trials": 1, # set zero to skip tuning
"tuning_log_file": "tuning.log",
"target": "llvm",
"device": tvm.cpu(),
}

pytorch_tvm_module = compile(model_jit, option)
torch.jit.script(pytorch_tvm_module).save("model_tvm.pt")


print("Run PyTorch...")
for i in range(20):
t = time.time()
outputs = pytorch_tvm_module.forward([x.cpu()])
print(1000 * (time.time() - t))
print(outputs[0].shape)
63 changes: 63 additions & 0 deletions apps/pt_tvmdsoop/tests/test_torch_compile_gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#!/usr/bin/env python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test script for torch module"""
import torch
import time
from torchvision.models import resnet50
import tvm
from tvm.contrib.torch import compile


model = resnet50().half().cuda()
x = torch.rand([1, 3, 224, 224]).half().cuda()
model_jit = torch.jit.trace(model, x)
print(model_jit.graph)

print("run torchscript...")
for i in range(20):
t = time.time()
model_jit(x)
torch.cuda.synchronize()
print(time.time() - t)


option = {
"input_infos": [
("x", (1, 3, 224, 224)),
],
"default_dtype": "float16",
"export_dir": "pytorch_compiled",
"num_outputs": 1,
"tuning_n_trials": 1, # set zero to skip tuning
"tuning_log_file": "tuning.log",
"target": "cuda",
"device": tvm.cuda(0),
}

pytorch_tvm_module = compile(model_jit, option)
torch.jit.script(pytorch_tvm_module).save("model_tvm.pt")


print("Run PyTorch...")
for i in range(20):
t = time.time()
outputs = pytorch_tvm_module.forward([x])
torch.cuda.synchronize()
print(1000 * (time.time() - t))
print(outputs[0].shape)
129 changes: 129 additions & 0 deletions apps/pt_tvmdsoop/tests/test_torch_graph_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
#!/usr/bin/env python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test script for torch module"""
import tempfile
import os
import logging
import torch
import numpy as np
import tvm
import tvm.testing
from tvm import te, relay
import tvm.contrib.torch
from tvm.contrib import graph_runtime

TVM_ASSETS = ["mod.so", "graph.json", "params"]


def test_use_pt_graph_module():
"""main test function"""

def build_export_graph(device):
"""relay build & export graph"""
x = relay.var("x", shape=(10, 5))
y = relay.var("y", shape=(1, 5))
z = relay.add(x, y)
z = relay.exp(z)
func = relay.Function([x, y], z)
x_data = np.random.rand(10, 5).astype("float32")
y_data = np.random.rand(1, 5).astype("float32")
params = {"y": y_data}

pt_device = torch.device(device)
if pt_device.type == "cuda":
target = "cuda"
ctx = tvm.cuda(pt_device.index)
else:
target = "llvm"
ctx = tvm.cpu(0)

graph, lib, params = relay.build(tvm.IRModule.from_expr(func), target=target, params=params)
mod = graph_runtime.create(graph, lib, device=ctx)
mod.set_input(**params)
mod.set_input(x=x_data)
mod.run()
res = mod.get_output(0).asnumpy()
ref_res = np.exp(y_data + x_data)
tvm.testing.assert_allclose(res, ref_res, atol=1e-5, rtol=1e-5)

# export to tempdir
export_dir = tempfile.mkdtemp("tvm_export")
lib.export_library(os.path.join(export_dir, TVM_ASSETS[0]))
with open(os.path.join(export_dir, TVM_ASSETS[1]), "w") as fout:
fout.write(graph)
with open(os.path.join(export_dir, TVM_ASSETS[2]), "wb") as fout:
fout.write(relay.save_param_dict(params))

return export_dir

def test_pt_run(device, trace=True, to_device=None):
"""test add lib with Pytorch wrapper"""
print("\n############## Test on device:", device, "#################")
export_dir = build_export_graph(device)
engine = tvm.contrib.torch.GraphModule(num_inputs=2, num_outputs=1).to(device)

x = np.random.rand(10, 5).astype("float32")
y = np.random.rand(1, 5).astype("float32")

expect = np.exp(y + x)

def get_inputs_by_device(device):
inps = [torch.Tensor(x), torch.Tensor(y)]
if device == "cpu":
return inps
else:
device_type, device_id = device.split(":")
assert device_type == "cuda"
return [inp.cuda(int(device_id)) for inp in inps]

assets = [os.path.join(export_dir, i) for i in TVM_ASSETS]
engine.init((x.shape, y.shape), *assets)

outputs = engine.forward(get_inputs_by_device(device))
tvm.testing.assert_allclose(outputs[0].cpu(), expect, atol=1e-5, rtol=1e-5)

if trace:
print("\n################ Test trace and load #################")
scripted = torch.jit.script(engine)
scripted_dir = tempfile.mkdtemp("scripted")
scripted_path = os.path.join(scripted_dir, "model.pt")
scripted.save(scripted_path)
loaded = torch.jit.load(scripted_path)
outputs = loaded.forward(get_inputs_by_device(device))
tvm.testing.assert_allclose(outputs[0].cpu(), expect, atol=1e-5, rtol=1e-5)
del scripted
del loaded

if to_device:
print(
"\n################ Test move from [{}] to [{}] #################".format(
device, to_device
)
)
engine = engine.to(to_device)
outputs = engine.forward(get_inputs_by_device(to_device))
tvm.testing.assert_allclose(outputs[0].cpu(), expect, atol=1e-5, rtol=1e-5)
del engine

test_pt_run(device="cuda:0", trace=True, to_device="cuda:1")
test_pt_run(device="cpu", trace=True)


if __name__ == "__main__":
test_use_pt_graph_module()
Loading

0 comments on commit e7024fb

Please sign in to comment.