Skip to content

Commit

Permalink
[AMP][Custom Device] add fp16 op detection for custom device (#56053)
Browse files Browse the repository at this point in the history
* [AMP] add fp16 op detection for custom device

* resolve conflicts
  • Loading branch information
jinyouzhi authored Aug 14, 2023
1 parent 066097e commit 9d40da3
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
15 changes: 9 additions & 6 deletions paddle/fluid/imperative/amp_auto_cast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,16 @@ OpSupportedInfos(const std::string& place,
{"CPU", &platform::is_cpu_place},
{"XPU", &platform::is_xpu_place},
{"CUSTOM_DEVICE", &platform::is_custom_place},
#ifdef PADDLE_WITH_CUSTOM_DEVICE
{query_place, &platform::is_custom_place},
#endif
};
PADDLE_ENFORCE_NE(
is_target_place.count(query_place),
0,
platform::errors::InvalidArgument(
"The argument `place` should be 'GPU', 'CPU', 'XPU', but got '%s'.",
place));
PADDLE_ENFORCE_NE(is_target_place.count(query_place),
0,
platform::errors::InvalidArgument(
"The argument `place` should be 'GPU', 'CPU', 'XPU' or "
"other Custom Device, but got '%s'.",
place));

std::unordered_set<std::string> all_ops;
const auto& op_info = framework::OpInfoMap::Instance().map();
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/static/amp/fp16_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _get_sys_unsupported_list(dtype):
elif isinstance(
paddle.framework._current_expected_place(), paddle.CustomPlace
):
device = 'CUSTOM_DEVICE'
device = paddle.framework._current_expected_place().get_device_type()
else:
device = 'GPU'
all_ops, _, sys_unsupported_list = core.op_supported_infos(device, var_type)
Expand Down

0 comments on commit 9d40da3

Please sign in to comment.