Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace framework::proto::VarType part #64388

Merged
merged 2 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions paddle/fluid/operators/fused/onednn/fusion_lstm_onednn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -398,15 +398,13 @@ class FusionLSTMMKLDNNKernel : public framework::OpKernel<T> {
std::shared_ptr<dnnl::memory> h0_memory_p, weight_h_memory_p,
weight_x_memory_p;

if (framework::TransToProtoVarType(weight_h->dtype()) ==
paddle::framework::proto::VarType_Type_FP32) {
if (weight_h->dtype() == phi::DataType::FLOAT32) {
h0_memory_p = handler.template AcquireH0Memory<float>(h0);
weight_x_memory_p =
handler.template AcquireWeightXMemory<float>(weight_x);
weight_h_memory_p =
handler.template AcquireWeightHMemory<float>(weight_h);
} else if (framework::TransToProtoVarType(weight_h->dtype()) ==
paddle::framework::proto::VarType_Type_BF16) {
} else if (weight_h->dtype() == phi::DataType::BFLOAT16) {
h0_memory_p = handler.template AcquireH0Memory<phi::dtype::bfloat16>(h0);
weight_x_memory_p =
handler.template AcquireWeightXMemory<phi::dtype::bfloat16>(weight_x);
Expand Down
14 changes: 5 additions & 9 deletions paddle/fluid/operators/fused/resnet_basic_block_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -225,26 +225,22 @@ class ResNetBasicBlockOp : public framework::OperatorWithKernel {

// By default, the type of the scale, bias, mean,
// and var tensors should be float when input tensor's dtype is float16.
auto bn_param_type = framework::proto::VarType::FP32;
auto bn_param_type = phi::DataType::FLOAT32;
PADDLE_ENFORCE_EQ(
bn_param_type,
framework::TransToProtoVarType(
ctx.Input<phi::DenseTensor>("Scale1")->dtype()),
ctx.Input<phi::DenseTensor>("Scale1")->dtype(),
phi::errors::InvalidArgument("Scale input should be of float type"));
PADDLE_ENFORCE_EQ(
bn_param_type,
framework::TransToProtoVarType(
ctx.Input<phi::DenseTensor>("Bias1")->dtype()),
ctx.Input<phi::DenseTensor>("Bias1")->dtype(),
phi::errors::InvalidArgument("Bias input should be of float type"));
PADDLE_ENFORCE_EQ(
bn_param_type,
framework::TransToProtoVarType(
ctx.Input<phi::DenseTensor>("Scale2")->dtype()),
ctx.Input<phi::DenseTensor>("Scale2")->dtype(),
phi::errors::InvalidArgument("Scale input should be of float type"));
PADDLE_ENFORCE_EQ(
bn_param_type,
framework::TransToProtoVarType(
ctx.Input<phi::DenseTensor>("Bias2")->dtype()),
ctx.Input<phi::DenseTensor>("Bias2")->dtype(),
phi::errors::InvalidArgument("Bias input should be of float type"));

return phi::KernelKey(input_data_type, ctx.GetPlace());
Expand Down
8 changes: 3 additions & 5 deletions paddle/fluid/operators/fused/resnet_unit_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,17 +206,15 @@ class ResNetUnitOp : public framework::OperatorWithKernel {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
// By default, the type of the scale, bias, mean,
// and var tensors should be float when input tensor's dtype is float16.
auto bn_param_type = framework::proto::VarType::FP32;
auto bn_param_type = phi::DataType::FLOAT32;

PADDLE_ENFORCE_EQ(
bn_param_type,
framework::TransToProtoVarType(
ctx.Input<phi::DenseTensor>("ScaleX")->dtype()),
ctx.Input<phi::DenseTensor>("ScaleX")->dtype(),
phi::errors::InvalidArgument("Scale input should be of float type"));
PADDLE_ENFORCE_EQ(
bn_param_type,
framework::TransToProtoVarType(
ctx.Input<phi::DenseTensor>("BiasX")->dtype()),
ctx.Input<phi::DenseTensor>("BiasX")->dtype(),
phi::errors::InvalidArgument("Bias input should be of float type"));
return phi::KernelKey(input_data_type, ctx.GetPlace());
}
Expand Down
3 changes: 1 addition & 2 deletions paddle/fluid/operators/lod_tensor_to_array_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ struct LoDTensorToArrayFunctor {
LoDTensorToArrayFunctorImpl<DeviceContext> func;
func.prev_functor_ = this;
func.dev_ctx_ = dev_ctx;
framework::VisitDataType(framework::TransToProtoVarType(input_.dtype()),
func);
phi::VisitDataType(input_.dtype(), func);
}
};

Expand Down
15 changes: 7 additions & 8 deletions paddle/fluid/operators/lookup_table_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,7 @@ class LookupTableKernel : public framework::OpKernel<T> {
int64_t row_width = table_t.value().dims()[1];
const auto *table = table_t.value().data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
auto input_data_type =
framework::TransToProtoVarType(table_t.value().dtype());
auto input_data_type = table_t.value().dtype();
for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
memset(output + i * row_width, 0, row_width * sizeof(T));
Expand All @@ -107,9 +106,9 @@ class LookupTableKernel : public framework::OpKernel<T> {
auto id_index = table_t.GetIndexFromId(ids[i]);

if (id_index != -1) {
if (input_data_type == framework::proto::VarType::INT8 ||
input_data_type == framework::proto::VarType::INT16 ||
input_data_type == framework::proto::VarType::BF16) {
if (input_data_type == phi::DataType::INT8 ||
input_data_type == phi::DataType::INT16 ||
input_data_type == phi::DataType::BFLOAT16) {
memcpy(output + i * row_width,
table + id_index * row_width,
row_width * sizeof(T));
Expand Down Expand Up @@ -140,9 +139,9 @@ class LookupTableKernel : public framework::OpKernel<T> {
"the input key should be exists. But received %d.",
id_index));

if (input_data_type == framework::proto::VarType::INT8 ||
input_data_type == framework::proto::VarType::INT16 ||
input_data_type == framework::proto::VarType::BF16) {
if (input_data_type == phi::DataType::INT8 ||
input_data_type == phi::DataType::INT16 ||
input_data_type == phi::DataType::BFLOAT16) {
memcpy(output + i * row_width,
table + id_index * row_width,
row_width * sizeof(T));
Expand Down
17 changes: 8 additions & 9 deletions paddle/fluid/operators/optimizers/sparse_momentum_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -304,11 +304,11 @@ class SparseMomentumOpKernel : public framework::OpKernel<T> {
const bool multi_precision = ctx.Attr<bool>("multi_precision");
bool use_nesterov = ctx.Attr<bool>("use_nesterov");
auto index = ctx.Input<phi::DenseTensor>("Index");
const auto& index_type = framework::TransToProtoVarType(index->dtype());
const auto& index_type = index->dtype();
if (multi_precision) {
if (use_nesterov) {
auto update_method = UseNesterov<MPDType>();
if (index_type == framework::proto::VarType::INT32) {
if (index_type == phi::DataType::INT32) {
InnerCompute<MPDType, int, UseNesterov<MPDType>>(
ctx, multi_precision, update_method);
} else {
Expand All @@ -317,7 +317,7 @@ class SparseMomentumOpKernel : public framework::OpKernel<T> {
}
} else {
auto update_method = NoNesterov<MPDType>();
if (index_type == framework::proto::VarType::INT32) {
if (index_type == phi::DataType::INT32) {
InnerCompute<MPDType, int, NoNesterov<MPDType>>(
ctx, multi_precision, update_method);
} else {
Expand All @@ -328,7 +328,7 @@ class SparseMomentumOpKernel : public framework::OpKernel<T> {
} else {
if (use_nesterov) {
auto update_method = UseNesterov<T>();
if (index_type == framework::proto::VarType::INT32) {
if (index_type == phi::DataType::INT32) {
InnerCompute<T, int, UseNesterov<T>>(
ctx, multi_precision, update_method);
} else {
Expand All @@ -337,7 +337,7 @@ class SparseMomentumOpKernel : public framework::OpKernel<T> {
}
} else {
auto update_method = NoNesterov<T>();
if (index_type == framework::proto::VarType::INT32) {
if (index_type == phi::DataType::INT32) {
InnerCompute<T, int, NoNesterov<T>>(
ctx, multi_precision, update_method);
} else {
Expand Down Expand Up @@ -371,11 +371,10 @@ class SparseMomentumOpKernel : public framework::OpKernel<T> {
phi::DenseTensor cpu_axis;
const phi::DenseTensor* axis_tensor = ctx.Input<phi::DenseTensor>("Axis");
framework::TensorCopy(*axis_tensor, phi::CPUPlace(), &cpu_axis);
const auto& axis_type =
framework::TransToProtoVarType(axis_tensor->dtype());
if (axis_type == framework::proto::VarType::INT32) {
const auto& axis_type = axis_tensor->dtype();
if (axis_type == phi::DataType::INT32) {
axis = static_cast<int>(cpu_axis.data<int32_t>()[0]);
} else if (axis_type == framework::proto::VarType::INT64) {
} else if (axis_type == phi::DataType::INT64) {
axis = static_cast<int>(cpu_axis.data<int64_t>()[0]);
}
} else {
Expand Down
15 changes: 5 additions & 10 deletions paddle/fluid/operators/uniform_random_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ namespace operators {

inline std::vector<int64_t> GetNewDataFromShapeTensor(
const phi::DenseTensor* new_data_tensor) {
if (framework::TransToProtoVarType(new_data_tensor->dtype()) ==
framework::proto::VarType::INT64) {
if (new_data_tensor->dtype() == phi::DataType::INT64) {
auto* new_data = new_data_tensor->data<int64_t>();
phi::DenseTensor cpu_starts_tensor;
if (new_data_tensor->place().GetType() == phi::AllocationType::GPU) {
Expand All @@ -45,8 +44,7 @@ inline std::vector<int64_t> GetNewDataFromShapeTensor(
std::vector<int64_t> vec_new_data(new_data,
new_data + new_data_tensor->numel());
return vec_new_data;
} else if (framework::TransToProtoVarType(new_data_tensor->dtype()) ==
framework::proto::VarType::INT32) {
} else if (new_data_tensor->dtype() == phi::DataType::INT32) {
auto* new_data = new_data_tensor->data<int32_t>();
std::vector<int64_t> vec_new_data;
phi::DenseTensor cpu_starts_tensor;
Expand Down Expand Up @@ -81,17 +79,15 @@ inline std::vector<int64_t> GetNewDataFromShapeTensorList(
"But received tensor's dim=%s.",
tensor->dims()));

if (framework::TransToProtoVarType(tensor->dtype()) ==
framework::proto::VarType::INT32) {
if (tensor->dtype() == phi::DataType::INT32) {
if (tensor->place().GetType() == phi::AllocationType::GPU) {
phi::DenseTensor temp;
paddle::framework::TensorCopySync(*tensor, phi::CPUPlace(), &temp);
vec_new_shape.push_back(static_cast<int64_t>(*temp.data<int32_t>()));
} else {
vec_new_shape.push_back(static_cast<int64_t>(*tensor->data<int32_t>()));
}
} else if (framework::TransToProtoVarType(tensor->dtype()) ==
framework::proto::VarType::INT64) {
} else if (tensor->dtype() == phi::DataType::INT64) {
if (tensor->place().GetType() == phi::AllocationType::GPU) {
phi::DenseTensor temp;
paddle::framework::TensorCopySync(*tensor, phi::CPUPlace(), &temp);
Expand All @@ -105,8 +101,7 @@ inline std::vector<int64_t> GetNewDataFromShapeTensorList(
"But got "
"unsupport dtype: %s.",
i,
paddle::framework::DataTypeToString(
framework::TransToProtoVarType(tensor->dtype()))));
phi::DataTypeToString(tensor->dtype())));
}
}

Expand Down