Skip to content

Commit

Permalink
[TVM PyTorch Integration] libstdc++ CXX11 ABI Compatibility & boolean…
Browse files Browse the repository at this point in the history
… tensor support (#12232)

* first commit

* rename

* cmake

* deprecated

* newline

* config

* config

* typo

* skip tvm_class

* rename

* delete ptr

* delete ptr

* save progress

* boolean support

* cmake file

* polish code

* compile config

* improving the codes

* format

* doc&errormsg

* zero-cost copy

* one step

* to ndarray

* extra output

* delete extra codes

* update test

* boolean support

* strong test

* decrease memory copy

* polish

* reformat

* polish

* remove redundant import

Co-authored-by: juda <yzhou@octoml.ai>
  • Loading branch information
juda and juda authored Aug 17, 2022
1 parent d2f9f25 commit 073304d
Show file tree
Hide file tree
Showing 10 changed files with 844 additions and 279 deletions.
7 changes: 6 additions & 1 deletion apps/pt_tvmdsoop/tests/test_as_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
# specific language governing permissions and limitations
# under the License.
"""Test script for tvm torch module"""
import tempfile

import numpy as np

import torch
Expand Down Expand Up @@ -190,7 +192,10 @@ def test_tvmscript_torch_gpu():
q1 = torch.arange(8, device=cuda0).type(torch.float32)
q2 = torch.zeros((8,), dtype=torch.float32, device=cuda0)

ModuleGPU(q1, q2)
with tempfile.NamedTemporaryFile(suffix=".pt") as tmp:
torch.save(ModuleGPU, tmp.name)
loaded_mod = torch.load(tmp.name)
loaded_mod(q1, q2)

tvm.testing.assert_allclose(q2.cpu().numpy(), (q1 + 1).cpu().numpy(), atol=1e-5, rtol=1e-5)

Expand Down
129 changes: 129 additions & 0 deletions apps/pt_tvmdsoop/tests/test_boolean_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
#!/usr/bin/env python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test script for boolean tensor support"""
import tempfile

import torch

import tvm
import tvm.testing
from tvm.contrib.torch import as_torch, optimize_torch
from tvm.script import tir as T


def negate(x):
return x.logical_not()


def sum_up_tensor(x):
return x.size(dim=0) - torch.sum(x.int())


def tensor_boolean_operation(x):
arr1 = (x + 0.3).floor().bool()
arr2 = (~((x + 0.7).int().bool())).bool()
ret = ((arr1 & arr2).byte() + 0.5).half()
return ~(ret.bool())


def test_bool_tensor_negate():
input = torch.ones(1, dtype=torch.bool)
optimized_negate = optimize_torch(
negate,
input,
)
with tempfile.NamedTemporaryFile(suffix=".pt") as tmp:
torch.save(optimized_negate, tmp.name)
loaded_mod = torch.load(tmp.name)
output = loaded_mod(negate(input))
tvm.testing.assert_allclose(input.numpy(), output.numpy(), atol=1e-5, rtol=1e-5)


def test_sum_up_tensor():
x = torch.randint(0, 2, (16,))
y = x.bool()
optimized_func = optimize_torch(
sum_up_tensor,
(y,),
)
ret1 = (x[x == 0]).size(dim=0)
ret2 = optimized_func(y).numpy()
tvm.testing.assert_allclose(ret1, ret2, atol=1e-5, rtol=1e-5)


def test_tensor_boolean_operation():
input = torch.rand(200)
model = optimize_torch(
tensor_boolean_operation,
input,
)
ret1 = tensor_boolean_operation(input)
ret2 = model(input)
tvm.testing.assert_allclose(ret1, ret2, atol=1e-5, rtol=1e-5)


@as_torch
@T.prim_func
def negate_tvmscript(
X: T.Buffer[(8, 8), "bool"],
Y: T.Buffer[(8, 8), "float32"],
Z: T.Buffer[(8, 8), "bool"],
U: T.Buffer[(8, 8), "float32"],
) -> None:
for i, j in T.grid(8, 8):
with T.block():
if Y[i, j] > 0.0:
Z[i, j] = X[i, j]
U[i, j] = Y[i, j]
else:
Z[i, j] = not X[i, j]
U[i, j] = 0.0 - Y[i, j]


def negate_vanila(x, y):
z = torch.zeros(8, 8).bool()
for i in range(8):
for j in range(8):
if y[i, j] > 0:
z[i, j] = x[i, j]
else:
z[i, j] = ~x[i, j]
return z


def test_tvmscript_torch_decorator():
q1 = (torch.rand(8, 8) + 0.5).int().bool()
q2 = torch.rand(8, 8) - 0.5
q3 = torch.zeros(8, 8).bool()
q4 = torch.zeros(8, 8)

std1 = negate_vanila(q1, q2)
std2 = torch.abs(q2)

negate_tvmscript(q1, q2, q3, q4)

tvm.testing.assert_allclose(std1.numpy(), q3.numpy(), atol=1e-5, rtol=1e-5)
tvm.testing.assert_allclose(std2.numpy(), q4.numpy(), atol=1e-5, rtol=1e-5)


if __name__ == "__main__":
test_tvmscript_torch_decorator()
test_bool_tensor_negate()
test_sum_up_tensor()
test_tensor_boolean_operation()
68 changes: 53 additions & 15 deletions cmake/modules/contrib/PT_TVMDSOOP.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
Expand All @@ -17,42 +17,80 @@

if(NOT USE_PT_TVMDSOOP STREQUAL "OFF")
find_package(PythonInterp REQUIRED)

execute_process(COMMAND ${PYTHON_EXECUTABLE} -c "import torch; print(torch.__path__[0].strip())"
OUTPUT_VARIABLE PT_PATH
RESULT_VARIABLE PT_STATUS)
if (NOT ${PT_STATUS} EQUAL 0)

if(NOT ${PT_STATUS} EQUAL 0)
message(FATAL_ERROR "Fail to get pytorch path")
endif()

string(REGEX REPLACE "\n" "" PT_PATH "${PT_PATH}")
message(STATUS "PyTorch path: ${PT_PATH}")

set(PT_COMPILE_FLAGS_STR "-I${PT_PATH}/include -D_GLIBCXX_USE_CXX11_ABI=0")
execute_process(COMMAND ${PYTHON_EXECUTABLE} -c "import torch;print(torch.compiled_with_cxx11_abi())"
OUTPUT_VARIABLE PT_CXX_FLAG
RESULT_VARIABLE PT_STATUS)

string(REGEX REPLACE "\n" "" PT_CXX_FLAG "${PT_CXX_FLAG}")
message(STATUS "Found TORCH_BUILT_WITH_CXX_ABI=${PT_CXX_FLAG} ")

if(${PT_CXX_FLAG} STREQUAL "False")
set(CXX_ABI_ENABLED 0)
else()
set(CXX_ABI_ENABLED 1)
endif()

set_property(
SOURCE
${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc
APPEND PROPERTY
COMPILE_OPTIONS
"-D_GLIBCXX_USE_CXX11_ABI=${CXX_ABI_ENABLED}"
"-I${PT_PATH}/include"
)

set_property(
SOURCE
${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/pt_call_tvm/tvm_class.cc
APPEND PROPERTY
COMPILE_OPTIONS
"-I${PT_PATH}/include"
)

set(PT_LINK_FLAGS_STR "-L${PT_PATH}/lib -l:libtorch.so -l:libtorch_python.so")

if(NOT USE_CUDA STREQUAL "OFF")
add_definitions(-DPT_TVMDSOOP_ENABLE_GPU)
endif()


string(REGEX REPLACE "\n" " " PT_FLAGS "${PT_COMPILE_FLAGS} ${PT_LINK_FLAGS}")
separate_arguments(PT_COMPILE_FLAGS UNIX_COMMAND ${PT_COMPILE_FLAGS_STR})
separate_arguments(PT_COMPILE_FLAGS UNIX_COMMAND)
separate_arguments(PT_LINK_FLAGS UNIX_COMMAND ${PT_LINK_FLAGS_STR})

# This old version is depereated and will be removed after tvm 0.11
set(LIBRARY_OLD_NAME pt_tvmdsoop)

set(LIBRARY_NAME pt_tvmdsoop)
tvm_file_glob(GLOB_RECURSE PTTVM_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/**/*.cc)
add_library(${LIBRARY_NAME} SHARED ${PTTVM_SRCS})
# This new library is set for pytorch integration, which solves the c++ abi imcompability issue
set(LIBRARY_NEW_NAME pt_tvmdsoop_new)
tvm_file_glob(GLOB_RECURSE PTTVM_TORCH ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/tvm_module_wrapper/*.cc)

tvm_file_glob(GLOB_RECURSE PTTVM_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/pt_call_tvm/*.cc)

add_library(${LIBRARY_OLD_NAME} SHARED ${PTTVM_SRCS})
add_library(${LIBRARY_NEW_NAME} SHARED ${PTTVM_TORCH})
set(PTTVM_LINK_FLAGS -ltvm -L${CMAKE_CURRENT_BINARY_DIR})

if (NOT BUILD_PT_TVMDSOOP_ONLY STREQUAL "ON")
add_dependencies(${LIBRARY_NAME} tvm)
if(NOT BUILD_PT_TVMDSOOP_ONLY STREQUAL "ON")
add_dependencies(${LIBRARY_OLD_NAME} tvm)
add_dependencies(${LIBRARY_NEW_NAME} tvm)
endif()

target_compile_options(${LIBRARY_NAME} PUBLIC ${PTTVM_COMPILE_FLAGS} ${PT_COMPILE_FLAGS})
target_link_libraries(${LIBRARY_NAME} PUBLIC ${PTTVM_LINK_FLAGS} ${PT_LINK_FLAGS})
target_compile_definitions(${LIBRARY_NAME} PUBLIC DMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>)
target_compile_options(${LIBRARY_OLD_NAME} PUBLIC ${PTTVM_COMPILE_FLAGS} ${PT_COMPILE_FLAGS})
target_link_libraries(${LIBRARY_OLD_NAME} PUBLIC ${PTTVM_LINK_FLAGS} ${PT_LINK_FLAGS})
target_compile_definitions(${LIBRARY_OLD_NAME} PUBLIC DMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>)

target_compile_options(${LIBRARY_NEW_NAME} PUBLIC ${PTTVM_COMPILE_FLAGS} ${PT_COMPILE_FLAGS})
target_link_libraries(${LIBRARY_NEW_NAME} PUBLIC ${PTTVM_LINK_FLAGS} ${PT_LINK_FLAGS})
target_compile_definitions(${LIBRARY_NEW_NAME} PUBLIC DMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>)
endif()

25 changes: 21 additions & 4 deletions python/tvm/contrib/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
"""Module container of Pytorch custom class"""
import os
import platform
import warnings
import torch
from tvm._ffi import libinfo


def _load_platform_specific_library(lib_name="libpt_tvmdsoop"):
def _load_platform_specific_library(lib_name):
system = platform.system()
if system == "Darwin":
lib_file_name = lib_name + ".dylib"
Expand All @@ -33,11 +34,27 @@ def _load_platform_specific_library(lib_name="libpt_tvmdsoop"):
lib_path = libinfo.find_lib_path()[0]
lib_dir = os.path.dirname(lib_path)
lib_file_path = os.path.join(lib_dir, lib_file_name)
torch.classes.load_library(lib_file_path)
try:
torch.classes.load_library(lib_file_path)
except OSError as err:
errmsg = str(err)
if errmsg.find("undefined symbol") != -1:
reason = " ".join(
(
"Got undefined symbol error,",
"which might be due to the CXXABI incompatibility.",
)
)
else:
reason = errmsg
warnings.warn(
f"The library {lib_name} is not built successfully. {reason}",
RuntimeWarning,
)


_load_platform_specific_library()

_load_platform_specific_library("libpt_tvmdsoop")
_load_platform_specific_library("libpt_tvmdsoop_new")

from . import module

Expand Down
17 changes: 17 additions & 0 deletions python/tvm/contrib/torch/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
# under the License.
# pylint: disable=invalid-name
"""Module container of PyTorch custom class"""
import warnings
from typing import List

import torch


Expand All @@ -29,6 +31,11 @@ def shape_repr(cls, input_shapes):
return torch.ops.tvm_dsoop.tvm_shape_repr(input_shapes)

def __init__(self, num_inputs, num_outputs, device=None):
warnings.warn(
"This module will be removed at TVM version 0.11",
DeprecationWarning,
stacklevel=2,
)
super().__init__()
self.dummy_param = torch.nn.Parameter(torch.empty(0))
self.engine = None
Expand Down Expand Up @@ -67,6 +74,11 @@ def shape_repr(cls, input_shapes):
return torch.ops.tvm_dsoop.tvm_shape_repr(input_shapes)

def __init__(self, num_inputs, num_outputs, device=None):
warnings.warn(
"This module will be removed at TVM version 0.11",
DeprecationWarning,
stacklevel=2,
)
super().__init__()
self.dummy_param = torch.nn.Parameter(torch.empty(0))
self.engine = None
Expand Down Expand Up @@ -113,6 +125,11 @@ class TraceTvmModule(torch.nn.Module):
"""

def __init__(self, tvm_module):
warnings.warn(
"This module will be removed at TVM version 0.11",
DeprecationWarning,
stacklevel=2,
)
super().__init__()
self.tvm_module = tvm_module

Expand Down
21 changes: 21 additions & 0 deletions python/tvm/contrib/torch/pytorch_tvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# pylint: disable=redefined-builtin
"""`compile` api that convert torch module to torch tvm module"""
import os
import warnings
import tvm
import tvm.testing
from tvm import relay, autotvm
Expand Down Expand Up @@ -183,6 +184,16 @@ def load_tvm(self, export_dir):

def build_pytorch_module(self, num_inputs, num_outputs, input_infos=None):
"""Build pytorch module containing TVM Graph Module"""
warnings.warn(
" ".join(
(
"This function will be removed at TVM version 0.11,",
"we suggest users to use `optimized_torch` for tuning Torch modules instead.",
)
),
DeprecationWarning,
stacklevel=2,
)
assert self.export_dir, "you must build_tvm or load_tvm before"
input_infos = input_infos or self.input_infos
assert input_infos
Expand Down Expand Up @@ -224,6 +235,16 @@ def compile(script_module, option):
pytorch_tvm_module = compile(script_module, option)
pytorch_tvm_module("model_tvm.pt")
"""
warnings.warn(
" ".join(
(
"This function will be removed at TVM version 0.11,",
"we suggest users to use `optimized_torch` for tuning Torch modules instead.",
)
),
DeprecationWarning,
stacklevel=2,
)
input_infos = option["input_infos"]
default_dtype = option.get("default_dtype", "float32")
export_dir = option.get("export_dir", "pytorch_compiled")
Expand Down
Loading

0 comments on commit 073304d

Please sign in to comment.