Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SOT] update SIR export check, add _reconstruct for TensorDtypeVariable #61822

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,20 @@ def get_py_type(self):
return paddle.pir.core.DataType
return super().get_py_type()

def _reconstruct(self, codegen: PyCodeGen):
# dtype of paddle.Tensor is hashable, we can just load it as const var
if use_pir_api() and isinstance(
self.value, paddle.base.core.VarDesc.VarType
):
assert (
self.value in paddle.pir.core.vartype_to_datatype
), f"Unknow dtype {self.value}"
codegen.gen_load_const(
paddle.pir.core.vartype_to_datatype[self.value]
)
else:
codegen.gen_load_const(self.value)

@property
def main_info(self) -> dict[str, Any]:
return {
Expand Down
42 changes: 19 additions & 23 deletions python/paddle/jit/sot/symbolic/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@
# limitations under the License.

import os
import sys
from itertools import chain

import paddle
from paddle.utils import flatten

from ..utils import ConstTypes, ExportError, NameGenerator
from ..utils import ConstTypes, ExportError, NameGenerator, get_api_fullname
from .statement_ir import Symbol


Expand Down Expand Up @@ -138,22 +137,27 @@ def gen_py_codes(self):
self.create_tail()
return self.roots_to_string()

def is_exportable_type(self, value):
if (
isinstance(value, (ConstTypes, Symbol, paddle.dtype))
or value is Ellipsis # NOINT
):
return True
if isinstance(value, slice):
return (
self.is_exportable_type(value.start)
and self.is_exportable_type(value.stop)
and self.is_exportable_type(value.step)
)
return False

def check_exportable(self):
for stmt in self.SIR.statements:
for inp in flatten(stmt.inputs):
if not isinstance(inp, ConstTypes) and not isinstance(
inp, Symbol
):
if not self.is_exportable_type(inp):
raise ExportError(
f"Not support create python file with input: {inp}"
)
for out in flatten(stmt.outputs):
if not isinstance(out, ConstTypes) and not isinstance(
out, Symbol
):
raise ExportError(
f"Not support create python file with output: {out}"
)

def create_header(self):
self.new_root(
Expand Down Expand Up @@ -341,20 +345,12 @@ def create_stmt_line(self, stmt):
return getattr(self, "create_" + stmt.type + "_stmt")(stmt)

def create_api_stmt(self, stmt):
def get_api_str(api):
api_name = api.__name__
module_str = api.__module__
while len(module_str) > 0:
module = sys.modules[module_str]
if hasattr(module, api_name):
return module_str + "." + api_name
module_str = module_str.rpartition(".")[0]
raise ExportError(f"Can not find module of {api}")

args, kwargs = stmt.inputs
input_str = self.create_input_string(args, kwargs)
api = stmt.api
api_str = get_api_str(api)
api_str = get_api_fullname(api)
if api_str is None:
raise ExportError(f"Can not find module of {api}")
if isinstance(stmt.outputs, Symbol):
return [f"{self.name_gener(stmt.outputs)} = {api_str}({input_str})"]
else:
Expand Down
9 changes: 5 additions & 4 deletions python/paddle/jit/sot/symbolic/statement_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from paddle.jit.utils import OrderedSet
from paddle.utils import flatten, map_structure

from ..utils import NameGenerator, Singleton, flatten_extend
from ..utils import NameGenerator, Singleton, flatten_extend, get_api_fullname


class Reference: # to unify weak_ref and strong_ref
Expand Down Expand Up @@ -135,9 +135,10 @@ def __init__(
outputs: list[Symbol],
stacks: list[str],
):
super().__init__(
"api", api.__module__ + "." + api.__name__, inputs, outputs, stacks
)
fullname = get_api_fullname(api)
if fullname is None:
fullname = "paddle." + api.__name__
super().__init__("api", fullname, inputs, outputs, stacks)
self.api = api


Expand Down
1 change: 1 addition & 0 deletions python/paddle/jit/sot/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
execute_time,
flatten,
flatten_extend,
get_api_fullname,
get_unbound_method,
hashable,
in_paddle_module,
Expand Down
12 changes: 12 additions & 0 deletions python/paddle/jit/sot/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import builtins
import inspect
import sys
import time
import types
import weakref
Expand Down Expand Up @@ -511,3 +512,14 @@ def clear(self):
self.step_record.clear()
self.current_code = None
self.current_step = -1


def get_api_fullname(api):
api_name = api.__name__
module_str = api.__module__
while len(module_str) > 0:
module = sys.modules[module_str]
if hasattr(module, api_name):
return module_str + "." + api_name
module_str = module_str.rpartition(".")[0]
return None
1 change: 1 addition & 0 deletions test/sot/skip_files_py312
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,5 @@
./test_specialization.py
./test_str_format.py
./test_tensor_dtype_in_guard.py
./test_dtype.py
./test_builtin_bool.py
16 changes: 16 additions & 0 deletions test/sot/test_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,16 @@ def tensor_dtype_guard(x):
return x + 1


def reconstruct_dtype():
x = paddle.to_tensor(1)
z = x.dtype
if x > 0:
y = paddle.to_tensor(1, dtype=z)
else:
y = 1
return y


class TestTensorAstype(TestCaseBase):
@run_in_both_default_and_pir
def test_tensor_astype(self):
Expand All @@ -54,5 +64,11 @@ def test_tensor_dtype_guard(self):
self.assertEqual(ctx.translate_count, 2)


class TestDtypeReconstruct(TestCaseBase):
@run_in_both_default_and_pir
def test_dtype_reconstruct(self):
self.assert_results(reconstruct_dtype)


if __name__ == "__main__":
unittest.main()