Skip to content

Commit

Permalink
[CINN]Enhance CacheKey hash logic by considering input dtypes (Paddle…
Browse files Browse the repository at this point in the history
…Paddle#50557)

* [CINN]Enhance CacheKey hash logic by considering input dtypes

* add unittest

* fix typo

* fix typo

* fix map.at

* fix find

* fix test

* fix cinn cache key structure realize

* using ordered map for attributes

* add test by review advice

---------

Co-authored-by: jiangcheng <thisjiang@qq.com>
  • Loading branch information
2 people authored and cxxly committed Mar 13, 2023
1 parent 4a48497 commit 7e66b47
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 9 deletions.
24 changes: 23 additions & 1 deletion python/paddle/fluid/tests/unittests/autograd/test_primapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
import autograd.scipy as ascipy
import config
import numpy as np
import parameterized as param
import utils

import paddle
from paddle.incubate.autograd import primx
from paddle.fluid import core
from paddle.incubate.autograd import primapi, primx


@utils.place(config.DEVICES)
Expand Down Expand Up @@ -1034,5 +1036,25 @@ def actual():
np.testing.assert_allclose(i, j, rtol=self._rtol, atol=self._atol)


class TestToPrim(unittest.TestCase):
def setUp(self):
paddle.enable_static()
core._set_prim_forward_enabled(True)

def tearDown(self):
core._set_prim_forward_enabled(False)
paddle.disable_static()

@param.parameterized((('dropout',),))
def test_exclude(self, exclude):
program = paddle.static.Program()
with paddle.static.program_guard(program):
x = paddle.rand((1,))
y = paddle.nn.functional.dropout(x)
primapi.to_prim(program, exclude)
ops = tuple(op.type for op in program.block(0).ops)
self.assertTrue(all(tuple(op in ops for op in exclude)))


if __name__ == '__main__':
unittest.main()
19 changes: 12 additions & 7 deletions python/paddle/incubate/autograd/primapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,12 +217,16 @@ def grad(outputs, inputs, grad_outputs=None):


@framework.static_only
def to_prim(blocks):
"""Search nonbasic ops which have be registered composite rules and replace them with primitive ops."""
def to_prim(blocks, exclude=frozenset()):
"""Search nonbasic ops which have be registered composite rules and replace them with primitive ops.
Args:
exclude(frozenset): The Operators that will be exclude in lowering.
"""
if not core._is_fwd_prim_enabled():
return
if isinstance(blocks, paddle.fluid.framework.Block):
logging.info("Atomize composite op to primitive ops begin.")
logging.debug("Atomize composite op to primitive ops begin.")
main_program = blocks.program
elif isinstance(blocks, typing.Sequence):
for item in blocks:
Expand All @@ -236,8 +240,9 @@ def to_prim(blocks):
f"Expect block or sequence of blocks, but got {type(blocks)}."
)
with framework.program_guard(main_program):
print("Lowering composite forward ops begin...")
primx._lower_composite(blocks, prim_config["forward_blacklist"])
logging.debug("Lowering composite forward ops begin...")
primx._lower_composite(
blocks, prim_config["forward_blacklist"] | exclude
)
replace_ops = prim_config["composite_ops_record"]
print(f"Lowering composite forward ops finish: {replace_ops}")
return
logging.debug(f"Lowering composite forward ops finish: {replace_ops}")
2 changes: 1 addition & 1 deletion python/paddle/incubate/autograd/primx.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ def expand_nested_list(xs):
block._sync_with_cpp()


def _lower_composite(block, blacklist=[]):
def _lower_composite(block, blacklist=frozenset()):
# Some functions which are only used in _lower.
def bind(args, to_bind, value_table):
for i in range(len(args)):
Expand Down

0 comments on commit 7e66b47

Please sign in to comment.