Skip to content

Commit

Permalink
[security] refine _get_program_cache_key (PaddlePaddle#61827) (Paddle…
Browse files Browse the repository at this point in the history
…Paddle#61896)

* security, refine _get_program_cache_key
  • Loading branch information
wanghuancoder authored Feb 21, 2024
1 parent 39010bf commit b6a38d0
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions python/paddle/base/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,15 +682,19 @@ def _get_varname_from_block(block):
)


def _get_program_cache_key(feed, fetch_list):
def _get_feed_fetch_var_names(feed, fetch_list):
feed_var_names = []
if isinstance(feed, dict):
feed_var_names = list(feed.keys())
elif isinstance(feed, (list, tuple)):
for i, each in enumerate(feed):
feed_var_names += list(each.keys())
fetch_var_names = list(map(_to_name_str, fetch_list))
return str(feed_var_names + fetch_var_names)
return feed_var_names + fetch_var_names


def _get_program_cache_key(feed, fetch_list):
return str(_get_feed_fetch_var_names(feed, fetch_list))


def _as_lodtensor(data, place, dtype=None):
Expand Down Expand Up @@ -1026,7 +1030,7 @@ def _get_program_and_executor(self, cached_data):

if enable_inplace or enable_addto:
# inplace should skip feed and fetch var
skip_var_names = eval(_get_program_cache_key(feed, fetch_list))
skip_var_names = _get_feed_fetch_var_names(feed, fetch_list)
_apply_inplace_addto_pass(
program, enable_inplace, enable_addto, skip_var_names
)
Expand Down

0 comments on commit b6a38d0

Please sign in to comment.