Skip to content

Commit

Permalink
[PT] Set NCHW as default Layout for type translator (#62263)
Browse files Browse the repository at this point in the history
* [PT] Set NCHW as default Layout for type translator

* fix randint

* fix typo

* fix delt
  • Loading branch information
Aurelius84 authored Mar 2, 2024
1 parent eabf863 commit 6f608ca
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 47 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/ir_adaptor/translator/op_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2746,7 +2746,7 @@ struct RandIntOpTranscriber : public OpTranscriber {
paddle::dialect::DenseTensorTypeStorage::Dim dim =
common::make_ddim(var->GetShape());
paddle::dialect::DenseTensorTypeStorage::DataLayout layout =
paddle::dialect::DenseTensorTypeStorage::DataLayout::UNDEFINED;
paddle::dialect::DenseTensorTypeStorage::DataLayout::NCHW;
paddle::dialect::DenseTensorTypeStorage::LoD lod = {};
size_t offset = 0;
pir::Type translated_var_type = paddle::dialect::DenseTensorType::get(
Expand Down
89 changes: 43 additions & 46 deletions paddle/fluid/ir_adaptor/translator/type_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,48 @@ using DenseTensorType = paddle::dialect::DenseTensorType;
using DenseTensorTypeStorage = paddle::dialect::DenseTensorTypeStorage;
using SelectedRowsType = paddle::dialect::SelectedRowsType;
using SelectedRowsTypeStorage = paddle::dialect::SelectedRowsTypeStorage;
using DataLayout = DenseTensorTypeStorage::DataLayout;
using LoD = DenseTensorTypeStorage::LoD;

TypeTranslator::TypeTranslator() {
const auto& HandleTensor = [&](pir::IrContext* ctx,
const VarDesc& var_desc) -> pir::Type {
VLOG(10) << "[vartype translating]"
<< "[" << var_desc.Name() << "] from LOD_TENSOR";
const pir::Type dtype =
this->operator[](var_desc.GetDataType())(ctx, var_desc);
const auto dim = common::make_ddim(var_desc.GetShape());
const auto layout = DataLayout::NCHW;
const LoD lod = {};
const size_t offset = 0;
return DenseTensorType::get(ctx, dtype, dim, layout, lod, offset);
};
const auto& HandleTensorArray = [&](pir::IrContext* ctx,
const VarDesc& var_desc) -> pir::Type {
VLOG(10) << "[vartype translating]"
<< "[" << var_desc.Name() << "] from LOD_TENSOR_ARRAY";
const pir::Type dtype =
this->operator[](var_desc.GetDataType())(ctx, var_desc);
const auto dims = common::make_ddim(var_desc.GetShape());
const auto layout = DataLayout::NCHW;
return paddle::dialect::DenseTensorArrayType::get(ctx, dtype, dims, layout);
};

const auto& HandleSelectedRows = [&](pir::IrContext* ctx,
const VarDesc& var_desc) -> pir::Type {
VLOG(10) << "[vartype translating]"
<< "[" << var_desc.Name() << "] from SELECTED_ROWS";
const pir::Type dtype =
this->operator[](var_desc.GetDataType())(ctx, var_desc);
const auto dim = common::make_ddim(var_desc.GetShape());
const auto layout = DataLayout::NCHW;
const LoD lod = {};
const size_t offset = 0;
pir::Type SelectedRows =
SelectedRowsType::get(ctx, dtype, dim, layout, lod, offset);
return SelectedRows;
};

handlers = {
{VarType::BOOL,
[&](pir::IrContext* ctx, const VarDesc& var_desc) -> pir::Type {
Expand Down Expand Up @@ -81,52 +121,9 @@ TypeTranslator::TypeTranslator() {
[&](pir::IrContext* ctx, const VarDesc& var_desc) -> pir::Type {
return pir::Complex128Type::get(ctx);
}},
{VarType::LOD_TENSOR,
[&](pir::IrContext* ctx, const VarDesc& var_desc) -> pir::Type {
VLOG(10) << "[vartype translating]"
<< "[" << var_desc.Name() << "] from LOD_TENSOR";

pir::Type dtype =
this->operator[](var_desc.GetDataType())(ctx, var_desc);
DenseTensorTypeStorage::Dim dim =
common::make_ddim(var_desc.GetShape());
DenseTensorTypeStorage::DataLayout layout =
DenseTensorTypeStorage::DataLayout::UNDEFINED;
DenseTensorTypeStorage::LoD lod = {};
size_t offset = 0;
return DenseTensorType::get(ctx, dtype, dim, layout, lod, offset);
}},
{VarType::LOD_TENSOR_ARRAY,
[&](pir::IrContext* ctx, const VarDesc& var_desc) -> pir::Type {
VLOG(10) << "[vartype translating]"
<< "[" << var_desc.Name() << "] from LOD_TENSOR_ARRAY";
pir::Type dtype =
this->operator[](var_desc.GetDataType())(ctx, var_desc);
phi::DDim dims = common::make_ddim(var_desc.GetShape());
DenseTensorTypeStorage::DataLayout layout =
DenseTensorTypeStorage::DataLayout::UNDEFINED;

return paddle::dialect::DenseTensorArrayType::get(
ctx, dtype, dims, layout);
}},
{VarType::SELECTED_ROWS,
[&](pir::IrContext* ctx, const VarDesc& var_desc) -> pir::Type {
VLOG(10) << "[vartype translating]"
<< "[" << var_desc.Name() << "] from SELECTED_ROWS";

pir::Type dtype =
this->operator[](var_desc.GetDataType())(ctx, var_desc);

SelectedRowsTypeStorage::Dim dim =
common::make_ddim(var_desc.GetShape());
SelectedRowsTypeStorage::DataLayout layout =
SelectedRowsTypeStorage::DataLayout::UNDEFINED;
SelectedRowsTypeStorage::LoD lod = {};
size_t offset = 0;
pir::Type SelectedRows =
SelectedRowsType::get(ctx, dtype, dim, layout, lod, offset);
return SelectedRows;
}},
{VarType::LOD_TENSOR, HandleTensor},
{VarType::LOD_TENSOR_ARRAY, HandleTensorArray},
{VarType::SELECTED_ROWS, HandleSelectedRows},
};
}

Expand Down

0 comments on commit 6f608ca

Please sign in to comment.