Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Cherry-Pick][JITLayer]Enable OneDNN on CPU and Fix zero shape (#47428) #47436

Merged
merged 1 commit into from
Oct 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,9 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt,
// NOTE(Aurelius84): Some kernels support zero shape input
// without memory holder, we should skip enforce logic.
bool has_zero_dim = (phi::product(ddim) == 0);
if (has_zero_dim) {
VLOG(3) << "Found zero dim from input with ddim: " << ddim;
VLOG(3) << "Found zero dim: " << has_zero_dim
<< " from input with ddim: " << ddim;
if (!has_zero_dim) {
PADDLE_ENFORCE_NOT_NULL(
input_ptr,
paddle::platform::errors::Fatal(
Expand Down
68 changes: 37 additions & 31 deletions paddle/fluid/jit/engine/predictor_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,16 @@ PredictorEngine::PredictorEngine(const std::shared_ptr<FunctionInfo> &info,
utils::ShareParamsIntoScope(info_->ParamNames(), params_dict, scope_.get());
VLOG(6) << framework::GenScopeTreeDebugInfo(scope_.get());

// TODO(Aurelius84): Expose AnalysisConfig to user.
AnalysisConfig config;
config.SetProgFile(info->ProgramFilePath());
if (platform::is_gpu_place(place_)) {
config.EnableUseGpu(100, place_.GetDeviceId());
} else if (platform::is_cpu_place(place_)) {
config.DisableGpu();
config.EnableMKLDNN();
config.EnableMkldnnInt8();
config.SetMkldnnCacheCapacity(0);
}
config.SetSkipLoadParams(true);
config.SetApplyOptim(true);
Expand All @@ -59,10 +63,6 @@ std::vector<Tensor> PredictorEngine::operator()(

std::vector<DenseTensor> PredictorEngine::operator()(
const std::vector<DenseTensor> &inputs) {
for (auto t : inputs) {
VLOG(1) << "inputs is init: " << t.initialized();
}

std::vector<PaddleTensor> pt_inputs;
std::vector<PaddleTensor> pt_outputs;
for (auto &t : inputs) {
Expand All @@ -84,22 +84,23 @@ std::vector<DenseTensor> PredictorEngine::operator()(

static PaddleTensor DenseTensorToPaddleTensor(DenseTensor *t) {
PaddleTensor pt;

if (framework::TransToProtoVarType(t->dtype()) ==
framework::proto::VarType::INT32) {
pt.data.Reset(t->data(), t->numel() * sizeof(int32_t));
pt.dtype = PaddleDType::INT32;
} else if (framework::TransToProtoVarType(t->dtype()) ==
framework::proto::VarType::INT64) {
pt.data.Reset(t->data(), t->numel() * sizeof(int64_t));
pt.dtype = PaddleDType::INT64;
} else if (framework::TransToProtoVarType(t->dtype()) ==
framework::proto::VarType::FP32) {
pt.data.Reset(t->data(), t->numel() * sizeof(float));
pt.dtype = PaddleDType::FLOAT32;
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported tensor date type. Now only supports INT64, FP32, INT32."));
switch (framework::TransToProtoVarType(t->dtype())) {
case framework::proto::VarType::INT32: {
pt.data.Reset(t->data(), t->numel() * sizeof(int32_t));
pt.dtype = PaddleDType::INT32;
} break;
case framework::proto::VarType::INT64: {
pt.data.Reset(t->data(), t->numel() * sizeof(int64_t));
pt.dtype = PaddleDType::INT64;
} break;
case framework::proto::VarType::FP32: {
pt.data.Reset(t->data(), t->numel() * sizeof(float));
pt.dtype = PaddleDType::FLOAT32;
} break;
default:
PADDLE_THROW(
platform::errors::Unimplemented("Unsupported tensor date type. Now "
"only supports INT64, FP32, INT32."));
}
pt.shape = phi::vectorize<int>(t->dims());
return pt;
Expand All @@ -110,17 +111,22 @@ static bool PaddleTensorToDenseTensor(const PaddleTensor &pt,
const platform::Place &place) {
framework::DDim ddim = phi::make_ddim(pt.shape);
void *input_ptr;
if (pt.dtype == PaddleDType::INT64) {
input_ptr = t->mutable_data<int64_t>(ddim, place);
} else if (pt.dtype == PaddleDType::FLOAT32) {
input_ptr = t->mutable_data<float>(ddim, place);
} else if (pt.dtype == PaddleDType::INT32) {
input_ptr = t->mutable_data<int32_t>(ddim, place);
} else if (pt.dtype == PaddleDType::FLOAT16) {
input_ptr = t->mutable_data<float16>(ddim, place);
} else {
LOG(ERROR) << "unsupported feed type " << pt.dtype;
return false;
switch (pt.dtype) {
case PaddleDType::INT64:
input_ptr = t->mutable_data<int64_t>(ddim, place);
break;
case PaddleDType::FLOAT32:
input_ptr = t->mutable_data<float>(ddim, place);
break;
case PaddleDType::INT32:
input_ptr = t->mutable_data<int32_t>(ddim, place);
break;
case PaddleDType::FLOAT16:
input_ptr = t->mutable_data<float16>(ddim, place);
break;
default:
LOG(ERROR) << "unsupported feed type " << pt.dtype;
return false;
}

PADDLE_ENFORCE_NOT_NULL(
Expand Down